Kevin Hu
commited on
Commit
·
c786948
1
Parent(s):
ba2f307
make highlight friendly to English (#2417)
Browse files### What problem does this PR solve?
#2415
### Type of change
- [x] Performance Improvement
- rag/nlp/__init__.py +1 -1
- rag/nlp/search.py +19 -17
- rag/utils/__init__.py +2 -2
rag/nlp/__init__.py
CHANGED
|
@@ -214,7 +214,7 @@ def is_english(texts):
|
|
| 214 |
eng = 0
|
| 215 |
if not texts: return False
|
| 216 |
for t in texts:
|
| 217 |
-
if re.match(r"[a-zA-Z]
|
| 218 |
eng += 1
|
| 219 |
if eng / len(texts) > 0.8:
|
| 220 |
return True
|
|
|
|
| 214 |
eng = 0
|
| 215 |
if not texts: return False
|
| 216 |
for t in texts:
|
| 217 |
+
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
|
| 218 |
eng += 1
|
| 219 |
if eng / len(texts) > 0.8:
|
| 220 |
return True
|
rag/nlp/search.py
CHANGED
|
@@ -24,7 +24,7 @@ from dataclasses import dataclass
|
|
| 24 |
|
| 25 |
from rag.settings import es_logger
|
| 26 |
from rag.utils import rmSpace
|
| 27 |
-
from rag.nlp import rag_tokenizer, query
|
| 28 |
import numpy as np
|
| 29 |
|
| 30 |
|
|
@@ -164,7 +164,7 @@ class Dealer:
|
|
| 164 |
ids=self.es.getDocIds(res),
|
| 165 |
query_vector=q_vec,
|
| 166 |
aggregation=aggs,
|
| 167 |
-
highlight=self.getHighlight(res),
|
| 168 |
field=self.getFields(res, src),
|
| 169 |
keywords=list(kwds)
|
| 170 |
)
|
|
@@ -175,26 +175,28 @@ class Dealer:
|
|
| 175 |
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
| 176 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
| 177 |
|
| 178 |
-
def getHighlight(self, res):
|
| 179 |
-
def rmspace(line):
|
| 180 |
-
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
| 181 |
-
r = []
|
| 182 |
-
for t in line.split(" "):
|
| 183 |
-
if not t:
|
| 184 |
-
continue
|
| 185 |
-
if len(r) > 0 and len(
|
| 186 |
-
t) > 0 and r[-1][-1] in eng and t[0] in eng:
|
| 187 |
-
r.append(" ")
|
| 188 |
-
r.append(t)
|
| 189 |
-
r = "".join(r)
|
| 190 |
-
return r
|
| 191 |
-
|
| 192 |
ans = {}
|
| 193 |
for d in res["hits"]["hits"]:
|
| 194 |
hlts = d.get("highlight")
|
| 195 |
if not hlts:
|
| 196 |
continue
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
return ans
|
| 199 |
|
| 200 |
def getFields(self, sres, flds):
|
|
|
|
| 24 |
|
| 25 |
from rag.settings import es_logger
|
| 26 |
from rag.utils import rmSpace
|
| 27 |
+
from rag.nlp import rag_tokenizer, query, is_english
|
| 28 |
import numpy as np
|
| 29 |
|
| 30 |
|
|
|
|
| 164 |
ids=self.es.getDocIds(res),
|
| 165 |
query_vector=q_vec,
|
| 166 |
aggregation=aggs,
|
| 167 |
+
highlight=self.getHighlight(res, keywords, "content_with_weight"),
|
| 168 |
field=self.getFields(res, src),
|
| 169 |
keywords=list(kwds)
|
| 170 |
)
|
|
|
|
| 175 |
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
| 176 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
| 177 |
|
| 178 |
+
def getHighlight(self, res, keywords, fieldnm):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
ans = {}
|
| 180 |
for d in res["hits"]["hits"]:
|
| 181 |
hlts = d.get("highlight")
|
| 182 |
if not hlts:
|
| 183 |
continue
|
| 184 |
+
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
| 185 |
+
if not is_english(txt.split(" ")):
|
| 186 |
+
ans[d["_id"]] = txt
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
txt = d["_source"][fieldnm]
|
| 190 |
+
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
| 191 |
+
txts = []
|
| 192 |
+
for w in keywords:
|
| 193 |
+
txt = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", txt, flags=re.IGNORECASE|re.MULTILINE)
|
| 194 |
+
|
| 195 |
+
for t in re.split(r"[.?!;\n]", txt):
|
| 196 |
+
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
|
| 197 |
+
txts.append(t)
|
| 198 |
+
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
| 199 |
+
|
| 200 |
return ans
|
| 201 |
|
| 202 |
def getFields(self, sres, flds):
|
rag/utils/__init__.py
CHANGED
|
@@ -32,8 +32,8 @@ def singleton(cls, *args, **kw):
|
|
| 32 |
|
| 33 |
|
| 34 |
def rmSpace(txt):
|
| 35 |
-
txt = re.sub(r"([^a-z0-9
|
| 36 |
-
return re.sub(r"([^ ]) +([^a-z0-9
|
| 37 |
|
| 38 |
|
| 39 |
def findMaxDt(fnm):
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def rmSpace(txt):
|
| 35 |
+
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
|
| 36 |
+
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
|
| 37 |
|
| 38 |
|
| 39 |
def findMaxDt(fnm):
|