caball21 commited on
Commit
c29e78b
·
verified ·
1 Parent(s): e8a6953

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -14
handler.py CHANGED
@@ -6,9 +6,7 @@ from PIL import Image
6
  import io
7
  import json
8
 
9
- from sam2_model_stub import SAM2Hierarchical # 👈 stub class we define separately
10
-
11
- # Define class labels (same order as training)
12
  CLASS_LABELS = [
13
  "glove_outline",
14
  "webbing",
@@ -19,17 +17,10 @@ CLASS_LABELS = [
19
  ]
20
 
21
  # ----------------------------
22
- # Load model weights + class
23
  # ----------------------------
24
  def load_model():
25
- model = SAM2Hierarchical(
26
- num_classes=len(CLASS_LABELS),
27
- in_channels=3,
28
- backbone="vit_b", # <-- match your config.yaml
29
- freeze_backbone=True,
30
- use_cls_head=True
31
- )
32
- model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
33
  model.eval()
34
  return model
35
 
@@ -48,11 +39,33 @@ def preprocess(input_bytes):
48
  tensor = transform(image).unsqueeze(0) # [1, 3, H, W]
49
  return tensor
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # ----------------------------
52
  # Postprocessing
53
  # ----------------------------
54
  def postprocess(output_tensor):
55
- pred = torch.argmax(output_tensor, dim=1)[0].cpu().numpy()
 
 
 
 
56
  return pred.tolist()
57
 
58
  # ----------------------------
@@ -67,8 +80,10 @@ def infer(payload):
67
  else:
68
  raise ValueError("Unsupported input format")
69
 
 
 
70
  with torch.no_grad():
71
- output = model(image_tensor)
72
 
73
  mask = postprocess(output)
74
  return {
 
6
  import io
7
  import json
8
 
9
+ # Define class labels (must match training order)
 
 
10
  CLASS_LABELS = [
11
  "glove_outline",
12
  "webbing",
 
17
  ]
18
 
19
  # ----------------------------
20
+ # Load model directly from full .bin
21
  # ----------------------------
22
  def load_model():
23
+ model = torch.load("pytorch_model.bin", map_location="cpu")
 
 
 
 
 
 
 
24
  model.eval()
25
  return model
26
 
 
39
  tensor = transform(image).unsqueeze(0) # [1, 3, H, W]
40
  return tensor
41
 
42
+ # ----------------------------
43
+ # Dummy input wrapper
44
+ # ----------------------------
45
+ class DummyInput:
46
+ def __init__(self, image_tensor):
47
+ B, C, H, W = image_tensor.shape
48
+ self.images = image_tensor
49
+ self.masks = [torch.zeros(B, H, W, dtype=torch.bool)]
50
+ self.num_frames = 1
51
+ self.original_size = [(H, W)]
52
+ self.target_size = [(H, W)]
53
+ self.point_coords = [None]
54
+ self.point_labels = [None]
55
+ self.boxes = [None]
56
+ self.mask_inputs = torch.zeros(B, 1, H, W)
57
+ self.video_mask = torch.zeros(B, 1, H, W)
58
+ self.flat_obj_to_img_idx = [[0]]
59
+
60
  # ----------------------------
61
  # Postprocessing
62
  # ----------------------------
63
  def postprocess(output_tensor):
64
+ if isinstance(output_tensor, dict) and "masks" in output_tensor:
65
+ logits = output_tensor["masks"]
66
+ else:
67
+ logits = output_tensor
68
+ pred = torch.argmax(logits, dim=1)[0].cpu().numpy()
69
  return pred.tolist()
70
 
71
  # ----------------------------
 
80
  else:
81
  raise ValueError("Unsupported input format")
82
 
83
+ input_obj = DummyInput(image_tensor)
84
+
85
  with torch.no_grad():
86
+ output = model(input_obj)
87
 
88
  mask = postprocess(output)
89
  return {