GYH
commited on
Commit
·
99ebece
1
Parent(s):
3c8f464
Add api/list_kb_docs function and modify api/list_chunks (#874)
Browse files### What problem does this PR solve?
#717
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/api_app.py +49 -14
- api/db/services/document_service.py +11 -0
- rag/nlp/search.py +10 -0
api/apps/api_app.py
CHANGED
|
@@ -31,7 +31,7 @@ from api.db.services.file_service import FileService
|
|
| 31 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 32 |
from api.db.services.task_service import queue_tasks, TaskService
|
| 33 |
from api.db.services.user_service import UserTenantService
|
| 34 |
-
from api.settings import RetCode
|
| 35 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
| 36 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
| 37 |
from itsdangerous import URLSafeTimedSerializer
|
|
@@ -39,9 +39,6 @@ from itsdangerous import URLSafeTimedSerializer
|
|
| 39 |
from api.utils.file_utils import filename_type, thumbnail
|
| 40 |
from rag.utils.minio_conn import MINIO
|
| 41 |
|
| 42 |
-
from rag.utils.es_conn import ELASTICSEARCH
|
| 43 |
-
from rag.nlp import search
|
| 44 |
-
from elasticsearch_dsl import Q
|
| 45 |
|
| 46 |
def generate_confirmation_token(tenent_id):
|
| 47 |
serializer = URLSafeTimedSerializer(tenent_id)
|
|
@@ -369,27 +366,65 @@ def list_chunks():
|
|
| 369 |
try:
|
| 370 |
if "doc_name" in form_data.keys():
|
| 371 |
tenant_id = DocumentService.get_tenant_id_by_name(form_data['doc_name'])
|
| 372 |
-
|
| 373 |
|
| 374 |
elif "doc_id" in form_data.keys():
|
| 375 |
tenant_id = DocumentService.get_tenant_id(form_data['doc_id'])
|
| 376 |
-
|
| 377 |
else:
|
| 378 |
return get_json_result(
|
| 379 |
data=False,retmsg="Can't find doc_name or doc_id"
|
| 380 |
)
|
| 381 |
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
except Exception as e:
|
| 393 |
return server_error_response(e)
|
| 394 |
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 32 |
from api.db.services.task_service import queue_tasks, TaskService
|
| 33 |
from api.db.services.user_service import UserTenantService
|
| 34 |
+
from api.settings import RetCode, retrievaler
|
| 35 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
| 36 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
| 37 |
from itsdangerous import URLSafeTimedSerializer
|
|
|
|
| 39 |
from api.utils.file_utils import filename_type, thumbnail
|
| 40 |
from rag.utils.minio_conn import MINIO
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def generate_confirmation_token(tenent_id):
|
| 44 |
serializer = URLSafeTimedSerializer(tenent_id)
|
|
|
|
| 366 |
try:
|
| 367 |
if "doc_name" in form_data.keys():
|
| 368 |
tenant_id = DocumentService.get_tenant_id_by_name(form_data['doc_name'])
|
| 369 |
+
doc_id = DocumentService.get_doc_id_by_doc_name(form_data['doc_name'])
|
| 370 |
|
| 371 |
elif "doc_id" in form_data.keys():
|
| 372 |
tenant_id = DocumentService.get_tenant_id(form_data['doc_id'])
|
| 373 |
+
doc_id = form_data['doc_id']
|
| 374 |
else:
|
| 375 |
return get_json_result(
|
| 376 |
data=False,retmsg="Can't find doc_name or doc_id"
|
| 377 |
)
|
| 378 |
|
| 379 |
+
res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
|
| 380 |
+
res = [
|
| 381 |
+
{
|
| 382 |
+
"content": res_item["content_with_weight"],
|
| 383 |
+
"doc_name": res_item["docnm_kwd"],
|
| 384 |
+
"img_id": res_item["img_id"]
|
| 385 |
+
} for res_item in res
|
| 386 |
+
]
|
| 387 |
|
| 388 |
+
except Exception as e:
|
| 389 |
+
return server_error_response(e)
|
| 390 |
+
|
| 391 |
+
return get_json_result(data=res)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
@manager.route('/list_kb_docs', methods=['POST'])
|
| 395 |
+
# @login_required
|
| 396 |
+
def list_kb_docs():
|
| 397 |
+
token = request.headers.get('Authorization').split()[1]
|
| 398 |
+
objs = APIToken.query(token=token)
|
| 399 |
+
if not objs:
|
| 400 |
+
return get_json_result(
|
| 401 |
+
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
| 402 |
|
| 403 |
+
tenant_id = objs[0].tenant_id
|
| 404 |
+
kb_name = request.form.get("kb_name").strip()
|
| 405 |
+
|
| 406 |
+
try:
|
| 407 |
+
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
| 408 |
+
if not e:
|
| 409 |
+
return get_data_error_result(
|
| 410 |
+
retmsg="Can't find this knowledgebase!")
|
| 411 |
+
kb_id = kb.id
|
| 412 |
|
| 413 |
except Exception as e:
|
| 414 |
return server_error_response(e)
|
| 415 |
|
| 416 |
+
page_number = int(request.form.get("page", 1))
|
| 417 |
+
items_per_page = int(request.form.get("page_size", 15))
|
| 418 |
+
orderby = request.form.get("orderby", "create_time")
|
| 419 |
+
desc = request.form.get("desc", True)
|
| 420 |
+
keywords = request.form.get("keywords", "")
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
docs, tol = DocumentService.get_by_kb_id(
|
| 424 |
+
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
| 425 |
+
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
|
| 426 |
+
|
| 427 |
+
return get_json_result(data={"total": tol, "docs": docs})
|
| 428 |
+
|
| 429 |
+
except Exception as e:
|
| 430 |
+
return server_error_response(e)
|
api/db/services/document_service.py
CHANGED
|
@@ -179,6 +179,17 @@ class DocumentService(CommonService):
|
|
| 179 |
return
|
| 180 |
return docs[0]["tenant_id"]
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
@classmethod
|
| 183 |
@DB.connection_context()
|
| 184 |
def get_thumbnails(cls, docids):
|
|
|
|
| 179 |
return
|
| 180 |
return docs[0]["tenant_id"]
|
| 181 |
|
| 182 |
+
@classmethod
|
| 183 |
+
@DB.connection_context()
|
| 184 |
+
def get_doc_id_by_doc_name(cls, doc_name):
|
| 185 |
+
fields = [cls.model.id]
|
| 186 |
+
doc_id = cls.model.select(*fields) \
|
| 187 |
+
.where(cls.model.name == doc_name)
|
| 188 |
+
doc_id = doc_id.dicts()
|
| 189 |
+
if not doc_id:
|
| 190 |
+
return
|
| 191 |
+
return doc_id[0]["id"]
|
| 192 |
+
|
| 193 |
@classmethod
|
| 194 |
@DB.connection_context()
|
| 195 |
def get_thumbnails(cls, docids):
|
rag/nlp/search.py
CHANGED
|
@@ -407,3 +407,13 @@ class Dealer:
|
|
| 407 |
except Exception as e:
|
| 408 |
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
|
| 409 |
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
except Exception as e:
|
| 408 |
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
|
| 409 |
return {"error": str(e)}
|
| 410 |
+
|
| 411 |
+
def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
| 412 |
+
s = Search()
|
| 413 |
+
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
| 414 |
+
s = s.to_dict()
|
| 415 |
+
es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields)
|
| 416 |
+
res = []
|
| 417 |
+
for index, chunk in enumerate(es_res['hits']['hits']):
|
| 418 |
+
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
| 419 |
+
return res
|