KevinHuSh commited on
Commit
79ada0b
·
1 Parent(s): 85b269d

apply pep8 formalize (#155)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. api/apps/chunk_app.py +14 -5
  2. api/apps/conversation_app.py +81 -43
  3. api/apps/dialog_app.py +30 -13
  4. api/apps/document_app.py +21 -8
  5. api/apps/kb_app.py +39 -17
  6. api/apps/llm_app.py +26 -12
  7. api/apps/user_app.py +57 -29
  8. api/db/db_models.py +243 -64
  9. api/db/db_utils.py +10 -5
  10. api/db/init_data.py +55 -45
  11. api/db/operatioins.py +1 -1
  12. api/db/reload_config_base.py +3 -2
  13. api/db/runtime_config.py +1 -1
  14. api/db/services/common_service.py +26 -12
  15. api/db/services/dialog_service.py +0 -1
  16. api/db/services/document_service.py +49 -12
  17. api/db/services/knowledgebase_service.py +8 -7
  18. api/db/services/llm_service.py +22 -11
  19. api/db/services/user_service.py +32 -13
  20. api/settings.py +38 -16
  21. api/utils/__init__.py +39 -14
  22. api/utils/api_utils.py +43 -17
  23. api/utils/file_utils.py +13 -10
  24. api/utils/log_utils.py +39 -21
  25. api/utils/t_crypt.py +10 -5
  26. deepdoc/parser/__init__.py +0 -2
  27. deepdoc/parser/docx_parser.py +6 -3
  28. deepdoc/parser/excel_parser.py +12 -7
  29. deepdoc/parser/pdf_parser.py +105 -53
  30. deepdoc/parser/ppt_parser.py +13 -7
  31. deepdoc/vision/layout_recognizer.py +29 -24
  32. deepdoc/vision/operators.py +2 -1
  33. deepdoc/vision/t_ocr.py +21 -13
  34. deepdoc/vision/t_recognizer.py +33 -18
  35. deepdoc/vision/table_structure_recognizer.py +13 -10
  36. rag/app/book.py +36 -18
  37. rag/app/laws.py +29 -15
  38. rag/app/manual.py +30 -19
  39. rag/app/naive.py +21 -12
  40. rag/app/one.py +16 -10
  41. rag/app/paper.py +26 -14
  42. rag/app/presentation.py +42 -22
  43. rag/app/resume.py +17 -8
  44. rag/app/table.py +30 -12
  45. rag/llm/chat_model.py +24 -13
  46. rag/llm/cv_model.py +8 -9
  47. rag/llm/embedding_model.py +16 -13
  48. rag/llm/rpc_server.py +20 -7
  49. rag/nlp/huchunk.py +13 -6
  50. rag/nlp/query.py +6 -5
api/apps/chunk_app.py CHANGED
@@ -121,7 +121,9 @@ def get():
121
  "important_kwd")
122
  def set():
123
  req = request.json
124
- d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
 
 
125
  d["content_ltks"] = huqie.qie(req["content_with_weight"])
126
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
127
  d["important_kwd"] = req["important_kwd"]
@@ -140,10 +142,16 @@ def set():
140
  return get_data_error_result(retmsg="Document not found!")
141
 
142
  if doc.parser_id == ParserType.QA:
143
- arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1]
144
- if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.")
 
 
 
 
 
145
  q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
146
- d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q + a]))
 
147
 
148
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
149
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
@@ -177,7 +185,8 @@ def switch():
177
  def rm():
178
  req = request.json
179
  try:
180
- if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
 
181
  return get_data_error_result(retmsg="Index updating failure")
182
  return get_json_result(data=True)
183
  except Exception as e:
 
121
  "important_kwd")
122
  def set():
123
  req = request.json
124
+ d = {
125
+ "id": req["chunk_id"],
126
+ "content_with_weight": req["content_with_weight"]}
127
  d["content_ltks"] = huqie.qie(req["content_with_weight"])
128
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
129
  d["important_kwd"] = req["important_kwd"]
 
142
  return get_data_error_result(retmsg="Document not found!")
143
 
144
  if doc.parser_id == ParserType.QA:
145
+ arr = [
146
+ t for t in re.split(
147
+ r"[\n\t]",
148
+ req["content_with_weight"]) if len(t) > 1]
149
+ if len(arr) != 2:
150
+ return get_data_error_result(
151
+ retmsg="Q&A must be separated by TAB/ENTER key.")
152
  q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
153
+ d = beAdoc(d, arr[0], arr[1], not any(
154
+ [huqie.is_chinese(t) for t in q + a]))
155
 
156
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
157
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
 
185
  def rm():
186
  req = request.json
187
  try:
188
+ if not ELASTICSEARCH.deleteByQuery(
189
+ Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
190
  return get_data_error_result(retmsg="Index updating failure")
191
  return get_json_result(data=True)
192
  except Exception as e:
api/apps/conversation_app.py CHANGED
@@ -100,7 +100,10 @@ def rm():
100
  def list_convsersation():
101
  dialog_id = request.args["dialog_id"]
102
  try:
103
- convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
 
 
 
104
  convs = [d.to_dict() for d in convs]
105
  return get_json_result(data=convs)
106
  except Exception as e:
@@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000):
111
  def count():
112
  nonlocal msg
113
  tks_cnts = []
114
- for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
 
 
115
  total = 0
116
- for m in tks_cnts: total += m["count"]
 
117
  return total
118
 
119
  c = count()
120
- if c < max_length: return c, msg
 
121
 
122
  msg_ = [m for m in msg[:-1] if m.role == "system"]
123
  msg_.append(msg[-1])
124
  msg = msg_
125
  c = count()
126
- if c < max_length: return c, msg
 
127
 
128
  ll = num_tokens_from_string(msg_[0].content)
129
  l = num_tokens_from_string(msg_[-1].content)
@@ -146,8 +154,10 @@ def completion():
146
  req = request.json
147
  msg = []
148
  for m in req["messages"]:
149
- if m["role"] == "system": continue
150
- if m["role"] == "assistant" and not msg: continue
 
 
151
  msg.append({"role": m["role"], "content": m["content"]})
152
  try:
153
  e, conv = ConversationService.get_by_id(req["conversation_id"])
@@ -160,7 +170,8 @@ def completion():
160
  del req["conversation_id"]
161
  del req["messages"]
162
  ans = chat(dia, msg, **req)
163
- if not conv.reference: conv.reference = []
 
164
  conv.reference.append(ans["reference"])
165
  conv.message.append({"role": "assistant", "content": ans["answer"]})
166
  ConversationService.update_by_id(conv.id, conv.to_dict())
@@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs):
180
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
181
 
182
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
183
- ## try to use sql if field mapping is good to go
184
  if field_map:
185
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
186
  return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
187
 
188
  prompt_config = dialog.prompt_config
189
  for p in prompt_config["parameters"]:
190
- if p["key"] == "knowledge": continue
191
- if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"])
 
 
192
  if p["key"] not in kwargs:
193
- prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
 
194
 
195
- for _ in range(len(questions)//2):
196
  questions.append(questions[-1])
197
  if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
198
- kbinfos = {"total":0, "chunks":[],"doc_aggs":[]}
199
  else:
200
  kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
201
- dialog.similarity_threshold,
202
- dialog.vector_similarity_weight, top=1024, aggs=False)
203
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
204
- chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
 
205
 
206
  if not knowledges and prompt_config.get("empty_response"):
207
- return {"answer": prompt_config["empty_response"], "reference": kbinfos}
 
208
 
209
  kwargs["knowledge"] = "\n".join(knowledges)
210
  gen_conf = dialog.llm_setting
211
- msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
 
212
  used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
213
  if "max_tokens" in gen_conf:
214
- gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
215
- answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
216
- chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
 
 
 
 
 
217
 
218
  if knowledges:
219
  answer, idx = retrievaler.insert_citations(answer,
220
- [ck["content_ltks"] for ck in kbinfos["chunks"]],
221
- [ck["vector"] for ck in kbinfos["chunks"]],
222
- embd_mdl,
223
- tkweight=1 - dialog.vector_similarity_weight,
224
- vtweight=dialog.vector_similarity_weight)
 
 
225
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
226
- kbinfos["doc_aggs"] = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
 
227
  for c in kbinfos["chunks"]:
228
- if c.get("vector"): del c["vector"]
 
229
  return {"answer": answer, "reference": kbinfos}
230
 
231
 
@@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
245
  question
246
  )
247
  tried_times = 0
 
248
  def get_table():
249
  nonlocal sys_prompt, user_promt, question, tried_times
250
- sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
 
251
  print(user_promt, sql)
252
  chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
253
  sql = re.sub(r"[\r\n]+", " ", sql.lower())
@@ -262,8 +290,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
262
  else:
263
  flds = []
264
  for k in field_map.keys():
265
- if k in forbidden_select_fields4resume:continue
266
- if len(flds) > 11:break
 
 
267
  flds.append(k)
268
  sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
269
 
@@ -284,13 +314,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
284
 
285
  问题如下:
286
  {}
287
-
288
  你上一次给出的错误SQL如下:
289
  {}
290
-
291
  后台报错如下:
292
  {}
293
-
294
  请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
295
  """.format(
296
  index_name(tenant_id),
@@ -302,16 +332,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
302
 
303
  chat_logger.info("GET table: {}".format(tbl))
304
  print(tbl)
305
- if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
 
306
 
307
- docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
308
- docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
309
- clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
 
 
 
310
 
311
  # compose markdown table
312
- clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
313
- line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
314
- rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
 
 
 
 
315
  if not docid_idx or not docnm_idx:
316
  chat_logger.warning("SQL missing field: " + sql)
317
  return "\n".join([clmns, line, "\n".join(rows)]), []
@@ -328,5 +366,5 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
328
  return {
329
  "answer": "\n".join([clmns, line, rows]),
330
  "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
331
- "doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
332
  }
 
100
  def list_convsersation():
101
  dialog_id = request.args["dialog_id"]
102
  try:
103
+ convs = ConversationService.query(
104
+ dialog_id=dialog_id,
105
+ order_by=ConversationService.model.create_time,
106
+ reverse=True)
107
  convs = [d.to_dict() for d in convs]
108
  return get_json_result(data=convs)
109
  except Exception as e:
 
114
  def count():
115
  nonlocal msg
116
  tks_cnts = []
117
+ for m in msg:
118
+ tks_cnts.append(
119
+ {"role": m["role"], "count": num_tokens_from_string(m["content"])})
120
  total = 0
121
+ for m in tks_cnts:
122
+ total += m["count"]
123
  return total
124
 
125
  c = count()
126
+ if c < max_length:
127
+ return c, msg
128
 
129
  msg_ = [m for m in msg[:-1] if m.role == "system"]
130
  msg_.append(msg[-1])
131
  msg = msg_
132
  c = count()
133
+ if c < max_length:
134
+ return c, msg
135
 
136
  ll = num_tokens_from_string(msg_[0].content)
137
  l = num_tokens_from_string(msg_[-1].content)
 
154
  req = request.json
155
  msg = []
156
  for m in req["messages"]:
157
+ if m["role"] == "system":
158
+ continue
159
+ if m["role"] == "assistant" and not msg:
160
+ continue
161
  msg.append({"role": m["role"], "content": m["content"]})
162
  try:
163
  e, conv = ConversationService.get_by_id(req["conversation_id"])
 
170
  del req["conversation_id"]
171
  del req["messages"]
172
  ans = chat(dia, msg, **req)
173
+ if not conv.reference:
174
+ conv.reference = []
175
  conv.reference.append(ans["reference"])
176
  conv.message.append({"role": "assistant", "content": ans["answer"]})
177
  ConversationService.update_by_id(conv.id, conv.to_dict())
 
191
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
192
 
193
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
194
+ # try to use sql if field mapping is good to go
195
  if field_map:
196
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
197
  return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
198
 
199
  prompt_config = dialog.prompt_config
200
  for p in prompt_config["parameters"]:
201
+ if p["key"] == "knowledge":
202
+ continue
203
+ if p["key"] not in kwargs and not p["optional"]:
204
+ raise KeyError("Miss parameter: " + p["key"])
205
  if p["key"] not in kwargs:
206
+ prompt_config["system"] = prompt_config["system"].replace(
207
+ "{%s}" % p["key"], " ")
208
 
209
+ for _ in range(len(questions) // 2):
210
  questions.append(questions[-1])
211
  if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
212
+ kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
213
  else:
214
  kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
215
+ dialog.similarity_threshold,
216
+ dialog.vector_similarity_weight, top=1024, aggs=False)
217
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
218
+ chat_logger.info(
219
+ "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
220
 
221
  if not knowledges and prompt_config.get("empty_response"):
222
+ return {
223
+ "answer": prompt_config["empty_response"], "reference": kbinfos}
224
 
225
  kwargs["knowledge"] = "\n".join(knowledges)
226
  gen_conf = dialog.llm_setting
227
+ msg = [{"role": m["role"], "content": m["content"]}
228
+ for m in messages if m["role"] != "system"]
229
  used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
230
  if "max_tokens" in gen_conf:
231
+ gen_conf["max_tokens"] = min(
232
+ gen_conf["max_tokens"],
233
+ llm.max_tokens - used_token_count)
234
+ answer = chat_mdl.chat(
235
+ prompt_config["system"].format(
236
+ **kwargs), msg, gen_conf)
237
+ chat_logger.info("User: {}|Assistant: {}".format(
238
+ msg[-1]["content"], answer))
239
 
240
  if knowledges:
241
  answer, idx = retrievaler.insert_citations(answer,
242
+ [ck["content_ltks"]
243
+ for ck in kbinfos["chunks"]],
244
+ [ck["vector"]
245
+ for ck in kbinfos["chunks"]],
246
+ embd_mdl,
247
+ tkweight=1 - dialog.vector_similarity_weight,
248
+ vtweight=dialog.vector_similarity_weight)
249
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
250
+ kbinfos["doc_aggs"] = [
251
+ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
252
  for c in kbinfos["chunks"]:
253
+ if c.get("vector"):
254
+ del c["vector"]
255
  return {"answer": answer, "reference": kbinfos}
256
 
257
 
 
271
  question
272
  )
273
  tried_times = 0
274
+
275
  def get_table():
276
  nonlocal sys_prompt, user_promt, question, tried_times
277
+ sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
278
+ "temperature": 0.06})
279
  print(user_promt, sql)
280
  chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
281
  sql = re.sub(r"[\r\n]+", " ", sql.lower())
 
290
  else:
291
  flds = []
292
  for k in field_map.keys():
293
+ if k in forbidden_select_fields4resume:
294
+ continue
295
+ if len(flds) > 11:
296
+ break
297
  flds.append(k)
298
  sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
299
 
 
314
 
315
  问题如下:
316
  {}
317
+
318
  你上一次给出的错误SQL如下:
319
  {}
320
+
321
  后台报错如下:
322
  {}
323
+
324
  请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
325
  """.format(
326
  index_name(tenant_id),
 
332
 
333
  chat_logger.info("GET table: {}".format(tbl))
334
  print(tbl)
335
+ if tbl.get("error") or len(tbl["rows"]) == 0:
336
+ return None, None
337
 
338
+ docid_idx = set([ii for ii, c in enumerate(
339
+ tbl["columns"]) if c["name"] == "doc_id"])
340
+ docnm_idx = set([ii for ii, c in enumerate(
341
+ tbl["columns"]) if c["name"] == "docnm_kwd"])
342
+ clmn_idx = [ii for ii in range(
343
+ len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
344
 
345
  # compose markdown table
346
+ clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
347
+ tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
348
+ line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
349
+ ("|------|" if docid_idx and docid_idx else "")
350
+ rows = ["|" +
351
+ "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
352
+ "|" for r in tbl["rows"]]
353
  if not docid_idx or not docnm_idx:
354
  chat_logger.warning("SQL missing field: " + sql)
355
  return "\n".join([clmns, line, "\n".join(rows)]), []
 
366
  return {
367
  "answer": "\n".join([clmns, line, rows]),
368
  "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
369
+ "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
370
  }
api/apps/dialog_app.py CHANGED
@@ -55,7 +55,8 @@ def set_dialog():
55
  }
56
  prompt_config = req.get("prompt_config", default_prompt)
57
 
58
- if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"]
 
59
  # if len(prompt_config["parameters"]) < 1:
60
  # prompt_config["parameters"] = default_prompt["parameters"]
61
  # for p in prompt_config["parameters"]:
@@ -63,16 +64,21 @@ def set_dialog():
63
  # else: prompt_config["parameters"].append(default_prompt["parameters"][0])
64
 
65
  for p in prompt_config["parameters"]:
66
- if p["optional"]: continue
 
67
  if prompt_config["system"].find("{%s}" % p["key"]) < 0:
68
- return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"]))
 
69
 
70
  try:
71
  e, tenant = TenantService.get_by_id(current_user.id)
72
- if not e: return get_data_error_result(retmsg="Tenant not found!")
 
73
  llm_id = req.get("llm_id", tenant.llm_id)
74
  if not dialog_id:
75
- if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!")
 
 
76
  dia = {
77
  "id": get_uuid(),
78
  "tenant_id": current_user.id,
@@ -86,17 +92,21 @@ def set_dialog():
86
  "similarity_threshold": similarity_threshold,
87
  "vector_similarity_weight": vector_similarity_weight
88
  }
89
- if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!")
 
90
  e, dia = DialogService.get_by_id(dia["id"])
91
- if not e: return get_data_error_result(retmsg="Fail to new a dialog!")
 
92
  return get_json_result(data=dia.to_json())
93
  else:
94
  del req["dialog_id"]
95
- if "kb_names" in req: del req["kb_names"]
 
96
  if not DialogService.update_by_id(dialog_id, req):
97
  return get_data_error_result(retmsg="Dialog not found!")
98
  e, dia = DialogService.get_by_id(dialog_id)
99
- if not e: return get_data_error_result(retmsg="Fail to update a dialog!")
 
100
  dia = dia.to_dict()
101
  dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
102
  return get_json_result(data=dia)
@@ -110,7 +120,8 @@ def get():
110
  dialog_id = request.args["dialog_id"]
111
  try:
112
  e, dia = DialogService.get_by_id(dialog_id)
113
- if not e: return get_data_error_result(retmsg="Dialog not found!")
 
114
  dia = dia.to_dict()
115
  dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
116
  return get_json_result(data=dia)
@@ -122,7 +133,8 @@ def get_kb_names(kb_ids):
122
  ids, nms = [], []
123
  for kid in kb_ids:
124
  e, kb = KnowledgebaseService.get_by_id(kid)
125
- if not e or kb.status != StatusEnum.VALID.value: continue
 
126
  ids.append(kid)
127
  nms.append(kb.name)
128
  return ids, nms
@@ -132,7 +144,11 @@ def get_kb_names(kb_ids):
132
  @login_required
133
  def list():
134
  try:
135
- diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time)
 
 
 
 
136
  diags = [d.to_dict() for d in diags]
137
  for d in diags:
138
  d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
@@ -147,7 +163,8 @@ def list():
147
  def rm():
148
  req = request.json
149
  try:
150
- DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
 
151
  return get_json_result(data=True)
152
  except Exception as e:
153
  return server_error_response(e)
 
55
  }
56
  prompt_config = req.get("prompt_config", default_prompt)
57
 
58
+ if not prompt_config["system"]:
59
+ prompt_config["system"] = default_prompt["system"]
60
  # if len(prompt_config["parameters"]) < 1:
61
  # prompt_config["parameters"] = default_prompt["parameters"]
62
  # for p in prompt_config["parameters"]:
 
64
  # else: prompt_config["parameters"].append(default_prompt["parameters"][0])
65
 
66
  for p in prompt_config["parameters"]:
67
+ if p["optional"]:
68
+ continue
69
  if prompt_config["system"].find("{%s}" % p["key"]) < 0:
70
+ return get_data_error_result(
71
+ retmsg="Parameter '{}' is not used".format(p["key"]))
72
 
73
  try:
74
  e, tenant = TenantService.get_by_id(current_user.id)
75
+ if not e:
76
+ return get_data_error_result(retmsg="Tenant not found!")
77
  llm_id = req.get("llm_id", tenant.llm_id)
78
  if not dialog_id:
79
+ if not req.get("kb_ids"):
80
+ return get_data_error_result(
81
+ retmsg="Fail! Please select knowledgebase!")
82
  dia = {
83
  "id": get_uuid(),
84
  "tenant_id": current_user.id,
 
92
  "similarity_threshold": similarity_threshold,
93
  "vector_similarity_weight": vector_similarity_weight
94
  }
95
+ if not DialogService.save(**dia):
96
+ return get_data_error_result(retmsg="Fail to new a dialog!")
97
  e, dia = DialogService.get_by_id(dia["id"])
98
+ if not e:
99
+ return get_data_error_result(retmsg="Fail to new a dialog!")
100
  return get_json_result(data=dia.to_json())
101
  else:
102
  del req["dialog_id"]
103
+ if "kb_names" in req:
104
+ del req["kb_names"]
105
  if not DialogService.update_by_id(dialog_id, req):
106
  return get_data_error_result(retmsg="Dialog not found!")
107
  e, dia = DialogService.get_by_id(dialog_id)
108
+ if not e:
109
+ return get_data_error_result(retmsg="Fail to update a dialog!")
110
  dia = dia.to_dict()
111
  dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
112
  return get_json_result(data=dia)
 
120
  dialog_id = request.args["dialog_id"]
121
  try:
122
  e, dia = DialogService.get_by_id(dialog_id)
123
+ if not e:
124
+ return get_data_error_result(retmsg="Dialog not found!")
125
  dia = dia.to_dict()
126
  dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
127
  return get_json_result(data=dia)
 
133
  ids, nms = [], []
134
  for kid in kb_ids:
135
  e, kb = KnowledgebaseService.get_by_id(kid)
136
+ if not e or kb.status != StatusEnum.VALID.value:
137
+ continue
138
  ids.append(kid)
139
  nms.append(kb.name)
140
  return ids, nms
 
144
  @login_required
145
  def list():
146
  try:
147
+ diags = DialogService.query(
148
+ tenant_id=current_user.id,
149
+ status=StatusEnum.VALID.value,
150
+ reverse=True,
151
+ order_by=DialogService.model.create_time)
152
  diags = [d.to_dict() for d in diags]
153
  for d in diags:
154
  d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
 
163
  def rm():
164
  req = request.json
165
  try:
166
+ DialogService.update_many_by_id(
167
+ [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
168
  return get_json_result(data=True)
169
  except Exception as e:
170
  return server_error_response(e)
api/apps/document_app.py CHANGED
@@ -57,6 +57,9 @@ def upload():
57
  if not e:
58
  return get_data_error_result(
59
  retmsg="Can't find this knowledgebase!")
 
 
 
60
 
61
  filename = duplicate_name(
62
  DocumentService.query,
@@ -215,9 +218,11 @@ def rm():
215
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
216
  if not tenant_id:
217
  return get_data_error_result(retmsg="Tenant not found!")
218
- ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
 
219
 
220
- DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
 
221
  if not DocumentService.delete(doc):
222
  return get_data_error_result(
223
  retmsg="Database error (Document removal)!")
@@ -245,7 +250,8 @@ def run():
245
  tenant_id = DocumentService.get_tenant_id(id)
246
  if not tenant_id:
247
  return get_data_error_result(retmsg="Tenant not found!")
248
- ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
 
249
 
250
  return get_json_result(data=True)
251
  except Exception as e:
@@ -261,7 +267,8 @@ def rename():
261
  e, doc = DocumentService.get_by_id(req["doc_id"])
262
  if not e:
263
  return get_data_error_result(retmsg="Document not found!")
264
- if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
 
265
  return get_json_result(
266
  data=False,
267
  retmsg="The extension of file can't be changed",
@@ -294,7 +301,10 @@ def get(doc_id):
294
  if doc.type == FileType.VISUAL.value:
295
  response.headers.set('Content-Type', 'image/%s' % ext.group(1))
296
  else:
297
- response.headers.set('Content-Type', 'application/%s' % ext.group(1))
 
 
 
298
  return response
299
  except Exception as e:
300
  return server_error_response(e)
@@ -313,9 +323,11 @@ def change_parser():
313
  if "parser_config" in req:
314
  if req["parser_config"] == doc.parser_config:
315
  return get_json_result(data=True)
316
- else: return get_json_result(data=True)
 
317
 
318
- if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
 
319
  return get_data_error_result(retmsg="Not supported yet!")
320
 
321
  e = DocumentService.update_by_id(doc.id,
@@ -332,7 +344,8 @@ def change_parser():
332
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
333
  if not tenant_id:
334
  return get_data_error_result(retmsg="Tenant not found!")
335
- ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
 
336
 
337
  return get_json_result(data=True)
338
  except Exception as e:
 
57
  if not e:
58
  return get_data_error_result(
59
  retmsg="Can't find this knowledgebase!")
60
+ if DocumentService.get_doc_count(kb.tenant_id) >= 128:
61
+ return get_data_error_result(
62
+ retmsg="Exceed the maximum file number of a free user!")
63
 
64
  filename = duplicate_name(
65
  DocumentService.query,
 
218
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
219
  if not tenant_id:
220
  return get_data_error_result(retmsg="Tenant not found!")
221
+ ELASTICSEARCH.deleteByQuery(
222
+ Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
223
 
224
+ DocumentService.increment_chunk_num(
225
+ doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
226
  if not DocumentService.delete(doc):
227
  return get_data_error_result(
228
  retmsg="Database error (Document removal)!")
 
250
  tenant_id = DocumentService.get_tenant_id(id)
251
  if not tenant_id:
252
  return get_data_error_result(retmsg="Tenant not found!")
253
+ ELASTICSEARCH.deleteByQuery(
254
+ Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
255
 
256
  return get_json_result(data=True)
257
  except Exception as e:
 
267
  e, doc = DocumentService.get_by_id(req["doc_id"])
268
  if not e:
269
  return get_data_error_result(retmsg="Document not found!")
270
+ if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
271
+ doc.name.lower()).suffix:
272
  return get_json_result(
273
  data=False,
274
  retmsg="The extension of file can't be changed",
 
301
  if doc.type == FileType.VISUAL.value:
302
  response.headers.set('Content-Type', 'image/%s' % ext.group(1))
303
  else:
304
+ response.headers.set(
305
+ 'Content-Type',
306
+ 'application/%s' %
307
+ ext.group(1))
308
  return response
309
  except Exception as e:
310
  return server_error_response(e)
 
323
  if "parser_config" in req:
324
  if req["parser_config"] == doc.parser_config:
325
  return get_json_result(data=True)
326
+ else:
327
+ return get_json_result(data=True)
328
 
329
+ if doc.type == FileType.VISUAL or re.search(
330
+ r"\.(ppt|pptx|pages)$", doc.name):
331
  return get_data_error_result(retmsg="Not supported yet!")
332
 
333
  e = DocumentService.update_by_id(doc.id,
 
344
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
345
  if not tenant_id:
346
  return get_data_error_result(retmsg="Tenant not found!")
347
+ ELASTICSEARCH.deleteByQuery(
348
+ Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
349
 
350
  return get_json_result(data=True)
351
  except Exception as e:
api/apps/kb_app.py CHANGED
@@ -33,15 +33,21 @@ from api.utils.api_utils import get_json_result
33
  def create():
34
  req = request.json
35
  req["name"] = req["name"].strip()
36
- req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
 
 
 
 
37
  try:
38
  req["id"] = get_uuid()
39
  req["tenant_id"] = current_user.id
40
  req["created_by"] = current_user.id
41
  e, t = TenantService.get_by_id(current_user.id)
42
- if not e: return get_data_error_result(retmsg="Tenant not found.")
 
43
  req["embd_id"] = t.embd_id
44
- if not KnowledgebaseService.save(**req): return get_data_error_result()
 
45
  return get_json_result(data={"kb_id": req["id"]})
46
  except Exception as e:
47
  return server_error_response(e)
@@ -54,21 +60,29 @@ def update():
54
  req = request.json
55
  req["name"] = req["name"].strip()
56
  try:
57
- if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
58
- return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
 
 
59
 
60
  e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
61
- if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
 
 
62
 
63
  if req["name"].lower() != kb.name.lower() \
64
- and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
65
- return get_data_error_result(retmsg="Duplicated knowledgebase name.")
 
66
 
67
  del req["kb_id"]
68
- if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
 
69
 
70
  e, kb = KnowledgebaseService.get_by_id(kb.id)
71
- if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
 
 
72
 
73
  return get_json_result(data=kb.to_json())
74
  except Exception as e:
@@ -81,7 +95,9 @@ def detail():
81
  kb_id = request.args["kb_id"]
82
  try:
83
  kb = KnowledgebaseService.get_detail(kb_id)
84
- if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
 
 
85
  return get_json_result(data=kb)
86
  except Exception as e:
87
  return server_error_response(e)
@@ -96,7 +112,8 @@ def list():
96
  desc = request.args.get("desc", True)
97
  try:
98
  tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
99
- kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
 
100
  return get_json_result(data=kbs)
101
  except Exception as e:
102
  return server_error_response(e)
@@ -108,10 +125,15 @@ def list():
108
  def rm():
109
  req = request.json
110
  try:
111
- if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
112
- return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
113
-
114
- if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
 
 
 
 
 
115
  return get_json_result(data=True)
116
  except Exception as e:
117
- return server_error_response(e)
 
33
  def create():
34
  req = request.json
35
  req["name"] = req["name"].strip()
36
+ req["name"] = duplicate_name(
37
+ KnowledgebaseService.query,
38
+ name=req["name"],
39
+ tenant_id=current_user.id,
40
+ status=StatusEnum.VALID.value)
41
  try:
42
  req["id"] = get_uuid()
43
  req["tenant_id"] = current_user.id
44
  req["created_by"] = current_user.id
45
  e, t = TenantService.get_by_id(current_user.id)
46
+ if not e:
47
+ return get_data_error_result(retmsg="Tenant not found.")
48
  req["embd_id"] = t.embd_id
49
+ if not KnowledgebaseService.save(**req):
50
+ return get_data_error_result()
51
  return get_json_result(data={"kb_id": req["id"]})
52
  except Exception as e:
53
  return server_error_response(e)
 
60
  req = request.json
61
  req["name"] = req["name"].strip()
62
  try:
63
+ if not KnowledgebaseService.query(
64
+ created_by=current_user.id, id=req["kb_id"]):
65
+ return get_json_result(
66
+ data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
67
 
68
  e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
69
+ if not e:
70
+ return get_data_error_result(
71
+ retmsg="Can't find this knowledgebase!")
72
 
73
  if req["name"].lower() != kb.name.lower() \
74
+ and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
75
+ return get_data_error_result(
76
+ retmsg="Duplicated knowledgebase name.")
77
 
78
  del req["kb_id"]
79
+ if not KnowledgebaseService.update_by_id(kb.id, req):
80
+ return get_data_error_result()
81
 
82
  e, kb = KnowledgebaseService.get_by_id(kb.id)
83
+ if not e:
84
+ return get_data_error_result(
85
+ retmsg="Database error (Knowledgebase rename)!")
86
 
87
  return get_json_result(data=kb.to_json())
88
  except Exception as e:
 
95
  kb_id = request.args["kb_id"]
96
  try:
97
  kb = KnowledgebaseService.get_detail(kb_id)
98
+ if not kb:
99
+ return get_data_error_result(
100
+ retmsg="Can't find this knowledgebase!")
101
  return get_json_result(data=kb)
102
  except Exception as e:
103
  return server_error_response(e)
 
112
  desc = request.args.get("desc", True)
113
  try:
114
  tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
115
+ kbs = KnowledgebaseService.get_by_tenant_ids(
116
+ [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
117
  return get_json_result(data=kbs)
118
  except Exception as e:
119
  return server_error_response(e)
 
125
  def rm():
126
  req = request.json
127
  try:
128
+ if not KnowledgebaseService.query(
129
+ created_by=current_user.id, id=req["kb_id"]):
130
+ return get_json_result(
131
+ data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
132
+
133
+ if not KnowledgebaseService.update_by_id(
134
+ req["kb_id"], {"status": StatusEnum.INVALID.value}):
135
+ return get_data_error_result(
136
+ retmsg="Database error (Knowledgebase removal)!")
137
  return get_json_result(data=True)
138
  except Exception as e:
139
+ return server_error_response(e)
api/apps/llm_app.py CHANGED
@@ -48,30 +48,42 @@ def set_api_key():
48
  req["api_key"], llm.llm_name)
49
  try:
50
  arr, tc = mdl.encode(["Test if the api key is available"])
51
- if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
 
52
  except Exception as e:
53
  msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
54
  elif not chat_passed and llm.model_type == LLMType.CHAT.value:
55
  mdl = ChatModel[factory](
56
  req["api_key"], llm.llm_name)
57
  try:
58
- m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
59
- if not tc: raise Exception(m)
 
 
60
  chat_passed = True
61
  except Exception as e:
62
- msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
 
63
 
64
- if msg: return get_data_error_result(retmsg=msg)
 
65
 
66
  llm = {
67
  "api_key": req["api_key"]
68
  }
69
  for n in ["model_type", "llm_name"]:
70
- if n in req: llm[n] = req[n]
 
71
 
72
- if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm):
 
73
  for llm in LLMService.query(fid=factory):
74
- TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"])
 
 
 
 
 
75
 
76
  return get_json_result(data=True)
77
 
@@ -105,17 +117,19 @@ def list():
105
  objs = TenantLLMService.query(tenant_id=current_user.id)
106
  facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
107
  llms = LLMService.get_all()
108
- llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
 
109
  for m in llms:
110
  m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
111
 
112
  res = {}
113
  for m in llms:
114
- if model_type and m["model_type"] != model_type: continue
115
- if m["fid"] not in res: res[m["fid"]] = []
 
 
116
  res[m["fid"]].append(m)
117
 
118
  return get_json_result(data=res)
119
  except Exception as e:
120
  return server_error_response(e)
121
-
 
48
  req["api_key"], llm.llm_name)
49
  try:
50
  arr, tc = mdl.encode(["Test if the api key is available"])
51
+ if len(arr[0]) == 0 or tc == 0:
52
+ raise Exception("Fail")
53
  except Exception as e:
54
  msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
55
  elif not chat_passed and llm.model_type == LLMType.CHAT.value:
56
  mdl = ChatModel[factory](
57
  req["api_key"], llm.llm_name)
58
  try:
59
+ m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
60
+ "temperature": 0.9})
61
+ if not tc:
62
+ raise Exception(m)
63
  chat_passed = True
64
  except Exception as e:
65
+ msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
66
+ e)
67
 
68
+ if msg:
69
+ return get_data_error_result(retmsg=msg)
70
 
71
  llm = {
72
  "api_key": req["api_key"]
73
  }
74
  for n in ["model_type", "llm_name"]:
75
+ if n in req:
76
+ llm[n] = req[n]
77
 
78
+ if not TenantLLMService.filter_update(
79
+ [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm):
80
  for llm in LLMService.query(fid=factory):
81
+ TenantLLMService.save(
82
+ tenant_id=current_user.id,
83
+ llm_factory=factory,
84
+ llm_name=llm.llm_name,
85
+ model_type=llm.model_type,
86
+ api_key=req["api_key"])
87
 
88
  return get_json_result(data=True)
89
 
 
117
  objs = TenantLLMService.query(tenant_id=current_user.id)
118
  facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
119
  llms = LLMService.get_all()
120
+ llms = [m.to_dict()
121
+ for m in llms if m.status == StatusEnum.VALID.value]
122
  for m in llms:
123
  m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
124
 
125
  res = {}
126
  for m in llms:
127
+ if model_type and m["model_type"] != model_type:
128
+ continue
129
+ if m["fid"] not in res:
130
+ res[m["fid"]] = []
131
  res[m["fid"]].append(m)
132
 
133
  return get_json_result(data=res)
134
  except Exception as e:
135
  return server_error_response(e)
 
api/apps/user_app.py CHANGED
@@ -40,13 +40,16 @@ def login():
40
 
41
  email = request.json.get('email', "")
42
  users = UserService.query(email=email)
43
- if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
 
 
44
 
45
  password = request.json.get('password')
46
  try:
47
  password = decrypt(password)
48
- except:
49
- return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
 
50
 
51
  user = UserService.query_user(email, password)
52
  if user:
@@ -57,7 +60,8 @@ def login():
57
  msg = "Welcome back!"
58
  return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
59
  else:
60
- return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
 
61
 
62
 
63
  @manager.route('/github_callback', methods=['GET'])
@@ -65,7 +69,7 @@ def github_callback():
65
  import requests
66
  res = requests.post(GITHUB_OAUTH.get("url"), data={
67
  "client_id": GITHUB_OAUTH.get("client_id"),
68
- "client_secret": GITHUB_OAUTH.get("secret_key"),
69
  "code": request.args.get('code')
70
  }, headers={"Accept": "application/json"})
71
  res = res.json()
@@ -96,15 +100,17 @@ def github_callback():
96
  "last_login_time": get_format_time(),
97
  "is_superuser": False,
98
  })
99
- if not users: raise Exception('Register user failure.')
100
- if len(users) > 1: raise Exception('Same E-mail exist!')
 
 
101
  user = users[0]
102
  login_user(user)
103
- return redirect("/?auth=%s"%user.get_id())
104
  except Exception as e:
105
  rollback_user_registration(user_id)
106
  stat_logger.exception(e)
107
- return redirect("/?error=%s"%str(e))
108
  user = users[0]
109
  user.access_token = get_uuid()
110
  login_user(user)
@@ -114,11 +120,18 @@ def github_callback():
114
 
115
  def user_info_from_github(access_token):
116
  import requests
117
- headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
118
- res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
 
 
 
119
  user_info = res.json()
120
- email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
121
- user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
 
 
 
 
122
  return user_info
123
 
124
 
@@ -138,13 +151,18 @@ def setting_user():
138
  request_data = request.json
139
  if request_data.get("password"):
140
  new_password = request_data.get("new_password")
141
- if not check_password_hash(current_user.password, decrypt(request_data["password"])):
142
- return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
 
 
143
 
144
- if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
 
 
145
 
146
  for k in request_data.keys():
147
- if k in ["password", "new_password"]:continue
 
148
  update_dict[k] = request_data[k]
149
 
150
  try:
@@ -152,7 +170,8 @@ def setting_user():
152
  return get_json_result(data=True)
153
  except Exception as e:
154
  stat_logger.exception(e)
155
- return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
 
156
 
157
 
158
  @manager.route("/info", methods=["GET"])
@@ -173,11 +192,11 @@ def rollback_user_registration(user_id):
173
  except Exception as e:
174
  pass
175
  try:
176
- TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
177
  except Exception as e:
178
  pass
179
 
180
-
181
  def user_register(user_id, user):
182
  user["id"] = user_id
183
  tenant = {
@@ -197,9 +216,14 @@ def user_register(user_id, user):
197
  }
198
  tenant_llm = []
199
  for llm in LLMService.query(fid=LLM_FACTORY):
200
- tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
201
-
202
- if not UserService.save(**user):return
 
 
 
 
 
203
  TenantService.insert(**tenant)
204
  UserTenantService.insert(**usr_tenant)
205
  TenantLLMService.insert_many(tenant_llm)
@@ -211,7 +235,8 @@ def user_register(user_id, user):
211
  def user_add():
212
  req = request.json
213
  if UserService.query(email=req["email"]):
214
- return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
 
215
  if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
216
  return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
217
  retcode=RetCode.OPERATING_ERROR)
@@ -229,16 +254,19 @@ def user_add():
229
  user_id = get_uuid()
230
  try:
231
  users = user_register(user_id, user_dict)
232
- if not users: raise Exception('Register user failure.')
233
- if len(users) > 1: raise Exception('Same E-mail exist!')
 
 
234
  user = users[0]
235
  login_user(user)
236
- return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
 
237
  except Exception as e:
238
  rollback_user_registration(user_id)
239
  stat_logger.exception(e)
240
- return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
241
-
242
 
243
 
244
  @manager.route("/tenant_info", methods=["GET"])
 
40
 
41
  email = request.json.get('email', "")
42
  users = UserService.query(email=email)
43
+ if not users:
44
+ return get_json_result(
45
+ data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
46
 
47
  password = request.json.get('password')
48
  try:
49
  password = decrypt(password)
50
+ except BaseException:
51
+ return get_json_result(
52
+ data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
53
 
54
  user = UserService.query_user(email, password)
55
  if user:
 
60
  msg = "Welcome back!"
61
  return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
62
  else:
63
+ return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
64
+ retmsg='Email and Password do not match!')
65
 
66
 
67
  @manager.route('/github_callback', methods=['GET'])
 
69
  import requests
70
  res = requests.post(GITHUB_OAUTH.get("url"), data={
71
  "client_id": GITHUB_OAUTH.get("client_id"),
72
+ "client_secret": GITHUB_OAUTH.get("secret_key"),
73
  "code": request.args.get('code')
74
  }, headers={"Accept": "application/json"})
75
  res = res.json()
 
100
  "last_login_time": get_format_time(),
101
  "is_superuser": False,
102
  })
103
+ if not users:
104
+ raise Exception('Register user failure.')
105
+ if len(users) > 1:
106
+ raise Exception('Same E-mail exist!')
107
  user = users[0]
108
  login_user(user)
109
+ return redirect("/?auth=%s" % user.get_id())
110
  except Exception as e:
111
  rollback_user_registration(user_id)
112
  stat_logger.exception(e)
113
+ return redirect("/?error=%s" % str(e))
114
  user = users[0]
115
  user.access_token = get_uuid()
116
  login_user(user)
 
120
 
121
  def user_info_from_github(access_token):
122
  import requests
123
+ headers = {"Accept": "application/json",
124
+ 'Authorization': f"token {access_token}"}
125
+ res = requests.get(
126
+ f"https://api.github.com/user?access_token={access_token}",
127
+ headers=headers)
128
  user_info = res.json()
129
+ email_info = requests.get(
130
+ f"https://api.github.com/user/emails?access_token={access_token}",
131
+ headers=headers).json()
132
+ user_info["email"] = next(
133
+ (email for email in email_info if email['primary'] == True),
134
+ None)["email"]
135
  return user_info
136
 
137
 
 
151
  request_data = request.json
152
  if request_data.get("password"):
153
  new_password = request_data.get("new_password")
154
+ if not check_password_hash(
155
+ current_user.password, decrypt(request_data["password"])):
156
+ return get_json_result(
157
+ data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
158
 
159
+ if new_password:
160
+ update_dict["password"] = generate_password_hash(
161
+ decrypt(new_password))
162
 
163
  for k in request_data.keys():
164
+ if k in ["password", "new_password"]:
165
+ continue
166
  update_dict[k] = request_data[k]
167
 
168
  try:
 
170
  return get_json_result(data=True)
171
  except Exception as e:
172
  stat_logger.exception(e)
173
+ return get_json_result(
174
+ data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
175
 
176
 
177
  @manager.route("/info", methods=["GET"])
 
192
  except Exception as e:
193
  pass
194
  try:
195
+ TenantLLM.delete().where(TenantLLM.tenant_id == user_id).excute()
196
  except Exception as e:
197
  pass
198
 
199
+
200
  def user_register(user_id, user):
201
  user["id"] = user_id
202
  tenant = {
 
216
  }
217
  tenant_llm = []
218
  for llm in LLMService.query(fid=LLM_FACTORY):
219
+ tenant_llm.append({"tenant_id": user_id,
220
+ "llm_factory": LLM_FACTORY,
221
+ "llm_name": llm.llm_name,
222
+ "model_type": llm.model_type,
223
+ "api_key": API_KEY})
224
+
225
+ if not UserService.save(**user):
226
+ return
227
  TenantService.insert(**tenant)
228
  UserTenantService.insert(**usr_tenant)
229
  TenantLLMService.insert_many(tenant_llm)
 
235
  def user_add():
236
  req = request.json
237
  if UserService.query(email=req["email"]):
238
+ return get_json_result(
239
+ data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
240
  if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
241
  return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
242
  retcode=RetCode.OPERATING_ERROR)
 
254
  user_id = get_uuid()
255
  try:
256
  users = user_register(user_id, user_dict)
257
+ if not users:
258
+ raise Exception('Register user failure.')
259
+ if len(users) > 1:
260
+ raise Exception('Same E-mail exist!')
261
  user = users[0]
262
  login_user(user)
263
+ return cors_reponse(data=user.to_json(),
264
+ auth=user.get_id(), retmsg="Welcome aboard!")
265
  except Exception as e:
266
  rollback_user_registration(user_id)
267
  stat_logger.exception(e)
268
+ return get_json_result(
269
+ data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
270
 
271
 
272
  @manager.route("/tenant_info", methods=["GET"])
api/db/db_models.py CHANGED
@@ -50,7 +50,13 @@ def singleton(cls, *args, **kw):
50
 
51
 
52
  CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
53
- AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
 
 
 
 
 
 
54
 
55
 
56
  class LongTextField(TextField):
@@ -73,7 +79,8 @@ class JSONField(LongTextField):
73
  def python_value(self, value):
74
  if not value:
75
  return self.default_value
76
- return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
 
77
 
78
 
79
  class ListField(JSONField):
@@ -81,7 +88,8 @@ class ListField(JSONField):
81
 
82
 
83
  class SerializedField(LongTextField):
84
- def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
 
85
  self._serialized_type = serialized_type
86
  self._object_hook = object_hook
87
  self._object_pairs_hook = object_pairs_hook
@@ -95,7 +103,8 @@ class SerializedField(LongTextField):
95
  return None
96
  return utils.json_dumps(value, with_type=True)
97
  else:
98
- raise ValueError(f"the serialized type {self._serialized_type} is not supported")
 
99
 
100
  def python_value(self, value):
101
  if self._serialized_type == SerializedType.PICKLE:
@@ -103,9 +112,11 @@ class SerializedField(LongTextField):
103
  elif self._serialized_type == SerializedType.JSON:
104
  if value is None:
105
  return {}
106
- return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
 
107
  else:
108
- raise ValueError(f"the serialized type {self._serialized_type} is not supported")
 
109
 
110
 
111
  def is_continuous_field(cls: typing.Type) -> bool:
@@ -150,7 +161,8 @@ class BaseModel(Model):
150
  model_dict = self.__dict__['__data__']
151
 
152
  if not only_primary_with:
153
- return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
 
154
 
155
  human_model_dict = {}
156
  for k in self._meta.primary_key.field_names:
@@ -184,17 +196,22 @@ class BaseModel(Model):
184
  if is_continuous_field(type(getattr(cls, attr_name))):
185
  if len(f_v) == 2:
186
  for i, v in enumerate(f_v):
187
- if isinstance(v, str) and f_n in auto_date_timestamp_field():
 
188
  # time type: %Y-%m-%d %H:%M:%S
189
  f_v[i] = utils.date_string_to_timestamp(v)
190
  lt_value = f_v[0]
191
  gt_value = f_v[1]
192
  if lt_value is not None and gt_value is not None:
193
- filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
 
 
194
  elif lt_value is not None:
195
- filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
 
196
  elif gt_value is not None:
197
- filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
 
198
  else:
199
  filters.append(operator.attrgetter(attr_name)(cls) << f_v)
200
  else:
@@ -205,9 +222,11 @@ class BaseModel(Model):
205
  if not order_by or not hasattr(cls, f"{order_by}"):
206
  order_by = "create_time"
207
  if reverse is True:
208
- query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
 
209
  elif reverse is False:
210
- query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
 
211
  return [query_record for query_record in query_records]
212
  else:
213
  return []
@@ -215,7 +234,8 @@ class BaseModel(Model):
215
  @classmethod
216
  def insert(cls, __data=None, **insert):
217
  if isinstance(__data, dict) and __data:
218
- __data[cls._meta.combined["create_time"]] = utils.current_timestamp()
 
219
  if insert:
220
  insert["create_time"] = utils.current_timestamp()
221
 
@@ -228,7 +248,8 @@ class BaseModel(Model):
228
  if not normalized:
229
  return {}
230
 
231
- normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
 
232
 
233
  for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
234
  if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
@@ -241,7 +262,8 @@ class BaseModel(Model):
241
 
242
 
243
  class JsonSerializedField(SerializedField):
244
- def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
 
245
  super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
246
  object_pairs_hook=object_pairs_hook, **kwargs)
247
 
@@ -251,7 +273,8 @@ class BaseDataBase:
251
  def __init__(self):
252
  database_config = DATABASE.copy()
253
  db_name = database_config.pop("name")
254
- self.database_connection = PooledMySQLDatabase(db_name, **database_config)
 
255
  stat_logger.info('init mysql database on cluster mode successfully')
256
 
257
 
@@ -263,7 +286,8 @@ class DatabaseLock:
263
 
264
  def lock(self):
265
  # SQL parameters only support %s format placeholders
266
- cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
 
267
  ret = cursor.fetchone()
268
  if ret[0] == 0:
269
  raise Exception(f'acquire mysql lock {self.lock_name} timeout')
@@ -273,10 +297,12 @@ class DatabaseLock:
273
  raise Exception(f'failed to acquire lock {self.lock_name}')
274
 
275
  def unlock(self):
276
- cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
 
277
  ret = cursor.fetchone()
278
  if ret[0] == 0:
279
- raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
 
280
  elif ret[0] == 1:
281
  return True
282
  else:
@@ -350,17 +376,37 @@ class User(DataBaseModel, UserMixin):
350
  access_token = CharField(max_length=255, null=True)
351
  nickname = CharField(max_length=100, null=False, help_text="nicky name")
352
  password = CharField(max_length=255, null=True, help_text="password")
353
- email = CharField(max_length=255, null=False, help_text="email", index=True)
 
 
 
 
354
  avatar = TextField(null=True, help_text="avatar base64 string")
355
- language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
356
- color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright")
357
- timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai")
 
 
 
 
 
 
 
 
 
 
 
 
358
  last_login_time = DateTimeField(null=True)
359
  is_authenticated = CharField(max_length=1, null=False, default="1")
360
  is_active = CharField(max_length=1, null=False, default="1")
361
  is_anonymous = CharField(max_length=1, null=False, default="0")
362
  login_channel = CharField(null=True, help_text="from which user login")
363
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
364
  is_superuser = BooleanField(null=True, help_text="is root", default=False)
365
 
366
  def __str__(self):
@@ -379,12 +425,28 @@ class Tenant(DataBaseModel):
379
  name = CharField(max_length=100, null=True, help_text="Tenant name")
380
  public_key = CharField(max_length=255, null=True)
381
  llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
382
- embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
383
- asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
384
- img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
385
- parser_ids = CharField(max_length=256, null=False, help_text="document processors")
 
 
 
 
 
 
 
 
 
 
 
 
386
  credit = IntegerField(default=512)
387
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
388
 
389
  class Meta:
390
  db_table = "tenant"
@@ -396,7 +458,11 @@ class UserTenant(DataBaseModel):
396
  tenant_id = CharField(max_length=32, null=False)
397
  role = CharField(max_length=32, null=False, help_text="UserTenantRole")
398
  invited_by = CharField(max_length=32, null=False)
399
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
400
 
401
  class Meta:
402
  db_table = "user_tenant"
@@ -408,17 +474,32 @@ class InvitationCode(DataBaseModel):
408
  visit_time = DateTimeField(null=True)
409
  user_id = CharField(max_length=32, null=True)
410
  tenant_id = CharField(max_length=32, null=True)
411
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
412
 
413
  class Meta:
414
  db_table = "invitation_code"
415
 
416
 
417
  class LLMFactories(DataBaseModel):
418
- name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
 
 
 
 
419
  logo = TextField(null=True, help_text="llm logo base64")
420
- tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
421
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
 
 
 
422
 
423
  def __str__(self):
424
  return self.name
@@ -429,12 +510,27 @@ class LLMFactories(DataBaseModel):
429
 
430
  class LLM(DataBaseModel):
431
  # LLMs dictionary
432
- llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True)
433
- model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
 
 
 
 
 
 
 
 
434
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
435
  max_tokens = IntegerField(default=0)
436
- tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
437
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
 
 
 
438
 
439
  def __str__(self):
440
  return self.llm_name
@@ -445,9 +541,19 @@ class LLM(DataBaseModel):
445
 
446
  class TenantLLM(DataBaseModel):
447
  tenant_id = CharField(max_length=32, null=False)
448
- llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
449
- model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
450
- llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
 
 
 
 
 
 
 
 
 
 
451
  api_key = CharField(max_length=255, null=True, help_text="API KEY")
452
  api_base = CharField(max_length=255, null=True, help_text="API Base")
453
  used_tokens = IntegerField(default=0)
@@ -464,11 +570,26 @@ class Knowledgebase(DataBaseModel):
464
  id = CharField(max_length=32, primary_key=True)
465
  avatar = TextField(null=True, help_text="avatar base64 string")
466
  tenant_id = CharField(max_length=32, null=False)
467
- name = CharField(max_length=128, null=False, help_text="KB name", index=True)
468
- language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
 
 
 
 
 
 
 
 
469
  description = TextField(null=True, help_text="KB description")
470
- embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
471
- permission = CharField(max_length=16, null=False, help_text="me|team", default="me")
 
 
 
 
 
 
 
472
  created_by = CharField(max_length=32, null=False)
473
  doc_num = IntegerField(default=0)
474
  token_num = IntegerField(default=0)
@@ -476,9 +597,17 @@ class Knowledgebase(DataBaseModel):
476
  similarity_threshold = FloatField(default=0.2)
477
  vector_similarity_weight = FloatField(default=0.3)
478
 
479
- parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
480
- parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
481
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
 
 
 
 
482
 
483
  def __str__(self):
484
  return self.name
@@ -491,22 +620,50 @@ class Document(DataBaseModel):
491
  id = CharField(max_length=32, primary_key=True)
492
  thumbnail = TextField(null=True, help_text="thumbnail base64 string")
493
  kb_id = CharField(max_length=256, null=False, index=True)
494
- parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
495
- parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
496
- source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
 
 
 
 
 
 
 
497
  type = CharField(max_length=32, null=False, help_text="file extension")
498
- created_by = CharField(max_length=32, null=False, help_text="who created it")
499
- name = CharField(max_length=255, null=True, help_text="file name", index=True)
500
- location = CharField(max_length=255, null=True, help_text="where dose it store")
 
 
 
 
 
 
 
 
 
 
501
  size = IntegerField(default=0)
502
  token_num = IntegerField(default=0)
503
  chunk_num = IntegerField(default=0)
504
  progress = FloatField(default=0)
505
- progress_msg = TextField(null=True, help_text="process message", default="")
 
 
 
506
  process_begin_at = DateTimeField(null=True)
507
  process_duation = FloatField(default=0)
508
- run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
509
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
 
 
 
 
510
 
511
  class Meta:
512
  db_table = "document"
@@ -520,30 +677,52 @@ class Task(DataBaseModel):
520
  begin_at = DateTimeField(null=True)
521
  process_duation = FloatField(default=0)
522
  progress = FloatField(default=0)
523
- progress_msg = TextField(null=True, help_text="process message", default="")
 
 
 
524
 
525
 
526
  class Dialog(DataBaseModel):
527
  id = CharField(max_length=32, primary_key=True)
528
  tenant_id = CharField(max_length=32, null=False)
529
- name = CharField(max_length=255, null=True, help_text="dialog application name")
 
 
 
530
  description = TextField(null=True, help_text="Dialog description")
531
  icon = TextField(null=True, help_text="icon base64 string")
532
- language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
 
 
 
 
533
  llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
534
  llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
535
  "presence_penalty": 0.4, "max_tokens": 215})
536
- prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
 
 
 
 
537
  prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
538
  "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
539
 
540
  similarity_threshold = FloatField(default=0.2)
541
  vector_similarity_weight = FloatField(default=0.3)
542
  top_n = IntegerField(default=6)
543
- do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1")
 
 
 
 
544
 
545
  kb_ids = JSONField(null=False, default=[])
546
- status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
 
 
 
547
 
548
  class Meta:
549
  db_table = "dialog"
 
50
 
51
 
52
  CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
53
+ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
54
+ "create",
55
+ "start",
56
+ "end",
57
+ "update",
58
+ "read_access",
59
+ "write_access"}
60
 
61
 
62
  class LongTextField(TextField):
 
79
  def python_value(self, value):
80
  if not value:
81
  return self.default_value
82
+ return utils.json_loads(
83
+ value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
84
 
85
 
86
  class ListField(JSONField):
 
88
 
89
 
90
  class SerializedField(LongTextField):
91
+ def __init__(self, serialized_type=SerializedType.PICKLE,
92
+ object_hook=None, object_pairs_hook=None, **kwargs):
93
  self._serialized_type = serialized_type
94
  self._object_hook = object_hook
95
  self._object_pairs_hook = object_pairs_hook
 
103
  return None
104
  return utils.json_dumps(value, with_type=True)
105
  else:
106
+ raise ValueError(
107
+ f"the serialized type {self._serialized_type} is not supported")
108
 
109
  def python_value(self, value):
110
  if self._serialized_type == SerializedType.PICKLE:
 
112
  elif self._serialized_type == SerializedType.JSON:
113
  if value is None:
114
  return {}
115
+ return utils.json_loads(
116
+ value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
117
  else:
118
+ raise ValueError(
119
+ f"the serialized type {self._serialized_type} is not supported")
120
 
121
 
122
  def is_continuous_field(cls: typing.Type) -> bool:
 
161
  model_dict = self.__dict__['__data__']
162
 
163
  if not only_primary_with:
164
+ return {remove_field_name_prefix(
165
+ k): v for k, v in model_dict.items()}
166
 
167
  human_model_dict = {}
168
  for k in self._meta.primary_key.field_names:
 
196
  if is_continuous_field(type(getattr(cls, attr_name))):
197
  if len(f_v) == 2:
198
  for i, v in enumerate(f_v):
199
+ if isinstance(
200
+ v, str) and f_n in auto_date_timestamp_field():
201
  # time type: %Y-%m-%d %H:%M:%S
202
  f_v[i] = utils.date_string_to_timestamp(v)
203
  lt_value = f_v[0]
204
  gt_value = f_v[1]
205
  if lt_value is not None and gt_value is not None:
206
+ filters.append(
207
+ cls.getter_by(attr_name).between(
208
+ lt_value, gt_value))
209
  elif lt_value is not None:
210
+ filters.append(
211
+ operator.attrgetter(attr_name)(cls) >= lt_value)
212
  elif gt_value is not None:
213
+ filters.append(
214
+ operator.attrgetter(attr_name)(cls) <= gt_value)
215
  else:
216
  filters.append(operator.attrgetter(attr_name)(cls) << f_v)
217
  else:
 
222
  if not order_by or not hasattr(cls, f"{order_by}"):
223
  order_by = "create_time"
224
  if reverse is True:
225
+ query_records = query_records.order_by(
226
+ cls.getter_by(f"{order_by}").desc())
227
  elif reverse is False:
228
+ query_records = query_records.order_by(
229
+ cls.getter_by(f"{order_by}").asc())
230
  return [query_record for query_record in query_records]
231
  else:
232
  return []
 
234
  @classmethod
235
  def insert(cls, __data=None, **insert):
236
  if isinstance(__data, dict) and __data:
237
+ __data[cls._meta.combined["create_time"]
238
+ ] = utils.current_timestamp()
239
  if insert:
240
  insert["create_time"] = utils.current_timestamp()
241
 
 
248
  if not normalized:
249
  return {}
250
 
251
+ normalized[cls._meta.combined["update_time"]
252
+ ] = utils.current_timestamp()
253
 
254
  for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
255
  if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
 
262
 
263
 
264
  class JsonSerializedField(SerializedField):
265
+ def __init__(self, object_hook=utils.from_dict_hook,
266
+ object_pairs_hook=None, **kwargs):
267
  super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
268
  object_pairs_hook=object_pairs_hook, **kwargs)
269
 
 
273
  def __init__(self):
274
  database_config = DATABASE.copy()
275
  db_name = database_config.pop("name")
276
+ self.database_connection = PooledMySQLDatabase(
277
+ db_name, **database_config)
278
  stat_logger.info('init mysql database on cluster mode successfully')
279
 
280
 
 
286
 
287
  def lock(self):
288
  # SQL parameters only support %s format placeholders
289
+ cursor = self.db.execute_sql(
290
+ "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
291
  ret = cursor.fetchone()
292
  if ret[0] == 0:
293
  raise Exception(f'acquire mysql lock {self.lock_name} timeout')
 
297
  raise Exception(f'failed to acquire lock {self.lock_name}')
298
 
299
  def unlock(self):
300
+ cursor = self.db.execute_sql(
301
+ "SELECT RELEASE_LOCK(%s)", (self.lock_name,))
302
  ret = cursor.fetchone()
303
  if ret[0] == 0:
304
+ raise Exception(
305
+ f'mysql lock {self.lock_name} was not established by this thread')
306
  elif ret[0] == 1:
307
  return True
308
  else:
 
376
  access_token = CharField(max_length=255, null=True)
377
  nickname = CharField(max_length=100, null=False, help_text="nicky name")
378
  password = CharField(max_length=255, null=True, help_text="password")
379
+ email = CharField(
380
+ max_length=255,
381
+ null=False,
382
+ help_text="email",
383
+ index=True)
384
  avatar = TextField(null=True, help_text="avatar base64 string")
385
+ language = CharField(
386
+ max_length=32,
387
+ null=True,
388
+ help_text="English|Chinese",
389
+ default="Chinese")
390
+ color_schema = CharField(
391
+ max_length=32,
392
+ null=True,
393
+ help_text="Bright|Dark",
394
+ default="Bright")
395
+ timezone = CharField(
396
+ max_length=64,
397
+ null=True,
398
+ help_text="Timezone",
399
+ default="UTC+8\tAsia/Shanghai")
400
  last_login_time = DateTimeField(null=True)
401
  is_authenticated = CharField(max_length=1, null=False, default="1")
402
  is_active = CharField(max_length=1, null=False, default="1")
403
  is_anonymous = CharField(max_length=1, null=False, default="0")
404
  login_channel = CharField(null=True, help_text="from which user login")
405
+ status = CharField(
406
+ max_length=1,
407
+ null=True,
408
+ help_text="is it validate(0: wasted,1: validate)",
409
+ default="1")
410
  is_superuser = BooleanField(null=True, help_text="is root", default=False)
411
 
412
  def __str__(self):
 
425
  name = CharField(max_length=100, null=True, help_text="Tenant name")
426
  public_key = CharField(max_length=255, null=True)
427
  llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
428
+ embd_id = CharField(
429
+ max_length=128,
430
+ null=False,
431
+ help_text="default embedding model ID")
432
+ asr_id = CharField(
433
+ max_length=128,
434
+ null=False,
435
+ help_text="default ASR model ID")
436
+ img2txt_id = CharField(
437
+ max_length=128,
438
+ null=False,
439
+ help_text="default image to text model ID")
440
+ parser_ids = CharField(
441
+ max_length=256,
442
+ null=False,
443
+ help_text="document processors")
444
  credit = IntegerField(default=512)
445
+ status = CharField(
446
+ max_length=1,
447
+ null=True,
448
+ help_text="is it validate(0: wasted,1: validate)",
449
+ default="1")
450
 
451
  class Meta:
452
  db_table = "tenant"
 
458
  tenant_id = CharField(max_length=32, null=False)
459
  role = CharField(max_length=32, null=False, help_text="UserTenantRole")
460
  invited_by = CharField(max_length=32, null=False)
461
+ status = CharField(
462
+ max_length=1,
463
+ null=True,
464
+ help_text="is it validate(0: wasted,1: validate)",
465
+ default="1")
466
 
467
  class Meta:
468
  db_table = "user_tenant"
 
474
  visit_time = DateTimeField(null=True)
475
  user_id = CharField(max_length=32, null=True)
476
  tenant_id = CharField(max_length=32, null=True)
477
+ status = CharField(
478
+ max_length=1,
479
+ null=True,
480
+ help_text="is it validate(0: wasted,1: validate)",
481
+ default="1")
482
 
483
  class Meta:
484
  db_table = "invitation_code"
485
 
486
 
487
  class LLMFactories(DataBaseModel):
488
+ name = CharField(
489
+ max_length=128,
490
+ null=False,
491
+ help_text="LLM factory name",
492
+ primary_key=True)
493
  logo = TextField(null=True, help_text="llm logo base64")
494
+ tags = CharField(
495
+ max_length=255,
496
+ null=False,
497
+ help_text="LLM, Text Embedding, Image2Text, ASR")
498
+ status = CharField(
499
+ max_length=1,
500
+ null=True,
501
+ help_text="is it validate(0: wasted,1: validate)",
502
+ default="1")
503
 
504
  def __str__(self):
505
  return self.name
 
510
 
511
  class LLM(DataBaseModel):
512
  # LLMs dictionary
513
+ llm_name = CharField(
514
+ max_length=128,
515
+ null=False,
516
+ help_text="LLM name",
517
+ index=True,
518
+ primary_key=True)
519
+ model_type = CharField(
520
+ max_length=128,
521
+ null=False,
522
+ help_text="LLM, Text Embedding, Image2Text, ASR")
523
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
524
  max_tokens = IntegerField(default=0)
525
+ tags = CharField(
526
+ max_length=255,
527
+ null=False,
528
+ help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
529
+ status = CharField(
530
+ max_length=1,
531
+ null=True,
532
+ help_text="is it validate(0: wasted,1: validate)",
533
+ default="1")
534
 
535
  def __str__(self):
536
  return self.llm_name
 
541
 
542
  class TenantLLM(DataBaseModel):
543
  tenant_id = CharField(max_length=32, null=False)
544
+ llm_factory = CharField(
545
+ max_length=128,
546
+ null=False,
547
+ help_text="LLM factory name")
548
+ model_type = CharField(
549
+ max_length=128,
550
+ null=True,
551
+ help_text="LLM, Text Embedding, Image2Text, ASR")
552
+ llm_name = CharField(
553
+ max_length=128,
554
+ null=True,
555
+ help_text="LLM name",
556
+ default="")
557
  api_key = CharField(max_length=255, null=True, help_text="API KEY")
558
  api_base = CharField(max_length=255, null=True, help_text="API Base")
559
  used_tokens = IntegerField(default=0)
 
570
  id = CharField(max_length=32, primary_key=True)
571
  avatar = TextField(null=True, help_text="avatar base64 string")
572
  tenant_id = CharField(max_length=32, null=False)
573
+ name = CharField(
574
+ max_length=128,
575
+ null=False,
576
+ help_text="KB name",
577
+ index=True)
578
+ language = CharField(
579
+ max_length=32,
580
+ null=True,
581
+ default="Chinese",
582
+ help_text="English|Chinese")
583
  description = TextField(null=True, help_text="KB description")
584
+ embd_id = CharField(
585
+ max_length=128,
586
+ null=False,
587
+ help_text="default embedding model ID")
588
+ permission = CharField(
589
+ max_length=16,
590
+ null=False,
591
+ help_text="me|team",
592
+ default="me")
593
  created_by = CharField(max_length=32, null=False)
594
  doc_num = IntegerField(default=0)
595
  token_num = IntegerField(default=0)
 
597
  similarity_threshold = FloatField(default=0.2)
598
  vector_similarity_weight = FloatField(default=0.3)
599
 
600
+ parser_id = CharField(
601
+ max_length=32,
602
+ null=False,
603
+ help_text="default parser ID",
604
+ default=ParserType.NAIVE.value)
605
+ parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
606
+ status = CharField(
607
+ max_length=1,
608
+ null=True,
609
+ help_text="is it validate(0: wasted,1: validate)",
610
+ default="1")
611
 
612
  def __str__(self):
613
  return self.name
 
620
  id = CharField(max_length=32, primary_key=True)
621
  thumbnail = TextField(null=True, help_text="thumbnail base64 string")
622
  kb_id = CharField(max_length=256, null=False, index=True)
623
+ parser_id = CharField(
624
+ max_length=32,
625
+ null=False,
626
+ help_text="default parser ID")
627
+ parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
628
+ source_type = CharField(
629
+ max_length=128,
630
+ null=False,
631
+ default="local",
632
+ help_text="where dose this document from")
633
  type = CharField(max_length=32, null=False, help_text="file extension")
634
+ created_by = CharField(
635
+ max_length=32,
636
+ null=False,
637
+ help_text="who created it")
638
+ name = CharField(
639
+ max_length=255,
640
+ null=True,
641
+ help_text="file name",
642
+ index=True)
643
+ location = CharField(
644
+ max_length=255,
645
+ null=True,
646
+ help_text="where dose it store")
647
  size = IntegerField(default=0)
648
  token_num = IntegerField(default=0)
649
  chunk_num = IntegerField(default=0)
650
  progress = FloatField(default=0)
651
+ progress_msg = TextField(
652
+ null=True,
653
+ help_text="process message",
654
+ default="")
655
  process_begin_at = DateTimeField(null=True)
656
  process_duation = FloatField(default=0)
657
+ run = CharField(
658
+ max_length=1,
659
+ null=True,
660
+ help_text="start to run processing or cancel.(1: run it; 2: cancel)",
661
+ default="0")
662
+ status = CharField(
663
+ max_length=1,
664
+ null=True,
665
+ help_text="is it validate(0: wasted,1: validate)",
666
+ default="1")
667
 
668
  class Meta:
669
  db_table = "document"
 
677
  begin_at = DateTimeField(null=True)
678
  process_duation = FloatField(default=0)
679
  progress = FloatField(default=0)
680
+ progress_msg = TextField(
681
+ null=True,
682
+ help_text="process message",
683
+ default="")
684
 
685
 
686
  class Dialog(DataBaseModel):
687
  id = CharField(max_length=32, primary_key=True)
688
  tenant_id = CharField(max_length=32, null=False)
689
+ name = CharField(
690
+ max_length=255,
691
+ null=True,
692
+ help_text="dialog application name")
693
  description = TextField(null=True, help_text="Dialog description")
694
  icon = TextField(null=True, help_text="icon base64 string")
695
+ language = CharField(
696
+ max_length=32,
697
+ null=True,
698
+ default="Chinese",
699
+ help_text="English|Chinese")
700
  llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
701
  llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
702
  "presence_penalty": 0.4, "max_tokens": 215})
703
+ prompt_type = CharField(
704
+ max_length=16,
705
+ null=False,
706
+ default="simple",
707
+ help_text="simple|advanced")
708
  prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
709
  "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
710
 
711
  similarity_threshold = FloatField(default=0.2)
712
  vector_similarity_weight = FloatField(default=0.3)
713
  top_n = IntegerField(default=6)
714
+ do_refer = CharField(
715
+ max_length=1,
716
+ null=False,
717
+ help_text="it needs to insert reference index into answer or not",
718
+ default="1")
719
 
720
  kb_ids = JSONField(null=False, default=[])
721
+ status = CharField(
722
+ max_length=1,
723
+ null=True,
724
+ help_text="is it validate(0: wasted,1: validate)",
725
+ default="1")
726
 
727
  class Meta:
728
  db_table = "dialog"
api/db/db_utils.py CHANGED
@@ -32,8 +32,7 @@ LOGGER = getLogger()
32
  def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
33
  DB.create_tables([model])
34
 
35
-
36
- for i,data in enumerate(data_source):
37
  current_time = current_timestamp() + i
38
  current_date = timestamp_to_date(current_time)
39
  if 'create_time' not in data:
@@ -55,7 +54,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
55
 
56
 
57
  def get_dynamic_db_model(base, job_id):
58
- return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
 
59
 
60
 
61
  def get_dynamic_tracking_table_index(job_id):
@@ -86,7 +86,9 @@ supported_operators = {
86
  '~': operator.inv,
87
  }
88
 
89
- def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
 
 
90
  expression = []
91
 
92
  for field, value in query.items():
@@ -95,7 +97,10 @@ def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[boo
95
  op, *val = value
96
 
97
  field = getattr(model, f'f_{field}')
98
- value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
 
 
 
99
  expression.append(value)
100
 
101
  return reduce(operator.iand, expression)
 
32
  def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
33
  DB.create_tables([model])
34
 
35
+ for i, data in enumerate(data_source):
 
36
  current_time = current_timestamp() + i
37
  current_date = timestamp_to_date(current_time)
38
  if 'create_time' not in data:
 
54
 
55
 
56
  def get_dynamic_db_model(base, job_id):
57
+ return type(base.model(
58
+ table_index=get_dynamic_tracking_table_index(job_id=job_id)))
59
 
60
 
61
  def get_dynamic_tracking_table_index(job_id):
 
86
  '~': operator.inv,
87
  }
88
 
89
+
90
+ def query_dict2expression(
91
+ model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
92
  expression = []
93
 
94
  for field, value in query.items():
 
97
  op, *val = value
98
 
99
  field = getattr(model, f'f_{field}')
100
+ value = supported_operators[op](
101
+ field, val[0]) if op in supported_operators else getattr(
102
+ field, op)(
103
+ *val)
104
  expression.append(value)
105
 
106
  return reduce(operator.iand, expression)
api/db/init_data.py CHANGED
@@ -61,45 +61,54 @@ def init_superuser():
61
  TenantService.insert(**tenant)
62
  UserTenantService.insert(**usr_tenant)
63
  TenantLLMService.insert_many(tenant_llm)
64
- print("【INFO】Super user initialized. \033[93memail: [email protected], password: admin\033[0m. Changing the password after logining is strongly recomanded.")
 
65
 
66
  chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
67
- msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
 
68
  if msg.find("ERROR: ") == 0:
69
- print("\33[91m【ERROR】\33[0m: ", "'{}' dosen't work. {}".format(tenant["llm_id"], msg))
 
 
 
 
70
  embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
71
  v, c = embd_mdl.encode(["Hello!"])
72
  if c == 0:
73
- print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"]))
 
 
 
74
 
75
 
76
  factory_infos = [{
77
- "name": "OpenAI",
78
- "logo": "",
79
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  "status": "1",
81
- },{
82
- "name": "Tongyi-Qianwen",
83
- "logo": "",
84
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
85
- "status": "1",
86
- },{
87
- "name": "ZHIPU-AI",
88
- "logo": "",
89
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
90
- "status": "1",
91
- },
92
- {
93
- "name": "Local",
94
- "logo": "",
95
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
96
- "status": "1",
97
- },{
98
  "name": "Moonshot",
99
- "logo": "",
100
- "tags": "LLM,TEXT EMBEDDING",
101
- "status": "1",
102
- }
103
  # {
104
  # "name": "文心一言",
105
  # "logo": "",
@@ -107,6 +116,8 @@ factory_infos = [{
107
  # "status": "1",
108
  # },
109
  ]
 
 
110
  def init_llm_factory():
111
  llm_infos = [
112
  # ---------------------- OpenAI ------------------------
@@ -116,37 +127,37 @@ def init_llm_factory():
116
  "tags": "LLM,CHAT,4K",
117
  "max_tokens": 4096,
118
  "model_type": LLMType.CHAT.value
119
- },{
120
  "fid": factory_infos[0]["name"],
121
  "llm_name": "gpt-3.5-turbo-16k-0613",
122
  "tags": "LLM,CHAT,16k",
123
  "max_tokens": 16385,
124
  "model_type": LLMType.CHAT.value
125
- },{
126
  "fid": factory_infos[0]["name"],
127
  "llm_name": "text-embedding-ada-002",
128
  "tags": "TEXT EMBEDDING,8K",
129
  "max_tokens": 8191,
130
  "model_type": LLMType.EMBEDDING.value
131
- },{
132
  "fid": factory_infos[0]["name"],
133
  "llm_name": "whisper-1",
134
  "tags": "SPEECH2TEXT",
135
- "max_tokens": 25*1024*1024,
136
  "model_type": LLMType.SPEECH2TEXT.value
137
- },{
138
  "fid": factory_infos[0]["name"],
139
  "llm_name": "gpt-4",
140
  "tags": "LLM,CHAT,8K",
141
  "max_tokens": 8191,
142
  "model_type": LLMType.CHAT.value
143
- },{
144
  "fid": factory_infos[0]["name"],
145
  "llm_name": "gpt-4-32k",
146
  "tags": "LLM,CHAT,32K",
147
  "max_tokens": 32768,
148
  "model_type": LLMType.CHAT.value
149
- },{
150
  "fid": factory_infos[0]["name"],
151
  "llm_name": "gpt-4-vision-preview",
152
  "tags": "LLM,CHAT,IMAGE2TEXT",
@@ -160,31 +171,31 @@ def init_llm_factory():
160
  "tags": "LLM,CHAT,8K",
161
  "max_tokens": 8191,
162
  "model_type": LLMType.CHAT.value
163
- },{
164
  "fid": factory_infos[1]["name"],
165
  "llm_name": "qwen-plus",
166
  "tags": "LLM,CHAT,32K",
167
  "max_tokens": 32768,
168
  "model_type": LLMType.CHAT.value
169
- },{
170
  "fid": factory_infos[1]["name"],
171
  "llm_name": "qwen-max-1201",
172
  "tags": "LLM,CHAT,6K",
173
  "max_tokens": 5899,
174
  "model_type": LLMType.CHAT.value
175
- },{
176
  "fid": factory_infos[1]["name"],
177
  "llm_name": "text-embedding-v2",
178
  "tags": "TEXT EMBEDDING,2K",
179
  "max_tokens": 2048,
180
  "model_type": LLMType.EMBEDDING.value
181
- },{
182
  "fid": factory_infos[1]["name"],
183
  "llm_name": "paraformer-realtime-8k-v1",
184
  "tags": "SPEECH2TEXT",
185
- "max_tokens": 25*1024*1024,
186
  "model_type": LLMType.SPEECH2TEXT.value
187
- },{
188
  "fid": factory_infos[1]["name"],
189
  "llm_name": "qwen-vl-max",
190
  "tags": "LLM,CHAT,IMAGE2TEXT",
@@ -245,13 +256,13 @@ def init_llm_factory():
245
  "tags": "TEXT EMBEDDING,",
246
  "max_tokens": 128 * 1000,
247
  "model_type": LLMType.EMBEDDING.value
248
- },{
249
  "fid": factory_infos[4]["name"],
250
  "llm_name": "moonshot-v1-32k",
251
  "tags": "LLM,CHAT,",
252
  "max_tokens": 32768,
253
  "model_type": LLMType.CHAT.value
254
- },{
255
  "fid": factory_infos[4]["name"],
256
  "llm_name": "moonshot-v1-128k",
257
  "tags": "LLM,CHAT",
@@ -294,7 +305,6 @@ def init_web_data():
294
  print("init web data success:{}".format(time.time() - start_time))
295
 
296
 
297
-
298
  if __name__ == '__main__':
299
  init_web_db()
300
- init_web_data()
 
61
  TenantService.insert(**tenant)
62
  UserTenantService.insert(**usr_tenant)
63
  TenantLLMService.insert_many(tenant_llm)
64
+ print(
65
+ "【INFO】Super user initialized. \033[93memail: [email protected], password: admin\033[0m. Changing the password after logining is strongly recomanded.")
66
 
67
  chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
68
+ msg = chat_mdl.chat(system="", history=[
69
+ {"role": "user", "content": "Hello!"}], gen_conf={})
70
  if msg.find("ERROR: ") == 0:
71
+ print(
72
+ "\33[91m【ERROR】\33[0m: ",
73
+ "'{}' dosen't work. {}".format(
74
+ tenant["llm_id"],
75
+ msg))
76
  embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
77
  v, c = embd_mdl.encode(["Hello!"])
78
  if c == 0:
79
+ print(
80
+ "\33[91m【ERROR】\33[0m:",
81
+ " '{}' dosen't work!".format(
82
+ tenant["embd_id"]))
83
 
84
 
85
  factory_infos = [{
86
+ "name": "OpenAI",
87
+ "logo": "",
88
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
89
+ "status": "1",
90
+ }, {
91
+ "name": "Tongyi-Qianwen",
92
+ "logo": "",
93
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
94
+ "status": "1",
95
+ }, {
96
+ "name": "ZHIPU-AI",
97
+ "logo": "",
98
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
99
+ "status": "1",
100
+ },
101
+ {
102
+ "name": "Local",
103
+ "logo": "",
104
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
105
  "status": "1",
106
+ }, {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  "name": "Moonshot",
108
+ "logo": "",
109
+ "tags": "LLM,TEXT EMBEDDING",
110
+ "status": "1",
111
+ }
112
  # {
113
  # "name": "文心一言",
114
  # "logo": "",
 
116
  # "status": "1",
117
  # },
118
  ]
119
+
120
+
121
  def init_llm_factory():
122
  llm_infos = [
123
  # ---------------------- OpenAI ------------------------
 
127
  "tags": "LLM,CHAT,4K",
128
  "max_tokens": 4096,
129
  "model_type": LLMType.CHAT.value
130
+ }, {
131
  "fid": factory_infos[0]["name"],
132
  "llm_name": "gpt-3.5-turbo-16k-0613",
133
  "tags": "LLM,CHAT,16k",
134
  "max_tokens": 16385,
135
  "model_type": LLMType.CHAT.value
136
+ }, {
137
  "fid": factory_infos[0]["name"],
138
  "llm_name": "text-embedding-ada-002",
139
  "tags": "TEXT EMBEDDING,8K",
140
  "max_tokens": 8191,
141
  "model_type": LLMType.EMBEDDING.value
142
+ }, {
143
  "fid": factory_infos[0]["name"],
144
  "llm_name": "whisper-1",
145
  "tags": "SPEECH2TEXT",
146
+ "max_tokens": 25 * 1024 * 1024,
147
  "model_type": LLMType.SPEECH2TEXT.value
148
+ }, {
149
  "fid": factory_infos[0]["name"],
150
  "llm_name": "gpt-4",
151
  "tags": "LLM,CHAT,8K",
152
  "max_tokens": 8191,
153
  "model_type": LLMType.CHAT.value
154
+ }, {
155
  "fid": factory_infos[0]["name"],
156
  "llm_name": "gpt-4-32k",
157
  "tags": "LLM,CHAT,32K",
158
  "max_tokens": 32768,
159
  "model_type": LLMType.CHAT.value
160
+ }, {
161
  "fid": factory_infos[0]["name"],
162
  "llm_name": "gpt-4-vision-preview",
163
  "tags": "LLM,CHAT,IMAGE2TEXT",
 
171
  "tags": "LLM,CHAT,8K",
172
  "max_tokens": 8191,
173
  "model_type": LLMType.CHAT.value
174
+ }, {
175
  "fid": factory_infos[1]["name"],
176
  "llm_name": "qwen-plus",
177
  "tags": "LLM,CHAT,32K",
178
  "max_tokens": 32768,
179
  "model_type": LLMType.CHAT.value
180
+ }, {
181
  "fid": factory_infos[1]["name"],
182
  "llm_name": "qwen-max-1201",
183
  "tags": "LLM,CHAT,6K",
184
  "max_tokens": 5899,
185
  "model_type": LLMType.CHAT.value
186
+ }, {
187
  "fid": factory_infos[1]["name"],
188
  "llm_name": "text-embedding-v2",
189
  "tags": "TEXT EMBEDDING,2K",
190
  "max_tokens": 2048,
191
  "model_type": LLMType.EMBEDDING.value
192
+ }, {
193
  "fid": factory_infos[1]["name"],
194
  "llm_name": "paraformer-realtime-8k-v1",
195
  "tags": "SPEECH2TEXT",
196
+ "max_tokens": 25 * 1024 * 1024,
197
  "model_type": LLMType.SPEECH2TEXT.value
198
+ }, {
199
  "fid": factory_infos[1]["name"],
200
  "llm_name": "qwen-vl-max",
201
  "tags": "LLM,CHAT,IMAGE2TEXT",
 
256
  "tags": "TEXT EMBEDDING,",
257
  "max_tokens": 128 * 1000,
258
  "model_type": LLMType.EMBEDDING.value
259
+ }, {
260
  "fid": factory_infos[4]["name"],
261
  "llm_name": "moonshot-v1-32k",
262
  "tags": "LLM,CHAT,",
263
  "max_tokens": 32768,
264
  "model_type": LLMType.CHAT.value
265
+ }, {
266
  "fid": factory_infos[4]["name"],
267
  "llm_name": "moonshot-v1-128k",
268
  "tags": "LLM,CHAT",
 
305
  print("init web data success:{}".format(time.time() - start_time))
306
 
307
 
 
308
  if __name__ == '__main__':
309
  init_web_db()
310
+ init_web_data()
api/db/operatioins.py CHANGED
@@ -18,4 +18,4 @@ import operator
18
  import time
19
  import typing
20
  from api.utils.log_utils import sql_logger
21
- import peewee
 
18
  import time
19
  import typing
20
  from api.utils.log_utils import sql_logger
21
+ import peewee
api/db/reload_config_base.py CHANGED
@@ -18,10 +18,11 @@ class ReloadConfigBase:
18
  def get_all(cls):
19
  configs = {}
20
  for k, v in cls.__dict__.items():
21
- if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
 
22
  configs[k] = v
23
  return configs
24
 
25
  @classmethod
26
  def get(cls, config_name):
27
- return getattr(cls, config_name) if hasattr(cls, config_name) else None
 
18
  def get_all(cls):
19
  configs = {}
20
  for k, v in cls.__dict__.items():
21
+ if not callable(getattr(cls, k)) and not k.startswith(
22
+ "__") and not k.startswith("_"):
23
  configs[k] = v
24
  return configs
25
 
26
  @classmethod
27
  def get(cls, config_name):
28
+ return getattr(cls, config_name) if hasattr(cls, config_name) else None
api/db/runtime_config.py CHANGED
@@ -51,4 +51,4 @@ class RuntimeConfig(ReloadConfigBase):
51
 
52
  @classmethod
53
  def set_service_db(cls, service_db):
54
- cls.SERVICE_DB = service_db
 
51
 
52
  @classmethod
53
  def set_service_db(cls, service_db):
54
+ cls.SERVICE_DB = service_db
api/db/services/common_service.py CHANGED
@@ -27,7 +27,8 @@ class CommonService:
27
  @classmethod
28
  @DB.connection_context()
29
  def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
30
- return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
 
31
 
32
  @classmethod
33
  @DB.connection_context()
@@ -40,9 +41,11 @@ class CommonService:
40
  if not order_by or not hasattr(cls, order_by):
41
  order_by = "create_time"
42
  if reverse is True:
43
- query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
 
44
  elif reverse is False:
45
- query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
 
46
  return query_records
47
 
48
  @classmethod
@@ -61,7 +64,7 @@ class CommonService:
61
  @classmethod
62
  @DB.connection_context()
63
  def save(cls, **kwargs):
64
- #if "id" not in kwargs:
65
  # kwargs["id"] = get_uuid()
66
  sample_obj = cls.model(**kwargs).save(force_insert=True)
67
  return sample_obj
@@ -95,7 +98,8 @@ class CommonService:
95
  for data in data_list:
96
  data["update_time"] = current_timestamp()
97
  data["update_date"] = datetime_format(datetime.now())
98
- cls.model.update(data).where(cls.model.id == data["id"]).execute()
 
99
 
100
  @classmethod
101
  @DB.connection_context()
@@ -128,7 +132,6 @@ class CommonService:
128
  def delete_by_id(cls, pid):
129
  return cls.model.delete().where(cls.model.id == pid).execute()
130
 
131
-
132
  @classmethod
133
  @DB.connection_context()
134
  def filter_delete(cls, filters):
@@ -151,19 +154,30 @@ class CommonService:
151
 
152
  @classmethod
153
  @DB.connection_context()
154
- def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
 
155
  in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
156
  if not filters:
157
  filters = []
158
  res_list = []
159
  if cols:
160
  for i in in_filters_tuple_list:
161
- query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
 
 
 
 
 
 
 
162
  if query_records:
163
- res_list.extend([query_record for query_record in query_records])
 
164
  else:
165
  for i in in_filters_tuple_list:
166
- query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
 
167
  if query_records:
168
- res_list.extend([query_record for query_record in query_records])
169
- return res_list
 
 
27
  @classmethod
28
  @DB.connection_context()
29
  def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
30
+ return cls.model.query(cols=cols, reverse=reverse,
31
+ order_by=order_by, **kwargs)
32
 
33
  @classmethod
34
  @DB.connection_context()
 
41
  if not order_by or not hasattr(cls, order_by):
42
  order_by = "create_time"
43
  if reverse is True:
44
+ query_records = query_records.order_by(
45
+ cls.model.getter_by(order_by).desc())
46
  elif reverse is False:
47
+ query_records = query_records.order_by(
48
+ cls.model.getter_by(order_by).asc())
49
  return query_records
50
 
51
  @classmethod
 
64
  @classmethod
65
  @DB.connection_context()
66
  def save(cls, **kwargs):
67
+ # if "id" not in kwargs:
68
  # kwargs["id"] = get_uuid()
69
  sample_obj = cls.model(**kwargs).save(force_insert=True)
70
  return sample_obj
 
98
  for data in data_list:
99
  data["update_time"] = current_timestamp()
100
  data["update_date"] = datetime_format(datetime.now())
101
+ cls.model.update(data).where(
102
+ cls.model.id == data["id"]).execute()
103
 
104
  @classmethod
105
  @DB.connection_context()
 
132
  def delete_by_id(cls, pid):
133
  return cls.model.delete().where(cls.model.id == pid).execute()
134
 
 
135
  @classmethod
136
  @DB.connection_context()
137
  def filter_delete(cls, filters):
 
154
 
155
  @classmethod
156
  @DB.connection_context()
157
+ def filter_scope_list(cls, in_key, in_filters_list,
158
+ filters=None, cols=None):
159
  in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
160
  if not filters:
161
  filters = []
162
  res_list = []
163
  if cols:
164
  for i in in_filters_tuple_list:
165
+ query_records = cls.model.select(
166
+ *
167
+ cols).where(
168
+ getattr(
169
+ cls.model,
170
+ in_key).in_(i),
171
+ *
172
+ filters)
173
  if query_records:
174
+ res_list.extend(
175
+ [query_record for query_record in query_records])
176
  else:
177
  for i in in_filters_tuple_list:
178
+ query_records = cls.model.select().where(
179
+ getattr(cls.model, in_key).in_(i), *filters)
180
  if query_records:
181
+ res_list.extend(
182
+ [query_record for query_record in query_records])
183
+ return res_list
api/db/services/dialog_service.py CHANGED
@@ -21,6 +21,5 @@ class DialogService(CommonService):
21
  model = Dialog
22
 
23
 
24
-
25
  class ConversationService(CommonService):
26
  model = Conversation
 
21
  model = Dialog
22
 
23
 
 
24
  class ConversationService(CommonService):
25
  model = Conversation
api/db/services/document_service.py CHANGED
@@ -72,7 +72,20 @@ class DocumentService(CommonService):
72
  @classmethod
73
  @DB.connection_context()
74
  def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
75
- fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  docs = cls.model.select(*fields) \
77
  .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
78
  .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
@@ -103,40 +116,64 @@ class DocumentService(CommonService):
103
  @DB.connection_context()
104
  def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
105
  num = cls.model.update(token_num=cls.model.token_num + token_num,
106
- chunk_num=cls.model.chunk_num + chunk_num,
107
- process_duation=cls.model.process_duation+duation).where(
108
  cls.model.id == doc_id).execute()
109
- if num == 0:raise LookupError("Document not found which is supposed to be there")
110
- num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
 
 
 
 
 
 
 
111
  return num
112
 
113
  @classmethod
114
  @DB.connection_context()
115
  def get_tenant_id(cls, doc_id):
116
- docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
 
 
 
 
117
  docs = docs.dicts()
118
- if not docs:return
 
119
  return docs[0]["tenant_id"]
120
 
121
  @classmethod
122
  @DB.connection_context()
123
  def get_thumbnails(cls, docids):
124
  fields = [cls.model.id, cls.model.thumbnail]
125
- return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
 
126
 
127
  @classmethod
128
  @DB.connection_context()
129
  def update_parser_config(cls, id, config):
130
  e, d = cls.get_by_id(id)
131
- if not e:raise LookupError(f"Document({id}) not found.")
 
 
132
  def dfs_update(old, new):
133
- for k,v in new.items():
134
  if k not in old:
135
  old[k] = v
136
  continue
137
  if isinstance(v, dict):
138
  assert isinstance(old[k], dict)
139
  dfs_update(old[k], v)
140
- else: old[k] = v
 
141
  dfs_update(d.parser_config, config)
142
- cls.update_by_id(id, {"parser_config": d.parser_config})
 
 
 
 
 
 
 
 
 
72
  @classmethod
73
  @DB.connection_context()
74
  def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
75
+ fields = [
76
+ cls.model.id,
77
+ cls.model.kb_id,
78
+ cls.model.parser_id,
79
+ cls.model.parser_config,
80
+ cls.model.name,
81
+ cls.model.type,
82
+ cls.model.location,
83
+ cls.model.size,
84
+ Knowledgebase.tenant_id,
85
+ Tenant.embd_id,
86
+ Tenant.img2txt_id,
87
+ Tenant.asr_id,
88
+ cls.model.update_time]
89
  docs = cls.model.select(*fields) \
90
  .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
91
  .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
 
116
  @DB.connection_context()
117
  def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
118
  num = cls.model.update(token_num=cls.model.token_num + token_num,
119
+ chunk_num=cls.model.chunk_num + chunk_num,
120
+ process_duation=cls.model.process_duation + duation).where(
121
  cls.model.id == doc_id).execute()
122
+ if num == 0:
123
+ raise LookupError(
124
+ "Document not found which is supposed to be there")
125
+ num = Knowledgebase.update(
126
+ token_num=Knowledgebase.token_num +
127
+ token_num,
128
+ chunk_num=Knowledgebase.chunk_num +
129
+ chunk_num).where(
130
+ Knowledgebase.id == kb_id).execute()
131
  return num
132
 
133
  @classmethod
134
  @DB.connection_context()
135
  def get_tenant_id(cls, doc_id):
136
+ docs = cls.model.select(
137
+ Knowledgebase.tenant_id).join(
138
+ Knowledgebase, on=(
139
+ Knowledgebase.id == cls.model.kb_id)).where(
140
+ cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
141
  docs = docs.dicts()
142
+ if not docs:
143
+ return
144
  return docs[0]["tenant_id"]
145
 
146
  @classmethod
147
  @DB.connection_context()
148
  def get_thumbnails(cls, docids):
149
  fields = [cls.model.id, cls.model.thumbnail]
150
+ return list(cls.model.select(
151
+ *fields).where(cls.model.id.in_(docids)).dicts())
152
 
153
  @classmethod
154
  @DB.connection_context()
155
  def update_parser_config(cls, id, config):
156
  e, d = cls.get_by_id(id)
157
+ if not e:
158
+ raise LookupError(f"Document({id}) not found.")
159
+
160
  def dfs_update(old, new):
161
+ for k, v in new.items():
162
  if k not in old:
163
  old[k] = v
164
  continue
165
  if isinstance(v, dict):
166
  assert isinstance(old[k], dict)
167
  dfs_update(old[k], v)
168
+ else:
169
+ old[k] = v
170
  dfs_update(d.parser_config, config)
171
+ cls.update_by_id(id, {"parser_config": d.parser_config})
172
+
173
+ @classmethod
174
+ @DB.connection_context()
175
+ def get_doc_count(cls, tenant_id):
176
+ docs = cls.model.select(cls.model.id).join(Knowledgebase,
177
+ on=(Knowledgebase.id == cls.model.kb_id)).where(
178
+ Knowledgebase.tenant_id == tenant_id)
179
+ return len(docs)
api/db/services/knowledgebase_service.py CHANGED
@@ -55,7 +55,7 @@ class KnowledgebaseService(CommonService):
55
  cls.model.chunk_num,
56
  cls.model.parser_id,
57
  cls.model.parser_config]
58
- kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
59
  (cls.model.id == kb_id),
60
  (cls.model.status == StatusEnum.VALID.value)
61
  )
@@ -69,9 +69,11 @@ class KnowledgebaseService(CommonService):
69
  @DB.connection_context()
70
  def update_parser_config(cls, id, config):
71
  e, m = cls.get_by_id(id)
72
- if not e:raise LookupError(f"knowledgebase({id}) not found.")
 
 
73
  def dfs_update(old, new):
74
- for k,v in new.items():
75
  if k not in old:
76
  old[k] = v
77
  continue
@@ -80,12 +82,12 @@ class KnowledgebaseService(CommonService):
80
  dfs_update(old[k], v)
81
  elif isinstance(v, list):
82
  assert isinstance(old[k], list)
83
- old[k] = list(set(old[k]+v))
84
- else: old[k] = v
 
85
  dfs_update(m.parser_config, config)
86
  cls.update_by_id(id, {"parser_config": m.parser_config})
87
 
88
-
89
  @classmethod
90
  @DB.connection_context()
91
  def get_field_map(cls, ids):
@@ -94,4 +96,3 @@ class KnowledgebaseService(CommonService):
94
  if k.parser_config and "field_map" in k.parser_config:
95
  conf.update(k.parser_config["field_map"])
96
  return conf
97
-
 
55
  cls.model.chunk_num,
56
  cls.model.parser_id,
57
  cls.model.parser_config]
58
+ kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
59
  (cls.model.id == kb_id),
60
  (cls.model.status == StatusEnum.VALID.value)
61
  )
 
69
  @DB.connection_context()
70
  def update_parser_config(cls, id, config):
71
  e, m = cls.get_by_id(id)
72
+ if not e:
73
+ raise LookupError(f"knowledgebase({id}) not found.")
74
+
75
  def dfs_update(old, new):
76
+ for k, v in new.items():
77
  if k not in old:
78
  old[k] = v
79
  continue
 
82
  dfs_update(old[k], v)
83
  elif isinstance(v, list):
84
  assert isinstance(old[k], list)
85
+ old[k] = list(set(old[k] + v))
86
+ else:
87
+ old[k] = v
88
  dfs_update(m.parser_config, config)
89
  cls.update_by_id(id, {"parser_config": m.parser_config})
90
 
 
91
  @classmethod
92
  @DB.connection_context()
93
  def get_field_map(cls, ids):
 
96
  if k.parser_config and "field_map" in k.parser_config:
97
  conf.update(k.parser_config["field_map"])
98
  return conf
 
api/db/services/llm_service.py CHANGED
@@ -59,7 +59,8 @@ class TenantLLMService(CommonService):
59
 
60
  @classmethod
61
  @DB.connection_context()
62
- def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
 
63
  e, tenant = TenantService.get_by_id(tenant_id)
64
  if not e:
65
  raise LookupError("Tenant not found")
@@ -126,29 +127,39 @@ class LLMBundle(object):
126
  self.tenant_id = tenant_id
127
  self.llm_type = llm_type
128
  self.llm_name = llm_name
129
- self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
130
- assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
 
 
131
 
132
  def encode(self, texts: list, batch_size=32):
133
  emd, used_tokens = self.mdl.encode(texts, batch_size)
134
- if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
135
- database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
 
 
136
  return emd, used_tokens
137
 
138
  def encode_queries(self, query: str):
139
  emd, used_tokens = self.mdl.encode_queries(query)
140
- if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
141
- database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
 
 
142
  return emd, used_tokens
143
 
144
  def describe(self, image, max_tokens=300):
145
  txt, used_tokens = self.mdl.describe(image, max_tokens)
146
- if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
147
- database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
 
 
148
  return txt
149
 
150
  def chat(self, system, history, gen_conf):
151
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
152
- if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
153
- database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
 
 
154
  return txt
 
59
 
60
  @classmethod
61
  @DB.connection_context()
62
+ def model_instance(cls, tenant_id, llm_type,
63
+ llm_name=None, lang="Chinese"):
64
  e, tenant = TenantService.get_by_id(tenant_id)
65
  if not e:
66
  raise LookupError("Tenant not found")
 
127
  self.tenant_id = tenant_id
128
  self.llm_type = llm_type
129
  self.llm_name = llm_name
130
+ self.mdl = TenantLLMService.model_instance(
131
+ tenant_id, llm_type, llm_name, lang=lang)
132
+ assert self.mdl, "Can't find mole for {}/{}/{}".format(
133
+ tenant_id, llm_type, llm_name)
134
 
135
  def encode(self, texts: list, batch_size=32):
136
  emd, used_tokens = self.mdl.encode(texts, batch_size)
137
+ if TenantLLMService.increase_usage(
138
+ self.tenant_id, self.llm_type, used_tokens):
139
+ database_logger.error(
140
+ "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
141
  return emd, used_tokens
142
 
143
  def encode_queries(self, query: str):
144
  emd, used_tokens = self.mdl.encode_queries(query)
145
+ if TenantLLMService.increase_usage(
146
+ self.tenant_id, self.llm_type, used_tokens):
147
+ database_logger.error(
148
+ "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
149
  return emd, used_tokens
150
 
151
  def describe(self, image, max_tokens=300):
152
  txt, used_tokens = self.mdl.describe(image, max_tokens)
153
+ if not TenantLLMService.increase_usage(
154
+ self.tenant_id, self.llm_type, used_tokens):
155
+ database_logger.error(
156
+ "Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
157
  return txt
158
 
159
  def chat(self, system, history, gen_conf):
160
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
161
+ if TenantLLMService.increase_usage(
162
+ self.tenant_id, self.llm_type, used_tokens, self.llm_name):
163
+ database_logger.error(
164
+ "Can't update token usage for {}/CHAT".format(self.tenant_id))
165
  return txt
api/db/services/user_service.py CHANGED
@@ -54,7 +54,8 @@ class UserService(CommonService):
54
  if "id" not in kwargs:
55
  kwargs["id"] = get_uuid()
56
  if "password" in kwargs:
57
- kwargs["password"] = generate_password_hash(str(kwargs["password"]))
 
58
 
59
  kwargs["create_time"] = current_timestamp()
60
  kwargs["create_date"] = datetime_format(datetime.now())
@@ -63,12 +64,12 @@ class UserService(CommonService):
63
  obj = cls.model(**kwargs).save(force_insert=True)
64
  return obj
65
 
66
-
67
  @classmethod
68
  @DB.connection_context()
69
  def delete_user(cls, user_ids, update_user_dict):
70
  with DB.atomic():
71
- cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
 
72
 
73
  @classmethod
74
  @DB.connection_context()
@@ -77,7 +78,8 @@ class UserService(CommonService):
77
  if user_dict:
78
  user_dict["update_time"] = current_timestamp()
79
  user_dict["update_date"] = datetime_format(datetime.now())
80
- cls.model.update(user_dict).where(cls.model.id == user_id).execute()
 
81
 
82
 
83
  class TenantService(CommonService):
@@ -86,25 +88,42 @@ class TenantService(CommonService):
86
  @classmethod
87
  @DB.connection_context()
88
  def get_by_user_id(cls, user_id):
89
- fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
90
- return list(cls.model.select(*fields)\
91
- .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
92
- .where(cls.model.status == StatusEnum.VALID.value).dicts())
 
 
 
 
 
 
 
 
93
 
94
  @classmethod
95
  @DB.connection_context()
96
  def get_joined_tenants_by_user_id(cls, user_id):
97
- fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
98
- return list(cls.model.select(*fields)\
99
- .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
100
- .where(cls.model.status == StatusEnum.VALID.value).dicts())
 
 
 
 
 
 
 
101
 
102
  @classmethod
103
  @DB.connection_context()
104
  def decrease(cls, user_id, num):
105
  num = cls.model.update(credit=cls.model.credit - num).where(
106
  cls.model.id == user_id).execute()
107
- if num == 0: raise LookupError("Tenant not found which is supposed to be there")
 
 
108
 
109
  class UserTenantService(CommonService):
110
  model = UserTenant
 
54
  if "id" not in kwargs:
55
  kwargs["id"] = get_uuid()
56
  if "password" in kwargs:
57
+ kwargs["password"] = generate_password_hash(
58
+ str(kwargs["password"]))
59
 
60
  kwargs["create_time"] = current_timestamp()
61
  kwargs["create_date"] = datetime_format(datetime.now())
 
64
  obj = cls.model(**kwargs).save(force_insert=True)
65
  return obj
66
 
 
67
  @classmethod
68
  @DB.connection_context()
69
  def delete_user(cls, user_ids, update_user_dict):
70
  with DB.atomic():
71
+ cls.model.update({"status": 0}).where(
72
+ cls.model.id.in_(user_ids)).execute()
73
 
74
  @classmethod
75
  @DB.connection_context()
 
78
  if user_dict:
79
  user_dict["update_time"] = current_timestamp()
80
  user_dict["update_date"] = datetime_format(datetime.now())
81
+ cls.model.update(user_dict).where(
82
+ cls.model.id == user_id).execute()
83
 
84
 
85
  class TenantService(CommonService):
 
88
  @classmethod
89
  @DB.connection_context()
90
  def get_by_user_id(cls, user_id):
91
+ fields = [
92
+ cls.model.id.alias("tenant_id"),
93
+ cls.model.name,
94
+ cls.model.llm_id,
95
+ cls.model.embd_id,
96
+ cls.model.asr_id,
97
+ cls.model.img2txt_id,
98
+ cls.model.parser_ids,
99
+ UserTenant.role]
100
+ return list(cls.model.select(*fields)
101
+ .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
102
+ .where(cls.model.status == StatusEnum.VALID.value).dicts())
103
 
104
  @classmethod
105
  @DB.connection_context()
106
  def get_joined_tenants_by_user_id(cls, user_id):
107
+ fields = [
108
+ cls.model.id.alias("tenant_id"),
109
+ cls.model.name,
110
+ cls.model.llm_id,
111
+ cls.model.embd_id,
112
+ cls.model.asr_id,
113
+ cls.model.img2txt_id,
114
+ UserTenant.role]
115
+ return list(cls.model.select(*fields)
116
+ .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL.value)))
117
+ .where(cls.model.status == StatusEnum.VALID.value).dicts())
118
 
119
  @classmethod
120
  @DB.connection_context()
121
  def decrease(cls, user_id, num):
122
  num = cls.model.update(credit=cls.model.credit - num).where(
123
  cls.model.id == user_id).execute()
124
+ if num == 0:
125
+ raise LookupError("Tenant not found which is supposed to be there")
126
+
127
 
128
  class UserTenantService(CommonService):
129
  model = UserTenant
api/settings.py CHANGED
@@ -13,16 +13,22 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  import os
17
 
18
  from enum import IntEnum, Enum
19
 
20
- from api.utils import get_base_config,decrypt_database_config
21
  from api.utils.file_utils import get_project_base_directory
22
  from api.utils.log_utils import LoggerFactory, getLogger
23
 
24
  # Logger
25
- LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api"))
 
 
 
 
26
  # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
27
  LoggerFactory.LEVEL = 10
28
 
@@ -86,7 +92,9 @@ default_llm = {
86
  LLM = get_base_config("user_default_llm", {})
87
  LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
88
  if LLM_FACTORY not in default_llm:
89
- print("\33[91m【ERROR】\33[0m:", f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
 
 
90
  LLM_FACTORY = "Tongyi-Qianwen"
91
  CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
92
  EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
@@ -94,7 +102,9 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
94
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
95
 
96
  API_KEY = LLM.get("api_key", "")
97
- PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
 
 
98
 
99
  # distribution
100
  DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
@@ -103,13 +113,25 @@ RAG_FLOW_UPDATE_CHECK = False
103
  HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
104
  HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
105
 
106
- SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow")
107
- TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600)
108
-
109
- NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST
110
- NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT
111
-
112
- RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False)
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
115
  PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
@@ -124,7 +146,9 @@ UPLOAD_DATA_FROM_CLIENT = True
124
  AUTHENTICATION_CONF = get_base_config("authentication", {})
125
 
126
  # client
127
- CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
 
 
128
  HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
129
  GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
130
  WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
@@ -147,12 +171,10 @@ USE_AUTHENTICATION = False
147
  USE_DATA_AUTHENTICATION = False
148
  AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
149
  USE_DEFAULT_TIMEOUT = False
150
- AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
151
  PRIVILEGE_COMMAND_WHITELIST = []
152
  CHECK_NODES_IDENTITY = False
153
 
154
- from rag.nlp import search
155
- from rag.utils import ELASTICSEARCH
156
  retrievaler = search.Dealer(ELASTICSEARCH)
157
 
158
 
@@ -162,7 +184,7 @@ class CustomEnum(Enum):
162
  try:
163
  cls(value)
164
  return True
165
- except:
166
  return False
167
 
168
  @classmethod
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from rag.utils import ELASTICSEARCH
17
+ from rag.nlp import search
18
  import os
19
 
20
  from enum import IntEnum, Enum
21
 
22
+ from api.utils import get_base_config, decrypt_database_config
23
  from api.utils.file_utils import get_project_base_directory
24
  from api.utils.log_utils import LoggerFactory, getLogger
25
 
26
  # Logger
27
+ LoggerFactory.set_directory(
28
+ os.path.join(
29
+ get_project_base_directory(),
30
+ "logs",
31
+ "api"))
32
  # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
33
  LoggerFactory.LEVEL = 10
34
 
 
92
  LLM = get_base_config("user_default_llm", {})
93
  LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
94
  if LLM_FACTORY not in default_llm:
95
+ print(
96
+ "\33[91m【ERROR】\33[0m:",
97
+ f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
98
  LLM_FACTORY = "Tongyi-Qianwen"
99
  CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
100
  EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
 
102
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
103
 
104
  API_KEY = LLM.get("api_key", "")
105
+ PARSERS = LLM.get(
106
+ "parsers",
107
+ "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
108
 
109
  # distribution
110
  DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
 
113
  HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
114
  HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
115
 
116
+ SECRET_KEY = get_base_config(
117
+ RAG_FLOW_SERVICE_NAME,
118
+ {}).get(
119
+ "secret_key",
120
+ "infiniflow")
121
+ TOKEN_EXPIRE_IN = get_base_config(
122
+ RAG_FLOW_SERVICE_NAME, {}).get(
123
+ "token_expires_in", 3600)
124
+
125
+ NGINX_HOST = get_base_config(
126
+ RAG_FLOW_SERVICE_NAME, {}).get(
127
+ "nginx", {}).get("host") or HOST
128
+ NGINX_HTTP_PORT = get_base_config(
129
+ RAG_FLOW_SERVICE_NAME, {}).get(
130
+ "nginx", {}).get("http_port") or HTTP_PORT
131
+
132
+ RANDOM_INSTANCE_ID = get_base_config(
133
+ RAG_FLOW_SERVICE_NAME, {}).get(
134
+ "random_instance_id", False)
135
 
136
  PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
137
  PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
 
146
  AUTHENTICATION_CONF = get_base_config("authentication", {})
147
 
148
  # client
149
+ CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
150
+ "client", {}).get(
151
+ "switch", False)
152
  HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
153
  GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
154
  WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
 
171
  USE_DATA_AUTHENTICATION = False
172
  AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
173
  USE_DEFAULT_TIMEOUT = False
174
+ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
175
  PRIVILEGE_COMMAND_WHITELIST = []
176
  CHECK_NODES_IDENTITY = False
177
 
 
 
178
  retrievaler = search.Dealer(ELASTICSEARCH)
179
 
180
 
 
184
  try:
185
  cls(value)
186
  return True
187
+ except BaseException:
188
  return False
189
 
190
  @classmethod
api/utils/__init__.py CHANGED
@@ -34,10 +34,12 @@ from . import file_utils
34
 
35
  SERVICE_CONF = "service_conf.yaml"
36
 
 
37
  def conf_realpath(conf_name):
38
  conf_path = f"conf/{conf_name}"
39
  return os.path.join(file_utils.get_project_base_directory(), conf_path)
40
 
 
41
  def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
42
  local_config = {}
43
  local_path = conf_realpath(f'local.{conf_name}')
@@ -62,7 +64,8 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
62
  return config.get(key, default) if key is not None else config
63
 
64
 
65
- use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False)
 
66
 
67
 
68
  class CoordinationCommunicationProtocol(object):
@@ -93,7 +96,8 @@ class BaseType:
93
  data[_k] = _dict(vv)
94
  else:
95
  data = obj
96
- return {"type": obj.__class__.__name__, "data": data, "module": module}
 
97
  return _dict(self)
98
 
99
 
@@ -129,7 +133,8 @@ def rag_uuid():
129
 
130
 
131
  def string_to_bytes(string):
132
- return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
 
133
 
134
 
135
  def bytes_to_string(byte):
@@ -137,7 +142,11 @@ def bytes_to_string(byte):
137
 
138
 
139
  def json_dumps(src, byte=False, indent=None, with_type=False):
140
- dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
 
 
 
 
141
  if byte:
142
  dest = string_to_bytes(dest)
143
  return dest
@@ -146,7 +155,8 @@ def json_dumps(src, byte=False, indent=None, with_type=False):
146
  def json_loads(src, object_hook=None, object_pairs_hook=None):
147
  if isinstance(src, bytes):
148
  src = bytes_to_string(src)
149
- return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)
 
150
 
151
 
152
  def current_timestamp():
@@ -177,7 +187,9 @@ def serialize_b64(src, to_str=False):
177
 
178
 
179
  def deserialize_b64(src):
180
- src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
 
 
181
  if use_deserialize_safe_module:
182
  return restricted_loads(src)
183
  return pickle.loads(src)
@@ -237,12 +249,14 @@ def get_lan_ip():
237
  pass
238
  return ip or ''
239
 
 
240
  def from_dict_hook(in_dict: dict):
241
  if "type" in in_dict and "data" in in_dict:
242
  if in_dict["module"] is None:
243
  return in_dict["data"]
244
  else:
245
- return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"])
 
246
  else:
247
  return in_dict
248
 
@@ -259,12 +273,16 @@ def decrypt_database_password(password):
259
  raise ValueError("No private key")
260
 
261
  module_fun = encrypt_module.split("#")
262
- pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1])
 
 
 
263
 
264
  return pwdecrypt_fun(private_key, password)
265
 
266
 
267
- def decrypt_database_config(database=None, passwd_key="password", name="database"):
 
268
  if not database:
269
  database = get_base_config(name, {})
270
 
@@ -275,7 +293,8 @@ def decrypt_database_config(database=None, passwd_key="password", name="database
275
  def update_config(key, value, conf_name=SERVICE_CONF):
276
  conf_path = conf_realpath(conf_name=conf_name)
277
  if not os.path.isabs(conf_path):
278
- conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path)
 
279
 
280
  with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
281
  config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
@@ -288,7 +307,8 @@ def get_uuid():
288
 
289
 
290
  def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
291
- return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second)
 
292
 
293
 
294
  def get_format_time() -> datetime.datetime:
@@ -307,14 +327,19 @@ def elapsed2time(elapsed):
307
 
308
 
309
  def decrypt(line):
310
- file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
 
 
 
311
  rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
312
  cipher = Cipher_pkcs1_v1_5.new(rsa_key)
313
- return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
 
314
 
315
 
316
  def download_img(url):
317
- if not url: return ""
 
318
  response = requests.get(url)
319
  return "data:" + \
320
  response.headers.get('Content-Type', 'image/jpg') + ";" + \
 
34
 
35
  SERVICE_CONF = "service_conf.yaml"
36
 
37
+
38
  def conf_realpath(conf_name):
39
  conf_path = f"conf/{conf_name}"
40
  return os.path.join(file_utils.get_project_base_directory(), conf_path)
41
 
42
+
43
  def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
44
  local_config = {}
45
  local_path = conf_realpath(f'local.{conf_name}')
 
64
  return config.get(key, default) if key is not None else config
65
 
66
 
67
+ use_deserialize_safe_module = get_base_config(
68
+ 'use_deserialize_safe_module', False)
69
 
70
 
71
  class CoordinationCommunicationProtocol(object):
 
96
  data[_k] = _dict(vv)
97
  else:
98
  data = obj
99
+ return {"type": obj.__class__.__name__,
100
+ "data": data, "module": module}
101
  return _dict(self)
102
 
103
 
 
133
 
134
 
135
  def string_to_bytes(string):
136
+ return string if isinstance(
137
+ string, bytes) else string.encode(encoding="utf-8")
138
 
139
 
140
  def bytes_to_string(byte):
 
142
 
143
 
144
  def json_dumps(src, byte=False, indent=None, with_type=False):
145
+ dest = json.dumps(
146
+ src,
147
+ indent=indent,
148
+ cls=CustomJSONEncoder,
149
+ with_type=with_type)
150
  if byte:
151
  dest = string_to_bytes(dest)
152
  return dest
 
155
  def json_loads(src, object_hook=None, object_pairs_hook=None):
156
  if isinstance(src, bytes):
157
  src = bytes_to_string(src)
158
+ return json.loads(src, object_hook=object_hook,
159
+ object_pairs_hook=object_pairs_hook)
160
 
161
 
162
  def current_timestamp():
 
187
 
188
 
189
  def deserialize_b64(src):
190
+ src = base64.b64decode(
191
+ string_to_bytes(src) if isinstance(
192
+ src, str) else src)
193
  if use_deserialize_safe_module:
194
  return restricted_loads(src)
195
  return pickle.loads(src)
 
249
  pass
250
  return ip or ''
251
 
252
+
253
  def from_dict_hook(in_dict: dict):
254
  if "type" in in_dict and "data" in in_dict:
255
  if in_dict["module"] is None:
256
  return in_dict["data"]
257
  else:
258
+ return getattr(importlib.import_module(
259
+ in_dict["module"]), in_dict["type"])(**in_dict["data"])
260
  else:
261
  return in_dict
262
 
 
273
  raise ValueError("No private key")
274
 
275
  module_fun = encrypt_module.split("#")
276
+ pwdecrypt_fun = getattr(
277
+ importlib.import_module(
278
+ module_fun[0]),
279
+ module_fun[1])
280
 
281
  return pwdecrypt_fun(private_key, password)
282
 
283
 
284
+ def decrypt_database_config(
285
+ database=None, passwd_key="password", name="database"):
286
  if not database:
287
  database = get_base_config(name, {})
288
 
 
293
  def update_config(key, value, conf_name=SERVICE_CONF):
294
  conf_path = conf_realpath(conf_name=conf_name)
295
  if not os.path.isabs(conf_path):
296
+ conf_path = os.path.join(
297
+ file_utils.get_project_base_directory(), conf_path)
298
 
299
  with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
300
  config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
 
307
 
308
 
309
  def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
310
+ return datetime.datetime(date_time.year, date_time.month, date_time.day,
311
+ date_time.hour, date_time.minute, date_time.second)
312
 
313
 
314
  def get_format_time() -> datetime.datetime:
 
327
 
328
 
329
  def decrypt(line):
330
+ file_path = os.path.join(
331
+ file_utils.get_project_base_directory(),
332
+ "conf",
333
+ "private.pem")
334
  rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
335
  cipher = Cipher_pkcs1_v1_5.new(rsa_key)
336
+ return cipher.decrypt(base64.b64decode(
337
+ line), "Fail to decrypt password!").decode('utf-8')
338
 
339
 
340
  def download_img(url):
341
+ if not url:
342
+ return ""
343
  response = requests.get(url)
344
  return "data:" + \
345
  response.headers.get('Content-Type', 'image/jpg') + ";" + \
api/utils/api_utils.py CHANGED
@@ -19,7 +19,7 @@ import time
19
  from functools import wraps
20
  from io import BytesIO
21
  from flask import (
22
- Response, jsonify, send_file,make_response,
23
  request as flask_request,
24
  )
25
  from werkzeug.http import HTTP_STATUS_CODES
@@ -29,7 +29,7 @@ from api.versions import get_rag_version
29
  from api.settings import RetCode
30
  from api.settings import (
31
  REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
32
- stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
33
  )
34
  import requests
35
  import functools
@@ -40,14 +40,21 @@ from hmac import HMAC
40
  from urllib.parse import quote, urlencode
41
 
42
 
43
- requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
 
44
 
45
 
46
  def request(**kwargs):
47
  sess = requests.Session()
48
  stream = kwargs.pop('stream', sess.stream)
49
  timeout = kwargs.pop('timeout', None)
50
- kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()}
 
 
 
 
 
 
51
  prepped = requests.Request(**kwargs).prepare()
52
 
53
  if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
@@ -59,7 +66,11 @@ def request(**kwargs):
59
  HTTP_APP_KEY.encode('ascii'),
60
  prepped.path_url.encode('ascii'),
61
  prepped.body if kwargs.get('json') else b'',
62
- urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii')
 
 
 
 
63
  if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
64
  ]), 'sha1').digest()).decode('ascii')
65
 
@@ -88,11 +99,12 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
88
  return max(0, countdown)
89
 
90
 
91
- def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None):
 
92
  import re
93
  result_dict = {
94
  "retcode": retcode,
95
- "retmsg":retmsg,
96
  # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
97
  "data": data,
98
  "jobId": job_id,
@@ -107,9 +119,17 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id
107
  response[key] = value
108
  return jsonify(response)
109
 
110
- def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'):
 
 
111
  import re
112
- result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)}
 
 
 
 
 
 
113
  response = {}
114
  for key, value in result_dict.items():
115
  if value is None and key != "retcode":
@@ -118,15 +138,17 @@ def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missin
118
  response[key] = value
119
  return jsonify(response)
120
 
 
121
  def server_error_response(e):
122
  stat_logger.exception(e)
123
  try:
124
- if e.code==401:
125
  return get_json_result(retcode=401, retmsg=repr(e))
126
- except:
127
  pass
128
  if len(e.args) > 1:
129
- return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
 
130
  return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
131
 
132
 
@@ -162,10 +184,13 @@ def validate_request(*args, **kwargs):
162
  if no_arguments or error_arguments:
163
  error_string = ""
164
  if no_arguments:
165
- error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
 
166
  if error_arguments:
167
- error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
168
- return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
 
 
169
  return func(*_args, **_kwargs)
170
  return decorated_function
171
  return wrapper
@@ -193,7 +218,8 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
193
  return jsonify(response)
194
 
195
 
196
- def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None):
 
197
  result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
198
  response_dict = {}
199
  for key, value in result_dict.items():
@@ -209,4 +235,4 @@ def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None
209
  response.headers["Access-Control-Allow-Headers"] = "*"
210
  response.headers["Access-Control-Allow-Headers"] = "*"
211
  response.headers["Access-Control-Expose-Headers"] = "Authorization"
212
- return response
 
19
  from functools import wraps
20
  from io import BytesIO
21
  from flask import (
22
+ Response, jsonify, send_file, make_response,
23
  request as flask_request,
24
  )
25
  from werkzeug.http import HTTP_STATUS_CODES
 
29
  from api.settings import RetCode
30
  from api.settings import (
31
  REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
32
+ stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
33
  )
34
  import requests
35
  import functools
 
40
  from urllib.parse import quote, urlencode
41
 
42
 
43
+ requests.models.complexjson.dumps = functools.partial(
44
+ json.dumps, cls=CustomJSONEncoder)
45
 
46
 
47
  def request(**kwargs):
48
  sess = requests.Session()
49
  stream = kwargs.pop('stream', sess.stream)
50
  timeout = kwargs.pop('timeout', None)
51
+ kwargs['headers'] = {
52
+ k.replace(
53
+ '_',
54
+ '-').upper(): v for k,
55
+ v in kwargs.get(
56
+ 'headers',
57
+ {}).items()}
58
  prepped = requests.Request(**kwargs).prepare()
59
 
60
  if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
 
66
  HTTP_APP_KEY.encode('ascii'),
67
  prepped.path_url.encode('ascii'),
68
  prepped.body if kwargs.get('json') else b'',
69
+ urlencode(
70
+ sorted(
71
+ kwargs['data'].items()),
72
+ quote_via=quote,
73
+ safe='-._~').encode('ascii')
74
  if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
75
  ]), 'sha1').digest()).decode('ascii')
76
 
 
99
  return max(0, countdown)
100
 
101
 
102
+ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
103
+ data=None, job_id=None, meta=None):
104
  import re
105
  result_dict = {
106
  "retcode": retcode,
107
+ "retmsg": retmsg,
108
  # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
109
  "data": data,
110
  "jobId": job_id,
 
119
  response[key] = value
120
  return jsonify(response)
121
 
122
+
123
+ def get_data_error_result(retcode=RetCode.DATA_ERROR,
124
+ retmsg='Sorry! Data missing!'):
125
  import re
126
+ result_dict = {
127
+ "retcode": retcode,
128
+ "retmsg": re.sub(
129
+ r"rag",
130
+ "seceum",
131
+ retmsg,
132
+ flags=re.IGNORECASE)}
133
  response = {}
134
  for key, value in result_dict.items():
135
  if value is None and key != "retcode":
 
138
  response[key] = value
139
  return jsonify(response)
140
 
141
+
142
  def server_error_response(e):
143
  stat_logger.exception(e)
144
  try:
145
+ if e.code == 401:
146
  return get_json_result(retcode=401, retmsg=repr(e))
147
+ except BaseException:
148
  pass
149
  if len(e.args) > 1:
150
+ return get_json_result(
151
+ retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
152
  return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
153
 
154
 
 
184
  if no_arguments or error_arguments:
185
  error_string = ""
186
  if no_arguments:
187
+ error_string += "required argument are missing: {}; ".format(
188
+ ",".join(no_arguments))
189
  if error_arguments:
190
+ error_string += "required argument values: {}".format(
191
+ ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
192
+ return get_json_result(
193
+ retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
194
  return func(*_args, **_kwargs)
195
  return decorated_function
196
  return wrapper
 
218
  return jsonify(response)
219
 
220
 
221
+ def cors_reponse(retcode=RetCode.SUCCESS,
222
+ retmsg='success', data=None, auth=None):
223
  result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
224
  response_dict = {}
225
  for key, value in result_dict.items():
 
235
  response.headers["Access-Control-Allow-Headers"] = "*"
236
  response.headers["Access-Control-Allow-Headers"] = "*"
237
  response.headers["Access-Control-Expose-Headers"] = "Authorization"
238
+ return response
api/utils/file_utils.py CHANGED
@@ -29,6 +29,7 @@ from api.db import FileType
29
  PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
30
  RAG_BASE = os.getenv("RAG_BASE")
31
 
 
32
  def get_project_base_directory(*args):
33
  global PROJECT_BASE
34
  if PROJECT_BASE is None:
@@ -65,7 +66,6 @@ def get_rag_python_directory(*args):
65
  return get_rag_directory("python", *args)
66
 
67
 
68
-
69
  @cached(cache=LRUCache(maxsize=10))
70
  def load_json_conf(conf_path):
71
  if os.path.isabs(conf_path):
@@ -146,10 +146,12 @@ def filename_type(filename):
146
  if re.match(r".*\.pdf$", filename):
147
  return FileType.PDF.value
148
 
149
- if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
 
150
  return FileType.DOC.value
151
 
152
- if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
 
153
  return FileType.AURAL.value
154
 
155
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
@@ -164,14 +166,16 @@ def thumbnail(filename, blob):
164
  buffered = BytesIO()
165
  Image.frombytes("RGB", [pix.width, pix.height],
166
  pix.samples).save(buffered, format="png")
167
- return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
 
168
 
169
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
170
  image = Image.open(BytesIO(blob))
171
  image.thumbnail((30, 30))
172
  buffered = BytesIO()
173
  image.save(buffered, format="png")
174
- return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
 
175
 
176
  if re.match(r".*\.(ppt|pptx)$", filename):
177
  import aspose.slides as slides
@@ -179,8 +183,10 @@ def thumbnail(filename, blob):
179
  try:
180
  with slides.Presentation(BytesIO(blob)) as presentation:
181
  buffered = BytesIO()
182
- presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
183
- return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
 
 
184
  except Exception as e:
185
  pass
186
 
@@ -190,6 +196,3 @@ def traversal_files(base):
190
  for f in fs:
191
  fullname = os.path.join(root, f)
192
  yield fullname
193
-
194
-
195
-
 
29
  PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
30
  RAG_BASE = os.getenv("RAG_BASE")
31
 
32
+
33
  def get_project_base_directory(*args):
34
  global PROJECT_BASE
35
  if PROJECT_BASE is None:
 
66
  return get_rag_directory("python", *args)
67
 
68
 
 
69
  @cached(cache=LRUCache(maxsize=10))
70
  def load_json_conf(conf_path):
71
  if os.path.isabs(conf_path):
 
146
  if re.match(r".*\.pdf$", filename):
147
  return FileType.PDF.value
148
 
149
+ if re.match(
150
+ r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
151
  return FileType.DOC.value
152
 
153
+ if re.match(
154
+ r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
155
  return FileType.AURAL.value
156
 
157
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
 
166
  buffered = BytesIO()
167
  Image.frombytes("RGB", [pix.width, pix.height],
168
  pix.samples).save(buffered, format="png")
169
+ return "data:image/png;base64," + \
170
+ base64.b64encode(buffered.getvalue()).decode("utf-8")
171
 
172
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
173
  image = Image.open(BytesIO(blob))
174
  image.thumbnail((30, 30))
175
  buffered = BytesIO()
176
  image.save(buffered, format="png")
177
+ return "data:image/png;base64," + \
178
+ base64.b64encode(buffered.getvalue()).decode("utf-8")
179
 
180
  if re.match(r".*\.(ppt|pptx)$", filename):
181
  import aspose.slides as slides
 
183
  try:
184
  with slides.Presentation(BytesIO(blob)) as presentation:
185
  buffered = BytesIO()
186
+ presentation.slides[0].get_thumbnail(0.03, 0.03).save(
187
+ buffered, drawing.imaging.ImageFormat.png)
188
+ return "data:image/png;base64," + \
189
+ base64.b64encode(buffered.getvalue()).decode("utf-8")
190
  except Exception as e:
191
  pass
192
 
 
196
  for f in fs:
197
  fullname = os.path.join(root, f)
198
  yield fullname
 
 
 
api/utils/log_utils.py CHANGED
@@ -23,6 +23,7 @@ from threading import RLock
23
 
24
  from api.utils import file_utils
25
 
 
26
  class LoggerFactory(object):
27
  TYPE = "FILE"
28
  LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
@@ -49,7 +50,8 @@ class LoggerFactory(object):
49
  schedule_logger_dict = {}
50
 
51
  @staticmethod
52
- def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False):
 
53
  if parent_log_dir:
54
  LoggerFactory.PARENT_LOG_DIR = parent_log_dir
55
  if append_to_parent_log:
@@ -66,11 +68,13 @@ class LoggerFactory(object):
66
  else:
67
  os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
68
  for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
69
- for className, (logger, handler) in LoggerFactory.logger_dict.items():
 
70
  logger.removeHandler(ghandler)
71
  ghandler.close()
72
  LoggerFactory.global_handler_dict = {}
73
- for className, (logger, handler) in LoggerFactory.logger_dict.items():
 
74
  logger.removeHandler(handler)
75
  _handler = None
76
  if handler:
@@ -111,19 +115,23 @@ class LoggerFactory(object):
111
  if logger_name_key not in LoggerFactory.global_handler_dict:
112
  with LoggerFactory.lock:
113
  if logger_name_key not in LoggerFactory.global_handler_dict:
114
- handler = LoggerFactory.get_handler(logger_name, level, log_dir)
 
115
  LoggerFactory.global_handler_dict[logger_name_key] = handler
116
  return LoggerFactory.global_handler_dict[logger_name_key]
117
 
118
  @staticmethod
119
- def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None):
 
120
  if not log_type:
121
  if not LoggerFactory.LOG_DIR or not class_name:
122
  return logging.StreamHandler()
123
  # return Diy_StreamHandler()
124
 
125
  if not log_dir:
126
- log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name))
 
 
127
  else:
128
  log_file = os.path.join(log_dir, "{}.log".format(class_name))
129
  else:
@@ -133,16 +141,16 @@ class LoggerFactory(object):
133
  os.makedirs(os.path.dirname(log_file), exist_ok=True)
134
  if LoggerFactory.log_share:
135
  handler = ROpenHandler(log_file,
136
- when='D',
137
- interval=1,
138
- backupCount=14,
139
- delay=True)
140
  else:
141
  handler = TimedRotatingFileHandler(log_file,
142
- when='D',
143
- interval=1,
144
- backupCount=14,
145
- delay=True)
146
  if level:
147
  handler.level = level
148
 
@@ -170,7 +178,9 @@ class LoggerFactory(object):
170
  for level in LoggerFactory.levels:
171
  if level >= LoggerFactory.LEVEL:
172
  level_logger_name = logging._levelToName[level]
173
- logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level))
 
 
174
  if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
175
  for level in LoggerFactory.levels:
176
  if level >= LoggerFactory.LEVEL:
@@ -224,22 +234,26 @@ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
224
  return f"{prefix}start to {msg}{suffix}"
225
 
226
 
227
- def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
 
228
  prefix, suffix = base_msg(job, task, role, party_id, detail)
229
  return f"{prefix}{msg} successfully{suffix}"
230
 
231
 
232
- def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
 
233
  prefix, suffix = base_msg(job, task, role, party_id, detail)
234
  return f"{prefix}{msg} is not effective{suffix}"
235
 
236
 
237
- def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
 
238
  prefix, suffix = base_msg(job, task, role, party_id, detail)
239
  return f"{prefix}failed to {msg}{suffix}"
240
 
241
 
242
- def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None):
 
243
  if detail:
244
  detail_msg = f" detail: \n{detail}"
245
  else:
@@ -285,10 +299,14 @@ def get_job_logger(job_id, log_type):
285
  for job_log_dir in log_dirs:
286
  handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
287
  log_dir=job_log_dir, log_type=log_type, job_id=job_id)
288
- error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id)
 
 
 
 
 
289
  logger.addHandler(handler)
290
  logger.addHandler(error_handler)
291
  with LoggerFactory.lock:
292
  LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
293
  return logger
294
-
 
23
 
24
  from api.utils import file_utils
25
 
26
+
27
  class LoggerFactory(object):
28
  TYPE = "FILE"
29
  LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
 
50
  schedule_logger_dict = {}
51
 
52
  @staticmethod
53
+ def set_directory(directory=None, parent_log_dir=None,
54
+ append_to_parent_log=None, force=False):
55
  if parent_log_dir:
56
  LoggerFactory.PARENT_LOG_DIR = parent_log_dir
57
  if append_to_parent_log:
 
68
  else:
69
  os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
70
  for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
71
+ for className, (logger,
72
+ handler) in LoggerFactory.logger_dict.items():
73
  logger.removeHandler(ghandler)
74
  ghandler.close()
75
  LoggerFactory.global_handler_dict = {}
76
+ for className, (logger,
77
+ handler) in LoggerFactory.logger_dict.items():
78
  logger.removeHandler(handler)
79
  _handler = None
80
  if handler:
 
115
  if logger_name_key not in LoggerFactory.global_handler_dict:
116
  with LoggerFactory.lock:
117
  if logger_name_key not in LoggerFactory.global_handler_dict:
118
+ handler = LoggerFactory.get_handler(
119
+ logger_name, level, log_dir)
120
  LoggerFactory.global_handler_dict[logger_name_key] = handler
121
  return LoggerFactory.global_handler_dict[logger_name_key]
122
 
123
  @staticmethod
124
+ def get_handler(class_name, level=None, log_dir=None,
125
+ log_type=None, job_id=None):
126
  if not log_type:
127
  if not LoggerFactory.LOG_DIR or not class_name:
128
  return logging.StreamHandler()
129
  # return Diy_StreamHandler()
130
 
131
  if not log_dir:
132
+ log_file = os.path.join(
133
+ LoggerFactory.LOG_DIR,
134
+ "{}.log".format(class_name))
135
  else:
136
  log_file = os.path.join(log_dir, "{}.log".format(class_name))
137
  else:
 
141
  os.makedirs(os.path.dirname(log_file), exist_ok=True)
142
  if LoggerFactory.log_share:
143
  handler = ROpenHandler(log_file,
144
+ when='D',
145
+ interval=1,
146
+ backupCount=14,
147
+ delay=True)
148
  else:
149
  handler = TimedRotatingFileHandler(log_file,
150
+ when='D',
151
+ interval=1,
152
+ backupCount=14,
153
+ delay=True)
154
  if level:
155
  handler.level = level
156
 
 
178
  for level in LoggerFactory.levels:
179
  if level >= LoggerFactory.LEVEL:
180
  level_logger_name = logging._levelToName[level]
181
+ logger.addHandler(
182
+ LoggerFactory.get_global_handler(
183
+ level_logger_name, level))
184
  if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
185
  for level in LoggerFactory.levels:
186
  if level >= LoggerFactory.LEVEL:
 
234
  return f"{prefix}start to {msg}{suffix}"
235
 
236
 
237
+ def successful_log(msg, job=None, task=None, role=None,
238
+ party_id=None, detail=None):
239
  prefix, suffix = base_msg(job, task, role, party_id, detail)
240
  return f"{prefix}{msg} successfully{suffix}"
241
 
242
 
243
+ def warning_log(msg, job=None, task=None, role=None,
244
+ party_id=None, detail=None):
245
  prefix, suffix = base_msg(job, task, role, party_id, detail)
246
  return f"{prefix}{msg} is not effective{suffix}"
247
 
248
 
249
+ def failed_log(msg, job=None, task=None, role=None,
250
+ party_id=None, detail=None):
251
  prefix, suffix = base_msg(job, task, role, party_id, detail)
252
  return f"{prefix}failed to {msg}{suffix}"
253
 
254
 
255
+ def base_msg(job=None, task=None, role: str = None,
256
+ party_id: typing.Union[str, int] = None, detail=None):
257
  if detail:
258
  detail_msg = f" detail: \n{detail}"
259
  else:
 
299
  for job_log_dir in log_dirs:
300
  handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
301
  log_dir=job_log_dir, log_type=log_type, job_id=job_id)
302
+ error_handler = LoggerFactory.get_handler(
303
+ class_name=None,
304
+ level=logging.ERROR,
305
+ log_dir=job_log_dir,
306
+ log_type=log_type,
307
+ job_id=job_id)
308
  logger.addHandler(handler)
309
  logger.addHandler(error_handler)
310
  with LoggerFactory.lock:
311
  LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
312
  return logger
 
api/utils/t_crypt.py CHANGED
@@ -1,18 +1,23 @@
1
- import base64, os, sys
 
 
2
  from Cryptodome.PublicKey import RSA
3
  from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
4
  from api.utils import decrypt, file_utils
5
 
 
6
  def crypt(line):
7
- file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
 
 
 
8
  rsa_key = RSA.importKey(open(file_path).read())
9
  cipher = Cipher_pkcs1_v1_5.new(rsa_key)
10
- return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8")
11
-
12
 
13
 
14
  if __name__ == "__main__":
15
  pswd = crypt(sys.argv[1])
16
  print(pswd)
17
  print(decrypt(pswd))
18
-
 
1
+ import base64
2
+ import os
3
+ import sys
4
  from Cryptodome.PublicKey import RSA
5
  from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
6
  from api.utils import decrypt, file_utils
7
 
8
+
9
  def crypt(line):
10
+ file_path = os.path.join(
11
+ file_utils.get_project_base_directory(),
12
+ "conf",
13
+ "public.pem")
14
  rsa_key = RSA.importKey(open(file_path).read())
15
  cipher = Cipher_pkcs1_v1_5.new(rsa_key)
16
+ return base64.b64encode(cipher.encrypt(
17
+ line.encode('utf-8'))).decode("utf-8")
18
 
19
 
20
  if __name__ == "__main__":
21
  pswd = crypt(sys.argv[1])
22
  print(pswd)
23
  print(decrypt(pswd))
 
deepdoc/parser/__init__.py CHANGED
@@ -4,5 +4,3 @@ from .pdf_parser import HuParser as PdfParser, PlainParser
4
  from .docx_parser import HuDocxParser as DocxParser
5
  from .excel_parser import HuExcelParser as ExcelParser
6
  from .ppt_parser import HuPptParser as PptParser
7
-
8
-
 
4
  from .docx_parser import HuDocxParser as DocxParser
5
  from .excel_parser import HuExcelParser as ExcelParser
6
  from .ppt_parser import HuPptParser as PptParser
 
 
deepdoc/parser/docx_parser.py CHANGED
@@ -99,12 +99,15 @@ class HuDocxParser:
99
  return ["\n".join(lines)]
100
 
101
  def __call__(self, fnm, from_page=0, to_page=100000):
102
- self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm))
 
103
  pn = 0
104
  secs = []
105
  for p in self.doc.paragraphs:
106
- if pn > to_page: break
107
- if from_page <= pn < to_page and p.text.strip(): secs.append((p.text, p.style.name))
 
 
108
  for run in p.runs:
109
  if 'lastRenderedPageBreak' in run._element.xml:
110
  pn += 1
 
99
  return ["\n".join(lines)]
100
 
101
  def __call__(self, fnm, from_page=0, to_page=100000):
102
+ self.doc = Document(fnm) if isinstance(
103
+ fnm, str) else Document(BytesIO(fnm))
104
  pn = 0
105
  secs = []
106
  for p in self.doc.paragraphs:
107
+ if pn > to_page:
108
+ break
109
+ if from_page <= pn < to_page and p.text.strip():
110
+ secs.append((p.text, p.style.name))
111
  for run in p.runs:
112
  if 'lastRenderedPageBreak' in run._element.xml:
113
  pn += 1
deepdoc/parser/excel_parser.py CHANGED
@@ -15,13 +15,16 @@ class HuExcelParser:
15
  ws = wb[sheetname]
16
  rows = list(ws.rows)
17
  tb += f"<table><caption>{sheetname}</caption><tr>"
18
- for t in list(rows[0]): tb += f"<th>{t.value}</th>"
 
19
  tb += "</tr>"
20
  for r in list(rows[1:]):
21
  tb += "<tr>"
22
- for i,c in enumerate(r):
23
- if c.value is None: tb += "<td></td>"
24
- else: tb += f"<td>{c.value}</td>"
 
 
25
  tb += "</tr>"
26
  tb += "</table>\n"
27
  return tb
@@ -38,13 +41,15 @@ class HuExcelParser:
38
  ti = list(rows[0])
39
  for r in list(rows[1:]):
40
  l = []
41
- for i,c in enumerate(r):
42
- if not c.value:continue
 
43
  t = str(ti[i].value) if i < len(ti) else ""
44
  t += (":" if t else "") + str(c.value)
45
  l.append(t)
46
  l = "; ".join(l)
47
- if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
 
48
  res.append(l)
49
  return res
50
 
 
15
  ws = wb[sheetname]
16
  rows = list(ws.rows)
17
  tb += f"<table><caption>{sheetname}</caption><tr>"
18
+ for t in list(rows[0]):
19
+ tb += f"<th>{t.value}</th>"
20
  tb += "</tr>"
21
  for r in list(rows[1:]):
22
  tb += "<tr>"
23
+ for i, c in enumerate(r):
24
+ if c.value is None:
25
+ tb += "<td></td>"
26
+ else:
27
+ tb += f"<td>{c.value}</td>"
28
  tb += "</tr>"
29
  tb += "</table>\n"
30
  return tb
 
41
  ti = list(rows[0])
42
  for r in list(rows[1:]):
43
  l = []
44
+ for i, c in enumerate(r):
45
+ if not c.value:
46
+ continue
47
  t = str(ti[i].value) if i < len(ti) else ""
48
  t += (":" if t else "") + str(c.value)
49
  l.append(t)
50
  l = "; ".join(l)
51
+ if sheetname.lower().find("sheet") < 0:
52
+ l += " ——" + sheetname
53
  res.append(l)
54
  return res
55
 
deepdoc/parser/pdf_parser.py CHANGED
@@ -43,9 +43,11 @@ class HuParser:
43
  "rag/res/deepdoc"),
44
  local_files_only=True)
45
  except Exception as e:
46
- model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0")
 
47
 
48
- self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
 
49
  self.page_from = 0
50
  """
51
  If you have trouble downloading HuggingFace models, -_^ this might help!!
@@ -72,7 +74,7 @@ class HuParser:
72
  def _y_dis(
73
  self, a, b):
74
  return (
75
- b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
76
 
77
  def _match_proj(self, b):
78
  proj_patt = [
@@ -95,9 +97,9 @@ class HuParser:
95
  tks_down = huqie.qie(down["text"][:LEN]).split(" ")
96
  tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
97
  tks_all = up["text"][-LEN:].strip() \
98
- + (" " if re.match(r"[a-zA-Z0-9]+",
99
- up["text"][-1] + down["text"][0]) else "") \
100
- + down["text"][:LEN].strip()
101
  tks_all = huqie.qie(tks_all).split(" ")
102
  fea = [
103
  up.get("R", -1) == down.get("R", -1),
@@ -119,7 +121,7 @@ class HuParser:
119
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
120
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
121
  True if re.search(r"[\((][^\))]+$", up["text"])
122
- and re.search(r"[\))]", down["text"]) else False,
123
  self._match_proj(down),
124
  True if re.match(r"[A-Z]", down["text"]) else False,
125
  True if re.match(r"[A-Z]", up["text"][-1]) else False,
@@ -181,7 +183,7 @@ class HuParser:
181
  continue
182
  for tb in tbls: # for table
183
  left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
184
- tb["x1"] + MARGIN, tb["bottom"] + MARGIN
185
  left *= ZM
186
  top *= ZM
187
  right *= ZM
@@ -235,7 +237,8 @@ class HuParser:
235
  b["R_top"] = rows[ii]["top"]
236
  b["R_bott"] = rows[ii]["bottom"]
237
 
238
- ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
 
239
  if ii is not None:
240
  b["H_top"] = headers[ii]["top"]
241
  b["H_bott"] = headers[ii]["bottom"]
@@ -272,7 +275,8 @@ class HuParser:
272
  )
273
 
274
  # merge chars in the same rect
275
- for c in Recognizer.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4):
 
276
  ii = Recognizer.find_overlapped(c, bxs)
277
  if ii is None:
278
  self.lefted_chars.append(c)
@@ -283,13 +287,15 @@ class HuParser:
283
  self.lefted_chars.append(c)
284
  continue
285
  if c["text"] == " " and bxs[ii]["text"]:
286
- if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]): bxs[ii]["text"] += " "
 
287
  else:
288
  bxs[ii]["text"] += c["text"]
289
 
290
  for b in bxs:
291
  if not b["text"]:
292
- left, right, top, bott = b["x0"] * ZM, b["x1"] * ZM, b["top"] * ZM, b["bottom"] * ZM
 
293
  b["text"] = self.ocr.recognize(np.array(img),
294
  np.array([[left, top], [right, top], [right, bott], [left, bott]],
295
  dtype=np.float32))
@@ -302,7 +308,8 @@ class HuParser:
302
 
303
  def _layouts_rec(self, ZM, drop=True):
304
  assert len(self.page_images) == len(self.boxes)
305
- self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM, drop=drop)
 
306
  # cumlative Y
307
  for i in range(len(self.boxes)):
308
  self.boxes[i]["top"] += \
@@ -332,7 +339,8 @@ class HuParser:
332
  "equation"]:
333
  i += 1
334
  continue
335
- if abs(self._y_dis(b, b_)) < self.mean_height[bxs[i]["page_number"] - 1] / 3:
 
336
  # merge
337
  bxs[i]["x1"] = b_["x1"]
338
  bxs[i]["top"] = (b["top"] + b_["top"]) / 2
@@ -366,12 +374,15 @@ class HuParser:
366
  self.boxes = bxs
367
 
368
  def _naive_vertical_merge(self):
369
- bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3)
 
 
370
  i = 0
371
  while i + 1 < len(bxs):
372
  b = bxs[i]
373
  b_ = bxs[i + 1]
374
- if b["page_number"] < b_["page_number"] and re.match(r"[0-9 •一—-]+$", b["text"]):
 
375
  bxs.pop(i)
376
  continue
377
  if not b["text"].strip():
@@ -379,7 +390,8 @@ class HuParser:
379
  continue
380
  concatting_feats = [
381
  b["text"].strip()[-1] in ",;:'\",、‘“;:-",
382
- len(b["text"].strip()) > 1 and b["text"].strip()[-2] in ",;:'\",‘“、;:",
 
383
  b["text"].strip()[0] in "。;?!?”)),,、:",
384
  ]
385
  # features for not concating
@@ -387,7 +399,7 @@ class HuParser:
387
  b.get("layoutno", 0) != b.get("layoutno", 0),
388
  b["text"].strip()[-1] in "。?!?",
389
  self.is_english and b["text"].strip()[-1] in ".!?",
390
- b["page_number"] == b_["page_number"] and b_["top"] - \
391
  b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5,
392
  b["page_number"] < b_["page_number"] and abs(
393
  b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4,
@@ -396,7 +408,12 @@ class HuParser:
396
  detach_feats = [b["x1"] < b_["x0"],
397
  b["x0"] > b_["x1"]]
398
  if (any(feats) and not any(concatting_feats)) or any(detach_feats):
399
- print(b["text"], b_["text"], any(feats), any(concatting_feats), any(detach_feats))
 
 
 
 
 
400
  i += 1
401
  continue
402
  # merge up and down
@@ -526,31 +543,39 @@ class HuParser:
526
  i += 1
527
  continue
528
  findit = True
529
- eng = re.match(r"[0-9a-zA-Z :'.-]{5,}", self.boxes[i]["text"].strip())
 
 
530
  self.boxes.pop(i)
531
- if i >= len(self.boxes): break
 
532
  prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
533
  self.boxes[i]["text"].strip().split(" ")[:2])
534
  while not prefix:
535
  self.boxes.pop(i)
536
- if i >= len(self.boxes): break
 
537
  prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
538
  self.boxes[i]["text"].strip().split(" ")[:2])
539
  self.boxes.pop(i)
540
- if i >= len(self.boxes) or not prefix: break
 
541
  for j in range(i, min(i + 128, len(self.boxes))):
542
  if not re.match(prefix, self.boxes[j]["text"]):
543
  continue
544
- for k in range(i, j): self.boxes.pop(i)
 
545
  break
546
- if findit: return
 
547
 
548
  page_dirty = [0] * len(self.page_images)
549
  for b in self.boxes:
550
  if re.search(r"(··|··|··)", b["text"]):
551
  page_dirty[b["page_number"] - 1] += 1
552
  page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3])
553
- if not page_dirty: return
 
554
  i = 0
555
  while i < len(self.boxes):
556
  if self.boxes[i]["page_number"] in page_dirty:
@@ -582,7 +607,8 @@ class HuParser:
582
  b_["top"] = b["top"]
583
  self.boxes.pop(i)
584
 
585
- def _extract_table_figure(self, need_image, ZM, return_html, need_position):
 
586
  tables = {}
587
  figures = {}
588
  # extract figure and table boxes
@@ -594,7 +620,7 @@ class HuParser:
594
  i += 1
595
  continue
596
  lout_no = str(self.boxes[i]["page_number"]) + \
597
- "-" + str(self.boxes[i]["layoutno"])
598
  if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
599
  "title",
600
  "figure caption",
@@ -761,7 +787,8 @@ class HuParser:
761
  for k, bxs in tables.items():
762
  if not bxs:
763
  continue
764
- bxs = Recognizer.sort_Y_firstly(bxs, np.mean([(b["bottom"] - b["top"]) / 2 for b in bxs]))
 
765
  poss = []
766
  res.append((cropout(bxs, "table", poss),
767
  self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
@@ -769,7 +796,8 @@ class HuParser:
769
 
770
  assert len(positions) == len(res)
771
 
772
- if need_position: return list(zip(res, positions))
 
773
  return res
774
 
775
  def proj_match(self, line):
@@ -873,7 +901,8 @@ class HuParser:
873
  boxes.pop(0)
874
  mw = np.mean(widths)
875
  if mj or mw / pw >= 0.35 or mw > 200:
876
- res.append("\n".join([c["text"] + self._line_tag(c, ZM) for c in lines]))
 
877
  else:
878
  logging.debug("REMOVED: " +
879
  "<<".join([c["text"] for c in lines]))
@@ -883,13 +912,16 @@ class HuParser:
883
  @staticmethod
884
  def total_page_number(fnm, binary=None):
885
  try:
886
- pdf = pdfplumber.open(fnm) if not binary else pdfplumber.open(BytesIO(binary))
 
887
  return len(pdf.pages)
888
  except Exception as e:
889
- pdf = fitz.open(fnm) if not binary else fitz.open(stream=fnm, filetype="pdf")
 
890
  return len(pdf)
891
 
892
- def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
 
893
  self.lefted_chars = []
894
  self.mean_height = []
895
  self.mean_width = []
@@ -899,21 +931,26 @@ class HuParser:
899
  self.page_layout = []
900
  self.page_from = page_from
901
  try:
902
- self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
 
903
  self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
904
  enumerate(self.pdf.pages[page_from:page_to])]
905
  self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in
906
  self.pdf.pages[page_from:page_to]]
907
  self.total_page = len(self.pdf.pages)
908
  except Exception as e:
909
- self.pdf = fitz.open(fnm) if isinstance(fnm, str) else fitz.open(stream=fnm, filetype="pdf")
 
 
910
  self.page_images = []
911
  self.page_chars = []
912
  mat = fitz.Matrix(zoomin, zoomin)
913
  self.total_page = len(self.pdf)
914
  for i, page in enumerate(self.pdf):
915
- if i < page_from: continue
916
- if i >= page_to: break
 
 
917
  pix = page.get_pixmap(matrix=mat)
918
  img = Image.frombytes("RGB", [pix.width, pix.height],
919
  pix.samples)
@@ -930,7 +967,7 @@ class HuParser:
930
  if isinstance(a, dict):
931
  self.outlines.append((a["/Title"], depth))
932
  continue
933
- dfs(a, depth+1)
934
  dfs(outlines, 0)
935
  except Exception as e:
936
  logging.warning(f"Outlines exception: {e}")
@@ -940,8 +977,9 @@ class HuParser:
940
  logging.info("Images converted.")
941
  self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
942
  random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
943
- range(len(self.page_chars))]
944
- if sum([1 if e else 0 for e in self.is_english]) > len(self.page_images) / 2:
 
945
  self.is_english = True
946
  else:
947
  self.is_english = False
@@ -970,9 +1008,11 @@ class HuParser:
970
  # self.page_cum_height.append(
971
  # np.max([c["bottom"] for c in chars]))
972
  self.__ocr(i + 1, img, chars, zoomin)
973
- if callback: callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
 
974
 
975
- if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
 
976
  bxes = [b for bxs in self.boxes for b in bxs]
977
  self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}",
978
  "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
@@ -989,7 +1029,8 @@ class HuParser:
989
  self._text_merge()
990
  self._concat_downward()
991
  self._filter_forpages()
992
- tbls = self._extract_table_figure(need_image, zoomin, return_html, False)
 
993
  return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
994
 
995
  def remove_tag(self, txt):
@@ -1003,15 +1044,19 @@ class HuParser:
1003
  "#").strip("@").split("\t")
1004
  left, right, top, bottom = float(left), float(
1005
  right), float(top), float(bottom)
1006
- poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
 
1007
  if not poss:
1008
- if need_position: return None, None
 
1009
  return
1010
 
1011
- max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
 
1012
  GAP = 6
1013
  pos = poss[0]
1014
- poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
 
1015
  pos = poss[-1]
1016
  poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP),
1017
  min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
@@ -1026,7 +1071,7 @@ class HuParser:
1026
  self.page_images[pns[0]].crop((left * ZM, top * ZM,
1027
  right *
1028
  ZM, min(
1029
- bottom, self.page_images[pns[0]].size[1])
1030
  ))
1031
  )
1032
  if 0 < ii < len(poss) - 1:
@@ -1047,7 +1092,8 @@ class HuParser:
1047
  bottom -= self.page_images[pn].size[1]
1048
 
1049
  if not imgs:
1050
- if need_position: return None, None
 
1051
  return
1052
  height = 0
1053
  for img in imgs:
@@ -1076,12 +1122,14 @@ class HuParser:
1076
  pn = bx["page_number"]
1077
  top = bx["top"] - self.page_cum_height[pn - 1]
1078
  bott = bx["bottom"] - self.page_cum_height[pn - 1]
1079
- poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM)))
 
1080
  while bott * ZM > self.page_images[pn - 1].size[1]:
1081
  bott -= self.page_images[pn - 1].size[1] / ZM
1082
  top = 0
1083
  pn += 1
1084
- poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM)))
 
1085
  return poss
1086
 
1087
 
@@ -1090,11 +1138,14 @@ class PlainParser(object):
1090
  self.outlines = []
1091
  lines = []
1092
  try:
1093
- self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename))
 
 
1094
  for page in self.pdf.pages[from_page:to_page]:
1095
  lines.extend([t for t in page.extract_text().split("\n")])
1096
 
1097
  outlines = self.pdf.outline
 
1098
  def dfs(arr, depth):
1099
  for a in arr:
1100
  if isinstance(a, dict):
@@ -1117,5 +1168,6 @@ class PlainParser(object):
1117
  def remove_tag(txt):
1118
  raise NotImplementedError
1119
 
 
1120
  if __name__ == "__main__":
1121
  pass
 
43
  "rag/res/deepdoc"),
44
  local_files_only=True)
45
  except Exception as e:
46
+ model_dir = snapshot_download(
47
+ repo_id="InfiniFlow/text_concat_xgb_v1.0")
48
 
49
+ self.updown_cnt_mdl.load_model(os.path.join(
50
+ model_dir, "updown_concat_xgb.model"))
51
  self.page_from = 0
52
  """
53
  If you have trouble downloading HuggingFace models, -_^ this might help!!
 
74
  def _y_dis(
75
  self, a, b):
76
  return (
77
+ b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
78
 
79
  def _match_proj(self, b):
80
  proj_patt = [
 
97
  tks_down = huqie.qie(down["text"][:LEN]).split(" ")
98
  tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
99
  tks_all = up["text"][-LEN:].strip() \
100
+ + (" " if re.match(r"[a-zA-Z0-9]+",
101
+ up["text"][-1] + down["text"][0]) else "") \
102
+ + down["text"][:LEN].strip()
103
  tks_all = huqie.qie(tks_all).split(" ")
104
  fea = [
105
  up.get("R", -1) == down.get("R", -1),
 
121
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
122
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
123
  True if re.search(r"[\((][^\))]+$", up["text"])
124
+ and re.search(r"[\))]", down["text"]) else False,
125
  self._match_proj(down),
126
  True if re.match(r"[A-Z]", down["text"]) else False,
127
  True if re.match(r"[A-Z]", up["text"][-1]) else False,
 
183
  continue
184
  for tb in tbls: # for table
185
  left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
186
+ tb["x1"] + MARGIN, tb["bottom"] + MARGIN
187
  left *= ZM
188
  top *= ZM
189
  right *= ZM
 
237
  b["R_top"] = rows[ii]["top"]
238
  b["R_bott"] = rows[ii]["bottom"]
239
 
240
+ ii = Recognizer.find_overlapped_with_threashold(
241
+ b, headers, thr=0.3)
242
  if ii is not None:
243
  b["H_top"] = headers[ii]["top"]
244
  b["H_bott"] = headers[ii]["bottom"]
 
275
  )
276
 
277
  # merge chars in the same rect
278
+ for c in Recognizer.sort_X_firstly(
279
+ chars, self.mean_width[pagenum - 1] // 4):
280
  ii = Recognizer.find_overlapped(c, bxs)
281
  if ii is None:
282
  self.lefted_chars.append(c)
 
287
  self.lefted_chars.append(c)
288
  continue
289
  if c["text"] == " " and bxs[ii]["text"]:
290
+ if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]):
291
+ bxs[ii]["text"] += " "
292
  else:
293
  bxs[ii]["text"] += c["text"]
294
 
295
  for b in bxs:
296
  if not b["text"]:
297
+ left, right, top, bott = b["x0"] * ZM, b["x1"] * \
298
+ ZM, b["top"] * ZM, b["bottom"] * ZM
299
  b["text"] = self.ocr.recognize(np.array(img),
300
  np.array([[left, top], [right, top], [right, bott], [left, bott]],
301
  dtype=np.float32))
 
308
 
309
  def _layouts_rec(self, ZM, drop=True):
310
  assert len(self.page_images) == len(self.boxes)
311
+ self.boxes, self.page_layout = self.layouter(
312
+ self.page_images, self.boxes, ZM, drop=drop)
313
  # cumlative Y
314
  for i in range(len(self.boxes)):
315
  self.boxes[i]["top"] += \
 
339
  "equation"]:
340
  i += 1
341
  continue
342
+ if abs(self._y_dis(b, b_)
343
+ ) < self.mean_height[bxs[i]["page_number"] - 1] / 3:
344
  # merge
345
  bxs[i]["x1"] = b_["x1"]
346
  bxs[i]["top"] = (b["top"] + b_["top"]) / 2
 
374
  self.boxes = bxs
375
 
376
  def _naive_vertical_merge(self):
377
+ bxs = Recognizer.sort_Y_firstly(
378
+ self.boxes, np.median(
379
+ self.mean_height) / 3)
380
  i = 0
381
  while i + 1 < len(bxs):
382
  b = bxs[i]
383
  b_ = bxs[i + 1]
384
+ if b["page_number"] < b_["page_number"] and re.match(
385
+ r"[0-9 •一—-]+$", b["text"]):
386
  bxs.pop(i)
387
  continue
388
  if not b["text"].strip():
 
390
  continue
391
  concatting_feats = [
392
  b["text"].strip()[-1] in ",;:'\",、‘“;:-",
393
+ len(b["text"].strip()) > 1 and b["text"].strip(
394
+ )[-2] in ",;:'\",‘“、;:",
395
  b["text"].strip()[0] in "。;?!?”)),,、:",
396
  ]
397
  # features for not concating
 
399
  b.get("layoutno", 0) != b.get("layoutno", 0),
400
  b["text"].strip()[-1] in "。?!?",
401
  self.is_english and b["text"].strip()[-1] in ".!?",
402
+ b["page_number"] == b_["page_number"] and b_["top"] -
403
  b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5,
404
  b["page_number"] < b_["page_number"] and abs(
405
  b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4,
 
408
  detach_feats = [b["x1"] < b_["x0"],
409
  b["x0"] > b_["x1"]]
410
  if (any(feats) and not any(concatting_feats)) or any(detach_feats):
411
+ print(
412
+ b["text"],
413
+ b_["text"],
414
+ any(feats),
415
+ any(concatting_feats),
416
+ any(detach_feats))
417
  i += 1
418
  continue
419
  # merge up and down
 
543
  i += 1
544
  continue
545
  findit = True
546
+ eng = re.match(
547
+ r"[0-9a-zA-Z :'.-]{5,}",
548
+ self.boxes[i]["text"].strip())
549
  self.boxes.pop(i)
550
+ if i >= len(self.boxes):
551
+ break
552
  prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
553
  self.boxes[i]["text"].strip().split(" ")[:2])
554
  while not prefix:
555
  self.boxes.pop(i)
556
+ if i >= len(self.boxes):
557
+ break
558
  prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
559
  self.boxes[i]["text"].strip().split(" ")[:2])
560
  self.boxes.pop(i)
561
+ if i >= len(self.boxes) or not prefix:
562
+ break
563
  for j in range(i, min(i + 128, len(self.boxes))):
564
  if not re.match(prefix, self.boxes[j]["text"]):
565
  continue
566
+ for k in range(i, j):
567
+ self.boxes.pop(i)
568
  break
569
+ if findit:
570
+ return
571
 
572
  page_dirty = [0] * len(self.page_images)
573
  for b in self.boxes:
574
  if re.search(r"(··|··|··)", b["text"]):
575
  page_dirty[b["page_number"] - 1] += 1
576
  page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3])
577
+ if not page_dirty:
578
+ return
579
  i = 0
580
  while i < len(self.boxes):
581
  if self.boxes[i]["page_number"] in page_dirty:
 
607
  b_["top"] = b["top"]
608
  self.boxes.pop(i)
609
 
610
+ def _extract_table_figure(self, need_image, ZM,
611
+ return_html, need_position):
612
  tables = {}
613
  figures = {}
614
  # extract figure and table boxes
 
620
  i += 1
621
  continue
622
  lout_no = str(self.boxes[i]["page_number"]) + \
623
+ "-" + str(self.boxes[i]["layoutno"])
624
  if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
625
  "title",
626
  "figure caption",
 
787
  for k, bxs in tables.items():
788
  if not bxs:
789
  continue
790
+ bxs = Recognizer.sort_Y_firstly(bxs, np.mean(
791
+ [(b["bottom"] - b["top"]) / 2 for b in bxs]))
792
  poss = []
793
  res.append((cropout(bxs, "table", poss),
794
  self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
 
796
 
797
  assert len(positions) == len(res)
798
 
799
+ if need_position:
800
+ return list(zip(res, positions))
801
  return res
802
 
803
  def proj_match(self, line):
 
901
  boxes.pop(0)
902
  mw = np.mean(widths)
903
  if mj or mw / pw >= 0.35 or mw > 200:
904
+ res.append(
905
+ "\n".join([c["text"] + self._line_tag(c, ZM) for c in lines]))
906
  else:
907
  logging.debug("REMOVED: " +
908
  "<<".join([c["text"] for c in lines]))
 
912
  @staticmethod
913
  def total_page_number(fnm, binary=None):
914
  try:
915
+ pdf = pdfplumber.open(
916
+ fnm) if not binary else pdfplumber.open(BytesIO(binary))
917
  return len(pdf.pages)
918
  except Exception as e:
919
+ pdf = fitz.open(fnm) if not binary else fitz.open(
920
+ stream=fnm, filetype="pdf")
921
  return len(pdf)
922
 
923
+ def __images__(self, fnm, zoomin=3, page_from=0,
924
+ page_to=299, callback=None):
925
  self.lefted_chars = []
926
  self.mean_height = []
927
  self.mean_width = []
 
931
  self.page_layout = []
932
  self.page_from = page_from
933
  try:
934
+ self.pdf = pdfplumber.open(fnm) if isinstance(
935
+ fnm, str) else pdfplumber.open(BytesIO(fnm))
936
  self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
937
  enumerate(self.pdf.pages[page_from:page_to])]
938
  self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in
939
  self.pdf.pages[page_from:page_to]]
940
  self.total_page = len(self.pdf.pages)
941
  except Exception as e:
942
+ self.pdf = fitz.open(fnm) if isinstance(
943
+ fnm, str) else fitz.open(
944
+ stream=fnm, filetype="pdf")
945
  self.page_images = []
946
  self.page_chars = []
947
  mat = fitz.Matrix(zoomin, zoomin)
948
  self.total_page = len(self.pdf)
949
  for i, page in enumerate(self.pdf):
950
+ if i < page_from:
951
+ continue
952
+ if i >= page_to:
953
+ break
954
  pix = page.get_pixmap(matrix=mat)
955
  img = Image.frombytes("RGB", [pix.width, pix.height],
956
  pix.samples)
 
967
  if isinstance(a, dict):
968
  self.outlines.append((a["/Title"], depth))
969
  continue
970
+ dfs(a, depth + 1)
971
  dfs(outlines, 0)
972
  except Exception as e:
973
  logging.warning(f"Outlines exception: {e}")
 
977
  logging.info("Images converted.")
978
  self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
979
  random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
980
+ range(len(self.page_chars))]
981
+ if sum([1 if e else 0 for e in self.is_english]) > len(
982
+ self.page_images) / 2:
983
  self.is_english = True
984
  else:
985
  self.is_english = False
 
1008
  # self.page_cum_height.append(
1009
  # np.max([c["bottom"] for c in chars]))
1010
  self.__ocr(i + 1, img, chars, zoomin)
1011
+ if callback:
1012
+ callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
1013
 
1014
+ if not self.is_english and not any(
1015
+ [c for c in self.page_chars]) and self.boxes:
1016
  bxes = [b for bxs in self.boxes for b in bxs]
1017
  self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}",
1018
  "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
 
1029
  self._text_merge()
1030
  self._concat_downward()
1031
  self._filter_forpages()
1032
+ tbls = self._extract_table_figure(
1033
+ need_image, zoomin, return_html, False)
1034
  return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
1035
 
1036
  def remove_tag(self, txt):
 
1044
  "#").strip("@").split("\t")
1045
  left, right, top, bottom = float(left), float(
1046
  right), float(top), float(bottom)
1047
+ poss.append(([int(p) - 1 for p in pn.split("-")],
1048
+ left, right, top, bottom))
1049
  if not poss:
1050
+ if need_position:
1051
+ return None, None
1052
  return
1053
 
1054
+ max_width = max(
1055
+ np.max([right - left for (_, left, right, _, _) in poss]), 6)
1056
  GAP = 6
1057
  pos = poss[0]
1058
+ poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(
1059
+ 0, pos[3] - 120), max(pos[3] - GAP, 0)))
1060
  pos = poss[-1]
1061
  poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP),
1062
  min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
 
1071
  self.page_images[pns[0]].crop((left * ZM, top * ZM,
1072
  right *
1073
  ZM, min(
1074
+ bottom, self.page_images[pns[0]].size[1])
1075
  ))
1076
  )
1077
  if 0 < ii < len(poss) - 1:
 
1092
  bottom -= self.page_images[pn].size[1]
1093
 
1094
  if not imgs:
1095
+ if need_position:
1096
+ return None, None
1097
  return
1098
  height = 0
1099
  for img in imgs:
 
1122
  pn = bx["page_number"]
1123
  top = bx["top"] - self.page_cum_height[pn - 1]
1124
  bott = bx["bottom"] - self.page_cum_height[pn - 1]
1125
+ poss.append((pn, bx["x0"], bx["x1"], top, min(
1126
+ bott, self.page_images[pn - 1].size[1] / ZM)))
1127
  while bott * ZM > self.page_images[pn - 1].size[1]:
1128
  bott -= self.page_images[pn - 1].size[1] / ZM
1129
  top = 0
1130
  pn += 1
1131
+ poss.append((pn, bx["x0"], bx["x1"], top, min(
1132
+ bott, self.page_images[pn - 1].size[1] / ZM)))
1133
  return poss
1134
 
1135
 
 
1138
  self.outlines = []
1139
  lines = []
1140
  try:
1141
+ self.pdf = pdf2_read(
1142
+ filename if isinstance(
1143
+ filename, str) else BytesIO(filename))
1144
  for page in self.pdf.pages[from_page:to_page]:
1145
  lines.extend([t for t in page.extract_text().split("\n")])
1146
 
1147
  outlines = self.pdf.outline
1148
+
1149
  def dfs(arr, depth):
1150
  for a in arr:
1151
  if isinstance(a, dict):
 
1168
  def remove_tag(txt):
1169
  raise NotImplementedError
1170
 
1171
+
1172
  if __name__ == "__main__":
1173
  pass
deepdoc/parser/ppt_parser.py CHANGED
@@ -23,7 +23,8 @@ class HuPptParser(object):
23
  tb = shape.table
24
  rows = []
25
  for i in range(1, len(tb.rows)):
26
- rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
 
27
  return "\n".join(rows)
28
 
29
  if shape.has_text_frame:
@@ -31,9 +32,10 @@ class HuPptParser(object):
31
 
32
  if shape.shape_type == 6:
33
  texts = []
34
- for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)):
35
  t = self.__extract(p)
36
- if t: texts.append(t)
 
37
  return "\n".join(texts)
38
 
39
  def __call__(self, fnm, from_page, to_page, callback=None):
@@ -43,12 +45,16 @@ class HuPptParser(object):
43
  txts = []
44
  self.total_page = len(ppt.slides)
45
  for i, slide in enumerate(ppt.slides):
46
- if i < from_page: continue
47
- if i >= to_page:break
 
 
48
  texts = []
49
- for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)):
 
50
  txt = self.__extract(shape)
51
- if txt: texts.append(txt)
 
52
  txts.append("\n".join(texts))
53
 
54
  return txts
 
23
  tb = shape.table
24
  rows = []
25
  for i in range(1, len(tb.rows)):
26
+ rows.append("; ".join([tb.cell(
27
+ 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
28
  return "\n".join(rows)
29
 
30
  if shape.has_text_frame:
 
32
 
33
  if shape.shape_type == 6:
34
  texts = []
35
+ for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
36
  t = self.__extract(p)
37
+ if t:
38
+ texts.append(t)
39
  return "\n".join(texts)
40
 
41
  def __call__(self, fnm, from_page, to_page, callback=None):
 
45
  txts = []
46
  self.total_page = len(ppt.slides)
47
  for i, slide in enumerate(ppt.slides):
48
+ if i < from_page:
49
+ continue
50
+ if i >= to_page:
51
+ break
52
  texts = []
53
+ for shape in sorted(
54
+ slide.shapes, key=lambda x: (x.top // 10, x.left)):
55
  txt = self.__extract(shape)
56
+ if txt:
57
+ texts.append(txt)
58
  txts.append("\n".join(texts))
59
 
60
  return txts
deepdoc/vision/layout_recognizer.py CHANGED
@@ -24,18 +24,19 @@ from deepdoc.vision import Recognizer
24
 
25
  class LayoutRecognizer(Recognizer):
26
  labels = [
27
- "_background_",
28
- "Text",
29
- "Title",
30
- "Figure",
31
- "Figure caption",
32
- "Table",
33
- "Table caption",
34
- "Header",
35
- "Footer",
36
- "Reference",
37
- "Equation",
38
- ]
 
39
  def __init__(self, domain):
40
  try:
41
  model_dir = snapshot_download(
@@ -47,10 +48,12 @@ class LayoutRecognizer(Recognizer):
47
  except Exception as e:
48
  model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
49
 
50
- super().__init__(self.labels, domain, model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
51
  self.garbage_layouts = ["footer", "header", "reference"]
52
 
53
- def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
 
54
  def __is_garbage(b):
55
  patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
56
  r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
@@ -75,7 +78,8 @@ class LayoutRecognizer(Recognizer):
75
  "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
76
  "page_number": pn,
77
  } for b in lts]
78
- lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2)
 
79
  lts = self.layouts_cleanup(bxs, lts)
80
  page_layout.append(lts)
81
 
@@ -93,17 +97,20 @@ class LayoutRecognizer(Recognizer):
93
  continue
94
 
95
  ii = self.find_overlapped_with_threashold(bxs[i], lts_,
96
- thr=0.4)
97
  if ii is None: # belong to nothing
98
  bxs[i]["layout_type"] = ""
99
  i += 1
100
  continue
101
  lts_[ii]["visited"] = True
102
  keep_feats = [
103
- lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1]*0.9/scale_factor,
104
- lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1]*0.1/scale_factor,
 
 
105
  ]
106
- if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats):
 
107
  if lts_[ii]["type"] not in garbages:
108
  garbages[lts_[ii]["type"]] = []
109
  garbages[lts_[ii]["type"]].append(bxs[i]["text"])
@@ -111,7 +118,8 @@ class LayoutRecognizer(Recognizer):
111
  continue
112
 
113
  bxs[i]["layoutno"] = f"{ty}-{ii}"
114
- bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"]!="equation" else "figure"
 
115
  i += 1
116
 
117
  for lt in ["footer", "header", "reference", "figure caption",
@@ -120,7 +128,7 @@ class LayoutRecognizer(Recognizer):
120
 
121
  # add box to figure layouts which has not text box
122
  for i, lt in enumerate(
123
- [lt for lt in lts if lt["type"] in ["figure","equation"]]):
124
  if lt.get("visited"):
125
  continue
126
  lt = deepcopy(lt)
@@ -143,6 +151,3 @@ class LayoutRecognizer(Recognizer):
143
 
144
  ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
145
  return ocr_res, page_layout
146
-
147
-
148
-
 
24
 
25
  class LayoutRecognizer(Recognizer):
26
  labels = [
27
+ "_background_",
28
+ "Text",
29
+ "Title",
30
+ "Figure",
31
+ "Figure caption",
32
+ "Table",
33
+ "Table caption",
34
+ "Header",
35
+ "Footer",
36
+ "Reference",
37
+ "Equation",
38
+ ]
39
+
40
  def __init__(self, domain):
41
  try:
42
  model_dir = snapshot_download(
 
48
  except Exception as e:
49
  model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
50
 
51
+ # os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
52
+ super().__init__(self.labels, domain, model_dir)
53
  self.garbage_layouts = ["footer", "header", "reference"]
54
 
55
+ def __call__(self, image_list, ocr_res, scale_factor=3,
56
+ thr=0.2, batch_size=16, drop=True):
57
  def __is_garbage(b):
58
  patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
59
  r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
 
78
  "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
79
  "page_number": pn,
80
  } for b in lts]
81
+ lts = self.sort_Y_firstly(lts, np.mean(
82
+ [l["bottom"] - l["top"] for l in lts]) / 2)
83
  lts = self.layouts_cleanup(bxs, lts)
84
  page_layout.append(lts)
85
 
 
97
  continue
98
 
99
  ii = self.find_overlapped_with_threashold(bxs[i], lts_,
100
+ thr=0.4)
101
  if ii is None: # belong to nothing
102
  bxs[i]["layout_type"] = ""
103
  i += 1
104
  continue
105
  lts_[ii]["visited"] = True
106
  keep_feats = [
107
+ lts_[
108
+ ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
109
+ lts_[
110
+ ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
111
  ]
112
+ if drop and lts_[
113
+ ii]["type"] in self.garbage_layouts and not any(keep_feats):
114
  if lts_[ii]["type"] not in garbages:
115
  garbages[lts_[ii]["type"]] = []
116
  garbages[lts_[ii]["type"]].append(bxs[i]["text"])
 
118
  continue
119
 
120
  bxs[i]["layoutno"] = f"{ty}-{ii}"
121
+ bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
122
+ ii]["type"] != "equation" else "figure"
123
  i += 1
124
 
125
  for lt in ["footer", "header", "reference", "figure caption",
 
128
 
129
  # add box to figure layouts which has not text box
130
  for i, lt in enumerate(
131
+ [lt for lt in lts if lt["type"] in ["figure", "equation"]]):
132
  if lt.get("visited"):
133
  continue
134
  lt = deepcopy(lt)
 
151
 
152
  ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
153
  return ocr_res, page_layout
 
 
 
deepdoc/vision/operators.py CHANGED
@@ -63,6 +63,7 @@ class DecodeImage(object):
63
  data['image'] = img
64
  return data
65
 
 
66
  class StandardizeImage(object):
67
  """normalize image
68
  Args:
@@ -707,4 +708,4 @@ def preprocess(im, preprocess_ops):
707
  im, im_info = decode_image(im, im_info)
708
  for operator in preprocess_ops:
709
  im, im_info = operator(im, im_info)
710
- return im, im_info
 
63
  data['image'] = img
64
  return data
65
 
66
+
67
  class StandardizeImage(object):
68
  """normalize image
69
  Args:
 
708
  im, im_info = decode_image(im, im_info)
709
  for operator in preprocess_ops:
710
  im, im_info = operator(im, im_info)
711
+ return im, im_info
deepdoc/vision/t_ocr.py CHANGED
@@ -11,12 +11,20 @@
11
  # limitations under the License.
12
  #
13
 
14
- import os, sys
15
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
16
- import numpy as np
17
- import argparse
18
- from deepdoc.vision import OCR, init_in_out
19
  from deepdoc.vision.seeit import draw_box
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def main(args):
22
  ocr = OCR()
@@ -26,14 +34,14 @@ def main(args):
26
  bxs = ocr(np.array(img))
27
  bxs = [(line[0], line[1][0]) for line in bxs]
28
  bxs = [{
29
- "text": t,
30
- "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
31
- "type": "ocr",
32
- "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
33
  img = draw_box(images[i], bxs, ["ocr"], 1.)
34
  img.save(outputs[i], quality=95)
35
- with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs]))
36
-
37
 
38
 
39
  if __name__ == "__main__":
@@ -42,6 +50,6 @@ if __name__ == "__main__":
42
  help="Directory where to store images or PDFs, or a file path to a single image or PDF",
43
  required=True)
44
  parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
45
- default="./ocr_outputs")
46
  args = parser.parse_args()
47
- main(args)
 
11
  # limitations under the License.
12
  #
13
 
 
 
 
 
 
14
  from deepdoc.vision.seeit import draw_box
15
+ from deepdoc.vision import OCR, init_in_out
16
+ import argparse
17
+ import numpy as np
18
+ import os
19
+ import sys
20
+ sys.path.insert(
21
+ 0,
22
+ os.path.abspath(
23
+ os.path.join(
24
+ os.path.dirname(
25
+ os.path.abspath(__file__)),
26
+ '../../')))
27
+
28
 
29
  def main(args):
30
  ocr = OCR()
 
34
  bxs = ocr(np.array(img))
35
  bxs = [(line[0], line[1][0]) for line in bxs]
36
  bxs = [{
37
+ "text": t,
38
+ "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
39
+ "type": "ocr",
40
+ "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
41
  img = draw_box(images[i], bxs, ["ocr"], 1.)
42
  img.save(outputs[i], quality=95)
43
+ with open(outputs[i] + ".txt", "w+") as f:
44
+ f.write("\n".join([o["text"] for o in bxs]))
45
 
46
 
47
  if __name__ == "__main__":
 
50
  help="Directory where to store images or PDFs, or a file path to a single image or PDF",
51
  required=True)
52
  parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
53
+ default="./ocr_outputs")
54
  args = parser.parse_args()
55
+ main(args)
deepdoc/vision/t_recognizer.py CHANGED
@@ -11,24 +11,35 @@
11
  # limitations under the License.
12
  #
13
 
14
- import os, sys
 
 
 
 
 
15
  import re
16
 
17
  import numpy as np
18
 
19
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
20
-
21
- import argparse
22
- from api.utils.file_utils import get_project_base_directory
23
- from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
24
- from deepdoc.vision.seeit import draw_box
 
25
 
26
 
27
  def main(args):
28
  images, outputs = init_in_out(args)
29
  if args.mode.lower() == "layout":
30
  labels = LayoutRecognizer.labels
31
- detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
 
 
 
 
32
  if args.mode.lower() == "tsr":
33
  labels = TableStructureRecognizer.labels
34
  detr = TableStructureRecognizer()
@@ -39,7 +50,8 @@ def main(args):
39
  if args.mode.lower() == "tsr":
40
  #lyt = [t for t in lyt if t["type"] == "table column"]
41
  html = get_table_html(images[i], lyt, ocr)
42
- with open(outputs[i]+".html", "w+") as f: f.write(html)
 
43
  lyt = [{
44
  "type": t["label"],
45
  "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
@@ -58,7 +70,7 @@ def get_table_html(img, tb_cpns, ocr):
58
  "bottom": b[-1][1],
59
  "layout_type": "table",
60
  "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
61
- np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3
62
  )
63
 
64
  def gather(kwd, fzy=10, ption=0.6):
@@ -117,7 +129,7 @@ def get_table_html(img, tb_cpns, ocr):
117
  margin-bottom: 50px;
118
  border: 1px solid #e1e1e1;
119
  }
120
-
121
  caption {
122
  color: #6ac1ca;
123
  font-size: 20px;
@@ -126,25 +138,25 @@ def get_table_html(img, tb_cpns, ocr):
126
  font-weight: 600;
127
  margin-bottom: 10px;
128
  }
129
-
130
  ._table_1nkzy_11 table {
131
  width: 100%%;
132
  border-collapse: collapse;
133
  }
134
-
135
  th {
136
  color: #fff;
137
  background-color: #6ac1ca;
138
  }
139
-
140
  td:hover {
141
  background: #c1e8e8;
142
  }
143
-
144
  tr:nth-child(even) {
145
  background-color: #f2f2f2;
146
  }
147
-
148
  ._table_1nkzy_11 th,
149
  ._table_1nkzy_11 td {
150
  text-align: center;
@@ -157,7 +169,7 @@ def get_table_html(img, tb_cpns, ocr):
157
  %s
158
  </body>
159
  </html>
160
- """% TableStructureRecognizer.construct_table(boxes, html=True)
161
  return html
162
 
163
 
@@ -168,7 +180,10 @@ if __name__ == "__main__":
168
  required=True)
169
  parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
170
  default="./layouts_outputs")
171
- parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5)
 
 
 
172
  parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
173
  default="layout")
174
  args = parser.parse_args()
 
11
  # limitations under the License.
12
  #
13
 
14
+ from deepdoc.vision.seeit import draw_box
15
+ from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
16
+ from api.utils.file_utils import get_project_base_directory
17
+ import argparse
18
+ import os
19
+ import sys
20
  import re
21
 
22
  import numpy as np
23
 
24
+ sys.path.insert(
25
+ 0,
26
+ os.path.abspath(
27
+ os.path.join(
28
+ os.path.dirname(
29
+ os.path.abspath(__file__)),
30
+ '../../')))
31
 
32
 
33
  def main(args):
34
  images, outputs = init_in_out(args)
35
  if args.mode.lower() == "layout":
36
  labels = LayoutRecognizer.labels
37
+ detr = Recognizer(
38
+ labels,
39
+ "layout",
40
+ os.path.join(
41
+ get_project_base_directory(),
42
+ "rag/res/deepdoc/"))
43
  if args.mode.lower() == "tsr":
44
  labels = TableStructureRecognizer.labels
45
  detr = TableStructureRecognizer()
 
50
  if args.mode.lower() == "tsr":
51
  #lyt = [t for t in lyt if t["type"] == "table column"]
52
  html = get_table_html(images[i], lyt, ocr)
53
+ with open(outputs[i] + ".html", "w+") as f:
54
+ f.write(html)
55
  lyt = [{
56
  "type": t["label"],
57
  "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
 
70
  "bottom": b[-1][1],
71
  "layout_type": "table",
72
  "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
73
+ np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
74
  )
75
 
76
  def gather(kwd, fzy=10, ption=0.6):
 
129
  margin-bottom: 50px;
130
  border: 1px solid #e1e1e1;
131
  }
132
+
133
  caption {
134
  color: #6ac1ca;
135
  font-size: 20px;
 
138
  font-weight: 600;
139
  margin-bottom: 10px;
140
  }
141
+
142
  ._table_1nkzy_11 table {
143
  width: 100%%;
144
  border-collapse: collapse;
145
  }
146
+
147
  th {
148
  color: #fff;
149
  background-color: #6ac1ca;
150
  }
151
+
152
  td:hover {
153
  background: #c1e8e8;
154
  }
155
+
156
  tr:nth-child(even) {
157
  background-color: #f2f2f2;
158
  }
159
+
160
  ._table_1nkzy_11 th,
161
  ._table_1nkzy_11 td {
162
  text-align: center;
 
169
  %s
170
  </body>
171
  </html>
172
+ """ % TableStructureRecognizer.construct_table(boxes, html=True)
173
  return html
174
 
175
 
 
180
  required=True)
181
  parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
182
  default="./layouts_outputs")
183
+ parser.add_argument(
184
+ '--threshold',
185
+ help="A threshold to filter out detections. Default: 0.5",
186
+ default=0.5)
187
  parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
188
  default="layout")
189
  args = parser.parse_args()
deepdoc/vision/table_structure_recognizer.py CHANGED
@@ -44,7 +44,8 @@ class TableStructureRecognizer(Recognizer):
44
  except Exception as e:
45
  model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
46
 
47
- super().__init__(self.labels, "tsr", model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
48
 
49
  def __call__(self, images, thr=0.2):
50
  tbls = super().__call__(images, thr)
@@ -138,7 +139,8 @@ class TableStructureRecognizer(Recognizer):
138
  i = 0
139
  while i < len(boxes):
140
  if TableStructureRecognizer.is_caption(boxes[i]):
141
- if is_english: cap + " "
 
142
  cap += boxes[i]["text"]
143
  boxes.pop(i)
144
  i -= 1
@@ -164,7 +166,7 @@ class TableStructureRecognizer(Recognizer):
164
  lst_r = rows[-1]
165
  if lst_r[-1].get("R", "") != b.get("R", "") \
166
  or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
167
- ): # new row
168
  btm = b["bottom"]
169
  b["rn"] += 1
170
  rows.append([b])
@@ -214,9 +216,9 @@ class TableStructureRecognizer(Recognizer):
214
  j += 1
215
  continue
216
  f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
217
- [j - 1][0].get("text")) or j == 0
218
  ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
219
- [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
220
  if f and ff:
221
  j += 1
222
  continue
@@ -277,9 +279,9 @@ class TableStructureRecognizer(Recognizer):
277
  i += 1
278
  continue
279
  f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
280
- [jj][0].get("text")) or i == 0
281
  ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
282
- [jj][0].get("text")) or i + 1 >= len(tbl)
283
  if f and ff:
284
  i += 1
285
  continue
@@ -366,7 +368,8 @@ class TableStructureRecognizer(Recognizer):
366
  continue
367
  txt = ""
368
  if arr:
369
- h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
 
370
  txt = " ".join([c["text"]
371
  for c in Recognizer.sort_Y_firstly(arr, h)])
372
  txts.append(txt)
@@ -438,8 +441,8 @@ class TableStructureRecognizer(Recognizer):
438
  else "") + headers[j - 1][k]
439
  else:
440
  headers[j][k] = headers[j - 1][k] \
441
- + (de if headers[j - 1][k] else "") \
442
- + headers[j][k]
443
 
444
  logging.debug(
445
  f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
 
44
  except Exception as e:
45
  model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
46
 
47
+ # os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
48
+ super().__init__(self.labels, "tsr", model_dir)
49
 
50
  def __call__(self, images, thr=0.2):
51
  tbls = super().__call__(images, thr)
 
139
  i = 0
140
  while i < len(boxes):
141
  if TableStructureRecognizer.is_caption(boxes[i]):
142
+ if is_english:
143
+ cap + " "
144
  cap += boxes[i]["text"]
145
  boxes.pop(i)
146
  i -= 1
 
166
  lst_r = rows[-1]
167
  if lst_r[-1].get("R", "") != b.get("R", "") \
168
  or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
169
+ ): # new row
170
  btm = b["bottom"]
171
  b["rn"] += 1
172
  rows.append([b])
 
216
  j += 1
217
  continue
218
  f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
219
+ [j - 1][0].get("text")) or j == 0
220
  ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
221
+ [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
222
  if f and ff:
223
  j += 1
224
  continue
 
279
  i += 1
280
  continue
281
  f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
282
+ [jj][0].get("text")) or i == 0
283
  ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
284
+ [jj][0].get("text")) or i + 1 >= len(tbl)
285
  if f and ff:
286
  i += 1
287
  continue
 
368
  continue
369
  txt = ""
370
  if arr:
371
+ h = min(np.min([c["bottom"] - c["top"]
372
+ for c in arr]) / 2, 10)
373
  txt = " ".join([c["text"]
374
  for c in Recognizer.sort_Y_firstly(arr, h)])
375
  txts.append(txt)
 
441
  else "") + headers[j - 1][k]
442
  else:
443
  headers[j][k] = headers[j - 1][k] \
444
+ + (de if headers[j - 1][k] else "") \
445
+ + headers[j][k]
446
 
447
  logging.debug(
448
  f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
rag/app/book.py CHANGED
@@ -48,10 +48,12 @@ class Pdf(PdfParser):
48
 
49
  callback(0.8, "Text extraction finished")
50
 
51
- return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno","")) for b in self.boxes], tbls
 
52
 
53
 
54
- def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
 
55
  """
56
  Supported file formats are docx, pdf, txt.
57
  Since a book is long and not all the parts are useful, if it's a PDF,
@@ -63,48 +65,63 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
63
  }
64
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
65
  pdf_parser = None
66
- sections,tbls = [], []
67
  if re.search(r"\.docx?$", filename, re.IGNORECASE):
68
  callback(0.1, "Start to parse.")
69
  doc_parser = DocxParser()
70
  # TODO: table of contents need to be removed
71
- sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page)
72
- remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200)))
 
 
73
  callback(0.8, "Finish parsing.")
74
 
75
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
76
- pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
 
 
77
  sections, tbls = pdf_parser(filename if not binary else binary,
78
- from_page=from_page, to_page=to_page, callback=callback)
79
 
80
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
81
  callback(0.1, "Start to parse.")
82
  txt = ""
83
- if binary:txt = binary.decode("utf-8")
 
84
  else:
85
  with open(filename, "r") as f:
86
  while True:
87
  l = f.readline()
88
- if not l:break
 
89
  txt += l
90
  sections = txt.split("\n")
91
- sections = [(l,"") for l in sections if l]
92
- remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200)))
 
93
  callback(0.8, "Finish parsing.")
94
 
95
- else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
 
96
 
97
  make_colon_as_title(sections)
98
- bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)])
 
99
  if bull >= 0:
100
- chunks = ["\n".join(ck) for ck in hierarchical_merge(bull, sections, 3)]
 
101
  else:
102
- sections = [s.split("@") for s,_ in sections]
103
- sections = [(pr[0], "@"+pr[1]) for pr in sections if len(pr)==2]
104
- chunks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?"))
 
 
 
105
 
106
  # is it English
107
- eng = lang.lower() == "english"#is_english(random_choices([t for t, _ in sections], k=218))
 
108
 
109
  res = tokenize_table(tbls, doc, eng)
110
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
@@ -114,6 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
114
 
115
  if __name__ == "__main__":
116
  import sys
 
117
  def dummy(prog=None, msg=""):
118
  pass
119
  chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
 
48
 
49
  callback(0.8, "Text extraction finished")
50
 
51
+ return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", ""))
52
+ for b in self.boxes], tbls
53
 
54
 
55
+ def chunk(filename, binary=None, from_page=0, to_page=100000,
56
+ lang="Chinese", callback=None, **kwargs):
57
  """
58
  Supported file formats are docx, pdf, txt.
59
  Since a book is long and not all the parts are useful, if it's a PDF,
 
65
  }
66
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
67
  pdf_parser = None
68
+ sections, tbls = [], []
69
  if re.search(r"\.docx?$", filename, re.IGNORECASE):
70
  callback(0.1, "Start to parse.")
71
  doc_parser = DocxParser()
72
  # TODO: table of contents need to be removed
73
+ sections, tbls = doc_parser(
74
+ binary if binary else filename, from_page=from_page, to_page=to_page)
75
+ remove_contents_table(sections, eng=is_english(
76
+ random_choices([t for t, _ in sections], k=200)))
77
  callback(0.8, "Finish parsing.")
78
 
79
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
80
+ pdf_parser = Pdf() if kwargs.get(
81
+ "parser_config", {}).get(
82
+ "layout_recognize", True) else PlainParser()
83
  sections, tbls = pdf_parser(filename if not binary else binary,
84
+ from_page=from_page, to_page=to_page, callback=callback)
85
 
86
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
87
  callback(0.1, "Start to parse.")
88
  txt = ""
89
+ if binary:
90
+ txt = binary.decode("utf-8")
91
  else:
92
  with open(filename, "r") as f:
93
  while True:
94
  l = f.readline()
95
+ if not l:
96
+ break
97
  txt += l
98
  sections = txt.split("\n")
99
+ sections = [(l, "") for l in sections if l]
100
+ remove_contents_table(sections, eng=is_english(
101
+ random_choices([t for t, _ in sections], k=200)))
102
  callback(0.8, "Finish parsing.")
103
 
104
+ else:
105
+ raise NotImplementedError(
106
+ "file type not supported yet(docx, pdf, txt supported)")
107
 
108
  make_colon_as_title(sections)
109
+ bull = bullets_category(
110
+ [t for t in random_choices([t for t, _ in sections], k=100)])
111
  if bull >= 0:
112
+ chunks = ["\n".join(ck)
113
+ for ck in hierarchical_merge(bull, sections, 3)]
114
  else:
115
+ sections = [s.split("@") for s, _ in sections]
116
+ sections = [(pr[0], "@" + pr[1]) for pr in sections if len(pr) == 2]
117
+ chunks = naive_merge(
118
+ sections, kwargs.get(
119
+ "chunk_token_num", 256), kwargs.get(
120
+ "delimer", "\n。;!?"))
121
 
122
  # is it English
123
+ # is_english(random_choices([t for t, _ in sections], k=218))
124
+ eng = lang.lower() == "english"
125
 
126
  res = tokenize_table(tbls, doc, eng)
127
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
 
131
 
132
  if __name__ == "__main__":
133
  import sys
134
+
135
  def dummy(prog=None, msg=""):
136
  pass
137
  chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
rag/app/laws.py CHANGED
@@ -35,8 +35,10 @@ class Docx(DocxParser):
35
  pn = 0
36
  lines = []
37
  for p in self.doc.paragraphs:
38
- if pn > to_page:break
39
- if from_page <= pn < to_page and p.text.strip(): lines.append(self.__clean(p.text))
 
 
40
  for run in p.runs:
41
  if 'lastRenderedPageBreak' in run._element.xml:
42
  pn += 1
@@ -63,15 +65,18 @@ class Pdf(PdfParser):
63
  start = timer()
64
  self._layouts_rec(zoomin)
65
  callback(0.67, "Layout analysis finished")
66
- cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1)))
 
67
  self._naive_vertical_merge()
68
 
69
  callback(0.8, "Text extraction finished")
70
 
71
- return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], None
 
72
 
73
 
74
- def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
 
75
  """
76
  Supported file formats are docx, pdf, txt.
77
  """
@@ -89,41 +94,50 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
89
  callback(0.8, "Finish parsing.")
90
 
91
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
92
- pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
93
- for txt, poss in pdf_parser(filename if not binary else binary,
94
- from_page=from_page, to_page=to_page, callback=callback)[0]:
95
- sections.append(txt + poss)
 
 
96
 
97
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
98
  callback(0.1, "Start to parse.")
99
  txt = ""
100
- if binary:txt = binary.decode("utf-8")
 
101
  else:
102
  with open(filename, "r") as f:
103
  while True:
104
  l = f.readline()
105
- if not l:break
 
106
  txt += l
107
  sections = txt.split("\n")
108
  sections = [l for l in sections if l]
109
  callback(0.8, "Finish parsing.")
110
- else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
 
111
 
112
  # is it English
113
- eng = lang.lower() == "english"#is_english(sections)
114
  # Remove 'Contents' part
115
  remove_contents_table(sections, eng)
116
 
117
  make_colon_as_title(sections)
118
  bull = bullets_category(sections)
119
  chunks = hierarchical_merge(bull, sections, 3)
120
- if not chunks: callback(0.99, "No chunk parsed out.")
 
121
 
122
- return tokenize_chunks(["\n".join(ck) for ck in chunks], doc, eng, pdf_parser)
 
123
 
124
 
125
  if __name__ == "__main__":
126
  import sys
 
127
  def dummy(prog=None, msg=""):
128
  pass
129
  chunk(sys.argv[1], callback=dummy)
 
35
  pn = 0
36
  lines = []
37
  for p in self.doc.paragraphs:
38
+ if pn > to_page:
39
+ break
40
+ if from_page <= pn < to_page and p.text.strip():
41
+ lines.append(self.__clean(p.text))
42
  for run in p.runs:
43
  if 'lastRenderedPageBreak' in run._element.xml:
44
  pn += 1
 
65
  start = timer()
66
  self._layouts_rec(zoomin)
67
  callback(0.67, "Layout analysis finished")
68
+ cron_logger.info("paddle layouts:".format(
69
+ (timer() - start) / (self.total_page + 0.1)))
70
  self._naive_vertical_merge()
71
 
72
  callback(0.8, "Text extraction finished")
73
 
74
+ return [(b["text"], self._line_tag(b, zoomin))
75
+ for b in self.boxes], None
76
 
77
 
78
+ def chunk(filename, binary=None, from_page=0, to_page=100000,
79
+ lang="Chinese", callback=None, **kwargs):
80
  """
81
  Supported file formats are docx, pdf, txt.
82
  """
 
94
  callback(0.8, "Finish parsing.")
95
 
96
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
97
+ pdf_parser = Pdf() if kwargs.get(
98
+ "parser_config", {}).get(
99
+ "layout_recognize", True) else PlainParser()
100
+ for txt, poss in pdf_parser(filename if not binary else binary,
101
+ from_page=from_page, to_page=to_page, callback=callback)[0]:
102
+ sections.append(txt + poss)
103
 
104
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
105
  callback(0.1, "Start to parse.")
106
  txt = ""
107
+ if binary:
108
+ txt = binary.decode("utf-8")
109
  else:
110
  with open(filename, "r") as f:
111
  while True:
112
  l = f.readline()
113
+ if not l:
114
+ break
115
  txt += l
116
  sections = txt.split("\n")
117
  sections = [l for l in sections if l]
118
  callback(0.8, "Finish parsing.")
119
+ else:
120
+ raise NotImplementedError(
121
+ "file type not supported yet(docx, pdf, txt supported)")
122
 
123
  # is it English
124
+ eng = lang.lower() == "english" # is_english(sections)
125
  # Remove 'Contents' part
126
  remove_contents_table(sections, eng)
127
 
128
  make_colon_as_title(sections)
129
  bull = bullets_category(sections)
130
  chunks = hierarchical_merge(bull, sections, 3)
131
+ if not chunks:
132
+ callback(0.99, "No chunk parsed out.")
133
 
134
+ return tokenize_chunks(["\n".join(ck)
135
+ for ck in chunks], doc, eng, pdf_parser)
136
 
137
 
138
  if __name__ == "__main__":
139
  import sys
140
+
141
  def dummy(prog=None, msg=""):
142
  pass
143
  chunk(sys.argv[1], callback=dummy)
rag/app/manual.py CHANGED
@@ -25,10 +25,10 @@ class Pdf(PdfParser):
25
  callback
26
  )
27
  callback(msg="OCR finished.")
28
- #for bb in self.boxes:
29
  # for b in bb:
30
  # print(b)
31
- print("OCR:", timer()-start)
32
 
33
  self._layouts_rec(zoomin)
34
  callback(0.65, "Layout analysis finished.")
@@ -45,30 +45,35 @@ class Pdf(PdfParser):
45
  for b in self.boxes:
46
  b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
47
 
48
- return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)], tbls
 
49
 
50
 
51
-
52
- def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
53
  """
54
  Only pdf is supported.
55
  """
56
  pdf_parser = None
57
 
58
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
59
- pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
 
 
60
  sections, tbls = pdf_parser(filename if not binary else binary,
61
- from_page=from_page, to_page=to_page, callback=callback)
62
- if sections and len(sections[0])<3: sections = [(t, l, [[0]*5]) for t, l in sections]
 
63
 
64
- else: raise NotImplementedError("file type not supported yet(pdf supported)")
 
65
  doc = {
66
  "docnm_kwd": filename
67
  }
68
  doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
69
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
70
  # is it English
71
- eng = lang.lower() == "english"#pdf_parser.is_english
72
 
73
  # set pivot using the most frequent type of title,
74
  # then merge between 2 pivot
@@ -79,7 +84,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
79
  for txt, _, _ in sections:
80
  for t, lvl in pdf_parser.outlines:
81
  tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
82
- tks_ = set([txt[i] + txt[i + 1] for i in range(min(len(t), len(txt) - 1))])
 
83
  if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
84
  levels.append(lvl)
85
  break
@@ -87,24 +93,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
87
  levels.append(max_lvl + 1)
88
 
89
  else:
90
- bull = bullets_category([txt for txt,_,_ in sections])
91
- most_level, levels = title_frequency(bull, [(txt, l) for txt, l, poss in sections])
 
92
 
93
  assert len(sections) == len(levels)
94
  sec_ids = []
95
  sid = 0
96
  for i, lvl in enumerate(levels):
97
- if lvl <= most_level and i > 0 and lvl != levels[i - 1]: sid += 1
 
98
  sec_ids.append(sid)
99
  # print(lvl, self.boxes[i]["text"], most_level, sid)
100
 
101
- sections = [(txt, sec_ids[i], poss) for i, (txt, _, poss) in enumerate(sections)]
 
102
  for (img, rows), poss in tbls:
103
  sections.append((rows if isinstance(rows, str) else rows[0], -1,
104
  [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
105
 
106
  def tag(pn, left, right, top, bottom):
107
- if pn+left+right+top+bottom == 0:
108
  return ""
109
  return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
110
  .format(pn, left, right, top, bottom)
@@ -112,7 +121,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
112
  chunks = []
113
  last_sid = -2
114
  tk_cnt = 0
115
- for txt, sec_id, poss in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1])):
 
116
  poss = "\t".join([tag(*pos) for pos in poss])
117
  if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1):
118
  if chunks:
@@ -121,16 +131,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
121
  continue
122
  chunks.append(txt + poss)
123
  tk_cnt = num_tokens_from_string(txt)
124
- if sec_id > -1: last_sid = sec_id
 
125
 
126
  res = tokenize_table(tbls, doc, eng)
127
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
128
  return res
129
 
130
 
131
-
132
  if __name__ == "__main__":
133
  import sys
 
134
  def dummy(prog=None, msg=""):
135
  pass
136
  chunk(sys.argv[1], callback=dummy)
 
25
  callback
26
  )
27
  callback(msg="OCR finished.")
28
+ # for bb in self.boxes:
29
  # for b in bb:
30
  # print(b)
31
+ print("OCR:", timer() - start)
32
 
33
  self._layouts_rec(zoomin)
34
  callback(0.65, "Layout analysis finished.")
 
45
  for b in self.boxes:
46
  b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
47
 
48
+ return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin))
49
+ for i, b in enumerate(self.boxes)], tbls
50
 
51
 
52
+ def chunk(filename, binary=None, from_page=0, to_page=100000,
53
+ lang="Chinese", callback=None, **kwargs):
54
  """
55
  Only pdf is supported.
56
  """
57
  pdf_parser = None
58
 
59
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
60
+ pdf_parser = Pdf() if kwargs.get(
61
+ "parser_config", {}).get(
62
+ "layout_recognize", True) else PlainParser()
63
  sections, tbls = pdf_parser(filename if not binary else binary,
64
+ from_page=from_page, to_page=to_page, callback=callback)
65
+ if sections and len(sections[0]) < 3:
66
+ sections = [(t, l, [[0] * 5]) for t, l in sections]
67
 
68
+ else:
69
+ raise NotImplementedError("file type not supported yet(pdf supported)")
70
  doc = {
71
  "docnm_kwd": filename
72
  }
73
  doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
74
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
75
  # is it English
76
+ eng = lang.lower() == "english" # pdf_parser.is_english
77
 
78
  # set pivot using the most frequent type of title,
79
  # then merge between 2 pivot
 
84
  for txt, _, _ in sections:
85
  for t, lvl in pdf_parser.outlines:
86
  tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
87
+ tks_ = set([txt[i] + txt[i + 1]
88
+ for i in range(min(len(t), len(txt) - 1))])
89
  if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
90
  levels.append(lvl)
91
  break
 
93
  levels.append(max_lvl + 1)
94
 
95
  else:
96
+ bull = bullets_category([txt for txt, _, _ in sections])
97
+ most_level, levels = title_frequency(
98
+ bull, [(txt, l) for txt, l, poss in sections])
99
 
100
  assert len(sections) == len(levels)
101
  sec_ids = []
102
  sid = 0
103
  for i, lvl in enumerate(levels):
104
+ if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
105
+ sid += 1
106
  sec_ids.append(sid)
107
  # print(lvl, self.boxes[i]["text"], most_level, sid)
108
 
109
+ sections = [(txt, sec_ids[i], poss)
110
+ for i, (txt, _, poss) in enumerate(sections)]
111
  for (img, rows), poss in tbls:
112
  sections.append((rows if isinstance(rows, str) else rows[0], -1,
113
  [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
114
 
115
  def tag(pn, left, right, top, bottom):
116
+ if pn + left + right + top + bottom == 0:
117
  return ""
118
  return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
119
  .format(pn, left, right, top, bottom)
 
121
  chunks = []
122
  last_sid = -2
123
  tk_cnt = 0
124
+ for txt, sec_id, poss in sorted(sections, key=lambda x: (
125
+ x[-1][0][0], x[-1][0][3], x[-1][0][1])):
126
  poss = "\t".join([tag(*pos) for pos in poss])
127
  if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1):
128
  if chunks:
 
131
  continue
132
  chunks.append(txt + poss)
133
  tk_cnt = num_tokens_from_string(txt)
134
+ if sec_id > -1:
135
+ last_sid = sec_id
136
 
137
  res = tokenize_table(tbls, doc, eng)
138
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
139
  return res
140
 
141
 
 
142
  if __name__ == "__main__":
143
  import sys
144
+
145
  def dummy(prog=None, msg=""):
146
  pass
147
  chunk(sys.argv[1], callback=dummy)
rag/app/naive.py CHANGED
@@ -44,11 +44,14 @@ class Pdf(PdfParser):
44
  tbls = self._extract_table_figure(True, zoomin, True, True)
45
  self._naive_vertical_merge()
46
 
47
- cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
48
- return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
 
 
49
 
50
 
51
- def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
 
52
  """
53
  Supported file formats are docx, pdf, excel, txt.
54
  This method apply the naive ways to chunk files.
@@ -56,8 +59,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
56
  Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
57
  """
58
 
59
- eng = lang.lower() == "english"#is_english(cks)
60
- parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True})
 
 
61
  doc = {
62
  "docnm_kwd": filename,
63
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
@@ -73,9 +78,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
73
  callback(0.8, "Finish parsing.")
74
 
75
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
76
- pdf_parser = Pdf() if parser_config["layout_recognize"] else PlainParser()
 
77
  sections, tbls = pdf_parser(filename if not binary else binary,
78
- from_page=from_page, to_page=to_page, callback=callback)
79
  res = tokenize_table(tbls, doc, eng)
80
 
81
  elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
@@ -92,16 +98,21 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
92
  with open(filename, "r") as f:
93
  while True:
94
  l = f.readline()
95
- if not l: break
 
96
  txt += l
97
  sections = txt.split("\n")
98
  sections = [(l, "") for l in sections if l]
99
  callback(0.8, "Finish parsing.")
100
 
101
  else:
102
- raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
103
 
104
- chunks = naive_merge(sections, parser_config.get("chunk_token_num", 128), parser_config.get("delimiter", "\n!?。;!?"))
 
 
 
105
 
106
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
107
  return res
@@ -110,9 +121,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
110
  if __name__ == "__main__":
111
  import sys
112
 
113
-
114
  def dummy(prog=None, msg=""):
115
  pass
116
 
117
-
118
  chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
 
44
  tbls = self._extract_table_figure(True, zoomin, True, True)
45
  self._naive_vertical_merge()
46
 
47
+ cron_logger.info("paddle layouts:".format(
48
+ (timer() - start) / (self.total_page + 0.1)))
49
+ return [(b["text"], self._line_tag(b, zoomin))
50
+ for b in self.boxes], tbls
51
 
52
 
53
+ def chunk(filename, binary=None, from_page=0, to_page=100000,
54
+ lang="Chinese", callback=None, **kwargs):
55
  """
56
  Supported file formats are docx, pdf, excel, txt.
57
  This method apply the naive ways to chunk files.
 
59
  Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
60
  """
61
 
62
+ eng = lang.lower() == "english" # is_english(cks)
63
+ parser_config = kwargs.get(
64
+ "parser_config", {
65
+ "chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True})
66
  doc = {
67
  "docnm_kwd": filename,
68
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
 
78
  callback(0.8, "Finish parsing.")
79
 
80
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
81
+ pdf_parser = Pdf(
82
+ ) if parser_config["layout_recognize"] else PlainParser()
83
  sections, tbls = pdf_parser(filename if not binary else binary,
84
+ from_page=from_page, to_page=to_page, callback=callback)
85
  res = tokenize_table(tbls, doc, eng)
86
 
87
  elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
 
98
  with open(filename, "r") as f:
99
  while True:
100
  l = f.readline()
101
+ if not l:
102
+ break
103
  txt += l
104
  sections = txt.split("\n")
105
  sections = [(l, "") for l in sections if l]
106
  callback(0.8, "Finish parsing.")
107
 
108
  else:
109
+ raise NotImplementedError(
110
+ "file type not supported yet(docx, pdf, txt supported)")
111
 
112
+ chunks = naive_merge(
113
+ sections, parser_config.get(
114
+ "chunk_token_num", 128), parser_config.get(
115
+ "delimiter", "\n!?。;!?"))
116
 
117
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
118
  return res
 
121
  if __name__ == "__main__":
122
  import sys
123
 
 
124
  def dummy(prog=None, msg=""):
125
  pass
126
 
 
127
  chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
rag/app/one.py CHANGED
@@ -41,20 +41,23 @@ class Pdf(PdfParser):
41
  tbls = self._extract_table_figure(True, zoomin, True, True)
42
  self._concat_downward()
43
 
44
- sections = [(b["text"], self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)]
 
45
  for (img, rows), poss in tbls:
46
  sections.append((rows if isinstance(rows, str) else rows[0],
47
  [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
48
- return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
 
49
 
50
 
51
- def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
 
52
  """
53
  Supported file formats are docx, pdf, excel, txt.
54
  One file forms a chunk which maintains original text order.
55
  """
56
 
57
- eng = lang.lower() == "english"#is_english(cks)
58
 
59
  if re.search(r"\.docx?$", filename, re.IGNORECASE):
60
  callback(0.1, "Start to parse.")
@@ -62,8 +65,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
62
  callback(0.8, "Finish parsing.")
63
 
64
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
65
- pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
66
- sections, _ = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback)
 
 
 
67
  sections = [s for s, _ in sections if s]
68
 
69
  elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
@@ -80,14 +86,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
80
  with open(filename, "r") as f:
81
  while True:
82
  l = f.readline()
83
- if not l: break
 
84
  txt += l
85
  sections = txt.split("\n")
86
  sections = [s for s in sections if s]
87
  callback(0.8, "Finish parsing.")
88
 
89
  else:
90
- raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
91
 
92
  doc = {
93
  "docnm_kwd": filename,
@@ -101,9 +109,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
101
  if __name__ == "__main__":
102
  import sys
103
 
104
-
105
  def dummy(prog=None, msg=""):
106
  pass
107
 
108
-
109
  chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
 
41
  tbls = self._extract_table_figure(True, zoomin, True, True)
42
  self._concat_downward()
43
 
44
+ sections = [(b["text"], self.get_position(b, zoomin))
45
+ for i, b in enumerate(self.boxes)]
46
  for (img, rows), poss in tbls:
47
  sections.append((rows if isinstance(rows, str) else rows[0],
48
  [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
49
+ return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
50
+ x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
51
 
52
 
53
+ def chunk(filename, binary=None, from_page=0, to_page=100000,
54
+ lang="Chinese", callback=None, **kwargs):
55
  """
56
  Supported file formats are docx, pdf, excel, txt.
57
  One file forms a chunk which maintains original text order.
58
  """
59
 
60
+ eng = lang.lower() == "english" # is_english(cks)
61
 
62
  if re.search(r"\.docx?$", filename, re.IGNORECASE):
63
  callback(0.1, "Start to parse.")
 
65
  callback(0.8, "Finish parsing.")
66
 
67
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
68
+ pdf_parser = Pdf() if kwargs.get(
69
+ "parser_config", {}).get(
70
+ "layout_recognize", True) else PlainParser()
71
+ sections, _ = pdf_parser(
72
+ filename if not binary else binary, to_page=to_page, callback=callback)
73
  sections = [s for s, _ in sections if s]
74
 
75
  elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
 
86
  with open(filename, "r") as f:
87
  while True:
88
  l = f.readline()
89
+ if not l:
90
+ break
91
  txt += l
92
  sections = txt.split("\n")
93
  sections = [s for s in sections if s]
94
  callback(0.8, "Finish parsing.")
95
 
96
  else:
97
+ raise NotImplementedError(
98
+ "file type not supported yet(docx, pdf, txt supported)")
99
 
100
  doc = {
101
  "docnm_kwd": filename,
 
109
  if __name__ == "__main__":
110
  import sys
111
 
 
112
  def dummy(prog=None, msg=""):
113
  pass
114
 
 
115
  chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
rag/app/paper.py CHANGED
@@ -67,11 +67,11 @@ class Pdf(PdfParser):
67
 
68
  if from_page > 0:
69
  return {
70
- "title":"",
71
  "authors": "",
72
  "abstract": "",
73
  "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
74
- re.match(r"(text|title)", b.get("layoutno", "text"))],
75
  "tables": tbls
76
  }
77
  # get title and authors
@@ -87,7 +87,8 @@ class Pdf(PdfParser):
87
  title = ""
88
  break
89
  for j in range(3):
90
- if _begin(self.boxes[i + j]["text"]): break
 
91
  authors.append(self.boxes[i + j]["text"])
92
  break
93
  break
@@ -107,10 +108,15 @@ class Pdf(PdfParser):
107
  abstr = txt + self._line_tag(self.boxes[i], zoomin)
108
  i += 1
109
  break
110
- if not abstr: i = 0
 
111
 
112
- callback(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page)))
113
- for b in self.boxes: print(b["text"], b.get("layoutno"))
 
 
 
 
114
  print(tbls)
115
 
116
  return {
@@ -118,19 +124,20 @@ class Pdf(PdfParser):
118
  "authors": " ".join(authors),
119
  "abstract": abstr,
120
  "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
121
- re.match(r"(text|title)", b.get("layoutno", "text"))],
122
  "tables": tbls
123
  }
124
 
125
 
126
- def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
 
127
  """
128
  Only pdf is supported.
129
  The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
130
  """
131
  pdf_parser = None
132
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
133
- if not kwargs.get("parser_config",{}).get("layout_recognize", True):
134
  pdf_parser = PlainParser()
135
  paper = {
136
  "title": filename,
@@ -143,14 +150,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
143
  pdf_parser = Pdf()
144
  paper = pdf_parser(filename if not binary else binary,
145
  from_page=from_page, to_page=to_page, callback=callback)
146
- else: raise NotImplementedError("file type not supported yet(pdf supported)")
 
147
 
148
  doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]),
149
  "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
150
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
151
  doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
152
  # is it English
153
- eng = lang.lower() == "english"#pdf_parser.is_english
154
  print("It's English.....", eng)
155
 
156
  res = tokenize_table(paper["tables"], doc, eng)
@@ -160,7 +168,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
160
  txt = pdf_parser.remove_tag(paper["abstract"])
161
  d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
162
  d["important_tks"] = " ".join(d["important_kwd"])
163
- d["image"], poss = pdf_parser.crop(paper["abstract"], need_position=True)
 
164
  add_positions(d, poss)
165
  tokenize(d, txt, eng)
166
  res.append(d)
@@ -174,7 +183,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
174
  sec_ids = []
175
  sid = 0
176
  for i, lvl in enumerate(levels):
177
- if lvl <= most_level and i > 0 and lvl != levels[i-1]: sid += 1
 
178
  sec_ids.append(sid)
179
  print(lvl, sorted_sections[i][0], most_level, sid)
180
 
@@ -190,6 +200,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
190
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
191
  return res
192
 
 
193
  """
194
  readed = [0] * len(paper["lines"])
195
  # find colon firstly
@@ -212,7 +223,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
212
  for k in range(j, i): readed[k] = True
213
  txt = txt[::-1]
214
  if eng:
215
- r = re.search(r"(.*?) ([\.;?!]|$)", txt)
216
  txt = r.group(1)[::-1] if r else txt[::-1]
217
  else:
218
  r = re.search(r"(.*?) ([。?;!]|$)", txt)
@@ -270,6 +281,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
270
 
271
  if __name__ == "__main__":
272
  import sys
 
273
  def dummy(prog=None, msg=""):
274
  pass
275
  chunk(sys.argv[1], callback=dummy)
 
67
 
68
  if from_page > 0:
69
  return {
70
+ "title": "",
71
  "authors": "",
72
  "abstract": "",
73
  "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
74
+ re.match(r"(text|title)", b.get("layoutno", "text"))],
75
  "tables": tbls
76
  }
77
  # get title and authors
 
87
  title = ""
88
  break
89
  for j in range(3):
90
+ if _begin(self.boxes[i + j]["text"]):
91
+ break
92
  authors.append(self.boxes[i + j]["text"])
93
  break
94
  break
 
108
  abstr = txt + self._line_tag(self.boxes[i], zoomin)
109
  i += 1
110
  break
111
+ if not abstr:
112
+ i = 0
113
 
114
+ callback(
115
+ 0.8, "Page {}~{}: Text merging finished".format(
116
+ from_page, min(
117
+ to_page, self.total_page)))
118
+ for b in self.boxes:
119
+ print(b["text"], b.get("layoutno"))
120
  print(tbls)
121
 
122
  return {
 
124
  "authors": " ".join(authors),
125
  "abstract": abstr,
126
  "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
127
+ re.match(r"(text|title)", b.get("layoutno", "text"))],
128
  "tables": tbls
129
  }
130
 
131
 
132
+ def chunk(filename, binary=None, from_page=0, to_page=100000,
133
+ lang="Chinese", callback=None, **kwargs):
134
  """
135
  Only pdf is supported.
136
  The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
137
  """
138
  pdf_parser = None
139
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
140
+ if not kwargs.get("parser_config", {}).get("layout_recognize", True):
141
  pdf_parser = PlainParser()
142
  paper = {
143
  "title": filename,
 
150
  pdf_parser = Pdf()
151
  paper = pdf_parser(filename if not binary else binary,
152
  from_page=from_page, to_page=to_page, callback=callback)
153
+ else:
154
+ raise NotImplementedError("file type not supported yet(pdf supported)")
155
 
156
  doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]),
157
  "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
158
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
159
  doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
160
  # is it English
161
+ eng = lang.lower() == "english" # pdf_parser.is_english
162
  print("It's English.....", eng)
163
 
164
  res = tokenize_table(paper["tables"], doc, eng)
 
168
  txt = pdf_parser.remove_tag(paper["abstract"])
169
  d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
170
  d["important_tks"] = " ".join(d["important_kwd"])
171
+ d["image"], poss = pdf_parser.crop(
172
+ paper["abstract"], need_position=True)
173
  add_positions(d, poss)
174
  tokenize(d, txt, eng)
175
  res.append(d)
 
183
  sec_ids = []
184
  sid = 0
185
  for i, lvl in enumerate(levels):
186
+ if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
187
+ sid += 1
188
  sec_ids.append(sid)
189
  print(lvl, sorted_sections[i][0], most_level, sid)
190
 
 
200
  res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
201
  return res
202
 
203
+
204
  """
205
  readed = [0] * len(paper["lines"])
206
  # find colon firstly
 
223
  for k in range(j, i): readed[k] = True
224
  txt = txt[::-1]
225
  if eng:
226
+ r = re.search(r"(.*?) ([\\.;?!]|$)", txt)
227
  txt = r.group(1)[::-1] if r else txt[::-1]
228
  else:
229
  r = re.search(r"(.*?) ([。?;!]|$)", txt)
 
281
 
282
  if __name__ == "__main__":
283
  import sys
284
+
285
  def dummy(prog=None, msg=""):
286
  pass
287
  chunk(sys.argv[1], callback=dummy)
rag/app/presentation.py CHANGED
@@ -33,9 +33,12 @@ class Ppt(PptParser):
33
  with slides.Presentation(BytesIO(fnm)) as presentation:
34
  for i, slide in enumerate(presentation.slides[from_page: to_page]):
35
  buffered = BytesIO()
36
- slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg)
 
 
37
  imgs.append(Image.open(buffered))
38
- assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
 
39
  callback(0.9, "Image extraction finished")
40
  self.is_english = is_english(txts)
41
  return [(txts[i], imgs[i]) for i in range(len(txts))]
@@ -47,25 +50,34 @@ class Pdf(PdfParser):
47
 
48
  def __garbage(self, txt):
49
  txt = txt.lower().strip()
50
- if re.match(r"[0-9\.,%/-]+$", txt): return True
51
- if len(txt) < 3:return True
 
 
52
  return False
53
 
54
- def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None):
 
55
  callback(msg="OCR is running...")
56
- self.__images__(filename if not binary else binary, zoomin, from_page, to_page, callback)
57
- callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)))
58
- assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images))
 
 
 
59
  res = []
60
  for i in range(len(self.boxes)):
61
- lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])])
 
62
  res.append((lines, self.page_images[i]))
63
- callback(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page)))
 
64
  return res
65
 
66
 
67
  class PlainPdf(PlainParser):
68
- def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
 
69
  self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
70
  page_txt = []
71
  for page in self.pdf.pages[from_page: to_page]:
@@ -74,7 +86,8 @@ class PlainPdf(PlainParser):
74
  return [(txt, None) for txt in page_txt]
75
 
76
 
77
- def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
 
78
  """
79
  The supported file formats are pdf, pptx.
80
  Every page will be treated as a chunk. And the thumbnail of every page will be stored.
@@ -89,35 +102,42 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
89
  res = []
90
  if re.search(r"\.pptx?$", filename, re.IGNORECASE):
91
  ppt_parser = Ppt()
92
- for pn, (txt,img) in enumerate(ppt_parser(filename if not binary else binary, from_page, 1000000, callback)):
 
93
  d = copy.deepcopy(doc)
94
  pn += from_page
95
  d["image"] = img
96
- d["page_num_int"] = [pn+1]
97
  d["top_int"] = [0]
98
  d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
99
  tokenize(d, txt, eng)
100
  res.append(d)
101
  return res
102
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
103
- pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainPdf()
104
- for pn, (txt,img) in enumerate(pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)):
 
 
 
105
  d = copy.deepcopy(doc)
106
  pn += from_page
107
- if img: d["image"] = img
108
- d["page_num_int"] = [pn+1]
 
109
  d["top_int"] = [0]
110
- d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
 
111
  tokenize(d, txt, eng)
112
  res.append(d)
113
  return res
114
 
115
- raise NotImplementedError("file type not supported yet(pptx, pdf supported)")
 
116
 
117
 
118
- if __name__== "__main__":
119
  import sys
 
120
  def dummy(a, b):
121
  pass
122
  chunk(sys.argv[1], callback=dummy)
123
-
 
33
  with slides.Presentation(BytesIO(fnm)) as presentation:
34
  for i, slide in enumerate(presentation.slides[from_page: to_page]):
35
  buffered = BytesIO()
36
+ slide.get_thumbnail(
37
+ 0.5, 0.5).save(
38
+ buffered, drawing.imaging.ImageFormat.jpeg)
39
  imgs.append(Image.open(buffered))
40
+ assert len(imgs) == len(
41
+ txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
42
  callback(0.9, "Image extraction finished")
43
  self.is_english = is_english(txts)
44
  return [(txts[i], imgs[i]) for i in range(len(txts))]
 
50
 
51
  def __garbage(self, txt):
52
  txt = txt.lower().strip()
53
+ if re.match(r"[0-9\.,%/-]+$", txt):
54
+ return True
55
+ if len(txt) < 3:
56
+ return True
57
  return False
58
 
59
+ def __call__(self, filename, binary=None, from_page=0,
60
+ to_page=100000, zoomin=3, callback=None):
61
  callback(msg="OCR is running...")
62
+ self.__images__(filename if not binary else binary,
63
+ zoomin, from_page, to_page, callback)
64
+ callback(0.8, "Page {}~{}: OCR finished".format(
65
+ from_page, min(to_page, self.total_page)))
66
+ assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
67
+ len(self.boxes), len(self.page_images))
68
  res = []
69
  for i in range(len(self.boxes)):
70
+ lines = "\n".join([b["text"] for b in self.boxes[i]
71
+ if not self.__garbage(b["text"])])
72
  res.append((lines, self.page_images[i]))
73
+ callback(0.9, "Page {}~{}: Parsing finished".format(
74
+ from_page, min(to_page, self.total_page)))
75
  return res
76
 
77
 
78
  class PlainPdf(PlainParser):
79
+ def __call__(self, filename, binary=None, from_page=0,
80
+ to_page=100000, callback=None, **kwargs):
81
  self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
82
  page_txt = []
83
  for page in self.pdf.pages[from_page: to_page]:
 
86
  return [(txt, None) for txt in page_txt]
87
 
88
 
89
+ def chunk(filename, binary=None, from_page=0, to_page=100000,
90
+ lang="Chinese", callback=None, **kwargs):
91
  """
92
  The supported file formats are pdf, pptx.
93
  Every page will be treated as a chunk. And the thumbnail of every page will be stored.
 
102
  res = []
103
  if re.search(r"\.pptx?$", filename, re.IGNORECASE):
104
  ppt_parser = Ppt()
105
+ for pn, (txt, img) in enumerate(ppt_parser(
106
+ filename if not binary else binary, from_page, 1000000, callback)):
107
  d = copy.deepcopy(doc)
108
  pn += from_page
109
  d["image"] = img
110
+ d["page_num_int"] = [pn + 1]
111
  d["top_int"] = [0]
112
  d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
113
  tokenize(d, txt, eng)
114
  res.append(d)
115
  return res
116
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
117
+ pdf_parser = Pdf() if kwargs.get(
118
+ "parser_config", {}).get(
119
+ "layout_recognize", True) else PlainPdf()
120
+ for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
121
+ from_page=from_page, to_page=to_page, callback=callback)):
122
  d = copy.deepcopy(doc)
123
  pn += from_page
124
+ if img:
125
+ d["image"] = img
126
+ d["page_num_int"] = [pn + 1]
127
  d["top_int"] = [0]
128
+ d["position_int"] = [
129
+ (pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
130
  tokenize(d, txt, eng)
131
  res.append(d)
132
  return res
133
 
134
+ raise NotImplementedError(
135
+ "file type not supported yet(pptx, pdf supported)")
136
 
137
 
138
+ if __name__ == "__main__":
139
  import sys
140
+
141
  def dummy(a, b):
142
  pass
143
  chunk(sys.argv[1], callback=dummy)
 
rag/app/resume.py CHANGED
@@ -27,6 +27,8 @@ from rag.utils import rmSpace
27
  forbidden_select_fields4resume = [
28
  "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
29
  ]
 
 
30
  def remote_call(filename, binary):
31
  q = {
32
  "header": {
@@ -48,18 +50,22 @@ def remote_call(filename, binary):
48
  }
49
  for _ in range(3):
50
  try:
51
- resume = requests.post("http://127.0.0.1:61670/tog", data=json.dumps(q))
 
 
52
  resume = resume.json()["response"]["results"]
53
  resume = refactor(resume)
54
- for k in ["education", "work", "project", "training", "skill", "certificate", "language"]:
55
- if not resume.get(k) and k in resume: del resume[k]
 
 
56
 
57
  resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
58
- "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
59
  resume = step_two.parse(resume)
60
  return resume
61
  except Exception as e:
62
- cron_logger.error("Resume parser error: "+str(e))
63
  return {}
64
 
65
 
@@ -144,10 +150,13 @@ def chunk(filename, binary=None, callback=None, **kwargs):
144
  doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
145
  doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
146
  for n, _ in field_map.items():
147
- if n not in resume:continue
148
- if isinstance(resume[n], list) and (len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
 
 
149
  resume[n] = resume[n][0]
150
- if n.find("_tks")>0: resume[n] = huqie.qieqie(resume[n])
 
151
  doc[n] = resume[n]
152
 
153
  print(doc)
 
27
  forbidden_select_fields4resume = [
28
  "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
29
  ]
30
+
31
+
32
  def remote_call(filename, binary):
33
  q = {
34
  "header": {
 
50
  }
51
  for _ in range(3):
52
  try:
53
+ resume = requests.post(
54
+ "http://127.0.0.1:61670/tog",
55
+ data=json.dumps(q))
56
  resume = resume.json()["response"]["results"]
57
  resume = refactor(resume)
58
+ for k in ["education", "work", "project",
59
+ "training", "skill", "certificate", "language"]:
60
+ if not resume.get(k) and k in resume:
61
+ del resume[k]
62
 
63
  resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
64
+ "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
65
  resume = step_two.parse(resume)
66
  return resume
67
  except Exception as e:
68
+ cron_logger.error("Resume parser error: " + str(e))
69
  return {}
70
 
71
 
 
150
  doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
151
  doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
152
  for n, _ in field_map.items():
153
+ if n not in resume:
154
+ continue
155
+ if isinstance(resume[n], list) and (
156
+ len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
157
  resume[n] = resume[n][0]
158
+ if n.find("_tks") > 0:
159
+ resume[n] = huqie.qieqie(resume[n])
160
  doc[n] = resume[n]
161
 
162
  print(doc)
rag/app/table.py CHANGED
@@ -25,7 +25,8 @@ from deepdoc.parser import ExcelParser
25
 
26
 
27
  class Excel(ExcelParser):
28
- def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
 
29
  if not binary:
30
  wb = load_workbook(fnm)
31
  else:
@@ -48,8 +49,10 @@ class Excel(ExcelParser):
48
  data = []
49
  for i, r in enumerate(rows[1:]):
50
  rn += 1
51
- if rn-1 < from_page:continue
52
- if rn -1>=to_page: break
 
 
53
  row = [
54
  cell.value for ii,
55
  cell in enumerate(r) if ii not in missed]
@@ -60,7 +63,7 @@ class Excel(ExcelParser):
60
  done += 1
61
  res.append(pd.DataFrame(np.array(data), columns=headers))
62
 
63
- callback(0.3, ("Extract records: {}~{}".format(from_page+1, min(to_page, from_page+rn)) + (
64
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
65
  return res
66
 
@@ -73,7 +76,8 @@ def trans_datatime(s):
73
 
74
 
75
  def trans_bool(s):
76
- if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", str(s).strip(), flags=re.IGNORECASE):
 
77
  return "yes"
78
  if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
79
  return "no"
@@ -107,13 +111,14 @@ def column_data_type(arr):
107
  arr[i] = trans[ty](str(arr[i]))
108
  except Exception as e:
109
  arr[i] = None
110
- #if ty == "text":
111
  # if len(arr) > 128 and uni / len(arr) < 0.1:
112
  # ty = "keyword"
113
  return arr, ty
114
 
115
 
116
- def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
 
117
  """
118
  Excel and csv(txt) format files are supported.
119
  For csv or txt file, the delimiter between columns is TAB.
@@ -131,7 +136,12 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
131
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
132
  callback(0.1, "Start to parse.")
133
  excel_parser = Excel()
134
- dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
 
 
 
 
 
135
  elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
136
  callback(0.1, "Start to parse.")
137
  txt = ""
@@ -149,8 +159,10 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
149
  headers = lines[0].split(kwargs.get("delimiter", "\t"))
150
  rows = []
151
  for i, line in enumerate(lines[1:]):
152
- if i < from_page:continue
153
- if i >= to_page: break
 
 
154
  row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
155
  if len(row) != len(headers):
156
  fails.append(str(i))
@@ -181,7 +193,13 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
181
  del df[n]
182
  clmns = df.columns.values
183
  txts = list(copy.deepcopy(clmns))
184
- py_clmns = [PY.get_pinyins(re.sub(r"(/.*|([^()]+?)|\([^()]+?\))", "", n), '_')[0] for n in clmns]
 
 
 
 
 
 
185
  clmn_tys = []
186
  for j in range(len(clmns)):
187
  cln, ty = column_data_type(df[clmns[j]])
@@ -192,7 +210,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
192
  clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " "))
193
  for i in range(len(clmns))]
194
 
195
- eng = lang.lower() == "english"#is_english(txts)
196
  for ii, row in df.iterrows():
197
  d = {
198
  "docnm_kwd": filename,
 
25
 
26
 
27
  class Excel(ExcelParser):
28
+ def __call__(self, fnm, binary=None, from_page=0,
29
+ to_page=10000000000, callback=None):
30
  if not binary:
31
  wb = load_workbook(fnm)
32
  else:
 
49
  data = []
50
  for i, r in enumerate(rows[1:]):
51
  rn += 1
52
+ if rn - 1 < from_page:
53
+ continue
54
+ if rn - 1 >= to_page:
55
+ break
56
  row = [
57
  cell.value for ii,
58
  cell in enumerate(r) if ii not in missed]
 
63
  done += 1
64
  res.append(pd.DataFrame(np.array(data), columns=headers))
65
 
66
+ callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
67
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
68
  return res
69
 
 
76
 
77
 
78
  def trans_bool(s):
79
+ if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$",
80
+ str(s).strip(), flags=re.IGNORECASE):
81
  return "yes"
82
  if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
83
  return "no"
 
111
  arr[i] = trans[ty](str(arr[i]))
112
  except Exception as e:
113
  arr[i] = None
114
+ # if ty == "text":
115
  # if len(arr) > 128 and uni / len(arr) < 0.1:
116
  # ty = "keyword"
117
  return arr, ty
118
 
119
 
120
+ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
121
+ lang="Chinese", callback=None, **kwargs):
122
  """
123
  Excel and csv(txt) format files are supported.
124
  For csv or txt file, the delimiter between columns is TAB.
 
136
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
137
  callback(0.1, "Start to parse.")
138
  excel_parser = Excel()
139
+ dfs = excel_parser(
140
+ filename,
141
+ binary,
142
+ from_page=from_page,
143
+ to_page=to_page,
144
+ callback=callback)
145
  elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
146
  callback(0.1, "Start to parse.")
147
  txt = ""
 
159
  headers = lines[0].split(kwargs.get("delimiter", "\t"))
160
  rows = []
161
  for i, line in enumerate(lines[1:]):
162
+ if i < from_page:
163
+ continue
164
+ if i >= to_page:
165
+ break
166
  row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
167
  if len(row) != len(headers):
168
  fails.append(str(i))
 
193
  del df[n]
194
  clmns = df.columns.values
195
  txts = list(copy.deepcopy(clmns))
196
+ py_clmns = [
197
+ PY.get_pinyins(
198
+ re.sub(
199
+ r"(/.*|([^()]+?)|\([^()]+?\))",
200
+ "",
201
+ n),
202
+ '_')[0] for n in clmns]
203
  clmn_tys = []
204
  for j in range(len(clmns)):
205
  cln, ty = column_data_type(df[clmns[j]])
 
210
  clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " "))
211
  for i in range(len(clmns))]
212
 
213
+ eng = lang.lower() == "english" # is_english(txts)
214
  for ii, row in df.iterrows():
215
  d = {
216
  "docnm_kwd": filename,
rag/llm/chat_model.py CHANGED
@@ -13,6 +13,8 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  from abc import ABC
17
  from openai import OpenAI
18
  import openai
@@ -34,7 +36,8 @@ class GptTurbo(Base):
34
  self.model_name = model_name
35
 
36
  def chat(self, system, history, gen_conf):
37
- if system: history.insert(0, {"role": "system", "content": system})
 
38
  try:
39
  response = self.client.chat.completions.create(
40
  model=self.model_name,
@@ -46,16 +49,18 @@ class GptTurbo(Base):
46
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
47
  return ans, response.usage.completion_tokens
48
  except openai.APIError as e:
49
- return "**ERROR**: "+str(e), 0
50
 
51
 
52
  class MoonshotChat(GptTurbo):
53
  def __init__(self, key, model_name="moonshot-v1-8k"):
54
- self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",)
 
55
  self.model_name = model_name
56
 
57
  def chat(self, system, history, gen_conf):
58
- if system: history.insert(0, {"role": "system", "content": system})
 
59
  try:
60
  response = self.client.chat.completions.create(
61
  model=self.model_name,
@@ -67,10 +72,9 @@ class MoonshotChat(GptTurbo):
67
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
68
  return ans, response.usage.completion_tokens
69
  except openai.APIError as e:
70
- return "**ERROR**: "+str(e), 0
71
 
72
 
73
- from dashscope import Generation
74
  class QWenChat(Base):
75
  def __init__(self, key, model_name=Generation.Models.qwen_turbo):
76
  import dashscope
@@ -79,7 +83,8 @@ class QWenChat(Base):
79
 
80
  def chat(self, system, history, gen_conf):
81
  from http import HTTPStatus
82
- if system: history.insert(0, {"role": "system", "content": system})
 
83
  response = Generation.call(
84
  self.model_name,
85
  messages=history,
@@ -92,20 +97,21 @@ class QWenChat(Base):
92
  ans += response.output.choices[0]['message']['content']
93
  tk_count += response.usage.output_tokens
94
  if response.output.choices[0].get("finish_reason", "") == "length":
95
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
 
96
  return ans, tk_count
97
 
98
  return "**ERROR**: " + response.message, tk_count
99
 
100
 
101
- from zhipuai import ZhipuAI
102
  class ZhipuChat(Base):
103
  def __init__(self, key, model_name="glm-3-turbo"):
104
  self.client = ZhipuAI(api_key=key)
105
  self.model_name = model_name
106
 
107
  def chat(self, system, history, gen_conf):
108
- if system: history.insert(0, {"role": "system", "content": system})
 
109
  try:
110
  response = self.client.chat.completions.create(
111
  self.model_name,
@@ -120,6 +126,7 @@ class ZhipuChat(Base):
120
  except Exception as e:
121
  return "**ERROR**: " + str(e), 0
122
 
 
123
  class LocalLLM(Base):
124
  class RPCProxy:
125
  def __init__(self, host, port):
@@ -129,14 +136,17 @@ class LocalLLM(Base):
129
 
130
  def __conn(self):
131
  from multiprocessing.connection import Client
132
- self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
 
133
 
134
  def __getattr__(self, name):
135
  import pickle
 
136
  def do_rpc(*args, **kwargs):
137
  for _ in range(3):
138
  try:
139
- self._connection.send(pickle.dumps((name, args, kwargs)))
 
140
  return pickle.loads(self._connection.recv())
141
  except Exception as e:
142
  self.__conn()
@@ -148,7 +158,8 @@ class LocalLLM(Base):
148
  self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
149
 
150
  def chat(self, system, history, gen_conf):
151
- if system: history.insert(0, {"role": "system", "content": system})
 
152
  try:
153
  ans = self.client.chat(
154
  history,
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from zhipuai import ZhipuAI
17
+ from dashscope import Generation
18
  from abc import ABC
19
  from openai import OpenAI
20
  import openai
 
36
  self.model_name = model_name
37
 
38
  def chat(self, system, history, gen_conf):
39
+ if system:
40
+ history.insert(0, {"role": "system", "content": system})
41
  try:
42
  response = self.client.chat.completions.create(
43
  model=self.model_name,
 
49
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
50
  return ans, response.usage.completion_tokens
51
  except openai.APIError as e:
52
+ return "**ERROR**: " + str(e), 0
53
 
54
 
55
  class MoonshotChat(GptTurbo):
56
  def __init__(self, key, model_name="moonshot-v1-8k"):
57
+ self.client = OpenAI(
58
+ api_key=key, base_url="https://api.moonshot.cn/v1",)
59
  self.model_name = model_name
60
 
61
  def chat(self, system, history, gen_conf):
62
+ if system:
63
+ history.insert(0, {"role": "system", "content": system})
64
  try:
65
  response = self.client.chat.completions.create(
66
  model=self.model_name,
 
72
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
73
  return ans, response.usage.completion_tokens
74
  except openai.APIError as e:
75
+ return "**ERROR**: " + str(e), 0
76
 
77
 
 
78
  class QWenChat(Base):
79
  def __init__(self, key, model_name=Generation.Models.qwen_turbo):
80
  import dashscope
 
83
 
84
  def chat(self, system, history, gen_conf):
85
  from http import HTTPStatus
86
+ if system:
87
+ history.insert(0, {"role": "system", "content": system})
88
  response = Generation.call(
89
  self.model_name,
90
  messages=history,
 
97
  ans += response.output.choices[0]['message']['content']
98
  tk_count += response.usage.output_tokens
99
  if response.output.choices[0].get("finish_reason", "") == "length":
100
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
101
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
102
  return ans, tk_count
103
 
104
  return "**ERROR**: " + response.message, tk_count
105
 
106
 
 
107
  class ZhipuChat(Base):
108
  def __init__(self, key, model_name="glm-3-turbo"):
109
  self.client = ZhipuAI(api_key=key)
110
  self.model_name = model_name
111
 
112
  def chat(self, system, history, gen_conf):
113
+ if system:
114
+ history.insert(0, {"role": "system", "content": system})
115
  try:
116
  response = self.client.chat.completions.create(
117
  self.model_name,
 
126
  except Exception as e:
127
  return "**ERROR**: " + str(e), 0
128
 
129
+
130
  class LocalLLM(Base):
131
  class RPCProxy:
132
  def __init__(self, host, port):
 
136
 
137
  def __conn(self):
138
  from multiprocessing.connection import Client
139
+ self._connection = Client(
140
+ (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
141
 
142
  def __getattr__(self, name):
143
  import pickle
144
+
145
  def do_rpc(*args, **kwargs):
146
  for _ in range(3):
147
  try:
148
+ self._connection.send(
149
+ pickle.dumps((name, args, kwargs)))
150
  return pickle.loads(self._connection.recv())
151
  except Exception as e:
152
  self.__conn()
 
158
  self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
159
 
160
  def chat(self, system, history, gen_conf):
161
+ if system:
162
+ history.insert(0, {"role": "system", "content": system})
163
  try:
164
  ans = self.client.chat(
165
  history,
rag/llm/cv_model.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import io
17
  from abc import ABC
18
 
@@ -57,8 +58,8 @@ class Base(ABC):
57
  },
58
  },
59
  {
60
- "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
61
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
62
  },
63
  ],
64
  }
@@ -92,8 +93,9 @@ class QWenCV(Base):
92
  def prompt(self, binary):
93
  # stupid as hell
94
  tmp_dir = get_project_base_directory("tmp")
95
- if not os.path.exists(tmp_dir): os.mkdir(tmp_dir)
96
- path = os.path.join(tmp_dir, "%s.jpg"%get_uuid())
 
97
  Image.open(io.BytesIO(binary)).save(path)
98
  return [
99
  {
@@ -103,8 +105,8 @@ class QWenCV(Base):
103
  "image": f"file://{path}"
104
  },
105
  {
106
- "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
107
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
108
  },
109
  ],
110
  }
@@ -120,9 +122,6 @@ class QWenCV(Base):
120
  return response.message, 0
121
 
122
 
123
- from zhipuai import ZhipuAI
124
-
125
-
126
  class Zhipu4V(Base):
127
  def __init__(self, key, model_name="glm-4v", lang="Chinese"):
128
  self.client = ZhipuAI(api_key=key)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from zhipuai import ZhipuAI
17
  import io
18
  from abc import ABC
19
 
 
58
  },
59
  },
60
  {
61
+ "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
62
+ "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
63
  },
64
  ],
65
  }
 
93
  def prompt(self, binary):
94
  # stupid as hell
95
  tmp_dir = get_project_base_directory("tmp")
96
+ if not os.path.exists(tmp_dir):
97
+ os.mkdir(tmp_dir)
98
+ path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
99
  Image.open(io.BytesIO(binary)).save(path)
100
  return [
101
  {
 
105
  "image": f"file://{path}"
106
  },
107
  {
108
+ "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
109
+ "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
110
  },
111
  ],
112
  }
 
122
  return response.message, 0
123
 
124
 
 
 
 
125
  class Zhipu4V(Base):
126
  def __init__(self, key, model_name="glm-4v", lang="Chinese"):
127
  self.client = ZhipuAI(api_key=key)
rag/llm/embedding_model.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import os
17
  from abc import ABC
18
 
@@ -40,11 +41,11 @@ flag_model = FlagModel(model_dir,
40
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
41
  use_fp16=torch.cuda.is_available())
42
 
 
43
  class Base(ABC):
44
  def __init__(self, key, model_name):
45
  pass
46
 
47
-
48
  def encode(self, texts: list, batch_size=32):
49
  raise NotImplementedError("Please implement encode method!")
50
 
@@ -67,11 +68,11 @@ class HuEmbedding(Base):
67
  """
68
  self.model = flag_model
69
 
70
-
71
  def encode(self, texts: list, batch_size=32):
72
  texts = [t[:2000] for t in texts]
73
  token_count = 0
74
- for t in texts: token_count += num_tokens_from_string(t)
 
75
  res = []
76
  for i in range(0, len(texts), batch_size):
77
  res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
@@ -90,7 +91,8 @@ class OpenAIEmbed(Base):
90
  def encode(self, texts: list, batch_size=32):
91
  res = self.client.embeddings.create(input=texts,
92
  model=self.model_name)
93
- return np.array([d.embedding for d in res.data]), res.usage.total_tokens
 
94
 
95
  def encode_queries(self, text):
96
  res = self.client.embeddings.create(input=[text],
@@ -111,7 +113,7 @@ class QWenEmbed(Base):
111
  for i in range(0, len(texts), batch_size):
112
  resp = dashscope.TextEmbedding.call(
113
  model=self.model_name,
114
- input=texts[i:i+batch_size],
115
  text_type="document"
116
  )
117
  embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
@@ -123,14 +125,14 @@ class QWenEmbed(Base):
123
 
124
  def encode_queries(self, text):
125
  resp = dashscope.TextEmbedding.call(
126
- model=self.model_name,
127
- input=text[:2048],
128
- text_type="query"
129
- )
130
- return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["total_tokens"]
 
131
 
132
 
133
- from zhipuai import ZhipuAI
134
  class ZhipuEmbed(Base):
135
  def __init__(self, key, model_name="embedding-2"):
136
  self.client = ZhipuAI(api_key=key)
@@ -139,9 +141,10 @@ class ZhipuEmbed(Base):
139
  def encode(self, texts: list, batch_size=32):
140
  res = self.client.embeddings.create(input=texts,
141
  model=self.model_name)
142
- return np.array([d.embedding for d in res.data]), res.usage.total_tokens
 
143
 
144
  def encode_queries(self, text):
145
  res = self.client.embeddings.create(input=text,
146
  model=self.model_name)
147
- return np.array(res["data"][0]["embedding"]), res.usage.total_tokens
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from zhipuai import ZhipuAI
17
  import os
18
  from abc import ABC
19
 
 
41
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
42
  use_fp16=torch.cuda.is_available())
43
 
44
+
45
  class Base(ABC):
46
  def __init__(self, key, model_name):
47
  pass
48
 
 
49
  def encode(self, texts: list, batch_size=32):
50
  raise NotImplementedError("Please implement encode method!")
51
 
 
68
  """
69
  self.model = flag_model
70
 
 
71
  def encode(self, texts: list, batch_size=32):
72
  texts = [t[:2000] for t in texts]
73
  token_count = 0
74
+ for t in texts:
75
+ token_count += num_tokens_from_string(t)
76
  res = []
77
  for i in range(0, len(texts), batch_size):
78
  res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
 
91
  def encode(self, texts: list, batch_size=32):
92
  res = self.client.embeddings.create(input=texts,
93
  model=self.model_name)
94
+ return np.array([d.embedding for d in res.data]
95
+ ), res.usage.total_tokens
96
 
97
  def encode_queries(self, text):
98
  res = self.client.embeddings.create(input=[text],
 
113
  for i in range(0, len(texts), batch_size):
114
  resp = dashscope.TextEmbedding.call(
115
  model=self.model_name,
116
+ input=texts[i:i + batch_size],
117
  text_type="document"
118
  )
119
  embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
 
125
 
126
  def encode_queries(self, text):
127
  resp = dashscope.TextEmbedding.call(
128
+ model=self.model_name,
129
+ input=text[:2048],
130
+ text_type="query"
131
+ )
132
+ return np.array(resp["output"]["embeddings"][0]
133
+ ["embedding"]), resp["usage"]["total_tokens"]
134
 
135
 
 
136
  class ZhipuEmbed(Base):
137
  def __init__(self, key, model_name="embedding-2"):
138
  self.client = ZhipuAI(api_key=key)
 
141
  def encode(self, texts: list, batch_size=32):
142
  res = self.client.embeddings.create(input=texts,
143
  model=self.model_name)
144
+ return np.array([d.embedding for d in res.data]
145
+ ), res.usage.total_tokens
146
 
147
  def encode_queries(self, text):
148
  res = self.client.embeddings.create(input=text,
149
  model=self.model_name)
150
+ return np.array(res["data"][0]["embedding"]), res.usage.total_tokens
rag/llm/rpc_server.py CHANGED
@@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  class RPCHandler:
11
  def __init__(self):
12
- self._functions = { }
13
 
14
  def register_function(self, func):
15
  self._functions[func.__name__] = func
@@ -21,12 +21,12 @@ class RPCHandler:
21
  func_name, args, kwargs = pickle.loads(connection.recv())
22
  # Run the RPC and send a response
23
  try:
24
- r = self._functions[func_name](*args,**kwargs)
25
  connection.send(pickle.dumps(r))
26
  except Exception as e:
27
  connection.send(pickle.dumps(e))
28
  except EOFError:
29
- pass
30
 
31
 
32
  def rpc_server(hdlr, address, authkey):
@@ -44,11 +44,17 @@ def rpc_server(hdlr, address, authkey):
44
  models = []
45
  tokenizer = None
46
 
 
47
  def chat(messages, gen_conf):
48
  global tokenizer
49
  model = Model()
50
  try:
51
- conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))}
 
 
 
 
 
52
  print(messages, conf)
53
  text = tokenizer.apply_chat_template(
54
  messages,
@@ -65,7 +71,8 @@ def chat(messages, gen_conf):
65
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
66
  ]
67
 
68
- return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
69
  except Exception as e:
70
  return str(e)
71
 
@@ -75,10 +82,15 @@ def Model():
75
  random.seed(time.time())
76
  return random.choice(models)
77
 
 
78
  if __name__ == "__main__":
79
  parser = argparse.ArgumentParser()
80
  parser.add_argument("--model_name", type=str, help="Model name")
81
- parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
 
 
 
 
82
  args = parser.parse_args()
83
 
84
  handler = RPCHandler()
@@ -93,4 +105,5 @@ if __name__ == "__main__":
93
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
94
 
95
  # Run the server
96
- rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
 
 
9
 
10
  class RPCHandler:
11
  def __init__(self):
12
+ self._functions = {}
13
 
14
  def register_function(self, func):
15
  self._functions[func.__name__] = func
 
21
  func_name, args, kwargs = pickle.loads(connection.recv())
22
  # Run the RPC and send a response
23
  try:
24
+ r = self._functions[func_name](*args, **kwargs)
25
  connection.send(pickle.dumps(r))
26
  except Exception as e:
27
  connection.send(pickle.dumps(e))
28
  except EOFError:
29
+ pass
30
 
31
 
32
  def rpc_server(hdlr, address, authkey):
 
44
  models = []
45
  tokenizer = None
46
 
47
+
48
  def chat(messages, gen_conf):
49
  global tokenizer
50
  model = Model()
51
  try:
52
+ conf = {
53
+ "max_new_tokens": int(
54
+ gen_conf.get(
55
+ "max_tokens", 256)), "temperature": float(
56
+ gen_conf.get(
57
+ "temperature", 0.1))}
58
  print(messages, conf)
59
  text = tokenizer.apply_chat_template(
60
  messages,
 
71
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
72
  ]
73
 
74
+ return tokenizer.batch_decode(
75
+ generated_ids, skip_special_tokens=True)[0]
76
  except Exception as e:
77
  return str(e)
78
 
 
82
  random.seed(time.time())
83
  return random.choice(models)
84
 
85
+
86
  if __name__ == "__main__":
87
  parser = argparse.ArgumentParser()
88
  parser.add_argument("--model_name", type=str, help="Model name")
89
+ parser.add_argument(
90
+ "--port",
91
+ default=7860,
92
+ type=int,
93
+ help="RPC serving port")
94
  args = parser.parse_args()
95
 
96
  handler = RPCHandler()
 
105
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
106
 
107
  # Run the server
108
+ rpc_server(handler, ('0.0.0.0', args.port),
109
+ authkey=b'infiniflow-token4kevinhu')
rag/nlp/huchunk.py CHANGED
@@ -372,7 +372,8 @@ class PptChunker(HuChunker):
372
  tb = shape.table
373
  rows = []
374
  for i in range(1, len(tb.rows)):
375
- rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
 
376
  return "\n".join(rows)
377
 
378
  if shape.has_text_frame:
@@ -382,7 +383,8 @@ class PptChunker(HuChunker):
382
  texts = []
383
  for p in shape.shapes:
384
  t = self.__extract(p)
385
- if t: texts.append(t)
 
386
  return "\n".join(texts)
387
 
388
  def __call__(self, fnm):
@@ -395,7 +397,8 @@ class PptChunker(HuChunker):
395
  texts = []
396
  for shape in slide.shapes:
397
  txt = self.__extract(shape)
398
- if txt: texts.append(txt)
 
399
  txts.append("\n".join(texts))
400
 
401
  import aspose.slides as slides
@@ -404,9 +407,12 @@ class PptChunker(HuChunker):
404
  with slides.Presentation(BytesIO(fnm)) as presentation:
405
  for slide in presentation.slides:
406
  buffered = BytesIO()
407
- slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg)
 
 
408
  imgs.append(buffered.getvalue())
409
- assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
 
410
 
411
  flds = self.Fields()
412
  flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
@@ -445,7 +451,8 @@ class TextChunker(HuChunker):
445
  if isinstance(fnm, str):
446
  with open(fnm, "r") as f:
447
  txt = f.read()
448
- else: txt = fnm.decode("utf-8")
 
449
  flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
450
  flds.table_chunks = []
451
  return flds
 
372
  tb = shape.table
373
  rows = []
374
  for i in range(1, len(tb.rows)):
375
+ rows.append("; ".join([tb.cell(
376
+ 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
377
  return "\n".join(rows)
378
 
379
  if shape.has_text_frame:
 
383
  texts = []
384
  for p in shape.shapes:
385
  t = self.__extract(p)
386
+ if t:
387
+ texts.append(t)
388
  return "\n".join(texts)
389
 
390
  def __call__(self, fnm):
 
397
  texts = []
398
  for shape in slide.shapes:
399
  txt = self.__extract(shape)
400
+ if txt:
401
+ texts.append(txt)
402
  txts.append("\n".join(texts))
403
 
404
  import aspose.slides as slides
 
407
  with slides.Presentation(BytesIO(fnm)) as presentation:
408
  for slide in presentation.slides:
409
  buffered = BytesIO()
410
+ slide.get_thumbnail(
411
+ 0.5, 0.5).save(
412
+ buffered, drawing.imaging.ImageFormat.jpeg)
413
  imgs.append(buffered.getvalue())
414
+ assert len(imgs) == len(
415
+ txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
416
 
417
  flds = self.Fields()
418
  flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
 
451
  if isinstance(fnm, str):
452
  with open(fnm, "r") as f:
453
  txt = f.read()
454
+ else:
455
+ txt = fnm.decode("utf-8")
456
  flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
457
  flds.table_chunks = []
458
  return flds
rag/nlp/query.py CHANGED
@@ -149,7 +149,8 @@ class EsQueryer:
149
  atks = toDict(atks)
150
  btkss = [toDict(tks) for tks in btkss]
151
  tksim = [self.similarity(atks, btks) for btks in btkss]
152
- return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
 
153
 
154
  def similarity(self, qtwt, dtwt):
155
  if isinstance(dtwt, type("")):
@@ -159,11 +160,11 @@ class EsQueryer:
159
  s = 1e-9
160
  for k, v in qtwt.items():
161
  if k in dtwt:
162
- s += v# * dtwt[k]
163
  q = 1e-9
164
  for k, v in qtwt.items():
165
- q += v #* v
166
  #d = 1e-9
167
- #for k, v in dtwt.items():
168
  # d += v * v
169
- return s / q #math.sqrt(q) / math.sqrt(d)
 
149
  atks = toDict(atks)
150
  btkss = [toDict(tks) for tks in btkss]
151
  tksim = [self.similarity(atks, btks) for btks in btkss]
152
+ return np.array(sims[0]) * vtweight + \
153
+ np.array(tksim) * tkweight, tksim, sims[0]
154
 
155
  def similarity(self, qtwt, dtwt):
156
  if isinstance(dtwt, type("")):
 
160
  s = 1e-9
161
  for k, v in qtwt.items():
162
  if k in dtwt:
163
+ s += v # * dtwt[k]
164
  q = 1e-9
165
  for k, v in qtwt.items():
166
+ q += v # * v
167
  #d = 1e-9
168
+ # for k, v in dtwt.items():
169
  # d += v * v
170
+ return s / q # math.sqrt(q) / math.sqrt(d)