yqkcn
commited on
Commit
·
f9dd38e
1
Parent(s):
080762f
handle nits in task_executor (#2637)
Browse files### What problem does this PR solve?
- fix typo
- fix string format
- format import
### Type of change
- [x] Refactoring
- rag/svr/task_executor.py +36 -38
rag/svr/task_executor.py
CHANGED
@@ -25,34 +25,31 @@ import time
|
|
25 |
import traceback
|
26 |
from concurrent.futures import ThreadPoolExecutor
|
27 |
from functools import partial
|
28 |
-
|
29 |
-
from api.db.services.file2document_service import File2DocumentService
|
30 |
-
from api.settings import retrievaler
|
31 |
-
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
32 |
-
from rag.utils.storage_factory import STORAGE_IMPL
|
33 |
-
from api.db.db_models import close_connection
|
34 |
-
from rag.settings import database_logger, SVR_QUEUE_NAME
|
35 |
-
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
36 |
-
from multiprocessing import Pool
|
37 |
-
import numpy as np
|
38 |
-
from elasticsearch_dsl import Q, Search
|
39 |
from multiprocessing.context import TimeoutError
|
40 |
-
from api.db.services.task_service import TaskService
|
41 |
-
from rag.utils.es_conn import ELASTICSEARCH
|
42 |
from timeit import default_timer as timer
|
43 |
-
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
|
44 |
|
45 |
-
|
46 |
-
from io import BytesIO
|
47 |
import pandas as pd
|
48 |
-
|
49 |
-
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
50 |
|
51 |
from api.db import LLMType, ParserType
|
52 |
from api.db.services.document_service import DocumentService
|
53 |
from api.db.services.llm_service import LLMBundle
|
|
|
|
|
|
|
54 |
from api.utils.file_utils import get_project_base_directory
|
55 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
BATCH_SIZE = 64
|
58 |
|
@@ -74,11 +71,11 @@ FACTORY = {
|
|
74 |
ParserType.KG.value: knowledge_graph
|
75 |
}
|
76 |
|
77 |
-
|
78 |
-
PAYLOAD = None
|
|
|
79 |
|
80 |
-
def set_progress(task_id, from_page=0, to_page=-1,
|
81 |
-
prog=None, msg="Processing..."):
|
82 |
global PAYLOAD
|
83 |
if prog is not None and prog < 0:
|
84 |
msg = "[ERROR]" + msg
|
@@ -107,11 +104,11 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
|
107 |
|
108 |
|
109 |
def collect():
|
110 |
-
global
|
111 |
try:
|
112 |
-
PAYLOAD = REDIS_CONN.get_unacked_for(
|
113 |
if not PAYLOAD:
|
114 |
-
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker",
|
115 |
if not PAYLOAD:
|
116 |
time.sleep(1)
|
117 |
return pd.DataFrame()
|
@@ -159,8 +156,8 @@ def build(row):
|
|
159 |
binary = get_storage_binary(bucket, name)
|
160 |
cron_logger.info(
|
161 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
162 |
-
except TimeoutError
|
163 |
-
callback(-1,
|
164 |
cron_logger.error(
|
165 |
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
166 |
return
|
@@ -168,8 +165,7 @@ def build(row):
|
|
168 |
if re.search("(No such file|not found)", str(e)):
|
169 |
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
|
170 |
else:
|
171 |
-
callback(-1,
|
172 |
-
str(e).replace("'", ""))
|
173 |
traceback.print_exc()
|
174 |
return
|
175 |
|
@@ -180,7 +176,7 @@ def build(row):
|
|
180 |
cron_logger.info(
|
181 |
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
182 |
except Exception as e:
|
183 |
-
callback(-1,
|
184 |
str(e).replace("'", ""))
|
185 |
cron_logger.error(
|
186 |
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
@@ -236,7 +232,9 @@ def init_kb(row):
|
|
236 |
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
237 |
|
238 |
|
239 |
-
def embedding(docs, mdl, parser_config=
|
|
|
|
|
240 |
batch_size = 32
|
241 |
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
242 |
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
|
@@ -277,7 +275,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
|
277 |
|
278 |
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
279 |
vts, _ = embd_mdl.encode(["ok"])
|
280 |
-
vctr_nm = "q_%d_vec"%len(vts[0])
|
281 |
chunks = []
|
282 |
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
283 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
@@ -374,7 +372,7 @@ def main():
|
|
374 |
|
375 |
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
376 |
if es_r:
|
377 |
-
callback(-1,
|
378 |
ELASTICSEARCH.deleteByQuery(
|
379 |
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
380 |
cron_logger.error(str(es_r))
|
@@ -392,15 +390,15 @@ def main():
|
|
392 |
|
393 |
|
394 |
def report_status():
|
395 |
-
global
|
396 |
while True:
|
397 |
try:
|
398 |
obj = REDIS_CONN.get("TASKEXE")
|
399 |
if not obj: obj = {}
|
400 |
else: obj = json.loads(obj)
|
401 |
-
if
|
402 |
-
obj[
|
403 |
-
obj[
|
404 |
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
405 |
except Exception as e:
|
406 |
print("[Exception]:", str(e))
|
|
|
25 |
import traceback
|
26 |
from concurrent.futures import ThreadPoolExecutor
|
27 |
from functools import partial
|
28 |
+
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
from multiprocessing.context import TimeoutError
|
|
|
|
|
30 |
from timeit import default_timer as timer
|
|
|
31 |
|
32 |
+
import numpy as np
|
|
|
33 |
import pandas as pd
|
34 |
+
from elasticsearch_dsl import Q
|
|
|
35 |
|
36 |
from api.db import LLMType, ParserType
|
37 |
from api.db.services.document_service import DocumentService
|
38 |
from api.db.services.llm_service import LLMBundle
|
39 |
+
from api.db.services.task_service import TaskService
|
40 |
+
from api.db.services.file2document_service import File2DocumentService
|
41 |
+
from api.settings import retrievaler
|
42 |
from api.utils.file_utils import get_project_base_directory
|
43 |
+
from api.db.db_models import close_connection
|
44 |
+
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
45 |
+
from rag.nlp import search, rag_tokenizer
|
46 |
+
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
47 |
+
from rag.settings import database_logger, SVR_QUEUE_NAME
|
48 |
+
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
49 |
+
from rag.utils import rmSpace, num_tokens_from_string
|
50 |
+
from rag.utils.es_conn import ELASTICSEARCH
|
51 |
+
from rag.utils.redis_conn import REDIS_CONN, Payload
|
52 |
+
from rag.utils.storage_factory import STORAGE_IMPL
|
53 |
|
54 |
BATCH_SIZE = 64
|
55 |
|
|
|
71 |
ParserType.KG.value: knowledge_graph
|
72 |
}
|
73 |
|
74 |
+
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
75 |
+
PAYLOAD: Payload | None = None
|
76 |
+
|
77 |
|
78 |
+
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
|
|
79 |
global PAYLOAD
|
80 |
if prog is not None and prog < 0:
|
81 |
msg = "[ERROR]" + msg
|
|
|
104 |
|
105 |
|
106 |
def collect():
|
107 |
+
global CONSUMER_NAME, PAYLOAD
|
108 |
try:
|
109 |
+
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
110 |
if not PAYLOAD:
|
111 |
+
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
112 |
if not PAYLOAD:
|
113 |
time.sleep(1)
|
114 |
return pd.DataFrame()
|
|
|
156 |
binary = get_storage_binary(bucket, name)
|
157 |
cron_logger.info(
|
158 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
159 |
+
except TimeoutError:
|
160 |
+
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
161 |
cron_logger.error(
|
162 |
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
163 |
return
|
|
|
165 |
if re.search("(No such file|not found)", str(e)):
|
166 |
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
|
167 |
else:
|
168 |
+
callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
|
|
|
169 |
traceback.print_exc()
|
170 |
return
|
171 |
|
|
|
176 |
cron_logger.info(
|
177 |
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
178 |
except Exception as e:
|
179 |
+
callback(-1, "Internal server error while chunking: %s" %
|
180 |
str(e).replace("'", ""))
|
181 |
cron_logger.error(
|
182 |
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
|
|
232 |
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
233 |
|
234 |
|
235 |
+
def embedding(docs, mdl, parser_config=None, callback=None):
|
236 |
+
if parser_config is None:
|
237 |
+
parser_config = {}
|
238 |
batch_size = 32
|
239 |
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
240 |
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
|
|
|
275 |
|
276 |
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
277 |
vts, _ = embd_mdl.encode(["ok"])
|
278 |
+
vctr_nm = "q_%d_vec" % len(vts[0])
|
279 |
chunks = []
|
280 |
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
281 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
|
|
372 |
|
373 |
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
374 |
if es_r:
|
375 |
+
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
376 |
ELASTICSEARCH.deleteByQuery(
|
377 |
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
378 |
cron_logger.error(str(es_r))
|
|
|
390 |
|
391 |
|
392 |
def report_status():
|
393 |
+
global CONSUMER_NAME
|
394 |
while True:
|
395 |
try:
|
396 |
obj = REDIS_CONN.get("TASKEXE")
|
397 |
if not obj: obj = {}
|
398 |
else: obj = json.loads(obj)
|
399 |
+
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
|
400 |
+
obj[CONSUMER_NAME].append(timer())
|
401 |
+
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
|
402 |
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
403 |
except Exception as e:
|
404 |
print("[Exception]:", str(e))
|