KevinHuSh
		
	commited on
		
		
					Commit 
							
							·
						
						4c52eb9
	
1
								Parent(s):
							
							3772f42
								
refine admin initialization (#75)
Browse files- api/apps/chunk_app.py +2 -2
- api/apps/conversation_app.py +1 -3
- api/db/init_data.py +42 -4
- api/settings.py +5 -1
- deepdoc/parser/pdf_parser.py +1 -1
- deepdoc/vision/layout_recognizer.py +1 -1
- deepdoc/vision/postprocess.py +2 -3
- deepdoc/vision/recognizer.py +12 -0
- deepdoc/vision/t_recognizer.py +3 -1
- deepdoc/vision/table_structure_recognizer.py +5 -5
- rag/llm/chat_model.py +11 -8
- rag/nlp/__init__.py +2 -3
- rag/nlp/search.py +4 -2
    	
        api/apps/chunk_app.py
    CHANGED
    
    | @@ -20,7 +20,7 @@ from flask_login import login_required, current_user | |
| 20 | 
             
            from elasticsearch_dsl import Q
         | 
| 21 |  | 
| 22 | 
             
            from rag.app.qa import rmPrefix, beAdoc
         | 
| 23 | 
            -
            from rag.nlp import search, huqie | 
| 24 | 
             
            from rag.utils import ELASTICSEARCH, rmSpace
         | 
| 25 | 
             
            from api.db import LLMType, ParserType
         | 
| 26 | 
             
            from api.db.services.knowledgebase_service import KnowledgebaseService
         | 
| @@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService | |
| 28 | 
             
            from api.db.services.user_service import UserTenantService
         | 
| 29 | 
             
            from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
         | 
| 30 | 
             
            from api.db.services.document_service import DocumentService
         | 
| 31 | 
            -
            from api.settings import RetCode
         | 
| 32 | 
             
            from api.utils.api_utils import get_json_result
         | 
| 33 | 
             
            import hashlib
         | 
| 34 | 
             
            import re
         | 
|  | |
| 20 | 
             
            from elasticsearch_dsl import Q
         | 
| 21 |  | 
| 22 | 
             
            from rag.app.qa import rmPrefix, beAdoc
         | 
| 23 | 
            +
            from rag.nlp import search, huqie
         | 
| 24 | 
             
            from rag.utils import ELASTICSEARCH, rmSpace
         | 
| 25 | 
             
            from api.db import LLMType, ParserType
         | 
| 26 | 
             
            from api.db.services.knowledgebase_service import KnowledgebaseService
         | 
|  | |
| 28 | 
             
            from api.db.services.user_service import UserTenantService
         | 
| 29 | 
             
            from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
         | 
| 30 | 
             
            from api.db.services.document_service import DocumentService
         | 
| 31 | 
            +
            from api.settings import RetCode, retrievaler
         | 
| 32 | 
             
            from api.utils.api_utils import get_json_result
         | 
| 33 | 
             
            import hashlib
         | 
| 34 | 
             
            import re
         | 
    	
        api/apps/conversation_app.py
    CHANGED
    
    | @@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService | |
| 21 | 
             
            from api.db import LLMType
         | 
| 22 | 
             
            from api.db.services.knowledgebase_service import KnowledgebaseService
         | 
| 23 | 
             
            from api.db.services.llm_service import LLMService, LLMBundle
         | 
| 24 | 
            -
            from api.settings import access_logger, stat_logger
         | 
| 25 | 
             
            from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
         | 
| 26 | 
             
            from api.utils import get_uuid
         | 
| 27 | 
             
            from api.utils.api_utils import get_json_result
         | 
| 28 | 
             
            from rag.app.resume import forbidden_select_fields4resume
         | 
| 29 | 
            -
            from rag.llm import ChatModel
         | 
| 30 | 
            -
            from rag.nlp import retrievaler
         | 
| 31 | 
             
            from rag.nlp.search import index_name
         | 
| 32 | 
             
            from rag.utils import num_tokens_from_string, encoder, rmSpace
         | 
| 33 |  | 
|  | |
| 21 | 
             
            from api.db import LLMType
         | 
| 22 | 
             
            from api.db.services.knowledgebase_service import KnowledgebaseService
         | 
| 23 | 
             
            from api.db.services.llm_service import LLMService, LLMBundle
         | 
| 24 | 
            +
            from api.settings import access_logger, stat_logger, retrievaler
         | 
| 25 | 
             
            from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
         | 
| 26 | 
             
            from api.utils import get_uuid
         | 
| 27 | 
             
            from api.utils.api_utils import get_json_result
         | 
| 28 | 
             
            from rag.app.resume import forbidden_select_fields4resume
         | 
|  | |
|  | |
| 29 | 
             
            from rag.nlp.search import index_name
         | 
| 30 | 
             
            from rag.utils import num_tokens_from_string, encoder, rmSpace
         | 
| 31 |  | 
    	
        api/db/init_data.py
    CHANGED
    
    | @@ -16,10 +16,12 @@ | |
| 16 | 
             
            import time
         | 
| 17 | 
             
            import uuid
         | 
| 18 |  | 
| 19 | 
            -
            from api.db import LLMType
         | 
| 20 | 
             
            from api.db.db_models import init_database_tables as init_web_db
         | 
| 21 | 
             
            from api.db.services import UserService
         | 
| 22 | 
            -
            from api.db.services.llm_service import LLMFactoriesService, LLMService
         | 
|  | |
|  | |
| 23 |  | 
| 24 |  | 
| 25 | 
             
            def init_superuser():
         | 
| @@ -32,8 +34,44 @@ def init_superuser(): | |
| 32 | 
             
                    "creator": "system",
         | 
| 33 | 
             
                    "status": "1",
         | 
| 34 | 
             
                }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 35 | 
             
                UserService.save(**user_info)
         | 
| 36 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 |  | 
| 38 | 
             
            def init_llm_factory():
         | 
| 39 | 
             
                factory_infos = [{
         | 
| @@ -171,10 +209,10 @@ def init_llm_factory(): | |
| 171 |  | 
| 172 | 
             
            def init_web_data():
         | 
| 173 | 
             
                start_time = time.time()
         | 
| 174 | 
            -
                if not UserService.get_all().count():
         | 
| 175 | 
            -
                    init_superuser()
         | 
| 176 |  | 
| 177 | 
             
                if not LLMService.get_all().count():init_llm_factory()
         | 
|  | |
|  | |
| 178 |  | 
| 179 | 
             
                print("init web data success:{}".format(time.time() - start_time))
         | 
| 180 |  | 
|  | |
| 16 | 
             
            import time
         | 
| 17 | 
             
            import uuid
         | 
| 18 |  | 
| 19 | 
            +
            from api.db import LLMType, UserTenantRole
         | 
| 20 | 
             
            from api.db.db_models import init_database_tables as init_web_db
         | 
| 21 | 
             
            from api.db.services import UserService
         | 
| 22 | 
            +
            from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
         | 
| 23 | 
            +
            from api.db.services.user_service import TenantService, UserTenantService
         | 
| 24 | 
            +
            from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
             
            def init_superuser():
         | 
|  | |
| 34 | 
             
                    "creator": "system",
         | 
| 35 | 
             
                    "status": "1",
         | 
| 36 | 
             
                }
         | 
| 37 | 
            +
                tenant = {
         | 
| 38 | 
            +
                    "id": user_info["id"],
         | 
| 39 | 
            +
                    "name": user_info["nickname"] + "‘s Kingdom",
         | 
| 40 | 
            +
                    "llm_id": CHAT_MDL,
         | 
| 41 | 
            +
                    "embd_id": EMBEDDING_MDL,
         | 
| 42 | 
            +
                    "asr_id": ASR_MDL,
         | 
| 43 | 
            +
                    "parser_ids": PARSERS,
         | 
| 44 | 
            +
                    "img2txt_id": IMAGE2TEXT_MDL
         | 
| 45 | 
            +
                }
         | 
| 46 | 
            +
                usr_tenant = {
         | 
| 47 | 
            +
                    "tenant_id": user_info["id"],
         | 
| 48 | 
            +
                    "user_id": user_info["id"],
         | 
| 49 | 
            +
                    "invited_by": user_info["id"],
         | 
| 50 | 
            +
                    "role": UserTenantRole.OWNER
         | 
| 51 | 
            +
                }
         | 
| 52 | 
            +
                tenant_llm = []
         | 
| 53 | 
            +
                for llm in LLMService.query(fid=LLM_FACTORY):
         | 
| 54 | 
            +
                    tenant_llm.append(
         | 
| 55 | 
            +
                        {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
         | 
| 56 | 
            +
                         "api_key": API_KEY})
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                if not UserService.save(**user_info):
         | 
| 59 | 
            +
                    print("【ERROR】can't init admin.")
         | 
| 60 | 
            +
                    return
         | 
| 61 | 
            +
                TenantService.save(**tenant)
         | 
| 62 | 
            +
                UserTenantService.save(**usr_tenant)
         | 
| 63 | 
            +
                TenantLLMService.insert_many(tenant_llm)
         | 
| 64 | 
             
                UserService.save(**user_info)
         | 
| 65 |  | 
| 66 | 
            +
                chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
         | 
| 67 | 
            +
                msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
         | 
| 68 | 
            +
                if msg.find("ERROR: ") == 0:
         | 
| 69 | 
            +
                    print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
         | 
| 70 | 
            +
                embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
         | 
| 71 | 
            +
                v,c = embd_mdl.encode(["Hello!"])
         | 
| 72 | 
            +
                if c == 0:
         | 
| 73 | 
            +
                    print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
         | 
| 74 | 
            +
             | 
| 75 |  | 
| 76 | 
             
            def init_llm_factory():
         | 
| 77 | 
             
                factory_infos = [{
         | 
|  | |
| 209 |  | 
| 210 | 
             
            def init_web_data():
         | 
| 211 | 
             
                start_time = time.time()
         | 
|  | |
|  | |
| 212 |  | 
| 213 | 
             
                if not LLMService.get_all().count():init_llm_factory()
         | 
| 214 | 
            +
                if not UserService.get_all().count():
         | 
| 215 | 
            +
                    init_superuser()
         | 
| 216 |  | 
| 217 | 
             
                print("init web data success:{}".format(time.time() - start_time))
         | 
| 218 |  | 
    	
        api/settings.py
    CHANGED
    
    | @@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config | |
| 21 | 
             
            from api.utils.file_utils import get_project_base_directory
         | 
| 22 | 
             
            from api.utils.log_utils import LoggerFactory, getLogger
         | 
| 23 |  | 
|  | |
|  | |
|  | |
| 24 |  | 
| 25 | 
            -
            # Server
         | 
| 26 | 
             
            API_VERSION = "v1"
         | 
| 27 | 
             
            RAG_FLOW_SERVICE_NAME = "ragflow"
         | 
| 28 | 
             
            SERVER_MODULE = "rag_flow_server.py"
         | 
| @@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s | |
| 116 | 
             
            PRIVILEGE_COMMAND_WHITELIST = []
         | 
| 117 | 
             
            CHECK_NODES_IDENTITY = False
         | 
| 118 |  | 
|  | |
|  | |
| 119 | 
             
            class CustomEnum(Enum):
         | 
| 120 | 
             
                @classmethod
         | 
| 121 | 
             
                def valid(cls, value):
         | 
|  | |
| 21 | 
             
            from api.utils.file_utils import get_project_base_directory
         | 
| 22 | 
             
            from api.utils.log_utils import LoggerFactory, getLogger
         | 
| 23 |  | 
| 24 | 
            +
            from rag.nlp import search
         | 
| 25 | 
            +
            from rag.utils import ELASTICSEARCH
         | 
| 26 | 
            +
             | 
| 27 |  | 
|  | |
| 28 | 
             
            API_VERSION = "v1"
         | 
| 29 | 
             
            RAG_FLOW_SERVICE_NAME = "ragflow"
         | 
| 30 | 
             
            SERVER_MODULE = "rag_flow_server.py"
         | 
|  | |
| 118 | 
             
            PRIVILEGE_COMMAND_WHITELIST = []
         | 
| 119 | 
             
            CHECK_NODES_IDENTITY = False
         | 
| 120 |  | 
| 121 | 
            +
            retrievaler = search.Dealer(ELASTICSEARCH)
         | 
| 122 | 
            +
             | 
| 123 | 
             
            class CustomEnum(Enum):
         | 
| 124 | 
             
                @classmethod
         | 
| 125 | 
             
                def valid(cls, value):
         | 
    	
        deepdoc/parser/pdf_parser.py
    CHANGED
    
    | @@ -230,7 +230,7 @@ class HuParser: | |
| 230 | 
             
                            b["H_right"] = headers[ii]["x1"]
         | 
| 231 | 
             
                            b["H"] = ii
         | 
| 232 |  | 
| 233 | 
            -
                        ii = Recognizer. | 
| 234 | 
             
                        if ii is not None:
         | 
| 235 | 
             
                            b["C"] = ii
         | 
| 236 | 
             
                            b["C_left"] = clmns[ii]["x0"]
         | 
|  | |
| 230 | 
             
                            b["H_right"] = headers[ii]["x1"]
         | 
| 231 | 
             
                            b["H"] = ii
         | 
| 232 |  | 
| 233 | 
            +
                        ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
         | 
| 234 | 
             
                        if ii is not None:
         | 
| 235 | 
             
                            b["C"] = ii
         | 
| 236 | 
             
                            b["C_left"] = clmns[ii]["x0"]
         | 
    	
        deepdoc/vision/layout_recognizer.py
    CHANGED
    
    | @@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer): | |
| 37 | 
             
                    super().__init__(self.labels, domain,
         | 
| 38 | 
             
                                     os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
         | 
| 39 |  | 
| 40 | 
            -
                def __call__(self, image_list, ocr_res, scale_factor=3, thr=0. | 
| 41 | 
             
                    def __is_garbage(b):
         | 
| 42 | 
             
                        patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
         | 
| 43 | 
             
                                r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
         | 
|  | |
| 37 | 
             
                    super().__init__(self.labels, domain,
         | 
| 38 | 
             
                                     os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
         | 
| 39 |  | 
| 40 | 
            +
                def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
         | 
| 41 | 
             
                    def __is_garbage(b):
         | 
| 42 | 
             
                        patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
         | 
| 43 | 
             
                                r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
         | 
    	
        deepdoc/vision/postprocess.py
    CHANGED
    
    | @@ -2,7 +2,6 @@ import copy | |
| 2 |  | 
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import cv2
         | 
| 5 | 
            -
            import paddle
         | 
| 6 | 
             
            from shapely.geometry import Polygon
         | 
| 7 | 
             
            import pyclipper
         | 
| 8 |  | 
| @@ -215,7 +214,7 @@ class DBPostProcess(object): | |
| 215 |  | 
| 216 | 
             
                def __call__(self, outs_dict, shape_list):
         | 
| 217 | 
             
                    pred = outs_dict['maps']
         | 
| 218 | 
            -
                    if isinstance(pred,  | 
| 219 | 
             
                        pred = pred.numpy()
         | 
| 220 | 
             
                    pred = pred[:, 0, :, :]
         | 
| 221 | 
             
                    segmentation = pred > self.thresh
         | 
| @@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode): | |
| 339 | 
             
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 340 | 
             
                    if isinstance(preds, tuple) or isinstance(preds, list):
         | 
| 341 | 
             
                        preds = preds[-1]
         | 
| 342 | 
            -
                    if isinstance(preds,  | 
| 343 | 
             
                        preds = preds.numpy()
         | 
| 344 | 
             
                    preds_idx = preds.argmax(axis=2)
         | 
| 345 | 
             
                    preds_prob = preds.max(axis=2)
         | 
|  | |
| 2 |  | 
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import cv2
         | 
|  | |
| 5 | 
             
            from shapely.geometry import Polygon
         | 
| 6 | 
             
            import pyclipper
         | 
| 7 |  | 
|  | |
| 214 |  | 
| 215 | 
             
                def __call__(self, outs_dict, shape_list):
         | 
| 216 | 
             
                    pred = outs_dict['maps']
         | 
| 217 | 
            +
                    if not isinstance(pred, np.ndarray):
         | 
| 218 | 
             
                        pred = pred.numpy()
         | 
| 219 | 
             
                    pred = pred[:, 0, :, :]
         | 
| 220 | 
             
                    segmentation = pred > self.thresh
         | 
|  | |
| 338 | 
             
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 339 | 
             
                    if isinstance(preds, tuple) or isinstance(preds, list):
         | 
| 340 | 
             
                        preds = preds[-1]
         | 
| 341 | 
            +
                    if not isinstance(preds, np.ndarray):
         | 
| 342 | 
             
                        preds = preds.numpy()
         | 
| 343 | 
             
                    preds_idx = preds.argmax(axis=2)
         | 
| 344 | 
             
                    preds_prob = preds.max(axis=2)
         | 
    	
        deepdoc/vision/recognizer.py
    CHANGED
    
    | @@ -259,6 +259,18 @@ class Recognizer(object): | |
| 259 |  | 
| 260 | 
             
                    return max_overlaped_i
         | 
| 261 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 262 | 
             
                @staticmethod
         | 
| 263 | 
             
                def find_overlapped_with_threashold(box, boxes, thr=0.3):
         | 
| 264 | 
             
                    if not boxes:
         | 
|  | |
| 259 |  | 
| 260 | 
             
                    return max_overlaped_i
         | 
| 261 |  | 
| 262 | 
            +
                @staticmethod
         | 
| 263 | 
            +
                def find_horizontally_tightest_fit(box, boxes):
         | 
| 264 | 
            +
                    if not boxes:
         | 
| 265 | 
            +
                        return
         | 
| 266 | 
            +
                    min_dis, min_i = 1000000, None
         | 
| 267 | 
            +
                    for i,b in enumerate(boxes):
         | 
| 268 | 
            +
                        dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
         | 
| 269 | 
            +
                        if dis < min_dis:
         | 
| 270 | 
            +
                            min_i = i
         | 
| 271 | 
            +
                            min_dis = dis
         | 
| 272 | 
            +
                    return min_i
         | 
| 273 | 
            +
             | 
| 274 | 
             
                @staticmethod
         | 
| 275 | 
             
                def find_overlapped_with_threashold(box, boxes, thr=0.3):
         | 
| 276 | 
             
                    if not boxes:
         | 
    	
        deepdoc/vision/t_recognizer.py
    CHANGED
    
    | @@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr): | |
| 74 | 
             
                clmns = sorted([r for r in tb_cpns if re.match(
         | 
| 75 | 
             
                    r"table column$", r["label"])], key=lambda x: x["x0"])
         | 
| 76 | 
             
                clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
         | 
|  | |
| 77 | 
             
                for b in boxes:
         | 
| 78 | 
             
                    ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
         | 
| 79 | 
             
                    if ii is not None:
         | 
| @@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr): | |
| 89 | 
             
                        b["H_right"] = headers[ii]["x1"]
         | 
| 90 | 
             
                        b["H"] = ii
         | 
| 91 |  | 
| 92 | 
            -
                    ii = Recognizer. | 
| 93 | 
             
                    if ii is not None:
         | 
| 94 | 
             
                        b["C"] = ii
         | 
| 95 | 
             
                        b["C_left"] = clmns[ii]["x0"]
         | 
| @@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr): | |
| 102 | 
             
                        b["H_left"] = spans[ii]["x0"]
         | 
| 103 | 
             
                        b["H_right"] = spans[ii]["x1"]
         | 
| 104 | 
             
                        b["SP"] = ii
         | 
|  | |
| 105 | 
             
                html = """
         | 
| 106 | 
             
                <html>
         | 
| 107 | 
             
                <head>
         | 
|  | |
| 74 | 
             
                clmns = sorted([r for r in tb_cpns if re.match(
         | 
| 75 | 
             
                    r"table column$", r["label"])], key=lambda x: x["x0"])
         | 
| 76 | 
             
                clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
         | 
| 77 | 
            +
             | 
| 78 | 
             
                for b in boxes:
         | 
| 79 | 
             
                    ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
         | 
| 80 | 
             
                    if ii is not None:
         | 
|  | |
| 90 | 
             
                        b["H_right"] = headers[ii]["x1"]
         | 
| 91 | 
             
                        b["H"] = ii
         | 
| 92 |  | 
| 93 | 
            +
                    ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
         | 
| 94 | 
             
                    if ii is not None:
         | 
| 95 | 
             
                        b["C"] = ii
         | 
| 96 | 
             
                        b["C_left"] = clmns[ii]["x0"]
         | 
|  | |
| 103 | 
             
                        b["H_left"] = spans[ii]["x0"]
         | 
| 104 | 
             
                        b["H_right"] = spans[ii]["x1"]
         | 
| 105 | 
             
                        b["SP"] = ii
         | 
| 106 | 
            +
             | 
| 107 | 
             
                html = """
         | 
| 108 | 
             
                <html>
         | 
| 109 | 
             
                <head>
         | 
    	
        deepdoc/vision/table_structure_recognizer.py
    CHANGED
    
    | @@ -14,7 +14,6 @@ import logging | |
| 14 | 
             
            import os
         | 
| 15 | 
             
            import re
         | 
| 16 | 
             
            from collections import Counter
         | 
| 17 | 
            -
            from copy import deepcopy
         | 
| 18 |  | 
| 19 | 
             
            import numpy as np
         | 
| 20 |  | 
| @@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer): | |
| 37 | 
             
                    super().__init__(self.labels, "tsr",
         | 
| 38 | 
             
                                     os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
         | 
| 39 |  | 
| 40 | 
            -
                def __call__(self, images, thr=0. | 
| 41 | 
             
                    tbls = super().__call__(images, thr)
         | 
| 42 | 
             
                    res = []
         | 
| 43 | 
             
                    # align left&right for rows, align top&bottom for columns
         | 
| @@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer): | |
| 56 | 
             
                            "row") > 0 or b["label"].find("header") > 0]
         | 
| 57 | 
             
                        if not left:
         | 
| 58 | 
             
                            continue
         | 
| 59 | 
            -
                        left = np. | 
| 60 | 
            -
                        right = np. | 
| 61 | 
             
                        for b in lts:
         | 
| 62 | 
             
                            if b["label"].find("row") > 0 or b["label"].find("header") > 0:
         | 
| 63 | 
             
                                if b["x0"] > left:
         | 
| @@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer): | |
| 129 | 
             
                    i = 0
         | 
| 130 | 
             
                    while i < len(boxes):
         | 
| 131 | 
             
                        if TableStructureRecognizer.is_caption(boxes[i]):
         | 
|  | |
| 132 | 
             
                            cap += boxes[i]["text"]
         | 
| 133 | 
             
                            boxes.pop(i)
         | 
| 134 | 
             
                            i -= 1
         | 
| @@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer): | |
| 398 | 
             
                        for i in range(clmno):
         | 
| 399 | 
             
                            if not tbl[r][i]:
         | 
| 400 | 
             
                                continue
         | 
| 401 | 
            -
                            txt = "".join([a["text"].strip() for a in tbl[r][i]])
         | 
| 402 | 
             
                            headers[r][i] = txt
         | 
| 403 | 
             
                            hdrset.add(txt)
         | 
| 404 | 
             
                        if all([not t for t in headers[r]]):
         | 
|  | |
| 14 | 
             
            import os
         | 
| 15 | 
             
            import re
         | 
| 16 | 
             
            from collections import Counter
         | 
|  | |
| 17 |  | 
| 18 | 
             
            import numpy as np
         | 
| 19 |  | 
|  | |
| 36 | 
             
                    super().__init__(self.labels, "tsr",
         | 
| 37 | 
             
                                     os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
         | 
| 38 |  | 
| 39 | 
            +
                def __call__(self, images, thr=0.2):
         | 
| 40 | 
             
                    tbls = super().__call__(images, thr)
         | 
| 41 | 
             
                    res = []
         | 
| 42 | 
             
                    # align left&right for rows, align top&bottom for columns
         | 
|  | |
| 55 | 
             
                            "row") > 0 or b["label"].find("header") > 0]
         | 
| 56 | 
             
                        if not left:
         | 
| 57 | 
             
                            continue
         | 
| 58 | 
            +
                        left = np.mean(left) if len(left) > 4 else np.min(left)
         | 
| 59 | 
            +
                        right = np.mean(right) if len(right) > 4 else np.max(right)
         | 
| 60 | 
             
                        for b in lts:
         | 
| 61 | 
             
                            if b["label"].find("row") > 0 or b["label"].find("header") > 0:
         | 
| 62 | 
             
                                if b["x0"] > left:
         | 
|  | |
| 128 | 
             
                    i = 0
         | 
| 129 | 
             
                    while i < len(boxes):
         | 
| 130 | 
             
                        if TableStructureRecognizer.is_caption(boxes[i]):
         | 
| 131 | 
            +
                            if is_english: cap + " "
         | 
| 132 | 
             
                            cap += boxes[i]["text"]
         | 
| 133 | 
             
                            boxes.pop(i)
         | 
| 134 | 
             
                            i -= 1
         | 
|  | |
| 398 | 
             
                        for i in range(clmno):
         | 
| 399 | 
             
                            if not tbl[r][i]:
         | 
| 400 | 
             
                                continue
         | 
| 401 | 
            +
                            txt = " ".join([a["text"].strip() for a in tbl[r][i]])
         | 
| 402 | 
             
                            headers[r][i] = txt
         | 
| 403 | 
             
                            hdrset.add(txt)
         | 
| 404 | 
             
                        if all([not t for t in headers[r]]):
         | 
    	
        rag/llm/chat_model.py
    CHANGED
    
    | @@ -15,7 +15,7 @@ | |
| 15 | 
             
            #
         | 
| 16 | 
             
            from abc import ABC
         | 
| 17 | 
             
            from openai import OpenAI
         | 
| 18 | 
            -
            import  | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            class Base(ABC):
         | 
| @@ -33,11 +33,14 @@ class GptTurbo(Base): | |
| 33 |  | 
| 34 | 
             
                def chat(self, system, history, gen_conf):
         | 
| 35 | 
             
                    if system: history.insert(0, {"role": "system", "content": system})
         | 
| 36 | 
            -
                     | 
| 37 | 
            -
                         | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
|  | |
|  | |
|  | |
| 41 |  | 
| 42 |  | 
| 43 | 
             
            from dashscope import Generation
         | 
| @@ -58,7 +61,7 @@ class QWenChat(Base): | |
| 58 | 
             
                    )
         | 
| 59 | 
             
                    if response.status_code == HTTPStatus.OK:
         | 
| 60 | 
             
                        return response.output.choices[0]['message']['content'], response.usage.output_tokens
         | 
| 61 | 
            -
                    return response.message, 0
         | 
| 62 |  | 
| 63 |  | 
| 64 | 
             
            from zhipuai import ZhipuAI
         | 
| @@ -77,4 +80,4 @@ class ZhipuChat(Base): | |
| 77 | 
             
                    )
         | 
| 78 | 
             
                    if response.status_code == HTTPStatus.OK:
         | 
| 79 | 
             
                        return response.output.choices[0]['message']['content'], response.usage.completion_tokens
         | 
| 80 | 
            -
                    return response.message, 0
         | 
|  | |
| 15 | 
             
            #
         | 
| 16 | 
             
            from abc import ABC
         | 
| 17 | 
             
            from openai import OpenAI
         | 
| 18 | 
            +
            import openai
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            class Base(ABC):
         | 
|  | |
| 33 |  | 
| 34 | 
             
                def chat(self, system, history, gen_conf):
         | 
| 35 | 
             
                    if system: history.insert(0, {"role": "system", "content": system})
         | 
| 36 | 
            +
                    try:
         | 
| 37 | 
            +
                        res = self.client.chat.completions.create(
         | 
| 38 | 
            +
                            model=self.model_name,
         | 
| 39 | 
            +
                            messages=history,
         | 
| 40 | 
            +
                            **gen_conf)
         | 
| 41 | 
            +
                        return res.choices[0].message.content.strip(), res.usage.completion_tokens
         | 
| 42 | 
            +
                    except openai.APIError as e:
         | 
| 43 | 
            +
                        return "ERROR: "+str(e), 0
         | 
| 44 |  | 
| 45 |  | 
| 46 | 
             
            from dashscope import Generation
         | 
|  | |
| 61 | 
             
                    )
         | 
| 62 | 
             
                    if response.status_code == HTTPStatus.OK:
         | 
| 63 | 
             
                        return response.output.choices[0]['message']['content'], response.usage.output_tokens
         | 
| 64 | 
            +
                    return "ERROR: " + response.message, 0
         | 
| 65 |  | 
| 66 |  | 
| 67 | 
             
            from zhipuai import ZhipuAI
         | 
|  | |
| 80 | 
             
                    )
         | 
| 81 | 
             
                    if response.status_code == HTTPStatus.OK:
         | 
| 82 | 
             
                        return response.output.choices[0]['message']['content'], response.usage.completion_tokens
         | 
| 83 | 
            +
                    return "ERROR: " + response.message, 0
         | 
    	
        rag/nlp/__init__.py
    CHANGED
    
    | @@ -1,7 +1,4 @@ | |
| 1 | 
            -
            from . import search
         | 
| 2 | 
            -
            from rag.utils import ELASTICSEARCH
         | 
| 3 |  | 
| 4 | 
            -
            retrievaler = search.Dealer(ELASTICSEARCH)
         | 
| 5 |  | 
| 6 | 
             
            from nltk.stem import PorterStemmer
         | 
| 7 | 
             
            stemmer = PorterStemmer()
         | 
| @@ -39,10 +36,12 @@ BULLET_PATTERN = [[ | |
| 39 | 
             
            ]
         | 
| 40 | 
             
            ]
         | 
| 41 |  | 
|  | |
| 42 | 
             
            def random_choices(arr, k):
         | 
| 43 | 
             
                k = min(len(arr), k)
         | 
| 44 | 
             
                return random.choices(arr, k=k)
         | 
| 45 |  | 
|  | |
| 46 | 
             
            def bullets_category(sections):
         | 
| 47 | 
             
                global BULLET_PATTERN
         | 
| 48 | 
             
                hits = [0] * len(BULLET_PATTERN)
         | 
|  | |
|  | |
|  | |
| 1 |  | 
|  | |
| 2 |  | 
| 3 | 
             
            from nltk.stem import PorterStemmer
         | 
| 4 | 
             
            stemmer = PorterStemmer()
         | 
|  | |
| 36 | 
             
            ]
         | 
| 37 | 
             
            ]
         | 
| 38 |  | 
| 39 | 
            +
             | 
| 40 | 
             
            def random_choices(arr, k):
         | 
| 41 | 
             
                k = min(len(arr), k)
         | 
| 42 | 
             
                return random.choices(arr, k=k)
         | 
| 43 |  | 
| 44 | 
            +
             | 
| 45 | 
             
            def bullets_category(sections):
         | 
| 46 | 
             
                global BULLET_PATTERN
         | 
| 47 | 
             
                hits = [0] * len(BULLET_PATTERN)
         | 
    	
        rag/nlp/search.py
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 | 
             
            # -*- coding: utf-8 -*-
         | 
| 2 | 
             
            import json
         | 
| 3 | 
             
            import re
         | 
| 4 | 
            -
            from elasticsearch_dsl import Q, Search | 
| 5 | 
             
            from typing import List, Optional, Dict, Union
         | 
| 6 | 
             
            from dataclasses import dataclass
         | 
| 7 |  | 
| @@ -183,6 +183,7 @@ class Dealer: | |
| 183 |  | 
| 184 | 
             
                def insert_citations(self, answer, chunks, chunk_v,
         | 
| 185 | 
             
                                     embd_mdl, tkweight=0.3, vtweight=0.7):
         | 
|  | |
| 186 | 
             
                    pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
         | 
| 187 | 
             
                    for i in range(1, len(pieces)):
         | 
| 188 | 
             
                        if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
         | 
| @@ -216,7 +217,7 @@ class Dealer: | |
| 216 | 
             
                        if mx < 0.55:
         | 
| 217 | 
             
                            continue
         | 
| 218 | 
             
                        cites[idx[i]] = list(
         | 
| 219 | 
            -
                            set([str( | 
| 220 |  | 
| 221 | 
             
                    res = ""
         | 
| 222 | 
             
                    for i, p in enumerate(pieces):
         | 
| @@ -225,6 +226,7 @@ class Dealer: | |
| 225 | 
             
                            continue
         | 
| 226 | 
             
                        if i not in cites:
         | 
| 227 | 
             
                            continue
         | 
|  | |
| 228 | 
             
                        res += "##%s$$" % "$".join(cites[i])
         | 
| 229 |  | 
| 230 | 
             
                    return res
         | 
|  | |
| 1 | 
             
            # -*- coding: utf-8 -*-
         | 
| 2 | 
             
            import json
         | 
| 3 | 
             
            import re
         | 
| 4 | 
            +
            from elasticsearch_dsl import Q, Search
         | 
| 5 | 
             
            from typing import List, Optional, Dict, Union
         | 
| 6 | 
             
            from dataclasses import dataclass
         | 
| 7 |  | 
|  | |
| 183 |  | 
| 184 | 
             
                def insert_citations(self, answer, chunks, chunk_v,
         | 
| 185 | 
             
                                     embd_mdl, tkweight=0.3, vtweight=0.7):
         | 
| 186 | 
            +
                    assert len(chunks) == len(chunk_v)
         | 
| 187 | 
             
                    pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
         | 
| 188 | 
             
                    for i in range(1, len(pieces)):
         | 
| 189 | 
             
                        if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
         | 
|  | |
| 217 | 
             
                        if mx < 0.55:
         | 
| 218 | 
             
                            continue
         | 
| 219 | 
             
                        cites[idx[i]] = list(
         | 
| 220 | 
            +
                            set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
         | 
| 221 |  | 
| 222 | 
             
                    res = ""
         | 
| 223 | 
             
                    for i, p in enumerate(pieces):
         | 
|  | |
| 226 | 
             
                            continue
         | 
| 227 | 
             
                        if i not in cites:
         | 
| 228 | 
             
                            continue
         | 
| 229 | 
            +
                        assert int(cites[i]) < len(chunk_v)
         | 
| 230 | 
             
                        res += "##%s$$" % "$".join(cites[i])
         | 
| 231 |  | 
| 232 | 
             
                    return res
         |