KevinHuSh
		
	commited on
		
		
					Commit 
							
							·
						
						3198faf
	
1
								Parent(s):
							
							3079197
								
add alot of api (#23)
Browse files* clean rust version project
* clean rust version project
* build python version rag-flow
* add alot of api
- rag/llm/embedding_model.py +1 -1
- rag/nlp/huchunk.py +6 -3
- rag/nlp/search.py +1 -1
- rag/svr/parse_user_docs.py +35 -19
- rag/utils/__init__.py +19 -0
- rag/utils/es_conn.py +1 -0
- web_server/apps/document_app.py +47 -2
- web_server/apps/kb_app.py +14 -2
- web_server/apps/llm_app.py +95 -0
- web_server/apps/user_app.py +33 -5
- web_server/db/db_models.py +7 -4
- web_server/db/services/document_service.py +27 -13
- web_server/db/services/kb_service.py +33 -6
- web_server/db/services/llm_service.py +18 -0
- web_server/db/services/user_service.py +1 -1
- web_server/utils/file_utils.py +1 -1
    	
        rag/llm/embedding_model.py
    CHANGED
    
    | @@ -35,7 +35,7 @@ class Base(ABC): | |
| 35 |  | 
| 36 |  | 
| 37 | 
             
            class HuEmbedding(Base):
         | 
| 38 | 
            -
                def __init__(self):
         | 
| 39 | 
             
                    """
         | 
| 40 | 
             
                    If you have trouble downloading HuggingFace models, -_^ this might help!!
         | 
| 41 |  | 
|  | |
| 35 |  | 
| 36 |  | 
| 37 | 
             
            class HuEmbedding(Base):
         | 
| 38 | 
            +
                def __init__(self, key="", model_name=""):
         | 
| 39 | 
             
                    """
         | 
| 40 | 
             
                    If you have trouble downloading HuggingFace models, -_^ this might help!!
         | 
| 41 |  | 
    	
        rag/nlp/huchunk.py
    CHANGED
    
    | @@ -411,9 +411,12 @@ class TextChunker(HuChunker): | |
| 411 | 
             
                    flds = self.Fields()
         | 
| 412 | 
             
                    if self.is_binary_file(fnm):
         | 
| 413 | 
             
                        return flds
         | 
| 414 | 
            -
                     | 
| 415 | 
            -
             | 
| 416 | 
            -
                         | 
|  | |
|  | |
|  | |
| 417 | 
             
                    flds.table_chunks = []
         | 
| 418 | 
             
                    return flds
         | 
| 419 |  | 
|  | |
| 411 | 
             
                    flds = self.Fields()
         | 
| 412 | 
             
                    if self.is_binary_file(fnm):
         | 
| 413 | 
             
                        return flds
         | 
| 414 | 
            +
                    txt = ""
         | 
| 415 | 
            +
                    if isinstance(fnm, str):
         | 
| 416 | 
            +
                        with open(fnm, "r") as f:
         | 
| 417 | 
            +
                            txt = f.read()
         | 
| 418 | 
            +
                    else: txt = fnm.decode("utf-8")
         | 
| 419 | 
            +
                    flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
         | 
| 420 | 
             
                    flds.table_chunks = []
         | 
| 421 | 
             
                    return flds
         | 
| 422 |  | 
    	
        rag/nlp/search.py
    CHANGED
    
    | @@ -8,7 +8,7 @@ from rag.nlp import huqie, query | |
| 8 | 
             
            import numpy as np
         | 
| 9 |  | 
| 10 |  | 
| 11 | 
            -
            def index_name(uid): return f" | 
| 12 |  | 
| 13 |  | 
| 14 | 
             
            class Dealer:
         | 
|  | |
| 8 | 
             
            import numpy as np
         | 
| 9 |  | 
| 10 |  | 
| 11 | 
            +
            def index_name(uid): return f"ragflow_{uid}"
         | 
| 12 |  | 
| 13 |  | 
| 14 | 
             
            class Dealer:
         | 
    	
        rag/svr/parse_user_docs.py
    CHANGED
    
    | @@ -14,6 +14,7 @@ | |
| 14 | 
             
            #  limitations under the License.
         | 
| 15 | 
             
            #
         | 
| 16 | 
             
            import json
         | 
|  | |
| 17 | 
             
            import os
         | 
| 18 | 
             
            import hashlib
         | 
| 19 | 
             
            import copy
         | 
| @@ -24,9 +25,10 @@ from timeit import default_timer as timer | |
| 24 |  | 
| 25 | 
             
            from rag.llm import EmbeddingModel, CvModel
         | 
| 26 | 
             
            from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
         | 
| 27 | 
            -
            from rag.utils import ELASTICSEARCH | 
| 28 | 
             
            from rag.utils import MINIO
         | 
| 29 | 
            -
            from rag.utils import rmSpace,  | 
|  | |
| 30 | 
             
            from rag.nlp import huchunk, huqie, search
         | 
| 31 | 
             
            from io import BytesIO
         | 
| 32 | 
             
            import pandas as pd
         | 
| @@ -47,6 +49,7 @@ from rag.nlp.huchunk import ( | |
| 47 | 
             
            from web_server.db import LLMType
         | 
| 48 | 
             
            from web_server.db.services.document_service import DocumentService
         | 
| 49 | 
             
            from web_server.db.services.llm_service import TenantLLMService
         | 
|  | |
| 50 | 
             
            from web_server.utils import get_format_time
         | 
| 51 | 
             
            from web_server.utils.file_utils import get_project_base_directory
         | 
| 52 |  | 
| @@ -83,7 +86,7 @@ def collect(comm, mod, tm): | |
| 83 | 
             
                if len(docs) == 0:
         | 
| 84 | 
             
                    return pd.DataFrame()
         | 
| 85 | 
             
                docs = pd.DataFrame(docs)
         | 
| 86 | 
            -
                mtm =  | 
| 87 | 
             
                cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
         | 
| 88 | 
             
                return docs
         | 
| 89 |  | 
| @@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False): | |
| 99 | 
             
                    cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
         | 
| 100 |  | 
| 101 |  | 
| 102 | 
            -
            def build(row):
         | 
| 103 | 
             
                if row["size"] > DOC_MAXIMUM_SIZE:
         | 
| 104 | 
             
                    set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
         | 
| 105 | 
             
                                 (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
         | 
| 106 | 
             
                    return []
         | 
|  | |
| 107 | 
             
                res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
         | 
| 108 | 
             
                if ELASTICSEARCH.getTotal(res) > 0:
         | 
| 109 | 
             
                    ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
         | 
| @@ -120,7 +124,8 @@ def build(row): | |
| 120 | 
             
                set_progress(row["id"], random.randint(0, 20) /
         | 
| 121 | 
             
                             100., "Finished preparing! Start to slice file!", True)
         | 
| 122 | 
             
                try:
         | 
| 123 | 
            -
                     | 
|  | |
| 124 | 
             
                except Exception as e:
         | 
| 125 | 
             
                    if re.search("(No such file|not found)", str(e)):
         | 
| 126 | 
             
                        set_progress(
         | 
| @@ -131,6 +136,9 @@ def build(row): | |
| 131 | 
             
                            row["id"], -1, f"Internal server error: %s" %
         | 
| 132 | 
             
                            str(e).replace(
         | 
| 133 | 
             
                                "'", ""))
         | 
|  | |
|  | |
|  | |
| 134 | 
             
                    return []
         | 
| 135 |  | 
| 136 | 
             
                if not obj.text_chunks and not obj.table_chunks:
         | 
| @@ -144,7 +152,7 @@ def build(row): | |
| 144 | 
             
                             "Finished slicing files. Start to embedding the content.")
         | 
| 145 |  | 
| 146 | 
             
                doc = {
         | 
| 147 | 
            -
                    "doc_id": row[" | 
| 148 | 
             
                    "kb_id": [str(row["kb_id"])],
         | 
| 149 | 
             
                    "docnm_kwd": os.path.split(row["location"])[-1],
         | 
| 150 | 
             
                    "title_tks": huqie.qie(row["name"]),
         | 
| @@ -164,10 +172,10 @@ def build(row): | |
| 164 | 
             
                        docs.append(d)
         | 
| 165 | 
             
                        continue
         | 
| 166 |  | 
| 167 | 
            -
                    if isinstance(img,  | 
| 168 | 
            -
                        img.save(output_buffer, format='JPEG')
         | 
| 169 | 
            -
                    else:
         | 
| 170 | 
             
                        output_buffer = BytesIO(img)
         | 
|  | |
|  | |
| 171 |  | 
| 172 | 
             
                    MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
         | 
| 173 | 
             
                    d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
         | 
| @@ -215,15 +223,16 @@ def embedding(docs, mdl): | |
| 215 |  | 
| 216 |  | 
| 217 | 
             
            def model_instance(tenant_id, llm_type):
         | 
| 218 | 
            -
                model_config = TenantLLMService. | 
| 219 | 
            -
                if not model_config: | 
| 220 | 
            -
             | 
|  | |
| 221 | 
             
                if llm_type == LLMType.EMBEDDING:
         | 
| 222 | 
            -
                    if model_config | 
| 223 | 
            -
                    return EmbeddingModel[model_config | 
| 224 | 
             
                if llm_type == LLMType.IMAGE2TEXT:
         | 
| 225 | 
            -
                    if model_config | 
| 226 | 
            -
                    return CvModel[model_config.llm_factory](model_config | 
| 227 |  | 
| 228 |  | 
| 229 | 
             
            def main(comm, mod):
         | 
| @@ -231,7 +240,7 @@ def main(comm, mod): | |
| 231 | 
             
                from rag.llm import HuEmbedding
         | 
| 232 | 
             
                model = HuEmbedding()
         | 
| 233 | 
             
                tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
         | 
| 234 | 
            -
                tm =  | 
| 235 | 
             
                rows = collect(comm, mod, tm)
         | 
| 236 | 
             
                if len(rows) == 0:
         | 
| 237 | 
             
                    return
         | 
| @@ -247,7 +256,7 @@ def main(comm, mod): | |
| 247 | 
             
                    st_tm = timer()
         | 
| 248 | 
             
                    cks = build(r, cv_mdl)
         | 
| 249 | 
             
                    if not cks:
         | 
| 250 | 
            -
                        tmf.write(str(r[" | 
| 251 | 
             
                        continue
         | 
| 252 | 
             
                    # TODO: exception handler
         | 
| 253 | 
             
                    ## set_progress(r["did"], -1, "ERROR: ")
         | 
| @@ -268,12 +277,19 @@ def main(comm, mod): | |
| 268 | 
             
                        cron_logger.error(str(es_r))
         | 
| 269 | 
             
                    else:
         | 
| 270 | 
             
                        set_progress(r["id"], 1., "Done!")
         | 
| 271 | 
            -
                        DocumentService. | 
|  | |
|  | |
| 272 | 
             
                    tmf.write(str(r["update_time"]) + "\n")
         | 
| 273 | 
             
                tmf.close()
         | 
| 274 |  | 
| 275 |  | 
| 276 | 
             
            if __name__ == "__main__":
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 277 | 
             
                from mpi4py import MPI
         | 
| 278 | 
             
                comm = MPI.COMM_WORLD
         | 
| 279 | 
             
                main(comm.Get_size(), comm.Get_rank())
         | 
|  | |
| 14 | 
             
            #  limitations under the License.
         | 
| 15 | 
             
            #
         | 
| 16 | 
             
            import json
         | 
| 17 | 
            +
            import logging
         | 
| 18 | 
             
            import os
         | 
| 19 | 
             
            import hashlib
         | 
| 20 | 
             
            import copy
         | 
|  | |
| 25 |  | 
| 26 | 
             
            from rag.llm import EmbeddingModel, CvModel
         | 
| 27 | 
             
            from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
         | 
| 28 | 
            +
            from rag.utils import ELASTICSEARCH
         | 
| 29 | 
             
            from rag.utils import MINIO
         | 
| 30 | 
            +
            from rag.utils import rmSpace, findMaxTm
         | 
| 31 | 
            +
             | 
| 32 | 
             
            from rag.nlp import huchunk, huqie, search
         | 
| 33 | 
             
            from io import BytesIO
         | 
| 34 | 
             
            import pandas as pd
         | 
|  | |
| 49 | 
             
            from web_server.db import LLMType
         | 
| 50 | 
             
            from web_server.db.services.document_service import DocumentService
         | 
| 51 | 
             
            from web_server.db.services.llm_service import TenantLLMService
         | 
| 52 | 
            +
            from web_server.settings import database_logger
         | 
| 53 | 
             
            from web_server.utils import get_format_time
         | 
| 54 | 
             
            from web_server.utils.file_utils import get_project_base_directory
         | 
| 55 |  | 
|  | |
| 86 | 
             
                if len(docs) == 0:
         | 
| 87 | 
             
                    return pd.DataFrame()
         | 
| 88 | 
             
                docs = pd.DataFrame(docs)
         | 
| 89 | 
            +
                mtm = docs["update_time"].max()
         | 
| 90 | 
             
                cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
         | 
| 91 | 
             
                return docs
         | 
| 92 |  | 
|  | |
| 102 | 
             
                    cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
         | 
| 103 |  | 
| 104 |  | 
| 105 | 
            +
            def build(row, cvmdl):
         | 
| 106 | 
             
                if row["size"] > DOC_MAXIMUM_SIZE:
         | 
| 107 | 
             
                    set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
         | 
| 108 | 
             
                                 (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
         | 
| 109 | 
             
                    return []
         | 
| 110 | 
            +
             | 
| 111 | 
             
                res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
         | 
| 112 | 
             
                if ELASTICSEARCH.getTotal(res) > 0:
         | 
| 113 | 
             
                    ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
         | 
|  | |
| 124 | 
             
                set_progress(row["id"], random.randint(0, 20) /
         | 
| 125 | 
             
                             100., "Finished preparing! Start to slice file!", True)
         | 
| 126 | 
             
                try:
         | 
| 127 | 
            +
                    cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
         | 
| 128 | 
            +
                    obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
         | 
| 129 | 
             
                except Exception as e:
         | 
| 130 | 
             
                    if re.search("(No such file|not found)", str(e)):
         | 
| 131 | 
             
                        set_progress(
         | 
|  | |
| 136 | 
             
                            row["id"], -1, f"Internal server error: %s" %
         | 
| 137 | 
             
                            str(e).replace(
         | 
| 138 | 
             
                                "'", ""))
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
         | 
| 141 | 
            +
             | 
| 142 | 
             
                    return []
         | 
| 143 |  | 
| 144 | 
             
                if not obj.text_chunks and not obj.table_chunks:
         | 
|  | |
| 152 | 
             
                             "Finished slicing files. Start to embedding the content.")
         | 
| 153 |  | 
| 154 | 
             
                doc = {
         | 
| 155 | 
            +
                    "doc_id": row["id"],
         | 
| 156 | 
             
                    "kb_id": [str(row["kb_id"])],
         | 
| 157 | 
             
                    "docnm_kwd": os.path.split(row["location"])[-1],
         | 
| 158 | 
             
                    "title_tks": huqie.qie(row["name"]),
         | 
|  | |
| 172 | 
             
                        docs.append(d)
         | 
| 173 | 
             
                        continue
         | 
| 174 |  | 
| 175 | 
            +
                    if isinstance(img, bytes):
         | 
|  | |
|  | |
| 176 | 
             
                        output_buffer = BytesIO(img)
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        img.save(output_buffer, format='JPEG')
         | 
| 179 |  | 
| 180 | 
             
                    MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
         | 
| 181 | 
             
                    d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
         | 
|  | |
| 223 |  | 
| 224 |  | 
| 225 | 
             
            def model_instance(tenant_id, llm_type):
         | 
| 226 | 
            +
                model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
         | 
| 227 | 
            +
                if not model_config:
         | 
| 228 | 
            +
                    model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
         | 
| 229 | 
            +
                else: model_config = model_config[0].to_dict()
         | 
| 230 | 
             
                if llm_type == LLMType.EMBEDDING:
         | 
| 231 | 
            +
                    if model_config["llm_factory"] not in EmbeddingModel: return
         | 
| 232 | 
            +
                    return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
         | 
| 233 | 
             
                if llm_type == LLMType.IMAGE2TEXT:
         | 
| 234 | 
            +
                    if model_config["llm_factory"] not in CvModel: return
         | 
| 235 | 
            +
                    return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
         | 
| 236 |  | 
| 237 |  | 
| 238 | 
             
            def main(comm, mod):
         | 
|  | |
| 240 | 
             
                from rag.llm import HuEmbedding
         | 
| 241 | 
             
                model = HuEmbedding()
         | 
| 242 | 
             
                tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
         | 
| 243 | 
            +
                tm = findMaxTm(tm_fnm)
         | 
| 244 | 
             
                rows = collect(comm, mod, tm)
         | 
| 245 | 
             
                if len(rows) == 0:
         | 
| 246 | 
             
                    return
         | 
|  | |
| 256 | 
             
                    st_tm = timer()
         | 
| 257 | 
             
                    cks = build(r, cv_mdl)
         | 
| 258 | 
             
                    if not cks:
         | 
| 259 | 
            +
                        tmf.write(str(r["update_time"]) + "\n")
         | 
| 260 | 
             
                        continue
         | 
| 261 | 
             
                    # TODO: exception handler
         | 
| 262 | 
             
                    ## set_progress(r["did"], -1, "ERROR: ")
         | 
|  | |
| 277 | 
             
                        cron_logger.error(str(es_r))
         | 
| 278 | 
             
                    else:
         | 
| 279 | 
             
                        set_progress(r["id"], 1., "Done!")
         | 
| 280 | 
            +
                        DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
         | 
| 281 | 
            +
                        cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
         | 
| 282 | 
            +
             | 
| 283 | 
             
                    tmf.write(str(r["update_time"]) + "\n")
         | 
| 284 | 
             
                tmf.close()
         | 
| 285 |  | 
| 286 |  | 
| 287 | 
             
            if __name__ == "__main__":
         | 
| 288 | 
            +
                peewee_logger = logging.getLogger('peewee')
         | 
| 289 | 
            +
                peewee_logger.propagate = False
         | 
| 290 | 
            +
                peewee_logger.addHandler(database_logger.handlers[0])
         | 
| 291 | 
            +
                peewee_logger.setLevel(database_logger.level)
         | 
| 292 | 
            +
             | 
| 293 | 
             
                from mpi4py import MPI
         | 
| 294 | 
             
                comm = MPI.COMM_WORLD
         | 
| 295 | 
             
                main(comm.Get_size(), comm.Get_rank())
         | 
    	
        rag/utils/__init__.py
    CHANGED
    
    | @@ -40,6 +40,25 @@ def findMaxDt(fnm): | |
| 40 | 
             
                    print("WARNING: can't find " + fnm)
         | 
| 41 | 
             
                return m
         | 
| 42 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 43 | 
             
            def num_tokens_from_string(string: str) -> int:
         | 
| 44 | 
             
                """Returns the number of tokens in a text string."""
         | 
| 45 | 
             
                encoding = tiktoken.get_encoding('cl100k_base')
         | 
|  | |
| 40 | 
             
                    print("WARNING: can't find " + fnm)
         | 
| 41 | 
             
                return m
         | 
| 42 |  | 
| 43 | 
            +
              
         | 
| 44 | 
            +
            def findMaxTm(fnm):
         | 
| 45 | 
            +
                m = 0
         | 
| 46 | 
            +
                try:
         | 
| 47 | 
            +
                    with open(fnm, "r") as f:
         | 
| 48 | 
            +
                        while True:
         | 
| 49 | 
            +
                            l = f.readline()
         | 
| 50 | 
            +
                            if not l:
         | 
| 51 | 
            +
                                break
         | 
| 52 | 
            +
                            l = l.strip("\n")
         | 
| 53 | 
            +
                            if l == 'nan':
         | 
| 54 | 
            +
                                continue
         | 
| 55 | 
            +
                            if int(l) > m:
         | 
| 56 | 
            +
                                m = int(l)
         | 
| 57 | 
            +
                except Exception as e:
         | 
| 58 | 
            +
                    print("WARNING: can't find " + fnm)
         | 
| 59 | 
            +
                return m
         | 
| 60 | 
            +
             | 
| 61 | 
            +
              
         | 
| 62 | 
             
            def num_tokens_from_string(string: str) -> int:
         | 
| 63 | 
             
                """Returns the number of tokens in a text string."""
         | 
| 64 | 
             
                encoding = tiktoken.get_encoding('cl100k_base')
         | 
    	
        rag/utils/es_conn.py
    CHANGED
    
    | @@ -294,6 +294,7 @@ class HuEs: | |
| 294 | 
             
                        except Exception as e:
         | 
| 295 | 
             
                            es_logger.error("ES updateByQuery deleteByQuery: " +
         | 
| 296 | 
             
                                            str(e) + "【Q】:" + str(query.to_dict()))
         | 
|  | |
| 297 | 
             
                            if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
         | 
| 298 | 
             
                                continue
         | 
| 299 |  | 
|  | |
| 294 | 
             
                        except Exception as e:
         | 
| 295 | 
             
                            es_logger.error("ES updateByQuery deleteByQuery: " +
         | 
| 296 | 
             
                                            str(e) + "【Q】:" + str(query.to_dict()))
         | 
| 297 | 
            +
                            if str(e).find("NotFoundError") > 0: return True
         | 
| 298 | 
             
                            if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
         | 
| 299 | 
             
                                continue
         | 
| 300 |  | 
    	
        web_server/apps/document_app.py
    CHANGED
    
    | @@ -13,6 +13,7 @@ | |
| 13 | 
             
            #  See the License for the specific language governing permissions and
         | 
| 14 | 
             
            #  limitations under the License.
         | 
| 15 | 
             
            #
         | 
|  | |
| 16 | 
             
            import pathlib
         | 
| 17 |  | 
| 18 | 
             
            from elasticsearch_dsl import Q
         | 
| @@ -195,11 +196,15 @@ def rm(): | |
| 195 | 
             
                    e, doc = DocumentService.get_by_id(req["doc_id"])
         | 
| 196 | 
             
                    if not e:
         | 
| 197 | 
             
                        return get_data_error_result(retmsg="Document not found!")
         | 
|  | |
|  | |
|  | |
|  | |
| 198 | 
             
                    if not DocumentService.delete_by_id(req["doc_id"]):
         | 
| 199 | 
             
                        return get_data_error_result(
         | 
| 200 | 
             
                            retmsg="Database error (Document removal)!")
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                    MINIO.rm( | 
| 203 | 
             
                    return get_json_result(data=True)
         | 
| 204 | 
             
                except Exception as e:
         | 
| 205 | 
             
                    return server_error_response(e)
         | 
| @@ -233,3 +238,43 @@ def rename(): | |
| 233 | 
             
                    return get_json_result(data=True)
         | 
| 234 | 
             
                except Exception as e:
         | 
| 235 | 
             
                    return server_error_response(e)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 13 | 
             
            #  See the License for the specific language governing permissions and
         | 
| 14 | 
             
            #  limitations under the License.
         | 
| 15 | 
             
            #
         | 
| 16 | 
            +
            import base64
         | 
| 17 | 
             
            import pathlib
         | 
| 18 |  | 
| 19 | 
             
            from elasticsearch_dsl import Q
         | 
|  | |
| 196 | 
             
                    e, doc = DocumentService.get_by_id(req["doc_id"])
         | 
| 197 | 
             
                    if not e:
         | 
| 198 | 
             
                        return get_data_error_result(retmsg="Document not found!")
         | 
| 199 | 
            +
                    if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
         | 
| 200 | 
            +
                        return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
         | 
| 203 | 
             
                    if not DocumentService.delete_by_id(req["doc_id"]):
         | 
| 204 | 
             
                        return get_data_error_result(
         | 
| 205 | 
             
                            retmsg="Database error (Document removal)!")
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    MINIO.rm(doc.kb_id, doc.location)
         | 
| 208 | 
             
                    return get_json_result(data=True)
         | 
| 209 | 
             
                except Exception as e:
         | 
| 210 | 
             
                    return server_error_response(e)
         | 
|  | |
| 238 | 
             
                    return get_json_result(data=True)
         | 
| 239 | 
             
                except Exception as e:
         | 
| 240 | 
             
                    return server_error_response(e)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            @manager.route('/get', methods=['GET'])
         | 
| 244 | 
            +
            @login_required
         | 
| 245 | 
            +
            def get():
         | 
| 246 | 
            +
                doc_id = request.args["doc_id"]
         | 
| 247 | 
            +
                try:
         | 
| 248 | 
            +
                    e, doc = DocumentService.get_by_id(doc_id)
         | 
| 249 | 
            +
                    if not e:
         | 
| 250 | 
            +
                        return get_data_error_result(retmsg="Document not found!")
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    blob = MINIO.get(doc.kb_id, doc.location)
         | 
| 253 | 
            +
                    return get_json_result(data={"base64": base64.b64decode(blob)})
         | 
| 254 | 
            +
                except Exception as e:
         | 
| 255 | 
            +
                    return server_error_response(e)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
             | 
| 258 | 
            +
            @manager.route('/change_parser', methods=['POST'])
         | 
| 259 | 
            +
            @login_required
         | 
| 260 | 
            +
            @validate_request("doc_id", "parser_id")
         | 
| 261 | 
            +
            def change_parser():
         | 
| 262 | 
            +
                req = request.json
         | 
| 263 | 
            +
                try:
         | 
| 264 | 
            +
                    e, doc = DocumentService.get_by_id(req["doc_id"])
         | 
| 265 | 
            +
                    if not e:
         | 
| 266 | 
            +
                        return get_data_error_result(retmsg="Document not found!")
         | 
| 267 | 
            +
                    if doc.parser_id.lower() == req["parser_id"].lower():
         | 
| 268 | 
            +
                        return get_json_result(data=True)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
         | 
| 271 | 
            +
                    if not e:
         | 
| 272 | 
            +
                        return get_data_error_result(retmsg="Document not found!")
         | 
| 273 | 
            +
                    e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
         | 
| 274 | 
            +
                    if not e:
         | 
| 275 | 
            +
                        return get_data_error_result(retmsg="Document not found!")
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    return get_json_result(data=True)
         | 
| 278 | 
            +
                except Exception as e:
         | 
| 279 | 
            +
                    return server_error_response(e)
         | 
| 280 | 
            +
             | 
    	
        web_server/apps/kb_app.py
    CHANGED
    
    | @@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result | |
| 29 |  | 
| 30 | 
             
            @manager.route('/create', methods=['post'])
         | 
| 31 | 
             
            @login_required
         | 
| 32 | 
            -
            @validate_request("name", "description", "permission", " | 
| 33 | 
             
            def create():
         | 
| 34 | 
             
                req = request.json
         | 
| 35 | 
             
                req["name"] = req["name"].strip()
         | 
| @@ -46,7 +46,7 @@ def create(): | |
| 46 |  | 
| 47 | 
             
            @manager.route('/update', methods=['post'])
         | 
| 48 | 
             
            @login_required
         | 
| 49 | 
            -
            @validate_request("kb_id", "name", "description", "permission", " | 
| 50 | 
             
            def update():
         | 
| 51 | 
             
                req = request.json
         | 
| 52 | 
             
                req["name"] = req["name"].strip()
         | 
| @@ -72,6 +72,18 @@ def update(): | |
| 72 | 
             
                    return server_error_response(e)
         | 
| 73 |  | 
| 74 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 75 | 
             
            @manager.route('/list', methods=['GET'])
         | 
| 76 | 
             
            @login_required
         | 
| 77 | 
             
            def list():
         | 
|  | |
| 29 |  | 
| 30 | 
             
            @manager.route('/create', methods=['post'])
         | 
| 31 | 
             
            @login_required
         | 
| 32 | 
            +
            @validate_request("name", "description", "permission", "parser_id")
         | 
| 33 | 
             
            def create():
         | 
| 34 | 
             
                req = request.json
         | 
| 35 | 
             
                req["name"] = req["name"].strip()
         | 
|  | |
| 46 |  | 
| 47 | 
             
            @manager.route('/update', methods=['post'])
         | 
| 48 | 
             
            @login_required
         | 
| 49 | 
            +
            @validate_request("kb_id", "name", "description", "permission", "parser_id")
         | 
| 50 | 
             
            def update():
         | 
| 51 | 
             
                req = request.json
         | 
| 52 | 
             
                req["name"] = req["name"].strip()
         | 
|  | |
| 72 | 
             
                    return server_error_response(e)
         | 
| 73 |  | 
| 74 |  | 
| 75 | 
            +
            @manager.route('/detail', methods=['GET'])
         | 
| 76 | 
            +
            @login_required
         | 
| 77 | 
            +
            def detail():
         | 
| 78 | 
            +
                kb_id = request.args["kb_id"]
         | 
| 79 | 
            +
                try:
         | 
| 80 | 
            +
                    kb = KnowledgebaseService.get_detail(kb_id)
         | 
| 81 | 
            +
                    if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
         | 
| 82 | 
            +
                    return get_json_result(data=kb)
         | 
| 83 | 
            +
                except Exception as e:
         | 
| 84 | 
            +
                    return server_error_response(e)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
             
            @manager.route('/list', methods=['GET'])
         | 
| 88 | 
             
            @login_required
         | 
| 89 | 
             
            def list():
         | 
    	
        web_server/apps/llm_app.py
    ADDED
    
    | @@ -0,0 +1,95 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            #  Copyright 2019 The FATE Authors. All Rights Reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            #  Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            #  you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            #  You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #      http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            #  Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            #  distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            #  See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            #  limitations under the License.
         | 
| 15 | 
            +
            #
         | 
| 16 | 
            +
            from flask import request
         | 
| 17 | 
            +
            from flask_login import login_required, current_user
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from web_server.db.services import duplicate_name
         | 
| 20 | 
            +
            from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
         | 
| 21 | 
            +
            from web_server.db.services.user_service import TenantService, UserTenantService
         | 
| 22 | 
            +
            from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
         | 
| 23 | 
            +
            from web_server.utils import get_uuid, get_format_time
         | 
| 24 | 
            +
            from web_server.db import StatusEnum, UserTenantRole
         | 
| 25 | 
            +
            from web_server.db.services.kb_service import KnowledgebaseService
         | 
| 26 | 
            +
            from web_server.db.db_models import Knowledgebase, TenantLLM
         | 
| 27 | 
            +
            from web_server.settings import stat_logger, RetCode
         | 
| 28 | 
            +
            from web_server.utils.api_utils import get_json_result
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            @manager.route('/factories', methods=['GET'])
         | 
| 32 | 
            +
            @login_required
         | 
| 33 | 
            +
            def factories():
         | 
| 34 | 
            +
                try:
         | 
| 35 | 
            +
                    fac = LLMFactoriesService.get_all()
         | 
| 36 | 
            +
                    return get_json_result(data=fac.to_json())
         | 
| 37 | 
            +
                except Exception as e:
         | 
| 38 | 
            +
                    return server_error_response(e)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            @manager.route('/set_api_key', methods=['POST'])
         | 
| 42 | 
            +
            @login_required
         | 
| 43 | 
            +
            @validate_request("llm_factory", "api_key")
         | 
| 44 | 
            +
            def set_api_key():
         | 
| 45 | 
            +
                req = request.json
         | 
| 46 | 
            +
                llm = {
         | 
| 47 | 
            +
                    "tenant_id": current_user.id,
         | 
| 48 | 
            +
                    "llm_factory": req["llm_factory"],
         | 
| 49 | 
            +
                    "api_key": req["api_key"]
         | 
| 50 | 
            +
                }
         | 
| 51 | 
            +
                # TODO: Test api_key
         | 
| 52 | 
            +
                for n in ["model_type", "llm_name"]:
         | 
| 53 | 
            +
                    if n in req: llm[n] = req[n]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                TenantLLM.insert(**llm).on_conflict("replace").execute()
         | 
| 56 | 
            +
                return get_json_result(data=True)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            @manager.route('/my_llms', methods=['GET'])
         | 
| 60 | 
            +
            @login_required
         | 
| 61 | 
            +
            def my_llms():
         | 
| 62 | 
            +
                try:
         | 
| 63 | 
            +
                    objs = TenantLLMService.query(tenant_id=current_user.id)
         | 
| 64 | 
            +
                    objs = [o.to_dict() for o in objs]
         | 
| 65 | 
            +
                    for o in objs: del o["api_key"]
         | 
| 66 | 
            +
                    return get_json_result(data=objs)
         | 
| 67 | 
            +
                except Exception as e:
         | 
| 68 | 
            +
                    return server_error_response(e)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            @manager.route('/list', methods=['GET'])
         | 
| 72 | 
            +
            @login_required
         | 
| 73 | 
            +
            def list():
         | 
| 74 | 
            +
                try:
         | 
| 75 | 
            +
                    objs = TenantLLMService.query(tenant_id=current_user.id)
         | 
| 76 | 
            +
                    objs = [o.to_dict() for o in objs if o.api_key]
         | 
| 77 | 
            +
                    fct = {}
         | 
| 78 | 
            +
                    for o in objs:
         | 
| 79 | 
            +
                        if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
         | 
| 80 | 
            +
                        if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    llms = LLMService.get_all()
         | 
| 83 | 
            +
                    llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
         | 
| 84 | 
            +
                    for m in llms:
         | 
| 85 | 
            +
                        m["available"] = False
         | 
| 86 | 
            +
                        if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
         | 
| 87 | 
            +
                            m["available"] = True
         | 
| 88 | 
            +
                    res = {}
         | 
| 89 | 
            +
                    for m in llms:
         | 
| 90 | 
            +
                        if m["fid"] not in res: res[m["fid"]] = []
         | 
| 91 | 
            +
                        res[m["fid"]].append(m)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    return get_json_result(data=res)
         | 
| 94 | 
            +
                except Exception as e:
         | 
| 95 | 
            +
                    return server_error_response(e)
         | 
    	
        web_server/apps/user_app.py
    CHANGED
    
    | @@ -16,9 +16,12 @@ | |
| 16 | 
             
            from flask import request, session, redirect, url_for
         | 
| 17 | 
             
            from werkzeug.security import generate_password_hash, check_password_hash
         | 
| 18 | 
             
            from flask_login import login_required, current_user, login_user, logout_user
         | 
|  | |
|  | |
|  | |
| 19 | 
             
            from web_server.utils.api_utils import server_error_response, validate_request
         | 
| 20 | 
             
            from web_server.utils import get_uuid, get_format_time, decrypt, download_img
         | 
| 21 | 
            -
            from web_server.db import UserTenantRole
         | 
| 22 | 
             
            from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
         | 
| 23 | 
             
            from web_server.db.services.user_service import UserService, TenantService, UserTenantService
         | 
| 24 | 
             
            from web_server.settings import stat_logger
         | 
| @@ -47,8 +50,9 @@ def login(): | |
| 47 | 
             
                        avatar = download_img(userinfo["avatar_url"])
         | 
| 48 | 
             
                    except Exception as e:
         | 
| 49 | 
             
                        stat_logger.exception(e)
         | 
|  | |
| 50 | 
             
                    try:
         | 
| 51 | 
            -
                        users = user_register({
         | 
| 52 | 
             
                            "access_token": session["access_token"],
         | 
| 53 | 
             
                            "email": userinfo["email"],
         | 
| 54 | 
             
                            "avatar": avatar,
         | 
| @@ -63,6 +67,7 @@ def login(): | |
| 63 | 
             
                        login_user(user)
         | 
| 64 | 
             
                        return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
         | 
| 65 | 
             
                    except Exception as e:
         | 
|  | |
| 66 | 
             
                        stat_logger.exception(e)
         | 
| 67 | 
             
                        return server_error_response(e)
         | 
| 68 | 
             
                elif not request.json:
         | 
| @@ -162,7 +167,25 @@ def user_info(): | |
| 162 | 
             
                return get_json_result(data=current_user.to_dict())
         | 
| 163 |  | 
| 164 |  | 
| 165 | 
            -
            def  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 166 | 
             
                user_id = get_uuid()
         | 
| 167 | 
             
                user["id"] = user_id
         | 
| 168 | 
             
                tenant = {
         | 
| @@ -180,10 +203,12 @@ def user_register(user): | |
| 180 | 
             
                    "invited_by": user_id,
         | 
| 181 | 
             
                    "role": UserTenantRole.OWNER
         | 
| 182 | 
             
                }
         | 
|  | |
| 183 |  | 
| 184 | 
             
                if not UserService.save(**user):return
         | 
| 185 | 
             
                TenantService.save(**tenant)
         | 
| 186 | 
             
                UserTenantService.save(**usr_tenant)
         | 
|  | |
| 187 | 
             
                return UserService.query(email=user["email"])
         | 
| 188 |  | 
| 189 |  | 
| @@ -203,14 +228,17 @@ def user_add(): | |
| 203 | 
             
                    "last_login_time": get_format_time(),
         | 
| 204 | 
             
                    "is_superuser": False,
         | 
| 205 | 
             
                }
         | 
|  | |
|  | |
| 206 | 
             
                try:
         | 
| 207 | 
            -
                    users = user_register(user_dict)
         | 
| 208 | 
             
                    if not users: raise Exception('Register user failure.')
         | 
| 209 | 
             
                    if len(users) > 1: raise Exception('Same E-mail exist!')
         | 
| 210 | 
             
                    user = users[0]
         | 
| 211 | 
             
                    login_user(user)
         | 
| 212 | 
             
                    return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
         | 
| 213 | 
             
                except Exception as e:
         | 
|  | |
| 214 | 
             
                    stat_logger.exception(e)
         | 
| 215 | 
             
                    return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
         | 
| 216 |  | 
| @@ -220,7 +248,7 @@ def user_add(): | |
| 220 | 
             
            @login_required
         | 
| 221 | 
             
            def tenant_info():
         | 
| 222 | 
             
                try:
         | 
| 223 | 
            -
                    tenants = TenantService.get_by_user_id(current_user.id)
         | 
| 224 | 
             
                    return get_json_result(data=tenants)
         | 
| 225 | 
             
                except Exception as e:
         | 
| 226 | 
             
                    return server_error_response(e)
         | 
|  | |
| 16 | 
             
            from flask import request, session, redirect, url_for
         | 
| 17 | 
             
            from werkzeug.security import generate_password_hash, check_password_hash
         | 
| 18 | 
             
            from flask_login import login_required, current_user, login_user, logout_user
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from web_server.db.db_models import TenantLLM
         | 
| 21 | 
            +
            from web_server.db.services.llm_service import TenantLLMService
         | 
| 22 | 
             
            from web_server.utils.api_utils import server_error_response, validate_request
         | 
| 23 | 
             
            from web_server.utils import get_uuid, get_format_time, decrypt, download_img
         | 
| 24 | 
            +
            from web_server.db import UserTenantRole, LLMType
         | 
| 25 | 
             
            from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
         | 
| 26 | 
             
            from web_server.db.services.user_service import UserService, TenantService, UserTenantService
         | 
| 27 | 
             
            from web_server.settings import stat_logger
         | 
|  | |
| 50 | 
             
                        avatar = download_img(userinfo["avatar_url"])
         | 
| 51 | 
             
                    except Exception as e:
         | 
| 52 | 
             
                        stat_logger.exception(e)
         | 
| 53 | 
            +
                    user_id = get_uuid()
         | 
| 54 | 
             
                    try:
         | 
| 55 | 
            +
                        users = user_register(user_id, {
         | 
| 56 | 
             
                            "access_token": session["access_token"],
         | 
| 57 | 
             
                            "email": userinfo["email"],
         | 
| 58 | 
             
                            "avatar": avatar,
         | 
|  | |
| 67 | 
             
                        login_user(user)
         | 
| 68 | 
             
                        return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
         | 
| 69 | 
             
                    except Exception as e:
         | 
| 70 | 
            +
                        rollback_user_registration(user_id)
         | 
| 71 | 
             
                        stat_logger.exception(e)
         | 
| 72 | 
             
                        return server_error_response(e)
         | 
| 73 | 
             
                elif not request.json:
         | 
|  | |
| 167 | 
             
                return get_json_result(data=current_user.to_dict())
         | 
| 168 |  | 
| 169 |  | 
| 170 | 
            +
            def rollback_user_registration(user_id):
         | 
| 171 | 
            +
                try:
         | 
| 172 | 
            +
                    TenantService.delete_by_id(user_id)
         | 
| 173 | 
            +
                except Exception as e:
         | 
| 174 | 
            +
                    pass
         | 
| 175 | 
            +
                try:
         | 
| 176 | 
            +
                    u = UserTenantService.query(tenant_id=user_id)
         | 
| 177 | 
            +
                    if u:
         | 
| 178 | 
            +
                        UserTenantService.delete_by_id(u[0].id)
         | 
| 179 | 
            +
                except Exception as e:
         | 
| 180 | 
            +
                    pass
         | 
| 181 | 
            +
                try:
         | 
| 182 | 
            +
                    TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
         | 
| 183 | 
            +
                except Exception as e:
         | 
| 184 | 
            +
                    pass
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                  
         | 
| 187 | 
            +
            def user_register(user_id, user):
         | 
| 188 | 
            +
             | 
| 189 | 
             
                user_id = get_uuid()
         | 
| 190 | 
             
                user["id"] = user_id
         | 
| 191 | 
             
                tenant = {
         | 
|  | |
| 203 | 
             
                    "invited_by": user_id,
         | 
| 204 | 
             
                    "role": UserTenantRole.OWNER
         | 
| 205 | 
             
                }
         | 
| 206 | 
            +
                tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
         | 
| 207 |  | 
| 208 | 
             
                if not UserService.save(**user):return
         | 
| 209 | 
             
                TenantService.save(**tenant)
         | 
| 210 | 
             
                UserTenantService.save(**usr_tenant)
         | 
| 211 | 
            +
                TenantLLMService.save(**tenant_llm)
         | 
| 212 | 
             
                return UserService.query(email=user["email"])
         | 
| 213 |  | 
| 214 |  | 
|  | |
| 228 | 
             
                    "last_login_time": get_format_time(),
         | 
| 229 | 
             
                    "is_superuser": False,
         | 
| 230 | 
             
                }
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                user_id = get_uuid()
         | 
| 233 | 
             
                try:
         | 
| 234 | 
            +
                    users = user_register(user_id, user_dict)
         | 
| 235 | 
             
                    if not users: raise Exception('Register user failure.')
         | 
| 236 | 
             
                    if len(users) > 1: raise Exception('Same E-mail exist!')
         | 
| 237 | 
             
                    user = users[0]
         | 
| 238 | 
             
                    login_user(user)
         | 
| 239 | 
             
                    return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
         | 
| 240 | 
             
                except Exception as e:
         | 
| 241 | 
            +
                    rollback_user_registration(user_id)
         | 
| 242 | 
             
                    stat_logger.exception(e)
         | 
| 243 | 
             
                    return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
         | 
| 244 |  | 
|  | |
| 248 | 
             
            @login_required
         | 
| 249 | 
             
            def tenant_info():
         | 
| 250 | 
             
                try:
         | 
| 251 | 
            +
                    tenants = TenantService.get_by_user_id(current_user.id)[0]
         | 
| 252 | 
             
                    return get_json_result(data=tenants)
         | 
| 253 | 
             
                except Exception as e:
         | 
| 254 | 
             
                    return server_error_response(e)
         | 
    	
        web_server/db/db_models.py
    CHANGED
    
    | @@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel): | |
| 428 | 
             
            class LLM(DataBaseModel):
         | 
| 429 | 
             
                # defautlt LLMs for every users
         | 
| 430 | 
             
                llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
         | 
|  | |
| 431 | 
             
                fid = CharField(max_length=128, null=False, help_text="LLM factory id")
         | 
| 432 | 
             
                tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
         | 
| 433 | 
             
                status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
         | 
| @@ -442,8 +443,8 @@ class LLM(DataBaseModel): | |
| 442 | 
             
            class TenantLLM(DataBaseModel):
         | 
| 443 | 
             
                tenant_id = CharField(max_length=32, null=False)
         | 
| 444 | 
             
                llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
         | 
| 445 | 
            -
                model_type = CharField(max_length=128, null= | 
| 446 | 
            -
                llm_name = CharField(max_length=128, null= | 
| 447 | 
             
                api_key = CharField(max_length=255, null=True, help_text="API KEY")
         | 
| 448 | 
             
                api_base = CharField(max_length=255, null=True, help_text="API Base")
         | 
| 449 |  | 
| @@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel): | |
| 452 |  | 
| 453 | 
             
                class Meta:
         | 
| 454 | 
             
                    db_table = "tenant_llm"
         | 
| 455 | 
            -
                    primary_key = CompositeKey('tenant_id', 'llm_factory')
         | 
| 456 |  | 
| 457 |  | 
| 458 | 
             
            class Knowledgebase(DataBaseModel):
         | 
| @@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel): | |
| 464 | 
             
                permission = CharField(max_length=16, null=False, help_text="me|team")
         | 
| 465 | 
             
                created_by = CharField(max_length=32, null=False)
         | 
| 466 | 
             
                doc_num = IntegerField(default=0)
         | 
| 467 | 
            -
                 | 
|  | |
|  | |
| 468 | 
             
                parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
         | 
| 469 | 
             
                status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
         | 
| 470 |  | 
|  | |
| 428 | 
             
            class LLM(DataBaseModel):
         | 
| 429 | 
             
                # defautlt LLMs for every users
         | 
| 430 | 
             
                llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
         | 
| 431 | 
            +
                model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
         | 
| 432 | 
             
                fid = CharField(max_length=128, null=False, help_text="LLM factory id")
         | 
| 433 | 
             
                tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
         | 
| 434 | 
             
                status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
         | 
|  | |
| 443 | 
             
            class TenantLLM(DataBaseModel):
         | 
| 444 | 
             
                tenant_id = CharField(max_length=32, null=False)
         | 
| 445 | 
             
                llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
         | 
| 446 | 
            +
                model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
         | 
| 447 | 
            +
                llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
         | 
| 448 | 
             
                api_key = CharField(max_length=255, null=True, help_text="API KEY")
         | 
| 449 | 
             
                api_base = CharField(max_length=255, null=True, help_text="API Base")
         | 
| 450 |  | 
|  | |
| 453 |  | 
| 454 | 
             
                class Meta:
         | 
| 455 | 
             
                    db_table = "tenant_llm"
         | 
| 456 | 
            +
                    primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
         | 
| 457 |  | 
| 458 |  | 
| 459 | 
             
            class Knowledgebase(DataBaseModel):
         | 
|  | |
| 465 | 
             
                permission = CharField(max_length=16, null=False, help_text="me|team")
         | 
| 466 | 
             
                created_by = CharField(max_length=32, null=False)
         | 
| 467 | 
             
                doc_num = IntegerField(default=0)
         | 
| 468 | 
            +
                token_num = IntegerField(default=0)
         | 
| 469 | 
            +
                chunk_num = IntegerField(default=0)
         | 
| 470 | 
            +
             | 
| 471 | 
             
                parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
         | 
| 472 | 
             
                status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
         | 
| 473 |  | 
    	
        web_server/db/services/document_service.py
    CHANGED
    
    | @@ -13,12 +13,13 @@ | |
| 13 | 
             
            #  See the License for the specific language governing permissions and
         | 
| 14 | 
             
            #  limitations under the License.
         | 
| 15 | 
             
            #
         | 
|  | |
|  | |
| 16 | 
             
            from web_server.db import TenantPermission, FileType
         | 
| 17 | 
            -
            from web_server.db.db_models import DB, Knowledgebase
         | 
| 18 | 
             
            from web_server.db.db_models import Document
         | 
| 19 | 
             
            from web_server.db.services.common_service import CommonService
         | 
| 20 | 
             
            from web_server.db.services.kb_service import KnowledgebaseService
         | 
| 21 | 
            -
            from web_server.utils import get_uuid, get_format_time
         | 
| 22 | 
             
            from web_server.db.db_utils import StatusEnum
         | 
| 23 |  | 
| 24 |  | 
| @@ -61,15 +62,28 @@ class DocumentService(CommonService): | |
| 61 | 
             
                @classmethod
         | 
| 62 | 
             
                @DB.connection_context()
         | 
| 63 | 
             
                def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
         | 
| 64 | 
            -
                    fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
         | 
| 65 | 
            -
                    docs = cls.model.select(fields) | 
| 66 | 
            -
                        cls.model. | 
| 67 | 
            -
                         | 
| 68 | 
            -
                         | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
                         | 
|  | |
| 75 | 
             
                    return list(docs.dicts())
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 13 | 
             
            #  See the License for the specific language governing permissions and
         | 
| 14 | 
             
            #  limitations under the License.
         | 
| 15 | 
             
            #
         | 
| 16 | 
            +
            from peewee import Expression
         | 
| 17 | 
            +
             | 
| 18 | 
             
            from web_server.db import TenantPermission, FileType
         | 
| 19 | 
            +
            from web_server.db.db_models import DB, Knowledgebase, Tenant
         | 
| 20 | 
             
            from web_server.db.db_models import Document
         | 
| 21 | 
             
            from web_server.db.services.common_service import CommonService
         | 
| 22 | 
             
            from web_server.db.services.kb_service import KnowledgebaseService
         | 
|  | |
| 23 | 
             
            from web_server.db.db_utils import StatusEnum
         | 
| 24 |  | 
| 25 |  | 
|  | |
| 62 | 
             
                @classmethod
         | 
| 63 | 
             
                @DB.connection_context()
         | 
| 64 | 
             
                def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
         | 
| 65 | 
            +
                    fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
         | 
| 66 | 
            +
                    docs = cls.model.select(*fields) \
         | 
| 67 | 
            +
                        .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
         | 
| 68 | 
            +
                        .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
         | 
| 69 | 
            +
                        .where(
         | 
| 70 | 
            +
                            cls.model.status == StatusEnum.VALID.value,
         | 
| 71 | 
            +
                            ~(cls.model.type == FileType.VIRTUAL.value),
         | 
| 72 | 
            +
                            cls.model.progress == 0,
         | 
| 73 | 
            +
                            cls.model.update_time >= tm,
         | 
| 74 | 
            +
                            (Expression(cls.model.create_time, "%%", comm) == mod))\
         | 
| 75 | 
            +
                        .order_by(cls.model.update_time.asc())\
         | 
| 76 | 
            +
                        .paginate(1, items_per_page)
         | 
| 77 | 
             
                    return list(docs.dicts())
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                @classmethod
         | 
| 80 | 
            +
                @DB.connection_context()
         | 
| 81 | 
            +
                def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
         | 
| 82 | 
            +
                    num = cls.model.update(token_num=cls.model.token_num + token_num,
         | 
| 83 | 
            +
                                               chunk_num=cls.model.chunk_num + chunk_num,
         | 
| 84 | 
            +
                                               process_duation=cls.model.process_duation+duation).where(
         | 
| 85 | 
            +
                        cls.model.id == doc_id).execute()
         | 
| 86 | 
            +
                    if num == 0:raise LookupError("Document not found which is supposed to be there")
         | 
| 87 | 
            +
                    num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
         | 
| 88 | 
            +
                    return num
         | 
| 89 | 
            +
             | 
    	
        web_server/db/services/kb_service.py
    CHANGED
    
    | @@ -17,7 +17,7 @@ import peewee | |
| 17 | 
             
            from werkzeug.security import generate_password_hash, check_password_hash
         | 
| 18 |  | 
| 19 | 
             
            from web_server.db import TenantPermission
         | 
| 20 | 
            -
            from web_server.db.db_models import DB, UserTenant
         | 
| 21 | 
             
            from web_server.db.db_models import Knowledgebase
         | 
| 22 | 
             
            from web_server.db.services.common_service import CommonService
         | 
| 23 | 
             
            from web_server.utils import get_uuid, get_format_time
         | 
| @@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService): | |
| 29 |  | 
| 30 | 
             
                @classmethod
         | 
| 31 | 
             
                @DB.connection_context()
         | 
| 32 | 
            -
                def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | 
|  | |
| 33 | 
             
                    kbs = cls.model.select().where(
         | 
| 34 | 
            -
                        ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == | 
| 35 | 
            -
             | 
|  | |
| 36 | 
             
                    )
         | 
| 37 | 
            -
                    if desc: | 
| 38 | 
            -
             | 
|  | |
|  | |
| 39 |  | 
| 40 | 
             
                    kbs = kbs.paginate(page_number, items_per_page)
         | 
| 41 |  | 
| 42 | 
             
                    return list(kbs.dicts())
         | 
| 43 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 17 | 
             
            from werkzeug.security import generate_password_hash, check_password_hash
         | 
| 18 |  | 
| 19 | 
             
            from web_server.db import TenantPermission
         | 
| 20 | 
            +
            from web_server.db.db_models import DB, UserTenant, Tenant
         | 
| 21 | 
             
            from web_server.db.db_models import Knowledgebase
         | 
| 22 | 
             
            from web_server.db.services.common_service import CommonService
         | 
| 23 | 
             
            from web_server.utils import get_uuid, get_format_time
         | 
|  | |
| 29 |  | 
| 30 | 
             
                @classmethod
         | 
| 31 | 
             
                @DB.connection_context()
         | 
| 32 | 
            +
                def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
         | 
| 33 | 
            +
                                      page_number, items_per_page, orderby, desc):
         | 
| 34 | 
             
                    kbs = cls.model.select().where(
         | 
| 35 | 
            +
                        ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
         | 
| 36 | 
            +
                         TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
         | 
| 37 | 
            +
                        & (cls.model.status == StatusEnum.VALID.value)
         | 
| 38 | 
             
                    )
         | 
| 39 | 
            +
                    if desc:
         | 
| 40 | 
            +
                        kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
         | 
| 41 | 
            +
                    else:
         | 
| 42 | 
            +
                        kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
         | 
| 43 |  | 
| 44 | 
             
                    kbs = kbs.paginate(page_number, items_per_page)
         | 
| 45 |  | 
| 46 | 
             
                    return list(kbs.dicts())
         | 
| 47 |  | 
| 48 | 
            +
                @classmethod
         | 
| 49 | 
            +
                @DB.connection_context()
         | 
| 50 | 
            +
                def get_detail(cls, kb_id):
         | 
| 51 | 
            +
                    fields = [
         | 
| 52 | 
            +
                        cls.model.id,
         | 
| 53 | 
            +
                        Tenant.embd_id,
         | 
| 54 | 
            +
                        cls.model.avatar,
         | 
| 55 | 
            +
                        cls.model.name,
         | 
| 56 | 
            +
                        cls.model.description,
         | 
| 57 | 
            +
                        cls.model.permission,
         | 
| 58 | 
            +
                        cls.model.doc_num,
         | 
| 59 | 
            +
                        cls.model.token_num,
         | 
| 60 | 
            +
                        cls.model.chunk_num,
         | 
| 61 | 
            +
                        cls.model.parser_id]
         | 
| 62 | 
            +
                    kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
         | 
| 63 | 
            +
                        (cls.model.id == kb_id),
         | 
| 64 | 
            +
                        (cls.model.status == StatusEnum.VALID.value)
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    if not kbs:
         | 
| 67 | 
            +
                        return
         | 
| 68 | 
            +
                    d = kbs[0].to_dict()
         | 
| 69 | 
            +
                    d["embd_id"] = kbs[0].tenant.embd_id
         | 
| 70 | 
            +
                    return d
         | 
    	
        web_server/db/services/llm_service.py
    CHANGED
    
    | @@ -33,3 +33,21 @@ class LLMService(CommonService): | |
| 33 |  | 
| 34 | 
             
            class TenantLLMService(CommonService):
         | 
| 35 | 
             
                model = TenantLLM
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 |  | 
| 34 | 
             
            class TenantLLMService(CommonService):
         | 
| 35 | 
             
                model = TenantLLM
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                @classmethod
         | 
| 38 | 
            +
                @DB.connection_context()
         | 
| 39 | 
            +
                def get_api_key(cls, tenant_id, model_type):
         | 
| 40 | 
            +
                    objs = cls.query(tenant_id=tenant_id, model_type=model_type)
         | 
| 41 | 
            +
                    if objs and len(objs)>0 and objs[0].llm_name:
         | 
| 42 | 
            +
                        return objs[0]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
         | 
| 45 | 
            +
                    objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
         | 
| 46 | 
            +
                        (cls.model.tenant_id == tenant_id),
         | 
| 47 | 
            +
                        (cls.model.model_type == model_type),
         | 
| 48 | 
            +
                        (LLM.status == StatusEnum.VALID)
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    if not objs:return
         | 
| 52 | 
            +
                    return objs[0]
         | 
| 53 | 
            +
             | 
    	
        web_server/db/services/user_service.py
    CHANGED
    
    | @@ -79,7 +79,7 @@ class TenantService(CommonService): | |
| 79 | 
             
                @classmethod
         | 
| 80 | 
             
                @DB.connection_context()
         | 
| 81 | 
             
                def get_by_user_id(cls, user_id):
         | 
| 82 | 
            -
                    fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
         | 
| 83 | 
             
                    return list(cls.model.select(*fields)\
         | 
| 84 | 
             
                        .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
         | 
| 85 | 
             
                        .where(cls.model.status == StatusEnum.VALID.value).dicts())
         | 
|  | |
| 79 | 
             
                @classmethod
         | 
| 80 | 
             
                @DB.connection_context()
         | 
| 81 | 
             
                def get_by_user_id(cls, user_id):
         | 
| 82 | 
            +
                    fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
         | 
| 83 | 
             
                    return list(cls.model.select(*fields)\
         | 
| 84 | 
             
                        .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
         | 
| 85 | 
             
                        .where(cls.model.status == StatusEnum.VALID.value).dicts())
         | 
    	
        web_server/utils/file_utils.py
    CHANGED
    
    | @@ -143,7 +143,7 @@ def filename_type(filename): | |
| 143 | 
             
                if re.match(r".*\.pdf$", filename):
         | 
| 144 | 
             
                    return FileType.PDF.value
         | 
| 145 |  | 
| 146 | 
            -
                if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
         | 
| 147 | 
             
                    return FileType.DOC.value
         | 
| 148 |  | 
| 149 | 
             
                if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
         | 
|  | |
| 143 | 
             
                if re.match(r".*\.pdf$", filename):
         | 
| 144 | 
             
                    return FileType.PDF.value
         | 
| 145 |  | 
| 146 | 
            +
                if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
         | 
| 147 | 
             
                    return FileType.DOC.value
         | 
| 148 |  | 
| 149 | 
             
                if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
         |