File size: 2,415 Bytes
c75120d
 
 
 
 
 
 
 
c29e78b
c75120d
 
 
 
 
 
 
 
 
fc03d28
c29e78b
fc03d28
c75120d
c29e78b
c75120d
 
 
 
 
fc03d28
 
 
c75120d
fc03d28
c75120d
 
 
 
 
 
 
 
c29e78b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc03d28
 
 
c75120d
c29e78b
 
 
 
 
fc03d28
c75120d
fc03d28
 
 
c75120d
 
 
 
 
 
 
 
 
c29e78b
 
c75120d
c29e78b
c75120d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# handler.py

import torch
import torchvision.transforms as T
from PIL import Image
import io
import json

# Define class labels (must match training order)
CLASS_LABELS = [
    "glove_outline",
    "webbing",
    "thumb",
    "palm_pocket",
    "hand",
    "glove_exterior"
]

# ----------------------------
# Load model directly from full .bin
# ----------------------------
def load_model():
    model = torch.load("pytorch_model.bin", map_location="cpu")
    model.eval()
    return model

model = load_model()

# ----------------------------
# Preprocessing
# ----------------------------
transform = T.Compose([
    T.Resize((720, 1280)),
    T.ToTensor()
])

def preprocess(input_bytes):
    image = Image.open(io.BytesIO(input_bytes)).convert("RGB")
    tensor = transform(image).unsqueeze(0)  # [1, 3, H, W]
    return tensor

# ----------------------------
# Dummy input wrapper
# ----------------------------
class DummyInput:
    def __init__(self, image_tensor):
        B, C, H, W = image_tensor.shape
        self.images = image_tensor
        self.masks = [torch.zeros(B, H, W, dtype=torch.bool)]
        self.num_frames = 1
        self.original_size = [(H, W)]
        self.target_size = [(H, W)]
        self.point_coords = [None]
        self.point_labels = [None]
        self.boxes = [None]
        self.mask_inputs = torch.zeros(B, 1, H, W)
        self.video_mask = torch.zeros(B, 1, H, W)
        self.flat_obj_to_img_idx = [[0]]

# ----------------------------
# Postprocessing
# ----------------------------
def postprocess(output_tensor):
    if isinstance(output_tensor, dict) and "masks" in output_tensor:
        logits = output_tensor["masks"]
    else:
        logits = output_tensor
    pred = torch.argmax(logits, dim=1)[0].cpu().numpy()
    return pred.tolist()

# ----------------------------
# Inference Entry Point
# ----------------------------
def infer(payload):
    if isinstance(payload, bytes):
        image_tensor = preprocess(payload)
    elif isinstance(payload, dict) and "inputs" in payload:
        from base64 import b64decode
        image_tensor = preprocess(b64decode(payload["inputs"]))
    else:
        raise ValueError("Unsupported input format")

    input_obj = DummyInput(image_tensor)

    with torch.no_grad():
        output = model(input_obj)

    mask = postprocess(output)
    return {
        "mask": mask,
        "classes": CLASS_LABELS
    }