yqkcn
commited on
Commit
·
f9dd38e
1
Parent(s):
080762f
handle nits in task_executor (#2637)
Browse files### What problem does this PR solve?
- fix typo
- fix string format
- format import
### Type of change
- [x] Refactoring
- rag/svr/task_executor.py +36 -38
rag/svr/task_executor.py
CHANGED
|
@@ -25,34 +25,31 @@ import time
|
|
| 25 |
import traceback
|
| 26 |
from concurrent.futures import ThreadPoolExecutor
|
| 27 |
from functools import partial
|
| 28 |
-
|
| 29 |
-
from api.db.services.file2document_service import File2DocumentService
|
| 30 |
-
from api.settings import retrievaler
|
| 31 |
-
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
| 32 |
-
from rag.utils.storage_factory import STORAGE_IMPL
|
| 33 |
-
from api.db.db_models import close_connection
|
| 34 |
-
from rag.settings import database_logger, SVR_QUEUE_NAME
|
| 35 |
-
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
| 36 |
-
from multiprocessing import Pool
|
| 37 |
-
import numpy as np
|
| 38 |
-
from elasticsearch_dsl import Q, Search
|
| 39 |
from multiprocessing.context import TimeoutError
|
| 40 |
-
from api.db.services.task_service import TaskService
|
| 41 |
-
from rag.utils.es_conn import ELASTICSEARCH
|
| 42 |
from timeit import default_timer as timer
|
| 43 |
-
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
|
| 44 |
|
| 45 |
-
|
| 46 |
-
from io import BytesIO
|
| 47 |
import pandas as pd
|
| 48 |
-
|
| 49 |
-
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
| 50 |
|
| 51 |
from api.db import LLMType, ParserType
|
| 52 |
from api.db.services.document_service import DocumentService
|
| 53 |
from api.db.services.llm_service import LLMBundle
|
|
|
|
|
|
|
|
|
|
| 54 |
from api.utils.file_utils import get_project_base_directory
|
| 55 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
BATCH_SIZE = 64
|
| 58 |
|
|
@@ -74,11 +71,11 @@ FACTORY = {
|
|
| 74 |
ParserType.KG.value: knowledge_graph
|
| 75 |
}
|
| 76 |
|
| 77 |
-
|
| 78 |
-
PAYLOAD = None
|
|
|
|
| 79 |
|
| 80 |
-
def set_progress(task_id, from_page=0, to_page=-1,
|
| 81 |
-
prog=None, msg="Processing..."):
|
| 82 |
global PAYLOAD
|
| 83 |
if prog is not None and prog < 0:
|
| 84 |
msg = "[ERROR]" + msg
|
|
@@ -107,11 +104,11 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
|
| 107 |
|
| 108 |
|
| 109 |
def collect():
|
| 110 |
-
global
|
| 111 |
try:
|
| 112 |
-
PAYLOAD = REDIS_CONN.get_unacked_for(
|
| 113 |
if not PAYLOAD:
|
| 114 |
-
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker",
|
| 115 |
if not PAYLOAD:
|
| 116 |
time.sleep(1)
|
| 117 |
return pd.DataFrame()
|
|
@@ -159,8 +156,8 @@ def build(row):
|
|
| 159 |
binary = get_storage_binary(bucket, name)
|
| 160 |
cron_logger.info(
|
| 161 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 162 |
-
except TimeoutError
|
| 163 |
-
callback(-1,
|
| 164 |
cron_logger.error(
|
| 165 |
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
| 166 |
return
|
|
@@ -168,8 +165,7 @@ def build(row):
|
|
| 168 |
if re.search("(No such file|not found)", str(e)):
|
| 169 |
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
|
| 170 |
else:
|
| 171 |
-
callback(-1,
|
| 172 |
-
str(e).replace("'", ""))
|
| 173 |
traceback.print_exc()
|
| 174 |
return
|
| 175 |
|
|
@@ -180,7 +176,7 @@ def build(row):
|
|
| 180 |
cron_logger.info(
|
| 181 |
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 182 |
except Exception as e:
|
| 183 |
-
callback(-1,
|
| 184 |
str(e).replace("'", ""))
|
| 185 |
cron_logger.error(
|
| 186 |
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
|
@@ -236,7 +232,9 @@ def init_kb(row):
|
|
| 236 |
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
| 237 |
|
| 238 |
|
| 239 |
-
def embedding(docs, mdl, parser_config=
|
|
|
|
|
|
|
| 240 |
batch_size = 32
|
| 241 |
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
| 242 |
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
|
|
@@ -277,7 +275,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
|
| 277 |
|
| 278 |
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
| 279 |
vts, _ = embd_mdl.encode(["ok"])
|
| 280 |
-
vctr_nm = "q_%d_vec"%len(vts[0])
|
| 281 |
chunks = []
|
| 282 |
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
| 283 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
|
@@ -374,7 +372,7 @@ def main():
|
|
| 374 |
|
| 375 |
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
| 376 |
if es_r:
|
| 377 |
-
callback(-1,
|
| 378 |
ELASTICSEARCH.deleteByQuery(
|
| 379 |
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
| 380 |
cron_logger.error(str(es_r))
|
|
@@ -392,15 +390,15 @@ def main():
|
|
| 392 |
|
| 393 |
|
| 394 |
def report_status():
|
| 395 |
-
global
|
| 396 |
while True:
|
| 397 |
try:
|
| 398 |
obj = REDIS_CONN.get("TASKEXE")
|
| 399 |
if not obj: obj = {}
|
| 400 |
else: obj = json.loads(obj)
|
| 401 |
-
if
|
| 402 |
-
obj[
|
| 403 |
-
obj[
|
| 404 |
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
| 405 |
except Exception as e:
|
| 406 |
print("[Exception]:", str(e))
|
|
|
|
| 25 |
import traceback
|
| 26 |
from concurrent.futures import ThreadPoolExecutor
|
| 27 |
from functools import partial
|
| 28 |
+
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
from multiprocessing.context import TimeoutError
|
|
|
|
|
|
|
| 30 |
from timeit import default_timer as timer
|
|
|
|
| 31 |
|
| 32 |
+
import numpy as np
|
|
|
|
| 33 |
import pandas as pd
|
| 34 |
+
from elasticsearch_dsl import Q
|
|
|
|
| 35 |
|
| 36 |
from api.db import LLMType, ParserType
|
| 37 |
from api.db.services.document_service import DocumentService
|
| 38 |
from api.db.services.llm_service import LLMBundle
|
| 39 |
+
from api.db.services.task_service import TaskService
|
| 40 |
+
from api.db.services.file2document_service import File2DocumentService
|
| 41 |
+
from api.settings import retrievaler
|
| 42 |
from api.utils.file_utils import get_project_base_directory
|
| 43 |
+
from api.db.db_models import close_connection
|
| 44 |
+
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
| 45 |
+
from rag.nlp import search, rag_tokenizer
|
| 46 |
+
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
| 47 |
+
from rag.settings import database_logger, SVR_QUEUE_NAME
|
| 48 |
+
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
| 49 |
+
from rag.utils import rmSpace, num_tokens_from_string
|
| 50 |
+
from rag.utils.es_conn import ELASTICSEARCH
|
| 51 |
+
from rag.utils.redis_conn import REDIS_CONN, Payload
|
| 52 |
+
from rag.utils.storage_factory import STORAGE_IMPL
|
| 53 |
|
| 54 |
BATCH_SIZE = 64
|
| 55 |
|
|
|
|
| 71 |
ParserType.KG.value: knowledge_graph
|
| 72 |
}
|
| 73 |
|
| 74 |
+
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
| 75 |
+
PAYLOAD: Payload | None = None
|
| 76 |
+
|
| 77 |
|
| 78 |
+
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
|
|
|
| 79 |
global PAYLOAD
|
| 80 |
if prog is not None and prog < 0:
|
| 81 |
msg = "[ERROR]" + msg
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
def collect():
|
| 107 |
+
global CONSUMER_NAME, PAYLOAD
|
| 108 |
try:
|
| 109 |
+
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
| 110 |
if not PAYLOAD:
|
| 111 |
+
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
| 112 |
if not PAYLOAD:
|
| 113 |
time.sleep(1)
|
| 114 |
return pd.DataFrame()
|
|
|
|
| 156 |
binary = get_storage_binary(bucket, name)
|
| 157 |
cron_logger.info(
|
| 158 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 159 |
+
except TimeoutError:
|
| 160 |
+
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
| 161 |
cron_logger.error(
|
| 162 |
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
| 163 |
return
|
|
|
|
| 165 |
if re.search("(No such file|not found)", str(e)):
|
| 166 |
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
|
| 167 |
else:
|
| 168 |
+
callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
|
|
|
|
| 169 |
traceback.print_exc()
|
| 170 |
return
|
| 171 |
|
|
|
|
| 176 |
cron_logger.info(
|
| 177 |
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 178 |
except Exception as e:
|
| 179 |
+
callback(-1, "Internal server error while chunking: %s" %
|
| 180 |
str(e).replace("'", ""))
|
| 181 |
cron_logger.error(
|
| 182 |
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
|
|
|
| 232 |
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
| 233 |
|
| 234 |
|
| 235 |
+
def embedding(docs, mdl, parser_config=None, callback=None):
|
| 236 |
+
if parser_config is None:
|
| 237 |
+
parser_config = {}
|
| 238 |
batch_size = 32
|
| 239 |
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
| 240 |
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
|
|
|
|
| 275 |
|
| 276 |
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
| 277 |
vts, _ = embd_mdl.encode(["ok"])
|
| 278 |
+
vctr_nm = "q_%d_vec" % len(vts[0])
|
| 279 |
chunks = []
|
| 280 |
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
| 281 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
|
|
|
| 372 |
|
| 373 |
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
| 374 |
if es_r:
|
| 375 |
+
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
| 376 |
ELASTICSEARCH.deleteByQuery(
|
| 377 |
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
| 378 |
cron_logger.error(str(es_r))
|
|
|
|
| 390 |
|
| 391 |
|
| 392 |
def report_status():
|
| 393 |
+
global CONSUMER_NAME
|
| 394 |
while True:
|
| 395 |
try:
|
| 396 |
obj = REDIS_CONN.get("TASKEXE")
|
| 397 |
if not obj: obj = {}
|
| 398 |
else: obj = json.loads(obj)
|
| 399 |
+
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
|
| 400 |
+
obj[CONSUMER_NAME].append(timer())
|
| 401 |
+
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
|
| 402 |
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
| 403 |
except Exception as e:
|
| 404 |
print("[Exception]:", str(e))
|