File size: 1,186 Bytes
643b8b8
 
 
5280e8d
643b8b8
 
 
 
 
 
 
5280e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643b8b8
 
 
 
 
 
 
 
5280e8d
643b8b8
 
5280e8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pylate import models
from transformers import AutoTokenizer
import torch
import numpy as np

class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
        self.model = models.ColBERT(model_name_or_path=path)
        self.model.eval()

    def _to_list(self, emb):
        """
        Make the output JSON-serialisable:
        – torch.Tensor  ➜  emb.cpu().tolist()
        – np.ndarray    ➜  emb.tolist()
        – list[...]     ➜  recurse
        """
        if isinstance(emb, torch.Tensor):
            return emb.cpu().tolist()
        if isinstance(emb, np.ndarray):
            return emb.tolist()
        if isinstance(emb, list):
            return [self._to_list(e) for e in emb]
        return emb  # already plain Python

    def __call__(self, data):
        texts = data.get("inputs") or data.get("text") or data
        if isinstance(texts, str):
            texts = [texts]

        with torch.no_grad():
            emb = self.model.encode(
                texts,
                is_query=True,
                batch_size=32,
            )

        return self._to_list(emb)