Kevin Hu commited on
Commit
95863fc
·
1 Parent(s): 7fcd041

search between multiple indiices for team function (#3079)

Browse files

### What problem does this PR solve?

#2834
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

agent/component/__init__.py CHANGED
@@ -29,6 +29,7 @@ from .jin10 import Jin10, Jin10Param
29
  from .tushare import TuShare, TuShareParam
30
  from .akshare import AkShare, AkShareParam
31
  from .crawler import Crawler, CrawlerParam
 
32
 
33
 
34
  def component_class(class_name):
 
29
  from .tushare import TuShare, TuShareParam
30
  from .akshare import AkShare, AkShareParam
31
  from .crawler import Crawler, CrawlerParam
32
+ from .invoke import Invoke, InvokeParam
33
 
34
 
35
  def component_class(class_name):
agent/component/generate.py CHANGED
@@ -17,6 +17,7 @@ import re
17
  from functools import partial
18
  import pandas as pd
19
  from api.db import LLMType
 
20
  from api.db.services.llm_service import LLMBundle
21
  from api.settings import retrievaler
22
  from agent.component.base import ComponentBase, ComponentParamBase
@@ -112,7 +113,7 @@ class Generate(ComponentBase):
112
 
113
  kwargs["input"] = input
114
  for n, v in kwargs.items():
115
- prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt)
116
 
117
  downstreams = self._canvas.get_component(self._id)["downstream"]
118
  if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
@@ -124,8 +125,10 @@ class Generate(ComponentBase):
124
  retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
125
  return pd.DataFrame([res])
126
 
127
- ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
128
- self._param.gen_conf())
 
 
129
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
130
  res = self.set_cite(retrieval_res, ans)
131
  return pd.DataFrame([res])
@@ -141,9 +144,10 @@ class Generate(ComponentBase):
141
  self.set_output(res)
142
  return
143
 
 
 
144
  answer = ""
145
- for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size),
146
- self._param.gen_conf()):
147
  res = {"content": ans, "reference": []}
148
  answer = ans
149
  yield res
 
17
  from functools import partial
18
  import pandas as pd
19
  from api.db import LLMType
20
+ from api.db.services.dialog_service import message_fit_in
21
  from api.db.services.llm_service import LLMBundle
22
  from api.settings import retrievaler
23
  from agent.component.base import ComponentBase, ComponentParamBase
 
113
 
114
  kwargs["input"] = input
115
  for n, v in kwargs.items():
116
+ prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
117
 
118
  downstreams = self._canvas.get_component(self._id)["downstream"]
119
  if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
 
125
  retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
126
  return pd.DataFrame([res])
127
 
128
+ msg = self._canvas.get_history(self._param.message_history_window_size)
129
+ _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
130
+ ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
131
+
132
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
133
  res = self.set_cite(retrieval_res, ans)
134
  return pd.DataFrame([res])
 
144
  self.set_output(res)
145
  return
146
 
147
+ msg = self._canvas.get_history(self._param.message_history_window_size)
148
+ _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
149
  answer = ""
150
+ for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
 
151
  res = {"content": ans, "reference": []}
152
  answer = ans
153
  yield res
agent/component/invoke.py CHANGED
@@ -14,10 +14,10 @@
14
  # limitations under the License.
15
  #
16
  import json
 
17
  from abc import ABC
18
-
19
  import requests
20
-
21
  from agent.component.base import ComponentBase, ComponentParamBase
22
 
23
 
@@ -34,11 +34,13 @@ class InvokeParam(ComponentParamBase):
34
  self.variables = []
35
  self.url = ""
36
  self.timeout = 60
 
37
 
38
  def check(self):
39
  self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
40
  self.check_empty(self.url, "End point URL")
41
  self.check_positive_integer(self.timeout, "Timeout time in second")
 
42
 
43
 
44
  class Invoke(ComponentBase, ABC):
@@ -63,7 +65,7 @@ class Invoke(ComponentBase, ABC):
63
  if self._param.headers:
64
  headers = json.loads(self._param.headers)
65
  proxies = None
66
- if self._param.proxy:
67
  proxies = {"http": self._param.proxy, "https": self._param.proxy}
68
 
69
  if method == 'get':
@@ -72,6 +74,10 @@ class Invoke(ComponentBase, ABC):
72
  headers=headers,
73
  proxies=proxies,
74
  timeout=self._param.timeout)
 
 
 
 
75
  return Invoke.be_output(response.text)
76
 
77
  if method == 'put':
@@ -80,5 +86,18 @@ class Invoke(ComponentBase, ABC):
80
  headers=headers,
81
  proxies=proxies,
82
  timeout=self._param.timeout)
 
 
 
 
83
 
 
 
 
 
 
 
 
 
 
84
  return Invoke.be_output(response.text)
 
14
  # limitations under the License.
15
  #
16
  import json
17
+ import re
18
  from abc import ABC
 
19
  import requests
20
+ from deepdoc.parser import HtmlParser
21
  from agent.component.base import ComponentBase, ComponentParamBase
22
 
23
 
 
34
  self.variables = []
35
  self.url = ""
36
  self.timeout = 60
37
+ self.clean_html = False
38
 
39
  def check(self):
40
  self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
41
  self.check_empty(self.url, "End point URL")
42
  self.check_positive_integer(self.timeout, "Timeout time in second")
43
+ self.check_boolean(self.clean_html, "Clean HTML")
44
 
45
 
46
  class Invoke(ComponentBase, ABC):
 
65
  if self._param.headers:
66
  headers = json.loads(self._param.headers)
67
  proxies = None
68
+ if re.sub(r"https?:?/?/?", "", self._param.proxy):
69
  proxies = {"http": self._param.proxy, "https": self._param.proxy}
70
 
71
  if method == 'get':
 
74
  headers=headers,
75
  proxies=proxies,
76
  timeout=self._param.timeout)
77
+ if self._param.clean_html:
78
+ sections = HtmlParser()(None, response.content)
79
+ return Invoke.be_output("\n".join(sections))
80
+
81
  return Invoke.be_output(response.text)
82
 
83
  if method == 'put':
 
86
  headers=headers,
87
  proxies=proxies,
88
  timeout=self._param.timeout)
89
+ if self._param.clean_html:
90
+ sections = HtmlParser()(None, response.content)
91
+ return Invoke.be_output("\n".join(sections))
92
+ return Invoke.be_output(response.text)
93
 
94
+ if method == 'post':
95
+ response = requests.post(url=url,
96
+ json=args,
97
+ headers=headers,
98
+ proxies=proxies,
99
+ timeout=self._param.timeout)
100
+ if self._param.clean_html:
101
+ sections = HtmlParser()(None, response.content)
102
+ return Invoke.be_output("\n".join(sections))
103
  return Invoke.be_output(response.text)
api/db/services/dialog_service.py CHANGED
@@ -205,7 +205,9 @@ def chat(dialog, messages, stream=True, **kwargs):
205
  else:
206
  if prompt_config.get("keyword", False):
207
  questions[-1] += keyword_extraction(chat_mdl, questions[-1])
208
- kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
 
 
209
  dialog.similarity_threshold,
210
  dialog.vector_similarity_weight,
211
  doc_ids=attachments,
 
205
  else:
206
  if prompt_config.get("keyword", False):
207
  questions[-1] += keyword_extraction(chat_mdl, questions[-1])
208
+
209
+ tenant_ids = list(set([kb.tenant_id for kb in kbs]))
210
+ kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
211
  dialog.similarity_threshold,
212
  dialog.vector_similarity_weight,
213
  doc_ids=attachments,
deepdoc/parser/html_parser.py CHANGED
@@ -16,11 +16,13 @@ import readability
16
  import html_text
17
  import chardet
18
 
 
19
  def get_encoding(file):
20
  with open(file,'rb') as f:
21
  tmp = chardet.detect(f.read())
22
  return tmp['encoding']
23
 
 
24
  class RAGFlowHtmlParser:
25
  def __call__(self, fnm, binary=None):
26
  txt = ""
 
16
  import html_text
17
  import chardet
18
 
19
+
20
  def get_encoding(file):
21
  with open(file,'rb') as f:
22
  tmp = chardet.detect(f.read())
23
  return tmp['encoding']
24
 
25
+
26
  class RAGFlowHtmlParser:
27
  def __call__(self, fnm, binary=None):
28
  txt = ""
rag/nlp/search.py CHANGED
@@ -79,7 +79,7 @@ class Dealer:
79
  Q("bool", must_not=Q("range", available_int={"lt": 1})))
80
  return bqry
81
 
82
- def search(self, req, idxnm, emb_mdl=None, highlight=False):
83
  qst = req.get("question", "")
84
  bqry, keywords = self.qryr.question(qst, min_match="30%")
85
  bqry = self._add_filters(bqry, req)
@@ -134,7 +134,7 @@ class Dealer:
134
  del s["highlight"]
135
  q_vec = s["knn"]["query_vector"]
136
  es_logger.info("【Q】: {}".format(json.dumps(s)))
137
- res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
138
  es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
139
  if self.es.getTotal(res) == 0 and "knn" in s:
140
  bqry, _ = self.qryr.question(qst, min_match="10%")
@@ -144,7 +144,7 @@ class Dealer:
144
  s["query"] = bqry.to_dict()
145
  s["knn"]["filter"] = bqry.to_dict()
146
  s["knn"]["similarity"] = 0.17
147
- res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
148
  es_logger.info("【Q】: {}".format(json.dumps(s)))
149
 
150
  kwds = set([])
@@ -358,20 +358,26 @@ class Dealer:
358
  rag_tokenizer.tokenize(ans).split(" "),
359
  rag_tokenizer.tokenize(inst).split(" "))
360
 
361
- def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
362
  vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
363
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
364
  if not question:
365
  return ranks
 
366
  RERANK_PAGE_LIMIT = 3
367
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
368
  "question": question, "vector": True, "topk": top,
369
  "similarity": similarity_threshold,
370
  "available_int": 1}
 
371
  if page > RERANK_PAGE_LIMIT:
372
  req["page"] = page
373
  req["size"] = page_size
374
- sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
 
 
 
 
375
  ranks["total"] = sres.total
376
 
377
  if page <= RERANK_PAGE_LIMIT:
@@ -467,7 +473,7 @@ class Dealer:
467
  s = Search()
468
  s = s.query(Q("match", doc_id=doc_id))[0:max_count]
469
  s = s.to_dict()
470
- es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields)
471
  res = []
472
  for index, chunk in enumerate(es_res['hits']['hits']):
473
  res.append({fld: chunk['_source'].get(fld) for fld in fields})
 
79
  Q("bool", must_not=Q("range", available_int={"lt": 1})))
80
  return bqry
81
 
82
+ def search(self, req, idxnms, emb_mdl=None, highlight=False):
83
  qst = req.get("question", "")
84
  bqry, keywords = self.qryr.question(qst, min_match="30%")
85
  bqry = self._add_filters(bqry, req)
 
134
  del s["highlight"]
135
  q_vec = s["knn"]["query_vector"]
136
  es_logger.info("【Q】: {}".format(json.dumps(s)))
137
+ res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
138
  es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
139
  if self.es.getTotal(res) == 0 and "knn" in s:
140
  bqry, _ = self.qryr.question(qst, min_match="10%")
 
144
  s["query"] = bqry.to_dict()
145
  s["knn"]["filter"] = bqry.to_dict()
146
  s["knn"]["similarity"] = 0.17
147
+ res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
148
  es_logger.info("【Q】: {}".format(json.dumps(s)))
149
 
150
  kwds = set([])
 
358
  rag_tokenizer.tokenize(ans).split(" "),
359
  rag_tokenizer.tokenize(inst).split(" "))
360
 
361
+ def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
362
  vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
363
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
364
  if not question:
365
  return ranks
366
+
367
  RERANK_PAGE_LIMIT = 3
368
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
369
  "question": question, "vector": True, "topk": top,
370
  "similarity": similarity_threshold,
371
  "available_int": 1}
372
+
373
  if page > RERANK_PAGE_LIMIT:
374
  req["page"] = page
375
  req["size"] = page_size
376
+
377
+ if isinstance(tenant_ids, str):
378
+ tenant_ids = tenant_ids.split(",")
379
+
380
+ sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
381
  ranks["total"] = sres.total
382
 
383
  if page <= RERANK_PAGE_LIMIT:
 
473
  s = Search()
474
  s = s.query(Q("match", doc_id=doc_id))[0:max_count]
475
  s = s.to_dict()
476
+ es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
477
  res = []
478
  for index, chunk in enumerate(es_res['hits']['hits']):
479
  res.append({fld: chunk['_source'].get(fld) for fld in fields})
rag/utils/es_conn.py CHANGED
@@ -221,12 +221,14 @@ class ESConnection:
221
 
222
  return False
223
 
224
- def search(self, q, idxnm=None, src=False, timeout="2s"):
225
  if not isinstance(q, dict):
226
  q = Search().query(q).to_dict()
 
 
227
  for i in range(3):
228
  try:
229
- res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
230
  body=q,
231
  timeout=timeout,
232
  # search_type="dfs_query_then_fetch",
 
221
 
222
  return False
223
 
224
+ def search(self, q, idxnms=None, src=False, timeout="2s"):
225
  if not isinstance(q, dict):
226
  q = Search().query(q).to_dict()
227
+ if isinstance(idxnms, str):
228
+ idxnms = idxnms.split(",")
229
  for i in range(3):
230
  try:
231
+ res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
232
  body=q,
233
  timeout=timeout,
234
  # search_type="dfs_query_then_fetch",