valdanito commited on
Commit
36537aa
·
1 Parent(s): 66857b9

API: retrieval api (#1763)

Browse files

### What problem does this PR solve?

Add retrieval api on a specific knowledge base


![ragflow](https://github.com/user-attachments/assets/dc30a4c3-03c5-4d34-bb7c-60b8830f1225)

https://github.com/infiniflow/ragflow/issues/1102

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Files changed (1) hide show
  1. api/apps/api_app.py +55 -1
api/apps/api_app.py CHANGED
@@ -20,7 +20,7 @@ from datetime import datetime, timedelta
20
  from flask import request, Response
21
  from flask_login import login_required, current_user
22
 
23
- from api.db import FileType, ParserType, FileSource
24
  from api.db.db_models import APIToken, API4Conversation, Task, File
25
  from api.db.services import duplicate_name
26
  from api.db.services.api_service import APITokenService, API4ConversationService
@@ -29,6 +29,7 @@ from api.db.services.document_service import DocumentService
29
  from api.db.services.file2document_service import File2DocumentService
30
  from api.db.services.file_service import FileService
31
  from api.db.services.knowledgebase_service import KnowledgebaseService
 
32
  from api.db.services.task_service import queue_tasks, TaskService
33
  from api.db.services.user_service import UserTenantService
34
  from api.settings import RetCode, retrievaler
@@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
37
  from itsdangerous import URLSafeTimedSerializer
38
 
39
  from api.utils.file_utils import filename_type, thumbnail
 
40
  from rag.utils.minio_conn import MINIO
41
 
42
 
@@ -587,3 +589,55 @@ def completion_faq():
587
 
588
  except Exception as e:
589
  return server_error_response(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from flask import request, Response
21
  from flask_login import login_required, current_user
22
 
23
+ from api.db import FileType, ParserType, FileSource, LLMType
24
  from api.db.db_models import APIToken, API4Conversation, Task, File
25
  from api.db.services import duplicate_name
26
  from api.db.services.api_service import APITokenService, API4ConversationService
 
29
  from api.db.services.file2document_service import File2DocumentService
30
  from api.db.services.file_service import FileService
31
  from api.db.services.knowledgebase_service import KnowledgebaseService
32
+ from api.db.services.llm_service import TenantLLMService
33
  from api.db.services.task_service import queue_tasks, TaskService
34
  from api.db.services.user_service import UserTenantService
35
  from api.settings import RetCode, retrievaler
 
38
  from itsdangerous import URLSafeTimedSerializer
39
 
40
  from api.utils.file_utils import filename_type, thumbnail
41
+ from rag.nlp import keyword_extraction
42
  from rag.utils.minio_conn import MINIO
43
 
44
 
 
589
 
590
  except Exception as e:
591
  return server_error_response(e)
592
+
593
+
594
+ @manager.route('/retrieval', methods=['POST'])
595
+ @validate_request("kb_id", "question")
596
+ def retrieval():
597
+ token = request.headers.get('Authorization').split()[1]
598
+ objs = APIToken.query(token=token)
599
+ if not objs:
600
+ return get_json_result(
601
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
602
+
603
+ req = request.json
604
+ kb_id = req.get("kb_id")
605
+ doc_ids = req.get("doc_ids", [])
606
+ question = req.get("question")
607
+ page = int(req.get("page", 1))
608
+ size = int(req.get("size", 30))
609
+ similarity_threshold = float(req.get("similarity_threshold", 0.2))
610
+ vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
611
+ top = int(req.get("top_k", 1024))
612
+
613
+ try:
614
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
615
+ if not e:
616
+ return get_data_error_result(retmsg="Knowledgebase not found!")
617
+
618
+ embd_mdl = TenantLLMService.model_instance(
619
+ kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
620
+
621
+ rerank_mdl = None
622
+ if req.get("rerank_id"):
623
+ rerank_mdl = TenantLLMService.model_instance(
624
+ kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
625
+
626
+ if req.get("keyword", False):
627
+ chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
628
+ question += keyword_extraction(chat_mdl, question)
629
+
630
+ ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
631
+ similarity_threshold, vector_similarity_weight, top,
632
+ doc_ids, rerank_mdl=rerank_mdl)
633
+ for c in ranks["chunks"]:
634
+ if "vector" in c:
635
+ del c["vector"]
636
+
637
+ return get_json_result(data=ranks)
638
+ except Exception as e:
639
+ if str(e).find("not_found") > 0:
640
+ return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
641
+ retcode=RetCode.DATA_ERROR)
642
+ return server_error_response(e)
643
+