File size: 11,425 Bytes
c71eb7c
914c8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c71eb7c
914c8c8
 
c71eb7c
914c8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
---

language:
- ru
- en

pipeline_tag: sentence-similarity

tags:
- russian
- pretraining
- embeddings
- feature-extraction
- sentence-similarity
- sentence-transformers
- transformers

datasets:
- IlyaGusev/gazeta
- zloelias/lenta-ru
- HuggingFaceFW/fineweb-2
- HuggingFaceFW/fineweb

license: mit
base_model: sergeyzh/LaBSE-ru-turbo

---


## BERTA

Модель для расчетов эмбеддингов предложений на русском и английском языках получена методом дистилляции эмбеддингов [ai-forever/FRIDA](https://huggingface.co/ai-forever/FRIDA) (размер эмбеддингов - 1536, слоёв - 24) в [sergeyzh/LaBSE-ru-turbo](https://huggingface.co/sergeyzh/LaBSE-ru-turbo) (размер эмбеддингов - 768, слоёв - 12). Основной режим использования FRIDA - CLS pooling заменен на mean pooling. Каких-либо других  изменений поведения модели не производилось. Дистиляция выполнена в максимально возможном объеме - эмбеддинги русских и английских предложений, работа префиксов. 

Размер контекста модели соответствует FRIDA - 512 токенов.

## Префиксы
Все префиксы унаследованы от FRIDA. 
Оптимальный (обеспечивающий средние результаты) префикс для большинства задач - "categorize_entailment: " прописан по умолчанию в [config_sentence_transformers.json](https://huggingface.co/sergeyzh/BERTA/blob/main/config_sentence_transformers.json)



Перечень используемых префиксов и их влияние на оценки модели в [encodechka](https://github.com/avidale/encodechka):



| Префикс                | STS       | PI        | NLI       | SA        | TI        |

|:-----------------------|:---------:|:---------:|:---------:|:---------:|:---------:|

| -                      |   0,842   |   0,757   |   0,463   | **0,830** |   0,985   |

| search_query:          |   0,853   |   0,767   |   0,479   |   0,825   |   0,987   |
| search_document:       |   0,831   |   0,749   |   0,463   |   0,817   |   0,986   |

| paraphrase:            |   0,847   | **0,778** |   0,446   |   0,825   |   0,986   |

| categorize:            | **0,857** |   0,765   |   0,501   |   0,829   | **0,988** |

| categorize_sentiment:  |   0,589   |   0,535   |   0,417   |   0,805   |   0,982   |
| categorize_topic:      |   0,740   |   0,521   |   0,396   |   0,770   |   0,982   |

| categorize_entailment: |   0,841   |   0,762   | **0,571** |   0,827   |   0,986   |


**Задачи:**

- Semantic text similarity (**STS**);
- Paraphrase identification (**PI**);
- Natural language inference (**NLI**);
- Sentiment analysis (**SA**);
- Toxicity identification (**TI**).



# Метрики
Оценки модели на бенчмарке [ruMTEB](https://habr.com/ru/companies/sberdevices/articles/831150/):

|Model Name                      | Metric              | FRIDA     | BERTA     | [rubert-mini-frida](https://huggingface.co/sergeyzh/rubert-mini-frida)   | multilingual-e5-large-instruct | multilingual-e5-large |
|:-------------------------------|:--------------------|----------:|----------:|--------------------:|---------------------:|----------------------:|
|CEDRClassification              | Accuracy            | **0.646** |   0.622   |        0.552        |        0.500         |         0.448         |
|GeoreviewClassification         | Accuracy            | **0.577** |   0.548   |        0.464        |        0.559         |         0.497         |
|GeoreviewClusteringP2P          | V-measure           | **0.783** |   0.738   |        0.698        |        0.743         |         0.605         |
|HeadlineClassification          | Accuracy            |   0.890   | **0.891** |        0.880        |        0.862         |         0.758         |
|InappropriatenessClassification | Accuracy            | **0.783** |   0.748   |        0.698        |        0.655         |         0.616         |
|KinopoiskClassification         | Accuracy            | **0.705** |   0.678   |        0.595        |        0.661         |         0.566         |
|RiaNewsRetrieval                | NDCG@10             | **0.868** |   0.816   |        0.721        |        0.824         |         0.807         |
|RuBQReranking                   | MAP@10              | **0.771** |   0.752   |        0.711        |        0.717         |         0.756         |
|RuBQRetrieval                   | NDCG@10             |   0.724   |   0.710   |        0.654        |        0.692         |       **0.741**       |
|RuReviewsClassification         | Accuracy            | **0.751** |   0.723   |        0.658        |        0.686         |         0.653         |
|RuSTSBenchmarkSTS               | Pearson correlation |   0.814   |   0.822   |        0.803        |      **0.840**       |         0.831         |
|RuSciBenchGRNTIClassification   | Accuracy            | **0.699** |   0.690   |        0.625        |        0.651         |         0.582         |
|RuSciBenchGRNTIClusteringP2P    | V-measure           | **0.670** |   0.650   |        0.586        |        0.622         |         0.520         |
|RuSciBenchOECDClassification    | Accuracy            |   0.546   | **0.555** |        0.493        |        0.502         |         0.445         |
|RuSciBenchOECDClusteringP2P     | V-measure           | **0.566** |   0.556   |        0.507        |        0.528         |         0.450         |
|SensitiveTopicsClassification   | Accuracy            |   0.398   | **0.399** |        0.373        |        0.323         |         0.257         |
|TERRaClassification             | Average Precision   | **0.665** |   0.657   |        0.606        |        0.639         |         0.584         |
								 

|Model Name                      | Metric              | FRIDA     | BERTA     | rubert-mini-frida   | multilingual-e5-large-instruct | multilingual-e5-large |

|:-------------------------------|:--------------------|----------:|----------:|--------------------:|----------------------:|---------------------:|

|Classification                  | Accuracy            | **0.707** |   0.698   |        0.631        |        0.654          |        0.588         |

|Clustering                      | V-measure           | **0.673** |   0.648   |        0.597        |        0.631          |        0.525         |

|MultiLabelClassification        | Accuracy            | **0.522** |   0.510   |        0.463        |        0.412          |        0.353         |

|PairClassification              | Average Precision   | **0.665** |   0.657   |        0.606        |        0.639          |        0.584         |

|Reranking                       | MAP@10              | **0.771** |   0.752   |        0.711        |        0.717          |        0.756         |

|Retrieval                       | NDCG@10             | **0.796** |   0.763   |        0.687        |        0.758          |        0.774         |

|STS                             | Pearson correlation |   0.814   |   0.822   |        0.803        |      **0.840**        |        0.831         |

|Average                         | Average             | **0.707** |   0.693   |        0.643        |        0.664          |        0.630         |




## Использование модели с библиотекой `transformers`:

```python

import torch

import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel





def pool(hidden_state, mask, pooling_method="mean"):

    if pooling_method == "mean":

        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)

        d = mask.sum(axis=1, keepdim=True).float()

        return s / d

    elif pooling_method == "cls":

        return hidden_state[:, 0]



inputs = [

    # 

    "paraphrase: В Ярославской области разрешили работу бань, но без посетителей",

    "categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.",

    "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",

    # 

    "paraphrase: Ярославским баням разрешили работать без посетителей",

    "categorize_entailment: Женщину спасают врачи.",

    "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование."

]



tokenizer = AutoTokenizer.from_pretrained("sergeyzh/BERTA")

model = AutoModel.from_pretrained("sergeyzh/BERTA")



tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")



with torch.no_grad():

    outputs = model(**tokenized_inputs)

    

embeddings = pool(

    outputs.last_hidden_state, 

    tokenized_inputs["attention_mask"],

    pooling_method="mean"

)



embeddings = F.normalize(embeddings, p=2, dim=1)

sim_scores = embeddings[:3] @ embeddings[3:].T

print(sim_scores.diag().tolist())

# [0.9530372023582458, 0.866746723651886,  0.7839133143424988]

# [0.9360030293464661, 0.8591322302818298, 0.728583037853241] - FRIDA

```

## Использование с `sentence_transformers` (sentence-transformers>=2.4.0):



```python

from sentence_transformers import SentenceTransformer

# loads model with mean pooling
model = SentenceTransformer("sergeyzh/BERTA")

paraphrase = model.encode(["В Ярославской области разрешили работу бань, но без посетителей", "Ярославским баням разрешили работать без посетителей"], prompt="paraphrase: ")
print(paraphrase[0] @ paraphrase[1].T) 
# 0.9530372
# 0.9360032 - FRIDA

categorize_entailment = model.encode(["Женщину доставили в больницу, за ее жизнь сейчас борются врачи.", "Женщину спасают врачи."], prompt="categorize_entailment: ")
print(categorize_entailment[0] @ categorize_entailment[1].T) 
# 0.8667469
# 0.8591322 - FRIDA

query_embedding = model.encode("Сколько программистов нужно, чтобы вкрутить лампочку?", prompt="search_query: ")
document_embedding = model.encode("Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", prompt="search_document: ")
print(query_embedding @ document_embedding.T) 
# 0.7839136
# 0.7285831 - FRIDA
```