yqkcn commited on
Commit
f9dd38e
·
1 Parent(s): 080762f

handle nits in task_executor (#2637)

Browse files

### What problem does this PR solve?

- fix typo
- fix string format
- format import

### Type of change

- [x] Refactoring

Files changed (1) hide show
  1. rag/svr/task_executor.py +36 -38
rag/svr/task_executor.py CHANGED
@@ -25,34 +25,31 @@ import time
25
  import traceback
26
  from concurrent.futures import ThreadPoolExecutor
27
  from functools import partial
28
-
29
- from api.db.services.file2document_service import File2DocumentService
30
- from api.settings import retrievaler
31
- from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
32
- from rag.utils.storage_factory import STORAGE_IMPL
33
- from api.db.db_models import close_connection
34
- from rag.settings import database_logger, SVR_QUEUE_NAME
35
- from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
36
- from multiprocessing import Pool
37
- import numpy as np
38
- from elasticsearch_dsl import Q, Search
39
  from multiprocessing.context import TimeoutError
40
- from api.db.services.task_service import TaskService
41
- from rag.utils.es_conn import ELASTICSEARCH
42
  from timeit import default_timer as timer
43
- from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
44
 
45
- from rag.nlp import search, rag_tokenizer
46
- from io import BytesIO
47
  import pandas as pd
48
-
49
- from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
50
 
51
  from api.db import LLMType, ParserType
52
  from api.db.services.document_service import DocumentService
53
  from api.db.services.llm_service import LLMBundle
 
 
 
54
  from api.utils.file_utils import get_project_base_directory
55
- from rag.utils.redis_conn import REDIS_CONN
 
 
 
 
 
 
 
 
 
56
 
57
  BATCH_SIZE = 64
58
 
@@ -74,11 +71,11 @@ FACTORY = {
74
  ParserType.KG.value: knowledge_graph
75
  }
76
 
77
- CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
78
- PAYLOAD = None
 
79
 
80
- def set_progress(task_id, from_page=0, to_page=-1,
81
- prog=None, msg="Processing..."):
82
  global PAYLOAD
83
  if prog is not None and prog < 0:
84
  msg = "[ERROR]" + msg
@@ -107,11 +104,11 @@ def set_progress(task_id, from_page=0, to_page=-1,
107
 
108
 
109
  def collect():
110
- global CONSUMEER_NAME, PAYLOAD
111
  try:
112
- PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
113
  if not PAYLOAD:
114
- PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
115
  if not PAYLOAD:
116
  time.sleep(1)
117
  return pd.DataFrame()
@@ -159,8 +156,8 @@ def build(row):
159
  binary = get_storage_binary(bucket, name)
160
  cron_logger.info(
161
  "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
162
- except TimeoutError as e:
163
- callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.")
164
  cron_logger.error(
165
  "Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
166
  return
@@ -168,8 +165,7 @@ def build(row):
168
  if re.search("(No such file|not found)", str(e)):
169
  callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
170
  else:
171
- callback(-1, f"Get file from minio: %s" %
172
- str(e).replace("'", ""))
173
  traceback.print_exc()
174
  return
175
 
@@ -180,7 +176,7 @@ def build(row):
180
  cron_logger.info(
181
  "Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
182
  except Exception as e:
183
- callback(-1, f"Internal server error while chunking: %s" %
184
  str(e).replace("'", ""))
185
  cron_logger.error(
186
  "Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
@@ -236,7 +232,9 @@ def init_kb(row):
236
  open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
237
 
238
 
239
- def embedding(docs, mdl, parser_config={}, callback=None):
 
 
240
  batch_size = 32
241
  tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
242
  re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
@@ -277,7 +275,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
277
 
278
  def run_raptor(row, chat_mdl, embd_mdl, callback=None):
279
  vts, _ = embd_mdl.encode(["ok"])
280
- vctr_nm = "q_%d_vec"%len(vts[0])
281
  chunks = []
282
  for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
283
  chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@@ -374,7 +372,7 @@ def main():
374
 
375
  cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
376
  if es_r:
377
- callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
378
  ELASTICSEARCH.deleteByQuery(
379
  Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
380
  cron_logger.error(str(es_r))
@@ -392,15 +390,15 @@ def main():
392
 
393
 
394
  def report_status():
395
- global CONSUMEER_NAME
396
  while True:
397
  try:
398
  obj = REDIS_CONN.get("TASKEXE")
399
  if not obj: obj = {}
400
  else: obj = json.loads(obj)
401
- if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
402
- obj[CONSUMEER_NAME].append(timer())
403
- obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
404
  REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
405
  except Exception as e:
406
  print("[Exception]:", str(e))
 
25
  import traceback
26
  from concurrent.futures import ThreadPoolExecutor
27
  from functools import partial
28
+ from io import BytesIO
 
 
 
 
 
 
 
 
 
 
29
  from multiprocessing.context import TimeoutError
 
 
30
  from timeit import default_timer as timer
 
31
 
32
+ import numpy as np
 
33
  import pandas as pd
34
+ from elasticsearch_dsl import Q
 
35
 
36
  from api.db import LLMType, ParserType
37
  from api.db.services.document_service import DocumentService
38
  from api.db.services.llm_service import LLMBundle
39
+ from api.db.services.task_service import TaskService
40
+ from api.db.services.file2document_service import File2DocumentService
41
+ from api.settings import retrievaler
42
  from api.utils.file_utils import get_project_base_directory
43
+ from api.db.db_models import close_connection
44
+ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
45
+ from rag.nlp import search, rag_tokenizer
46
+ from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
47
+ from rag.settings import database_logger, SVR_QUEUE_NAME
48
+ from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
49
+ from rag.utils import rmSpace, num_tokens_from_string
50
+ from rag.utils.es_conn import ELASTICSEARCH
51
+ from rag.utils.redis_conn import REDIS_CONN, Payload
52
+ from rag.utils.storage_factory import STORAGE_IMPL
53
 
54
  BATCH_SIZE = 64
55
 
 
71
  ParserType.KG.value: knowledge_graph
72
  }
73
 
74
+ CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
75
+ PAYLOAD: Payload | None = None
76
+
77
 
78
+ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
 
79
  global PAYLOAD
80
  if prog is not None and prog < 0:
81
  msg = "[ERROR]" + msg
 
104
 
105
 
106
  def collect():
107
+ global CONSUMER_NAME, PAYLOAD
108
  try:
109
+ PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
110
  if not PAYLOAD:
111
+ PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
112
  if not PAYLOAD:
113
  time.sleep(1)
114
  return pd.DataFrame()
 
156
  binary = get_storage_binary(bucket, name)
157
  cron_logger.info(
158
  "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
159
+ except TimeoutError:
160
+ callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
161
  cron_logger.error(
162
  "Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
163
  return
 
165
  if re.search("(No such file|not found)", str(e)):
166
  callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
167
  else:
168
+ callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
 
169
  traceback.print_exc()
170
  return
171
 
 
176
  cron_logger.info(
177
  "Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
178
  except Exception as e:
179
+ callback(-1, "Internal server error while chunking: %s" %
180
  str(e).replace("'", ""))
181
  cron_logger.error(
182
  "Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
 
232
  open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
233
 
234
 
235
+ def embedding(docs, mdl, parser_config=None, callback=None):
236
+ if parser_config is None:
237
+ parser_config = {}
238
  batch_size = 32
239
  tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
240
  re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
 
275
 
276
  def run_raptor(row, chat_mdl, embd_mdl, callback=None):
277
  vts, _ = embd_mdl.encode(["ok"])
278
+ vctr_nm = "q_%d_vec" % len(vts[0])
279
  chunks = []
280
  for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
281
  chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
 
372
 
373
  cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
374
  if es_r:
375
+ callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
376
  ELASTICSEARCH.deleteByQuery(
377
  Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
378
  cron_logger.error(str(es_r))
 
390
 
391
 
392
  def report_status():
393
+ global CONSUMER_NAME
394
  while True:
395
  try:
396
  obj = REDIS_CONN.get("TASKEXE")
397
  if not obj: obj = {}
398
  else: obj = json.loads(obj)
399
+ if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
400
+ obj[CONSUMER_NAME].append(timer())
401
+ obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
402
  REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
403
  except Exception as e:
404
  print("[Exception]:", str(e))