|
|
from pylate import models |
|
|
from transformers import AutoTokenizer |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) |
|
|
self.model = models.ColBERT(model_name_or_path=path) |
|
|
self.model.eval() |
|
|
|
|
|
def _to_list(self, emb): |
|
|
""" |
|
|
Make the output JSON-serialisable: |
|
|
β torch.Tensor β emb.cpu().tolist() |
|
|
β np.ndarray β emb.tolist() |
|
|
β list[...] β recurse |
|
|
""" |
|
|
if isinstance(emb, torch.Tensor): |
|
|
return emb.cpu().tolist() |
|
|
if isinstance(emb, np.ndarray): |
|
|
return emb.tolist() |
|
|
if isinstance(emb, list): |
|
|
return [self._to_list(e) for e in emb] |
|
|
return emb |
|
|
|
|
|
def __call__(self, data): |
|
|
texts = data.get("inputs") or data.get("text") or data |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
with torch.no_grad(): |
|
|
emb = self.model.encode( |
|
|
texts, |
|
|
is_query=True, |
|
|
batch_size=32, |
|
|
) |
|
|
|
|
|
return self._to_list(emb) |
|
|
|