|
|
|
|
|
import torch |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
import io |
|
import json |
|
|
|
|
|
CLASS_LABELS = [ |
|
"glove_outline", |
|
"webbing", |
|
"thumb", |
|
"palm_pocket", |
|
"hand", |
|
"glove_exterior" |
|
] |
|
|
|
|
|
|
|
|
|
def load_model(): |
|
model = torch.load("pytorch_model.bin", map_location="cpu") |
|
model.eval() |
|
return model |
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
|
transform = T.Compose([ |
|
T.Resize((720, 1280)), |
|
T.ToTensor() |
|
]) |
|
|
|
def preprocess(input_bytes): |
|
image = Image.open(io.BytesIO(input_bytes)).convert("RGB") |
|
tensor = transform(image).unsqueeze(0) |
|
return tensor |
|
|
|
|
|
|
|
|
|
class DummyInput: |
|
def __init__(self, image_tensor): |
|
B, C, H, W = image_tensor.shape |
|
self.images = image_tensor |
|
self.masks = [torch.zeros(B, H, W, dtype=torch.bool)] |
|
self.num_frames = 1 |
|
self.original_size = [(H, W)] |
|
self.target_size = [(H, W)] |
|
self.point_coords = [None] |
|
self.point_labels = [None] |
|
self.boxes = [None] |
|
self.mask_inputs = torch.zeros(B, 1, H, W) |
|
self.video_mask = torch.zeros(B, 1, H, W) |
|
self.flat_obj_to_img_idx = [[0]] |
|
|
|
|
|
|
|
|
|
def postprocess(output_tensor): |
|
if isinstance(output_tensor, dict) and "masks" in output_tensor: |
|
logits = output_tensor["masks"] |
|
else: |
|
logits = output_tensor |
|
pred = torch.argmax(logits, dim=1)[0].cpu().numpy() |
|
return pred.tolist() |
|
|
|
|
|
|
|
|
|
def infer(payload): |
|
if isinstance(payload, bytes): |
|
image_tensor = preprocess(payload) |
|
elif isinstance(payload, dict) and "inputs" in payload: |
|
from base64 import b64decode |
|
image_tensor = preprocess(b64decode(payload["inputs"])) |
|
else: |
|
raise ValueError("Unsupported input format") |
|
|
|
input_obj = DummyInput(image_tensor) |
|
|
|
with torch.no_grad(): |
|
output = model(input_obj) |
|
|
|
mask = postprocess(output) |
|
return { |
|
"mask": mask, |
|
"classes": CLASS_LABELS |
|
} |
|
|