Kevin Hu commited on
Commit
692cc99
·
1 Parent(s): 0b587a0

fix: term weight issue (#3294)

Browse files

### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (2) hide show
  1. rag/benchmark.py +34 -5
  2. rag/nlp/term_weight.py +1 -1
rag/benchmark.py CHANGED
@@ -16,11 +16,15 @@
16
  import json
17
  import os
18
  from collections import defaultdict
 
 
 
19
  from api.db import LLMType
20
  from api.db.services.llm_service import LLMBundle
21
  from api.db.services.knowledgebase_service import KnowledgebaseService
22
  from api.settings import retrievaler
23
  from api.utils import get_uuid
 
24
  from rag.nlp import tokenize, search
25
  from rag.utils.es_conn import ELASTICSEARCH
26
  from ranx import evaluate
@@ -63,14 +67,34 @@ class Benchmark:
63
  d["q_%d_vec" % len(v)] = v
64
  return docs
65
 
 
 
 
 
 
 
 
 
 
66
  def ms_marco_index(self, file_path, index_name):
67
  qrels = defaultdict(dict)
68
  texts = defaultdict(dict)
69
  docs = []
70
  filelist = os.listdir(file_path)
 
 
 
 
 
 
 
 
 
 
 
71
  for dir in filelist:
72
  data = pd.read_parquet(os.path.join(file_path, dir))
73
- for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
74
 
75
  query = data.iloc[i]['query']
76
  for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
@@ -82,12 +106,17 @@ class Benchmark:
82
  texts[d["id"]] = text
83
  qrels[query][d["id"]] = int(rel)
84
  if len(docs) >= 32:
85
- docs = self.embedding(docs)
86
- ELASTICSEARCH.bulk(docs, search.index_name(index_name))
87
  docs = []
88
 
89
- docs = self.embedding(docs)
90
- ELASTICSEARCH.bulk(docs, search.index_name(index_name))
 
 
 
 
 
91
  return qrels, texts
92
 
93
  def trivia_qa_index(self, file_path, index_name):
 
16
  import json
17
  import os
18
  from collections import defaultdict
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from copy import deepcopy
21
+
22
  from api.db import LLMType
23
  from api.db.services.llm_service import LLMBundle
24
  from api.db.services.knowledgebase_service import KnowledgebaseService
25
  from api.settings import retrievaler
26
  from api.utils import get_uuid
27
+ from api.utils.file_utils import get_project_base_directory
28
  from rag.nlp import tokenize, search
29
  from rag.utils.es_conn import ELASTICSEARCH
30
  from ranx import evaluate
 
67
  d["q_%d_vec" % len(v)] = v
68
  return docs
69
 
70
+ @staticmethod
71
+ def init_kb(index_name):
72
+ idxnm = search.index_name(index_name)
73
+ if ELASTICSEARCH.indexExist(idxnm):
74
+ ELASTICSEARCH.deleteIdx(search.index_name(index_name))
75
+
76
+ return ELASTICSEARCH.createIdx(idxnm, json.load(
77
+ open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
78
+
79
  def ms_marco_index(self, file_path, index_name):
80
  qrels = defaultdict(dict)
81
  texts = defaultdict(dict)
82
  docs = []
83
  filelist = os.listdir(file_path)
84
+ self.init_kb(index_name)
85
+
86
+ max_workers = int(os.environ.get('MAX_WORKERS', 3))
87
+ exe = ThreadPoolExecutor(max_workers=max_workers)
88
+ threads = []
89
+
90
+ def slow_actions(es_docs, idx_nm):
91
+ es_docs = self.embedding(es_docs)
92
+ ELASTICSEARCH.bulk(es_docs, idx_nm)
93
+ return True
94
+
95
  for dir in filelist:
96
  data = pd.read_parquet(os.path.join(file_path, dir))
97
+ for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
98
 
99
  query = data.iloc[i]['query']
100
  for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
 
106
  texts[d["id"]] = text
107
  qrels[query][d["id"]] = int(rel)
108
  if len(docs) >= 32:
109
+ threads.append(
110
+ exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
111
  docs = []
112
 
113
+ threads.append(
114
+ exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
115
+
116
+ for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
117
+ if not threads[i].result().output:
118
+ print("Indexing error...")
119
+
120
  return qrels, texts
121
 
122
  def trivia_qa_index(self, file_path, index_name):
rag/nlp/term_weight.py CHANGED
@@ -227,7 +227,7 @@ class Dealer:
227
  idf2 = np.array([idf(df(t), 1000000000) for t in tks])
228
  wts = (0.3 * idf1 + 0.7 * idf2) * \
229
  np.array([ner(t) * postag(t) for t in tks])
230
- tw = zip(tks, wts)
231
  else:
232
  for tk in tks:
233
  tt = self.tokenMerge(self.pretoken(tk, True))
 
227
  idf2 = np.array([idf(df(t), 1000000000) for t in tks])
228
  wts = (0.3 * idf1 + 0.7 * idf2) * \
229
  np.array([ner(t) * postag(t) for t in tks])
230
+ tw = list(zip(tks, wts))
231
  else:
232
  for tk in tks:
233
  tt = self.tokenMerge(self.pretoken(tk, True))