KevinHuSh
commited on
Commit
·
79ada0b
1
Parent(s):
85b269d
apply pep8 formalize (#155)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- api/apps/chunk_app.py +14 -5
- api/apps/conversation_app.py +81 -43
- api/apps/dialog_app.py +30 -13
- api/apps/document_app.py +21 -8
- api/apps/kb_app.py +39 -17
- api/apps/llm_app.py +26 -12
- api/apps/user_app.py +57 -29
- api/db/db_models.py +243 -64
- api/db/db_utils.py +10 -5
- api/db/init_data.py +55 -45
- api/db/operatioins.py +1 -1
- api/db/reload_config_base.py +3 -2
- api/db/runtime_config.py +1 -1
- api/db/services/common_service.py +26 -12
- api/db/services/dialog_service.py +0 -1
- api/db/services/document_service.py +49 -12
- api/db/services/knowledgebase_service.py +8 -7
- api/db/services/llm_service.py +22 -11
- api/db/services/user_service.py +32 -13
- api/settings.py +38 -16
- api/utils/__init__.py +39 -14
- api/utils/api_utils.py +43 -17
- api/utils/file_utils.py +13 -10
- api/utils/log_utils.py +39 -21
- api/utils/t_crypt.py +10 -5
- deepdoc/parser/__init__.py +0 -2
- deepdoc/parser/docx_parser.py +6 -3
- deepdoc/parser/excel_parser.py +12 -7
- deepdoc/parser/pdf_parser.py +105 -53
- deepdoc/parser/ppt_parser.py +13 -7
- deepdoc/vision/layout_recognizer.py +29 -24
- deepdoc/vision/operators.py +2 -1
- deepdoc/vision/t_ocr.py +21 -13
- deepdoc/vision/t_recognizer.py +33 -18
- deepdoc/vision/table_structure_recognizer.py +13 -10
- rag/app/book.py +36 -18
- rag/app/laws.py +29 -15
- rag/app/manual.py +30 -19
- rag/app/naive.py +21 -12
- rag/app/one.py +16 -10
- rag/app/paper.py +26 -14
- rag/app/presentation.py +42 -22
- rag/app/resume.py +17 -8
- rag/app/table.py +30 -12
- rag/llm/chat_model.py +24 -13
- rag/llm/cv_model.py +8 -9
- rag/llm/embedding_model.py +16 -13
- rag/llm/rpc_server.py +20 -7
- rag/nlp/huchunk.py +13 -6
- rag/nlp/query.py +6 -5
api/apps/chunk_app.py
CHANGED
@@ -121,7 +121,9 @@ def get():
|
|
121 |
"important_kwd")
|
122 |
def set():
|
123 |
req = request.json
|
124 |
-
d = {
|
|
|
|
|
125 |
d["content_ltks"] = huqie.qie(req["content_with_weight"])
|
126 |
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
127 |
d["important_kwd"] = req["important_kwd"]
|
@@ -140,10 +142,16 @@ def set():
|
|
140 |
return get_data_error_result(retmsg="Document not found!")
|
141 |
|
142 |
if doc.parser_id == ParserType.QA:
|
143 |
-
arr = [
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
145 |
q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
|
146 |
-
d = beAdoc(d, arr[0], arr[1], not any(
|
|
|
147 |
|
148 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
149 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
@@ -177,7 +185,8 @@ def switch():
|
|
177 |
def rm():
|
178 |
req = request.json
|
179 |
try:
|
180 |
-
if not ELASTICSEARCH.deleteByQuery(
|
|
|
181 |
return get_data_error_result(retmsg="Index updating failure")
|
182 |
return get_json_result(data=True)
|
183 |
except Exception as e:
|
|
|
121 |
"important_kwd")
|
122 |
def set():
|
123 |
req = request.json
|
124 |
+
d = {
|
125 |
+
"id": req["chunk_id"],
|
126 |
+
"content_with_weight": req["content_with_weight"]}
|
127 |
d["content_ltks"] = huqie.qie(req["content_with_weight"])
|
128 |
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
129 |
d["important_kwd"] = req["important_kwd"]
|
|
|
142 |
return get_data_error_result(retmsg="Document not found!")
|
143 |
|
144 |
if doc.parser_id == ParserType.QA:
|
145 |
+
arr = [
|
146 |
+
t for t in re.split(
|
147 |
+
r"[\n\t]",
|
148 |
+
req["content_with_weight"]) if len(t) > 1]
|
149 |
+
if len(arr) != 2:
|
150 |
+
return get_data_error_result(
|
151 |
+
retmsg="Q&A must be separated by TAB/ENTER key.")
|
152 |
q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
|
153 |
+
d = beAdoc(d, arr[0], arr[1], not any(
|
154 |
+
[huqie.is_chinese(t) for t in q + a]))
|
155 |
|
156 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
157 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
|
|
185 |
def rm():
|
186 |
req = request.json
|
187 |
try:
|
188 |
+
if not ELASTICSEARCH.deleteByQuery(
|
189 |
+
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
190 |
return get_data_error_result(retmsg="Index updating failure")
|
191 |
return get_json_result(data=True)
|
192 |
except Exception as e:
|
api/apps/conversation_app.py
CHANGED
@@ -100,7 +100,10 @@ def rm():
|
|
100 |
def list_convsersation():
|
101 |
dialog_id = request.args["dialog_id"]
|
102 |
try:
|
103 |
-
convs = ConversationService.query(
|
|
|
|
|
|
|
104 |
convs = [d.to_dict() for d in convs]
|
105 |
return get_json_result(data=convs)
|
106 |
except Exception as e:
|
@@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000):
|
|
111 |
def count():
|
112 |
nonlocal msg
|
113 |
tks_cnts = []
|
114 |
-
for m in msg:
|
|
|
|
|
115 |
total = 0
|
116 |
-
for m in tks_cnts:
|
|
|
117 |
return total
|
118 |
|
119 |
c = count()
|
120 |
-
if c < max_length:
|
|
|
121 |
|
122 |
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
123 |
msg_.append(msg[-1])
|
124 |
msg = msg_
|
125 |
c = count()
|
126 |
-
if c < max_length:
|
|
|
127 |
|
128 |
ll = num_tokens_from_string(msg_[0].content)
|
129 |
l = num_tokens_from_string(msg_[-1].content)
|
@@ -146,8 +154,10 @@ def completion():
|
|
146 |
req = request.json
|
147 |
msg = []
|
148 |
for m in req["messages"]:
|
149 |
-
if m["role"] == "system":
|
150 |
-
|
|
|
|
|
151 |
msg.append({"role": m["role"], "content": m["content"]})
|
152 |
try:
|
153 |
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
@@ -160,7 +170,8 @@ def completion():
|
|
160 |
del req["conversation_id"]
|
161 |
del req["messages"]
|
162 |
ans = chat(dia, msg, **req)
|
163 |
-
if not conv.reference:
|
|
|
164 |
conv.reference.append(ans["reference"])
|
165 |
conv.message.append({"role": "assistant", "content": ans["answer"]})
|
166 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
@@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs):
|
|
180 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
181 |
|
182 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
183 |
-
|
184 |
if field_map:
|
185 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
186 |
return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
|
187 |
|
188 |
prompt_config = dialog.prompt_config
|
189 |
for p in prompt_config["parameters"]:
|
190 |
-
if p["key"] == "knowledge":
|
191 |
-
|
|
|
|
|
192 |
if p["key"] not in kwargs:
|
193 |
-
prompt_config["system"] = prompt_config["system"].replace(
|
|
|
194 |
|
195 |
-
for _ in range(len(questions)//2):
|
196 |
questions.append(questions[-1])
|
197 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
198 |
-
kbinfos = {"total":0, "chunks":[],"doc_aggs":[]}
|
199 |
else:
|
200 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
201 |
-
|
202 |
-
|
203 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
204 |
-
chat_logger.info(
|
|
|
205 |
|
206 |
if not knowledges and prompt_config.get("empty_response"):
|
207 |
-
return {
|
|
|
208 |
|
209 |
kwargs["knowledge"] = "\n".join(knowledges)
|
210 |
gen_conf = dialog.llm_setting
|
211 |
-
msg = [{"role": m["role"], "content": m["content"]}
|
|
|
212 |
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
213 |
if "max_tokens" in gen_conf:
|
214 |
-
gen_conf["max_tokens"] = min(
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
if knowledges:
|
219 |
answer, idx = retrievaler.insert_citations(answer,
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
225 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
226 |
-
kbinfos["doc_aggs"] = [
|
|
|
227 |
for c in kbinfos["chunks"]:
|
228 |
-
if c.get("vector"):
|
|
|
229 |
return {"answer": answer, "reference": kbinfos}
|
230 |
|
231 |
|
@@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
|
245 |
question
|
246 |
)
|
247 |
tried_times = 0
|
|
|
248 |
def get_table():
|
249 |
nonlocal sys_prompt, user_promt, question, tried_times
|
250 |
-
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
|
|
251 |
print(user_promt, sql)
|
252 |
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
253 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
@@ -262,8 +290,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
|
262 |
else:
|
263 |
flds = []
|
264 |
for k in field_map.keys():
|
265 |
-
if k in forbidden_select_fields4resume:
|
266 |
-
|
|
|
|
|
267 |
flds.append(k)
|
268 |
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
269 |
|
@@ -284,13 +314,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
|
284 |
|
285 |
问题如下:
|
286 |
{}
|
287 |
-
|
288 |
你上一次给出的错误SQL如下:
|
289 |
{}
|
290 |
-
|
291 |
后台报错如下:
|
292 |
{}
|
293 |
-
|
294 |
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
|
295 |
""".format(
|
296 |
index_name(tenant_id),
|
@@ -302,16 +332,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
|
302 |
|
303 |
chat_logger.info("GET table: {}".format(tbl))
|
304 |
print(tbl)
|
305 |
-
if tbl.get("error") or len(tbl["rows"]) == 0:
|
|
|
306 |
|
307 |
-
docid_idx = set([ii for ii, c in enumerate(
|
308 |
-
|
309 |
-
|
|
|
|
|
|
|
310 |
|
311 |
# compose markdown table
|
312 |
-
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
313 |
-
|
314 |
-
|
|
|
|
|
|
|
|
|
315 |
if not docid_idx or not docnm_idx:
|
316 |
chat_logger.warning("SQL missing field: " + sql)
|
317 |
return "\n".join([clmns, line, "\n".join(rows)]), []
|
@@ -328,5 +366,5 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
|
328 |
return {
|
329 |
"answer": "\n".join([clmns, line, rows]),
|
330 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
331 |
-
"doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
|
332 |
}
|
|
|
100 |
def list_convsersation():
|
101 |
dialog_id = request.args["dialog_id"]
|
102 |
try:
|
103 |
+
convs = ConversationService.query(
|
104 |
+
dialog_id=dialog_id,
|
105 |
+
order_by=ConversationService.model.create_time,
|
106 |
+
reverse=True)
|
107 |
convs = [d.to_dict() for d in convs]
|
108 |
return get_json_result(data=convs)
|
109 |
except Exception as e:
|
|
|
114 |
def count():
|
115 |
nonlocal msg
|
116 |
tks_cnts = []
|
117 |
+
for m in msg:
|
118 |
+
tks_cnts.append(
|
119 |
+
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
120 |
total = 0
|
121 |
+
for m in tks_cnts:
|
122 |
+
total += m["count"]
|
123 |
return total
|
124 |
|
125 |
c = count()
|
126 |
+
if c < max_length:
|
127 |
+
return c, msg
|
128 |
|
129 |
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
130 |
msg_.append(msg[-1])
|
131 |
msg = msg_
|
132 |
c = count()
|
133 |
+
if c < max_length:
|
134 |
+
return c, msg
|
135 |
|
136 |
ll = num_tokens_from_string(msg_[0].content)
|
137 |
l = num_tokens_from_string(msg_[-1].content)
|
|
|
154 |
req = request.json
|
155 |
msg = []
|
156 |
for m in req["messages"]:
|
157 |
+
if m["role"] == "system":
|
158 |
+
continue
|
159 |
+
if m["role"] == "assistant" and not msg:
|
160 |
+
continue
|
161 |
msg.append({"role": m["role"], "content": m["content"]})
|
162 |
try:
|
163 |
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
|
|
170 |
del req["conversation_id"]
|
171 |
del req["messages"]
|
172 |
ans = chat(dia, msg, **req)
|
173 |
+
if not conv.reference:
|
174 |
+
conv.reference = []
|
175 |
conv.reference.append(ans["reference"])
|
176 |
conv.message.append({"role": "assistant", "content": ans["answer"]})
|
177 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
|
|
191 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
192 |
|
193 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
194 |
+
# try to use sql if field mapping is good to go
|
195 |
if field_map:
|
196 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
197 |
return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
|
198 |
|
199 |
prompt_config = dialog.prompt_config
|
200 |
for p in prompt_config["parameters"]:
|
201 |
+
if p["key"] == "knowledge":
|
202 |
+
continue
|
203 |
+
if p["key"] not in kwargs and not p["optional"]:
|
204 |
+
raise KeyError("Miss parameter: " + p["key"])
|
205 |
if p["key"] not in kwargs:
|
206 |
+
prompt_config["system"] = prompt_config["system"].replace(
|
207 |
+
"{%s}" % p["key"], " ")
|
208 |
|
209 |
+
for _ in range(len(questions) // 2):
|
210 |
questions.append(questions[-1])
|
211 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
212 |
+
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
213 |
else:
|
214 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
215 |
+
dialog.similarity_threshold,
|
216 |
+
dialog.vector_similarity_weight, top=1024, aggs=False)
|
217 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
218 |
+
chat_logger.info(
|
219 |
+
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
220 |
|
221 |
if not knowledges and prompt_config.get("empty_response"):
|
222 |
+
return {
|
223 |
+
"answer": prompt_config["empty_response"], "reference": kbinfos}
|
224 |
|
225 |
kwargs["knowledge"] = "\n".join(knowledges)
|
226 |
gen_conf = dialog.llm_setting
|
227 |
+
msg = [{"role": m["role"], "content": m["content"]}
|
228 |
+
for m in messages if m["role"] != "system"]
|
229 |
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
230 |
if "max_tokens" in gen_conf:
|
231 |
+
gen_conf["max_tokens"] = min(
|
232 |
+
gen_conf["max_tokens"],
|
233 |
+
llm.max_tokens - used_token_count)
|
234 |
+
answer = chat_mdl.chat(
|
235 |
+
prompt_config["system"].format(
|
236 |
+
**kwargs), msg, gen_conf)
|
237 |
+
chat_logger.info("User: {}|Assistant: {}".format(
|
238 |
+
msg[-1]["content"], answer))
|
239 |
|
240 |
if knowledges:
|
241 |
answer, idx = retrievaler.insert_citations(answer,
|
242 |
+
[ck["content_ltks"]
|
243 |
+
for ck in kbinfos["chunks"]],
|
244 |
+
[ck["vector"]
|
245 |
+
for ck in kbinfos["chunks"]],
|
246 |
+
embd_mdl,
|
247 |
+
tkweight=1 - dialog.vector_similarity_weight,
|
248 |
+
vtweight=dialog.vector_similarity_weight)
|
249 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
250 |
+
kbinfos["doc_aggs"] = [
|
251 |
+
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
252 |
for c in kbinfos["chunks"]:
|
253 |
+
if c.get("vector"):
|
254 |
+
del c["vector"]
|
255 |
return {"answer": answer, "reference": kbinfos}
|
256 |
|
257 |
|
|
|
271 |
question
|
272 |
)
|
273 |
tried_times = 0
|
274 |
+
|
275 |
def get_table():
|
276 |
nonlocal sys_prompt, user_promt, question, tried_times
|
277 |
+
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
278 |
+
"temperature": 0.06})
|
279 |
print(user_promt, sql)
|
280 |
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
281 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
|
|
290 |
else:
|
291 |
flds = []
|
292 |
for k in field_map.keys():
|
293 |
+
if k in forbidden_select_fields4resume:
|
294 |
+
continue
|
295 |
+
if len(flds) > 11:
|
296 |
+
break
|
297 |
flds.append(k)
|
298 |
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
299 |
|
|
|
314 |
|
315 |
问题如下:
|
316 |
{}
|
317 |
+
|
318 |
你上一次给出的错误SQL如下:
|
319 |
{}
|
320 |
+
|
321 |
后台报错如下:
|
322 |
{}
|
323 |
+
|
324 |
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
|
325 |
""".format(
|
326 |
index_name(tenant_id),
|
|
|
332 |
|
333 |
chat_logger.info("GET table: {}".format(tbl))
|
334 |
print(tbl)
|
335 |
+
if tbl.get("error") or len(tbl["rows"]) == 0:
|
336 |
+
return None, None
|
337 |
|
338 |
+
docid_idx = set([ii for ii, c in enumerate(
|
339 |
+
tbl["columns"]) if c["name"] == "doc_id"])
|
340 |
+
docnm_idx = set([ii for ii, c in enumerate(
|
341 |
+
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
342 |
+
clmn_idx = [ii for ii in range(
|
343 |
+
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
344 |
|
345 |
# compose markdown table
|
346 |
+
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
347 |
+
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
348 |
+
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
349 |
+
("|------|" if docid_idx and docid_idx else "")
|
350 |
+
rows = ["|" +
|
351 |
+
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
352 |
+
"|" for r in tbl["rows"]]
|
353 |
if not docid_idx or not docnm_idx:
|
354 |
chat_logger.warning("SQL missing field: " + sql)
|
355 |
return "\n".join([clmns, line, "\n".join(rows)]), []
|
|
|
366 |
return {
|
367 |
"answer": "\n".join([clmns, line, rows]),
|
368 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
369 |
+
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
|
370 |
}
|
api/apps/dialog_app.py
CHANGED
@@ -55,7 +55,8 @@ def set_dialog():
|
|
55 |
}
|
56 |
prompt_config = req.get("prompt_config", default_prompt)
|
57 |
|
58 |
-
if not prompt_config["system"]:
|
|
|
59 |
# if len(prompt_config["parameters"]) < 1:
|
60 |
# prompt_config["parameters"] = default_prompt["parameters"]
|
61 |
# for p in prompt_config["parameters"]:
|
@@ -63,16 +64,21 @@ def set_dialog():
|
|
63 |
# else: prompt_config["parameters"].append(default_prompt["parameters"][0])
|
64 |
|
65 |
for p in prompt_config["parameters"]:
|
66 |
-
if p["optional"]:
|
|
|
67 |
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
68 |
-
return get_data_error_result(
|
|
|
69 |
|
70 |
try:
|
71 |
e, tenant = TenantService.get_by_id(current_user.id)
|
72 |
-
if not e:
|
|
|
73 |
llm_id = req.get("llm_id", tenant.llm_id)
|
74 |
if not dialog_id:
|
75 |
-
if not req.get("kb_ids"):
|
|
|
|
|
76 |
dia = {
|
77 |
"id": get_uuid(),
|
78 |
"tenant_id": current_user.id,
|
@@ -86,17 +92,21 @@ def set_dialog():
|
|
86 |
"similarity_threshold": similarity_threshold,
|
87 |
"vector_similarity_weight": vector_similarity_weight
|
88 |
}
|
89 |
-
if not DialogService.save(**dia):
|
|
|
90 |
e, dia = DialogService.get_by_id(dia["id"])
|
91 |
-
if not e:
|
|
|
92 |
return get_json_result(data=dia.to_json())
|
93 |
else:
|
94 |
del req["dialog_id"]
|
95 |
-
if "kb_names" in req:
|
|
|
96 |
if not DialogService.update_by_id(dialog_id, req):
|
97 |
return get_data_error_result(retmsg="Dialog not found!")
|
98 |
e, dia = DialogService.get_by_id(dialog_id)
|
99 |
-
if not e:
|
|
|
100 |
dia = dia.to_dict()
|
101 |
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
102 |
return get_json_result(data=dia)
|
@@ -110,7 +120,8 @@ def get():
|
|
110 |
dialog_id = request.args["dialog_id"]
|
111 |
try:
|
112 |
e, dia = DialogService.get_by_id(dialog_id)
|
113 |
-
if not e:
|
|
|
114 |
dia = dia.to_dict()
|
115 |
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
116 |
return get_json_result(data=dia)
|
@@ -122,7 +133,8 @@ def get_kb_names(kb_ids):
|
|
122 |
ids, nms = [], []
|
123 |
for kid in kb_ids:
|
124 |
e, kb = KnowledgebaseService.get_by_id(kid)
|
125 |
-
if not e or kb.status != StatusEnum.VALID.value:
|
|
|
126 |
ids.append(kid)
|
127 |
nms.append(kb.name)
|
128 |
return ids, nms
|
@@ -132,7 +144,11 @@ def get_kb_names(kb_ids):
|
|
132 |
@login_required
|
133 |
def list():
|
134 |
try:
|
135 |
-
diags = DialogService.query(
|
|
|
|
|
|
|
|
|
136 |
diags = [d.to_dict() for d in diags]
|
137 |
for d in diags:
|
138 |
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
@@ -147,7 +163,8 @@ def list():
|
|
147 |
def rm():
|
148 |
req = request.json
|
149 |
try:
|
150 |
-
DialogService.update_many_by_id(
|
|
|
151 |
return get_json_result(data=True)
|
152 |
except Exception as e:
|
153 |
return server_error_response(e)
|
|
|
55 |
}
|
56 |
prompt_config = req.get("prompt_config", default_prompt)
|
57 |
|
58 |
+
if not prompt_config["system"]:
|
59 |
+
prompt_config["system"] = default_prompt["system"]
|
60 |
# if len(prompt_config["parameters"]) < 1:
|
61 |
# prompt_config["parameters"] = default_prompt["parameters"]
|
62 |
# for p in prompt_config["parameters"]:
|
|
|
64 |
# else: prompt_config["parameters"].append(default_prompt["parameters"][0])
|
65 |
|
66 |
for p in prompt_config["parameters"]:
|
67 |
+
if p["optional"]:
|
68 |
+
continue
|
69 |
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
70 |
+
return get_data_error_result(
|
71 |
+
retmsg="Parameter '{}' is not used".format(p["key"]))
|
72 |
|
73 |
try:
|
74 |
e, tenant = TenantService.get_by_id(current_user.id)
|
75 |
+
if not e:
|
76 |
+
return get_data_error_result(retmsg="Tenant not found!")
|
77 |
llm_id = req.get("llm_id", tenant.llm_id)
|
78 |
if not dialog_id:
|
79 |
+
if not req.get("kb_ids"):
|
80 |
+
return get_data_error_result(
|
81 |
+
retmsg="Fail! Please select knowledgebase!")
|
82 |
dia = {
|
83 |
"id": get_uuid(),
|
84 |
"tenant_id": current_user.id,
|
|
|
92 |
"similarity_threshold": similarity_threshold,
|
93 |
"vector_similarity_weight": vector_similarity_weight
|
94 |
}
|
95 |
+
if not DialogService.save(**dia):
|
96 |
+
return get_data_error_result(retmsg="Fail to new a dialog!")
|
97 |
e, dia = DialogService.get_by_id(dia["id"])
|
98 |
+
if not e:
|
99 |
+
return get_data_error_result(retmsg="Fail to new a dialog!")
|
100 |
return get_json_result(data=dia.to_json())
|
101 |
else:
|
102 |
del req["dialog_id"]
|
103 |
+
if "kb_names" in req:
|
104 |
+
del req["kb_names"]
|
105 |
if not DialogService.update_by_id(dialog_id, req):
|
106 |
return get_data_error_result(retmsg="Dialog not found!")
|
107 |
e, dia = DialogService.get_by_id(dialog_id)
|
108 |
+
if not e:
|
109 |
+
return get_data_error_result(retmsg="Fail to update a dialog!")
|
110 |
dia = dia.to_dict()
|
111 |
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
112 |
return get_json_result(data=dia)
|
|
|
120 |
dialog_id = request.args["dialog_id"]
|
121 |
try:
|
122 |
e, dia = DialogService.get_by_id(dialog_id)
|
123 |
+
if not e:
|
124 |
+
return get_data_error_result(retmsg="Dialog not found!")
|
125 |
dia = dia.to_dict()
|
126 |
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
127 |
return get_json_result(data=dia)
|
|
|
133 |
ids, nms = [], []
|
134 |
for kid in kb_ids:
|
135 |
e, kb = KnowledgebaseService.get_by_id(kid)
|
136 |
+
if not e or kb.status != StatusEnum.VALID.value:
|
137 |
+
continue
|
138 |
ids.append(kid)
|
139 |
nms.append(kb.name)
|
140 |
return ids, nms
|
|
|
144 |
@login_required
|
145 |
def list():
|
146 |
try:
|
147 |
+
diags = DialogService.query(
|
148 |
+
tenant_id=current_user.id,
|
149 |
+
status=StatusEnum.VALID.value,
|
150 |
+
reverse=True,
|
151 |
+
order_by=DialogService.model.create_time)
|
152 |
diags = [d.to_dict() for d in diags]
|
153 |
for d in diags:
|
154 |
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
|
|
163 |
def rm():
|
164 |
req = request.json
|
165 |
try:
|
166 |
+
DialogService.update_many_by_id(
|
167 |
+
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
|
168 |
return get_json_result(data=True)
|
169 |
except Exception as e:
|
170 |
return server_error_response(e)
|
api/apps/document_app.py
CHANGED
@@ -57,6 +57,9 @@ def upload():
|
|
57 |
if not e:
|
58 |
return get_data_error_result(
|
59 |
retmsg="Can't find this knowledgebase!")
|
|
|
|
|
|
|
60 |
|
61 |
filename = duplicate_name(
|
62 |
DocumentService.query,
|
@@ -215,9 +218,11 @@ def rm():
|
|
215 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
216 |
if not tenant_id:
|
217 |
return get_data_error_result(retmsg="Tenant not found!")
|
218 |
-
ELASTICSEARCH.deleteByQuery(
|
|
|
219 |
|
220 |
-
DocumentService.increment_chunk_num(
|
|
|
221 |
if not DocumentService.delete(doc):
|
222 |
return get_data_error_result(
|
223 |
retmsg="Database error (Document removal)!")
|
@@ -245,7 +250,8 @@ def run():
|
|
245 |
tenant_id = DocumentService.get_tenant_id(id)
|
246 |
if not tenant_id:
|
247 |
return get_data_error_result(retmsg="Tenant not found!")
|
248 |
-
ELASTICSEARCH.deleteByQuery(
|
|
|
249 |
|
250 |
return get_json_result(data=True)
|
251 |
except Exception as e:
|
@@ -261,7 +267,8 @@ def rename():
|
|
261 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
262 |
if not e:
|
263 |
return get_data_error_result(retmsg="Document not found!")
|
264 |
-
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
|
|
265 |
return get_json_result(
|
266 |
data=False,
|
267 |
retmsg="The extension of file can't be changed",
|
@@ -294,7 +301,10 @@ def get(doc_id):
|
|
294 |
if doc.type == FileType.VISUAL.value:
|
295 |
response.headers.set('Content-Type', 'image/%s' % ext.group(1))
|
296 |
else:
|
297 |
-
response.headers.set(
|
|
|
|
|
|
|
298 |
return response
|
299 |
except Exception as e:
|
300 |
return server_error_response(e)
|
@@ -313,9 +323,11 @@ def change_parser():
|
|
313 |
if "parser_config" in req:
|
314 |
if req["parser_config"] == doc.parser_config:
|
315 |
return get_json_result(data=True)
|
316 |
-
else:
|
|
|
317 |
|
318 |
-
if doc.type == FileType.VISUAL or re.search(
|
|
|
319 |
return get_data_error_result(retmsg="Not supported yet!")
|
320 |
|
321 |
e = DocumentService.update_by_id(doc.id,
|
@@ -332,7 +344,8 @@ def change_parser():
|
|
332 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
333 |
if not tenant_id:
|
334 |
return get_data_error_result(retmsg="Tenant not found!")
|
335 |
-
ELASTICSEARCH.deleteByQuery(
|
|
|
336 |
|
337 |
return get_json_result(data=True)
|
338 |
except Exception as e:
|
|
|
57 |
if not e:
|
58 |
return get_data_error_result(
|
59 |
retmsg="Can't find this knowledgebase!")
|
60 |
+
if DocumentService.get_doc_count(kb.tenant_id) >= 128:
|
61 |
+
return get_data_error_result(
|
62 |
+
retmsg="Exceed the maximum file number of a free user!")
|
63 |
|
64 |
filename = duplicate_name(
|
65 |
DocumentService.query,
|
|
|
218 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
219 |
if not tenant_id:
|
220 |
return get_data_error_result(retmsg="Tenant not found!")
|
221 |
+
ELASTICSEARCH.deleteByQuery(
|
222 |
+
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
223 |
|
224 |
+
DocumentService.increment_chunk_num(
|
225 |
+
doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
|
226 |
if not DocumentService.delete(doc):
|
227 |
return get_data_error_result(
|
228 |
retmsg="Database error (Document removal)!")
|
|
|
250 |
tenant_id = DocumentService.get_tenant_id(id)
|
251 |
if not tenant_id:
|
252 |
return get_data_error_result(retmsg="Tenant not found!")
|
253 |
+
ELASTICSEARCH.deleteByQuery(
|
254 |
+
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
255 |
|
256 |
return get_json_result(data=True)
|
257 |
except Exception as e:
|
|
|
267 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
268 |
if not e:
|
269 |
return get_data_error_result(retmsg="Document not found!")
|
270 |
+
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
271 |
+
doc.name.lower()).suffix:
|
272 |
return get_json_result(
|
273 |
data=False,
|
274 |
retmsg="The extension of file can't be changed",
|
|
|
301 |
if doc.type == FileType.VISUAL.value:
|
302 |
response.headers.set('Content-Type', 'image/%s' % ext.group(1))
|
303 |
else:
|
304 |
+
response.headers.set(
|
305 |
+
'Content-Type',
|
306 |
+
'application/%s' %
|
307 |
+
ext.group(1))
|
308 |
return response
|
309 |
except Exception as e:
|
310 |
return server_error_response(e)
|
|
|
323 |
if "parser_config" in req:
|
324 |
if req["parser_config"] == doc.parser_config:
|
325 |
return get_json_result(data=True)
|
326 |
+
else:
|
327 |
+
return get_json_result(data=True)
|
328 |
|
329 |
+
if doc.type == FileType.VISUAL or re.search(
|
330 |
+
r"\.(ppt|pptx|pages)$", doc.name):
|
331 |
return get_data_error_result(retmsg="Not supported yet!")
|
332 |
|
333 |
e = DocumentService.update_by_id(doc.id,
|
|
|
344 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
345 |
if not tenant_id:
|
346 |
return get_data_error_result(retmsg="Tenant not found!")
|
347 |
+
ELASTICSEARCH.deleteByQuery(
|
348 |
+
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
349 |
|
350 |
return get_json_result(data=True)
|
351 |
except Exception as e:
|
api/apps/kb_app.py
CHANGED
@@ -33,15 +33,21 @@ from api.utils.api_utils import get_json_result
|
|
33 |
def create():
|
34 |
req = request.json
|
35 |
req["name"] = req["name"].strip()
|
36 |
-
req["name"] = duplicate_name(
|
|
|
|
|
|
|
|
|
37 |
try:
|
38 |
req["id"] = get_uuid()
|
39 |
req["tenant_id"] = current_user.id
|
40 |
req["created_by"] = current_user.id
|
41 |
e, t = TenantService.get_by_id(current_user.id)
|
42 |
-
if not e:
|
|
|
43 |
req["embd_id"] = t.embd_id
|
44 |
-
if not KnowledgebaseService.save(**req):
|
|
|
45 |
return get_json_result(data={"kb_id": req["id"]})
|
46 |
except Exception as e:
|
47 |
return server_error_response(e)
|
@@ -54,21 +60,29 @@ def update():
|
|
54 |
req = request.json
|
55 |
req["name"] = req["name"].strip()
|
56 |
try:
|
57 |
-
if not KnowledgebaseService.query(
|
58 |
-
|
|
|
|
|
59 |
|
60 |
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
61 |
-
if not e:
|
|
|
|
|
62 |
|
63 |
if req["name"].lower() != kb.name.lower() \
|
64 |
-
|
65 |
-
return get_data_error_result(
|
|
|
66 |
|
67 |
del req["kb_id"]
|
68 |
-
if not KnowledgebaseService.update_by_id(kb.id, req):
|
|
|
69 |
|
70 |
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
71 |
-
if not e:
|
|
|
|
|
72 |
|
73 |
return get_json_result(data=kb.to_json())
|
74 |
except Exception as e:
|
@@ -81,7 +95,9 @@ def detail():
|
|
81 |
kb_id = request.args["kb_id"]
|
82 |
try:
|
83 |
kb = KnowledgebaseService.get_detail(kb_id)
|
84 |
-
if not kb:
|
|
|
|
|
85 |
return get_json_result(data=kb)
|
86 |
except Exception as e:
|
87 |
return server_error_response(e)
|
@@ -96,7 +112,8 @@ def list():
|
|
96 |
desc = request.args.get("desc", True)
|
97 |
try:
|
98 |
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
99 |
-
kbs = KnowledgebaseService.get_by_tenant_ids(
|
|
|
100 |
return get_json_result(data=kbs)
|
101 |
except Exception as e:
|
102 |
return server_error_response(e)
|
@@ -108,10 +125,15 @@ def list():
|
|
108 |
def rm():
|
109 |
req = request.json
|
110 |
try:
|
111 |
-
if not KnowledgebaseService.query(
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
115 |
return get_json_result(data=True)
|
116 |
except Exception as e:
|
117 |
-
return server_error_response(e)
|
|
|
33 |
def create():
|
34 |
req = request.json
|
35 |
req["name"] = req["name"].strip()
|
36 |
+
req["name"] = duplicate_name(
|
37 |
+
KnowledgebaseService.query,
|
38 |
+
name=req["name"],
|
39 |
+
tenant_id=current_user.id,
|
40 |
+
status=StatusEnum.VALID.value)
|
41 |
try:
|
42 |
req["id"] = get_uuid()
|
43 |
req["tenant_id"] = current_user.id
|
44 |
req["created_by"] = current_user.id
|
45 |
e, t = TenantService.get_by_id(current_user.id)
|
46 |
+
if not e:
|
47 |
+
return get_data_error_result(retmsg="Tenant not found.")
|
48 |
req["embd_id"] = t.embd_id
|
49 |
+
if not KnowledgebaseService.save(**req):
|
50 |
+
return get_data_error_result()
|
51 |
return get_json_result(data={"kb_id": req["id"]})
|
52 |
except Exception as e:
|
53 |
return server_error_response(e)
|
|
|
60 |
req = request.json
|
61 |
req["name"] = req["name"].strip()
|
62 |
try:
|
63 |
+
if not KnowledgebaseService.query(
|
64 |
+
created_by=current_user.id, id=req["kb_id"]):
|
65 |
+
return get_json_result(
|
66 |
+
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
67 |
|
68 |
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
69 |
+
if not e:
|
70 |
+
return get_data_error_result(
|
71 |
+
retmsg="Can't find this knowledgebase!")
|
72 |
|
73 |
if req["name"].lower() != kb.name.lower() \
|
74 |
+
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
|
75 |
+
return get_data_error_result(
|
76 |
+
retmsg="Duplicated knowledgebase name.")
|
77 |
|
78 |
del req["kb_id"]
|
79 |
+
if not KnowledgebaseService.update_by_id(kb.id, req):
|
80 |
+
return get_data_error_result()
|
81 |
|
82 |
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
83 |
+
if not e:
|
84 |
+
return get_data_error_result(
|
85 |
+
retmsg="Database error (Knowledgebase rename)!")
|
86 |
|
87 |
return get_json_result(data=kb.to_json())
|
88 |
except Exception as e:
|
|
|
95 |
kb_id = request.args["kb_id"]
|
96 |
try:
|
97 |
kb = KnowledgebaseService.get_detail(kb_id)
|
98 |
+
if not kb:
|
99 |
+
return get_data_error_result(
|
100 |
+
retmsg="Can't find this knowledgebase!")
|
101 |
return get_json_result(data=kb)
|
102 |
except Exception as e:
|
103 |
return server_error_response(e)
|
|
|
112 |
desc = request.args.get("desc", True)
|
113 |
try:
|
114 |
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
115 |
+
kbs = KnowledgebaseService.get_by_tenant_ids(
|
116 |
+
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
117 |
return get_json_result(data=kbs)
|
118 |
except Exception as e:
|
119 |
return server_error_response(e)
|
|
|
125 |
def rm():
|
126 |
req = request.json
|
127 |
try:
|
128 |
+
if not KnowledgebaseService.query(
|
129 |
+
created_by=current_user.id, id=req["kb_id"]):
|
130 |
+
return get_json_result(
|
131 |
+
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
132 |
+
|
133 |
+
if not KnowledgebaseService.update_by_id(
|
134 |
+
req["kb_id"], {"status": StatusEnum.INVALID.value}):
|
135 |
+
return get_data_error_result(
|
136 |
+
retmsg="Database error (Knowledgebase removal)!")
|
137 |
return get_json_result(data=True)
|
138 |
except Exception as e:
|
139 |
+
return server_error_response(e)
|
api/apps/llm_app.py
CHANGED
@@ -48,30 +48,42 @@ def set_api_key():
|
|
48 |
req["api_key"], llm.llm_name)
|
49 |
try:
|
50 |
arr, tc = mdl.encode(["Test if the api key is available"])
|
51 |
-
if len(arr[0]) == 0 or tc ==0:
|
|
|
52 |
except Exception as e:
|
53 |
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
|
54 |
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
55 |
mdl = ChatModel[factory](
|
56 |
req["api_key"], llm.llm_name)
|
57 |
try:
|
58 |
-
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
59 |
-
|
|
|
|
|
60 |
chat_passed = True
|
61 |
except Exception as e:
|
62 |
-
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
|
|
63 |
|
64 |
-
if msg:
|
|
|
65 |
|
66 |
llm = {
|
67 |
"api_key": req["api_key"]
|
68 |
}
|
69 |
for n in ["model_type", "llm_name"]:
|
70 |
-
if n in req:
|
|
|
71 |
|
72 |
-
if not TenantLLMService.filter_update(
|
|
|
73 |
for llm in LLMService.query(fid=factory):
|
74 |
-
TenantLLMService.save(
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
return get_json_result(data=True)
|
77 |
|
@@ -105,17 +117,19 @@ def list():
|
|
105 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
106 |
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
107 |
llms = LLMService.get_all()
|
108 |
-
llms = [m.to_dict()
|
|
|
109 |
for m in llms:
|
110 |
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
111 |
|
112 |
res = {}
|
113 |
for m in llms:
|
114 |
-
if model_type and m["model_type"] != model_type:
|
115 |
-
|
|
|
|
|
116 |
res[m["fid"]].append(m)
|
117 |
|
118 |
return get_json_result(data=res)
|
119 |
except Exception as e:
|
120 |
return server_error_response(e)
|
121 |
-
|
|
|
48 |
req["api_key"], llm.llm_name)
|
49 |
try:
|
50 |
arr, tc = mdl.encode(["Test if the api key is available"])
|
51 |
+
if len(arr[0]) == 0 or tc == 0:
|
52 |
+
raise Exception("Fail")
|
53 |
except Exception as e:
|
54 |
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
|
55 |
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
56 |
mdl = ChatModel[factory](
|
57 |
req["api_key"], llm.llm_name)
|
58 |
try:
|
59 |
+
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
60 |
+
"temperature": 0.9})
|
61 |
+
if not tc:
|
62 |
+
raise Exception(m)
|
63 |
chat_passed = True
|
64 |
except Exception as e:
|
65 |
+
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
66 |
+
e)
|
67 |
|
68 |
+
if msg:
|
69 |
+
return get_data_error_result(retmsg=msg)
|
70 |
|
71 |
llm = {
|
72 |
"api_key": req["api_key"]
|
73 |
}
|
74 |
for n in ["model_type", "llm_name"]:
|
75 |
+
if n in req:
|
76 |
+
llm[n] = req[n]
|
77 |
|
78 |
+
if not TenantLLMService.filter_update(
|
79 |
+
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm):
|
80 |
for llm in LLMService.query(fid=factory):
|
81 |
+
TenantLLMService.save(
|
82 |
+
tenant_id=current_user.id,
|
83 |
+
llm_factory=factory,
|
84 |
+
llm_name=llm.llm_name,
|
85 |
+
model_type=llm.model_type,
|
86 |
+
api_key=req["api_key"])
|
87 |
|
88 |
return get_json_result(data=True)
|
89 |
|
|
|
117 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
118 |
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
119 |
llms = LLMService.get_all()
|
120 |
+
llms = [m.to_dict()
|
121 |
+
for m in llms if m.status == StatusEnum.VALID.value]
|
122 |
for m in llms:
|
123 |
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
124 |
|
125 |
res = {}
|
126 |
for m in llms:
|
127 |
+
if model_type and m["model_type"] != model_type:
|
128 |
+
continue
|
129 |
+
if m["fid"] not in res:
|
130 |
+
res[m["fid"]] = []
|
131 |
res[m["fid"]].append(m)
|
132 |
|
133 |
return get_json_result(data=res)
|
134 |
except Exception as e:
|
135 |
return server_error_response(e)
|
|
api/apps/user_app.py
CHANGED
@@ -40,13 +40,16 @@ def login():
|
|
40 |
|
41 |
email = request.json.get('email', "")
|
42 |
users = UserService.query(email=email)
|
43 |
-
if not users:
|
|
|
|
|
44 |
|
45 |
password = request.json.get('password')
|
46 |
try:
|
47 |
password = decrypt(password)
|
48 |
-
except:
|
49 |
-
return get_json_result(
|
|
|
50 |
|
51 |
user = UserService.query_user(email, password)
|
52 |
if user:
|
@@ -57,7 +60,8 @@ def login():
|
|
57 |
msg = "Welcome back!"
|
58 |
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
|
59 |
else:
|
60 |
-
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
|
|
61 |
|
62 |
|
63 |
@manager.route('/github_callback', methods=['GET'])
|
@@ -65,7 +69,7 @@ def github_callback():
|
|
65 |
import requests
|
66 |
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
67 |
"client_id": GITHUB_OAUTH.get("client_id"),
|
68 |
-
"client_secret":
|
69 |
"code": request.args.get('code')
|
70 |
}, headers={"Accept": "application/json"})
|
71 |
res = res.json()
|
@@ -96,15 +100,17 @@ def github_callback():
|
|
96 |
"last_login_time": get_format_time(),
|
97 |
"is_superuser": False,
|
98 |
})
|
99 |
-
if not users:
|
100 |
-
|
|
|
|
|
101 |
user = users[0]
|
102 |
login_user(user)
|
103 |
-
return redirect("/?auth=%s"%user.get_id())
|
104 |
except Exception as e:
|
105 |
rollback_user_registration(user_id)
|
106 |
stat_logger.exception(e)
|
107 |
-
return redirect("/?error=%s"%str(e))
|
108 |
user = users[0]
|
109 |
user.access_token = get_uuid()
|
110 |
login_user(user)
|
@@ -114,11 +120,18 @@ def github_callback():
|
|
114 |
|
115 |
def user_info_from_github(access_token):
|
116 |
import requests
|
117 |
-
headers = {"Accept": "application/json",
|
118 |
-
|
|
|
|
|
|
|
119 |
user_info = res.json()
|
120 |
-
email_info = requests.get(
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
return user_info
|
123 |
|
124 |
|
@@ -138,13 +151,18 @@ def setting_user():
|
|
138 |
request_data = request.json
|
139 |
if request_data.get("password"):
|
140 |
new_password = request_data.get("new_password")
|
141 |
-
if not check_password_hash(
|
142 |
-
|
|
|
|
|
143 |
|
144 |
-
if new_password:
|
|
|
|
|
145 |
|
146 |
for k in request_data.keys():
|
147 |
-
if k in ["password", "new_password"]:
|
|
|
148 |
update_dict[k] = request_data[k]
|
149 |
|
150 |
try:
|
@@ -152,7 +170,8 @@ def setting_user():
|
|
152 |
return get_json_result(data=True)
|
153 |
except Exception as e:
|
154 |
stat_logger.exception(e)
|
155 |
-
return get_json_result(
|
|
|
156 |
|
157 |
|
158 |
@manager.route("/info", methods=["GET"])
|
@@ -173,11 +192,11 @@ def rollback_user_registration(user_id):
|
|
173 |
except Exception as e:
|
174 |
pass
|
175 |
try:
|
176 |
-
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
177 |
except Exception as e:
|
178 |
pass
|
179 |
|
180 |
-
|
181 |
def user_register(user_id, user):
|
182 |
user["id"] = user_id
|
183 |
tenant = {
|
@@ -197,9 +216,14 @@ def user_register(user_id, user):
|
|
197 |
}
|
198 |
tenant_llm = []
|
199 |
for llm in LLMService.query(fid=LLM_FACTORY):
|
200 |
-
tenant_llm.append({"tenant_id": user_id,
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
203 |
TenantService.insert(**tenant)
|
204 |
UserTenantService.insert(**usr_tenant)
|
205 |
TenantLLMService.insert_many(tenant_llm)
|
@@ -211,7 +235,8 @@ def user_register(user_id, user):
|
|
211 |
def user_add():
|
212 |
req = request.json
|
213 |
if UserService.query(email=req["email"]):
|
214 |
-
return get_json_result(
|
|
|
215 |
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
|
216 |
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
|
217 |
retcode=RetCode.OPERATING_ERROR)
|
@@ -229,16 +254,19 @@ def user_add():
|
|
229 |
user_id = get_uuid()
|
230 |
try:
|
231 |
users = user_register(user_id, user_dict)
|
232 |
-
if not users:
|
233 |
-
|
|
|
|
|
234 |
user = users[0]
|
235 |
login_user(user)
|
236 |
-
return cors_reponse(data=user.to_json(),
|
|
|
237 |
except Exception as e:
|
238 |
rollback_user_registration(user_id)
|
239 |
stat_logger.exception(e)
|
240 |
-
return get_json_result(
|
241 |
-
|
242 |
|
243 |
|
244 |
@manager.route("/tenant_info", methods=["GET"])
|
|
|
40 |
|
41 |
email = request.json.get('email', "")
|
42 |
users = UserService.query(email=email)
|
43 |
+
if not users:
|
44 |
+
return get_json_result(
|
45 |
+
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
46 |
|
47 |
password = request.json.get('password')
|
48 |
try:
|
49 |
password = decrypt(password)
|
50 |
+
except BaseException:
|
51 |
+
return get_json_result(
|
52 |
+
data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
53 |
|
54 |
user = UserService.query_user(email, password)
|
55 |
if user:
|
|
|
60 |
msg = "Welcome back!"
|
61 |
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
|
62 |
else:
|
63 |
+
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
64 |
+
retmsg='Email and Password do not match!')
|
65 |
|
66 |
|
67 |
@manager.route('/github_callback', methods=['GET'])
|
|
|
69 |
import requests
|
70 |
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
71 |
"client_id": GITHUB_OAUTH.get("client_id"),
|
72 |
+
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
73 |
"code": request.args.get('code')
|
74 |
}, headers={"Accept": "application/json"})
|
75 |
res = res.json()
|
|
|
100 |
"last_login_time": get_format_time(),
|
101 |
"is_superuser": False,
|
102 |
})
|
103 |
+
if not users:
|
104 |
+
raise Exception('Register user failure.')
|
105 |
+
if len(users) > 1:
|
106 |
+
raise Exception('Same E-mail exist!')
|
107 |
user = users[0]
|
108 |
login_user(user)
|
109 |
+
return redirect("/?auth=%s" % user.get_id())
|
110 |
except Exception as e:
|
111 |
rollback_user_registration(user_id)
|
112 |
stat_logger.exception(e)
|
113 |
+
return redirect("/?error=%s" % str(e))
|
114 |
user = users[0]
|
115 |
user.access_token = get_uuid()
|
116 |
login_user(user)
|
|
|
120 |
|
121 |
def user_info_from_github(access_token):
|
122 |
import requests
|
123 |
+
headers = {"Accept": "application/json",
|
124 |
+
'Authorization': f"token {access_token}"}
|
125 |
+
res = requests.get(
|
126 |
+
f"https://api.github.com/user?access_token={access_token}",
|
127 |
+
headers=headers)
|
128 |
user_info = res.json()
|
129 |
+
email_info = requests.get(
|
130 |
+
f"https://api.github.com/user/emails?access_token={access_token}",
|
131 |
+
headers=headers).json()
|
132 |
+
user_info["email"] = next(
|
133 |
+
(email for email in email_info if email['primary'] == True),
|
134 |
+
None)["email"]
|
135 |
return user_info
|
136 |
|
137 |
|
|
|
151 |
request_data = request.json
|
152 |
if request_data.get("password"):
|
153 |
new_password = request_data.get("new_password")
|
154 |
+
if not check_password_hash(
|
155 |
+
current_user.password, decrypt(request_data["password"])):
|
156 |
+
return get_json_result(
|
157 |
+
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
158 |
|
159 |
+
if new_password:
|
160 |
+
update_dict["password"] = generate_password_hash(
|
161 |
+
decrypt(new_password))
|
162 |
|
163 |
for k in request_data.keys():
|
164 |
+
if k in ["password", "new_password"]:
|
165 |
+
continue
|
166 |
update_dict[k] = request_data[k]
|
167 |
|
168 |
try:
|
|
|
170 |
return get_json_result(data=True)
|
171 |
except Exception as e:
|
172 |
stat_logger.exception(e)
|
173 |
+
return get_json_result(
|
174 |
+
data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
175 |
|
176 |
|
177 |
@manager.route("/info", methods=["GET"])
|
|
|
192 |
except Exception as e:
|
193 |
pass
|
194 |
try:
|
195 |
+
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).excute()
|
196 |
except Exception as e:
|
197 |
pass
|
198 |
|
199 |
+
|
200 |
def user_register(user_id, user):
|
201 |
user["id"] = user_id
|
202 |
tenant = {
|
|
|
216 |
}
|
217 |
tenant_llm = []
|
218 |
for llm in LLMService.query(fid=LLM_FACTORY):
|
219 |
+
tenant_llm.append({"tenant_id": user_id,
|
220 |
+
"llm_factory": LLM_FACTORY,
|
221 |
+
"llm_name": llm.llm_name,
|
222 |
+
"model_type": llm.model_type,
|
223 |
+
"api_key": API_KEY})
|
224 |
+
|
225 |
+
if not UserService.save(**user):
|
226 |
+
return
|
227 |
TenantService.insert(**tenant)
|
228 |
UserTenantService.insert(**usr_tenant)
|
229 |
TenantLLMService.insert_many(tenant_llm)
|
|
|
235 |
def user_add():
|
236 |
req = request.json
|
237 |
if UserService.query(email=req["email"]):
|
238 |
+
return get_json_result(
|
239 |
+
data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
240 |
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
|
241 |
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
|
242 |
retcode=RetCode.OPERATING_ERROR)
|
|
|
254 |
user_id = get_uuid()
|
255 |
try:
|
256 |
users = user_register(user_id, user_dict)
|
257 |
+
if not users:
|
258 |
+
raise Exception('Register user failure.')
|
259 |
+
if len(users) > 1:
|
260 |
+
raise Exception('Same E-mail exist!')
|
261 |
user = users[0]
|
262 |
login_user(user)
|
263 |
+
return cors_reponse(data=user.to_json(),
|
264 |
+
auth=user.get_id(), retmsg="Welcome aboard!")
|
265 |
except Exception as e:
|
266 |
rollback_user_registration(user_id)
|
267 |
stat_logger.exception(e)
|
268 |
+
return get_json_result(
|
269 |
+
data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
270 |
|
271 |
|
272 |
@manager.route("/tenant_info", methods=["GET"])
|
api/db/db_models.py
CHANGED
@@ -50,7 +50,13 @@ def singleton(cls, *args, **kw):
|
|
50 |
|
51 |
|
52 |
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
53 |
-
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
class LongTextField(TextField):
|
@@ -73,7 +79,8 @@ class JSONField(LongTextField):
|
|
73 |
def python_value(self, value):
|
74 |
if not value:
|
75 |
return self.default_value
|
76 |
-
return utils.json_loads(
|
|
|
77 |
|
78 |
|
79 |
class ListField(JSONField):
|
@@ -81,7 +88,8 @@ class ListField(JSONField):
|
|
81 |
|
82 |
|
83 |
class SerializedField(LongTextField):
|
84 |
-
def __init__(self, serialized_type=SerializedType.PICKLE,
|
|
|
85 |
self._serialized_type = serialized_type
|
86 |
self._object_hook = object_hook
|
87 |
self._object_pairs_hook = object_pairs_hook
|
@@ -95,7 +103,8 @@ class SerializedField(LongTextField):
|
|
95 |
return None
|
96 |
return utils.json_dumps(value, with_type=True)
|
97 |
else:
|
98 |
-
raise ValueError(
|
|
|
99 |
|
100 |
def python_value(self, value):
|
101 |
if self._serialized_type == SerializedType.PICKLE:
|
@@ -103,9 +112,11 @@ class SerializedField(LongTextField):
|
|
103 |
elif self._serialized_type == SerializedType.JSON:
|
104 |
if value is None:
|
105 |
return {}
|
106 |
-
return utils.json_loads(
|
|
|
107 |
else:
|
108 |
-
raise ValueError(
|
|
|
109 |
|
110 |
|
111 |
def is_continuous_field(cls: typing.Type) -> bool:
|
@@ -150,7 +161,8 @@ class BaseModel(Model):
|
|
150 |
model_dict = self.__dict__['__data__']
|
151 |
|
152 |
if not only_primary_with:
|
153 |
-
return {remove_field_name_prefix(
|
|
|
154 |
|
155 |
human_model_dict = {}
|
156 |
for k in self._meta.primary_key.field_names:
|
@@ -184,17 +196,22 @@ class BaseModel(Model):
|
|
184 |
if is_continuous_field(type(getattr(cls, attr_name))):
|
185 |
if len(f_v) == 2:
|
186 |
for i, v in enumerate(f_v):
|
187 |
-
if isinstance(
|
|
|
188 |
# time type: %Y-%m-%d %H:%M:%S
|
189 |
f_v[i] = utils.date_string_to_timestamp(v)
|
190 |
lt_value = f_v[0]
|
191 |
gt_value = f_v[1]
|
192 |
if lt_value is not None and gt_value is not None:
|
193 |
-
filters.append(
|
|
|
|
|
194 |
elif lt_value is not None:
|
195 |
-
filters.append(
|
|
|
196 |
elif gt_value is not None:
|
197 |
-
filters.append(
|
|
|
198 |
else:
|
199 |
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
|
200 |
else:
|
@@ -205,9 +222,11 @@ class BaseModel(Model):
|
|
205 |
if not order_by or not hasattr(cls, f"{order_by}"):
|
206 |
order_by = "create_time"
|
207 |
if reverse is True:
|
208 |
-
query_records = query_records.order_by(
|
|
|
209 |
elif reverse is False:
|
210 |
-
query_records = query_records.order_by(
|
|
|
211 |
return [query_record for query_record in query_records]
|
212 |
else:
|
213 |
return []
|
@@ -215,7 +234,8 @@ class BaseModel(Model):
|
|
215 |
@classmethod
|
216 |
def insert(cls, __data=None, **insert):
|
217 |
if isinstance(__data, dict) and __data:
|
218 |
-
__data[cls._meta.combined["create_time"]
|
|
|
219 |
if insert:
|
220 |
insert["create_time"] = utils.current_timestamp()
|
221 |
|
@@ -228,7 +248,8 @@ class BaseModel(Model):
|
|
228 |
if not normalized:
|
229 |
return {}
|
230 |
|
231 |
-
normalized[cls._meta.combined["update_time"]
|
|
|
232 |
|
233 |
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
234 |
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
@@ -241,7 +262,8 @@ class BaseModel(Model):
|
|
241 |
|
242 |
|
243 |
class JsonSerializedField(SerializedField):
|
244 |
-
def __init__(self, object_hook=utils.from_dict_hook,
|
|
|
245 |
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
246 |
object_pairs_hook=object_pairs_hook, **kwargs)
|
247 |
|
@@ -251,7 +273,8 @@ class BaseDataBase:
|
|
251 |
def __init__(self):
|
252 |
database_config = DATABASE.copy()
|
253 |
db_name = database_config.pop("name")
|
254 |
-
self.database_connection = PooledMySQLDatabase(
|
|
|
255 |
stat_logger.info('init mysql database on cluster mode successfully')
|
256 |
|
257 |
|
@@ -263,7 +286,8 @@ class DatabaseLock:
|
|
263 |
|
264 |
def lock(self):
|
265 |
# SQL parameters only support %s format placeholders
|
266 |
-
cursor = self.db.execute_sql(
|
|
|
267 |
ret = cursor.fetchone()
|
268 |
if ret[0] == 0:
|
269 |
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
|
@@ -273,10 +297,12 @@ class DatabaseLock:
|
|
273 |
raise Exception(f'failed to acquire lock {self.lock_name}')
|
274 |
|
275 |
def unlock(self):
|
276 |
-
cursor = self.db.execute_sql(
|
|
|
277 |
ret = cursor.fetchone()
|
278 |
if ret[0] == 0:
|
279 |
-
raise Exception(
|
|
|
280 |
elif ret[0] == 1:
|
281 |
return True
|
282 |
else:
|
@@ -350,17 +376,37 @@ class User(DataBaseModel, UserMixin):
|
|
350 |
access_token = CharField(max_length=255, null=True)
|
351 |
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
352 |
password = CharField(max_length=255, null=True, help_text="password")
|
353 |
-
email = CharField(
|
|
|
|
|
|
|
|
|
354 |
avatar = TextField(null=True, help_text="avatar base64 string")
|
355 |
-
language = CharField(
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
last_login_time = DateTimeField(null=True)
|
359 |
is_authenticated = CharField(max_length=1, null=False, default="1")
|
360 |
is_active = CharField(max_length=1, null=False, default="1")
|
361 |
is_anonymous = CharField(max_length=1, null=False, default="0")
|
362 |
login_channel = CharField(null=True, help_text="from which user login")
|
363 |
-
status = CharField(
|
|
|
|
|
|
|
|
|
364 |
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
365 |
|
366 |
def __str__(self):
|
@@ -379,12 +425,28 @@ class Tenant(DataBaseModel):
|
|
379 |
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
380 |
public_key = CharField(max_length=255, null=True)
|
381 |
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
382 |
-
embd_id = CharField(
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
credit = IntegerField(default=512)
|
387 |
-
status = CharField(
|
|
|
|
|
|
|
|
|
388 |
|
389 |
class Meta:
|
390 |
db_table = "tenant"
|
@@ -396,7 +458,11 @@ class UserTenant(DataBaseModel):
|
|
396 |
tenant_id = CharField(max_length=32, null=False)
|
397 |
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
398 |
invited_by = CharField(max_length=32, null=False)
|
399 |
-
status = CharField(
|
|
|
|
|
|
|
|
|
400 |
|
401 |
class Meta:
|
402 |
db_table = "user_tenant"
|
@@ -408,17 +474,32 @@ class InvitationCode(DataBaseModel):
|
|
408 |
visit_time = DateTimeField(null=True)
|
409 |
user_id = CharField(max_length=32, null=True)
|
410 |
tenant_id = CharField(max_length=32, null=True)
|
411 |
-
status = CharField(
|
|
|
|
|
|
|
|
|
412 |
|
413 |
class Meta:
|
414 |
db_table = "invitation_code"
|
415 |
|
416 |
|
417 |
class LLMFactories(DataBaseModel):
|
418 |
-
name = CharField(
|
|
|
|
|
|
|
|
|
419 |
logo = TextField(null=True, help_text="llm logo base64")
|
420 |
-
tags = CharField(
|
421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
|
423 |
def __str__(self):
|
424 |
return self.name
|
@@ -429,12 +510,27 @@ class LLMFactories(DataBaseModel):
|
|
429 |
|
430 |
class LLM(DataBaseModel):
|
431 |
# LLMs dictionary
|
432 |
-
llm_name = CharField(
|
433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
435 |
max_tokens = IntegerField(default=0)
|
436 |
-
tags = CharField(
|
437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
def __str__(self):
|
440 |
return self.llm_name
|
@@ -445,9 +541,19 @@ class LLM(DataBaseModel):
|
|
445 |
|
446 |
class TenantLLM(DataBaseModel):
|
447 |
tenant_id = CharField(max_length=32, null=False)
|
448 |
-
llm_factory = CharField(
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
452 |
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
453 |
used_tokens = IntegerField(default=0)
|
@@ -464,11 +570,26 @@ class Knowledgebase(DataBaseModel):
|
|
464 |
id = CharField(max_length=32, primary_key=True)
|
465 |
avatar = TextField(null=True, help_text="avatar base64 string")
|
466 |
tenant_id = CharField(max_length=32, null=False)
|
467 |
-
name = CharField(
|
468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
description = TextField(null=True, help_text="KB description")
|
470 |
-
embd_id = CharField(
|
471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
created_by = CharField(max_length=32, null=False)
|
473 |
doc_num = IntegerField(default=0)
|
474 |
token_num = IntegerField(default=0)
|
@@ -476,9 +597,17 @@ class Knowledgebase(DataBaseModel):
|
|
476 |
similarity_threshold = FloatField(default=0.2)
|
477 |
vector_similarity_weight = FloatField(default=0.3)
|
478 |
|
479 |
-
parser_id = CharField(
|
480 |
-
|
481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
482 |
|
483 |
def __str__(self):
|
484 |
return self.name
|
@@ -491,22 +620,50 @@ class Document(DataBaseModel):
|
|
491 |
id = CharField(max_length=32, primary_key=True)
|
492 |
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
493 |
kb_id = CharField(max_length=256, null=False, index=True)
|
494 |
-
parser_id = CharField(
|
495 |
-
|
496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
type = CharField(max_length=32, null=False, help_text="file extension")
|
498 |
-
created_by = CharField(
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
size = IntegerField(default=0)
|
502 |
token_num = IntegerField(default=0)
|
503 |
chunk_num = IntegerField(default=0)
|
504 |
progress = FloatField(default=0)
|
505 |
-
progress_msg = TextField(
|
|
|
|
|
|
|
506 |
process_begin_at = DateTimeField(null=True)
|
507 |
process_duation = FloatField(default=0)
|
508 |
-
run = CharField(
|
509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
|
511 |
class Meta:
|
512 |
db_table = "document"
|
@@ -520,30 +677,52 @@ class Task(DataBaseModel):
|
|
520 |
begin_at = DateTimeField(null=True)
|
521 |
process_duation = FloatField(default=0)
|
522 |
progress = FloatField(default=0)
|
523 |
-
progress_msg = TextField(
|
|
|
|
|
|
|
524 |
|
525 |
|
526 |
class Dialog(DataBaseModel):
|
527 |
id = CharField(max_length=32, primary_key=True)
|
528 |
tenant_id = CharField(max_length=32, null=False)
|
529 |
-
name = CharField(
|
|
|
|
|
|
|
530 |
description = TextField(null=True, help_text="Dialog description")
|
531 |
icon = TextField(null=True, help_text="icon base64 string")
|
532 |
-
language = CharField(
|
|
|
|
|
|
|
|
|
533 |
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
534 |
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
535 |
"presence_penalty": 0.4, "max_tokens": 215})
|
536 |
-
prompt_type = CharField(
|
|
|
|
|
|
|
|
|
537 |
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
538 |
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
539 |
|
540 |
similarity_threshold = FloatField(default=0.2)
|
541 |
vector_similarity_weight = FloatField(default=0.3)
|
542 |
top_n = IntegerField(default=6)
|
543 |
-
do_refer = CharField(
|
|
|
|
|
|
|
|
|
544 |
|
545 |
kb_ids = JSONField(null=False, default=[])
|
546 |
-
status = CharField(
|
|
|
|
|
|
|
|
|
547 |
|
548 |
class Meta:
|
549 |
db_table = "dialog"
|
|
|
50 |
|
51 |
|
52 |
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
53 |
+
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
|
54 |
+
"create",
|
55 |
+
"start",
|
56 |
+
"end",
|
57 |
+
"update",
|
58 |
+
"read_access",
|
59 |
+
"write_access"}
|
60 |
|
61 |
|
62 |
class LongTextField(TextField):
|
|
|
79 |
def python_value(self, value):
|
80 |
if not value:
|
81 |
return self.default_value
|
82 |
+
return utils.json_loads(
|
83 |
+
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
84 |
|
85 |
|
86 |
class ListField(JSONField):
|
|
|
88 |
|
89 |
|
90 |
class SerializedField(LongTextField):
|
91 |
+
def __init__(self, serialized_type=SerializedType.PICKLE,
|
92 |
+
object_hook=None, object_pairs_hook=None, **kwargs):
|
93 |
self._serialized_type = serialized_type
|
94 |
self._object_hook = object_hook
|
95 |
self._object_pairs_hook = object_pairs_hook
|
|
|
103 |
return None
|
104 |
return utils.json_dumps(value, with_type=True)
|
105 |
else:
|
106 |
+
raise ValueError(
|
107 |
+
f"the serialized type {self._serialized_type} is not supported")
|
108 |
|
109 |
def python_value(self, value):
|
110 |
if self._serialized_type == SerializedType.PICKLE:
|
|
|
112 |
elif self._serialized_type == SerializedType.JSON:
|
113 |
if value is None:
|
114 |
return {}
|
115 |
+
return utils.json_loads(
|
116 |
+
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
117 |
else:
|
118 |
+
raise ValueError(
|
119 |
+
f"the serialized type {self._serialized_type} is not supported")
|
120 |
|
121 |
|
122 |
def is_continuous_field(cls: typing.Type) -> bool:
|
|
|
161 |
model_dict = self.__dict__['__data__']
|
162 |
|
163 |
if not only_primary_with:
|
164 |
+
return {remove_field_name_prefix(
|
165 |
+
k): v for k, v in model_dict.items()}
|
166 |
|
167 |
human_model_dict = {}
|
168 |
for k in self._meta.primary_key.field_names:
|
|
|
196 |
if is_continuous_field(type(getattr(cls, attr_name))):
|
197 |
if len(f_v) == 2:
|
198 |
for i, v in enumerate(f_v):
|
199 |
+
if isinstance(
|
200 |
+
v, str) and f_n in auto_date_timestamp_field():
|
201 |
# time type: %Y-%m-%d %H:%M:%S
|
202 |
f_v[i] = utils.date_string_to_timestamp(v)
|
203 |
lt_value = f_v[0]
|
204 |
gt_value = f_v[1]
|
205 |
if lt_value is not None and gt_value is not None:
|
206 |
+
filters.append(
|
207 |
+
cls.getter_by(attr_name).between(
|
208 |
+
lt_value, gt_value))
|
209 |
elif lt_value is not None:
|
210 |
+
filters.append(
|
211 |
+
operator.attrgetter(attr_name)(cls) >= lt_value)
|
212 |
elif gt_value is not None:
|
213 |
+
filters.append(
|
214 |
+
operator.attrgetter(attr_name)(cls) <= gt_value)
|
215 |
else:
|
216 |
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
|
217 |
else:
|
|
|
222 |
if not order_by or not hasattr(cls, f"{order_by}"):
|
223 |
order_by = "create_time"
|
224 |
if reverse is True:
|
225 |
+
query_records = query_records.order_by(
|
226 |
+
cls.getter_by(f"{order_by}").desc())
|
227 |
elif reverse is False:
|
228 |
+
query_records = query_records.order_by(
|
229 |
+
cls.getter_by(f"{order_by}").asc())
|
230 |
return [query_record for query_record in query_records]
|
231 |
else:
|
232 |
return []
|
|
|
234 |
@classmethod
|
235 |
def insert(cls, __data=None, **insert):
|
236 |
if isinstance(__data, dict) and __data:
|
237 |
+
__data[cls._meta.combined["create_time"]
|
238 |
+
] = utils.current_timestamp()
|
239 |
if insert:
|
240 |
insert["create_time"] = utils.current_timestamp()
|
241 |
|
|
|
248 |
if not normalized:
|
249 |
return {}
|
250 |
|
251 |
+
normalized[cls._meta.combined["update_time"]
|
252 |
+
] = utils.current_timestamp()
|
253 |
|
254 |
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
255 |
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
|
|
262 |
|
263 |
|
264 |
class JsonSerializedField(SerializedField):
|
265 |
+
def __init__(self, object_hook=utils.from_dict_hook,
|
266 |
+
object_pairs_hook=None, **kwargs):
|
267 |
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
268 |
object_pairs_hook=object_pairs_hook, **kwargs)
|
269 |
|
|
|
273 |
def __init__(self):
|
274 |
database_config = DATABASE.copy()
|
275 |
db_name = database_config.pop("name")
|
276 |
+
self.database_connection = PooledMySQLDatabase(
|
277 |
+
db_name, **database_config)
|
278 |
stat_logger.info('init mysql database on cluster mode successfully')
|
279 |
|
280 |
|
|
|
286 |
|
287 |
def lock(self):
|
288 |
# SQL parameters only support %s format placeholders
|
289 |
+
cursor = self.db.execute_sql(
|
290 |
+
"SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
|
291 |
ret = cursor.fetchone()
|
292 |
if ret[0] == 0:
|
293 |
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
|
|
|
297 |
raise Exception(f'failed to acquire lock {self.lock_name}')
|
298 |
|
299 |
def unlock(self):
|
300 |
+
cursor = self.db.execute_sql(
|
301 |
+
"SELECT RELEASE_LOCK(%s)", (self.lock_name,))
|
302 |
ret = cursor.fetchone()
|
303 |
if ret[0] == 0:
|
304 |
+
raise Exception(
|
305 |
+
f'mysql lock {self.lock_name} was not established by this thread')
|
306 |
elif ret[0] == 1:
|
307 |
return True
|
308 |
else:
|
|
|
376 |
access_token = CharField(max_length=255, null=True)
|
377 |
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
378 |
password = CharField(max_length=255, null=True, help_text="password")
|
379 |
+
email = CharField(
|
380 |
+
max_length=255,
|
381 |
+
null=False,
|
382 |
+
help_text="email",
|
383 |
+
index=True)
|
384 |
avatar = TextField(null=True, help_text="avatar base64 string")
|
385 |
+
language = CharField(
|
386 |
+
max_length=32,
|
387 |
+
null=True,
|
388 |
+
help_text="English|Chinese",
|
389 |
+
default="Chinese")
|
390 |
+
color_schema = CharField(
|
391 |
+
max_length=32,
|
392 |
+
null=True,
|
393 |
+
help_text="Bright|Dark",
|
394 |
+
default="Bright")
|
395 |
+
timezone = CharField(
|
396 |
+
max_length=64,
|
397 |
+
null=True,
|
398 |
+
help_text="Timezone",
|
399 |
+
default="UTC+8\tAsia/Shanghai")
|
400 |
last_login_time = DateTimeField(null=True)
|
401 |
is_authenticated = CharField(max_length=1, null=False, default="1")
|
402 |
is_active = CharField(max_length=1, null=False, default="1")
|
403 |
is_anonymous = CharField(max_length=1, null=False, default="0")
|
404 |
login_channel = CharField(null=True, help_text="from which user login")
|
405 |
+
status = CharField(
|
406 |
+
max_length=1,
|
407 |
+
null=True,
|
408 |
+
help_text="is it validate(0: wasted,1: validate)",
|
409 |
+
default="1")
|
410 |
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
411 |
|
412 |
def __str__(self):
|
|
|
425 |
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
426 |
public_key = CharField(max_length=255, null=True)
|
427 |
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
428 |
+
embd_id = CharField(
|
429 |
+
max_length=128,
|
430 |
+
null=False,
|
431 |
+
help_text="default embedding model ID")
|
432 |
+
asr_id = CharField(
|
433 |
+
max_length=128,
|
434 |
+
null=False,
|
435 |
+
help_text="default ASR model ID")
|
436 |
+
img2txt_id = CharField(
|
437 |
+
max_length=128,
|
438 |
+
null=False,
|
439 |
+
help_text="default image to text model ID")
|
440 |
+
parser_ids = CharField(
|
441 |
+
max_length=256,
|
442 |
+
null=False,
|
443 |
+
help_text="document processors")
|
444 |
credit = IntegerField(default=512)
|
445 |
+
status = CharField(
|
446 |
+
max_length=1,
|
447 |
+
null=True,
|
448 |
+
help_text="is it validate(0: wasted,1: validate)",
|
449 |
+
default="1")
|
450 |
|
451 |
class Meta:
|
452 |
db_table = "tenant"
|
|
|
458 |
tenant_id = CharField(max_length=32, null=False)
|
459 |
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
460 |
invited_by = CharField(max_length=32, null=False)
|
461 |
+
status = CharField(
|
462 |
+
max_length=1,
|
463 |
+
null=True,
|
464 |
+
help_text="is it validate(0: wasted,1: validate)",
|
465 |
+
default="1")
|
466 |
|
467 |
class Meta:
|
468 |
db_table = "user_tenant"
|
|
|
474 |
visit_time = DateTimeField(null=True)
|
475 |
user_id = CharField(max_length=32, null=True)
|
476 |
tenant_id = CharField(max_length=32, null=True)
|
477 |
+
status = CharField(
|
478 |
+
max_length=1,
|
479 |
+
null=True,
|
480 |
+
help_text="is it validate(0: wasted,1: validate)",
|
481 |
+
default="1")
|
482 |
|
483 |
class Meta:
|
484 |
db_table = "invitation_code"
|
485 |
|
486 |
|
487 |
class LLMFactories(DataBaseModel):
|
488 |
+
name = CharField(
|
489 |
+
max_length=128,
|
490 |
+
null=False,
|
491 |
+
help_text="LLM factory name",
|
492 |
+
primary_key=True)
|
493 |
logo = TextField(null=True, help_text="llm logo base64")
|
494 |
+
tags = CharField(
|
495 |
+
max_length=255,
|
496 |
+
null=False,
|
497 |
+
help_text="LLM, Text Embedding, Image2Text, ASR")
|
498 |
+
status = CharField(
|
499 |
+
max_length=1,
|
500 |
+
null=True,
|
501 |
+
help_text="is it validate(0: wasted,1: validate)",
|
502 |
+
default="1")
|
503 |
|
504 |
def __str__(self):
|
505 |
return self.name
|
|
|
510 |
|
511 |
class LLM(DataBaseModel):
|
512 |
# LLMs dictionary
|
513 |
+
llm_name = CharField(
|
514 |
+
max_length=128,
|
515 |
+
null=False,
|
516 |
+
help_text="LLM name",
|
517 |
+
index=True,
|
518 |
+
primary_key=True)
|
519 |
+
model_type = CharField(
|
520 |
+
max_length=128,
|
521 |
+
null=False,
|
522 |
+
help_text="LLM, Text Embedding, Image2Text, ASR")
|
523 |
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
524 |
max_tokens = IntegerField(default=0)
|
525 |
+
tags = CharField(
|
526 |
+
max_length=255,
|
527 |
+
null=False,
|
528 |
+
help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
529 |
+
status = CharField(
|
530 |
+
max_length=1,
|
531 |
+
null=True,
|
532 |
+
help_text="is it validate(0: wasted,1: validate)",
|
533 |
+
default="1")
|
534 |
|
535 |
def __str__(self):
|
536 |
return self.llm_name
|
|
|
541 |
|
542 |
class TenantLLM(DataBaseModel):
|
543 |
tenant_id = CharField(max_length=32, null=False)
|
544 |
+
llm_factory = CharField(
|
545 |
+
max_length=128,
|
546 |
+
null=False,
|
547 |
+
help_text="LLM factory name")
|
548 |
+
model_type = CharField(
|
549 |
+
max_length=128,
|
550 |
+
null=True,
|
551 |
+
help_text="LLM, Text Embedding, Image2Text, ASR")
|
552 |
+
llm_name = CharField(
|
553 |
+
max_length=128,
|
554 |
+
null=True,
|
555 |
+
help_text="LLM name",
|
556 |
+
default="")
|
557 |
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
558 |
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
559 |
used_tokens = IntegerField(default=0)
|
|
|
570 |
id = CharField(max_length=32, primary_key=True)
|
571 |
avatar = TextField(null=True, help_text="avatar base64 string")
|
572 |
tenant_id = CharField(max_length=32, null=False)
|
573 |
+
name = CharField(
|
574 |
+
max_length=128,
|
575 |
+
null=False,
|
576 |
+
help_text="KB name",
|
577 |
+
index=True)
|
578 |
+
language = CharField(
|
579 |
+
max_length=32,
|
580 |
+
null=True,
|
581 |
+
default="Chinese",
|
582 |
+
help_text="English|Chinese")
|
583 |
description = TextField(null=True, help_text="KB description")
|
584 |
+
embd_id = CharField(
|
585 |
+
max_length=128,
|
586 |
+
null=False,
|
587 |
+
help_text="default embedding model ID")
|
588 |
+
permission = CharField(
|
589 |
+
max_length=16,
|
590 |
+
null=False,
|
591 |
+
help_text="me|team",
|
592 |
+
default="me")
|
593 |
created_by = CharField(max_length=32, null=False)
|
594 |
doc_num = IntegerField(default=0)
|
595 |
token_num = IntegerField(default=0)
|
|
|
597 |
similarity_threshold = FloatField(default=0.2)
|
598 |
vector_similarity_weight = FloatField(default=0.3)
|
599 |
|
600 |
+
parser_id = CharField(
|
601 |
+
max_length=32,
|
602 |
+
null=False,
|
603 |
+
help_text="default parser ID",
|
604 |
+
default=ParserType.NAIVE.value)
|
605 |
+
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
606 |
+
status = CharField(
|
607 |
+
max_length=1,
|
608 |
+
null=True,
|
609 |
+
help_text="is it validate(0: wasted,1: validate)",
|
610 |
+
default="1")
|
611 |
|
612 |
def __str__(self):
|
613 |
return self.name
|
|
|
620 |
id = CharField(max_length=32, primary_key=True)
|
621 |
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
622 |
kb_id = CharField(max_length=256, null=False, index=True)
|
623 |
+
parser_id = CharField(
|
624 |
+
max_length=32,
|
625 |
+
null=False,
|
626 |
+
help_text="default parser ID")
|
627 |
+
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
628 |
+
source_type = CharField(
|
629 |
+
max_length=128,
|
630 |
+
null=False,
|
631 |
+
default="local",
|
632 |
+
help_text="where dose this document from")
|
633 |
type = CharField(max_length=32, null=False, help_text="file extension")
|
634 |
+
created_by = CharField(
|
635 |
+
max_length=32,
|
636 |
+
null=False,
|
637 |
+
help_text="who created it")
|
638 |
+
name = CharField(
|
639 |
+
max_length=255,
|
640 |
+
null=True,
|
641 |
+
help_text="file name",
|
642 |
+
index=True)
|
643 |
+
location = CharField(
|
644 |
+
max_length=255,
|
645 |
+
null=True,
|
646 |
+
help_text="where dose it store")
|
647 |
size = IntegerField(default=0)
|
648 |
token_num = IntegerField(default=0)
|
649 |
chunk_num = IntegerField(default=0)
|
650 |
progress = FloatField(default=0)
|
651 |
+
progress_msg = TextField(
|
652 |
+
null=True,
|
653 |
+
help_text="process message",
|
654 |
+
default="")
|
655 |
process_begin_at = DateTimeField(null=True)
|
656 |
process_duation = FloatField(default=0)
|
657 |
+
run = CharField(
|
658 |
+
max_length=1,
|
659 |
+
null=True,
|
660 |
+
help_text="start to run processing or cancel.(1: run it; 2: cancel)",
|
661 |
+
default="0")
|
662 |
+
status = CharField(
|
663 |
+
max_length=1,
|
664 |
+
null=True,
|
665 |
+
help_text="is it validate(0: wasted,1: validate)",
|
666 |
+
default="1")
|
667 |
|
668 |
class Meta:
|
669 |
db_table = "document"
|
|
|
677 |
begin_at = DateTimeField(null=True)
|
678 |
process_duation = FloatField(default=0)
|
679 |
progress = FloatField(default=0)
|
680 |
+
progress_msg = TextField(
|
681 |
+
null=True,
|
682 |
+
help_text="process message",
|
683 |
+
default="")
|
684 |
|
685 |
|
686 |
class Dialog(DataBaseModel):
|
687 |
id = CharField(max_length=32, primary_key=True)
|
688 |
tenant_id = CharField(max_length=32, null=False)
|
689 |
+
name = CharField(
|
690 |
+
max_length=255,
|
691 |
+
null=True,
|
692 |
+
help_text="dialog application name")
|
693 |
description = TextField(null=True, help_text="Dialog description")
|
694 |
icon = TextField(null=True, help_text="icon base64 string")
|
695 |
+
language = CharField(
|
696 |
+
max_length=32,
|
697 |
+
null=True,
|
698 |
+
default="Chinese",
|
699 |
+
help_text="English|Chinese")
|
700 |
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
701 |
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
702 |
"presence_penalty": 0.4, "max_tokens": 215})
|
703 |
+
prompt_type = CharField(
|
704 |
+
max_length=16,
|
705 |
+
null=False,
|
706 |
+
default="simple",
|
707 |
+
help_text="simple|advanced")
|
708 |
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
709 |
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
710 |
|
711 |
similarity_threshold = FloatField(default=0.2)
|
712 |
vector_similarity_weight = FloatField(default=0.3)
|
713 |
top_n = IntegerField(default=6)
|
714 |
+
do_refer = CharField(
|
715 |
+
max_length=1,
|
716 |
+
null=False,
|
717 |
+
help_text="it needs to insert reference index into answer or not",
|
718 |
+
default="1")
|
719 |
|
720 |
kb_ids = JSONField(null=False, default=[])
|
721 |
+
status = CharField(
|
722 |
+
max_length=1,
|
723 |
+
null=True,
|
724 |
+
help_text="is it validate(0: wasted,1: validate)",
|
725 |
+
default="1")
|
726 |
|
727 |
class Meta:
|
728 |
db_table = "dialog"
|
api/db/db_utils.py
CHANGED
@@ -32,8 +32,7 @@ LOGGER = getLogger()
|
|
32 |
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
33 |
DB.create_tables([model])
|
34 |
|
35 |
-
|
36 |
-
for i,data in enumerate(data_source):
|
37 |
current_time = current_timestamp() + i
|
38 |
current_date = timestamp_to_date(current_time)
|
39 |
if 'create_time' not in data:
|
@@ -55,7 +54,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
|
55 |
|
56 |
|
57 |
def get_dynamic_db_model(base, job_id):
|
58 |
-
return type(base.model(
|
|
|
59 |
|
60 |
|
61 |
def get_dynamic_tracking_table_index(job_id):
|
@@ -86,7 +86,9 @@ supported_operators = {
|
|
86 |
'~': operator.inv,
|
87 |
}
|
88 |
|
89 |
-
|
|
|
|
|
90 |
expression = []
|
91 |
|
92 |
for field, value in query.items():
|
@@ -95,7 +97,10 @@ def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[boo
|
|
95 |
op, *val = value
|
96 |
|
97 |
field = getattr(model, f'f_{field}')
|
98 |
-
value = supported_operators[op](
|
|
|
|
|
|
|
99 |
expression.append(value)
|
100 |
|
101 |
return reduce(operator.iand, expression)
|
|
|
32 |
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
33 |
DB.create_tables([model])
|
34 |
|
35 |
+
for i, data in enumerate(data_source):
|
|
|
36 |
current_time = current_timestamp() + i
|
37 |
current_date = timestamp_to_date(current_time)
|
38 |
if 'create_time' not in data:
|
|
|
54 |
|
55 |
|
56 |
def get_dynamic_db_model(base, job_id):
|
57 |
+
return type(base.model(
|
58 |
+
table_index=get_dynamic_tracking_table_index(job_id=job_id)))
|
59 |
|
60 |
|
61 |
def get_dynamic_tracking_table_index(job_id):
|
|
|
86 |
'~': operator.inv,
|
87 |
}
|
88 |
|
89 |
+
|
90 |
+
def query_dict2expression(
|
91 |
+
model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
|
92 |
expression = []
|
93 |
|
94 |
for field, value in query.items():
|
|
|
97 |
op, *val = value
|
98 |
|
99 |
field = getattr(model, f'f_{field}')
|
100 |
+
value = supported_operators[op](
|
101 |
+
field, val[0]) if op in supported_operators else getattr(
|
102 |
+
field, op)(
|
103 |
+
*val)
|
104 |
expression.append(value)
|
105 |
|
106 |
return reduce(operator.iand, expression)
|
api/db/init_data.py
CHANGED
@@ -61,45 +61,54 @@ def init_superuser():
|
|
61 |
TenantService.insert(**tenant)
|
62 |
UserTenantService.insert(**usr_tenant)
|
63 |
TenantLLMService.insert_many(tenant_llm)
|
64 |
-
print(
|
|
|
65 |
|
66 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
67 |
-
msg = chat_mdl.chat(system="", history=[
|
|
|
68 |
if msg.find("ERROR: ") == 0:
|
69 |
-
print(
|
|
|
|
|
|
|
|
|
70 |
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
|
71 |
v, c = embd_mdl.encode(["Hello!"])
|
72 |
if c == 0:
|
73 |
-
print(
|
|
|
|
|
|
|
74 |
|
75 |
|
76 |
factory_infos = [{
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"status": "1",
|
81 |
-
|
82 |
-
"name": "Tongyi-Qianwen",
|
83 |
-
"logo": "",
|
84 |
-
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
85 |
-
"status": "1",
|
86 |
-
},{
|
87 |
-
"name": "ZHIPU-AI",
|
88 |
-
"logo": "",
|
89 |
-
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
90 |
-
"status": "1",
|
91 |
-
},
|
92 |
-
{
|
93 |
-
"name": "Local",
|
94 |
-
"logo": "",
|
95 |
-
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
96 |
-
"status": "1",
|
97 |
-
},{
|
98 |
"name": "Moonshot",
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
# {
|
104 |
# "name": "文心一言",
|
105 |
# "logo": "",
|
@@ -107,6 +116,8 @@ factory_infos = [{
|
|
107 |
# "status": "1",
|
108 |
# },
|
109 |
]
|
|
|
|
|
110 |
def init_llm_factory():
|
111 |
llm_infos = [
|
112 |
# ---------------------- OpenAI ------------------------
|
@@ -116,37 +127,37 @@ def init_llm_factory():
|
|
116 |
"tags": "LLM,CHAT,4K",
|
117 |
"max_tokens": 4096,
|
118 |
"model_type": LLMType.CHAT.value
|
119 |
-
},{
|
120 |
"fid": factory_infos[0]["name"],
|
121 |
"llm_name": "gpt-3.5-turbo-16k-0613",
|
122 |
"tags": "LLM,CHAT,16k",
|
123 |
"max_tokens": 16385,
|
124 |
"model_type": LLMType.CHAT.value
|
125 |
-
},{
|
126 |
"fid": factory_infos[0]["name"],
|
127 |
"llm_name": "text-embedding-ada-002",
|
128 |
"tags": "TEXT EMBEDDING,8K",
|
129 |
"max_tokens": 8191,
|
130 |
"model_type": LLMType.EMBEDDING.value
|
131 |
-
},{
|
132 |
"fid": factory_infos[0]["name"],
|
133 |
"llm_name": "whisper-1",
|
134 |
"tags": "SPEECH2TEXT",
|
135 |
-
"max_tokens": 25*1024*1024,
|
136 |
"model_type": LLMType.SPEECH2TEXT.value
|
137 |
-
},{
|
138 |
"fid": factory_infos[0]["name"],
|
139 |
"llm_name": "gpt-4",
|
140 |
"tags": "LLM,CHAT,8K",
|
141 |
"max_tokens": 8191,
|
142 |
"model_type": LLMType.CHAT.value
|
143 |
-
},{
|
144 |
"fid": factory_infos[0]["name"],
|
145 |
"llm_name": "gpt-4-32k",
|
146 |
"tags": "LLM,CHAT,32K",
|
147 |
"max_tokens": 32768,
|
148 |
"model_type": LLMType.CHAT.value
|
149 |
-
},{
|
150 |
"fid": factory_infos[0]["name"],
|
151 |
"llm_name": "gpt-4-vision-preview",
|
152 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
@@ -160,31 +171,31 @@ def init_llm_factory():
|
|
160 |
"tags": "LLM,CHAT,8K",
|
161 |
"max_tokens": 8191,
|
162 |
"model_type": LLMType.CHAT.value
|
163 |
-
},{
|
164 |
"fid": factory_infos[1]["name"],
|
165 |
"llm_name": "qwen-plus",
|
166 |
"tags": "LLM,CHAT,32K",
|
167 |
"max_tokens": 32768,
|
168 |
"model_type": LLMType.CHAT.value
|
169 |
-
},{
|
170 |
"fid": factory_infos[1]["name"],
|
171 |
"llm_name": "qwen-max-1201",
|
172 |
"tags": "LLM,CHAT,6K",
|
173 |
"max_tokens": 5899,
|
174 |
"model_type": LLMType.CHAT.value
|
175 |
-
},{
|
176 |
"fid": factory_infos[1]["name"],
|
177 |
"llm_name": "text-embedding-v2",
|
178 |
"tags": "TEXT EMBEDDING,2K",
|
179 |
"max_tokens": 2048,
|
180 |
"model_type": LLMType.EMBEDDING.value
|
181 |
-
},{
|
182 |
"fid": factory_infos[1]["name"],
|
183 |
"llm_name": "paraformer-realtime-8k-v1",
|
184 |
"tags": "SPEECH2TEXT",
|
185 |
-
"max_tokens": 25*1024*1024,
|
186 |
"model_type": LLMType.SPEECH2TEXT.value
|
187 |
-
},{
|
188 |
"fid": factory_infos[1]["name"],
|
189 |
"llm_name": "qwen-vl-max",
|
190 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
@@ -245,13 +256,13 @@ def init_llm_factory():
|
|
245 |
"tags": "TEXT EMBEDDING,",
|
246 |
"max_tokens": 128 * 1000,
|
247 |
"model_type": LLMType.EMBEDDING.value
|
248 |
-
},{
|
249 |
"fid": factory_infos[4]["name"],
|
250 |
"llm_name": "moonshot-v1-32k",
|
251 |
"tags": "LLM,CHAT,",
|
252 |
"max_tokens": 32768,
|
253 |
"model_type": LLMType.CHAT.value
|
254 |
-
},{
|
255 |
"fid": factory_infos[4]["name"],
|
256 |
"llm_name": "moonshot-v1-128k",
|
257 |
"tags": "LLM,CHAT",
|
@@ -294,7 +305,6 @@ def init_web_data():
|
|
294 |
print("init web data success:{}".format(time.time() - start_time))
|
295 |
|
296 |
|
297 |
-
|
298 |
if __name__ == '__main__':
|
299 |
init_web_db()
|
300 |
-
init_web_data()
|
|
|
61 |
TenantService.insert(**tenant)
|
62 |
UserTenantService.insert(**usr_tenant)
|
63 |
TenantLLMService.insert_many(tenant_llm)
|
64 |
+
print(
|
65 |
+
"【INFO】Super user initialized. \033[93memail: [email protected], password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
66 |
|
67 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
68 |
+
msg = chat_mdl.chat(system="", history=[
|
69 |
+
{"role": "user", "content": "Hello!"}], gen_conf={})
|
70 |
if msg.find("ERROR: ") == 0:
|
71 |
+
print(
|
72 |
+
"\33[91m【ERROR】\33[0m: ",
|
73 |
+
"'{}' dosen't work. {}".format(
|
74 |
+
tenant["llm_id"],
|
75 |
+
msg))
|
76 |
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
|
77 |
v, c = embd_mdl.encode(["Hello!"])
|
78 |
if c == 0:
|
79 |
+
print(
|
80 |
+
"\33[91m【ERROR】\33[0m:",
|
81 |
+
" '{}' dosen't work!".format(
|
82 |
+
tenant["embd_id"]))
|
83 |
|
84 |
|
85 |
factory_infos = [{
|
86 |
+
"name": "OpenAI",
|
87 |
+
"logo": "",
|
88 |
+
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
89 |
+
"status": "1",
|
90 |
+
}, {
|
91 |
+
"name": "Tongyi-Qianwen",
|
92 |
+
"logo": "",
|
93 |
+
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
94 |
+
"status": "1",
|
95 |
+
}, {
|
96 |
+
"name": "ZHIPU-AI",
|
97 |
+
"logo": "",
|
98 |
+
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
99 |
+
"status": "1",
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"name": "Local",
|
103 |
+
"logo": "",
|
104 |
+
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
105 |
"status": "1",
|
106 |
+
}, {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
"name": "Moonshot",
|
108 |
+
"logo": "",
|
109 |
+
"tags": "LLM,TEXT EMBEDDING",
|
110 |
+
"status": "1",
|
111 |
+
}
|
112 |
# {
|
113 |
# "name": "文心一言",
|
114 |
# "logo": "",
|
|
|
116 |
# "status": "1",
|
117 |
# },
|
118 |
]
|
119 |
+
|
120 |
+
|
121 |
def init_llm_factory():
|
122 |
llm_infos = [
|
123 |
# ---------------------- OpenAI ------------------------
|
|
|
127 |
"tags": "LLM,CHAT,4K",
|
128 |
"max_tokens": 4096,
|
129 |
"model_type": LLMType.CHAT.value
|
130 |
+
}, {
|
131 |
"fid": factory_infos[0]["name"],
|
132 |
"llm_name": "gpt-3.5-turbo-16k-0613",
|
133 |
"tags": "LLM,CHAT,16k",
|
134 |
"max_tokens": 16385,
|
135 |
"model_type": LLMType.CHAT.value
|
136 |
+
}, {
|
137 |
"fid": factory_infos[0]["name"],
|
138 |
"llm_name": "text-embedding-ada-002",
|
139 |
"tags": "TEXT EMBEDDING,8K",
|
140 |
"max_tokens": 8191,
|
141 |
"model_type": LLMType.EMBEDDING.value
|
142 |
+
}, {
|
143 |
"fid": factory_infos[0]["name"],
|
144 |
"llm_name": "whisper-1",
|
145 |
"tags": "SPEECH2TEXT",
|
146 |
+
"max_tokens": 25 * 1024 * 1024,
|
147 |
"model_type": LLMType.SPEECH2TEXT.value
|
148 |
+
}, {
|
149 |
"fid": factory_infos[0]["name"],
|
150 |
"llm_name": "gpt-4",
|
151 |
"tags": "LLM,CHAT,8K",
|
152 |
"max_tokens": 8191,
|
153 |
"model_type": LLMType.CHAT.value
|
154 |
+
}, {
|
155 |
"fid": factory_infos[0]["name"],
|
156 |
"llm_name": "gpt-4-32k",
|
157 |
"tags": "LLM,CHAT,32K",
|
158 |
"max_tokens": 32768,
|
159 |
"model_type": LLMType.CHAT.value
|
160 |
+
}, {
|
161 |
"fid": factory_infos[0]["name"],
|
162 |
"llm_name": "gpt-4-vision-preview",
|
163 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
|
|
171 |
"tags": "LLM,CHAT,8K",
|
172 |
"max_tokens": 8191,
|
173 |
"model_type": LLMType.CHAT.value
|
174 |
+
}, {
|
175 |
"fid": factory_infos[1]["name"],
|
176 |
"llm_name": "qwen-plus",
|
177 |
"tags": "LLM,CHAT,32K",
|
178 |
"max_tokens": 32768,
|
179 |
"model_type": LLMType.CHAT.value
|
180 |
+
}, {
|
181 |
"fid": factory_infos[1]["name"],
|
182 |
"llm_name": "qwen-max-1201",
|
183 |
"tags": "LLM,CHAT,6K",
|
184 |
"max_tokens": 5899,
|
185 |
"model_type": LLMType.CHAT.value
|
186 |
+
}, {
|
187 |
"fid": factory_infos[1]["name"],
|
188 |
"llm_name": "text-embedding-v2",
|
189 |
"tags": "TEXT EMBEDDING,2K",
|
190 |
"max_tokens": 2048,
|
191 |
"model_type": LLMType.EMBEDDING.value
|
192 |
+
}, {
|
193 |
"fid": factory_infos[1]["name"],
|
194 |
"llm_name": "paraformer-realtime-8k-v1",
|
195 |
"tags": "SPEECH2TEXT",
|
196 |
+
"max_tokens": 25 * 1024 * 1024,
|
197 |
"model_type": LLMType.SPEECH2TEXT.value
|
198 |
+
}, {
|
199 |
"fid": factory_infos[1]["name"],
|
200 |
"llm_name": "qwen-vl-max",
|
201 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
|
|
256 |
"tags": "TEXT EMBEDDING,",
|
257 |
"max_tokens": 128 * 1000,
|
258 |
"model_type": LLMType.EMBEDDING.value
|
259 |
+
}, {
|
260 |
"fid": factory_infos[4]["name"],
|
261 |
"llm_name": "moonshot-v1-32k",
|
262 |
"tags": "LLM,CHAT,",
|
263 |
"max_tokens": 32768,
|
264 |
"model_type": LLMType.CHAT.value
|
265 |
+
}, {
|
266 |
"fid": factory_infos[4]["name"],
|
267 |
"llm_name": "moonshot-v1-128k",
|
268 |
"tags": "LLM,CHAT",
|
|
|
305 |
print("init web data success:{}".format(time.time() - start_time))
|
306 |
|
307 |
|
|
|
308 |
if __name__ == '__main__':
|
309 |
init_web_db()
|
310 |
+
init_web_data()
|
api/db/operatioins.py
CHANGED
@@ -18,4 +18,4 @@ import operator
|
|
18 |
import time
|
19 |
import typing
|
20 |
from api.utils.log_utils import sql_logger
|
21 |
-
import peewee
|
|
|
18 |
import time
|
19 |
import typing
|
20 |
from api.utils.log_utils import sql_logger
|
21 |
+
import peewee
|
api/db/reload_config_base.py
CHANGED
@@ -18,10 +18,11 @@ class ReloadConfigBase:
|
|
18 |
def get_all(cls):
|
19 |
configs = {}
|
20 |
for k, v in cls.__dict__.items():
|
21 |
-
if not callable(getattr(cls, k)) and not k.startswith(
|
|
|
22 |
configs[k] = v
|
23 |
return configs
|
24 |
|
25 |
@classmethod
|
26 |
def get(cls, config_name):
|
27 |
-
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
|
|
18 |
def get_all(cls):
|
19 |
configs = {}
|
20 |
for k, v in cls.__dict__.items():
|
21 |
+
if not callable(getattr(cls, k)) and not k.startswith(
|
22 |
+
"__") and not k.startswith("_"):
|
23 |
configs[k] = v
|
24 |
return configs
|
25 |
|
26 |
@classmethod
|
27 |
def get(cls, config_name):
|
28 |
+
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
api/db/runtime_config.py
CHANGED
@@ -51,4 +51,4 @@ class RuntimeConfig(ReloadConfigBase):
|
|
51 |
|
52 |
@classmethod
|
53 |
def set_service_db(cls, service_db):
|
54 |
-
cls.SERVICE_DB = service_db
|
|
|
51 |
|
52 |
@classmethod
|
53 |
def set_service_db(cls, service_db):
|
54 |
+
cls.SERVICE_DB = service_db
|
api/db/services/common_service.py
CHANGED
@@ -27,7 +27,8 @@ class CommonService:
|
|
27 |
@classmethod
|
28 |
@DB.connection_context()
|
29 |
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
30 |
-
return cls.model.query(cols=cols, reverse=reverse,
|
|
|
31 |
|
32 |
@classmethod
|
33 |
@DB.connection_context()
|
@@ -40,9 +41,11 @@ class CommonService:
|
|
40 |
if not order_by or not hasattr(cls, order_by):
|
41 |
order_by = "create_time"
|
42 |
if reverse is True:
|
43 |
-
query_records = query_records.order_by(
|
|
|
44 |
elif reverse is False:
|
45 |
-
query_records = query_records.order_by(
|
|
|
46 |
return query_records
|
47 |
|
48 |
@classmethod
|
@@ -61,7 +64,7 @@ class CommonService:
|
|
61 |
@classmethod
|
62 |
@DB.connection_context()
|
63 |
def save(cls, **kwargs):
|
64 |
-
#if "id" not in kwargs:
|
65 |
# kwargs["id"] = get_uuid()
|
66 |
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
67 |
return sample_obj
|
@@ -95,7 +98,8 @@ class CommonService:
|
|
95 |
for data in data_list:
|
96 |
data["update_time"] = current_timestamp()
|
97 |
data["update_date"] = datetime_format(datetime.now())
|
98 |
-
cls.model.update(data).where(
|
|
|
99 |
|
100 |
@classmethod
|
101 |
@DB.connection_context()
|
@@ -128,7 +132,6 @@ class CommonService:
|
|
128 |
def delete_by_id(cls, pid):
|
129 |
return cls.model.delete().where(cls.model.id == pid).execute()
|
130 |
|
131 |
-
|
132 |
@classmethod
|
133 |
@DB.connection_context()
|
134 |
def filter_delete(cls, filters):
|
@@ -151,19 +154,30 @@ class CommonService:
|
|
151 |
|
152 |
@classmethod
|
153 |
@DB.connection_context()
|
154 |
-
def filter_scope_list(cls, in_key, in_filters_list,
|
|
|
155 |
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
156 |
if not filters:
|
157 |
filters = []
|
158 |
res_list = []
|
159 |
if cols:
|
160 |
for i in in_filters_tuple_list:
|
161 |
-
query_records = cls.model.select(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
if query_records:
|
163 |
-
res_list.extend(
|
|
|
164 |
else:
|
165 |
for i in in_filters_tuple_list:
|
166 |
-
query_records = cls.model.select().where(
|
|
|
167 |
if query_records:
|
168 |
-
res_list.extend(
|
169 |
-
|
|
|
|
27 |
@classmethod
|
28 |
@DB.connection_context()
|
29 |
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
30 |
+
return cls.model.query(cols=cols, reverse=reverse,
|
31 |
+
order_by=order_by, **kwargs)
|
32 |
|
33 |
@classmethod
|
34 |
@DB.connection_context()
|
|
|
41 |
if not order_by or not hasattr(cls, order_by):
|
42 |
order_by = "create_time"
|
43 |
if reverse is True:
|
44 |
+
query_records = query_records.order_by(
|
45 |
+
cls.model.getter_by(order_by).desc())
|
46 |
elif reverse is False:
|
47 |
+
query_records = query_records.order_by(
|
48 |
+
cls.model.getter_by(order_by).asc())
|
49 |
return query_records
|
50 |
|
51 |
@classmethod
|
|
|
64 |
@classmethod
|
65 |
@DB.connection_context()
|
66 |
def save(cls, **kwargs):
|
67 |
+
# if "id" not in kwargs:
|
68 |
# kwargs["id"] = get_uuid()
|
69 |
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
70 |
return sample_obj
|
|
|
98 |
for data in data_list:
|
99 |
data["update_time"] = current_timestamp()
|
100 |
data["update_date"] = datetime_format(datetime.now())
|
101 |
+
cls.model.update(data).where(
|
102 |
+
cls.model.id == data["id"]).execute()
|
103 |
|
104 |
@classmethod
|
105 |
@DB.connection_context()
|
|
|
132 |
def delete_by_id(cls, pid):
|
133 |
return cls.model.delete().where(cls.model.id == pid).execute()
|
134 |
|
|
|
135 |
@classmethod
|
136 |
@DB.connection_context()
|
137 |
def filter_delete(cls, filters):
|
|
|
154 |
|
155 |
@classmethod
|
156 |
@DB.connection_context()
|
157 |
+
def filter_scope_list(cls, in_key, in_filters_list,
|
158 |
+
filters=None, cols=None):
|
159 |
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
160 |
if not filters:
|
161 |
filters = []
|
162 |
res_list = []
|
163 |
if cols:
|
164 |
for i in in_filters_tuple_list:
|
165 |
+
query_records = cls.model.select(
|
166 |
+
*
|
167 |
+
cols).where(
|
168 |
+
getattr(
|
169 |
+
cls.model,
|
170 |
+
in_key).in_(i),
|
171 |
+
*
|
172 |
+
filters)
|
173 |
if query_records:
|
174 |
+
res_list.extend(
|
175 |
+
[query_record for query_record in query_records])
|
176 |
else:
|
177 |
for i in in_filters_tuple_list:
|
178 |
+
query_records = cls.model.select().where(
|
179 |
+
getattr(cls.model, in_key).in_(i), *filters)
|
180 |
if query_records:
|
181 |
+
res_list.extend(
|
182 |
+
[query_record for query_record in query_records])
|
183 |
+
return res_list
|
api/db/services/dialog_service.py
CHANGED
@@ -21,6 +21,5 @@ class DialogService(CommonService):
|
|
21 |
model = Dialog
|
22 |
|
23 |
|
24 |
-
|
25 |
class ConversationService(CommonService):
|
26 |
model = Conversation
|
|
|
21 |
model = Dialog
|
22 |
|
23 |
|
|
|
24 |
class ConversationService(CommonService):
|
25 |
model = Conversation
|
api/db/services/document_service.py
CHANGED
@@ -72,7 +72,20 @@ class DocumentService(CommonService):
|
|
72 |
@classmethod
|
73 |
@DB.connection_context()
|
74 |
def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
|
75 |
-
fields = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
docs = cls.model.select(*fields) \
|
77 |
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
78 |
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
@@ -103,40 +116,64 @@ class DocumentService(CommonService):
|
|
103 |
@DB.connection_context()
|
104 |
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
105 |
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
106 |
-
|
107 |
-
|
108 |
cls.model.id == doc_id).execute()
|
109 |
-
if num == 0:
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
return num
|
112 |
|
113 |
@classmethod
|
114 |
@DB.connection_context()
|
115 |
def get_tenant_id(cls, doc_id):
|
116 |
-
docs = cls.model.select(
|
|
|
|
|
|
|
|
|
117 |
docs = docs.dicts()
|
118 |
-
if not docs:
|
|
|
119 |
return docs[0]["tenant_id"]
|
120 |
|
121 |
@classmethod
|
122 |
@DB.connection_context()
|
123 |
def get_thumbnails(cls, docids):
|
124 |
fields = [cls.model.id, cls.model.thumbnail]
|
125 |
-
return list(cls.model.select(
|
|
|
126 |
|
127 |
@classmethod
|
128 |
@DB.connection_context()
|
129 |
def update_parser_config(cls, id, config):
|
130 |
e, d = cls.get_by_id(id)
|
131 |
-
if not e:
|
|
|
|
|
132 |
def dfs_update(old, new):
|
133 |
-
for k,v in new.items():
|
134 |
if k not in old:
|
135 |
old[k] = v
|
136 |
continue
|
137 |
if isinstance(v, dict):
|
138 |
assert isinstance(old[k], dict)
|
139 |
dfs_update(old[k], v)
|
140 |
-
else:
|
|
|
141 |
dfs_update(d.parser_config, config)
|
142 |
-
cls.update_by_id(id, {"parser_config": d.parser_config})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
@classmethod
|
73 |
@DB.connection_context()
|
74 |
def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
|
75 |
+
fields = [
|
76 |
+
cls.model.id,
|
77 |
+
cls.model.kb_id,
|
78 |
+
cls.model.parser_id,
|
79 |
+
cls.model.parser_config,
|
80 |
+
cls.model.name,
|
81 |
+
cls.model.type,
|
82 |
+
cls.model.location,
|
83 |
+
cls.model.size,
|
84 |
+
Knowledgebase.tenant_id,
|
85 |
+
Tenant.embd_id,
|
86 |
+
Tenant.img2txt_id,
|
87 |
+
Tenant.asr_id,
|
88 |
+
cls.model.update_time]
|
89 |
docs = cls.model.select(*fields) \
|
90 |
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
91 |
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
|
|
116 |
@DB.connection_context()
|
117 |
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
118 |
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
119 |
+
chunk_num=cls.model.chunk_num + chunk_num,
|
120 |
+
process_duation=cls.model.process_duation + duation).where(
|
121 |
cls.model.id == doc_id).execute()
|
122 |
+
if num == 0:
|
123 |
+
raise LookupError(
|
124 |
+
"Document not found which is supposed to be there")
|
125 |
+
num = Knowledgebase.update(
|
126 |
+
token_num=Knowledgebase.token_num +
|
127 |
+
token_num,
|
128 |
+
chunk_num=Knowledgebase.chunk_num +
|
129 |
+
chunk_num).where(
|
130 |
+
Knowledgebase.id == kb_id).execute()
|
131 |
return num
|
132 |
|
133 |
@classmethod
|
134 |
@DB.connection_context()
|
135 |
def get_tenant_id(cls, doc_id):
|
136 |
+
docs = cls.model.select(
|
137 |
+
Knowledgebase.tenant_id).join(
|
138 |
+
Knowledgebase, on=(
|
139 |
+
Knowledgebase.id == cls.model.kb_id)).where(
|
140 |
+
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
141 |
docs = docs.dicts()
|
142 |
+
if not docs:
|
143 |
+
return
|
144 |
return docs[0]["tenant_id"]
|
145 |
|
146 |
@classmethod
|
147 |
@DB.connection_context()
|
148 |
def get_thumbnails(cls, docids):
|
149 |
fields = [cls.model.id, cls.model.thumbnail]
|
150 |
+
return list(cls.model.select(
|
151 |
+
*fields).where(cls.model.id.in_(docids)).dicts())
|
152 |
|
153 |
@classmethod
|
154 |
@DB.connection_context()
|
155 |
def update_parser_config(cls, id, config):
|
156 |
e, d = cls.get_by_id(id)
|
157 |
+
if not e:
|
158 |
+
raise LookupError(f"Document({id}) not found.")
|
159 |
+
|
160 |
def dfs_update(old, new):
|
161 |
+
for k, v in new.items():
|
162 |
if k not in old:
|
163 |
old[k] = v
|
164 |
continue
|
165 |
if isinstance(v, dict):
|
166 |
assert isinstance(old[k], dict)
|
167 |
dfs_update(old[k], v)
|
168 |
+
else:
|
169 |
+
old[k] = v
|
170 |
dfs_update(d.parser_config, config)
|
171 |
+
cls.update_by_id(id, {"parser_config": d.parser_config})
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
@DB.connection_context()
|
175 |
+
def get_doc_count(cls, tenant_id):
|
176 |
+
docs = cls.model.select(cls.model.id).join(Knowledgebase,
|
177 |
+
on=(Knowledgebase.id == cls.model.kb_id)).where(
|
178 |
+
Knowledgebase.tenant_id == tenant_id)
|
179 |
+
return len(docs)
|
api/db/services/knowledgebase_service.py
CHANGED
@@ -55,7 +55,7 @@ class KnowledgebaseService(CommonService):
|
|
55 |
cls.model.chunk_num,
|
56 |
cls.model.parser_id,
|
57 |
cls.model.parser_config]
|
58 |
-
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
59 |
(cls.model.id == kb_id),
|
60 |
(cls.model.status == StatusEnum.VALID.value)
|
61 |
)
|
@@ -69,9 +69,11 @@ class KnowledgebaseService(CommonService):
|
|
69 |
@DB.connection_context()
|
70 |
def update_parser_config(cls, id, config):
|
71 |
e, m = cls.get_by_id(id)
|
72 |
-
if not e:
|
|
|
|
|
73 |
def dfs_update(old, new):
|
74 |
-
for k,v in new.items():
|
75 |
if k not in old:
|
76 |
old[k] = v
|
77 |
continue
|
@@ -80,12 +82,12 @@ class KnowledgebaseService(CommonService):
|
|
80 |
dfs_update(old[k], v)
|
81 |
elif isinstance(v, list):
|
82 |
assert isinstance(old[k], list)
|
83 |
-
old[k] = list(set(old[k]+v))
|
84 |
-
else:
|
|
|
85 |
dfs_update(m.parser_config, config)
|
86 |
cls.update_by_id(id, {"parser_config": m.parser_config})
|
87 |
|
88 |
-
|
89 |
@classmethod
|
90 |
@DB.connection_context()
|
91 |
def get_field_map(cls, ids):
|
@@ -94,4 +96,3 @@ class KnowledgebaseService(CommonService):
|
|
94 |
if k.parser_config and "field_map" in k.parser_config:
|
95 |
conf.update(k.parser_config["field_map"])
|
96 |
return conf
|
97 |
-
|
|
|
55 |
cls.model.chunk_num,
|
56 |
cls.model.parser_id,
|
57 |
cls.model.parser_config]
|
58 |
+
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
59 |
(cls.model.id == kb_id),
|
60 |
(cls.model.status == StatusEnum.VALID.value)
|
61 |
)
|
|
|
69 |
@DB.connection_context()
|
70 |
def update_parser_config(cls, id, config):
|
71 |
e, m = cls.get_by_id(id)
|
72 |
+
if not e:
|
73 |
+
raise LookupError(f"knowledgebase({id}) not found.")
|
74 |
+
|
75 |
def dfs_update(old, new):
|
76 |
+
for k, v in new.items():
|
77 |
if k not in old:
|
78 |
old[k] = v
|
79 |
continue
|
|
|
82 |
dfs_update(old[k], v)
|
83 |
elif isinstance(v, list):
|
84 |
assert isinstance(old[k], list)
|
85 |
+
old[k] = list(set(old[k] + v))
|
86 |
+
else:
|
87 |
+
old[k] = v
|
88 |
dfs_update(m.parser_config, config)
|
89 |
cls.update_by_id(id, {"parser_config": m.parser_config})
|
90 |
|
|
|
91 |
@classmethod
|
92 |
@DB.connection_context()
|
93 |
def get_field_map(cls, ids):
|
|
|
96 |
if k.parser_config and "field_map" in k.parser_config:
|
97 |
conf.update(k.parser_config["field_map"])
|
98 |
return conf
|
|
api/db/services/llm_service.py
CHANGED
@@ -59,7 +59,8 @@ class TenantLLMService(CommonService):
|
|
59 |
|
60 |
@classmethod
|
61 |
@DB.connection_context()
|
62 |
-
def model_instance(cls, tenant_id, llm_type,
|
|
|
63 |
e, tenant = TenantService.get_by_id(tenant_id)
|
64 |
if not e:
|
65 |
raise LookupError("Tenant not found")
|
@@ -126,29 +127,39 @@ class LLMBundle(object):
|
|
126 |
self.tenant_id = tenant_id
|
127 |
self.llm_type = llm_type
|
128 |
self.llm_name = llm_name
|
129 |
-
self.mdl = TenantLLMService.model_instance(
|
130 |
-
|
|
|
|
|
131 |
|
132 |
def encode(self, texts: list, batch_size=32):
|
133 |
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
134 |
-
if TenantLLMService.increase_usage(
|
135 |
-
|
|
|
|
|
136 |
return emd, used_tokens
|
137 |
|
138 |
def encode_queries(self, query: str):
|
139 |
emd, used_tokens = self.mdl.encode_queries(query)
|
140 |
-
if TenantLLMService.increase_usage(
|
141 |
-
|
|
|
|
|
142 |
return emd, used_tokens
|
143 |
|
144 |
def describe(self, image, max_tokens=300):
|
145 |
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
146 |
-
if not TenantLLMService.increase_usage(
|
147 |
-
|
|
|
|
|
148 |
return txt
|
149 |
|
150 |
def chat(self, system, history, gen_conf):
|
151 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
152 |
-
if TenantLLMService.increase_usage(
|
153 |
-
|
|
|
|
|
154 |
return txt
|
|
|
59 |
|
60 |
@classmethod
|
61 |
@DB.connection_context()
|
62 |
+
def model_instance(cls, tenant_id, llm_type,
|
63 |
+
llm_name=None, lang="Chinese"):
|
64 |
e, tenant = TenantService.get_by_id(tenant_id)
|
65 |
if not e:
|
66 |
raise LookupError("Tenant not found")
|
|
|
127 |
self.tenant_id = tenant_id
|
128 |
self.llm_type = llm_type
|
129 |
self.llm_name = llm_name
|
130 |
+
self.mdl = TenantLLMService.model_instance(
|
131 |
+
tenant_id, llm_type, llm_name, lang=lang)
|
132 |
+
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
133 |
+
tenant_id, llm_type, llm_name)
|
134 |
|
135 |
def encode(self, texts: list, batch_size=32):
|
136 |
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
137 |
+
if TenantLLMService.increase_usage(
|
138 |
+
self.tenant_id, self.llm_type, used_tokens):
|
139 |
+
database_logger.error(
|
140 |
+
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
141 |
return emd, used_tokens
|
142 |
|
143 |
def encode_queries(self, query: str):
|
144 |
emd, used_tokens = self.mdl.encode_queries(query)
|
145 |
+
if TenantLLMService.increase_usage(
|
146 |
+
self.tenant_id, self.llm_type, used_tokens):
|
147 |
+
database_logger.error(
|
148 |
+
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
149 |
return emd, used_tokens
|
150 |
|
151 |
def describe(self, image, max_tokens=300):
|
152 |
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
153 |
+
if not TenantLLMService.increase_usage(
|
154 |
+
self.tenant_id, self.llm_type, used_tokens):
|
155 |
+
database_logger.error(
|
156 |
+
"Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
|
157 |
return txt
|
158 |
|
159 |
def chat(self, system, history, gen_conf):
|
160 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
161 |
+
if TenantLLMService.increase_usage(
|
162 |
+
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
163 |
+
database_logger.error(
|
164 |
+
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
165 |
return txt
|
api/db/services/user_service.py
CHANGED
@@ -54,7 +54,8 @@ class UserService(CommonService):
|
|
54 |
if "id" not in kwargs:
|
55 |
kwargs["id"] = get_uuid()
|
56 |
if "password" in kwargs:
|
57 |
-
kwargs["password"] = generate_password_hash(
|
|
|
58 |
|
59 |
kwargs["create_time"] = current_timestamp()
|
60 |
kwargs["create_date"] = datetime_format(datetime.now())
|
@@ -63,12 +64,12 @@ class UserService(CommonService):
|
|
63 |
obj = cls.model(**kwargs).save(force_insert=True)
|
64 |
return obj
|
65 |
|
66 |
-
|
67 |
@classmethod
|
68 |
@DB.connection_context()
|
69 |
def delete_user(cls, user_ids, update_user_dict):
|
70 |
with DB.atomic():
|
71 |
-
cls.model.update({"status": 0}).where(
|
|
|
72 |
|
73 |
@classmethod
|
74 |
@DB.connection_context()
|
@@ -77,7 +78,8 @@ class UserService(CommonService):
|
|
77 |
if user_dict:
|
78 |
user_dict["update_time"] = current_timestamp()
|
79 |
user_dict["update_date"] = datetime_format(datetime.now())
|
80 |
-
cls.model.update(user_dict).where(
|
|
|
81 |
|
82 |
|
83 |
class TenantService(CommonService):
|
@@ -86,25 +88,42 @@ class TenantService(CommonService):
|
|
86 |
@classmethod
|
87 |
@DB.connection_context()
|
88 |
def get_by_user_id(cls, user_id):
|
89 |
-
fields = [
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
@classmethod
|
95 |
@DB.connection_context()
|
96 |
def get_joined_tenants_by_user_id(cls, user_id):
|
97 |
-
fields = [
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
@classmethod
|
103 |
@DB.connection_context()
|
104 |
def decrease(cls, user_id, num):
|
105 |
num = cls.model.update(credit=cls.model.credit - num).where(
|
106 |
cls.model.id == user_id).execute()
|
107 |
-
if num == 0:
|
|
|
|
|
108 |
|
109 |
class UserTenantService(CommonService):
|
110 |
model = UserTenant
|
|
|
54 |
if "id" not in kwargs:
|
55 |
kwargs["id"] = get_uuid()
|
56 |
if "password" in kwargs:
|
57 |
+
kwargs["password"] = generate_password_hash(
|
58 |
+
str(kwargs["password"]))
|
59 |
|
60 |
kwargs["create_time"] = current_timestamp()
|
61 |
kwargs["create_date"] = datetime_format(datetime.now())
|
|
|
64 |
obj = cls.model(**kwargs).save(force_insert=True)
|
65 |
return obj
|
66 |
|
|
|
67 |
@classmethod
|
68 |
@DB.connection_context()
|
69 |
def delete_user(cls, user_ids, update_user_dict):
|
70 |
with DB.atomic():
|
71 |
+
cls.model.update({"status": 0}).where(
|
72 |
+
cls.model.id.in_(user_ids)).execute()
|
73 |
|
74 |
@classmethod
|
75 |
@DB.connection_context()
|
|
|
78 |
if user_dict:
|
79 |
user_dict["update_time"] = current_timestamp()
|
80 |
user_dict["update_date"] = datetime_format(datetime.now())
|
81 |
+
cls.model.update(user_dict).where(
|
82 |
+
cls.model.id == user_id).execute()
|
83 |
|
84 |
|
85 |
class TenantService(CommonService):
|
|
|
88 |
@classmethod
|
89 |
@DB.connection_context()
|
90 |
def get_by_user_id(cls, user_id):
|
91 |
+
fields = [
|
92 |
+
cls.model.id.alias("tenant_id"),
|
93 |
+
cls.model.name,
|
94 |
+
cls.model.llm_id,
|
95 |
+
cls.model.embd_id,
|
96 |
+
cls.model.asr_id,
|
97 |
+
cls.model.img2txt_id,
|
98 |
+
cls.model.parser_ids,
|
99 |
+
UserTenant.role]
|
100 |
+
return list(cls.model.select(*fields)
|
101 |
+
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
102 |
+
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
103 |
|
104 |
@classmethod
|
105 |
@DB.connection_context()
|
106 |
def get_joined_tenants_by_user_id(cls, user_id):
|
107 |
+
fields = [
|
108 |
+
cls.model.id.alias("tenant_id"),
|
109 |
+
cls.model.name,
|
110 |
+
cls.model.llm_id,
|
111 |
+
cls.model.embd_id,
|
112 |
+
cls.model.asr_id,
|
113 |
+
cls.model.img2txt_id,
|
114 |
+
UserTenant.role]
|
115 |
+
return list(cls.model.select(*fields)
|
116 |
+
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL.value)))
|
117 |
+
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
118 |
|
119 |
@classmethod
|
120 |
@DB.connection_context()
|
121 |
def decrease(cls, user_id, num):
|
122 |
num = cls.model.update(credit=cls.model.credit - num).where(
|
123 |
cls.model.id == user_id).execute()
|
124 |
+
if num == 0:
|
125 |
+
raise LookupError("Tenant not found which is supposed to be there")
|
126 |
+
|
127 |
|
128 |
class UserTenantService(CommonService):
|
129 |
model = UserTenant
|
api/settings.py
CHANGED
@@ -13,16 +13,22 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
|
|
16 |
import os
|
17 |
|
18 |
from enum import IntEnum, Enum
|
19 |
|
20 |
-
from api.utils import get_base_config,decrypt_database_config
|
21 |
from api.utils.file_utils import get_project_base_directory
|
22 |
from api.utils.log_utils import LoggerFactory, getLogger
|
23 |
|
24 |
# Logger
|
25 |
-
LoggerFactory.set_directory(
|
|
|
|
|
|
|
|
|
26 |
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
27 |
LoggerFactory.LEVEL = 10
|
28 |
|
@@ -86,7 +92,9 @@ default_llm = {
|
|
86 |
LLM = get_base_config("user_default_llm", {})
|
87 |
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
88 |
if LLM_FACTORY not in default_llm:
|
89 |
-
print(
|
|
|
|
|
90 |
LLM_FACTORY = "Tongyi-Qianwen"
|
91 |
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
92 |
EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
|
@@ -94,7 +102,9 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
|
94 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
95 |
|
96 |
API_KEY = LLM.get("api_key", "")
|
97 |
-
PARSERS = LLM.get(
|
|
|
|
|
98 |
|
99 |
# distribution
|
100 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
@@ -103,13 +113,25 @@ RAG_FLOW_UPDATE_CHECK = False
|
|
103 |
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
104 |
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
105 |
|
106 |
-
SECRET_KEY = get_base_config(
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
|
115 |
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
|
@@ -124,7 +146,9 @@ UPLOAD_DATA_FROM_CLIENT = True
|
|
124 |
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
125 |
|
126 |
# client
|
127 |
-
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
|
|
|
|
128 |
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
129 |
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
130 |
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
|
@@ -147,12 +171,10 @@ USE_AUTHENTICATION = False
|
|
147 |
USE_DATA_AUTHENTICATION = False
|
148 |
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
|
149 |
USE_DEFAULT_TIMEOUT = False
|
150 |
-
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60
|
151 |
PRIVILEGE_COMMAND_WHITELIST = []
|
152 |
CHECK_NODES_IDENTITY = False
|
153 |
|
154 |
-
from rag.nlp import search
|
155 |
-
from rag.utils import ELASTICSEARCH
|
156 |
retrievaler = search.Dealer(ELASTICSEARCH)
|
157 |
|
158 |
|
@@ -162,7 +184,7 @@ class CustomEnum(Enum):
|
|
162 |
try:
|
163 |
cls(value)
|
164 |
return True
|
165 |
-
except:
|
166 |
return False
|
167 |
|
168 |
@classmethod
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from rag.utils import ELASTICSEARCH
|
17 |
+
from rag.nlp import search
|
18 |
import os
|
19 |
|
20 |
from enum import IntEnum, Enum
|
21 |
|
22 |
+
from api.utils import get_base_config, decrypt_database_config
|
23 |
from api.utils.file_utils import get_project_base_directory
|
24 |
from api.utils.log_utils import LoggerFactory, getLogger
|
25 |
|
26 |
# Logger
|
27 |
+
LoggerFactory.set_directory(
|
28 |
+
os.path.join(
|
29 |
+
get_project_base_directory(),
|
30 |
+
"logs",
|
31 |
+
"api"))
|
32 |
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
33 |
LoggerFactory.LEVEL = 10
|
34 |
|
|
|
92 |
LLM = get_base_config("user_default_llm", {})
|
93 |
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
94 |
if LLM_FACTORY not in default_llm:
|
95 |
+
print(
|
96 |
+
"\33[91m【ERROR】\33[0m:",
|
97 |
+
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
98 |
LLM_FACTORY = "Tongyi-Qianwen"
|
99 |
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
100 |
EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
|
|
|
102 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
103 |
|
104 |
API_KEY = LLM.get("api_key", "")
|
105 |
+
PARSERS = LLM.get(
|
106 |
+
"parsers",
|
107 |
+
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
|
108 |
|
109 |
# distribution
|
110 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
|
|
113 |
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
114 |
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
115 |
|
116 |
+
SECRET_KEY = get_base_config(
|
117 |
+
RAG_FLOW_SERVICE_NAME,
|
118 |
+
{}).get(
|
119 |
+
"secret_key",
|
120 |
+
"infiniflow")
|
121 |
+
TOKEN_EXPIRE_IN = get_base_config(
|
122 |
+
RAG_FLOW_SERVICE_NAME, {}).get(
|
123 |
+
"token_expires_in", 3600)
|
124 |
+
|
125 |
+
NGINX_HOST = get_base_config(
|
126 |
+
RAG_FLOW_SERVICE_NAME, {}).get(
|
127 |
+
"nginx", {}).get("host") or HOST
|
128 |
+
NGINX_HTTP_PORT = get_base_config(
|
129 |
+
RAG_FLOW_SERVICE_NAME, {}).get(
|
130 |
+
"nginx", {}).get("http_port") or HTTP_PORT
|
131 |
+
|
132 |
+
RANDOM_INSTANCE_ID = get_base_config(
|
133 |
+
RAG_FLOW_SERVICE_NAME, {}).get(
|
134 |
+
"random_instance_id", False)
|
135 |
|
136 |
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
|
137 |
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
|
|
|
146 |
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
147 |
|
148 |
# client
|
149 |
+
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
150 |
+
"client", {}).get(
|
151 |
+
"switch", False)
|
152 |
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
153 |
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
154 |
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
|
|
|
171 |
USE_DATA_AUTHENTICATION = False
|
172 |
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
|
173 |
USE_DEFAULT_TIMEOUT = False
|
174 |
+
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
175 |
PRIVILEGE_COMMAND_WHITELIST = []
|
176 |
CHECK_NODES_IDENTITY = False
|
177 |
|
|
|
|
|
178 |
retrievaler = search.Dealer(ELASTICSEARCH)
|
179 |
|
180 |
|
|
|
184 |
try:
|
185 |
cls(value)
|
186 |
return True
|
187 |
+
except BaseException:
|
188 |
return False
|
189 |
|
190 |
@classmethod
|
api/utils/__init__.py
CHANGED
@@ -34,10 +34,12 @@ from . import file_utils
|
|
34 |
|
35 |
SERVICE_CONF = "service_conf.yaml"
|
36 |
|
|
|
37 |
def conf_realpath(conf_name):
|
38 |
conf_path = f"conf/{conf_name}"
|
39 |
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
40 |
|
|
|
41 |
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
42 |
local_config = {}
|
43 |
local_path = conf_realpath(f'local.{conf_name}')
|
@@ -62,7 +64,8 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
|
62 |
return config.get(key, default) if key is not None else config
|
63 |
|
64 |
|
65 |
-
use_deserialize_safe_module = get_base_config(
|
|
|
66 |
|
67 |
|
68 |
class CoordinationCommunicationProtocol(object):
|
@@ -93,7 +96,8 @@ class BaseType:
|
|
93 |
data[_k] = _dict(vv)
|
94 |
else:
|
95 |
data = obj
|
96 |
-
return {"type": obj.__class__.__name__,
|
|
|
97 |
return _dict(self)
|
98 |
|
99 |
|
@@ -129,7 +133,8 @@ def rag_uuid():
|
|
129 |
|
130 |
|
131 |
def string_to_bytes(string):
|
132 |
-
return string if isinstance(
|
|
|
133 |
|
134 |
|
135 |
def bytes_to_string(byte):
|
@@ -137,7 +142,11 @@ def bytes_to_string(byte):
|
|
137 |
|
138 |
|
139 |
def json_dumps(src, byte=False, indent=None, with_type=False):
|
140 |
-
dest = json.dumps(
|
|
|
|
|
|
|
|
|
141 |
if byte:
|
142 |
dest = string_to_bytes(dest)
|
143 |
return dest
|
@@ -146,7 +155,8 @@ def json_dumps(src, byte=False, indent=None, with_type=False):
|
|
146 |
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
147 |
if isinstance(src, bytes):
|
148 |
src = bytes_to_string(src)
|
149 |
-
return json.loads(src, object_hook=object_hook,
|
|
|
150 |
|
151 |
|
152 |
def current_timestamp():
|
@@ -177,7 +187,9 @@ def serialize_b64(src, to_str=False):
|
|
177 |
|
178 |
|
179 |
def deserialize_b64(src):
|
180 |
-
src = base64.b64decode(
|
|
|
|
|
181 |
if use_deserialize_safe_module:
|
182 |
return restricted_loads(src)
|
183 |
return pickle.loads(src)
|
@@ -237,12 +249,14 @@ def get_lan_ip():
|
|
237 |
pass
|
238 |
return ip or ''
|
239 |
|
|
|
240 |
def from_dict_hook(in_dict: dict):
|
241 |
if "type" in in_dict and "data" in in_dict:
|
242 |
if in_dict["module"] is None:
|
243 |
return in_dict["data"]
|
244 |
else:
|
245 |
-
return getattr(importlib.import_module(
|
|
|
246 |
else:
|
247 |
return in_dict
|
248 |
|
@@ -259,12 +273,16 @@ def decrypt_database_password(password):
|
|
259 |
raise ValueError("No private key")
|
260 |
|
261 |
module_fun = encrypt_module.split("#")
|
262 |
-
pwdecrypt_fun = getattr(
|
|
|
|
|
|
|
263 |
|
264 |
return pwdecrypt_fun(private_key, password)
|
265 |
|
266 |
|
267 |
-
def decrypt_database_config(
|
|
|
268 |
if not database:
|
269 |
database = get_base_config(name, {})
|
270 |
|
@@ -275,7 +293,8 @@ def decrypt_database_config(database=None, passwd_key="password", name="database
|
|
275 |
def update_config(key, value, conf_name=SERVICE_CONF):
|
276 |
conf_path = conf_realpath(conf_name=conf_name)
|
277 |
if not os.path.isabs(conf_path):
|
278 |
-
conf_path = os.path.join(
|
|
|
279 |
|
280 |
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
281 |
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
@@ -288,7 +307,8 @@ def get_uuid():
|
|
288 |
|
289 |
|
290 |
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
|
291 |
-
return datetime.datetime(date_time.year, date_time.month, date_time.day,
|
|
|
292 |
|
293 |
|
294 |
def get_format_time() -> datetime.datetime:
|
@@ -307,14 +327,19 @@ def elapsed2time(elapsed):
|
|
307 |
|
308 |
|
309 |
def decrypt(line):
|
310 |
-
file_path = os.path.join(
|
|
|
|
|
|
|
311 |
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
312 |
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
313 |
-
return cipher.decrypt(base64.b64decode(
|
|
|
314 |
|
315 |
|
316 |
def download_img(url):
|
317 |
-
if not url:
|
|
|
318 |
response = requests.get(url)
|
319 |
return "data:" + \
|
320 |
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
|
|
34 |
|
35 |
SERVICE_CONF = "service_conf.yaml"
|
36 |
|
37 |
+
|
38 |
def conf_realpath(conf_name):
|
39 |
conf_path = f"conf/{conf_name}"
|
40 |
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
41 |
|
42 |
+
|
43 |
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
44 |
local_config = {}
|
45 |
local_path = conf_realpath(f'local.{conf_name}')
|
|
|
64 |
return config.get(key, default) if key is not None else config
|
65 |
|
66 |
|
67 |
+
use_deserialize_safe_module = get_base_config(
|
68 |
+
'use_deserialize_safe_module', False)
|
69 |
|
70 |
|
71 |
class CoordinationCommunicationProtocol(object):
|
|
|
96 |
data[_k] = _dict(vv)
|
97 |
else:
|
98 |
data = obj
|
99 |
+
return {"type": obj.__class__.__name__,
|
100 |
+
"data": data, "module": module}
|
101 |
return _dict(self)
|
102 |
|
103 |
|
|
|
133 |
|
134 |
|
135 |
def string_to_bytes(string):
|
136 |
+
return string if isinstance(
|
137 |
+
string, bytes) else string.encode(encoding="utf-8")
|
138 |
|
139 |
|
140 |
def bytes_to_string(byte):
|
|
|
142 |
|
143 |
|
144 |
def json_dumps(src, byte=False, indent=None, with_type=False):
|
145 |
+
dest = json.dumps(
|
146 |
+
src,
|
147 |
+
indent=indent,
|
148 |
+
cls=CustomJSONEncoder,
|
149 |
+
with_type=with_type)
|
150 |
if byte:
|
151 |
dest = string_to_bytes(dest)
|
152 |
return dest
|
|
|
155 |
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
156 |
if isinstance(src, bytes):
|
157 |
src = bytes_to_string(src)
|
158 |
+
return json.loads(src, object_hook=object_hook,
|
159 |
+
object_pairs_hook=object_pairs_hook)
|
160 |
|
161 |
|
162 |
def current_timestamp():
|
|
|
187 |
|
188 |
|
189 |
def deserialize_b64(src):
|
190 |
+
src = base64.b64decode(
|
191 |
+
string_to_bytes(src) if isinstance(
|
192 |
+
src, str) else src)
|
193 |
if use_deserialize_safe_module:
|
194 |
return restricted_loads(src)
|
195 |
return pickle.loads(src)
|
|
|
249 |
pass
|
250 |
return ip or ''
|
251 |
|
252 |
+
|
253 |
def from_dict_hook(in_dict: dict):
|
254 |
if "type" in in_dict and "data" in in_dict:
|
255 |
if in_dict["module"] is None:
|
256 |
return in_dict["data"]
|
257 |
else:
|
258 |
+
return getattr(importlib.import_module(
|
259 |
+
in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
260 |
else:
|
261 |
return in_dict
|
262 |
|
|
|
273 |
raise ValueError("No private key")
|
274 |
|
275 |
module_fun = encrypt_module.split("#")
|
276 |
+
pwdecrypt_fun = getattr(
|
277 |
+
importlib.import_module(
|
278 |
+
module_fun[0]),
|
279 |
+
module_fun[1])
|
280 |
|
281 |
return pwdecrypt_fun(private_key, password)
|
282 |
|
283 |
|
284 |
+
def decrypt_database_config(
|
285 |
+
database=None, passwd_key="password", name="database"):
|
286 |
if not database:
|
287 |
database = get_base_config(name, {})
|
288 |
|
|
|
293 |
def update_config(key, value, conf_name=SERVICE_CONF):
|
294 |
conf_path = conf_realpath(conf_name=conf_name)
|
295 |
if not os.path.isabs(conf_path):
|
296 |
+
conf_path = os.path.join(
|
297 |
+
file_utils.get_project_base_directory(), conf_path)
|
298 |
|
299 |
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
300 |
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
|
307 |
|
308 |
|
309 |
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
|
310 |
+
return datetime.datetime(date_time.year, date_time.month, date_time.day,
|
311 |
+
date_time.hour, date_time.minute, date_time.second)
|
312 |
|
313 |
|
314 |
def get_format_time() -> datetime.datetime:
|
|
|
327 |
|
328 |
|
329 |
def decrypt(line):
|
330 |
+
file_path = os.path.join(
|
331 |
+
file_utils.get_project_base_directory(),
|
332 |
+
"conf",
|
333 |
+
"private.pem")
|
334 |
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
335 |
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
336 |
+
return cipher.decrypt(base64.b64decode(
|
337 |
+
line), "Fail to decrypt password!").decode('utf-8')
|
338 |
|
339 |
|
340 |
def download_img(url):
|
341 |
+
if not url:
|
342 |
+
return ""
|
343 |
response = requests.get(url)
|
344 |
return "data:" + \
|
345 |
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
api/utils/api_utils.py
CHANGED
@@ -19,7 +19,7 @@ import time
|
|
19 |
from functools import wraps
|
20 |
from io import BytesIO
|
21 |
from flask import (
|
22 |
-
Response, jsonify, send_file,make_response,
|
23 |
request as flask_request,
|
24 |
)
|
25 |
from werkzeug.http import HTTP_STATUS_CODES
|
@@ -29,7 +29,7 @@ from api.versions import get_rag_version
|
|
29 |
from api.settings import RetCode
|
30 |
from api.settings import (
|
31 |
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
32 |
-
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
33 |
)
|
34 |
import requests
|
35 |
import functools
|
@@ -40,14 +40,21 @@ from hmac import HMAC
|
|
40 |
from urllib.parse import quote, urlencode
|
41 |
|
42 |
|
43 |
-
requests.models.complexjson.dumps = functools.partial(
|
|
|
44 |
|
45 |
|
46 |
def request(**kwargs):
|
47 |
sess = requests.Session()
|
48 |
stream = kwargs.pop('stream', sess.stream)
|
49 |
timeout = kwargs.pop('timeout', None)
|
50 |
-
kwargs['headers'] = {
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
prepped = requests.Request(**kwargs).prepare()
|
52 |
|
53 |
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
@@ -59,7 +66,11 @@ def request(**kwargs):
|
|
59 |
HTTP_APP_KEY.encode('ascii'),
|
60 |
prepped.path_url.encode('ascii'),
|
61 |
prepped.body if kwargs.get('json') else b'',
|
62 |
-
urlencode(
|
|
|
|
|
|
|
|
|
63 |
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
|
64 |
]), 'sha1').digest()).decode('ascii')
|
65 |
|
@@ -88,11 +99,12 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
|
|
88 |
return max(0, countdown)
|
89 |
|
90 |
|
91 |
-
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
|
|
|
92 |
import re
|
93 |
result_dict = {
|
94 |
"retcode": retcode,
|
95 |
-
"retmsg":retmsg,
|
96 |
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
|
97 |
"data": data,
|
98 |
"jobId": job_id,
|
@@ -107,9 +119,17 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id
|
|
107 |
response[key] = value
|
108 |
return jsonify(response)
|
109 |
|
110 |
-
|
|
|
|
|
111 |
import re
|
112 |
-
result_dict = {
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
response = {}
|
114 |
for key, value in result_dict.items():
|
115 |
if value is None and key != "retcode":
|
@@ -118,15 +138,17 @@ def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missin
|
|
118 |
response[key] = value
|
119 |
return jsonify(response)
|
120 |
|
|
|
121 |
def server_error_response(e):
|
122 |
stat_logger.exception(e)
|
123 |
try:
|
124 |
-
if e.code==401:
|
125 |
return get_json_result(retcode=401, retmsg=repr(e))
|
126 |
-
except:
|
127 |
pass
|
128 |
if len(e.args) > 1:
|
129 |
-
return get_json_result(
|
|
|
130 |
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
|
131 |
|
132 |
|
@@ -162,10 +184,13 @@ def validate_request(*args, **kwargs):
|
|
162 |
if no_arguments or error_arguments:
|
163 |
error_string = ""
|
164 |
if no_arguments:
|
165 |
-
error_string += "required argument are missing: {}; ".format(
|
|
|
166 |
if error_arguments:
|
167 |
-
error_string += "required argument values: {}".format(
|
168 |
-
|
|
|
|
|
169 |
return func(*_args, **_kwargs)
|
170 |
return decorated_function
|
171 |
return wrapper
|
@@ -193,7 +218,8 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
|
|
193 |
return jsonify(response)
|
194 |
|
195 |
|
196 |
-
def cors_reponse(retcode=RetCode.SUCCESS,
|
|
|
197 |
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
|
198 |
response_dict = {}
|
199 |
for key, value in result_dict.items():
|
@@ -209,4 +235,4 @@ def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None
|
|
209 |
response.headers["Access-Control-Allow-Headers"] = "*"
|
210 |
response.headers["Access-Control-Allow-Headers"] = "*"
|
211 |
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
212 |
-
return response
|
|
|
19 |
from functools import wraps
|
20 |
from io import BytesIO
|
21 |
from flask import (
|
22 |
+
Response, jsonify, send_file, make_response,
|
23 |
request as flask_request,
|
24 |
)
|
25 |
from werkzeug.http import HTTP_STATUS_CODES
|
|
|
29 |
from api.settings import RetCode
|
30 |
from api.settings import (
|
31 |
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
32 |
+
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
33 |
)
|
34 |
import requests
|
35 |
import functools
|
|
|
40 |
from urllib.parse import quote, urlencode
|
41 |
|
42 |
|
43 |
+
requests.models.complexjson.dumps = functools.partial(
|
44 |
+
json.dumps, cls=CustomJSONEncoder)
|
45 |
|
46 |
|
47 |
def request(**kwargs):
|
48 |
sess = requests.Session()
|
49 |
stream = kwargs.pop('stream', sess.stream)
|
50 |
timeout = kwargs.pop('timeout', None)
|
51 |
+
kwargs['headers'] = {
|
52 |
+
k.replace(
|
53 |
+
'_',
|
54 |
+
'-').upper(): v for k,
|
55 |
+
v in kwargs.get(
|
56 |
+
'headers',
|
57 |
+
{}).items()}
|
58 |
prepped = requests.Request(**kwargs).prepare()
|
59 |
|
60 |
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
|
|
66 |
HTTP_APP_KEY.encode('ascii'),
|
67 |
prepped.path_url.encode('ascii'),
|
68 |
prepped.body if kwargs.get('json') else b'',
|
69 |
+
urlencode(
|
70 |
+
sorted(
|
71 |
+
kwargs['data'].items()),
|
72 |
+
quote_via=quote,
|
73 |
+
safe='-._~').encode('ascii')
|
74 |
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
|
75 |
]), 'sha1').digest()).decode('ascii')
|
76 |
|
|
|
99 |
return max(0, countdown)
|
100 |
|
101 |
|
102 |
+
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
|
103 |
+
data=None, job_id=None, meta=None):
|
104 |
import re
|
105 |
result_dict = {
|
106 |
"retcode": retcode,
|
107 |
+
"retmsg": retmsg,
|
108 |
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
|
109 |
"data": data,
|
110 |
"jobId": job_id,
|
|
|
119 |
response[key] = value
|
120 |
return jsonify(response)
|
121 |
|
122 |
+
|
123 |
+
def get_data_error_result(retcode=RetCode.DATA_ERROR,
|
124 |
+
retmsg='Sorry! Data missing!'):
|
125 |
import re
|
126 |
+
result_dict = {
|
127 |
+
"retcode": retcode,
|
128 |
+
"retmsg": re.sub(
|
129 |
+
r"rag",
|
130 |
+
"seceum",
|
131 |
+
retmsg,
|
132 |
+
flags=re.IGNORECASE)}
|
133 |
response = {}
|
134 |
for key, value in result_dict.items():
|
135 |
if value is None and key != "retcode":
|
|
|
138 |
response[key] = value
|
139 |
return jsonify(response)
|
140 |
|
141 |
+
|
142 |
def server_error_response(e):
|
143 |
stat_logger.exception(e)
|
144 |
try:
|
145 |
+
if e.code == 401:
|
146 |
return get_json_result(retcode=401, retmsg=repr(e))
|
147 |
+
except BaseException:
|
148 |
pass
|
149 |
if len(e.args) > 1:
|
150 |
+
return get_json_result(
|
151 |
+
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
|
152 |
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
|
153 |
|
154 |
|
|
|
184 |
if no_arguments or error_arguments:
|
185 |
error_string = ""
|
186 |
if no_arguments:
|
187 |
+
error_string += "required argument are missing: {}; ".format(
|
188 |
+
",".join(no_arguments))
|
189 |
if error_arguments:
|
190 |
+
error_string += "required argument values: {}".format(
|
191 |
+
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
192 |
+
return get_json_result(
|
193 |
+
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
|
194 |
return func(*_args, **_kwargs)
|
195 |
return decorated_function
|
196 |
return wrapper
|
|
|
218 |
return jsonify(response)
|
219 |
|
220 |
|
221 |
+
def cors_reponse(retcode=RetCode.SUCCESS,
|
222 |
+
retmsg='success', data=None, auth=None):
|
223 |
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
|
224 |
response_dict = {}
|
225 |
for key, value in result_dict.items():
|
|
|
235 |
response.headers["Access-Control-Allow-Headers"] = "*"
|
236 |
response.headers["Access-Control-Allow-Headers"] = "*"
|
237 |
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
238 |
+
return response
|
api/utils/file_utils.py
CHANGED
@@ -29,6 +29,7 @@ from api.db import FileType
|
|
29 |
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
30 |
RAG_BASE = os.getenv("RAG_BASE")
|
31 |
|
|
|
32 |
def get_project_base_directory(*args):
|
33 |
global PROJECT_BASE
|
34 |
if PROJECT_BASE is None:
|
@@ -65,7 +66,6 @@ def get_rag_python_directory(*args):
|
|
65 |
return get_rag_directory("python", *args)
|
66 |
|
67 |
|
68 |
-
|
69 |
@cached(cache=LRUCache(maxsize=10))
|
70 |
def load_json_conf(conf_path):
|
71 |
if os.path.isabs(conf_path):
|
@@ -146,10 +146,12 @@ def filename_type(filename):
|
|
146 |
if re.match(r".*\.pdf$", filename):
|
147 |
return FileType.PDF.value
|
148 |
|
149 |
-
if re.match(
|
|
|
150 |
return FileType.DOC.value
|
151 |
|
152 |
-
if re.match(
|
|
|
153 |
return FileType.AURAL.value
|
154 |
|
155 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
@@ -164,14 +166,16 @@ def thumbnail(filename, blob):
|
|
164 |
buffered = BytesIO()
|
165 |
Image.frombytes("RGB", [pix.width, pix.height],
|
166 |
pix.samples).save(buffered, format="png")
|
167 |
-
return "data:image/png;base64," +
|
|
|
168 |
|
169 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
|
170 |
image = Image.open(BytesIO(blob))
|
171 |
image.thumbnail((30, 30))
|
172 |
buffered = BytesIO()
|
173 |
image.save(buffered, format="png")
|
174 |
-
return "data:image/png;base64," +
|
|
|
175 |
|
176 |
if re.match(r".*\.(ppt|pptx)$", filename):
|
177 |
import aspose.slides as slides
|
@@ -179,8 +183,10 @@ def thumbnail(filename, blob):
|
|
179 |
try:
|
180 |
with slides.Presentation(BytesIO(blob)) as presentation:
|
181 |
buffered = BytesIO()
|
182 |
-
presentation.slides[0].get_thumbnail(0.03, 0.03).save(
|
183 |
-
|
|
|
|
|
184 |
except Exception as e:
|
185 |
pass
|
186 |
|
@@ -190,6 +196,3 @@ def traversal_files(base):
|
|
190 |
for f in fs:
|
191 |
fullname = os.path.join(root, f)
|
192 |
yield fullname
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
29 |
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
30 |
RAG_BASE = os.getenv("RAG_BASE")
|
31 |
|
32 |
+
|
33 |
def get_project_base_directory(*args):
|
34 |
global PROJECT_BASE
|
35 |
if PROJECT_BASE is None:
|
|
|
66 |
return get_rag_directory("python", *args)
|
67 |
|
68 |
|
|
|
69 |
@cached(cache=LRUCache(maxsize=10))
|
70 |
def load_json_conf(conf_path):
|
71 |
if os.path.isabs(conf_path):
|
|
|
146 |
if re.match(r".*\.pdf$", filename):
|
147 |
return FileType.PDF.value
|
148 |
|
149 |
+
if re.match(
|
150 |
+
r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
151 |
return FileType.DOC.value
|
152 |
|
153 |
+
if re.match(
|
154 |
+
r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
155 |
return FileType.AURAL.value
|
156 |
|
157 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
|
|
166 |
buffered = BytesIO()
|
167 |
Image.frombytes("RGB", [pix.width, pix.height],
|
168 |
pix.samples).save(buffered, format="png")
|
169 |
+
return "data:image/png;base64," + \
|
170 |
+
base64.b64encode(buffered.getvalue()).decode("utf-8")
|
171 |
|
172 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
|
173 |
image = Image.open(BytesIO(blob))
|
174 |
image.thumbnail((30, 30))
|
175 |
buffered = BytesIO()
|
176 |
image.save(buffered, format="png")
|
177 |
+
return "data:image/png;base64," + \
|
178 |
+
base64.b64encode(buffered.getvalue()).decode("utf-8")
|
179 |
|
180 |
if re.match(r".*\.(ppt|pptx)$", filename):
|
181 |
import aspose.slides as slides
|
|
|
183 |
try:
|
184 |
with slides.Presentation(BytesIO(blob)) as presentation:
|
185 |
buffered = BytesIO()
|
186 |
+
presentation.slides[0].get_thumbnail(0.03, 0.03).save(
|
187 |
+
buffered, drawing.imaging.ImageFormat.png)
|
188 |
+
return "data:image/png;base64," + \
|
189 |
+
base64.b64encode(buffered.getvalue()).decode("utf-8")
|
190 |
except Exception as e:
|
191 |
pass
|
192 |
|
|
|
196 |
for f in fs:
|
197 |
fullname = os.path.join(root, f)
|
198 |
yield fullname
|
|
|
|
|
|
api/utils/log_utils.py
CHANGED
@@ -23,6 +23,7 @@ from threading import RLock
|
|
23 |
|
24 |
from api.utils import file_utils
|
25 |
|
|
|
26 |
class LoggerFactory(object):
|
27 |
TYPE = "FILE"
|
28 |
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
|
@@ -49,7 +50,8 @@ class LoggerFactory(object):
|
|
49 |
schedule_logger_dict = {}
|
50 |
|
51 |
@staticmethod
|
52 |
-
def set_directory(directory=None, parent_log_dir=None,
|
|
|
53 |
if parent_log_dir:
|
54 |
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
|
55 |
if append_to_parent_log:
|
@@ -66,11 +68,13 @@ class LoggerFactory(object):
|
|
66 |
else:
|
67 |
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
|
68 |
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
|
69 |
-
for className, (logger,
|
|
|
70 |
logger.removeHandler(ghandler)
|
71 |
ghandler.close()
|
72 |
LoggerFactory.global_handler_dict = {}
|
73 |
-
for className, (logger,
|
|
|
74 |
logger.removeHandler(handler)
|
75 |
_handler = None
|
76 |
if handler:
|
@@ -111,19 +115,23 @@ class LoggerFactory(object):
|
|
111 |
if logger_name_key not in LoggerFactory.global_handler_dict:
|
112 |
with LoggerFactory.lock:
|
113 |
if logger_name_key not in LoggerFactory.global_handler_dict:
|
114 |
-
handler = LoggerFactory.get_handler(
|
|
|
115 |
LoggerFactory.global_handler_dict[logger_name_key] = handler
|
116 |
return LoggerFactory.global_handler_dict[logger_name_key]
|
117 |
|
118 |
@staticmethod
|
119 |
-
def get_handler(class_name, level=None, log_dir=None,
|
|
|
120 |
if not log_type:
|
121 |
if not LoggerFactory.LOG_DIR or not class_name:
|
122 |
return logging.StreamHandler()
|
123 |
# return Diy_StreamHandler()
|
124 |
|
125 |
if not log_dir:
|
126 |
-
log_file = os.path.join(
|
|
|
|
|
127 |
else:
|
128 |
log_file = os.path.join(log_dir, "{}.log".format(class_name))
|
129 |
else:
|
@@ -133,16 +141,16 @@ class LoggerFactory(object):
|
|
133 |
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
134 |
if LoggerFactory.log_share:
|
135 |
handler = ROpenHandler(log_file,
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
else:
|
141 |
handler = TimedRotatingFileHandler(log_file,
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
if level:
|
147 |
handler.level = level
|
148 |
|
@@ -170,7 +178,9 @@ class LoggerFactory(object):
|
|
170 |
for level in LoggerFactory.levels:
|
171 |
if level >= LoggerFactory.LEVEL:
|
172 |
level_logger_name = logging._levelToName[level]
|
173 |
-
logger.addHandler(
|
|
|
|
|
174 |
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
|
175 |
for level in LoggerFactory.levels:
|
176 |
if level >= LoggerFactory.LEVEL:
|
@@ -224,22 +234,26 @@ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
|
224 |
return f"{prefix}start to {msg}{suffix}"
|
225 |
|
226 |
|
227 |
-
def successful_log(msg, job=None, task=None, role=None,
|
|
|
228 |
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
229 |
return f"{prefix}{msg} successfully{suffix}"
|
230 |
|
231 |
|
232 |
-
def warning_log(msg, job=None, task=None, role=None,
|
|
|
233 |
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
234 |
return f"{prefix}{msg} is not effective{suffix}"
|
235 |
|
236 |
|
237 |
-
def failed_log(msg, job=None, task=None, role=None,
|
|
|
238 |
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
239 |
return f"{prefix}failed to {msg}{suffix}"
|
240 |
|
241 |
|
242 |
-
def base_msg(job=None, task=None, role: str = None,
|
|
|
243 |
if detail:
|
244 |
detail_msg = f" detail: \n{detail}"
|
245 |
else:
|
@@ -285,10 +299,14 @@ def get_job_logger(job_id, log_type):
|
|
285 |
for job_log_dir in log_dirs:
|
286 |
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
|
287 |
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
288 |
-
error_handler = LoggerFactory.get_handler(
|
|
|
|
|
|
|
|
|
|
|
289 |
logger.addHandler(handler)
|
290 |
logger.addHandler(error_handler)
|
291 |
with LoggerFactory.lock:
|
292 |
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
|
293 |
return logger
|
294 |
-
|
|
|
23 |
|
24 |
from api.utils import file_utils
|
25 |
|
26 |
+
|
27 |
class LoggerFactory(object):
|
28 |
TYPE = "FILE"
|
29 |
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
|
|
|
50 |
schedule_logger_dict = {}
|
51 |
|
52 |
@staticmethod
|
53 |
+
def set_directory(directory=None, parent_log_dir=None,
|
54 |
+
append_to_parent_log=None, force=False):
|
55 |
if parent_log_dir:
|
56 |
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
|
57 |
if append_to_parent_log:
|
|
|
68 |
else:
|
69 |
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
|
70 |
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
|
71 |
+
for className, (logger,
|
72 |
+
handler) in LoggerFactory.logger_dict.items():
|
73 |
logger.removeHandler(ghandler)
|
74 |
ghandler.close()
|
75 |
LoggerFactory.global_handler_dict = {}
|
76 |
+
for className, (logger,
|
77 |
+
handler) in LoggerFactory.logger_dict.items():
|
78 |
logger.removeHandler(handler)
|
79 |
_handler = None
|
80 |
if handler:
|
|
|
115 |
if logger_name_key not in LoggerFactory.global_handler_dict:
|
116 |
with LoggerFactory.lock:
|
117 |
if logger_name_key not in LoggerFactory.global_handler_dict:
|
118 |
+
handler = LoggerFactory.get_handler(
|
119 |
+
logger_name, level, log_dir)
|
120 |
LoggerFactory.global_handler_dict[logger_name_key] = handler
|
121 |
return LoggerFactory.global_handler_dict[logger_name_key]
|
122 |
|
123 |
@staticmethod
|
124 |
+
def get_handler(class_name, level=None, log_dir=None,
|
125 |
+
log_type=None, job_id=None):
|
126 |
if not log_type:
|
127 |
if not LoggerFactory.LOG_DIR or not class_name:
|
128 |
return logging.StreamHandler()
|
129 |
# return Diy_StreamHandler()
|
130 |
|
131 |
if not log_dir:
|
132 |
+
log_file = os.path.join(
|
133 |
+
LoggerFactory.LOG_DIR,
|
134 |
+
"{}.log".format(class_name))
|
135 |
else:
|
136 |
log_file = os.path.join(log_dir, "{}.log".format(class_name))
|
137 |
else:
|
|
|
141 |
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
142 |
if LoggerFactory.log_share:
|
143 |
handler = ROpenHandler(log_file,
|
144 |
+
when='D',
|
145 |
+
interval=1,
|
146 |
+
backupCount=14,
|
147 |
+
delay=True)
|
148 |
else:
|
149 |
handler = TimedRotatingFileHandler(log_file,
|
150 |
+
when='D',
|
151 |
+
interval=1,
|
152 |
+
backupCount=14,
|
153 |
+
delay=True)
|
154 |
if level:
|
155 |
handler.level = level
|
156 |
|
|
|
178 |
for level in LoggerFactory.levels:
|
179 |
if level >= LoggerFactory.LEVEL:
|
180 |
level_logger_name = logging._levelToName[level]
|
181 |
+
logger.addHandler(
|
182 |
+
LoggerFactory.get_global_handler(
|
183 |
+
level_logger_name, level))
|
184 |
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
|
185 |
for level in LoggerFactory.levels:
|
186 |
if level >= LoggerFactory.LEVEL:
|
|
|
234 |
return f"{prefix}start to {msg}{suffix}"
|
235 |
|
236 |
|
237 |
+
def successful_log(msg, job=None, task=None, role=None,
|
238 |
+
party_id=None, detail=None):
|
239 |
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
240 |
return f"{prefix}{msg} successfully{suffix}"
|
241 |
|
242 |
|
243 |
+
def warning_log(msg, job=None, task=None, role=None,
|
244 |
+
party_id=None, detail=None):
|
245 |
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
246 |
return f"{prefix}{msg} is not effective{suffix}"
|
247 |
|
248 |
|
249 |
+
def failed_log(msg, job=None, task=None, role=None,
|
250 |
+
party_id=None, detail=None):
|
251 |
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
252 |
return f"{prefix}failed to {msg}{suffix}"
|
253 |
|
254 |
|
255 |
+
def base_msg(job=None, task=None, role: str = None,
|
256 |
+
party_id: typing.Union[str, int] = None, detail=None):
|
257 |
if detail:
|
258 |
detail_msg = f" detail: \n{detail}"
|
259 |
else:
|
|
|
299 |
for job_log_dir in log_dirs:
|
300 |
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
|
301 |
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
302 |
+
error_handler = LoggerFactory.get_handler(
|
303 |
+
class_name=None,
|
304 |
+
level=logging.ERROR,
|
305 |
+
log_dir=job_log_dir,
|
306 |
+
log_type=log_type,
|
307 |
+
job_id=job_id)
|
308 |
logger.addHandler(handler)
|
309 |
logger.addHandler(error_handler)
|
310 |
with LoggerFactory.lock:
|
311 |
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
|
312 |
return logger
|
|
api/utils/t_crypt.py
CHANGED
@@ -1,18 +1,23 @@
|
|
1 |
-
import base64
|
|
|
|
|
2 |
from Cryptodome.PublicKey import RSA
|
3 |
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
4 |
from api.utils import decrypt, file_utils
|
5 |
|
|
|
6 |
def crypt(line):
|
7 |
-
file_path = os.path.join(
|
|
|
|
|
|
|
8 |
rsa_key = RSA.importKey(open(file_path).read())
|
9 |
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
10 |
-
return base64.b64encode(cipher.encrypt(
|
11 |
-
|
12 |
|
13 |
|
14 |
if __name__ == "__main__":
|
15 |
pswd = crypt(sys.argv[1])
|
16 |
print(pswd)
|
17 |
print(decrypt(pswd))
|
18 |
-
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
from Cryptodome.PublicKey import RSA
|
5 |
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
6 |
from api.utils import decrypt, file_utils
|
7 |
|
8 |
+
|
9 |
def crypt(line):
|
10 |
+
file_path = os.path.join(
|
11 |
+
file_utils.get_project_base_directory(),
|
12 |
+
"conf",
|
13 |
+
"public.pem")
|
14 |
rsa_key = RSA.importKey(open(file_path).read())
|
15 |
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
16 |
+
return base64.b64encode(cipher.encrypt(
|
17 |
+
line.encode('utf-8'))).decode("utf-8")
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
21 |
pswd = crypt(sys.argv[1])
|
22 |
print(pswd)
|
23 |
print(decrypt(pswd))
|
|
deepdoc/parser/__init__.py
CHANGED
@@ -4,5 +4,3 @@ from .pdf_parser import HuParser as PdfParser, PlainParser
|
|
4 |
from .docx_parser import HuDocxParser as DocxParser
|
5 |
from .excel_parser import HuExcelParser as ExcelParser
|
6 |
from .ppt_parser import HuPptParser as PptParser
|
7 |
-
|
8 |
-
|
|
|
4 |
from .docx_parser import HuDocxParser as DocxParser
|
5 |
from .excel_parser import HuExcelParser as ExcelParser
|
6 |
from .ppt_parser import HuPptParser as PptParser
|
|
|
|
deepdoc/parser/docx_parser.py
CHANGED
@@ -99,12 +99,15 @@ class HuDocxParser:
|
|
99 |
return ["\n".join(lines)]
|
100 |
|
101 |
def __call__(self, fnm, from_page=0, to_page=100000):
|
102 |
-
self.doc = Document(fnm) if isinstance(
|
|
|
103 |
pn = 0
|
104 |
secs = []
|
105 |
for p in self.doc.paragraphs:
|
106 |
-
if pn > to_page:
|
107 |
-
|
|
|
|
|
108 |
for run in p.runs:
|
109 |
if 'lastRenderedPageBreak' in run._element.xml:
|
110 |
pn += 1
|
|
|
99 |
return ["\n".join(lines)]
|
100 |
|
101 |
def __call__(self, fnm, from_page=0, to_page=100000):
|
102 |
+
self.doc = Document(fnm) if isinstance(
|
103 |
+
fnm, str) else Document(BytesIO(fnm))
|
104 |
pn = 0
|
105 |
secs = []
|
106 |
for p in self.doc.paragraphs:
|
107 |
+
if pn > to_page:
|
108 |
+
break
|
109 |
+
if from_page <= pn < to_page and p.text.strip():
|
110 |
+
secs.append((p.text, p.style.name))
|
111 |
for run in p.runs:
|
112 |
if 'lastRenderedPageBreak' in run._element.xml:
|
113 |
pn += 1
|
deepdoc/parser/excel_parser.py
CHANGED
@@ -15,13 +15,16 @@ class HuExcelParser:
|
|
15 |
ws = wb[sheetname]
|
16 |
rows = list(ws.rows)
|
17 |
tb += f"<table><caption>{sheetname}</caption><tr>"
|
18 |
-
for t in list(rows[0]):
|
|
|
19 |
tb += "</tr>"
|
20 |
for r in list(rows[1:]):
|
21 |
tb += "<tr>"
|
22 |
-
for i,c in enumerate(r):
|
23 |
-
if c.value is None:
|
24 |
-
|
|
|
|
|
25 |
tb += "</tr>"
|
26 |
tb += "</table>\n"
|
27 |
return tb
|
@@ -38,13 +41,15 @@ class HuExcelParser:
|
|
38 |
ti = list(rows[0])
|
39 |
for r in list(rows[1:]):
|
40 |
l = []
|
41 |
-
for i,c in enumerate(r):
|
42 |
-
if not c.value:
|
|
|
43 |
t = str(ti[i].value) if i < len(ti) else ""
|
44 |
t += (":" if t else "") + str(c.value)
|
45 |
l.append(t)
|
46 |
l = "; ".join(l)
|
47 |
-
if sheetname.lower().find("sheet") <0:
|
|
|
48 |
res.append(l)
|
49 |
return res
|
50 |
|
|
|
15 |
ws = wb[sheetname]
|
16 |
rows = list(ws.rows)
|
17 |
tb += f"<table><caption>{sheetname}</caption><tr>"
|
18 |
+
for t in list(rows[0]):
|
19 |
+
tb += f"<th>{t.value}</th>"
|
20 |
tb += "</tr>"
|
21 |
for r in list(rows[1:]):
|
22 |
tb += "<tr>"
|
23 |
+
for i, c in enumerate(r):
|
24 |
+
if c.value is None:
|
25 |
+
tb += "<td></td>"
|
26 |
+
else:
|
27 |
+
tb += f"<td>{c.value}</td>"
|
28 |
tb += "</tr>"
|
29 |
tb += "</table>\n"
|
30 |
return tb
|
|
|
41 |
ti = list(rows[0])
|
42 |
for r in list(rows[1:]):
|
43 |
l = []
|
44 |
+
for i, c in enumerate(r):
|
45 |
+
if not c.value:
|
46 |
+
continue
|
47 |
t = str(ti[i].value) if i < len(ti) else ""
|
48 |
t += (":" if t else "") + str(c.value)
|
49 |
l.append(t)
|
50 |
l = "; ".join(l)
|
51 |
+
if sheetname.lower().find("sheet") < 0:
|
52 |
+
l += " ——" + sheetname
|
53 |
res.append(l)
|
54 |
return res
|
55 |
|
deepdoc/parser/pdf_parser.py
CHANGED
@@ -43,9 +43,11 @@ class HuParser:
|
|
43 |
"rag/res/deepdoc"),
|
44 |
local_files_only=True)
|
45 |
except Exception as e:
|
46 |
-
model_dir = snapshot_download(
|
|
|
47 |
|
48 |
-
self.updown_cnt_mdl.load_model(os.path.join(
|
|
|
49 |
self.page_from = 0
|
50 |
"""
|
51 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
@@ -72,7 +74,7 @@ class HuParser:
|
|
72 |
def _y_dis(
|
73 |
self, a, b):
|
74 |
return (
|
75 |
-
|
76 |
|
77 |
def _match_proj(self, b):
|
78 |
proj_patt = [
|
@@ -95,9 +97,9 @@ class HuParser:
|
|
95 |
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
96 |
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
97 |
tks_all = up["text"][-LEN:].strip() \
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
tks_all = huqie.qie(tks_all).split(" ")
|
102 |
fea = [
|
103 |
up.get("R", -1) == down.get("R", -1),
|
@@ -119,7 +121,7 @@ class HuParser:
|
|
119 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
120 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
121 |
True if re.search(r"[\((][^\))]+$", up["text"])
|
122 |
-
|
123 |
self._match_proj(down),
|
124 |
True if re.match(r"[A-Z]", down["text"]) else False,
|
125 |
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
@@ -181,7 +183,7 @@ class HuParser:
|
|
181 |
continue
|
182 |
for tb in tbls: # for table
|
183 |
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
184 |
-
|
185 |
left *= ZM
|
186 |
top *= ZM
|
187 |
right *= ZM
|
@@ -235,7 +237,8 @@ class HuParser:
|
|
235 |
b["R_top"] = rows[ii]["top"]
|
236 |
b["R_bott"] = rows[ii]["bottom"]
|
237 |
|
238 |
-
ii = Recognizer.find_overlapped_with_threashold(
|
|
|
239 |
if ii is not None:
|
240 |
b["H_top"] = headers[ii]["top"]
|
241 |
b["H_bott"] = headers[ii]["bottom"]
|
@@ -272,7 +275,8 @@ class HuParser:
|
|
272 |
)
|
273 |
|
274 |
# merge chars in the same rect
|
275 |
-
for c in Recognizer.sort_X_firstly(
|
|
|
276 |
ii = Recognizer.find_overlapped(c, bxs)
|
277 |
if ii is None:
|
278 |
self.lefted_chars.append(c)
|
@@ -283,13 +287,15 @@ class HuParser:
|
|
283 |
self.lefted_chars.append(c)
|
284 |
continue
|
285 |
if c["text"] == " " and bxs[ii]["text"]:
|
286 |
-
if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]):
|
|
|
287 |
else:
|
288 |
bxs[ii]["text"] += c["text"]
|
289 |
|
290 |
for b in bxs:
|
291 |
if not b["text"]:
|
292 |
-
left, right, top, bott = b["x0"] * ZM, b["x1"] *
|
|
|
293 |
b["text"] = self.ocr.recognize(np.array(img),
|
294 |
np.array([[left, top], [right, top], [right, bott], [left, bott]],
|
295 |
dtype=np.float32))
|
@@ -302,7 +308,8 @@ class HuParser:
|
|
302 |
|
303 |
def _layouts_rec(self, ZM, drop=True):
|
304 |
assert len(self.page_images) == len(self.boxes)
|
305 |
-
self.boxes, self.page_layout = self.layouter(
|
|
|
306 |
# cumlative Y
|
307 |
for i in range(len(self.boxes)):
|
308 |
self.boxes[i]["top"] += \
|
@@ -332,7 +339,8 @@ class HuParser:
|
|
332 |
"equation"]:
|
333 |
i += 1
|
334 |
continue
|
335 |
-
if abs(self._y_dis(b, b_)
|
|
|
336 |
# merge
|
337 |
bxs[i]["x1"] = b_["x1"]
|
338 |
bxs[i]["top"] = (b["top"] + b_["top"]) / 2
|
@@ -366,12 +374,15 @@ class HuParser:
|
|
366 |
self.boxes = bxs
|
367 |
|
368 |
def _naive_vertical_merge(self):
|
369 |
-
bxs = Recognizer.sort_Y_firstly(
|
|
|
|
|
370 |
i = 0
|
371 |
while i + 1 < len(bxs):
|
372 |
b = bxs[i]
|
373 |
b_ = bxs[i + 1]
|
374 |
-
if b["page_number"] < b_["page_number"] and re.match(
|
|
|
375 |
bxs.pop(i)
|
376 |
continue
|
377 |
if not b["text"].strip():
|
@@ -379,7 +390,8 @@ class HuParser:
|
|
379 |
continue
|
380 |
concatting_feats = [
|
381 |
b["text"].strip()[-1] in ",;:'\",、‘“;:-",
|
382 |
-
len(b["text"].strip()) > 1 and b["text"].strip(
|
|
|
383 |
b["text"].strip()[0] in "。;?!?”)),,、:",
|
384 |
]
|
385 |
# features for not concating
|
@@ -387,7 +399,7 @@ class HuParser:
|
|
387 |
b.get("layoutno", 0) != b.get("layoutno", 0),
|
388 |
b["text"].strip()[-1] in "。?!?",
|
389 |
self.is_english and b["text"].strip()[-1] in ".!?",
|
390 |
-
b["page_number"] == b_["page_number"] and b_["top"] -
|
391 |
b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5,
|
392 |
b["page_number"] < b_["page_number"] and abs(
|
393 |
b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4,
|
@@ -396,7 +408,12 @@ class HuParser:
|
|
396 |
detach_feats = [b["x1"] < b_["x0"],
|
397 |
b["x0"] > b_["x1"]]
|
398 |
if (any(feats) and not any(concatting_feats)) or any(detach_feats):
|
399 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
400 |
i += 1
|
401 |
continue
|
402 |
# merge up and down
|
@@ -526,31 +543,39 @@ class HuParser:
|
|
526 |
i += 1
|
527 |
continue
|
528 |
findit = True
|
529 |
-
eng = re.match(
|
|
|
|
|
530 |
self.boxes.pop(i)
|
531 |
-
if i >= len(self.boxes):
|
|
|
532 |
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
|
533 |
self.boxes[i]["text"].strip().split(" ")[:2])
|
534 |
while not prefix:
|
535 |
self.boxes.pop(i)
|
536 |
-
if i >= len(self.boxes):
|
|
|
537 |
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
|
538 |
self.boxes[i]["text"].strip().split(" ")[:2])
|
539 |
self.boxes.pop(i)
|
540 |
-
if i >= len(self.boxes) or not prefix:
|
|
|
541 |
for j in range(i, min(i + 128, len(self.boxes))):
|
542 |
if not re.match(prefix, self.boxes[j]["text"]):
|
543 |
continue
|
544 |
-
for k in range(i, j):
|
|
|
545 |
break
|
546 |
-
if findit:
|
|
|
547 |
|
548 |
page_dirty = [0] * len(self.page_images)
|
549 |
for b in self.boxes:
|
550 |
if re.search(r"(··|··|··)", b["text"]):
|
551 |
page_dirty[b["page_number"] - 1] += 1
|
552 |
page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3])
|
553 |
-
if not page_dirty:
|
|
|
554 |
i = 0
|
555 |
while i < len(self.boxes):
|
556 |
if self.boxes[i]["page_number"] in page_dirty:
|
@@ -582,7 +607,8 @@ class HuParser:
|
|
582 |
b_["top"] = b["top"]
|
583 |
self.boxes.pop(i)
|
584 |
|
585 |
-
def _extract_table_figure(self, need_image, ZM,
|
|
|
586 |
tables = {}
|
587 |
figures = {}
|
588 |
# extract figure and table boxes
|
@@ -594,7 +620,7 @@ class HuParser:
|
|
594 |
i += 1
|
595 |
continue
|
596 |
lout_no = str(self.boxes[i]["page_number"]) + \
|
597 |
-
|
598 |
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
|
599 |
"title",
|
600 |
"figure caption",
|
@@ -761,7 +787,8 @@ class HuParser:
|
|
761 |
for k, bxs in tables.items():
|
762 |
if not bxs:
|
763 |
continue
|
764 |
-
bxs = Recognizer.sort_Y_firstly(bxs, np.mean(
|
|
|
765 |
poss = []
|
766 |
res.append((cropout(bxs, "table", poss),
|
767 |
self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
|
@@ -769,7 +796,8 @@ class HuParser:
|
|
769 |
|
770 |
assert len(positions) == len(res)
|
771 |
|
772 |
-
if need_position:
|
|
|
773 |
return res
|
774 |
|
775 |
def proj_match(self, line):
|
@@ -873,7 +901,8 @@ class HuParser:
|
|
873 |
boxes.pop(0)
|
874 |
mw = np.mean(widths)
|
875 |
if mj or mw / pw >= 0.35 or mw > 200:
|
876 |
-
res.append(
|
|
|
877 |
else:
|
878 |
logging.debug("REMOVED: " +
|
879 |
"<<".join([c["text"] for c in lines]))
|
@@ -883,13 +912,16 @@ class HuParser:
|
|
883 |
@staticmethod
|
884 |
def total_page_number(fnm, binary=None):
|
885 |
try:
|
886 |
-
pdf = pdfplumber.open(
|
|
|
887 |
return len(pdf.pages)
|
888 |
except Exception as e:
|
889 |
-
pdf = fitz.open(fnm) if not binary else fitz.open(
|
|
|
890 |
return len(pdf)
|
891 |
|
892 |
-
def __images__(self, fnm, zoomin=3, page_from=0,
|
|
|
893 |
self.lefted_chars = []
|
894 |
self.mean_height = []
|
895 |
self.mean_width = []
|
@@ -899,21 +931,26 @@ class HuParser:
|
|
899 |
self.page_layout = []
|
900 |
self.page_from = page_from
|
901 |
try:
|
902 |
-
self.pdf = pdfplumber.open(fnm) if isinstance(
|
|
|
903 |
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
904 |
enumerate(self.pdf.pages[page_from:page_to])]
|
905 |
self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in
|
906 |
self.pdf.pages[page_from:page_to]]
|
907 |
self.total_page = len(self.pdf.pages)
|
908 |
except Exception as e:
|
909 |
-
self.pdf = fitz.open(fnm) if isinstance(
|
|
|
|
|
910 |
self.page_images = []
|
911 |
self.page_chars = []
|
912 |
mat = fitz.Matrix(zoomin, zoomin)
|
913 |
self.total_page = len(self.pdf)
|
914 |
for i, page in enumerate(self.pdf):
|
915 |
-
if i < page_from:
|
916 |
-
|
|
|
|
|
917 |
pix = page.get_pixmap(matrix=mat)
|
918 |
img = Image.frombytes("RGB", [pix.width, pix.height],
|
919 |
pix.samples)
|
@@ -930,7 +967,7 @@ class HuParser:
|
|
930 |
if isinstance(a, dict):
|
931 |
self.outlines.append((a["/Title"], depth))
|
932 |
continue
|
933 |
-
dfs(a, depth+1)
|
934 |
dfs(outlines, 0)
|
935 |
except Exception as e:
|
936 |
logging.warning(f"Outlines exception: {e}")
|
@@ -940,8 +977,9 @@ class HuParser:
|
|
940 |
logging.info("Images converted.")
|
941 |
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
942 |
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
943 |
-
|
944 |
-
if sum([1 if e else 0 for e in self.is_english]) > len(
|
|
|
945 |
self.is_english = True
|
946 |
else:
|
947 |
self.is_english = False
|
@@ -970,9 +1008,11 @@ class HuParser:
|
|
970 |
# self.page_cum_height.append(
|
971 |
# np.max([c["bottom"] for c in chars]))
|
972 |
self.__ocr(i + 1, img, chars, zoomin)
|
973 |
-
if callback:
|
|
|
974 |
|
975 |
-
if not self.is_english and not any(
|
|
|
976 |
bxes = [b for bxs in self.boxes for b in bxs]
|
977 |
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}",
|
978 |
"".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
|
@@ -989,7 +1029,8 @@ class HuParser:
|
|
989 |
self._text_merge()
|
990 |
self._concat_downward()
|
991 |
self._filter_forpages()
|
992 |
-
tbls = self._extract_table_figure(
|
|
|
993 |
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
|
994 |
|
995 |
def remove_tag(self, txt):
|
@@ -1003,15 +1044,19 @@ class HuParser:
|
|
1003 |
"#").strip("@").split("\t")
|
1004 |
left, right, top, bottom = float(left), float(
|
1005 |
right), float(top), float(bottom)
|
1006 |
-
poss.append(([int(p) - 1 for p in pn.split("-")],
|
|
|
1007 |
if not poss:
|
1008 |
-
if need_position:
|
|
|
1009 |
return
|
1010 |
|
1011 |
-
max_width = max(
|
|
|
1012 |
GAP = 6
|
1013 |
pos = poss[0]
|
1014 |
-
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(
|
|
|
1015 |
pos = poss[-1]
|
1016 |
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP),
|
1017 |
min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
|
@@ -1026,7 +1071,7 @@ class HuParser:
|
|
1026 |
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
1027 |
right *
|
1028 |
ZM, min(
|
1029 |
-
|
1030 |
))
|
1031 |
)
|
1032 |
if 0 < ii < len(poss) - 1:
|
@@ -1047,7 +1092,8 @@ class HuParser:
|
|
1047 |
bottom -= self.page_images[pn].size[1]
|
1048 |
|
1049 |
if not imgs:
|
1050 |
-
if need_position:
|
|
|
1051 |
return
|
1052 |
height = 0
|
1053 |
for img in imgs:
|
@@ -1076,12 +1122,14 @@ class HuParser:
|
|
1076 |
pn = bx["page_number"]
|
1077 |
top = bx["top"] - self.page_cum_height[pn - 1]
|
1078 |
bott = bx["bottom"] - self.page_cum_height[pn - 1]
|
1079 |
-
poss.append((pn, bx["x0"], bx["x1"], top, min(
|
|
|
1080 |
while bott * ZM > self.page_images[pn - 1].size[1]:
|
1081 |
bott -= self.page_images[pn - 1].size[1] / ZM
|
1082 |
top = 0
|
1083 |
pn += 1
|
1084 |
-
poss.append((pn, bx["x0"], bx["x1"], top, min(
|
|
|
1085 |
return poss
|
1086 |
|
1087 |
|
@@ -1090,11 +1138,14 @@ class PlainParser(object):
|
|
1090 |
self.outlines = []
|
1091 |
lines = []
|
1092 |
try:
|
1093 |
-
self.pdf = pdf2_read(
|
|
|
|
|
1094 |
for page in self.pdf.pages[from_page:to_page]:
|
1095 |
lines.extend([t for t in page.extract_text().split("\n")])
|
1096 |
|
1097 |
outlines = self.pdf.outline
|
|
|
1098 |
def dfs(arr, depth):
|
1099 |
for a in arr:
|
1100 |
if isinstance(a, dict):
|
@@ -1117,5 +1168,6 @@ class PlainParser(object):
|
|
1117 |
def remove_tag(txt):
|
1118 |
raise NotImplementedError
|
1119 |
|
|
|
1120 |
if __name__ == "__main__":
|
1121 |
pass
|
|
|
43 |
"rag/res/deepdoc"),
|
44 |
local_files_only=True)
|
45 |
except Exception as e:
|
46 |
+
model_dir = snapshot_download(
|
47 |
+
repo_id="InfiniFlow/text_concat_xgb_v1.0")
|
48 |
|
49 |
+
self.updown_cnt_mdl.load_model(os.path.join(
|
50 |
+
model_dir, "updown_concat_xgb.model"))
|
51 |
self.page_from = 0
|
52 |
"""
|
53 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
|
74 |
def _y_dis(
|
75 |
self, a, b):
|
76 |
return (
|
77 |
+
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
78 |
|
79 |
def _match_proj(self, b):
|
80 |
proj_patt = [
|
|
|
97 |
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
98 |
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
99 |
tks_all = up["text"][-LEN:].strip() \
|
100 |
+
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
101 |
+
up["text"][-1] + down["text"][0]) else "") \
|
102 |
+
+ down["text"][:LEN].strip()
|
103 |
tks_all = huqie.qie(tks_all).split(" ")
|
104 |
fea = [
|
105 |
up.get("R", -1) == down.get("R", -1),
|
|
|
121 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
122 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
123 |
True if re.search(r"[\((][^\))]+$", up["text"])
|
124 |
+
and re.search(r"[\))]", down["text"]) else False,
|
125 |
self._match_proj(down),
|
126 |
True if re.match(r"[A-Z]", down["text"]) else False,
|
127 |
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
|
|
183 |
continue
|
184 |
for tb in tbls: # for table
|
185 |
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
186 |
+
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
187 |
left *= ZM
|
188 |
top *= ZM
|
189 |
right *= ZM
|
|
|
237 |
b["R_top"] = rows[ii]["top"]
|
238 |
b["R_bott"] = rows[ii]["bottom"]
|
239 |
|
240 |
+
ii = Recognizer.find_overlapped_with_threashold(
|
241 |
+
b, headers, thr=0.3)
|
242 |
if ii is not None:
|
243 |
b["H_top"] = headers[ii]["top"]
|
244 |
b["H_bott"] = headers[ii]["bottom"]
|
|
|
275 |
)
|
276 |
|
277 |
# merge chars in the same rect
|
278 |
+
for c in Recognizer.sort_X_firstly(
|
279 |
+
chars, self.mean_width[pagenum - 1] // 4):
|
280 |
ii = Recognizer.find_overlapped(c, bxs)
|
281 |
if ii is None:
|
282 |
self.lefted_chars.append(c)
|
|
|
287 |
self.lefted_chars.append(c)
|
288 |
continue
|
289 |
if c["text"] == " " and bxs[ii]["text"]:
|
290 |
+
if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]):
|
291 |
+
bxs[ii]["text"] += " "
|
292 |
else:
|
293 |
bxs[ii]["text"] += c["text"]
|
294 |
|
295 |
for b in bxs:
|
296 |
if not b["text"]:
|
297 |
+
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
|
298 |
+
ZM, b["top"] * ZM, b["bottom"] * ZM
|
299 |
b["text"] = self.ocr.recognize(np.array(img),
|
300 |
np.array([[left, top], [right, top], [right, bott], [left, bott]],
|
301 |
dtype=np.float32))
|
|
|
308 |
|
309 |
def _layouts_rec(self, ZM, drop=True):
|
310 |
assert len(self.page_images) == len(self.boxes)
|
311 |
+
self.boxes, self.page_layout = self.layouter(
|
312 |
+
self.page_images, self.boxes, ZM, drop=drop)
|
313 |
# cumlative Y
|
314 |
for i in range(len(self.boxes)):
|
315 |
self.boxes[i]["top"] += \
|
|
|
339 |
"equation"]:
|
340 |
i += 1
|
341 |
continue
|
342 |
+
if abs(self._y_dis(b, b_)
|
343 |
+
) < self.mean_height[bxs[i]["page_number"] - 1] / 3:
|
344 |
# merge
|
345 |
bxs[i]["x1"] = b_["x1"]
|
346 |
bxs[i]["top"] = (b["top"] + b_["top"]) / 2
|
|
|
374 |
self.boxes = bxs
|
375 |
|
376 |
def _naive_vertical_merge(self):
|
377 |
+
bxs = Recognizer.sort_Y_firstly(
|
378 |
+
self.boxes, np.median(
|
379 |
+
self.mean_height) / 3)
|
380 |
i = 0
|
381 |
while i + 1 < len(bxs):
|
382 |
b = bxs[i]
|
383 |
b_ = bxs[i + 1]
|
384 |
+
if b["page_number"] < b_["page_number"] and re.match(
|
385 |
+
r"[0-9 •一—-]+$", b["text"]):
|
386 |
bxs.pop(i)
|
387 |
continue
|
388 |
if not b["text"].strip():
|
|
|
390 |
continue
|
391 |
concatting_feats = [
|
392 |
b["text"].strip()[-1] in ",;:'\",、‘“;:-",
|
393 |
+
len(b["text"].strip()) > 1 and b["text"].strip(
|
394 |
+
)[-2] in ",;:'\",‘“、;:",
|
395 |
b["text"].strip()[0] in "。;?!?”)),,、:",
|
396 |
]
|
397 |
# features for not concating
|
|
|
399 |
b.get("layoutno", 0) != b.get("layoutno", 0),
|
400 |
b["text"].strip()[-1] in "。?!?",
|
401 |
self.is_english and b["text"].strip()[-1] in ".!?",
|
402 |
+
b["page_number"] == b_["page_number"] and b_["top"] -
|
403 |
b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5,
|
404 |
b["page_number"] < b_["page_number"] and abs(
|
405 |
b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4,
|
|
|
408 |
detach_feats = [b["x1"] < b_["x0"],
|
409 |
b["x0"] > b_["x1"]]
|
410 |
if (any(feats) and not any(concatting_feats)) or any(detach_feats):
|
411 |
+
print(
|
412 |
+
b["text"],
|
413 |
+
b_["text"],
|
414 |
+
any(feats),
|
415 |
+
any(concatting_feats),
|
416 |
+
any(detach_feats))
|
417 |
i += 1
|
418 |
continue
|
419 |
# merge up and down
|
|
|
543 |
i += 1
|
544 |
continue
|
545 |
findit = True
|
546 |
+
eng = re.match(
|
547 |
+
r"[0-9a-zA-Z :'.-]{5,}",
|
548 |
+
self.boxes[i]["text"].strip())
|
549 |
self.boxes.pop(i)
|
550 |
+
if i >= len(self.boxes):
|
551 |
+
break
|
552 |
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
|
553 |
self.boxes[i]["text"].strip().split(" ")[:2])
|
554 |
while not prefix:
|
555 |
self.boxes.pop(i)
|
556 |
+
if i >= len(self.boxes):
|
557 |
+
break
|
558 |
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
|
559 |
self.boxes[i]["text"].strip().split(" ")[:2])
|
560 |
self.boxes.pop(i)
|
561 |
+
if i >= len(self.boxes) or not prefix:
|
562 |
+
break
|
563 |
for j in range(i, min(i + 128, len(self.boxes))):
|
564 |
if not re.match(prefix, self.boxes[j]["text"]):
|
565 |
continue
|
566 |
+
for k in range(i, j):
|
567 |
+
self.boxes.pop(i)
|
568 |
break
|
569 |
+
if findit:
|
570 |
+
return
|
571 |
|
572 |
page_dirty = [0] * len(self.page_images)
|
573 |
for b in self.boxes:
|
574 |
if re.search(r"(··|··|··)", b["text"]):
|
575 |
page_dirty[b["page_number"] - 1] += 1
|
576 |
page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3])
|
577 |
+
if not page_dirty:
|
578 |
+
return
|
579 |
i = 0
|
580 |
while i < len(self.boxes):
|
581 |
if self.boxes[i]["page_number"] in page_dirty:
|
|
|
607 |
b_["top"] = b["top"]
|
608 |
self.boxes.pop(i)
|
609 |
|
610 |
+
def _extract_table_figure(self, need_image, ZM,
|
611 |
+
return_html, need_position):
|
612 |
tables = {}
|
613 |
figures = {}
|
614 |
# extract figure and table boxes
|
|
|
620 |
i += 1
|
621 |
continue
|
622 |
lout_no = str(self.boxes[i]["page_number"]) + \
|
623 |
+
"-" + str(self.boxes[i]["layoutno"])
|
624 |
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
|
625 |
"title",
|
626 |
"figure caption",
|
|
|
787 |
for k, bxs in tables.items():
|
788 |
if not bxs:
|
789 |
continue
|
790 |
+
bxs = Recognizer.sort_Y_firstly(bxs, np.mean(
|
791 |
+
[(b["bottom"] - b["top"]) / 2 for b in bxs]))
|
792 |
poss = []
|
793 |
res.append((cropout(bxs, "table", poss),
|
794 |
self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
|
|
|
796 |
|
797 |
assert len(positions) == len(res)
|
798 |
|
799 |
+
if need_position:
|
800 |
+
return list(zip(res, positions))
|
801 |
return res
|
802 |
|
803 |
def proj_match(self, line):
|
|
|
901 |
boxes.pop(0)
|
902 |
mw = np.mean(widths)
|
903 |
if mj or mw / pw >= 0.35 or mw > 200:
|
904 |
+
res.append(
|
905 |
+
"\n".join([c["text"] + self._line_tag(c, ZM) for c in lines]))
|
906 |
else:
|
907 |
logging.debug("REMOVED: " +
|
908 |
"<<".join([c["text"] for c in lines]))
|
|
|
912 |
@staticmethod
|
913 |
def total_page_number(fnm, binary=None):
|
914 |
try:
|
915 |
+
pdf = pdfplumber.open(
|
916 |
+
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
917 |
return len(pdf.pages)
|
918 |
except Exception as e:
|
919 |
+
pdf = fitz.open(fnm) if not binary else fitz.open(
|
920 |
+
stream=fnm, filetype="pdf")
|
921 |
return len(pdf)
|
922 |
|
923 |
+
def __images__(self, fnm, zoomin=3, page_from=0,
|
924 |
+
page_to=299, callback=None):
|
925 |
self.lefted_chars = []
|
926 |
self.mean_height = []
|
927 |
self.mean_width = []
|
|
|
931 |
self.page_layout = []
|
932 |
self.page_from = page_from
|
933 |
try:
|
934 |
+
self.pdf = pdfplumber.open(fnm) if isinstance(
|
935 |
+
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
936 |
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
937 |
enumerate(self.pdf.pages[page_from:page_to])]
|
938 |
self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in
|
939 |
self.pdf.pages[page_from:page_to]]
|
940 |
self.total_page = len(self.pdf.pages)
|
941 |
except Exception as e:
|
942 |
+
self.pdf = fitz.open(fnm) if isinstance(
|
943 |
+
fnm, str) else fitz.open(
|
944 |
+
stream=fnm, filetype="pdf")
|
945 |
self.page_images = []
|
946 |
self.page_chars = []
|
947 |
mat = fitz.Matrix(zoomin, zoomin)
|
948 |
self.total_page = len(self.pdf)
|
949 |
for i, page in enumerate(self.pdf):
|
950 |
+
if i < page_from:
|
951 |
+
continue
|
952 |
+
if i >= page_to:
|
953 |
+
break
|
954 |
pix = page.get_pixmap(matrix=mat)
|
955 |
img = Image.frombytes("RGB", [pix.width, pix.height],
|
956 |
pix.samples)
|
|
|
967 |
if isinstance(a, dict):
|
968 |
self.outlines.append((a["/Title"], depth))
|
969 |
continue
|
970 |
+
dfs(a, depth + 1)
|
971 |
dfs(outlines, 0)
|
972 |
except Exception as e:
|
973 |
logging.warning(f"Outlines exception: {e}")
|
|
|
977 |
logging.info("Images converted.")
|
978 |
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
979 |
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
980 |
+
range(len(self.page_chars))]
|
981 |
+
if sum([1 if e else 0 for e in self.is_english]) > len(
|
982 |
+
self.page_images) / 2:
|
983 |
self.is_english = True
|
984 |
else:
|
985 |
self.is_english = False
|
|
|
1008 |
# self.page_cum_height.append(
|
1009 |
# np.max([c["bottom"] for c in chars]))
|
1010 |
self.__ocr(i + 1, img, chars, zoomin)
|
1011 |
+
if callback:
|
1012 |
+
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
1013 |
|
1014 |
+
if not self.is_english and not any(
|
1015 |
+
[c for c in self.page_chars]) and self.boxes:
|
1016 |
bxes = [b for bxs in self.boxes for b in bxs]
|
1017 |
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}",
|
1018 |
"".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
|
|
|
1029 |
self._text_merge()
|
1030 |
self._concat_downward()
|
1031 |
self._filter_forpages()
|
1032 |
+
tbls = self._extract_table_figure(
|
1033 |
+
need_image, zoomin, return_html, False)
|
1034 |
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
|
1035 |
|
1036 |
def remove_tag(self, txt):
|
|
|
1044 |
"#").strip("@").split("\t")
|
1045 |
left, right, top, bottom = float(left), float(
|
1046 |
right), float(top), float(bottom)
|
1047 |
+
poss.append(([int(p) - 1 for p in pn.split("-")],
|
1048 |
+
left, right, top, bottom))
|
1049 |
if not poss:
|
1050 |
+
if need_position:
|
1051 |
+
return None, None
|
1052 |
return
|
1053 |
|
1054 |
+
max_width = max(
|
1055 |
+
np.max([right - left for (_, left, right, _, _) in poss]), 6)
|
1056 |
GAP = 6
|
1057 |
pos = poss[0]
|
1058 |
+
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(
|
1059 |
+
0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
1060 |
pos = poss[-1]
|
1061 |
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP),
|
1062 |
min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
|
|
|
1071 |
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
1072 |
right *
|
1073 |
ZM, min(
|
1074 |
+
bottom, self.page_images[pns[0]].size[1])
|
1075 |
))
|
1076 |
)
|
1077 |
if 0 < ii < len(poss) - 1:
|
|
|
1092 |
bottom -= self.page_images[pn].size[1]
|
1093 |
|
1094 |
if not imgs:
|
1095 |
+
if need_position:
|
1096 |
+
return None, None
|
1097 |
return
|
1098 |
height = 0
|
1099 |
for img in imgs:
|
|
|
1122 |
pn = bx["page_number"]
|
1123 |
top = bx["top"] - self.page_cum_height[pn - 1]
|
1124 |
bott = bx["bottom"] - self.page_cum_height[pn - 1]
|
1125 |
+
poss.append((pn, bx["x0"], bx["x1"], top, min(
|
1126 |
+
bott, self.page_images[pn - 1].size[1] / ZM)))
|
1127 |
while bott * ZM > self.page_images[pn - 1].size[1]:
|
1128 |
bott -= self.page_images[pn - 1].size[1] / ZM
|
1129 |
top = 0
|
1130 |
pn += 1
|
1131 |
+
poss.append((pn, bx["x0"], bx["x1"], top, min(
|
1132 |
+
bott, self.page_images[pn - 1].size[1] / ZM)))
|
1133 |
return poss
|
1134 |
|
1135 |
|
|
|
1138 |
self.outlines = []
|
1139 |
lines = []
|
1140 |
try:
|
1141 |
+
self.pdf = pdf2_read(
|
1142 |
+
filename if isinstance(
|
1143 |
+
filename, str) else BytesIO(filename))
|
1144 |
for page in self.pdf.pages[from_page:to_page]:
|
1145 |
lines.extend([t for t in page.extract_text().split("\n")])
|
1146 |
|
1147 |
outlines = self.pdf.outline
|
1148 |
+
|
1149 |
def dfs(arr, depth):
|
1150 |
for a in arr:
|
1151 |
if isinstance(a, dict):
|
|
|
1168 |
def remove_tag(txt):
|
1169 |
raise NotImplementedError
|
1170 |
|
1171 |
+
|
1172 |
if __name__ == "__main__":
|
1173 |
pass
|
deepdoc/parser/ppt_parser.py
CHANGED
@@ -23,7 +23,8 @@ class HuPptParser(object):
|
|
23 |
tb = shape.table
|
24 |
rows = []
|
25 |
for i in range(1, len(tb.rows)):
|
26 |
-
rows.append("; ".join([tb.cell(
|
|
|
27 |
return "\n".join(rows)
|
28 |
|
29 |
if shape.has_text_frame:
|
@@ -31,9 +32,10 @@ class HuPptParser(object):
|
|
31 |
|
32 |
if shape.shape_type == 6:
|
33 |
texts = []
|
34 |
-
for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)):
|
35 |
t = self.__extract(p)
|
36 |
-
if t:
|
|
|
37 |
return "\n".join(texts)
|
38 |
|
39 |
def __call__(self, fnm, from_page, to_page, callback=None):
|
@@ -43,12 +45,16 @@ class HuPptParser(object):
|
|
43 |
txts = []
|
44 |
self.total_page = len(ppt.slides)
|
45 |
for i, slide in enumerate(ppt.slides):
|
46 |
-
if i < from_page:
|
47 |
-
|
|
|
|
|
48 |
texts = []
|
49 |
-
for shape in sorted(
|
|
|
50 |
txt = self.__extract(shape)
|
51 |
-
if txt:
|
|
|
52 |
txts.append("\n".join(texts))
|
53 |
|
54 |
return txts
|
|
|
23 |
tb = shape.table
|
24 |
rows = []
|
25 |
for i in range(1, len(tb.rows)):
|
26 |
+
rows.append("; ".join([tb.cell(
|
27 |
+
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
|
28 |
return "\n".join(rows)
|
29 |
|
30 |
if shape.has_text_frame:
|
|
|
32 |
|
33 |
if shape.shape_type == 6:
|
34 |
texts = []
|
35 |
+
for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
|
36 |
t = self.__extract(p)
|
37 |
+
if t:
|
38 |
+
texts.append(t)
|
39 |
return "\n".join(texts)
|
40 |
|
41 |
def __call__(self, fnm, from_page, to_page, callback=None):
|
|
|
45 |
txts = []
|
46 |
self.total_page = len(ppt.slides)
|
47 |
for i, slide in enumerate(ppt.slides):
|
48 |
+
if i < from_page:
|
49 |
+
continue
|
50 |
+
if i >= to_page:
|
51 |
+
break
|
52 |
texts = []
|
53 |
+
for shape in sorted(
|
54 |
+
slide.shapes, key=lambda x: (x.top // 10, x.left)):
|
55 |
txt = self.__extract(shape)
|
56 |
+
if txt:
|
57 |
+
texts.append(txt)
|
58 |
txts.append("\n".join(texts))
|
59 |
|
60 |
return txts
|
deepdoc/vision/layout_recognizer.py
CHANGED
@@ -24,18 +24,19 @@ from deepdoc.vision import Recognizer
|
|
24 |
|
25 |
class LayoutRecognizer(Recognizer):
|
26 |
labels = [
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
def __init__(self, domain):
|
40 |
try:
|
41 |
model_dir = snapshot_download(
|
@@ -47,10 +48,12 @@ class LayoutRecognizer(Recognizer):
|
|
47 |
except Exception as e:
|
48 |
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
49 |
|
50 |
-
|
|
|
51 |
self.garbage_layouts = ["footer", "header", "reference"]
|
52 |
|
53 |
-
def __call__(self, image_list, ocr_res, scale_factor=3,
|
|
|
54 |
def __is_garbage(b):
|
55 |
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
56 |
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
@@ -75,7 +78,8 @@ class LayoutRecognizer(Recognizer):
|
|
75 |
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
76 |
"page_number": pn,
|
77 |
} for b in lts]
|
78 |
-
lts = self.sort_Y_firstly(lts, np.mean(
|
|
|
79 |
lts = self.layouts_cleanup(bxs, lts)
|
80 |
page_layout.append(lts)
|
81 |
|
@@ -93,17 +97,20 @@ class LayoutRecognizer(Recognizer):
|
|
93 |
continue
|
94 |
|
95 |
ii = self.find_overlapped_with_threashold(bxs[i], lts_,
|
96 |
-
|
97 |
if ii is None: # belong to nothing
|
98 |
bxs[i]["layout_type"] = ""
|
99 |
i += 1
|
100 |
continue
|
101 |
lts_[ii]["visited"] = True
|
102 |
keep_feats = [
|
103 |
-
lts_[
|
104 |
-
|
|
|
|
|
105 |
]
|
106 |
-
if drop and lts_[
|
|
|
107 |
if lts_[ii]["type"] not in garbages:
|
108 |
garbages[lts_[ii]["type"]] = []
|
109 |
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
@@ -111,7 +118,8 @@ class LayoutRecognizer(Recognizer):
|
|
111 |
continue
|
112 |
|
113 |
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
114 |
-
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
|
|
|
115 |
i += 1
|
116 |
|
117 |
for lt in ["footer", "header", "reference", "figure caption",
|
@@ -120,7 +128,7 @@ class LayoutRecognizer(Recognizer):
|
|
120 |
|
121 |
# add box to figure layouts which has not text box
|
122 |
for i, lt in enumerate(
|
123 |
-
[lt for lt in lts if lt["type"] in ["figure","equation"]]):
|
124 |
if lt.get("visited"):
|
125 |
continue
|
126 |
lt = deepcopy(lt)
|
@@ -143,6 +151,3 @@ class LayoutRecognizer(Recognizer):
|
|
143 |
|
144 |
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
145 |
return ocr_res, page_layout
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
24 |
|
25 |
class LayoutRecognizer(Recognizer):
|
26 |
labels = [
|
27 |
+
"_background_",
|
28 |
+
"Text",
|
29 |
+
"Title",
|
30 |
+
"Figure",
|
31 |
+
"Figure caption",
|
32 |
+
"Table",
|
33 |
+
"Table caption",
|
34 |
+
"Header",
|
35 |
+
"Footer",
|
36 |
+
"Reference",
|
37 |
+
"Equation",
|
38 |
+
]
|
39 |
+
|
40 |
def __init__(self, domain):
|
41 |
try:
|
42 |
model_dir = snapshot_download(
|
|
|
48 |
except Exception as e:
|
49 |
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
50 |
|
51 |
+
# os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
52 |
+
super().__init__(self.labels, domain, model_dir)
|
53 |
self.garbage_layouts = ["footer", "header", "reference"]
|
54 |
|
55 |
+
def __call__(self, image_list, ocr_res, scale_factor=3,
|
56 |
+
thr=0.2, batch_size=16, drop=True):
|
57 |
def __is_garbage(b):
|
58 |
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
59 |
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
|
|
78 |
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
79 |
"page_number": pn,
|
80 |
} for b in lts]
|
81 |
+
lts = self.sort_Y_firstly(lts, np.mean(
|
82 |
+
[l["bottom"] - l["top"] for l in lts]) / 2)
|
83 |
lts = self.layouts_cleanup(bxs, lts)
|
84 |
page_layout.append(lts)
|
85 |
|
|
|
97 |
continue
|
98 |
|
99 |
ii = self.find_overlapped_with_threashold(bxs[i], lts_,
|
100 |
+
thr=0.4)
|
101 |
if ii is None: # belong to nothing
|
102 |
bxs[i]["layout_type"] = ""
|
103 |
i += 1
|
104 |
continue
|
105 |
lts_[ii]["visited"] = True
|
106 |
keep_feats = [
|
107 |
+
lts_[
|
108 |
+
ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
109 |
+
lts_[
|
110 |
+
ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
111 |
]
|
112 |
+
if drop and lts_[
|
113 |
+
ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
114 |
if lts_[ii]["type"] not in garbages:
|
115 |
garbages[lts_[ii]["type"]] = []
|
116 |
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
|
|
118 |
continue
|
119 |
|
120 |
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
121 |
+
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
|
122 |
+
ii]["type"] != "equation" else "figure"
|
123 |
i += 1
|
124 |
|
125 |
for lt in ["footer", "header", "reference", "figure caption",
|
|
|
128 |
|
129 |
# add box to figure layouts which has not text box
|
130 |
for i, lt in enumerate(
|
131 |
+
[lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
132 |
if lt.get("visited"):
|
133 |
continue
|
134 |
lt = deepcopy(lt)
|
|
|
151 |
|
152 |
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
153 |
return ocr_res, page_layout
|
|
|
|
|
|
deepdoc/vision/operators.py
CHANGED
@@ -63,6 +63,7 @@ class DecodeImage(object):
|
|
63 |
data['image'] = img
|
64 |
return data
|
65 |
|
|
|
66 |
class StandardizeImage(object):
|
67 |
"""normalize image
|
68 |
Args:
|
@@ -707,4 +708,4 @@ def preprocess(im, preprocess_ops):
|
|
707 |
im, im_info = decode_image(im, im_info)
|
708 |
for operator in preprocess_ops:
|
709 |
im, im_info = operator(im, im_info)
|
710 |
-
return im, im_info
|
|
|
63 |
data['image'] = img
|
64 |
return data
|
65 |
|
66 |
+
|
67 |
class StandardizeImage(object):
|
68 |
"""normalize image
|
69 |
Args:
|
|
|
708 |
im, im_info = decode_image(im, im_info)
|
709 |
for operator in preprocess_ops:
|
710 |
im, im_info = operator(im, im_info)
|
711 |
+
return im, im_info
|
deepdoc/vision/t_ocr.py
CHANGED
@@ -11,12 +11,20 @@
|
|
11 |
# limitations under the License.
|
12 |
#
|
13 |
|
14 |
-
import os, sys
|
15 |
-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
|
16 |
-
import numpy as np
|
17 |
-
import argparse
|
18 |
-
from deepdoc.vision import OCR, init_in_out
|
19 |
from deepdoc.vision.seeit import draw_box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def main(args):
|
22 |
ocr = OCR()
|
@@ -26,14 +34,14 @@ def main(args):
|
|
26 |
bxs = ocr(np.array(img))
|
27 |
bxs = [(line[0], line[1][0]) for line in bxs]
|
28 |
bxs = [{
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
img = draw_box(images[i], bxs, ["ocr"], 1.)
|
34 |
img.save(outputs[i], quality=95)
|
35 |
-
with open(outputs[i] + ".txt", "w+") as f:
|
36 |
-
|
37 |
|
38 |
|
39 |
if __name__ == "__main__":
|
@@ -42,6 +50,6 @@ if __name__ == "__main__":
|
|
42 |
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
43 |
required=True)
|
44 |
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
|
45 |
-
|
46 |
args = parser.parse_args()
|
47 |
-
main(args)
|
|
|
11 |
# limitations under the License.
|
12 |
#
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
from deepdoc.vision.seeit import draw_box
|
15 |
+
from deepdoc.vision import OCR, init_in_out
|
16 |
+
import argparse
|
17 |
+
import numpy as np
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
sys.path.insert(
|
21 |
+
0,
|
22 |
+
os.path.abspath(
|
23 |
+
os.path.join(
|
24 |
+
os.path.dirname(
|
25 |
+
os.path.abspath(__file__)),
|
26 |
+
'../../')))
|
27 |
+
|
28 |
|
29 |
def main(args):
|
30 |
ocr = OCR()
|
|
|
34 |
bxs = ocr(np.array(img))
|
35 |
bxs = [(line[0], line[1][0]) for line in bxs]
|
36 |
bxs = [{
|
37 |
+
"text": t,
|
38 |
+
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
|
39 |
+
"type": "ocr",
|
40 |
+
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
|
41 |
img = draw_box(images[i], bxs, ["ocr"], 1.)
|
42 |
img.save(outputs[i], quality=95)
|
43 |
+
with open(outputs[i] + ".txt", "w+") as f:
|
44 |
+
f.write("\n".join([o["text"] for o in bxs]))
|
45 |
|
46 |
|
47 |
if __name__ == "__main__":
|
|
|
50 |
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
51 |
required=True)
|
52 |
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
|
53 |
+
default="./ocr_outputs")
|
54 |
args = parser.parse_args()
|
55 |
+
main(args)
|
deepdoc/vision/t_recognizer.py
CHANGED
@@ -11,24 +11,35 @@
|
|
11 |
# limitations under the License.
|
12 |
#
|
13 |
|
14 |
-
import
|
|
|
|
|
|
|
|
|
|
|
15 |
import re
|
16 |
|
17 |
import numpy as np
|
18 |
|
19 |
-
sys.path.insert(
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
25 |
|
26 |
|
27 |
def main(args):
|
28 |
images, outputs = init_in_out(args)
|
29 |
if args.mode.lower() == "layout":
|
30 |
labels = LayoutRecognizer.labels
|
31 |
-
detr = Recognizer(
|
|
|
|
|
|
|
|
|
|
|
32 |
if args.mode.lower() == "tsr":
|
33 |
labels = TableStructureRecognizer.labels
|
34 |
detr = TableStructureRecognizer()
|
@@ -39,7 +50,8 @@ def main(args):
|
|
39 |
if args.mode.lower() == "tsr":
|
40 |
#lyt = [t for t in lyt if t["type"] == "table column"]
|
41 |
html = get_table_html(images[i], lyt, ocr)
|
42 |
-
with open(outputs[i]+".html", "w+") as f:
|
|
|
43 |
lyt = [{
|
44 |
"type": t["label"],
|
45 |
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
|
@@ -58,7 +70,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
58 |
"bottom": b[-1][1],
|
59 |
"layout_type": "table",
|
60 |
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
61 |
-
np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3
|
62 |
)
|
63 |
|
64 |
def gather(kwd, fzy=10, ption=0.6):
|
@@ -117,7 +129,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
117 |
margin-bottom: 50px;
|
118 |
border: 1px solid #e1e1e1;
|
119 |
}
|
120 |
-
|
121 |
caption {
|
122 |
color: #6ac1ca;
|
123 |
font-size: 20px;
|
@@ -126,25 +138,25 @@ def get_table_html(img, tb_cpns, ocr):
|
|
126 |
font-weight: 600;
|
127 |
margin-bottom: 10px;
|
128 |
}
|
129 |
-
|
130 |
._table_1nkzy_11 table {
|
131 |
width: 100%%;
|
132 |
border-collapse: collapse;
|
133 |
}
|
134 |
-
|
135 |
th {
|
136 |
color: #fff;
|
137 |
background-color: #6ac1ca;
|
138 |
}
|
139 |
-
|
140 |
td:hover {
|
141 |
background: #c1e8e8;
|
142 |
}
|
143 |
-
|
144 |
tr:nth-child(even) {
|
145 |
background-color: #f2f2f2;
|
146 |
}
|
147 |
-
|
148 |
._table_1nkzy_11 th,
|
149 |
._table_1nkzy_11 td {
|
150 |
text-align: center;
|
@@ -157,7 +169,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
157 |
%s
|
158 |
</body>
|
159 |
</html>
|
160 |
-
"""% TableStructureRecognizer.construct_table(boxes, html=True)
|
161 |
return html
|
162 |
|
163 |
|
@@ -168,7 +180,10 @@ if __name__ == "__main__":
|
|
168 |
required=True)
|
169 |
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
|
170 |
default="./layouts_outputs")
|
171 |
-
parser.add_argument(
|
|
|
|
|
|
|
172 |
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
|
173 |
default="layout")
|
174 |
args = parser.parse_args()
|
|
|
11 |
# limitations under the License.
|
12 |
#
|
13 |
|
14 |
+
from deepdoc.vision.seeit import draw_box
|
15 |
+
from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
|
16 |
+
from api.utils.file_utils import get_project_base_directory
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
import re
|
21 |
|
22 |
import numpy as np
|
23 |
|
24 |
+
sys.path.insert(
|
25 |
+
0,
|
26 |
+
os.path.abspath(
|
27 |
+
os.path.join(
|
28 |
+
os.path.dirname(
|
29 |
+
os.path.abspath(__file__)),
|
30 |
+
'../../')))
|
31 |
|
32 |
|
33 |
def main(args):
|
34 |
images, outputs = init_in_out(args)
|
35 |
if args.mode.lower() == "layout":
|
36 |
labels = LayoutRecognizer.labels
|
37 |
+
detr = Recognizer(
|
38 |
+
labels,
|
39 |
+
"layout",
|
40 |
+
os.path.join(
|
41 |
+
get_project_base_directory(),
|
42 |
+
"rag/res/deepdoc/"))
|
43 |
if args.mode.lower() == "tsr":
|
44 |
labels = TableStructureRecognizer.labels
|
45 |
detr = TableStructureRecognizer()
|
|
|
50 |
if args.mode.lower() == "tsr":
|
51 |
#lyt = [t for t in lyt if t["type"] == "table column"]
|
52 |
html = get_table_html(images[i], lyt, ocr)
|
53 |
+
with open(outputs[i] + ".html", "w+") as f:
|
54 |
+
f.write(html)
|
55 |
lyt = [{
|
56 |
"type": t["label"],
|
57 |
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
|
|
|
70 |
"bottom": b[-1][1],
|
71 |
"layout_type": "table",
|
72 |
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
73 |
+
np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
|
74 |
)
|
75 |
|
76 |
def gather(kwd, fzy=10, ption=0.6):
|
|
|
129 |
margin-bottom: 50px;
|
130 |
border: 1px solid #e1e1e1;
|
131 |
}
|
132 |
+
|
133 |
caption {
|
134 |
color: #6ac1ca;
|
135 |
font-size: 20px;
|
|
|
138 |
font-weight: 600;
|
139 |
margin-bottom: 10px;
|
140 |
}
|
141 |
+
|
142 |
._table_1nkzy_11 table {
|
143 |
width: 100%%;
|
144 |
border-collapse: collapse;
|
145 |
}
|
146 |
+
|
147 |
th {
|
148 |
color: #fff;
|
149 |
background-color: #6ac1ca;
|
150 |
}
|
151 |
+
|
152 |
td:hover {
|
153 |
background: #c1e8e8;
|
154 |
}
|
155 |
+
|
156 |
tr:nth-child(even) {
|
157 |
background-color: #f2f2f2;
|
158 |
}
|
159 |
+
|
160 |
._table_1nkzy_11 th,
|
161 |
._table_1nkzy_11 td {
|
162 |
text-align: center;
|
|
|
169 |
%s
|
170 |
</body>
|
171 |
</html>
|
172 |
+
""" % TableStructureRecognizer.construct_table(boxes, html=True)
|
173 |
return html
|
174 |
|
175 |
|
|
|
180 |
required=True)
|
181 |
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
|
182 |
default="./layouts_outputs")
|
183 |
+
parser.add_argument(
|
184 |
+
'--threshold',
|
185 |
+
help="A threshold to filter out detections. Default: 0.5",
|
186 |
+
default=0.5)
|
187 |
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
|
188 |
default="layout")
|
189 |
args = parser.parse_args()
|
deepdoc/vision/table_structure_recognizer.py
CHANGED
@@ -44,7 +44,8 @@ class TableStructureRecognizer(Recognizer):
|
|
44 |
except Exception as e:
|
45 |
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
46 |
|
47 |
-
|
|
|
48 |
|
49 |
def __call__(self, images, thr=0.2):
|
50 |
tbls = super().__call__(images, thr)
|
@@ -138,7 +139,8 @@ class TableStructureRecognizer(Recognizer):
|
|
138 |
i = 0
|
139 |
while i < len(boxes):
|
140 |
if TableStructureRecognizer.is_caption(boxes[i]):
|
141 |
-
if is_english:
|
|
|
142 |
cap += boxes[i]["text"]
|
143 |
boxes.pop(i)
|
144 |
i -= 1
|
@@ -164,7 +166,7 @@ class TableStructureRecognizer(Recognizer):
|
|
164 |
lst_r = rows[-1]
|
165 |
if lst_r[-1].get("R", "") != b.get("R", "") \
|
166 |
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
167 |
-
|
168 |
btm = b["bottom"]
|
169 |
b["rn"] += 1
|
170 |
rows.append([b])
|
@@ -214,9 +216,9 @@ class TableStructureRecognizer(Recognizer):
|
|
214 |
j += 1
|
215 |
continue
|
216 |
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
217 |
-
|
218 |
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
219 |
-
|
220 |
if f and ff:
|
221 |
j += 1
|
222 |
continue
|
@@ -277,9 +279,9 @@ class TableStructureRecognizer(Recognizer):
|
|
277 |
i += 1
|
278 |
continue
|
279 |
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
280 |
-
|
281 |
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
282 |
-
|
283 |
if f and ff:
|
284 |
i += 1
|
285 |
continue
|
@@ -366,7 +368,8 @@ class TableStructureRecognizer(Recognizer):
|
|
366 |
continue
|
367 |
txt = ""
|
368 |
if arr:
|
369 |
-
h = min(np.min([c["bottom"] - c["top"]
|
|
|
370 |
txt = " ".join([c["text"]
|
371 |
for c in Recognizer.sort_Y_firstly(arr, h)])
|
372 |
txts.append(txt)
|
@@ -438,8 +441,8 @@ class TableStructureRecognizer(Recognizer):
|
|
438 |
else "") + headers[j - 1][k]
|
439 |
else:
|
440 |
headers[j][k] = headers[j - 1][k] \
|
441 |
-
|
442 |
-
|
443 |
|
444 |
logging.debug(
|
445 |
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
|
|
44 |
except Exception as e:
|
45 |
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
46 |
|
47 |
+
# os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
48 |
+
super().__init__(self.labels, "tsr", model_dir)
|
49 |
|
50 |
def __call__(self, images, thr=0.2):
|
51 |
tbls = super().__call__(images, thr)
|
|
|
139 |
i = 0
|
140 |
while i < len(boxes):
|
141 |
if TableStructureRecognizer.is_caption(boxes[i]):
|
142 |
+
if is_english:
|
143 |
+
cap + " "
|
144 |
cap += boxes[i]["text"]
|
145 |
boxes.pop(i)
|
146 |
i -= 1
|
|
|
166 |
lst_r = rows[-1]
|
167 |
if lst_r[-1].get("R", "") != b.get("R", "") \
|
168 |
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
169 |
+
): # new row
|
170 |
btm = b["bottom"]
|
171 |
b["rn"] += 1
|
172 |
rows.append([b])
|
|
|
216 |
j += 1
|
217 |
continue
|
218 |
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
219 |
+
[j - 1][0].get("text")) or j == 0
|
220 |
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
221 |
+
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
222 |
if f and ff:
|
223 |
j += 1
|
224 |
continue
|
|
|
279 |
i += 1
|
280 |
continue
|
281 |
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
282 |
+
[jj][0].get("text")) or i == 0
|
283 |
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
284 |
+
[jj][0].get("text")) or i + 1 >= len(tbl)
|
285 |
if f and ff:
|
286 |
i += 1
|
287 |
continue
|
|
|
368 |
continue
|
369 |
txt = ""
|
370 |
if arr:
|
371 |
+
h = min(np.min([c["bottom"] - c["top"]
|
372 |
+
for c in arr]) / 2, 10)
|
373 |
txt = " ".join([c["text"]
|
374 |
for c in Recognizer.sort_Y_firstly(arr, h)])
|
375 |
txts.append(txt)
|
|
|
441 |
else "") + headers[j - 1][k]
|
442 |
else:
|
443 |
headers[j][k] = headers[j - 1][k] \
|
444 |
+
+ (de if headers[j - 1][k] else "") \
|
445 |
+
+ headers[j][k]
|
446 |
|
447 |
logging.debug(
|
448 |
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
rag/app/book.py
CHANGED
@@ -48,10 +48,12 @@ class Pdf(PdfParser):
|
|
48 |
|
49 |
callback(0.8, "Text extraction finished")
|
50 |
|
51 |
-
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno",""))
|
|
|
52 |
|
53 |
|
54 |
-
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
|
55 |
"""
|
56 |
Supported file formats are docx, pdf, txt.
|
57 |
Since a book is long and not all the parts are useful, if it's a PDF,
|
@@ -63,48 +65,63 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
63 |
}
|
64 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
65 |
pdf_parser = None
|
66 |
-
sections,tbls = [], []
|
67 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
68 |
callback(0.1, "Start to parse.")
|
69 |
doc_parser = DocxParser()
|
70 |
# TODO: table of contents need to be removed
|
71 |
-
sections, tbls = doc_parser(
|
72 |
-
|
|
|
|
|
73 |
callback(0.8, "Finish parsing.")
|
74 |
|
75 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
76 |
-
pdf_parser = Pdf() if kwargs.get(
|
|
|
|
|
77 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
78 |
-
|
79 |
|
80 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
81 |
callback(0.1, "Start to parse.")
|
82 |
txt = ""
|
83 |
-
if binary:
|
|
|
84 |
else:
|
85 |
with open(filename, "r") as f:
|
86 |
while True:
|
87 |
l = f.readline()
|
88 |
-
if not l:
|
|
|
89 |
txt += l
|
90 |
sections = txt.split("\n")
|
91 |
-
sections = [(l,"") for l in sections if l]
|
92 |
-
remove_contents_table(sections, eng
|
|
|
93 |
callback(0.8, "Finish parsing.")
|
94 |
|
95 |
-
else:
|
|
|
|
|
96 |
|
97 |
make_colon_as_title(sections)
|
98 |
-
bull = bullets_category(
|
|
|
99 |
if bull >= 0:
|
100 |
-
chunks = ["\n".join(ck)
|
|
|
101 |
else:
|
102 |
-
sections = [s.split("@") for s,_ in sections]
|
103 |
-
sections = [(pr[0], "@"+pr[1]) for pr in sections if len(pr)==2]
|
104 |
-
chunks = naive_merge(
|
|
|
|
|
|
|
105 |
|
106 |
# is it English
|
107 |
-
|
|
|
108 |
|
109 |
res = tokenize_table(tbls, doc, eng)
|
110 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
@@ -114,6 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
114 |
|
115 |
if __name__ == "__main__":
|
116 |
import sys
|
|
|
117 |
def dummy(prog=None, msg=""):
|
118 |
pass
|
119 |
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
|
|
|
48 |
|
49 |
callback(0.8, "Text extraction finished")
|
50 |
|
51 |
+
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", ""))
|
52 |
+
for b in self.boxes], tbls
|
53 |
|
54 |
|
55 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
56 |
+
lang="Chinese", callback=None, **kwargs):
|
57 |
"""
|
58 |
Supported file formats are docx, pdf, txt.
|
59 |
Since a book is long and not all the parts are useful, if it's a PDF,
|
|
|
65 |
}
|
66 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
67 |
pdf_parser = None
|
68 |
+
sections, tbls = [], []
|
69 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
70 |
callback(0.1, "Start to parse.")
|
71 |
doc_parser = DocxParser()
|
72 |
# TODO: table of contents need to be removed
|
73 |
+
sections, tbls = doc_parser(
|
74 |
+
binary if binary else filename, from_page=from_page, to_page=to_page)
|
75 |
+
remove_contents_table(sections, eng=is_english(
|
76 |
+
random_choices([t for t, _ in sections], k=200)))
|
77 |
callback(0.8, "Finish parsing.")
|
78 |
|
79 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
80 |
+
pdf_parser = Pdf() if kwargs.get(
|
81 |
+
"parser_config", {}).get(
|
82 |
+
"layout_recognize", True) else PlainParser()
|
83 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
84 |
+
from_page=from_page, to_page=to_page, callback=callback)
|
85 |
|
86 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
87 |
callback(0.1, "Start to parse.")
|
88 |
txt = ""
|
89 |
+
if binary:
|
90 |
+
txt = binary.decode("utf-8")
|
91 |
else:
|
92 |
with open(filename, "r") as f:
|
93 |
while True:
|
94 |
l = f.readline()
|
95 |
+
if not l:
|
96 |
+
break
|
97 |
txt += l
|
98 |
sections = txt.split("\n")
|
99 |
+
sections = [(l, "") for l in sections if l]
|
100 |
+
remove_contents_table(sections, eng=is_english(
|
101 |
+
random_choices([t for t, _ in sections], k=200)))
|
102 |
callback(0.8, "Finish parsing.")
|
103 |
|
104 |
+
else:
|
105 |
+
raise NotImplementedError(
|
106 |
+
"file type not supported yet(docx, pdf, txt supported)")
|
107 |
|
108 |
make_colon_as_title(sections)
|
109 |
+
bull = bullets_category(
|
110 |
+
[t for t in random_choices([t for t, _ in sections], k=100)])
|
111 |
if bull >= 0:
|
112 |
+
chunks = ["\n".join(ck)
|
113 |
+
for ck in hierarchical_merge(bull, sections, 3)]
|
114 |
else:
|
115 |
+
sections = [s.split("@") for s, _ in sections]
|
116 |
+
sections = [(pr[0], "@" + pr[1]) for pr in sections if len(pr) == 2]
|
117 |
+
chunks = naive_merge(
|
118 |
+
sections, kwargs.get(
|
119 |
+
"chunk_token_num", 256), kwargs.get(
|
120 |
+
"delimer", "\n。;!?"))
|
121 |
|
122 |
# is it English
|
123 |
+
# is_english(random_choices([t for t, _ in sections], k=218))
|
124 |
+
eng = lang.lower() == "english"
|
125 |
|
126 |
res = tokenize_table(tbls, doc, eng)
|
127 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
|
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
import sys
|
134 |
+
|
135 |
def dummy(prog=None, msg=""):
|
136 |
pass
|
137 |
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
|
rag/app/laws.py
CHANGED
@@ -35,8 +35,10 @@ class Docx(DocxParser):
|
|
35 |
pn = 0
|
36 |
lines = []
|
37 |
for p in self.doc.paragraphs:
|
38 |
-
if pn > to_page:
|
39 |
-
|
|
|
|
|
40 |
for run in p.runs:
|
41 |
if 'lastRenderedPageBreak' in run._element.xml:
|
42 |
pn += 1
|
@@ -63,15 +65,18 @@ class Pdf(PdfParser):
|
|
63 |
start = timer()
|
64 |
self._layouts_rec(zoomin)
|
65 |
callback(0.67, "Layout analysis finished")
|
66 |
-
cron_logger.info("paddle layouts:".format(
|
|
|
67 |
self._naive_vertical_merge()
|
68 |
|
69 |
callback(0.8, "Text extraction finished")
|
70 |
|
71 |
-
return [(b["text"], self._line_tag(b, zoomin))
|
|
|
72 |
|
73 |
|
74 |
-
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
|
75 |
"""
|
76 |
Supported file formats are docx, pdf, txt.
|
77 |
"""
|
@@ -89,41 +94,50 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
89 |
callback(0.8, "Finish parsing.")
|
90 |
|
91 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
|
97 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
98 |
callback(0.1, "Start to parse.")
|
99 |
txt = ""
|
100 |
-
if binary:
|
|
|
101 |
else:
|
102 |
with open(filename, "r") as f:
|
103 |
while True:
|
104 |
l = f.readline()
|
105 |
-
if not l:
|
|
|
106 |
txt += l
|
107 |
sections = txt.split("\n")
|
108 |
sections = [l for l in sections if l]
|
109 |
callback(0.8, "Finish parsing.")
|
110 |
-
else:
|
|
|
|
|
111 |
|
112 |
# is it English
|
113 |
-
eng = lang.lower() == "english"#is_english(sections)
|
114 |
# Remove 'Contents' part
|
115 |
remove_contents_table(sections, eng)
|
116 |
|
117 |
make_colon_as_title(sections)
|
118 |
bull = bullets_category(sections)
|
119 |
chunks = hierarchical_merge(bull, sections, 3)
|
120 |
-
if not chunks:
|
|
|
121 |
|
122 |
-
return tokenize_chunks(["\n".join(ck)
|
|
|
123 |
|
124 |
|
125 |
if __name__ == "__main__":
|
126 |
import sys
|
|
|
127 |
def dummy(prog=None, msg=""):
|
128 |
pass
|
129 |
chunk(sys.argv[1], callback=dummy)
|
|
|
35 |
pn = 0
|
36 |
lines = []
|
37 |
for p in self.doc.paragraphs:
|
38 |
+
if pn > to_page:
|
39 |
+
break
|
40 |
+
if from_page <= pn < to_page and p.text.strip():
|
41 |
+
lines.append(self.__clean(p.text))
|
42 |
for run in p.runs:
|
43 |
if 'lastRenderedPageBreak' in run._element.xml:
|
44 |
pn += 1
|
|
|
65 |
start = timer()
|
66 |
self._layouts_rec(zoomin)
|
67 |
callback(0.67, "Layout analysis finished")
|
68 |
+
cron_logger.info("paddle layouts:".format(
|
69 |
+
(timer() - start) / (self.total_page + 0.1)))
|
70 |
self._naive_vertical_merge()
|
71 |
|
72 |
callback(0.8, "Text extraction finished")
|
73 |
|
74 |
+
return [(b["text"], self._line_tag(b, zoomin))
|
75 |
+
for b in self.boxes], None
|
76 |
|
77 |
|
78 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
79 |
+
lang="Chinese", callback=None, **kwargs):
|
80 |
"""
|
81 |
Supported file formats are docx, pdf, txt.
|
82 |
"""
|
|
|
94 |
callback(0.8, "Finish parsing.")
|
95 |
|
96 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
97 |
+
pdf_parser = Pdf() if kwargs.get(
|
98 |
+
"parser_config", {}).get(
|
99 |
+
"layout_recognize", True) else PlainParser()
|
100 |
+
for txt, poss in pdf_parser(filename if not binary else binary,
|
101 |
+
from_page=from_page, to_page=to_page, callback=callback)[0]:
|
102 |
+
sections.append(txt + poss)
|
103 |
|
104 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
105 |
callback(0.1, "Start to parse.")
|
106 |
txt = ""
|
107 |
+
if binary:
|
108 |
+
txt = binary.decode("utf-8")
|
109 |
else:
|
110 |
with open(filename, "r") as f:
|
111 |
while True:
|
112 |
l = f.readline()
|
113 |
+
if not l:
|
114 |
+
break
|
115 |
txt += l
|
116 |
sections = txt.split("\n")
|
117 |
sections = [l for l in sections if l]
|
118 |
callback(0.8, "Finish parsing.")
|
119 |
+
else:
|
120 |
+
raise NotImplementedError(
|
121 |
+
"file type not supported yet(docx, pdf, txt supported)")
|
122 |
|
123 |
# is it English
|
124 |
+
eng = lang.lower() == "english" # is_english(sections)
|
125 |
# Remove 'Contents' part
|
126 |
remove_contents_table(sections, eng)
|
127 |
|
128 |
make_colon_as_title(sections)
|
129 |
bull = bullets_category(sections)
|
130 |
chunks = hierarchical_merge(bull, sections, 3)
|
131 |
+
if not chunks:
|
132 |
+
callback(0.99, "No chunk parsed out.")
|
133 |
|
134 |
+
return tokenize_chunks(["\n".join(ck)
|
135 |
+
for ck in chunks], doc, eng, pdf_parser)
|
136 |
|
137 |
|
138 |
if __name__ == "__main__":
|
139 |
import sys
|
140 |
+
|
141 |
def dummy(prog=None, msg=""):
|
142 |
pass
|
143 |
chunk(sys.argv[1], callback=dummy)
|
rag/app/manual.py
CHANGED
@@ -25,10 +25,10 @@ class Pdf(PdfParser):
|
|
25 |
callback
|
26 |
)
|
27 |
callback(msg="OCR finished.")
|
28 |
-
#for bb in self.boxes:
|
29 |
# for b in bb:
|
30 |
# print(b)
|
31 |
-
print("OCR:", timer()-start)
|
32 |
|
33 |
self._layouts_rec(zoomin)
|
34 |
callback(0.65, "Layout analysis finished.")
|
@@ -45,30 +45,35 @@ class Pdf(PdfParser):
|
|
45 |
for b in self.boxes:
|
46 |
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
|
47 |
|
48 |
-
return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin))
|
|
|
49 |
|
50 |
|
51 |
-
|
52 |
-
|
53 |
"""
|
54 |
Only pdf is supported.
|
55 |
"""
|
56 |
pdf_parser = None
|
57 |
|
58 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
59 |
-
pdf_parser = Pdf() if kwargs.get(
|
|
|
|
|
60 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
61 |
-
|
62 |
-
if sections and len(sections[0])<3:
|
|
|
63 |
|
64 |
-
else:
|
|
|
65 |
doc = {
|
66 |
"docnm_kwd": filename
|
67 |
}
|
68 |
doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
|
69 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
70 |
# is it English
|
71 |
-
eng = lang.lower() == "english"#pdf_parser.is_english
|
72 |
|
73 |
# set pivot using the most frequent type of title,
|
74 |
# then merge between 2 pivot
|
@@ -79,7 +84,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
79 |
for txt, _, _ in sections:
|
80 |
for t, lvl in pdf_parser.outlines:
|
81 |
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
|
82 |
-
tks_ = set([txt[i] + txt[i + 1]
|
|
|
83 |
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
|
84 |
levels.append(lvl)
|
85 |
break
|
@@ -87,24 +93,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
87 |
levels.append(max_lvl + 1)
|
88 |
|
89 |
else:
|
90 |
-
bull = bullets_category([txt for txt,_,_ in sections])
|
91 |
-
most_level, levels = title_frequency(
|
|
|
92 |
|
93 |
assert len(sections) == len(levels)
|
94 |
sec_ids = []
|
95 |
sid = 0
|
96 |
for i, lvl in enumerate(levels):
|
97 |
-
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
|
|
98 |
sec_ids.append(sid)
|
99 |
# print(lvl, self.boxes[i]["text"], most_level, sid)
|
100 |
|
101 |
-
sections = [(txt, sec_ids[i], poss)
|
|
|
102 |
for (img, rows), poss in tbls:
|
103 |
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
104 |
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
105 |
|
106 |
def tag(pn, left, right, top, bottom):
|
107 |
-
if pn+left+right+top+bottom == 0:
|
108 |
return ""
|
109 |
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
110 |
.format(pn, left, right, top, bottom)
|
@@ -112,7 +121,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
112 |
chunks = []
|
113 |
last_sid = -2
|
114 |
tk_cnt = 0
|
115 |
-
for txt, sec_id, poss in sorted(sections, key=lambda x: (
|
|
|
116 |
poss = "\t".join([tag(*pos) for pos in poss])
|
117 |
if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1):
|
118 |
if chunks:
|
@@ -121,16 +131,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
121 |
continue
|
122 |
chunks.append(txt + poss)
|
123 |
tk_cnt = num_tokens_from_string(txt)
|
124 |
-
if sec_id > -1:
|
|
|
125 |
|
126 |
res = tokenize_table(tbls, doc, eng)
|
127 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
128 |
return res
|
129 |
|
130 |
|
131 |
-
|
132 |
if __name__ == "__main__":
|
133 |
import sys
|
|
|
134 |
def dummy(prog=None, msg=""):
|
135 |
pass
|
136 |
chunk(sys.argv[1], callback=dummy)
|
|
|
25 |
callback
|
26 |
)
|
27 |
callback(msg="OCR finished.")
|
28 |
+
# for bb in self.boxes:
|
29 |
# for b in bb:
|
30 |
# print(b)
|
31 |
+
print("OCR:", timer() - start)
|
32 |
|
33 |
self._layouts_rec(zoomin)
|
34 |
callback(0.65, "Layout analysis finished.")
|
|
|
45 |
for b in self.boxes:
|
46 |
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
|
47 |
|
48 |
+
return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin))
|
49 |
+
for i, b in enumerate(self.boxes)], tbls
|
50 |
|
51 |
|
52 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
53 |
+
lang="Chinese", callback=None, **kwargs):
|
54 |
"""
|
55 |
Only pdf is supported.
|
56 |
"""
|
57 |
pdf_parser = None
|
58 |
|
59 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
60 |
+
pdf_parser = Pdf() if kwargs.get(
|
61 |
+
"parser_config", {}).get(
|
62 |
+
"layout_recognize", True) else PlainParser()
|
63 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
64 |
+
from_page=from_page, to_page=to_page, callback=callback)
|
65 |
+
if sections and len(sections[0]) < 3:
|
66 |
+
sections = [(t, l, [[0] * 5]) for t, l in sections]
|
67 |
|
68 |
+
else:
|
69 |
+
raise NotImplementedError("file type not supported yet(pdf supported)")
|
70 |
doc = {
|
71 |
"docnm_kwd": filename
|
72 |
}
|
73 |
doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
|
74 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
75 |
# is it English
|
76 |
+
eng = lang.lower() == "english" # pdf_parser.is_english
|
77 |
|
78 |
# set pivot using the most frequent type of title,
|
79 |
# then merge between 2 pivot
|
|
|
84 |
for txt, _, _ in sections:
|
85 |
for t, lvl in pdf_parser.outlines:
|
86 |
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
|
87 |
+
tks_ = set([txt[i] + txt[i + 1]
|
88 |
+
for i in range(min(len(t), len(txt) - 1))])
|
89 |
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
|
90 |
levels.append(lvl)
|
91 |
break
|
|
|
93 |
levels.append(max_lvl + 1)
|
94 |
|
95 |
else:
|
96 |
+
bull = bullets_category([txt for txt, _, _ in sections])
|
97 |
+
most_level, levels = title_frequency(
|
98 |
+
bull, [(txt, l) for txt, l, poss in sections])
|
99 |
|
100 |
assert len(sections) == len(levels)
|
101 |
sec_ids = []
|
102 |
sid = 0
|
103 |
for i, lvl in enumerate(levels):
|
104 |
+
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
105 |
+
sid += 1
|
106 |
sec_ids.append(sid)
|
107 |
# print(lvl, self.boxes[i]["text"], most_level, sid)
|
108 |
|
109 |
+
sections = [(txt, sec_ids[i], poss)
|
110 |
+
for i, (txt, _, poss) in enumerate(sections)]
|
111 |
for (img, rows), poss in tbls:
|
112 |
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
113 |
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
114 |
|
115 |
def tag(pn, left, right, top, bottom):
|
116 |
+
if pn + left + right + top + bottom == 0:
|
117 |
return ""
|
118 |
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
119 |
.format(pn, left, right, top, bottom)
|
|
|
121 |
chunks = []
|
122 |
last_sid = -2
|
123 |
tk_cnt = 0
|
124 |
+
for txt, sec_id, poss in sorted(sections, key=lambda x: (
|
125 |
+
x[-1][0][0], x[-1][0][3], x[-1][0][1])):
|
126 |
poss = "\t".join([tag(*pos) for pos in poss])
|
127 |
if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1):
|
128 |
if chunks:
|
|
|
131 |
continue
|
132 |
chunks.append(txt + poss)
|
133 |
tk_cnt = num_tokens_from_string(txt)
|
134 |
+
if sec_id > -1:
|
135 |
+
last_sid = sec_id
|
136 |
|
137 |
res = tokenize_table(tbls, doc, eng)
|
138 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
139 |
return res
|
140 |
|
141 |
|
|
|
142 |
if __name__ == "__main__":
|
143 |
import sys
|
144 |
+
|
145 |
def dummy(prog=None, msg=""):
|
146 |
pass
|
147 |
chunk(sys.argv[1], callback=dummy)
|
rag/app/naive.py
CHANGED
@@ -44,11 +44,14 @@ class Pdf(PdfParser):
|
|
44 |
tbls = self._extract_table_figure(True, zoomin, True, True)
|
45 |
self._naive_vertical_merge()
|
46 |
|
47 |
-
cron_logger.info("paddle layouts:".format(
|
48 |
-
|
|
|
|
|
49 |
|
50 |
|
51 |
-
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
|
52 |
"""
|
53 |
Supported file formats are docx, pdf, excel, txt.
|
54 |
This method apply the naive ways to chunk files.
|
@@ -56,8 +59,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
56 |
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
57 |
"""
|
58 |
|
59 |
-
eng = lang.lower() == "english"#is_english(cks)
|
60 |
-
parser_config = kwargs.get(
|
|
|
|
|
61 |
doc = {
|
62 |
"docnm_kwd": filename,
|
63 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
@@ -73,9 +78,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
73 |
callback(0.8, "Finish parsing.")
|
74 |
|
75 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
76 |
-
pdf_parser = Pdf(
|
|
|
77 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
78 |
-
|
79 |
res = tokenize_table(tbls, doc, eng)
|
80 |
|
81 |
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
@@ -92,16 +98,21 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
92 |
with open(filename, "r") as f:
|
93 |
while True:
|
94 |
l = f.readline()
|
95 |
-
if not l:
|
|
|
96 |
txt += l
|
97 |
sections = txt.split("\n")
|
98 |
sections = [(l, "") for l in sections if l]
|
99 |
callback(0.8, "Finish parsing.")
|
100 |
|
101 |
else:
|
102 |
-
raise NotImplementedError(
|
|
|
103 |
|
104 |
-
chunks = naive_merge(
|
|
|
|
|
|
|
105 |
|
106 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
107 |
return res
|
@@ -110,9 +121,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
110 |
if __name__ == "__main__":
|
111 |
import sys
|
112 |
|
113 |
-
|
114 |
def dummy(prog=None, msg=""):
|
115 |
pass
|
116 |
|
117 |
-
|
118 |
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
|
|
44 |
tbls = self._extract_table_figure(True, zoomin, True, True)
|
45 |
self._naive_vertical_merge()
|
46 |
|
47 |
+
cron_logger.info("paddle layouts:".format(
|
48 |
+
(timer() - start) / (self.total_page + 0.1)))
|
49 |
+
return [(b["text"], self._line_tag(b, zoomin))
|
50 |
+
for b in self.boxes], tbls
|
51 |
|
52 |
|
53 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
54 |
+
lang="Chinese", callback=None, **kwargs):
|
55 |
"""
|
56 |
Supported file formats are docx, pdf, excel, txt.
|
57 |
This method apply the naive ways to chunk files.
|
|
|
59 |
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
60 |
"""
|
61 |
|
62 |
+
eng = lang.lower() == "english" # is_english(cks)
|
63 |
+
parser_config = kwargs.get(
|
64 |
+
"parser_config", {
|
65 |
+
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True})
|
66 |
doc = {
|
67 |
"docnm_kwd": filename,
|
68 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
|
78 |
callback(0.8, "Finish parsing.")
|
79 |
|
80 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
81 |
+
pdf_parser = Pdf(
|
82 |
+
) if parser_config["layout_recognize"] else PlainParser()
|
83 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
84 |
+
from_page=from_page, to_page=to_page, callback=callback)
|
85 |
res = tokenize_table(tbls, doc, eng)
|
86 |
|
87 |
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
|
|
98 |
with open(filename, "r") as f:
|
99 |
while True:
|
100 |
l = f.readline()
|
101 |
+
if not l:
|
102 |
+
break
|
103 |
txt += l
|
104 |
sections = txt.split("\n")
|
105 |
sections = [(l, "") for l in sections if l]
|
106 |
callback(0.8, "Finish parsing.")
|
107 |
|
108 |
else:
|
109 |
+
raise NotImplementedError(
|
110 |
+
"file type not supported yet(docx, pdf, txt supported)")
|
111 |
|
112 |
+
chunks = naive_merge(
|
113 |
+
sections, parser_config.get(
|
114 |
+
"chunk_token_num", 128), parser_config.get(
|
115 |
+
"delimiter", "\n!?。;!?"))
|
116 |
|
117 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
118 |
return res
|
|
|
121 |
if __name__ == "__main__":
|
122 |
import sys
|
123 |
|
|
|
124 |
def dummy(prog=None, msg=""):
|
125 |
pass
|
126 |
|
|
|
127 |
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
rag/app/one.py
CHANGED
@@ -41,20 +41,23 @@ class Pdf(PdfParser):
|
|
41 |
tbls = self._extract_table_figure(True, zoomin, True, True)
|
42 |
self._concat_downward()
|
43 |
|
44 |
-
sections = [(b["text"], self.get_position(b, zoomin))
|
|
|
45 |
for (img, rows), poss in tbls:
|
46 |
sections.append((rows if isinstance(rows, str) else rows[0],
|
47 |
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
48 |
-
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
|
|
49 |
|
50 |
|
51 |
-
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
|
52 |
"""
|
53 |
Supported file formats are docx, pdf, excel, txt.
|
54 |
One file forms a chunk which maintains original text order.
|
55 |
"""
|
56 |
|
57 |
-
eng = lang.lower() == "english"#is_english(cks)
|
58 |
|
59 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
60 |
callback(0.1, "Start to parse.")
|
@@ -62,8 +65,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
62 |
callback(0.8, "Finish parsing.")
|
63 |
|
64 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
65 |
-
pdf_parser = Pdf() if kwargs.get(
|
66 |
-
|
|
|
|
|
|
|
67 |
sections = [s for s, _ in sections if s]
|
68 |
|
69 |
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
@@ -80,14 +86,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
80 |
with open(filename, "r") as f:
|
81 |
while True:
|
82 |
l = f.readline()
|
83 |
-
if not l:
|
|
|
84 |
txt += l
|
85 |
sections = txt.split("\n")
|
86 |
sections = [s for s in sections if s]
|
87 |
callback(0.8, "Finish parsing.")
|
88 |
|
89 |
else:
|
90 |
-
raise NotImplementedError(
|
|
|
91 |
|
92 |
doc = {
|
93 |
"docnm_kwd": filename,
|
@@ -101,9 +109,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
101 |
if __name__ == "__main__":
|
102 |
import sys
|
103 |
|
104 |
-
|
105 |
def dummy(prog=None, msg=""):
|
106 |
pass
|
107 |
|
108 |
-
|
109 |
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
|
|
41 |
tbls = self._extract_table_figure(True, zoomin, True, True)
|
42 |
self._concat_downward()
|
43 |
|
44 |
+
sections = [(b["text"], self.get_position(b, zoomin))
|
45 |
+
for i, b in enumerate(self.boxes)]
|
46 |
for (img, rows), poss in tbls:
|
47 |
sections.append((rows if isinstance(rows, str) else rows[0],
|
48 |
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
49 |
+
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
50 |
+
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
|
51 |
|
52 |
|
53 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
54 |
+
lang="Chinese", callback=None, **kwargs):
|
55 |
"""
|
56 |
Supported file formats are docx, pdf, excel, txt.
|
57 |
One file forms a chunk which maintains original text order.
|
58 |
"""
|
59 |
|
60 |
+
eng = lang.lower() == "english" # is_english(cks)
|
61 |
|
62 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
63 |
callback(0.1, "Start to parse.")
|
|
|
65 |
callback(0.8, "Finish parsing.")
|
66 |
|
67 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
68 |
+
pdf_parser = Pdf() if kwargs.get(
|
69 |
+
"parser_config", {}).get(
|
70 |
+
"layout_recognize", True) else PlainParser()
|
71 |
+
sections, _ = pdf_parser(
|
72 |
+
filename if not binary else binary, to_page=to_page, callback=callback)
|
73 |
sections = [s for s, _ in sections if s]
|
74 |
|
75 |
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
|
|
86 |
with open(filename, "r") as f:
|
87 |
while True:
|
88 |
l = f.readline()
|
89 |
+
if not l:
|
90 |
+
break
|
91 |
txt += l
|
92 |
sections = txt.split("\n")
|
93 |
sections = [s for s in sections if s]
|
94 |
callback(0.8, "Finish parsing.")
|
95 |
|
96 |
else:
|
97 |
+
raise NotImplementedError(
|
98 |
+
"file type not supported yet(docx, pdf, txt supported)")
|
99 |
|
100 |
doc = {
|
101 |
"docnm_kwd": filename,
|
|
|
109 |
if __name__ == "__main__":
|
110 |
import sys
|
111 |
|
|
|
112 |
def dummy(prog=None, msg=""):
|
113 |
pass
|
114 |
|
|
|
115 |
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
rag/app/paper.py
CHANGED
@@ -67,11 +67,11 @@ class Pdf(PdfParser):
|
|
67 |
|
68 |
if from_page > 0:
|
69 |
return {
|
70 |
-
"title":"",
|
71 |
"authors": "",
|
72 |
"abstract": "",
|
73 |
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
|
74 |
-
|
75 |
"tables": tbls
|
76 |
}
|
77 |
# get title and authors
|
@@ -87,7 +87,8 @@ class Pdf(PdfParser):
|
|
87 |
title = ""
|
88 |
break
|
89 |
for j in range(3):
|
90 |
-
if _begin(self.boxes[i + j]["text"]):
|
|
|
91 |
authors.append(self.boxes[i + j]["text"])
|
92 |
break
|
93 |
break
|
@@ -107,10 +108,15 @@ class Pdf(PdfParser):
|
|
107 |
abstr = txt + self._line_tag(self.boxes[i], zoomin)
|
108 |
i += 1
|
109 |
break
|
110 |
-
if not abstr:
|
|
|
111 |
|
112 |
-
callback(
|
113 |
-
|
|
|
|
|
|
|
|
|
114 |
print(tbls)
|
115 |
|
116 |
return {
|
@@ -118,19 +124,20 @@ class Pdf(PdfParser):
|
|
118 |
"authors": " ".join(authors),
|
119 |
"abstract": abstr,
|
120 |
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
|
121 |
-
|
122 |
"tables": tbls
|
123 |
}
|
124 |
|
125 |
|
126 |
-
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
|
127 |
"""
|
128 |
Only pdf is supported.
|
129 |
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
130 |
"""
|
131 |
pdf_parser = None
|
132 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
133 |
-
if not kwargs.get("parser_config",{}).get("layout_recognize", True):
|
134 |
pdf_parser = PlainParser()
|
135 |
paper = {
|
136 |
"title": filename,
|
@@ -143,14 +150,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
143 |
pdf_parser = Pdf()
|
144 |
paper = pdf_parser(filename if not binary else binary,
|
145 |
from_page=from_page, to_page=to_page, callback=callback)
|
146 |
-
else:
|
|
|
147 |
|
148 |
doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]),
|
149 |
"title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
|
150 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
151 |
doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
|
152 |
# is it English
|
153 |
-
eng = lang.lower() == "english"#pdf_parser.is_english
|
154 |
print("It's English.....", eng)
|
155 |
|
156 |
res = tokenize_table(paper["tables"], doc, eng)
|
@@ -160,7 +168,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
160 |
txt = pdf_parser.remove_tag(paper["abstract"])
|
161 |
d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
|
162 |
d["important_tks"] = " ".join(d["important_kwd"])
|
163 |
-
d["image"], poss = pdf_parser.crop(
|
|
|
164 |
add_positions(d, poss)
|
165 |
tokenize(d, txt, eng)
|
166 |
res.append(d)
|
@@ -174,7 +183,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
174 |
sec_ids = []
|
175 |
sid = 0
|
176 |
for i, lvl in enumerate(levels):
|
177 |
-
if lvl <= most_level and i > 0 and lvl != levels[i-1]:
|
|
|
178 |
sec_ids.append(sid)
|
179 |
print(lvl, sorted_sections[i][0], most_level, sid)
|
180 |
|
@@ -190,6 +200,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
190 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
191 |
return res
|
192 |
|
|
|
193 |
"""
|
194 |
readed = [0] * len(paper["lines"])
|
195 |
# find colon firstly
|
@@ -212,7 +223,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
212 |
for k in range(j, i): readed[k] = True
|
213 |
txt = txt[::-1]
|
214 |
if eng:
|
215 |
-
r = re.search(r"(.*?) ([
|
216 |
txt = r.group(1)[::-1] if r else txt[::-1]
|
217 |
else:
|
218 |
r = re.search(r"(.*?) ([。?;!]|$)", txt)
|
@@ -270,6 +281,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
270 |
|
271 |
if __name__ == "__main__":
|
272 |
import sys
|
|
|
273 |
def dummy(prog=None, msg=""):
|
274 |
pass
|
275 |
chunk(sys.argv[1], callback=dummy)
|
|
|
67 |
|
68 |
if from_page > 0:
|
69 |
return {
|
70 |
+
"title": "",
|
71 |
"authors": "",
|
72 |
"abstract": "",
|
73 |
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
|
74 |
+
re.match(r"(text|title)", b.get("layoutno", "text"))],
|
75 |
"tables": tbls
|
76 |
}
|
77 |
# get title and authors
|
|
|
87 |
title = ""
|
88 |
break
|
89 |
for j in range(3):
|
90 |
+
if _begin(self.boxes[i + j]["text"]):
|
91 |
+
break
|
92 |
authors.append(self.boxes[i + j]["text"])
|
93 |
break
|
94 |
break
|
|
|
108 |
abstr = txt + self._line_tag(self.boxes[i], zoomin)
|
109 |
i += 1
|
110 |
break
|
111 |
+
if not abstr:
|
112 |
+
i = 0
|
113 |
|
114 |
+
callback(
|
115 |
+
0.8, "Page {}~{}: Text merging finished".format(
|
116 |
+
from_page, min(
|
117 |
+
to_page, self.total_page)))
|
118 |
+
for b in self.boxes:
|
119 |
+
print(b["text"], b.get("layoutno"))
|
120 |
print(tbls)
|
121 |
|
122 |
return {
|
|
|
124 |
"authors": " ".join(authors),
|
125 |
"abstract": abstr,
|
126 |
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
|
127 |
+
re.match(r"(text|title)", b.get("layoutno", "text"))],
|
128 |
"tables": tbls
|
129 |
}
|
130 |
|
131 |
|
132 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
133 |
+
lang="Chinese", callback=None, **kwargs):
|
134 |
"""
|
135 |
Only pdf is supported.
|
136 |
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
137 |
"""
|
138 |
pdf_parser = None
|
139 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
140 |
+
if not kwargs.get("parser_config", {}).get("layout_recognize", True):
|
141 |
pdf_parser = PlainParser()
|
142 |
paper = {
|
143 |
"title": filename,
|
|
|
150 |
pdf_parser = Pdf()
|
151 |
paper = pdf_parser(filename if not binary else binary,
|
152 |
from_page=from_page, to_page=to_page, callback=callback)
|
153 |
+
else:
|
154 |
+
raise NotImplementedError("file type not supported yet(pdf supported)")
|
155 |
|
156 |
doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]),
|
157 |
"title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
|
158 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
159 |
doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
|
160 |
# is it English
|
161 |
+
eng = lang.lower() == "english" # pdf_parser.is_english
|
162 |
print("It's English.....", eng)
|
163 |
|
164 |
res = tokenize_table(paper["tables"], doc, eng)
|
|
|
168 |
txt = pdf_parser.remove_tag(paper["abstract"])
|
169 |
d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
|
170 |
d["important_tks"] = " ".join(d["important_kwd"])
|
171 |
+
d["image"], poss = pdf_parser.crop(
|
172 |
+
paper["abstract"], need_position=True)
|
173 |
add_positions(d, poss)
|
174 |
tokenize(d, txt, eng)
|
175 |
res.append(d)
|
|
|
183 |
sec_ids = []
|
184 |
sid = 0
|
185 |
for i, lvl in enumerate(levels):
|
186 |
+
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
187 |
+
sid += 1
|
188 |
sec_ids.append(sid)
|
189 |
print(lvl, sorted_sections[i][0], most_level, sid)
|
190 |
|
|
|
200 |
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
201 |
return res
|
202 |
|
203 |
+
|
204 |
"""
|
205 |
readed = [0] * len(paper["lines"])
|
206 |
# find colon firstly
|
|
|
223 |
for k in range(j, i): readed[k] = True
|
224 |
txt = txt[::-1]
|
225 |
if eng:
|
226 |
+
r = re.search(r"(.*?) ([\\.;?!]|$)", txt)
|
227 |
txt = r.group(1)[::-1] if r else txt[::-1]
|
228 |
else:
|
229 |
r = re.search(r"(.*?) ([。?;!]|$)", txt)
|
|
|
281 |
|
282 |
if __name__ == "__main__":
|
283 |
import sys
|
284 |
+
|
285 |
def dummy(prog=None, msg=""):
|
286 |
pass
|
287 |
chunk(sys.argv[1], callback=dummy)
|
rag/app/presentation.py
CHANGED
@@ -33,9 +33,12 @@ class Ppt(PptParser):
|
|
33 |
with slides.Presentation(BytesIO(fnm)) as presentation:
|
34 |
for i, slide in enumerate(presentation.slides[from_page: to_page]):
|
35 |
buffered = BytesIO()
|
36 |
-
slide.get_thumbnail(
|
|
|
|
|
37 |
imgs.append(Image.open(buffered))
|
38 |
-
assert len(imgs) == len(
|
|
|
39 |
callback(0.9, "Image extraction finished")
|
40 |
self.is_english = is_english(txts)
|
41 |
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
@@ -47,25 +50,34 @@ class Pdf(PdfParser):
|
|
47 |
|
48 |
def __garbage(self, txt):
|
49 |
txt = txt.lower().strip()
|
50 |
-
if re.match(r"[0-9\.,%/-]+$", txt):
|
51 |
-
|
|
|
|
|
52 |
return False
|
53 |
|
54 |
-
def __call__(self, filename, binary=None, from_page=0,
|
|
|
55 |
callback(msg="OCR is running...")
|
56 |
-
self.__images__(filename if not binary else binary,
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
res = []
|
60 |
for i in range(len(self.boxes)):
|
61 |
-
lines = "\n".join([b["text"] for b in self.boxes[i]
|
|
|
62 |
res.append((lines, self.page_images[i]))
|
63 |
-
callback(0.9, "Page {}~{}: Parsing finished".format(
|
|
|
64 |
return res
|
65 |
|
66 |
|
67 |
class PlainPdf(PlainParser):
|
68 |
-
def __call__(self, filename, binary=None, from_page=0,
|
|
|
69 |
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
|
70 |
page_txt = []
|
71 |
for page in self.pdf.pages[from_page: to_page]:
|
@@ -74,7 +86,8 @@ class PlainPdf(PlainParser):
|
|
74 |
return [(txt, None) for txt in page_txt]
|
75 |
|
76 |
|
77 |
-
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
|
78 |
"""
|
79 |
The supported file formats are pdf, pptx.
|
80 |
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
@@ -89,35 +102,42 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
89 |
res = []
|
90 |
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
91 |
ppt_parser = Ppt()
|
92 |
-
for pn, (txt,img) in enumerate(ppt_parser(
|
|
|
93 |
d = copy.deepcopy(doc)
|
94 |
pn += from_page
|
95 |
d["image"] = img
|
96 |
-
d["page_num_int"] = [pn+1]
|
97 |
d["top_int"] = [0]
|
98 |
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
99 |
tokenize(d, txt, eng)
|
100 |
res.append(d)
|
101 |
return res
|
102 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
103 |
-
pdf_parser = Pdf() if kwargs.get(
|
104 |
-
|
|
|
|
|
|
|
105 |
d = copy.deepcopy(doc)
|
106 |
pn += from_page
|
107 |
-
if img:
|
108 |
-
|
|
|
109 |
d["top_int"] = [0]
|
110 |
-
d["position_int"] = [
|
|
|
111 |
tokenize(d, txt, eng)
|
112 |
res.append(d)
|
113 |
return res
|
114 |
|
115 |
-
raise NotImplementedError(
|
|
|
116 |
|
117 |
|
118 |
-
if __name__== "__main__":
|
119 |
import sys
|
|
|
120 |
def dummy(a, b):
|
121 |
pass
|
122 |
chunk(sys.argv[1], callback=dummy)
|
123 |
-
|
|
|
33 |
with slides.Presentation(BytesIO(fnm)) as presentation:
|
34 |
for i, slide in enumerate(presentation.slides[from_page: to_page]):
|
35 |
buffered = BytesIO()
|
36 |
+
slide.get_thumbnail(
|
37 |
+
0.5, 0.5).save(
|
38 |
+
buffered, drawing.imaging.ImageFormat.jpeg)
|
39 |
imgs.append(Image.open(buffered))
|
40 |
+
assert len(imgs) == len(
|
41 |
+
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
42 |
callback(0.9, "Image extraction finished")
|
43 |
self.is_english = is_english(txts)
|
44 |
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
|
|
50 |
|
51 |
def __garbage(self, txt):
|
52 |
txt = txt.lower().strip()
|
53 |
+
if re.match(r"[0-9\.,%/-]+$", txt):
|
54 |
+
return True
|
55 |
+
if len(txt) < 3:
|
56 |
+
return True
|
57 |
return False
|
58 |
|
59 |
+
def __call__(self, filename, binary=None, from_page=0,
|
60 |
+
to_page=100000, zoomin=3, callback=None):
|
61 |
callback(msg="OCR is running...")
|
62 |
+
self.__images__(filename if not binary else binary,
|
63 |
+
zoomin, from_page, to_page, callback)
|
64 |
+
callback(0.8, "Page {}~{}: OCR finished".format(
|
65 |
+
from_page, min(to_page, self.total_page)))
|
66 |
+
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
|
67 |
+
len(self.boxes), len(self.page_images))
|
68 |
res = []
|
69 |
for i in range(len(self.boxes)):
|
70 |
+
lines = "\n".join([b["text"] for b in self.boxes[i]
|
71 |
+
if not self.__garbage(b["text"])])
|
72 |
res.append((lines, self.page_images[i]))
|
73 |
+
callback(0.9, "Page {}~{}: Parsing finished".format(
|
74 |
+
from_page, min(to_page, self.total_page)))
|
75 |
return res
|
76 |
|
77 |
|
78 |
class PlainPdf(PlainParser):
|
79 |
+
def __call__(self, filename, binary=None, from_page=0,
|
80 |
+
to_page=100000, callback=None, **kwargs):
|
81 |
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
|
82 |
page_txt = []
|
83 |
for page in self.pdf.pages[from_page: to_page]:
|
|
|
86 |
return [(txt, None) for txt in page_txt]
|
87 |
|
88 |
|
89 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
90 |
+
lang="Chinese", callback=None, **kwargs):
|
91 |
"""
|
92 |
The supported file formats are pdf, pptx.
|
93 |
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
|
|
102 |
res = []
|
103 |
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
104 |
ppt_parser = Ppt()
|
105 |
+
for pn, (txt, img) in enumerate(ppt_parser(
|
106 |
+
filename if not binary else binary, from_page, 1000000, callback)):
|
107 |
d = copy.deepcopy(doc)
|
108 |
pn += from_page
|
109 |
d["image"] = img
|
110 |
+
d["page_num_int"] = [pn + 1]
|
111 |
d["top_int"] = [0]
|
112 |
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
113 |
tokenize(d, txt, eng)
|
114 |
res.append(d)
|
115 |
return res
|
116 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
117 |
+
pdf_parser = Pdf() if kwargs.get(
|
118 |
+
"parser_config", {}).get(
|
119 |
+
"layout_recognize", True) else PlainPdf()
|
120 |
+
for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
|
121 |
+
from_page=from_page, to_page=to_page, callback=callback)):
|
122 |
d = copy.deepcopy(doc)
|
123 |
pn += from_page
|
124 |
+
if img:
|
125 |
+
d["image"] = img
|
126 |
+
d["page_num_int"] = [pn + 1]
|
127 |
d["top_int"] = [0]
|
128 |
+
d["position_int"] = [
|
129 |
+
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
130 |
tokenize(d, txt, eng)
|
131 |
res.append(d)
|
132 |
return res
|
133 |
|
134 |
+
raise NotImplementedError(
|
135 |
+
"file type not supported yet(pptx, pdf supported)")
|
136 |
|
137 |
|
138 |
+
if __name__ == "__main__":
|
139 |
import sys
|
140 |
+
|
141 |
def dummy(a, b):
|
142 |
pass
|
143 |
chunk(sys.argv[1], callback=dummy)
|
|
rag/app/resume.py
CHANGED
@@ -27,6 +27,8 @@ from rag.utils import rmSpace
|
|
27 |
forbidden_select_fields4resume = [
|
28 |
"name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
|
29 |
]
|
|
|
|
|
30 |
def remote_call(filename, binary):
|
31 |
q = {
|
32 |
"header": {
|
@@ -48,18 +50,22 @@ def remote_call(filename, binary):
|
|
48 |
}
|
49 |
for _ in range(3):
|
50 |
try:
|
51 |
-
resume = requests.post(
|
|
|
|
|
52 |
resume = resume.json()["response"]["results"]
|
53 |
resume = refactor(resume)
|
54 |
-
for k in ["education", "work", "project",
|
55 |
-
|
|
|
|
|
56 |
|
57 |
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
|
58 |
-
|
59 |
resume = step_two.parse(resume)
|
60 |
return resume
|
61 |
except Exception as e:
|
62 |
-
cron_logger.error("Resume parser error: "+str(e))
|
63 |
return {}
|
64 |
|
65 |
|
@@ -144,10 +150,13 @@ def chunk(filename, binary=None, callback=None, **kwargs):
|
|
144 |
doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
|
145 |
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
|
146 |
for n, _ in field_map.items():
|
147 |
-
if n not in resume:
|
148 |
-
|
|
|
|
|
149 |
resume[n] = resume[n][0]
|
150 |
-
if n.find("_tks")>0:
|
|
|
151 |
doc[n] = resume[n]
|
152 |
|
153 |
print(doc)
|
|
|
27 |
forbidden_select_fields4resume = [
|
28 |
"name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
|
29 |
]
|
30 |
+
|
31 |
+
|
32 |
def remote_call(filename, binary):
|
33 |
q = {
|
34 |
"header": {
|
|
|
50 |
}
|
51 |
for _ in range(3):
|
52 |
try:
|
53 |
+
resume = requests.post(
|
54 |
+
"http://127.0.0.1:61670/tog",
|
55 |
+
data=json.dumps(q))
|
56 |
resume = resume.json()["response"]["results"]
|
57 |
resume = refactor(resume)
|
58 |
+
for k in ["education", "work", "project",
|
59 |
+
"training", "skill", "certificate", "language"]:
|
60 |
+
if not resume.get(k) and k in resume:
|
61 |
+
del resume[k]
|
62 |
|
63 |
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
|
64 |
+
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
|
65 |
resume = step_two.parse(resume)
|
66 |
return resume
|
67 |
except Exception as e:
|
68 |
+
cron_logger.error("Resume parser error: " + str(e))
|
69 |
return {}
|
70 |
|
71 |
|
|
|
150 |
doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
|
151 |
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
|
152 |
for n, _ in field_map.items():
|
153 |
+
if n not in resume:
|
154 |
+
continue
|
155 |
+
if isinstance(resume[n], list) and (
|
156 |
+
len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
|
157 |
resume[n] = resume[n][0]
|
158 |
+
if n.find("_tks") > 0:
|
159 |
+
resume[n] = huqie.qieqie(resume[n])
|
160 |
doc[n] = resume[n]
|
161 |
|
162 |
print(doc)
|
rag/app/table.py
CHANGED
@@ -25,7 +25,8 @@ from deepdoc.parser import ExcelParser
|
|
25 |
|
26 |
|
27 |
class Excel(ExcelParser):
|
28 |
-
def __call__(self, fnm, binary=None, from_page=0,
|
|
|
29 |
if not binary:
|
30 |
wb = load_workbook(fnm)
|
31 |
else:
|
@@ -48,8 +49,10 @@ class Excel(ExcelParser):
|
|
48 |
data = []
|
49 |
for i, r in enumerate(rows[1:]):
|
50 |
rn += 1
|
51 |
-
if rn-1 < from_page:
|
52 |
-
|
|
|
|
|
53 |
row = [
|
54 |
cell.value for ii,
|
55 |
cell in enumerate(r) if ii not in missed]
|
@@ -60,7 +63,7 @@ class Excel(ExcelParser):
|
|
60 |
done += 1
|
61 |
res.append(pd.DataFrame(np.array(data), columns=headers))
|
62 |
|
63 |
-
callback(0.3, ("Extract records: {}~{}".format(from_page+1, min(to_page, from_page+rn)) + (
|
64 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
65 |
return res
|
66 |
|
@@ -73,7 +76,8 @@ def trans_datatime(s):
|
|
73 |
|
74 |
|
75 |
def trans_bool(s):
|
76 |
-
if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$",
|
|
|
77 |
return "yes"
|
78 |
if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
|
79 |
return "no"
|
@@ -107,13 +111,14 @@ def column_data_type(arr):
|
|
107 |
arr[i] = trans[ty](str(arr[i]))
|
108 |
except Exception as e:
|
109 |
arr[i] = None
|
110 |
-
#if ty == "text":
|
111 |
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
112 |
# ty = "keyword"
|
113 |
return arr, ty
|
114 |
|
115 |
|
116 |
-
def chunk(filename, binary=None, from_page=0, to_page=10000000000,
|
|
|
117 |
"""
|
118 |
Excel and csv(txt) format files are supported.
|
119 |
For csv or txt file, the delimiter between columns is TAB.
|
@@ -131,7 +136,12 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
|
131 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
132 |
callback(0.1, "Start to parse.")
|
133 |
excel_parser = Excel()
|
134 |
-
dfs = excel_parser(
|
|
|
|
|
|
|
|
|
|
|
135 |
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
|
136 |
callback(0.1, "Start to parse.")
|
137 |
txt = ""
|
@@ -149,8 +159,10 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
|
149 |
headers = lines[0].split(kwargs.get("delimiter", "\t"))
|
150 |
rows = []
|
151 |
for i, line in enumerate(lines[1:]):
|
152 |
-
if i < from_page:
|
153 |
-
|
|
|
|
|
154 |
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
|
155 |
if len(row) != len(headers):
|
156 |
fails.append(str(i))
|
@@ -181,7 +193,13 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
|
181 |
del df[n]
|
182 |
clmns = df.columns.values
|
183 |
txts = list(copy.deepcopy(clmns))
|
184 |
-
py_clmns = [
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
clmn_tys = []
|
186 |
for j in range(len(clmns)):
|
187 |
cln, ty = column_data_type(df[clmns[j]])
|
@@ -192,7 +210,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
|
192 |
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " "))
|
193 |
for i in range(len(clmns))]
|
194 |
|
195 |
-
eng = lang.lower() == "english"#is_english(txts)
|
196 |
for ii, row in df.iterrows():
|
197 |
d = {
|
198 |
"docnm_kwd": filename,
|
|
|
25 |
|
26 |
|
27 |
class Excel(ExcelParser):
|
28 |
+
def __call__(self, fnm, binary=None, from_page=0,
|
29 |
+
to_page=10000000000, callback=None):
|
30 |
if not binary:
|
31 |
wb = load_workbook(fnm)
|
32 |
else:
|
|
|
49 |
data = []
|
50 |
for i, r in enumerate(rows[1:]):
|
51 |
rn += 1
|
52 |
+
if rn - 1 < from_page:
|
53 |
+
continue
|
54 |
+
if rn - 1 >= to_page:
|
55 |
+
break
|
56 |
row = [
|
57 |
cell.value for ii,
|
58 |
cell in enumerate(r) if ii not in missed]
|
|
|
63 |
done += 1
|
64 |
res.append(pd.DataFrame(np.array(data), columns=headers))
|
65 |
|
66 |
+
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
|
67 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
68 |
return res
|
69 |
|
|
|
76 |
|
77 |
|
78 |
def trans_bool(s):
|
79 |
+
if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$",
|
80 |
+
str(s).strip(), flags=re.IGNORECASE):
|
81 |
return "yes"
|
82 |
if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
|
83 |
return "no"
|
|
|
111 |
arr[i] = trans[ty](str(arr[i]))
|
112 |
except Exception as e:
|
113 |
arr[i] = None
|
114 |
+
# if ty == "text":
|
115 |
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
116 |
# ty = "keyword"
|
117 |
return arr, ty
|
118 |
|
119 |
|
120 |
+
def chunk(filename, binary=None, from_page=0, to_page=10000000000,
|
121 |
+
lang="Chinese", callback=None, **kwargs):
|
122 |
"""
|
123 |
Excel and csv(txt) format files are supported.
|
124 |
For csv or txt file, the delimiter between columns is TAB.
|
|
|
136 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
137 |
callback(0.1, "Start to parse.")
|
138 |
excel_parser = Excel()
|
139 |
+
dfs = excel_parser(
|
140 |
+
filename,
|
141 |
+
binary,
|
142 |
+
from_page=from_page,
|
143 |
+
to_page=to_page,
|
144 |
+
callback=callback)
|
145 |
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
|
146 |
callback(0.1, "Start to parse.")
|
147 |
txt = ""
|
|
|
159 |
headers = lines[0].split(kwargs.get("delimiter", "\t"))
|
160 |
rows = []
|
161 |
for i, line in enumerate(lines[1:]):
|
162 |
+
if i < from_page:
|
163 |
+
continue
|
164 |
+
if i >= to_page:
|
165 |
+
break
|
166 |
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
|
167 |
if len(row) != len(headers):
|
168 |
fails.append(str(i))
|
|
|
193 |
del df[n]
|
194 |
clmns = df.columns.values
|
195 |
txts = list(copy.deepcopy(clmns))
|
196 |
+
py_clmns = [
|
197 |
+
PY.get_pinyins(
|
198 |
+
re.sub(
|
199 |
+
r"(/.*|([^()]+?)|\([^()]+?\))",
|
200 |
+
"",
|
201 |
+
n),
|
202 |
+
'_')[0] for n in clmns]
|
203 |
clmn_tys = []
|
204 |
for j in range(len(clmns)):
|
205 |
cln, ty = column_data_type(df[clmns[j]])
|
|
|
210 |
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " "))
|
211 |
for i in range(len(clmns))]
|
212 |
|
213 |
+
eng = lang.lower() == "english" # is_english(txts)
|
214 |
for ii, row in df.iterrows():
|
215 |
d = {
|
216 |
"docnm_kwd": filename,
|
rag/llm/chat_model.py
CHANGED
@@ -13,6 +13,8 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
|
|
16 |
from abc import ABC
|
17 |
from openai import OpenAI
|
18 |
import openai
|
@@ -34,7 +36,8 @@ class GptTurbo(Base):
|
|
34 |
self.model_name = model_name
|
35 |
|
36 |
def chat(self, system, history, gen_conf):
|
37 |
-
if system:
|
|
|
38 |
try:
|
39 |
response = self.client.chat.completions.create(
|
40 |
model=self.model_name,
|
@@ -46,16 +49,18 @@ class GptTurbo(Base):
|
|
46 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
47 |
return ans, response.usage.completion_tokens
|
48 |
except openai.APIError as e:
|
49 |
-
return "**ERROR**: "+str(e), 0
|
50 |
|
51 |
|
52 |
class MoonshotChat(GptTurbo):
|
53 |
def __init__(self, key, model_name="moonshot-v1-8k"):
|
54 |
-
self.client = OpenAI(
|
|
|
55 |
self.model_name = model_name
|
56 |
|
57 |
def chat(self, system, history, gen_conf):
|
58 |
-
if system:
|
|
|
59 |
try:
|
60 |
response = self.client.chat.completions.create(
|
61 |
model=self.model_name,
|
@@ -67,10 +72,9 @@ class MoonshotChat(GptTurbo):
|
|
67 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
68 |
return ans, response.usage.completion_tokens
|
69 |
except openai.APIError as e:
|
70 |
-
return "**ERROR**: "+str(e), 0
|
71 |
|
72 |
|
73 |
-
from dashscope import Generation
|
74 |
class QWenChat(Base):
|
75 |
def __init__(self, key, model_name=Generation.Models.qwen_turbo):
|
76 |
import dashscope
|
@@ -79,7 +83,8 @@ class QWenChat(Base):
|
|
79 |
|
80 |
def chat(self, system, history, gen_conf):
|
81 |
from http import HTTPStatus
|
82 |
-
if system:
|
|
|
83 |
response = Generation.call(
|
84 |
self.model_name,
|
85 |
messages=history,
|
@@ -92,20 +97,21 @@ class QWenChat(Base):
|
|
92 |
ans += response.output.choices[0]['message']['content']
|
93 |
tk_count += response.usage.output_tokens
|
94 |
if response.output.choices[0].get("finish_reason", "") == "length":
|
95 |
-
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
|
|
96 |
return ans, tk_count
|
97 |
|
98 |
return "**ERROR**: " + response.message, tk_count
|
99 |
|
100 |
|
101 |
-
from zhipuai import ZhipuAI
|
102 |
class ZhipuChat(Base):
|
103 |
def __init__(self, key, model_name="glm-3-turbo"):
|
104 |
self.client = ZhipuAI(api_key=key)
|
105 |
self.model_name = model_name
|
106 |
|
107 |
def chat(self, system, history, gen_conf):
|
108 |
-
if system:
|
|
|
109 |
try:
|
110 |
response = self.client.chat.completions.create(
|
111 |
self.model_name,
|
@@ -120,6 +126,7 @@ class ZhipuChat(Base):
|
|
120 |
except Exception as e:
|
121 |
return "**ERROR**: " + str(e), 0
|
122 |
|
|
|
123 |
class LocalLLM(Base):
|
124 |
class RPCProxy:
|
125 |
def __init__(self, host, port):
|
@@ -129,14 +136,17 @@ class LocalLLM(Base):
|
|
129 |
|
130 |
def __conn(self):
|
131 |
from multiprocessing.connection import Client
|
132 |
-
self._connection = Client(
|
|
|
133 |
|
134 |
def __getattr__(self, name):
|
135 |
import pickle
|
|
|
136 |
def do_rpc(*args, **kwargs):
|
137 |
for _ in range(3):
|
138 |
try:
|
139 |
-
self._connection.send(
|
|
|
140 |
return pickle.loads(self._connection.recv())
|
141 |
except Exception as e:
|
142 |
self.__conn()
|
@@ -148,7 +158,8 @@ class LocalLLM(Base):
|
|
148 |
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
|
149 |
|
150 |
def chat(self, system, history, gen_conf):
|
151 |
-
if system:
|
|
|
152 |
try:
|
153 |
ans = self.client.chat(
|
154 |
history,
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from zhipuai import ZhipuAI
|
17 |
+
from dashscope import Generation
|
18 |
from abc import ABC
|
19 |
from openai import OpenAI
|
20 |
import openai
|
|
|
36 |
self.model_name = model_name
|
37 |
|
38 |
def chat(self, system, history, gen_conf):
|
39 |
+
if system:
|
40 |
+
history.insert(0, {"role": "system", "content": system})
|
41 |
try:
|
42 |
response = self.client.chat.completions.create(
|
43 |
model=self.model_name,
|
|
|
49 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
50 |
return ans, response.usage.completion_tokens
|
51 |
except openai.APIError as e:
|
52 |
+
return "**ERROR**: " + str(e), 0
|
53 |
|
54 |
|
55 |
class MoonshotChat(GptTurbo):
|
56 |
def __init__(self, key, model_name="moonshot-v1-8k"):
|
57 |
+
self.client = OpenAI(
|
58 |
+
api_key=key, base_url="https://api.moonshot.cn/v1",)
|
59 |
self.model_name = model_name
|
60 |
|
61 |
def chat(self, system, history, gen_conf):
|
62 |
+
if system:
|
63 |
+
history.insert(0, {"role": "system", "content": system})
|
64 |
try:
|
65 |
response = self.client.chat.completions.create(
|
66 |
model=self.model_name,
|
|
|
72 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
73 |
return ans, response.usage.completion_tokens
|
74 |
except openai.APIError as e:
|
75 |
+
return "**ERROR**: " + str(e), 0
|
76 |
|
77 |
|
|
|
78 |
class QWenChat(Base):
|
79 |
def __init__(self, key, model_name=Generation.Models.qwen_turbo):
|
80 |
import dashscope
|
|
|
83 |
|
84 |
def chat(self, system, history, gen_conf):
|
85 |
from http import HTTPStatus
|
86 |
+
if system:
|
87 |
+
history.insert(0, {"role": "system", "content": system})
|
88 |
response = Generation.call(
|
89 |
self.model_name,
|
90 |
messages=history,
|
|
|
97 |
ans += response.output.choices[0]['message']['content']
|
98 |
tk_count += response.usage.output_tokens
|
99 |
if response.output.choices[0].get("finish_reason", "") == "length":
|
100 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
101 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
102 |
return ans, tk_count
|
103 |
|
104 |
return "**ERROR**: " + response.message, tk_count
|
105 |
|
106 |
|
|
|
107 |
class ZhipuChat(Base):
|
108 |
def __init__(self, key, model_name="glm-3-turbo"):
|
109 |
self.client = ZhipuAI(api_key=key)
|
110 |
self.model_name = model_name
|
111 |
|
112 |
def chat(self, system, history, gen_conf):
|
113 |
+
if system:
|
114 |
+
history.insert(0, {"role": "system", "content": system})
|
115 |
try:
|
116 |
response = self.client.chat.completions.create(
|
117 |
self.model_name,
|
|
|
126 |
except Exception as e:
|
127 |
return "**ERROR**: " + str(e), 0
|
128 |
|
129 |
+
|
130 |
class LocalLLM(Base):
|
131 |
class RPCProxy:
|
132 |
def __init__(self, host, port):
|
|
|
136 |
|
137 |
def __conn(self):
|
138 |
from multiprocessing.connection import Client
|
139 |
+
self._connection = Client(
|
140 |
+
(self.host, self.port), authkey=b'infiniflow-token4kevinhu')
|
141 |
|
142 |
def __getattr__(self, name):
|
143 |
import pickle
|
144 |
+
|
145 |
def do_rpc(*args, **kwargs):
|
146 |
for _ in range(3):
|
147 |
try:
|
148 |
+
self._connection.send(
|
149 |
+
pickle.dumps((name, args, kwargs)))
|
150 |
return pickle.loads(self._connection.recv())
|
151 |
except Exception as e:
|
152 |
self.__conn()
|
|
|
158 |
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
|
159 |
|
160 |
def chat(self, system, history, gen_conf):
|
161 |
+
if system:
|
162 |
+
history.insert(0, {"role": "system", "content": system})
|
163 |
try:
|
164 |
ans = self.client.chat(
|
165 |
history,
|
rag/llm/cv_model.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
import io
|
17 |
from abc import ABC
|
18 |
|
@@ -57,8 +58,8 @@ class Base(ABC):
|
|
57 |
},
|
58 |
},
|
59 |
{
|
60 |
-
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
61 |
-
|
62 |
},
|
63 |
],
|
64 |
}
|
@@ -92,8 +93,9 @@ class QWenCV(Base):
|
|
92 |
def prompt(self, binary):
|
93 |
# stupid as hell
|
94 |
tmp_dir = get_project_base_directory("tmp")
|
95 |
-
if not os.path.exists(tmp_dir):
|
96 |
-
|
|
|
97 |
Image.open(io.BytesIO(binary)).save(path)
|
98 |
return [
|
99 |
{
|
@@ -103,8 +105,8 @@ class QWenCV(Base):
|
|
103 |
"image": f"file://{path}"
|
104 |
},
|
105 |
{
|
106 |
-
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
107 |
-
|
108 |
},
|
109 |
],
|
110 |
}
|
@@ -120,9 +122,6 @@ class QWenCV(Base):
|
|
120 |
return response.message, 0
|
121 |
|
122 |
|
123 |
-
from zhipuai import ZhipuAI
|
124 |
-
|
125 |
-
|
126 |
class Zhipu4V(Base):
|
127 |
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
|
128 |
self.client = ZhipuAI(api_key=key)
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from zhipuai import ZhipuAI
|
17 |
import io
|
18 |
from abc import ABC
|
19 |
|
|
|
58 |
},
|
59 |
},
|
60 |
{
|
61 |
+
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
62 |
+
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
63 |
},
|
64 |
],
|
65 |
}
|
|
|
93 |
def prompt(self, binary):
|
94 |
# stupid as hell
|
95 |
tmp_dir = get_project_base_directory("tmp")
|
96 |
+
if not os.path.exists(tmp_dir):
|
97 |
+
os.mkdir(tmp_dir)
|
98 |
+
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
|
99 |
Image.open(io.BytesIO(binary)).save(path)
|
100 |
return [
|
101 |
{
|
|
|
105 |
"image": f"file://{path}"
|
106 |
},
|
107 |
{
|
108 |
+
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
109 |
+
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
110 |
},
|
111 |
],
|
112 |
}
|
|
|
122 |
return response.message, 0
|
123 |
|
124 |
|
|
|
|
|
|
|
125 |
class Zhipu4V(Base):
|
126 |
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
|
127 |
self.client = ZhipuAI(api_key=key)
|
rag/llm/embedding_model.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
import os
|
17 |
from abc import ABC
|
18 |
|
@@ -40,11 +41,11 @@ flag_model = FlagModel(model_dir,
|
|
40 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
41 |
use_fp16=torch.cuda.is_available())
|
42 |
|
|
|
43 |
class Base(ABC):
|
44 |
def __init__(self, key, model_name):
|
45 |
pass
|
46 |
|
47 |
-
|
48 |
def encode(self, texts: list, batch_size=32):
|
49 |
raise NotImplementedError("Please implement encode method!")
|
50 |
|
@@ -67,11 +68,11 @@ class HuEmbedding(Base):
|
|
67 |
"""
|
68 |
self.model = flag_model
|
69 |
|
70 |
-
|
71 |
def encode(self, texts: list, batch_size=32):
|
72 |
texts = [t[:2000] for t in texts]
|
73 |
token_count = 0
|
74 |
-
for t in texts:
|
|
|
75 |
res = []
|
76 |
for i in range(0, len(texts), batch_size):
|
77 |
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
@@ -90,7 +91,8 @@ class OpenAIEmbed(Base):
|
|
90 |
def encode(self, texts: list, batch_size=32):
|
91 |
res = self.client.embeddings.create(input=texts,
|
92 |
model=self.model_name)
|
93 |
-
return np.array([d.embedding for d in res.data]
|
|
|
94 |
|
95 |
def encode_queries(self, text):
|
96 |
res = self.client.embeddings.create(input=[text],
|
@@ -111,7 +113,7 @@ class QWenEmbed(Base):
|
|
111 |
for i in range(0, len(texts), batch_size):
|
112 |
resp = dashscope.TextEmbedding.call(
|
113 |
model=self.model_name,
|
114 |
-
input=texts[i:i+batch_size],
|
115 |
text_type="document"
|
116 |
)
|
117 |
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
@@ -123,14 +125,14 @@ class QWenEmbed(Base):
|
|
123 |
|
124 |
def encode_queries(self, text):
|
125 |
resp = dashscope.TextEmbedding.call(
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
return np.array(resp["output"]["embeddings"][0]
|
|
|
131 |
|
132 |
|
133 |
-
from zhipuai import ZhipuAI
|
134 |
class ZhipuEmbed(Base):
|
135 |
def __init__(self, key, model_name="embedding-2"):
|
136 |
self.client = ZhipuAI(api_key=key)
|
@@ -139,9 +141,10 @@ class ZhipuEmbed(Base):
|
|
139 |
def encode(self, texts: list, batch_size=32):
|
140 |
res = self.client.embeddings.create(input=texts,
|
141 |
model=self.model_name)
|
142 |
-
return np.array([d.embedding for d in res.data]
|
|
|
143 |
|
144 |
def encode_queries(self, text):
|
145 |
res = self.client.embeddings.create(input=text,
|
146 |
model=self.model_name)
|
147 |
-
return np.array(res["data"][0]["embedding"]), res.usage.total_tokens
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from zhipuai import ZhipuAI
|
17 |
import os
|
18 |
from abc import ABC
|
19 |
|
|
|
41 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
42 |
use_fp16=torch.cuda.is_available())
|
43 |
|
44 |
+
|
45 |
class Base(ABC):
|
46 |
def __init__(self, key, model_name):
|
47 |
pass
|
48 |
|
|
|
49 |
def encode(self, texts: list, batch_size=32):
|
50 |
raise NotImplementedError("Please implement encode method!")
|
51 |
|
|
|
68 |
"""
|
69 |
self.model = flag_model
|
70 |
|
|
|
71 |
def encode(self, texts: list, batch_size=32):
|
72 |
texts = [t[:2000] for t in texts]
|
73 |
token_count = 0
|
74 |
+
for t in texts:
|
75 |
+
token_count += num_tokens_from_string(t)
|
76 |
res = []
|
77 |
for i in range(0, len(texts), batch_size):
|
78 |
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
|
|
91 |
def encode(self, texts: list, batch_size=32):
|
92 |
res = self.client.embeddings.create(input=texts,
|
93 |
model=self.model_name)
|
94 |
+
return np.array([d.embedding for d in res.data]
|
95 |
+
), res.usage.total_tokens
|
96 |
|
97 |
def encode_queries(self, text):
|
98 |
res = self.client.embeddings.create(input=[text],
|
|
|
113 |
for i in range(0, len(texts), batch_size):
|
114 |
resp = dashscope.TextEmbedding.call(
|
115 |
model=self.model_name,
|
116 |
+
input=texts[i:i + batch_size],
|
117 |
text_type="document"
|
118 |
)
|
119 |
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
|
|
125 |
|
126 |
def encode_queries(self, text):
|
127 |
resp = dashscope.TextEmbedding.call(
|
128 |
+
model=self.model_name,
|
129 |
+
input=text[:2048],
|
130 |
+
text_type="query"
|
131 |
+
)
|
132 |
+
return np.array(resp["output"]["embeddings"][0]
|
133 |
+
["embedding"]), resp["usage"]["total_tokens"]
|
134 |
|
135 |
|
|
|
136 |
class ZhipuEmbed(Base):
|
137 |
def __init__(self, key, model_name="embedding-2"):
|
138 |
self.client = ZhipuAI(api_key=key)
|
|
|
141 |
def encode(self, texts: list, batch_size=32):
|
142 |
res = self.client.embeddings.create(input=texts,
|
143 |
model=self.model_name)
|
144 |
+
return np.array([d.embedding for d in res.data]
|
145 |
+
), res.usage.total_tokens
|
146 |
|
147 |
def encode_queries(self, text):
|
148 |
res = self.client.embeddings.create(input=text,
|
149 |
model=self.model_name)
|
150 |
+
return np.array(res["data"][0]["embedding"]), res.usage.total_tokens
|
rag/llm/rpc_server.py
CHANGED
@@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
9 |
|
10 |
class RPCHandler:
|
11 |
def __init__(self):
|
12 |
-
self._functions = {
|
13 |
|
14 |
def register_function(self, func):
|
15 |
self._functions[func.__name__] = func
|
@@ -21,12 +21,12 @@ class RPCHandler:
|
|
21 |
func_name, args, kwargs = pickle.loads(connection.recv())
|
22 |
# Run the RPC and send a response
|
23 |
try:
|
24 |
-
r = self._functions[func_name](*args
|
25 |
connection.send(pickle.dumps(r))
|
26 |
except Exception as e:
|
27 |
connection.send(pickle.dumps(e))
|
28 |
except EOFError:
|
29 |
-
|
30 |
|
31 |
|
32 |
def rpc_server(hdlr, address, authkey):
|
@@ -44,11 +44,17 @@ def rpc_server(hdlr, address, authkey):
|
|
44 |
models = []
|
45 |
tokenizer = None
|
46 |
|
|
|
47 |
def chat(messages, gen_conf):
|
48 |
global tokenizer
|
49 |
model = Model()
|
50 |
try:
|
51 |
-
conf = {
|
|
|
|
|
|
|
|
|
|
|
52 |
print(messages, conf)
|
53 |
text = tokenizer.apply_chat_template(
|
54 |
messages,
|
@@ -65,7 +71,8 @@ def chat(messages, gen_conf):
|
|
65 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
66 |
]
|
67 |
|
68 |
-
return tokenizer.batch_decode(
|
|
|
69 |
except Exception as e:
|
70 |
return str(e)
|
71 |
|
@@ -75,10 +82,15 @@ def Model():
|
|
75 |
random.seed(time.time())
|
76 |
return random.choice(models)
|
77 |
|
|
|
78 |
if __name__ == "__main__":
|
79 |
parser = argparse.ArgumentParser()
|
80 |
parser.add_argument("--model_name", type=str, help="Model name")
|
81 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
82 |
args = parser.parse_args()
|
83 |
|
84 |
handler = RPCHandler()
|
@@ -93,4 +105,5 @@ if __name__ == "__main__":
|
|
93 |
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
94 |
|
95 |
# Run the server
|
96 |
-
rpc_server(handler, ('0.0.0.0', args.port),
|
|
|
|
9 |
|
10 |
class RPCHandler:
|
11 |
def __init__(self):
|
12 |
+
self._functions = {}
|
13 |
|
14 |
def register_function(self, func):
|
15 |
self._functions[func.__name__] = func
|
|
|
21 |
func_name, args, kwargs = pickle.loads(connection.recv())
|
22 |
# Run the RPC and send a response
|
23 |
try:
|
24 |
+
r = self._functions[func_name](*args, **kwargs)
|
25 |
connection.send(pickle.dumps(r))
|
26 |
except Exception as e:
|
27 |
connection.send(pickle.dumps(e))
|
28 |
except EOFError:
|
29 |
+
pass
|
30 |
|
31 |
|
32 |
def rpc_server(hdlr, address, authkey):
|
|
|
44 |
models = []
|
45 |
tokenizer = None
|
46 |
|
47 |
+
|
48 |
def chat(messages, gen_conf):
|
49 |
global tokenizer
|
50 |
model = Model()
|
51 |
try:
|
52 |
+
conf = {
|
53 |
+
"max_new_tokens": int(
|
54 |
+
gen_conf.get(
|
55 |
+
"max_tokens", 256)), "temperature": float(
|
56 |
+
gen_conf.get(
|
57 |
+
"temperature", 0.1))}
|
58 |
print(messages, conf)
|
59 |
text = tokenizer.apply_chat_template(
|
60 |
messages,
|
|
|
71 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
72 |
]
|
73 |
|
74 |
+
return tokenizer.batch_decode(
|
75 |
+
generated_ids, skip_special_tokens=True)[0]
|
76 |
except Exception as e:
|
77 |
return str(e)
|
78 |
|
|
|
82 |
random.seed(time.time())
|
83 |
return random.choice(models)
|
84 |
|
85 |
+
|
86 |
if __name__ == "__main__":
|
87 |
parser = argparse.ArgumentParser()
|
88 |
parser.add_argument("--model_name", type=str, help="Model name")
|
89 |
+
parser.add_argument(
|
90 |
+
"--port",
|
91 |
+
default=7860,
|
92 |
+
type=int,
|
93 |
+
help="RPC serving port")
|
94 |
args = parser.parse_args()
|
95 |
|
96 |
handler = RPCHandler()
|
|
|
105 |
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
106 |
|
107 |
# Run the server
|
108 |
+
rpc_server(handler, ('0.0.0.0', args.port),
|
109 |
+
authkey=b'infiniflow-token4kevinhu')
|
rag/nlp/huchunk.py
CHANGED
@@ -372,7 +372,8 @@ class PptChunker(HuChunker):
|
|
372 |
tb = shape.table
|
373 |
rows = []
|
374 |
for i in range(1, len(tb.rows)):
|
375 |
-
rows.append("; ".join([tb.cell(
|
|
|
376 |
return "\n".join(rows)
|
377 |
|
378 |
if shape.has_text_frame:
|
@@ -382,7 +383,8 @@ class PptChunker(HuChunker):
|
|
382 |
texts = []
|
383 |
for p in shape.shapes:
|
384 |
t = self.__extract(p)
|
385 |
-
if t:
|
|
|
386 |
return "\n".join(texts)
|
387 |
|
388 |
def __call__(self, fnm):
|
@@ -395,7 +397,8 @@ class PptChunker(HuChunker):
|
|
395 |
texts = []
|
396 |
for shape in slide.shapes:
|
397 |
txt = self.__extract(shape)
|
398 |
-
if txt:
|
|
|
399 |
txts.append("\n".join(texts))
|
400 |
|
401 |
import aspose.slides as slides
|
@@ -404,9 +407,12 @@ class PptChunker(HuChunker):
|
|
404 |
with slides.Presentation(BytesIO(fnm)) as presentation:
|
405 |
for slide in presentation.slides:
|
406 |
buffered = BytesIO()
|
407 |
-
slide.get_thumbnail(
|
|
|
|
|
408 |
imgs.append(buffered.getvalue())
|
409 |
-
assert len(imgs) == len(
|
|
|
410 |
|
411 |
flds = self.Fields()
|
412 |
flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
|
@@ -445,7 +451,8 @@ class TextChunker(HuChunker):
|
|
445 |
if isinstance(fnm, str):
|
446 |
with open(fnm, "r") as f:
|
447 |
txt = f.read()
|
448 |
-
else:
|
|
|
449 |
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
450 |
flds.table_chunks = []
|
451 |
return flds
|
|
|
372 |
tb = shape.table
|
373 |
rows = []
|
374 |
for i in range(1, len(tb.rows)):
|
375 |
+
rows.append("; ".join([tb.cell(
|
376 |
+
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
|
377 |
return "\n".join(rows)
|
378 |
|
379 |
if shape.has_text_frame:
|
|
|
383 |
texts = []
|
384 |
for p in shape.shapes:
|
385 |
t = self.__extract(p)
|
386 |
+
if t:
|
387 |
+
texts.append(t)
|
388 |
return "\n".join(texts)
|
389 |
|
390 |
def __call__(self, fnm):
|
|
|
397 |
texts = []
|
398 |
for shape in slide.shapes:
|
399 |
txt = self.__extract(shape)
|
400 |
+
if txt:
|
401 |
+
texts.append(txt)
|
402 |
txts.append("\n".join(texts))
|
403 |
|
404 |
import aspose.slides as slides
|
|
|
407 |
with slides.Presentation(BytesIO(fnm)) as presentation:
|
408 |
for slide in presentation.slides:
|
409 |
buffered = BytesIO()
|
410 |
+
slide.get_thumbnail(
|
411 |
+
0.5, 0.5).save(
|
412 |
+
buffered, drawing.imaging.ImageFormat.jpeg)
|
413 |
imgs.append(buffered.getvalue())
|
414 |
+
assert len(imgs) == len(
|
415 |
+
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
416 |
|
417 |
flds = self.Fields()
|
418 |
flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
|
|
|
451 |
if isinstance(fnm, str):
|
452 |
with open(fnm, "r") as f:
|
453 |
txt = f.read()
|
454 |
+
else:
|
455 |
+
txt = fnm.decode("utf-8")
|
456 |
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
457 |
flds.table_chunks = []
|
458 |
return flds
|
rag/nlp/query.py
CHANGED
@@ -149,7 +149,8 @@ class EsQueryer:
|
|
149 |
atks = toDict(atks)
|
150 |
btkss = [toDict(tks) for tks in btkss]
|
151 |
tksim = [self.similarity(atks, btks) for btks in btkss]
|
152 |
-
return np.array(sims[0]) * vtweight +
|
|
|
153 |
|
154 |
def similarity(self, qtwt, dtwt):
|
155 |
if isinstance(dtwt, type("")):
|
@@ -159,11 +160,11 @@ class EsQueryer:
|
|
159 |
s = 1e-9
|
160 |
for k, v in qtwt.items():
|
161 |
if k in dtwt:
|
162 |
-
s += v# * dtwt[k]
|
163 |
q = 1e-9
|
164 |
for k, v in qtwt.items():
|
165 |
-
q += v
|
166 |
#d = 1e-9
|
167 |
-
#for k, v in dtwt.items():
|
168 |
# d += v * v
|
169 |
-
return s / q
|
|
|
149 |
atks = toDict(atks)
|
150 |
btkss = [toDict(tks) for tks in btkss]
|
151 |
tksim = [self.similarity(atks, btks) for btks in btkss]
|
152 |
+
return np.array(sims[0]) * vtweight + \
|
153 |
+
np.array(tksim) * tkweight, tksim, sims[0]
|
154 |
|
155 |
def similarity(self, qtwt, dtwt):
|
156 |
if isinstance(dtwt, type("")):
|
|
|
160 |
s = 1e-9
|
161 |
for k, v in qtwt.items():
|
162 |
if k in dtwt:
|
163 |
+
s += v # * dtwt[k]
|
164 |
q = 1e-9
|
165 |
for k, v in qtwt.items():
|
166 |
+
q += v # * v
|
167 |
#d = 1e-9
|
168 |
+
# for k, v in dtwt.items():
|
169 |
# d += v * v
|
170 |
+
return s / q # math.sqrt(q) / math.sqrt(d)
|