rvo commited on
Commit
c41b585
·
verified ·
1 Parent(s): 8321e9a

Upload evaluate_models.ipynb

Browse files
Files changed (1) hide show
  1. evaluate_models.ipynb +172 -70
evaluate_models.ipynb CHANGED
@@ -2,92 +2,122 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
 
5
  "id": "initial_id",
6
  "metadata": {
7
  "collapsed": true
8
  },
 
9
  "source": [
10
  "import os\n",
 
11
  "\n",
12
- "IS_COLAB = True if 'GOOGLE_CLOUD_PROJECT' in os.environ else False\n",
13
  "if IS_COLAB:\n",
14
  " # this needs to run before all other imports\n",
15
- " os.environ['HF_HOME'] = '/content/cache/' # to avoid running out of disk space\n",
16
  "\n",
17
  "import mteb\n",
 
 
 
18
  "from sentence_transformers import SentenceTransformer"
19
- ],
20
- "outputs": [],
21
- "execution_count": null
22
  },
23
  {
 
 
24
  "metadata": {},
 
 
 
 
 
25
  "cell_type": "code",
 
 
 
 
26
  "source": [
27
  "MODELS = {\n",
28
- " 'ir-prod': {\n",
29
- " 'name': 'MongoDB/mdbr-leaf-ir',\n",
30
- " 'revision': '2e46f5aac796e621d51f678c306a66ede4712ecb'\n",
 
31
  " },\n",
32
- " 'ir-paper': {\n",
33
- " 'name': 'MongoDB/mdbr-leaf-ir',\n",
34
- " 'revision': 'ea98995e96beac21b820aa8ad9afaa6fd29b243d'\n",
 
35
  " },\n",
36
- " 'mt-prod': {\n",
37
- " 'name': 'MongoDB/mdbr-leaf-mt',\n",
38
- " 'revision': '66c47ba6d753efc208d54412b5af6c744a39a4df'\n",
 
 
 
 
 
 
39
  " },\n",
40
- " 'mt-paper': {\n",
41
- " 'name': 'MongoDB/mdbr-leaf-mt',\n",
42
- " 'revision': 'c342f945a6855346bd5f48d5ee8b7e39120b0ce9',\n",
43
- " }\n",
44
  "}"
45
- ],
46
- "id": "f0189ff1e7814a5a",
47
- "outputs": [],
48
- "execution_count": null
49
  },
50
  {
51
- "metadata": {},
52
  "cell_type": "markdown",
 
 
53
  "source": [
54
- "**Notebook configuration**:\n",
55
  "* set the output folder and\n",
56
  "* select one of the models defined above\n",
57
  "* desired benchmark"
58
- ],
59
- "id": "371c6122efdf476a"
60
  },
61
  {
62
- "metadata": {},
63
  "cell_type": "code",
 
 
 
 
64
  "source": [
65
- "output_folder = f\"../../data/results/publish/\"\n",
 
66
  "\n",
67
- "model_selection = MODELS['ir-prod']\n",
68
  "benchmark_name = \"BEIR\"\n",
69
  "\n",
70
  "# model_selection = MODELS['mt-prod']\n",
71
  "# benchmark_name = \"MTEB(eng, v2)\""
72
- ],
73
- "id": "58d52a330febb9ac",
74
- "outputs": [],
75
- "execution_count": null
76
  },
77
  {
78
- "metadata": {},
79
  "cell_type": "markdown",
80
- "source": "Load the model and run the evals",
81
- "id": "1b4367afc1278e"
 
 
 
82
  },
83
  {
 
 
 
84
  "metadata": {},
 
 
 
 
 
 
 
85
  "cell_type": "code",
 
 
 
 
86
  "source": [
87
- "model = SentenceTransformer(\n",
88
- " model_selection['name'],\n",
89
- " revision=model_selection['revision']\n",
90
- ")\n",
91
  "\n",
92
  "# alternative:\n",
93
  "# meta = mteb.get_model_meta(\n",
@@ -95,25 +125,14 @@
95
  "# revision=model_selection['revision']\n",
96
  "# )\n",
97
  "# model = meta.load_model()"
98
- ],
99
- "id": "d6f13945a94f7a85",
100
- "outputs": [],
101
- "execution_count": null
102
  },
103
  {
104
- "metadata": {},
105
  "cell_type": "code",
106
- "source": [
107
- "benchmark = mteb.get_benchmark(benchmark_name)\n",
108
- "evaluation = mteb.MTEB(tasks=benchmark)"
109
- ],
110
- "id": "c716c6344f9cd939",
111
- "outputs": [],
112
- "execution_count": null
113
- },
114
- {
115
  "metadata": {},
116
- "cell_type": "code",
117
  "source": [
118
  "%%time\n",
119
  "results = evaluation.run(\n",
@@ -122,28 +141,32 @@
122
  " output_folder=output_folder,\n",
123
  " overwrite_results=True,\n",
124
  ")"
125
- ],
126
- "id": "9bd44e88fc360663",
127
- "outputs": [],
128
- "execution_count": null
129
  },
130
  {
131
- "metadata": {},
132
  "cell_type": "markdown",
133
- "source": "Evaluate Quora",
134
- "id": "733e52ca41cf92a7"
 
 
 
135
  },
136
  {
137
- "metadata": {},
138
  "cell_type": "code",
 
 
 
 
139
  "source": [
140
- "if model_selection['name'].endswith('ir'):\n",
141
  " # quora is closer to a sentence similarity task than a retrieval one, as queries aren't proper user queries\n",
142
  " # we thus embed them without the typical query prompt\n",
143
  " model.prompts = {}\n",
144
- " tasks = mteb.get_tasks(tasks=[\n",
145
- " \"QuoraRetrieval\",\n",
146
- " ])\n",
 
 
147
  "\n",
148
  " evaluation = mteb.MTEB(tasks=tasks)\n",
149
  " results = evaluation.run(\n",
@@ -152,10 +175,89 @@
152
  " output_folder=output_folder,\n",
153
  " overwrite_results=True,\n",
154
  " )"
155
- ],
156
- "id": "61aea9a04468202f",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  "outputs": [],
158
- "execution_count": null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  }
160
  ],
161
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "initial_id",
7
  "metadata": {
8
  "collapsed": true
9
  },
10
+ "outputs": [],
11
  "source": [
12
  "import os\n",
13
+ "from typing import Dict, List\n",
14
  "\n",
15
+ "IS_COLAB = True if \"GOOGLE_CLOUD_PROJECT\" in os.environ else False\n",
16
  "if IS_COLAB:\n",
17
  " # this needs to run before all other imports\n",
18
+ " os.environ[\"HF_HOME\"] = \"/content/cache/\" # to avoid running out of disk space\n",
19
  "\n",
20
  "import mteb\n",
21
+ "import numpy as np\n",
22
+ "import torch\n",
23
+ "from mteb.encoder_interface import PromptType\n",
24
  "from sentence_transformers import SentenceTransformer"
25
+ ]
 
 
26
  },
27
  {
28
+ "cell_type": "markdown",
29
+ "id": "5325acfb",
30
  "metadata": {},
31
+ "source": [
32
+ "### Notebook Configuration"
33
+ ]
34
+ },
35
+ {
36
  "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "f0189ff1e7814a5a",
39
+ "metadata": {},
40
+ "outputs": [],
41
  "source": [
42
  "MODELS = {\n",
43
+ " \"ir-prod\": {\n",
44
+ " \"name\": \"MongoDB/mdbr-leaf-ir\",\n",
45
+ " \"revision\": \"2e46f5aac796e621d51f678c306a66ede4712ecb\",\n",
46
+ " \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n",
47
  " },\n",
48
+ " \"ir-paper\": {\n",
49
+ " \"name\": \"MongoDB/mdbr-leaf-ir\",\n",
50
+ " \"revision\": \"ea98995e96beac21b820aa8ad9afaa6fd29b243d\",\n",
51
+ " \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n",
52
  " },\n",
53
+ " \"mt-prod\": {\n",
54
+ " \"name\": \"MongoDB/mdbr-leaf-mt\",\n",
55
+ " \"revision\": \"66c47ba6d753efc208d54412b5af6c744a39a4df\",\n",
56
+ " \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n",
57
+ " },\n",
58
+ " \"mt-paper\": {\n",
59
+ " \"name\": \"MongoDB/mdbr-leaf-mt\",\n",
60
+ " \"revision\": \"c342f945a6855346bd5f48d5ee8b7e39120b0ce9\",\n",
61
+ " \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n",
62
  " },\n",
 
 
 
 
63
  "}"
64
+ ]
 
 
 
65
  },
66
  {
 
67
  "cell_type": "markdown",
68
+ "id": "371c6122efdf476a",
69
+ "metadata": {},
70
  "source": [
71
+ "In the cell below:\n",
72
  "* set the output folder and\n",
73
  "* select one of the models defined above\n",
74
  "* desired benchmark"
75
+ ]
 
76
  },
77
  {
 
78
  "cell_type": "code",
79
+ "execution_count": null,
80
+ "id": "58d52a330febb9ac",
81
+ "metadata": {},
82
+ "outputs": [],
83
  "source": [
84
+ "# output_folder = f\"../../data/results/publish/\"\n",
85
+ "output_folder = f\"/content/data/results/publish/\"\n",
86
  "\n",
87
+ "model_selection = MODELS[\"ir-prod\"]\n",
88
  "benchmark_name = \"BEIR\"\n",
89
  "\n",
90
  "# model_selection = MODELS['mt-prod']\n",
91
  "# benchmark_name = \"MTEB(eng, v2)\""
92
+ ]
 
 
 
93
  },
94
  {
 
95
  "cell_type": "markdown",
96
+ "id": "1b4367afc1278e",
97
+ "metadata": {},
98
+ "source": [
99
+ "### Run Evals"
100
+ ]
101
  },
102
  {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "id": "c716c6344f9cd939",
106
  "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "benchmark = mteb.get_benchmark(benchmark_name)\n",
110
+ "evaluation = mteb.MTEB(tasks=benchmark)"
111
+ ]
112
+ },
113
+ {
114
  "cell_type": "code",
115
+ "execution_count": null,
116
+ "id": "d6f13945a94f7a85",
117
+ "metadata": {},
118
+ "outputs": [],
119
  "source": [
120
+ "model = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n",
 
 
 
121
  "\n",
122
  "# alternative:\n",
123
  "# meta = mteb.get_model_meta(\n",
 
125
  "# revision=model_selection['revision']\n",
126
  "# )\n",
127
  "# model = meta.load_model()"
128
+ ]
 
 
 
129
  },
130
  {
 
131
  "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "9bd44e88fc360663",
 
 
 
 
 
 
 
134
  "metadata": {},
135
+ "outputs": [],
136
  "source": [
137
  "%%time\n",
138
  "results = evaluation.run(\n",
 
141
  " output_folder=output_folder,\n",
142
  " overwrite_results=True,\n",
143
  ")"
144
+ ]
 
 
 
145
  },
146
  {
 
147
  "cell_type": "markdown",
148
+ "id": "733e52ca41cf92a7",
149
+ "metadata": {},
150
+ "source": [
151
+ "Evaluate Quora"
152
+ ]
153
  },
154
  {
 
155
  "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "61aea9a04468202f",
158
+ "metadata": {},
159
+ "outputs": [],
160
  "source": [
161
+ "if model_selection[\"name\"].endswith(\"ir\"):\n",
162
  " # quora is closer to a sentence similarity task than a retrieval one, as queries aren't proper user queries\n",
163
  " # we thus embed them without the typical query prompt\n",
164
  " model.prompts = {}\n",
165
+ " tasks = mteb.get_tasks(\n",
166
+ " tasks=[\n",
167
+ " \"QuoraRetrieval\",\n",
168
+ " ]\n",
169
+ " )\n",
170
  "\n",
171
  " evaluation = mteb.MTEB(tasks=tasks)\n",
172
  " results = evaluation.run(\n",
 
175
  " output_folder=output_folder,\n",
176
  " overwrite_results=True,\n",
177
  " )"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "id": "6a6c164e",
183
+ "metadata": {},
184
+ "source": [
185
+ "### Asymmetric Mode\n",
186
+ "\n",
187
+ "Compute asymmetric mode scores: queries encoded by `leaf`, documents by the original teacher model."
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "487ba349",
194
+ "metadata": {},
195
  "outputs": [],
196
+ "source": [
197
+ "class AsymmetricModel:\n",
198
+ " def __init__(\n",
199
+ " self,\n",
200
+ " doc_model: SentenceTransformer,\n",
201
+ " query_model: SentenceTransformer,\n",
202
+ " ) -> None:\n",
203
+ " self.doc_model = doc_model\n",
204
+ " self.query_model = query_model\n",
205
+ "\n",
206
+ " def encode(self, sentences: List[str], **kwargs) -> np.ndarray | torch.Tensor:\n",
207
+ " if \"prompt_type\" not in kwargs:\n",
208
+ " kwargs[\"prompt_type\"] = None\n",
209
+ "\n",
210
+ " match kwargs[\"prompt_type\"]:\n",
211
+ " case PromptType.query:\n",
212
+ " out = self.query_model.encode(sentences, prompt_name=\"query\", **kwargs)\n",
213
+ "\n",
214
+ " case PromptType.document:\n",
215
+ " out = self.doc_model.encode(sentences, **kwargs)\n",
216
+ "\n",
217
+ " case None:\n",
218
+ " print(\"No prompt type: using query (leaf) model for encoding\")\n",
219
+ " out = self.query_model.encode(sentences, **kwargs)\n",
220
+ " case _:\n",
221
+ " raise ValueError(f\"Encoding unknown type: {kwargs['prompt_type']}\")\n",
222
+ "\n",
223
+ " if not isinstance(out, torch.Tensor):\n",
224
+ " out = torch.from_numpy(out)\n",
225
+ "\n",
226
+ " out = out.to(\"cpu\")\n",
227
+ " return out"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "id": "4162af7f",
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "leaf = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n",
238
+ "teacher = SentenceTransformer(model_selection[\"teacher\"])\n",
239
+ "\n",
240
+ "asymm_model = AsymmetricModel(\n",
241
+ " query_model=leaf,\n",
242
+ " doc_model=teacher,\n",
243
+ ")"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "848d8a5f",
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "%%time\n",
254
+ "results = evaluation.run(\n",
255
+ " model=asymm_model,\n",
256
+ " verbosity=1,\n",
257
+ " output_folder=output_folder,\n",
258
+ " overwrite_results=True,\n",
259
+ ")"
260
+ ]
261
  }
262
  ],
263
  "metadata": {