liuhua
liuhua
commited on
Commit
·
f335329
1
Parent(s):
9bf66a3
Fix the bug that the agent could not find the context (#3795)
Browse files### What problem does this PR solve?
Fix the bug that the agent could not find the context
#3682
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
Co-authored-by: liuhua <[email protected]>
- api/apps/sdk/session.py +28 -28
- api/db/db_models.py +8 -1
api/apps/sdk/session.py
CHANGED
@@ -35,7 +35,7 @@ from api.db.services.llm_service import LLMBundle
|
|
35 |
|
36 |
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
|
37 |
@token_required
|
38 |
-
def create(tenant_id,
|
39 |
req = request.json
|
40 |
req["dialog_id"] = chat_id
|
41 |
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
@@ -77,9 +77,10 @@ def create_agent_session(tenant_id, agent_id):
|
|
77 |
conv = {
|
78 |
"id": get_uuid(),
|
79 |
"dialog_id": cvs.id,
|
80 |
-
"user_id": req.get("usr_id",
|
81 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
82 |
-
"source": "agent"
|
|
|
83 |
}
|
84 |
API4ConversationService.save(**conv)
|
85 |
conv["agent_id"] = conv.pop("dialog_id")
|
@@ -88,11 +89,11 @@ def create_agent_session(tenant_id, agent_id):
|
|
88 |
|
89 |
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
|
90 |
@token_required
|
91 |
-
def update(tenant_id,
|
92 |
req = request.json
|
93 |
req["dialog_id"] = chat_id
|
94 |
conv_id = session_id
|
95 |
-
conv = ConversationService.query(id=conv_id,
|
96 |
if not conv:
|
97 |
return get_error_data_result(message="Session does not exist")
|
98 |
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
@@ -123,12 +124,12 @@ def completion(tenant_id, chat_id):
|
|
123 |
return get_error_data_result(message="`name` can not be empty.")
|
124 |
ConversationService.save(**conv)
|
125 |
e, conv = ConversationService.get_by_id(conv["id"])
|
126 |
-
session_id
|
127 |
else:
|
128 |
session_id = req.get("session_id")
|
129 |
if not req.get("question"):
|
130 |
return get_error_data_result(message="Please input your question.")
|
131 |
-
conv = ConversationService.query(id=session_id,
|
132 |
if not conv:
|
133 |
return get_error_data_result(message="Session does not exist")
|
134 |
conv = conv[0]
|
@@ -182,18 +183,18 @@ def completion(tenant_id, chat_id):
|
|
182 |
chunk_list.append(new_chunk)
|
183 |
reference["chunks"] = chunk_list
|
184 |
ans["id"] = message_id
|
185 |
-
ans["session_id"]
|
186 |
|
187 |
def stream():
|
188 |
nonlocal dia, msg, req, conv
|
189 |
try:
|
190 |
for ans in chat(dia, msg, **req):
|
191 |
fillin_conv(ans)
|
192 |
-
yield "data:" + json.dumps({"code": 0,
|
193 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
194 |
except Exception as e:
|
195 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
196 |
-
"data": {"answer": "**ERROR**: " + str(e),
|
197 |
ensure_ascii=False) + "\n\n"
|
198 |
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
|
199 |
|
@@ -237,7 +238,8 @@ def agent_completion(tenant_id, agent_id):
|
|
237 |
"dialog_id": cvs.id,
|
238 |
"user_id": req.get("user_id", ""),
|
239 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
240 |
-
"source": "agent"
|
|
|
241 |
}
|
242 |
API4ConversationService.save(**conv)
|
243 |
conv = API4Conversation(**conv)
|
@@ -246,6 +248,7 @@ def agent_completion(tenant_id, agent_id):
|
|
246 |
e, conv = API4ConversationService.get_by_id(req["session_id"])
|
247 |
if not e:
|
248 |
return get_error_data_result(message="Session not found!")
|
|
|
249 |
|
250 |
messages = conv.message
|
251 |
question = req.get("question")
|
@@ -267,11 +270,11 @@ def agent_completion(tenant_id, agent_id):
|
|
267 |
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
268 |
message_id = msg[-1]["id"]
|
269 |
|
270 |
-
if "quote" not in req: req["quote"] = False
|
271 |
stream = req.get("stream", True)
|
272 |
|
273 |
def fillin_conv(ans):
|
274 |
reference = ans["reference"]
|
|
|
275 |
temp_reference = deepcopy(ans["reference"])
|
276 |
nonlocal conv, message_id
|
277 |
if not conv.reference:
|
@@ -322,7 +325,7 @@ def agent_completion(tenant_id, agent_id):
|
|
322 |
def sse():
|
323 |
nonlocal answer, cvs
|
324 |
try:
|
325 |
-
for ans in canvas.run(stream=
|
326 |
if ans.get("running_status"):
|
327 |
yield "data:" + json.dumps({"code": 0, "message": "",
|
328 |
"data": {"answer": ans["content"],
|
@@ -341,10 +344,10 @@ def agent_completion(tenant_id, agent_id):
|
|
341 |
canvas.history.append(("assistant", final_ans["content"]))
|
342 |
if final_ans.get("reference"):
|
343 |
canvas.reference.append(final_ans["reference"])
|
344 |
-
|
345 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
346 |
except Exception as e:
|
347 |
-
|
348 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
349 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
350 |
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
@@ -364,7 +367,7 @@ def agent_completion(tenant_id, agent_id):
|
|
364 |
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
365 |
if final_ans.get("reference"):
|
366 |
canvas.reference.append(final_ans["reference"])
|
367 |
-
|
368 |
|
369 |
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
370 |
fillin_conv(result)
|
@@ -372,10 +375,9 @@ def agent_completion(tenant_id, agent_id):
|
|
372 |
rename_field(result)
|
373 |
return get_result(data=result)
|
374 |
|
375 |
-
|
376 |
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
|
377 |
@token_required
|
378 |
-
def list_session(chat_id,
|
379 |
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
380 |
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
381 |
id = request.args.get("id")
|
@@ -387,7 +389,7 @@ def list_session(chat_id, tenant_id):
|
|
387 |
desc = False
|
388 |
else:
|
389 |
desc = True
|
390 |
-
convs = ConversationService.get_list(chat_id,
|
391 |
if not convs:
|
392 |
return get_result(data=[])
|
393 |
for conv in convs:
|
@@ -429,7 +431,7 @@ def list_session(chat_id, tenant_id):
|
|
429 |
|
430 |
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
|
431 |
@token_required
|
432 |
-
def delete(tenant_id,
|
433 |
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
434 |
return get_error_data_result(message="You don't own the chat")
|
435 |
req = request.json
|
@@ -437,22 +439,21 @@ def delete(tenant_id, chat_id):
|
|
437 |
if not req:
|
438 |
ids = None
|
439 |
else:
|
440 |
-
ids
|
441 |
|
442 |
if not ids:
|
443 |
conv_list = []
|
444 |
for conv in convs:
|
445 |
conv_list.append(conv.id)
|
446 |
else:
|
447 |
-
conv_list
|
448 |
for id in conv_list:
|
449 |
-
conv = ConversationService.query(id=id,
|
450 |
if not conv:
|
451 |
return get_error_data_result(message="The chat doesn't own the session")
|
452 |
ConversationService.delete_by_id(id)
|
453 |
return get_result()
|
454 |
|
455 |
-
|
456 |
@manager.route('/sessions/ask', methods=['POST'])
|
457 |
@token_required
|
458 |
def ask_about(tenant_id):
|
@@ -461,18 +462,17 @@ def ask_about(tenant_id):
|
|
461 |
return get_error_data_result("`question` is required.")
|
462 |
if not req.get("dataset_ids"):
|
463 |
return get_error_data_result("`dataset_ids` is required.")
|
464 |
-
if not isinstance(req.get("dataset_ids"),
|
465 |
return get_error_data_result("`dataset_ids` should be a list.")
|
466 |
-
req["kb_ids"]
|
467 |
for kb_id in req["kb_ids"]:
|
468 |
-
if not KnowledgebaseService.accessible(kb_id,
|
469 |
return get_error_data_result(f"You don't own the dataset {kb_id}.")
|
470 |
kbs = KnowledgebaseService.query(id=kb_id)
|
471 |
kb = kbs[0]
|
472 |
if kb.chunk_num == 0:
|
473 |
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
474 |
uid = tenant_id
|
475 |
-
|
476 |
def stream():
|
477 |
nonlocal req, uid
|
478 |
try:
|
|
|
35 |
|
36 |
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
|
37 |
@token_required
|
38 |
+
def create(tenant_id,chat_id):
|
39 |
req = request.json
|
40 |
req["dialog_id"] = chat_id
|
41 |
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
|
|
77 |
conv = {
|
78 |
"id": get_uuid(),
|
79 |
"dialog_id": cvs.id,
|
80 |
+
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
|
81 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
82 |
+
"source": "agent",
|
83 |
+
"dsl":json.loads(cvs.dsl)
|
84 |
}
|
85 |
API4ConversationService.save(**conv)
|
86 |
conv["agent_id"] = conv.pop("dialog_id")
|
|
|
89 |
|
90 |
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
|
91 |
@token_required
|
92 |
+
def update(tenant_id,chat_id,session_id):
|
93 |
req = request.json
|
94 |
req["dialog_id"] = chat_id
|
95 |
conv_id = session_id
|
96 |
+
conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
|
97 |
if not conv:
|
98 |
return get_error_data_result(message="Session does not exist")
|
99 |
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
|
|
124 |
return get_error_data_result(message="`name` can not be empty.")
|
125 |
ConversationService.save(**conv)
|
126 |
e, conv = ConversationService.get_by_id(conv["id"])
|
127 |
+
session_id=conv.id
|
128 |
else:
|
129 |
session_id = req.get("session_id")
|
130 |
if not req.get("question"):
|
131 |
return get_error_data_result(message="Please input your question.")
|
132 |
+
conv = ConversationService.query(id=session_id,dialog_id=chat_id)
|
133 |
if not conv:
|
134 |
return get_error_data_result(message="Session does not exist")
|
135 |
conv = conv[0]
|
|
|
183 |
chunk_list.append(new_chunk)
|
184 |
reference["chunks"] = chunk_list
|
185 |
ans["id"] = message_id
|
186 |
+
ans["session_id"]=session_id
|
187 |
|
188 |
def stream():
|
189 |
nonlocal dia, msg, req, conv
|
190 |
try:
|
191 |
for ans in chat(dia, msg, **req):
|
192 |
fillin_conv(ans)
|
193 |
+
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
194 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
195 |
except Exception as e:
|
196 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
197 |
+
"data": {"answer": "**ERROR**: " + str(e),"reference": []}},
|
198 |
ensure_ascii=False) + "\n\n"
|
199 |
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
|
200 |
|
|
|
238 |
"dialog_id": cvs.id,
|
239 |
"user_id": req.get("user_id", ""),
|
240 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
241 |
+
"source": "agent",
|
242 |
+
"dsl": json.loads(cvs.dsl)
|
243 |
}
|
244 |
API4ConversationService.save(**conv)
|
245 |
conv = API4Conversation(**conv)
|
|
|
248 |
e, conv = API4ConversationService.get_by_id(req["session_id"])
|
249 |
if not e:
|
250 |
return get_error_data_result(message="Session not found!")
|
251 |
+
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
|
252 |
|
253 |
messages = conv.message
|
254 |
question = req.get("question")
|
|
|
270 |
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
271 |
message_id = msg[-1]["id"]
|
272 |
|
|
|
273 |
stream = req.get("stream", True)
|
274 |
|
275 |
def fillin_conv(ans):
|
276 |
reference = ans["reference"]
|
277 |
+
print(reference,flush=True)
|
278 |
temp_reference = deepcopy(ans["reference"])
|
279 |
nonlocal conv, message_id
|
280 |
if not conv.reference:
|
|
|
325 |
def sse():
|
326 |
nonlocal answer, cvs
|
327 |
try:
|
328 |
+
for ans in canvas.run(stream=stream):
|
329 |
if ans.get("running_status"):
|
330 |
yield "data:" + json.dumps({"code": 0, "message": "",
|
331 |
"data": {"answer": ans["content"],
|
|
|
344 |
canvas.history.append(("assistant", final_ans["content"]))
|
345 |
if final_ans.get("reference"):
|
346 |
canvas.reference.append(final_ans["reference"])
|
347 |
+
conv.dsl = json.loads(str(canvas))
|
348 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
349 |
except Exception as e:
|
350 |
+
conv.dsl = json.loads(str(canvas))
|
351 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
352 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
353 |
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
|
|
367 |
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
368 |
if final_ans.get("reference"):
|
369 |
canvas.reference.append(final_ans["reference"])
|
370 |
+
conv.dsl = json.loads(str(canvas))
|
371 |
|
372 |
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
373 |
fillin_conv(result)
|
|
|
375 |
rename_field(result)
|
376 |
return get_result(data=result)
|
377 |
|
|
|
378 |
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
|
379 |
@token_required
|
380 |
+
def list_session(chat_id,tenant_id):
|
381 |
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
382 |
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
383 |
id = request.args.get("id")
|
|
|
389 |
desc = False
|
390 |
else:
|
391 |
desc = True
|
392 |
+
convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
|
393 |
if not convs:
|
394 |
return get_result(data=[])
|
395 |
for conv in convs:
|
|
|
431 |
|
432 |
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
|
433 |
@token_required
|
434 |
+
def delete(tenant_id,chat_id):
|
435 |
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
436 |
return get_error_data_result(message="You don't own the chat")
|
437 |
req = request.json
|
|
|
439 |
if not req:
|
440 |
ids = None
|
441 |
else:
|
442 |
+
ids=req.get("ids")
|
443 |
|
444 |
if not ids:
|
445 |
conv_list = []
|
446 |
for conv in convs:
|
447 |
conv_list.append(conv.id)
|
448 |
else:
|
449 |
+
conv_list=ids
|
450 |
for id in conv_list:
|
451 |
+
conv = ConversationService.query(id=id,dialog_id=chat_id)
|
452 |
if not conv:
|
453 |
return get_error_data_result(message="The chat doesn't own the session")
|
454 |
ConversationService.delete_by_id(id)
|
455 |
return get_result()
|
456 |
|
|
|
457 |
@manager.route('/sessions/ask', methods=['POST'])
|
458 |
@token_required
|
459 |
def ask_about(tenant_id):
|
|
|
462 |
return get_error_data_result("`question` is required.")
|
463 |
if not req.get("dataset_ids"):
|
464 |
return get_error_data_result("`dataset_ids` is required.")
|
465 |
+
if not isinstance(req.get("dataset_ids"),list):
|
466 |
return get_error_data_result("`dataset_ids` should be a list.")
|
467 |
+
req["kb_ids"]=req.pop("dataset_ids")
|
468 |
for kb_id in req["kb_ids"]:
|
469 |
+
if not KnowledgebaseService.accessible(kb_id,tenant_id):
|
470 |
return get_error_data_result(f"You don't own the dataset {kb_id}.")
|
471 |
kbs = KnowledgebaseService.query(id=kb_id)
|
472 |
kb = kbs[0]
|
473 |
if kb.chunk_num == 0:
|
474 |
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
475 |
uid = tenant_id
|
|
|
476 |
def stream():
|
477 |
nonlocal req, uid
|
478 |
try:
|
api/db/db_models.py
CHANGED
@@ -947,7 +947,7 @@ class API4Conversation(DataBaseModel):
|
|
947 |
reference = JSONField(null=True, default=[])
|
948 |
tokens = IntegerField(default=0)
|
949 |
source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
|
950 |
-
|
951 |
duration = FloatField(default=0, index=True)
|
952 |
round = IntegerField(default=0, index=True)
|
953 |
thumb_up = IntegerField(default=0, index=True)
|
@@ -1070,3 +1070,10 @@ def migrate_db():
|
|
1070 |
)
|
1071 |
except Exception:
|
1072 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
reference = JSONField(null=True, default=[])
|
948 |
tokens = IntegerField(default=0)
|
949 |
source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
|
950 |
+
dsl = JSONField(null=True, default={})
|
951 |
duration = FloatField(default=0, index=True)
|
952 |
round = IntegerField(default=0, index=True)
|
953 |
thumb_up = IntegerField(default=0, index=True)
|
|
|
1070 |
)
|
1071 |
except Exception:
|
1072 |
pass
|
1073 |
+
try:
|
1074 |
+
migrate(
|
1075 |
+
migrator.add_column("api_4_conversation","dsl",JSONField(null=True, default={}))
|
1076 |
+
)
|
1077 |
+
except Exception:
|
1078 |
+
pass
|
1079 |
+
|