infgrad commited on
Commit
fbc1304
verified
1 Parent(s): a41c6a1

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/inference_architecture.png filter=lfs diff=lfs merge=lfs -text
imgs/inference_architecture.png ADDED

Git LFS Details

  • SHA256: b921665746ae629386f81c2161ed498bbdfe4e6a0b1bef7fd31fe4627ec49706
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
scripts/evaluate/run_evaluate_long_embed.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4
+ os.environ["OPENBLAS_NUM_THREADS"] = "32"
5
+ import numpy as np
6
+ import torch
7
+ import mteb
8
+ from mteb.encoder_interface import PromptType
9
+ from sentence_transformers import SentenceTransformer
10
+ from mteb.models.wrapper import Wrapper
11
+ from typing import Sequence
12
+ from typing import Any
13
+ from transformers import AutoTokenizer, AutoModel
14
+
15
+
16
+ class DeweySingleVectorWrapper:
17
+ def __init__(self, model_dir, batch_size: int = 8):
18
+ self.model = SentenceTransformer(
19
+ model_dir,
20
+ trust_remote_code=True,
21
+ model_kwargs={
22
+ "torch_dtype": torch.bfloat16, # fp16 瀹规槗璁$畻鍑簄an
23
+ "attn_implementation": "flash_attention_2"
24
+ },
25
+ config_kwargs={"single_vector_type": "mean"}
26
+ ).cuda().bfloat16().eval()
27
+ self.model.max_seq_length = max_seq_length
28
+ self.pool = self.model.start_multi_process_pool()
29
+ self.batch_size = batch_size
30
+
31
+ def encode(
32
+ self,
33
+ sentences: list[str],
34
+ task_name: str,
35
+ prompt_type: PromptType | None = None,
36
+ **kwargs,
37
+ ) -> np.ndarray:
38
+ if prompt_type.value == "query":
39
+ prompt = RETRIEVE_Q_PROMPT
40
+ else:
41
+ prompt = RETRIEVE_P_PROMPT
42
+ vectors = self.model.encode_multi_process(
43
+ sentences=sentences,
44
+ pool=self.pool,
45
+ show_progress_bar=True,
46
+ batch_size=self.batch_size,
47
+ normalize_embeddings=True,
48
+ prompt=prompt,
49
+ precision="float32"
50
+ )
51
+ return vectors
52
+
53
+
54
+ class DeweyMultiVectorWrapper(Wrapper):
55
+ def __init__(
56
+ self,
57
+ model_dir: str,
58
+ batch_size: int = 8,
59
+ *args,
60
+ **kwargs,
61
+ ) -> None:
62
+ self.model = AutoModel.from_pretrained(
63
+ model_dir,
64
+ trust_remote_code=True,
65
+ attn_implementation="flash_attention_2"
66
+ ).cuda().bfloat16()
67
+ self.batch_size = batch_size
68
+ self.model.tokenizer = AutoTokenizer.from_pretrained(model_dir)
69
+
70
+ def encode(
71
+ self,
72
+ sentences: Sequence[str],
73
+ *,
74
+ task_name: str,
75
+ prompt_type: PromptType | None = None,
76
+ **kwargs: Any,
77
+ ) -> np.ndarray:
78
+
79
+ if prompt_type.value == "query":
80
+ prompt = RETRIEVE_Q_PROMPT
81
+ else:
82
+ prompt = RETRIEVE_P_PROMPT
83
+ if prompt_type.value == "query":
84
+ pred = self.model.encode(
85
+ sentences=list(sentences),
86
+ use_cuda=True,
87
+ show_progress_bar=True,
88
+ chunk_size=-1,
89
+ chunk_overlap=32,
90
+ convert_to_tensor=True,
91
+ max_seq_length=max_seq_length,
92
+ batch_size=self.batch_size,
93
+ normalize_embeddings=True,
94
+ prompt=prompt,
95
+ fast_chunk=False
96
+
97
+ )[0]
98
+ # query vector do not need multi vector, we only use mean as final one vector
99
+ pred = [vecs[1:2, :] for vecs in pred]
100
+ else:
101
+ pred = self.model.encode(
102
+ sentences=list(sentences),
103
+ use_cuda=True,
104
+ show_progress_bar=True,
105
+ chunk_size=256,
106
+ chunk_overlap=32,
107
+ convert_to_tensor=True,
108
+ max_seq_length=max_seq_length,
109
+ batch_size=self.batch_size,
110
+ normalize_embeddings=True,
111
+ prompt=prompt,
112
+ fast_chunk=True,
113
+ )[0]
114
+
115
+ pred = torch.nn.utils.rnn.pad_sequence(pred, batch_first=True, padding_value=0)
116
+ return pred.cpu().numpy()
117
+
118
+ def similarity(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
119
+ if not isinstance(a, torch.Tensor):
120
+ a = torch.tensor(a, dtype=torch.float32)
121
+
122
+ if not isinstance(b, torch.Tensor):
123
+ b = torch.tensor(b, dtype=torch.float32)
124
+
125
+ if len(a.shape) == 2:
126
+ a = a.unsqueeze(0)
127
+
128
+ if len(b.shape) == 2:
129
+ b = b.unsqueeze(0)
130
+
131
+ scores = torch.einsum(
132
+ "ash,bth->abst",
133
+ a,
134
+ b,
135
+ )
136
+
137
+ return scores.max(axis=-1).values.sum(axis=-1)
138
+
139
+
140
+ RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
141
+ RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
142
+
143
+ if __name__ == "__main__":
144
+ ################# evaluate single vector #################
145
+ # batch_size = 4
146
+ # max_seq_length = 128 * 1024
147
+ # model = DeweySingleVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size)
148
+ # output_folder = f"./long_embed_benchmark/dewey_en_beta_single_vector_128k"
149
+ # tasks = list(mteb.get_benchmark("LongEmbed"))
150
+ # evaluation = mteb.MTEB(tasks=tasks)
151
+ # evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
152
+
153
+ ################# evaluate multi vectors #################
154
+ batch_size = 4
155
+ max_seq_length = 128 * 1024
156
+ model = DeweyMultiVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size)
157
+ output_folder = f"./long_embed_benchmark/dewey_en_beta_multi_vectors"
158
+
159
+ tasks = list(mteb.get_benchmark("LongEmbed"))
160
+ evaluation = mteb.MTEB(tasks=tasks)
161
+ evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
scripts/evaluate/run_evaluate_mteb_dewey_en_beta.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4
+ os.environ["OPENBLAS_NUM_THREADS"] = "32"
5
+ import mteb
6
+ import torch
7
+ import numpy as np
8
+ from mteb.encoder_interface import PromptType
9
+ from sentence_transformers import SentenceTransformer
10
+
11
+ TASK_NAME2TYPE = {
12
+ 'ArguAna': 'Retrieval', 'ArXivHierarchicalClusteringP2P': 'Clustering',
13
+ 'ArXivHierarchicalClusteringS2S': 'Clustering', 'AskUbuntuDupQuestions': 'Reranking',
14
+ 'BIOSSES': 'STS', 'Banking77Classification': 'Classification',
15
+ 'BiorxivClusteringP2P.v2': 'Clustering', 'CQADupstackGamingRetrieval': 'Retrieval',
16
+ 'CQADupstackUnixRetrieval': 'Retrieval', 'ClimateFEVERHardNegatives': 'Retrieval',
17
+ 'FEVERHardNegatives': 'Retrieval', 'FiQA2018': 'Retrieval', 'HotpotQAHardNegatives': 'Retrieval',
18
+ 'ImdbClassification': 'Classification', 'MTOPDomainClassification': 'Classification',
19
+ 'MassiveIntentClassification': 'Classification', 'MassiveScenarioClassification': 'Classification',
20
+ 'MedrxivClusteringP2P.v2': 'Clustering', 'MedrxivClusteringS2S.v2': 'Clustering',
21
+ 'MindSmallReranking': 'Reranking', 'SCIDOCS': 'Retrieval', 'SICK-R': 'STS', 'STS12': 'STS',
22
+ 'STS13': 'STS', 'STS14': 'STS', 'STS15': 'STS', 'STSBenchmark': 'STS',
23
+ 'SprintDuplicateQuestions': 'PairClassification', 'StackExchangeClustering.v2': 'Clustering',
24
+ 'StackExchangeClusteringP2P.v2': 'Clustering', 'TRECCOVID': 'Retrieval',
25
+ 'Touche2020Retrieval.v3': 'Retrieval', 'ToxicConversationsClassification': 'Classification',
26
+ 'TweetSentimentExtractionClassification': 'Classification',
27
+ 'TwentyNewsgroupsClustering.v2': 'Clustering', 'TwitterSemEval2015': 'PairClassification',
28
+ 'TwitterURLCorpus': 'PairClassification', 'SummEvalSummarization.v2': 'Summarization',
29
+ 'AmazonCounterfactualClassification': 'Classification', 'STS17': 'STS', 'STS22.v2': 'STS'
30
+ }
31
+
32
+ RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
33
+ RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
34
+ STS_PROMPT = "<|START_INSTRUCTION|>Generate semantically similar text<|END_INSTRUCTION|>"
35
+
36
+ TASK_NAME2PROMPT = {
37
+ # Classification
38
+ "Banking77Classification": "<|START_INSTRUCTION|>Classify text into intents<|END_INSTRUCTION|>",
39
+ "ImdbClassification": "<|START_INSTRUCTION|>Classify text into sentiment<|END_INSTRUCTION|>",
40
+ "MTOPDomainClassification": "<|START_INSTRUCTION|>Classify text into intent domain<|END_INSTRUCTION|>",
41
+ "MassiveIntentClassification": "<|START_INSTRUCTION|>Classify text into user intents<|END_INSTRUCTION|>",
42
+ "MassiveScenarioClassification": "<|START_INSTRUCTION|>Classify text into user scenarios<|END_INSTRUCTION|>",
43
+ "ToxicConversationsClassification": "<|START_INSTRUCTION|>Classify text into toxic or not toxic<|END_INSTRUCTION|>",
44
+ "TweetSentimentExtractionClassification": "<|START_INSTRUCTION|>Classify text into positive, negative, or neutral sentiment<|END_INSTRUCTION|>",
45
+ "AmazonCounterfactualClassification": "<|START_INSTRUCTION|>Classify text into counterfactual or not-counterfactual<|END_INSTRUCTION|>",
46
+
47
+ # Clustering
48
+ "ArXivHierarchicalClusteringP2P": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
49
+ "ArXivHierarchicalClusteringS2S": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles<|END_INSTRUCTION|>",
50
+ "BiorxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Biorxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
51
+ "MedrxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
52
+ "MedrxivClusteringS2S.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles<|END_INSTRUCTION|>",
53
+ "StackExchangeClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the titles<|END_INSTRUCTION|>",
54
+ "StackExchangeClusteringP2P.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the given paragraphs<|END_INSTRUCTION|>",
55
+ "TwentyNewsgroupsClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of news articles<|END_INSTRUCTION|>",
56
+ }
57
+
58
+
59
+ class DeweyWrapper:
60
+ def __init__(self, model_dir, max_seq_length: int = 1536, batch_size: int = 8):
61
+ self.model = SentenceTransformer(
62
+ model_dir,
63
+ trust_remote_code=True,
64
+ model_kwargs={
65
+ "torch_dtype": torch.bfloat16, # fp16 瀹规槗璁$畻鍑簄an
66
+ "attn_implementation": "flash_attention_2"
67
+ },
68
+ config_kwargs={"single_vector_type": "cls_add_mean"}
69
+ ).cuda().bfloat16().eval()
70
+ self.model.max_seq_length = max_seq_length
71
+ self.pool = self.model.start_multi_process_pool()
72
+ self.batch_size = batch_size
73
+
74
+ def encode(
75
+ self,
76
+ sentences: list[str],
77
+ task_name: str,
78
+ prompt_type: PromptType | None = None,
79
+ **kwargs,
80
+ ) -> np.ndarray:
81
+ task_type = TASK_NAME2TYPE[task_name]
82
+ if task_type == "Retrieval":
83
+ if prompt_type.value == "query":
84
+ prompt = RETRIEVE_Q_PROMPT
85
+ else:
86
+ prompt = RETRIEVE_P_PROMPT
87
+ elif task_type in ["STS", "PairClassification", "Summarization", "Reranking"]:
88
+ prompt = STS_PROMPT
89
+ else:
90
+ prompt = TASK_NAME2PROMPT[task_name]
91
+ vectors = self.model.encode_multi_process(
92
+ sentences=sentences,
93
+ pool=self.pool,
94
+ show_progress_bar=True,
95
+ batch_size=self.batch_size,
96
+ normalize_embeddings=True,
97
+ prompt=prompt,
98
+ precision="float32"
99
+ )
100
+ return vectors
101
+
102
+
103
+ if __name__ == "__main__":
104
+ max_seq_length = 1536
105
+ batch_szie = 8
106
+ model_dir_or_name = "infgrad/dewey_en_beta"
107
+ output_folder = f"./mteb_eng_results/dewey_en_beta"
108
+ model = DeweyWrapper(model_dir_or_name, max_seq_length=max_seq_length, batch_size=batch_szie)
109
+
110
+ tasks = list(mteb.get_benchmark("MTEB(eng, v2)"))
111
+ evaluation = mteb.MTEB(tasks=tasks)
112
+ evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)