KevinHuSh
commited on
Commit
·
2436df2
1
Parent(s):
368b624
add raptor (#899)
Browse files### What problem does this PR solve?
#882
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/system_app.py +2 -1
- api/db/services/document_service.py +27 -3
- api/db/services/llm_service.py +4 -0
- api/db/services/task_service.py +2 -1
- rag/llm/chat_model.py +2 -3
- rag/raptor.py +114 -0
- rag/svr/task_executor.py +82 -27
- rag/utils/redis_conn.py +11 -9
api/apps/system_app.py
CHANGED
|
@@ -60,7 +60,8 @@ def status():
|
|
| 60 |
st = timer()
|
| 61 |
try:
|
| 62 |
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
| 63 |
-
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.),
|
|
|
|
| 64 |
except Exception as e:
|
| 65 |
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
| 66 |
|
|
|
|
| 60 |
st = timer()
|
| 61 |
try:
|
| 62 |
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
| 63 |
+
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.),
|
| 64 |
+
"pending": qinfo.get("pending", 0)}
|
| 65 |
except Exception as e:
|
| 66 |
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
| 67 |
|
api/db/services/document_service.py
CHANGED
|
@@ -18,8 +18,10 @@ from datetime import datetime
|
|
| 18 |
from elasticsearch_dsl import Q
|
| 19 |
from peewee import fn
|
| 20 |
|
|
|
|
| 21 |
from api.settings import stat_logger
|
| 22 |
-
from api.utils import current_timestamp, get_format_time
|
|
|
|
| 23 |
from rag.utils.es_conn import ELASTICSEARCH
|
| 24 |
from rag.utils.minio_conn import MINIO
|
| 25 |
from rag.nlp import search
|
|
@@ -30,6 +32,7 @@ from api.db.db_models import Document
|
|
| 30 |
from api.db.services.common_service import CommonService
|
| 31 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 32 |
from api.db import StatusEnum
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class DocumentService(CommonService):
|
|
@@ -110,7 +113,7 @@ class DocumentService(CommonService):
|
|
| 110 |
@classmethod
|
| 111 |
@DB.connection_context()
|
| 112 |
def get_unfinished_docs(cls):
|
| 113 |
-
fields = [cls.model.id, cls.model.process_begin_at]
|
| 114 |
docs = cls.model.select(*fields) \
|
| 115 |
.where(
|
| 116 |
cls.model.status == StatusEnum.VALID.value,
|
|
@@ -260,7 +263,12 @@ class DocumentService(CommonService):
|
|
| 260 |
prg = -1
|
| 261 |
status = TaskStatus.FAIL.value
|
| 262 |
elif finished:
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
msg = "\n".join(msg)
|
| 266 |
info = {
|
|
@@ -282,3 +290,19 @@ class DocumentService(CommonService):
|
|
| 282 |
return len(cls.model.select(cls.model.id).where(
|
| 283 |
cls.model.kb_id == kb_id).dicts())
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from elasticsearch_dsl import Q
|
| 19 |
from peewee import fn
|
| 20 |
|
| 21 |
+
from api.db.db_utils import bulk_insert_into_db
|
| 22 |
from api.settings import stat_logger
|
| 23 |
+
from api.utils import current_timestamp, get_format_time, get_uuid
|
| 24 |
+
from rag.settings import SVR_QUEUE_NAME
|
| 25 |
from rag.utils.es_conn import ELASTICSEARCH
|
| 26 |
from rag.utils.minio_conn import MINIO
|
| 27 |
from rag.nlp import search
|
|
|
|
| 32 |
from api.db.services.common_service import CommonService
|
| 33 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 34 |
from api.db import StatusEnum
|
| 35 |
+
from rag.utils.redis_conn import REDIS_CONN
|
| 36 |
|
| 37 |
|
| 38 |
class DocumentService(CommonService):
|
|
|
|
| 113 |
@classmethod
|
| 114 |
@DB.connection_context()
|
| 115 |
def get_unfinished_docs(cls):
|
| 116 |
+
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg]
|
| 117 |
docs = cls.model.select(*fields) \
|
| 118 |
.where(
|
| 119 |
cls.model.status == StatusEnum.VALID.value,
|
|
|
|
| 263 |
prg = -1
|
| 264 |
status = TaskStatus.FAIL.value
|
| 265 |
elif finished:
|
| 266 |
+
if d["parser_config"].get("raptor") and d["progress_msg"].lower().find(" raptor")<0:
|
| 267 |
+
queue_raptor_tasks(d)
|
| 268 |
+
prg *= 0.98
|
| 269 |
+
msg.append("------ RAPTOR -------")
|
| 270 |
+
else:
|
| 271 |
+
status = TaskStatus.DONE.value
|
| 272 |
|
| 273 |
msg = "\n".join(msg)
|
| 274 |
info = {
|
|
|
|
| 290 |
return len(cls.model.select(cls.model.id).where(
|
| 291 |
cls.model.kb_id == kb_id).dicts())
|
| 292 |
|
| 293 |
+
|
| 294 |
+
def queue_raptor_tasks(doc):
|
| 295 |
+
def new_task():
|
| 296 |
+
nonlocal doc
|
| 297 |
+
return {
|
| 298 |
+
"id": get_uuid(),
|
| 299 |
+
"doc_id": doc["id"],
|
| 300 |
+
"from_page": 0,
|
| 301 |
+
"to_page": -1,
|
| 302 |
+
"progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
task = new_task()
|
| 306 |
+
bulk_insert_into_db(Task, [task], True)
|
| 307 |
+
task["type"] = "raptor"
|
| 308 |
+
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
api/db/services/llm_service.py
CHANGED
|
@@ -155,6 +155,10 @@ class LLMBundle(object):
|
|
| 155 |
tenant_id, llm_type, llm_name, lang=lang)
|
| 156 |
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
| 157 |
tenant_id, llm_type, llm_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
def encode(self, texts: list, batch_size=32):
|
| 160 |
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
|
|
|
| 155 |
tenant_id, llm_type, llm_name, lang=lang)
|
| 156 |
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
| 157 |
tenant_id, llm_type, llm_name)
|
| 158 |
+
self.max_length = 512
|
| 159 |
+
for lm in LLMService.query(llm_name=llm_name):
|
| 160 |
+
self.max_length = lm.max_tokens
|
| 161 |
+
break
|
| 162 |
|
| 163 |
def encode(self, texts: list, batch_size=32):
|
| 164 |
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
api/db/services/task_service.py
CHANGED
|
@@ -53,6 +53,7 @@ class TaskService(CommonService):
|
|
| 53 |
Knowledgebase.embd_id,
|
| 54 |
Tenant.img2txt_id,
|
| 55 |
Tenant.asr_id,
|
|
|
|
| 56 |
cls.model.update_time]
|
| 57 |
docs = cls.model.select(*fields) \
|
| 58 |
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
|
@@ -159,4 +160,4 @@ def queue_tasks(doc, bucket, name):
|
|
| 159 |
DocumentService.begin2parse(doc["id"])
|
| 160 |
|
| 161 |
for t in tsks:
|
| 162 |
-
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
|
|
|
| 53 |
Knowledgebase.embd_id,
|
| 54 |
Tenant.img2txt_id,
|
| 55 |
Tenant.asr_id,
|
| 56 |
+
Tenant.llm_id,
|
| 57 |
cls.model.update_time]
|
| 58 |
docs = cls.model.select(*fields) \
|
| 59 |
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
|
|
|
| 160 |
DocumentService.begin2parse(doc["id"])
|
| 161 |
|
| 162 |
for t in tsks:
|
| 163 |
+
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
rag/llm/chat_model.py
CHANGED
|
@@ -57,8 +57,7 @@ class Base(ABC):
|
|
| 57 |
stream=True,
|
| 58 |
**gen_conf)
|
| 59 |
for resp in response:
|
| 60 |
-
if
|
| 61 |
-
if not resp.choices[0].delta.content:continue
|
| 62 |
ans += resp.choices[0].delta.content
|
| 63 |
total_tokens += 1
|
| 64 |
if resp.choices[0].finish_reason == "length":
|
|
@@ -379,7 +378,7 @@ class VolcEngineChat(Base):
|
|
| 379 |
ans += resp.choices[0].message.content
|
| 380 |
yield ans
|
| 381 |
if resp.choices[0].finish_reason == "stop":
|
| 382 |
-
|
| 383 |
|
| 384 |
except Exception as e:
|
| 385 |
yield ans + "\n**ERROR**: " + str(e)
|
|
|
|
| 57 |
stream=True,
|
| 58 |
**gen_conf)
|
| 59 |
for resp in response:
|
| 60 |
+
if not resp.choices or not resp.choices[0].delta.content:continue
|
|
|
|
| 61 |
ans += resp.choices[0].delta.content
|
| 62 |
total_tokens += 1
|
| 63 |
if resp.choices[0].finish_reason == "length":
|
|
|
|
| 378 |
ans += resp.choices[0].message.content
|
| 379 |
yield ans
|
| 380 |
if resp.choices[0].finish_reason == "stop":
|
| 381 |
+
yield resp.usage.total_tokens
|
| 382 |
|
| 383 |
except Exception as e:
|
| 384 |
yield ans + "\n**ERROR**: " + str(e)
|
rag/raptor.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
import re
|
| 17 |
+
import traceback
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
|
| 19 |
+
from threading import Lock
|
| 20 |
+
from typing import Tuple
|
| 21 |
+
import umap
|
| 22 |
+
import numpy as np
|
| 23 |
+
from sklearn.mixture import GaussianMixture
|
| 24 |
+
|
| 25 |
+
from rag.utils import num_tokens_from_string, truncate
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
| 29 |
+
def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=256, threshold=0.1):
|
| 30 |
+
self._max_cluster = max_cluster
|
| 31 |
+
self._llm_model = llm_model
|
| 32 |
+
self._embd_model = embd_model
|
| 33 |
+
self._threshold = threshold
|
| 34 |
+
self._prompt = prompt
|
| 35 |
+
self._max_token = max_token
|
| 36 |
+
|
| 37 |
+
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state:int):
|
| 38 |
+
max_clusters = min(self._max_cluster, len(embeddings))
|
| 39 |
+
n_clusters = np.arange(1, max_clusters)
|
| 40 |
+
bics = []
|
| 41 |
+
for n in n_clusters:
|
| 42 |
+
gm = GaussianMixture(n_components=n, random_state=random_state)
|
| 43 |
+
gm.fit(embeddings)
|
| 44 |
+
bics.append(gm.bic(embeddings))
|
| 45 |
+
optimal_clusters = n_clusters[np.argmin(bics)]
|
| 46 |
+
return optimal_clusters
|
| 47 |
+
|
| 48 |
+
def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None):
|
| 49 |
+
layers = [(0, len(chunks))]
|
| 50 |
+
start, end = 0, len(chunks)
|
| 51 |
+
if len(chunks) <= 1: return
|
| 52 |
+
|
| 53 |
+
def summarize(ck_idx, lock):
|
| 54 |
+
nonlocal chunks
|
| 55 |
+
try:
|
| 56 |
+
texts = [chunks[i][0] for i in ck_idx]
|
| 57 |
+
len_per_chunk = int((self._llm_model.max_length - self._max_token)/len(texts))
|
| 58 |
+
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
| 59 |
+
cnt = self._llm_model.chat("You're a helpful assistant.",
|
| 60 |
+
[{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}],
|
| 61 |
+
{"temperature": 0.3, "max_tokens": self._max_token}
|
| 62 |
+
)
|
| 63 |
+
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt)
|
| 64 |
+
print("SUM:", cnt)
|
| 65 |
+
embds, _ = self._embd_model.encode([cnt])
|
| 66 |
+
with lock:
|
| 67 |
+
chunks.append((cnt, embds[0]))
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(e, flush=True)
|
| 70 |
+
traceback.print_stack(e)
|
| 71 |
+
return e
|
| 72 |
+
|
| 73 |
+
labels = []
|
| 74 |
+
while end - start > 1:
|
| 75 |
+
embeddings = [embd for _, embd in chunks[start: end]]
|
| 76 |
+
if len(embeddings) == 2:
|
| 77 |
+
summarize([start, start+1], Lock())
|
| 78 |
+
if callback:
|
| 79 |
+
callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
|
| 80 |
+
labels.extend([0,0])
|
| 81 |
+
layers.append((end, len(chunks)))
|
| 82 |
+
start = end
|
| 83 |
+
end = len(chunks)
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
n_neighbors = int((len(embeddings) - 1) ** 0.8)
|
| 87 |
+
reduced_embeddings = umap.UMAP(
|
| 88 |
+
n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings)-2), metric="cosine"
|
| 89 |
+
).fit_transform(embeddings)
|
| 90 |
+
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
|
| 91 |
+
if n_clusters == 1:
|
| 92 |
+
lbls = [0 for _ in range(len(reduced_embeddings))]
|
| 93 |
+
else:
|
| 94 |
+
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
| 95 |
+
gm.fit(reduced_embeddings)
|
| 96 |
+
probs = gm.predict_proba(reduced_embeddings)
|
| 97 |
+
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
| 98 |
+
lock = Lock()
|
| 99 |
+
with ThreadPoolExecutor(max_workers=12) as executor:
|
| 100 |
+
threads = []
|
| 101 |
+
for c in range(n_clusters):
|
| 102 |
+
ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c]
|
| 103 |
+
threads.append(executor.submit(summarize, ck_idx, lock))
|
| 104 |
+
wait(threads, return_when=ALL_COMPLETED)
|
| 105 |
+
print([t.result() for t in threads])
|
| 106 |
+
|
| 107 |
+
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
| 108 |
+
labels.extend(lbls)
|
| 109 |
+
layers.append((end, len(chunks)))
|
| 110 |
+
if callback:
|
| 111 |
+
callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
|
| 112 |
+
start = end
|
| 113 |
+
end = len(chunks)
|
| 114 |
+
|
rag/svr/task_executor.py
CHANGED
|
@@ -26,20 +26,22 @@ import traceback
|
|
| 26 |
from functools import partial
|
| 27 |
|
| 28 |
from api.db.services.file2document_service import File2DocumentService
|
|
|
|
|
|
|
| 29 |
from rag.utils.minio_conn import MINIO
|
| 30 |
from api.db.db_models import close_connection
|
| 31 |
from rag.settings import database_logger, SVR_QUEUE_NAME
|
| 32 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
| 33 |
from multiprocessing import Pool
|
| 34 |
import numpy as np
|
| 35 |
-
from elasticsearch_dsl import Q
|
| 36 |
from multiprocessing.context import TimeoutError
|
| 37 |
from api.db.services.task_service import TaskService
|
| 38 |
from rag.utils.es_conn import ELASTICSEARCH
|
| 39 |
from timeit import default_timer as timer
|
| 40 |
-
from rag.utils import rmSpace, findMaxTm
|
| 41 |
|
| 42 |
-
from rag.nlp import search
|
| 43 |
from io import BytesIO
|
| 44 |
import pandas as pd
|
| 45 |
|
|
@@ -114,6 +116,8 @@ def collect():
|
|
| 114 |
tasks = TaskService.get_tasks(msg["id"])
|
| 115 |
assert tasks, "{} empty task!".format(msg["id"])
|
| 116 |
tasks = pd.DataFrame(tasks)
|
|
|
|
|
|
|
| 117 |
return tasks
|
| 118 |
|
| 119 |
|
|
@@ -245,6 +249,47 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
|
| 245 |
return tk_count
|
| 246 |
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
def main():
|
| 249 |
rows = collect()
|
| 250 |
if len(rows) == 0:
|
|
@@ -259,35 +304,45 @@ def main():
|
|
| 259 |
cron_logger.error(str(e))
|
| 260 |
continue
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
-
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
| 285 |
init_kb(r)
|
| 286 |
chunk_count = len(set([c["_id"] for c in cks]))
|
| 287 |
st = timer()
|
| 288 |
es_r = ""
|
| 289 |
-
|
| 290 |
-
|
|
|
|
| 291 |
if b % 128 == 0:
|
| 292 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
| 293 |
|
|
|
|
| 26 |
from functools import partial
|
| 27 |
|
| 28 |
from api.db.services.file2document_service import File2DocumentService
|
| 29 |
+
from api.settings import retrievaler
|
| 30 |
+
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
| 31 |
from rag.utils.minio_conn import MINIO
|
| 32 |
from api.db.db_models import close_connection
|
| 33 |
from rag.settings import database_logger, SVR_QUEUE_NAME
|
| 34 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
| 35 |
from multiprocessing import Pool
|
| 36 |
import numpy as np
|
| 37 |
+
from elasticsearch_dsl import Q, Search
|
| 38 |
from multiprocessing.context import TimeoutError
|
| 39 |
from api.db.services.task_service import TaskService
|
| 40 |
from rag.utils.es_conn import ELASTICSEARCH
|
| 41 |
from timeit import default_timer as timer
|
| 42 |
+
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
|
| 43 |
|
| 44 |
+
from rag.nlp import search, rag_tokenizer
|
| 45 |
from io import BytesIO
|
| 46 |
import pandas as pd
|
| 47 |
|
|
|
|
| 116 |
tasks = TaskService.get_tasks(msg["id"])
|
| 117 |
assert tasks, "{} empty task!".format(msg["id"])
|
| 118 |
tasks = pd.DataFrame(tasks)
|
| 119 |
+
if msg.get("type", "") == "raptor":
|
| 120 |
+
tasks["task_type"] = "raptor"
|
| 121 |
return tasks
|
| 122 |
|
| 123 |
|
|
|
|
| 249 |
return tk_count
|
| 250 |
|
| 251 |
|
| 252 |
+
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
| 253 |
+
vts, _ = embd_mdl.encode(["ok"])
|
| 254 |
+
vctr_nm = "q_%d_vec"%len(vts[0])
|
| 255 |
+
chunks = []
|
| 256 |
+
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
| 257 |
+
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
| 258 |
+
|
| 259 |
+
raptor = Raptor(
|
| 260 |
+
row["parser_config"]["raptor"].get("max_cluster", 64),
|
| 261 |
+
chat_mdl,
|
| 262 |
+
embd_mdl,
|
| 263 |
+
row["parser_config"]["raptor"]["prompt"],
|
| 264 |
+
row["parser_config"]["raptor"]["max_token"],
|
| 265 |
+
row["parser_config"]["raptor"]["threshold"]
|
| 266 |
+
)
|
| 267 |
+
original_length = len(chunks)
|
| 268 |
+
raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
| 269 |
+
doc = {
|
| 270 |
+
"doc_id": row["doc_id"],
|
| 271 |
+
"kb_id": [str(row["kb_id"])],
|
| 272 |
+
"docnm_kwd": row["name"],
|
| 273 |
+
"title_tks": rag_tokenizer.tokenize(row["name"])
|
| 274 |
+
}
|
| 275 |
+
res = []
|
| 276 |
+
tk_count = 0
|
| 277 |
+
for content, vctr in chunks[original_length:]:
|
| 278 |
+
d = copy.deepcopy(doc)
|
| 279 |
+
md5 = hashlib.md5()
|
| 280 |
+
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
| 281 |
+
d["_id"] = md5.hexdigest()
|
| 282 |
+
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
| 283 |
+
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
| 284 |
+
d[vctr_nm] = vctr.tolist()
|
| 285 |
+
d["content_with_weight"] = content
|
| 286 |
+
d["content_ltks"] = rag_tokenizer.tokenize(content)
|
| 287 |
+
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
| 288 |
+
res.append(d)
|
| 289 |
+
tk_count += num_tokens_from_string(content)
|
| 290 |
+
return res, tk_count
|
| 291 |
+
|
| 292 |
+
|
| 293 |
def main():
|
| 294 |
rows = collect()
|
| 295 |
if len(rows) == 0:
|
|
|
|
| 304 |
cron_logger.error(str(e))
|
| 305 |
continue
|
| 306 |
|
| 307 |
+
if r.get("task_type", "") == "raptor":
|
| 308 |
+
try:
|
| 309 |
+
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
| 310 |
+
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
|
| 311 |
+
except Exception as e:
|
| 312 |
+
callback(-1, msg=str(e))
|
| 313 |
+
cron_logger.error(str(e))
|
| 314 |
+
continue
|
| 315 |
+
else:
|
| 316 |
+
st = timer()
|
| 317 |
+
cks = build(r)
|
| 318 |
+
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
| 319 |
+
if cks is None:
|
| 320 |
+
continue
|
| 321 |
+
if not cks:
|
| 322 |
+
callback(1., "No chunk! Done!")
|
| 323 |
+
continue
|
| 324 |
+
# TODO: exception handler
|
| 325 |
+
## set_progress(r["did"], -1, "ERROR: ")
|
| 326 |
+
callback(
|
| 327 |
+
msg="Finished slicing files(%d). Start to embedding the content." %
|
| 328 |
+
len(cks))
|
| 329 |
+
st = timer()
|
| 330 |
+
try:
|
| 331 |
+
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
| 332 |
+
except Exception as e:
|
| 333 |
+
callback(-1, "Embedding error:{}".format(str(e)))
|
| 334 |
+
cron_logger.error(str(e))
|
| 335 |
+
tk_count = 0
|
| 336 |
+
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
| 337 |
+
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
| 338 |
|
|
|
|
| 339 |
init_kb(r)
|
| 340 |
chunk_count = len(set([c["_id"] for c in cks]))
|
| 341 |
st = timer()
|
| 342 |
es_r = ""
|
| 343 |
+
es_bulk_size = 16
|
| 344 |
+
for b in range(0, len(cks), es_bulk_size):
|
| 345 |
+
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
|
| 346 |
if b % 128 == 0:
|
| 347 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
| 348 |
|
rag/utils/redis_conn.py
CHANGED
|
@@ -97,15 +97,17 @@ class RedisDB:
|
|
| 97 |
return False
|
| 98 |
|
| 99 |
def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool:
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
return False
|
| 110 |
|
| 111 |
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload:
|
|
|
|
| 97 |
return False
|
| 98 |
|
| 99 |
def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool:
|
| 100 |
+
for _ in range(3):
|
| 101 |
+
try:
|
| 102 |
+
payload = {"message": json.dumps(message)}
|
| 103 |
+
pipeline = self.REDIS.pipeline()
|
| 104 |
+
pipeline.xadd(queue, payload)
|
| 105 |
+
pipeline.expire(queue, exp)
|
| 106 |
+
pipeline.execute()
|
| 107 |
+
return True
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(e)
|
| 110 |
+
logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e))
|
| 111 |
return False
|
| 112 |
|
| 113 |
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload:
|