Kevin Hu
commited on
Commit
·
1f1194f
1
Parent(s):
cfb71b4
fix benchmark issue (#3324)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/benchmark.py +18 -9
rag/benchmark.py
CHANGED
@@ -30,6 +30,7 @@ from rag.utils.es_conn import ELASTICSEARCH
|
|
30 |
from ranx import evaluate
|
31 |
import pandas as pd
|
32 |
from tqdm import tqdm
|
|
|
33 |
|
34 |
|
35 |
class Benchmark:
|
@@ -50,8 +51,8 @@ class Benchmark:
|
|
50 |
query_list = list(qrels.keys())
|
51 |
for query in query_list:
|
52 |
|
53 |
-
ranks = retrievaler.retrieval(query, self.embd_mdl,
|
54 |
-
[self.kb.id],
|
55 |
0.0, self.vector_similarity_weight)
|
56 |
for c in ranks["chunks"]:
|
57 |
if "vector" in c:
|
@@ -105,7 +106,9 @@ class Benchmark:
|
|
105 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
106 |
d = {
|
107 |
"id": get_uuid(),
|
108 |
-
"kb_id": self.kb.id
|
|
|
|
|
109 |
}
|
110 |
tokenize(d, text, "english")
|
111 |
docs.append(d)
|
@@ -137,7 +140,10 @@ class Benchmark:
|
|
137 |
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
138 |
data.iloc[i]["search_results"]['search_context']):
|
139 |
d = {
|
140 |
-
"id": get_uuid()
|
|
|
|
|
|
|
141 |
}
|
142 |
tokenize(d, text, "english")
|
143 |
docs.append(d)
|
@@ -182,7 +188,10 @@ class Benchmark:
|
|
182 |
text = corpus_total[tmp_data.iloc[i]['docid']]
|
183 |
rel = tmp_data.iloc[i]['relevance']
|
184 |
d = {
|
185 |
-
"id": get_uuid()
|
|
|
|
|
|
|
186 |
}
|
187 |
tokenize(d, text, 'english')
|
188 |
docs.append(d)
|
@@ -204,7 +213,7 @@ class Benchmark:
|
|
204 |
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
205 |
key = run_keys[run_i]
|
206 |
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
207 |
-
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
|
208 |
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
209 |
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
210 |
f.write('## Score For Every Query\n')
|
@@ -222,12 +231,12 @@ class Benchmark:
|
|
222 |
if dataset == "ms_marco_v1.1":
|
223 |
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
224 |
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
|
225 |
-
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
226 |
self.save_results(qrels, run, texts, dataset, file_path)
|
227 |
if dataset == "trivia_qa":
|
228 |
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
229 |
run = self._get_retrieval(qrels, "benchmark_trivia_qa")
|
230 |
-
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
231 |
self.save_results(qrels, run, texts, dataset, file_path)
|
232 |
if dataset == "miracl":
|
233 |
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
@@ -248,7 +257,7 @@ class Benchmark:
|
|
248 |
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
249 |
"benchmark_miracl_" + lang)
|
250 |
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
|
251 |
-
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
252 |
self.save_results(qrels, run, texts, dataset, file_path)
|
253 |
|
254 |
|
|
|
30 |
from ranx import evaluate
|
31 |
import pandas as pd
|
32 |
from tqdm import tqdm
|
33 |
+
from ranx import Qrels, Run
|
34 |
|
35 |
|
36 |
class Benchmark:
|
|
|
51 |
query_list = list(qrels.keys())
|
52 |
for query in query_list:
|
53 |
|
54 |
+
ranks = retrievaler.retrieval(query, self.embd_mdl,
|
55 |
+
dataset_idxnm, [self.kb.id], 1, 30,
|
56 |
0.0, self.vector_similarity_weight)
|
57 |
for c in ranks["chunks"]:
|
58 |
if "vector" in c:
|
|
|
106 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
107 |
d = {
|
108 |
"id": get_uuid(),
|
109 |
+
"kb_id": self.kb.id,
|
110 |
+
"docnm_kwd": "xxxxx",
|
111 |
+
"doc_id": "ksksks"
|
112 |
}
|
113 |
tokenize(d, text, "english")
|
114 |
docs.append(d)
|
|
|
140 |
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
141 |
data.iloc[i]["search_results"]['search_context']):
|
142 |
d = {
|
143 |
+
"id": get_uuid(),
|
144 |
+
"kb_id": self.kb.id,
|
145 |
+
"docnm_kwd": "xxxxx",
|
146 |
+
"doc_id": "ksksks"
|
147 |
}
|
148 |
tokenize(d, text, "english")
|
149 |
docs.append(d)
|
|
|
188 |
text = corpus_total[tmp_data.iloc[i]['docid']]
|
189 |
rel = tmp_data.iloc[i]['relevance']
|
190 |
d = {
|
191 |
+
"id": get_uuid(),
|
192 |
+
"kb_id": self.kb.id,
|
193 |
+
"docnm_kwd": "xxxxx",
|
194 |
+
"doc_id": "ksksks"
|
195 |
}
|
196 |
tokenize(d, text, 'english')
|
197 |
docs.append(d)
|
|
|
213 |
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
214 |
key = run_keys[run_i]
|
215 |
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
216 |
+
'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
|
217 |
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
218 |
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
219 |
f.write('## Score For Every Query\n')
|
|
|
231 |
if dataset == "ms_marco_v1.1":
|
232 |
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
233 |
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
|
234 |
+
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
235 |
self.save_results(qrels, run, texts, dataset, file_path)
|
236 |
if dataset == "trivia_qa":
|
237 |
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
238 |
run = self._get_retrieval(qrels, "benchmark_trivia_qa")
|
239 |
+
print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
240 |
self.save_results(qrels, run, texts, dataset, file_path)
|
241 |
if dataset == "miracl":
|
242 |
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
|
|
257 |
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
258 |
"benchmark_miracl_" + lang)
|
259 |
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
|
260 |
+
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
261 |
self.save_results(qrels, run, texts, dataset, file_path)
|
262 |
|
263 |
|