LDanielBlueway's picture
Update handler.py
6c3ce49 verified
raw
history blame
1.4 kB
from typing import Dict, List, Any
from io import BytesIO
import base64
import logging
from PIL import Image
import numpy as np
from transformers import AutoModel
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModel.from_pretrained('Blueway/Inference-endpoint-for-jina-clip-v1', trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
image (:obj:`string`)
candidates (:obj:`list`)
Return:
A :obj:`list`: une liste permettant de passer les embedding
"""
serializable_results = {}
inputs_request = data.pop("inputs", data)
if 'text' in inputs_request:
text = inputs_request['text']
text_embedding = self.model.encode_text(text)
serializable_results['text_embedding'] = text_embedding.tolist() if isinstance(text_embedding, np.ndarray) else text_embedding
if 'image' in inputs_request:
image = Image.open(BytesIO(base64.b64decode(inputs_request['image'])))
image_embedding = self.model.encode_image(image)
serializable_results['image_embedding'] = image_embedding.tolist() if isinstance(image_embedding, np.ndarray) else image_embedding
return serializable_results