Kevin Hu
commited on
Commit
·
6cdee07
1
Parent(s):
28c23a3
make task resumable (#2132)
Browse files### What problem does this PR solve?
### Type of change
- [x] Performance Improvement
- api/db/services/dialog_service.py +1 -1
- docker/entrypoint.sh +2 -2
- rag/svr/task_executor.py +24 -11
- rag/utils/redis_conn.py +16 -2
api/db/services/dialog_service.py
CHANGED
@@ -217,7 +217,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
217 |
answer = ""
|
218 |
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
219 |
answer = ans
|
220 |
-
yield {"answer": answer, "reference": {}
|
221 |
yield decorate_answer(answer)
|
222 |
else:
|
223 |
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
|
|
217 |
answer = ""
|
218 |
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
219 |
answer = ans
|
220 |
+
yield {"answer": answer, "reference": {}}
|
221 |
yield decorate_answer(answer)
|
222 |
else:
|
223 |
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
docker/entrypoint.sh
CHANGED
@@ -11,13 +11,13 @@ fi
|
|
11 |
|
12 |
function task_exe(){
|
13 |
while [ 1 -eq 1 ];do
|
14 |
-
$PY rag/svr/task_executor.py ;
|
15 |
done
|
16 |
}
|
17 |
|
18 |
for ((i=0;i<WS;i++))
|
19 |
do
|
20 |
-
task_exe &
|
21 |
done
|
22 |
|
23 |
while [ 1 -eq 1 ];do
|
|
|
11 |
|
12 |
function task_exe(){
|
13 |
while [ 1 -eq 1 ];do
|
14 |
+
$PY rag/svr/task_executor.py $1;
|
15 |
done
|
16 |
}
|
17 |
|
18 |
for ((i=0;i<WS;i++))
|
19 |
do
|
20 |
+
task_exe $i &
|
21 |
done
|
22 |
|
23 |
while [ 1 -eq 1 ];do
|
rag/svr/task_executor.py
CHANGED
@@ -74,9 +74,12 @@ FACTORY = {
|
|
74 |
ParserType.KG.value: knowledge_graph
|
75 |
}
|
76 |
|
|
|
|
|
77 |
|
78 |
def set_progress(task_id, from_page=0, to_page=-1,
|
79 |
prog=None, msg="Processing..."):
|
|
|
80 |
if prog is not None and prog < 0:
|
81 |
msg = "[ERROR]" + msg
|
82 |
cancel = TaskService.do_cancel(task_id)
|
@@ -97,22 +100,28 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
|
97 |
|
98 |
close_connection()
|
99 |
if cancel:
|
100 |
-
|
|
|
|
|
|
|
101 |
|
102 |
|
103 |
def collect():
|
|
|
104 |
try:
|
105 |
-
|
106 |
-
if not
|
|
|
|
|
107 |
time.sleep(1)
|
108 |
return pd.DataFrame()
|
109 |
except Exception as e:
|
110 |
cron_logger.error("Get task event from queue exception:" + str(e))
|
111 |
return pd.DataFrame()
|
112 |
|
113 |
-
msg =
|
114 |
-
|
115 |
-
|
116 |
|
117 |
if TaskService.do_cancel(msg["id"]):
|
118 |
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
|
@@ -378,20 +387,21 @@ def main():
|
|
378 |
|
379 |
|
380 |
def report_status():
|
381 |
-
|
382 |
while True:
|
383 |
try:
|
384 |
obj = REDIS_CONN.get("TASKEXE")
|
385 |
if not obj: obj = {}
|
386 |
-
else: obj = json.
|
387 |
-
if
|
388 |
-
obj[
|
389 |
-
obj[
|
390 |
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
391 |
except Exception as e:
|
392 |
print("[Exception]:", str(e))
|
393 |
time.sleep(60)
|
394 |
|
|
|
395 |
if __name__ == "__main__":
|
396 |
peewee_logger = logging.getLogger('peewee')
|
397 |
peewee_logger.propagate = False
|
@@ -403,3 +413,6 @@ if __name__ == "__main__":
|
|
403 |
|
404 |
while True:
|
405 |
main()
|
|
|
|
|
|
|
|
74 |
ParserType.KG.value: knowledge_graph
|
75 |
}
|
76 |
|
77 |
+
CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
78 |
+
PAYLOAD = None
|
79 |
|
80 |
def set_progress(task_id, from_page=0, to_page=-1,
|
81 |
prog=None, msg="Processing..."):
|
82 |
+
global PAYLOAD
|
83 |
if prog is not None and prog < 0:
|
84 |
msg = "[ERROR]" + msg
|
85 |
cancel = TaskService.do_cancel(task_id)
|
|
|
100 |
|
101 |
close_connection()
|
102 |
if cancel:
|
103 |
+
if PAYLOAD:
|
104 |
+
PAYLOAD.ack()
|
105 |
+
PAYLOAD = None
|
106 |
+
os._exit(0)
|
107 |
|
108 |
|
109 |
def collect():
|
110 |
+
global CONSUMEER_NAME, PAYLOAD
|
111 |
try:
|
112 |
+
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
113 |
+
if not PAYLOAD:
|
114 |
+
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
|
115 |
+
if not PAYLOAD:
|
116 |
time.sleep(1)
|
117 |
return pd.DataFrame()
|
118 |
except Exception as e:
|
119 |
cron_logger.error("Get task event from queue exception:" + str(e))
|
120 |
return pd.DataFrame()
|
121 |
|
122 |
+
msg = PAYLOAD.get_message()
|
123 |
+
if not msg:
|
124 |
+
return pd.DataFrame()
|
125 |
|
126 |
if TaskService.do_cancel(msg["id"]):
|
127 |
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
|
|
|
387 |
|
388 |
|
389 |
def report_status():
|
390 |
+
global CONSUMEER_NAME
|
391 |
while True:
|
392 |
try:
|
393 |
obj = REDIS_CONN.get("TASKEXE")
|
394 |
if not obj: obj = {}
|
395 |
+
else: obj = json.loads(obj)
|
396 |
+
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
|
397 |
+
obj[CONSUMEER_NAME].append(timer())
|
398 |
+
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
|
399 |
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
400 |
except Exception as e:
|
401 |
print("[Exception]:", str(e))
|
402 |
time.sleep(60)
|
403 |
|
404 |
+
|
405 |
if __name__ == "__main__":
|
406 |
peewee_logger = logging.getLogger('peewee')
|
407 |
peewee_logger.propagate = False
|
|
|
413 |
|
414 |
while True:
|
415 |
main()
|
416 |
+
if PAYLOAD:
|
417 |
+
PAYLOAD.ack()
|
418 |
+
PAYLOAD = None
|
rag/utils/redis_conn.py
CHANGED
@@ -107,7 +107,7 @@ class RedisDB:
|
|
107 |
payload = {"message": json.dumps(message)}
|
108 |
pipeline = self.REDIS.pipeline()
|
109 |
pipeline.xadd(queue, payload)
|
110 |
-
pipeline.expire(queue, exp)
|
111 |
pipeline.execute()
|
112 |
return True
|
113 |
except Exception as e:
|
@@ -143,8 +143,22 @@ class RedisDB:
|
|
143 |
if 'key' in str(e):
|
144 |
pass
|
145 |
else:
|
146 |
-
logging.warning("[EXCEPTION]consumer" + str(queue_name) + "||" + str(e))
|
147 |
return None
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
REDIS_CONN = RedisDB()
|
|
|
107 |
payload = {"message": json.dumps(message)}
|
108 |
pipeline = self.REDIS.pipeline()
|
109 |
pipeline.xadd(queue, payload)
|
110 |
+
#pipeline.expire(queue, exp)
|
111 |
pipeline.execute()
|
112 |
return True
|
113 |
except Exception as e:
|
|
|
143 |
if 'key' in str(e):
|
144 |
pass
|
145 |
else:
|
146 |
+
logging.warning("[EXCEPTION]consumer: " + str(queue_name) + "||" + str(e))
|
147 |
return None
|
148 |
|
149 |
+
def get_unacked_for(self, consumer_name, queue_name, group_name):
|
150 |
+
try:
|
151 |
+
group_info = self.REDIS.xinfo_groups(queue_name)
|
152 |
+
if not any(e["name"] == group_name for e in group_info):
|
153 |
+
return
|
154 |
+
pendings = self.REDIS.xpending_range(queue_name, group_name, min=0, max=10000000000000, count=1, consumername=consumer_name)
|
155 |
+
if not pendings: return
|
156 |
+
msg_id = pendings[0]["message_id"]
|
157 |
+
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
|
158 |
+
_, payload = msg[0]
|
159 |
+
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
|
160 |
+
except Exception as e:
|
161 |
+
logging.warning("[EXCEPTION]xpending_range" + consumer_name + "||" + str(e))
|
162 |
+
self.__open__()
|
163 |
|
164 |
REDIS_CONN = RedisDB()
|