Kevin Hu
commited on
Commit
·
692cc99
1
Parent(s):
0b587a0
fix: term weight issue (#3294)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/benchmark.py +34 -5
- rag/nlp/term_weight.py +1 -1
rag/benchmark.py
CHANGED
@@ -16,11 +16,15 @@
|
|
16 |
import json
|
17 |
import os
|
18 |
from collections import defaultdict
|
|
|
|
|
|
|
19 |
from api.db import LLMType
|
20 |
from api.db.services.llm_service import LLMBundle
|
21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
22 |
from api.settings import retrievaler
|
23 |
from api.utils import get_uuid
|
|
|
24 |
from rag.nlp import tokenize, search
|
25 |
from rag.utils.es_conn import ELASTICSEARCH
|
26 |
from ranx import evaluate
|
@@ -63,14 +67,34 @@ class Benchmark:
|
|
63 |
d["q_%d_vec" % len(v)] = v
|
64 |
return docs
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
def ms_marco_index(self, file_path, index_name):
|
67 |
qrels = defaultdict(dict)
|
68 |
texts = defaultdict(dict)
|
69 |
docs = []
|
70 |
filelist = os.listdir(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
for dir in filelist:
|
72 |
data = pd.read_parquet(os.path.join(file_path, dir))
|
73 |
-
for i in tqdm(range(len(data)), colour="green", desc="
|
74 |
|
75 |
query = data.iloc[i]['query']
|
76 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
@@ -82,12 +106,17 @@ class Benchmark:
|
|
82 |
texts[d["id"]] = text
|
83 |
qrels[query][d["id"]] = int(rel)
|
84 |
if len(docs) >= 32:
|
85 |
-
|
86 |
-
|
87 |
docs = []
|
88 |
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
return qrels, texts
|
92 |
|
93 |
def trivia_qa_index(self, file_path, index_name):
|
|
|
16 |
import json
|
17 |
import os
|
18 |
from collections import defaultdict
|
19 |
+
from concurrent.futures import ThreadPoolExecutor
|
20 |
+
from copy import deepcopy
|
21 |
+
|
22 |
from api.db import LLMType
|
23 |
from api.db.services.llm_service import LLMBundle
|
24 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
25 |
from api.settings import retrievaler
|
26 |
from api.utils import get_uuid
|
27 |
+
from api.utils.file_utils import get_project_base_directory
|
28 |
from rag.nlp import tokenize, search
|
29 |
from rag.utils.es_conn import ELASTICSEARCH
|
30 |
from ranx import evaluate
|
|
|
67 |
d["q_%d_vec" % len(v)] = v
|
68 |
return docs
|
69 |
|
70 |
+
@staticmethod
|
71 |
+
def init_kb(index_name):
|
72 |
+
idxnm = search.index_name(index_name)
|
73 |
+
if ELASTICSEARCH.indexExist(idxnm):
|
74 |
+
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
|
75 |
+
|
76 |
+
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
77 |
+
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
78 |
+
|
79 |
def ms_marco_index(self, file_path, index_name):
|
80 |
qrels = defaultdict(dict)
|
81 |
texts = defaultdict(dict)
|
82 |
docs = []
|
83 |
filelist = os.listdir(file_path)
|
84 |
+
self.init_kb(index_name)
|
85 |
+
|
86 |
+
max_workers = int(os.environ.get('MAX_WORKERS', 3))
|
87 |
+
exe = ThreadPoolExecutor(max_workers=max_workers)
|
88 |
+
threads = []
|
89 |
+
|
90 |
+
def slow_actions(es_docs, idx_nm):
|
91 |
+
es_docs = self.embedding(es_docs)
|
92 |
+
ELASTICSEARCH.bulk(es_docs, idx_nm)
|
93 |
+
return True
|
94 |
+
|
95 |
for dir in filelist:
|
96 |
data = pd.read_parquet(os.path.join(file_path, dir))
|
97 |
+
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
|
98 |
|
99 |
query = data.iloc[i]['query']
|
100 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
|
|
106 |
texts[d["id"]] = text
|
107 |
qrels[query][d["id"]] = int(rel)
|
108 |
if len(docs) >= 32:
|
109 |
+
threads.append(
|
110 |
+
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
111 |
docs = []
|
112 |
|
113 |
+
threads.append(
|
114 |
+
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
115 |
+
|
116 |
+
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
|
117 |
+
if not threads[i].result().output:
|
118 |
+
print("Indexing error...")
|
119 |
+
|
120 |
return qrels, texts
|
121 |
|
122 |
def trivia_qa_index(self, file_path, index_name):
|
rag/nlp/term_weight.py
CHANGED
@@ -227,7 +227,7 @@ class Dealer:
|
|
227 |
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
228 |
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
229 |
np.array([ner(t) * postag(t) for t in tks])
|
230 |
-
tw = zip(tks, wts)
|
231 |
else:
|
232 |
for tk in tks:
|
233 |
tt = self.tokenMerge(self.pretoken(tk, True))
|
|
|
227 |
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
228 |
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
229 |
np.array([ner(t) * postag(t) for t in tks])
|
230 |
+
tw = list(zip(tks, wts))
|
231 |
else:
|
232 |
for tk in tks:
|
233 |
tt = self.tokenMerge(self.pretoken(tk, True))
|