Kevin Hu
commited on
Commit
·
95863fc
1
Parent(s):
7fcd041
search between multiple indiices for team function (#3079)
Browse files### What problem does this PR solve?
#2834
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- agent/component/__init__.py +1 -0
- agent/component/generate.py +9 -5
- agent/component/invoke.py +22 -3
- api/db/services/dialog_service.py +3 -1
- deepdoc/parser/html_parser.py +2 -0
- rag/nlp/search.py +12 -6
- rag/utils/es_conn.py +4 -2
agent/component/__init__.py
CHANGED
@@ -29,6 +29,7 @@ from .jin10 import Jin10, Jin10Param
|
|
29 |
from .tushare import TuShare, TuShareParam
|
30 |
from .akshare import AkShare, AkShareParam
|
31 |
from .crawler import Crawler, CrawlerParam
|
|
|
32 |
|
33 |
|
34 |
def component_class(class_name):
|
|
|
29 |
from .tushare import TuShare, TuShareParam
|
30 |
from .akshare import AkShare, AkShareParam
|
31 |
from .crawler import Crawler, CrawlerParam
|
32 |
+
from .invoke import Invoke, InvokeParam
|
33 |
|
34 |
|
35 |
def component_class(class_name):
|
agent/component/generate.py
CHANGED
@@ -17,6 +17,7 @@ import re
|
|
17 |
from functools import partial
|
18 |
import pandas as pd
|
19 |
from api.db import LLMType
|
|
|
20 |
from api.db.services.llm_service import LLMBundle
|
21 |
from api.settings import retrievaler
|
22 |
from agent.component.base import ComponentBase, ComponentParamBase
|
@@ -112,7 +113,7 @@ class Generate(ComponentBase):
|
|
112 |
|
113 |
kwargs["input"] = input
|
114 |
for n, v in kwargs.items():
|
115 |
-
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt)
|
116 |
|
117 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
118 |
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
@@ -124,8 +125,10 @@ class Generate(ComponentBase):
|
|
124 |
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
|
125 |
return pd.DataFrame([res])
|
126 |
|
127 |
-
|
128 |
-
|
|
|
|
|
129 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
130 |
res = self.set_cite(retrieval_res, ans)
|
131 |
return pd.DataFrame([res])
|
@@ -141,9 +144,10 @@ class Generate(ComponentBase):
|
|
141 |
self.set_output(res)
|
142 |
return
|
143 |
|
|
|
|
|
144 |
answer = ""
|
145 |
-
for ans in chat_mdl.chat_streamly(
|
146 |
-
self._param.gen_conf()):
|
147 |
res = {"content": ans, "reference": []}
|
148 |
answer = ans
|
149 |
yield res
|
|
|
17 |
from functools import partial
|
18 |
import pandas as pd
|
19 |
from api.db import LLMType
|
20 |
+
from api.db.services.dialog_service import message_fit_in
|
21 |
from api.db.services.llm_service import LLMBundle
|
22 |
from api.settings import retrievaler
|
23 |
from agent.component.base import ComponentBase, ComponentParamBase
|
|
|
113 |
|
114 |
kwargs["input"] = input
|
115 |
for n, v in kwargs.items():
|
116 |
+
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
117 |
|
118 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
119 |
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
|
|
125 |
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
|
126 |
return pd.DataFrame([res])
|
127 |
|
128 |
+
msg = self._canvas.get_history(self._param.message_history_window_size)
|
129 |
+
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
130 |
+
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
|
131 |
+
|
132 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
133 |
res = self.set_cite(retrieval_res, ans)
|
134 |
return pd.DataFrame([res])
|
|
|
144 |
self.set_output(res)
|
145 |
return
|
146 |
|
147 |
+
msg = self._canvas.get_history(self._param.message_history_window_size)
|
148 |
+
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
149 |
answer = ""
|
150 |
+
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
|
|
|
151 |
res = {"content": ans, "reference": []}
|
152 |
answer = ans
|
153 |
yield res
|
agent/component/invoke.py
CHANGED
@@ -14,10 +14,10 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
|
|
17 |
from abc import ABC
|
18 |
-
|
19 |
import requests
|
20 |
-
|
21 |
from agent.component.base import ComponentBase, ComponentParamBase
|
22 |
|
23 |
|
@@ -34,11 +34,13 @@ class InvokeParam(ComponentParamBase):
|
|
34 |
self.variables = []
|
35 |
self.url = ""
|
36 |
self.timeout = 60
|
|
|
37 |
|
38 |
def check(self):
|
39 |
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
|
40 |
self.check_empty(self.url, "End point URL")
|
41 |
self.check_positive_integer(self.timeout, "Timeout time in second")
|
|
|
42 |
|
43 |
|
44 |
class Invoke(ComponentBase, ABC):
|
@@ -63,7 +65,7 @@ class Invoke(ComponentBase, ABC):
|
|
63 |
if self._param.headers:
|
64 |
headers = json.loads(self._param.headers)
|
65 |
proxies = None
|
66 |
-
if self._param.proxy:
|
67 |
proxies = {"http": self._param.proxy, "https": self._param.proxy}
|
68 |
|
69 |
if method == 'get':
|
@@ -72,6 +74,10 @@ class Invoke(ComponentBase, ABC):
|
|
72 |
headers=headers,
|
73 |
proxies=proxies,
|
74 |
timeout=self._param.timeout)
|
|
|
|
|
|
|
|
|
75 |
return Invoke.be_output(response.text)
|
76 |
|
77 |
if method == 'put':
|
@@ -80,5 +86,18 @@ class Invoke(ComponentBase, ABC):
|
|
80 |
headers=headers,
|
81 |
proxies=proxies,
|
82 |
timeout=self._param.timeout)
|
|
|
|
|
|
|
|
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
return Invoke.be_output(response.text)
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
17 |
+
import re
|
18 |
from abc import ABC
|
|
|
19 |
import requests
|
20 |
+
from deepdoc.parser import HtmlParser
|
21 |
from agent.component.base import ComponentBase, ComponentParamBase
|
22 |
|
23 |
|
|
|
34 |
self.variables = []
|
35 |
self.url = ""
|
36 |
self.timeout = 60
|
37 |
+
self.clean_html = False
|
38 |
|
39 |
def check(self):
|
40 |
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
|
41 |
self.check_empty(self.url, "End point URL")
|
42 |
self.check_positive_integer(self.timeout, "Timeout time in second")
|
43 |
+
self.check_boolean(self.clean_html, "Clean HTML")
|
44 |
|
45 |
|
46 |
class Invoke(ComponentBase, ABC):
|
|
|
65 |
if self._param.headers:
|
66 |
headers = json.loads(self._param.headers)
|
67 |
proxies = None
|
68 |
+
if re.sub(r"https?:?/?/?", "", self._param.proxy):
|
69 |
proxies = {"http": self._param.proxy, "https": self._param.proxy}
|
70 |
|
71 |
if method == 'get':
|
|
|
74 |
headers=headers,
|
75 |
proxies=proxies,
|
76 |
timeout=self._param.timeout)
|
77 |
+
if self._param.clean_html:
|
78 |
+
sections = HtmlParser()(None, response.content)
|
79 |
+
return Invoke.be_output("\n".join(sections))
|
80 |
+
|
81 |
return Invoke.be_output(response.text)
|
82 |
|
83 |
if method == 'put':
|
|
|
86 |
headers=headers,
|
87 |
proxies=proxies,
|
88 |
timeout=self._param.timeout)
|
89 |
+
if self._param.clean_html:
|
90 |
+
sections = HtmlParser()(None, response.content)
|
91 |
+
return Invoke.be_output("\n".join(sections))
|
92 |
+
return Invoke.be_output(response.text)
|
93 |
|
94 |
+
if method == 'post':
|
95 |
+
response = requests.post(url=url,
|
96 |
+
json=args,
|
97 |
+
headers=headers,
|
98 |
+
proxies=proxies,
|
99 |
+
timeout=self._param.timeout)
|
100 |
+
if self._param.clean_html:
|
101 |
+
sections = HtmlParser()(None, response.content)
|
102 |
+
return Invoke.be_output("\n".join(sections))
|
103 |
return Invoke.be_output(response.text)
|
api/db/services/dialog_service.py
CHANGED
@@ -205,7 +205,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
205 |
else:
|
206 |
if prompt_config.get("keyword", False):
|
207 |
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
208 |
-
|
|
|
|
|
209 |
dialog.similarity_threshold,
|
210 |
dialog.vector_similarity_weight,
|
211 |
doc_ids=attachments,
|
|
|
205 |
else:
|
206 |
if prompt_config.get("keyword", False):
|
207 |
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
208 |
+
|
209 |
+
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
210 |
+
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
211 |
dialog.similarity_threshold,
|
212 |
dialog.vector_similarity_weight,
|
213 |
doc_ids=attachments,
|
deepdoc/parser/html_parser.py
CHANGED
@@ -16,11 +16,13 @@ import readability
|
|
16 |
import html_text
|
17 |
import chardet
|
18 |
|
|
|
19 |
def get_encoding(file):
|
20 |
with open(file,'rb') as f:
|
21 |
tmp = chardet.detect(f.read())
|
22 |
return tmp['encoding']
|
23 |
|
|
|
24 |
class RAGFlowHtmlParser:
|
25 |
def __call__(self, fnm, binary=None):
|
26 |
txt = ""
|
|
|
16 |
import html_text
|
17 |
import chardet
|
18 |
|
19 |
+
|
20 |
def get_encoding(file):
|
21 |
with open(file,'rb') as f:
|
22 |
tmp = chardet.detect(f.read())
|
23 |
return tmp['encoding']
|
24 |
|
25 |
+
|
26 |
class RAGFlowHtmlParser:
|
27 |
def __call__(self, fnm, binary=None):
|
28 |
txt = ""
|
rag/nlp/search.py
CHANGED
@@ -79,7 +79,7 @@ class Dealer:
|
|
79 |
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
80 |
return bqry
|
81 |
|
82 |
-
def search(self, req,
|
83 |
qst = req.get("question", "")
|
84 |
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
85 |
bqry = self._add_filters(bqry, req)
|
@@ -134,7 +134,7 @@ class Dealer:
|
|
134 |
del s["highlight"]
|
135 |
q_vec = s["knn"]["query_vector"]
|
136 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
137 |
-
res = self.es.search(deepcopy(s),
|
138 |
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
139 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
140 |
bqry, _ = self.qryr.question(qst, min_match="10%")
|
@@ -144,7 +144,7 @@ class Dealer:
|
|
144 |
s["query"] = bqry.to_dict()
|
145 |
s["knn"]["filter"] = bqry.to_dict()
|
146 |
s["knn"]["similarity"] = 0.17
|
147 |
-
res = self.es.search(s,
|
148 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
149 |
|
150 |
kwds = set([])
|
@@ -358,20 +358,26 @@ class Dealer:
|
|
358 |
rag_tokenizer.tokenize(ans).split(" "),
|
359 |
rag_tokenizer.tokenize(inst).split(" "))
|
360 |
|
361 |
-
def retrieval(self, question, embd_mdl,
|
362 |
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
|
363 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
364 |
if not question:
|
365 |
return ranks
|
|
|
366 |
RERANK_PAGE_LIMIT = 3
|
367 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
|
368 |
"question": question, "vector": True, "topk": top,
|
369 |
"similarity": similarity_threshold,
|
370 |
"available_int": 1}
|
|
|
371 |
if page > RERANK_PAGE_LIMIT:
|
372 |
req["page"] = page
|
373 |
req["size"] = page_size
|
374 |
-
|
|
|
|
|
|
|
|
|
375 |
ranks["total"] = sres.total
|
376 |
|
377 |
if page <= RERANK_PAGE_LIMIT:
|
@@ -467,7 +473,7 @@ class Dealer:
|
|
467 |
s = Search()
|
468 |
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
469 |
s = s.to_dict()
|
470 |
-
es_res = self.es.search(s,
|
471 |
res = []
|
472 |
for index, chunk in enumerate(es_res['hits']['hits']):
|
473 |
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
|
|
79 |
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
80 |
return bqry
|
81 |
|
82 |
+
def search(self, req, idxnms, emb_mdl=None, highlight=False):
|
83 |
qst = req.get("question", "")
|
84 |
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
85 |
bqry = self._add_filters(bqry, req)
|
|
|
134 |
del s["highlight"]
|
135 |
q_vec = s["knn"]["query_vector"]
|
136 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
137 |
+
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
|
138 |
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
139 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
140 |
bqry, _ = self.qryr.question(qst, min_match="10%")
|
|
|
144 |
s["query"] = bqry.to_dict()
|
145 |
s["knn"]["filter"] = bqry.to_dict()
|
146 |
s["knn"]["similarity"] = 0.17
|
147 |
+
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
|
148 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
149 |
|
150 |
kwds = set([])
|
|
|
358 |
rag_tokenizer.tokenize(ans).split(" "),
|
359 |
rag_tokenizer.tokenize(inst).split(" "))
|
360 |
|
361 |
+
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
|
362 |
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
|
363 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
364 |
if not question:
|
365 |
return ranks
|
366 |
+
|
367 |
RERANK_PAGE_LIMIT = 3
|
368 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
|
369 |
"question": question, "vector": True, "topk": top,
|
370 |
"similarity": similarity_threshold,
|
371 |
"available_int": 1}
|
372 |
+
|
373 |
if page > RERANK_PAGE_LIMIT:
|
374 |
req["page"] = page
|
375 |
req["size"] = page_size
|
376 |
+
|
377 |
+
if isinstance(tenant_ids, str):
|
378 |
+
tenant_ids = tenant_ids.split(",")
|
379 |
+
|
380 |
+
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
|
381 |
ranks["total"] = sres.total
|
382 |
|
383 |
if page <= RERANK_PAGE_LIMIT:
|
|
|
473 |
s = Search()
|
474 |
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
475 |
s = s.to_dict()
|
476 |
+
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
|
477 |
res = []
|
478 |
for index, chunk in enumerate(es_res['hits']['hits']):
|
479 |
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
rag/utils/es_conn.py
CHANGED
@@ -221,12 +221,14 @@ class ESConnection:
|
|
221 |
|
222 |
return False
|
223 |
|
224 |
-
def search(self, q,
|
225 |
if not isinstance(q, dict):
|
226 |
q = Search().query(q).to_dict()
|
|
|
|
|
227 |
for i in range(3):
|
228 |
try:
|
229 |
-
res = self.es.search(index=(self.idxnm if not
|
230 |
body=q,
|
231 |
timeout=timeout,
|
232 |
# search_type="dfs_query_then_fetch",
|
|
|
221 |
|
222 |
return False
|
223 |
|
224 |
+
def search(self, q, idxnms=None, src=False, timeout="2s"):
|
225 |
if not isinstance(q, dict):
|
226 |
q = Search().query(q).to_dict()
|
227 |
+
if isinstance(idxnms, str):
|
228 |
+
idxnms = idxnms.split(",")
|
229 |
for i in range(3):
|
230 |
try:
|
231 |
+
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
|
232 |
body=q,
|
233 |
timeout=timeout,
|
234 |
# search_type="dfs_query_then_fetch",
|