File size: 2,472 Bytes
a41c6a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from typing import Optional
from pydantic import BaseModel
from sentence_transformers.models import Transformer as BaseTransformer


class TextSpan(BaseModel):
    s: int
    e: int
    module_name: str
    text: Optional[str] = None


class DeweyTransformer(BaseTransformer):
    def __init__(
            self,
            model_name_or_path: str,
            **kwargs,
    ):
        self.single_vector_type = kwargs.get("config_args", {}).get("single_vector_type", "mean")
        super().__init__(model_name_or_path, **kwargs)

    def forward(
            self, features: dict[str, torch.Tensor], **kwargs
    ) -> dict[str, torch.Tensor]:
        prompt_length = features.get("prompt_length", 0)
        if prompt_length > 0:
            # in MondernBert, text is surrounded by [CLS] and [SEP]
            prompt_length -= 1
        batch_text_spans = []
        for data_len in features["attention_mask"].sum(dim=1):
            if self.single_vector_type == "cls":
                batch_text_spans.append(
                    [
                        TextSpan(s=0, e=1, module_name="cls_linear")
                    ]
                )
            elif self.single_vector_type == "mean":
                batch_text_spans.append(
                    [
                        TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear")
                    ]
                )
            elif self.single_vector_type == "cls_add_mean":
                batch_text_spans.append(
                    [
                        TextSpan(s=0, e=1, module_name="cls_linear"),
                        TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear")
                    ]
                )
            else:
                raise Exception("single_vector_type should be in {cls, mean or cls_add_mean}")

        trans_features = {
            "input_ids": features["input_ids"],
            "attention_mask": features["attention_mask"],
            "batch_text_spans": batch_text_spans,
            "normalize_embeddings": self.single_vector_type == "cls_add_mean",
        }
        # print(features["input_ids"].shape)
        vectors_list = self.auto_model(**trans_features, **kwargs)
        sentence_embedding = torch.cat(
            [vecs.mean(dim=0, keepdim=True) for vecs in vectors_list],
            dim=0
        )
        features.update({"sentence_embedding": sentence_embedding})
        return features