KevinHuSh
commited on
Commit
·
1550520
1
Parent(s):
2fd3125
add local llm implementation (#119)
Browse files- Dockerfile +1 -1
- README.md +3 -3
- api/apps/__init__.py +1 -1
- api/apps/llm_app.py +1 -1
- api/db/db_models.py +1 -1
- api/db/services/knowledgebase_service.py +1 -0
- api/settings.py +8 -2
- deepdoc/parser/excel_parser.py +15 -1
- docker/nginx/nginx.conf +1 -1
- rag/app/table.py +10 -7
- rag/llm/__init__.py +6 -3
- rag/llm/chat_model.py +40 -2
- rag/llm/cv_model.py +8 -0
- rag/llm/rpc_server.py +90 -0
- rag/settings.py +1 -1
- rag/svr/task_broker.py +8 -0
- rag/svr/task_executor.py +1 -1
Dockerfile
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
FROM infiniflow/ragflow-base:v1.0
|
2 |
USER root
|
3 |
|
4 |
WORKDIR /ragflow
|
|
|
1 |
+
FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
|
2 |
USER root
|
3 |
|
4 |
WORKDIR /ragflow
|
README.md
CHANGED
@@ -21,7 +21,7 @@
|
|
21 |
</a>
|
22 |
</p>
|
23 |
|
24 |
-
[
|
25 |
with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management
|
26 |
platform to empower your business with AI.
|
27 |
|
@@ -29,12 +29,12 @@ platform to empower your business with AI.
|
|
29 |
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/>
|
30 |
</div>
|
31 |
|
32 |
-
# Features
|
33 |
- **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain.
|
34 |
- For documents from different domain for different purpose, the engine applys different analyzing and search strategy.
|
35 |
- Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation.
|
36 |
- Multi-media document understanding is supported using OCR and multi-modal LLM.
|
37 |
-
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. [README](./deepdoc/README.md)
|
38 |
- For PDF files, layout and table structures including row, column and span of them are recognized.
|
39 |
- Put the table accrossing the pages together.
|
40 |
- Reconstruct the table structure components into html table.
|
|
|
21 |
</a>
|
22 |
</p>
|
23 |
|
24 |
+
[RagFlow](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
|
25 |
with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management
|
26 |
platform to empower your business with AI.
|
27 |
|
|
|
29 |
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/>
|
30 |
</div>
|
31 |
|
32 |
+
# Key Features
|
33 |
- **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain.
|
34 |
- For documents from different domain for different purpose, the engine applys different analyzing and search strategy.
|
35 |
- Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation.
|
36 |
- Multi-media document understanding is supported using OCR and multi-modal LLM.
|
37 |
+
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. See [README.](./deepdoc/README.md)
|
38 |
- For PDF files, layout and table structures including row, column and span of them are recognized.
|
39 |
- Put the table accrossing the pages together.
|
40 |
- Reconstruct the table structure components into html table.
|
api/apps/__init__.py
CHANGED
@@ -52,7 +52,7 @@ app.errorhandler(Exception)(server_error_response)
|
|
52 |
#app.config["LOGIN_DISABLED"] = True
|
53 |
app.config["SESSION_PERMANENT"] = False
|
54 |
app.config["SESSION_TYPE"] = "filesystem"
|
55 |
-
app.config['MAX_CONTENT_LENGTH'] =
|
56 |
|
57 |
Session(app)
|
58 |
login_manager = LoginManager()
|
|
|
52 |
#app.config["LOGIN_DISABLED"] = True
|
53 |
app.config["SESSION_PERMANENT"] = False
|
54 |
app.config["SESSION_TYPE"] = "filesystem"
|
55 |
+
app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024
|
56 |
|
57 |
Session(app)
|
58 |
login_manager = LoginManager()
|
api/apps/llm_app.py
CHANGED
@@ -85,7 +85,7 @@ def my_llms():
|
|
85 |
}
|
86 |
res[o["llm_factory"]]["llm"].append({
|
87 |
"type": o["model_type"],
|
88 |
-
"name": o["
|
89 |
"used_token": o["used_tokens"]
|
90 |
})
|
91 |
return get_json_result(data=res)
|
|
|
85 |
}
|
86 |
res[o["llm_factory"]]["llm"].append({
|
87 |
"type": o["model_type"],
|
88 |
+
"name": o["llm_name"],
|
89 |
"used_token": o["used_tokens"]
|
90 |
})
|
91 |
return get_json_result(data=res)
|
api/db/db_models.py
CHANGED
@@ -520,7 +520,7 @@ class Task(DataBaseModel):
|
|
520 |
begin_at = DateTimeField(null=True)
|
521 |
process_duation = FloatField(default=0)
|
522 |
progress = FloatField(default=0)
|
523 |
-
progress_msg =
|
524 |
|
525 |
|
526 |
class Dialog(DataBaseModel):
|
|
|
520 |
begin_at = DateTimeField(null=True)
|
521 |
process_duation = FloatField(default=0)
|
522 |
progress = FloatField(default=0)
|
523 |
+
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
|
524 |
|
525 |
|
526 |
class Dialog(DataBaseModel):
|
api/db/services/knowledgebase_service.py
CHANGED
@@ -47,6 +47,7 @@ class KnowledgebaseService(CommonService):
|
|
47 |
Tenant.embd_id,
|
48 |
cls.model.avatar,
|
49 |
cls.model.name,
|
|
|
50 |
cls.model.description,
|
51 |
cls.model.permission,
|
52 |
cls.model.doc_num,
|
|
|
47 |
Tenant.embd_id,
|
48 |
cls.model.avatar,
|
49 |
cls.model.name,
|
50 |
+
cls.model.language,
|
51 |
cls.model.description,
|
52 |
cls.model.permission,
|
53 |
cls.model.doc_num,
|
api/settings.py
CHANGED
@@ -42,7 +42,7 @@ ERROR_REPORT = True
|
|
42 |
ERROR_REPORT_WITH_PATH = False
|
43 |
|
44 |
MAX_TIMESTAMP_INTERVAL = 60
|
45 |
-
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
|
46 |
|
47 |
REQUEST_TRY_TIMES = 3
|
48 |
REQUEST_WAIT_SEC = 2
|
@@ -69,6 +69,12 @@ default_llm = {
|
|
69 |
"image2text_model": "glm-4v",
|
70 |
"asr_model": "",
|
71 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
}
|
73 |
LLM = get_base_config("user_default_llm", {})
|
74 |
LLM_FACTORY = LLM.get("factory", "通义千问")
|
@@ -134,7 +140,7 @@ USE_AUTHENTICATION = False
|
|
134 |
USE_DATA_AUTHENTICATION = False
|
135 |
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
|
136 |
USE_DEFAULT_TIMEOUT = False
|
137 |
-
AUTHENTICATION_DEFAULT_TIMEOUT =
|
138 |
PRIVILEGE_COMMAND_WHITELIST = []
|
139 |
CHECK_NODES_IDENTITY = False
|
140 |
|
|
|
42 |
ERROR_REPORT_WITH_PATH = False
|
43 |
|
44 |
MAX_TIMESTAMP_INTERVAL = 60
|
45 |
+
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
|
46 |
|
47 |
REQUEST_TRY_TIMES = 3
|
48 |
REQUEST_WAIT_SEC = 2
|
|
|
69 |
"image2text_model": "glm-4v",
|
70 |
"asr_model": "",
|
71 |
},
|
72 |
+
"local": {
|
73 |
+
"chat_model": "",
|
74 |
+
"embedding_model": "",
|
75 |
+
"image2text_model": "",
|
76 |
+
"asr_model": "",
|
77 |
+
}
|
78 |
}
|
79 |
LLM = get_base_config("user_default_llm", {})
|
80 |
LLM_FACTORY = LLM.get("factory", "通义千问")
|
|
|
140 |
USE_DATA_AUTHENTICATION = False
|
141 |
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
|
142 |
USE_DEFAULT_TIMEOUT = False
|
143 |
+
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
144 |
PRIVILEGE_COMMAND_WHITELIST = []
|
145 |
CHECK_NODES_IDENTITY = False
|
146 |
|
deepdoc/parser/excel_parser.py
CHANGED
@@ -20,13 +20,27 @@ class HuExcelParser:
|
|
20 |
for i,c in enumerate(r):
|
21 |
if not c.value:continue
|
22 |
t = str(ti[i].value) if i < len(ti) else ""
|
23 |
-
t += (":"
|
24 |
l.append(t)
|
25 |
l = "; ".join(l)
|
26 |
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
|
27 |
res.append(l)
|
28 |
return res
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
if __name__ == "__main__":
|
32 |
psr = HuExcelParser()
|
|
|
20 |
for i,c in enumerate(r):
|
21 |
if not c.value:continue
|
22 |
t = str(ti[i].value) if i < len(ti) else ""
|
23 |
+
t += (":" if t else "") + str(c.value)
|
24 |
l.append(t)
|
25 |
l = "; ".join(l)
|
26 |
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
|
27 |
res.append(l)
|
28 |
return res
|
29 |
|
30 |
+
@staticmethod
|
31 |
+
def row_number(fnm, binary):
|
32 |
+
if fnm.split(".")[-1].lower().find("xls") >= 0:
|
33 |
+
wb = load_workbook(BytesIO(binary))
|
34 |
+
total = 0
|
35 |
+
for sheetname in wb.sheetnames:
|
36 |
+
ws = wb[sheetname]
|
37 |
+
total += len(ws.rows)
|
38 |
+
return total
|
39 |
+
|
40 |
+
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
|
41 |
+
txt = binary.decode("utf-8")
|
42 |
+
return len(txt.split("\n"))
|
43 |
+
|
44 |
|
45 |
if __name__ == "__main__":
|
46 |
psr = HuExcelParser()
|
docker/nginx/nginx.conf
CHANGED
@@ -26,7 +26,7 @@ http {
|
|
26 |
keepalive_timeout 65;
|
27 |
|
28 |
#gzip on;
|
29 |
-
client_max_body_size
|
30 |
|
31 |
include /etc/nginx/conf.d/ragflow.conf;
|
32 |
}
|
|
|
26 |
keepalive_timeout 65;
|
27 |
|
28 |
#gzip on;
|
29 |
+
client_max_body_size 128M;
|
30 |
|
31 |
include /etc/nginx/conf.d/ragflow.conf;
|
32 |
}
|
rag/app/table.py
CHANGED
@@ -25,7 +25,7 @@ from deepdoc.parser import ExcelParser
|
|
25 |
|
26 |
|
27 |
class Excel(ExcelParser):
|
28 |
-
def __call__(self, fnm, binary=None, callback=None):
|
29 |
if not binary:
|
30 |
wb = load_workbook(fnm)
|
31 |
else:
|
@@ -35,6 +35,7 @@ class Excel(ExcelParser):
|
|
35 |
total += len(list(wb[sheetname].rows))
|
36 |
|
37 |
res, fails, done = [], [], 0
|
|
|
38 |
for sheetname in wb.sheetnames:
|
39 |
ws = wb[sheetname]
|
40 |
rows = list(ws.rows)
|
@@ -46,6 +47,9 @@ class Excel(ExcelParser):
|
|
46 |
rows[0]) if i not in missed]
|
47 |
data = []
|
48 |
for i, r in enumerate(rows[1:]):
|
|
|
|
|
|
|
49 |
row = [
|
50 |
cell.value for ii,
|
51 |
cell in enumerate(r) if ii not in missed]
|
@@ -111,7 +115,7 @@ def column_data_type(arr):
|
|
111 |
return arr, ty
|
112 |
|
113 |
|
114 |
-
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
115 |
"""
|
116 |
Excel and csv(txt) format files are supported.
|
117 |
For csv or txt file, the delimiter between columns is TAB.
|
@@ -147,16 +151,15 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
147 |
headers = lines[0].split(kwargs.get("delimiter", "\t"))
|
148 |
rows = []
|
149 |
for i, line in enumerate(lines[1:]):
|
|
|
|
|
150 |
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
|
151 |
if len(row) != len(headers):
|
152 |
fails.append(str(i))
|
153 |
continue
|
154 |
rows.append(row)
|
155 |
-
if len(rows) % 999 == 0:
|
156 |
-
callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
|
157 |
-
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
158 |
|
159 |
-
callback(0.
|
160 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
161 |
|
162 |
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
|
@@ -209,7 +212,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
209 |
|
210 |
KnowledgebaseService.update_parser_config(
|
211 |
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
|
212 |
-
callback(0.
|
213 |
|
214 |
return res
|
215 |
|
|
|
25 |
|
26 |
|
27 |
class Excel(ExcelParser):
|
28 |
+
def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
|
29 |
if not binary:
|
30 |
wb = load_workbook(fnm)
|
31 |
else:
|
|
|
35 |
total += len(list(wb[sheetname].rows))
|
36 |
|
37 |
res, fails, done = [], [], 0
|
38 |
+
rn = 0
|
39 |
for sheetname in wb.sheetnames:
|
40 |
ws = wb[sheetname]
|
41 |
rows = list(ws.rows)
|
|
|
47 |
rows[0]) if i not in missed]
|
48 |
data = []
|
49 |
for i, r in enumerate(rows[1:]):
|
50 |
+
rn += 1
|
51 |
+
if rn-1 < from_page:continue
|
52 |
+
if rn -1>=to_page: break
|
53 |
row = [
|
54 |
cell.value for ii,
|
55 |
cell in enumerate(r) if ii not in missed]
|
|
|
115 |
return arr, ty
|
116 |
|
117 |
|
118 |
+
def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
|
119 |
"""
|
120 |
Excel and csv(txt) format files are supported.
|
121 |
For csv or txt file, the delimiter between columns is TAB.
|
|
|
151 |
headers = lines[0].split(kwargs.get("delimiter", "\t"))
|
152 |
rows = []
|
153 |
for i, line in enumerate(lines[1:]):
|
154 |
+
if from_page < from_page:continue
|
155 |
+
if i >= to_page: break
|
156 |
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
|
157 |
if len(row) != len(headers):
|
158 |
fails.append(str(i))
|
159 |
continue
|
160 |
rows.append(row)
|
|
|
|
|
|
|
161 |
|
162 |
+
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
|
163 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
164 |
|
165 |
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
|
|
|
212 |
|
213 |
KnowledgebaseService.update_parser_config(
|
214 |
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
|
215 |
+
callback(0.35, "")
|
216 |
|
217 |
return res
|
218 |
|
rag/llm/__init__.py
CHANGED
@@ -19,22 +19,25 @@ from .cv_model import *
|
|
19 |
|
20 |
|
21 |
EmbeddingModel = {
|
22 |
-
"
|
23 |
"OpenAI": OpenAIEmbed,
|
24 |
"通义千问": HuEmbedding, #QWenEmbed,
|
|
|
25 |
}
|
26 |
|
27 |
|
28 |
CvModel = {
|
29 |
"OpenAI": GptV4,
|
30 |
-
"
|
31 |
"通义千问": QWenCV,
|
|
|
32 |
}
|
33 |
|
34 |
|
35 |
ChatModel = {
|
36 |
"OpenAI": GptTurbo,
|
37 |
-
"
|
38 |
"通义千问": QWenChat,
|
|
|
39 |
}
|
40 |
|
|
|
19 |
|
20 |
|
21 |
EmbeddingModel = {
|
22 |
+
"local": HuEmbedding,
|
23 |
"OpenAI": OpenAIEmbed,
|
24 |
"通义千问": HuEmbedding, #QWenEmbed,
|
25 |
+
"智谱AI": ZhipuEmbed
|
26 |
}
|
27 |
|
28 |
|
29 |
CvModel = {
|
30 |
"OpenAI": GptV4,
|
31 |
+
"local": LocalCV,
|
32 |
"通义千问": QWenCV,
|
33 |
+
"智谱AI": Zhipu4V
|
34 |
}
|
35 |
|
36 |
|
37 |
ChatModel = {
|
38 |
"OpenAI": GptTurbo,
|
39 |
+
"智谱AI": ZhipuChat,
|
40 |
"通义千问": QWenChat,
|
41 |
+
"local": LocalLLM
|
42 |
}
|
43 |
|
rag/llm/chat_model.py
CHANGED
@@ -20,6 +20,7 @@ from openai import OpenAI
|
|
20 |
import openai
|
21 |
|
22 |
from rag.nlp import is_english
|
|
|
23 |
|
24 |
|
25 |
class Base(ABC):
|
@@ -86,7 +87,6 @@ class ZhipuChat(Base):
|
|
86 |
self.model_name = model_name
|
87 |
|
88 |
def chat(self, system, history, gen_conf):
|
89 |
-
from http import HTTPStatus
|
90 |
if system: history.insert(0, {"role": "system", "content": system})
|
91 |
try:
|
92 |
response = self.client.chat.completions.create(
|
@@ -100,4 +100,42 @@ class ZhipuChat(Base):
|
|
100 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
101 |
return ans, response.usage.completion_tokens
|
102 |
except Exception as e:
|
103 |
-
return "**ERROR**: " + str(e), 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
import openai
|
21 |
|
22 |
from rag.nlp import is_english
|
23 |
+
from rag.utils import num_tokens_from_string
|
24 |
|
25 |
|
26 |
class Base(ABC):
|
|
|
87 |
self.model_name = model_name
|
88 |
|
89 |
def chat(self, system, history, gen_conf):
|
|
|
90 |
if system: history.insert(0, {"role": "system", "content": system})
|
91 |
try:
|
92 |
response = self.client.chat.completions.create(
|
|
|
100 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
101 |
return ans, response.usage.completion_tokens
|
102 |
except Exception as e:
|
103 |
+
return "**ERROR**: " + str(e), 0
|
104 |
+
|
105 |
+
class LocalLLM(Base):
|
106 |
+
class RPCProxy:
|
107 |
+
def __init__(self, host, port):
|
108 |
+
self.host = host
|
109 |
+
self.port = int(port)
|
110 |
+
self.__conn()
|
111 |
+
|
112 |
+
def __conn(self):
|
113 |
+
from multiprocessing.connection import Client
|
114 |
+
self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
|
115 |
+
|
116 |
+
def __getattr__(self, name):
|
117 |
+
import pickle
|
118 |
+
def do_rpc(*args, **kwargs):
|
119 |
+
for _ in range(3):
|
120 |
+
try:
|
121 |
+
self._connection.send(pickle.dumps((name, args, kwargs)))
|
122 |
+
return pickle.loads(self._connection.recv())
|
123 |
+
except Exception as e:
|
124 |
+
self.__conn()
|
125 |
+
raise Exception("RPC connection lost!")
|
126 |
+
|
127 |
+
return do_rpc
|
128 |
+
|
129 |
+
def __init__(self, key, model_name="glm-3-turbo"):
|
130 |
+
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
|
131 |
+
|
132 |
+
def chat(self, system, history, gen_conf):
|
133 |
+
if system: history.insert(0, {"role": "system", "content": system})
|
134 |
+
try:
|
135 |
+
ans = self.client.chat(
|
136 |
+
history,
|
137 |
+
gen_conf
|
138 |
+
)
|
139 |
+
return ans, num_tokens_from_string(ans)
|
140 |
+
except Exception as e:
|
141 |
+
return "**ERROR**: " + str(e), 0
|
rag/llm/cv_model.py
CHANGED
@@ -138,3 +138,11 @@ class Zhipu4V(Base):
|
|
138 |
max_tokens=max_tokens,
|
139 |
)
|
140 |
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
max_tokens=max_tokens,
|
139 |
)
|
140 |
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
141 |
+
|
142 |
+
|
143 |
+
class LocalCV(Base):
|
144 |
+
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
|
145 |
+
pass
|
146 |
+
|
147 |
+
def describe(self, image, max_tokens=1024):
|
148 |
+
return "", 0
|
rag/llm/rpc_server.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pickle
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
from multiprocessing.connection import Listener
|
6 |
+
from threading import Thread
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class RPCHandler:
|
11 |
+
def __init__(self):
|
12 |
+
self._functions = { }
|
13 |
+
|
14 |
+
def register_function(self, func):
|
15 |
+
self._functions[func.__name__] = func
|
16 |
+
|
17 |
+
def handle_connection(self, connection):
|
18 |
+
try:
|
19 |
+
while True:
|
20 |
+
# Receive a message
|
21 |
+
func_name, args, kwargs = pickle.loads(connection.recv())
|
22 |
+
# Run the RPC and send a response
|
23 |
+
try:
|
24 |
+
r = self._functions[func_name](*args,**kwargs)
|
25 |
+
connection.send(pickle.dumps(r))
|
26 |
+
except Exception as e:
|
27 |
+
connection.send(pickle.dumps(e))
|
28 |
+
except EOFError:
|
29 |
+
pass
|
30 |
+
|
31 |
+
|
32 |
+
def rpc_server(hdlr, address, authkey):
|
33 |
+
sock = Listener(address, authkey=authkey)
|
34 |
+
while True:
|
35 |
+
try:
|
36 |
+
client = sock.accept()
|
37 |
+
t = Thread(target=hdlr.handle_connection, args=(client,))
|
38 |
+
t.daemon = True
|
39 |
+
t.start()
|
40 |
+
except Exception as e:
|
41 |
+
print("【EXCEPTION】:", str(e))
|
42 |
+
|
43 |
+
|
44 |
+
models = []
|
45 |
+
tokenizer = None
|
46 |
+
|
47 |
+
def chat(messages, gen_conf):
|
48 |
+
global tokenizer
|
49 |
+
model = Model()
|
50 |
+
roles = {"system":"System", "user": "User", "assistant": "Assistant"}
|
51 |
+
line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
|
52 |
+
line = "\n".join(line) + "\nAssistant: "
|
53 |
+
tokens = tokenizer([line], return_tensors='pt')
|
54 |
+
tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
|
55 |
+
tokens.keys()}
|
56 |
+
res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
|
57 |
+
return res.split("Assistant: ")[-1]
|
58 |
+
|
59 |
+
|
60 |
+
def Model():
|
61 |
+
global models
|
62 |
+
random.seed(time.time())
|
63 |
+
return random.choice(models)
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument("--model_name", type=str, help="Model name")
|
68 |
+
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
|
69 |
+
args = parser.parse_args()
|
70 |
+
|
71 |
+
handler = RPCHandler()
|
72 |
+
handler.register_function(chat)
|
73 |
+
|
74 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
75 |
+
from transformers.generation.utils import GenerationConfig
|
76 |
+
|
77 |
+
models = []
|
78 |
+
for _ in range(2):
|
79 |
+
m = AutoModelForCausalLM.from_pretrained(args.model_name,
|
80 |
+
device_map="auto",
|
81 |
+
torch_dtype='auto',
|
82 |
+
trust_remote_code=True)
|
83 |
+
m.generation_config = GenerationConfig.from_pretrained(args.model_name)
|
84 |
+
m.generation_config.pad_token_id = m.generation_config.eos_token_id
|
85 |
+
models.append(m)
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
|
87 |
+
trust_remote_code=True)
|
88 |
+
|
89 |
+
# Run the server
|
90 |
+
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
|
rag/settings.py
CHANGED
@@ -25,7 +25,7 @@ SUBPROCESS_STD_LOG_NAME = "std.log"
|
|
25 |
|
26 |
ES = get_base_config("es", {})
|
27 |
MINIO = decrypt_database_config(name="minio")
|
28 |
-
DOC_MAXIMUM_SIZE =
|
29 |
|
30 |
# Logger
|
31 |
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
|
|
|
25 |
|
26 |
ES = get_base_config("es", {})
|
27 |
MINIO = decrypt_database_config(name="minio")
|
28 |
+
DOC_MAXIMUM_SIZE = 128 * 1024 * 1024
|
29 |
|
30 |
# Logger
|
31 |
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
|
rag/svr/task_broker.py
CHANGED
@@ -22,6 +22,7 @@ from api.db.db_models import Task
|
|
22 |
from api.db.db_utils import bulk_insert_into_db
|
23 |
from api.db.services.task_service import TaskService
|
24 |
from deepdoc.parser import PdfParser
|
|
|
25 |
from rag.settings import cron_logger
|
26 |
from rag.utils import MINIO
|
27 |
from rag.utils import findMaxTm
|
@@ -88,6 +89,13 @@ def dispatch():
|
|
88 |
task["from_page"] = p
|
89 |
task["to_page"] = min(p + 5, e)
|
90 |
tsks.append(task)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
else:
|
92 |
tsks.append(new_task())
|
93 |
|
|
|
22 |
from api.db.db_utils import bulk_insert_into_db
|
23 |
from api.db.services.task_service import TaskService
|
24 |
from deepdoc.parser import PdfParser
|
25 |
+
from deepdoc.parser.excel_parser import HuExcelParser
|
26 |
from rag.settings import cron_logger
|
27 |
from rag.utils import MINIO
|
28 |
from rag.utils import findMaxTm
|
|
|
89 |
task["from_page"] = p
|
90 |
task["to_page"] = min(p + 5, e)
|
91 |
tsks.append(task)
|
92 |
+
elif r["parser_id"] == "table":
|
93 |
+
rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
|
94 |
+
for i in range(0, rn, 1000):
|
95 |
+
task = new_task()
|
96 |
+
task["from_page"] = i
|
97 |
+
task["to_page"] = min(i + 1000, rn)
|
98 |
+
tsks.append(task)
|
99 |
else:
|
100 |
tsks.append(new_task())
|
101 |
|
rag/svr/task_executor.py
CHANGED
@@ -184,7 +184,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
|
184 |
if len(cnts_) == 0: cnts_ = vts
|
185 |
else: cnts_ = np.concatenate((cnts_, vts), axis=0)
|
186 |
tk_count += c
|
187 |
-
callback(msg="")
|
188 |
cnts = cnts_
|
189 |
|
190 |
title_w = float(parser_config.get("filename_embd_weight", 0.1))
|
|
|
184 |
if len(cnts_) == 0: cnts_ = vts
|
185 |
else: cnts_ = np.concatenate((cnts_, vts), axis=0)
|
186 |
tk_count += c
|
187 |
+
callback(prog=0.7+0.2*(i+1)/len(cnts), msg="")
|
188 |
cnts = cnts_
|
189 |
|
190 |
title_w = float(parser_config.get("filename_embd_weight", 0.1))
|