dewey_en_beta / scripts /evaluate /run_evaluate_mteb_dewey_en_beta.py
infgrad's picture
Upload 5 files
fbc1304 verified
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["OPENBLAS_NUM_THREADS"] = "32"
import mteb
import torch
import numpy as np
from mteb.encoder_interface import PromptType
from sentence_transformers import SentenceTransformer
TASK_NAME2TYPE = {
'ArguAna': 'Retrieval', 'ArXivHierarchicalClusteringP2P': 'Clustering',
'ArXivHierarchicalClusteringS2S': 'Clustering', 'AskUbuntuDupQuestions': 'Reranking',
'BIOSSES': 'STS', 'Banking77Classification': 'Classification',
'BiorxivClusteringP2P.v2': 'Clustering', 'CQADupstackGamingRetrieval': 'Retrieval',
'CQADupstackUnixRetrieval': 'Retrieval', 'ClimateFEVERHardNegatives': 'Retrieval',
'FEVERHardNegatives': 'Retrieval', 'FiQA2018': 'Retrieval', 'HotpotQAHardNegatives': 'Retrieval',
'ImdbClassification': 'Classification', 'MTOPDomainClassification': 'Classification',
'MassiveIntentClassification': 'Classification', 'MassiveScenarioClassification': 'Classification',
'MedrxivClusteringP2P.v2': 'Clustering', 'MedrxivClusteringS2S.v2': 'Clustering',
'MindSmallReranking': 'Reranking', 'SCIDOCS': 'Retrieval', 'SICK-R': 'STS', 'STS12': 'STS',
'STS13': 'STS', 'STS14': 'STS', 'STS15': 'STS', 'STSBenchmark': 'STS',
'SprintDuplicateQuestions': 'PairClassification', 'StackExchangeClustering.v2': 'Clustering',
'StackExchangeClusteringP2P.v2': 'Clustering', 'TRECCOVID': 'Retrieval',
'Touche2020Retrieval.v3': 'Retrieval', 'ToxicConversationsClassification': 'Classification',
'TweetSentimentExtractionClassification': 'Classification',
'TwentyNewsgroupsClustering.v2': 'Clustering', 'TwitterSemEval2015': 'PairClassification',
'TwitterURLCorpus': 'PairClassification', 'SummEvalSummarization.v2': 'Summarization',
'AmazonCounterfactualClassification': 'Classification', 'STS17': 'STS', 'STS22.v2': 'STS'
}
RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
STS_PROMPT = "<|START_INSTRUCTION|>Generate semantically similar text<|END_INSTRUCTION|>"
TASK_NAME2PROMPT = {
# Classification
"Banking77Classification": "<|START_INSTRUCTION|>Classify text into intents<|END_INSTRUCTION|>",
"ImdbClassification": "<|START_INSTRUCTION|>Classify text into sentiment<|END_INSTRUCTION|>",
"MTOPDomainClassification": "<|START_INSTRUCTION|>Classify text into intent domain<|END_INSTRUCTION|>",
"MassiveIntentClassification": "<|START_INSTRUCTION|>Classify text into user intents<|END_INSTRUCTION|>",
"MassiveScenarioClassification": "<|START_INSTRUCTION|>Classify text into user scenarios<|END_INSTRUCTION|>",
"ToxicConversationsClassification": "<|START_INSTRUCTION|>Classify text into toxic or not toxic<|END_INSTRUCTION|>",
"TweetSentimentExtractionClassification": "<|START_INSTRUCTION|>Classify text into positive, negative, or neutral sentiment<|END_INSTRUCTION|>",
"AmazonCounterfactualClassification": "<|START_INSTRUCTION|>Classify text into counterfactual or not-counterfactual<|END_INSTRUCTION|>",
# Clustering
"ArXivHierarchicalClusteringP2P": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
"ArXivHierarchicalClusteringS2S": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles<|END_INSTRUCTION|>",
"BiorxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Biorxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
"MedrxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
"MedrxivClusteringS2S.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles<|END_INSTRUCTION|>",
"StackExchangeClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the titles<|END_INSTRUCTION|>",
"StackExchangeClusteringP2P.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the given paragraphs<|END_INSTRUCTION|>",
"TwentyNewsgroupsClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of news articles<|END_INSTRUCTION|>",
}
class DeweyWrapper:
def __init__(self, model_dir, max_seq_length: int = 1536, batch_size: int = 8):
self.model = SentenceTransformer(
model_dir,
trust_remote_code=True,
model_kwargs={
"torch_dtype": torch.bfloat16, # fp16 容易计算出nan
"attn_implementation": "flash_attention_2"
},
config_kwargs={"single_vector_type": "cls_add_mean"}
).cuda().bfloat16().eval()
self.model.max_seq_length = max_seq_length
self.pool = self.model.start_multi_process_pool()
self.batch_size = batch_size
def encode(
self,
sentences: list[str],
task_name: str,
prompt_type: PromptType | None = None,
**kwargs,
) -> np.ndarray:
task_type = TASK_NAME2TYPE[task_name]
if task_type == "Retrieval":
if prompt_type.value == "query":
prompt = RETRIEVE_Q_PROMPT
else:
prompt = RETRIEVE_P_PROMPT
elif task_type in ["STS", "PairClassification", "Summarization", "Reranking"]:
prompt = STS_PROMPT
else:
prompt = TASK_NAME2PROMPT[task_name]
vectors = self.model.encode_multi_process(
sentences=sentences,
pool=self.pool,
show_progress_bar=True,
batch_size=self.batch_size,
normalize_embeddings=True,
prompt=prompt,
precision="float32"
)
return vectors
if __name__ == "__main__":
max_seq_length = 1536
batch_szie = 8
model_dir_or_name = "infgrad/dewey_en_beta"
output_folder = f"./mteb_eng_results/dewey_en_beta"
model = DeweyWrapper(model_dir_or_name, max_seq_length=max_seq_length, batch_size=batch_szie)
tasks = list(mteb.get_benchmark("MTEB(eng, v2)"))
evaluation = mteb.MTEB(tasks=tasks)
evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)