KevinHuSh
commited on
Commit
·
75a07ce
1
Parent(s):
78dc980
fix raptor bugs (#928)
Browse files### What problem does this PR solve?
#922
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/api_app.py +74 -0
- api/apps/chunk_app.py +3 -0
- api/db/services/document_service.py +1 -1
- deepdoc/vision/postprocess.py +0 -1
- rag/llm/chat_model.py +1 -1
- rag/raptor.py +1 -0
- requirements.txt +1 -0
- requirements_dev.txt +1 -0
api/apps/api_app.py
CHANGED
|
@@ -488,3 +488,77 @@ def document_rm():
|
|
| 488 |
return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
|
| 489 |
|
| 490 |
return get_json_result(data=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
|
| 489 |
|
| 490 |
return get_json_result(data=True)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@manager.route('/completion_aibotk', methods=['POST'])
|
| 494 |
+
@validate_request("Authorization", "conversation_id", "word")
|
| 495 |
+
def completion_faq():
|
| 496 |
+
import base64
|
| 497 |
+
req = request.json
|
| 498 |
+
|
| 499 |
+
token = req["Authorization"]
|
| 500 |
+
objs = APIToken.query(token=token)
|
| 501 |
+
if not objs:
|
| 502 |
+
return get_json_result(
|
| 503 |
+
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
| 504 |
+
|
| 505 |
+
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
| 506 |
+
if not e:
|
| 507 |
+
return get_data_error_result(retmsg="Conversation not found!")
|
| 508 |
+
if "quote" not in req: req["quote"] = True
|
| 509 |
+
|
| 510 |
+
msg = []
|
| 511 |
+
msg.append({"role": "user", "content": req["word"]})
|
| 512 |
+
|
| 513 |
+
try:
|
| 514 |
+
conv.message.append(msg[-1])
|
| 515 |
+
e, dia = DialogService.get_by_id(conv.dialog_id)
|
| 516 |
+
if not e:
|
| 517 |
+
return get_data_error_result(retmsg="Dialog not found!")
|
| 518 |
+
del req["conversation_id"]
|
| 519 |
+
|
| 520 |
+
if not conv.reference:
|
| 521 |
+
conv.reference = []
|
| 522 |
+
conv.message.append({"role": "assistant", "content": ""})
|
| 523 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
| 524 |
+
|
| 525 |
+
def fillin_conv(ans):
|
| 526 |
+
nonlocal conv
|
| 527 |
+
if not conv.reference:
|
| 528 |
+
conv.reference.append(ans["reference"])
|
| 529 |
+
else: conv.reference[-1] = ans["reference"]
|
| 530 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
| 531 |
+
|
| 532 |
+
data_type_picture = {
|
| 533 |
+
"type": 3,
|
| 534 |
+
"url": "base64 content"
|
| 535 |
+
}
|
| 536 |
+
data = [
|
| 537 |
+
{
|
| 538 |
+
"type": 1,
|
| 539 |
+
"content": ""
|
| 540 |
+
}
|
| 541 |
+
]
|
| 542 |
+
for ans in chat(dia, msg, stream=False, **req):
|
| 543 |
+
# answer = ans
|
| 544 |
+
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
| 545 |
+
fillin_conv(ans)
|
| 546 |
+
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 547 |
+
|
| 548 |
+
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
| 549 |
+
for chunk_idx in chunk_idxs[:1]:
|
| 550 |
+
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
| 551 |
+
try:
|
| 552 |
+
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
| 553 |
+
response = MINIO.get(bkt, nm)
|
| 554 |
+
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
| 555 |
+
data.append(data_type_picture)
|
| 556 |
+
except Exception as e:
|
| 557 |
+
return server_error_response(e)
|
| 558 |
+
break
|
| 559 |
+
|
| 560 |
+
response = {"code": 200, "msg": "success", "data": data}
|
| 561 |
+
return response
|
| 562 |
+
|
| 563 |
+
except Exception as e:
|
| 564 |
+
return server_error_response(e)
|
api/apps/chunk_app.py
CHANGED
|
@@ -229,6 +229,9 @@ def create():
|
|
| 229 |
v = 0.1 * v[0] + 0.9 * v[1]
|
| 230 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 231 |
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
|
|
|
|
|
|
|
|
|
| 232 |
return get_json_result(data={"chunk_id": chunck_id})
|
| 233 |
except Exception as e:
|
| 234 |
return server_error_response(e)
|
|
|
|
| 229 |
v = 0.1 * v[0] + 0.9 * v[1]
|
| 230 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 231 |
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
| 232 |
+
|
| 233 |
+
DocumentService.increment_chunk_num(
|
| 234 |
+
doc.id, doc.kb_id, c, 1, 0)
|
| 235 |
return get_json_result(data={"chunk_id": chunck_id})
|
| 236 |
except Exception as e:
|
| 237 |
return server_error_response(e)
|
api/db/services/document_service.py
CHANGED
|
@@ -263,7 +263,7 @@ class DocumentService(CommonService):
|
|
| 263 |
prg = -1
|
| 264 |
status = TaskStatus.FAIL.value
|
| 265 |
elif finished:
|
| 266 |
-
if d["parser_config"].get("raptor") and d["progress_msg"].lower().find(" raptor")<0:
|
| 267 |
queue_raptor_tasks(d)
|
| 268 |
prg *= 0.98
|
| 269 |
msg.append("------ RAPTOR -------")
|
|
|
|
| 263 |
prg = -1
|
| 264 |
status = TaskStatus.FAIL.value
|
| 265 |
elif finished:
|
| 266 |
+
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
|
| 267 |
queue_raptor_tasks(d)
|
| 268 |
prg *= 0.98
|
| 269 |
msg.append("------ RAPTOR -------")
|
deepdoc/vision/postprocess.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import copy
|
| 2 |
import re
|
| 3 |
-
|
| 4 |
import numpy as np
|
| 5 |
import cv2
|
| 6 |
from shapely.geometry import Polygon
|
|
|
|
| 1 |
import copy
|
| 2 |
import re
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
from shapely.geometry import Polygon
|
rag/llm/chat_model.py
CHANGED
|
@@ -359,7 +359,6 @@ class VolcEngineChat(Base):
|
|
| 359 |
if system:
|
| 360 |
history.insert(0, {"role": "system", "content": system})
|
| 361 |
ans = ""
|
| 362 |
-
tk_count = 0
|
| 363 |
try:
|
| 364 |
req = {
|
| 365 |
"parameters": {
|
|
@@ -380,6 +379,7 @@ class VolcEngineChat(Base):
|
|
| 380 |
if resp.choices[0].finish_reason == "stop":
|
| 381 |
tk_count = resp.usage.total_tokens
|
| 382 |
yield ans
|
|
|
|
| 383 |
except Exception as e:
|
| 384 |
yield ans + "\n**ERROR**: " + str(e)
|
| 385 |
yield tk_count
|
|
|
|
| 359 |
if system:
|
| 360 |
history.insert(0, {"role": "system", "content": system})
|
| 361 |
ans = ""
|
|
|
|
| 362 |
try:
|
| 363 |
req = {
|
| 364 |
"parameters": {
|
|
|
|
| 379 |
if resp.choices[0].finish_reason == "stop":
|
| 380 |
tk_count = resp.usage.total_tokens
|
| 381 |
yield ans
|
| 382 |
+
|
| 383 |
except Exception as e:
|
| 384 |
yield ans + "\n**ERROR**: " + str(e)
|
| 385 |
yield tk_count
|
rag/raptor.py
CHANGED
|
@@ -95,6 +95,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|
| 95 |
gm.fit(reduced_embeddings)
|
| 96 |
probs = gm.predict_proba(reduced_embeddings)
|
| 97 |
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
|
|
|
| 98 |
lock = Lock()
|
| 99 |
with ThreadPoolExecutor(max_workers=12) as executor:
|
| 100 |
threads = []
|
|
|
|
| 95 |
gm.fit(reduced_embeddings)
|
| 96 |
probs = gm.predict_proba(reduced_embeddings)
|
| 97 |
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
| 98 |
+
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
| 99 |
lock = Lock()
|
| 100 |
with ThreadPoolExecutor(max_workers=12) as executor:
|
| 101 |
threads = []
|
requirements.txt
CHANGED
|
@@ -134,4 +134,5 @@ yarl==1.9.4
|
|
| 134 |
zhipuai==2.0.1
|
| 135 |
BCEmbedding
|
| 136 |
loguru==0.7.2
|
|
|
|
| 137 |
fasttext==0.9.2
|
|
|
|
| 134 |
zhipuai==2.0.1
|
| 135 |
BCEmbedding
|
| 136 |
loguru==0.7.2
|
| 137 |
+
umap-learn
|
| 138 |
fasttext==0.9.2
|
requirements_dev.txt
CHANGED
|
@@ -123,3 +123,4 @@ loguru==0.7.2
|
|
| 123 |
ollama==0.1.8
|
| 124 |
redis==5.0.4
|
| 125 |
fasttext==0.9.2
|
|
|
|
|
|
| 123 |
ollama==0.1.8
|
| 124 |
redis==5.0.4
|
| 125 |
fasttext==0.9.2
|
| 126 |
+
umap-learn
|