Upload evaluate_models.ipynb
Browse files- evaluate_models.ipynb +172 -70
evaluate_models.ipynb
CHANGED
@@ -2,92 +2,122 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
|
|
5 |
"id": "initial_id",
|
6 |
"metadata": {
|
7 |
"collapsed": true
|
8 |
},
|
|
|
9 |
"source": [
|
10 |
"import os\n",
|
|
|
11 |
"\n",
|
12 |
-
"IS_COLAB = True if
|
13 |
"if IS_COLAB:\n",
|
14 |
" # this needs to run before all other imports\n",
|
15 |
-
" os.environ[
|
16 |
"\n",
|
17 |
"import mteb\n",
|
|
|
|
|
|
|
18 |
"from sentence_transformers import SentenceTransformer"
|
19 |
-
]
|
20 |
-
"outputs": [],
|
21 |
-
"execution_count": null
|
22 |
},
|
23 |
{
|
|
|
|
|
24 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
25 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
26 |
"source": [
|
27 |
"MODELS = {\n",
|
28 |
-
"
|
29 |
-
"
|
30 |
-
"
|
|
|
31 |
" },\n",
|
32 |
-
"
|
33 |
-
"
|
34 |
-
"
|
|
|
35 |
" },\n",
|
36 |
-
"
|
37 |
-
"
|
38 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
" },\n",
|
40 |
-
" 'mt-paper': {\n",
|
41 |
-
" 'name': 'MongoDB/mdbr-leaf-mt',\n",
|
42 |
-
" 'revision': 'c342f945a6855346bd5f48d5ee8b7e39120b0ce9',\n",
|
43 |
-
" }\n",
|
44 |
"}"
|
45 |
-
]
|
46 |
-
"id": "f0189ff1e7814a5a",
|
47 |
-
"outputs": [],
|
48 |
-
"execution_count": null
|
49 |
},
|
50 |
{
|
51 |
-
"metadata": {},
|
52 |
"cell_type": "markdown",
|
|
|
|
|
53 |
"source": [
|
54 |
-
"
|
55 |
"* set the output folder and\n",
|
56 |
"* select one of the models defined above\n",
|
57 |
"* desired benchmark"
|
58 |
-
]
|
59 |
-
"id": "371c6122efdf476a"
|
60 |
},
|
61 |
{
|
62 |
-
"metadata": {},
|
63 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
64 |
"source": [
|
65 |
-
"output_folder = f\"../../data/results/publish/\"\n",
|
|
|
66 |
"\n",
|
67 |
-
"model_selection = MODELS[
|
68 |
"benchmark_name = \"BEIR\"\n",
|
69 |
"\n",
|
70 |
"# model_selection = MODELS['mt-prod']\n",
|
71 |
"# benchmark_name = \"MTEB(eng, v2)\""
|
72 |
-
]
|
73 |
-
"id": "58d52a330febb9ac",
|
74 |
-
"outputs": [],
|
75 |
-
"execution_count": null
|
76 |
},
|
77 |
{
|
78 |
-
"metadata": {},
|
79 |
"cell_type": "markdown",
|
80 |
-
"
|
81 |
-
"
|
|
|
|
|
|
|
82 |
},
|
83 |
{
|
|
|
|
|
|
|
84 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
86 |
"source": [
|
87 |
-
"model = SentenceTransformer(\n",
|
88 |
-
" model_selection['name'],\n",
|
89 |
-
" revision=model_selection['revision']\n",
|
90 |
-
")\n",
|
91 |
"\n",
|
92 |
"# alternative:\n",
|
93 |
"# meta = mteb.get_model_meta(\n",
|
@@ -95,25 +125,14 @@
|
|
95 |
"# revision=model_selection['revision']\n",
|
96 |
"# )\n",
|
97 |
"# model = meta.load_model()"
|
98 |
-
]
|
99 |
-
"id": "d6f13945a94f7a85",
|
100 |
-
"outputs": [],
|
101 |
-
"execution_count": null
|
102 |
},
|
103 |
{
|
104 |
-
"metadata": {},
|
105 |
"cell_type": "code",
|
106 |
-
"
|
107 |
-
|
108 |
-
"evaluation = mteb.MTEB(tasks=benchmark)"
|
109 |
-
],
|
110 |
-
"id": "c716c6344f9cd939",
|
111 |
-
"outputs": [],
|
112 |
-
"execution_count": null
|
113 |
-
},
|
114 |
-
{
|
115 |
"metadata": {},
|
116 |
-
"
|
117 |
"source": [
|
118 |
"%%time\n",
|
119 |
"results = evaluation.run(\n",
|
@@ -122,28 +141,32 @@
|
|
122 |
" output_folder=output_folder,\n",
|
123 |
" overwrite_results=True,\n",
|
124 |
")"
|
125 |
-
]
|
126 |
-
"id": "9bd44e88fc360663",
|
127 |
-
"outputs": [],
|
128 |
-
"execution_count": null
|
129 |
},
|
130 |
{
|
131 |
-
"metadata": {},
|
132 |
"cell_type": "markdown",
|
133 |
-
"
|
134 |
-
"
|
|
|
|
|
|
|
135 |
},
|
136 |
{
|
137 |
-
"metadata": {},
|
138 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
139 |
"source": [
|
140 |
-
"if model_selection[
|
141 |
" # quora is closer to a sentence similarity task than a retrieval one, as queries aren't proper user queries\n",
|
142 |
" # we thus embed them without the typical query prompt\n",
|
143 |
" model.prompts = {}\n",
|
144 |
-
" tasks = mteb.get_tasks(
|
145 |
-
" \
|
146 |
-
"
|
|
|
|
|
147 |
"\n",
|
148 |
" evaluation = mteb.MTEB(tasks=tasks)\n",
|
149 |
" results = evaluation.run(\n",
|
@@ -152,10 +175,89 @@
|
|
152 |
" output_folder=output_folder,\n",
|
153 |
" overwrite_results=True,\n",
|
154 |
" )"
|
155 |
-
]
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
"outputs": [],
|
158 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
}
|
160 |
],
|
161 |
"metadata": {
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"id": "initial_id",
|
7 |
"metadata": {
|
8 |
"collapsed": true
|
9 |
},
|
10 |
+
"outputs": [],
|
11 |
"source": [
|
12 |
"import os\n",
|
13 |
+
"from typing import Dict, List\n",
|
14 |
"\n",
|
15 |
+
"IS_COLAB = True if \"GOOGLE_CLOUD_PROJECT\" in os.environ else False\n",
|
16 |
"if IS_COLAB:\n",
|
17 |
" # this needs to run before all other imports\n",
|
18 |
+
" os.environ[\"HF_HOME\"] = \"/content/cache/\" # to avoid running out of disk space\n",
|
19 |
"\n",
|
20 |
"import mteb\n",
|
21 |
+
"import numpy as np\n",
|
22 |
+
"import torch\n",
|
23 |
+
"from mteb.encoder_interface import PromptType\n",
|
24 |
"from sentence_transformers import SentenceTransformer"
|
25 |
+
]
|
|
|
|
|
26 |
},
|
27 |
{
|
28 |
+
"cell_type": "markdown",
|
29 |
+
"id": "5325acfb",
|
30 |
"metadata": {},
|
31 |
+
"source": [
|
32 |
+
"### Notebook Configuration"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
"cell_type": "code",
|
37 |
+
"execution_count": null,
|
38 |
+
"id": "f0189ff1e7814a5a",
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [],
|
41 |
"source": [
|
42 |
"MODELS = {\n",
|
43 |
+
" \"ir-prod\": {\n",
|
44 |
+
" \"name\": \"MongoDB/mdbr-leaf-ir\",\n",
|
45 |
+
" \"revision\": \"2e46f5aac796e621d51f678c306a66ede4712ecb\",\n",
|
46 |
+
" \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n",
|
47 |
" },\n",
|
48 |
+
" \"ir-paper\": {\n",
|
49 |
+
" \"name\": \"MongoDB/mdbr-leaf-ir\",\n",
|
50 |
+
" \"revision\": \"ea98995e96beac21b820aa8ad9afaa6fd29b243d\",\n",
|
51 |
+
" \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n",
|
52 |
" },\n",
|
53 |
+
" \"mt-prod\": {\n",
|
54 |
+
" \"name\": \"MongoDB/mdbr-leaf-mt\",\n",
|
55 |
+
" \"revision\": \"66c47ba6d753efc208d54412b5af6c744a39a4df\",\n",
|
56 |
+
" \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n",
|
57 |
+
" },\n",
|
58 |
+
" \"mt-paper\": {\n",
|
59 |
+
" \"name\": \"MongoDB/mdbr-leaf-mt\",\n",
|
60 |
+
" \"revision\": \"c342f945a6855346bd5f48d5ee8b7e39120b0ce9\",\n",
|
61 |
+
" \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n",
|
62 |
" },\n",
|
|
|
|
|
|
|
|
|
63 |
"}"
|
64 |
+
]
|
|
|
|
|
|
|
65 |
},
|
66 |
{
|
|
|
67 |
"cell_type": "markdown",
|
68 |
+
"id": "371c6122efdf476a",
|
69 |
+
"metadata": {},
|
70 |
"source": [
|
71 |
+
"In the cell below:\n",
|
72 |
"* set the output folder and\n",
|
73 |
"* select one of the models defined above\n",
|
74 |
"* desired benchmark"
|
75 |
+
]
|
|
|
76 |
},
|
77 |
{
|
|
|
78 |
"cell_type": "code",
|
79 |
+
"execution_count": null,
|
80 |
+
"id": "58d52a330febb9ac",
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [],
|
83 |
"source": [
|
84 |
+
"# output_folder = f\"../../data/results/publish/\"\n",
|
85 |
+
"output_folder = f\"/content/data/results/publish/\"\n",
|
86 |
"\n",
|
87 |
+
"model_selection = MODELS[\"ir-prod\"]\n",
|
88 |
"benchmark_name = \"BEIR\"\n",
|
89 |
"\n",
|
90 |
"# model_selection = MODELS['mt-prod']\n",
|
91 |
"# benchmark_name = \"MTEB(eng, v2)\""
|
92 |
+
]
|
|
|
|
|
|
|
93 |
},
|
94 |
{
|
|
|
95 |
"cell_type": "markdown",
|
96 |
+
"id": "1b4367afc1278e",
|
97 |
+
"metadata": {},
|
98 |
+
"source": [
|
99 |
+
"### Run Evals"
|
100 |
+
]
|
101 |
},
|
102 |
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": null,
|
105 |
+
"id": "c716c6344f9cd939",
|
106 |
"metadata": {},
|
107 |
+
"outputs": [],
|
108 |
+
"source": [
|
109 |
+
"benchmark = mteb.get_benchmark(benchmark_name)\n",
|
110 |
+
"evaluation = mteb.MTEB(tasks=benchmark)"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
"cell_type": "code",
|
115 |
+
"execution_count": null,
|
116 |
+
"id": "d6f13945a94f7a85",
|
117 |
+
"metadata": {},
|
118 |
+
"outputs": [],
|
119 |
"source": [
|
120 |
+
"model = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n",
|
|
|
|
|
|
|
121 |
"\n",
|
122 |
"# alternative:\n",
|
123 |
"# meta = mteb.get_model_meta(\n",
|
|
|
125 |
"# revision=model_selection['revision']\n",
|
126 |
"# )\n",
|
127 |
"# model = meta.load_model()"
|
128 |
+
]
|
|
|
|
|
|
|
129 |
},
|
130 |
{
|
|
|
131 |
"cell_type": "code",
|
132 |
+
"execution_count": null,
|
133 |
+
"id": "9bd44e88fc360663",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
"metadata": {},
|
135 |
+
"outputs": [],
|
136 |
"source": [
|
137 |
"%%time\n",
|
138 |
"results = evaluation.run(\n",
|
|
|
141 |
" output_folder=output_folder,\n",
|
142 |
" overwrite_results=True,\n",
|
143 |
")"
|
144 |
+
]
|
|
|
|
|
|
|
145 |
},
|
146 |
{
|
|
|
147 |
"cell_type": "markdown",
|
148 |
+
"id": "733e52ca41cf92a7",
|
149 |
+
"metadata": {},
|
150 |
+
"source": [
|
151 |
+
"Evaluate Quora"
|
152 |
+
]
|
153 |
},
|
154 |
{
|
|
|
155 |
"cell_type": "code",
|
156 |
+
"execution_count": null,
|
157 |
+
"id": "61aea9a04468202f",
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [],
|
160 |
"source": [
|
161 |
+
"if model_selection[\"name\"].endswith(\"ir\"):\n",
|
162 |
" # quora is closer to a sentence similarity task than a retrieval one, as queries aren't proper user queries\n",
|
163 |
" # we thus embed them without the typical query prompt\n",
|
164 |
" model.prompts = {}\n",
|
165 |
+
" tasks = mteb.get_tasks(\n",
|
166 |
+
" tasks=[\n",
|
167 |
+
" \"QuoraRetrieval\",\n",
|
168 |
+
" ]\n",
|
169 |
+
" )\n",
|
170 |
"\n",
|
171 |
" evaluation = mteb.MTEB(tasks=tasks)\n",
|
172 |
" results = evaluation.run(\n",
|
|
|
175 |
" output_folder=output_folder,\n",
|
176 |
" overwrite_results=True,\n",
|
177 |
" )"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "markdown",
|
182 |
+
"id": "6a6c164e",
|
183 |
+
"metadata": {},
|
184 |
+
"source": [
|
185 |
+
"### Asymmetric Mode\n",
|
186 |
+
"\n",
|
187 |
+
"Compute asymmetric mode scores: queries encoded by `leaf`, documents by the original teacher model."
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"cell_type": "code",
|
192 |
+
"execution_count": null,
|
193 |
+
"id": "487ba349",
|
194 |
+
"metadata": {},
|
195 |
"outputs": [],
|
196 |
+
"source": [
|
197 |
+
"class AsymmetricModel:\n",
|
198 |
+
" def __init__(\n",
|
199 |
+
" self,\n",
|
200 |
+
" doc_model: SentenceTransformer,\n",
|
201 |
+
" query_model: SentenceTransformer,\n",
|
202 |
+
" ) -> None:\n",
|
203 |
+
" self.doc_model = doc_model\n",
|
204 |
+
" self.query_model = query_model\n",
|
205 |
+
"\n",
|
206 |
+
" def encode(self, sentences: List[str], **kwargs) -> np.ndarray | torch.Tensor:\n",
|
207 |
+
" if \"prompt_type\" not in kwargs:\n",
|
208 |
+
" kwargs[\"prompt_type\"] = None\n",
|
209 |
+
"\n",
|
210 |
+
" match kwargs[\"prompt_type\"]:\n",
|
211 |
+
" case PromptType.query:\n",
|
212 |
+
" out = self.query_model.encode(sentences, prompt_name=\"query\", **kwargs)\n",
|
213 |
+
"\n",
|
214 |
+
" case PromptType.document:\n",
|
215 |
+
" out = self.doc_model.encode(sentences, **kwargs)\n",
|
216 |
+
"\n",
|
217 |
+
" case None:\n",
|
218 |
+
" print(\"No prompt type: using query (leaf) model for encoding\")\n",
|
219 |
+
" out = self.query_model.encode(sentences, **kwargs)\n",
|
220 |
+
" case _:\n",
|
221 |
+
" raise ValueError(f\"Encoding unknown type: {kwargs['prompt_type']}\")\n",
|
222 |
+
"\n",
|
223 |
+
" if not isinstance(out, torch.Tensor):\n",
|
224 |
+
" out = torch.from_numpy(out)\n",
|
225 |
+
"\n",
|
226 |
+
" out = out.to(\"cpu\")\n",
|
227 |
+
" return out"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": null,
|
233 |
+
"id": "4162af7f",
|
234 |
+
"metadata": {},
|
235 |
+
"outputs": [],
|
236 |
+
"source": [
|
237 |
+
"leaf = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n",
|
238 |
+
"teacher = SentenceTransformer(model_selection[\"teacher\"])\n",
|
239 |
+
"\n",
|
240 |
+
"asymm_model = AsymmetricModel(\n",
|
241 |
+
" query_model=leaf,\n",
|
242 |
+
" doc_model=teacher,\n",
|
243 |
+
")"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": null,
|
249 |
+
"id": "848d8a5f",
|
250 |
+
"metadata": {},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"%%time\n",
|
254 |
+
"results = evaluation.run(\n",
|
255 |
+
" model=asymm_model,\n",
|
256 |
+
" verbosity=1,\n",
|
257 |
+
" output_folder=output_folder,\n",
|
258 |
+
" overwrite_results=True,\n",
|
259 |
+
")"
|
260 |
+
]
|
261 |
}
|
262 |
],
|
263 |
"metadata": {
|