Upload 5 files
Browse files
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
imgs/inference_architecture.png filter=lfs diff=lfs merge=lfs -text
|
imgs/inference_architecture.png
ADDED
![]() |
Git LFS Details
|
scripts/evaluate/run_evaluate_long_embed.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
4 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "32"
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import mteb
|
8 |
+
from mteb.encoder_interface import PromptType
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
from mteb.models.wrapper import Wrapper
|
11 |
+
from typing import Sequence
|
12 |
+
from typing import Any
|
13 |
+
from transformers import AutoTokenizer, AutoModel
|
14 |
+
|
15 |
+
|
16 |
+
class DeweySingleVectorWrapper:
|
17 |
+
def __init__(self, model_dir, batch_size: int = 8):
|
18 |
+
self.model = SentenceTransformer(
|
19 |
+
model_dir,
|
20 |
+
trust_remote_code=True,
|
21 |
+
model_kwargs={
|
22 |
+
"torch_dtype": torch.bfloat16, # fp16 瀹规槗璁$畻鍑簄an
|
23 |
+
"attn_implementation": "flash_attention_2"
|
24 |
+
},
|
25 |
+
config_kwargs={"single_vector_type": "mean"}
|
26 |
+
).cuda().bfloat16().eval()
|
27 |
+
self.model.max_seq_length = max_seq_length
|
28 |
+
self.pool = self.model.start_multi_process_pool()
|
29 |
+
self.batch_size = batch_size
|
30 |
+
|
31 |
+
def encode(
|
32 |
+
self,
|
33 |
+
sentences: list[str],
|
34 |
+
task_name: str,
|
35 |
+
prompt_type: PromptType | None = None,
|
36 |
+
**kwargs,
|
37 |
+
) -> np.ndarray:
|
38 |
+
if prompt_type.value == "query":
|
39 |
+
prompt = RETRIEVE_Q_PROMPT
|
40 |
+
else:
|
41 |
+
prompt = RETRIEVE_P_PROMPT
|
42 |
+
vectors = self.model.encode_multi_process(
|
43 |
+
sentences=sentences,
|
44 |
+
pool=self.pool,
|
45 |
+
show_progress_bar=True,
|
46 |
+
batch_size=self.batch_size,
|
47 |
+
normalize_embeddings=True,
|
48 |
+
prompt=prompt,
|
49 |
+
precision="float32"
|
50 |
+
)
|
51 |
+
return vectors
|
52 |
+
|
53 |
+
|
54 |
+
class DeweyMultiVectorWrapper(Wrapper):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
model_dir: str,
|
58 |
+
batch_size: int = 8,
|
59 |
+
*args,
|
60 |
+
**kwargs,
|
61 |
+
) -> None:
|
62 |
+
self.model = AutoModel.from_pretrained(
|
63 |
+
model_dir,
|
64 |
+
trust_remote_code=True,
|
65 |
+
attn_implementation="flash_attention_2"
|
66 |
+
).cuda().bfloat16()
|
67 |
+
self.batch_size = batch_size
|
68 |
+
self.model.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
69 |
+
|
70 |
+
def encode(
|
71 |
+
self,
|
72 |
+
sentences: Sequence[str],
|
73 |
+
*,
|
74 |
+
task_name: str,
|
75 |
+
prompt_type: PromptType | None = None,
|
76 |
+
**kwargs: Any,
|
77 |
+
) -> np.ndarray:
|
78 |
+
|
79 |
+
if prompt_type.value == "query":
|
80 |
+
prompt = RETRIEVE_Q_PROMPT
|
81 |
+
else:
|
82 |
+
prompt = RETRIEVE_P_PROMPT
|
83 |
+
if prompt_type.value == "query":
|
84 |
+
pred = self.model.encode(
|
85 |
+
sentences=list(sentences),
|
86 |
+
use_cuda=True,
|
87 |
+
show_progress_bar=True,
|
88 |
+
chunk_size=-1,
|
89 |
+
chunk_overlap=32,
|
90 |
+
convert_to_tensor=True,
|
91 |
+
max_seq_length=max_seq_length,
|
92 |
+
batch_size=self.batch_size,
|
93 |
+
normalize_embeddings=True,
|
94 |
+
prompt=prompt,
|
95 |
+
fast_chunk=False
|
96 |
+
|
97 |
+
)[0]
|
98 |
+
# query vector do not need multi vector, we only use mean as final one vector
|
99 |
+
pred = [vecs[1:2, :] for vecs in pred]
|
100 |
+
else:
|
101 |
+
pred = self.model.encode(
|
102 |
+
sentences=list(sentences),
|
103 |
+
use_cuda=True,
|
104 |
+
show_progress_bar=True,
|
105 |
+
chunk_size=256,
|
106 |
+
chunk_overlap=32,
|
107 |
+
convert_to_tensor=True,
|
108 |
+
max_seq_length=max_seq_length,
|
109 |
+
batch_size=self.batch_size,
|
110 |
+
normalize_embeddings=True,
|
111 |
+
prompt=prompt,
|
112 |
+
fast_chunk=True,
|
113 |
+
)[0]
|
114 |
+
|
115 |
+
pred = torch.nn.utils.rnn.pad_sequence(pred, batch_first=True, padding_value=0)
|
116 |
+
return pred.cpu().numpy()
|
117 |
+
|
118 |
+
def similarity(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
119 |
+
if not isinstance(a, torch.Tensor):
|
120 |
+
a = torch.tensor(a, dtype=torch.float32)
|
121 |
+
|
122 |
+
if not isinstance(b, torch.Tensor):
|
123 |
+
b = torch.tensor(b, dtype=torch.float32)
|
124 |
+
|
125 |
+
if len(a.shape) == 2:
|
126 |
+
a = a.unsqueeze(0)
|
127 |
+
|
128 |
+
if len(b.shape) == 2:
|
129 |
+
b = b.unsqueeze(0)
|
130 |
+
|
131 |
+
scores = torch.einsum(
|
132 |
+
"ash,bth->abst",
|
133 |
+
a,
|
134 |
+
b,
|
135 |
+
)
|
136 |
+
|
137 |
+
return scores.max(axis=-1).values.sum(axis=-1)
|
138 |
+
|
139 |
+
|
140 |
+
RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
|
141 |
+
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
################# evaluate single vector #################
|
145 |
+
# batch_size = 4
|
146 |
+
# max_seq_length = 128 * 1024
|
147 |
+
# model = DeweySingleVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size)
|
148 |
+
# output_folder = f"./long_embed_benchmark/dewey_en_beta_single_vector_128k"
|
149 |
+
# tasks = list(mteb.get_benchmark("LongEmbed"))
|
150 |
+
# evaluation = mteb.MTEB(tasks=tasks)
|
151 |
+
# evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
|
152 |
+
|
153 |
+
################# evaluate multi vectors #################
|
154 |
+
batch_size = 4
|
155 |
+
max_seq_length = 128 * 1024
|
156 |
+
model = DeweyMultiVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size)
|
157 |
+
output_folder = f"./long_embed_benchmark/dewey_en_beta_multi_vectors"
|
158 |
+
|
159 |
+
tasks = list(mteb.get_benchmark("LongEmbed"))
|
160 |
+
evaluation = mteb.MTEB(tasks=tasks)
|
161 |
+
evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
|
scripts/evaluate/run_evaluate_mteb_dewey_en_beta.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
4 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "32"
|
5 |
+
import mteb
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from mteb.encoder_interface import PromptType
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
|
11 |
+
TASK_NAME2TYPE = {
|
12 |
+
'ArguAna': 'Retrieval', 'ArXivHierarchicalClusteringP2P': 'Clustering',
|
13 |
+
'ArXivHierarchicalClusteringS2S': 'Clustering', 'AskUbuntuDupQuestions': 'Reranking',
|
14 |
+
'BIOSSES': 'STS', 'Banking77Classification': 'Classification',
|
15 |
+
'BiorxivClusteringP2P.v2': 'Clustering', 'CQADupstackGamingRetrieval': 'Retrieval',
|
16 |
+
'CQADupstackUnixRetrieval': 'Retrieval', 'ClimateFEVERHardNegatives': 'Retrieval',
|
17 |
+
'FEVERHardNegatives': 'Retrieval', 'FiQA2018': 'Retrieval', 'HotpotQAHardNegatives': 'Retrieval',
|
18 |
+
'ImdbClassification': 'Classification', 'MTOPDomainClassification': 'Classification',
|
19 |
+
'MassiveIntentClassification': 'Classification', 'MassiveScenarioClassification': 'Classification',
|
20 |
+
'MedrxivClusteringP2P.v2': 'Clustering', 'MedrxivClusteringS2S.v2': 'Clustering',
|
21 |
+
'MindSmallReranking': 'Reranking', 'SCIDOCS': 'Retrieval', 'SICK-R': 'STS', 'STS12': 'STS',
|
22 |
+
'STS13': 'STS', 'STS14': 'STS', 'STS15': 'STS', 'STSBenchmark': 'STS',
|
23 |
+
'SprintDuplicateQuestions': 'PairClassification', 'StackExchangeClustering.v2': 'Clustering',
|
24 |
+
'StackExchangeClusteringP2P.v2': 'Clustering', 'TRECCOVID': 'Retrieval',
|
25 |
+
'Touche2020Retrieval.v3': 'Retrieval', 'ToxicConversationsClassification': 'Classification',
|
26 |
+
'TweetSentimentExtractionClassification': 'Classification',
|
27 |
+
'TwentyNewsgroupsClustering.v2': 'Clustering', 'TwitterSemEval2015': 'PairClassification',
|
28 |
+
'TwitterURLCorpus': 'PairClassification', 'SummEvalSummarization.v2': 'Summarization',
|
29 |
+
'AmazonCounterfactualClassification': 'Classification', 'STS17': 'STS', 'STS22.v2': 'STS'
|
30 |
+
}
|
31 |
+
|
32 |
+
RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
|
33 |
+
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
|
34 |
+
STS_PROMPT = "<|START_INSTRUCTION|>Generate semantically similar text<|END_INSTRUCTION|>"
|
35 |
+
|
36 |
+
TASK_NAME2PROMPT = {
|
37 |
+
# Classification
|
38 |
+
"Banking77Classification": "<|START_INSTRUCTION|>Classify text into intents<|END_INSTRUCTION|>",
|
39 |
+
"ImdbClassification": "<|START_INSTRUCTION|>Classify text into sentiment<|END_INSTRUCTION|>",
|
40 |
+
"MTOPDomainClassification": "<|START_INSTRUCTION|>Classify text into intent domain<|END_INSTRUCTION|>",
|
41 |
+
"MassiveIntentClassification": "<|START_INSTRUCTION|>Classify text into user intents<|END_INSTRUCTION|>",
|
42 |
+
"MassiveScenarioClassification": "<|START_INSTRUCTION|>Classify text into user scenarios<|END_INSTRUCTION|>",
|
43 |
+
"ToxicConversationsClassification": "<|START_INSTRUCTION|>Classify text into toxic or not toxic<|END_INSTRUCTION|>",
|
44 |
+
"TweetSentimentExtractionClassification": "<|START_INSTRUCTION|>Classify text into positive, negative, or neutral sentiment<|END_INSTRUCTION|>",
|
45 |
+
"AmazonCounterfactualClassification": "<|START_INSTRUCTION|>Classify text into counterfactual or not-counterfactual<|END_INSTRUCTION|>",
|
46 |
+
|
47 |
+
# Clustering
|
48 |
+
"ArXivHierarchicalClusteringP2P": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
49 |
+
"ArXivHierarchicalClusteringS2S": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles<|END_INSTRUCTION|>",
|
50 |
+
"BiorxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Biorxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
51 |
+
"MedrxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
52 |
+
"MedrxivClusteringS2S.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles<|END_INSTRUCTION|>",
|
53 |
+
"StackExchangeClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the titles<|END_INSTRUCTION|>",
|
54 |
+
"StackExchangeClusteringP2P.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the given paragraphs<|END_INSTRUCTION|>",
|
55 |
+
"TwentyNewsgroupsClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of news articles<|END_INSTRUCTION|>",
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
class DeweyWrapper:
|
60 |
+
def __init__(self, model_dir, max_seq_length: int = 1536, batch_size: int = 8):
|
61 |
+
self.model = SentenceTransformer(
|
62 |
+
model_dir,
|
63 |
+
trust_remote_code=True,
|
64 |
+
model_kwargs={
|
65 |
+
"torch_dtype": torch.bfloat16, # fp16 瀹规槗璁$畻鍑簄an
|
66 |
+
"attn_implementation": "flash_attention_2"
|
67 |
+
},
|
68 |
+
config_kwargs={"single_vector_type": "cls_add_mean"}
|
69 |
+
).cuda().bfloat16().eval()
|
70 |
+
self.model.max_seq_length = max_seq_length
|
71 |
+
self.pool = self.model.start_multi_process_pool()
|
72 |
+
self.batch_size = batch_size
|
73 |
+
|
74 |
+
def encode(
|
75 |
+
self,
|
76 |
+
sentences: list[str],
|
77 |
+
task_name: str,
|
78 |
+
prompt_type: PromptType | None = None,
|
79 |
+
**kwargs,
|
80 |
+
) -> np.ndarray:
|
81 |
+
task_type = TASK_NAME2TYPE[task_name]
|
82 |
+
if task_type == "Retrieval":
|
83 |
+
if prompt_type.value == "query":
|
84 |
+
prompt = RETRIEVE_Q_PROMPT
|
85 |
+
else:
|
86 |
+
prompt = RETRIEVE_P_PROMPT
|
87 |
+
elif task_type in ["STS", "PairClassification", "Summarization", "Reranking"]:
|
88 |
+
prompt = STS_PROMPT
|
89 |
+
else:
|
90 |
+
prompt = TASK_NAME2PROMPT[task_name]
|
91 |
+
vectors = self.model.encode_multi_process(
|
92 |
+
sentences=sentences,
|
93 |
+
pool=self.pool,
|
94 |
+
show_progress_bar=True,
|
95 |
+
batch_size=self.batch_size,
|
96 |
+
normalize_embeddings=True,
|
97 |
+
prompt=prompt,
|
98 |
+
precision="float32"
|
99 |
+
)
|
100 |
+
return vectors
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
max_seq_length = 1536
|
105 |
+
batch_szie = 8
|
106 |
+
model_dir_or_name = "infgrad/dewey_en_beta"
|
107 |
+
output_folder = f"./mteb_eng_results/dewey_en_beta"
|
108 |
+
model = DeweyWrapper(model_dir_or_name, max_seq_length=max_seq_length, batch_size=batch_szie)
|
109 |
+
|
110 |
+
tasks = list(mteb.get_benchmark("MTEB(eng, v2)"))
|
111 |
+
evaluation = mteb.MTEB(tasks=tasks)
|
112 |
+
evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
|