glove_labelling / handler.py
caball21's picture
Update handler.py
c29e78b verified
raw
history blame
2.42 kB
# handler.py
import torch
import torchvision.transforms as T
from PIL import Image
import io
import json
# Define class labels (must match training order)
CLASS_LABELS = [
"glove_outline",
"webbing",
"thumb",
"palm_pocket",
"hand",
"glove_exterior"
]
# ----------------------------
# Load model directly from full .bin
# ----------------------------
def load_model():
model = torch.load("pytorch_model.bin", map_location="cpu")
model.eval()
return model
model = load_model()
# ----------------------------
# Preprocessing
# ----------------------------
transform = T.Compose([
T.Resize((720, 1280)),
T.ToTensor()
])
def preprocess(input_bytes):
image = Image.open(io.BytesIO(input_bytes)).convert("RGB")
tensor = transform(image).unsqueeze(0) # [1, 3, H, W]
return tensor
# ----------------------------
# Dummy input wrapper
# ----------------------------
class DummyInput:
def __init__(self, image_tensor):
B, C, H, W = image_tensor.shape
self.images = image_tensor
self.masks = [torch.zeros(B, H, W, dtype=torch.bool)]
self.num_frames = 1
self.original_size = [(H, W)]
self.target_size = [(H, W)]
self.point_coords = [None]
self.point_labels = [None]
self.boxes = [None]
self.mask_inputs = torch.zeros(B, 1, H, W)
self.video_mask = torch.zeros(B, 1, H, W)
self.flat_obj_to_img_idx = [[0]]
# ----------------------------
# Postprocessing
# ----------------------------
def postprocess(output_tensor):
if isinstance(output_tensor, dict) and "masks" in output_tensor:
logits = output_tensor["masks"]
else:
logits = output_tensor
pred = torch.argmax(logits, dim=1)[0].cpu().numpy()
return pred.tolist()
# ----------------------------
# Inference Entry Point
# ----------------------------
def infer(payload):
if isinstance(payload, bytes):
image_tensor = preprocess(payload)
elif isinstance(payload, dict) and "inputs" in payload:
from base64 import b64decode
image_tensor = preprocess(b64decode(payload["inputs"]))
else:
raise ValueError("Unsupported input format")
input_obj = DummyInput(image_tensor)
with torch.no_grad():
output = model(input_obj)
mask = postprocess(output)
return {
"mask": mask,
"classes": CLASS_LABELS
}