Kevin Hu commited on
Commit
82bdd9f
·
1 Parent(s): c39b5d3

add search TAB backend api (#2375)

Browse files

### What problem does this PR solve?
#2247

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/chunk_app.py CHANGED
@@ -58,7 +58,7 @@ def list_chunk():
58
  }
59
  if "available_int" in req:
60
  query["available_int"] = int(req["available_int"])
61
- sres = retrievaler.search(query, search.index_name(tenant_id))
62
  res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
63
  for id in sres.ids:
64
  d = {
@@ -259,12 +259,25 @@ def retrieval_test():
259
  size = int(req.get("size", 30))
260
  question = req["question"]
261
  kb_id = req["kb_id"]
 
262
  doc_ids = req.get("doc_ids", [])
263
  similarity_threshold = float(req.get("similarity_threshold", 0.2))
264
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
265
  top = int(req.get("top_k", 1024))
 
266
  try:
267
- e, kb = KnowledgebaseService.get_by_id(kb_id)
 
 
 
 
 
 
 
 
 
 
 
268
  if not e:
269
  return get_data_error_result(retmsg="Knowledgebase not found!")
270
 
@@ -281,9 +294,9 @@ def retrieval_test():
281
  question += keyword_extraction(chat_mdl, question)
282
 
283
  retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
284
- ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
285
  similarity_threshold, vector_similarity_weight, top,
286
- doc_ids, rerank_mdl=rerank_mdl)
287
  for c in ranks["chunks"]:
288
  if "vector" in c:
289
  del c["vector"]
 
58
  }
59
  if "available_int" in req:
60
  query["available_int"] = int(req["available_int"])
61
+ sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
62
  res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
63
  for id in sres.ids:
64
  d = {
 
259
  size = int(req.get("size", 30))
260
  question = req["question"]
261
  kb_id = req["kb_id"]
262
+ if isinstance(kb_id, str): kb_id = [kb_id]
263
  doc_ids = req.get("doc_ids", [])
264
  similarity_threshold = float(req.get("similarity_threshold", 0.2))
265
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
266
  top = int(req.get("top_k", 1024))
267
+
268
  try:
269
+ tenants = UserTenantService.query(user_id=current_user.id)
270
+ for kid in kb_id:
271
+ for tenant in tenants:
272
+ if KnowledgebaseService.query(
273
+ tenant_id=tenant.tenant_id, id=kid):
274
+ break
275
+ else:
276
+ return get_json_result(
277
+ data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
278
+ retcode=RetCode.OPERATING_ERROR)
279
+
280
+ e, kb = KnowledgebaseService.get_by_id(kb_id[0])
281
  if not e:
282
  return get_data_error_result(retmsg="Knowledgebase not found!")
283
 
 
294
  question += keyword_extraction(chat_mdl, question)
295
 
296
  retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
297
+ ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
298
  similarity_threshold, vector_similarity_weight, top,
299
+ doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
300
  for c in ranks["chunks"]:
301
  if "vector" in c:
302
  del c["vector"]
api/apps/conversation_app.py CHANGED
@@ -14,19 +14,22 @@
14
  # limitations under the License.
15
  #
16
  import json
 
17
  from copy import deepcopy
18
 
19
- from db.services.user_service import UserTenantService
20
  from flask import request, Response
21
  from flask_login import login_required, current_user
22
 
23
  from api.db import LLMType
24
- from api.db.services.dialog_service import DialogService, ConversationService, chat
25
- from api.db.services.llm_service import LLMBundle, TenantService
26
- from api.settings import RetCode
 
27
  from api.utils import get_uuid
28
  from api.utils.api_utils import get_json_result
29
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
 
30
 
31
 
32
  @manager.route('/set', methods=['POST'])
@@ -286,3 +289,86 @@ def thumbup():
286
 
287
  ConversationService.update_by_id(conv["id"], conv)
288
  return get_json_result(data=conv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # limitations under the License.
15
  #
16
  import json
17
+ import re
18
  from copy import deepcopy
19
 
20
+ from api.db.services.user_service import UserTenantService
21
  from flask import request, Response
22
  from flask_login import login_required, current_user
23
 
24
  from api.db import LLMType
25
+ from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
26
+ from api.db.services.knowledgebase_service import KnowledgebaseService
27
+ from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
28
+ from api.settings import RetCode, retrievaler
29
  from api.utils import get_uuid
30
  from api.utils.api_utils import get_json_result
31
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
32
+ from graphrag.mind_map_extractor import MindMapExtractor
33
 
34
 
35
  @manager.route('/set', methods=['POST'])
 
289
 
290
  ConversationService.update_by_id(conv["id"], conv)
291
  return get_json_result(data=conv)
292
+
293
+
294
+ @manager.route('/ask', methods=['POST'])
295
+ @login_required
296
+ @validate_request("question", "kb_ids")
297
+ def ask_about():
298
+ req = request.json
299
+ uid = current_user.id
300
+ def stream():
301
+ nonlocal req, uid
302
+ try:
303
+ for ans in ask(req["question"], req["kb_ids"], uid):
304
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
305
+ except Exception as e:
306
+ yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
307
+ "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
308
+ ensure_ascii=False) + "\n\n"
309
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
310
+
311
+ resp = Response(stream(), mimetype="text/event-stream")
312
+ resp.headers.add_header("Cache-control", "no-cache")
313
+ resp.headers.add_header("Connection", "keep-alive")
314
+ resp.headers.add_header("X-Accel-Buffering", "no")
315
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
316
+ return resp
317
+
318
+
319
+ @manager.route('/mindmap', methods=['POST'])
320
+ @login_required
321
+ @validate_request("question", "kb_ids")
322
+ def mindmap():
323
+ req = request.json
324
+ kb_ids = req["kb_ids"]
325
+ e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
326
+ if not e:
327
+ return get_data_error_result(retmsg="Knowledgebase not found!")
328
+
329
+ embd_mdl = TenantLLMService.model_instance(
330
+ kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
331
+ chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
332
+ ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
333
+ 0.3, 0.3, aggs=False)
334
+ mindmap = MindMapExtractor(chat_mdl)
335
+ mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
336
+ return get_json_result(data=mind_map)
337
+
338
+
339
+ @manager.route('/related_questions', methods=['POST'])
340
+ @login_required
341
+ @validate_request("question")
342
+ def related_questions():
343
+ req = request.json
344
+ question = req["question"]
345
+ chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
346
+ prompt = """
347
+ Objective: To generate search terms related to the user's search keywords, helping users find more valuable information.
348
+ Instructions:
349
+ - Based on the keywords provided by the user, generate 5-10 related search terms.
350
+ - Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information.
351
+ - Use common, general terms as much as possible, avoiding obscure words or technical jargon.
352
+ - Keep the term length between 2-4 words, concise and clear.
353
+ - DO NOT translate, use the language of the original keywords.
354
+
355
+ ### Example:
356
+ Keywords: Chinese football
357
+ Related search terms:
358
+ 1. Current status of Chinese football
359
+ 2. Reform of Chinese football
360
+ 3. Youth training of Chinese football
361
+ 4. Chinese football in the Asian Cup
362
+ 5. Chinese football in the World Cup
363
+
364
+ Reason:
365
+ - When searching, users often only use one or two keywords, making it difficult to fully express their information needs.
366
+ - Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
367
+ - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
368
+
369
+ """
370
+ ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
371
+ Keywords: {question}
372
+ Related search terms:
373
+ """}], {"temperature": 0.9})
374
+ return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
api/db/services/dialog_service.py CHANGED
@@ -210,7 +210,7 @@ def chat(dialog, messages, stream=True, **kwargs):
210
  answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
211
  done_tm = timer()
212
  prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000)
213
- return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", "<br/>", prompt)}
214
 
215
  if stream:
216
  last_ans = ""
 
210
  answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
211
  done_tm = timer()
212
  prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000)
213
+ return {"answer": answer, "reference": refs, "prompt": prompt}
214
 
215
  if stream:
216
  last_ans = ""
api/db/services/llm_service.py CHANGED
@@ -190,7 +190,7 @@ class LLMBundle(object):
190
  tenant_id, llm_type, llm_name, lang=lang)
191
  assert self.mdl, "Can't find mole for {}/{}/{}".format(
192
  tenant_id, llm_type, llm_name)
193
- self.max_length = 512
194
  for lm in LLMService.query(llm_name=llm_name):
195
  self.max_length = lm.max_tokens
196
  break
 
190
  tenant_id, llm_type, llm_name, lang=lang)
191
  assert self.mdl, "Can't find mole for {}/{}/{}".format(
192
  tenant_id, llm_type, llm_name)
193
+ self.max_length = 8192
194
  for lm in LLMService.query(llm_name=llm_name):
195
  self.max_length = lm.max_tokens
196
  break
graphrag/search.py CHANGED
@@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
23
 
24
 
25
  class KGSearch(Dealer):
26
- def search(self, req, idxnm, emb_mdl=None):
27
  def merge_into_first(sres, title=""):
28
  df,texts = [],[]
29
  for d in sres["hits"]["hits"]:
 
23
 
24
 
25
  class KGSearch(Dealer):
26
+ def search(self, req, idxnm, emb_mdl=None, highlight=False):
27
  def merge_into_first(sres, title=""):
28
  df,texts = [],[]
29
  for d in sres["hits"]["hits"]:
rag/nlp/search.py CHANGED
@@ -79,9 +79,9 @@ class Dealer:
79
  Q("bool", must_not=Q("range", available_int={"lt": 1})))
80
  return bqry
81
 
82
- def search(self, req, idxnm, emb_mdl=None):
83
  qst = req.get("question", "")
84
- bqry, keywords = self.qryr.question(qst)
85
  bqry = self._add_filters(bqry, req)
86
  bqry.boost = 0.05
87
 
@@ -130,7 +130,7 @@ class Dealer:
130
  qst, emb_mdl, req.get(
131
  "similarity", 0.1), topk)
132
  s["knn"]["filter"] = bqry.to_dict()
133
- if "highlight" in s:
134
  del s["highlight"]
135
  q_vec = s["knn"]["query_vector"]
136
  es_logger.info("【Q】: {}".format(json.dumps(s)))
@@ -356,7 +356,7 @@ class Dealer:
356
  rag_tokenizer.tokenize(inst).split(" "))
357
 
358
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
359
- vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
360
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
361
  if not question:
362
  return ranks
@@ -364,7 +364,7 @@ class Dealer:
364
  "question": question, "vector": True, "topk": top,
365
  "similarity": similarity_threshold,
366
  "available_int": 1}
367
- sres = self.search(req, index_name(tenant_id), embd_mdl)
368
 
369
  if rerank_mdl:
370
  sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
@@ -405,6 +405,8 @@ class Dealer:
405
  "vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
406
  "positions": sres.field[id].get("position_int", "").split("\t")
407
  }
 
 
408
  if len(d["positions"]) % 5 == 0:
409
  poss = []
410
  for i in range(0, len(d["positions"]), 5):
 
79
  Q("bool", must_not=Q("range", available_int={"lt": 1})))
80
  return bqry
81
 
82
+ def search(self, req, idxnm, emb_mdl=None, highlight=False):
83
  qst = req.get("question", "")
84
+ bqry, keywords = self.qryr.question(qst, min_match="30%")
85
  bqry = self._add_filters(bqry, req)
86
  bqry.boost = 0.05
87
 
 
130
  qst, emb_mdl, req.get(
131
  "similarity", 0.1), topk)
132
  s["knn"]["filter"] = bqry.to_dict()
133
+ if not highlight and "highlight" in s:
134
  del s["highlight"]
135
  q_vec = s["knn"]["query_vector"]
136
  es_logger.info("【Q】: {}".format(json.dumps(s)))
 
356
  rag_tokenizer.tokenize(inst).split(" "))
357
 
358
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
359
+ vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
360
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
361
  if not question:
362
  return ranks
 
364
  "question": question, "vector": True, "topk": top,
365
  "similarity": similarity_threshold,
366
  "available_int": 1}
367
+ sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
368
 
369
  if rerank_mdl:
370
  sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
 
405
  "vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
406
  "positions": sres.field[id].get("position_int", "").split("\t")
407
  }
408
+ if highlight:
409
+ d["highlight"] = rmSpace(sres.highlight[id])
410
  if len(d["positions"]) % 5 == 0:
411
  poss = []
412
  for i in range(0, len(d["positions"]), 5):