zhichyu commited on
Commit
362b09b
·
1 Parent(s): a1e29f7

Rework task executor heartbeat (#3430)

Browse files

### What problem does this PR solve?

Rework task executor heartbeat, and print in console.

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):

api/ragflow_server.py CHANGED
@@ -15,10 +15,8 @@
15
  #
16
 
17
  import logging
18
- import inspect
19
  from api.utils.log_utils import initRootLogger
20
-
21
- initRootLogger(inspect.getfile(inspect.currentframe()))
22
  for module in ["pdfminer"]:
23
  module_logger = logging.getLogger(module)
24
  module_logger.setLevel(logging.WARNING)
 
15
  #
16
 
17
  import logging
 
18
  from api.utils.log_utils import initRootLogger
19
+ initRootLogger("ragflow_server")
 
20
  for module in ["pdfminer"]:
21
  module_logger = logging.getLogger(module)
22
  module_logger.setLevel(logging.WARNING)
rag/svr/task_executor.py CHANGED
@@ -14,9 +14,10 @@
14
  # limitations under the License.
15
  #
16
  import logging
17
- import inspect
18
  from api.utils.log_utils import initRootLogger
19
- initRootLogger(inspect.getfile(inspect.currentframe()))
 
20
  for module in ["pdfminer"]:
21
  module_logger = logging.getLogger(module)
22
  module_logger.setLevel(logging.WARNING)
@@ -25,7 +26,7 @@ for module in ["peewee"]:
25
  module_logger.handlers.clear()
26
  module_logger.propagate = True
27
 
28
- import datetime
29
  import json
30
  import os
31
  import hashlib
@@ -33,7 +34,7 @@ import copy
33
  import re
34
  import sys
35
  import time
36
- from concurrent.futures import ThreadPoolExecutor
37
  from functools import partial
38
  from io import BytesIO
39
  from multiprocessing.context import TimeoutError
@@ -78,9 +79,14 @@ FACTORY = {
78
  ParserType.KG.value: knowledge_graph
79
  }
80
 
81
- CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
82
  PAYLOAD: Payload | None = None
83
-
 
 
 
 
 
84
 
85
  def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
86
  global PAYLOAD
@@ -199,8 +205,8 @@ def build(row):
199
  md5.update((ck["content_with_weight"] +
200
  str(d["doc_id"])).encode("utf-8"))
201
  d["id"] = md5.hexdigest()
202
- d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
203
- d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
204
  if not d.get("image"):
205
  d["img_id"] = ""
206
  d["page_num_list"] = json.dumps([])
@@ -333,8 +339,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
333
  md5 = hashlib.md5()
334
  md5.update((content + str(d["doc_id"])).encode("utf-8"))
335
  d["id"] = md5.hexdigest()
336
- d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
337
- d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
338
  d[vctr_nm] = vctr.tolist()
339
  d["content_with_weight"] = content
340
  d["content_ltks"] = rag_tokenizer.tokenize(content)
@@ -403,7 +409,7 @@ def main():
403
 
404
  logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
405
  if es_r:
406
- callback(-1, f"Insert chunk error, detail info please check {LOG_FILE}. Please also check ES status!")
407
  docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
408
  logging.error('Insert chunk error: ' + str(es_r))
409
  else:
@@ -420,24 +426,44 @@ def main():
420
 
421
 
422
  def report_status():
423
- global CONSUMER_NAME
 
424
  while True:
425
  try:
426
- obj = REDIS_CONN.get("TASKEXE")
427
- if not obj: obj = {}
428
- else: obj = json.loads(obj)
429
- if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
430
- obj[CONSUMER_NAME].append(timer())
431
- obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
432
- REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  except Exception:
434
  logging.exception("report_status got exception")
435
  time.sleep(30)
436
 
437
 
438
  if __name__ == "__main__":
439
- exe = ThreadPoolExecutor(max_workers=1)
440
- exe.submit(report_status)
 
441
 
442
  while True:
443
  main()
 
14
  # limitations under the License.
15
  #
16
  import logging
17
+ import sys
18
  from api.utils.log_utils import initRootLogger
19
+ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
20
+ initRootLogger(f"task_executor_{CONSUMER_NO}")
21
  for module in ["pdfminer"]:
22
  module_logger = logging.getLogger(module)
23
  module_logger.setLevel(logging.WARNING)
 
26
  module_logger.handlers.clear()
27
  module_logger.propagate = True
28
 
29
+ from datetime import datetime
30
  import json
31
  import os
32
  import hashlib
 
34
  import re
35
  import sys
36
  import time
37
+ import threading
38
  from functools import partial
39
  from io import BytesIO
40
  from multiprocessing.context import TimeoutError
 
79
  ParserType.KG.value: knowledge_graph
80
  }
81
 
82
+ CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
83
  PAYLOAD: Payload | None = None
84
+ BOOT_AT = datetime.now().isoformat()
85
+ DONE_TASKS = 0
86
+ RETRY_TASKS = 0
87
+ PENDING_TASKS = 0
88
+ HEAD_CREATED_AT = ""
89
+ HEAD_DETAIL = ""
90
 
91
  def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
92
  global PAYLOAD
 
205
  md5.update((ck["content_with_weight"] +
206
  str(d["doc_id"])).encode("utf-8"))
207
  d["id"] = md5.hexdigest()
208
+ d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
209
+ d["create_timestamp_flt"] = datetime.now().timestamp()
210
  if not d.get("image"):
211
  d["img_id"] = ""
212
  d["page_num_list"] = json.dumps([])
 
339
  md5 = hashlib.md5()
340
  md5.update((content + str(d["doc_id"])).encode("utf-8"))
341
  d["id"] = md5.hexdigest()
342
+ d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
343
+ d["create_timestamp_flt"] = datetime.now().timestamp()
344
  d[vctr_nm] = vctr.tolist()
345
  d["content_with_weight"] = content
346
  d["content_ltks"] = rag_tokenizer.tokenize(content)
 
409
 
410
  logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
411
  if es_r:
412
+ callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
413
  docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
414
  logging.error('Insert chunk error: ' + str(es_r))
415
  else:
 
426
 
427
 
428
  def report_status():
429
+ global CONSUMER_NAME, BOOT_AT, DONE_TASKS, RETRY_TASKS, PENDING_TASKS, HEAD_CREATED_AT, HEAD_DETAIL
430
+ REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
431
  while True:
432
  try:
433
+ now = datetime.now()
434
+ PENDING_TASKS = REDIS_CONN.queue_length(SVR_QUEUE_NAME)
435
+ if PENDING_TASKS > 0:
436
+ head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
437
+ if head_info is not None:
438
+ seconds = int(head_info[0].split("-")[0])/1000
439
+ HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
440
+ HEAD_DETAIL = head_info[1]
441
+
442
+ heartbeat = json.dumps({
443
+ "name": CONSUMER_NAME,
444
+ "now": now.isoformat(),
445
+ "boot_at": BOOT_AT,
446
+ "done": DONE_TASKS,
447
+ "retry": RETRY_TASKS,
448
+ "pending": PENDING_TASKS,
449
+ "head_created_at": HEAD_CREATED_AT,
450
+ "head_detail": HEAD_DETAIL,
451
+ })
452
+ REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
453
+ logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
454
+
455
+ expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
456
+ if expired > 0:
457
+ REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
458
  except Exception:
459
  logging.exception("report_status got exception")
460
  time.sleep(30)
461
 
462
 
463
  if __name__ == "__main__":
464
+ background_thread = threading.Thread(target=report_status)
465
+ background_thread.daemon = True
466
+ background_thread.start()
467
 
468
  while True:
469
  main()
rag/utils/redis_conn.py CHANGED
@@ -90,6 +90,69 @@ class RedisDB:
90
  self.__open__()
91
  return False
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def transaction(self, key, value, exp=3600):
94
  try:
95
  pipeline = self.REDIS.pipeline(transaction=True)
@@ -162,4 +225,22 @@ class RedisDB:
162
  logging.exception("xpending_range: " + consumer_name + " got exception")
163
  self.__open__()
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  REDIS_CONN = RedisDB()
 
90
  self.__open__()
91
  return False
92
 
93
+ def sadd(self, key: str, member: str):
94
+ try:
95
+ self.REDIS.sadd(key, member)
96
+ return True
97
+ except Exception as e:
98
+ logging.warning("[EXCEPTION]sadd" + str(key) + "||" + str(e))
99
+ self.__open__()
100
+ return False
101
+
102
+ def srem(self, key: str, member: str):
103
+ try:
104
+ self.REDIS.srem(key, member)
105
+ return True
106
+ except Exception as e:
107
+ logging.warning("[EXCEPTION]srem" + str(key) + "||" + str(e))
108
+ self.__open__()
109
+ return False
110
+
111
+ def smembers(self, key: str):
112
+ try:
113
+ res = self.REDIS.smembers(key)
114
+ return res
115
+ except Exception as e:
116
+ logging.warning("[EXCEPTION]smembers" + str(key) + "||" + str(e))
117
+ self.__open__()
118
+ return None
119
+
120
+ def zadd(self, key: str, member: str, score: float):
121
+ try:
122
+ self.REDIS.zadd(key, {member: score})
123
+ return True
124
+ except Exception as e:
125
+ logging.warning("[EXCEPTION]zadd" + str(key) + "||" + str(e))
126
+ self.__open__()
127
+ return False
128
+
129
+ def zcount(self, key: str, min: float, max: float):
130
+ try:
131
+ res = self.REDIS.zcount(key, min, max)
132
+ return res
133
+ except Exception as e:
134
+ logging.warning("[EXCEPTION]spopmin" + str(key) + "||" + str(e))
135
+ self.__open__()
136
+ return 0
137
+
138
+ def zpopmin(self, key: str, count: int):
139
+ try:
140
+ res = self.REDIS.zpopmin(key, count)
141
+ return res
142
+ except Exception as e:
143
+ logging.warning("[EXCEPTION]spopmin" + str(key) + "||" + str(e))
144
+ self.__open__()
145
+ return None
146
+
147
+ def zrangebyscore(self, key: str, min: float, max: float):
148
+ try:
149
+ res = self.REDIS.zrangebyscore(key, min, max)
150
+ return res
151
+ except Exception as e:
152
+ logging.warning("[EXCEPTION]srangebyscore" + str(key) + "||" + str(e))
153
+ self.__open__()
154
+ return None
155
+
156
  def transaction(self, key, value, exp=3600):
157
  try:
158
  pipeline = self.REDIS.pipeline(transaction=True)
 
225
  logging.exception("xpending_range: " + consumer_name + " got exception")
226
  self.__open__()
227
 
228
+ def queue_length(self, queue) -> int:
229
+ for _ in range(3):
230
+ try:
231
+ num = self.REDIS.xlen(queue)
232
+ return num
233
+ except Exception:
234
+ logging.exception("queue_length" + str(queue) + " got exception")
235
+ return 0
236
+
237
+ def queue_head(self, queue) -> int:
238
+ for _ in range(3):
239
+ try:
240
+ ent = self.REDIS.xrange(queue, count=1)
241
+ return ent[0]
242
+ except Exception:
243
+ logging.exception("queue_head" + str(queue) + " got exception")
244
+ return 0
245
+
246
  REDIS_CONN = RedisDB()