liuhua liuhua commited on
Commit
f335329
·
1 Parent(s): 9bf66a3

Fix the bug that the agent could not find the context (#3795)

Browse files

### What problem does this PR solve?

Fix the bug that the agent could not find the context
#3682
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: liuhua <[email protected]>

Files changed (2) hide show
  1. api/apps/sdk/session.py +28 -28
  2. api/db/db_models.py +8 -1
api/apps/sdk/session.py CHANGED
@@ -35,7 +35,7 @@ from api.db.services.llm_service import LLMBundle
35
 
36
  @manager.route('/chats/<chat_id>/sessions', methods=['POST'])
37
  @token_required
38
- def create(tenant_id, chat_id):
39
  req = request.json
40
  req["dialog_id"] = chat_id
41
  dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
@@ -77,9 +77,10 @@ def create_agent_session(tenant_id, agent_id):
77
  conv = {
78
  "id": get_uuid(),
79
  "dialog_id": cvs.id,
80
- "user_id": req.get("usr_id", "") if isinstance(req, dict) else "",
81
  "message": [{"role": "assistant", "content": canvas.get_prologue()}],
82
- "source": "agent"
 
83
  }
84
  API4ConversationService.save(**conv)
85
  conv["agent_id"] = conv.pop("dialog_id")
@@ -88,11 +89,11 @@ def create_agent_session(tenant_id, agent_id):
88
 
89
  @manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
90
  @token_required
91
- def update(tenant_id, chat_id, session_id):
92
  req = request.json
93
  req["dialog_id"] = chat_id
94
  conv_id = session_id
95
- conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
96
  if not conv:
97
  return get_error_data_result(message="Session does not exist")
98
  if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@@ -123,12 +124,12 @@ def completion(tenant_id, chat_id):
123
  return get_error_data_result(message="`name` can not be empty.")
124
  ConversationService.save(**conv)
125
  e, conv = ConversationService.get_by_id(conv["id"])
126
- session_id = conv.id
127
  else:
128
  session_id = req.get("session_id")
129
  if not req.get("question"):
130
  return get_error_data_result(message="Please input your question.")
131
- conv = ConversationService.query(id=session_id, dialog_id=chat_id)
132
  if not conv:
133
  return get_error_data_result(message="Session does not exist")
134
  conv = conv[0]
@@ -182,18 +183,18 @@ def completion(tenant_id, chat_id):
182
  chunk_list.append(new_chunk)
183
  reference["chunks"] = chunk_list
184
  ans["id"] = message_id
185
- ans["session_id"] = session_id
186
 
187
  def stream():
188
  nonlocal dia, msg, req, conv
189
  try:
190
  for ans in chat(dia, msg, **req):
191
  fillin_conv(ans)
192
- yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
193
  ConversationService.update_by_id(conv.id, conv.to_dict())
194
  except Exception as e:
195
  yield "data:" + json.dumps({"code": 500, "message": str(e),
196
- "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
197
  ensure_ascii=False) + "\n\n"
198
  yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
199
 
@@ -237,7 +238,8 @@ def agent_completion(tenant_id, agent_id):
237
  "dialog_id": cvs.id,
238
  "user_id": req.get("user_id", ""),
239
  "message": [{"role": "assistant", "content": canvas.get_prologue()}],
240
- "source": "agent"
 
241
  }
242
  API4ConversationService.save(**conv)
243
  conv = API4Conversation(**conv)
@@ -246,6 +248,7 @@ def agent_completion(tenant_id, agent_id):
246
  e, conv = API4ConversationService.get_by_id(req["session_id"])
247
  if not e:
248
  return get_error_data_result(message="Session not found!")
 
249
 
250
  messages = conv.message
251
  question = req.get("question")
@@ -267,11 +270,11 @@ def agent_completion(tenant_id, agent_id):
267
  if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
268
  message_id = msg[-1]["id"]
269
 
270
- if "quote" not in req: req["quote"] = False
271
  stream = req.get("stream", True)
272
 
273
  def fillin_conv(ans):
274
  reference = ans["reference"]
 
275
  temp_reference = deepcopy(ans["reference"])
276
  nonlocal conv, message_id
277
  if not conv.reference:
@@ -322,7 +325,7 @@ def agent_completion(tenant_id, agent_id):
322
  def sse():
323
  nonlocal answer, cvs
324
  try:
325
- for ans in canvas.run(stream=True):
326
  if ans.get("running_status"):
327
  yield "data:" + json.dumps({"code": 0, "message": "",
328
  "data": {"answer": ans["content"],
@@ -341,10 +344,10 @@ def agent_completion(tenant_id, agent_id):
341
  canvas.history.append(("assistant", final_ans["content"]))
342
  if final_ans.get("reference"):
343
  canvas.reference.append(final_ans["reference"])
344
- cvs.dsl = json.loads(str(canvas))
345
  API4ConversationService.append_message(conv.id, conv.to_dict())
346
  except Exception as e:
347
- cvs.dsl = json.loads(str(canvas))
348
  API4ConversationService.append_message(conv.id, conv.to_dict())
349
  yield "data:" + json.dumps({"code": 500, "message": str(e),
350
  "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
@@ -364,7 +367,7 @@ def agent_completion(tenant_id, agent_id):
364
  canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
365
  if final_ans.get("reference"):
366
  canvas.reference.append(final_ans["reference"])
367
- cvs.dsl = json.loads(str(canvas))
368
 
369
  result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
370
  fillin_conv(result)
@@ -372,10 +375,9 @@ def agent_completion(tenant_id, agent_id):
372
  rename_field(result)
373
  return get_result(data=result)
374
 
375
-
376
  @manager.route('/chats/<chat_id>/sessions', methods=['GET'])
377
  @token_required
378
- def list_session(chat_id, tenant_id):
379
  if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
380
  return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
381
  id = request.args.get("id")
@@ -387,7 +389,7 @@ def list_session(chat_id, tenant_id):
387
  desc = False
388
  else:
389
  desc = True
390
- convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name)
391
  if not convs:
392
  return get_result(data=[])
393
  for conv in convs:
@@ -429,7 +431,7 @@ def list_session(chat_id, tenant_id):
429
 
430
  @manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
431
  @token_required
432
- def delete(tenant_id, chat_id):
433
  if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
434
  return get_error_data_result(message="You don't own the chat")
435
  req = request.json
@@ -437,22 +439,21 @@ def delete(tenant_id, chat_id):
437
  if not req:
438
  ids = None
439
  else:
440
- ids = req.get("ids")
441
 
442
  if not ids:
443
  conv_list = []
444
  for conv in convs:
445
  conv_list.append(conv.id)
446
  else:
447
- conv_list = ids
448
  for id in conv_list:
449
- conv = ConversationService.query(id=id, dialog_id=chat_id)
450
  if not conv:
451
  return get_error_data_result(message="The chat doesn't own the session")
452
  ConversationService.delete_by_id(id)
453
  return get_result()
454
 
455
-
456
  @manager.route('/sessions/ask', methods=['POST'])
457
  @token_required
458
  def ask_about(tenant_id):
@@ -461,18 +462,17 @@ def ask_about(tenant_id):
461
  return get_error_data_result("`question` is required.")
462
  if not req.get("dataset_ids"):
463
  return get_error_data_result("`dataset_ids` is required.")
464
- if not isinstance(req.get("dataset_ids"), list):
465
  return get_error_data_result("`dataset_ids` should be a list.")
466
- req["kb_ids"] = req.pop("dataset_ids")
467
  for kb_id in req["kb_ids"]:
468
- if not KnowledgebaseService.accessible(kb_id, tenant_id):
469
  return get_error_data_result(f"You don't own the dataset {kb_id}.")
470
  kbs = KnowledgebaseService.query(id=kb_id)
471
  kb = kbs[0]
472
  if kb.chunk_num == 0:
473
  return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
474
  uid = tenant_id
475
-
476
  def stream():
477
  nonlocal req, uid
478
  try:
 
35
 
36
  @manager.route('/chats/<chat_id>/sessions', methods=['POST'])
37
  @token_required
38
+ def create(tenant_id,chat_id):
39
  req = request.json
40
  req["dialog_id"] = chat_id
41
  dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
 
77
  conv = {
78
  "id": get_uuid(),
79
  "dialog_id": cvs.id,
80
+ "user_id": req.get("usr_id","") if isinstance(req, dict) else "",
81
  "message": [{"role": "assistant", "content": canvas.get_prologue()}],
82
+ "source": "agent",
83
+ "dsl":json.loads(cvs.dsl)
84
  }
85
  API4ConversationService.save(**conv)
86
  conv["agent_id"] = conv.pop("dialog_id")
 
89
 
90
  @manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
91
  @token_required
92
+ def update(tenant_id,chat_id,session_id):
93
  req = request.json
94
  req["dialog_id"] = chat_id
95
  conv_id = session_id
96
+ conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
97
  if not conv:
98
  return get_error_data_result(message="Session does not exist")
99
  if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
 
124
  return get_error_data_result(message="`name` can not be empty.")
125
  ConversationService.save(**conv)
126
  e, conv = ConversationService.get_by_id(conv["id"])
127
+ session_id=conv.id
128
  else:
129
  session_id = req.get("session_id")
130
  if not req.get("question"):
131
  return get_error_data_result(message="Please input your question.")
132
+ conv = ConversationService.query(id=session_id,dialog_id=chat_id)
133
  if not conv:
134
  return get_error_data_result(message="Session does not exist")
135
  conv = conv[0]
 
183
  chunk_list.append(new_chunk)
184
  reference["chunks"] = chunk_list
185
  ans["id"] = message_id
186
+ ans["session_id"]=session_id
187
 
188
  def stream():
189
  nonlocal dia, msg, req, conv
190
  try:
191
  for ans in chat(dia, msg, **req):
192
  fillin_conv(ans)
193
+ yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
194
  ConversationService.update_by_id(conv.id, conv.to_dict())
195
  except Exception as e:
196
  yield "data:" + json.dumps({"code": 500, "message": str(e),
197
+ "data": {"answer": "**ERROR**: " + str(e),"reference": []}},
198
  ensure_ascii=False) + "\n\n"
199
  yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
200
 
 
238
  "dialog_id": cvs.id,
239
  "user_id": req.get("user_id", ""),
240
  "message": [{"role": "assistant", "content": canvas.get_prologue()}],
241
+ "source": "agent",
242
+ "dsl": json.loads(cvs.dsl)
243
  }
244
  API4ConversationService.save(**conv)
245
  conv = API4Conversation(**conv)
 
248
  e, conv = API4ConversationService.get_by_id(req["session_id"])
249
  if not e:
250
  return get_error_data_result(message="Session not found!")
251
+ canvas = Canvas(json.dumps(conv.dsl), tenant_id)
252
 
253
  messages = conv.message
254
  question = req.get("question")
 
270
  if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
271
  message_id = msg[-1]["id"]
272
 
 
273
  stream = req.get("stream", True)
274
 
275
  def fillin_conv(ans):
276
  reference = ans["reference"]
277
+ print(reference,flush=True)
278
  temp_reference = deepcopy(ans["reference"])
279
  nonlocal conv, message_id
280
  if not conv.reference:
 
325
  def sse():
326
  nonlocal answer, cvs
327
  try:
328
+ for ans in canvas.run(stream=stream):
329
  if ans.get("running_status"):
330
  yield "data:" + json.dumps({"code": 0, "message": "",
331
  "data": {"answer": ans["content"],
 
344
  canvas.history.append(("assistant", final_ans["content"]))
345
  if final_ans.get("reference"):
346
  canvas.reference.append(final_ans["reference"])
347
+ conv.dsl = json.loads(str(canvas))
348
  API4ConversationService.append_message(conv.id, conv.to_dict())
349
  except Exception as e:
350
+ conv.dsl = json.loads(str(canvas))
351
  API4ConversationService.append_message(conv.id, conv.to_dict())
352
  yield "data:" + json.dumps({"code": 500, "message": str(e),
353
  "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
 
367
  canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
368
  if final_ans.get("reference"):
369
  canvas.reference.append(final_ans["reference"])
370
+ conv.dsl = json.loads(str(canvas))
371
 
372
  result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
373
  fillin_conv(result)
 
375
  rename_field(result)
376
  return get_result(data=result)
377
 
 
378
  @manager.route('/chats/<chat_id>/sessions', methods=['GET'])
379
  @token_required
380
+ def list_session(chat_id,tenant_id):
381
  if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
382
  return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
383
  id = request.args.get("id")
 
389
  desc = False
390
  else:
391
  desc = True
392
+ convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
393
  if not convs:
394
  return get_result(data=[])
395
  for conv in convs:
 
431
 
432
  @manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
433
  @token_required
434
+ def delete(tenant_id,chat_id):
435
  if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
436
  return get_error_data_result(message="You don't own the chat")
437
  req = request.json
 
439
  if not req:
440
  ids = None
441
  else:
442
+ ids=req.get("ids")
443
 
444
  if not ids:
445
  conv_list = []
446
  for conv in convs:
447
  conv_list.append(conv.id)
448
  else:
449
+ conv_list=ids
450
  for id in conv_list:
451
+ conv = ConversationService.query(id=id,dialog_id=chat_id)
452
  if not conv:
453
  return get_error_data_result(message="The chat doesn't own the session")
454
  ConversationService.delete_by_id(id)
455
  return get_result()
456
 
 
457
  @manager.route('/sessions/ask', methods=['POST'])
458
  @token_required
459
  def ask_about(tenant_id):
 
462
  return get_error_data_result("`question` is required.")
463
  if not req.get("dataset_ids"):
464
  return get_error_data_result("`dataset_ids` is required.")
465
+ if not isinstance(req.get("dataset_ids"),list):
466
  return get_error_data_result("`dataset_ids` should be a list.")
467
+ req["kb_ids"]=req.pop("dataset_ids")
468
  for kb_id in req["kb_ids"]:
469
+ if not KnowledgebaseService.accessible(kb_id,tenant_id):
470
  return get_error_data_result(f"You don't own the dataset {kb_id}.")
471
  kbs = KnowledgebaseService.query(id=kb_id)
472
  kb = kbs[0]
473
  if kb.chunk_num == 0:
474
  return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
475
  uid = tenant_id
 
476
  def stream():
477
  nonlocal req, uid
478
  try:
api/db/db_models.py CHANGED
@@ -947,7 +947,7 @@ class API4Conversation(DataBaseModel):
947
  reference = JSONField(null=True, default=[])
948
  tokens = IntegerField(default=0)
949
  source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
950
-
951
  duration = FloatField(default=0, index=True)
952
  round = IntegerField(default=0, index=True)
953
  thumb_up = IntegerField(default=0, index=True)
@@ -1070,3 +1070,10 @@ def migrate_db():
1070
  )
1071
  except Exception:
1072
  pass
 
 
 
 
 
 
 
 
947
  reference = JSONField(null=True, default=[])
948
  tokens = IntegerField(default=0)
949
  source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
950
+ dsl = JSONField(null=True, default={})
951
  duration = FloatField(default=0, index=True)
952
  round = IntegerField(default=0, index=True)
953
  thumb_up = IntegerField(default=0, index=True)
 
1070
  )
1071
  except Exception:
1072
  pass
1073
+ try:
1074
+ migrate(
1075
+ migrator.add_column("api_4_conversation","dsl",JSONField(null=True, default={}))
1076
+ )
1077
+ except Exception:
1078
+ pass
1079
+