liuhua
liuhua
commited on
Commit
·
fea9976
1
Parent(s):
11bef16
Add parameters for ask_chat and fix bugs in list_sessions (#4119)
Browse files### What problem does this PR solve?
Add parameters for ask_chat and fix bugs in list_sessions
#4105
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
Co-authored-by: liuhua <[email protected]>
api/apps/sdk/session.py
CHANGED
|
@@ -65,20 +65,24 @@ def create(tenant_id, chat_id):
|
|
| 65 |
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
|
| 66 |
@token_required
|
| 67 |
def create_agent_session(tenant_id, agent_id):
|
|
|
|
| 68 |
e, cvs = UserCanvasService.get_by_id(agent_id)
|
| 69 |
if not e:
|
| 70 |
return get_error_data_result("Agent not found.")
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
if not isinstance(cvs.dsl, str):
|
| 73 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
| 74 |
|
| 75 |
canvas = Canvas(cvs.dsl, tenant_id)
|
| 76 |
if canvas.get_preset_param():
|
| 77 |
-
return get_error_data_result("The agent
|
| 78 |
conv = {
|
| 79 |
"id": get_uuid(),
|
| 80 |
"dialog_id": cvs.id,
|
| 81 |
-
"user_id":
|
| 82 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
| 83 |
"source": "agent",
|
| 84 |
"dsl": json.loads(cvs.dsl)
|
|
@@ -199,17 +203,15 @@ def list_session(tenant_id, chat_id):
|
|
| 199 |
chunks = conv["reference"][chunk_num]["chunks"]
|
| 200 |
for chunk in chunks:
|
| 201 |
new_chunk = {
|
| 202 |
-
"id": chunk
|
| 203 |
-
"content": chunk
|
| 204 |
-
"document_id": chunk
|
| 205 |
-
"document_name": chunk
|
| 206 |
-
"dataset_id": chunk
|
| 207 |
-
"image_id": chunk.get("image_id", ""),
|
| 208 |
-
"
|
| 209 |
-
"vector_similarity": chunk["vector_similarity"],
|
| 210 |
-
"term_similarity": chunk["term_similarity"],
|
| 211 |
-
"positions": chunk["positions"],
|
| 212 |
}
|
|
|
|
| 213 |
chunk_list.append(new_chunk)
|
| 214 |
chunk_num += 1
|
| 215 |
messages[message_num]["reference"] = chunk_list
|
|
@@ -254,16 +256,13 @@ def list_agent_session(tenant_id, agent_id):
|
|
| 254 |
chunks = conv["reference"][chunk_num]["chunks"]
|
| 255 |
for chunk in chunks:
|
| 256 |
new_chunk = {
|
| 257 |
-
"id": chunk
|
| 258 |
-
"content": chunk
|
| 259 |
-
"document_id": chunk
|
| 260 |
-
"document_name": chunk
|
| 261 |
-
"dataset_id": chunk
|
| 262 |
-
"image_id": chunk.get("image_id", ""),
|
| 263 |
-
"
|
| 264 |
-
"vector_similarity": chunk["vector_similarity"],
|
| 265 |
-
"term_similarity": chunk["term_similarity"],
|
| 266 |
-
"positions": chunk["positions"],
|
| 267 |
}
|
| 268 |
chunk_list.append(new_chunk)
|
| 269 |
chunk_num += 1
|
|
|
|
| 65 |
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
|
| 66 |
@token_required
|
| 67 |
def create_agent_session(tenant_id, agent_id):
|
| 68 |
+
req = request.json
|
| 69 |
e, cvs = UserCanvasService.get_by_id(agent_id)
|
| 70 |
if not e:
|
| 71 |
return get_error_data_result("Agent not found.")
|
| 72 |
|
| 73 |
+
if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
|
| 74 |
+
return get_error_data_result("You cannot access the agent.")
|
| 75 |
+
|
| 76 |
if not isinstance(cvs.dsl, str):
|
| 77 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
| 78 |
|
| 79 |
canvas = Canvas(cvs.dsl, tenant_id)
|
| 80 |
if canvas.get_preset_param():
|
| 81 |
+
return get_error_data_result("The agent cannot create a session directly")
|
| 82 |
conv = {
|
| 83 |
"id": get_uuid(),
|
| 84 |
"dialog_id": cvs.id,
|
| 85 |
+
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
|
| 86 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
| 87 |
"source": "agent",
|
| 88 |
"dsl": json.loads(cvs.dsl)
|
|
|
|
| 203 |
chunks = conv["reference"][chunk_num]["chunks"]
|
| 204 |
for chunk in chunks:
|
| 205 |
new_chunk = {
|
| 206 |
+
"id": chunk.get("chunk_id", chunk.get("id")),
|
| 207 |
+
"content": chunk.get("content_with_weight", chunk.get("content")),
|
| 208 |
+
"document_id": chunk.get("doc_id", chunk.get("document_id")),
|
| 209 |
+
"document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
|
| 210 |
+
"dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
|
| 211 |
+
"image_id": chunk.get("image_id", chunk.get("img_id")),
|
| 212 |
+
"positions": chunk.get("positions", chunk.get("position_int")),
|
|
|
|
|
|
|
|
|
|
| 213 |
}
|
| 214 |
+
|
| 215 |
chunk_list.append(new_chunk)
|
| 216 |
chunk_num += 1
|
| 217 |
messages[message_num]["reference"] = chunk_list
|
|
|
|
| 256 |
chunks = conv["reference"][chunk_num]["chunks"]
|
| 257 |
for chunk in chunks:
|
| 258 |
new_chunk = {
|
| 259 |
+
"id": chunk.get("chunk_id", chunk.get("id")),
|
| 260 |
+
"content": chunk.get("content_with_weight", chunk.get("content")),
|
| 261 |
+
"document_id": chunk.get("doc_id", chunk.get("document_id")),
|
| 262 |
+
"document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
|
| 263 |
+
"dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
|
| 264 |
+
"image_id": chunk.get("image_id", chunk.get("img_id")),
|
| 265 |
+
"positions": chunk.get("positions", chunk.get("position_int")),
|
|
|
|
|
|
|
|
|
|
| 266 |
}
|
| 267 |
chunk_list.append(new_chunk)
|
| 268 |
chunk_num += 1
|
sdk/python/ragflow_sdk/modules/session.py
CHANGED
|
@@ -17,11 +17,11 @@ class Session(Base):
|
|
| 17 |
self.__session_type = "agent"
|
| 18 |
super().__init__(rag, res_dict)
|
| 19 |
|
| 20 |
-
def ask(self, question,stream=True):
|
| 21 |
if self.__session_type == "agent":
|
| 22 |
res=self._ask_agent(question,stream)
|
| 23 |
elif self.__session_type == "chat":
|
| 24 |
-
res=self._ask_chat(question,stream)
|
| 25 |
for line in res.iter_lines():
|
| 26 |
line = line.decode("utf-8")
|
| 27 |
if line.startswith("{"):
|
|
@@ -45,9 +45,11 @@ class Session(Base):
|
|
| 45 |
yield message
|
| 46 |
|
| 47 |
|
| 48 |
-
def _ask_chat(self, question: str, stream: bool):
|
|
|
|
|
|
|
| 49 |
res = self.post(f"/chats/{self.chat_id}/completions",
|
| 50 |
-
|
| 51 |
return res
|
| 52 |
def _ask_agent(self,question:str,stream:bool):
|
| 53 |
res = self.post(f"/agents/{self.agent_id}/completions",
|
|
|
|
| 17 |
self.__session_type = "agent"
|
| 18 |
super().__init__(rag, res_dict)
|
| 19 |
|
| 20 |
+
def ask(self, question,stream=True,**kwargs):
|
| 21 |
if self.__session_type == "agent":
|
| 22 |
res=self._ask_agent(question,stream)
|
| 23 |
elif self.__session_type == "chat":
|
| 24 |
+
res=self._ask_chat(question,stream,**kwargs)
|
| 25 |
for line in res.iter_lines():
|
| 26 |
line = line.decode("utf-8")
|
| 27 |
if line.startswith("{"):
|
|
|
|
| 45 |
yield message
|
| 46 |
|
| 47 |
|
| 48 |
+
def _ask_chat(self, question: str, stream: bool,**kwargs):
|
| 49 |
+
json_data={"question": question, "stream": True,"session_id":self.id}
|
| 50 |
+
json_data.update(kwargs)
|
| 51 |
res = self.post(f"/chats/{self.chat_id}/completions",
|
| 52 |
+
json_data, stream=stream)
|
| 53 |
return res
|
| 54 |
def _ask_agent(self,question:str,stream:bool):
|
| 55 |
res = self.post(f"/agents/{self.agent_id}/completions",
|