Kevin Hu commited on
Commit
e023933
·
1 Parent(s): 5661fd5

fix: synonym bug (#3423)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

agent/component/generate.py CHANGED
@@ -104,6 +104,7 @@ class Generate(ComponentBase):
104
  retrieval_res = []
105
  self._param.inputs = []
106
  for para in self._param.parameters:
 
107
  if para["component_id"].split("@")[0].lower().find("begin") > 0:
108
  cpn_id, key = para["component_id"].split("@")
109
  for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
 
104
  retrieval_res = []
105
  self._param.inputs = []
106
  for para in self._param.parameters:
107
+ if not para.get("component_id"): continue
108
  if para["component_id"].split("@")[0].lower().find("begin") > 0:
109
  cpn_id, key = para["component_id"].split("@")
110
  for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
rag/benchmark.py CHANGED
@@ -27,6 +27,7 @@ from api.settings import retrievaler, docStoreConn
27
  from api.utils import get_uuid
28
  from rag.nlp import tokenize, search
29
  from ranx import evaluate
 
30
  import pandas as pd
31
  from tqdm import tqdm
32
 
@@ -247,14 +248,14 @@ class Benchmark:
247
  self.index_name = search.index_name(self.tenant_id)
248
  qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
249
  run = self._get_retrieval(qrels)
250
- print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
251
  self.save_results(qrels, run, texts, dataset, file_path)
252
  if dataset == "trivia_qa":
253
  self.tenant_id = "benchmark_trivia_qa"
254
  self.index_name = search.index_name(self.tenant_id)
255
  qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
256
  run = self._get_retrieval(qrels)
257
- print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
258
  self.save_results(qrels, run, texts, dataset, file_path)
259
  if dataset == "miracl":
260
  for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
@@ -278,7 +279,7 @@ class Benchmark:
278
  os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
279
  "benchmark_miracl_" + lang)
280
  run = self._get_retrieval(qrels)
281
- print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
282
  self.save_results(qrels, run, texts, dataset, file_path)
283
 
284
 
 
27
  from api.utils import get_uuid
28
  from rag.nlp import tokenize, search
29
  from ranx import evaluate
30
+ from ranx import Qrels, Run
31
  import pandas as pd
32
  from tqdm import tqdm
33
 
 
248
  self.index_name = search.index_name(self.tenant_id)
249
  qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
250
  run = self._get_retrieval(qrels)
251
+ print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
252
  self.save_results(qrels, run, texts, dataset, file_path)
253
  if dataset == "trivia_qa":
254
  self.tenant_id = "benchmark_trivia_qa"
255
  self.index_name = search.index_name(self.tenant_id)
256
  qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
257
  run = self._get_retrieval(qrels)
258
+ print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
259
  self.save_results(qrels, run, texts, dataset, file_path)
260
  if dataset == "miracl":
261
  for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
 
279
  os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
280
  "benchmark_miracl_" + lang)
281
  run = self._get_retrieval(qrels)
282
+ print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
283
  self.save_results(qrels, run, texts, dataset, file_path)
284
 
285
 
rag/nlp/query.py CHANGED
@@ -88,7 +88,7 @@ class FulltextQueryer:
88
  syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn]
89
  syns.append(" ".join(syn))
90
 
91
- q = ["({}^{:.4f}".format(tk, w) + " %s)".format() for (tk, w), syn in zip(tks_w, syns)]
92
  for i in range(1, len(tks_w)):
93
  q.append(
94
  '"%s %s"^%.4f'
 
88
  syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn]
89
  syns.append(" ".join(syn))
90
 
91
+ q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns)]
92
  for i in range(1, len(tks_w)):
93
  q.append(
94
  '"%s %s"^%.4f'