jinhai-2012 Kevin Hu commited on
Commit
75c7829
·
1 Parent(s): 2a6f834

Refactor ask decorator (#4116)

Browse files

### What problem does this PR solve?

Refactor ask decorator

### Type of change

- [x] Refactoring

---------

Signed-off-by: jinhai <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>

api/db/services/dialog_service.py CHANGED
@@ -23,7 +23,7 @@ from copy import deepcopy
23
  from timeit import default_timer as timer
24
  import datetime
25
  from datetime import timedelta
26
- from api.db import LLMType, ParserType,StatusEnum
27
  from api.db.db_models import Dialog, DB
28
  from api.db.services.common_service import CommonService
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -41,14 +41,14 @@ class DialogService(CommonService):
41
  @classmethod
42
  @DB.connection_context()
43
  def get_list(cls, tenant_id,
44
- page_number, items_per_page, orderby, desc, id , name):
45
  chats = cls.model.select()
46
  if id:
47
  chats = chats.where(cls.model.id == id)
48
  if name:
49
  chats = chats.where(cls.model.name == name)
50
  chats = chats.where(
51
- (cls.model.tenant_id == tenant_id)
52
  & (cls.model.status == StatusEnum.VALID.value)
53
  )
54
  if desc:
@@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens):
137
 
138
  def chat(dialog, messages, stream=True, **kwargs):
139
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
140
- st = timer()
141
- llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
142
- llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
 
 
 
 
 
 
143
  if not llm:
144
- llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
145
- TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
 
146
  if not llm:
147
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
148
  max_tokens = 8192
149
  else:
150
  max_tokens = llm[0].max_tokens
 
 
 
151
  kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
152
- embd_nms = list(set([kb.embd_id for kb in kbs]))
153
- if len(embd_nms) != 1:
154
  yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
155
  return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
156
 
157
- is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
158
- retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
 
 
159
 
160
  questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
161
  attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
@@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs):
165
  if "doc_ids" in m:
166
  attachments.extend(m["doc_ids"])
167
 
168
- embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
 
 
169
  if not embd_mdl:
170
- raise LookupError("Embedding model(%s) not found" % embd_nms[0])
 
 
171
 
172
  if llm_id2llm_type(dialog.llm_id) == "image2text":
173
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
174
  else:
175
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
176
 
 
 
177
  prompt_config = dialog.prompt_config
178
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
179
  tts_mdl = None
@@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs):
200
  questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
201
  else:
202
  questions = questions[-1:]
203
- refineQ_tm = timer()
204
- keyword_tm = timer()
205
 
206
  rerank_mdl = None
207
  if dialog.rerank_id:
208
  rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
209
 
210
- for _ in range(len(questions) // 2):
211
- questions.append(questions[-1])
 
212
  if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
213
  kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
214
  else:
215
  if prompt_config.get("keyword", False):
216
  questions[-1] += keyword_extraction(chat_mdl, questions[-1])
217
- keyword_tm = timer()
218
 
219
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
220
- kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
221
- dialog.similarity_threshold,
222
- dialog.vector_similarity_weight,
223
- doc_ids=attachments,
224
- top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
 
 
 
225
  knowledges = kb_prompt(kbinfos, max_tokens)
226
  logging.debug(
227
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
228
- retrieval_tm = timer()
229
 
230
  if not knowledges and prompt_config.get("empty_response"):
231
  empty_res = prompt_config["empty_response"]
@@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs):
249
  max_tokens - used_token_count)
250
 
251
  def decorate_answer(answer):
252
- nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
 
 
 
253
  refs = []
254
  if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
255
- answer, idx = retr.insert_citations(answer,
256
- [ck["content_ltks"]
257
- for ck in kbinfos["chunks"]],
258
- [ck["vector"]
259
- for ck in kbinfos["chunks"]],
260
- embd_mdl,
261
- tkweight=1 - dialog.vector_similarity_weight,
262
- vtweight=dialog.vector_similarity_weight)
263
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
264
  recall_docs = [
265
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@@ -274,10 +298,20 @@ def chat(dialog, messages, stream=True, **kwargs):
274
 
275
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
276
  answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
277
- done_tm = timer()
278
- prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % (
279
- (refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000,
280
- (done_tm - retrieval_tm) * 1000)
 
 
 
 
 
 
 
 
 
 
281
  return {"answer": answer, "reference": refs, "prompt": prompt}
282
 
283
  if stream:
@@ -304,15 +338,15 @@ def chat(dialog, messages, stream=True, **kwargs):
304
 
305
 
306
  def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
307
- sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL"
308
- user_promt = """
309
- 表名:{}
310
- 数据库表字段说明如下:
311
  {}
312
 
313
- 问题如下:
314
  {}
315
- 请写出SQL, 且只要SQL,不要有其他说明及文字。
316
  """.format(
317
  index_name(tenant_id),
318
  "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
@@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
321
  tried_times = 0
322
 
323
  def get_table():
324
- nonlocal sys_prompt, user_promt, question, tried_times
325
- sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
326
  "temperature": 0.06})
327
- logging.debug(f"{question} ==> {user_promt} get SQL: {sql}")
328
  sql = re.sub(r"[\r\n]+", " ", sql.lower())
329
  sql = re.sub(r".*select ", "select ", sql.lower())
330
  sql = re.sub(r" +", " ", sql)
@@ -352,21 +386,23 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
352
  if tbl is None:
353
  return None
354
  if tbl.get("error") and tried_times <= 2:
355
- user_promt = """
356
- 表名:{}
357
- 数据库表字段说明如下:
358
  {}
359
-
360
- 问题如下:
361
  {}
 
 
362
 
363
- 你上一次给出的错误SQL如下:
364
  {}
365
 
366
- 后台报错如下:
367
  {}
368
 
369
- 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
370
  """.format(
371
  index_name(tenant_id),
372
  "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
@@ -381,21 +417,21 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
381
 
382
  docid_idx = set([ii for ii, c in enumerate(
383
  tbl["columns"]) if c["name"] == "doc_id"])
384
- docnm_idx = set([ii for ii, c in enumerate(
385
  tbl["columns"]) if c["name"] == "docnm_kwd"])
386
- clmn_idx = [ii for ii in range(
387
- len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
388
 
389
- # compose markdown table
390
- clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
391
- tbl["columns"][i]["name"])) for i in
392
- clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
393
 
394
- line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
395
  ("|------|" if docid_idx and docid_idx else "")
396
 
397
  rows = ["|" +
398
- "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
399
  "|" for r in tbl["rows"]]
400
  rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
401
  if quota:
@@ -404,24 +440,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
404
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
405
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
406
 
407
- if not docid_idx or not docnm_idx:
408
  logging.warning("SQL missing field: " + sql)
409
  return {
410
- "answer": "\n".join([clmns, line, rows]),
411
  "reference": {"chunks": [], "doc_aggs": []},
412
  "prompt": sys_prompt
413
  }
414
 
415
  docid_idx = list(docid_idx)[0]
416
- docnm_idx = list(docnm_idx)[0]
417
  doc_aggs = {}
418
  for r in tbl["rows"]:
419
  if r[docid_idx] not in doc_aggs:
420
- doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
421
  doc_aggs[r[docid_idx]]["count"] += 1
422
  return {
423
- "answer": "\n".join([clmns, line, rows]),
424
- "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
425
  "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
426
  doc_aggs.items()]},
427
  "prompt": sys_prompt
@@ -492,7 +528,7 @@ Requirements:
492
  kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
493
  if isinstance(kwd, tuple):
494
  kwd = kwd[0]
495
- if kwd.find("**ERROR**") >=0:
496
  return ""
497
  return kwd
498
 
@@ -605,16 +641,16 @@ def tts(tts_mdl, text):
605
 
606
  def ask(question, kb_ids, tenant_id):
607
  kbs = KnowledgebaseService.get_by_ids(kb_ids)
608
- embd_nms = list(set([kb.embd_id for kb in kbs]))
609
 
610
- is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
611
- retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
612
 
613
- embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
614
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
615
  max_tokens = chat_mdl.max_length
616
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
617
- kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
618
  knowledges = kb_prompt(kbinfos, max_tokens)
619
  prompt = """
620
  Role: You're a smart assistant. Your name is Miss R.
@@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id):
636
 
637
  def decorate_answer(answer):
638
  nonlocal knowledges, kbinfos, prompt
639
- answer, idx = retr.insert_citations(answer,
640
- [ck["content_ltks"]
641
- for ck in kbinfos["chunks"]],
642
- [ck["vector"]
643
- for ck in kbinfos["chunks"]],
644
- embd_mdl,
645
- tkweight=0.7,
646
- vtweight=0.3)
647
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
648
  recall_docs = [
649
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@@ -664,4 +700,3 @@ def ask(question, kb_ids, tenant_id):
664
  answer = ans
665
  yield {"answer": answer, "reference": {}}
666
  yield decorate_answer(answer)
667
-
 
23
  from timeit import default_timer as timer
24
  import datetime
25
  from datetime import timedelta
26
+ from api.db import LLMType, ParserType, StatusEnum
27
  from api.db.db_models import Dialog, DB
28
  from api.db.services.common_service import CommonService
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
 
41
  @classmethod
42
  @DB.connection_context()
43
  def get_list(cls, tenant_id,
44
+ page_number, items_per_page, orderby, desc, id, name):
45
  chats = cls.model.select()
46
  if id:
47
  chats = chats.where(cls.model.id == id)
48
  if name:
49
  chats = chats.where(cls.model.name == name)
50
  chats = chats.where(
51
+ (cls.model.tenant_id == tenant_id)
52
  & (cls.model.status == StatusEnum.VALID.value)
53
  )
54
  if desc:
 
137
 
138
  def chat(dialog, messages, stream=True, **kwargs):
139
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
140
+
141
+ chat_start_ts = timer()
142
+
143
+ # Get llm model name and model provider name
144
+ llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
145
+
146
+ # Get llm model instance by model and provide name
147
+ llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)
148
+
149
  if not llm:
150
+ # Model name is provided by tenant, but not system built-in
151
+ llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
152
+ TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
153
  if not llm:
154
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
155
  max_tokens = 8192
156
  else:
157
  max_tokens = llm[0].max_tokens
158
+
159
+ check_llm_ts = timer()
160
+
161
  kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
162
+ embedding_list = list(set([kb.embd_id for kb in kbs]))
163
+ if len(embedding_list) != 1:
164
  yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
165
  return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
166
 
167
+ embedding_model_name = embedding_list[0]
168
+
169
+ is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
170
+ retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
171
 
172
  questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
173
  attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
 
177
  if "doc_ids" in m:
178
  attachments.extend(m["doc_ids"])
179
 
180
+ create_retriever_ts = timer()
181
+
182
+ embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
183
  if not embd_mdl:
184
+ raise LookupError("Embedding model(%s) not found" % embedding_model_name)
185
+
186
+ bind_embedding_ts = timer()
187
 
188
  if llm_id2llm_type(dialog.llm_id) == "image2text":
189
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
190
  else:
191
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
192
 
193
+ bind_llm_ts = timer()
194
+
195
  prompt_config = dialog.prompt_config
196
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
197
  tts_mdl = None
 
218
  questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
219
  else:
220
  questions = questions[-1:]
221
+
222
+ refine_question_ts = timer()
223
 
224
  rerank_mdl = None
225
  if dialog.rerank_id:
226
  rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
227
 
228
+ bind_reranker_ts = timer()
229
+ generate_keyword_ts = bind_reranker_ts
230
+
231
  if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
232
  kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
233
  else:
234
  if prompt_config.get("keyword", False):
235
  questions[-1] += keyword_extraction(chat_mdl, questions[-1])
236
+ generate_keyword_ts = timer()
237
 
238
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
239
+ kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
240
+ dialog.similarity_threshold,
241
+ dialog.vector_similarity_weight,
242
+ doc_ids=attachments,
243
+ top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
244
+
245
+ retrieval_ts = timer()
246
+
247
  knowledges = kb_prompt(kbinfos, max_tokens)
248
  logging.debug(
249
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
 
250
 
251
  if not knowledges and prompt_config.get("empty_response"):
252
  empty_res = prompt_config["empty_response"]
 
270
  max_tokens - used_token_count)
271
 
272
  def decorate_answer(answer):
273
+ nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
274
+
275
+ finish_chat_ts = timer()
276
+
277
  refs = []
278
  if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
279
+ answer, idx = retriever.insert_citations(answer,
280
+ [ck["content_ltks"]
281
+ for ck in kbinfos["chunks"]],
282
+ [ck["vector"]
283
+ for ck in kbinfos["chunks"]],
284
+ embd_mdl,
285
+ tkweight=1 - dialog.vector_similarity_weight,
286
+ vtweight=dialog.vector_similarity_weight)
287
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
288
  recall_docs = [
289
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
 
298
 
299
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
300
  answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
301
+ finish_chat_ts = timer()
302
+
303
+ total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
304
+ check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
305
+ create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
306
+ bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
307
+ bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
308
+ refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
309
+ bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
310
+ generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
311
+ retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
312
+ generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
313
+
314
+ prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
315
  return {"answer": answer, "reference": refs, "prompt": prompt}
316
 
317
  if stream:
 
338
 
339
 
340
  def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
341
+ sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
342
+ user_prompt = """
343
+ Table name: {};
344
+ Table of database fields are as follows:
345
  {}
346
 
347
+ Question are as follows:
348
  {}
349
+ Please write the SQL, only SQL, without any other explanations or text.
350
  """.format(
351
  index_name(tenant_id),
352
  "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
 
355
  tried_times = 0
356
 
357
  def get_table():
358
+ nonlocal sys_prompt, user_prompt, question, tried_times
359
+ sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
360
  "temperature": 0.06})
361
+ logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
362
  sql = re.sub(r"[\r\n]+", " ", sql.lower())
363
  sql = re.sub(r".*select ", "select ", sql.lower())
364
  sql = re.sub(r" +", " ", sql)
 
386
  if tbl is None:
387
  return None
388
  if tbl.get("error") and tried_times <= 2:
389
+ user_prompt = """
390
+ Table name: {};
391
+ Table of database fields are as follows:
392
  {}
393
+
394
+ Question are as follows:
395
  {}
396
+ Please write the SQL, only SQL, without any other explanations or text.
397
+
398
 
399
+ The SQL error you provided last time is as follows:
400
  {}
401
 
402
+ Error issued by database as follows:
403
  {}
404
 
405
+ Please correct the error and write SQL again, only SQL, without any other explanations or text.
406
  """.format(
407
  index_name(tenant_id),
408
  "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
 
417
 
418
  docid_idx = set([ii for ii, c in enumerate(
419
  tbl["columns"]) if c["name"] == "doc_id"])
420
+ doc_name_idx = set([ii for ii, c in enumerate(
421
  tbl["columns"]) if c["name"] == "docnm_kwd"])
422
+ column_idx = [ii for ii in range(
423
+ len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
424
 
425
+ # compose Markdown table
426
+ columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
427
+ tbl["columns"][i]["name"])) for i in
428
+ column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
429
 
430
+ line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
431
  ("|------|" if docid_idx and docid_idx else "")
432
 
433
  rows = ["|" +
434
+ "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
435
  "|" for r in tbl["rows"]]
436
  rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
437
  if quota:
 
440
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
441
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
442
 
443
+ if not docid_idx or not doc_name_idx:
444
  logging.warning("SQL missing field: " + sql)
445
  return {
446
+ "answer": "\n".join([columns, line, rows]),
447
  "reference": {"chunks": [], "doc_aggs": []},
448
  "prompt": sys_prompt
449
  }
450
 
451
  docid_idx = list(docid_idx)[0]
452
+ doc_name_idx = list(doc_name_idx)[0]
453
  doc_aggs = {}
454
  for r in tbl["rows"]:
455
  if r[docid_idx] not in doc_aggs:
456
+ doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
457
  doc_aggs[r[docid_idx]]["count"] += 1
458
  return {
459
+ "answer": "\n".join([columns, line, rows]),
460
+ "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
461
  "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
462
  doc_aggs.items()]},
463
  "prompt": sys_prompt
 
528
  kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
529
  if isinstance(kwd, tuple):
530
  kwd = kwd[0]
531
+ if kwd.find("**ERROR**") >= 0:
532
  return ""
533
  return kwd
534
 
 
641
 
642
  def ask(question, kb_ids, tenant_id):
643
  kbs = KnowledgebaseService.get_by_ids(kb_ids)
644
+ embedding_list = list(set([kb.embd_id for kb in kbs]))
645
 
646
+ is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
647
+ retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
648
 
649
+ embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
650
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
651
  max_tokens = chat_mdl.max_length
652
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
653
+ kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
654
  knowledges = kb_prompt(kbinfos, max_tokens)
655
  prompt = """
656
  Role: You're a smart assistant. Your name is Miss R.
 
672
 
673
  def decorate_answer(answer):
674
  nonlocal knowledges, kbinfos, prompt
675
+ answer, idx = retriever.insert_citations(answer,
676
+ [ck["content_ltks"]
677
+ for ck in kbinfos["chunks"]],
678
+ [ck["vector"]
679
+ for ck in kbinfos["chunks"]],
680
+ embd_mdl,
681
+ tkweight=0.7,
682
+ vtweight=0.3)
683
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
684
  recall_docs = [
685
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
 
700
  answer = ans
701
  yield {"answer": answer, "reference": {}}
702
  yield decorate_answer(answer)
 
api/db/services/llm_service.py CHANGED
@@ -72,10 +72,12 @@ class TenantLLMService(CommonService):
72
  return model_name, None
73
  if len(arr) > 2:
74
  return "@".join(arr[0:-1]), arr[-1]
 
 
75
  try:
76
- fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
77
- fact = set([f["name"] for f in fact])
78
- if arr[-1] not in fact:
79
  return model_name, None
80
  return arr[0], arr[-1]
81
  except Exception as e:
@@ -113,11 +115,11 @@ class TenantLLMService(CommonService):
113
  if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
114
  llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
115
  if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
116
- model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
117
  if not model_config:
118
  if mdlnm == "flag-embedding":
119
  model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
120
- "llm_name": llm_name, "api_base": ""}
121
  else:
122
  if not mdlnm:
123
  raise LookupError(f"Type of {llm_type} model is not set.")
@@ -200,8 +202,8 @@ class TenantLLMService(CommonService):
200
  return num
201
  else:
202
  tenant_llm = tenant_llms[0]
203
- num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
204
- .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
205
  .execute()
206
  except Exception:
207
  logging.exception("TenantLLMService.increase_usage got exception")
@@ -231,7 +233,7 @@ class LLMBundle(object):
231
  for lm in LLMService.query(llm_name=llm_name):
232
  self.max_length = lm.max_tokens
233
  break
234
-
235
  def encode(self, texts: list):
236
  embeddings, used_tokens = self.mdl.encode(texts)
237
  if not TenantLLMService.increase_usage(
@@ -274,11 +276,11 @@ class LLMBundle(object):
274
 
275
  def tts(self, text):
276
  for chunk in self.mdl.tts(text):
277
- if isinstance(chunk,int):
278
  if not TenantLLMService.increase_usage(
279
- self.tenant_id, self.llm_type, chunk, self.llm_name):
280
- logging.error(
281
- "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
282
  return
283
  yield chunk
284
 
@@ -287,7 +289,8 @@ class LLMBundle(object):
287
  if isinstance(txt, int) and not TenantLLMService.increase_usage(
288
  self.tenant_id, self.llm_type, used_tokens, self.llm_name):
289
  logging.error(
290
- "LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
 
291
  return txt
292
 
293
  def chat_streamly(self, system, history, gen_conf):
@@ -296,6 +299,7 @@ class LLMBundle(object):
296
  if not TenantLLMService.increase_usage(
297
  self.tenant_id, self.llm_type, txt, self.llm_name):
298
  logging.error(
299
- "LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
 
300
  return
301
  yield txt
 
72
  return model_name, None
73
  if len(arr) > 2:
74
  return "@".join(arr[0:-1]), arr[-1]
75
+
76
+ # model name must be xxx@yyy
77
  try:
78
+ model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
79
+ model_providers = set([f["name"] for f in model_factories])
80
+ if arr[-1] not in model_providers:
81
  return model_name, None
82
  return arr[0], arr[-1]
83
  except Exception as e:
 
115
  if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
116
  llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
117
  if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
118
+ model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
119
  if not model_config:
120
  if mdlnm == "flag-embedding":
121
  model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
122
+ "llm_name": llm_name, "api_base": ""}
123
  else:
124
  if not mdlnm:
125
  raise LookupError(f"Type of {llm_type} model is not set.")
 
202
  return num
203
  else:
204
  tenant_llm = tenant_llms[0]
205
+ num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \
206
+ .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \
207
  .execute()
208
  except Exception:
209
  logging.exception("TenantLLMService.increase_usage got exception")
 
233
  for lm in LLMService.query(llm_name=llm_name):
234
  self.max_length = lm.max_tokens
235
  break
236
+
237
  def encode(self, texts: list):
238
  embeddings, used_tokens = self.mdl.encode(texts)
239
  if not TenantLLMService.increase_usage(
 
276
 
277
  def tts(self, text):
278
  for chunk in self.mdl.tts(text):
279
+ if isinstance(chunk, int):
280
  if not TenantLLMService.increase_usage(
281
+ self.tenant_id, self.llm_type, chunk, self.llm_name):
282
+ logging.error(
283
+ "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
284
  return
285
  yield chunk
286
 
 
289
  if isinstance(txt, int) and not TenantLLMService.increase_usage(
290
  self.tenant_id, self.llm_type, used_tokens, self.llm_name):
291
  logging.error(
292
+ "LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
293
+ used_tokens))
294
  return txt
295
 
296
  def chat_streamly(self, system, history, gen_conf):
 
299
  if not TenantLLMService.increase_usage(
300
  self.tenant_id, self.llm_type, txt, self.llm_name):
301
  logging.error(
302
+ "LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
303
+ txt))
304
  return
305
  yield txt