LiuHua
Feiue
Kevin Hu
commited on
Commit
·
172caf6
1
Parent(s):
0dad3f5
SDK for session (#2312)
Browse files### What problem does this PR solve?
SDK for session
#1102
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: Feiue <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
- api/apps/sdk/assistant.py +27 -16
- api/apps/sdk/dataset.py +5 -1
- api/apps/sdk/session.py +168 -0
- sdk/python/ragflow/modules/chat_assistant.py +18 -3
- sdk/python/ragflow/modules/session.py +64 -0
- sdk/python/ragflow/ragflow.py +2 -3
- sdk/python/test/t_assistant.py +14 -12
- sdk/python/test/t_session.py +27 -0
api/apps/sdk/assistant.py
CHANGED
@@ -16,9 +16,10 @@
|
|
16 |
from flask import request
|
17 |
|
18 |
from api.db import StatusEnum
|
|
|
19 |
from api.db.services.dialog_service import DialogService
|
20 |
-
from api.db.services.document_service import DocumentService
|
21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
22 |
from api.db.services.user_service import TenantService
|
23 |
from api.settings import RetCode
|
24 |
from api.utils import get_uuid
|
@@ -30,7 +31,6 @@ from api.utils.api_utils import get_json_result
|
|
30 |
@token_required
|
31 |
def save(tenant_id):
|
32 |
req = request.json
|
33 |
-
id = req.get("id")
|
34 |
# dataset
|
35 |
if req.get("knowledgebases") == []:
|
36 |
return get_data_error_result(retmsg="knowledgebases can not be empty list")
|
@@ -41,8 +41,8 @@ def save(tenant_id):
|
|
41 |
return get_data_error_result(retmsg="knowledgebase needs id")
|
42 |
if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
|
43 |
return get_data_error_result(retmsg="you do not own the knowledgebase")
|
44 |
-
if not DocumentService.query(kb_id=kb["id"]):
|
45 |
-
|
46 |
kb_list.append(kb["id"])
|
47 |
req["kb_ids"] = kb_list
|
48 |
# llm
|
@@ -72,10 +72,10 @@ def save(tenant_id):
|
|
72 |
req[key] = prompt.pop(key)
|
73 |
req["prompt_config"] = req.pop("prompt")
|
74 |
# create
|
75 |
-
if not
|
76 |
# dataset
|
77 |
if not kb_list:
|
78 |
-
return get_data_error_result(retmsg="
|
79 |
# init
|
80 |
req["id"] = get_uuid()
|
81 |
req["description"] = req.get("description", "A helpful Assistant")
|
@@ -83,7 +83,11 @@ def save(tenant_id):
|
|
83 |
req["top_n"] = req.get("top_n", 6)
|
84 |
req["top_k"] = req.get("top_k", 1024)
|
85 |
req["rerank_id"] = req.get("rerank_id", "")
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
if not req.get("name"):
|
88 |
return get_data_error_result(retmsg="name is required.")
|
89 |
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
@@ -149,14 +153,20 @@ def save(tenant_id):
|
|
149 |
if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value):
|
150 |
return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR)
|
151 |
# prompt
|
|
|
|
|
152 |
e, res = DialogService.get_by_id(req["id"])
|
153 |
res = res.to_json()
|
|
|
|
|
|
|
154 |
if "name" in req:
|
155 |
if not req.get("name"):
|
156 |
return get_data_error_result(retmsg="name is not empty.")
|
157 |
if req["name"].lower() != res["name"].lower() \
|
158 |
-
and len(
|
159 |
-
|
|
|
160 |
if "prompt_config" in req:
|
161 |
res["prompt_config"].update(req["prompt_config"])
|
162 |
for p in res["prompt_config"]["parameters"]:
|
@@ -186,7 +196,7 @@ def delete(tenant_id):
|
|
186 |
if "id" not in req:
|
187 |
return get_data_error_result(retmsg="id is required")
|
188 |
id = req['id']
|
189 |
-
if not DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value):
|
190 |
return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
191 |
|
192 |
temp_dict = {"status": StatusEnum.INVALID.value}
|
@@ -200,21 +210,22 @@ def get(tenant_id):
|
|
200 |
req = request.args
|
201 |
if "id" in req:
|
202 |
id = req["id"]
|
203 |
-
ass = DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value)
|
204 |
if not ass:
|
205 |
return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
206 |
if "name" in req:
|
207 |
name = req["name"]
|
208 |
if ass[0].name != name:
|
209 |
return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR)
|
210 |
-
res=ass[0].to_json()
|
211 |
else:
|
212 |
if "name" in req:
|
213 |
name = req["name"]
|
214 |
-
ass = DialogService.query(name=name, tenant_id=tenant_id,status=StatusEnum.VALID.value)
|
215 |
if not ass:
|
216 |
-
return get_json_result(data=False, retmsg='You do not own the
|
217 |
-
|
|
|
218 |
else:
|
219 |
return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.")
|
220 |
renamed_dict = {}
|
@@ -258,7 +269,7 @@ def list_assistants(tenant_id):
|
|
258 |
reverse=True,
|
259 |
order_by=DialogService.model.create_time)
|
260 |
assts = [d.to_dict() for d in assts]
|
261 |
-
list_assts=[]
|
262 |
renamed_dict = {}
|
263 |
key_mapping = {"parameters": "variables",
|
264 |
"prologue": "opener",
|
|
|
16 |
from flask import request
|
17 |
|
18 |
from api.db import StatusEnum
|
19 |
+
from api.db.db_models import TenantLLM
|
20 |
from api.db.services.dialog_service import DialogService
|
|
|
21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
22 |
+
from api.db.services.llm_service import LLMService, TenantLLMService
|
23 |
from api.db.services.user_service import TenantService
|
24 |
from api.settings import RetCode
|
25 |
from api.utils import get_uuid
|
|
|
31 |
@token_required
|
32 |
def save(tenant_id):
|
33 |
req = request.json
|
|
|
34 |
# dataset
|
35 |
if req.get("knowledgebases") == []:
|
36 |
return get_data_error_result(retmsg="knowledgebases can not be empty list")
|
|
|
41 |
return get_data_error_result(retmsg="knowledgebase needs id")
|
42 |
if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
|
43 |
return get_data_error_result(retmsg="you do not own the knowledgebase")
|
44 |
+
# if not DocumentService.query(kb_id=kb["id"]):
|
45 |
+
# return get_data_error_result(retmsg="There is a invalid knowledgebase")
|
46 |
kb_list.append(kb["id"])
|
47 |
req["kb_ids"] = kb_list
|
48 |
# llm
|
|
|
72 |
req[key] = prompt.pop(key)
|
73 |
req["prompt_config"] = req.pop("prompt")
|
74 |
# create
|
75 |
+
if "id" not in req:
|
76 |
# dataset
|
77 |
if not kb_list:
|
78 |
+
return get_data_error_result(retmsg="knowledgebases are required!")
|
79 |
# init
|
80 |
req["id"] = get_uuid()
|
81 |
req["description"] = req.get("description", "A helpful Assistant")
|
|
|
83 |
req["top_n"] = req.get("top_n", 6)
|
84 |
req["top_k"] = req.get("top_k", 1024)
|
85 |
req["rerank_id"] = req.get("rerank_id", "")
|
86 |
+
if req.get("llm_id"):
|
87 |
+
if not TenantLLMService.query(llm_name=req["llm_id"]):
|
88 |
+
return get_data_error_result(retmsg="the model_name does not exist.")
|
89 |
+
else:
|
90 |
+
req["llm_id"] = tenant.llm_id
|
91 |
if not req.get("name"):
|
92 |
return get_data_error_result(retmsg="name is required.")
|
93 |
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
|
|
153 |
if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value):
|
154 |
return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR)
|
155 |
# prompt
|
156 |
+
if not req["id"]:
|
157 |
+
return get_data_error_result(retmsg="id can not be empty")
|
158 |
e, res = DialogService.get_by_id(req["id"])
|
159 |
res = res.to_json()
|
160 |
+
if "llm_id" in req:
|
161 |
+
if not TenantLLMService.query(llm_name=req["llm_id"]):
|
162 |
+
return get_data_error_result(retmsg="the model_name does not exist.")
|
163 |
if "name" in req:
|
164 |
if not req.get("name"):
|
165 |
return get_data_error_result(retmsg="name is not empty.")
|
166 |
if req["name"].lower() != res["name"].lower() \
|
167 |
+
and len(
|
168 |
+
DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
|
169 |
+
return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.")
|
170 |
if "prompt_config" in req:
|
171 |
res["prompt_config"].update(req["prompt_config"])
|
172 |
for p in res["prompt_config"]["parameters"]:
|
|
|
196 |
if "id" not in req:
|
197 |
return get_data_error_result(retmsg="id is required")
|
198 |
id = req['id']
|
199 |
+
if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
|
200 |
return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
201 |
|
202 |
temp_dict = {"status": StatusEnum.INVALID.value}
|
|
|
210 |
req = request.args
|
211 |
if "id" in req:
|
212 |
id = req["id"]
|
213 |
+
ass = DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value)
|
214 |
if not ass:
|
215 |
return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
216 |
if "name" in req:
|
217 |
name = req["name"]
|
218 |
if ass[0].name != name:
|
219 |
return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR)
|
220 |
+
res = ass[0].to_json()
|
221 |
else:
|
222 |
if "name" in req:
|
223 |
name = req["name"]
|
224 |
+
ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
225 |
if not ass:
|
226 |
+
return get_json_result(data=False, retmsg='You do not own the assistant.',
|
227 |
+
retcode=RetCode.OPERATING_ERROR)
|
228 |
+
res = ass[0].to_json()
|
229 |
else:
|
230 |
return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.")
|
231 |
renamed_dict = {}
|
|
|
269 |
reverse=True,
|
270 |
order_by=DialogService.model.create_time)
|
271 |
assts = [d.to_dict() for d in assts]
|
272 |
+
list_assts = []
|
273 |
renamed_dict = {}
|
274 |
key_mapping = {"parameters": "variables",
|
275 |
"prologue": "opener",
|
api/apps/sdk/dataset.py
CHANGED
@@ -60,7 +60,7 @@ def save(tenant_id):
|
|
60 |
req.update(mapped_keys)
|
61 |
if not KnowledgebaseService.save(**req):
|
62 |
return get_data_error_result(retmsg="Create dataset error.(Database error)")
|
63 |
-
renamed_data={}
|
64 |
e, k = KnowledgebaseService.get_by_id(req["id"])
|
65 |
for key, value in k.to_dict().items():
|
66 |
new_key = key_mapping.get(key, key)
|
@@ -88,6 +88,9 @@ def save(tenant_id):
|
|
88 |
data=False, retmsg='You do not own the dataset.',
|
89 |
retcode=RetCode.OPERATING_ERROR)
|
90 |
|
|
|
|
|
|
|
91 |
e, kb = KnowledgebaseService.get_by_id(req["id"])
|
92 |
|
93 |
if "chunk_count" in req:
|
@@ -108,6 +111,7 @@ def save(tenant_id):
|
|
108 |
retmsg="If chunk count is not 0, parse method is not changable.")
|
109 |
req['parser_id'] = req.pop('parse_method')
|
110 |
if "name" in req:
|
|
|
111 |
if req["name"].lower() != kb.name.lower() \
|
112 |
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
|
113 |
status=StatusEnum.VALID.value)) > 0:
|
|
|
60 |
req.update(mapped_keys)
|
61 |
if not KnowledgebaseService.save(**req):
|
62 |
return get_data_error_result(retmsg="Create dataset error.(Database error)")
|
63 |
+
renamed_data = {}
|
64 |
e, k = KnowledgebaseService.get_by_id(req["id"])
|
65 |
for key, value in k.to_dict().items():
|
66 |
new_key = key_mapping.get(key, key)
|
|
|
88 |
data=False, retmsg='You do not own the dataset.',
|
89 |
retcode=RetCode.OPERATING_ERROR)
|
90 |
|
91 |
+
if not req["id"]:
|
92 |
+
return get_data_error_result(
|
93 |
+
retmsg="id can not be empty.")
|
94 |
e, kb = KnowledgebaseService.get_by_id(req["id"])
|
95 |
|
96 |
if "chunk_count" in req:
|
|
|
111 |
retmsg="If chunk count is not 0, parse method is not changable.")
|
112 |
req['parser_id'] = req.pop('parse_method')
|
113 |
if "name" in req:
|
114 |
+
req["name"] = req["name"].strip()
|
115 |
if req["name"].lower() != kb.name.lower() \
|
116 |
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
|
117 |
status=StatusEnum.VALID.value)) > 0:
|
api/apps/sdk/session.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
import json
|
17 |
+
from copy import deepcopy
|
18 |
+
from uuid import uuid4
|
19 |
+
|
20 |
+
from flask import request, Response
|
21 |
+
|
22 |
+
from api.db import StatusEnum
|
23 |
+
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
24 |
+
from api.utils import get_uuid
|
25 |
+
from api.utils.api_utils import get_data_error_result
|
26 |
+
from api.utils.api_utils import get_json_result, token_required
|
27 |
+
|
28 |
+
|
29 |
+
@manager.route('/save', methods=['POST'])
|
30 |
+
@token_required
|
31 |
+
def set_conversation(tenant_id):
|
32 |
+
req = request.json
|
33 |
+
conv_id = req.get("id")
|
34 |
+
if "messages" in req:
|
35 |
+
req["message"] = req.pop("messages")
|
36 |
+
if req["message"]:
|
37 |
+
for message in req["message"]:
|
38 |
+
if "reference" in message:
|
39 |
+
req["reference"] = message.pop("reference")
|
40 |
+
if "assistant_id" in req:
|
41 |
+
req["dialog_id"] = req.pop("assistant_id")
|
42 |
+
if "id" in req:
|
43 |
+
del req["id"]
|
44 |
+
conv = ConversationService.query(id=conv_id)
|
45 |
+
if not conv:
|
46 |
+
return get_data_error_result(retmsg="Session does not exist")
|
47 |
+
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
48 |
+
return get_data_error_result(retmsg="You do not own the session")
|
49 |
+
if req.get("dialog_id"):
|
50 |
+
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
51 |
+
if not dia:
|
52 |
+
return get_data_error_result(retmsg="You do not own the assistant")
|
53 |
+
if "dialog_id" in req and not req.get("dialog_id"):
|
54 |
+
return get_data_error_result(retmsg="assistant_id can not be empty.")
|
55 |
+
if "name" in req and not req.get("name"):
|
56 |
+
return get_data_error_result(retmsg="name can not be empty.")
|
57 |
+
if "message" in req and not req.get("message"):
|
58 |
+
return get_data_error_result(retmsg="messages can not be empty")
|
59 |
+
if not ConversationService.update_by_id(conv_id, req):
|
60 |
+
return get_data_error_result(retmsg="Session updates error")
|
61 |
+
return get_json_result(data=True)
|
62 |
+
|
63 |
+
if not req.get("dialog_id"):
|
64 |
+
return get_data_error_result(retmsg="assistant_id is required.")
|
65 |
+
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
66 |
+
if not dia:
|
67 |
+
return get_data_error_result(retmsg="You do not own the assistant")
|
68 |
+
conv = {
|
69 |
+
"id": get_uuid(),
|
70 |
+
"dialog_id": req["dialog_id"],
|
71 |
+
"name": req.get("name", "New session"),
|
72 |
+
"message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]),
|
73 |
+
"reference": req.get("reference", [])
|
74 |
+
}
|
75 |
+
if not conv.get("name"):
|
76 |
+
return get_data_error_result(retmsg="name can not be empty.")
|
77 |
+
if not conv.get("message"):
|
78 |
+
return get_data_error_result(retmsg="messages can not be empty")
|
79 |
+
ConversationService.save(**conv)
|
80 |
+
e, conv = ConversationService.get_by_id(conv["id"])
|
81 |
+
if not e:
|
82 |
+
return get_data_error_result(retmsg="Fail to new session!")
|
83 |
+
conv = conv.to_dict()
|
84 |
+
conv["messages"] = conv.pop("message")
|
85 |
+
conv["assistant_id"] = conv.pop("dialog_id")
|
86 |
+
for message in conv["messages"]:
|
87 |
+
message["reference"] = conv.get("reference")
|
88 |
+
del conv["reference"]
|
89 |
+
return get_json_result(data=conv)
|
90 |
+
|
91 |
+
|
92 |
+
@manager.route('/completion', methods=['POST'])
|
93 |
+
@token_required
|
94 |
+
def completion(tenant_id):
|
95 |
+
req = request.json
|
96 |
+
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
97 |
+
# {"role": "user", "content": "上海有吗?"}
|
98 |
+
# ]}
|
99 |
+
msg = []
|
100 |
+
question = {
|
101 |
+
"content": req.get("question"),
|
102 |
+
"role": "user",
|
103 |
+
"id": str(uuid4())
|
104 |
+
}
|
105 |
+
req["messages"].append(question)
|
106 |
+
for m in req["messages"]:
|
107 |
+
if m["role"] == "system": continue
|
108 |
+
if m["role"] == "assistant" and not msg: continue
|
109 |
+
m["id"] = m.get("id", str(uuid4()))
|
110 |
+
msg.append(m)
|
111 |
+
message_id = msg[-1].get("id")
|
112 |
+
conv = ConversationService.query(id=req["id"])
|
113 |
+
conv = conv[0]
|
114 |
+
if not conv:
|
115 |
+
return get_data_error_result(retmsg="Session does not exist")
|
116 |
+
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
117 |
+
return get_data_error_result(retmsg="You do not own the session")
|
118 |
+
conv.message = deepcopy(req["messages"])
|
119 |
+
e, dia = DialogService.get_by_id(conv.dialog_id)
|
120 |
+
if not e:
|
121 |
+
return get_data_error_result(retmsg="Dialog not found!")
|
122 |
+
del req["id"]
|
123 |
+
del req["messages"]
|
124 |
+
|
125 |
+
if not conv.reference:
|
126 |
+
conv.reference = []
|
127 |
+
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
128 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
129 |
+
|
130 |
+
def fillin_conv(ans):
|
131 |
+
nonlocal conv, message_id
|
132 |
+
if not conv.reference:
|
133 |
+
conv.reference.append(ans["reference"])
|
134 |
+
else:
|
135 |
+
conv.reference[-1] = ans["reference"]
|
136 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
|
137 |
+
"id": message_id, "prompt": ans.get("prompt", "")}
|
138 |
+
ans["id"] = message_id
|
139 |
+
|
140 |
+
def stream():
|
141 |
+
nonlocal dia, msg, req, conv
|
142 |
+
try:
|
143 |
+
for ans in chat(dia, msg, **req):
|
144 |
+
fillin_conv(ans)
|
145 |
+
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
146 |
+
ConversationService.update_by_id(conv.id, conv.to_dict())
|
147 |
+
except Exception as e:
|
148 |
+
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
149 |
+
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
150 |
+
ensure_ascii=False) + "\n\n"
|
151 |
+
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
152 |
+
|
153 |
+
if req.get("stream", True):
|
154 |
+
resp = Response(stream(), mimetype="text/event-stream")
|
155 |
+
resp.headers.add_header("Cache-control", "no-cache")
|
156 |
+
resp.headers.add_header("Connection", "keep-alive")
|
157 |
+
resp.headers.add_header("X-Accel-Buffering", "no")
|
158 |
+
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
159 |
+
return resp
|
160 |
+
|
161 |
+
else:
|
162 |
+
answer = None
|
163 |
+
for ans in chat(dia, msg, **req):
|
164 |
+
answer = ans
|
165 |
+
fillin_conv(ans)
|
166 |
+
ConversationService.update_by_id(conv.id, conv.to_dict())
|
167 |
+
break
|
168 |
+
return get_json_result(data=answer)
|
sdk/python/ragflow/modules/chat_assistant.py
CHANGED
@@ -1,9 +1,12 @@
|
|
|
|
|
|
1 |
from .base import Base
|
|
|
2 |
|
3 |
|
4 |
class Assistant(Base):
|
5 |
def __init__(self, rag, res_dict):
|
6 |
-
self.id=""
|
7 |
self.name = "assistant"
|
8 |
self.avatar = "path/to/avatar"
|
9 |
self.knowledgebases = ["kb1"]
|
@@ -41,8 +44,8 @@ class Assistant(Base):
|
|
41 |
|
42 |
def save(self) -> bool:
|
43 |
res = self.post('/assistant/save',
|
44 |
-
{"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases":self.knowledgebases,
|
45 |
-
"llm":self.llm.to_json(),"prompt":self.prompt.to_json()
|
46 |
})
|
47 |
res = res.json()
|
48 |
if res.get("retmsg") == "success": return True
|
@@ -54,3 +57,15 @@ class Assistant(Base):
|
|
54 |
res = res.json()
|
55 |
if res.get("retmsg") == "success": return True
|
56 |
raise Exception(res["retmsg"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
from .base import Base
|
4 |
+
from .session import Session, Message
|
5 |
|
6 |
|
7 |
class Assistant(Base):
|
8 |
def __init__(self, rag, res_dict):
|
9 |
+
self.id = ""
|
10 |
self.name = "assistant"
|
11 |
self.avatar = "path/to/avatar"
|
12 |
self.knowledgebases = ["kb1"]
|
|
|
44 |
|
45 |
def save(self) -> bool:
|
46 |
res = self.post('/assistant/save',
|
47 |
+
{"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases,
|
48 |
+
"llm": self.llm.to_json(), "prompt": self.prompt.to_json()
|
49 |
})
|
50 |
res = res.json()
|
51 |
if res.get("retmsg") == "success": return True
|
|
|
57 |
res = res.json()
|
58 |
if res.get("retmsg") == "success": return True
|
59 |
raise Exception(res["retmsg"])
|
60 |
+
|
61 |
+
def create_session(self, name: str = "New session", messages: List[Message] = [
|
62 |
+
{"role": "assistant", "reference": [],
|
63 |
+
"content": "您好,我是您的助手小樱,长得可爱又善良,can I help you?"}]) -> Session:
|
64 |
+
res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, })
|
65 |
+
res = res.json()
|
66 |
+
if res.get("retmsg") == "success":
|
67 |
+
return Session(self.rag, res['data'])
|
68 |
+
raise Exception(res["retmsg"])
|
69 |
+
|
70 |
+
def get_prologue(self):
|
71 |
+
return self.prompt.opener
|
sdk/python/ragflow/modules/session.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from .base import Base
|
4 |
+
|
5 |
+
|
6 |
+
class Session(Base):
|
7 |
+
def __init__(self, rag, res_dict):
|
8 |
+
self.id = None
|
9 |
+
self.name = "New session"
|
10 |
+
self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
|
11 |
+
|
12 |
+
self.assistant_id = None
|
13 |
+
super().__init__(rag, res_dict)
|
14 |
+
|
15 |
+
def chat(self, question: str, stream: bool = False):
|
16 |
+
res = self.post("/session/completion",
|
17 |
+
{"id": self.id, "question": question, "stream": stream, "messages": self.messages})
|
18 |
+
res = res.text
|
19 |
+
response_lines = res.splitlines()
|
20 |
+
message_list = []
|
21 |
+
for line in response_lines:
|
22 |
+
if line.startswith("data:"):
|
23 |
+
json_data = json.loads(line[5:])
|
24 |
+
if json_data["data"] != True:
|
25 |
+
answer = json_data["data"]["answer"]
|
26 |
+
reference = json_data["data"]["reference"]
|
27 |
+
temp_dict = {
|
28 |
+
"content": answer,
|
29 |
+
"role": "assistant",
|
30 |
+
"reference": reference
|
31 |
+
}
|
32 |
+
message = Message(self.rag, temp_dict)
|
33 |
+
message_list.append(message)
|
34 |
+
return message_list
|
35 |
+
|
36 |
+
def save(self):
|
37 |
+
res = self.post("/session/save",
|
38 |
+
{"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages})
|
39 |
+
res = res.json()
|
40 |
+
if res.get("retmsg") == "success": return True
|
41 |
+
raise Exception(res.get("retmsg"))
|
42 |
+
|
43 |
+
class Message(Base):
|
44 |
+
def __init__(self, rag, res_dict):
|
45 |
+
self.content = "您好,我是您的助手小樱,长得可爱又善良,can I help you?"
|
46 |
+
self.reference = []
|
47 |
+
self.role = "assistant"
|
48 |
+
self.prompt=None
|
49 |
+
super().__init__(rag, res_dict)
|
50 |
+
|
51 |
+
|
52 |
+
class Chunk(Base):
|
53 |
+
def __init__(self, rag, res_dict):
|
54 |
+
self.id = None
|
55 |
+
self.content = None
|
56 |
+
self.document_id = None
|
57 |
+
self.document_name = None
|
58 |
+
self.knowledgebase_id = None
|
59 |
+
self.image_id = None
|
60 |
+
self.similarity = None
|
61 |
+
self.vector_similarity = None
|
62 |
+
self.term_similarity = None
|
63 |
+
self.positions = None
|
64 |
+
super().__init__(rag, res_dict)
|
sdk/python/ragflow/ragflow.py
CHANGED
@@ -17,7 +17,6 @@ from typing import List
|
|
17 |
|
18 |
import requests
|
19 |
|
20 |
-
|
21 |
from .modules.chat_assistant import Assistant
|
22 |
from .modules.dataset import DataSet
|
23 |
|
@@ -88,7 +87,7 @@ class RAGFlow:
|
|
88 |
datasets.append(dataset.to_json())
|
89 |
|
90 |
if llm is None:
|
91 |
-
llm = Assistant.LLM(self, {"model_name":
|
92 |
"temperature": 0.1,
|
93 |
"top_p": 0.3,
|
94 |
"presence_penalty": 0.4,
|
@@ -142,4 +141,4 @@ class RAGFlow:
|
|
142 |
for data in res['data']:
|
143 |
result_list.append(Assistant(self, data))
|
144 |
return result_list
|
145 |
-
raise Exception(res["retmsg"])
|
|
|
17 |
|
18 |
import requests
|
19 |
|
|
|
20 |
from .modules.chat_assistant import Assistant
|
21 |
from .modules.dataset import DataSet
|
22 |
|
|
|
87 |
datasets.append(dataset.to_json())
|
88 |
|
89 |
if llm is None:
|
90 |
+
llm = Assistant.LLM(self, {"model_name": None,
|
91 |
"temperature": 0.1,
|
92 |
"top_p": 0.3,
|
93 |
"presence_penalty": 0.4,
|
|
|
141 |
for data in res['data']:
|
142 |
result_list.append(Assistant(self, data))
|
143 |
return result_list
|
144 |
+
raise Exception(res["retmsg"])
|
sdk/python/test/t_assistant.py
CHANGED
@@ -10,10 +10,10 @@ class TestAssistant(TestSdk):
|
|
10 |
Test creating an assistant with success
|
11 |
"""
|
12 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
13 |
-
kb = rag.
|
14 |
-
assistant = rag.create_assistant("
|
15 |
if isinstance(assistant, Assistant):
|
16 |
-
assert assistant.name == "
|
17 |
else:
|
18 |
assert False, f"Failed to create assistant, error: {assistant}"
|
19 |
|
@@ -22,11 +22,11 @@ class TestAssistant(TestSdk):
|
|
22 |
Test updating an assistant with success.
|
23 |
"""
|
24 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
25 |
-
kb = rag.
|
26 |
-
assistant = rag.create_assistant("
|
27 |
if isinstance(assistant, Assistant):
|
28 |
-
assert assistant.name == "
|
29 |
-
assistant.name = '
|
30 |
res = assistant.save()
|
31 |
assert res is True, f"Failed to update assistant, error: {res}"
|
32 |
else:
|
@@ -37,10 +37,10 @@ class TestAssistant(TestSdk):
|
|
37 |
Test deleting an assistant with success
|
38 |
"""
|
39 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
40 |
-
kb = rag.
|
41 |
-
assistant = rag.create_assistant("
|
42 |
if isinstance(assistant, Assistant):
|
43 |
-
assert assistant.name == "
|
44 |
res = assistant.delete()
|
45 |
assert res is True, f"Failed to delete assistant, error: {res}"
|
46 |
else:
|
@@ -61,6 +61,8 @@ class TestAssistant(TestSdk):
|
|
61 |
Test getting an assistant's detail with success
|
62 |
"""
|
63 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
64 |
-
|
|
|
|
|
65 |
assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}."
|
66 |
-
assert assistant.name == "
|
|
|
10 |
Test creating an assistant with success
|
11 |
"""
|
12 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
13 |
+
kb = rag.create_dataset(name="test_create_assistant")
|
14 |
+
assistant = rag.create_assistant("test_create", knowledgebases=[kb])
|
15 |
if isinstance(assistant, Assistant):
|
16 |
+
assert assistant.name == "test_create", "Name does not match."
|
17 |
else:
|
18 |
assert False, f"Failed to create assistant, error: {assistant}"
|
19 |
|
|
|
22 |
Test updating an assistant with success.
|
23 |
"""
|
24 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
25 |
+
kb = rag.create_dataset(name="test_update_assistant")
|
26 |
+
assistant = rag.create_assistant("test_update", knowledgebases=[kb])
|
27 |
if isinstance(assistant, Assistant):
|
28 |
+
assert assistant.name == "test_update", "Name does not match."
|
29 |
+
assistant.name = 'new_assistant'
|
30 |
res = assistant.save()
|
31 |
assert res is True, f"Failed to update assistant, error: {res}"
|
32 |
else:
|
|
|
37 |
Test deleting an assistant with success
|
38 |
"""
|
39 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
40 |
+
kb = rag.create_dataset(name="test_delete_assistant")
|
41 |
+
assistant = rag.create_assistant("test_delete", knowledgebases=[kb])
|
42 |
if isinstance(assistant, Assistant):
|
43 |
+
assert assistant.name == "test_delete", "Name does not match."
|
44 |
res = assistant.delete()
|
45 |
assert res is True, f"Failed to delete assistant, error: {res}"
|
46 |
else:
|
|
|
61 |
Test getting an assistant's detail with success
|
62 |
"""
|
63 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
64 |
+
kb = rag.create_dataset(name="test_get_assistant")
|
65 |
+
rag.create_assistant("test_get_assistant", knowledgebases=[kb])
|
66 |
+
assistant = rag.get_assistant(name="test_get_assistant")
|
67 |
assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}."
|
68 |
+
assert assistant.name == "test_get_assistant", "Name does not match"
|
sdk/python/test/t_session.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ragflow import RAGFlow
|
2 |
+
|
3 |
+
from common import API_KEY, HOST_ADDRESS
|
4 |
+
|
5 |
+
|
6 |
+
class TestChatSession:
|
7 |
+
def test_create_session(self):
|
8 |
+
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
9 |
+
kb = rag.create_dataset(name="test_create_session")
|
10 |
+
assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb])
|
11 |
+
session = assistant.create_session()
|
12 |
+
assert assistant is not None, "Failed to get the assistant."
|
13 |
+
assert session is not None, "Failed to create a session."
|
14 |
+
|
15 |
+
def test_create_chat_with_success(self):
|
16 |
+
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
17 |
+
kb = rag.create_dataset(name="test_create_chat")
|
18 |
+
assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb])
|
19 |
+
session = assistant.create_session()
|
20 |
+
assert session is not None, "Failed to create a session."
|
21 |
+
prologue = assistant.get_prologue()
|
22 |
+
assert isinstance(prologue, str), "Prologue is not a string."
|
23 |
+
assert len(prologue) > 0, "Prologue is empty."
|
24 |
+
question = "What is AI"
|
25 |
+
ans = session.chat(question, stream=True)
|
26 |
+
response = ans[-1].content
|
27 |
+
assert len(response) > 0, "Assistant did not return any response."
|