LDanielBlueway's picture
Update handler.py
0611e70 verified
raw
history blame
1.63 kB
from typing import Dict, List, Any
from io import BytesIO
import base64
import logging
import uform
from PIL import Image
import numpy as np
class EndpointHandler():
def __init__(self, path=""):
self.model, self.processor = uform.get_model('unum-cloud/uform-vl-multilingual-v2')
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
image (:obj:`string`)
candidates (:obj:`list`)
Return:
A :obj:`list`: une liste permettant de passer les embedding
"""
inputs_request = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs_request['image'])))
text = inputs_request['text']
image_data = self.processor.preprocess_image(image)
text_data = self.processor.preprocess_text(text)
image_features, image_embedding = self.model.encode_image(image_data)
text_features, text_embedding = self.model.encode_text(text_data)
joint_embedding = self.model.encode_multimodal(image=image_data, text=text_data)
# Convert embeddings to lists of floats
serializable_results = {
'joint_embedding': joint_embedding.tolist() if isinstance(joint_embedding, np.ndarray) else joint_embedding,
'text_embedding': text_embedding.tolist() if isinstance(text_embedding, np.ndarray) else text_embedding,
'image_embedding': image_embedding.tolist() if isinstance(image_embedding, np.ndarray) else image_embedding
}
return serializable_results