LDanielBlueway's picture
Update handler.py
9639d94 verified
raw
history blame
1.83 kB
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
from transformers import AutoProcessor, OmDetTurboForObjectDetection
import base64
import logging
class EndpointHandler():
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained("Blueway/inference-endpoint-for-omdet-turbo-swin-tiny-hf")
self.model = OmDetTurboForObjectDetection.from_pretrained("Blueway/inference-endpoint-for-omdet-turbo-swin-tiny-hf")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
image (:obj:`string`)
candidates (:obj:`list`)
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
"""
inputs_request = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs_request['image'])))
# run prediction one image wit provided candiates
inputs = self.processor(image, text=inputs_request["candidates"], return_tensors="pt")
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
classes=inputs_request["candidates"],
target_sizes=[image.size[::-1]],
score_threshold=0.3,
nms_threshold=0.3,
)[0]
# Convert tensors to lists
serializable_results = {
'boxes': results['boxes'].tolist(),
'scores': results['scores'].tolist(),
'candidates': results['classes'] # Already serializable
}
return serializable_results
#prediction = self.pipeline(image=[image], candidate_labels=inputs["candidates"])
#return prediction[0]