caball21 commited on
Commit
a21ff88
·
verified ·
1 Parent(s): 755ff1c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -2
README.md CHANGED
@@ -4,9 +4,10 @@ license: mit
4
  tags:
5
  - pytorch
6
  - image-segmentation
7
- - sam
8
  - glove
9
  - baseball
 
10
  - computer-vision
11
  - custom-model
12
  library_name: pytorch
@@ -15,8 +16,63 @@ datasets:
15
  metrics:
16
  - dice
17
  - iou
18
- inference: false
19
  widget: []
20
  model-index:
21
  - name: glove_labelling
22
  results: []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  tags:
5
  - pytorch
6
  - image-segmentation
7
+ - sam2
8
  - glove
9
  - baseball
10
+ - sports-analytics
11
  - computer-vision
12
  - custom-model
13
  library_name: pytorch
 
16
  metrics:
17
  - dice
18
  - iou
19
+ inference: true
20
  widget: []
21
  model-index:
22
  - name: glove_labelling
23
  results: []
24
+ ---
25
+
26
+ # Glove Labelling Model (SAM2 fine-tuned)
27
+
28
+ This repository contains a fine-tuned [SAM2](https://github.com/facebookresearch/sam2) hierarchical image segmentation model adapted for high-precision baseball glove segmentation.
29
+
30
+ ### 💡 What it does
31
+
32
+ Given a frame from a pitching video, this model outputs per-pixel segmentations for:
33
+
34
+ - `glove_outline`
35
+ - `webbing`
36
+ - `thumb`
37
+ - `palm_pocket`
38
+ - `hand`
39
+ - `glove_exterior`
40
+
41
+ Trained on individual pitch frame sequences using COCO format masks.
42
+
43
+ ---
44
+
45
+ ### 🏗 Architecture
46
+
47
+ - Base Model: `SAM2Hierarchical`
48
+ - Framework: PyTorch
49
+ - Input shape: `[1, 3, 720, 1280]` RGB frame
50
+ - Output: Segmentation logits across 6 glove-related classes
51
+
52
+ ---
53
+
54
+ ### 🔧 Usage
55
+
56
+ To use the model for inference:
57
+
58
+ ```python
59
+ import torch
60
+ from PIL import Image
61
+ import torchvision.transforms as T
62
+
63
+ model = torch.load("pytorch_model.bin", map_location="cpu")
64
+ model.eval()
65
+
66
+ transform = T.Compose([
67
+ T.Resize((720, 1280)),
68
+ T.ToTensor()
69
+ ])
70
+
71
+ img = Image.open("example.jpg").convert("RGB")
72
+ x = transform(img).unsqueeze(0)
73
+
74
+ with torch.no_grad():
75
+ output = model(x)
76
+
77
+ # Convert logits to class labels
78
+ pred_mask = output.argmax(dim=1).squeeze().cpu().numpy()