KevinHuSh commited on
Commit
50de567
·
1 Parent(s): 631128d

conversation API backend update (#360)

Browse files

### What problem does this PR solve?


Issue link:#345

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/__init__.py CHANGED
@@ -14,11 +14,11 @@
14
  # limitations under the License.
15
  #
16
  import logging
17
- import sys
18
  import os
 
19
  from importlib.util import module_from_spec, spec_from_file_location
20
  from pathlib import Path
21
- from flask import Blueprint, Flask, request
22
  from werkzeug.wrappers.request import Request
23
  from flask_cors import CORS
24
 
@@ -29,9 +29,9 @@ from api.utils import CustomJSONEncoder
29
 
30
  from flask_session import Session
31
  from flask_login import LoginManager
32
- from api.settings import RetCode, SECRET_KEY, stat_logger
33
- from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
34
- from api.utils.api_utils import get_json_result, server_error_response
35
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
36
 
37
  __all__ = ['app']
@@ -54,8 +54,8 @@ app.errorhandler(Exception)(server_error_response)
54
  #app.config["LOGIN_DISABLED"] = True
55
  app.config["SESSION_PERMANENT"] = False
56
  app.config["SESSION_TYPE"] = "filesystem"
57
- #app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024
58
  app.config['MAX_CONTENT_LENGTH'] = os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)
 
59
  Session(app)
60
  login_manager = LoginManager()
61
  login_manager.init_app(app)
@@ -117,4 +117,4 @@ def load_user(web_request):
117
 
118
  @app.teardown_request
119
  def _db_close(exc):
120
- close_connection()
 
14
  # limitations under the License.
15
  #
16
  import logging
 
17
  import os
18
+ import sys
19
  from importlib.util import module_from_spec, spec_from_file_location
20
  from pathlib import Path
21
+ from flask import Blueprint, Flask
22
  from werkzeug.wrappers.request import Request
23
  from flask_cors import CORS
24
 
 
29
 
30
  from flask_session import Session
31
  from flask_login import LoginManager
32
+ from api.settings import SECRET_KEY, stat_logger
33
+ from api.settings import API_VERSION, access_logger
34
+ from api.utils.api_utils import server_error_response
35
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
36
 
37
  __all__ = ['app']
 
54
  #app.config["LOGIN_DISABLED"] = True
55
  app.config["SESSION_PERMANENT"] = False
56
  app.config["SESSION_TYPE"] = "filesystem"
 
57
  app.config['MAX_CONTENT_LENGTH'] = os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)
58
+
59
  Session(app)
60
  login_manager = LoginManager()
61
  login_manager.init_app(app)
 
117
 
118
  @app.teardown_request
119
  def _db_close(exc):
120
+ close_connection()
api/apps/api_app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from datetime import datetime, timedelta
17
+ from flask import request
18
+ from flask_login import login_required, current_user
19
+ from api.db.db_models import APIToken, API4Conversation
20
+ from api.db.services.api_service import APITokenService, API4ConversationService
21
+ from api.db.services.dialog_service import DialogService, chat
22
+ from api.db.services.user_service import UserTenantService
23
+ from api.settings import RetCode
24
+ from api.utils import get_uuid, current_timestamp, datetime_format
25
+ from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
26
+ from itsdangerous import URLSafeTimedSerializer
27
+
28
+
29
+ def generate_confirmation_token(tenent_id):
30
+ serializer = URLSafeTimedSerializer(tenent_id)
31
+ return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
32
+
33
+
34
+ @manager.route('/new_token', methods=['POST'])
35
+ @validate_request("dialog_id")
36
+ @login_required
37
+ def new_token():
38
+ req = request.json
39
+ try:
40
+ tenants = UserTenantService.query(user_id=current_user.id)
41
+ if not tenants:
42
+ return get_data_error_result(retmsg="Tenant not found!")
43
+
44
+ tenant_id = tenants[0].tenant_id
45
+ obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id),
46
+ "dialog_id": req["dialog_id"],
47
+ "create_time": current_timestamp(),
48
+ "create_date": datetime_format(datetime.now()),
49
+ "update_time": None,
50
+ "update_date": None
51
+ }
52
+ if not APITokenService.save(**obj):
53
+ return get_data_error_result(retmsg="Fail to new a dialog!")
54
+
55
+ return get_json_result(data=obj)
56
+ except Exception as e:
57
+ return server_error_response(e)
58
+
59
+
60
+ @manager.route('/token_list', methods=['GET'])
61
+ @login_required
62
+ def token_list():
63
+ try:
64
+ tenants = UserTenantService.query(user_id=current_user.id)
65
+ if not tenants:
66
+ return get_data_error_result(retmsg="Tenant not found!")
67
+
68
+ objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=request.args["dialog_id"])
69
+ return get_json_result(data=[o.to_dict() for o in objs])
70
+ except Exception as e:
71
+ return server_error_response(e)
72
+
73
+
74
+ @manager.route('/rm', methods=['POST'])
75
+ @validate_request("tokens", "tenant_id")
76
+ @login_required
77
+ def rm():
78
+ req = request.json
79
+ try:
80
+ for token in req["tokens"]:
81
+ APITokenService.filter_delete(
82
+ [APIToken.tenant_id == req["tenant_id"], APIToken.token == token])
83
+ return get_json_result(data=True)
84
+ except Exception as e:
85
+ return server_error_response(e)
86
+
87
+
88
+ @manager.route('/stats', methods=['GET'])
89
+ @login_required
90
+ def stats():
91
+ try:
92
+ tenants = UserTenantService.query(user_id=current_user.id)
93
+ if not tenants:
94
+ return get_data_error_result(retmsg="Tenant not found!")
95
+ objs = API4ConversationService.stats(
96
+ tenants[0].tenant_id,
97
+ request.args.get(
98
+ "from_date",
99
+ (datetime.now() -
100
+ timedelta(
101
+ days=7)).strftime("%Y-%m-%d 24:00:00")),
102
+ request.args.get(
103
+ "to_date",
104
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
105
+ res = {
106
+ "pv": [(o["dt"], o["pv"]) for o in objs],
107
+ "uv": [(o["dt"], o["uv"]) for o in objs],
108
+ "speed": [(o["dt"], o["tokens"]/o["duration"]) for o in objs],
109
+ "tokens": [(o["dt"], o["tokens"]/1000.) for o in objs],
110
+ "round": [(o["dt"], o["round"]) for o in objs],
111
+ "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
112
+ }
113
+ return get_json_result(data=res)
114
+ except Exception as e:
115
+ return server_error_response(e)
116
+
117
+
118
+ @manager.route('/new_conversation', methods=['POST'])
119
+ @validate_request("user_id")
120
+ def set_conversation():
121
+ token = request.headers.get('Authorization').split()[1]
122
+ objs = APIToken.query(token=token)
123
+ if not objs:
124
+ return get_json_result(
125
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
126
+ req = request.json
127
+ try:
128
+ e, dia = DialogService.get_by_id(objs[0].dialog_id)
129
+ if not e:
130
+ return get_data_error_result(retmsg="Dialog not found")
131
+ conv = {
132
+ "id": get_uuid(),
133
+ "dialog_id": dia.id,
134
+ "user_id": req["user_id"],
135
+ "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
136
+ }
137
+ API4ConversationService.save(**conv)
138
+ e, conv = API4ConversationService.get_by_id(conv["id"])
139
+ if not e:
140
+ return get_data_error_result(retmsg="Fail to new a conversation!")
141
+ conv = conv.to_dict()
142
+ return get_json_result(data=conv)
143
+ except Exception as e:
144
+ return server_error_response(e)
145
+
146
+
147
+ @manager.route('/completion', methods=['POST'])
148
+ @validate_request("conversation_id", "messages")
149
+ def completion():
150
+ token = request.headers.get('Authorization').split()[1]
151
+ if not APIToken.query(token=token):
152
+ return get_json_result(
153
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
154
+ req = request.json
155
+ e, conv = API4ConversationService.get_by_id(req["conversation_id"])
156
+ if not e:
157
+ return get_data_error_result(retmsg="Conversation not found!")
158
+
159
+ msg = []
160
+ for m in req["messages"]:
161
+ if m["role"] == "system":
162
+ continue
163
+ if m["role"] == "assistant" and not msg:
164
+ continue
165
+ msg.append({"role": m["role"], "content": m["content"]})
166
+
167
+ try:
168
+ conv.message.append(msg[-1])
169
+ e, dia = DialogService.get_by_id(conv.dialog_id)
170
+ if not e:
171
+ return get_data_error_result(retmsg="Dialog not found!")
172
+ del req["conversation_id"]
173
+ del req["messages"]
174
+ ans = chat(dia, msg, **req)
175
+ if not conv.reference:
176
+ conv.reference = []
177
+ conv.reference.append(ans["reference"])
178
+ conv.message.append({"role": "assistant", "content": ans["answer"]})
179
+ API4ConversationService.append_message(conv.id, conv.to_dict())
180
+ APITokenService.APITokenService(token)
181
+ return get_json_result(data=ans)
182
+ except Exception as e:
183
+ return server_error_response(e)
184
+
185
+
186
+ @manager.route('/conversation/<conversation_id>', methods=['GET'])
187
+ # @login_required
188
+ def get(conversation_id):
189
+ try:
190
+ e, conv = API4ConversationService.get_by_id(conversation_id)
191
+ if not e:
192
+ return get_data_error_result(retmsg="Conversation not found!")
193
+
194
+ return get_json_result(data=conv.to_dict())
195
+ except Exception as e:
196
+ return server_error_response(e)
api/apps/chunk_app.py CHANGED
@@ -60,7 +60,7 @@ def list():
60
  for id in sres.ids:
61
  d = {
62
  "chunk_id": id,
63
- "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get(
64
  "content_with_weight", ""),
65
  "doc_id": sres.field[id]["doc_id"],
66
  "docnm_kwd": sres.field[id]["docnm_kwd"],
 
60
  for id in sres.ids:
61
  d = {
62
  "chunk_id": id,
63
+ "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get(
64
  "content_with_weight", ""),
65
  "doc_id": sres.field[id]["doc_id"],
66
  "docnm_kwd": sres.field[id]["docnm_kwd"],
api/apps/conversation_app.py CHANGED
@@ -13,21 +13,12 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
- import re
17
-
18
  from flask import request
19
  from flask_login import login_required
20
- from api.db.services.dialog_service import DialogService, ConversationService
21
- from api.db import LLMType
22
- from api.db.services.knowledgebase_service import KnowledgebaseService
23
- from api.db.services.llm_service import LLMService, LLMBundle, TenantLLMService
24
- from api.settings import access_logger, stat_logger, retrievaler, chat_logger
25
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
  from api.utils import get_uuid
27
  from api.utils.api_utils import get_json_result
28
- from rag.app.resume import forbidden_select_fields4resume
29
- from rag.nlp.search import index_name
30
- from rag.utils import num_tokens_from_string, encoder, rmSpace
31
 
32
 
33
  @manager.route('/set', methods=['POST'])
@@ -110,43 +101,6 @@ def list_convsersation():
110
  return server_error_response(e)
111
 
112
 
113
- def message_fit_in(msg, max_length=4000):
114
- def count():
115
- nonlocal msg
116
- tks_cnts = []
117
- for m in msg:
118
- tks_cnts.append(
119
- {"role": m["role"], "count": num_tokens_from_string(m["content"])})
120
- total = 0
121
- for m in tks_cnts:
122
- total += m["count"]
123
- return total
124
-
125
- c = count()
126
- if c < max_length:
127
- return c, msg
128
-
129
- msg_ = [m for m in msg[:-1] if m["role"] == "system"]
130
- msg_.append(msg[-1])
131
- msg = msg_
132
- c = count()
133
- if c < max_length:
134
- return c, msg
135
-
136
- ll = num_tokens_from_string(msg_[0].content)
137
- l = num_tokens_from_string(msg_[-1].content)
138
- if ll / (ll + l) > 0.8:
139
- m = msg_[0].content
140
- m = encoder.decode(encoder.encode(m)[:max_length - l])
141
- msg[0].content = m
142
- return max_length, msg
143
-
144
- m = msg_[1].content
145
- m = encoder.decode(encoder.encode(m)[:max_length - l])
146
- msg[1].content = m
147
- return max_length, msg
148
-
149
-
150
  @manager.route('/completion', methods=['POST'])
151
  @login_required
152
  @validate_request("conversation_id", "messages")
@@ -179,209 +133,3 @@ def completion():
179
  except Exception as e:
180
  return server_error_response(e)
181
 
182
-
183
- def chat(dialog, messages, **kwargs):
184
- assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
185
- llm = LLMService.query(llm_name=dialog.llm_id)
186
- if not llm:
187
- llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
188
- if not llm:
189
- raise LookupError("LLM(%s) not found" % dialog.llm_id)
190
- max_tokens = 1024
191
- else: max_tokens = llm[0].max_tokens
192
- questions = [m["content"] for m in messages if m["role"] == "user"]
193
- embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
194
- chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
195
-
196
- prompt_config = dialog.prompt_config
197
- field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
198
- # try to use sql if field mapping is good to go
199
- if field_map:
200
- chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
201
- ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
202
- if ans: return ans
203
-
204
- for p in prompt_config["parameters"]:
205
- if p["key"] == "knowledge":
206
- continue
207
- if p["key"] not in kwargs and not p["optional"]:
208
- raise KeyError("Miss parameter: " + p["key"])
209
- if p["key"] not in kwargs:
210
- prompt_config["system"] = prompt_config["system"].replace(
211
- "{%s}" % p["key"], " ")
212
-
213
- for _ in range(len(questions) // 2):
214
- questions.append(questions[-1])
215
- if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
216
- kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
217
- else:
218
- kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
219
- dialog.similarity_threshold,
220
- dialog.vector_similarity_weight, top=1024, aggs=False)
221
- knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
222
- chat_logger.info(
223
- "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
224
-
225
- if not knowledges and prompt_config.get("empty_response"):
226
- return {
227
- "answer": prompt_config["empty_response"], "reference": kbinfos}
228
-
229
- kwargs["knowledge"] = "\n".join(knowledges)
230
- gen_conf = dialog.llm_setting
231
- msg = [{"role": m["role"], "content": m["content"]}
232
- for m in messages if m["role"] != "system"]
233
- used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
234
- if "max_tokens" in gen_conf:
235
- gen_conf["max_tokens"] = min(
236
- gen_conf["max_tokens"],
237
- max_tokens - used_token_count)
238
- answer = chat_mdl.chat(
239
- prompt_config["system"].format(
240
- **kwargs), msg, gen_conf)
241
- chat_logger.info("User: {}|Assistant: {}".format(
242
- msg[-1]["content"], answer))
243
-
244
- if knowledges and prompt_config.get("quote", True):
245
- answer, idx = retrievaler.insert_citations(answer,
246
- [ck["content_ltks"]
247
- for ck in kbinfos["chunks"]],
248
- [ck["vector"]
249
- for ck in kbinfos["chunks"]],
250
- embd_mdl,
251
- tkweight=1 - dialog.vector_similarity_weight,
252
- vtweight=dialog.vector_similarity_weight)
253
- idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
254
- recall_docs = [
255
- d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
256
- if not recall_docs: recall_docs = kbinfos["doc_aggs"]
257
- kbinfos["doc_aggs"] = recall_docs
258
-
259
- for c in kbinfos["chunks"]:
260
- if c.get("vector"):
261
- del c["vector"]
262
- if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
263
- answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
264
- return {"answer": answer, "reference": kbinfos}
265
-
266
-
267
- def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
268
- sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
269
- user_promt = """
270
- 表名:{};
271
- 数据库表字段说明如下:
272
- {}
273
-
274
- 问题如下:
275
- {}
276
- 请写出SQL, 且只要SQL,不要有其他说明及文字。
277
- """.format(
278
- index_name(tenant_id),
279
- "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
280
- question
281
- )
282
- tried_times = 0
283
-
284
- def get_table():
285
- nonlocal sys_prompt, user_promt, question, tried_times
286
- sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
287
- "temperature": 0.06})
288
- print(user_promt, sql)
289
- chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
290
- sql = re.sub(r"[\r\n]+", " ", sql.lower())
291
- sql = re.sub(r".*select ", "select ", sql.lower())
292
- sql = re.sub(r" +", " ", sql)
293
- sql = re.sub(r"([;;]|```).*", "", sql)
294
- if sql[:len("select ")] != "select ":
295
- return None, None
296
- if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
297
- if sql[:len("select *")] != "select *":
298
- sql = "select doc_id,docnm_kwd," + sql[6:]
299
- else:
300
- flds = []
301
- for k in field_map.keys():
302
- if k in forbidden_select_fields4resume:
303
- continue
304
- if len(flds) > 11:
305
- break
306
- flds.append(k)
307
- sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
308
-
309
- print(f"“{question}” get SQL(refined): {sql}")
310
-
311
- chat_logger.info(f"“{question}” get SQL(refined): {sql}")
312
- tried_times += 1
313
- return retrievaler.sql_retrieval(sql, format="json"), sql
314
-
315
- tbl, sql = get_table()
316
- if tbl is None:
317
- return None
318
- if tbl.get("error") and tried_times <= 2:
319
- user_promt = """
320
- 表名:{};
321
- 数据库表字段说明如下:
322
- {}
323
-
324
- 问题如下:
325
- {}
326
-
327
- 你上一次给出的错误SQL如下:
328
- {}
329
-
330
- 后台报错如下:
331
- {}
332
-
333
- 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
334
- """.format(
335
- index_name(tenant_id),
336
- "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
337
- question, sql, tbl["error"]
338
- )
339
- tbl, sql = get_table()
340
- chat_logger.info("TRY it again: {}".format(sql))
341
-
342
- chat_logger.info("GET table: {}".format(tbl))
343
- print(tbl)
344
- if tbl.get("error") or len(tbl["rows"]) == 0:
345
- return None
346
-
347
- docid_idx = set([ii for ii, c in enumerate(
348
- tbl["columns"]) if c["name"] == "doc_id"])
349
- docnm_idx = set([ii for ii, c in enumerate(
350
- tbl["columns"]) if c["name"] == "docnm_kwd"])
351
- clmn_idx = [ii for ii in range(
352
- len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
353
-
354
- # compose markdown table
355
- clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
356
- tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
357
-
358
- line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
359
- ("|------|" if docid_idx and docid_idx else "")
360
-
361
- rows = ["|" +
362
- "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
363
- "|" for r in tbl["rows"]]
364
- if quota:
365
- rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
366
- else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
367
- rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
368
-
369
- if not docid_idx or not docnm_idx:
370
- chat_logger.warning("SQL missing field: " + sql)
371
- return {
372
- "answer": "\n".join([clmns, line, rows]),
373
- "reference": {"chunks": [], "doc_aggs": []}
374
- }
375
-
376
- docid_idx = list(docid_idx)[0]
377
- docnm_idx = list(docnm_idx)[0]
378
- doc_aggs = {}
379
- for r in tbl["rows"]:
380
- if r[docid_idx] not in doc_aggs:
381
- doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
382
- doc_aggs[r[docid_idx]]["count"] += 1
383
- return {
384
- "answer": "\n".join([clmns, line, rows]),
385
- "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
386
- "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
387
- }
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  from flask import request
17
  from flask_login import login_required
18
+ from api.db.services.dialog_service import DialogService, ConversationService, chat
 
 
 
 
19
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
20
  from api.utils import get_uuid
21
  from api.utils.api_utils import get_json_result
 
 
 
22
 
23
 
24
  @manager.route('/set', methods=['POST'])
 
101
  return server_error_response(e)
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  @manager.route('/completion', methods=['POST'])
105
  @login_required
106
  @validate_request("conversation_id", "messages")
 
133
  except Exception as e:
134
  return server_error_response(e)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/apps/user_app.py CHANGED
@@ -15,7 +15,7 @@
15
  #
16
  import re
17
 
18
- from flask import request, session, redirect, url_for
19
  from werkzeug.security import generate_password_hash, check_password_hash
20
  from flask_login import login_required, current_user, login_user, logout_user
21
 
 
15
  #
16
  import re
17
 
18
+ from flask import request, session, redirect
19
  from werkzeug.security import generate_password_hash, check_password_hash
20
  from flask_login import login_required, current_user, login_user, logout_user
21
 
api/db/db_models.py CHANGED
@@ -728,15 +728,6 @@ class Dialog(DataBaseModel):
728
  db_table = "dialog"
729
 
730
 
731
- # class DialogKb(DataBaseModel):
732
- # dialog_id = CharField(max_length=32, null=False, index=True)
733
- # kb_id = CharField(max_length=32, null=False)
734
- #
735
- # class Meta:
736
- # db_table = "dialog_kb"
737
- # primary_key = CompositeKey('dialog_id', 'kb_id')
738
-
739
-
740
  class Conversation(DataBaseModel):
741
  id = CharField(max_length=32, primary_key=True)
742
  dialog_id = CharField(max_length=32, null=False, index=True)
@@ -748,13 +739,26 @@ class Conversation(DataBaseModel):
748
  db_table = "conversation"
749
 
750
 
751
- """
 
 
 
752
 
753
  class Meta:
754
- db_table = 't_pipeline_component_meta'
755
- indexes = (
756
- (('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
757
- )
758
 
759
 
760
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  db_table = "dialog"
729
 
730
 
 
 
 
 
 
 
 
 
 
731
  class Conversation(DataBaseModel):
732
  id = CharField(max_length=32, primary_key=True)
733
  dialog_id = CharField(max_length=32, null=False, index=True)
 
739
  db_table = "conversation"
740
 
741
 
742
+ class APIToken(DataBaseModel):
743
+ tenant_id = CharField(max_length=32, null=False)
744
+ token = CharField(max_length=255, null=False)
745
+ dialog_id = CharField(max_length=32, null=False, index=True)
746
 
747
  class Meta:
748
+ db_table = "api_token"
749
+ primary_key = CompositeKey('tenant_id', 'token')
 
 
750
 
751
 
752
+ class API4Conversation(DataBaseModel):
753
+ id = CharField(max_length=32, primary_key=True)
754
+ dialog_id = CharField(max_length=32, null=False, index=True)
755
+ user_id = CharField(max_length=255, null=False, help_text="user_id")
756
+ message = JSONField(null=True)
757
+ reference = JSONField(null=True, default=[])
758
+ tokens = IntegerField(default=0)
759
+ duration = FloatField(default=0)
760
+ round = IntegerField(default=0)
761
+ thumb_up = IntegerField(default=0)
762
+
763
+ class Meta:
764
+ db_table = "api_4_conversation"
api/db/services/api_service.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from datetime import datetime
17
+ import peewee
18
+ from api.db.db_models import DB, API4Conversation, APIToken, Dialog
19
+ from api.db.services.common_service import CommonService
20
+ from api.utils import current_timestamp, datetime_format
21
+
22
+
23
+ class APITokenService(CommonService):
24
+ model = APIToken
25
+
26
+ @classmethod
27
+ @DB.connection_context()
28
+ def used(cls, token):
29
+ return cls.model.update({
30
+ "update_time": current_timestamp(),
31
+ "update_date": datetime_format(datetime.now()),
32
+ }).where(
33
+ cls.model.token == token
34
+ )
35
+
36
+
37
+ class API4ConversationService(CommonService):
38
+ model = API4Conversation
39
+
40
+ @classmethod
41
+ @DB.connection_context()
42
+ def append_message(cls, id, conversation):
43
+ cls.model.update_by_id(id, conversation)
44
+ return cls.model.update(round=cls.model.round + 1).where(id=id).execute()
45
+
46
+ @classmethod
47
+ @DB.connection_context()
48
+ def stats(cls, tenant_id, from_date, to_date):
49
+ return cls.model.select(
50
+ cls.model.create_date.truncate("day").alias("dt"),
51
+ peewee.fn.COUNT(
52
+ cls.model.id).alias("pv"),
53
+ peewee.fn.COUNT(
54
+ cls.model.user_id.distinct()).alias("uv"),
55
+ peewee.fn.SUM(
56
+ cls.model.tokens).alias("tokens"),
57
+ peewee.fn.SUM(
58
+ cls.model.duration).alias("duration"),
59
+ peewee.fn.AVG(
60
+ cls.model.round).alias("round"),
61
+ peewee.fn.SUM(
62
+ cls.model.thumb_up).alias("thumb_up")
63
+ ).join(Dialog, on=(cls.model.dialog_id == Dialog.id & Dialog.tenant_id == tenant_id)).where(
64
+ cls.model.create_date >= from_date,
65
+ cls.model.create_date <= to_date
66
+ ).group_by(cls.model.create_date.truncate("day")).dicts()
api/db/services/dialog_service.py CHANGED
@@ -13,8 +13,17 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
 
16
  from api.db.db_models import Dialog, Conversation
17
  from api.db.services.common_service import CommonService
 
 
 
 
 
 
18
 
19
 
20
  class DialogService(CommonService):
@@ -23,3 +32,247 @@ class DialogService(CommonService):
23
 
24
  class ConversationService(CommonService):
25
  model = Conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import re
17
+
18
+ from api.db import LLMType
19
  from api.db.db_models import Dialog, Conversation
20
  from api.db.services.common_service import CommonService
21
+ from api.db.services.knowledgebase_service import KnowledgebaseService
22
+ from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
23
+ from api.settings import chat_logger, retrievaler
24
+ from rag.app.resume import forbidden_select_fields4resume
25
+ from rag.nlp.search import index_name
26
+ from rag.utils import rmSpace, num_tokens_from_string, encoder
27
 
28
 
29
  class DialogService(CommonService):
 
32
 
33
  class ConversationService(CommonService):
34
  model = Conversation
35
+
36
+
37
+ def message_fit_in(msg, max_length=4000):
38
+ def count():
39
+ nonlocal msg
40
+ tks_cnts = []
41
+ for m in msg:
42
+ tks_cnts.append(
43
+ {"role": m["role"], "count": num_tokens_from_string(m["content"])})
44
+ total = 0
45
+ for m in tks_cnts:
46
+ total += m["count"]
47
+ return total
48
+
49
+ c = count()
50
+ if c < max_length:
51
+ return c, msg
52
+
53
+ msg_ = [m for m in msg[:-1] if m["role"] == "system"]
54
+ msg_.append(msg[-1])
55
+ msg = msg_
56
+ c = count()
57
+ if c < max_length:
58
+ return c, msg
59
+
60
+ ll = num_tokens_from_string(msg_[0].content)
61
+ l = num_tokens_from_string(msg_[-1].content)
62
+ if ll / (ll + l) > 0.8:
63
+ m = msg_[0].content
64
+ m = encoder.decode(encoder.encode(m)[:max_length - l])
65
+ msg[0].content = m
66
+ return max_length, msg
67
+
68
+ m = msg_[1].content
69
+ m = encoder.decode(encoder.encode(m)[:max_length - l])
70
+ msg[1].content = m
71
+ return max_length, msg
72
+
73
+
74
+ def chat(dialog, messages, **kwargs):
75
+ assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
76
+ llm = LLMService.query(llm_name=dialog.llm_id)
77
+ if not llm:
78
+ llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
79
+ if not llm:
80
+ raise LookupError("LLM(%s) not found" % dialog.llm_id)
81
+ max_tokens = 1024
82
+ else: max_tokens = llm[0].max_tokens
83
+ questions = [m["content"] for m in messages if m["role"] == "user"]
84
+ embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
85
+ chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
86
+
87
+ prompt_config = dialog.prompt_config
88
+ field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
89
+ # try to use sql if field mapping is good to go
90
+ if field_map:
91
+ chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
92
+ ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
93
+ if ans: return ans
94
+
95
+ for p in prompt_config["parameters"]:
96
+ if p["key"] == "knowledge":
97
+ continue
98
+ if p["key"] not in kwargs and not p["optional"]:
99
+ raise KeyError("Miss parameter: " + p["key"])
100
+ if p["key"] not in kwargs:
101
+ prompt_config["system"] = prompt_config["system"].replace(
102
+ "{%s}" % p["key"], " ")
103
+
104
+ for _ in range(len(questions) // 2):
105
+ questions.append(questions[-1])
106
+ if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
107
+ kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
108
+ else:
109
+ kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
110
+ dialog.similarity_threshold,
111
+ dialog.vector_similarity_weight, top=1024, aggs=False)
112
+ knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
113
+ chat_logger.info(
114
+ "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
115
+
116
+ if not knowledges and prompt_config.get("empty_response"):
117
+ return {
118
+ "answer": prompt_config["empty_response"], "reference": kbinfos}
119
+
120
+ kwargs["knowledge"] = "\n".join(knowledges)
121
+ gen_conf = dialog.llm_setting
122
+ msg = [{"role": m["role"], "content": m["content"]}
123
+ for m in messages if m["role"] != "system"]
124
+ used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
125
+ if "max_tokens" in gen_conf:
126
+ gen_conf["max_tokens"] = min(
127
+ gen_conf["max_tokens"],
128
+ max_tokens - used_token_count)
129
+ answer = chat_mdl.chat(
130
+ prompt_config["system"].format(
131
+ **kwargs), msg, gen_conf)
132
+ chat_logger.info("User: {}|Assistant: {}".format(
133
+ msg[-1]["content"], answer))
134
+
135
+ if knowledges and prompt_config.get("quote", True):
136
+ answer, idx = retrievaler.insert_citations(answer,
137
+ [ck["content_ltks"]
138
+ for ck in kbinfos["chunks"]],
139
+ [ck["vector"]
140
+ for ck in kbinfos["chunks"]],
141
+ embd_mdl,
142
+ tkweight=1 - dialog.vector_similarity_weight,
143
+ vtweight=dialog.vector_similarity_weight)
144
+ idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
145
+ recall_docs = [
146
+ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
147
+ if not recall_docs: recall_docs = kbinfos["doc_aggs"]
148
+ kbinfos["doc_aggs"] = recall_docs
149
+
150
+ for c in kbinfos["chunks"]:
151
+ if c.get("vector"):
152
+ del c["vector"]
153
+ if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
154
+ answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
155
+ return {"answer": answer, "reference": kbinfos}
156
+
157
+
158
+ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
159
+ sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
160
+ user_promt = """
161
+ 表名:{};
162
+ 数据库表字段说明如下:
163
+ {}
164
+
165
+ 问题如下:
166
+ {}
167
+ 请写出SQL, 且只要SQL,不要有其他说明及文字。
168
+ """.format(
169
+ index_name(tenant_id),
170
+ "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
171
+ question
172
+ )
173
+ tried_times = 0
174
+
175
+ def get_table():
176
+ nonlocal sys_prompt, user_promt, question, tried_times
177
+ sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
178
+ "temperature": 0.06})
179
+ print(user_promt, sql)
180
+ chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
181
+ sql = re.sub(r"[\r\n]+", " ", sql.lower())
182
+ sql = re.sub(r".*select ", "select ", sql.lower())
183
+ sql = re.sub(r" +", " ", sql)
184
+ sql = re.sub(r"([;;]|```).*", "", sql)
185
+ if sql[:len("select ")] != "select ":
186
+ return None, None
187
+ if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
188
+ if sql[:len("select *")] != "select *":
189
+ sql = "select doc_id,docnm_kwd," + sql[6:]
190
+ else:
191
+ flds = []
192
+ for k in field_map.keys():
193
+ if k in forbidden_select_fields4resume:
194
+ continue
195
+ if len(flds) > 11:
196
+ break
197
+ flds.append(k)
198
+ sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
199
+
200
+ print(f"“{question}” get SQL(refined): {sql}")
201
+
202
+ chat_logger.info(f"“{question}” get SQL(refined): {sql}")
203
+ tried_times += 1
204
+ return retrievaler.sql_retrieval(sql, format="json"), sql
205
+
206
+ tbl, sql = get_table()
207
+ if tbl is None:
208
+ return None
209
+ if tbl.get("error") and tried_times <= 2:
210
+ user_promt = """
211
+ 表名:{};
212
+ 数据库表字段说明如下:
213
+ {}
214
+
215
+ 问题如下:
216
+ {}
217
+
218
+ 你上一次给出的错误SQL如下:
219
+ {}
220
+
221
+ 后台报错如下:
222
+ {}
223
+
224
+ 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
225
+ """.format(
226
+ index_name(tenant_id),
227
+ "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
228
+ question, sql, tbl["error"]
229
+ )
230
+ tbl, sql = get_table()
231
+ chat_logger.info("TRY it again: {}".format(sql))
232
+
233
+ chat_logger.info("GET table: {}".format(tbl))
234
+ print(tbl)
235
+ if tbl.get("error") or len(tbl["rows"]) == 0:
236
+ return None
237
+
238
+ docid_idx = set([ii for ii, c in enumerate(
239
+ tbl["columns"]) if c["name"] == "doc_id"])
240
+ docnm_idx = set([ii for ii, c in enumerate(
241
+ tbl["columns"]) if c["name"] == "docnm_kwd"])
242
+ clmn_idx = [ii for ii in range(
243
+ len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
244
+
245
+ # compose markdown table
246
+ clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
247
+ tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
248
+
249
+ line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
250
+ ("|------|" if docid_idx and docid_idx else "")
251
+
252
+ rows = ["|" +
253
+ "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
254
+ "|" for r in tbl["rows"]]
255
+ if quota:
256
+ rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
257
+ else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
258
+ rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
259
+
260
+ if not docid_idx or not docnm_idx:
261
+ chat_logger.warning("SQL missing field: " + sql)
262
+ return {
263
+ "answer": "\n".join([clmns, line, rows]),
264
+ "reference": {"chunks": [], "doc_aggs": []}
265
+ }
266
+
267
+ docid_idx = list(docid_idx)[0]
268
+ docnm_idx = list(docnm_idx)[0]
269
+ doc_aggs = {}
270
+ for r in tbl["rows"]:
271
+ if r[docid_idx] not in doc_aggs:
272
+ doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
273
+ doc_aggs[r[docid_idx]]["count"] += 1
274
+ return {
275
+ "answer": "\n".join([clmns, line, rows]),
276
+ "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
277
+ "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
278
+ }
api/db/services/document_service.py CHANGED
@@ -15,7 +15,7 @@
15
  #
16
  from peewee import Expression
17
 
18
- from api.db import TenantPermission, FileType, TaskStatus
19
  from api.db.db_models import DB, Knowledgebase, Tenant
20
  from api.db.db_models import Document
21
  from api.db.services.common_service import CommonService
 
15
  #
16
  from peewee import Expression
17
 
18
+ from api.db import FileType, TaskStatus
19
  from api.db.db_models import DB, Knowledgebase, Tenant
20
  from api.db.db_models import Document
21
  from api.db.services.common_service import CommonService
docs/conversation_api.md ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Conversation API Instruction
2
+
3
+ ## Base URL
4
+ ```buildoutcfg
5
+ https://demo.ragflow.io/v1/
6
+ ```
7
+
8
+ ## Authorization
9
+
10
+ All the APIs are authorized with API-Key. Please keep it save and private. Don't reveal it in any way from the front-end.
11
+ The API-Key should put in the header of request:
12
+ ```buildoutcfg
13
+ Authorization: Bearer {API_KEY}
14
+ ```
15
+
16
+ ## Start a conversation
17
+
18
+ This should be called whenever there's new user coming to chat.
19
+ ### Path: /api/new_conversation
20
+ ### Method: GET
21
+ ### Parameter:
22
+
23
+ | name | type | optional | description|
24
+ |------|-------|----|----|
25
+ | user_id| string | No | It's for identifying user in order to search and calculate statistics.|
26
+
27
+ ### Response
28
+ ```json
29
+ {
30
+ "data": {
31
+ "create_date": "Fri, 12 Apr 2024 17:26:21 GMT",
32
+ "create_time": 1712913981857,
33
+ "dialog_id": "4f0a2e4cb9af11ee9ba20aef05f5e94f",
34
+ "duration": 0.0,
35
+ "id": "b9b2e098f8ae11ee9f45fa163e197198",
36
+ "message": [
37
+ {
38
+ "content": "Hi, I'm your assistant, can I help you?",
39
+ "role": "assistant"
40
+ }
41
+ ],
42
+ "reference": [],
43
+ "tokens": 0,
44
+ "update_date": "Fri, 12 Apr 2024 17:26:21 GMT",
45
+ "update_time": 1712913981857,
46
+ "user_id": "kevinhu"
47
+ },
48
+ "retcode": 0,
49
+ "retmsg": "success"
50
+ }
51
+ ```
52
+ > data['id'] in response should be stored and will be used in every round of following conversation.
53
+
54
+ ## Get history of a conversation
55
+
56
+ ### Path: /api/conversation/\<id\>
57
+ ### Method: GET
58
+ ### Response
59
+ ```json
60
+ {
61
+ "data": {
62
+ "create_date": "Mon, 01 Apr 2024 09:28:42 GMT",
63
+ "create_time": 1711934922220,
64
+ "dialog_id": "df4a4916d7bd11eeaa650242ac180006",
65
+ "id": "2cae30fcefc711ee94140242ac180006",
66
+ "message": [
67
+ {
68
+ "content": "Hi! I'm your assistant, what can I do for you?",
69
+ "role": "assistant"
70
+ },
71
+ {
72
+ "content": "What's the vit score for GPT-4?",
73
+ "role": "user"
74
+ },
75
+ {
76
+ "content": "The ViT Score for GPT-4 in the zero-shot scenario is 0.5058, and in the few-shot scenario, it is 0.6480. ##0$$",
77
+ "role": "assistant"
78
+ },
79
+ {
80
+ "content": "How is the nvlink topology like?",
81
+ "role": "user"
82
+ },
83
+ {
84
+ "content": "NVLink topology refers to the arrangement of connections between GPUs using NVIDIA's NVLink technology. Correct NVLink topology for NVIDIA A100 cards involves connecting one GPU to another through a series of NVLink bridges ##0$$. Each of the three attached bridges spans two PCIe slots, and for optimal performance and balanced bridge topology, all three NVLink bridges should be used when connecting two adjacent A100 cards.\n\nHere's a summary of the correct and incorrect topologies:\n\n- **Correct**: Both GPUs are connected via all three NVLink bridges, ensuring full bandwidth and proper communication.\n- **Incorrect**: Not using all three bridges or having an uneven connection configuration would result in suboptimal performance.\n\nIt's also important to note that for multi-CPU systems, both A100 cards in a bridged pair should be within the same CPU domain, unless each CPU has a single A100 PCIe card, in which case they can be bridged together.",
85
+ "role": "assistant"
86
+ }
87
+ ],
88
+ "user_id": "user name",
89
+ "reference": [
90
+ {
91
+ "chunks": [
92
+ {
93
+ "chunk_id": "d0bc7892c3ec4aeac071544fd56730a8",
94
+ "content_ltks": "tabl 1:openagi task-solv perform under differ set for three closed-sourc llm . boldfac denot the highest score under each learn schema . metric gpt-3.5-turbo claude-2 gpt-4 zero few zero few zero few clip score 0.0 0.0 0.0 0.2543 0.0 0.3055 bert score 0.1914 0.3820 0.2111 0.5038 0.2076 0.6307 vit score 0.2437 0.7497 0.4082 0.5416 0.5058 0.6480 overal 0.1450 0.3772 0.2064 0.4332 0.2378 0.5281",
95
+ "content_with_weight": "<table><caption>Table 1: OpenAGI task-solving performances under different settings for three closed-source LLMs. Boldface denotes the highest score under each learning schema.</caption>\n<tr><th rowspan=2 >Metrics</th><th >GPT-3.5-turbo</th><th></th><th >Claude-2</th><th >GPT-4</th></tr>\n<tr><th >Zero</th><th >Few</th><th >Zero Few</th><th >Zero Few</th></tr>\n<tr><td >CLIP Score</td><td >0.0</td><td >0.0</td><td >0.0 0.2543</td><td >0.0 0.3055</td></tr>\n<tr><td >BERT Score</td><td >0.1914</td><td >0.3820</td><td >0.2111 0.5038</td><td >0.2076 0.6307</td></tr>\n<tr><td >ViT Score</td><td >0.2437</td><td >0.7497</td><td >0.4082 0.5416</td><td >0.5058 0.6480</td></tr>\n<tr><td >Overall</td><td >0.1450</td><td >0.3772</td><td >0.2064 0.4332</td><td >0.2378 0.5281</td></tr>\n</table>",
96
+ "doc_id": "c790da40ea8911ee928e0242ac180005",
97
+ "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf",
98
+ "img_id": "afab9fdad6e511eebdb20242ac180006-d0bc7892c3ec4aeac071544fd56730a8",
99
+ "important_kwd": [],
100
+ "kb_id": "afab9fdad6e511eebdb20242ac180006",
101
+ "positions": [
102
+ [
103
+ 9.0,
104
+ 159.9383341471354,
105
+ 472.1773274739583,
106
+ 223.58013916015625,
107
+ 307.86692301432294
108
+ ]
109
+ ],
110
+ "similarity": 0.7310340654129031,
111
+ "term_similarity": 0.7671974387781668,
112
+ "vector_similarity": 0.40556370512552886
113
+ },
114
+ {
115
+ "chunk_id": "7e2345d440383b756670e1b0f43a7007",
116
+ "content_ltks": "5.5 experiment analysi the main experiment result are tabul in tab . 1 and 2 , showcas the result for closed-sourc and open-sourc llm , respect . the overal perform is calcul a the averag of cllp 8 bert and vit score . here , onli the task descript of the benchmark task are fed into llm(addit inform , such a the input prompt and llm\u2019output , is provid in fig . a.4 and a.5 in supplementari). broadli speak , closed-sourc llm demonstr superior perform on openagi task , with gpt-4 lead the pack under both zero-and few-shot scenario . in the open-sourc categori , llama-2-13b take the lead , consist post top result across variou learn schema--the perform possibl influenc by it larger model size . notabl , open-sourc llm significantli benefit from the tune method , particularli fine-tun and\u2019rltf . these method mark notic enhanc for flan-t5-larg , vicuna-7b , and llama-2-13b when compar with zero-shot and few-shot learn schema . in fact , each of these open-sourc model hit it pinnacl under the rltf approach . conclus , with rltf tune , the perform of llama-2-13b approach that of gpt-3.5 , illustr it potenti .",
117
+ "content_with_weight": "5.5 Experimental Analysis\nThe main experimental results are tabulated in Tab. 1 and 2, showcasing the results for closed-source and open-source LLMs, respectively. The overall performance is calculated as the average of CLlP\n8\nBERT and ViT scores. Here, only the task descriptions of the benchmark tasks are fed into LLMs (additional information, such as the input prompt and LLMs\u2019 outputs, is provided in Fig. A.4 and A.5 in supplementary). Broadly speaking, closed-source LLMs demonstrate superior performance on OpenAGI tasks, with GPT-4 leading the pack under both zero- and few-shot scenarios. In the open-source category, LLaMA-2-13B takes the lead, consistently posting top results across various learning schema--the performance possibly influenced by its larger model size. Notably, open-source LLMs significantly benefit from the tuning methods, particularly Fine-tuning and\u2019 RLTF. These methods mark noticeable enhancements for Flan-T5-Large, Vicuna-7B, and LLaMA-2-13B when compared with zero-shot and few-shot learning schema. In fact, each of these open-source models hits its pinnacle under the RLTF approach. Conclusively, with RLTF tuning, the performance of LLaMA-2-13B approaches that of GPT-3.5, illustrating its potential.",
118
+ "doc_id": "c790da40ea8911ee928e0242ac180005",
119
+ "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf",
120
+ "img_id": "afab9fdad6e511eebdb20242ac180006-7e2345d440383b756670e1b0f43a7007",
121
+ "important_kwd": [],
122
+ "kb_id": "afab9fdad6e511eebdb20242ac180006",
123
+ "positions": [
124
+ [
125
+ 8.0,
126
+ 107.3,
127
+ 508.90000000000003,
128
+ 686.3,
129
+ 697.0
130
+ ],
131
+ ],
132
+ "similarity": 0.6691508616357027,
133
+ "term_similarity": 0.6999011754270821,
134
+ "vector_similarity": 0.39239803751328806
135
+ },
136
+ ],
137
+ "doc_aggs": [
138
+ {
139
+ "count": 8,
140
+ "doc_id": "c790da40ea8911ee928e0242ac180005",
141
+ "doc_name": "OpenAGI When LLM Meets Domain Experts.pdf"
142
+ }
143
+ ],
144
+ "total": 8
145
+ },
146
+ {
147
+ "chunks": [
148
+ {
149
+ "chunk_id": "8c11a1edddb21ad2ae0c43b4a5dcfa62",
150
+ "content_ltks": "nvlink bridg support nvidia\u00aenvlink\u00aei a high-spe point-to-point peer transfer connect , where one gpu can transfer data to and receiv data from one other gpu . the nvidia a100 card support nvlink bridg connect with a singl adjac a100 card . each of the three attach bridg span two pcie slot . to function correctli a well a to provid peak bridg bandwidth , bridg connect with an adjac a100 card must incorpor all three nvlink bridg . wherev an adjac pair of a100 card exist in the server , for best bridg perform and balanc bridg topolog , the a100 pair should be bridg . figur 4 illustr correct and incorrect a100 nvlink connect topolog . nvlink topolog\u2013top view figur 4. correct incorrect correct incorrect for system that featur multipl cpu , both a100 card of a bridg card pair should be within the same cpu domain\u2014that is , under the same cpu\u2019s topolog . ensur thi benefit workload applic perform . the onli except is for dual cpu system wherein each cpu ha a singl a100 pcie card under it;in that case , the two a100 pcie card in the system may be bridg togeth . a100 nvlink speed and bandwidth are given in the follow tabl . tabl 5. a100 nvlink speed and bandwidth paramet valu total nvlink bridg support by nvidia a100 3 total nvlink rx and tx lane support 96 data rate per nvidia a100 nvlink lane(each direct)50 gbp total maximum nvlink bandwidth 600 gbyte per second pb-10137-001_v03|8 nvidia a100 40gb pcie gpu acceler",
151
+ "content_with_weight": "NVLink Bridge Support\nNVIDIA\u00aeNVLink\u00aeis a high-speed point-to-point peer transfer connection, where one GPU can transfer data to and receive data from one other GPU. The NVIDIA A100 card supports NVLink bridge connection with a single adjacent A100 card.\nEach of the three attached bridges spans two PCIe slots. To function correctly as well as to provide peak bridge bandwidth, bridge connection with an adjacent A100 card must incorporate all three NVLink bridges. Wherever an adjacent pair of A100 cards exists in the server, for best bridging performance and balanced bridge topology, the A100 pair should be bridged. Figure 4 illustrates correct and incorrect A100 NVLink connection topologies.\nNVLink Topology \u2013Top Views \nFigure 4. \nCORRECT \nINCORRECT \nCORRECT \nINCORRECT \nFor systems that feature multiple CPUs, both A100 cards of a bridged card pair should be within the same CPU domain\u2014that is, under the same CPU\u2019s topology. Ensuring this benefits workload application performance. The only exception is for dual CPU systems wherein each CPU has a single A100 PCIe card under it; in that case, the two A100 PCIe cards in the system may be bridged together.\nA100 NVLink speed and bandwidth are given in the following table.\n<table><caption>Table 5. A100 NVLink Speed and Bandwidth </caption>\n<tr><th >Parameter </th><th >Value </th></tr>\n<tr><td >Total NVLink bridges supported by NVIDIA A100 </td><td >3 </td></tr>\n<tr><td >Total NVLink Rx and Tx lanes supported </td><td >96 </td></tr>\n<tr><td >Data rate per NVIDIA A100 NVLink lane (each direction)</td><td >50 Gbps </td></tr>\n<tr><td >Total maximum NVLink bandwidth</td><td >600 Gbytes per second </td></tr>\n</table>\nPB-10137-001_v03 |8\nNVIDIA A100 40GB PCIe GPU Accelerator",
152
+ "doc_id": "806d1ed0ea9311ee860a0242ac180005",
153
+ "docnm_kwd": "A100-PCIE-Prduct-Brief.pdf",
154
+ "img_id": "afab9fdad6e511eebdb20242ac180006-8c11a1edddb21ad2ae0c43b4a5dcfa62",
155
+ "important_kwd": [],
156
+ "kb_id": "afab9fdad6e511eebdb20242ac180006",
157
+ "positions": [
158
+ [
159
+ 12.0,
160
+ 84.0,
161
+ 541.3,
162
+ 76.7,
163
+ 96.7
164
+ ],
165
+ ],
166
+ "similarity": 0.3200748779905588,
167
+ "term_similarity": 0.3082244010114718,
168
+ "vector_similarity": 0.42672917080234146
169
+ },
170
+ ],
171
+ "doc_aggs": [
172
+ {
173
+ "count": 1,
174
+ "doc_id": "806d1ed0ea9311ee860a0242ac180005",
175
+ "doc_name": "A100-PCIE-Prduct-Brief.pdf"
176
+ }
177
+ ],
178
+ "total": 3
179
+ }
180
+ ],
181
+ "update_date": "Tue, 02 Apr 2024 09:07:49 GMT",
182
+ "update_time": 1712020069421
183
+ },
184
+ "retcode": 0,
185
+ "retmsg": "success"
186
+ }
187
+ ```
188
+
189
+ - **message**: All the chat history in it.
190
+ - role: user or assistant
191
+ - content: the text content of user or assistant. The citations are in format like: ##0$$. The number in the middle indicate which part in data.reference.chunks it refers to.
192
+
193
+ - **user_id**: This is set by the caller.
194
+ - **reference**: Every item in it refer to the corresponding message in data.message whose role is assistant.
195
+ - chunks
196
+ - content_with_weight: The content of chunk.
197
+ - docnm_kwd: the document name.
198
+ - img_id: the image id of the chunk. It is an optional field only for PDF/pptx/picture. And accessed by 'GET' /document/get/\<id\>.
199
+ - positions: [page_number, [upleft corner(x, y)], [right bottom(x, y)]], the chunk position, only for PDF.
200
+ - similarity: the hybrid similarity.
201
+ - term_similarity: keyword simimlarity
202
+ - vector_similarity: embedding similarity
203
+ - doc_aggs:
204
+ - doc_id: the document can be accessed by 'GET' /document/get/\<id\>
205
+ - doc_name: the file name
206
+ - count: the chunk number hit in this document.
207
+
208
+ ## Chat
209
+
210
+ This will be called to get the answer to users' questions.
211
+
212
+ ### Path: /api/completion
213
+ ### Method: POST
214
+ ### Parameter:
215
+
216
+ | name | type | optional | description|
217
+ |------|-------|----|----|
218
+ | conversation_id| string | No | This is from calling /new_conversation.|
219
+ | messages| json | No | All the conversation history stored here including the latest user's question.|
220
+
221
+ ### Response
222
+ ```json
223
+ {
224
+ "data": {
225
+ "answer": "The ViT Score for GPT-4 in the zero-shot scenario is 0.5058, and in the few-shot scenario, it is 0.6480. ##0$$",
226
+ "reference": {
227
+ "chunks": [
228
+ {
229
+ "chunk_id": "d0bc7892c3ec4aeac071544fd56730a8",
230
+ "content_ltks": "tabl 1:openagi task-solv perform under differ set for three closed-sourc llm . boldfac denot the highest score under each learn schema . metric gpt-3.5-turbo claude-2 gpt-4 zero few zero few zero few clip score 0.0 0.0 0.0 0.2543 0.0 0.3055 bert score 0.1914 0.3820 0.2111 0.5038 0.2076 0.6307 vit score 0.2437 0.7497 0.4082 0.5416 0.5058 0.6480 overal 0.1450 0.3772 0.2064 0.4332 0.2378 0.5281",
231
+ "content_with_weight": "<table><caption>Table 1: OpenAGI task-solving performances under different settings for three closed-source LLMs. Boldface denotes the highest score under each learning schema.</caption>\n<tr><th rowspan=2 >Metrics</th><th >GPT-3.5-turbo</th><th></th><th >Claude-2</th><th >GPT-4</th></tr>\n<tr><th >Zero</th><th >Few</th><th >Zero Few</th><th >Zero Few</th></tr>\n<tr><td >CLIP Score</td><td >0.0</td><td >0.0</td><td >0.0 0.2543</td><td >0.0 0.3055</td></tr>\n<tr><td >BERT Score</td><td >0.1914</td><td >0.3820</td><td >0.2111 0.5038</td><td >0.2076 0.6307</td></tr>\n<tr><td >ViT Score</td><td >0.2437</td><td >0.7497</td><td >0.4082 0.5416</td><td >0.5058 0.6480</td></tr>\n<tr><td >Overall</td><td >0.1450</td><td >0.3772</td><td >0.2064 0.4332</td><td >0.2378 0.5281</td></tr>\n</table>",
232
+ "doc_id": "c790da40ea8911ee928e0242ac180005",
233
+ "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf",
234
+ "img_id": "afab9fdad6e511eebdb20242ac180006-d0bc7892c3ec4aeac071544fd56730a8",
235
+ "important_kwd": [],
236
+ "kb_id": "afab9fdad6e511eebdb20242ac180006",
237
+ "positions": [
238
+ [
239
+ 9.0,
240
+ 159.9383341471354,
241
+ 472.1773274739583,
242
+ 223.58013916015625,
243
+ 307.86692301432294
244
+ ]
245
+ ],
246
+ "similarity": 0.7310340654129031,
247
+ "term_similarity": 0.7671974387781668,
248
+ "vector_similarity": 0.40556370512552886
249
+ },
250
+ {
251
+ "chunk_id": "7e2345d440383b756670e1b0f43a7007",
252
+ "content_ltks": "5.5 experiment analysi the main experiment result are tabul in tab . 1 and 2 , showcas the result for closed-sourc and open-sourc llm , respect . the overal perform is calcul a the averag of cllp 8 bert and vit score . here , onli the task descript of the benchmark task are fed into llm(addit inform , such a the input prompt and llm\u2019output , is provid in fig . a.4 and a.5 in supplementari). broadli speak , closed-sourc llm demonstr superior perform on openagi task , with gpt-4 lead the pack under both zero-and few-shot scenario . in the open-sourc categori , llama-2-13b take the lead , consist post top result across variou learn schema--the perform possibl influenc by it larger model size . notabl , open-sourc llm significantli benefit from the tune method , particularli fine-tun and\u2019rltf . these method mark notic enhanc for flan-t5-larg , vicuna-7b , and llama-2-13b when compar with zero-shot and few-shot learn schema . in fact , each of these open-sourc model hit it pinnacl under the rltf approach . conclus , with rltf tune , the perform of llama-2-13b approach that of gpt-3.5 , illustr it potenti .",
253
+ "content_with_weight": "5.5 Experimental Analysis\nThe main experimental results are tabulated in Tab. 1 and 2, showcasing the results for closed-source and open-source LLMs, respectively. The overall performance is calculated as the average of CLlP\n8\nBERT and ViT scores. Here, only the task descriptions of the benchmark tasks are fed into LLMs (additional information, such as the input prompt and LLMs\u2019 outputs, is provided in Fig. A.4 and A.5 in supplementary). Broadly speaking, closed-source LLMs demonstrate superior performance on OpenAGI tasks, with GPT-4 leading the pack under both zero- and few-shot scenarios. In the open-source category, LLaMA-2-13B takes the lead, consistently posting top results across various learning schema--the performance possibly influenced by its larger model size. Notably, open-source LLMs significantly benefit from the tuning methods, particularly Fine-tuning and\u2019 RLTF. These methods mark noticeable enhancements for Flan-T5-Large, Vicuna-7B, and LLaMA-2-13B when compared with zero-shot and few-shot learning schema. In fact, each of these open-source models hits its pinnacle under the RLTF approach. Conclusively, with RLTF tuning, the performance of LLaMA-2-13B approaches that of GPT-3.5, illustrating its potential.",
254
+ "doc_id": "c790da40ea8911ee928e0242ac180005",
255
+ "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf",
256
+ "img_id": "afab9fdad6e511eebdb20242ac180006-7e2345d440383b756670e1b0f43a7007",
257
+ "important_kwd": [],
258
+ "kb_id": "afab9fdad6e511eebdb20242ac180006",
259
+ "positions": [
260
+ [
261
+ 8.0,
262
+ 107.3,
263
+ 508.90000000000003,
264
+ 686.3,
265
+ 697.0
266
+ ]
267
+ ],
268
+ "similarity": 0.6691508616357027,
269
+ "term_similarity": 0.6999011754270821,
270
+ "vector_similarity": 0.39239803751328806
271
+ }
272
+ ],
273
+ "doc_aggs": {
274
+ "OpenAGI When LLM Meets Domain Experts.pdf": 4
275
+ },
276
+ "total": 8
277
+ }
278
+ },
279
+ "retcode": 0,
280
+ "retmsg": "success"
281
+ }
282
+ ```
283
+
284
+ - **answer**: The replay of the chat bot.
285
+ - **reference**:
286
+ - chunks: Every item in it refer to the corresponding message in answer.
287
+ - content_with_weight: The content of chunk.
288
+ - docnm_kwd: the document name.
289
+ - img_id: the image id of the chunk. It is an optional field only for PDF/pptx/picture. And accessed by 'GET' /document/get/\<id\>.
290
+ - positions: [page_number, [upleft corner(x, y)], [right bottom(x, y)]], the chunk position, only for PDF.
291
+ - similarity: the hybrid similarity.
292
+ - term_similarity: keyword simimlarity
293
+ - vector_similarity: embedding similarity
294
+ - doc_aggs:
295
+ - doc_id: the document can be accessed by 'GET' /document/get/\<id\>
296
+ - doc_name: the file name
297
+ - count: the chunk number hit in this document.
298
+
299
+ ## Get document content or image
300
+
301
+ This is usually used when display content of citation.
302
+ ### Path: /document/get/\<id\>
303
+ ### Method: GET
rag/llm/chat_model.py CHANGED
@@ -49,7 +49,7 @@ class GptTurbo(Base):
49
  if response.choices[0].finish_reason == "length":
50
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
51
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
52
- return ans, response.usage.completion_tokens
53
  except openai.APIError as e:
54
  return "**ERROR**: " + str(e), 0
55
 
@@ -73,7 +73,7 @@ class MoonshotChat(GptTurbo):
73
  if response.choices[0].finish_reason == "length":
74
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
75
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
76
- return ans, response.usage.completion_tokens
77
  except openai.APIError as e:
78
  return "**ERROR**: " + str(e), 0
79
 
@@ -127,7 +127,7 @@ class ZhipuChat(Base):
127
  if response.choices[0].finish_reason == "length":
128
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
129
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
130
- return ans, response.usage.completion_tokens
131
  except Exception as e:
132
  return "**ERROR**: " + str(e), 0
133
 
@@ -153,7 +153,7 @@ class OllamaChat(Base):
153
  options=options
154
  )
155
  ans = response["message"]["content"].strip()
156
- return ans, response["eval_count"]
157
  except Exception as e:
158
  return "**ERROR**: " + str(e), 0
159
 
@@ -175,50 +175,7 @@ class XinferenceChat(Base):
175
  if response.choices[0].finish_reason == "length":
176
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
177
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
178
- return ans, response.usage.completion_tokens
179
  except openai.APIError as e:
180
  return "**ERROR**: " + str(e), 0
181
 
182
-
183
- class LocalLLM(Base):
184
- class RPCProxy:
185
- def __init__(self, host, port):
186
- self.host = host
187
- self.port = int(port)
188
- self.__conn()
189
-
190
- def __conn(self):
191
- from multiprocessing.connection import Client
192
- self._connection = Client(
193
- (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
194
-
195
- def __getattr__(self, name):
196
- import pickle
197
-
198
- def do_rpc(*args, **kwargs):
199
- for _ in range(3):
200
- try:
201
- self._connection.send(
202
- pickle.dumps((name, args, kwargs)))
203
- return pickle.loads(self._connection.recv())
204
- except Exception as e:
205
- self.__conn()
206
- raise Exception("RPC connection lost!")
207
-
208
- return do_rpc
209
-
210
- def __init__(self, *args, **kwargs):
211
- self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
212
-
213
- def chat(self, system, history, gen_conf):
214
- if system:
215
- history.insert(0, {"role": "system", "content": system})
216
- try:
217
- ans = self.client.chat(
218
- history,
219
- gen_conf
220
- )
221
- return ans, num_tokens_from_string(ans)
222
- except Exception as e:
223
- return "**ERROR**: " + str(e), 0
224
-
 
49
  if response.choices[0].finish_reason == "length":
50
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
51
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
52
+ return ans, response.usage.total_tokens
53
  except openai.APIError as e:
54
  return "**ERROR**: " + str(e), 0
55
 
 
73
  if response.choices[0].finish_reason == "length":
74
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
75
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
76
+ return ans, response.usage.total_tokens
77
  except openai.APIError as e:
78
  return "**ERROR**: " + str(e), 0
79
 
 
127
  if response.choices[0].finish_reason == "length":
128
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
129
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
130
+ return ans, response.usage.total_tokens
131
  except Exception as e:
132
  return "**ERROR**: " + str(e), 0
133
 
 
153
  options=options
154
  )
155
  ans = response["message"]["content"].strip()
156
+ return ans, response["eval_count"] + response["prompt_eval_count"]
157
  except Exception as e:
158
  return "**ERROR**: " + str(e), 0
159
 
 
175
  if response.choices[0].finish_reason == "length":
176
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
177
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
178
+ return ans, response.usage.total_tokens
179
  except openai.APIError as e:
180
  return "**ERROR**: " + str(e), 0
181