API: retrieval api (#1763)
Browse files### What problem does this PR solve?
Add retrieval api on a specific knowledge base

https://github.com/infiniflow/ragflow/issues/1102
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/api_app.py +55 -1
api/apps/api_app.py
CHANGED
@@ -20,7 +20,7 @@ from datetime import datetime, timedelta
|
|
20 |
from flask import request, Response
|
21 |
from flask_login import login_required, current_user
|
22 |
|
23 |
-
from api.db import FileType, ParserType, FileSource
|
24 |
from api.db.db_models import APIToken, API4Conversation, Task, File
|
25 |
from api.db.services import duplicate_name
|
26 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
@@ -29,6 +29,7 @@ from api.db.services.document_service import DocumentService
|
|
29 |
from api.db.services.file2document_service import File2DocumentService
|
30 |
from api.db.services.file_service import FileService
|
31 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
32 |
from api.db.services.task_service import queue_tasks, TaskService
|
33 |
from api.db.services.user_service import UserTenantService
|
34 |
from api.settings import RetCode, retrievaler
|
@@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
|
37 |
from itsdangerous import URLSafeTimedSerializer
|
38 |
|
39 |
from api.utils.file_utils import filename_type, thumbnail
|
|
|
40 |
from rag.utils.minio_conn import MINIO
|
41 |
|
42 |
|
@@ -587,3 +589,55 @@ def completion_faq():
|
|
587 |
|
588 |
except Exception as e:
|
589 |
return server_error_response(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
from flask import request, Response
|
21 |
from flask_login import login_required, current_user
|
22 |
|
23 |
+
from api.db import FileType, ParserType, FileSource, LLMType
|
24 |
from api.db.db_models import APIToken, API4Conversation, Task, File
|
25 |
from api.db.services import duplicate_name
|
26 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
|
|
29 |
from api.db.services.file2document_service import File2DocumentService
|
30 |
from api.db.services.file_service import FileService
|
31 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
32 |
+
from api.db.services.llm_service import TenantLLMService
|
33 |
from api.db.services.task_service import queue_tasks, TaskService
|
34 |
from api.db.services.user_service import UserTenantService
|
35 |
from api.settings import RetCode, retrievaler
|
|
|
38 |
from itsdangerous import URLSafeTimedSerializer
|
39 |
|
40 |
from api.utils.file_utils import filename_type, thumbnail
|
41 |
+
from rag.nlp import keyword_extraction
|
42 |
from rag.utils.minio_conn import MINIO
|
43 |
|
44 |
|
|
|
589 |
|
590 |
except Exception as e:
|
591 |
return server_error_response(e)
|
592 |
+
|
593 |
+
|
594 |
+
@manager.route('/retrieval', methods=['POST'])
|
595 |
+
@validate_request("kb_id", "question")
|
596 |
+
def retrieval():
|
597 |
+
token = request.headers.get('Authorization').split()[1]
|
598 |
+
objs = APIToken.query(token=token)
|
599 |
+
if not objs:
|
600 |
+
return get_json_result(
|
601 |
+
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
602 |
+
|
603 |
+
req = request.json
|
604 |
+
kb_id = req.get("kb_id")
|
605 |
+
doc_ids = req.get("doc_ids", [])
|
606 |
+
question = req.get("question")
|
607 |
+
page = int(req.get("page", 1))
|
608 |
+
size = int(req.get("size", 30))
|
609 |
+
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
610 |
+
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
611 |
+
top = int(req.get("top_k", 1024))
|
612 |
+
|
613 |
+
try:
|
614 |
+
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
615 |
+
if not e:
|
616 |
+
return get_data_error_result(retmsg="Knowledgebase not found!")
|
617 |
+
|
618 |
+
embd_mdl = TenantLLMService.model_instance(
|
619 |
+
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
620 |
+
|
621 |
+
rerank_mdl = None
|
622 |
+
if req.get("rerank_id"):
|
623 |
+
rerank_mdl = TenantLLMService.model_instance(
|
624 |
+
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
625 |
+
|
626 |
+
if req.get("keyword", False):
|
627 |
+
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
628 |
+
question += keyword_extraction(chat_mdl, question)
|
629 |
+
|
630 |
+
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
631 |
+
similarity_threshold, vector_similarity_weight, top,
|
632 |
+
doc_ids, rerank_mdl=rerank_mdl)
|
633 |
+
for c in ranks["chunks"]:
|
634 |
+
if "vector" in c:
|
635 |
+
del c["vector"]
|
636 |
+
|
637 |
+
return get_json_result(data=ranks)
|
638 |
+
except Exception as e:
|
639 |
+
if str(e).find("not_found") > 0:
|
640 |
+
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
641 |
+
retcode=RetCode.DATA_ERROR)
|
642 |
+
return server_error_response(e)
|
643 |
+
|