wwwlll
commited on
Commit
·
17f4221
1
Parent(s):
7b030d6
Fix retrieval API error and add multi-kb search (#1928)
Browse files### What problem does this PR solve?
Type of change
Bug Fix (Import necessary class for retrieval API )
New Feature (Add multi-KB search to retrieval API)
- api/apps/api_app.py +16 -16
api/apps/api_app.py
CHANGED
|
@@ -18,9 +18,10 @@ import os
|
|
| 18 |
import re
|
| 19 |
from datetime import datetime, timedelta
|
| 20 |
from flask import request, Response
|
|
|
|
| 21 |
from flask_login import login_required, current_user
|
| 22 |
|
| 23 |
-
from api.db import FileType, ParserType, FileSource
|
| 24 |
from api.db.db_models import APIToken, API4Conversation, Task, File
|
| 25 |
from api.db.services import duplicate_name
|
| 26 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
|
@@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
|
| 37 |
from itsdangerous import URLSafeTimedSerializer
|
| 38 |
|
| 39 |
from api.utils.file_utils import filename_type, thumbnail
|
|
|
|
| 40 |
from rag.utils.minio_conn import MINIO
|
| 41 |
|
| 42 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
|
@@ -694,7 +696,7 @@ def retrieval():
|
|
| 694 |
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
| 695 |
|
| 696 |
req = request.json
|
| 697 |
-
|
| 698 |
doc_ids = req.get("doc_ids", [])
|
| 699 |
question = req.get("question")
|
| 700 |
page = int(req.get("page", 1))
|
|
@@ -704,32 +706,30 @@ def retrieval():
|
|
| 704 |
top = int(req.get("top_k", 1024))
|
| 705 |
|
| 706 |
try:
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
|
|
|
|
|
|
| 710 |
|
| 711 |
embd_mdl = TenantLLMService.model_instance(
|
| 712 |
-
|
| 713 |
-
|
| 714 |
rerank_mdl = None
|
| 715 |
if req.get("rerank_id"):
|
| 716 |
rerank_mdl = TenantLLMService.model_instance(
|
| 717 |
-
|
| 718 |
-
|
| 719 |
if req.get("keyword", False):
|
| 720 |
-
chat_mdl = TenantLLMService.model_instance(
|
| 721 |
question += keyword_extraction(chat_mdl, question)
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
doc_ids, rerank_mdl=rerank_mdl)
|
| 726 |
for c in ranks["chunks"]:
|
| 727 |
if "vector" in c:
|
| 728 |
del c["vector"]
|
| 729 |
-
|
| 730 |
return get_json_result(data=ranks)
|
| 731 |
except Exception as e:
|
| 732 |
if str(e).find("not_found") > 0:
|
| 733 |
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
| 734 |
retcode=RetCode.DATA_ERROR)
|
| 735 |
-
return server_error_response(e)
|
|
|
|
| 18 |
import re
|
| 19 |
from datetime import datetime, timedelta
|
| 20 |
from flask import request, Response
|
| 21 |
+
from api.db.services.llm_service import TenantLLMService
|
| 22 |
from flask_login import login_required, current_user
|
| 23 |
|
| 24 |
+
from api.db import FileType, LLMType, ParserType, FileSource
|
| 25 |
from api.db.db_models import APIToken, API4Conversation, Task, File
|
| 26 |
from api.db.services import duplicate_name
|
| 27 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
|
|
|
| 38 |
from itsdangerous import URLSafeTimedSerializer
|
| 39 |
|
| 40 |
from api.utils.file_utils import filename_type, thumbnail
|
| 41 |
+
from rag.nlp import keyword_extraction
|
| 42 |
from rag.utils.minio_conn import MINIO
|
| 43 |
|
| 44 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
|
|
|
| 696 |
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
| 697 |
|
| 698 |
req = request.json
|
| 699 |
+
kb_ids = req.get("kb_id",[])
|
| 700 |
doc_ids = req.get("doc_ids", [])
|
| 701 |
question = req.get("question")
|
| 702 |
page = int(req.get("page", 1))
|
|
|
|
| 706 |
top = int(req.get("top_k", 1024))
|
| 707 |
|
| 708 |
try:
|
| 709 |
+
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
| 710 |
+
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 711 |
+
if len(embd_nms) != 1:
|
| 712 |
+
return get_json_result(
|
| 713 |
+
data=False, retmsg='Knowledge bases use different embedding models or does not exist."', retcode=RetCode.AUTHENTICATION_ERROR)
|
| 714 |
|
| 715 |
embd_mdl = TenantLLMService.model_instance(
|
| 716 |
+
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
|
|
|
| 717 |
rerank_mdl = None
|
| 718 |
if req.get("rerank_id"):
|
| 719 |
rerank_mdl = TenantLLMService.model_instance(
|
| 720 |
+
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
|
|
|
| 721 |
if req.get("keyword", False):
|
| 722 |
+
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
| 723 |
question += keyword_extraction(chat_mdl, question)
|
| 724 |
+
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
| 725 |
+
similarity_threshold, vector_similarity_weight, top,
|
| 726 |
+
doc_ids, rerank_mdl=rerank_mdl)
|
|
|
|
| 727 |
for c in ranks["chunks"]:
|
| 728 |
if "vector" in c:
|
| 729 |
del c["vector"]
|
|
|
|
| 730 |
return get_json_result(data=ranks)
|
| 731 |
except Exception as e:
|
| 732 |
if str(e).find("not_found") > 0:
|
| 733 |
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
| 734 |
retcode=RetCode.DATA_ERROR)
|
| 735 |
+
return server_error_response(e)
|