Commit
·
6101699
1
Parent(s):
fc803e8
Move settings initialization after module init phase (#3438)
Browse files### What problem does this PR solve?
1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME
### Type of change
- [x] Refactoring
Signed-off-by: jinhai <[email protected]>
- agent/component/generate.py +13 -9
- agent/component/retrieval.py +2 -2
- api/apps/__init__.py +4 -6
- api/apps/api_app.py +33 -31
- api/apps/canvas_app.py +12 -10
- api/apps/chunk_app.py +17 -17
- api/apps/conversation_app.py +7 -6
- api/apps/dialog_app.py +2 -2
- api/apps/document_app.py +31 -30
- api/apps/file2document_app.py +2 -2
- api/apps/file_app.py +5 -5
- api/apps/kb_app.py +7 -8
- api/apps/llm_app.py +2 -2
- api/apps/sdk/chat.py +3 -3
- api/apps/sdk/dataset.py +3 -3
- api/apps/sdk/dify_retrieval.py +6 -6
- api/apps/sdk/doc.py +25 -25
- api/apps/system_app.py +4 -5
- api/apps/user_app.py +29 -42
- api/db/db_models.py +7 -7
- api/db/init_data.py +12 -11
- api/db/services/dialog_service.py +4 -4
- api/db/services/document_service.py +5 -5
- api/ragflow_server.py +5 -6
- api/settings.py +144 -101
- api/utils/api_utils.py +24 -26
- deepdoc/parser/pdf_parser.py +2 -2
- graphrag/claim_extractor.py +2 -2
- graphrag/smoke.py +2 -2
- rag/benchmark.py +11 -11
- rag/llm/embedding_model.py +4 -4
- rag/llm/rerank_model.py +3 -3
- rag/svr/task_executor.py +22 -15
agent/component/generate.py
CHANGED
|
@@ -19,7 +19,7 @@ import pandas as pd
|
|
| 19 |
from api.db import LLMType
|
| 20 |
from api.db.services.dialog_service import message_fit_in
|
| 21 |
from api.db.services.llm_service import LLMBundle
|
| 22 |
-
from api
|
| 23 |
from agent.component.base import ComponentBase, ComponentParamBase
|
| 24 |
|
| 25 |
|
|
@@ -63,18 +63,20 @@ class Generate(ComponentBase):
|
|
| 63 |
component_name = "Generate"
|
| 64 |
|
| 65 |
def get_dependent_components(self):
|
| 66 |
-
cpnts = [para["component_id"] for para in self._param.parameters if
|
|
|
|
| 67 |
return cpnts
|
| 68 |
|
| 69 |
def set_cite(self, retrieval_res, answer):
|
| 70 |
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
|
| 71 |
if "empty_response" in retrieval_res.columns:
|
| 72 |
retrieval_res["empty_response"].fillna("", inplace=True)
|
| 73 |
-
answer, idx = retrievaler.insert_citations(answer,
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
doc_ids = set([])
|
| 79 |
recall_docs = []
|
| 80 |
for i in idx:
|
|
@@ -127,12 +129,14 @@ class Generate(ComponentBase):
|
|
| 127 |
else:
|
| 128 |
if cpn.component_name.lower() == "retrieval":
|
| 129 |
retrieval_res.append(out)
|
| 130 |
-
kwargs[para["key"]] = " - "+"\n - ".join(
|
|
|
|
| 131 |
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
| 132 |
|
| 133 |
if retrieval_res:
|
| 134 |
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
| 135 |
-
else:
|
|
|
|
| 136 |
|
| 137 |
for n, v in kwargs.items():
|
| 138 |
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
|
|
|
| 19 |
from api.db import LLMType
|
| 20 |
from api.db.services.dialog_service import message_fit_in
|
| 21 |
from api.db.services.llm_service import LLMBundle
|
| 22 |
+
from api import settings
|
| 23 |
from agent.component.base import ComponentBase, ComponentParamBase
|
| 24 |
|
| 25 |
|
|
|
|
| 63 |
component_name = "Generate"
|
| 64 |
|
| 65 |
def get_dependent_components(self):
|
| 66 |
+
cpnts = [para["component_id"] for para in self._param.parameters if
|
| 67 |
+
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
|
| 68 |
return cpnts
|
| 69 |
|
| 70 |
def set_cite(self, retrieval_res, answer):
|
| 71 |
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
|
| 72 |
if "empty_response" in retrieval_res.columns:
|
| 73 |
retrieval_res["empty_response"].fillna("", inplace=True)
|
| 74 |
+
answer, idx = settings.retrievaler.insert_citations(answer,
|
| 75 |
+
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
| 76 |
+
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
| 77 |
+
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
| 78 |
+
self._canvas.get_embedding_model()), tkweight=0.7,
|
| 79 |
+
vtweight=0.3)
|
| 80 |
doc_ids = set([])
|
| 81 |
recall_docs = []
|
| 82 |
for i in idx:
|
|
|
|
| 129 |
else:
|
| 130 |
if cpn.component_name.lower() == "retrieval":
|
| 131 |
retrieval_res.append(out)
|
| 132 |
+
kwargs[para["key"]] = " - " + "\n - ".join(
|
| 133 |
+
[o if isinstance(o, str) else str(o) for o in out["content"]])
|
| 134 |
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
| 135 |
|
| 136 |
if retrieval_res:
|
| 137 |
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
| 138 |
+
else:
|
| 139 |
+
retrieval_res = pd.DataFrame([])
|
| 140 |
|
| 141 |
for n, v in kwargs.items():
|
| 142 |
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
agent/component/retrieval.py
CHANGED
|
@@ -21,7 +21,7 @@ import pandas as pd
|
|
| 21 |
from api.db import LLMType
|
| 22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 23 |
from api.db.services.llm_service import LLMBundle
|
| 24 |
-
from api
|
| 25 |
from agent.component.base import ComponentBase, ComponentParamBase
|
| 26 |
|
| 27 |
|
|
@@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC):
|
|
| 67 |
if self._param.rerank_id:
|
| 68 |
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
| 69 |
|
| 70 |
-
kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
| 71 |
1, self._param.top_n,
|
| 72 |
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
| 73 |
aggs=False, rerank_mdl=rerank_mdl)
|
|
|
|
| 21 |
from api.db import LLMType
|
| 22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 23 |
from api.db.services.llm_service import LLMBundle
|
| 24 |
+
from api import settings
|
| 25 |
from agent.component.base import ComponentBase, ComponentParamBase
|
| 26 |
|
| 27 |
|
|
|
|
| 67 |
if self._param.rerank_id:
|
| 68 |
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
| 69 |
|
| 70 |
+
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
| 71 |
1, self._param.top_n,
|
| 72 |
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
| 73 |
aggs=False, rerank_mdl=rerank_mdl)
|
api/apps/__init__.py
CHANGED
|
@@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands
|
|
| 30 |
|
| 31 |
from flask_session import Session
|
| 32 |
from flask_login import LoginManager
|
| 33 |
-
from api
|
| 34 |
-
from api.settings import API_VERSION
|
| 35 |
from api.utils.api_utils import server_error_response
|
| 36 |
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
| 37 |
|
|
@@ -78,7 +77,6 @@ app.url_map.strict_slashes = False
|
|
| 78 |
app.json_encoder = CustomJSONEncoder
|
| 79 |
app.errorhandler(Exception)(server_error_response)
|
| 80 |
|
| 81 |
-
|
| 82 |
## convince for dev and debug
|
| 83 |
# app.config["LOGIN_DISABLED"] = True
|
| 84 |
app.config["SESSION_PERMANENT"] = False
|
|
@@ -110,7 +108,7 @@ def register_page(page_path):
|
|
| 110 |
|
| 111 |
page_name = page_path.stem.rstrip("_app")
|
| 112 |
module_name = ".".join(
|
| 113 |
-
page_path.parts[page_path.parts.index("api")
|
| 114 |
)
|
| 115 |
|
| 116 |
spec = spec_from_file_location(module_name, page_path)
|
|
@@ -121,7 +119,7 @@ def register_page(page_path):
|
|
| 121 |
spec.loader.exec_module(page)
|
| 122 |
page_name = getattr(page, "page_name", page_name)
|
| 123 |
url_prefix = (
|
| 124 |
-
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
|
| 125 |
)
|
| 126 |
|
| 127 |
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
|
@@ -141,7 +139,7 @@ client_urls_prefix = [
|
|
| 141 |
|
| 142 |
@login_manager.request_loader
|
| 143 |
def load_user(web_request):
|
| 144 |
-
jwt = Serializer(secret_key=SECRET_KEY)
|
| 145 |
authorization = web_request.headers.get("Authorization")
|
| 146 |
if authorization:
|
| 147 |
try:
|
|
|
|
| 30 |
|
| 31 |
from flask_session import Session
|
| 32 |
from flask_login import LoginManager
|
| 33 |
+
from api import settings
|
|
|
|
| 34 |
from api.utils.api_utils import server_error_response
|
| 35 |
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
| 36 |
|
|
|
|
| 77 |
app.json_encoder = CustomJSONEncoder
|
| 78 |
app.errorhandler(Exception)(server_error_response)
|
| 79 |
|
|
|
|
| 80 |
## convince for dev and debug
|
| 81 |
# app.config["LOGIN_DISABLED"] = True
|
| 82 |
app.config["SESSION_PERMANENT"] = False
|
|
|
|
| 108 |
|
| 109 |
page_name = page_path.stem.rstrip("_app")
|
| 110 |
module_name = ".".join(
|
| 111 |
+
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
|
| 112 |
)
|
| 113 |
|
| 114 |
spec = spec_from_file_location(module_name, page_path)
|
|
|
|
| 119 |
spec.loader.exec_module(page)
|
| 120 |
page_name = getattr(page, "page_name", page_name)
|
| 121 |
url_prefix = (
|
| 122 |
+
f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
|
| 123 |
)
|
| 124 |
|
| 125 |
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
|
|
|
| 139 |
|
| 140 |
@login_manager.request_loader
|
| 141 |
def load_user(web_request):
|
| 142 |
+
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
| 143 |
authorization = web_request.headers.get("Authorization")
|
| 144 |
if authorization:
|
| 145 |
try:
|
api/apps/api_app.py
CHANGED
|
@@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
|
|
| 32 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 33 |
from api.db.services.task_service import queue_tasks, TaskService
|
| 34 |
from api.db.services.user_service import UserTenantService
|
| 35 |
-
from api
|
| 36 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
| 37 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
| 38 |
generate_confirmation_token
|
|
@@ -141,7 +141,7 @@ def set_conversation():
|
|
| 141 |
objs = APIToken.query(token=token)
|
| 142 |
if not objs:
|
| 143 |
return get_json_result(
|
| 144 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 145 |
req = request.json
|
| 146 |
try:
|
| 147 |
if objs[0].source == "agent":
|
|
@@ -183,7 +183,7 @@ def completion():
|
|
| 183 |
objs = APIToken.query(token=token)
|
| 184 |
if not objs:
|
| 185 |
return get_json_result(
|
| 186 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 187 |
req = request.json
|
| 188 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
| 189 |
if not e:
|
|
@@ -290,8 +290,8 @@ def completion():
|
|
| 290 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 291 |
rename_field(result)
|
| 292 |
return get_json_result(data=result)
|
| 293 |
-
|
| 294 |
-
|
| 295 |
conv.message.append(msg[-1])
|
| 296 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
| 297 |
if not e:
|
|
@@ -326,7 +326,7 @@ def completion():
|
|
| 326 |
resp.headers.add_header("X-Accel-Buffering", "no")
|
| 327 |
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
| 328 |
return resp
|
| 329 |
-
|
| 330 |
answer = None
|
| 331 |
for ans in chat(dia, msg, **req):
|
| 332 |
answer = ans
|
|
@@ -347,8 +347,8 @@ def get(conversation_id):
|
|
| 347 |
objs = APIToken.query(token=token)
|
| 348 |
if not objs:
|
| 349 |
return get_json_result(
|
| 350 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 351 |
-
|
| 352 |
try:
|
| 353 |
e, conv = API4ConversationService.get_by_id(conversation_id)
|
| 354 |
if not e:
|
|
@@ -357,8 +357,8 @@ def get(conversation_id):
|
|
| 357 |
conv = conv.to_dict()
|
| 358 |
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
| 359 |
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
|
| 360 |
-
code=RetCode.AUTHENTICATION_ERROR)
|
| 361 |
-
|
| 362 |
for referenct_i in conv['reference']:
|
| 363 |
if referenct_i is None or len(referenct_i) == 0:
|
| 364 |
continue
|
|
@@ -378,7 +378,7 @@ def upload():
|
|
| 378 |
objs = APIToken.query(token=token)
|
| 379 |
if not objs:
|
| 380 |
return get_json_result(
|
| 381 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 382 |
|
| 383 |
kb_name = request.form.get("kb_name").strip()
|
| 384 |
tenant_id = objs[0].tenant_id
|
|
@@ -394,12 +394,12 @@ def upload():
|
|
| 394 |
|
| 395 |
if 'file' not in request.files:
|
| 396 |
return get_json_result(
|
| 397 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
| 398 |
|
| 399 |
file = request.files['file']
|
| 400 |
if file.filename == '':
|
| 401 |
return get_json_result(
|
| 402 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
| 403 |
|
| 404 |
root_folder = FileService.get_root_folder(tenant_id)
|
| 405 |
pf_id = root_folder["id"]
|
|
@@ -490,17 +490,17 @@ def upload_parse():
|
|
| 490 |
objs = APIToken.query(token=token)
|
| 491 |
if not objs:
|
| 492 |
return get_json_result(
|
| 493 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 494 |
|
| 495 |
if 'file' not in request.files:
|
| 496 |
return get_json_result(
|
| 497 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
| 498 |
|
| 499 |
file_objs = request.files.getlist('file')
|
| 500 |
for file_obj in file_objs:
|
| 501 |
if file_obj.filename == '':
|
| 502 |
return get_json_result(
|
| 503 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
| 504 |
|
| 505 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
| 506 |
return get_json_result(data=doc_ids)
|
|
@@ -513,7 +513,7 @@ def list_chunks():
|
|
| 513 |
objs = APIToken.query(token=token)
|
| 514 |
if not objs:
|
| 515 |
return get_json_result(
|
| 516 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 517 |
|
| 518 |
req = request.json
|
| 519 |
|
|
@@ -531,7 +531,7 @@ def list_chunks():
|
|
| 531 |
)
|
| 532 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 533 |
|
| 534 |
-
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
| 535 |
res = [
|
| 536 |
{
|
| 537 |
"content": res_item["content_with_weight"],
|
|
@@ -553,7 +553,7 @@ def list_kb_docs():
|
|
| 553 |
objs = APIToken.query(token=token)
|
| 554 |
if not objs:
|
| 555 |
return get_json_result(
|
| 556 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 557 |
|
| 558 |
req = request.json
|
| 559 |
tenant_id = objs[0].tenant_id
|
|
@@ -585,6 +585,7 @@ def list_kb_docs():
|
|
| 585 |
except Exception as e:
|
| 586 |
return server_error_response(e)
|
| 587 |
|
|
|
|
| 588 |
@manager.route('/document/infos', methods=['POST'])
|
| 589 |
@validate_request("doc_ids")
|
| 590 |
def docinfos():
|
|
@@ -592,7 +593,7 @@ def docinfos():
|
|
| 592 |
objs = APIToken.query(token=token)
|
| 593 |
if not objs:
|
| 594 |
return get_json_result(
|
| 595 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 596 |
req = request.json
|
| 597 |
doc_ids = req["doc_ids"]
|
| 598 |
docs = DocumentService.get_by_ids(doc_ids)
|
|
@@ -606,7 +607,7 @@ def document_rm():
|
|
| 606 |
objs = APIToken.query(token=token)
|
| 607 |
if not objs:
|
| 608 |
return get_json_result(
|
| 609 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 610 |
|
| 611 |
tenant_id = objs[0].tenant_id
|
| 612 |
req = request.json
|
|
@@ -653,7 +654,7 @@ def document_rm():
|
|
| 653 |
errors += str(e)
|
| 654 |
|
| 655 |
if errors:
|
| 656 |
-
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
| 657 |
|
| 658 |
return get_json_result(data=True)
|
| 659 |
|
|
@@ -668,7 +669,7 @@ def completion_faq():
|
|
| 668 |
objs = APIToken.query(token=token)
|
| 669 |
if not objs:
|
| 670 |
return get_json_result(
|
| 671 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 672 |
|
| 673 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
| 674 |
if not e:
|
|
@@ -805,10 +806,10 @@ def retrieval():
|
|
| 805 |
objs = APIToken.query(token=token)
|
| 806 |
if not objs:
|
| 807 |
return get_json_result(
|
| 808 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
| 809 |
|
| 810 |
req = request.json
|
| 811 |
-
kb_ids = req.get("kb_id",[])
|
| 812 |
doc_ids = req.get("doc_ids", [])
|
| 813 |
question = req.get("question")
|
| 814 |
page = int(req.get("page", 1))
|
|
@@ -822,20 +823,21 @@ def retrieval():
|
|
| 822 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 823 |
if len(embd_nms) != 1:
|
| 824 |
return get_json_result(
|
| 825 |
-
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
|
|
|
| 826 |
|
| 827 |
embd_mdl = TenantLLMService.model_instance(
|
| 828 |
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
| 829 |
rerank_mdl = None
|
| 830 |
if req.get("rerank_id"):
|
| 831 |
rerank_mdl = TenantLLMService.model_instance(
|
| 832 |
-
|
| 833 |
if req.get("keyword", False):
|
| 834 |
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
| 835 |
question += keyword_extraction(chat_mdl, question)
|
| 836 |
-
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
| 837 |
-
|
| 838 |
-
|
| 839 |
for c in ranks["chunks"]:
|
| 840 |
if "vector" in c:
|
| 841 |
del c["vector"]
|
|
@@ -843,5 +845,5 @@ def retrieval():
|
|
| 843 |
except Exception as e:
|
| 844 |
if str(e).find("not_found") > 0:
|
| 845 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
| 846 |
-
code=RetCode.DATA_ERROR)
|
| 847 |
return server_error_response(e)
|
|
|
|
| 32 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 33 |
from api.db.services.task_service import queue_tasks, TaskService
|
| 34 |
from api.db.services.user_service import UserTenantService
|
| 35 |
+
from api import settings
|
| 36 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
| 37 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
| 38 |
generate_confirmation_token
|
|
|
|
| 141 |
objs = APIToken.query(token=token)
|
| 142 |
if not objs:
|
| 143 |
return get_json_result(
|
| 144 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 145 |
req = request.json
|
| 146 |
try:
|
| 147 |
if objs[0].source == "agent":
|
|
|
|
| 183 |
objs = APIToken.query(token=token)
|
| 184 |
if not objs:
|
| 185 |
return get_json_result(
|
| 186 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 187 |
req = request.json
|
| 188 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
| 189 |
if not e:
|
|
|
|
| 290 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 291 |
rename_field(result)
|
| 292 |
return get_json_result(data=result)
|
| 293 |
+
|
| 294 |
+
# ******************For dialog******************
|
| 295 |
conv.message.append(msg[-1])
|
| 296 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
| 297 |
if not e:
|
|
|
|
| 326 |
resp.headers.add_header("X-Accel-Buffering", "no")
|
| 327 |
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
| 328 |
return resp
|
| 329 |
+
|
| 330 |
answer = None
|
| 331 |
for ans in chat(dia, msg, **req):
|
| 332 |
answer = ans
|
|
|
|
| 347 |
objs = APIToken.query(token=token)
|
| 348 |
if not objs:
|
| 349 |
return get_json_result(
|
| 350 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 351 |
+
|
| 352 |
try:
|
| 353 |
e, conv = API4ConversationService.get_by_id(conversation_id)
|
| 354 |
if not e:
|
|
|
|
| 357 |
conv = conv.to_dict()
|
| 358 |
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
| 359 |
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
|
| 360 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 361 |
+
|
| 362 |
for referenct_i in conv['reference']:
|
| 363 |
if referenct_i is None or len(referenct_i) == 0:
|
| 364 |
continue
|
|
|
|
| 378 |
objs = APIToken.query(token=token)
|
| 379 |
if not objs:
|
| 380 |
return get_json_result(
|
| 381 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 382 |
|
| 383 |
kb_name = request.form.get("kb_name").strip()
|
| 384 |
tenant_id = objs[0].tenant_id
|
|
|
|
| 394 |
|
| 395 |
if 'file' not in request.files:
|
| 396 |
return get_json_result(
|
| 397 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 398 |
|
| 399 |
file = request.files['file']
|
| 400 |
if file.filename == '':
|
| 401 |
return get_json_result(
|
| 402 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 403 |
|
| 404 |
root_folder = FileService.get_root_folder(tenant_id)
|
| 405 |
pf_id = root_folder["id"]
|
|
|
|
| 490 |
objs = APIToken.query(token=token)
|
| 491 |
if not objs:
|
| 492 |
return get_json_result(
|
| 493 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 494 |
|
| 495 |
if 'file' not in request.files:
|
| 496 |
return get_json_result(
|
| 497 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 498 |
|
| 499 |
file_objs = request.files.getlist('file')
|
| 500 |
for file_obj in file_objs:
|
| 501 |
if file_obj.filename == '':
|
| 502 |
return get_json_result(
|
| 503 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 504 |
|
| 505 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
| 506 |
return get_json_result(data=doc_ids)
|
|
|
|
| 513 |
objs = APIToken.query(token=token)
|
| 514 |
if not objs:
|
| 515 |
return get_json_result(
|
| 516 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 517 |
|
| 518 |
req = request.json
|
| 519 |
|
|
|
|
| 531 |
)
|
| 532 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 533 |
|
| 534 |
+
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
| 535 |
res = [
|
| 536 |
{
|
| 537 |
"content": res_item["content_with_weight"],
|
|
|
|
| 553 |
objs = APIToken.query(token=token)
|
| 554 |
if not objs:
|
| 555 |
return get_json_result(
|
| 556 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 557 |
|
| 558 |
req = request.json
|
| 559 |
tenant_id = objs[0].tenant_id
|
|
|
|
| 585 |
except Exception as e:
|
| 586 |
return server_error_response(e)
|
| 587 |
|
| 588 |
+
|
| 589 |
@manager.route('/document/infos', methods=['POST'])
|
| 590 |
@validate_request("doc_ids")
|
| 591 |
def docinfos():
|
|
|
|
| 593 |
objs = APIToken.query(token=token)
|
| 594 |
if not objs:
|
| 595 |
return get_json_result(
|
| 596 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 597 |
req = request.json
|
| 598 |
doc_ids = req["doc_ids"]
|
| 599 |
docs = DocumentService.get_by_ids(doc_ids)
|
|
|
|
| 607 |
objs = APIToken.query(token=token)
|
| 608 |
if not objs:
|
| 609 |
return get_json_result(
|
| 610 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 611 |
|
| 612 |
tenant_id = objs[0].tenant_id
|
| 613 |
req = request.json
|
|
|
|
| 654 |
errors += str(e)
|
| 655 |
|
| 656 |
if errors:
|
| 657 |
+
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
| 658 |
|
| 659 |
return get_json_result(data=True)
|
| 660 |
|
|
|
|
| 669 |
objs = APIToken.query(token=token)
|
| 670 |
if not objs:
|
| 671 |
return get_json_result(
|
| 672 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 673 |
|
| 674 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
| 675 |
if not e:
|
|
|
|
| 806 |
objs = APIToken.query(token=token)
|
| 807 |
if not objs:
|
| 808 |
return get_json_result(
|
| 809 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 810 |
|
| 811 |
req = request.json
|
| 812 |
+
kb_ids = req.get("kb_id", [])
|
| 813 |
doc_ids = req.get("doc_ids", [])
|
| 814 |
question = req.get("question")
|
| 815 |
page = int(req.get("page", 1))
|
|
|
|
| 823 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 824 |
if len(embd_nms) != 1:
|
| 825 |
return get_json_result(
|
| 826 |
+
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
| 827 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 828 |
|
| 829 |
embd_mdl = TenantLLMService.model_instance(
|
| 830 |
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
| 831 |
rerank_mdl = None
|
| 832 |
if req.get("rerank_id"):
|
| 833 |
rerank_mdl = TenantLLMService.model_instance(
|
| 834 |
+
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
| 835 |
if req.get("keyword", False):
|
| 836 |
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
| 837 |
question += keyword_extraction(chat_mdl, question)
|
| 838 |
+
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
| 839 |
+
similarity_threshold, vector_similarity_weight, top,
|
| 840 |
+
doc_ids, rerank_mdl=rerank_mdl)
|
| 841 |
for c in ranks["chunks"]:
|
| 842 |
if "vector" in c:
|
| 843 |
del c["vector"]
|
|
|
|
| 845 |
except Exception as e:
|
| 846 |
if str(e).find("not_found") > 0:
|
| 847 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
| 848 |
+
code=settings.RetCode.DATA_ERROR)
|
| 849 |
return server_error_response(e)
|
api/apps/canvas_app.py
CHANGED
|
@@ -19,7 +19,7 @@ from functools import partial
|
|
| 19 |
from flask import request, Response
|
| 20 |
from flask_login import login_required, current_user
|
| 21 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
| 22 |
-
from api
|
| 23 |
from api.utils import get_uuid
|
| 24 |
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
| 25 |
from agent.canvas import Canvas
|
|
@@ -36,7 +36,8 @@ def templates():
|
|
| 36 |
@login_required
|
| 37 |
def canvas_list():
|
| 38 |
return get_json_result(data=sorted([c.to_dict() for c in \
|
| 39 |
-
|
|
|
|
| 40 |
)
|
| 41 |
|
| 42 |
|
|
@@ -45,10 +46,10 @@ def canvas_list():
|
|
| 45 |
@login_required
|
| 46 |
def rm():
|
| 47 |
for i in request.json["canvas_ids"]:
|
| 48 |
-
if not UserCanvasService.query(user_id=current_user.id,id=i):
|
| 49 |
return get_json_result(
|
| 50 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 51 |
-
code=RetCode.OPERATING_ERROR)
|
| 52 |
UserCanvasService.delete_by_id(i)
|
| 53 |
return get_json_result(data=True)
|
| 54 |
|
|
@@ -72,7 +73,7 @@ def save():
|
|
| 72 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
| 73 |
return get_json_result(
|
| 74 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 75 |
-
code=RetCode.OPERATING_ERROR)
|
| 76 |
UserCanvasService.update_by_id(req["id"], req)
|
| 77 |
return get_json_result(data=req)
|
| 78 |
|
|
@@ -98,7 +99,7 @@ def run():
|
|
| 98 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
| 99 |
return get_json_result(
|
| 100 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 101 |
-
code=RetCode.OPERATING_ERROR)
|
| 102 |
|
| 103 |
if not isinstance(cvs.dsl, str):
|
| 104 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
|
@@ -110,8 +111,8 @@ def run():
|
|
| 110 |
if "message" in req:
|
| 111 |
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
| 112 |
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
| 113 |
-
#ten = TenantService.get_info_by(current_user.id)[0]
|
| 114 |
-
#req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
| 115 |
pass
|
| 116 |
canvas.add_user_input(req["message"])
|
| 117 |
answer = canvas.run(stream=stream)
|
|
@@ -122,7 +123,8 @@ def run():
|
|
| 122 |
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
| 123 |
|
| 124 |
if stream:
|
| 125 |
-
assert isinstance(answer,
|
|
|
|
| 126 |
|
| 127 |
def sse():
|
| 128 |
nonlocal answer, cvs
|
|
@@ -173,7 +175,7 @@ def reset():
|
|
| 173 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
| 174 |
return get_json_result(
|
| 175 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 176 |
-
code=RetCode.OPERATING_ERROR)
|
| 177 |
|
| 178 |
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
| 179 |
canvas.reset()
|
|
|
|
| 19 |
from flask import request, Response
|
| 20 |
from flask_login import login_required, current_user
|
| 21 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
| 22 |
+
from api import settings
|
| 23 |
from api.utils import get_uuid
|
| 24 |
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
| 25 |
from agent.canvas import Canvas
|
|
|
|
| 36 |
@login_required
|
| 37 |
def canvas_list():
|
| 38 |
return get_json_result(data=sorted([c.to_dict() for c in \
|
| 39 |
+
UserCanvasService.query(user_id=current_user.id)],
|
| 40 |
+
key=lambda x: x["update_time"] * -1)
|
| 41 |
)
|
| 42 |
|
| 43 |
|
|
|
|
| 46 |
@login_required
|
| 47 |
def rm():
|
| 48 |
for i in request.json["canvas_ids"]:
|
| 49 |
+
if not UserCanvasService.query(user_id=current_user.id, id=i):
|
| 50 |
return get_json_result(
|
| 51 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 52 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 53 |
UserCanvasService.delete_by_id(i)
|
| 54 |
return get_json_result(data=True)
|
| 55 |
|
|
|
|
| 73 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
| 74 |
return get_json_result(
|
| 75 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 76 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 77 |
UserCanvasService.update_by_id(req["id"], req)
|
| 78 |
return get_json_result(data=req)
|
| 79 |
|
|
|
|
| 99 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
| 100 |
return get_json_result(
|
| 101 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 102 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 103 |
|
| 104 |
if not isinstance(cvs.dsl, str):
|
| 105 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
|
|
|
| 111 |
if "message" in req:
|
| 112 |
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
| 113 |
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
| 114 |
+
# ten = TenantService.get_info_by(current_user.id)[0]
|
| 115 |
+
# req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
| 116 |
pass
|
| 117 |
canvas.add_user_input(req["message"])
|
| 118 |
answer = canvas.run(stream=stream)
|
|
|
|
| 123 |
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
| 124 |
|
| 125 |
if stream:
|
| 126 |
+
assert isinstance(answer,
|
| 127 |
+
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
| 128 |
|
| 129 |
def sse():
|
| 130 |
nonlocal answer, cvs
|
|
|
|
| 175 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
| 176 |
return get_json_result(
|
| 177 |
data=False, message='Only owner of canvas authorized for this operation.',
|
| 178 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 179 |
|
| 180 |
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
| 181 |
canvas.reset()
|
api/apps/chunk_app.py
CHANGED
|
@@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle
|
|
| 29 |
from api.db.services.user_service import UserTenantService
|
| 30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 31 |
from api.db.services.document_service import DocumentService
|
| 32 |
-
from api
|
| 33 |
from api.utils.api_utils import get_json_result
|
| 34 |
import hashlib
|
| 35 |
import re
|
| 36 |
|
|
|
|
| 37 |
@manager.route('/list', methods=['POST'])
|
| 38 |
@login_required
|
| 39 |
@validate_request("doc_id")
|
|
@@ -56,7 +57,7 @@ def list_chunk():
|
|
| 56 |
}
|
| 57 |
if "available_int" in req:
|
| 58 |
query["available_int"] = int(req["available_int"])
|
| 59 |
-
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
| 60 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
| 61 |
for id in sres.ids:
|
| 62 |
d = {
|
|
@@ -72,13 +73,13 @@ def list_chunk():
|
|
| 72 |
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
| 73 |
}
|
| 74 |
assert isinstance(d["positions"], list)
|
| 75 |
-
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
| 76 |
res["chunks"].append(d)
|
| 77 |
return get_json_result(data=res)
|
| 78 |
except Exception as e:
|
| 79 |
if str(e).find("not_found") > 0:
|
| 80 |
return get_json_result(data=False, message='No chunk found!',
|
| 81 |
-
code=RetCode.DATA_ERROR)
|
| 82 |
return server_error_response(e)
|
| 83 |
|
| 84 |
|
|
@@ -93,7 +94,7 @@ def get():
|
|
| 93 |
tenant_id = tenants[0].tenant_id
|
| 94 |
|
| 95 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 96 |
-
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
| 97 |
if chunk is None:
|
| 98 |
return server_error_response("Chunk not found")
|
| 99 |
k = []
|
|
@@ -107,7 +108,7 @@ def get():
|
|
| 107 |
except Exception as e:
|
| 108 |
if str(e).find("NotFoundError") >= 0:
|
| 109 |
return get_json_result(data=False, message='Chunk not found!',
|
| 110 |
-
code=RetCode.DATA_ERROR)
|
| 111 |
return server_error_response(e)
|
| 112 |
|
| 113 |
|
|
@@ -154,7 +155,7 @@ def set():
|
|
| 154 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
| 155 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
| 156 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 157 |
-
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
| 158 |
return get_json_result(data=True)
|
| 159 |
except Exception as e:
|
| 160 |
return server_error_response(e)
|
|
@@ -169,8 +170,8 @@ def switch():
|
|
| 169 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 170 |
if not e:
|
| 171 |
return get_data_error_result(message="Document not found!")
|
| 172 |
-
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
| 173 |
-
|
| 174 |
return get_data_error_result(message="Index updating failure")
|
| 175 |
return get_json_result(data=True)
|
| 176 |
except Exception as e:
|
|
@@ -186,7 +187,7 @@ def rm():
|
|
| 186 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 187 |
if not e:
|
| 188 |
return get_data_error_result(message="Document not found!")
|
| 189 |
-
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
| 190 |
return get_data_error_result(message="Index updating failure")
|
| 191 |
deleted_chunk_ids = req["chunk_ids"]
|
| 192 |
chunk_number = len(deleted_chunk_ids)
|
|
@@ -230,7 +231,7 @@ def create():
|
|
| 230 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
| 231 |
v = 0.1 * v[0] + 0.9 * v[1]
|
| 232 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 233 |
-
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
| 234 |
|
| 235 |
DocumentService.increment_chunk_num(
|
| 236 |
doc.id, doc.kb_id, c, 1, 0)
|
|
@@ -265,7 +266,7 @@ def retrieval_test():
|
|
| 265 |
else:
|
| 266 |
return get_json_result(
|
| 267 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
| 268 |
-
code=RetCode.OPERATING_ERROR)
|
| 269 |
|
| 270 |
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
| 271 |
if not e:
|
|
@@ -281,7 +282,7 @@ def retrieval_test():
|
|
| 281 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
| 282 |
question += keyword_extraction(chat_mdl, question)
|
| 283 |
|
| 284 |
-
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
| 285 |
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
| 286 |
similarity_threshold, vector_similarity_weight, top,
|
| 287 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
|
@@ -293,7 +294,7 @@ def retrieval_test():
|
|
| 293 |
except Exception as e:
|
| 294 |
if str(e).find("not_found") > 0:
|
| 295 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
| 296 |
-
code=RetCode.DATA_ERROR)
|
| 297 |
return server_error_response(e)
|
| 298 |
|
| 299 |
|
|
@@ -304,10 +305,10 @@ def knowledge_graph():
|
|
| 304 |
tenant_id = DocumentService.get_tenant_id(doc_id)
|
| 305 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 306 |
req = {
|
| 307 |
-
"doc_ids":[doc_id],
|
| 308 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
| 309 |
}
|
| 310 |
-
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
| 311 |
obj = {"graph": {}, "mind_map": {}}
|
| 312 |
for id in sres.ids[:2]:
|
| 313 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
|
@@ -336,4 +337,3 @@ def knowledge_graph():
|
|
| 336 |
obj[ty] = content_json
|
| 337 |
|
| 338 |
return get_json_result(data=obj)
|
| 339 |
-
|
|
|
|
| 29 |
from api.db.services.user_service import UserTenantService
|
| 30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 31 |
from api.db.services.document_service import DocumentService
|
| 32 |
+
from api import settings
|
| 33 |
from api.utils.api_utils import get_json_result
|
| 34 |
import hashlib
|
| 35 |
import re
|
| 36 |
|
| 37 |
+
|
| 38 |
@manager.route('/list', methods=['POST'])
|
| 39 |
@login_required
|
| 40 |
@validate_request("doc_id")
|
|
|
|
| 57 |
}
|
| 58 |
if "available_int" in req:
|
| 59 |
query["available_int"] = int(req["available_int"])
|
| 60 |
+
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
| 61 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
| 62 |
for id in sres.ids:
|
| 63 |
d = {
|
|
|
|
| 73 |
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
| 74 |
}
|
| 75 |
assert isinstance(d["positions"], list)
|
| 76 |
+
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
| 77 |
res["chunks"].append(d)
|
| 78 |
return get_json_result(data=res)
|
| 79 |
except Exception as e:
|
| 80 |
if str(e).find("not_found") > 0:
|
| 81 |
return get_json_result(data=False, message='No chunk found!',
|
| 82 |
+
code=settings.RetCode.DATA_ERROR)
|
| 83 |
return server_error_response(e)
|
| 84 |
|
| 85 |
|
|
|
|
| 94 |
tenant_id = tenants[0].tenant_id
|
| 95 |
|
| 96 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 97 |
+
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
| 98 |
if chunk is None:
|
| 99 |
return server_error_response("Chunk not found")
|
| 100 |
k = []
|
|
|
|
| 108 |
except Exception as e:
|
| 109 |
if str(e).find("NotFoundError") >= 0:
|
| 110 |
return get_json_result(data=False, message='Chunk not found!',
|
| 111 |
+
code=settings.RetCode.DATA_ERROR)
|
| 112 |
return server_error_response(e)
|
| 113 |
|
| 114 |
|
|
|
|
| 155 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
| 156 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
| 157 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 158 |
+
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
| 159 |
return get_json_result(data=True)
|
| 160 |
except Exception as e:
|
| 161 |
return server_error_response(e)
|
|
|
|
| 170 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 171 |
if not e:
|
| 172 |
return get_data_error_result(message="Document not found!")
|
| 173 |
+
if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
| 174 |
+
search.index_name(doc.tenant_id), doc.kb_id):
|
| 175 |
return get_data_error_result(message="Index updating failure")
|
| 176 |
return get_json_result(data=True)
|
| 177 |
except Exception as e:
|
|
|
|
| 187 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 188 |
if not e:
|
| 189 |
return get_data_error_result(message="Document not found!")
|
| 190 |
+
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
| 191 |
return get_data_error_result(message="Index updating failure")
|
| 192 |
deleted_chunk_ids = req["chunk_ids"]
|
| 193 |
chunk_number = len(deleted_chunk_ids)
|
|
|
|
| 231 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
| 232 |
v = 0.1 * v[0] + 0.9 * v[1]
|
| 233 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 234 |
+
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
| 235 |
|
| 236 |
DocumentService.increment_chunk_num(
|
| 237 |
doc.id, doc.kb_id, c, 1, 0)
|
|
|
|
| 266 |
else:
|
| 267 |
return get_json_result(
|
| 268 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
| 269 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 270 |
|
| 271 |
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
| 272 |
if not e:
|
|
|
|
| 282 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
| 283 |
question += keyword_extraction(chat_mdl, question)
|
| 284 |
|
| 285 |
+
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
| 286 |
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
| 287 |
similarity_threshold, vector_similarity_weight, top,
|
| 288 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
|
|
|
| 294 |
except Exception as e:
|
| 295 |
if str(e).find("not_found") > 0:
|
| 296 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
| 297 |
+
code=settings.RetCode.DATA_ERROR)
|
| 298 |
return server_error_response(e)
|
| 299 |
|
| 300 |
|
|
|
|
| 305 |
tenant_id = DocumentService.get_tenant_id(doc_id)
|
| 306 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 307 |
req = {
|
| 308 |
+
"doc_ids": [doc_id],
|
| 309 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
| 310 |
}
|
| 311 |
+
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
| 312 |
obj = {"graph": {}, "mind_map": {}}
|
| 313 |
for id in sres.ids[:2]:
|
| 314 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
|
|
|
| 337 |
obj[ty] = content_json
|
| 338 |
|
| 339 |
return get_json_result(data=obj)
|
|
|
api/apps/conversation_app.py
CHANGED
|
@@ -25,7 +25,7 @@ from api.db import LLMType
|
|
| 25 |
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
| 26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 27 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
| 28 |
-
from api
|
| 29 |
from api.utils.api_utils import get_json_result
|
| 30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
|
@@ -87,7 +87,7 @@ def get():
|
|
| 87 |
else:
|
| 88 |
return get_json_result(
|
| 89 |
data=False, message='Only owner of conversation authorized for this operation.',
|
| 90 |
-
code=RetCode.OPERATING_ERROR)
|
| 91 |
conv = conv.to_dict()
|
| 92 |
return get_json_result(data=conv)
|
| 93 |
except Exception as e:
|
|
@@ -110,7 +110,7 @@ def rm():
|
|
| 110 |
else:
|
| 111 |
return get_json_result(
|
| 112 |
data=False, message='Only owner of conversation authorized for this operation.',
|
| 113 |
-
code=RetCode.OPERATING_ERROR)
|
| 114 |
ConversationService.delete_by_id(cid)
|
| 115 |
return get_json_result(data=True)
|
| 116 |
except Exception as e:
|
|
@@ -125,7 +125,7 @@ def list_convsersation():
|
|
| 125 |
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
| 126 |
return get_json_result(
|
| 127 |
data=False, message='Only owner of dialog authorized for this operation.',
|
| 128 |
-
code=RetCode.OPERATING_ERROR)
|
| 129 |
convs = ConversationService.query(
|
| 130 |
dialog_id=dialog_id,
|
| 131 |
order_by=ConversationService.model.create_time,
|
|
@@ -297,6 +297,7 @@ def thumbup():
|
|
| 297 |
def ask_about():
|
| 298 |
req = request.json
|
| 299 |
uid = current_user.id
|
|
|
|
| 300 |
def stream():
|
| 301 |
nonlocal req, uid
|
| 302 |
try:
|
|
@@ -329,8 +330,8 @@ def mindmap():
|
|
| 329 |
embd_mdl = TenantLLMService.model_instance(
|
| 330 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
| 331 |
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
| 332 |
-
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
| 333 |
-
|
| 334 |
mindmap = MindMapExtractor(chat_mdl)
|
| 335 |
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
| 336 |
if "error" in mind_map:
|
|
|
|
| 25 |
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
| 26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 27 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
| 28 |
+
from api import settings
|
| 29 |
from api.utils.api_utils import get_json_result
|
| 30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
|
|
|
| 87 |
else:
|
| 88 |
return get_json_result(
|
| 89 |
data=False, message='Only owner of conversation authorized for this operation.',
|
| 90 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 91 |
conv = conv.to_dict()
|
| 92 |
return get_json_result(data=conv)
|
| 93 |
except Exception as e:
|
|
|
|
| 110 |
else:
|
| 111 |
return get_json_result(
|
| 112 |
data=False, message='Only owner of conversation authorized for this operation.',
|
| 113 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 114 |
ConversationService.delete_by_id(cid)
|
| 115 |
return get_json_result(data=True)
|
| 116 |
except Exception as e:
|
|
|
|
| 125 |
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
| 126 |
return get_json_result(
|
| 127 |
data=False, message='Only owner of dialog authorized for this operation.',
|
| 128 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 129 |
convs = ConversationService.query(
|
| 130 |
dialog_id=dialog_id,
|
| 131 |
order_by=ConversationService.model.create_time,
|
|
|
|
| 297 |
def ask_about():
|
| 298 |
req = request.json
|
| 299 |
uid = current_user.id
|
| 300 |
+
|
| 301 |
def stream():
|
| 302 |
nonlocal req, uid
|
| 303 |
try:
|
|
|
|
| 330 |
embd_mdl = TenantLLMService.model_instance(
|
| 331 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
| 332 |
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
| 333 |
+
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
| 334 |
+
0.3, 0.3, aggs=False)
|
| 335 |
mindmap = MindMapExtractor(chat_mdl)
|
| 336 |
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
| 337 |
if "error" in mind_map:
|
api/apps/dialog_app.py
CHANGED
|
@@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService
|
|
| 20 |
from api.db import StatusEnum
|
| 21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 22 |
from api.db.services.user_service import TenantService, UserTenantService
|
| 23 |
-
from api
|
| 24 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 25 |
from api.utils import get_uuid
|
| 26 |
from api.utils.api_utils import get_json_result
|
|
@@ -175,7 +175,7 @@ def rm():
|
|
| 175 |
else:
|
| 176 |
return get_json_result(
|
| 177 |
data=False, message='Only owner of dialog authorized for this operation.',
|
| 178 |
-
code=RetCode.OPERATING_ERROR)
|
| 179 |
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
| 180 |
DialogService.update_many_by_id(dialog_list)
|
| 181 |
return get_json_result(data=True)
|
|
|
|
| 20 |
from api.db import StatusEnum
|
| 21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 22 |
from api.db.services.user_service import TenantService, UserTenantService
|
| 23 |
+
from api import settings
|
| 24 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 25 |
from api.utils import get_uuid
|
| 26 |
from api.utils.api_utils import get_json_result
|
|
|
|
| 175 |
else:
|
| 176 |
return get_json_result(
|
| 177 |
data=False, message='Only owner of dialog authorized for this operation.',
|
| 178 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 179 |
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
| 180 |
DialogService.update_many_by_id(dialog_list)
|
| 181 |
return get_json_result(data=True)
|
api/apps/document_app.py
CHANGED
|
@@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|
| 34 |
from api.utils import get_uuid
|
| 35 |
from api.db import FileType, TaskStatus, ParserType, FileSource
|
| 36 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
| 37 |
-
from api
|
| 38 |
from api.utils.api_utils import get_json_result
|
| 39 |
from rag.utils.storage_factory import STORAGE_IMPL
|
| 40 |
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
|
|
@@ -49,16 +49,16 @@ def upload():
|
|
| 49 |
kb_id = request.form.get("kb_id")
|
| 50 |
if not kb_id:
|
| 51 |
return get_json_result(
|
| 52 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
| 53 |
if 'file' not in request.files:
|
| 54 |
return get_json_result(
|
| 55 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
| 56 |
|
| 57 |
file_objs = request.files.getlist('file')
|
| 58 |
for file_obj in file_objs:
|
| 59 |
if file_obj.filename == '':
|
| 60 |
return get_json_result(
|
| 61 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
| 62 |
|
| 63 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
| 64 |
if not e:
|
|
@@ -67,7 +67,7 @@ def upload():
|
|
| 67 |
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
| 68 |
if err:
|
| 69 |
return get_json_result(
|
| 70 |
-
data=False, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
| 71 |
return get_json_result(data=True)
|
| 72 |
|
| 73 |
|
|
@@ -78,12 +78,12 @@ def web_crawl():
|
|
| 78 |
kb_id = request.form.get("kb_id")
|
| 79 |
if not kb_id:
|
| 80 |
return get_json_result(
|
| 81 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
| 82 |
name = request.form.get("name")
|
| 83 |
url = request.form.get("url")
|
| 84 |
if not is_valid_url(url):
|
| 85 |
return get_json_result(
|
| 86 |
-
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
| 87 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
| 88 |
if not e:
|
| 89 |
raise LookupError("Can't find this knowledgebase!")
|
|
@@ -145,7 +145,7 @@ def create():
|
|
| 145 |
kb_id = req["kb_id"]
|
| 146 |
if not kb_id:
|
| 147 |
return get_json_result(
|
| 148 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
| 149 |
|
| 150 |
try:
|
| 151 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
|
@@ -179,7 +179,7 @@ def list_docs():
|
|
| 179 |
kb_id = request.args.get("kb_id")
|
| 180 |
if not kb_id:
|
| 181 |
return get_json_result(
|
| 182 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
| 183 |
tenants = UserTenantService.query(user_id=current_user.id)
|
| 184 |
for tenant in tenants:
|
| 185 |
if KnowledgebaseService.query(
|
|
@@ -188,7 +188,7 @@ def list_docs():
|
|
| 188 |
else:
|
| 189 |
return get_json_result(
|
| 190 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
| 191 |
-
code=RetCode.OPERATING_ERROR)
|
| 192 |
keywords = request.args.get("keywords", "")
|
| 193 |
|
| 194 |
page_number = int(request.args.get("page", 1))
|
|
@@ -218,19 +218,19 @@ def docinfos():
|
|
| 218 |
return get_json_result(
|
| 219 |
data=False,
|
| 220 |
message='No authorization.',
|
| 221 |
-
code=RetCode.AUTHENTICATION_ERROR
|
| 222 |
)
|
| 223 |
docs = DocumentService.get_by_ids(doc_ids)
|
| 224 |
return get_json_result(data=list(docs.dicts()))
|
| 225 |
|
| 226 |
|
| 227 |
@manager.route('/thumbnails', methods=['GET'])
|
| 228 |
-
|
| 229 |
def thumbnails():
|
| 230 |
doc_ids = request.args.get("doc_ids").split(",")
|
| 231 |
if not doc_ids:
|
| 232 |
return get_json_result(
|
| 233 |
-
data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
|
| 234 |
|
| 235 |
try:
|
| 236 |
docs = DocumentService.get_thumbnails(doc_ids)
|
|
@@ -253,13 +253,13 @@ def change_status():
|
|
| 253 |
return get_json_result(
|
| 254 |
data=False,
|
| 255 |
message='"Status" must be either 0 or 1!',
|
| 256 |
-
code=RetCode.ARGUMENT_ERROR)
|
| 257 |
|
| 258 |
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
| 259 |
return get_json_result(
|
| 260 |
data=False,
|
| 261 |
message='No authorization.',
|
| 262 |
-
code=RetCode.AUTHENTICATION_ERROR)
|
| 263 |
|
| 264 |
try:
|
| 265 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
@@ -276,7 +276,8 @@ def change_status():
|
|
| 276 |
message="Database error (Document update)!")
|
| 277 |
|
| 278 |
status = int(req["status"])
|
| 279 |
-
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
|
|
|
|
| 280 |
return get_json_result(data=True)
|
| 281 |
except Exception as e:
|
| 282 |
return server_error_response(e)
|
|
@@ -295,7 +296,7 @@ def rm():
|
|
| 295 |
return get_json_result(
|
| 296 |
data=False,
|
| 297 |
message='No authorization.',
|
| 298 |
-
code=RetCode.AUTHENTICATION_ERROR
|
| 299 |
)
|
| 300 |
|
| 301 |
root_folder = FileService.get_root_folder(current_user.id)
|
|
@@ -326,7 +327,7 @@ def rm():
|
|
| 326 |
errors += str(e)
|
| 327 |
|
| 328 |
if errors:
|
| 329 |
-
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
| 330 |
|
| 331 |
return get_json_result(data=True)
|
| 332 |
|
|
@@ -341,7 +342,7 @@ def run():
|
|
| 341 |
return get_json_result(
|
| 342 |
data=False,
|
| 343 |
message='No authorization.',
|
| 344 |
-
code=RetCode.AUTHENTICATION_ERROR
|
| 345 |
)
|
| 346 |
try:
|
| 347 |
for id in req["doc_ids"]:
|
|
@@ -358,8 +359,8 @@ def run():
|
|
| 358 |
e, doc = DocumentService.get_by_id(id)
|
| 359 |
if not e:
|
| 360 |
return get_data_error_result(message="Document not found!")
|
| 361 |
-
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
| 362 |
-
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
| 363 |
|
| 364 |
if str(req["run"]) == TaskStatus.RUNNING.value:
|
| 365 |
TaskService.filter_delete([Task.doc_id == id])
|
|
@@ -383,7 +384,7 @@ def rename():
|
|
| 383 |
return get_json_result(
|
| 384 |
data=False,
|
| 385 |
message='No authorization.',
|
| 386 |
-
code=RetCode.AUTHENTICATION_ERROR
|
| 387 |
)
|
| 388 |
try:
|
| 389 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
@@ -394,7 +395,7 @@ def rename():
|
|
| 394 |
return get_json_result(
|
| 395 |
data=False,
|
| 396 |
message="The extension of file can't be changed",
|
| 397 |
-
code=RetCode.ARGUMENT_ERROR)
|
| 398 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
| 399 |
if d.name == req["name"]:
|
| 400 |
return get_data_error_result(
|
|
@@ -450,7 +451,7 @@ def change_parser():
|
|
| 450 |
return get_json_result(
|
| 451 |
data=False,
|
| 452 |
message='No authorization.',
|
| 453 |
-
code=RetCode.AUTHENTICATION_ERROR
|
| 454 |
)
|
| 455 |
try:
|
| 456 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
@@ -483,8 +484,8 @@ def change_parser():
|
|
| 483 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
| 484 |
if not tenant_id:
|
| 485 |
return get_data_error_result(message="Tenant not found!")
|
| 486 |
-
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
| 487 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
| 488 |
|
| 489 |
return get_json_result(data=True)
|
| 490 |
except Exception as e:
|
|
@@ -509,13 +510,13 @@ def get_image(image_id):
|
|
| 509 |
def upload_and_parse():
|
| 510 |
if 'file' not in request.files:
|
| 511 |
return get_json_result(
|
| 512 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
| 513 |
|
| 514 |
file_objs = request.files.getlist('file')
|
| 515 |
for file_obj in file_objs:
|
| 516 |
if file_obj.filename == '':
|
| 517 |
return get_json_result(
|
| 518 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
| 519 |
|
| 520 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
| 521 |
|
|
@@ -529,7 +530,7 @@ def parse():
|
|
| 529 |
if url:
|
| 530 |
if not is_valid_url(url):
|
| 531 |
return get_json_result(
|
| 532 |
-
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
| 533 |
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
| 534 |
os.makedirs(download_path, exist_ok=True)
|
| 535 |
from selenium.webdriver import Chrome, ChromeOptions
|
|
@@ -553,7 +554,7 @@ def parse():
|
|
| 553 |
|
| 554 |
if 'file' not in request.files:
|
| 555 |
return get_json_result(
|
| 556 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
| 557 |
|
| 558 |
file_objs = request.files.getlist('file')
|
| 559 |
txt = FileService.parse_docs(file_objs, current_user.id)
|
|
|
|
| 34 |
from api.utils import get_uuid
|
| 35 |
from api.db import FileType, TaskStatus, ParserType, FileSource
|
| 36 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
| 37 |
+
from api import settings
|
| 38 |
from api.utils.api_utils import get_json_result
|
| 39 |
from rag.utils.storage_factory import STORAGE_IMPL
|
| 40 |
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
|
|
|
|
| 49 |
kb_id = request.form.get("kb_id")
|
| 50 |
if not kb_id:
|
| 51 |
return get_json_result(
|
| 52 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
| 53 |
if 'file' not in request.files:
|
| 54 |
return get_json_result(
|
| 55 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 56 |
|
| 57 |
file_objs = request.files.getlist('file')
|
| 58 |
for file_obj in file_objs:
|
| 59 |
if file_obj.filename == '':
|
| 60 |
return get_json_result(
|
| 61 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 62 |
|
| 63 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
| 64 |
if not e:
|
|
|
|
| 67 |
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
| 68 |
if err:
|
| 69 |
return get_json_result(
|
| 70 |
+
data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
| 71 |
return get_json_result(data=True)
|
| 72 |
|
| 73 |
|
|
|
|
| 78 |
kb_id = request.form.get("kb_id")
|
| 79 |
if not kb_id:
|
| 80 |
return get_json_result(
|
| 81 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
| 82 |
name = request.form.get("name")
|
| 83 |
url = request.form.get("url")
|
| 84 |
if not is_valid_url(url):
|
| 85 |
return get_json_result(
|
| 86 |
+
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
| 87 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
| 88 |
if not e:
|
| 89 |
raise LookupError("Can't find this knowledgebase!")
|
|
|
|
| 145 |
kb_id = req["kb_id"]
|
| 146 |
if not kb_id:
|
| 147 |
return get_json_result(
|
| 148 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
| 149 |
|
| 150 |
try:
|
| 151 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
|
|
|
| 179 |
kb_id = request.args.get("kb_id")
|
| 180 |
if not kb_id:
|
| 181 |
return get_json_result(
|
| 182 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
| 183 |
tenants = UserTenantService.query(user_id=current_user.id)
|
| 184 |
for tenant in tenants:
|
| 185 |
if KnowledgebaseService.query(
|
|
|
|
| 188 |
else:
|
| 189 |
return get_json_result(
|
| 190 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
| 191 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 192 |
keywords = request.args.get("keywords", "")
|
| 193 |
|
| 194 |
page_number = int(request.args.get("page", 1))
|
|
|
|
| 218 |
return get_json_result(
|
| 219 |
data=False,
|
| 220 |
message='No authorization.',
|
| 221 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
| 222 |
)
|
| 223 |
docs = DocumentService.get_by_ids(doc_ids)
|
| 224 |
return get_json_result(data=list(docs.dicts()))
|
| 225 |
|
| 226 |
|
| 227 |
@manager.route('/thumbnails', methods=['GET'])
|
| 228 |
+
# @login_required
|
| 229 |
def thumbnails():
|
| 230 |
doc_ids = request.args.get("doc_ids").split(",")
|
| 231 |
if not doc_ids:
|
| 232 |
return get_json_result(
|
| 233 |
+
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
| 234 |
|
| 235 |
try:
|
| 236 |
docs = DocumentService.get_thumbnails(doc_ids)
|
|
|
|
| 253 |
return get_json_result(
|
| 254 |
data=False,
|
| 255 |
message='"Status" must be either 0 or 1!',
|
| 256 |
+
code=settings.RetCode.ARGUMENT_ERROR)
|
| 257 |
|
| 258 |
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
| 259 |
return get_json_result(
|
| 260 |
data=False,
|
| 261 |
message='No authorization.',
|
| 262 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 263 |
|
| 264 |
try:
|
| 265 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
|
|
| 276 |
message="Database error (Document update)!")
|
| 277 |
|
| 278 |
status = int(req["status"])
|
| 279 |
+
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
|
| 280 |
+
search.index_name(kb.tenant_id), doc.kb_id)
|
| 281 |
return get_json_result(data=True)
|
| 282 |
except Exception as e:
|
| 283 |
return server_error_response(e)
|
|
|
|
| 296 |
return get_json_result(
|
| 297 |
data=False,
|
| 298 |
message='No authorization.',
|
| 299 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
| 300 |
)
|
| 301 |
|
| 302 |
root_folder = FileService.get_root_folder(current_user.id)
|
|
|
|
| 327 |
errors += str(e)
|
| 328 |
|
| 329 |
if errors:
|
| 330 |
+
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
| 331 |
|
| 332 |
return get_json_result(data=True)
|
| 333 |
|
|
|
|
| 342 |
return get_json_result(
|
| 343 |
data=False,
|
| 344 |
message='No authorization.',
|
| 345 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
| 346 |
)
|
| 347 |
try:
|
| 348 |
for id in req["doc_ids"]:
|
|
|
|
| 359 |
e, doc = DocumentService.get_by_id(id)
|
| 360 |
if not e:
|
| 361 |
return get_data_error_result(message="Document not found!")
|
| 362 |
+
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
| 363 |
+
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
| 364 |
|
| 365 |
if str(req["run"]) == TaskStatus.RUNNING.value:
|
| 366 |
TaskService.filter_delete([Task.doc_id == id])
|
|
|
|
| 384 |
return get_json_result(
|
| 385 |
data=False,
|
| 386 |
message='No authorization.',
|
| 387 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
| 388 |
)
|
| 389 |
try:
|
| 390 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
|
|
| 395 |
return get_json_result(
|
| 396 |
data=False,
|
| 397 |
message="The extension of file can't be changed",
|
| 398 |
+
code=settings.RetCode.ARGUMENT_ERROR)
|
| 399 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
| 400 |
if d.name == req["name"]:
|
| 401 |
return get_data_error_result(
|
|
|
|
| 451 |
return get_json_result(
|
| 452 |
data=False,
|
| 453 |
message='No authorization.',
|
| 454 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
| 455 |
)
|
| 456 |
try:
|
| 457 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
|
|
| 484 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
| 485 |
if not tenant_id:
|
| 486 |
return get_data_error_result(message="Tenant not found!")
|
| 487 |
+
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
| 488 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
| 489 |
|
| 490 |
return get_json_result(data=True)
|
| 491 |
except Exception as e:
|
|
|
|
| 510 |
def upload_and_parse():
|
| 511 |
if 'file' not in request.files:
|
| 512 |
return get_json_result(
|
| 513 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 514 |
|
| 515 |
file_objs = request.files.getlist('file')
|
| 516 |
for file_obj in file_objs:
|
| 517 |
if file_obj.filename == '':
|
| 518 |
return get_json_result(
|
| 519 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 520 |
|
| 521 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
| 522 |
|
|
|
|
| 530 |
if url:
|
| 531 |
if not is_valid_url(url):
|
| 532 |
return get_json_result(
|
| 533 |
+
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
| 534 |
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
| 535 |
os.makedirs(download_path, exist_ok=True)
|
| 536 |
from selenium.webdriver import Chrome, ChromeOptions
|
|
|
|
| 554 |
|
| 555 |
if 'file' not in request.files:
|
| 556 |
return get_json_result(
|
| 557 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 558 |
|
| 559 |
file_objs = request.files.getlist('file')
|
| 560 |
txt = FileService.parse_docs(file_objs, current_user.id)
|
api/apps/file2document_app.py
CHANGED
|
@@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|
| 24 |
from api.utils import get_uuid
|
| 25 |
from api.db import FileType
|
| 26 |
from api.db.services.document_service import DocumentService
|
| 27 |
-
from api
|
| 28 |
from api.utils.api_utils import get_json_result
|
| 29 |
|
| 30 |
|
|
@@ -100,7 +100,7 @@ def rm():
|
|
| 100 |
file_ids = req["file_ids"]
|
| 101 |
if not file_ids:
|
| 102 |
return get_json_result(
|
| 103 |
-
data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
|
| 104 |
try:
|
| 105 |
for file_id in file_ids:
|
| 106 |
informs = File2DocumentService.get_by_file_id(file_id)
|
|
|
|
| 24 |
from api.utils import get_uuid
|
| 25 |
from api.db import FileType
|
| 26 |
from api.db.services.document_service import DocumentService
|
| 27 |
+
from api import settings
|
| 28 |
from api.utils.api_utils import get_json_result
|
| 29 |
|
| 30 |
|
|
|
|
| 100 |
file_ids = req["file_ids"]
|
| 101 |
if not file_ids:
|
| 102 |
return get_json_result(
|
| 103 |
+
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
| 104 |
try:
|
| 105 |
for file_id in file_ids:
|
| 106 |
informs = File2DocumentService.get_by_file_id(file_id)
|
api/apps/file_app.py
CHANGED
|
@@ -28,7 +28,7 @@ from api.utils import get_uuid
|
|
| 28 |
from api.db import FileType, FileSource
|
| 29 |
from api.db.services import duplicate_name
|
| 30 |
from api.db.services.file_service import FileService
|
| 31 |
-
from api
|
| 32 |
from api.utils.api_utils import get_json_result
|
| 33 |
from api.utils.file_utils import filename_type
|
| 34 |
from rag.utils.storage_factory import STORAGE_IMPL
|
|
@@ -46,13 +46,13 @@ def upload():
|
|
| 46 |
|
| 47 |
if 'file' not in request.files:
|
| 48 |
return get_json_result(
|
| 49 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
| 50 |
file_objs = request.files.getlist('file')
|
| 51 |
|
| 52 |
for file_obj in file_objs:
|
| 53 |
if file_obj.filename == '':
|
| 54 |
return get_json_result(
|
| 55 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
| 56 |
file_res = []
|
| 57 |
try:
|
| 58 |
for file_obj in file_objs:
|
|
@@ -134,7 +134,7 @@ def create():
|
|
| 134 |
try:
|
| 135 |
if not FileService.is_parent_folder_exist(pf_id):
|
| 136 |
return get_json_result(
|
| 137 |
-
data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
|
| 138 |
if FileService.query(name=req["name"], parent_id=pf_id):
|
| 139 |
return get_data_error_result(
|
| 140 |
message="Duplicated folder name in the same folder.")
|
|
@@ -299,7 +299,7 @@ def rename():
|
|
| 299 |
return get_json_result(
|
| 300 |
data=False,
|
| 301 |
message="The extension of file can't be changed",
|
| 302 |
-
code=RetCode.ARGUMENT_ERROR)
|
| 303 |
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
| 304 |
if file.name == req["name"]:
|
| 305 |
return get_data_error_result(
|
|
|
|
| 28 |
from api.db import FileType, FileSource
|
| 29 |
from api.db.services import duplicate_name
|
| 30 |
from api.db.services.file_service import FileService
|
| 31 |
+
from api import settings
|
| 32 |
from api.utils.api_utils import get_json_result
|
| 33 |
from api.utils.file_utils import filename_type
|
| 34 |
from rag.utils.storage_factory import STORAGE_IMPL
|
|
|
|
| 46 |
|
| 47 |
if 'file' not in request.files:
|
| 48 |
return get_json_result(
|
| 49 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 50 |
file_objs = request.files.getlist('file')
|
| 51 |
|
| 52 |
for file_obj in file_objs:
|
| 53 |
if file_obj.filename == '':
|
| 54 |
return get_json_result(
|
| 55 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
| 56 |
file_res = []
|
| 57 |
try:
|
| 58 |
for file_obj in file_objs:
|
|
|
|
| 134 |
try:
|
| 135 |
if not FileService.is_parent_folder_exist(pf_id):
|
| 136 |
return get_json_result(
|
| 137 |
+
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
|
| 138 |
if FileService.query(name=req["name"], parent_id=pf_id):
|
| 139 |
return get_data_error_result(
|
| 140 |
message="Duplicated folder name in the same folder.")
|
|
|
|
| 299 |
return get_json_result(
|
| 300 |
data=False,
|
| 301 |
message="The extension of file can't be changed",
|
| 302 |
+
code=settings.RetCode.ARGUMENT_ERROR)
|
| 303 |
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
| 304 |
if file.name == req["name"]:
|
| 305 |
return get_data_error_result(
|
api/apps/kb_app.py
CHANGED
|
@@ -26,9 +26,8 @@ from api.utils import get_uuid
|
|
| 26 |
from api.db import StatusEnum, FileSource
|
| 27 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 28 |
from api.db.db_models import File
|
| 29 |
-
from api.settings import RetCode
|
| 30 |
from api.utils.api_utils import get_json_result
|
| 31 |
-
from api
|
| 32 |
from rag.nlp import search
|
| 33 |
|
| 34 |
|
|
@@ -68,13 +67,13 @@ def update():
|
|
| 68 |
return get_json_result(
|
| 69 |
data=False,
|
| 70 |
message='No authorization.',
|
| 71 |
-
code=RetCode.AUTHENTICATION_ERROR
|
| 72 |
)
|
| 73 |
try:
|
| 74 |
if not KnowledgebaseService.query(
|
| 75 |
created_by=current_user.id, id=req["kb_id"]):
|
| 76 |
return get_json_result(
|
| 77 |
-
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
| 78 |
|
| 79 |
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
| 80 |
if not e:
|
|
@@ -113,7 +112,7 @@ def detail():
|
|
| 113 |
else:
|
| 114 |
return get_json_result(
|
| 115 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
| 116 |
-
code=RetCode.OPERATING_ERROR)
|
| 117 |
kb = KnowledgebaseService.get_detail(kb_id)
|
| 118 |
if not kb:
|
| 119 |
return get_data_error_result(
|
|
@@ -148,14 +147,14 @@ def rm():
|
|
| 148 |
return get_json_result(
|
| 149 |
data=False,
|
| 150 |
message='No authorization.',
|
| 151 |
-
code=RetCode.AUTHENTICATION_ERROR
|
| 152 |
)
|
| 153 |
try:
|
| 154 |
kbs = KnowledgebaseService.query(
|
| 155 |
created_by=current_user.id, id=req["kb_id"])
|
| 156 |
if not kbs:
|
| 157 |
return get_json_result(
|
| 158 |
-
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
| 159 |
|
| 160 |
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
| 161 |
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
|
@@ -170,7 +169,7 @@ def rm():
|
|
| 170 |
message="Database error (Knowledgebase removal)!")
|
| 171 |
tenants = UserTenantService.query(user_id=current_user.id)
|
| 172 |
for tenant in tenants:
|
| 173 |
-
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
| 174 |
return get_json_result(data=True)
|
| 175 |
except Exception as e:
|
| 176 |
return server_error_response(e)
|
|
|
|
| 26 |
from api.db import StatusEnum, FileSource
|
| 27 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 28 |
from api.db.db_models import File
|
|
|
|
| 29 |
from api.utils.api_utils import get_json_result
|
| 30 |
+
from api import settings
|
| 31 |
from rag.nlp import search
|
| 32 |
|
| 33 |
|
|
|
|
| 67 |
return get_json_result(
|
| 68 |
data=False,
|
| 69 |
message='No authorization.',
|
| 70 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
| 71 |
)
|
| 72 |
try:
|
| 73 |
if not KnowledgebaseService.query(
|
| 74 |
created_by=current_user.id, id=req["kb_id"]):
|
| 75 |
return get_json_result(
|
| 76 |
+
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
| 77 |
|
| 78 |
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
| 79 |
if not e:
|
|
|
|
| 112 |
else:
|
| 113 |
return get_json_result(
|
| 114 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
| 115 |
+
code=settings.RetCode.OPERATING_ERROR)
|
| 116 |
kb = KnowledgebaseService.get_detail(kb_id)
|
| 117 |
if not kb:
|
| 118 |
return get_data_error_result(
|
|
|
|
| 147 |
return get_json_result(
|
| 148 |
data=False,
|
| 149 |
message='No authorization.',
|
| 150 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
| 151 |
)
|
| 152 |
try:
|
| 153 |
kbs = KnowledgebaseService.query(
|
| 154 |
created_by=current_user.id, id=req["kb_id"])
|
| 155 |
if not kbs:
|
| 156 |
return get_json_result(
|
| 157 |
+
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
| 158 |
|
| 159 |
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
| 160 |
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
|
|
|
| 169 |
message="Database error (Knowledgebase removal)!")
|
| 170 |
tenants = UserTenantService.query(user_id=current_user.id)
|
| 171 |
for tenant in tenants:
|
| 172 |
+
settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
| 173 |
return get_json_result(data=True)
|
| 174 |
except Exception as e:
|
| 175 |
return server_error_response(e)
|
api/apps/llm_app.py
CHANGED
|
@@ -19,7 +19,7 @@ import json
|
|
| 19 |
from flask import request
|
| 20 |
from flask_login import login_required, current_user
|
| 21 |
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
| 22 |
-
from api
|
| 23 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 24 |
from api.db import StatusEnum, LLMType
|
| 25 |
from api.db.db_models import TenantLLM
|
|
@@ -333,7 +333,7 @@ def my_llms():
|
|
| 333 |
@login_required
|
| 334 |
def list_app():
|
| 335 |
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
| 336 |
-
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else []
|
| 337 |
model_type = request.args.get("model_type")
|
| 338 |
try:
|
| 339 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
|
|
|
| 19 |
from flask import request
|
| 20 |
from flask_login import login_required, current_user
|
| 21 |
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
| 22 |
+
from api import settings
|
| 23 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 24 |
from api.db import StatusEnum, LLMType
|
| 25 |
from api.db.db_models import TenantLLM
|
|
|
|
| 333 |
@login_required
|
| 334 |
def list_app():
|
| 335 |
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
| 336 |
+
weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
| 337 |
model_type = request.args.get("model_type")
|
| 338 |
try:
|
| 339 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
api/apps/sdk/chat.py
CHANGED
|
@@ -14,7 +14,7 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
from flask import request
|
| 17 |
-
from api
|
| 18 |
from api.db import StatusEnum
|
| 19 |
from api.db.services.dialog_service import DialogService
|
| 20 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
@@ -44,7 +44,7 @@ def create(tenant_id):
|
|
| 44 |
kbs = KnowledgebaseService.get_by_ids(ids)
|
| 45 |
embd_count = list(set([kb.embd_id for kb in kbs]))
|
| 46 |
if len(embd_count) != 1:
|
| 47 |
-
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
|
| 48 |
req["kb_ids"] = ids
|
| 49 |
# llm
|
| 50 |
llm = req.get("llm")
|
|
@@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
|
|
| 173 |
if len(embd_count) != 1 :
|
| 174 |
return get_result(
|
| 175 |
message='Datasets use different embedding models."',
|
| 176 |
-
code=RetCode.AUTHENTICATION_ERROR)
|
| 177 |
req["kb_ids"] = ids
|
| 178 |
llm = req.get("llm")
|
| 179 |
if llm:
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
from flask import request
|
| 17 |
+
from api import settings
|
| 18 |
from api.db import StatusEnum
|
| 19 |
from api.db.services.dialog_service import DialogService
|
| 20 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
|
| 44 |
kbs = KnowledgebaseService.get_by_ids(ids)
|
| 45 |
embd_count = list(set([kb.embd_id for kb in kbs]))
|
| 46 |
if len(embd_count) != 1:
|
| 47 |
+
return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 48 |
req["kb_ids"] = ids
|
| 49 |
# llm
|
| 50 |
llm = req.get("llm")
|
|
|
|
| 173 |
if len(embd_count) != 1 :
|
| 174 |
return get_result(
|
| 175 |
message='Datasets use different embedding models."',
|
| 176 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
| 177 |
req["kb_ids"] = ids
|
| 178 |
llm = req.get("llm")
|
| 179 |
if llm:
|
api/apps/sdk/dataset.py
CHANGED
|
@@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
|
|
| 23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 24 |
from api.db.services.llm_service import TenantLLMService, LLMService
|
| 25 |
from api.db.services.user_service import TenantService
|
| 26 |
-
from api
|
| 27 |
from api.utils import get_uuid
|
| 28 |
from api.utils.api_utils import (
|
| 29 |
get_result,
|
|
@@ -255,7 +255,7 @@ def delete(tenant_id):
|
|
| 255 |
File2DocumentService.delete_by_document_id(doc.id)
|
| 256 |
if not KnowledgebaseService.delete_by_id(id):
|
| 257 |
return get_error_data_result(message="Delete dataset error.(Database error)")
|
| 258 |
-
return get_result(code=RetCode.SUCCESS)
|
| 259 |
|
| 260 |
|
| 261 |
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
|
|
@@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
|
|
| 424 |
)
|
| 425 |
if not KnowledgebaseService.update_by_id(kb.id, req):
|
| 426 |
return get_error_data_result(message="Update dataset error.(Database error)")
|
| 427 |
-
return get_result(code=RetCode.SUCCESS)
|
| 428 |
|
| 429 |
|
| 430 |
@manager.route("/datasets", methods=["GET"])
|
|
|
|
| 23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 24 |
from api.db.services.llm_service import TenantLLMService, LLMService
|
| 25 |
from api.db.services.user_service import TenantService
|
| 26 |
+
from api import settings
|
| 27 |
from api.utils import get_uuid
|
| 28 |
from api.utils.api_utils import (
|
| 29 |
get_result,
|
|
|
|
| 255 |
File2DocumentService.delete_by_document_id(doc.id)
|
| 256 |
if not KnowledgebaseService.delete_by_id(id):
|
| 257 |
return get_error_data_result(message="Delete dataset error.(Database error)")
|
| 258 |
+
return get_result(code=settings.RetCode.SUCCESS)
|
| 259 |
|
| 260 |
|
| 261 |
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
|
|
|
|
| 424 |
)
|
| 425 |
if not KnowledgebaseService.update_by_id(kb.id, req):
|
| 426 |
return get_error_data_result(message="Update dataset error.(Database error)")
|
| 427 |
+
return get_result(code=settings.RetCode.SUCCESS)
|
| 428 |
|
| 429 |
|
| 430 |
@manager.route("/datasets", methods=["GET"])
|
api/apps/sdk/dify_retrieval.py
CHANGED
|
@@ -18,7 +18,7 @@ from flask import request, jsonify
|
|
| 18 |
from api.db import LLMType, ParserType
|
| 19 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 20 |
from api.db.services.llm_service import LLMBundle
|
| 21 |
-
from api
|
| 22 |
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
| 23 |
|
| 24 |
|
|
@@ -37,14 +37,14 @@ def retrieval(tenant_id):
|
|
| 37 |
|
| 38 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
| 39 |
if not e:
|
| 40 |
-
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
| 41 |
|
| 42 |
if kb.tenant_id != tenant_id:
|
| 43 |
-
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
| 44 |
|
| 45 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
| 46 |
|
| 47 |
-
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
| 48 |
ranks = retr.retrieval(
|
| 49 |
question,
|
| 50 |
embd_mdl,
|
|
@@ -72,6 +72,6 @@ def retrieval(tenant_id):
|
|
| 72 |
if str(e).find("not_found") > 0:
|
| 73 |
return build_error_result(
|
| 74 |
message='No chunk found! Check the chunk status please!',
|
| 75 |
-
code=RetCode.NOT_FOUND
|
| 76 |
)
|
| 77 |
-
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
|
|
|
|
| 18 |
from api.db import LLMType, ParserType
|
| 19 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 20 |
from api.db.services.llm_service import LLMBundle
|
| 21 |
+
from api import settings
|
| 22 |
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
| 23 |
|
| 24 |
|
|
|
|
| 37 |
|
| 38 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
| 39 |
if not e:
|
| 40 |
+
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
| 41 |
|
| 42 |
if kb.tenant_id != tenant_id:
|
| 43 |
+
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
| 44 |
|
| 45 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
| 46 |
|
| 47 |
+
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
| 48 |
ranks = retr.retrieval(
|
| 49 |
question,
|
| 50 |
embd_mdl,
|
|
|
|
| 72 |
if str(e).find("not_found") > 0:
|
| 73 |
return build_error_result(
|
| 74 |
message='No chunk found! Check the chunk status please!',
|
| 75 |
+
code=settings.RetCode.NOT_FOUND
|
| 76 |
)
|
| 77 |
+
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
api/apps/sdk/doc.py
CHANGED
|
@@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
|
|
| 21 |
from rag.nlp import rag_tokenizer
|
| 22 |
from api.db import LLMType, ParserType
|
| 23 |
from api.db.services.llm_service import TenantLLMService
|
| 24 |
-
from api
|
| 25 |
import hashlib
|
| 26 |
import re
|
| 27 |
from api.utils.api_utils import token_required
|
|
@@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
|
|
| 37 |
from api.db.services.file2document_service import File2DocumentService
|
| 38 |
from api.db.services.file_service import FileService
|
| 39 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 40 |
-
from api
|
| 41 |
from api.utils.api_utils import construct_json_result, get_parser_config
|
| 42 |
from rag.nlp import search
|
| 43 |
from rag.utils import rmSpace
|
| 44 |
-
from api.settings import docStoreConn
|
| 45 |
from rag.utils.storage_factory import STORAGE_IMPL
|
| 46 |
import os
|
| 47 |
|
|
@@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
|
|
| 109 |
"""
|
| 110 |
if "file" not in request.files:
|
| 111 |
return get_error_data_result(
|
| 112 |
-
message="No file part!", code=RetCode.ARGUMENT_ERROR
|
| 113 |
)
|
| 114 |
file_objs = request.files.getlist("file")
|
| 115 |
for file_obj in file_objs:
|
| 116 |
if file_obj.filename == "":
|
| 117 |
return get_result(
|
| 118 |
-
message="No file selected!", code=RetCode.ARGUMENT_ERROR
|
| 119 |
)
|
| 120 |
# total size
|
| 121 |
total_size = 0
|
|
@@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
|
|
| 127 |
if total_size > MAX_TOTAL_FILE_SIZE:
|
| 128 |
return get_result(
|
| 129 |
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
|
| 130 |
-
code=RetCode.ARGUMENT_ERROR,
|
| 131 |
)
|
| 132 |
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
| 133 |
if not e:
|
| 134 |
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
| 135 |
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
| 136 |
if err:
|
| 137 |
-
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
| 138 |
# rename key's name
|
| 139 |
renamed_doc_list = []
|
| 140 |
for file in files:
|
|
@@ -221,12 +220,12 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|
| 221 |
|
| 222 |
if "name" in req and req["name"] != doc.name:
|
| 223 |
if (
|
| 224 |
-
|
| 225 |
-
|
| 226 |
):
|
| 227 |
return get_result(
|
| 228 |
message="The extension of file can't be changed",
|
| 229 |
-
code=RetCode.ARGUMENT_ERROR,
|
| 230 |
)
|
| 231 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
| 232 |
if d.name == req["name"]:
|
|
@@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|
| 292 |
)
|
| 293 |
if not e:
|
| 294 |
return get_error_data_result(message="Document not found!")
|
| 295 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
| 296 |
|
| 297 |
return get_result()
|
| 298 |
|
|
@@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
|
|
| 349 |
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
| 350 |
if not file_stream:
|
| 351 |
return construct_json_result(
|
| 352 |
-
message="This file is empty.", code=RetCode.DATA_ERROR
|
| 353 |
)
|
| 354 |
file = BytesIO(file_stream)
|
| 355 |
# Use send_file with a proper filename and MIME type
|
|
@@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
|
|
| 582 |
errors += str(e)
|
| 583 |
|
| 584 |
if errors:
|
| 585 |
-
return get_result(message=errors, code=RetCode.SERVER_ERROR)
|
| 586 |
|
| 587 |
return get_result()
|
| 588 |
|
|
@@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
|
|
| 644 |
info["chunk_num"] = 0
|
| 645 |
info["token_num"] = 0
|
| 646 |
DocumentService.update_by_id(id, info)
|
| 647 |
-
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
| 648 |
TaskService.filter_delete([Task.doc_id == id])
|
| 649 |
e, doc = DocumentService.get_by_id(id)
|
| 650 |
doc = doc.to_dict()
|
|
@@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
|
|
| 708 |
)
|
| 709 |
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
| 710 |
DocumentService.update_by_id(id, info)
|
| 711 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
| 712 |
return get_result()
|
| 713 |
|
| 714 |
|
|
@@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|
| 828 |
|
| 829 |
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
| 830 |
origin_chunks = []
|
| 831 |
-
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
| 832 |
-
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
|
|
|
|
| 833 |
res["total"] = sres.total
|
| 834 |
sign = 0
|
| 835 |
for id in sres.ids:
|
|
@@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|
| 1003 |
v, c = embd_mdl.encode([doc.name, req["content"]])
|
| 1004 |
v = 0.1 * v[0] + 0.9 * v[1]
|
| 1005 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 1006 |
-
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
| 1007 |
|
| 1008 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
| 1009 |
# rename keys
|
|
@@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
|
| 1078 |
condition = {"doc_id": document_id}
|
| 1079 |
if "chunk_ids" in req:
|
| 1080 |
condition["id"] = req["chunk_ids"]
|
| 1081 |
-
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
| 1082 |
if chunk_number != 0:
|
| 1083 |
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
| 1084 |
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
|
@@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|
| 1143 |
schema:
|
| 1144 |
type: object
|
| 1145 |
"""
|
| 1146 |
-
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
| 1147 |
if chunk is None:
|
| 1148 |
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
| 1149 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
|
@@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|
| 1187 |
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
| 1188 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
| 1189 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 1190 |
-
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
| 1191 |
return get_result()
|
| 1192 |
|
| 1193 |
|
|
@@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
|
|
| 1285 |
if len(embd_nms) != 1:
|
| 1286 |
return get_result(
|
| 1287 |
message='Datasets use different embedding models."',
|
| 1288 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
| 1289 |
)
|
| 1290 |
if "question" not in req:
|
| 1291 |
return get_error_data_result("`question` is required.")
|
|
@@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
|
|
| 1326 |
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
| 1327 |
question += keyword_extraction(chat_mdl, question)
|
| 1328 |
|
| 1329 |
-
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
| 1330 |
ranks = retr.retrieval(
|
| 1331 |
question,
|
| 1332 |
embd_mdl,
|
|
@@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
|
|
| 1366 |
if str(e).find("not_found") > 0:
|
| 1367 |
return get_result(
|
| 1368 |
message="No chunk found! Check the chunk status please!",
|
| 1369 |
-
code=RetCode.DATA_ERROR,
|
| 1370 |
)
|
| 1371 |
-
return server_error_response(e)
|
|
|
|
| 21 |
from rag.nlp import rag_tokenizer
|
| 22 |
from api.db import LLMType, ParserType
|
| 23 |
from api.db.services.llm_service import TenantLLMService
|
| 24 |
+
from api import settings
|
| 25 |
import hashlib
|
| 26 |
import re
|
| 27 |
from api.utils.api_utils import token_required
|
|
|
|
| 37 |
from api.db.services.file2document_service import File2DocumentService
|
| 38 |
from api.db.services.file_service import FileService
|
| 39 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 40 |
+
from api import settings
|
| 41 |
from api.utils.api_utils import construct_json_result, get_parser_config
|
| 42 |
from rag.nlp import search
|
| 43 |
from rag.utils import rmSpace
|
|
|
|
| 44 |
from rag.utils.storage_factory import STORAGE_IMPL
|
| 45 |
import os
|
| 46 |
|
|
|
|
| 108 |
"""
|
| 109 |
if "file" not in request.files:
|
| 110 |
return get_error_data_result(
|
| 111 |
+
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
|
| 112 |
)
|
| 113 |
file_objs = request.files.getlist("file")
|
| 114 |
for file_obj in file_objs:
|
| 115 |
if file_obj.filename == "":
|
| 116 |
return get_result(
|
| 117 |
+
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
|
| 118 |
)
|
| 119 |
# total size
|
| 120 |
total_size = 0
|
|
|
|
| 126 |
if total_size > MAX_TOTAL_FILE_SIZE:
|
| 127 |
return get_result(
|
| 128 |
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
|
| 129 |
+
code=settings.RetCode.ARGUMENT_ERROR,
|
| 130 |
)
|
| 131 |
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
| 132 |
if not e:
|
| 133 |
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
| 134 |
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
| 135 |
if err:
|
| 136 |
+
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
| 137 |
# rename key's name
|
| 138 |
renamed_doc_list = []
|
| 139 |
for file in files:
|
|
|
|
| 220 |
|
| 221 |
if "name" in req and req["name"] != doc.name:
|
| 222 |
if (
|
| 223 |
+
pathlib.Path(req["name"].lower()).suffix
|
| 224 |
+
!= pathlib.Path(doc.name.lower()).suffix
|
| 225 |
):
|
| 226 |
return get_result(
|
| 227 |
message="The extension of file can't be changed",
|
| 228 |
+
code=settings.RetCode.ARGUMENT_ERROR,
|
| 229 |
)
|
| 230 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
| 231 |
if d.name == req["name"]:
|
|
|
|
| 291 |
)
|
| 292 |
if not e:
|
| 293 |
return get_error_data_result(message="Document not found!")
|
| 294 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
| 295 |
|
| 296 |
return get_result()
|
| 297 |
|
|
|
|
| 348 |
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
| 349 |
if not file_stream:
|
| 350 |
return construct_json_result(
|
| 351 |
+
message="This file is empty.", code=settings.RetCode.DATA_ERROR
|
| 352 |
)
|
| 353 |
file = BytesIO(file_stream)
|
| 354 |
# Use send_file with a proper filename and MIME type
|
|
|
|
| 581 |
errors += str(e)
|
| 582 |
|
| 583 |
if errors:
|
| 584 |
+
return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
|
| 585 |
|
| 586 |
return get_result()
|
| 587 |
|
|
|
|
| 643 |
info["chunk_num"] = 0
|
| 644 |
info["token_num"] = 0
|
| 645 |
DocumentService.update_by_id(id, info)
|
| 646 |
+
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
| 647 |
TaskService.filter_delete([Task.doc_id == id])
|
| 648 |
e, doc = DocumentService.get_by_id(id)
|
| 649 |
doc = doc.to_dict()
|
|
|
|
| 707 |
)
|
| 708 |
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
| 709 |
DocumentService.update_by_id(id, info)
|
| 710 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
| 711 |
return get_result()
|
| 712 |
|
| 713 |
|
|
|
|
| 827 |
|
| 828 |
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
| 829 |
origin_chunks = []
|
| 830 |
+
if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
| 831 |
+
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
|
| 832 |
+
highlight=True)
|
| 833 |
res["total"] = sres.total
|
| 834 |
sign = 0
|
| 835 |
for id in sres.ids:
|
|
|
|
| 1003 |
v, c = embd_mdl.encode([doc.name, req["content"]])
|
| 1004 |
v = 0.1 * v[0] + 0.9 * v[1]
|
| 1005 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 1006 |
+
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
| 1007 |
|
| 1008 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
| 1009 |
# rename keys
|
|
|
|
| 1078 |
condition = {"doc_id": document_id}
|
| 1079 |
if "chunk_ids" in req:
|
| 1080 |
condition["id"] = req["chunk_ids"]
|
| 1081 |
+
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
| 1082 |
if chunk_number != 0:
|
| 1083 |
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
| 1084 |
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
|
|
|
| 1143 |
schema:
|
| 1144 |
type: object
|
| 1145 |
"""
|
| 1146 |
+
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
| 1147 |
if chunk is None:
|
| 1148 |
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
| 1149 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
|
|
|
| 1187 |
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
| 1188 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
| 1189 |
d["q_%d_vec" % len(v)] = v.tolist()
|
| 1190 |
+
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
| 1191 |
return get_result()
|
| 1192 |
|
| 1193 |
|
|
|
|
| 1285 |
if len(embd_nms) != 1:
|
| 1286 |
return get_result(
|
| 1287 |
message='Datasets use different embedding models."',
|
| 1288 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
| 1289 |
)
|
| 1290 |
if "question" not in req:
|
| 1291 |
return get_error_data_result("`question` is required.")
|
|
|
|
| 1326 |
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
| 1327 |
question += keyword_extraction(chat_mdl, question)
|
| 1328 |
|
| 1329 |
+
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
| 1330 |
ranks = retr.retrieval(
|
| 1331 |
question,
|
| 1332 |
embd_mdl,
|
|
|
|
| 1366 |
if str(e).find("not_found") > 0:
|
| 1367 |
return get_result(
|
| 1368 |
message="No chunk found! Check the chunk status please!",
|
| 1369 |
+
code=settings.RetCode.DATA_ERROR,
|
| 1370 |
)
|
| 1371 |
+
return server_error_response(e)
|
api/apps/system_app.py
CHANGED
|
@@ -22,7 +22,7 @@ from api.db.db_models import APIToken
|
|
| 22 |
from api.db.services.api_service import APITokenService
|
| 23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 24 |
from api.db.services.user_service import UserTenantService
|
| 25 |
-
from api
|
| 26 |
from api.utils import current_timestamp, datetime_format
|
| 27 |
from api.utils.api_utils import (
|
| 28 |
get_json_result,
|
|
@@ -31,7 +31,6 @@ from api.utils.api_utils import (
|
|
| 31 |
generate_confirmation_token,
|
| 32 |
)
|
| 33 |
from api.versions import get_ragflow_version
|
| 34 |
-
from api.settings import docStoreConn
|
| 35 |
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
| 36 |
from timeit import default_timer as timer
|
| 37 |
|
|
@@ -98,7 +97,7 @@ def status():
|
|
| 98 |
res = {}
|
| 99 |
st = timer()
|
| 100 |
try:
|
| 101 |
-
res["doc_store"] = docStoreConn.health()
|
| 102 |
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
| 103 |
except Exception as e:
|
| 104 |
res["doc_store"] = {
|
|
@@ -128,13 +127,13 @@ def status():
|
|
| 128 |
try:
|
| 129 |
KnowledgebaseService.get_by_id("x")
|
| 130 |
res["database"] = {
|
| 131 |
-
"database": DATABASE_TYPE.lower(),
|
| 132 |
"status": "green",
|
| 133 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
| 134 |
}
|
| 135 |
except Exception as e:
|
| 136 |
res["database"] = {
|
| 137 |
-
"database": DATABASE_TYPE.lower(),
|
| 138 |
"status": "red",
|
| 139 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
| 140 |
"error": str(e),
|
|
|
|
| 22 |
from api.db.services.api_service import APITokenService
|
| 23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 24 |
from api.db.services.user_service import UserTenantService
|
| 25 |
+
from api import settings
|
| 26 |
from api.utils import current_timestamp, datetime_format
|
| 27 |
from api.utils.api_utils import (
|
| 28 |
get_json_result,
|
|
|
|
| 31 |
generate_confirmation_token,
|
| 32 |
)
|
| 33 |
from api.versions import get_ragflow_version
|
|
|
|
| 34 |
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
| 35 |
from timeit import default_timer as timer
|
| 36 |
|
|
|
|
| 97 |
res = {}
|
| 98 |
st = timer()
|
| 99 |
try:
|
| 100 |
+
res["doc_store"] = settings.docStoreConn.health()
|
| 101 |
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
| 102 |
except Exception as e:
|
| 103 |
res["doc_store"] = {
|
|
|
|
| 127 |
try:
|
| 128 |
KnowledgebaseService.get_by_id("x")
|
| 129 |
res["database"] = {
|
| 130 |
+
"database": settings.DATABASE_TYPE.lower(),
|
| 131 |
"status": "green",
|
| 132 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
| 133 |
}
|
| 134 |
except Exception as e:
|
| 135 |
res["database"] = {
|
| 136 |
+
"database": settings.DATABASE_TYPE.lower(),
|
| 137 |
"status": "red",
|
| 138 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
| 139 |
"error": str(e),
|
api/apps/user_app.py
CHANGED
|
@@ -38,20 +38,7 @@ from api.utils import (
|
|
| 38 |
datetime_format,
|
| 39 |
)
|
| 40 |
from api.db import UserTenantRole, FileType
|
| 41 |
-
from api
|
| 42 |
-
RetCode,
|
| 43 |
-
GITHUB_OAUTH,
|
| 44 |
-
FEISHU_OAUTH,
|
| 45 |
-
CHAT_MDL,
|
| 46 |
-
EMBEDDING_MDL,
|
| 47 |
-
ASR_MDL,
|
| 48 |
-
IMAGE2TEXT_MDL,
|
| 49 |
-
PARSERS,
|
| 50 |
-
API_KEY,
|
| 51 |
-
LLM_FACTORY,
|
| 52 |
-
LLM_BASE_URL,
|
| 53 |
-
RERANK_MDL,
|
| 54 |
-
)
|
| 55 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
| 56 |
from api.db.services.file_service import FileService
|
| 57 |
from api.utils.api_utils import get_json_result, construct_response
|
|
@@ -90,7 +77,7 @@ def login():
|
|
| 90 |
"""
|
| 91 |
if not request.json:
|
| 92 |
return get_json_result(
|
| 93 |
-
data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
| 94 |
)
|
| 95 |
|
| 96 |
email = request.json.get("email", "")
|
|
@@ -98,7 +85,7 @@ def login():
|
|
| 98 |
if not users:
|
| 99 |
return get_json_result(
|
| 100 |
data=False,
|
| 101 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
| 102 |
message=f"Email: {email} is not registered!",
|
| 103 |
)
|
| 104 |
|
|
@@ -107,7 +94,7 @@ def login():
|
|
| 107 |
password = decrypt(password)
|
| 108 |
except BaseException:
|
| 109 |
return get_json_result(
|
| 110 |
-
data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
|
| 111 |
)
|
| 112 |
|
| 113 |
user = UserService.query_user(email, password)
|
|
@@ -123,7 +110,7 @@ def login():
|
|
| 123 |
else:
|
| 124 |
return get_json_result(
|
| 125 |
data=False,
|
| 126 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
| 127 |
message="Email and password do not match!",
|
| 128 |
)
|
| 129 |
|
|
@@ -150,10 +137,10 @@ def github_callback():
|
|
| 150 |
import requests
|
| 151 |
|
| 152 |
res = requests.post(
|
| 153 |
-
GITHUB_OAUTH.get("url"),
|
| 154 |
data={
|
| 155 |
-
"client_id": GITHUB_OAUTH.get("client_id"),
|
| 156 |
-
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
| 157 |
"code": request.args.get("code"),
|
| 158 |
},
|
| 159 |
headers={"Accept": "application/json"},
|
|
@@ -235,11 +222,11 @@ def feishu_callback():
|
|
| 235 |
import requests
|
| 236 |
|
| 237 |
app_access_token_res = requests.post(
|
| 238 |
-
FEISHU_OAUTH.get("app_access_token_url"),
|
| 239 |
data=json.dumps(
|
| 240 |
{
|
| 241 |
-
"app_id": FEISHU_OAUTH.get("app_id"),
|
| 242 |
-
"app_secret": FEISHU_OAUTH.get("app_secret"),
|
| 243 |
}
|
| 244 |
),
|
| 245 |
headers={"Content-Type": "application/json; charset=utf-8"},
|
|
@@ -249,10 +236,10 @@ def feishu_callback():
|
|
| 249 |
return redirect("/?error=%s" % app_access_token_res)
|
| 250 |
|
| 251 |
res = requests.post(
|
| 252 |
-
FEISHU_OAUTH.get("user_access_token_url"),
|
| 253 |
data=json.dumps(
|
| 254 |
{
|
| 255 |
-
"grant_type": FEISHU_OAUTH.get("grant_type"),
|
| 256 |
"code": request.args.get("code"),
|
| 257 |
}
|
| 258 |
),
|
|
@@ -405,11 +392,11 @@ def setting_user():
|
|
| 405 |
if request_data.get("password"):
|
| 406 |
new_password = request_data.get("new_password")
|
| 407 |
if not check_password_hash(
|
| 408 |
-
|
| 409 |
):
|
| 410 |
return get_json_result(
|
| 411 |
data=False,
|
| 412 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
| 413 |
message="Password error!",
|
| 414 |
)
|
| 415 |
|
|
@@ -438,7 +425,7 @@ def setting_user():
|
|
| 438 |
except Exception as e:
|
| 439 |
logging.exception(e)
|
| 440 |
return get_json_result(
|
| 441 |
-
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
|
| 442 |
)
|
| 443 |
|
| 444 |
|
|
@@ -497,12 +484,12 @@ def user_register(user_id, user):
|
|
| 497 |
tenant = {
|
| 498 |
"id": user_id,
|
| 499 |
"name": user["nickname"] + "‘s Kingdom",
|
| 500 |
-
"llm_id": CHAT_MDL,
|
| 501 |
-
"embd_id": EMBEDDING_MDL,
|
| 502 |
-
"asr_id": ASR_MDL,
|
| 503 |
-
"parser_ids": PARSERS,
|
| 504 |
-
"img2txt_id": IMAGE2TEXT_MDL,
|
| 505 |
-
"rerank_id": RERANK_MDL,
|
| 506 |
}
|
| 507 |
usr_tenant = {
|
| 508 |
"tenant_id": user_id,
|
|
@@ -522,15 +509,15 @@ def user_register(user_id, user):
|
|
| 522 |
"location": "",
|
| 523 |
}
|
| 524 |
tenant_llm = []
|
| 525 |
-
for llm in LLMService.query(fid=LLM_FACTORY):
|
| 526 |
tenant_llm.append(
|
| 527 |
{
|
| 528 |
"tenant_id": user_id,
|
| 529 |
-
"llm_factory": LLM_FACTORY,
|
| 530 |
"llm_name": llm.llm_name,
|
| 531 |
"model_type": llm.model_type,
|
| 532 |
-
"api_key": API_KEY,
|
| 533 |
-
"api_base": LLM_BASE_URL,
|
| 534 |
}
|
| 535 |
)
|
| 536 |
|
|
@@ -582,7 +569,7 @@ def user_add():
|
|
| 582 |
return get_json_result(
|
| 583 |
data=False,
|
| 584 |
message=f"Invalid email address: {email_address}!",
|
| 585 |
-
code=RetCode.OPERATING_ERROR,
|
| 586 |
)
|
| 587 |
|
| 588 |
# Check if the email address is already used
|
|
@@ -590,7 +577,7 @@ def user_add():
|
|
| 590 |
return get_json_result(
|
| 591 |
data=False,
|
| 592 |
message=f"Email: {email_address} has already registered!",
|
| 593 |
-
code=RetCode.OPERATING_ERROR,
|
| 594 |
)
|
| 595 |
|
| 596 |
# Construct user info data
|
|
@@ -625,7 +612,7 @@ def user_add():
|
|
| 625 |
return get_json_result(
|
| 626 |
data=False,
|
| 627 |
message=f"User registration failure, error: {str(e)}",
|
| 628 |
-
code=RetCode.EXCEPTION_ERROR,
|
| 629 |
)
|
| 630 |
|
| 631 |
|
|
|
|
| 38 |
datetime_format,
|
| 39 |
)
|
| 40 |
from api.db import UserTenantRole, FileType
|
| 41 |
+
from api import settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
| 43 |
from api.db.services.file_service import FileService
|
| 44 |
from api.utils.api_utils import get_json_result, construct_response
|
|
|
|
| 77 |
"""
|
| 78 |
if not request.json:
|
| 79 |
return get_json_result(
|
| 80 |
+
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
| 81 |
)
|
| 82 |
|
| 83 |
email = request.json.get("email", "")
|
|
|
|
| 85 |
if not users:
|
| 86 |
return get_json_result(
|
| 87 |
data=False,
|
| 88 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
| 89 |
message=f"Email: {email} is not registered!",
|
| 90 |
)
|
| 91 |
|
|
|
|
| 94 |
password = decrypt(password)
|
| 95 |
except BaseException:
|
| 96 |
return get_json_result(
|
| 97 |
+
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
|
| 98 |
)
|
| 99 |
|
| 100 |
user = UserService.query_user(email, password)
|
|
|
|
| 110 |
else:
|
| 111 |
return get_json_result(
|
| 112 |
data=False,
|
| 113 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
| 114 |
message="Email and password do not match!",
|
| 115 |
)
|
| 116 |
|
|
|
|
| 137 |
import requests
|
| 138 |
|
| 139 |
res = requests.post(
|
| 140 |
+
settings.GITHUB_OAUTH.get("url"),
|
| 141 |
data={
|
| 142 |
+
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
| 143 |
+
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
|
| 144 |
"code": request.args.get("code"),
|
| 145 |
},
|
| 146 |
headers={"Accept": "application/json"},
|
|
|
|
| 222 |
import requests
|
| 223 |
|
| 224 |
app_access_token_res = requests.post(
|
| 225 |
+
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
| 226 |
data=json.dumps(
|
| 227 |
{
|
| 228 |
+
"app_id": settings.FEISHU_OAUTH.get("app_id"),
|
| 229 |
+
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
|
| 230 |
}
|
| 231 |
),
|
| 232 |
headers={"Content-Type": "application/json; charset=utf-8"},
|
|
|
|
| 236 |
return redirect("/?error=%s" % app_access_token_res)
|
| 237 |
|
| 238 |
res = requests.post(
|
| 239 |
+
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
| 240 |
data=json.dumps(
|
| 241 |
{
|
| 242 |
+
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
|
| 243 |
"code": request.args.get("code"),
|
| 244 |
}
|
| 245 |
),
|
|
|
|
| 392 |
if request_data.get("password"):
|
| 393 |
new_password = request_data.get("new_password")
|
| 394 |
if not check_password_hash(
|
| 395 |
+
current_user.password, decrypt(request_data["password"])
|
| 396 |
):
|
| 397 |
return get_json_result(
|
| 398 |
data=False,
|
| 399 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
| 400 |
message="Password error!",
|
| 401 |
)
|
| 402 |
|
|
|
|
| 425 |
except Exception as e:
|
| 426 |
logging.exception(e)
|
| 427 |
return get_json_result(
|
| 428 |
+
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
|
| 429 |
)
|
| 430 |
|
| 431 |
|
|
|
|
| 484 |
tenant = {
|
| 485 |
"id": user_id,
|
| 486 |
"name": user["nickname"] + "‘s Kingdom",
|
| 487 |
+
"llm_id": settings.CHAT_MDL,
|
| 488 |
+
"embd_id": settings.EMBEDDING_MDL,
|
| 489 |
+
"asr_id": settings.ASR_MDL,
|
| 490 |
+
"parser_ids": settings.PARSERS,
|
| 491 |
+
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
| 492 |
+
"rerank_id": settings.RERANK_MDL,
|
| 493 |
}
|
| 494 |
usr_tenant = {
|
| 495 |
"tenant_id": user_id,
|
|
|
|
| 509 |
"location": "",
|
| 510 |
}
|
| 511 |
tenant_llm = []
|
| 512 |
+
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
| 513 |
tenant_llm.append(
|
| 514 |
{
|
| 515 |
"tenant_id": user_id,
|
| 516 |
+
"llm_factory": settings.LLM_FACTORY,
|
| 517 |
"llm_name": llm.llm_name,
|
| 518 |
"model_type": llm.model_type,
|
| 519 |
+
"api_key": settings.API_KEY,
|
| 520 |
+
"api_base": settings.LLM_BASE_URL,
|
| 521 |
}
|
| 522 |
)
|
| 523 |
|
|
|
|
| 569 |
return get_json_result(
|
| 570 |
data=False,
|
| 571 |
message=f"Invalid email address: {email_address}!",
|
| 572 |
+
code=settings.RetCode.OPERATING_ERROR,
|
| 573 |
)
|
| 574 |
|
| 575 |
# Check if the email address is already used
|
|
|
|
| 577 |
return get_json_result(
|
| 578 |
data=False,
|
| 579 |
message=f"Email: {email_address} has already registered!",
|
| 580 |
+
code=settings.RetCode.OPERATING_ERROR,
|
| 581 |
)
|
| 582 |
|
| 583 |
# Construct user info data
|
|
|
|
| 612 |
return get_json_result(
|
| 613 |
data=False,
|
| 614 |
message=f"User registration failure, error: {str(e)}",
|
| 615 |
+
code=settings.RetCode.EXCEPTION_ERROR,
|
| 616 |
)
|
| 617 |
|
| 618 |
|
api/db/db_models.py
CHANGED
|
@@ -31,7 +31,7 @@ from peewee import (
|
|
| 31 |
)
|
| 32 |
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
| 33 |
from api.db import SerializedType, ParserType
|
| 34 |
-
from api
|
| 35 |
from api import utils
|
| 36 |
|
| 37 |
def singleton(cls, *args, **kw):
|
|
@@ -62,7 +62,7 @@ class TextFieldType(Enum):
|
|
| 62 |
|
| 63 |
|
| 64 |
class LongTextField(TextField):
|
| 65 |
-
field_type = TextFieldType[DATABASE_TYPE.upper()].value
|
| 66 |
|
| 67 |
|
| 68 |
class JSONField(LongTextField):
|
|
@@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
|
|
| 282 |
@singleton
|
| 283 |
class BaseDataBase:
|
| 284 |
def __init__(self):
|
| 285 |
-
database_config = DATABASE.copy()
|
| 286 |
db_name = database_config.pop("name")
|
| 287 |
-
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
|
| 288 |
logging.info('init database on cluster mode successfully')
|
| 289 |
|
| 290 |
class PostgresDatabaseLock:
|
|
@@ -385,7 +385,7 @@ class DatabaseLock(Enum):
|
|
| 385 |
|
| 386 |
|
| 387 |
DB = BaseDataBase().database_connection
|
| 388 |
-
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
|
| 389 |
|
| 390 |
|
| 391 |
def close_connection():
|
|
@@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
|
|
| 476 |
return self.email
|
| 477 |
|
| 478 |
def get_id(self):
|
| 479 |
-
jwt = Serializer(secret_key=SECRET_KEY)
|
| 480 |
return jwt.dumps(str(self.access_token))
|
| 481 |
|
| 482 |
class Meta:
|
|
@@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
|
|
| 977 |
|
| 978 |
def migrate_db():
|
| 979 |
with DB.transaction():
|
| 980 |
-
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
|
| 981 |
try:
|
| 982 |
migrate(
|
| 983 |
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
|
|
|
| 31 |
)
|
| 32 |
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
| 33 |
from api.db import SerializedType, ParserType
|
| 34 |
+
from api import settings
|
| 35 |
from api import utils
|
| 36 |
|
| 37 |
def singleton(cls, *args, **kw):
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
class LongTextField(TextField):
|
| 65 |
+
field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
|
| 66 |
|
| 67 |
|
| 68 |
class JSONField(LongTextField):
|
|
|
|
| 282 |
@singleton
|
| 283 |
class BaseDataBase:
|
| 284 |
def __init__(self):
|
| 285 |
+
database_config = settings.DATABASE.copy()
|
| 286 |
db_name = database_config.pop("name")
|
| 287 |
+
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
| 288 |
logging.info('init database on cluster mode successfully')
|
| 289 |
|
| 290 |
class PostgresDatabaseLock:
|
|
|
|
| 385 |
|
| 386 |
|
| 387 |
DB = BaseDataBase().database_connection
|
| 388 |
+
DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
|
| 389 |
|
| 390 |
|
| 391 |
def close_connection():
|
|
|
|
| 476 |
return self.email
|
| 477 |
|
| 478 |
def get_id(self):
|
| 479 |
+
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
| 480 |
return jwt.dumps(str(self.access_token))
|
| 481 |
|
| 482 |
class Meta:
|
|
|
|
| 977 |
|
| 978 |
def migrate_db():
|
| 979 |
with DB.transaction():
|
| 980 |
+
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
| 981 |
try:
|
| 982 |
migrate(
|
| 983 |
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
api/db/init_data.py
CHANGED
|
@@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
|
|
| 29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 30 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
| 31 |
from api.db.services.user_service import TenantService, UserTenantService
|
| 32 |
-
from api
|
| 33 |
from api.utils.file_utils import get_project_base_directory
|
| 34 |
|
| 35 |
|
|
@@ -51,11 +51,11 @@ def init_superuser():
|
|
| 51 |
tenant = {
|
| 52 |
"id": user_info["id"],
|
| 53 |
"name": user_info["nickname"] + "‘s Kingdom",
|
| 54 |
-
"llm_id": CHAT_MDL,
|
| 55 |
-
"embd_id": EMBEDDING_MDL,
|
| 56 |
-
"asr_id": ASR_MDL,
|
| 57 |
-
"parser_ids": PARSERS,
|
| 58 |
-
"img2txt_id": IMAGE2TEXT_MDL
|
| 59 |
}
|
| 60 |
usr_tenant = {
|
| 61 |
"tenant_id": user_info["id"],
|
|
@@ -64,10 +64,11 @@ def init_superuser():
|
|
| 64 |
"role": UserTenantRole.OWNER
|
| 65 |
}
|
| 66 |
tenant_llm = []
|
| 67 |
-
for llm in LLMService.query(fid=LLM_FACTORY):
|
| 68 |
tenant_llm.append(
|
| 69 |
-
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name,
|
| 70 |
-
"
|
|
|
|
| 71 |
|
| 72 |
if not UserService.save(**user_info):
|
| 73 |
logging.error("can't init admin.")
|
|
@@ -80,7 +81,7 @@ def init_superuser():
|
|
| 80 |
|
| 81 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
| 82 |
msg = chat_mdl.chat(system="", history=[
|
| 83 |
-
|
| 84 |
if msg.find("ERROR: ") == 0:
|
| 85 |
logging.error(
|
| 86 |
"'{}' dosen't work. {}".format(
|
|
@@ -179,7 +180,7 @@ def init_web_data():
|
|
| 179 |
start_time = time.time()
|
| 180 |
|
| 181 |
init_llm_factory()
|
| 182 |
-
#if not UserService.get_all().count():
|
| 183 |
# init_superuser()
|
| 184 |
|
| 185 |
add_graph_templates()
|
|
|
|
| 29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 30 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
| 31 |
from api.db.services.user_service import TenantService, UserTenantService
|
| 32 |
+
from api import settings
|
| 33 |
from api.utils.file_utils import get_project_base_directory
|
| 34 |
|
| 35 |
|
|
|
|
| 51 |
tenant = {
|
| 52 |
"id": user_info["id"],
|
| 53 |
"name": user_info["nickname"] + "‘s Kingdom",
|
| 54 |
+
"llm_id": settings.CHAT_MDL,
|
| 55 |
+
"embd_id": settings.EMBEDDING_MDL,
|
| 56 |
+
"asr_id": settings.ASR_MDL,
|
| 57 |
+
"parser_ids": settings.PARSERS,
|
| 58 |
+
"img2txt_id": settings.IMAGE2TEXT_MDL
|
| 59 |
}
|
| 60 |
usr_tenant = {
|
| 61 |
"tenant_id": user_info["id"],
|
|
|
|
| 64 |
"role": UserTenantRole.OWNER
|
| 65 |
}
|
| 66 |
tenant_llm = []
|
| 67 |
+
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
| 68 |
tenant_llm.append(
|
| 69 |
+
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
|
| 70 |
+
"model_type": llm.model_type,
|
| 71 |
+
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
|
| 72 |
|
| 73 |
if not UserService.save(**user_info):
|
| 74 |
logging.error("can't init admin.")
|
|
|
|
| 81 |
|
| 82 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
| 83 |
msg = chat_mdl.chat(system="", history=[
|
| 84 |
+
{"role": "user", "content": "Hello!"}], gen_conf={})
|
| 85 |
if msg.find("ERROR: ") == 0:
|
| 86 |
logging.error(
|
| 87 |
"'{}' dosen't work. {}".format(
|
|
|
|
| 180 |
start_time = time.time()
|
| 181 |
|
| 182 |
init_llm_factory()
|
| 183 |
+
# if not UserService.get_all().count():
|
| 184 |
# init_superuser()
|
| 185 |
|
| 186 |
add_graph_templates()
|
api/db/services/dialog_service.py
CHANGED
|
@@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
|
|
| 27 |
from api.db.services.common_service import CommonService
|
| 28 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 29 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
| 30 |
-
from api
|
| 31 |
from rag.app.resume import forbidden_select_fields4resume
|
| 32 |
from rag.nlp.search import index_name
|
| 33 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
|
@@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 152 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 153 |
|
| 154 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
| 155 |
-
retr = retrievaler if not is_kg else kg_retrievaler
|
| 156 |
|
| 157 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
| 158 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
|
@@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 342 |
|
| 343 |
logging.debug(f"{question} get SQL(refined): {sql}")
|
| 344 |
tried_times += 1
|
| 345 |
-
return retrievaler.sql_retrieval(sql, format="json"), sql
|
| 346 |
|
| 347 |
tbl, sql = get_table()
|
| 348 |
if tbl is None:
|
|
@@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
|
|
| 596 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 597 |
|
| 598 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
| 599 |
-
retr = retrievaler if not is_kg else kg_retrievaler
|
| 600 |
|
| 601 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
| 602 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
|
|
|
| 27 |
from api.db.services.common_service import CommonService
|
| 28 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 29 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
| 30 |
+
from api import settings
|
| 31 |
from rag.app.resume import forbidden_select_fields4resume
|
| 32 |
from rag.nlp.search import index_name
|
| 33 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
|
|
|
| 152 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 153 |
|
| 154 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
| 155 |
+
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
| 156 |
|
| 157 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
| 158 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
|
|
|
| 342 |
|
| 343 |
logging.debug(f"{question} get SQL(refined): {sql}")
|
| 344 |
tried_times += 1
|
| 345 |
+
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
| 346 |
|
| 347 |
tbl, sql = get_table()
|
| 348 |
if tbl is None:
|
|
|
|
| 596 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 597 |
|
| 598 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
| 599 |
+
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
| 600 |
|
| 601 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
| 602 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
api/db/services/document_service.py
CHANGED
|
@@ -26,7 +26,7 @@ from io import BytesIO
|
|
| 26 |
from peewee import fn
|
| 27 |
|
| 28 |
from api.db.db_utils import bulk_insert_into_db
|
| 29 |
-
from api
|
| 30 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
| 31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
| 32 |
from rag.settings import SVR_QUEUE_NAME
|
|
@@ -108,7 +108,7 @@ class DocumentService(CommonService):
|
|
| 108 |
@classmethod
|
| 109 |
@DB.connection_context()
|
| 110 |
def remove_document(cls, doc, tenant_id):
|
| 111 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
| 112 |
cls.clear_chunk_num(doc.id)
|
| 113 |
return cls.delete_by_id(doc.id)
|
| 114 |
|
|
@@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
| 553 |
d["q_%d_vec" % len(v)] = v
|
| 554 |
for b in range(0, len(cks), es_bulk_size):
|
| 555 |
if try_create_idx:
|
| 556 |
-
if not docStoreConn.indexExist(idxnm, kb_id):
|
| 557 |
-
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
| 558 |
try_create_idx = False
|
| 559 |
-
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
| 560 |
|
| 561 |
DocumentService.increment_chunk_num(
|
| 562 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
|
|
|
| 26 |
from peewee import fn
|
| 27 |
|
| 28 |
from api.db.db_utils import bulk_insert_into_db
|
| 29 |
+
from api import settings
|
| 30 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
| 31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
| 32 |
from rag.settings import SVR_QUEUE_NAME
|
|
|
|
| 108 |
@classmethod
|
| 109 |
@DB.connection_context()
|
| 110 |
def remove_document(cls, doc, tenant_id):
|
| 111 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
| 112 |
cls.clear_chunk_num(doc.id)
|
| 113 |
return cls.delete_by_id(doc.id)
|
| 114 |
|
|
|
|
| 553 |
d["q_%d_vec" % len(v)] = v
|
| 554 |
for b in range(0, len(cks), es_bulk_size):
|
| 555 |
if try_create_idx:
|
| 556 |
+
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
| 557 |
+
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
| 558 |
try_create_idx = False
|
| 559 |
+
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
| 560 |
|
| 561 |
DocumentService.increment_chunk_num(
|
| 562 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
api/ragflow_server.py
CHANGED
|
@@ -33,12 +33,10 @@ import traceback
|
|
| 33 |
from concurrent.futures import ThreadPoolExecutor
|
| 34 |
|
| 35 |
from werkzeug.serving import run_simple
|
|
|
|
| 36 |
from api.apps import app
|
| 37 |
from api.db.runtime_config import RuntimeConfig
|
| 38 |
from api.db.services.document_service import DocumentService
|
| 39 |
-
from api.settings import (
|
| 40 |
-
HOST, HTTP_PORT
|
| 41 |
-
)
|
| 42 |
from api import utils
|
| 43 |
|
| 44 |
from api.db.db_models import init_database_tables as init_web_db
|
|
@@ -72,6 +70,7 @@ if __name__ == '__main__':
|
|
| 72 |
f'project base: {utils.file_utils.get_project_base_directory()}'
|
| 73 |
)
|
| 74 |
show_configs()
|
|
|
|
| 75 |
|
| 76 |
# init db
|
| 77 |
init_web_db()
|
|
@@ -96,7 +95,7 @@ if __name__ == '__main__':
|
|
| 96 |
logging.info("run on debug mode")
|
| 97 |
|
| 98 |
RuntimeConfig.init_env()
|
| 99 |
-
RuntimeConfig.init_config(JOB_SERVER_HOST=
|
| 100 |
|
| 101 |
thread = ThreadPoolExecutor(max_workers=1)
|
| 102 |
thread.submit(update_progress)
|
|
@@ -105,8 +104,8 @@ if __name__ == '__main__':
|
|
| 105 |
try:
|
| 106 |
logging.info("RAGFlow HTTP server start...")
|
| 107 |
run_simple(
|
| 108 |
-
hostname=
|
| 109 |
-
port=
|
| 110 |
application=app,
|
| 111 |
threaded=True,
|
| 112 |
use_reloader=RuntimeConfig.DEBUG,
|
|
|
|
| 33 |
from concurrent.futures import ThreadPoolExecutor
|
| 34 |
|
| 35 |
from werkzeug.serving import run_simple
|
| 36 |
+
from api import settings
|
| 37 |
from api.apps import app
|
| 38 |
from api.db.runtime_config import RuntimeConfig
|
| 39 |
from api.db.services.document_service import DocumentService
|
|
|
|
|
|
|
|
|
|
| 40 |
from api import utils
|
| 41 |
|
| 42 |
from api.db.db_models import init_database_tables as init_web_db
|
|
|
|
| 70 |
f'project base: {utils.file_utils.get_project_base_directory()}'
|
| 71 |
)
|
| 72 |
show_configs()
|
| 73 |
+
settings.init_settings()
|
| 74 |
|
| 75 |
# init db
|
| 76 |
init_web_db()
|
|
|
|
| 95 |
logging.info("run on debug mode")
|
| 96 |
|
| 97 |
RuntimeConfig.init_env()
|
| 98 |
+
RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
|
| 99 |
|
| 100 |
thread = ThreadPoolExecutor(max_workers=1)
|
| 101 |
thread.submit(update_progress)
|
|
|
|
| 104 |
try:
|
| 105 |
logging.info("RAGFlow HTTP server start...")
|
| 106 |
run_simple(
|
| 107 |
+
hostname=settings.HOST_IP,
|
| 108 |
+
port=settings.HOST_PORT,
|
| 109 |
application=app,
|
| 110 |
threaded=True,
|
| 111 |
use_reloader=RuntimeConfig.DEBUG,
|
api/settings.py
CHANGED
|
@@ -30,114 +30,157 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
|
| 30 |
|
| 31 |
REQUEST_WAIT_SEC = 2
|
| 32 |
REQUEST_MAX_WAIT_SEC = 300
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
},
|
| 47 |
-
"OpenAI": {
|
| 48 |
-
"chat_model": "gpt-3.5-turbo",
|
| 49 |
-
"embedding_model": "text-embedding-ada-002",
|
| 50 |
-
"image2text_model": "gpt-4-vision-preview",
|
| 51 |
-
"asr_model": "whisper-1",
|
| 52 |
-
},
|
| 53 |
-
"Azure-OpenAI": {
|
| 54 |
-
"chat_model": "gpt-35-turbo",
|
| 55 |
-
"embedding_model": "text-embedding-ada-002",
|
| 56 |
-
"image2text_model": "gpt-4-vision-preview",
|
| 57 |
-
"asr_model": "whisper-1",
|
| 58 |
-
},
|
| 59 |
-
"ZHIPU-AI": {
|
| 60 |
-
"chat_model": "glm-3-turbo",
|
| 61 |
-
"embedding_model": "embedding-2",
|
| 62 |
-
"image2text_model": "glm-4v",
|
| 63 |
-
"asr_model": "",
|
| 64 |
-
},
|
| 65 |
-
"Ollama": {
|
| 66 |
-
"chat_model": "qwen-14B-chat",
|
| 67 |
-
"embedding_model": "flag-embedding",
|
| 68 |
-
"image2text_model": "",
|
| 69 |
-
"asr_model": "",
|
| 70 |
-
},
|
| 71 |
-
"Moonshot": {
|
| 72 |
-
"chat_model": "moonshot-v1-8k",
|
| 73 |
-
"embedding_model": "",
|
| 74 |
-
"image2text_model": "",
|
| 75 |
-
"asr_model": "",
|
| 76 |
-
},
|
| 77 |
-
"DeepSeek": {
|
| 78 |
-
"chat_model": "deepseek-chat",
|
| 79 |
-
"embedding_model": "",
|
| 80 |
-
"image2text_model": "",
|
| 81 |
-
"asr_model": "",
|
| 82 |
-
},
|
| 83 |
-
"VolcEngine": {
|
| 84 |
-
"chat_model": "",
|
| 85 |
-
"embedding_model": "",
|
| 86 |
-
"image2text_model": "",
|
| 87 |
-
"asr_model": "",
|
| 88 |
-
},
|
| 89 |
-
"BAAI": {
|
| 90 |
-
"chat_model": "",
|
| 91 |
-
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
| 92 |
-
"image2text_model": "",
|
| 93 |
-
"asr_model": "",
|
| 94 |
-
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
| 95 |
-
}
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
if LLM_FACTORY:
|
| 99 |
-
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
|
| 100 |
-
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
|
| 101 |
-
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
|
| 102 |
-
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
| 103 |
-
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
| 104 |
-
|
| 105 |
-
API_KEY = LLM.get("api_key", "")
|
| 106 |
-
PARSERS = LLM.get(
|
| 107 |
-
"parsers",
|
| 108 |
-
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
| 109 |
-
|
| 110 |
-
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
| 111 |
-
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
| 112 |
-
|
| 113 |
-
SECRET_KEY = get_base_config(
|
| 114 |
-
RAG_FLOW_SERVICE_NAME,
|
| 115 |
-
{}).get("secret_key", str(date.today()))
|
| 116 |
|
| 117 |
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
| 118 |
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
| 119 |
|
| 120 |
# authentication
|
| 121 |
-
AUTHENTICATION_CONF =
|
| 122 |
|
| 123 |
# client
|
| 124 |
-
CLIENT_AUTHENTICATION =
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
class CustomEnum(Enum):
|
|
|
|
| 30 |
|
| 31 |
REQUEST_WAIT_SEC = 2
|
| 32 |
REQUEST_MAX_WAIT_SEC = 300
|
| 33 |
+
LLM = None
|
| 34 |
+
LLM_FACTORY = None
|
| 35 |
+
LLM_BASE_URL = None
|
| 36 |
+
CHAT_MDL = ""
|
| 37 |
+
EMBEDDING_MDL = ""
|
| 38 |
+
RERANK_MDL = ""
|
| 39 |
+
ASR_MDL = ""
|
| 40 |
+
IMAGE2TEXT_MDL = ""
|
| 41 |
+
API_KEY = None
|
| 42 |
+
PARSERS = None
|
| 43 |
+
HOST_IP = None
|
| 44 |
+
HOST_PORT = None
|
| 45 |
+
SECRET_KEY = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
| 48 |
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
| 49 |
|
| 50 |
# authentication
|
| 51 |
+
AUTHENTICATION_CONF = None
|
| 52 |
|
| 53 |
# client
|
| 54 |
+
CLIENT_AUTHENTICATION = None
|
| 55 |
+
HTTP_APP_KEY = None
|
| 56 |
+
GITHUB_OAUTH = None
|
| 57 |
+
FEISHU_OAUTH = None
|
| 58 |
+
|
| 59 |
+
DOC_ENGINE = None
|
| 60 |
+
docStoreConn = None
|
| 61 |
+
|
| 62 |
+
retrievaler = None
|
| 63 |
+
kg_retrievaler = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def init_settings():
|
| 67 |
+
global LLM, LLM_FACTORY, LLM_BASE_URL
|
| 68 |
+
LLM = get_base_config("user_default_llm", {})
|
| 69 |
+
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
| 70 |
+
LLM_BASE_URL = LLM.get("base_url")
|
| 71 |
+
|
| 72 |
+
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
| 73 |
+
if not LIGHTEN:
|
| 74 |
+
default_llm = {
|
| 75 |
+
"Tongyi-Qianwen": {
|
| 76 |
+
"chat_model": "qwen-plus",
|
| 77 |
+
"embedding_model": "text-embedding-v2",
|
| 78 |
+
"image2text_model": "qwen-vl-max",
|
| 79 |
+
"asr_model": "paraformer-realtime-8k-v1",
|
| 80 |
+
},
|
| 81 |
+
"OpenAI": {
|
| 82 |
+
"chat_model": "gpt-3.5-turbo",
|
| 83 |
+
"embedding_model": "text-embedding-ada-002",
|
| 84 |
+
"image2text_model": "gpt-4-vision-preview",
|
| 85 |
+
"asr_model": "whisper-1",
|
| 86 |
+
},
|
| 87 |
+
"Azure-OpenAI": {
|
| 88 |
+
"chat_model": "gpt-35-turbo",
|
| 89 |
+
"embedding_model": "text-embedding-ada-002",
|
| 90 |
+
"image2text_model": "gpt-4-vision-preview",
|
| 91 |
+
"asr_model": "whisper-1",
|
| 92 |
+
},
|
| 93 |
+
"ZHIPU-AI": {
|
| 94 |
+
"chat_model": "glm-3-turbo",
|
| 95 |
+
"embedding_model": "embedding-2",
|
| 96 |
+
"image2text_model": "glm-4v",
|
| 97 |
+
"asr_model": "",
|
| 98 |
+
},
|
| 99 |
+
"Ollama": {
|
| 100 |
+
"chat_model": "qwen-14B-chat",
|
| 101 |
+
"embedding_model": "flag-embedding",
|
| 102 |
+
"image2text_model": "",
|
| 103 |
+
"asr_model": "",
|
| 104 |
+
},
|
| 105 |
+
"Moonshot": {
|
| 106 |
+
"chat_model": "moonshot-v1-8k",
|
| 107 |
+
"embedding_model": "",
|
| 108 |
+
"image2text_model": "",
|
| 109 |
+
"asr_model": "",
|
| 110 |
+
},
|
| 111 |
+
"DeepSeek": {
|
| 112 |
+
"chat_model": "deepseek-chat",
|
| 113 |
+
"embedding_model": "",
|
| 114 |
+
"image2text_model": "",
|
| 115 |
+
"asr_model": "",
|
| 116 |
+
},
|
| 117 |
+
"VolcEngine": {
|
| 118 |
+
"chat_model": "",
|
| 119 |
+
"embedding_model": "",
|
| 120 |
+
"image2text_model": "",
|
| 121 |
+
"asr_model": "",
|
| 122 |
+
},
|
| 123 |
+
"BAAI": {
|
| 124 |
+
"chat_model": "",
|
| 125 |
+
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
| 126 |
+
"image2text_model": "",
|
| 127 |
+
"asr_model": "",
|
| 128 |
+
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
if LLM_FACTORY:
|
| 133 |
+
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
|
| 134 |
+
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
|
| 135 |
+
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
|
| 136 |
+
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
| 137 |
+
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
| 138 |
+
|
| 139 |
+
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
| 140 |
+
API_KEY = LLM.get("api_key", "")
|
| 141 |
+
PARSERS = LLM.get(
|
| 142 |
+
"parsers",
|
| 143 |
+
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
| 144 |
+
|
| 145 |
+
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
| 146 |
+
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
| 147 |
+
|
| 148 |
+
SECRET_KEY = get_base_config(
|
| 149 |
+
RAG_FLOW_SERVICE_NAME,
|
| 150 |
+
{}).get("secret_key", str(date.today()))
|
| 151 |
+
|
| 152 |
+
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
|
| 153 |
+
# authentication
|
| 154 |
+
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
| 155 |
+
|
| 156 |
+
# client
|
| 157 |
+
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
| 158 |
+
"client", {}).get(
|
| 159 |
+
"switch", False)
|
| 160 |
+
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
| 161 |
+
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
| 162 |
+
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
| 163 |
+
|
| 164 |
+
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
|
| 165 |
+
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
| 166 |
+
if DOC_ENGINE == "elasticsearch":
|
| 167 |
+
docStoreConn = rag.utils.es_conn.ESConnection()
|
| 168 |
+
elif DOC_ENGINE == "infinity":
|
| 169 |
+
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
| 170 |
+
else:
|
| 171 |
+
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
| 172 |
+
|
| 173 |
+
retrievaler = search.Dealer(docStoreConn)
|
| 174 |
+
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
| 175 |
+
|
| 176 |
+
def get_host_ip():
|
| 177 |
+
global HOST_IP
|
| 178 |
+
return HOST_IP
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_host_port():
|
| 182 |
+
global HOST_PORT
|
| 183 |
+
return HOST_PORT
|
| 184 |
|
| 185 |
|
| 186 |
class CustomEnum(Enum):
|
api/utils/api_utils.py
CHANGED
|
@@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer
|
|
| 34 |
from werkzeug.http import HTTP_STATUS_CODES
|
| 35 |
|
| 36 |
from api.db.db_models import APIToken
|
| 37 |
-
from api
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
)
|
| 41 |
-
from api.settings import RetCode
|
| 42 |
from api.utils import CustomJSONEncoder, get_uuid
|
| 43 |
from api.utils import json_dumps
|
| 44 |
|
|
@@ -59,13 +57,13 @@ def request(**kwargs):
|
|
| 59 |
{}).items()}
|
| 60 |
prepped = requests.Request(**kwargs).prepare()
|
| 61 |
|
| 62 |
-
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
| 63 |
timestamp = str(round(time() * 1000))
|
| 64 |
nonce = str(uuid1())
|
| 65 |
-
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
|
| 66 |
timestamp.encode('ascii'),
|
| 67 |
nonce.encode('ascii'),
|
| 68 |
-
HTTP_APP_KEY.encode('ascii'),
|
| 69 |
prepped.path_url.encode('ascii'),
|
| 70 |
prepped.body if kwargs.get('json') else b'',
|
| 71 |
urlencode(
|
|
@@ -79,7 +77,7 @@ def request(**kwargs):
|
|
| 79 |
prepped.headers.update({
|
| 80 |
'TIMESTAMP': timestamp,
|
| 81 |
'NONCE': nonce,
|
| 82 |
-
'APP-KEY': HTTP_APP_KEY,
|
| 83 |
'SIGNATURE': signature,
|
| 84 |
})
|
| 85 |
|
|
@@ -89,7 +87,7 @@ def request(**kwargs):
|
|
| 89 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
| 90 |
"""Calculate the exponential backoff wait time."""
|
| 91 |
# Will be zero if factor equals 0
|
| 92 |
-
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
| 93 |
# Full jitter according to
|
| 94 |
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
| 95 |
if full_jitter:
|
|
@@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
|
|
| 98 |
return max(0, countdown)
|
| 99 |
|
| 100 |
|
| 101 |
-
def get_data_error_result(code=RetCode.DATA_ERROR,
|
| 102 |
message='Sorry! Data missing!'):
|
| 103 |
import re
|
| 104 |
result_dict = {
|
|
@@ -126,8 +124,8 @@ def server_error_response(e):
|
|
| 126 |
pass
|
| 127 |
if len(e.args) > 1:
|
| 128 |
return get_json_result(
|
| 129 |
-
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
| 130 |
-
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
| 131 |
|
| 132 |
|
| 133 |
def error_response(response_code, message=None):
|
|
@@ -168,7 +166,7 @@ def validate_request(*args, **kwargs):
|
|
| 168 |
error_string += "required argument values: {}".format(
|
| 169 |
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
| 170 |
return get_json_result(
|
| 171 |
-
code=RetCode.ARGUMENT_ERROR, message=error_string)
|
| 172 |
return func(*_args, **_kwargs)
|
| 173 |
|
| 174 |
return decorated_function
|
|
@@ -193,7 +191,7 @@ def send_file_in_mem(data, filename):
|
|
| 193 |
return send_file(f, as_attachment=True, attachment_filename=filename)
|
| 194 |
|
| 195 |
|
| 196 |
-
def get_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
| 197 |
response = {"code": code, "message": message, "data": data}
|
| 198 |
return jsonify(response)
|
| 199 |
|
|
@@ -204,7 +202,7 @@ def apikey_required(func):
|
|
| 204 |
objs = APIToken.query(token=token)
|
| 205 |
if not objs:
|
| 206 |
return build_error_result(
|
| 207 |
-
message='API-KEY is invalid!', code=RetCode.FORBIDDEN
|
| 208 |
)
|
| 209 |
kwargs['tenant_id'] = objs[0].tenant_id
|
| 210 |
return func(*args, **kwargs)
|
|
@@ -212,14 +210,14 @@ def apikey_required(func):
|
|
| 212 |
return decorated_function
|
| 213 |
|
| 214 |
|
| 215 |
-
def build_error_result(code=RetCode.FORBIDDEN, message='success'):
|
| 216 |
response = {"code": code, "message": message}
|
| 217 |
response = jsonify(response)
|
| 218 |
response.status_code = code
|
| 219 |
return response
|
| 220 |
|
| 221 |
|
| 222 |
-
def construct_response(code=RetCode.SUCCESS,
|
| 223 |
message='success', data=None, auth=None):
|
| 224 |
result_dict = {"code": code, "message": message, "data": data}
|
| 225 |
response_dict = {}
|
|
@@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS,
|
|
| 239 |
return response
|
| 240 |
|
| 241 |
|
| 242 |
-
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
| 243 |
import re
|
| 244 |
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
|
| 245 |
response = {}
|
|
@@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
|
| 251 |
return jsonify(response)
|
| 252 |
|
| 253 |
|
| 254 |
-
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
| 255 |
if data is None:
|
| 256 |
return jsonify({"code": code, "message": message})
|
| 257 |
else:
|
|
@@ -262,12 +260,12 @@ def construct_error_response(e):
|
|
| 262 |
logging.exception(e)
|
| 263 |
try:
|
| 264 |
if e.code == 401:
|
| 265 |
-
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
|
| 266 |
except BaseException:
|
| 267 |
pass
|
| 268 |
if len(e.args) > 1:
|
| 269 |
-
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
| 270 |
-
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
| 271 |
|
| 272 |
|
| 273 |
def token_required(func):
|
|
@@ -280,7 +278,7 @@ def token_required(func):
|
|
| 280 |
objs = APIToken.query(token=token)
|
| 281 |
if not objs:
|
| 282 |
return get_json_result(
|
| 283 |
-
data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR
|
| 284 |
)
|
| 285 |
kwargs['tenant_id'] = objs[0].tenant_id
|
| 286 |
return func(*args, **kwargs)
|
|
@@ -288,7 +286,7 @@ def token_required(func):
|
|
| 288 |
return decorated_function
|
| 289 |
|
| 290 |
|
| 291 |
-
def get_result(code=RetCode.SUCCESS, message="", data=None):
|
| 292 |
if code == 0:
|
| 293 |
if data is not None:
|
| 294 |
response = {"code": code, "data": data}
|
|
@@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None):
|
|
| 299 |
return jsonify(response)
|
| 300 |
|
| 301 |
|
| 302 |
-
def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR,
|
| 303 |
):
|
| 304 |
import re
|
| 305 |
result_dict = {
|
|
|
|
| 34 |
from werkzeug.http import HTTP_STATUS_CODES
|
| 35 |
|
| 36 |
from api.db.db_models import APIToken
|
| 37 |
+
from api import settings
|
| 38 |
+
|
| 39 |
+
from api import settings
|
|
|
|
|
|
|
| 40 |
from api.utils import CustomJSONEncoder, get_uuid
|
| 41 |
from api.utils import json_dumps
|
| 42 |
|
|
|
|
| 57 |
{}).items()}
|
| 58 |
prepped = requests.Request(**kwargs).prepare()
|
| 59 |
|
| 60 |
+
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
|
| 61 |
timestamp = str(round(time() * 1000))
|
| 62 |
nonce = str(uuid1())
|
| 63 |
+
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
|
| 64 |
timestamp.encode('ascii'),
|
| 65 |
nonce.encode('ascii'),
|
| 66 |
+
settings.HTTP_APP_KEY.encode('ascii'),
|
| 67 |
prepped.path_url.encode('ascii'),
|
| 68 |
prepped.body if kwargs.get('json') else b'',
|
| 69 |
urlencode(
|
|
|
|
| 77 |
prepped.headers.update({
|
| 78 |
'TIMESTAMP': timestamp,
|
| 79 |
'NONCE': nonce,
|
| 80 |
+
'APP-KEY': settings.HTTP_APP_KEY,
|
| 81 |
'SIGNATURE': signature,
|
| 82 |
})
|
| 83 |
|
|
|
|
| 87 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
| 88 |
"""Calculate the exponential backoff wait time."""
|
| 89 |
# Will be zero if factor equals 0
|
| 90 |
+
countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
|
| 91 |
# Full jitter according to
|
| 92 |
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
| 93 |
if full_jitter:
|
|
|
|
| 96 |
return max(0, countdown)
|
| 97 |
|
| 98 |
|
| 99 |
+
def get_data_error_result(code=settings.RetCode.DATA_ERROR,
|
| 100 |
message='Sorry! Data missing!'):
|
| 101 |
import re
|
| 102 |
result_dict = {
|
|
|
|
| 124 |
pass
|
| 125 |
if len(e.args) > 1:
|
| 126 |
return get_json_result(
|
| 127 |
+
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
| 128 |
+
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
| 129 |
|
| 130 |
|
| 131 |
def error_response(response_code, message=None):
|
|
|
|
| 166 |
error_string += "required argument values: {}".format(
|
| 167 |
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
| 168 |
return get_json_result(
|
| 169 |
+
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
|
| 170 |
return func(*_args, **_kwargs)
|
| 171 |
|
| 172 |
return decorated_function
|
|
|
|
| 191 |
return send_file(f, as_attachment=True, attachment_filename=filename)
|
| 192 |
|
| 193 |
|
| 194 |
+
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
| 195 |
response = {"code": code, "message": message, "data": data}
|
| 196 |
return jsonify(response)
|
| 197 |
|
|
|
|
| 202 |
objs = APIToken.query(token=token)
|
| 203 |
if not objs:
|
| 204 |
return build_error_result(
|
| 205 |
+
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
|
| 206 |
)
|
| 207 |
kwargs['tenant_id'] = objs[0].tenant_id
|
| 208 |
return func(*args, **kwargs)
|
|
|
|
| 210 |
return decorated_function
|
| 211 |
|
| 212 |
|
| 213 |
+
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
|
| 214 |
response = {"code": code, "message": message}
|
| 215 |
response = jsonify(response)
|
| 216 |
response.status_code = code
|
| 217 |
return response
|
| 218 |
|
| 219 |
|
| 220 |
+
def construct_response(code=settings.RetCode.SUCCESS,
|
| 221 |
message='success', data=None, auth=None):
|
| 222 |
result_dict = {"code": code, "message": message, "data": data}
|
| 223 |
response_dict = {}
|
|
|
|
| 237 |
return response
|
| 238 |
|
| 239 |
|
| 240 |
+
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
|
| 241 |
import re
|
| 242 |
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
|
| 243 |
response = {}
|
|
|
|
| 249 |
return jsonify(response)
|
| 250 |
|
| 251 |
|
| 252 |
+
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
| 253 |
if data is None:
|
| 254 |
return jsonify({"code": code, "message": message})
|
| 255 |
else:
|
|
|
|
| 260 |
logging.exception(e)
|
| 261 |
try:
|
| 262 |
if e.code == 401:
|
| 263 |
+
return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
|
| 264 |
except BaseException:
|
| 265 |
pass
|
| 266 |
if len(e.args) > 1:
|
| 267 |
+
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
| 268 |
+
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
| 269 |
|
| 270 |
|
| 271 |
def token_required(func):
|
|
|
|
| 278 |
objs = APIToken.query(token=token)
|
| 279 |
if not objs:
|
| 280 |
return get_json_result(
|
| 281 |
+
data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
|
| 282 |
)
|
| 283 |
kwargs['tenant_id'] = objs[0].tenant_id
|
| 284 |
return func(*args, **kwargs)
|
|
|
|
| 286 |
return decorated_function
|
| 287 |
|
| 288 |
|
| 289 |
+
def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
|
| 290 |
if code == 0:
|
| 291 |
if data is not None:
|
| 292 |
response = {"code": code, "data": data}
|
|
|
|
| 297 |
return jsonify(response)
|
| 298 |
|
| 299 |
|
| 300 |
+
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
|
| 301 |
):
|
| 302 |
import re
|
| 303 |
result_dict = {
|
deepdoc/parser/pdf_parser.py
CHANGED
|
@@ -24,7 +24,7 @@ import numpy as np
|
|
| 24 |
from timeit import default_timer as timer
|
| 25 |
from pypdf import PdfReader as pdf2_read
|
| 26 |
|
| 27 |
-
from api
|
| 28 |
from api.utils.file_utils import get_project_base_directory
|
| 29 |
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
| 30 |
from rag.nlp import rag_tokenizer
|
|
@@ -41,7 +41,7 @@ class RAGFlowPdfParser:
|
|
| 41 |
self.tbl_det = TableStructureRecognizer()
|
| 42 |
|
| 43 |
self.updown_cnt_mdl = xgb.Booster()
|
| 44 |
-
if not LIGHTEN:
|
| 45 |
try:
|
| 46 |
import torch
|
| 47 |
if torch.cuda.is_available():
|
|
|
|
| 24 |
from timeit import default_timer as timer
|
| 25 |
from pypdf import PdfReader as pdf2_read
|
| 26 |
|
| 27 |
+
from api import settings
|
| 28 |
from api.utils.file_utils import get_project_base_directory
|
| 29 |
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
| 30 |
from rag.nlp import rag_tokenizer
|
|
|
|
| 41 |
self.tbl_det = TableStructureRecognizer()
|
| 42 |
|
| 43 |
self.updown_cnt_mdl = xgb.Booster()
|
| 44 |
+
if not settings.LIGHTEN:
|
| 45 |
try:
|
| 46 |
import torch
|
| 47 |
if torch.cuda.is_available():
|
graphrag/claim_extractor.py
CHANGED
|
@@ -252,13 +252,13 @@ if __name__ == "__main__":
|
|
| 252 |
|
| 253 |
from api.db import LLMType
|
| 254 |
from api.db.services.llm_service import LLMBundle
|
| 255 |
-
from api
|
| 256 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 257 |
|
| 258 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
| 259 |
|
| 260 |
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
| 261 |
-
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
| 262 |
info = {
|
| 263 |
"input_text": docs,
|
| 264 |
"entity_specs": "organization, person",
|
|
|
|
| 252 |
|
| 253 |
from api.db import LLMType
|
| 254 |
from api.db.services.llm_service import LLMBundle
|
| 255 |
+
from api import settings
|
| 256 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 257 |
|
| 258 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
| 259 |
|
| 260 |
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
| 261 |
+
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
| 262 |
info = {
|
| 263 |
"input_text": docs,
|
| 264 |
"entity_specs": "organization, person",
|
graphrag/smoke.py
CHANGED
|
@@ -30,14 +30,14 @@ if __name__ == "__main__":
|
|
| 30 |
|
| 31 |
from api.db import LLMType
|
| 32 |
from api.db.services.llm_service import LLMBundle
|
| 33 |
-
from api
|
| 34 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 35 |
|
| 36 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
| 37 |
|
| 38 |
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
| 39 |
docs = [d["content_with_weight"] for d in
|
| 40 |
-
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
| 41 |
graph = ex(docs)
|
| 42 |
|
| 43 |
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
|
|
|
| 30 |
|
| 31 |
from api.db import LLMType
|
| 32 |
from api.db.services.llm_service import LLMBundle
|
| 33 |
+
from api import settings
|
| 34 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 35 |
|
| 36 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
| 37 |
|
| 38 |
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
| 39 |
docs = [d["content_with_weight"] for d in
|
| 40 |
+
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
| 41 |
graph = ex(docs)
|
| 42 |
|
| 43 |
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
rag/benchmark.py
CHANGED
|
@@ -23,7 +23,7 @@ from collections import defaultdict
|
|
| 23 |
from api.db import LLMType
|
| 24 |
from api.db.services.llm_service import LLMBundle
|
| 25 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 26 |
-
from api
|
| 27 |
from api.utils import get_uuid
|
| 28 |
from rag.nlp import tokenize, search
|
| 29 |
from ranx import evaluate
|
|
@@ -52,7 +52,7 @@ class Benchmark:
|
|
| 52 |
run = defaultdict(dict)
|
| 53 |
query_list = list(qrels.keys())
|
| 54 |
for query in query_list:
|
| 55 |
-
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
| 56 |
0.0, self.vector_similarity_weight)
|
| 57 |
if len(ranks["chunks"]) == 0:
|
| 58 |
print(f"deleted query: {query}")
|
|
@@ -81,9 +81,9 @@ class Benchmark:
|
|
| 81 |
def init_index(self, vector_size: int):
|
| 82 |
if self.initialized_index:
|
| 83 |
return
|
| 84 |
-
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
| 85 |
-
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
| 86 |
-
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
| 87 |
self.initialized_index = True
|
| 88 |
|
| 89 |
def ms_marco_index(self, file_path, index_name):
|
|
@@ -118,13 +118,13 @@ class Benchmark:
|
|
| 118 |
docs_count += len(docs)
|
| 119 |
docs, vector_size = self.embedding(docs)
|
| 120 |
self.init_index(vector_size)
|
| 121 |
-
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
| 122 |
docs = []
|
| 123 |
|
| 124 |
if docs:
|
| 125 |
docs, vector_size = self.embedding(docs)
|
| 126 |
self.init_index(vector_size)
|
| 127 |
-
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
| 128 |
return qrels, texts
|
| 129 |
|
| 130 |
def trivia_qa_index(self, file_path, index_name):
|
|
@@ -159,12 +159,12 @@ class Benchmark:
|
|
| 159 |
docs_count += len(docs)
|
| 160 |
docs, vector_size = self.embedding(docs)
|
| 161 |
self.init_index(vector_size)
|
| 162 |
-
docStoreConn.insert(docs,self.index_name)
|
| 163 |
docs = []
|
| 164 |
|
| 165 |
docs, vector_size = self.embedding(docs)
|
| 166 |
self.init_index(vector_size)
|
| 167 |
-
docStoreConn.insert(docs, self.index_name)
|
| 168 |
return qrels, texts
|
| 169 |
|
| 170 |
def miracl_index(self, file_path, corpus_path, index_name):
|
|
@@ -214,12 +214,12 @@ class Benchmark:
|
|
| 214 |
docs_count += len(docs)
|
| 215 |
docs, vector_size = self.embedding(docs)
|
| 216 |
self.init_index(vector_size)
|
| 217 |
-
docStoreConn.insert(docs, self.index_name)
|
| 218 |
docs = []
|
| 219 |
|
| 220 |
docs, vector_size = self.embedding(docs)
|
| 221 |
self.init_index(vector_size)
|
| 222 |
-
docStoreConn.insert(docs, self.index_name)
|
| 223 |
return qrels, texts
|
| 224 |
|
| 225 |
def save_results(self, qrels, run, texts, dataset, file_path):
|
|
|
|
| 23 |
from api.db import LLMType
|
| 24 |
from api.db.services.llm_service import LLMBundle
|
| 25 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 26 |
+
from api import settings
|
| 27 |
from api.utils import get_uuid
|
| 28 |
from rag.nlp import tokenize, search
|
| 29 |
from ranx import evaluate
|
|
|
|
| 52 |
run = defaultdict(dict)
|
| 53 |
query_list = list(qrels.keys())
|
| 54 |
for query in query_list:
|
| 55 |
+
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
| 56 |
0.0, self.vector_similarity_weight)
|
| 57 |
if len(ranks["chunks"]) == 0:
|
| 58 |
print(f"deleted query: {query}")
|
|
|
|
| 81 |
def init_index(self, vector_size: int):
|
| 82 |
if self.initialized_index:
|
| 83 |
return
|
| 84 |
+
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
| 85 |
+
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
| 86 |
+
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
| 87 |
self.initialized_index = True
|
| 88 |
|
| 89 |
def ms_marco_index(self, file_path, index_name):
|
|
|
|
| 118 |
docs_count += len(docs)
|
| 119 |
docs, vector_size = self.embedding(docs)
|
| 120 |
self.init_index(vector_size)
|
| 121 |
+
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
| 122 |
docs = []
|
| 123 |
|
| 124 |
if docs:
|
| 125 |
docs, vector_size = self.embedding(docs)
|
| 126 |
self.init_index(vector_size)
|
| 127 |
+
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
| 128 |
return qrels, texts
|
| 129 |
|
| 130 |
def trivia_qa_index(self, file_path, index_name):
|
|
|
|
| 159 |
docs_count += len(docs)
|
| 160 |
docs, vector_size = self.embedding(docs)
|
| 161 |
self.init_index(vector_size)
|
| 162 |
+
settings.docStoreConn.insert(docs,self.index_name)
|
| 163 |
docs = []
|
| 164 |
|
| 165 |
docs, vector_size = self.embedding(docs)
|
| 166 |
self.init_index(vector_size)
|
| 167 |
+
settings.docStoreConn.insert(docs, self.index_name)
|
| 168 |
return qrels, texts
|
| 169 |
|
| 170 |
def miracl_index(self, file_path, corpus_path, index_name):
|
|
|
|
| 214 |
docs_count += len(docs)
|
| 215 |
docs, vector_size = self.embedding(docs)
|
| 216 |
self.init_index(vector_size)
|
| 217 |
+
settings.docStoreConn.insert(docs, self.index_name)
|
| 218 |
docs = []
|
| 219 |
|
| 220 |
docs, vector_size = self.embedding(docs)
|
| 221 |
self.init_index(vector_size)
|
| 222 |
+
settings.docStoreConn.insert(docs, self.index_name)
|
| 223 |
return qrels, texts
|
| 224 |
|
| 225 |
def save_results(self, qrels, run, texts, dataset, file_path):
|
rag/llm/embedding_model.py
CHANGED
|
@@ -28,7 +28,7 @@ from openai import OpenAI
|
|
| 28 |
import numpy as np
|
| 29 |
import asyncio
|
| 30 |
|
| 31 |
-
from api
|
| 32 |
from api.utils.file_utils import get_home_cache_dir
|
| 33 |
from rag.utils import num_tokens_from_string, truncate
|
| 34 |
import google.generativeai as genai
|
|
@@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
|
|
| 60 |
^_-
|
| 61 |
|
| 62 |
"""
|
| 63 |
-
if not LIGHTEN and not DefaultEmbedding._model:
|
| 64 |
with DefaultEmbedding._model_lock:
|
| 65 |
from FlagEmbedding import FlagModel
|
| 66 |
import torch
|
|
@@ -248,7 +248,7 @@ class FastEmbed(Base):
|
|
| 248 |
threads: Optional[int] = None,
|
| 249 |
**kwargs,
|
| 250 |
):
|
| 251 |
-
if not LIGHTEN and not FastEmbed._model:
|
| 252 |
from fastembed import TextEmbedding
|
| 253 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 254 |
|
|
@@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
|
|
| 294 |
_client = None
|
| 295 |
|
| 296 |
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
| 297 |
-
if not LIGHTEN and not YoudaoEmbed._client:
|
| 298 |
from BCEmbedding import EmbeddingModel as qanthing
|
| 299 |
try:
|
| 300 |
logging.info("LOADING BCE...")
|
|
|
|
| 28 |
import numpy as np
|
| 29 |
import asyncio
|
| 30 |
|
| 31 |
+
from api import settings
|
| 32 |
from api.utils.file_utils import get_home_cache_dir
|
| 33 |
from rag.utils import num_tokens_from_string, truncate
|
| 34 |
import google.generativeai as genai
|
|
|
|
| 60 |
^_-
|
| 61 |
|
| 62 |
"""
|
| 63 |
+
if not settings.LIGHTEN and not DefaultEmbedding._model:
|
| 64 |
with DefaultEmbedding._model_lock:
|
| 65 |
from FlagEmbedding import FlagModel
|
| 66 |
import torch
|
|
|
|
| 248 |
threads: Optional[int] = None,
|
| 249 |
**kwargs,
|
| 250 |
):
|
| 251 |
+
if not settings.LIGHTEN and not FastEmbed._model:
|
| 252 |
from fastembed import TextEmbedding
|
| 253 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 254 |
|
|
|
|
| 294 |
_client = None
|
| 295 |
|
| 296 |
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
| 297 |
+
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
| 298 |
from BCEmbedding import EmbeddingModel as qanthing
|
| 299 |
try:
|
| 300 |
logging.info("LOADING BCE...")
|
rag/llm/rerank_model.py
CHANGED
|
@@ -23,7 +23,7 @@ import os
|
|
| 23 |
from abc import ABC
|
| 24 |
import numpy as np
|
| 25 |
|
| 26 |
-
from api
|
| 27 |
from api.utils.file_utils import get_home_cache_dir
|
| 28 |
from rag.utils import num_tokens_from_string, truncate
|
| 29 |
import json
|
|
@@ -57,7 +57,7 @@ class DefaultRerank(Base):
|
|
| 57 |
^_-
|
| 58 |
|
| 59 |
"""
|
| 60 |
-
if not LIGHTEN and not DefaultRerank._model:
|
| 61 |
import torch
|
| 62 |
from FlagEmbedding import FlagReranker
|
| 63 |
with DefaultRerank._model_lock:
|
|
@@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
|
|
| 121 |
_model_lock = threading.Lock()
|
| 122 |
|
| 123 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
| 124 |
-
if not LIGHTEN and not YoudaoRerank._model:
|
| 125 |
from BCEmbedding import RerankerModel
|
| 126 |
with YoudaoRerank._model_lock:
|
| 127 |
if not YoudaoRerank._model:
|
|
|
|
| 23 |
from abc import ABC
|
| 24 |
import numpy as np
|
| 25 |
|
| 26 |
+
from api import settings
|
| 27 |
from api.utils.file_utils import get_home_cache_dir
|
| 28 |
from rag.utils import num_tokens_from_string, truncate
|
| 29 |
import json
|
|
|
|
| 57 |
^_-
|
| 58 |
|
| 59 |
"""
|
| 60 |
+
if not settings.LIGHTEN and not DefaultRerank._model:
|
| 61 |
import torch
|
| 62 |
from FlagEmbedding import FlagReranker
|
| 63 |
with DefaultRerank._model_lock:
|
|
|
|
| 121 |
_model_lock = threading.Lock()
|
| 122 |
|
| 123 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
| 124 |
+
if not settings.LIGHTEN and not YoudaoRerank._model:
|
| 125 |
from BCEmbedding import RerankerModel
|
| 126 |
with YoudaoRerank._model_lock:
|
| 127 |
if not YoudaoRerank._model:
|
rag/svr/task_executor.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
| 16 |
import logging
|
| 17 |
import sys
|
| 18 |
from api.utils.log_utils import initRootLogger
|
|
|
|
| 19 |
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
| 20 |
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
| 21 |
for module in ["pdfminer"]:
|
|
@@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
|
|
| 49 |
from api.db.services.llm_service import LLMBundle
|
| 50 |
from api.db.services.task_service import TaskService
|
| 51 |
from api.db.services.file2document_service import File2DocumentService
|
| 52 |
-
from api
|
| 53 |
from api.db.db_models import close_connection
|
| 54 |
-
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio,
|
|
|
|
| 55 |
from rag.nlp import search, rag_tokenizer
|
| 56 |
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
| 57 |
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
|
@@ -88,6 +90,7 @@ PENDING_TASKS = 0
|
|
| 88 |
HEAD_CREATED_AT = ""
|
| 89 |
HEAD_DETAIL = ""
|
| 90 |
|
|
|
|
| 91 |
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
| 92 |
global PAYLOAD
|
| 93 |
if prog is not None and prog < 0:
|
|
@@ -171,7 +174,8 @@ def build(row):
|
|
| 171 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 172 |
except TimeoutError:
|
| 173 |
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
| 174 |
-
logging.exception(
|
|
|
|
| 175 |
return
|
| 176 |
except Exception as e:
|
| 177 |
if re.search("(No such file|not found)", str(e)):
|
|
@@ -188,7 +192,7 @@ def build(row):
|
|
| 188 |
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
|
| 189 |
except Exception as e:
|
| 190 |
callback(-1, "Internal server error while chunking: %s" %
|
| 191 |
-
|
| 192 |
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
|
| 193 |
return
|
| 194 |
|
|
@@ -226,7 +230,8 @@ def build(row):
|
|
| 226 |
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
| 227 |
el += timer() - st
|
| 228 |
except Exception:
|
| 229 |
-
logging.exception(
|
|
|
|
| 230 |
|
| 231 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
| 232 |
del d["image"]
|
|
@@ -241,7 +246,7 @@ def build(row):
|
|
| 241 |
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
|
| 242 |
row["parser_config"]["auto_keywords"]).split(",")
|
| 243 |
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
| 244 |
-
callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
|
| 245 |
|
| 246 |
if row["parser_config"].get("auto_questions", 0):
|
| 247 |
st = timer()
|
|
@@ -255,14 +260,14 @@ def build(row):
|
|
| 255 |
d["content_ltks"] += " " + qst
|
| 256 |
if "content_sm_ltks" in d:
|
| 257 |
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
|
| 258 |
-
callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
|
| 259 |
|
| 260 |
return docs
|
| 261 |
|
| 262 |
|
| 263 |
def init_kb(row, vector_size: int):
|
| 264 |
idxnm = search.index_name(row["tenant_id"])
|
| 265 |
-
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
| 266 |
|
| 267 |
|
| 268 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
|
@@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|
| 313 |
vector_size = len(vts[0])
|
| 314 |
vctr_nm = "q_%d_vec" % vector_size
|
| 315 |
chunks = []
|
| 316 |
-
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
|
|
|
| 317 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
| 318 |
|
| 319 |
raptor = Raptor(
|
|
@@ -384,7 +390,8 @@ def main():
|
|
| 384 |
# TODO: exception handler
|
| 385 |
## set_progress(r["did"], -1, "ERROR: ")
|
| 386 |
callback(
|
| 387 |
-
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
|
|
|
|
| 388 |
)
|
| 389 |
st = timer()
|
| 390 |
try:
|
|
@@ -403,18 +410,18 @@ def main():
|
|
| 403 |
es_r = ""
|
| 404 |
es_bulk_size = 4
|
| 405 |
for b in range(0, len(cks), es_bulk_size):
|
| 406 |
-
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
| 407 |
if b % 128 == 0:
|
| 408 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
| 409 |
|
| 410 |
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
| 411 |
if es_r:
|
| 412 |
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
| 413 |
-
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
| 414 |
logging.error('Insert chunk error: ' + str(es_r))
|
| 415 |
else:
|
| 416 |
if TaskService.do_cancel(r["id"]):
|
| 417 |
-
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
| 418 |
continue
|
| 419 |
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
| 420 |
callback(1., "Done!")
|
|
@@ -435,7 +442,7 @@ def report_status():
|
|
| 435 |
if PENDING_TASKS > 0:
|
| 436 |
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
|
| 437 |
if head_info is not None:
|
| 438 |
-
seconds = int(head_info[0].split("-")[0])/1000
|
| 439 |
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
|
| 440 |
HEAD_DETAIL = head_info[1]
|
| 441 |
|
|
@@ -452,7 +459,7 @@ def report_status():
|
|
| 452 |
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
| 453 |
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
| 454 |
|
| 455 |
-
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
|
| 456 |
if expired > 0:
|
| 457 |
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
| 458 |
except Exception:
|
|
|
|
| 16 |
import logging
|
| 17 |
import sys
|
| 18 |
from api.utils.log_utils import initRootLogger
|
| 19 |
+
|
| 20 |
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
| 21 |
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
| 22 |
for module in ["pdfminer"]:
|
|
|
|
| 50 |
from api.db.services.llm_service import LLMBundle
|
| 51 |
from api.db.services.task_service import TaskService
|
| 52 |
from api.db.services.file2document_service import File2DocumentService
|
| 53 |
+
from api import settings
|
| 54 |
from api.db.db_models import close_connection
|
| 55 |
+
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
| 56 |
+
knowledge_graph, email
|
| 57 |
from rag.nlp import search, rag_tokenizer
|
| 58 |
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
| 59 |
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
|
|
|
| 90 |
HEAD_CREATED_AT = ""
|
| 91 |
HEAD_DETAIL = ""
|
| 92 |
|
| 93 |
+
|
| 94 |
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
| 95 |
global PAYLOAD
|
| 96 |
if prog is not None and prog < 0:
|
|
|
|
| 174 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 175 |
except TimeoutError:
|
| 176 |
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
| 177 |
+
logging.exception(
|
| 178 |
+
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
| 179 |
return
|
| 180 |
except Exception as e:
|
| 181 |
if re.search("(No such file|not found)", str(e)):
|
|
|
|
| 192 |
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
|
| 193 |
except Exception as e:
|
| 194 |
callback(-1, "Internal server error while chunking: %s" %
|
| 195 |
+
str(e).replace("'", ""))
|
| 196 |
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
|
| 197 |
return
|
| 198 |
|
|
|
|
| 230 |
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
| 231 |
el += timer() - st
|
| 232 |
except Exception:
|
| 233 |
+
logging.exception(
|
| 234 |
+
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
| 235 |
|
| 236 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
| 237 |
del d["image"]
|
|
|
|
| 246 |
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
|
| 247 |
row["parser_config"]["auto_keywords"]).split(",")
|
| 248 |
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
| 249 |
+
callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
|
| 250 |
|
| 251 |
if row["parser_config"].get("auto_questions", 0):
|
| 252 |
st = timer()
|
|
|
|
| 260 |
d["content_ltks"] += " " + qst
|
| 261 |
if "content_sm_ltks" in d:
|
| 262 |
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
|
| 263 |
+
callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
| 264 |
|
| 265 |
return docs
|
| 266 |
|
| 267 |
|
| 268 |
def init_kb(row, vector_size: int):
|
| 269 |
idxnm = search.index_name(row["tenant_id"])
|
| 270 |
+
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
| 271 |
|
| 272 |
|
| 273 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
|
|
|
| 318 |
vector_size = len(vts[0])
|
| 319 |
vctr_nm = "q_%d_vec" % vector_size
|
| 320 |
chunks = []
|
| 321 |
+
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
| 322 |
+
fields=["content_with_weight", vctr_nm]):
|
| 323 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
| 324 |
|
| 325 |
raptor = Raptor(
|
|
|
|
| 390 |
# TODO: exception handler
|
| 391 |
## set_progress(r["did"], -1, "ERROR: ")
|
| 392 |
callback(
|
| 393 |
+
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
|
| 394 |
+
timer() - st)
|
| 395 |
)
|
| 396 |
st = timer()
|
| 397 |
try:
|
|
|
|
| 410 |
es_r = ""
|
| 411 |
es_bulk_size = 4
|
| 412 |
for b in range(0, len(cks), es_bulk_size):
|
| 413 |
+
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
| 414 |
if b % 128 == 0:
|
| 415 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
| 416 |
|
| 417 |
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
| 418 |
if es_r:
|
| 419 |
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
| 420 |
+
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
| 421 |
logging.error('Insert chunk error: ' + str(es_r))
|
| 422 |
else:
|
| 423 |
if TaskService.do_cancel(r["id"]):
|
| 424 |
+
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
| 425 |
continue
|
| 426 |
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
| 427 |
callback(1., "Done!")
|
|
|
|
| 442 |
if PENDING_TASKS > 0:
|
| 443 |
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
|
| 444 |
if head_info is not None:
|
| 445 |
+
seconds = int(head_info[0].split("-")[0]) / 1000
|
| 446 |
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
|
| 447 |
HEAD_DETAIL = head_info[1]
|
| 448 |
|
|
|
|
| 459 |
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
| 460 |
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
| 461 |
|
| 462 |
+
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
|
| 463 |
if expired > 0:
|
| 464 |
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
| 465 |
except Exception:
|