Kevin Hu commited on
Commit
6cdee07
·
1 Parent(s): 28c23a3

make task resumable (#2132)

Browse files

### What problem does this PR solve?

### Type of change


- [x] Performance Improvement

api/db/services/dialog_service.py CHANGED
@@ -217,7 +217,7 @@ def chat(dialog, messages, stream=True, **kwargs):
217
  answer = ""
218
  for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
219
  answer = ans
220
- yield {"answer": answer, "reference": {}, "prompt": prompt}
221
  yield decorate_answer(answer)
222
  else:
223
  answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
 
217
  answer = ""
218
  for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
219
  answer = ans
220
+ yield {"answer": answer, "reference": {}}
221
  yield decorate_answer(answer)
222
  else:
223
  answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
docker/entrypoint.sh CHANGED
@@ -11,13 +11,13 @@ fi
11
 
12
  function task_exe(){
13
  while [ 1 -eq 1 ];do
14
- $PY rag/svr/task_executor.py ;
15
  done
16
  }
17
 
18
  for ((i=0;i<WS;i++))
19
  do
20
- task_exe &
21
  done
22
 
23
  while [ 1 -eq 1 ];do
 
11
 
12
  function task_exe(){
13
  while [ 1 -eq 1 ];do
14
+ $PY rag/svr/task_executor.py $1;
15
  done
16
  }
17
 
18
  for ((i=0;i<WS;i++))
19
  do
20
+ task_exe $i &
21
  done
22
 
23
  while [ 1 -eq 1 ];do
rag/svr/task_executor.py CHANGED
@@ -74,9 +74,12 @@ FACTORY = {
74
  ParserType.KG.value: knowledge_graph
75
  }
76
 
 
 
77
 
78
  def set_progress(task_id, from_page=0, to_page=-1,
79
  prog=None, msg="Processing..."):
 
80
  if prog is not None and prog < 0:
81
  msg = "[ERROR]" + msg
82
  cancel = TaskService.do_cancel(task_id)
@@ -97,22 +100,28 @@ def set_progress(task_id, from_page=0, to_page=-1,
97
 
98
  close_connection()
99
  if cancel:
100
- sys.exit()
 
 
 
101
 
102
 
103
  def collect():
 
104
  try:
105
- payload = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", "rag_flow_svr_task_consumer")
106
- if not payload:
 
 
107
  time.sleep(1)
108
  return pd.DataFrame()
109
  except Exception as e:
110
  cron_logger.error("Get task event from queue exception:" + str(e))
111
  return pd.DataFrame()
112
 
113
- msg = payload.get_message()
114
- payload.ack()
115
- if not msg: return pd.DataFrame()
116
 
117
  if TaskService.do_cancel(msg["id"]):
118
  cron_logger.info("Task {} has been canceled.".format(msg["id"]))
@@ -378,20 +387,21 @@ def main():
378
 
379
 
380
  def report_status():
381
- id = "0" if len(sys.argv) < 2 else sys.argv[1]
382
  while True:
383
  try:
384
  obj = REDIS_CONN.get("TASKEXE")
385
  if not obj: obj = {}
386
- else: obj = json.load(obj)
387
- if id not in obj: obj[id] = []
388
- obj[id].append(timer()*1000)
389
- obj[id] = obj[id][-60:]
390
  REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
391
  except Exception as e:
392
  print("[Exception]:", str(e))
393
  time.sleep(60)
394
 
 
395
  if __name__ == "__main__":
396
  peewee_logger = logging.getLogger('peewee')
397
  peewee_logger.propagate = False
@@ -403,3 +413,6 @@ if __name__ == "__main__":
403
 
404
  while True:
405
  main()
 
 
 
 
74
  ParserType.KG.value: knowledge_graph
75
  }
76
 
77
+ CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
78
+ PAYLOAD = None
79
 
80
  def set_progress(task_id, from_page=0, to_page=-1,
81
  prog=None, msg="Processing..."):
82
+ global PAYLOAD
83
  if prog is not None and prog < 0:
84
  msg = "[ERROR]" + msg
85
  cancel = TaskService.do_cancel(task_id)
 
100
 
101
  close_connection()
102
  if cancel:
103
+ if PAYLOAD:
104
+ PAYLOAD.ack()
105
+ PAYLOAD = None
106
+ os._exit(0)
107
 
108
 
109
  def collect():
110
+ global CONSUMEER_NAME, PAYLOAD
111
  try:
112
+ PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
113
+ if not PAYLOAD:
114
+ PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
115
+ if not PAYLOAD:
116
  time.sleep(1)
117
  return pd.DataFrame()
118
  except Exception as e:
119
  cron_logger.error("Get task event from queue exception:" + str(e))
120
  return pd.DataFrame()
121
 
122
+ msg = PAYLOAD.get_message()
123
+ if not msg:
124
+ return pd.DataFrame()
125
 
126
  if TaskService.do_cancel(msg["id"]):
127
  cron_logger.info("Task {} has been canceled.".format(msg["id"]))
 
387
 
388
 
389
  def report_status():
390
+ global CONSUMEER_NAME
391
  while True:
392
  try:
393
  obj = REDIS_CONN.get("TASKEXE")
394
  if not obj: obj = {}
395
+ else: obj = json.loads(obj)
396
+ if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
397
+ obj[CONSUMEER_NAME].append(timer())
398
+ obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
399
  REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
400
  except Exception as e:
401
  print("[Exception]:", str(e))
402
  time.sleep(60)
403
 
404
+
405
  if __name__ == "__main__":
406
  peewee_logger = logging.getLogger('peewee')
407
  peewee_logger.propagate = False
 
413
 
414
  while True:
415
  main()
416
+ if PAYLOAD:
417
+ PAYLOAD.ack()
418
+ PAYLOAD = None
rag/utils/redis_conn.py CHANGED
@@ -107,7 +107,7 @@ class RedisDB:
107
  payload = {"message": json.dumps(message)}
108
  pipeline = self.REDIS.pipeline()
109
  pipeline.xadd(queue, payload)
110
- pipeline.expire(queue, exp)
111
  pipeline.execute()
112
  return True
113
  except Exception as e:
@@ -143,8 +143,22 @@ class RedisDB:
143
  if 'key' in str(e):
144
  pass
145
  else:
146
- logging.warning("[EXCEPTION]consumer" + str(queue_name) + "||" + str(e))
147
  return None
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  REDIS_CONN = RedisDB()
 
107
  payload = {"message": json.dumps(message)}
108
  pipeline = self.REDIS.pipeline()
109
  pipeline.xadd(queue, payload)
110
+ #pipeline.expire(queue, exp)
111
  pipeline.execute()
112
  return True
113
  except Exception as e:
 
143
  if 'key' in str(e):
144
  pass
145
  else:
146
+ logging.warning("[EXCEPTION]consumer: " + str(queue_name) + "||" + str(e))
147
  return None
148
 
149
+ def get_unacked_for(self, consumer_name, queue_name, group_name):
150
+ try:
151
+ group_info = self.REDIS.xinfo_groups(queue_name)
152
+ if not any(e["name"] == group_name for e in group_info):
153
+ return
154
+ pendings = self.REDIS.xpending_range(queue_name, group_name, min=0, max=10000000000000, count=1, consumername=consumer_name)
155
+ if not pendings: return
156
+ msg_id = pendings[0]["message_id"]
157
+ msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
158
+ _, payload = msg[0]
159
+ return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
160
+ except Exception as e:
161
+ logging.warning("[EXCEPTION]xpending_range" + consumer_name + "||" + str(e))
162
+ self.__open__()
163
 
164
  REDIS_CONN = RedisDB()