metadata
license: mit
language:
- ru
- en
tags:
- PyTorch
- Transformers
ru-en RoBERTa large model for Sentence Embeddings in Russian and English.
The model is described in this article
Russian MTEB metrics
For better quality, use cls token embeddings. Also, use next prefixes for tasks:
- For assimethric retrieval tasks like search/QuestAnsw: "search_query: "/"search_document: ".
- NLI, NLU and paraphrasing tasks: "classification: ".
- Title body/abstract and clasterization: "clustering: ".
Usage (HuggingFace Models Repository)
You can use the model directly from the model repository to compute sentence embeddings:
from transformers import AutoTokenizer, AutoModel
import torch
#You might to use two variants of mode for embeddings creation:
#CLS token embs or MEAN Pooling.
#You can choose embs pooling with best quality for your downstream tasks.
#Mean Pooling example - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
#Sentences we want sentence embeddings for
sentences = ['Привет! Как твои дела?',
'А правда, что 42 твое любимое число?']
#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa")
model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa")
#Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors='pt')
#Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
#In this case, mean pooling
sentence_mean_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#In this case, cls "pooling"
last_hidden_states = model_output[0]
sentence_cls_embeddings = last_hidden_states[:,0]