KevinHuSh commited on
Commit
4a858d3
·
1 Parent(s): fad2ec7

add conversation API (#35)

Browse files
api/apps/chunk_app.py CHANGED
@@ -13,17 +13,13 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
- import hashlib
17
- import re
18
 
19
- import numpy as np
20
  from flask import request
21
  from flask_login import login_required, current_user
22
-
23
- from rag.nlp import search, huqie
24
  from rag.utils import ELASTICSEARCH, rmSpace
25
  from api.db import LLMType
26
- from api.db.services import duplicate_name
27
  from api.db.services.kb_service import KnowledgebaseService
28
  from api.db.services.llm_service import TenantLLMService
29
  from api.db.services.user_service import UserTenantService
@@ -31,8 +27,9 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
31
  from api.db.services.document_service import DocumentService
32
  from api.settings import RetCode
33
  from api.utils.api_utils import get_json_result
 
 
34
 
35
- retrival = search.Dealer(ELASTICSEARCH)
36
 
37
  @manager.route('/list', methods=['POST'])
38
  @login_required
@@ -45,12 +42,14 @@ def list():
45
  question = req.get("keywords", "")
46
  try:
47
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
48
- if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
 
49
  query = {
50
  "doc_ids": [doc_id], "page": page, "size": size, "question": question
51
  }
52
- if "available_int" in req: query["available_int"] = int(req["available_int"])
53
- sres = retrival.search(query, search.index_name(tenant_id))
 
54
  res = {"total": sres.total, "chunks": []}
55
  for id in sres.ids:
56
  d = {
@@ -67,7 +66,7 @@ def list():
67
  except Exception as e:
68
  if str(e).find("not_found") > 0:
69
  return get_json_result(data=False, retmsg=f'Index not found!',
70
- retcode=RetCode.DATA_ERROR)
71
  return server_error_response(e)
72
 
73
 
@@ -79,8 +78,11 @@ def get():
79
  tenants = UserTenantService.query(user_id=current_user.id)
80
  if not tenants:
81
  return get_data_error_result(retmsg="Tenant not found!")
82
- res = ELASTICSEARCH.get(chunk_id, search.index_name(tenants[0].tenant_id))
83
- if not res.get("found"):return server_error_response("Chunk not found")
 
 
 
84
  id = res["_id"]
85
  res = res["_source"]
86
  res["chunk_id"] = id
@@ -90,7 +92,8 @@ def get():
90
  k.append(n)
91
  if re.search(r"(_tks|_ltks)", n):
92
  res[n] = rmSpace(res[n])
93
- for n in k: del res[n]
 
94
 
95
  return get_json_result(data=res)
96
  except Exception as e:
@@ -102,7 +105,8 @@ def get():
102
 
103
  @manager.route('/set', methods=['POST'])
104
  @login_required
105
- @validate_request("doc_id", "chunk_id", "content_ltks", "important_kwd", "docnm_kwd")
 
106
  def set():
107
  req = request.json
108
  d = {"id": req["chunk_id"]}
@@ -110,15 +114,21 @@ def set():
110
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
111
  d["important_kwd"] = req["important_kwd"]
112
  d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
113
- if "available_int" in req: d["available_int"] = req["available_int"]
 
114
 
115
  try:
116
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
117
- if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
118
- embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value)
119
- v, c = embd_mdl.encode([req["docnm_kwd"], req["content_ltks"]])
 
 
 
 
 
120
  v = 0.1 * v[0] + 0.9 * v[1]
121
- d["q_%d_vec"%len(v)] = v.tolist()
122
  ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
123
  return get_json_result(data=True)
124
  except Exception as e:
@@ -132,19 +142,32 @@ def switch():
132
  req = request.json
133
  try:
134
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
135
- if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
 
136
  if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
137
- search.index_name(tenant_id)):
138
  return get_data_error_result(retmsg="Index updating failure")
139
  return get_json_result(data=True)
140
  except Exception as e:
141
  return server_error_response(e)
142
 
143
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  @manager.route('/create', methods=['POST'])
146
  @login_required
147
- @validate_request("doc_id", "content_ltks", "important_kwd")
148
  def create():
149
  req = request.json
150
  md5 = hashlib.md5()
@@ -152,24 +175,27 @@ def create():
152
  chunck_id = md5.hexdigest()
153
  d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])}
154
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
155
- d["important_kwd"] = req["important_kwd"]
156
- d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
157
 
158
  try:
159
  e, doc = DocumentService.get_by_id(req["doc_id"])
160
- if not e: return get_data_error_result(retmsg="Document not found!")
 
161
  d["kb_id"] = [doc.kb_id]
162
  d["docnm_kwd"] = doc.name
163
  d["doc_id"] = doc.id
164
 
165
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
166
- if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
 
167
 
168
- embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value)
 
169
  v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
170
  DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
171
  v = 0.1 * v[0] + 0.9 * v[1]
172
- d["q_%d_vec"%len(v)] = v.tolist()
173
  ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
174
  return get_json_result(data={"chunk_id": chunck_id})
175
  except Exception as e:
@@ -194,44 +220,15 @@ def retrieval_test():
194
  if not e:
195
  return get_data_error_result(retmsg="Knowledgebase not found!")
196
 
197
- embd_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.EMBEDDING.value)
198
- sres = retrival.search({"kb_ids": [kb_id], "doc_ids": doc_ids, "size": top,
199
- "question": question, "vector": True,
200
- "similarity": similarity_threshold},
201
- search.index_name(kb.tenant_id),
202
- embd_mdl)
203
-
204
- sim, tsim, vsim = retrival.rerank(sres, question, 1-vector_similarity_weight, vector_similarity_weight)
205
- idx = np.argsort(sim*-1)
206
- ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
207
- start_idx = (page-1)*size
208
- for i in idx:
209
- ranks["total"] += 1
210
- if sim[i] < similarity_threshold: break
211
- start_idx -= 1
212
- if start_idx >= 0:continue
213
- if len(ranks["chunks"]) == size:continue
214
- id = sres.ids[i]
215
- dnm = sres.field[id]["docnm_kwd"]
216
- d = {
217
- "chunk_id": id,
218
- "content_ltks": sres.field[id]["content_ltks"],
219
- "doc_id": sres.field[id]["doc_id"],
220
- "docnm_kwd": dnm,
221
- "kb_id": sres.field[id]["kb_id"],
222
- "important_kwd": sres.field[id].get("important_kwd", []),
223
- "img_id": sres.field[id].get("img_id", ""),
224
- "similarity": sim[i],
225
- "vector_similarity": vsim[i],
226
- "term_similarity": tsim[i]
227
- }
228
- ranks["chunks"].append(d)
229
- if dnm not in ranks["doc_aggs"]:ranks["doc_aggs"][dnm] = 0
230
- ranks["doc_aggs"][dnm] += 1
231
 
232
  return get_json_result(data=ranks)
233
  except Exception as e:
234
  if str(e).find("not_found") > 0:
235
  return get_json_result(data=False, retmsg=f'Index not found!',
236
- retcode=RetCode.DATA_ERROR)
237
- return server_error_response(e)
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
 
 
17
  from flask import request
18
  from flask_login import login_required, current_user
19
+ from elasticsearch_dsl import Q
20
+ from rag.nlp import search, huqie, retrievaler
21
  from rag.utils import ELASTICSEARCH, rmSpace
22
  from api.db import LLMType
 
23
  from api.db.services.kb_service import KnowledgebaseService
24
  from api.db.services.llm_service import TenantLLMService
25
  from api.db.services.user_service import UserTenantService
 
27
  from api.db.services.document_service import DocumentService
28
  from api.settings import RetCode
29
  from api.utils.api_utils import get_json_result
30
+ import hashlib
31
+ import re
32
 
 
33
 
34
  @manager.route('/list', methods=['POST'])
35
  @login_required
 
42
  question = req.get("keywords", "")
43
  try:
44
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
45
+ if not tenant_id:
46
+ return get_data_error_result(retmsg="Tenant not found!")
47
  query = {
48
  "doc_ids": [doc_id], "page": page, "size": size, "question": question
49
  }
50
+ if "available_int" in req:
51
+ query["available_int"] = int(req["available_int"])
52
+ sres = retrievaler.search(query, search.index_name(tenant_id))
53
  res = {"total": sres.total, "chunks": []}
54
  for id in sres.ids:
55
  d = {
 
66
  except Exception as e:
67
  if str(e).find("not_found") > 0:
68
  return get_json_result(data=False, retmsg=f'Index not found!',
69
+ retcode=RetCode.DATA_ERROR)
70
  return server_error_response(e)
71
 
72
 
 
78
  tenants = UserTenantService.query(user_id=current_user.id)
79
  if not tenants:
80
  return get_data_error_result(retmsg="Tenant not found!")
81
+ res = ELASTICSEARCH.get(
82
+ chunk_id, search.index_name(
83
+ tenants[0].tenant_id))
84
+ if not res.get("found"):
85
+ return server_error_response("Chunk not found")
86
  id = res["_id"]
87
  res = res["_source"]
88
  res["chunk_id"] = id
 
92
  k.append(n)
93
  if re.search(r"(_tks|_ltks)", n):
94
  res[n] = rmSpace(res[n])
95
+ for n in k:
96
+ del res[n]
97
 
98
  return get_json_result(data=res)
99
  except Exception as e:
 
105
 
106
  @manager.route('/set', methods=['POST'])
107
  @login_required
108
+ @validate_request("doc_id", "chunk_id", "content_ltks",
109
+ "important_kwd")
110
  def set():
111
  req = request.json
112
  d = {"id": req["chunk_id"]}
 
114
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
115
  d["important_kwd"] = req["important_kwd"]
116
  d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
117
+ if "available_int" in req:
118
+ d["available_int"] = req["available_int"]
119
 
120
  try:
121
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
122
+ if not tenant_id:
123
+ return get_data_error_result(retmsg="Tenant not found!")
124
+ embd_mdl = TenantLLMService.model_instance(
125
+ tenant_id, LLMType.EMBEDDING.value)
126
+ e, doc = DocumentService.get_by_id(req["doc_id"])
127
+ if not e:
128
+ return get_data_error_result(retmsg="Document not found!")
129
+ v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
130
  v = 0.1 * v[0] + 0.9 * v[1]
131
+ d["q_%d_vec" % len(v)] = v.tolist()
132
  ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
133
  return get_json_result(data=True)
134
  except Exception as e:
 
142
  req = request.json
143
  try:
144
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
145
+ if not tenant_id:
146
+ return get_data_error_result(retmsg="Tenant not found!")
147
  if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
148
+ search.index_name(tenant_id)):
149
  return get_data_error_result(retmsg="Index updating failure")
150
  return get_json_result(data=True)
151
  except Exception as e:
152
  return server_error_response(e)
153
 
154
 
155
+ @manager.route('/rm', methods=['POST'])
156
+ @login_required
157
+ @validate_request("chunk_ids")
158
+ def rm():
159
+ req = request.json
160
+ try:
161
+ if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
162
+ return get_data_error_result(retmsg="Index updating failure")
163
+ return get_json_result(data=True)
164
+ except Exception as e:
165
+ return server_error_response(e)
166
+
167
 
168
  @manager.route('/create', methods=['POST'])
169
  @login_required
170
+ @validate_request("doc_id", "content_ltks")
171
  def create():
172
  req = request.json
173
  md5 = hashlib.md5()
 
175
  chunck_id = md5.hexdigest()
176
  d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])}
177
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
178
+ d["important_kwd"] = req.get("important_kwd", [])
179
+ d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
180
 
181
  try:
182
  e, doc = DocumentService.get_by_id(req["doc_id"])
183
+ if not e:
184
+ return get_data_error_result(retmsg="Document not found!")
185
  d["kb_id"] = [doc.kb_id]
186
  d["docnm_kwd"] = doc.name
187
  d["doc_id"] = doc.id
188
 
189
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
190
+ if not tenant_id:
191
+ return get_data_error_result(retmsg="Tenant not found!")
192
 
193
+ embd_mdl = TenantLLMService.model_instance(
194
+ tenant_id, LLMType.EMBEDDING.value)
195
  v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
196
  DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
197
  v = 0.1 * v[0] + 0.9 * v[1]
198
+ d["q_%d_vec" % len(v)] = v.tolist()
199
  ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
200
  return get_json_result(data={"chunk_id": chunck_id})
201
  except Exception as e:
 
220
  if not e:
221
  return get_data_error_result(retmsg="Knowledgebase not found!")
222
 
223
+ embd_mdl = TenantLLMService.model_instance(
224
+ kb.tenant_id, LLMType.EMBEDDING.value)
225
+ ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
226
+ vector_similarity_weight, top, doc_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  return get_json_result(data=ranks)
229
  except Exception as e:
230
  if str(e).find("not_found") > 0:
231
  return get_json_result(data=False, retmsg=f'Index not found!',
232
+ retcode=RetCode.DATA_ERROR)
233
+ return server_error_response(e)
234
+
api/apps/conversation_app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import re
17
+
18
+ import tiktoken
19
+ from flask import request
20
+ from flask_login import login_required, current_user
21
+ from api.db.services.dialog_service import DialogService, ConversationService
22
+ from api.db import StatusEnum, LLMType
23
+ from api.db.services.kb_service import KnowledgebaseService
24
+ from api.db.services.llm_service import LLMService, TenantLLMService
25
+ from api.db.services.user_service import TenantService
26
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
27
+ from api.utils import get_uuid
28
+ from api.utils.api_utils import get_json_result
29
+ from rag.llm import ChatModel
30
+ from rag.nlp import retrievaler
31
+ from rag.nlp.query import EsQueryer
32
+ from rag.utils import num_tokens_from_string, encoder
33
+
34
+
35
+ @manager.route('/set', methods=['POST'])
36
+ @login_required
37
+ @validate_request("dialog_id")
38
+ def set():
39
+ req = request.json
40
+ conv_id = req.get("conversation_id")
41
+ if conv_id:
42
+ del req["conversation_id"]
43
+ try:
44
+ if not ConversationService.update_by_id(conv_id, req):
45
+ return get_data_error_result(retmsg="Conversation not found!")
46
+ e, conv = ConversationService.get_by_id(conv_id)
47
+ if not e:
48
+ return get_data_error_result(
49
+ retmsg="Fail to update a conversation!")
50
+ conv = conv.to_dict()
51
+ return get_json_result(data=conv)
52
+ except Exception as e:
53
+ return server_error_response(e)
54
+
55
+ try:
56
+ e, dia = DialogService.get_by_id(req["dialog_id"])
57
+ if not e:
58
+ return get_data_error_result(retmsg="Dialog not found")
59
+ conv = {
60
+ "id": get_uuid(),
61
+ "dialog_id": req["dialog_id"],
62
+ "name": "New conversation",
63
+ "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
64
+ }
65
+ ConversationService.save(**conv)
66
+ e, conv = ConversationService.get_by_id(conv["id"])
67
+ if not e:
68
+ return get_data_error_result(retmsg="Fail to new a conversation!")
69
+ conv = conv.to_dict()
70
+ return get_json_result(data=conv)
71
+ except Exception as e:
72
+ return server_error_response(e)
73
+
74
+
75
+ @manager.route('/get', methods=['GET'])
76
+ @login_required
77
+ def get():
78
+ conv_id = request.args["conversation_id"]
79
+ try:
80
+ e, conv = ConversationService.get_by_id(conv_id)
81
+ if not e:
82
+ return get_data_error_result(retmsg="Conversation not found!")
83
+ conv = conv.to_dict()
84
+ return get_json_result(data=conv)
85
+ except Exception as e:
86
+ return server_error_response(e)
87
+
88
+
89
+ @manager.route('/rm', methods=['POST'])
90
+ @login_required
91
+ def rm():
92
+ conv_ids = request.json["conversation_ids"]
93
+ try:
94
+ for cid in conv_ids:
95
+ ConversationService.delete_by_id(cid)
96
+ return get_json_result(data=True)
97
+ except Exception as e:
98
+ return server_error_response(e)
99
+
100
+ @manager.route('/list', methods=['GET'])
101
+ @login_required
102
+ def list():
103
+ dialog_id = request.args["dialog_id"]
104
+ try:
105
+ convs = ConversationService.query(dialog_id=dialog_id)
106
+ convs = [d.to_dict() for d in convs]
107
+ return get_json_result(data=convs)
108
+ except Exception as e:
109
+ return server_error_response(e)
110
+
111
+
112
+ def message_fit_in(msg, max_length=4000):
113
+ def count():
114
+ nonlocal msg
115
+ tks_cnts = []
116
+ for m in msg:tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
117
+ total = 0
118
+ for m in tks_cnts: total += m["count"]
119
+ return total
120
+
121
+ c = count()
122
+ if c < max_length: return c, msg
123
+ msg = [m for m in msg if m.role in ["system", "user"]]
124
+ c = count()
125
+ if c < max_length:return c, msg
126
+ msg_ = [m for m in msg[:-1] if m.role == "system"]
127
+ msg_.append(msg[-1])
128
+ msg = msg_
129
+ c = count()
130
+ if c < max_length:return c, msg
131
+ ll = num_tokens_from_string(msg_[0].content)
132
+ l = num_tokens_from_string(msg_[-1].content)
133
+ if ll/(ll + l) > 0.8:
134
+ m = msg_[0].content
135
+ m = encoder.decode(encoder.encode(m)[:max_length-l])
136
+ msg[0].content = m
137
+ return max_length, msg
138
+
139
+ m = msg_[1].content
140
+ m = encoder.decode(encoder.encode(m)[:max_length-l])
141
+ msg[1].content = m
142
+ return max_length, msg
143
+
144
+
145
+ def chat(dialog, messages, **kwargs):
146
+ assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
147
+ llm = LLMService.query(llm_name=dialog.llm_id)
148
+ if not llm:
149
+ raise LookupError("LLM(%s) not found"%dialog.llm_id)
150
+ llm = llm[0]
151
+ prompt_config = dialog.prompt_config
152
+ for p in prompt_config["parameters"]:
153
+ if p["key"] == "knowledge":continue
154
+ if p["key"] not in kwargs and not p["optional"]:raise KeyError("Miss parameter: " + p["key"])
155
+ if p["key"] not in kwargs:
156
+ prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
157
+
158
+ model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id)
159
+ if not model_config: raise LookupError("LLM(%s) API key not found"%dialog.llm_id)
160
+
161
+ question = messages[-1]["content"]
162
+ embd_mdl = TenantLLMService.model_instance(
163
+ dialog.tenant_id, LLMType.EMBEDDING.value)
164
+ kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
165
+ dialog.vector_similarity_weight, top=1024, aggs=False)
166
+ knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
167
+
168
+ if not knowledges and prompt_config["empty_response"]:
169
+ return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
170
+
171
+ kwargs["knowledge"] = "\n".join(knowledges)
172
+ gen_conf = dialog.llm_setting[dialog.llm_setting_type]
173
+ msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
174
+ used_token_count = message_fit_in(msg, int(llm.max_tokens * 0.97))
175
+ if "max_tokens" in gen_conf:
176
+ gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
177
+ mdl = ChatModel[model_config.llm_factory](model_config["api_key"], dialog.llm_id)
178
+ answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
179
+
180
+ answer = retrievaler.insert_citations(answer,
181
+ [ck["content_ltks"] for ck in kbinfos["chunks"]],
182
+ [ck["vector"] for ck in kbinfos["chunks"]],
183
+ embd_mdl,
184
+ tkweight=1-dialog.vector_similarity_weight,
185
+ vtweight=dialog.vector_similarity_weight)
186
+ return {"answer": answer, "retrieval": kbinfos}
187
+
188
+
189
+ @manager.route('/completion', methods=['POST'])
190
+ @login_required
191
+ @validate_request("dialog_id", "messages")
192
+ def completion():
193
+ req = request.json
194
+ msg = []
195
+ for m in req["messages"]:
196
+ if m["role"] == "system":continue
197
+ if m["role"] == "assistant" and not msg:continue
198
+ msg.append({"role": m["role"], "content": m["content"]})
199
+ try:
200
+ e, dia = DialogService.get_by_id(req["dialog_id"])
201
+ if not e:
202
+ return get_data_error_result(retmsg="Dialog not found!")
203
+ del req["dialog_id"]
204
+ del req["messages"]
205
+ return get_json_result(data=chat(dia, msg, **req))
206
+ except Exception as e:
207
+ return server_error_response(e)
api/apps/dialog_app.py CHANGED
@@ -13,28 +13,16 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
- import hashlib
17
- import re
18
 
19
- import numpy as np
20
  from flask import request
21
  from flask_login import login_required, current_user
22
-
23
  from api.db.services.dialog_service import DialogService
24
- from rag.nlp import search, huqie
25
- from rag.utils import ELASTICSEARCH, rmSpace
26
- from api.db import LLMType, StatusEnum
27
- from api.db.services import duplicate_name
28
  from api.db.services.kb_service import KnowledgebaseService
29
- from api.db.services.llm_service import TenantLLMService
30
- from api.db.services.user_service import UserTenantService, TenantService
31
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
32
  from api.utils import get_uuid
33
- from api.db.services.document_service import DocumentService
34
- from api.settings import RetCode, stat_logger
35
  from api.utils.api_utils import get_json_result
36
- from rag.utils.minio_conn import MINIO
37
- from api.utils.file_utils import filename_type
38
 
39
 
40
  @manager.route('/set', methods=['POST'])
@@ -128,6 +116,7 @@ def set():
128
  except Exception as e:
129
  return server_error_response(e)
130
 
 
131
  @manager.route('/get', methods=['GET'])
132
  @login_required
133
  def get():
@@ -159,5 +148,18 @@ def list():
159
  for d in diags:
160
  d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
161
  return get_json_result(data=diags)
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  except Exception as e:
163
  return server_error_response(e)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
 
 
17
  from flask import request
18
  from flask_login import login_required, current_user
 
19
  from api.db.services.dialog_service import DialogService
20
+ from api.db import StatusEnum
 
 
 
21
  from api.db.services.kb_service import KnowledgebaseService
22
+ from api.db.services.user_service import TenantService
 
23
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
24
  from api.utils import get_uuid
 
 
25
  from api.utils.api_utils import get_json_result
 
 
26
 
27
 
28
  @manager.route('/set', methods=['POST'])
 
116
  except Exception as e:
117
  return server_error_response(e)
118
 
119
+
120
  @manager.route('/get', methods=['GET'])
121
  @login_required
122
  def get():
 
148
  for d in diags:
149
  d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
150
  return get_json_result(data=diags)
151
+ except Exception as e:
152
+ return server_error_response(e)
153
+
154
+
155
+ @manager.route('/rm', methods=['POST'])
156
+ @login_required
157
+ @validate_request("dialog_id")
158
+ def rm():
159
+ req = request.json
160
+ try:
161
+ if not DialogService.update_by_id(req["dialog_id"], {"status": StatusEnum.INVALID.value}):
162
+ return get_data_error_result(retmsg="Dialog not found!")
163
+ return get_json_result(data=True)
164
  except Exception as e:
165
  return server_error_response(e)
api/apps/document_app.py CHANGED
@@ -271,7 +271,7 @@ def change_parser():
271
 
272
 
273
  @manager.route('/image/<image_id>', methods=['GET'])
274
- @login_required
275
  def get_image(image_id):
276
  try:
277
  bkt, nm = image_id.split("-")
 
271
 
272
 
273
  @manager.route('/image/<image_id>', methods=['GET'])
274
+ #@login_required
275
  def get_image(image_id):
276
  try:
277
  bkt, nm = image_id.split("-")
api/apps/kb_app.py CHANGED
@@ -108,7 +108,7 @@ def rm():
108
  if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
109
  return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
110
 
111
- if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
112
  return get_json_result(data=True)
113
  except Exception as e:
114
  return server_error_response(e)
 
108
  if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
109
  return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
110
 
111
+ if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
112
  return get_json_result(data=True)
113
  except Exception as e:
114
  return server_error_response(e)
api/db/__init__.py CHANGED
@@ -20,7 +20,7 @@ from strenum import StrEnum
20
 
21
  class StatusEnum(Enum):
22
  VALID = "1"
23
- IN_VALID = "0"
24
 
25
 
26
  class UserTenantRole(StrEnum):
 
20
 
21
  class StatusEnum(Enum):
22
  VALID = "1"
23
+ INVALID = "0"
24
 
25
 
26
  class UserTenantRole(StrEnum):
api/db/db_models.py CHANGED
@@ -430,6 +430,7 @@ class LLM(DataBaseModel):
430
  llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
431
  model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
432
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
 
433
  tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
434
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
435
 
@@ -467,8 +468,8 @@ class Knowledgebase(DataBaseModel):
467
  doc_num = IntegerField(default=0)
468
  token_num = IntegerField(default=0)
469
  chunk_num = IntegerField(default=0)
470
- similarity_threshold = FloatField(default=0.4)
471
- vector_similarity_weight = FloatField(default=0.3)
472
 
473
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
474
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
@@ -518,6 +519,11 @@ class Dialog(DataBaseModel):
518
  prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
519
  prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
520
  "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
 
 
 
 
 
521
  kb_ids = JSONField(null=False, default=[])
522
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
523
 
 
430
  llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
431
  model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
432
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
433
+ max_tokens = IntegerField(default=0)
434
  tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
435
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
436
 
 
468
  doc_num = IntegerField(default=0)
469
  token_num = IntegerField(default=0)
470
  chunk_num = IntegerField(default=0)
471
+ #similarity_threshold = FloatField(default=0.4)
472
+ #vector_similarity_weight = FloatField(default=0.3)
473
 
474
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
475
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
519
  prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
520
  prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
521
  "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
522
+
523
+ similarity_threshold = FloatField(default=0.4)
524
+ vector_similarity_weight = FloatField(default=0.3)
525
+ top_n = IntegerField(default=6)
526
+
527
  kb_ids = JSONField(null=False, default=[])
528
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
529
 
api/db/init_data.py CHANGED
@@ -62,61 +62,73 @@ def init_llm_factory():
62
  "fid": factory_infos[0]["name"],
63
  "llm_name": "gpt-3.5-turbo",
64
  "tags": "LLM,CHAT,4K",
 
65
  "model_type": LLMType.CHAT.value
66
  },{
67
  "fid": factory_infos[0]["name"],
68
  "llm_name": "gpt-3.5-turbo-16k-0613",
69
  "tags": "LLM,CHAT,16k",
 
70
  "model_type": LLMType.CHAT.value
71
  },{
72
  "fid": factory_infos[0]["name"],
73
  "llm_name": "text-embedding-ada-002",
74
  "tags": "TEXT EMBEDDING,8K",
 
75
  "model_type": LLMType.EMBEDDING.value
76
  },{
77
  "fid": factory_infos[0]["name"],
78
  "llm_name": "whisper-1",
79
  "tags": "SPEECH2TEXT",
 
80
  "model_type": LLMType.SPEECH2TEXT.value
81
  },{
82
  "fid": factory_infos[0]["name"],
83
  "llm_name": "gpt-4",
84
  "tags": "LLM,CHAT,8K",
 
85
  "model_type": LLMType.CHAT.value
86
  },{
87
  "fid": factory_infos[0]["name"],
88
  "llm_name": "gpt-4-32k",
89
  "tags": "LLM,CHAT,32K",
 
90
  "model_type": LLMType.CHAT.value
91
  },{
92
  "fid": factory_infos[0]["name"],
93
  "llm_name": "gpt-4-vision-preview",
94
  "tags": "LLM,CHAT,IMAGE2TEXT",
 
95
  "model_type": LLMType.IMAGE2TEXT.value
96
  },{
97
  "fid": factory_infos[1]["name"],
98
  "llm_name": "qwen-turbo",
99
  "tags": "LLM,CHAT,8K",
 
100
  "model_type": LLMType.CHAT.value
101
  },{
102
  "fid": factory_infos[1]["name"],
103
  "llm_name": "qwen-plus",
104
  "tags": "LLM,CHAT,32K",
 
105
  "model_type": LLMType.CHAT.value
106
  },{
107
  "fid": factory_infos[1]["name"],
108
  "llm_name": "text-embedding-v2",
109
  "tags": "TEXT EMBEDDING,2K",
 
110
  "model_type": LLMType.EMBEDDING.value
111
  },{
112
  "fid": factory_infos[1]["name"],
113
  "llm_name": "paraformer-realtime-8k-v1",
114
  "tags": "SPEECH2TEXT",
 
115
  "model_type": LLMType.SPEECH2TEXT.value
116
  },{
117
  "fid": factory_infos[1]["name"],
118
  "llm_name": "qwen_vl_chat_v1",
119
  "tags": "LLM,CHAT,IMAGE2TEXT",
 
120
  "model_type": LLMType.IMAGE2TEXT.value
121
  },
122
  ]
 
62
  "fid": factory_infos[0]["name"],
63
  "llm_name": "gpt-3.5-turbo",
64
  "tags": "LLM,CHAT,4K",
65
+ "max_tokens": 4096,
66
  "model_type": LLMType.CHAT.value
67
  },{
68
  "fid": factory_infos[0]["name"],
69
  "llm_name": "gpt-3.5-turbo-16k-0613",
70
  "tags": "LLM,CHAT,16k",
71
+ "max_tokens": 16385,
72
  "model_type": LLMType.CHAT.value
73
  },{
74
  "fid": factory_infos[0]["name"],
75
  "llm_name": "text-embedding-ada-002",
76
  "tags": "TEXT EMBEDDING,8K",
77
+ "max_tokens": 8191,
78
  "model_type": LLMType.EMBEDDING.value
79
  },{
80
  "fid": factory_infos[0]["name"],
81
  "llm_name": "whisper-1",
82
  "tags": "SPEECH2TEXT",
83
+ "max_tokens": 25*1024*1024,
84
  "model_type": LLMType.SPEECH2TEXT.value
85
  },{
86
  "fid": factory_infos[0]["name"],
87
  "llm_name": "gpt-4",
88
  "tags": "LLM,CHAT,8K",
89
+ "max_tokens": 8191,
90
  "model_type": LLMType.CHAT.value
91
  },{
92
  "fid": factory_infos[0]["name"],
93
  "llm_name": "gpt-4-32k",
94
  "tags": "LLM,CHAT,32K",
95
+ "max_tokens": 32768,
96
  "model_type": LLMType.CHAT.value
97
  },{
98
  "fid": factory_infos[0]["name"],
99
  "llm_name": "gpt-4-vision-preview",
100
  "tags": "LLM,CHAT,IMAGE2TEXT",
101
+ "max_tokens": 765,
102
  "model_type": LLMType.IMAGE2TEXT.value
103
  },{
104
  "fid": factory_infos[1]["name"],
105
  "llm_name": "qwen-turbo",
106
  "tags": "LLM,CHAT,8K",
107
+ "max_tokens": 8191,
108
  "model_type": LLMType.CHAT.value
109
  },{
110
  "fid": factory_infos[1]["name"],
111
  "llm_name": "qwen-plus",
112
  "tags": "LLM,CHAT,32K",
113
+ "max_tokens": 32768,
114
  "model_type": LLMType.CHAT.value
115
  },{
116
  "fid": factory_infos[1]["name"],
117
  "llm_name": "text-embedding-v2",
118
  "tags": "TEXT EMBEDDING,2K",
119
+ "max_tokens": 2048,
120
  "model_type": LLMType.EMBEDDING.value
121
  },{
122
  "fid": factory_infos[1]["name"],
123
  "llm_name": "paraformer-realtime-8k-v1",
124
  "tags": "SPEECH2TEXT",
125
+ "max_tokens": 25*1024*1024,
126
  "model_type": LLMType.SPEECH2TEXT.value
127
  },{
128
  "fid": factory_infos[1]["name"],
129
  "llm_name": "qwen_vl_chat_v1",
130
  "tags": "LLM,CHAT,IMAGE2TEXT",
131
+ "max_tokens": 765,
132
  "model_type": LLMType.IMAGE2TEXT.value
133
  },
134
  ]
api/db/services/llm_service.py CHANGED
@@ -34,7 +34,7 @@ class TenantLLMService(CommonService):
34
 
35
  @classmethod
36
  @DB.connection_context()
37
- def get_api_key(cls, tenant_id, model_type):
38
  objs = cls.query(tenant_id=tenant_id, model_type=model_type)
39
  if objs and len(objs)>0 and objs[0].llm_name:
40
  return objs[0]
@@ -42,7 +42,7 @@ class TenantLLMService(CommonService):
42
  fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
43
  objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
44
  (cls.model.tenant_id == tenant_id),
45
- (cls.model.model_type == model_type),
46
  (LLM.status == StatusEnum.VALID)
47
  )
48
 
@@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
60
  @classmethod
61
  @DB.connection_context()
62
  def model_instance(cls, tenant_id, llm_type):
63
- model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
64
  if not model_config:
65
  model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
66
  else:
 
34
 
35
  @classmethod
36
  @DB.connection_context()
37
+ def get_api_key(cls, tenant_id, model_type, model_name=""):
38
  objs = cls.query(tenant_id=tenant_id, model_type=model_type)
39
  if objs and len(objs)>0 and objs[0].llm_name:
40
  return objs[0]
 
42
  fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
43
  objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
44
  (cls.model.tenant_id == tenant_id),
45
+ ((cls.model.model_type == model_type) | (cls.model.llm_name == model_name)),
46
  (LLM.status == StatusEnum.VALID)
47
  )
48
 
 
60
  @classmethod
61
  @DB.connection_context()
62
  def model_instance(cls, tenant_id, llm_type):
63
+ model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING.value)
64
  if not model_config:
65
  model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
66
  else:
rag/llm/__init__.py CHANGED
@@ -30,3 +30,9 @@ CvModel = {
30
  "通义千问": QWenCV,
31
  }
32
 
 
 
 
 
 
 
 
30
  "通义千问": QWenCV,
31
  }
32
 
33
+
34
+ ChatModel = {
35
+ "OpenAI": GptTurbo,
36
+ "通义千问": QWenChat,
37
+ }
38
+
rag/nlp/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import search
2
+ from rag.utils import ELASTICSEARCH
3
+
4
+ retrievaler = search.Dealer(ELASTICSEARCH)
rag/nlp/search.py CHANGED
@@ -2,7 +2,7 @@
2
  import json
3
  import re
4
  from elasticsearch_dsl import Q, Search, A
5
- from typing import List, Optional, Tuple, Dict, Union
6
  from dataclasses import dataclass
7
 
8
  from rag.settings import es_logger
@@ -20,6 +20,8 @@ class Dealer:
20
  self.qryr.flds = [
21
  "title_tks^10",
22
  "title_sm_tks^5",
 
 
23
  "content_ltks^2",
24
  "content_sm_ltks"]
25
  self.es = es
@@ -38,10 +40,10 @@ class Dealer:
38
  def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
39
  qv, c = emb_mdl.encode_queries(txt)
40
  return {
41
- "field": "q_%d_vec"%len(qv),
42
  "k": topk,
43
  "similarity": sim,
44
- "num_candidates": topk*2,
45
  "query_vector": qv
46
  }
47
 
@@ -53,16 +55,18 @@ class Dealer:
53
  if req.get("doc_ids"):
54
  bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
55
  if "available_int" in req:
56
- if req["available_int"] == 0: bqry.filter.append(Q("range", available_int={"lt": 1}))
57
- else: bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1})))
 
 
58
  bqry.boost = 0.05
59
 
60
  s = Search()
61
  pg = int(req.get("page", 1)) - 1
62
  ps = int(req.get("size", 1000))
63
- src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id","img_id",
64
- "image_id", "doc_id", "q_512_vec", "q_768_vec",
65
- "q_1024_vec", "q_1536_vec", "available_int"])
66
 
67
  s = s.query(bqry)[pg * ps:(pg + 1) * ps]
68
  s = s.highlight("content_ltks")
@@ -171,74 +175,106 @@ class Dealer:
171
  def trans2floats(txt):
172
  return [float(t) for t in txt.split("\t")]
173
 
174
- def insert_citations(self, ans, top_idx, sres, emb_mdl,
175
- vfield="q_vec", cfield="content_ltks"):
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- ins_embd = [Dealer.trans2floats(
178
- sres.field[sres.ids[i]][vfield]) for i in top_idx]
179
- ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
180
- s = 0
181
- e = 0
182
- res = ""
183
 
184
- def citeit():
185
- nonlocal s, e, ans, res, emb_mdl
186
- if not ins_embd:
187
- return
188
- embd = emb_mdl.encode(ans[s: e])
189
- sim = self.qryr.hybrid_similarity(embd,
190
- ins_embd,
191
- huqie.qie(ans[s:e]).split(" "),
192
- ins_tw)
193
  mx = np.max(sim) * 0.99
194
- if mx < 0.55:
195
- return
196
- cita = list(set([top_idx[i]
197
- for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
198
- for i in cita:
199
- res += f"@?{i}?@"
200
-
201
- return cita
202
-
203
- punct = set(";。?!!")
204
- if not self.qryr.isChinese(ans):
205
- punct.add("?")
206
- punct.add(".")
207
- while e < len(ans):
208
- if e - s < 12 or ans[e] not in punct:
209
- e += 1
210
- continue
211
- if ans[e] == "." and e + \
212
- 1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
213
- e += 1
214
- continue
215
- if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
216
- e += 1
217
- continue
218
- res += ans[s: e]
219
- citeit()
220
- res += ans[e]
221
- e += 1
222
- s = e
223
 
224
- if s < len(ans):
225
- res += ans[s:]
226
- citeit()
 
 
 
227
 
228
  return res
229
 
230
  def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"):
231
  ins_embd = [
232
  Dealer.trans2floats(
233
- sres.field[i]["q_%d_vec"%len(sres.query_vector)]) for i in sres.ids]
234
  if not ins_embd:
235
  return []
236
- ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
237
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
238
- ins_embd,
239
- huqie.qie(query).split(" "),
240
- ins_tw, tkweight, vtweight)
241
  return sim, tksim, vtsim
242
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
 
 
2
  import json
3
  import re
4
  from elasticsearch_dsl import Q, Search, A
5
+ from typing import List, Optional, Dict, Union
6
  from dataclasses import dataclass
7
 
8
  from rag.settings import es_logger
 
20
  self.qryr.flds = [
21
  "title_tks^10",
22
  "title_sm_tks^5",
23
+ "important_kwd^30",
24
+ "important_tks^20",
25
  "content_ltks^2",
26
  "content_sm_ltks"]
27
  self.es = es
 
40
  def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
41
  qv, c = emb_mdl.encode_queries(txt)
42
  return {
43
+ "field": "q_%d_vec" % len(qv),
44
  "k": topk,
45
  "similarity": sim,
46
+ "num_candidates": topk * 2,
47
  "query_vector": qv
48
  }
49
 
 
55
  if req.get("doc_ids"):
56
  bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
57
  if "available_int" in req:
58
+ if req["available_int"] == 0:
59
+ bqry.filter.append(Q("range", available_int={"lt": 1}))
60
+ else:
61
+ bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1})))
62
  bqry.boost = 0.05
63
 
64
  s = Search()
65
  pg = int(req.get("page", 1)) - 1
66
  ps = int(req.get("size", 1000))
67
+ src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id",
68
+ "image_id", "doc_id", "q_512_vec", "q_768_vec",
69
+ "q_1024_vec", "q_1536_vec", "available_int"])
70
 
71
  s = s.query(bqry)[pg * ps:(pg + 1) * ps]
72
  s = s.highlight("content_ltks")
 
175
  def trans2floats(txt):
176
  return [float(t) for t in txt.split("\t")]
177
 
178
+ def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7):
179
+ pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
180
+ for i in range(1, len(pieces)):
181
+ if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
182
+ pieces[i - 1] += pieces[i][0]
183
+ pieces[i] = pieces[i][1:]
184
+ idx = []
185
+ pieces_ = []
186
+ for i, t in enumerate(pieces):
187
+ if len(t) < 5: continue
188
+ idx.append(i)
189
+ pieces_.append(t)
190
+ if not pieces_: return answer
191
 
192
+ ans_v = embd_mdl.encode(pieces_)
193
+ assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
194
+ len(ans_v[0]), len(chunk_v[0]))
 
 
 
195
 
196
+ chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks]
197
+ cites = {}
198
+ for i,a in enumerate(pieces_):
199
+ sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
200
+ chunk_v,
201
+ huqie.qie(pieces_[i]).split(" "),
202
+ chunks_tks,
203
+ tkweight, vtweight)
 
204
  mx = np.max(sim) * 0.99
205
+ if mx < 0.55: continue
206
+ cites[idx[i]] = list(set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ res = ""
209
+ for i,p in enumerate(pieces):
210
+ res += p
211
+ if i not in idx:continue
212
+ if i not in cites:continue
213
+ res += "##%s$$"%"$".join(cites[i])
214
 
215
  return res
216
 
217
  def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"):
218
  ins_embd = [
219
  Dealer.trans2floats(
220
+ sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
221
  if not ins_embd:
222
  return []
223
+ ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
224
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
225
+ ins_embd,
226
+ huqie.qie(query).split(" "),
227
+ ins_tw, tkweight, vtweight)
228
  return sim, tksim, vtsim
229
 
230
+ def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
231
+ return self.qryr.hybrid_similarity(ans_embd,
232
+ ins_embd,
233
+ huqie.qie(ans).split(" "),
234
+ huqie.qie(inst).split(" "))
235
+
236
+ def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
237
+ vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
238
+ req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
239
+ "question": question, "vector": True,
240
+ "similarity": similarity_threshold}
241
+ sres = self.search(req, index_name(tenant_id), embd_mdl)
242
 
243
+ sim, tsim, vsim = self.rerank(
244
+ sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
245
+ idx = np.argsort(sim * -1)
246
+ ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
247
+ dim = len(sres.query_vector)
248
+ start_idx = (page - 1) * page_size
249
+ for i in idx:
250
+ ranks["total"] += 1
251
+ if sim[i] < similarity_threshold:
252
+ break
253
+ start_idx -= 1
254
+ if start_idx >= 0:
255
+ continue
256
+ if len(ranks["chunks"]) == page_size:
257
+ if aggs:
258
+ continue
259
+ break
260
+ id = sres.ids[i]
261
+ dnm = sres.field[id]["docnm_kwd"]
262
+ d = {
263
+ "chunk_id": id,
264
+ "content_ltks": sres.field[id]["content_ltks"],
265
+ "doc_id": sres.field[id]["doc_id"],
266
+ "docnm_kwd": dnm,
267
+ "kb_id": sres.field[id]["kb_id"],
268
+ "important_kwd": sres.field[id].get("important_kwd", []),
269
+ "img_id": sres.field[id].get("img_id", ""),
270
+ "similarity": sim[i],
271
+ "vector_similarity": vsim[i],
272
+ "term_similarity": tsim[i],
273
+ "vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim)))
274
+ }
275
+ ranks["chunks"].append(d)
276
+ if dnm not in ranks["doc_aggs"]:
277
+ ranks["doc_aggs"][dnm] = 0
278
+ ranks["doc_aggs"][dnm] += 1
279
 
280
+ return ranks
rag/utils/__init__.py CHANGED
@@ -58,9 +58,11 @@ def findMaxTm(fnm):
58
  print("WARNING: can't find " + fnm)
59
  return m
60
 
61
-
 
 
62
  def num_tokens_from_string(string: str) -> int:
63
  """Returns the number of tokens in a text string."""
64
- encoding = tiktoken.get_encoding('cl100k_base')
65
- num_tokens = len(encoding.encode(string))
66
- return num_tokens
 
58
  print("WARNING: can't find " + fnm)
59
  return m
60
 
61
+
62
+ encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
63
+
64
  def num_tokens_from_string(string: str) -> int:
65
  """Returns the number of tokens in a text string."""
66
+ num_tokens = len(encoder.encode(string))
67
+ return num_tokens
68
+