File size: 1,186 Bytes
643b8b8 5280e8d 643b8b8 5280e8d 643b8b8 5280e8d 643b8b8 5280e8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from pylate import models
from transformers import AutoTokenizer
import torch
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
self.model = models.ColBERT(model_name_or_path=path)
self.model.eval()
def _to_list(self, emb):
"""
Make the output JSON-serialisable:
β torch.Tensor β emb.cpu().tolist()
β np.ndarray β emb.tolist()
β list[...] β recurse
"""
if isinstance(emb, torch.Tensor):
return emb.cpu().tolist()
if isinstance(emb, np.ndarray):
return emb.tolist()
if isinstance(emb, list):
return [self._to_list(e) for e in emb]
return emb # already plain Python
def __call__(self, data):
texts = data.get("inputs") or data.get("text") or data
if isinstance(texts, str):
texts = [texts]
with torch.no_grad():
emb = self.model.encode(
texts,
is_query=True,
batch_size=32,
)
return self._to_list(emb)
|