Kevin Hu
commited on
Commit
·
c786948
1
Parent(s):
ba2f307
make highlight friendly to English (#2417)
Browse files### What problem does this PR solve?
#2415
### Type of change
- [x] Performance Improvement
- rag/nlp/__init__.py +1 -1
- rag/nlp/search.py +19 -17
- rag/utils/__init__.py +2 -2
rag/nlp/__init__.py
CHANGED
@@ -214,7 +214,7 @@ def is_english(texts):
|
|
214 |
eng = 0
|
215 |
if not texts: return False
|
216 |
for t in texts:
|
217 |
-
if re.match(r"[a-zA-Z]
|
218 |
eng += 1
|
219 |
if eng / len(texts) > 0.8:
|
220 |
return True
|
|
|
214 |
eng = 0
|
215 |
if not texts: return False
|
216 |
for t in texts:
|
217 |
+
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
|
218 |
eng += 1
|
219 |
if eng / len(texts) > 0.8:
|
220 |
return True
|
rag/nlp/search.py
CHANGED
@@ -24,7 +24,7 @@ from dataclasses import dataclass
|
|
24 |
|
25 |
from rag.settings import es_logger
|
26 |
from rag.utils import rmSpace
|
27 |
-
from rag.nlp import rag_tokenizer, query
|
28 |
import numpy as np
|
29 |
|
30 |
|
@@ -164,7 +164,7 @@ class Dealer:
|
|
164 |
ids=self.es.getDocIds(res),
|
165 |
query_vector=q_vec,
|
166 |
aggregation=aggs,
|
167 |
-
highlight=self.getHighlight(res),
|
168 |
field=self.getFields(res, src),
|
169 |
keywords=list(kwds)
|
170 |
)
|
@@ -175,26 +175,28 @@ class Dealer:
|
|
175 |
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
176 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
177 |
|
178 |
-
def getHighlight(self, res):
|
179 |
-
def rmspace(line):
|
180 |
-
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
181 |
-
r = []
|
182 |
-
for t in line.split(" "):
|
183 |
-
if not t:
|
184 |
-
continue
|
185 |
-
if len(r) > 0 and len(
|
186 |
-
t) > 0 and r[-1][-1] in eng and t[0] in eng:
|
187 |
-
r.append(" ")
|
188 |
-
r.append(t)
|
189 |
-
r = "".join(r)
|
190 |
-
return r
|
191 |
-
|
192 |
ans = {}
|
193 |
for d in res["hits"]["hits"]:
|
194 |
hlts = d.get("highlight")
|
195 |
if not hlts:
|
196 |
continue
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
return ans
|
199 |
|
200 |
def getFields(self, sres, flds):
|
|
|
24 |
|
25 |
from rag.settings import es_logger
|
26 |
from rag.utils import rmSpace
|
27 |
+
from rag.nlp import rag_tokenizer, query, is_english
|
28 |
import numpy as np
|
29 |
|
30 |
|
|
|
164 |
ids=self.es.getDocIds(res),
|
165 |
query_vector=q_vec,
|
166 |
aggregation=aggs,
|
167 |
+
highlight=self.getHighlight(res, keywords, "content_with_weight"),
|
168 |
field=self.getFields(res, src),
|
169 |
keywords=list(kwds)
|
170 |
)
|
|
|
175 |
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
176 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
177 |
|
178 |
+
def getHighlight(self, res, keywords, fieldnm):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
ans = {}
|
180 |
for d in res["hits"]["hits"]:
|
181 |
hlts = d.get("highlight")
|
182 |
if not hlts:
|
183 |
continue
|
184 |
+
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
185 |
+
if not is_english(txt.split(" ")):
|
186 |
+
ans[d["_id"]] = txt
|
187 |
+
continue
|
188 |
+
|
189 |
+
txt = d["_source"][fieldnm]
|
190 |
+
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
191 |
+
txts = []
|
192 |
+
for w in keywords:
|
193 |
+
txt = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", txt, flags=re.IGNORECASE|re.MULTILINE)
|
194 |
+
|
195 |
+
for t in re.split(r"[.?!;\n]", txt):
|
196 |
+
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
|
197 |
+
txts.append(t)
|
198 |
+
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
199 |
+
|
200 |
return ans
|
201 |
|
202 |
def getFields(self, sres, flds):
|
rag/utils/__init__.py
CHANGED
@@ -32,8 +32,8 @@ def singleton(cls, *args, **kw):
|
|
32 |
|
33 |
|
34 |
def rmSpace(txt):
|
35 |
-
txt = re.sub(r"([^a-z0-9
|
36 |
-
return re.sub(r"([^ ]) +([^a-z0-9
|
37 |
|
38 |
|
39 |
def findMaxDt(fnm):
|
|
|
32 |
|
33 |
|
34 |
def rmSpace(txt):
|
35 |
+
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
|
36 |
+
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
|
37 |
|
38 |
|
39 |
def findMaxDt(fnm):
|