LiuHua Feiue Kevin Hu commited on
Commit
98d8f14
·
1 Parent(s): e07a35a

Add Authorization checks (#2221)

Browse files

### What problem does this PR solve?

Add Authorization checks
#2203

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>

api/apps/canvas_app.py CHANGED
@@ -18,6 +18,7 @@ from functools import partial
18
  from flask import request, Response
19
  from flask_login import login_required, current_user
20
  from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
 
21
  from api.utils import get_uuid
22
  from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
23
  from agent.canvas import Canvas
@@ -43,6 +44,10 @@ def canvas_list():
43
  @login_required
44
  def rm():
45
  for i in request.json["canvas_ids"]:
 
 
 
 
46
  UserCanvasService.delete_by_id(i)
47
  return get_json_result(data=True)
48
 
 
18
  from flask import request, Response
19
  from flask_login import login_required, current_user
20
  from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
21
+ from api.settings import RetCode
22
  from api.utils import get_uuid
23
  from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
24
  from agent.canvas import Canvas
 
44
  @login_required
45
  def rm():
46
  for i in request.json["canvas_ids"]:
47
+ if not UserCanvasService.query(user_id=current_user.id,id=i):
48
+ return get_json_result(
49
+ data=False, retmsg=f'Only owner of canvas authorized for this operation.',
50
+ retcode=RetCode.OPERATING_ERROR)
51
  UserCanvasService.delete_by_id(i)
52
  return get_json_result(data=True)
53
 
api/apps/conversation_app.py CHANGED
@@ -13,16 +13,20 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  from copy import deepcopy
 
 
17
  from flask import request, Response
18
- from flask_login import login_required,current_user
 
 
19
  from api.db.services.dialog_service import DialogService, ConversationService, chat
20
  from api.db.services.llm_service import LLMBundle, TenantService
21
- from api.db import LLMType
22
- from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
23
  from api.utils import get_uuid
24
  from api.utils.api_utils import get_json_result
25
- import json
26
 
27
 
28
  @manager.route('/set', methods=['POST'])
@@ -72,6 +76,14 @@ def get():
72
  e, conv = ConversationService.get_by_id(conv_id)
73
  if not e:
74
  return get_data_error_result(retmsg="Conversation not found!")
 
 
 
 
 
 
 
 
75
  conv = conv.to_dict()
76
  return get_json_result(data=conv)
77
  except Exception as e:
@@ -84,6 +96,17 @@ def rm():
84
  conv_ids = request.json["conversation_ids"]
85
  try:
86
  for cid in conv_ids:
 
 
 
 
 
 
 
 
 
 
 
87
  ConversationService.delete_by_id(cid)
88
  return get_json_result(data=True)
89
  except Exception as e:
@@ -95,6 +118,10 @@ def rm():
95
  def list_convsersation():
96
  dialog_id = request.args["dialog_id"]
97
  try:
 
 
 
 
98
  convs = ConversationService.query(
99
  dialog_id=dialog_id,
100
  order_by=ConversationService.model.create_time,
@@ -107,12 +134,12 @@ def list_convsersation():
107
 
108
  @manager.route('/completion', methods=['POST'])
109
  @login_required
110
- #@validate_request("conversation_id", "messages")
111
  def completion():
112
  req = request.json
113
- #req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
114
  # {"role": "user", "content": "上海有吗?"}
115
- #]}
116
  msg = []
117
  for m in req["messages"]:
118
  if m["role"] == "system":
@@ -141,7 +168,8 @@ def completion():
141
  nonlocal conv, message_id
142
  if not conv.reference:
143
  conv.reference.append(ans["reference"])
144
- else: conv.reference[-1] = ans["reference"]
 
145
  conv.message[-1] = {"role": "assistant", "content": ans["answer"],
146
  "id": message_id, "prompt": ans.get("prompt", "")}
147
  ans["id"] = message_id
@@ -151,13 +179,13 @@ def completion():
151
  try:
152
  for ans in chat(dia, msg, True, **req):
153
  fillin_conv(ans)
154
- yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
155
  ConversationService.update_by_id(conv.id, conv.to_dict())
156
  except Exception as e:
157
  yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
158
- "data": {"answer": "**ERROR**: "+str(e), "reference": []}},
159
  ensure_ascii=False) + "\n\n"
160
- yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
161
 
162
  if req.get("stream", True):
163
  resp = Response(stream(), mimetype="text/event-stream")
@@ -184,33 +212,34 @@ def completion():
184
  def tts():
185
  req = request.json
186
  text = req["text"]
187
-
188
  tenants = TenantService.get_by_user_id(current_user.id)
189
  if not tenants:
190
  return get_data_error_result(retmsg="Tenant not found!")
191
-
192
  tts_id = tenants[0]["tts_id"]
193
  if not tts_id:
194
  return get_data_error_result(retmsg="No default TTS model is set")
195
-
196
  tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
 
197
  def stream_audio():
198
  try:
199
  for chunk in tts_mdl.tts(text):
200
  yield chunk
201
  except Exception as e:
202
  yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
203
- "data": {"answer": "**ERROR**: "+str(e)}},
204
- ensure_ascii=False)).encode('utf-8')
205
 
206
- resp = Response(stream_audio(), mimetype="audio/mpeg")
207
  resp.headers.add_header("Cache-Control", "no-cache")
208
  resp.headers.add_header("Connection", "keep-alive")
209
  resp.headers.add_header("X-Accel-Buffering", "no")
210
-
211
  return resp
212
 
213
-
214
  @manager.route('/delete_msg', methods=['POST'])
215
  @login_required
216
  @validate_request("conversation_id", "message_id")
@@ -224,10 +253,10 @@ def delete_msg():
224
  for i, msg in enumerate(conv["message"]):
225
  if req["message_id"] != msg.get("id", ""):
226
  continue
227
- assert conv["message"][i+1]["id"] == req["message_id"]
228
  conv["message"].pop(i)
229
  conv["message"].pop(i)
230
- conv["reference"].pop(max(0, i//2-1))
231
  break
232
 
233
  ConversationService.update_by_id(conv["id"], conv)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import json
17
  from copy import deepcopy
18
+
19
+ from db.services.user_service import UserTenantService
20
  from flask import request, Response
21
+ from flask_login import login_required, current_user
22
+
23
+ from api.db import LLMType
24
  from api.db.services.dialog_service import DialogService, ConversationService, chat
25
  from api.db.services.llm_service import LLMBundle, TenantService
26
+ from api.settings import RetCode
 
27
  from api.utils import get_uuid
28
  from api.utils.api_utils import get_json_result
29
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
30
 
31
 
32
  @manager.route('/set', methods=['POST'])
 
76
  e, conv = ConversationService.get_by_id(conv_id)
77
  if not e:
78
  return get_data_error_result(retmsg="Conversation not found!")
79
+ tenants = UserTenantService.query(user_id=current_user.id)
80
+ for tenant in tenants:
81
+ if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
82
+ break
83
+ else:
84
+ return get_json_result(
85
+ data=False, retmsg=f'Only owner of conversation authorized for this operation.',
86
+ retcode=RetCode.OPERATING_ERROR)
87
  conv = conv.to_dict()
88
  return get_json_result(data=conv)
89
  except Exception as e:
 
96
  conv_ids = request.json["conversation_ids"]
97
  try:
98
  for cid in conv_ids:
99
+ exist, conv = ConversationService.get_by_id(cid)
100
+ if not exist:
101
+ return get_data_error_result(retmsg="Conversation not found!")
102
+ tenants = UserTenantService.query(user_id=current_user.id)
103
+ for tenant in tenants:
104
+ if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
105
+ break
106
+ else:
107
+ return get_json_result(
108
+ data=False, retmsg=f'Only owner of conversation authorized for this operation.',
109
+ retcode=RetCode.OPERATING_ERROR)
110
  ConversationService.delete_by_id(cid)
111
  return get_json_result(data=True)
112
  except Exception as e:
 
118
  def list_convsersation():
119
  dialog_id = request.args["dialog_id"]
120
  try:
121
+ if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
122
+ return get_json_result(
123
+ data=False, retmsg=f'Only owner of dialog authorized for this operation.',
124
+ retcode=RetCode.OPERATING_ERROR)
125
  convs = ConversationService.query(
126
  dialog_id=dialog_id,
127
  order_by=ConversationService.model.create_time,
 
134
 
135
  @manager.route('/completion', methods=['POST'])
136
  @login_required
137
+ @validate_request("conversation_id", "messages")
138
  def completion():
139
  req = request.json
140
+ # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
141
  # {"role": "user", "content": "上海有吗?"}
142
+ # ]}
143
  msg = []
144
  for m in req["messages"]:
145
  if m["role"] == "system":
 
168
  nonlocal conv, message_id
169
  if not conv.reference:
170
  conv.reference.append(ans["reference"])
171
+ else:
172
+ conv.reference[-1] = ans["reference"]
173
  conv.message[-1] = {"role": "assistant", "content": ans["answer"],
174
  "id": message_id, "prompt": ans.get("prompt", "")}
175
  ans["id"] = message_id
 
179
  try:
180
  for ans in chat(dia, msg, True, **req):
181
  fillin_conv(ans)
182
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
183
  ConversationService.update_by_id(conv.id, conv.to_dict())
184
  except Exception as e:
185
  yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
186
+ "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
187
  ensure_ascii=False) + "\n\n"
188
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
189
 
190
  if req.get("stream", True):
191
  resp = Response(stream(), mimetype="text/event-stream")
 
212
  def tts():
213
  req = request.json
214
  text = req["text"]
215
+
216
  tenants = TenantService.get_by_user_id(current_user.id)
217
  if not tenants:
218
  return get_data_error_result(retmsg="Tenant not found!")
219
+
220
  tts_id = tenants[0]["tts_id"]
221
  if not tts_id:
222
  return get_data_error_result(retmsg="No default TTS model is set")
223
+
224
  tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
225
+
226
  def stream_audio():
227
  try:
228
  for chunk in tts_mdl.tts(text):
229
  yield chunk
230
  except Exception as e:
231
  yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
232
+ "data": {"answer": "**ERROR**: " + str(e)}},
233
+ ensure_ascii=False)).encode('utf-8')
234
 
235
+ resp = Response(stream_audio(), mimetype="audio/mpeg")
236
  resp.headers.add_header("Cache-Control", "no-cache")
237
  resp.headers.add_header("Connection", "keep-alive")
238
  resp.headers.add_header("X-Accel-Buffering", "no")
239
+
240
  return resp
241
 
242
+
243
  @manager.route('/delete_msg', methods=['POST'])
244
  @login_required
245
  @validate_request("conversation_id", "message_id")
 
253
  for i, msg in enumerate(conv["message"]):
254
  if req["message_id"] != msg.get("id", ""):
255
  continue
256
+ assert conv["message"][i + 1]["id"] == req["message_id"]
257
  conv["message"].pop(i)
258
  conv["message"].pop(i)
259
+ conv["reference"].pop(max(0, i // 2 - 1))
260
  break
261
 
262
  ConversationService.update_by_id(conv["id"], conv)
api/apps/dialog_app.py CHANGED
@@ -19,7 +19,8 @@ from flask_login import login_required, current_user
19
  from api.db.services.dialog_service import DialogService
20
  from api.db import StatusEnum
21
  from api.db.services.knowledgebase_service import KnowledgebaseService
22
- from api.db.services.user_service import TenantService
 
23
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
24
  from api.utils import get_uuid
25
  from api.utils.api_utils import get_json_result
@@ -164,9 +165,19 @@ def list_dialogs():
164
  @validate_request("dialog_ids")
165
  def rm():
166
  req = request.json
 
 
167
  try:
168
- DialogService.update_many_by_id(
169
- [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
 
 
 
 
 
 
 
 
170
  return get_json_result(data=True)
171
  except Exception as e:
172
  return server_error_response(e)
 
19
  from api.db.services.dialog_service import DialogService
20
  from api.db import StatusEnum
21
  from api.db.services.knowledgebase_service import KnowledgebaseService
22
+ from api.db.services.user_service import TenantService, UserTenantService
23
+ from api.settings import RetCode
24
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
25
  from api.utils import get_uuid
26
  from api.utils.api_utils import get_json_result
 
165
  @validate_request("dialog_ids")
166
  def rm():
167
  req = request.json
168
+ dialog_list=[]
169
+ tenants = UserTenantService.query(user_id=current_user.id)
170
  try:
171
+ for id in req["dialog_ids"]:
172
+ for tenant in tenants:
173
+ if DialogService.query(tenant_id=tenant.tenant_id, id=id):
174
+ break
175
+ else:
176
+ return get_json_result(
177
+ data=False, retmsg=f'Only owner of dialog authorized for this operation.',
178
+ retcode=RetCode.OPERATING_ERROR)
179
+ dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
180
+ DialogService.update_many_by_id(dialog_list)
181
  return get_json_result(data=True)
182
  except Exception as e:
183
  return server_error_response(e)
api/apps/document_app.py CHANGED
@@ -35,7 +35,7 @@ from api.db.services.file2document_service import File2DocumentService
35
  from api.db.services.file_service import FileService
36
  from api.db.services.llm_service import LLMBundle
37
  from api.db.services.task_service import TaskService, queue_tasks
38
- from api.db.services.user_service import TenantService
39
  from graphrag.mind_map_extractor import MindMapExtractor
40
  from rag.app import naive
41
  from rag.nlp import search
@@ -189,6 +189,15 @@ def list_docs():
189
  if not kb_id:
190
  return get_json_result(
191
  data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
 
 
 
 
 
 
 
 
 
192
  keywords = request.args.get("keywords", "")
193
 
194
  page_number = int(request.args.get("page", 1))
 
35
  from api.db.services.file_service import FileService
36
  from api.db.services.llm_service import LLMBundle
37
  from api.db.services.task_service import TaskService, queue_tasks
38
+ from api.db.services.user_service import TenantService, UserTenantService
39
  from graphrag.mind_map_extractor import MindMapExtractor
40
  from rag.app import naive
41
  from rag.nlp import search
 
189
  if not kb_id:
190
  return get_json_result(
191
  data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
192
+ tenants = UserTenantService.query(user_id=current_user.id)
193
+ for tenant in tenants:
194
+ if KnowledgebaseService.query(
195
+ tenant_id=tenant.tenant_id, id=kb_id):
196
+ break
197
+ else:
198
+ return get_json_result(
199
+ data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
200
+ retcode=RetCode.OPERATING_ERROR)
201
  keywords = request.args.get("keywords", "")
202
 
203
  page_number = int(request.args.get("page", 1))