Kevin Hu
commited on
Commit
·
41012b3
1
Parent(s):
63ae668
add elapsed time of conversation (#2316)
Browse files### What problem does this PR solve?
#2315
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
api/db/services/dialog_service.py
CHANGED
@@ -18,7 +18,7 @@ import os
|
|
18 |
import json
|
19 |
import re
|
20 |
from copy import deepcopy
|
21 |
-
|
22 |
from api.db import LLMType, ParserType
|
23 |
from api.db.db_models import Dialog, Conversation
|
24 |
from api.db.services.common_service import CommonService
|
@@ -88,6 +88,7 @@ def llm_id2llm_type(llm_id):
|
|
88 |
|
89 |
def chat(dialog, messages, stream=True, **kwargs):
|
90 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
|
|
91 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
92 |
if not llm:
|
93 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
|
@@ -158,25 +159,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
158 |
doc_ids=attachments,
|
159 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
160 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
161 |
-
#self-rag
|
162 |
-
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
|
163 |
-
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
|
164 |
-
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
165 |
-
dialog.similarity_threshold,
|
166 |
-
dialog.vector_similarity_weight,
|
167 |
-
doc_ids=attachments,
|
168 |
-
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
169 |
-
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
170 |
-
|
171 |
chat_logger.info(
|
172 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
|
|
173 |
|
174 |
if not knowledges and prompt_config.get("empty_response"):
|
175 |
empty_res = prompt_config["empty_response"]
|
176 |
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
177 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
178 |
|
179 |
-
kwargs["knowledge"] = "\n".join(knowledges)
|
180 |
gen_conf = dialog.llm_setting
|
181 |
|
182 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
@@ -192,7 +184,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
192 |
max_tokens - used_token_count)
|
193 |
|
194 |
def decorate_answer(answer):
|
195 |
-
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt
|
196 |
refs = []
|
197 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
198 |
answer, idx = retr.insert_citations(answer,
|
@@ -216,7 +208,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
216 |
|
217 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
218 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
219 |
-
|
|
|
|
|
220 |
|
221 |
if stream:
|
222 |
last_ans = ""
|
@@ -415,4 +409,75 @@ def tts(tts_mdl, text):
|
|
415 |
bin = b""
|
416 |
for chunk in tts_mdl.tts(text):
|
417 |
bin += chunk
|
418 |
-
return binascii.hexlify(bin).decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
import json
|
19 |
import re
|
20 |
from copy import deepcopy
|
21 |
+
from timeit import default_timer as timer
|
22 |
from api.db import LLMType, ParserType
|
23 |
from api.db.db_models import Dialog, Conversation
|
24 |
from api.db.services.common_service import CommonService
|
|
|
88 |
|
89 |
def chat(dialog, messages, stream=True, **kwargs):
|
90 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
91 |
+
st = timer()
|
92 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
93 |
if not llm:
|
94 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
|
|
|
159 |
doc_ids=attachments,
|
160 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
161 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
chat_logger.info(
|
163 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
164 |
+
retrieval_tm = timer()
|
165 |
|
166 |
if not knowledges and prompt_config.get("empty_response"):
|
167 |
empty_res = prompt_config["empty_response"]
|
168 |
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
169 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
170 |
|
171 |
+
kwargs["knowledge"] = "\n------\n".join(knowledges)
|
172 |
gen_conf = dialog.llm_setting
|
173 |
|
174 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
|
|
184 |
max_tokens - used_token_count)
|
185 |
|
186 |
def decorate_answer(answer):
|
187 |
+
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
|
188 |
refs = []
|
189 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
190 |
answer, idx = retr.insert_citations(answer,
|
|
|
208 |
|
209 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
210 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
211 |
+
done_tm = timer()
|
212 |
+
prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000)
|
213 |
+
return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", "<br/>", prompt)}
|
214 |
|
215 |
if stream:
|
216 |
last_ans = ""
|
|
|
409 |
bin = b""
|
410 |
for chunk in tts_mdl.tts(text):
|
411 |
bin += chunk
|
412 |
+
return binascii.hexlify(bin).decode("utf-8")
|
413 |
+
|
414 |
+
|
415 |
+
def ask(question, kb_ids, tenant_id):
|
416 |
+
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
417 |
+
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
418 |
+
|
419 |
+
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
420 |
+
retr = retrievaler if not is_kg else kg_retrievaler
|
421 |
+
|
422 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
423 |
+
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
424 |
+
max_tokens = chat_mdl.max_length
|
425 |
+
|
426 |
+
kbinfos = retr.retrieval(question, embd_mdl, tenant_id, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
427 |
+
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
428 |
+
|
429 |
+
used_token_count = 0
|
430 |
+
for i, c in enumerate(knowledges):
|
431 |
+
used_token_count += num_tokens_from_string(c)
|
432 |
+
if max_tokens * 0.97 < used_token_count:
|
433 |
+
knowledges = knowledges[:i]
|
434 |
+
break
|
435 |
+
|
436 |
+
prompt = """
|
437 |
+
Role: You're a smart assistant. Your name is Miss R.
|
438 |
+
Task: Summarize the information from knowledge bases and answer user's question.
|
439 |
+
Requirements and restriction:
|
440 |
+
- DO NOT make things up, especially for numbers.
|
441 |
+
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
442 |
+
- Answer with markdown format text.
|
443 |
+
- Answer in language of user's question.
|
444 |
+
- DO NOT make things up, especially for numbers.
|
445 |
+
|
446 |
+
### Information from knowledge bases
|
447 |
+
%s
|
448 |
+
|
449 |
+
The above is information from knowledge bases.
|
450 |
+
|
451 |
+
"""%"\n".join(knowledges)
|
452 |
+
msg = [{"role": "user", "content": question}]
|
453 |
+
|
454 |
+
def decorate_answer(answer):
|
455 |
+
nonlocal knowledges, kbinfos, prompt
|
456 |
+
answer, idx = retr.insert_citations(answer,
|
457 |
+
[ck["content_ltks"]
|
458 |
+
for ck in kbinfos["chunks"]],
|
459 |
+
[ck["vector"]
|
460 |
+
for ck in kbinfos["chunks"]],
|
461 |
+
embd_mdl,
|
462 |
+
tkweight=0.7,
|
463 |
+
vtweight=0.3)
|
464 |
+
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
465 |
+
recall_docs = [
|
466 |
+
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
467 |
+
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
468 |
+
kbinfos["doc_aggs"] = recall_docs
|
469 |
+
refs = deepcopy(kbinfos)
|
470 |
+
for c in refs["chunks"]:
|
471 |
+
if c.get("vector"):
|
472 |
+
del c["vector"]
|
473 |
+
|
474 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
475 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
476 |
+
return {"answer": answer, "reference": refs}
|
477 |
+
|
478 |
+
answer = ""
|
479 |
+
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
480 |
+
answer = ans
|
481 |
+
yield {"answer": answer, "reference": {}}
|
482 |
+
yield decorate_answer(answer)
|
483 |
+
|