Commit
·
75c7829
1
Parent(s):
2a6f834
Refactor ask decorator (#4116)
Browse files### What problem does this PR solve?
Refactor ask decorator
### Type of change
- [x] Refactoring
---------
Signed-off-by: jinhai <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
- api/db/services/dialog_service.py +120 -85
- api/db/services/llm_service.py +18 -14
api/db/services/dialog_service.py
CHANGED
@@ -23,7 +23,7 @@ from copy import deepcopy
|
|
23 |
from timeit import default_timer as timer
|
24 |
import datetime
|
25 |
from datetime import timedelta
|
26 |
-
from api.db import LLMType, ParserType,StatusEnum
|
27 |
from api.db.db_models import Dialog, DB
|
28 |
from api.db.services.common_service import CommonService
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
@@ -41,14 +41,14 @@ class DialogService(CommonService):
|
|
41 |
@classmethod
|
42 |
@DB.connection_context()
|
43 |
def get_list(cls, tenant_id,
|
44 |
-
page_number, items_per_page, orderby, desc, id
|
45 |
chats = cls.model.select()
|
46 |
if id:
|
47 |
chats = chats.where(cls.model.id == id)
|
48 |
if name:
|
49 |
chats = chats.where(cls.model.name == name)
|
50 |
chats = chats.where(
|
51 |
-
|
52 |
& (cls.model.status == StatusEnum.VALID.value)
|
53 |
)
|
54 |
if desc:
|
@@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens):
|
|
137 |
|
138 |
def chat(dialog, messages, stream=True, **kwargs):
|
139 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
if not llm:
|
144 |
-
|
145 |
-
|
|
|
146 |
if not llm:
|
147 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
148 |
max_tokens = 8192
|
149 |
else:
|
150 |
max_tokens = llm[0].max_tokens
|
|
|
|
|
|
|
151 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
152 |
-
|
153 |
-
if len(
|
154 |
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
155 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
156 |
|
157 |
-
|
158 |
-
|
|
|
|
|
159 |
|
160 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
161 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
@@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
165 |
if "doc_ids" in m:
|
166 |
attachments.extend(m["doc_ids"])
|
167 |
|
168 |
-
|
|
|
|
|
169 |
if not embd_mdl:
|
170 |
-
raise LookupError("Embedding model(%s) not found" %
|
|
|
|
|
171 |
|
172 |
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
173 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
174 |
else:
|
175 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
176 |
|
|
|
|
|
177 |
prompt_config = dialog.prompt_config
|
178 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
179 |
tts_mdl = None
|
@@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
200 |
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
201 |
else:
|
202 |
questions = questions[-1:]
|
203 |
-
|
204 |
-
|
205 |
|
206 |
rerank_mdl = None
|
207 |
if dialog.rerank_id:
|
208 |
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
209 |
|
210 |
-
|
211 |
-
|
|
|
212 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
213 |
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
214 |
else:
|
215 |
if prompt_config.get("keyword", False):
|
216 |
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
217 |
-
|
218 |
|
219 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
220 |
-
kbinfos =
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
225 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
226 |
logging.debug(
|
227 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
228 |
-
retrieval_tm = timer()
|
229 |
|
230 |
if not knowledges and prompt_config.get("empty_response"):
|
231 |
empty_res = prompt_config["empty_response"]
|
@@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
249 |
max_tokens - used_token_count)
|
250 |
|
251 |
def decorate_answer(answer):
|
252 |
-
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt,
|
|
|
|
|
|
|
253 |
refs = []
|
254 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
255 |
-
answer, idx =
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
264 |
recall_docs = [
|
265 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
@@ -274,10 +298,20 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
274 |
|
275 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
276 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
282 |
|
283 |
if stream:
|
@@ -304,15 +338,15 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
304 |
|
305 |
|
306 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
307 |
-
sys_prompt = "
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
{}
|
312 |
|
313 |
-
|
314 |
{}
|
315 |
-
|
316 |
""".format(
|
317 |
index_name(tenant_id),
|
318 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
@@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
321 |
tried_times = 0
|
322 |
|
323 |
def get_table():
|
324 |
-
nonlocal sys_prompt,
|
325 |
-
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content":
|
326 |
"temperature": 0.06})
|
327 |
-
logging.debug(f"{question} ==> {
|
328 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
329 |
sql = re.sub(r".*select ", "select ", sql.lower())
|
330 |
sql = re.sub(r" +", " ", sql)
|
@@ -352,21 +386,23 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
352 |
if tbl is None:
|
353 |
return None
|
354 |
if tbl.get("error") and tried_times <= 2:
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
{}
|
359 |
-
|
360 |
-
|
361 |
{}
|
|
|
|
|
362 |
|
363 |
-
|
364 |
{}
|
365 |
|
366 |
-
|
367 |
{}
|
368 |
|
369 |
-
|
370 |
""".format(
|
371 |
index_name(tenant_id),
|
372 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
@@ -381,21 +417,21 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
381 |
|
382 |
docid_idx = set([ii for ii, c in enumerate(
|
383 |
tbl["columns"]) if c["name"] == "doc_id"])
|
384 |
-
|
385 |
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
386 |
-
|
387 |
-
len(tbl["columns"])) if ii not in (docid_idx |
|
388 |
|
389 |
-
# compose
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
|
394 |
-
line = "|" + "|".join(["------" for _ in range(len(
|
395 |
("|------|" if docid_idx and docid_idx else "")
|
396 |
|
397 |
rows = ["|" +
|
398 |
-
"|".join([rmSpace(str(r[i])) for i in
|
399 |
"|" for r in tbl["rows"]]
|
400 |
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
401 |
if quota:
|
@@ -404,24 +440,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
404 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
405 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
406 |
|
407 |
-
if not docid_idx or not
|
408 |
logging.warning("SQL missing field: " + sql)
|
409 |
return {
|
410 |
-
"answer": "\n".join([
|
411 |
"reference": {"chunks": [], "doc_aggs": []},
|
412 |
"prompt": sys_prompt
|
413 |
}
|
414 |
|
415 |
docid_idx = list(docid_idx)[0]
|
416 |
-
|
417 |
doc_aggs = {}
|
418 |
for r in tbl["rows"]:
|
419 |
if r[docid_idx] not in doc_aggs:
|
420 |
-
doc_aggs[r[docid_idx]] = {"doc_name": r[
|
421 |
doc_aggs[r[docid_idx]]["count"] += 1
|
422 |
return {
|
423 |
-
"answer": "\n".join([
|
424 |
-
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[
|
425 |
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
426 |
doc_aggs.items()]},
|
427 |
"prompt": sys_prompt
|
@@ -492,7 +528,7 @@ Requirements:
|
|
492 |
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
493 |
if isinstance(kwd, tuple):
|
494 |
kwd = kwd[0]
|
495 |
-
if kwd.find("**ERROR**") >=0:
|
496 |
return ""
|
497 |
return kwd
|
498 |
|
@@ -605,16 +641,16 @@ def tts(tts_mdl, text):
|
|
605 |
|
606 |
def ask(question, kb_ids, tenant_id):
|
607 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
608 |
-
|
609 |
|
610 |
-
|
611 |
-
|
612 |
|
613 |
-
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING,
|
614 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
615 |
max_tokens = chat_mdl.max_length
|
616 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
617 |
-
kbinfos =
|
618 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
619 |
prompt = """
|
620 |
Role: You're a smart assistant. Your name is Miss R.
|
@@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id):
|
|
636 |
|
637 |
def decorate_answer(answer):
|
638 |
nonlocal knowledges, kbinfos, prompt
|
639 |
-
answer, idx =
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
648 |
recall_docs = [
|
649 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
@@ -664,4 +700,3 @@ def ask(question, kb_ids, tenant_id):
|
|
664 |
answer = ans
|
665 |
yield {"answer": answer, "reference": {}}
|
666 |
yield decorate_answer(answer)
|
667 |
-
|
|
|
23 |
from timeit import default_timer as timer
|
24 |
import datetime
|
25 |
from datetime import timedelta
|
26 |
+
from api.db import LLMType, ParserType, StatusEnum
|
27 |
from api.db.db_models import Dialog, DB
|
28 |
from api.db.services.common_service import CommonService
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
41 |
@classmethod
|
42 |
@DB.connection_context()
|
43 |
def get_list(cls, tenant_id,
|
44 |
+
page_number, items_per_page, orderby, desc, id, name):
|
45 |
chats = cls.model.select()
|
46 |
if id:
|
47 |
chats = chats.where(cls.model.id == id)
|
48 |
if name:
|
49 |
chats = chats.where(cls.model.name == name)
|
50 |
chats = chats.where(
|
51 |
+
(cls.model.tenant_id == tenant_id)
|
52 |
& (cls.model.status == StatusEnum.VALID.value)
|
53 |
)
|
54 |
if desc:
|
|
|
137 |
|
138 |
def chat(dialog, messages, stream=True, **kwargs):
|
139 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
140 |
+
|
141 |
+
chat_start_ts = timer()
|
142 |
+
|
143 |
+
# Get llm model name and model provider name
|
144 |
+
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
145 |
+
|
146 |
+
# Get llm model instance by model and provide name
|
147 |
+
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)
|
148 |
+
|
149 |
if not llm:
|
150 |
+
# Model name is provided by tenant, but not system built-in
|
151 |
+
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
|
152 |
+
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
|
153 |
if not llm:
|
154 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
155 |
max_tokens = 8192
|
156 |
else:
|
157 |
max_tokens = llm[0].max_tokens
|
158 |
+
|
159 |
+
check_llm_ts = timer()
|
160 |
+
|
161 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
162 |
+
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
163 |
+
if len(embedding_list) != 1:
|
164 |
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
165 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
166 |
|
167 |
+
embedding_model_name = embedding_list[0]
|
168 |
+
|
169 |
+
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
170 |
+
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
171 |
|
172 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
173 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
|
|
177 |
if "doc_ids" in m:
|
178 |
attachments.extend(m["doc_ids"])
|
179 |
|
180 |
+
create_retriever_ts = timer()
|
181 |
+
|
182 |
+
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
|
183 |
if not embd_mdl:
|
184 |
+
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
|
185 |
+
|
186 |
+
bind_embedding_ts = timer()
|
187 |
|
188 |
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
189 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
190 |
else:
|
191 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
192 |
|
193 |
+
bind_llm_ts = timer()
|
194 |
+
|
195 |
prompt_config = dialog.prompt_config
|
196 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
197 |
tts_mdl = None
|
|
|
218 |
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
219 |
else:
|
220 |
questions = questions[-1:]
|
221 |
+
|
222 |
+
refine_question_ts = timer()
|
223 |
|
224 |
rerank_mdl = None
|
225 |
if dialog.rerank_id:
|
226 |
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
227 |
|
228 |
+
bind_reranker_ts = timer()
|
229 |
+
generate_keyword_ts = bind_reranker_ts
|
230 |
+
|
231 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
232 |
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
233 |
else:
|
234 |
if prompt_config.get("keyword", False):
|
235 |
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
236 |
+
generate_keyword_ts = timer()
|
237 |
|
238 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
239 |
+
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
240 |
+
dialog.similarity_threshold,
|
241 |
+
dialog.vector_similarity_weight,
|
242 |
+
doc_ids=attachments,
|
243 |
+
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
244 |
+
|
245 |
+
retrieval_ts = timer()
|
246 |
+
|
247 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
248 |
logging.debug(
|
249 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
|
|
250 |
|
251 |
if not knowledges and prompt_config.get("empty_response"):
|
252 |
empty_res = prompt_config["empty_response"]
|
|
|
270 |
max_tokens - used_token_count)
|
271 |
|
272 |
def decorate_answer(answer):
|
273 |
+
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
|
274 |
+
|
275 |
+
finish_chat_ts = timer()
|
276 |
+
|
277 |
refs = []
|
278 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
279 |
+
answer, idx = retriever.insert_citations(answer,
|
280 |
+
[ck["content_ltks"]
|
281 |
+
for ck in kbinfos["chunks"]],
|
282 |
+
[ck["vector"]
|
283 |
+
for ck in kbinfos["chunks"]],
|
284 |
+
embd_mdl,
|
285 |
+
tkweight=1 - dialog.vector_similarity_weight,
|
286 |
+
vtweight=dialog.vector_similarity_weight)
|
287 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
288 |
recall_docs = [
|
289 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
|
298 |
|
299 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
300 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
301 |
+
finish_chat_ts = timer()
|
302 |
+
|
303 |
+
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
304 |
+
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
305 |
+
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
|
306 |
+
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
|
307 |
+
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
|
308 |
+
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
|
309 |
+
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
|
310 |
+
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
|
311 |
+
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
|
312 |
+
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
313 |
+
|
314 |
+
prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
315 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
316 |
|
317 |
if stream:
|
|
|
338 |
|
339 |
|
340 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
341 |
+
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
342 |
+
user_prompt = """
|
343 |
+
Table name: {};
|
344 |
+
Table of database fields are as follows:
|
345 |
{}
|
346 |
|
347 |
+
Question are as follows:
|
348 |
{}
|
349 |
+
Please write the SQL, only SQL, without any other explanations or text.
|
350 |
""".format(
|
351 |
index_name(tenant_id),
|
352 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
|
355 |
tried_times = 0
|
356 |
|
357 |
def get_table():
|
358 |
+
nonlocal sys_prompt, user_prompt, question, tried_times
|
359 |
+
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
|
360 |
"temperature": 0.06})
|
361 |
+
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
362 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
363 |
sql = re.sub(r".*select ", "select ", sql.lower())
|
364 |
sql = re.sub(r" +", " ", sql)
|
|
|
386 |
if tbl is None:
|
387 |
return None
|
388 |
if tbl.get("error") and tried_times <= 2:
|
389 |
+
user_prompt = """
|
390 |
+
Table name: {};
|
391 |
+
Table of database fields are as follows:
|
392 |
{}
|
393 |
+
|
394 |
+
Question are as follows:
|
395 |
{}
|
396 |
+
Please write the SQL, only SQL, without any other explanations or text.
|
397 |
+
|
398 |
|
399 |
+
The SQL error you provided last time is as follows:
|
400 |
{}
|
401 |
|
402 |
+
Error issued by database as follows:
|
403 |
{}
|
404 |
|
405 |
+
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
406 |
""".format(
|
407 |
index_name(tenant_id),
|
408 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
|
417 |
|
418 |
docid_idx = set([ii for ii, c in enumerate(
|
419 |
tbl["columns"]) if c["name"] == "doc_id"])
|
420 |
+
doc_name_idx = set([ii for ii, c in enumerate(
|
421 |
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
422 |
+
column_idx = [ii for ii in range(
|
423 |
+
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
424 |
|
425 |
+
# compose Markdown table
|
426 |
+
columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
427 |
+
tbl["columns"][i]["name"])) for i in
|
428 |
+
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
429 |
|
430 |
+
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
|
431 |
("|------|" if docid_idx and docid_idx else "")
|
432 |
|
433 |
rows = ["|" +
|
434 |
+
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
|
435 |
"|" for r in tbl["rows"]]
|
436 |
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
437 |
if quota:
|
|
|
440 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
441 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
442 |
|
443 |
+
if not docid_idx or not doc_name_idx:
|
444 |
logging.warning("SQL missing field: " + sql)
|
445 |
return {
|
446 |
+
"answer": "\n".join([columns, line, rows]),
|
447 |
"reference": {"chunks": [], "doc_aggs": []},
|
448 |
"prompt": sys_prompt
|
449 |
}
|
450 |
|
451 |
docid_idx = list(docid_idx)[0]
|
452 |
+
doc_name_idx = list(doc_name_idx)[0]
|
453 |
doc_aggs = {}
|
454 |
for r in tbl["rows"]:
|
455 |
if r[docid_idx] not in doc_aggs:
|
456 |
+
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
457 |
doc_aggs[r[docid_idx]]["count"] += 1
|
458 |
return {
|
459 |
+
"answer": "\n".join([columns, line, rows]),
|
460 |
+
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
461 |
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
462 |
doc_aggs.items()]},
|
463 |
"prompt": sys_prompt
|
|
|
528 |
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
529 |
if isinstance(kwd, tuple):
|
530 |
kwd = kwd[0]
|
531 |
+
if kwd.find("**ERROR**") >= 0:
|
532 |
return ""
|
533 |
return kwd
|
534 |
|
|
|
641 |
|
642 |
def ask(question, kb_ids, tenant_id):
|
643 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
644 |
+
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
645 |
|
646 |
+
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
647 |
+
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
648 |
|
649 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
650 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
651 |
max_tokens = chat_mdl.max_length
|
652 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
653 |
+
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
654 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
655 |
prompt = """
|
656 |
Role: You're a smart assistant. Your name is Miss R.
|
|
|
672 |
|
673 |
def decorate_answer(answer):
|
674 |
nonlocal knowledges, kbinfos, prompt
|
675 |
+
answer, idx = retriever.insert_citations(answer,
|
676 |
+
[ck["content_ltks"]
|
677 |
+
for ck in kbinfos["chunks"]],
|
678 |
+
[ck["vector"]
|
679 |
+
for ck in kbinfos["chunks"]],
|
680 |
+
embd_mdl,
|
681 |
+
tkweight=0.7,
|
682 |
+
vtweight=0.3)
|
683 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
684 |
recall_docs = [
|
685 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
|
700 |
answer = ans
|
701 |
yield {"answer": answer, "reference": {}}
|
702 |
yield decorate_answer(answer)
|
|
api/db/services/llm_service.py
CHANGED
@@ -72,10 +72,12 @@ class TenantLLMService(CommonService):
|
|
72 |
return model_name, None
|
73 |
if len(arr) > 2:
|
74 |
return "@".join(arr[0:-1]), arr[-1]
|
|
|
|
|
75 |
try:
|
76 |
-
|
77 |
-
|
78 |
-
if arr[-1] not in
|
79 |
return model_name, None
|
80 |
return arr[0], arr[-1]
|
81 |
except Exception as e:
|
@@ -113,11 +115,11 @@ class TenantLLMService(CommonService):
|
|
113 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
114 |
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
115 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
116 |
-
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
|
117 |
if not model_config:
|
118 |
if mdlnm == "flag-embedding":
|
119 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
120 |
-
|
121 |
else:
|
122 |
if not mdlnm:
|
123 |
raise LookupError(f"Type of {llm_type} model is not set.")
|
@@ -200,8 +202,8 @@ class TenantLLMService(CommonService):
|
|
200 |
return num
|
201 |
else:
|
202 |
tenant_llm = tenant_llms[0]
|
203 |
-
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
204 |
-
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
205 |
.execute()
|
206 |
except Exception:
|
207 |
logging.exception("TenantLLMService.increase_usage got exception")
|
@@ -231,7 +233,7 @@ class LLMBundle(object):
|
|
231 |
for lm in LLMService.query(llm_name=llm_name):
|
232 |
self.max_length = lm.max_tokens
|
233 |
break
|
234 |
-
|
235 |
def encode(self, texts: list):
|
236 |
embeddings, used_tokens = self.mdl.encode(texts)
|
237 |
if not TenantLLMService.increase_usage(
|
@@ -274,11 +276,11 @@ class LLMBundle(object):
|
|
274 |
|
275 |
def tts(self, text):
|
276 |
for chunk in self.mdl.tts(text):
|
277 |
-
if isinstance(chunk,int):
|
278 |
if not TenantLLMService.increase_usage(
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
return
|
283 |
yield chunk
|
284 |
|
@@ -287,7 +289,8 @@ class LLMBundle(object):
|
|
287 |
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
288 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
289 |
logging.error(
|
290 |
-
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
|
|
291 |
return txt
|
292 |
|
293 |
def chat_streamly(self, system, history, gen_conf):
|
@@ -296,6 +299,7 @@ class LLMBundle(object):
|
|
296 |
if not TenantLLMService.increase_usage(
|
297 |
self.tenant_id, self.llm_type, txt, self.llm_name):
|
298 |
logging.error(
|
299 |
-
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
|
|
300 |
return
|
301 |
yield txt
|
|
|
72 |
return model_name, None
|
73 |
if len(arr) > 2:
|
74 |
return "@".join(arr[0:-1]), arr[-1]
|
75 |
+
|
76 |
+
# model name must be xxx@yyy
|
77 |
try:
|
78 |
+
model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
79 |
+
model_providers = set([f["name"] for f in model_factories])
|
80 |
+
if arr[-1] not in model_providers:
|
81 |
return model_name, None
|
82 |
return arr[0], arr[-1]
|
83 |
except Exception as e:
|
|
|
115 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
116 |
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
117 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
118 |
+
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
119 |
if not model_config:
|
120 |
if mdlnm == "flag-embedding":
|
121 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
122 |
+
"llm_name": llm_name, "api_base": ""}
|
123 |
else:
|
124 |
if not mdlnm:
|
125 |
raise LookupError(f"Type of {llm_type} model is not set.")
|
|
|
202 |
return num
|
203 |
else:
|
204 |
tenant_llm = tenant_llms[0]
|
205 |
+
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \
|
206 |
+
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \
|
207 |
.execute()
|
208 |
except Exception:
|
209 |
logging.exception("TenantLLMService.increase_usage got exception")
|
|
|
233 |
for lm in LLMService.query(llm_name=llm_name):
|
234 |
self.max_length = lm.max_tokens
|
235 |
break
|
236 |
+
|
237 |
def encode(self, texts: list):
|
238 |
embeddings, used_tokens = self.mdl.encode(texts)
|
239 |
if not TenantLLMService.increase_usage(
|
|
|
276 |
|
277 |
def tts(self, text):
|
278 |
for chunk in self.mdl.tts(text):
|
279 |
+
if isinstance(chunk, int):
|
280 |
if not TenantLLMService.increase_usage(
|
281 |
+
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
282 |
+
logging.error(
|
283 |
+
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
284 |
return
|
285 |
yield chunk
|
286 |
|
|
|
289 |
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
290 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
291 |
logging.error(
|
292 |
+
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
293 |
+
used_tokens))
|
294 |
return txt
|
295 |
|
296 |
def chat_streamly(self, system, history, gen_conf):
|
|
|
299 |
if not TenantLLMService.increase_usage(
|
300 |
self.tenant_id, self.llm_type, txt, self.llm_name):
|
301 |
logging.error(
|
302 |
+
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
303 |
+
txt))
|
304 |
return
|
305 |
yield txt
|