Kevin Hu commited on
Commit
1f1194f
·
1 Parent(s): cfb71b4

fix benchmark issue (#3324)

Browse files

### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/benchmark.py +18 -9
rag/benchmark.py CHANGED
@@ -30,6 +30,7 @@ from rag.utils.es_conn import ELASTICSEARCH
30
  from ranx import evaluate
31
  import pandas as pd
32
  from tqdm import tqdm
 
33
 
34
 
35
  class Benchmark:
@@ -50,8 +51,8 @@ class Benchmark:
50
  query_list = list(qrels.keys())
51
  for query in query_list:
52
 
53
- ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""),
54
- [self.kb.id], 0, 30,
55
  0.0, self.vector_similarity_weight)
56
  for c in ranks["chunks"]:
57
  if "vector" in c:
@@ -105,7 +106,9 @@ class Benchmark:
105
  for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
106
  d = {
107
  "id": get_uuid(),
108
- "kb_id": self.kb.id
 
 
109
  }
110
  tokenize(d, text, "english")
111
  docs.append(d)
@@ -137,7 +140,10 @@ class Benchmark:
137
  for rel, text in zip(data.iloc[i]["search_results"]['rank'],
138
  data.iloc[i]["search_results"]['search_context']):
139
  d = {
140
- "id": get_uuid()
 
 
 
141
  }
142
  tokenize(d, text, "english")
143
  docs.append(d)
@@ -182,7 +188,10 @@ class Benchmark:
182
  text = corpus_total[tmp_data.iloc[i]['docid']]
183
  rel = tmp_data.iloc[i]['relevance']
184
  d = {
185
- "id": get_uuid()
 
 
 
186
  }
187
  tokenize(d, text, 'english')
188
  docs.append(d)
@@ -204,7 +213,7 @@ class Benchmark:
204
  for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
205
  key = run_keys[run_i]
206
  keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
207
- 'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
208
  keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
209
  with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
210
  f.write('## Score For Every Query\n')
@@ -222,12 +231,12 @@ class Benchmark:
222
  if dataset == "ms_marco_v1.1":
223
  qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
224
  run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
225
- print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
226
  self.save_results(qrels, run, texts, dataset, file_path)
227
  if dataset == "trivia_qa":
228
  qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
229
  run = self._get_retrieval(qrels, "benchmark_trivia_qa")
230
- print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
231
  self.save_results(qrels, run, texts, dataset, file_path)
232
  if dataset == "miracl":
233
  for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
@@ -248,7 +257,7 @@ class Benchmark:
248
  os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
249
  "benchmark_miracl_" + lang)
250
  run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
251
- print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
252
  self.save_results(qrels, run, texts, dataset, file_path)
253
 
254
 
 
30
  from ranx import evaluate
31
  import pandas as pd
32
  from tqdm import tqdm
33
+ from ranx import Qrels, Run
34
 
35
 
36
  class Benchmark:
 
51
  query_list = list(qrels.keys())
52
  for query in query_list:
53
 
54
+ ranks = retrievaler.retrieval(query, self.embd_mdl,
55
+ dataset_idxnm, [self.kb.id], 1, 30,
56
  0.0, self.vector_similarity_weight)
57
  for c in ranks["chunks"]:
58
  if "vector" in c:
 
106
  for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
107
  d = {
108
  "id": get_uuid(),
109
+ "kb_id": self.kb.id,
110
+ "docnm_kwd": "xxxxx",
111
+ "doc_id": "ksksks"
112
  }
113
  tokenize(d, text, "english")
114
  docs.append(d)
 
140
  for rel, text in zip(data.iloc[i]["search_results"]['rank'],
141
  data.iloc[i]["search_results"]['search_context']):
142
  d = {
143
+ "id": get_uuid(),
144
+ "kb_id": self.kb.id,
145
+ "docnm_kwd": "xxxxx",
146
+ "doc_id": "ksksks"
147
  }
148
  tokenize(d, text, "english")
149
  docs.append(d)
 
188
  text = corpus_total[tmp_data.iloc[i]['docid']]
189
  rel = tmp_data.iloc[i]['relevance']
190
  d = {
191
+ "id": get_uuid(),
192
+ "kb_id": self.kb.id,
193
+ "docnm_kwd": "xxxxx",
194
+ "doc_id": "ksksks"
195
  }
196
  tokenize(d, text, 'english')
197
  docs.append(d)
 
213
  for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
214
  key = run_keys[run_i]
215
  keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
216
+ 'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
217
  keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
218
  with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
219
  f.write('## Score For Every Query\n')
 
231
  if dataset == "ms_marco_v1.1":
232
  qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
233
  run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
234
+ print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
235
  self.save_results(qrels, run, texts, dataset, file_path)
236
  if dataset == "trivia_qa":
237
  qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
238
  run = self._get_retrieval(qrels, "benchmark_trivia_qa")
239
+ print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
240
  self.save_results(qrels, run, texts, dataset, file_path)
241
  if dataset == "miracl":
242
  for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
 
257
  os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
258
  "benchmark_miracl_" + lang)
259
  run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
260
+ print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
261
  self.save_results(qrels, run, texts, dataset, file_path)
262
 
263