LDanielBlueway's picture
Update handler.py
c167b01 verified
raw
history blame
1.45 kB
from typing import Dict, List, Any
from io import BytesIO
import base64
import logging
from PIL import Image
import numpy as np
from transformers import AutoModel
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModel.from_pretrained('Blueway/Inference-endpoint-for-jina-clip-v1', trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
image (:obj:`string`)
candidates (:obj:`list`)
Return:
A :obj:`list`: une liste permettant de passer les embedding
"""
inputs_request = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs_request['image'])))
text = inputs_request['text']
if text is not None:
text_embedding = self.model.encode_text(text)
if image is not None:
image_embedding = self.model.encode_image(image)
# Convert embeddings to lists of floats
serializable_results = {
'text_embedding': (text_embedding.tolist() if isinstance(text_embedding, np.ndarray) else text_embedding) if text_embedding is not None else [],
'image_embedding': (image_embedding.tolist() if isinstance(image_embedding, np.ndarray) else image_embedding) if image_embeddingis not None else []
}
return serializable_results