KevinHuSh
		
	commited on
		
		
					Commit 
							
							·
						
						83a0020
	
1
								Parent(s):
							
							87a2c48
								
refactor (#1124)
Browse files### What problem does this PR solve?
### Type of change
- [x] Refactoring
- api/apps/__init__.py +0 -1
- api/apps/document_app.py +63 -66
- api/db/services/dialog_service.py +4 -3
- api/utils/api_utils.py +2 -1
- api/utils/log_utils.py +0 -5
- api/utils/web_utils.py +0 -2
- rag/llm/embedding_model.py +4 -5
    	
        api/apps/__init__.py
    CHANGED
    
    | @@ -85,7 +85,6 @@ def register_page(page_path): | |
| 85 | 
             
                url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}'
         | 
| 86 |  | 
| 87 | 
             
                app.register_blueprint(page.manager, url_prefix=url_prefix)
         | 
| 88 | 
            -
                print(f'API file: {page_path}, URL: {url_prefix}')
         | 
| 89 | 
             
                return url_prefix
         | 
| 90 |  | 
| 91 |  | 
|  | |
| 85 | 
             
                url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}'
         | 
| 86 |  | 
| 87 | 
             
                app.register_blueprint(page.manager, url_prefix=url_prefix)
         | 
|  | |
| 88 | 
             
                return url_prefix
         | 
| 89 |  | 
| 90 |  | 
    	
        api/apps/document_app.py
    CHANGED
    
    | @@ -40,6 +40,7 @@ from api.utils.api_utils import get_json_result | |
| 40 | 
             
            from rag.utils.minio_conn import MINIO
         | 
| 41 | 
             
            from api.utils.file_utils import filename_type, thumbnail
         | 
| 42 | 
             
            from api.utils.web_utils import html2pdf, is_valid_url
         | 
|  | |
| 43 |  | 
| 44 |  | 
| 45 | 
             
            @manager.route('/upload', methods=['POST'])
         | 
| @@ -117,6 +118,68 @@ def upload(): | |
| 117 | 
             
                return get_json_result(data=True)
         | 
| 118 |  | 
| 119 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 120 | 
             
            @manager.route('/create', methods=['POST'])
         | 
| 121 | 
             
            @login_required
         | 
| 122 | 
             
            @validate_request("name", "kb_id")
         | 
| @@ -417,69 +480,3 @@ def get_image(image_id): | |
| 417 | 
             
                    return response
         | 
| 418 | 
             
                except Exception as e:
         | 
| 419 | 
             
                    return server_error_response(e)
         | 
| 420 | 
            -
             | 
| 421 | 
            -
             | 
| 422 | 
            -
            @manager.route('/web_crawl', methods=['POST'])
         | 
| 423 | 
            -
            @login_required
         | 
| 424 | 
            -
            def web_crawl():
         | 
| 425 | 
            -
                kb_id = request.form.get("kb_id")
         | 
| 426 | 
            -
                if not kb_id:
         | 
| 427 | 
            -
                    return get_json_result(
         | 
| 428 | 
            -
                        data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
         | 
| 429 | 
            -
                name = request.form.get("name")
         | 
| 430 | 
            -
                url = request.form.get("url")
         | 
| 431 | 
            -
                if not name:
         | 
| 432 | 
            -
                    return get_json_result(
         | 
| 433 | 
            -
                        data=False, retmsg='Lack of "name"', retcode=RetCode.ARGUMENT_ERROR)
         | 
| 434 | 
            -
                if not url:
         | 
| 435 | 
            -
                    return get_json_result(
         | 
| 436 | 
            -
                        data=False, retmsg='Lack of "url"', retcode=RetCode.ARGUMENT_ERROR)
         | 
| 437 | 
            -
                if not is_valid_url(url):
         | 
| 438 | 
            -
                    return get_json_result(
         | 
| 439 | 
            -
                        data=False, retmsg='The URL format is invalid', retcode=RetCode.ARGUMENT_ERROR)
         | 
| 440 | 
            -
                e, kb = KnowledgebaseService.get_by_id(kb_id)
         | 
| 441 | 
            -
                if not e:
         | 
| 442 | 
            -
                    raise LookupError("Can't find this knowledgebase!")
         | 
| 443 | 
            -
             | 
| 444 | 
            -
                root_folder = FileService.get_root_folder(current_user.id)
         | 
| 445 | 
            -
                pf_id = root_folder["id"]
         | 
| 446 | 
            -
                FileService.init_knowledgebase_docs(pf_id, current_user.id)
         | 
| 447 | 
            -
                kb_root_folder = FileService.get_kb_folder(current_user.id)
         | 
| 448 | 
            -
                kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
         | 
| 449 | 
            -
             | 
| 450 | 
            -
                try:
         | 
| 451 | 
            -
                    filename = duplicate_name(
         | 
| 452 | 
            -
                        DocumentService.query,
         | 
| 453 | 
            -
                        name=name+".pdf",
         | 
| 454 | 
            -
                        kb_id=kb.id)
         | 
| 455 | 
            -
                    filetype = filename_type(filename)
         | 
| 456 | 
            -
                    if filetype == FileType.OTHER.value:
         | 
| 457 | 
            -
                        raise RuntimeError("This type of file has not been supported yet!")
         | 
| 458 | 
            -
             | 
| 459 | 
            -
                    location = filename
         | 
| 460 | 
            -
                    while MINIO.obj_exist(kb_id, location):
         | 
| 461 | 
            -
                        location += "_"
         | 
| 462 | 
            -
                    blob = html2pdf(url)
         | 
| 463 | 
            -
                    MINIO.put(kb_id, location, blob)
         | 
| 464 | 
            -
                    doc = {
         | 
| 465 | 
            -
                        "id": get_uuid(),
         | 
| 466 | 
            -
                        "kb_id": kb.id,
         | 
| 467 | 
            -
                        "parser_id": kb.parser_id,
         | 
| 468 | 
            -
                        "parser_config": kb.parser_config,
         | 
| 469 | 
            -
                        "created_by": current_user.id,
         | 
| 470 | 
            -
                        "type": filetype,
         | 
| 471 | 
            -
                        "name": filename,
         | 
| 472 | 
            -
                        "location": location,
         | 
| 473 | 
            -
                        "size": len(blob),
         | 
| 474 | 
            -
                        "thumbnail": thumbnail(filename, blob)
         | 
| 475 | 
            -
                    }
         | 
| 476 | 
            -
                    if doc["type"] == FileType.VISUAL:
         | 
| 477 | 
            -
                        doc["parser_id"] = ParserType.PICTURE.value
         | 
| 478 | 
            -
                    if re.search(r"\.(ppt|pptx|pages)$", filename):
         | 
| 479 | 
            -
                        doc["parser_id"] = ParserType.PRESENTATION.value
         | 
| 480 | 
            -
                    DocumentService.insert(doc)
         | 
| 481 | 
            -
                    FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
         | 
| 482 | 
            -
                except Exception as e:
         | 
| 483 | 
            -
                    return get_json_result(
         | 
| 484 | 
            -
                        data=False, retmsg=e, retcode=RetCode.SERVER_ERROR)
         | 
| 485 | 
            -
                return get_json_result(data=True)
         | 
|  | |
| 40 | 
             
            from rag.utils.minio_conn import MINIO
         | 
| 41 | 
             
            from api.utils.file_utils import filename_type, thumbnail
         | 
| 42 | 
             
            from api.utils.web_utils import html2pdf, is_valid_url
         | 
| 43 | 
            +
            from api.utils.web_utils import html2pdf, is_valid_url
         | 
| 44 |  | 
| 45 |  | 
| 46 | 
             
            @manager.route('/upload', methods=['POST'])
         | 
|  | |
| 118 | 
             
                return get_json_result(data=True)
         | 
| 119 |  | 
| 120 |  | 
| 121 | 
            +
            @manager.route('/web_crawl', methods=['POST'])
         | 
| 122 | 
            +
            @login_required
         | 
| 123 | 
            +
            @validate_request("kb_id", "name", "url")
         | 
| 124 | 
            +
            def web_crawl():
         | 
| 125 | 
            +
                kb_id = request.form.get("kb_id")
         | 
| 126 | 
            +
                if not kb_id:
         | 
| 127 | 
            +
                    return get_json_result(
         | 
| 128 | 
            +
                        data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
         | 
| 129 | 
            +
                name = request.form.get("name")
         | 
| 130 | 
            +
                url = request.form.get("url")
         | 
| 131 | 
            +
                if not is_valid_url(url):
         | 
| 132 | 
            +
                    return get_json_result(
         | 
| 133 | 
            +
                        data=False, retmsg='The URL format is invalid', retcode=RetCode.ARGUMENT_ERROR)
         | 
| 134 | 
            +
                e, kb = KnowledgebaseService.get_by_id(kb_id)
         | 
| 135 | 
            +
                if not e:
         | 
| 136 | 
            +
                    raise LookupError("Can't find this knowledgebase!")
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                blob = html2pdf(url)
         | 
| 139 | 
            +
                if not blob: return server_error_response(ValueError("Download failure."))
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                root_folder = FileService.get_root_folder(current_user.id)
         | 
| 142 | 
            +
                pf_id = root_folder["id"]
         | 
| 143 | 
            +
                FileService.init_knowledgebase_docs(pf_id, current_user.id)
         | 
| 144 | 
            +
                kb_root_folder = FileService.get_kb_folder(current_user.id)
         | 
| 145 | 
            +
                kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                try:
         | 
| 148 | 
            +
                    filename = duplicate_name(
         | 
| 149 | 
            +
                        DocumentService.query,
         | 
| 150 | 
            +
                        name=name+".pdf",
         | 
| 151 | 
            +
                        kb_id=kb.id)
         | 
| 152 | 
            +
                    filetype = filename_type(filename)
         | 
| 153 | 
            +
                    if filetype == FileType.OTHER.value:
         | 
| 154 | 
            +
                        raise RuntimeError("This type of file has not been supported yet!")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    location = filename
         | 
| 157 | 
            +
                    while MINIO.obj_exist(kb_id, location):
         | 
| 158 | 
            +
                        location += "_"
         | 
| 159 | 
            +
                    MINIO.put(kb_id, location, blob)
         | 
| 160 | 
            +
                    doc = {
         | 
| 161 | 
            +
                        "id": get_uuid(),
         | 
| 162 | 
            +
                        "kb_id": kb.id,
         | 
| 163 | 
            +
                        "parser_id": kb.parser_id,
         | 
| 164 | 
            +
                        "parser_config": kb.parser_config,
         | 
| 165 | 
            +
                        "created_by": current_user.id,
         | 
| 166 | 
            +
                        "type": filetype,
         | 
| 167 | 
            +
                        "name": filename,
         | 
| 168 | 
            +
                        "location": location,
         | 
| 169 | 
            +
                        "size": len(blob),
         | 
| 170 | 
            +
                        "thumbnail": thumbnail(filename, blob)
         | 
| 171 | 
            +
                    }
         | 
| 172 | 
            +
                    if doc["type"] == FileType.VISUAL:
         | 
| 173 | 
            +
                        doc["parser_id"] = ParserType.PICTURE.value
         | 
| 174 | 
            +
                    if re.search(r"\.(ppt|pptx|pages)$", filename):
         | 
| 175 | 
            +
                        doc["parser_id"] = ParserType.PRESENTATION.value
         | 
| 176 | 
            +
                    DocumentService.insert(doc)
         | 
| 177 | 
            +
                    FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
         | 
| 178 | 
            +
                except Exception as e:
         | 
| 179 | 
            +
                    return server_error_response(e)
         | 
| 180 | 
            +
                return get_json_result(data=True)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
             
            @manager.route('/create', methods=['POST'])
         | 
| 184 | 
             
            @login_required
         | 
| 185 | 
             
            @validate_request("name", "kb_id")
         | 
|  | |
| 480 | 
             
                    return response
         | 
| 481 | 
             
                except Exception as e:
         | 
| 482 | 
             
                    return server_error_response(e)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        api/db/services/dialog_service.py
    CHANGED
    
    | @@ -112,14 +112,15 @@ def chat(dialog, messages, stream=True, **kwargs): | |
| 112 | 
             
                        prompt_config["system"] = prompt_config["system"].replace(
         | 
| 113 | 
             
                            "{%s}" % p["key"], " ")
         | 
| 114 |  | 
|  | |
|  | |
|  | |
|  | |
| 115 | 
             
                for _ in range(len(questions) // 2):
         | 
| 116 | 
             
                    questions.append(questions[-1])
         | 
| 117 | 
             
                if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
         | 
| 118 | 
             
                    kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
         | 
| 119 | 
             
                else:
         | 
| 120 | 
            -
                    rerank_mdl = None
         | 
| 121 | 
            -
                    if dialog.rerank_id:
         | 
| 122 | 
            -
                        rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
         | 
| 123 | 
             
                    kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
         | 
| 124 | 
             
                                                    dialog.similarity_threshold,
         | 
| 125 | 
             
                                                    dialog.vector_similarity_weight,
         | 
|  | |
| 112 | 
             
                        prompt_config["system"] = prompt_config["system"].replace(
         | 
| 113 | 
             
                            "{%s}" % p["key"], " ")
         | 
| 114 |  | 
| 115 | 
            +
                rerank_mdl = None
         | 
| 116 | 
            +
                if dialog.rerank_id:
         | 
| 117 | 
            +
                    rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
         | 
| 118 | 
            +
             | 
| 119 | 
             
                for _ in range(len(questions) // 2):
         | 
| 120 | 
             
                    questions.append(questions[-1])
         | 
| 121 | 
             
                if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
         | 
| 122 | 
             
                    kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
         | 
| 123 | 
             
                else:
         | 
|  | |
|  | |
|  | |
| 124 | 
             
                    kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
         | 
| 125 | 
             
                                                    dialog.similarity_threshold,
         | 
| 126 | 
             
                                                    dialog.vector_similarity_weight,
         | 
    	
        api/utils/api_utils.py
    CHANGED
    
    | @@ -248,11 +248,12 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): | |
| 248 |  | 
| 249 |  | 
| 250 | 
             
            def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
         | 
| 251 | 
            -
                if data  | 
| 252 | 
             
                    return jsonify({"code": code, "message": message})
         | 
| 253 | 
             
                else:
         | 
| 254 | 
             
                    return jsonify({"code": code, "message": message, "data": data})
         | 
| 255 |  | 
|  | |
| 256 | 
             
            def construct_error_response(e):
         | 
| 257 | 
             
                stat_logger.exception(e)
         | 
| 258 | 
             
                try:
         | 
|  | |
| 248 |  | 
| 249 |  | 
| 250 | 
             
            def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
         | 
| 251 | 
            +
                if data is None:
         | 
| 252 | 
             
                    return jsonify({"code": code, "message": message})
         | 
| 253 | 
             
                else:
         | 
| 254 | 
             
                    return jsonify({"code": code, "message": message, "data": data})
         | 
| 255 |  | 
| 256 | 
            +
             | 
| 257 | 
             
            def construct_error_response(e):
         | 
| 258 | 
             
                stat_logger.exception(e)
         | 
| 259 | 
             
                try:
         | 
    	
        api/utils/log_utils.py
    CHANGED
    
    | @@ -154,11 +154,6 @@ class LoggerFactory(object): | |
| 154 | 
             
                                                           delay=True)
         | 
| 155 | 
             
                    if level:
         | 
| 156 | 
             
                        handler.level = level
         | 
| 157 | 
            -
                    else:
         | 
| 158 | 
            -
                        handler.level = LoggerFactory.LEVEL
         | 
| 159 | 
            -
             | 
| 160 | 
            -
                    formatter = logging.Formatter(LoggerFactory.LOG_FORMAT)
         | 
| 161 | 
            -
                    handler.setFormatter(formatter)
         | 
| 162 |  | 
| 163 | 
             
                    return handler
         | 
| 164 |  | 
|  | |
| 154 | 
             
                                                           delay=True)
         | 
| 155 | 
             
                    if level:
         | 
| 156 | 
             
                        handler.level = level
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 157 |  | 
| 158 | 
             
                    return handler
         | 
| 159 |  | 
    	
        api/utils/web_utils.py
    CHANGED
    
    | @@ -78,5 +78,3 @@ def __get_pdf_from_html( | |
| 78 |  | 
| 79 | 
             
            def is_valid_url(url: str) -> bool:
         | 
| 80 | 
             
                return bool(re.match(r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url))
         | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
|  | |
| 78 |  | 
| 79 | 
             
            def is_valid_url(url: str) -> bool:
         | 
| 80 | 
             
                return bool(re.match(r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url))
         | 
|  | |
|  | 
    	
        rag/llm/embedding_model.py
    CHANGED
    
    | @@ -26,9 +26,8 @@ import dashscope | |
| 26 | 
             
            from openai import OpenAI
         | 
| 27 | 
             
            from FlagEmbedding import FlagModel
         | 
| 28 | 
             
            import torch
         | 
| 29 | 
            -
            import asyncio
         | 
| 30 | 
             
            import numpy as np
         | 
| 31 | 
            -
             | 
| 32 | 
             
            from api.utils.file_utils import get_home_cache_dir
         | 
| 33 | 
             
            from rag.utils import num_tokens_from_string, truncate
         | 
| 34 |  | 
| @@ -317,12 +316,12 @@ class InfinityEmbed(Base): | |
| 317 | 
             
                        engine_kwargs: dict = {},
         | 
| 318 | 
             
                        key = None,
         | 
| 319 | 
             
                ):
         | 
| 320 | 
            -
             | 
| 321 | 
             
                    from infinity_emb import EngineArgs
         | 
| 322 | 
             
                    from infinity_emb.engine import AsyncEngineArray
         | 
| 323 | 
            -
             | 
| 324 | 
             
                    self._default_model = model_names[0]
         | 
| 325 | 
            -
                    self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names]) | 
| 326 |  | 
| 327 | 
             
                async def _embed(self, sentences: list[str], model_name: str = ""):
         | 
| 328 | 
             
                    if not model_name:
         | 
|  | |
| 26 | 
             
            from openai import OpenAI
         | 
| 27 | 
             
            from FlagEmbedding import FlagModel
         | 
| 28 | 
             
            import torch
         | 
|  | |
| 29 | 
             
            import numpy as np
         | 
| 30 | 
            +
            import asyncio
         | 
| 31 | 
             
            from api.utils.file_utils import get_home_cache_dir
         | 
| 32 | 
             
            from rag.utils import num_tokens_from_string, truncate
         | 
| 33 |  | 
|  | |
| 316 | 
             
                        engine_kwargs: dict = {},
         | 
| 317 | 
             
                        key = None,
         | 
| 318 | 
             
                ):
         | 
| 319 | 
            +
             | 
| 320 | 
             
                    from infinity_emb import EngineArgs
         | 
| 321 | 
             
                    from infinity_emb.engine import AsyncEngineArray
         | 
| 322 | 
            +
             | 
| 323 | 
             
                    self._default_model = model_names[0]
         | 
| 324 | 
            +
                    self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
         | 
| 325 |  | 
| 326 | 
             
                async def _embed(self, sentences: list[str], model_name: str = ""):
         | 
| 327 | 
             
                    if not model_name:
         |