This is a cross-encoder model trained to predict semantic equivalence of two Russian sentences.
It classifies text pairs as paraphrases (class 1) or non-paraphrases (class 0). Its scores can be used as a metric of content preservation for paraphrasing or text style transfer.
It is a sberbank-ai/ruRoberta-large model fine-tuned on a union of 3 datasets:
RuPAWS
: https://github.com/ivkrotova/rupaws_dataset based on Quora and QQP;ru_paraphraser
: https://huggingface.co/merionum/ru_paraphraser;- Results of the manual check of content preservation for the RUSSE-2022 text detoxification dataset collection (
content_5.tsv
).
The task was formulated as binary classification: whether the two sentences have the same meaning (1) or different (0).
The table shows the training dataset size after duplication (joining text1 + text2
and text2 + text1
pairs):
source \ label | 0 | 1 |
---|---|---|
detox | 1412 | 3843 |
paraphraser | 5539 | 1688 |
rupaws_qqp | 1112 | 792 |
rupaws_wiki | 3526 | 2166 |
The model was trained with Adam optimizer and the following hyperparameters:
learning_rate = 1e-5
batch_size = 8
gradient_accumulation_steps = 4
n_epochs = 3
max_grad_norm = 1.0
After training, the model had the following ROC AUC scores on the test sets:
set | ROC AUC |
---|---|
detox | 0.857112 |
paraphraser | 0.858465 |
rupaws_qqp | 0.859195 |
rupaws_wiki | 0.906121 |
Example usage:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained('SkolkovoInstitute/ruRoberta-large-paraphrase-v1')
tokenizer = AutoTokenizer.from_pretrained('SkolkovoInstitute/ruRoberta-large-paraphrase-v1')
def get_similarity(text1, text2):
""" Predict the probability that two Russian sentences are paraphrases of each other. """
with torch.inference_mode():
batch = tokenizer(
text1, text2,
truncation=True, max_length=model.config.max_position_embeddings, return_tensors='pt',
).to(model.device)
proba = torch.softmax(model(**batch).logits, -1)
return proba[0][1].item()
print(get_similarity('Я тебя люблю', 'Ты мне нравишься')) # 0.9798
print(get_similarity('Я тебя люблю', 'Я тебя ненавижу')) # 0.0008
- Downloads last month
- 35