Katsumata420 commited on
Commit
2e4bc3b
·
verified ·
1 Parent(s): dbc3ffe

Upload 6 files

Browse files
mldr/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default import PROMPT as default_prompt
2
+ from .retrieva import PROMPT as retrieva_prompt
3
+ from .retrieva_en import PROMPT as retrieva_en_prompt
4
+
5
+
6
+ PROMPTS = {
7
+ "default": default_prompt,
8
+ "retrieva": retrieva_prompt,
9
+ "retrieva-en": retrieva_en_prompt,
10
+ }
mldr/models/default.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PROMPT = {
2
+ "query": "query: ",
3
+ "passage": "passage: ",
4
+ }
mldr/models/retrieva.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT = {
2
+ "STS": "同じ意味の文を探すために次の文を表現して\n",
3
+ "Summarization": "次の記事またはタイトルを表現して\n",
4
+ "BitextMining": "次の文を表現して\n",
5
+ "Classification": "同じクラスに属する文を探すために次の文を表現して\n",
6
+ "Clustering": "類似した文を探すために次の文を表現して\n",
7
+ "Reranking-query": "関連した文書を探すために次の文を表現して\n",
8
+ "Reranking-passage": "次の文章を表現して\n",
9
+ "Retrieval-query": "関連した文書を探すために次の文を表現して\n",
10
+ "Retrieval-passage": "次の文章を表現して\n",
11
+ "InstructionRetrieval": "",
12
+ "PairClassification": "同じ意味の文を探すために次の文を表現して\n",
13
+ }
mldr/models/retrieva_en.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT = {
2
+ "STS": "Represent the sentence for retrieving duplicate sentences:\n",
3
+ "Summarization": "Represent the news article or news title for retrieval:\n",
4
+ "BitextMining": "Represent the sentence\n",
5
+ "Classification": "Represent the sentence for retrieving the sentence belonging to the same category:\n",
6
+ "Clustering": "Represent the sentence to find similar sentences:\n",
7
+ "Reranking-query": "Represent the question:\n",
8
+ "Reranking-passage": "Represent the following text:\n",
9
+ "Retrieval-query": "Represent the question:\n",
10
+ "Retrieval-passage": "Represent the following text:\n",
11
+ "InstructionRetrieval": "Retrieve text based on user query:\n",
12
+ "PairClassification": "Represent the sentence for retrieving duplicate sentences:\n",
13
+ "MultilabelClassification": "Represent the sentence for retrieving the sentence belonging to the same category:\n",
14
+ "Speed": "",
15
+ }
mldr/mteb_eval.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate AMBER models"""
2
+
3
+ import argparse
4
+
5
+ import mteb
6
+
7
+ from models import PROMPTS
8
+
9
+
10
+ def get_args() -> argparse.Namespace:
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--model_type", type=str, required=True, help="Model name", choices=PROMPTS.keys())
13
+ parser.add_argument("--model_name_or_path", type=str, required=True)
14
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
15
+ parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
16
+ parser.add_argument("--corpus_chunk_size", type=int, default=50000)
17
+ parser.add_argument("--convert_to_tensor", action="store_true")
18
+ return parser.parse_args()
19
+
20
+
21
+ def main():
22
+ args = get_args()
23
+ prompt = PROMPTS[args.model_type]
24
+ model = mteb.get_model(args.model_name_or_path, model_prompts=prompt)
25
+
26
+ tasks = [mteb.get_task("MultiLongDocRetrieval", languages=["jpn"])]
27
+ evaluation = mteb.MTEB(tasks=tasks)
28
+
29
+ encode_kwargs = {
30
+ "batch_size": args.batch_size,
31
+ "convert_to_tensor": args.convert_to_tensor,
32
+ }
33
+
34
+ evaluation.run(
35
+ model,
36
+ output_folder=args.output_dir,
37
+ encode_kwargs=encode_kwargs,
38
+ corpus_chunk_size=args.corpus_chunk_size,
39
+ )
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
mldr/results/MultiLongDocRetrieval.json ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "d67138e705d963e346253a80e59676ddb418810a",
3
+ "task_name": "MultiLongDocRetrieval",
4
+ "mteb_version": "1.36.1",
5
+ "scores": {
6
+ "dev": [
7
+ {
8
+ "ndcg_at_1": 0.29,
9
+ "ndcg_at_3": 0.35547,
10
+ "ndcg_at_5": 0.36193,
11
+ "ndcg_at_10": 0.37975,
12
+ "ndcg_at_20": 0.39722,
13
+ "ndcg_at_100": 0.42003,
14
+ "ndcg_at_1000": 0.4448,
15
+ "map_at_1": 0.29,
16
+ "map_at_3": 0.34,
17
+ "map_at_5": 0.34375,
18
+ "map_at_10": 0.35114,
19
+ "map_at_20": 0.3558,
20
+ "map_at_100": 0.35884,
21
+ "map_at_1000": 0.35967,
22
+ "recall_at_1": 0.29,
23
+ "recall_at_3": 0.4,
24
+ "recall_at_5": 0.415,
25
+ "recall_at_10": 0.47,
26
+ "recall_at_20": 0.54,
27
+ "recall_at_100": 0.665,
28
+ "recall_at_1000": 0.865,
29
+ "precision_at_1": 0.29,
30
+ "precision_at_3": 0.13333,
31
+ "precision_at_5": 0.083,
32
+ "precision_at_10": 0.047,
33
+ "precision_at_20": 0.027,
34
+ "precision_at_100": 0.00665,
35
+ "precision_at_1000": 0.00087,
36
+ "mrr_at_1": 0.29,
37
+ "mrr_at_3": 0.34,
38
+ "mrr_at_5": 0.34375,
39
+ "mrr_at_10": 0.351137,
40
+ "mrr_at_20": 0.355805,
41
+ "mrr_at_100": 0.358836,
42
+ "mrr_at_1000": 0.359674,
43
+ "nauc_ndcg_at_1_max": 0.505659,
44
+ "nauc_ndcg_at_1_std": 0.015906,
45
+ "nauc_ndcg_at_1_diff1": 0.596287,
46
+ "nauc_ndcg_at_3_max": 0.533264,
47
+ "nauc_ndcg_at_3_std": 0.032754,
48
+ "nauc_ndcg_at_3_diff1": 0.586672,
49
+ "nauc_ndcg_at_5_max": 0.531308,
50
+ "nauc_ndcg_at_5_std": 0.036854,
51
+ "nauc_ndcg_at_5_diff1": 0.586102,
52
+ "nauc_ndcg_at_10_max": 0.531042,
53
+ "nauc_ndcg_at_10_std": 0.061878,
54
+ "nauc_ndcg_at_10_diff1": 0.566107,
55
+ "nauc_ndcg_at_20_max": 0.539226,
56
+ "nauc_ndcg_at_20_std": 0.067577,
57
+ "nauc_ndcg_at_20_diff1": 0.554971,
58
+ "nauc_ndcg_at_100_max": 0.536018,
59
+ "nauc_ndcg_at_100_std": 0.071992,
60
+ "nauc_ndcg_at_100_diff1": 0.531947,
61
+ "nauc_ndcg_at_1000_max": 0.542956,
62
+ "nauc_ndcg_at_1000_std": 0.088587,
63
+ "nauc_ndcg_at_1000_diff1": 0.547443,
64
+ "nauc_map_at_1_max": 0.505659,
65
+ "nauc_map_at_1_std": 0.015906,
66
+ "nauc_map_at_1_diff1": 0.596287,
67
+ "nauc_map_at_3_max": 0.529127,
68
+ "nauc_map_at_3_std": 0.027918,
69
+ "nauc_map_at_3_diff1": 0.588492,
70
+ "nauc_map_at_5_max": 0.527978,
71
+ "nauc_map_at_5_std": 0.030311,
72
+ "nauc_map_at_5_diff1": 0.58816,
73
+ "nauc_map_at_10_max": 0.527822,
74
+ "nauc_map_at_10_std": 0.041362,
75
+ "nauc_map_at_10_diff1": 0.579864,
76
+ "nauc_map_at_20_max": 0.530149,
77
+ "nauc_map_at_20_std": 0.042933,
78
+ "nauc_map_at_20_diff1": 0.576511,
79
+ "nauc_map_at_100_max": 0.530358,
80
+ "nauc_map_at_100_std": 0.044082,
81
+ "nauc_map_at_100_diff1": 0.57324,
82
+ "nauc_map_at_1000_max": 0.530766,
83
+ "nauc_map_at_1000_std": 0.045058,
84
+ "nauc_map_at_1000_diff1": 0.57385,
85
+ "nauc_recall_at_1_max": 0.505659,
86
+ "nauc_recall_at_1_std": 0.015906,
87
+ "nauc_recall_at_1_diff1": 0.596287,
88
+ "nauc_recall_at_3_max": 0.543872,
89
+ "nauc_recall_at_3_std": 0.046595,
90
+ "nauc_recall_at_3_diff1": 0.581771,
91
+ "nauc_recall_at_5_max": 0.539487,
92
+ "nauc_recall_at_5_std": 0.055974,
93
+ "nauc_recall_at_5_diff1": 0.580452,
94
+ "nauc_recall_at_10_max": 0.538971,
95
+ "nauc_recall_at_10_std": 0.129465,
96
+ "nauc_recall_at_10_diff1": 0.519139,
97
+ "nauc_recall_at_20_max": 0.572192,
98
+ "nauc_recall_at_20_std": 0.155817,
99
+ "nauc_recall_at_20_diff1": 0.47407,
100
+ "nauc_recall_at_100_max": 0.547704,
101
+ "nauc_recall_at_100_std": 0.193272,
102
+ "nauc_recall_at_100_diff1": 0.311819,
103
+ "nauc_recall_at_1000_max": 0.666392,
104
+ "nauc_recall_at_1000_std": 0.580073,
105
+ "nauc_recall_at_1000_diff1": 0.351308,
106
+ "nauc_precision_at_1_max": 0.505659,
107
+ "nauc_precision_at_1_std": 0.015906,
108
+ "nauc_precision_at_1_diff1": 0.596287,
109
+ "nauc_precision_at_3_max": 0.543872,
110
+ "nauc_precision_at_3_std": 0.046595,
111
+ "nauc_precision_at_3_diff1": 0.581771,
112
+ "nauc_precision_at_5_max": 0.539487,
113
+ "nauc_precision_at_5_std": 0.055974,
114
+ "nauc_precision_at_5_diff1": 0.580452,
115
+ "nauc_precision_at_10_max": 0.538971,
116
+ "nauc_precision_at_10_std": 0.129465,
117
+ "nauc_precision_at_10_diff1": 0.519139,
118
+ "nauc_precision_at_20_max": 0.572192,
119
+ "nauc_precision_at_20_std": 0.155817,
120
+ "nauc_precision_at_20_diff1": 0.47407,
121
+ "nauc_precision_at_100_max": 0.547704,
122
+ "nauc_precision_at_100_std": 0.193272,
123
+ "nauc_precision_at_100_diff1": 0.311819,
124
+ "nauc_precision_at_1000_max": 0.666392,
125
+ "nauc_precision_at_1000_std": 0.580073,
126
+ "nauc_precision_at_1000_diff1": 0.351308,
127
+ "nauc_mrr_at_1_max": 0.505659,
128
+ "nauc_mrr_at_1_std": 0.015906,
129
+ "nauc_mrr_at_1_diff1": 0.596287,
130
+ "nauc_mrr_at_3_max": 0.529127,
131
+ "nauc_mrr_at_3_std": 0.027918,
132
+ "nauc_mrr_at_3_diff1": 0.588492,
133
+ "nauc_mrr_at_5_max": 0.527978,
134
+ "nauc_mrr_at_5_std": 0.030311,
135
+ "nauc_mrr_at_5_diff1": 0.58816,
136
+ "nauc_mrr_at_10_max": 0.527822,
137
+ "nauc_mrr_at_10_std": 0.041362,
138
+ "nauc_mrr_at_10_diff1": 0.579864,
139
+ "nauc_mrr_at_20_max": 0.530149,
140
+ "nauc_mrr_at_20_std": 0.042933,
141
+ "nauc_mrr_at_20_diff1": 0.576511,
142
+ "nauc_mrr_at_100_max": 0.530358,
143
+ "nauc_mrr_at_100_std": 0.044082,
144
+ "nauc_mrr_at_100_diff1": 0.57324,
145
+ "nauc_mrr_at_1000_max": 0.530766,
146
+ "nauc_mrr_at_1000_std": 0.045058,
147
+ "nauc_mrr_at_1000_diff1": 0.57385,
148
+ "main_score": 0.37975,
149
+ "hf_subset": "ja",
150
+ "languages": [
151
+ "jpn-Jpan"
152
+ ]
153
+ }
154
+ ],
155
+ "test": [
156
+ {
157
+ "ndcg_at_1": 0.245,
158
+ "ndcg_at_3": 0.30928,
159
+ "ndcg_at_5": 0.33166,
160
+ "ndcg_at_10": 0.34569,
161
+ "ndcg_at_20": 0.35817,
162
+ "ndcg_at_100": 0.38436,
163
+ "ndcg_at_1000": 0.40594,
164
+ "map_at_1": 0.245,
165
+ "map_at_3": 0.295,
166
+ "map_at_5": 0.30725,
167
+ "map_at_10": 0.31275,
168
+ "map_at_20": 0.31608,
169
+ "map_at_100": 0.31973,
170
+ "map_at_1000": 0.32053,
171
+ "recall_at_1": 0.245,
172
+ "recall_at_3": 0.35,
173
+ "recall_at_5": 0.405,
174
+ "recall_at_10": 0.45,
175
+ "recall_at_20": 0.5,
176
+ "recall_at_100": 0.64,
177
+ "recall_at_1000": 0.81,
178
+ "precision_at_1": 0.245,
179
+ "precision_at_3": 0.11667,
180
+ "precision_at_5": 0.081,
181
+ "precision_at_10": 0.045,
182
+ "precision_at_20": 0.025,
183
+ "precision_at_100": 0.0064,
184
+ "precision_at_1000": 0.00081,
185
+ "mrr_at_1": 0.245,
186
+ "mrr_at_3": 0.295,
187
+ "mrr_at_5": 0.30725,
188
+ "mrr_at_10": 0.312748,
189
+ "mrr_at_20": 0.316079,
190
+ "mrr_at_100": 0.319726,
191
+ "mrr_at_1000": 0.320528,
192
+ "nauc_ndcg_at_1_max": 0.406893,
193
+ "nauc_ndcg_at_1_std": -0.009559,
194
+ "nauc_ndcg_at_1_diff1": 0.554901,
195
+ "nauc_ndcg_at_3_max": 0.444372,
196
+ "nauc_ndcg_at_3_std": -0.02926,
197
+ "nauc_ndcg_at_3_diff1": 0.509425,
198
+ "nauc_ndcg_at_5_max": 0.425091,
199
+ "nauc_ndcg_at_5_std": -0.031815,
200
+ "nauc_ndcg_at_5_diff1": 0.469611,
201
+ "nauc_ndcg_at_10_max": 0.447755,
202
+ "nauc_ndcg_at_10_std": -0.015871,
203
+ "nauc_ndcg_at_10_diff1": 0.462957,
204
+ "nauc_ndcg_at_20_max": 0.46053,
205
+ "nauc_ndcg_at_20_std": 0.005444,
206
+ "nauc_ndcg_at_20_diff1": 0.466256,
207
+ "nauc_ndcg_at_100_max": 0.461105,
208
+ "nauc_ndcg_at_100_std": 0.024618,
209
+ "nauc_ndcg_at_100_diff1": 0.453195,
210
+ "nauc_ndcg_at_1000_max": 0.465154,
211
+ "nauc_ndcg_at_1000_std": 0.038943,
212
+ "nauc_ndcg_at_1000_diff1": 0.448247,
213
+ "nauc_map_at_1_max": 0.406893,
214
+ "nauc_map_at_1_std": -0.009559,
215
+ "nauc_map_at_1_diff1": 0.554901,
216
+ "nauc_map_at_3_max": 0.435427,
217
+ "nauc_map_at_3_std": -0.025665,
218
+ "nauc_map_at_3_diff1": 0.517996,
219
+ "nauc_map_at_5_max": 0.424549,
220
+ "nauc_map_at_5_std": -0.027697,
221
+ "nauc_map_at_5_diff1": 0.495027,
222
+ "nauc_map_at_10_max": 0.433903,
223
+ "nauc_map_at_10_std": -0.021661,
224
+ "nauc_map_at_10_diff1": 0.492772,
225
+ "nauc_map_at_20_max": 0.437387,
226
+ "nauc_map_at_20_std": -0.015962,
227
+ "nauc_map_at_20_diff1": 0.493878,
228
+ "nauc_map_at_100_max": 0.437411,
229
+ "nauc_map_at_100_std": -0.013427,
230
+ "nauc_map_at_100_diff1": 0.49183,
231
+ "nauc_map_at_1000_max": 0.437574,
232
+ "nauc_map_at_1000_std": -0.012967,
233
+ "nauc_map_at_1000_diff1": 0.49167,
234
+ "nauc_recall_at_1_max": 0.406893,
235
+ "nauc_recall_at_1_std": -0.009559,
236
+ "nauc_recall_at_1_diff1": 0.554901,
237
+ "nauc_recall_at_3_max": 0.468877,
238
+ "nauc_recall_at_3_std": -0.038697,
239
+ "nauc_recall_at_3_diff1": 0.486789,
240
+ "nauc_recall_at_5_max": 0.424115,
241
+ "nauc_recall_at_5_std": -0.042187,
242
+ "nauc_recall_at_5_diff1": 0.397345,
243
+ "nauc_recall_at_10_max": 0.491711,
244
+ "nauc_recall_at_10_std": 0.008171,
245
+ "nauc_recall_at_10_diff1": 0.373993,
246
+ "nauc_recall_at_20_max": 0.541845,
247
+ "nauc_recall_at_20_std": 0.091805,
248
+ "nauc_recall_at_20_diff1": 0.383971,
249
+ "nauc_recall_at_100_max": 0.5606,
250
+ "nauc_recall_at_100_std": 0.225021,
251
+ "nauc_recall_at_100_diff1": 0.294953,
252
+ "nauc_recall_at_1000_max": 0.674171,
253
+ "nauc_recall_at_1000_std": 0.550841,
254
+ "nauc_recall_at_1000_diff1": 0.131908,
255
+ "nauc_precision_at_1_max": 0.406893,
256
+ "nauc_precision_at_1_std": -0.009559,
257
+ "nauc_precision_at_1_diff1": 0.554901,
258
+ "nauc_precision_at_3_max": 0.468877,
259
+ "nauc_precision_at_3_std": -0.038697,
260
+ "nauc_precision_at_3_diff1": 0.486789,
261
+ "nauc_precision_at_5_max": 0.424115,
262
+ "nauc_precision_at_5_std": -0.042187,
263
+ "nauc_precision_at_5_diff1": 0.397345,
264
+ "nauc_precision_at_10_max": 0.491711,
265
+ "nauc_precision_at_10_std": 0.008171,
266
+ "nauc_precision_at_10_diff1": 0.373993,
267
+ "nauc_precision_at_20_max": 0.541845,
268
+ "nauc_precision_at_20_std": 0.091805,
269
+ "nauc_precision_at_20_diff1": 0.383971,
270
+ "nauc_precision_at_100_max": 0.5606,
271
+ "nauc_precision_at_100_std": 0.225021,
272
+ "nauc_precision_at_100_diff1": 0.294953,
273
+ "nauc_precision_at_1000_max": 0.674171,
274
+ "nauc_precision_at_1000_std": 0.550841,
275
+ "nauc_precision_at_1000_diff1": 0.131908,
276
+ "nauc_mrr_at_1_max": 0.406893,
277
+ "nauc_mrr_at_1_std": -0.009559,
278
+ "nauc_mrr_at_1_diff1": 0.554901,
279
+ "nauc_mrr_at_3_max": 0.435427,
280
+ "nauc_mrr_at_3_std": -0.025665,
281
+ "nauc_mrr_at_3_diff1": 0.517996,
282
+ "nauc_mrr_at_5_max": 0.424549,
283
+ "nauc_mrr_at_5_std": -0.027697,
284
+ "nauc_mrr_at_5_diff1": 0.495027,
285
+ "nauc_mrr_at_10_max": 0.433903,
286
+ "nauc_mrr_at_10_std": -0.021661,
287
+ "nauc_mrr_at_10_diff1": 0.492772,
288
+ "nauc_mrr_at_20_max": 0.437387,
289
+ "nauc_mrr_at_20_std": -0.015962,
290
+ "nauc_mrr_at_20_diff1": 0.493878,
291
+ "nauc_mrr_at_100_max": 0.437411,
292
+ "nauc_mrr_at_100_std": -0.013427,
293
+ "nauc_mrr_at_100_diff1": 0.49183,
294
+ "nauc_mrr_at_1000_max": 0.437574,
295
+ "nauc_mrr_at_1000_std": -0.012967,
296
+ "nauc_mrr_at_1000_diff1": 0.49167,
297
+ "main_score": 0.34569,
298
+ "hf_subset": "ja",
299
+ "languages": [
300
+ "jpn-Jpan"
301
+ ]
302
+ }
303
+ ]
304
+ },
305
+ "evaluation_time": 297.0003197193146,
306
+ "kg_co2_emissions": null
307
+ }