|
import os
|
|
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
os.environ["OPENBLAS_NUM_THREADS"] = "32"
|
|
import mteb
|
|
import torch
|
|
import numpy as np
|
|
from mteb.encoder_interface import PromptType
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
TASK_NAME2TYPE = {
|
|
'ArguAna': 'Retrieval', 'ArXivHierarchicalClusteringP2P': 'Clustering',
|
|
'ArXivHierarchicalClusteringS2S': 'Clustering', 'AskUbuntuDupQuestions': 'Reranking',
|
|
'BIOSSES': 'STS', 'Banking77Classification': 'Classification',
|
|
'BiorxivClusteringP2P.v2': 'Clustering', 'CQADupstackGamingRetrieval': 'Retrieval',
|
|
'CQADupstackUnixRetrieval': 'Retrieval', 'ClimateFEVERHardNegatives': 'Retrieval',
|
|
'FEVERHardNegatives': 'Retrieval', 'FiQA2018': 'Retrieval', 'HotpotQAHardNegatives': 'Retrieval',
|
|
'ImdbClassification': 'Classification', 'MTOPDomainClassification': 'Classification',
|
|
'MassiveIntentClassification': 'Classification', 'MassiveScenarioClassification': 'Classification',
|
|
'MedrxivClusteringP2P.v2': 'Clustering', 'MedrxivClusteringS2S.v2': 'Clustering',
|
|
'MindSmallReranking': 'Reranking', 'SCIDOCS': 'Retrieval', 'SICK-R': 'STS', 'STS12': 'STS',
|
|
'STS13': 'STS', 'STS14': 'STS', 'STS15': 'STS', 'STSBenchmark': 'STS',
|
|
'SprintDuplicateQuestions': 'PairClassification', 'StackExchangeClustering.v2': 'Clustering',
|
|
'StackExchangeClusteringP2P.v2': 'Clustering', 'TRECCOVID': 'Retrieval',
|
|
'Touche2020Retrieval.v3': 'Retrieval', 'ToxicConversationsClassification': 'Classification',
|
|
'TweetSentimentExtractionClassification': 'Classification',
|
|
'TwentyNewsgroupsClustering.v2': 'Clustering', 'TwitterSemEval2015': 'PairClassification',
|
|
'TwitterURLCorpus': 'PairClassification', 'SummEvalSummarization.v2': 'Summarization',
|
|
'AmazonCounterfactualClassification': 'Classification', 'STS17': 'STS', 'STS22.v2': 'STS'
|
|
}
|
|
|
|
RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
|
|
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
|
|
STS_PROMPT = "<|START_INSTRUCTION|>Generate semantically similar text<|END_INSTRUCTION|>"
|
|
|
|
TASK_NAME2PROMPT = {
|
|
|
|
"Banking77Classification": "<|START_INSTRUCTION|>Classify text into intents<|END_INSTRUCTION|>",
|
|
"ImdbClassification": "<|START_INSTRUCTION|>Classify text into sentiment<|END_INSTRUCTION|>",
|
|
"MTOPDomainClassification": "<|START_INSTRUCTION|>Classify text into intent domain<|END_INSTRUCTION|>",
|
|
"MassiveIntentClassification": "<|START_INSTRUCTION|>Classify text into user intents<|END_INSTRUCTION|>",
|
|
"MassiveScenarioClassification": "<|START_INSTRUCTION|>Classify text into user scenarios<|END_INSTRUCTION|>",
|
|
"ToxicConversationsClassification": "<|START_INSTRUCTION|>Classify text into toxic or not toxic<|END_INSTRUCTION|>",
|
|
"TweetSentimentExtractionClassification": "<|START_INSTRUCTION|>Classify text into positive, negative, or neutral sentiment<|END_INSTRUCTION|>",
|
|
"AmazonCounterfactualClassification": "<|START_INSTRUCTION|>Classify text into counterfactual or not-counterfactual<|END_INSTRUCTION|>",
|
|
|
|
|
|
"ArXivHierarchicalClusteringP2P": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
|
"ArXivHierarchicalClusteringS2S": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles<|END_INSTRUCTION|>",
|
|
"BiorxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Biorxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
|
"MedrxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
|
"MedrxivClusteringS2S.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles<|END_INSTRUCTION|>",
|
|
"StackExchangeClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the titles<|END_INSTRUCTION|>",
|
|
"StackExchangeClusteringP2P.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the given paragraphs<|END_INSTRUCTION|>",
|
|
"TwentyNewsgroupsClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of news articles<|END_INSTRUCTION|>",
|
|
}
|
|
|
|
|
|
class DeweyWrapper:
|
|
def __init__(self, model_dir, max_seq_length: int = 1536, batch_size: int = 8):
|
|
self.model = SentenceTransformer(
|
|
model_dir,
|
|
trust_remote_code=True,
|
|
model_kwargs={
|
|
"torch_dtype": torch.bfloat16,
|
|
"attn_implementation": "flash_attention_2"
|
|
},
|
|
config_kwargs={"single_vector_type": "cls_add_mean"}
|
|
).cuda().bfloat16().eval()
|
|
self.model.max_seq_length = max_seq_length
|
|
self.pool = self.model.start_multi_process_pool()
|
|
self.batch_size = batch_size
|
|
|
|
def encode(
|
|
self,
|
|
sentences: list[str],
|
|
task_name: str,
|
|
prompt_type: PromptType | None = None,
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
task_type = TASK_NAME2TYPE[task_name]
|
|
if task_type == "Retrieval":
|
|
if prompt_type.value == "query":
|
|
prompt = RETRIEVE_Q_PROMPT
|
|
else:
|
|
prompt = RETRIEVE_P_PROMPT
|
|
elif task_type in ["STS", "PairClassification", "Summarization", "Reranking"]:
|
|
prompt = STS_PROMPT
|
|
else:
|
|
prompt = TASK_NAME2PROMPT[task_name]
|
|
vectors = self.model.encode_multi_process(
|
|
sentences=sentences,
|
|
pool=self.pool,
|
|
show_progress_bar=True,
|
|
batch_size=self.batch_size,
|
|
normalize_embeddings=True,
|
|
prompt=prompt,
|
|
precision="float32"
|
|
)
|
|
return vectors
|
|
|
|
|
|
if __name__ == "__main__":
|
|
max_seq_length = 1536
|
|
batch_szie = 8
|
|
model_dir_or_name = "infgrad/dewey_en_beta"
|
|
output_folder = f"./mteb_eng_results/dewey_en_beta"
|
|
model = DeweyWrapper(model_dir_or_name, max_seq_length=max_seq_length, batch_size=batch_szie)
|
|
|
|
tasks = list(mteb.get_benchmark("MTEB(eng, v2)"))
|
|
evaluation = mteb.MTEB(tasks=tasks)
|
|
evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
|
|
|