Integration with Infinity (#2894)
Browse files### What problem does this PR solve?
Integration with Infinity
- Replaced ELASTICSEARCH with dataStoreConn
- Renamed deleteByQuery with delete
- Renamed bulk to upsertBulk
- getHighlight, getAggregation
- Fix KGSearch.search
- Moved Dealer.sql_retrieval to es_conn.py
### Type of change
- [x] Refactoring
- .github/workflows/tests.yml +1 -1
- README.md +2 -2
- README_ja.md +1 -1
- README_ko.md +1 -1
- README_zh.md +1 -1
- api/apps/api_app.py +3 -2
- api/apps/chunk_app.py +37 -41
- api/apps/document_app.py +10 -19
- api/apps/kb_app.py +5 -0
- api/apps/sdk/doc.py +70 -103
- api/apps/system_app.py +5 -4
- api/db/db_models.py +20 -20
- api/db/services/document_service.py +22 -16
- api/db/services/knowledgebase_service.py +10 -0
- api/settings.py +9 -3
- api/utils/api_utils.py +1 -9
- conf/infinity_mapping.json +26 -0
- conf/mapping.json +193 -190
- docker/.env +5 -0
- docker/docker-compose-base.yml +37 -4
- docker/docker-compose.yml +5 -4
- docs/guides/develop/launch_ragflow_from_source.md +1 -1
- docs/references/http_api_reference.md +1 -1
- docs/references/python_api_reference.md +1 -1
- graphrag/claim_extractor.py +4 -1
- graphrag/search.py +58 -63
- graphrag/smoke.py +4 -1
- poetry.lock +0 -0
- pyproject.toml +8 -6
- rag/app/presentation.py +8 -7
- rag/app/table.py +3 -3
- rag/benchmark.py +310 -280
- rag/nlp/__init__.py +13 -9
- rag/nlp/query.py +66 -46
- rag/nlp/search.py +123 -225
- rag/settings.py +4 -3
- rag/svr/task_executor.py +29 -29
- rag/utils/doc_store_conn.py +251 -0
- rag/utils/es_conn.py +342 -366
- rag/utils/infinity_conn.py +436 -0
- sdk/python/ragflow_sdk/modules/document.py +2 -2
- sdk/python/test/t_chunk.py +6 -1
.github/workflows/tests.yml
CHANGED
@@ -78,7 +78,7 @@ jobs:
|
|
78 |
echo "Waiting for service to be available..."
|
79 |
sleep 5
|
80 |
done
|
81 |
-
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
|
82 |
|
83 |
- name: Stop ragflow:dev
|
84 |
if: always() # always run this step even if previous steps failed
|
|
|
78 |
echo "Waiting for service to be available..."
|
79 |
sleep 5
|
80 |
done
|
81 |
+
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest --tb=short t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
|
82 |
|
83 |
- name: Stop ragflow:dev
|
84 |
if: always() # always run this step even if previous steps failed
|
README.md
CHANGED
@@ -285,7 +285,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|
285 |
git clone https://github.com/infiniflow/ragflow.git
|
286 |
cd ragflow/
|
287 |
export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
|
288 |
-
~/.local/bin/poetry install --sync --no-root # install RAGFlow dependent python modules
|
289 |
```
|
290 |
|
291 |
3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
|
@@ -295,7 +295,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|
295 |
|
296 |
Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
297 |
```
|
298 |
-
127.0.0.1 es01 mysql minio redis
|
299 |
```
|
300 |
In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
301 |
|
|
|
285 |
git clone https://github.com/infiniflow/ragflow.git
|
286 |
cd ragflow/
|
287 |
export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
|
288 |
+
~/.local/bin/poetry install --sync --no-root --with=full # install RAGFlow dependent python modules
|
289 |
```
|
290 |
|
291 |
3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
|
|
|
295 |
|
296 |
Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
297 |
```
|
298 |
+
127.0.0.1 es01 infinity mysql minio redis
|
299 |
```
|
300 |
In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
301 |
|
README_ja.md
CHANGED
@@ -250,7 +250,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|
250 |
|
251 |
`/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
|
252 |
```
|
253 |
-
127.0.0.1 es01 mysql minio redis
|
254 |
```
|
255 |
**docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).
|
256 |
|
|
|
250 |
|
251 |
`/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
|
252 |
```
|
253 |
+
127.0.0.1 es01 infinity mysql minio redis
|
254 |
```
|
255 |
**docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).
|
256 |
|
README_ko.md
CHANGED
@@ -254,7 +254,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|
254 |
|
255 |
`/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
|
256 |
```
|
257 |
-
127.0.0.1 es01 mysql minio redis
|
258 |
```
|
259 |
**docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).
|
260 |
|
|
|
254 |
|
255 |
`/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
|
256 |
```
|
257 |
+
127.0.0.1 es01 infinity mysql minio redis
|
258 |
```
|
259 |
**docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).
|
260 |
|
README_zh.md
CHANGED
@@ -252,7 +252,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|
252 |
|
253 |
在 `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`:
|
254 |
```
|
255 |
-
127.0.0.1 es01 mysql minio redis
|
256 |
```
|
257 |
在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`,es 端口更新为 `1200`。
|
258 |
|
|
|
252 |
|
253 |
在 `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`:
|
254 |
```
|
255 |
+
127.0.0.1 es01 infinity mysql minio redis
|
256 |
```
|
257 |
在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`,es 端口更新为 `1200`。
|
258 |
|
api/apps/api_app.py
CHANGED
@@ -529,13 +529,14 @@ def list_chunks():
|
|
529 |
return get_json_result(
|
530 |
data=False, message="Can't find doc_name or doc_id"
|
531 |
)
|
|
|
532 |
|
533 |
-
res = retrievaler.chunk_list(doc_id
|
534 |
res = [
|
535 |
{
|
536 |
"content": res_item["content_with_weight"],
|
537 |
"doc_name": res_item["docnm_kwd"],
|
538 |
-
"
|
539 |
} for res_item in res
|
540 |
]
|
541 |
|
|
|
529 |
return get_json_result(
|
530 |
data=False, message="Can't find doc_name or doc_id"
|
531 |
)
|
532 |
+
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
533 |
|
534 |
+
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
535 |
res = [
|
536 |
{
|
537 |
"content": res_item["content_with_weight"],
|
538 |
"doc_name": res_item["docnm_kwd"],
|
539 |
+
"image_id": res_item["img_id"]
|
540 |
} for res_item in res
|
541 |
]
|
542 |
|
api/apps/chunk_app.py
CHANGED
@@ -18,12 +18,10 @@ import json
|
|
18 |
|
19 |
from flask import request
|
20 |
from flask_login import login_required, current_user
|
21 |
-
from elasticsearch_dsl import Q
|
22 |
|
23 |
from api.db.services.dialog_service import keyword_extraction
|
24 |
from rag.app.qa import rmPrefix, beAdoc
|
25 |
from rag.nlp import search, rag_tokenizer
|
26 |
-
from rag.utils.es_conn import ELASTICSEARCH
|
27 |
from rag.utils import rmSpace
|
28 |
from api.db import LLMType, ParserType
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
@@ -31,12 +29,11 @@ from api.db.services.llm_service import LLMBundle
|
|
31 |
from api.db.services.user_service import UserTenantService
|
32 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
33 |
from api.db.services.document_service import DocumentService
|
34 |
-
from api.settings import RetCode, retrievaler, kg_retrievaler
|
35 |
from api.utils.api_utils import get_json_result
|
36 |
import hashlib
|
37 |
import re
|
38 |
|
39 |
-
|
40 |
@manager.route('/list', methods=['POST'])
|
41 |
@login_required
|
42 |
@validate_request("doc_id")
|
@@ -53,12 +50,13 @@ def list_chunk():
|
|
53 |
e, doc = DocumentService.get_by_id(doc_id)
|
54 |
if not e:
|
55 |
return get_data_error_result(message="Document not found!")
|
|
|
56 |
query = {
|
57 |
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
58 |
}
|
59 |
if "available_int" in req:
|
60 |
query["available_int"] = int(req["available_int"])
|
61 |
-
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
62 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
63 |
for id in sres.ids:
|
64 |
d = {
|
@@ -69,16 +67,12 @@ def list_chunk():
|
|
69 |
"doc_id": sres.field[id]["doc_id"],
|
70 |
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
71 |
"important_kwd": sres.field[id].get("important_kwd", []),
|
72 |
-
"
|
73 |
"available_int": sres.field[id].get("available_int", 1),
|
74 |
-
"positions": sres.field[id].get("
|
75 |
}
|
76 |
-
|
77 |
-
|
78 |
-
for i in range(0, len(d["positions"]), 5):
|
79 |
-
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
80 |
-
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
81 |
-
d["positions"] = poss
|
82 |
res["chunks"].append(d)
|
83 |
return get_json_result(data=res)
|
84 |
except Exception as e:
|
@@ -96,22 +90,20 @@ def get():
|
|
96 |
tenants = UserTenantService.query(user_id=current_user.id)
|
97 |
if not tenants:
|
98 |
return get_data_error_result(message="Tenant not found!")
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
103 |
return server_error_response("Chunk not found")
|
104 |
-
id = res["_id"]
|
105 |
-
res = res["_source"]
|
106 |
-
res["chunk_id"] = id
|
107 |
k = []
|
108 |
-
for n in
|
109 |
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
110 |
k.append(n)
|
111 |
for n in k:
|
112 |
-
del
|
113 |
|
114 |
-
return get_json_result(data=
|
115 |
except Exception as e:
|
116 |
if str(e).find("NotFoundError") >= 0:
|
117 |
return get_json_result(data=False, message='Chunk not found!',
|
@@ -162,7 +154,7 @@ def set():
|
|
162 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
163 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
164 |
d["q_%d_vec" % len(v)] = v.tolist()
|
165 |
-
|
166 |
return get_json_result(data=True)
|
167 |
except Exception as e:
|
168 |
return server_error_response(e)
|
@@ -174,11 +166,11 @@ def set():
|
|
174 |
def switch():
|
175 |
req = request.json
|
176 |
try:
|
177 |
-
|
178 |
-
if not
|
179 |
-
return get_data_error_result(message="
|
180 |
-
if not
|
181 |
-
search.index_name(tenant_id)):
|
182 |
return get_data_error_result(message="Index updating failure")
|
183 |
return get_json_result(data=True)
|
184 |
except Exception as e:
|
@@ -191,12 +183,11 @@ def switch():
|
|
191 |
def rm():
|
192 |
req = request.json
|
193 |
try:
|
194 |
-
if not ELASTICSEARCH.deleteByQuery(
|
195 |
-
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
196 |
-
return get_data_error_result(message="Index updating failure")
|
197 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
198 |
if not e:
|
199 |
return get_data_error_result(message="Document not found!")
|
|
|
|
|
200 |
deleted_chunk_ids = req["chunk_ids"]
|
201 |
chunk_number = len(deleted_chunk_ids)
|
202 |
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
@@ -239,7 +230,7 @@ def create():
|
|
239 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
240 |
v = 0.1 * v[0] + 0.9 * v[1]
|
241 |
d["q_%d_vec" % len(v)] = v.tolist()
|
242 |
-
|
243 |
|
244 |
DocumentService.increment_chunk_num(
|
245 |
doc.id, doc.kb_id, c, 1, 0)
|
@@ -256,8 +247,9 @@ def retrieval_test():
|
|
256 |
page = int(req.get("page", 1))
|
257 |
size = int(req.get("size", 30))
|
258 |
question = req["question"]
|
259 |
-
|
260 |
-
if isinstance(
|
|
|
261 |
doc_ids = req.get("doc_ids", [])
|
262 |
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
263 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
@@ -265,17 +257,17 @@ def retrieval_test():
|
|
265 |
|
266 |
try:
|
267 |
tenants = UserTenantService.query(user_id=current_user.id)
|
268 |
-
for
|
269 |
for tenant in tenants:
|
270 |
if KnowledgebaseService.query(
|
271 |
-
tenant_id=tenant.tenant_id, id=
|
272 |
break
|
273 |
else:
|
274 |
return get_json_result(
|
275 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
276 |
code=RetCode.OPERATING_ERROR)
|
277 |
|
278 |
-
e, kb = KnowledgebaseService.get_by_id(
|
279 |
if not e:
|
280 |
return get_data_error_result(message="Knowledgebase not found!")
|
281 |
|
@@ -290,7 +282,7 @@ def retrieval_test():
|
|
290 |
question += keyword_extraction(chat_mdl, question)
|
291 |
|
292 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
293 |
-
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id,
|
294 |
similarity_threshold, vector_similarity_weight, top,
|
295 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
296 |
for c in ranks["chunks"]:
|
@@ -309,12 +301,16 @@ def retrieval_test():
|
|
309 |
@login_required
|
310 |
def knowledge_graph():
|
311 |
doc_id = request.args["doc_id"]
|
|
|
|
|
|
|
|
|
|
|
312 |
req = {
|
313 |
"doc_ids":[doc_id],
|
314 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
315 |
}
|
316 |
-
|
317 |
-
sres = retrievaler.search(req, search.index_name(tenant_id))
|
318 |
obj = {"graph": {}, "mind_map": {}}
|
319 |
for id in sres.ids[:2]:
|
320 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
|
|
18 |
|
19 |
from flask import request
|
20 |
from flask_login import login_required, current_user
|
|
|
21 |
|
22 |
from api.db.services.dialog_service import keyword_extraction
|
23 |
from rag.app.qa import rmPrefix, beAdoc
|
24 |
from rag.nlp import search, rag_tokenizer
|
|
|
25 |
from rag.utils import rmSpace
|
26 |
from api.db import LLMType, ParserType
|
27 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
29 |
from api.db.services.user_service import UserTenantService
|
30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
31 |
from api.db.services.document_service import DocumentService
|
32 |
+
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
|
33 |
from api.utils.api_utils import get_json_result
|
34 |
import hashlib
|
35 |
import re
|
36 |
|
|
|
37 |
@manager.route('/list', methods=['POST'])
|
38 |
@login_required
|
39 |
@validate_request("doc_id")
|
|
|
50 |
e, doc = DocumentService.get_by_id(doc_id)
|
51 |
if not e:
|
52 |
return get_data_error_result(message="Document not found!")
|
53 |
+
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
54 |
query = {
|
55 |
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
56 |
}
|
57 |
if "available_int" in req:
|
58 |
query["available_int"] = int(req["available_int"])
|
59 |
+
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
60 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
61 |
for id in sres.ids:
|
62 |
d = {
|
|
|
67 |
"doc_id": sres.field[id]["doc_id"],
|
68 |
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
69 |
"important_kwd": sres.field[id].get("important_kwd", []),
|
70 |
+
"image_id": sres.field[id].get("img_id", ""),
|
71 |
"available_int": sres.field[id].get("available_int", 1),
|
72 |
+
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
73 |
}
|
74 |
+
assert isinstance(d["positions"], list)
|
75 |
+
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
|
|
|
|
|
|
|
|
76 |
res["chunks"].append(d)
|
77 |
return get_json_result(data=res)
|
78 |
except Exception as e:
|
|
|
90 |
tenants = UserTenantService.query(user_id=current_user.id)
|
91 |
if not tenants:
|
92 |
return get_data_error_result(message="Tenant not found!")
|
93 |
+
tenant_id = tenants[0].tenant_id
|
94 |
+
|
95 |
+
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
96 |
+
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
97 |
+
if chunk is None:
|
98 |
return server_error_response("Chunk not found")
|
|
|
|
|
|
|
99 |
k = []
|
100 |
+
for n in chunk.keys():
|
101 |
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
102 |
k.append(n)
|
103 |
for n in k:
|
104 |
+
del chunk[n]
|
105 |
|
106 |
+
return get_json_result(data=chunk)
|
107 |
except Exception as e:
|
108 |
if str(e).find("NotFoundError") >= 0:
|
109 |
return get_json_result(data=False, message='Chunk not found!',
|
|
|
154 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
155 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
156 |
d["q_%d_vec" % len(v)] = v.tolist()
|
157 |
+
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
158 |
return get_json_result(data=True)
|
159 |
except Exception as e:
|
160 |
return server_error_response(e)
|
|
|
166 |
def switch():
|
167 |
req = request.json
|
168 |
try:
|
169 |
+
e, doc = DocumentService.get_by_id(req["doc_id"])
|
170 |
+
if not e:
|
171 |
+
return get_data_error_result(message="Document not found!")
|
172 |
+
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
173 |
+
search.index_name(doc.tenant_id), doc.kb_id):
|
174 |
return get_data_error_result(message="Index updating failure")
|
175 |
return get_json_result(data=True)
|
176 |
except Exception as e:
|
|
|
183 |
def rm():
|
184 |
req = request.json
|
185 |
try:
|
|
|
|
|
|
|
186 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
187 |
if not e:
|
188 |
return get_data_error_result(message="Document not found!")
|
189 |
+
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
190 |
+
return get_data_error_result(message="Index updating failure")
|
191 |
deleted_chunk_ids = req["chunk_ids"]
|
192 |
chunk_number = len(deleted_chunk_ids)
|
193 |
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
|
|
230 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
231 |
v = 0.1 * v[0] + 0.9 * v[1]
|
232 |
d["q_%d_vec" % len(v)] = v.tolist()
|
233 |
+
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
234 |
|
235 |
DocumentService.increment_chunk_num(
|
236 |
doc.id, doc.kb_id, c, 1, 0)
|
|
|
247 |
page = int(req.get("page", 1))
|
248 |
size = int(req.get("size", 30))
|
249 |
question = req["question"]
|
250 |
+
kb_ids = req["kb_id"]
|
251 |
+
if isinstance(kb_ids, str):
|
252 |
+
kb_ids = [kb_ids]
|
253 |
doc_ids = req.get("doc_ids", [])
|
254 |
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
255 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
|
|
257 |
|
258 |
try:
|
259 |
tenants = UserTenantService.query(user_id=current_user.id)
|
260 |
+
for kb_id in kb_ids:
|
261 |
for tenant in tenants:
|
262 |
if KnowledgebaseService.query(
|
263 |
+
tenant_id=tenant.tenant_id, id=kb_id):
|
264 |
break
|
265 |
else:
|
266 |
return get_json_result(
|
267 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
268 |
code=RetCode.OPERATING_ERROR)
|
269 |
|
270 |
+
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
271 |
if not e:
|
272 |
return get_data_error_result(message="Knowledgebase not found!")
|
273 |
|
|
|
282 |
question += keyword_extraction(chat_mdl, question)
|
283 |
|
284 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
285 |
+
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
286 |
similarity_threshold, vector_similarity_weight, top,
|
287 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
288 |
for c in ranks["chunks"]:
|
|
|
301 |
@login_required
|
302 |
def knowledge_graph():
|
303 |
doc_id = request.args["doc_id"]
|
304 |
+
e, doc = DocumentService.get_by_id(doc_id)
|
305 |
+
if not e:
|
306 |
+
return get_data_error_result(message="Document not found!")
|
307 |
+
tenant_id = DocumentService.get_tenant_id(doc_id)
|
308 |
+
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
309 |
req = {
|
310 |
"doc_ids":[doc_id],
|
311 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
312 |
}
|
313 |
+
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
|
|
|
314 |
obj = {"graph": {}, "mind_map": {}}
|
315 |
for id in sres.ids[:2]:
|
316 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
api/apps/document_app.py
CHANGED
@@ -17,7 +17,6 @@ import pathlib
|
|
17 |
import re
|
18 |
|
19 |
import flask
|
20 |
-
from elasticsearch_dsl import Q
|
21 |
from flask import request
|
22 |
from flask_login import login_required, current_user
|
23 |
|
@@ -27,14 +26,13 @@ from api.db.services.file_service import FileService
|
|
27 |
from api.db.services.task_service import TaskService, queue_tasks
|
28 |
from api.db.services.user_service import UserTenantService
|
29 |
from rag.nlp import search
|
30 |
-
from rag.utils.es_conn import ELASTICSEARCH
|
31 |
from api.db.services import duplicate_name
|
32 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
33 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
34 |
from api.utils import get_uuid
|
35 |
from api.db import FileType, TaskStatus, ParserType, FileSource
|
36 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
37 |
-
from api.settings import RetCode
|
38 |
from api.utils.api_utils import get_json_result
|
39 |
from rag.utils.storage_factory import STORAGE_IMPL
|
40 |
from api.utils.file_utils import filename_type, thumbnail
|
@@ -275,18 +273,8 @@ def change_status():
|
|
275 |
return get_data_error_result(
|
276 |
message="Database error (Document update)!")
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
scripts="ctx._source.available_int=0;",
|
281 |
-
idxnm=search.index_name(
|
282 |
-
kb.tenant_id)
|
283 |
-
)
|
284 |
-
else:
|
285 |
-
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
286 |
-
scripts="ctx._source.available_int=1;",
|
287 |
-
idxnm=search.index_name(
|
288 |
-
kb.tenant_id)
|
289 |
-
)
|
290 |
return get_json_result(data=True)
|
291 |
except Exception as e:
|
292 |
return server_error_response(e)
|
@@ -365,8 +353,11 @@ def run():
|
|
365 |
tenant_id = DocumentService.get_tenant_id(id)
|
366 |
if not tenant_id:
|
367 |
return get_data_error_result(message="Tenant not found!")
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
370 |
|
371 |
if str(req["run"]) == TaskStatus.RUNNING.value:
|
372 |
TaskService.filter_delete([Task.doc_id == id])
|
@@ -490,8 +481,8 @@ def change_parser():
|
|
490 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
491 |
if not tenant_id:
|
492 |
return get_data_error_result(message="Tenant not found!")
|
493 |
-
|
494 |
-
|
495 |
|
496 |
return get_json_result(data=True)
|
497 |
except Exception as e:
|
|
|
17 |
import re
|
18 |
|
19 |
import flask
|
|
|
20 |
from flask import request
|
21 |
from flask_login import login_required, current_user
|
22 |
|
|
|
26 |
from api.db.services.task_service import TaskService, queue_tasks
|
27 |
from api.db.services.user_service import UserTenantService
|
28 |
from rag.nlp import search
|
|
|
29 |
from api.db.services import duplicate_name
|
30 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
31 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
32 |
from api.utils import get_uuid
|
33 |
from api.db import FileType, TaskStatus, ParserType, FileSource
|
34 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
35 |
+
from api.settings import RetCode, docStoreConn
|
36 |
from api.utils.api_utils import get_json_result
|
37 |
from rag.utils.storage_factory import STORAGE_IMPL
|
38 |
from api.utils.file_utils import filename_type, thumbnail
|
|
|
273 |
return get_data_error_result(
|
274 |
message="Database error (Document update)!")
|
275 |
|
276 |
+
status = int(req["status"])
|
277 |
+
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
return get_json_result(data=True)
|
279 |
except Exception as e:
|
280 |
return server_error_response(e)
|
|
|
353 |
tenant_id = DocumentService.get_tenant_id(id)
|
354 |
if not tenant_id:
|
355 |
return get_data_error_result(message="Tenant not found!")
|
356 |
+
e, doc = DocumentService.get_by_id(id)
|
357 |
+
if not e:
|
358 |
+
return get_data_error_result(message="Document not found!")
|
359 |
+
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
360 |
+
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
361 |
|
362 |
if str(req["run"]) == TaskStatus.RUNNING.value:
|
363 |
TaskService.filter_delete([Task.doc_id == id])
|
|
|
481 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
482 |
if not tenant_id:
|
483 |
return get_data_error_result(message="Tenant not found!")
|
484 |
+
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
485 |
+
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
486 |
|
487 |
return get_json_result(data=True)
|
488 |
except Exception as e:
|
api/apps/kb_app.py
CHANGED
@@ -28,6 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
28 |
from api.db.db_models import File
|
29 |
from api.settings import RetCode
|
30 |
from api.utils.api_utils import get_json_result
|
|
|
|
|
31 |
|
32 |
|
33 |
@manager.route('/create', methods=['post'])
|
@@ -166,6 +168,9 @@ def rm():
|
|
166 |
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
167 |
return get_data_error_result(
|
168 |
message="Database error (Knowledgebase removal)!")
|
|
|
|
|
|
|
169 |
return get_json_result(data=True)
|
170 |
except Exception as e:
|
171 |
return server_error_response(e)
|
|
|
28 |
from api.db.db_models import File
|
29 |
from api.settings import RetCode
|
30 |
from api.utils.api_utils import get_json_result
|
31 |
+
from api.settings import docStoreConn
|
32 |
+
from rag.nlp import search
|
33 |
|
34 |
|
35 |
@manager.route('/create', methods=['post'])
|
|
|
168 |
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
169 |
return get_data_error_result(
|
170 |
message="Database error (Knowledgebase removal)!")
|
171 |
+
tenants = UserTenantService.query(user_id=current_user.id)
|
172 |
+
for tenant in tenants:
|
173 |
+
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
174 |
return get_json_result(data=True)
|
175 |
except Exception as e:
|
176 |
return server_error_response(e)
|
api/apps/sdk/doc.py
CHANGED
@@ -30,7 +30,6 @@ from api.db.services.task_service import TaskService, queue_tasks
|
|
30 |
from api.utils.api_utils import server_error_response
|
31 |
from api.utils.api_utils import get_result, get_error_data_result
|
32 |
from io import BytesIO
|
33 |
-
from elasticsearch_dsl import Q
|
34 |
from flask import request, send_file
|
35 |
from api.db import FileSource, TaskStatus, FileType
|
36 |
from api.db.db_models import File
|
@@ -42,7 +41,7 @@ from api.settings import RetCode, retrievaler
|
|
42 |
from api.utils.api_utils import construct_json_result, get_parser_config
|
43 |
from rag.nlp import search
|
44 |
from rag.utils import rmSpace
|
45 |
-
from
|
46 |
from rag.utils.storage_factory import STORAGE_IMPL
|
47 |
import os
|
48 |
|
@@ -293,9 +292,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|
293 |
)
|
294 |
if not e:
|
295 |
return get_error_data_result(message="Document not found!")
|
296 |
-
|
297 |
-
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)
|
298 |
-
)
|
299 |
|
300 |
return get_result()
|
301 |
|
@@ -647,9 +644,7 @@ def parse(tenant_id, dataset_id):
|
|
647 |
info["chunk_num"] = 0
|
648 |
info["token_num"] = 0
|
649 |
DocumentService.update_by_id(id, info)
|
650 |
-
|
651 |
-
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
|
652 |
-
)
|
653 |
TaskService.filter_delete([Task.doc_id == id])
|
654 |
e, doc = DocumentService.get_by_id(id)
|
655 |
doc = doc.to_dict()
|
@@ -713,9 +708,7 @@ def stop_parsing(tenant_id, dataset_id):
|
|
713 |
)
|
714 |
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
715 |
DocumentService.update_by_id(id, info)
|
716 |
-
|
717 |
-
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
|
718 |
-
)
|
719 |
return get_result()
|
720 |
|
721 |
|
@@ -812,7 +805,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|
812 |
"question": question,
|
813 |
"sort": True,
|
814 |
}
|
815 |
-
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
816 |
key_mapping = {
|
817 |
"chunk_num": "chunk_count",
|
818 |
"kb_id": "dataset_id",
|
@@ -833,51 +825,56 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|
833 |
renamed_doc[new_key] = value
|
834 |
if key == "run":
|
835 |
renamed_doc["run"] = run_mapping.get(str(value))
|
836 |
-
|
|
|
837 |
origin_chunks = []
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
|
|
|
|
|
|
867 |
|
868 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
if req.get("id"):
|
870 |
-
if
|
871 |
-
|
872 |
-
|
873 |
-
sign = 1
|
874 |
-
break
|
875 |
-
if req.get("id"):
|
876 |
-
if sign == 0:
|
877 |
-
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
|
878 |
for chunk in origin_chunks:
|
879 |
key_mapping = {
|
880 |
-
"
|
881 |
"content_with_weight": "content",
|
882 |
"doc_id": "document_id",
|
883 |
"important_kwd": "important_keywords",
|
@@ -996,9 +993,9 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|
996 |
)
|
997 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
998 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
999 |
-
d["kb_id"] =
|
1000 |
d["docnm_kwd"] = doc.name
|
1001 |
-
d["doc_id"] =
|
1002 |
embd_id = DocumentService.get_embd_id(document_id)
|
1003 |
embd_mdl = TenantLLMService.model_instance(
|
1004 |
tenant_id, LLMType.EMBEDDING.value, embd_id
|
@@ -1006,14 +1003,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|
1006 |
v, c = embd_mdl.encode([doc.name, req["content"]])
|
1007 |
v = 0.1 * v[0] + 0.9 * v[1]
|
1008 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1009 |
-
|
1010 |
|
1011 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
1012 |
-
d["chunk_id"] = chunk_id
|
1013 |
-
d["kb_id"] = doc.kb_id
|
1014 |
# rename keys
|
1015 |
key_mapping = {
|
1016 |
-
"
|
1017 |
"content_with_weight": "content",
|
1018 |
"doc_id": "document_id",
|
1019 |
"important_kwd": "important_keywords",
|
@@ -1079,36 +1074,16 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
|
1079 |
"""
|
1080 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
1081 |
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
1082 |
-
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
|
1083 |
-
if not doc:
|
1084 |
-
return get_error_data_result(
|
1085 |
-
message=f"You don't own the document {document_id}."
|
1086 |
-
)
|
1087 |
-
doc = doc[0]
|
1088 |
req = request.json
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
if
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
chunk_list = sres.ids
|
1099 |
-
else:
|
1100 |
-
chunk_list = chunk_ids
|
1101 |
-
for chunk_id in chunk_list:
|
1102 |
-
if chunk_id not in sres.ids:
|
1103 |
-
return get_error_data_result(f"Chunk {chunk_id} not found")
|
1104 |
-
if not ELASTICSEARCH.deleteByQuery(
|
1105 |
-
Q("ids", values=chunk_list), search.index_name(tenant_id)
|
1106 |
-
):
|
1107 |
-
return get_error_data_result(message="Index updating failure")
|
1108 |
-
deleted_chunk_ids = chunk_list
|
1109 |
-
chunk_number = len(deleted_chunk_ids)
|
1110 |
-
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
1111 |
-
return get_result()
|
1112 |
|
1113 |
|
1114 |
@manager.route(
|
@@ -1168,9 +1143,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|
1168 |
schema:
|
1169 |
type: object
|
1170 |
"""
|
1171 |
-
|
1172 |
-
|
1173 |
-
except Exception:
|
1174 |
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
1175 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
1176 |
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
@@ -1180,19 +1154,12 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|
1180 |
message=f"You don't own the document {document_id}."
|
1181 |
)
|
1182 |
doc = doc[0]
|
1183 |
-
query = {
|
1184 |
-
"doc_ids": [document_id],
|
1185 |
-
"page": 1,
|
1186 |
-
"size": 1024,
|
1187 |
-
"question": "",
|
1188 |
-
"sort": True,
|
1189 |
-
}
|
1190 |
-
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
1191 |
-
if chunk_id not in sres.ids:
|
1192 |
-
return get_error_data_result(f"You don't own the chunk {chunk_id}")
|
1193 |
req = request.json
|
1194 |
-
content
|
1195 |
-
|
|
|
|
|
|
|
1196 |
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
|
1197 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
1198 |
if "important_keywords" in req:
|
@@ -1220,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|
1220 |
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
1221 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
1222 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1223 |
-
|
1224 |
return get_result()
|
1225 |
|
1226 |
|
|
|
30 |
from api.utils.api_utils import server_error_response
|
31 |
from api.utils.api_utils import get_result, get_error_data_result
|
32 |
from io import BytesIO
|
|
|
33 |
from flask import request, send_file
|
34 |
from api.db import FileSource, TaskStatus, FileType
|
35 |
from api.db.db_models import File
|
|
|
41 |
from api.utils.api_utils import construct_json_result, get_parser_config
|
42 |
from rag.nlp import search
|
43 |
from rag.utils import rmSpace
|
44 |
+
from api.settings import docStoreConn
|
45 |
from rag.utils.storage_factory import STORAGE_IMPL
|
46 |
import os
|
47 |
|
|
|
292 |
)
|
293 |
if not e:
|
294 |
return get_error_data_result(message="Document not found!")
|
295 |
+
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
|
|
|
|
296 |
|
297 |
return get_result()
|
298 |
|
|
|
644 |
info["chunk_num"] = 0
|
645 |
info["token_num"] = 0
|
646 |
DocumentService.update_by_id(id, info)
|
647 |
+
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
|
|
|
|
648 |
TaskService.filter_delete([Task.doc_id == id])
|
649 |
e, doc = DocumentService.get_by_id(id)
|
650 |
doc = doc.to_dict()
|
|
|
708 |
)
|
709 |
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
710 |
DocumentService.update_by_id(id, info)
|
711 |
+
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
|
|
|
|
712 |
return get_result()
|
713 |
|
714 |
|
|
|
805 |
"question": question,
|
806 |
"sort": True,
|
807 |
}
|
|
|
808 |
key_mapping = {
|
809 |
"chunk_num": "chunk_count",
|
810 |
"kb_id": "dataset_id",
|
|
|
825 |
renamed_doc[new_key] = value
|
826 |
if key == "run":
|
827 |
renamed_doc["run"] = run_mapping.get(str(value))
|
828 |
+
|
829 |
+
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
830 |
origin_chunks = []
|
831 |
+
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
832 |
+
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
833 |
+
res["total"] = sres.total
|
834 |
+
sign = 0
|
835 |
+
for id in sres.ids:
|
836 |
+
d = {
|
837 |
+
"id": id,
|
838 |
+
"content_with_weight": (
|
839 |
+
rmSpace(sres.highlight[id])
|
840 |
+
if question and id in sres.highlight
|
841 |
+
else sres.field[id].get("content_with_weight", "")
|
842 |
+
),
|
843 |
+
"doc_id": sres.field[id]["doc_id"],
|
844 |
+
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
845 |
+
"important_kwd": sres.field[id].get("important_kwd", []),
|
846 |
+
"img_id": sres.field[id].get("img_id", ""),
|
847 |
+
"available_int": sres.field[id].get("available_int", 1),
|
848 |
+
"positions": sres.field[id].get("position_int", "").split("\t"),
|
849 |
+
}
|
850 |
+
if len(d["positions"]) % 5 == 0:
|
851 |
+
poss = []
|
852 |
+
for i in range(0, len(d["positions"]), 5):
|
853 |
+
poss.append(
|
854 |
+
[
|
855 |
+
float(d["positions"][i]),
|
856 |
+
float(d["positions"][i + 1]),
|
857 |
+
float(d["positions"][i + 2]),
|
858 |
+
float(d["positions"][i + 3]),
|
859 |
+
float(d["positions"][i + 4]),
|
860 |
+
]
|
861 |
+
)
|
862 |
+
d["positions"] = poss
|
863 |
|
864 |
+
origin_chunks.append(d)
|
865 |
+
if req.get("id"):
|
866 |
+
if req.get("id") == id:
|
867 |
+
origin_chunks.clear()
|
868 |
+
origin_chunks.append(d)
|
869 |
+
sign = 1
|
870 |
+
break
|
871 |
if req.get("id"):
|
872 |
+
if sign == 0:
|
873 |
+
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
|
874 |
+
|
|
|
|
|
|
|
|
|
|
|
875 |
for chunk in origin_chunks:
|
876 |
key_mapping = {
|
877 |
+
"id": "id",
|
878 |
"content_with_weight": "content",
|
879 |
"doc_id": "document_id",
|
880 |
"important_kwd": "important_keywords",
|
|
|
993 |
)
|
994 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
995 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
996 |
+
d["kb_id"] = dataset_id
|
997 |
d["docnm_kwd"] = doc.name
|
998 |
+
d["doc_id"] = document_id
|
999 |
embd_id = DocumentService.get_embd_id(document_id)
|
1000 |
embd_mdl = TenantLLMService.model_instance(
|
1001 |
tenant_id, LLMType.EMBEDDING.value, embd_id
|
|
|
1003 |
v, c = embd_mdl.encode([doc.name, req["content"]])
|
1004 |
v = 0.1 * v[0] + 0.9 * v[1]
|
1005 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1006 |
+
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
1007 |
|
1008 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
|
|
|
|
1009 |
# rename keys
|
1010 |
key_mapping = {
|
1011 |
+
"id": "id",
|
1012 |
"content_with_weight": "content",
|
1013 |
"doc_id": "document_id",
|
1014 |
"important_kwd": "important_keywords",
|
|
|
1074 |
"""
|
1075 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
1076 |
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
1077 |
req = request.json
|
1078 |
+
condition = {"doc_id": document_id}
|
1079 |
+
if "chunk_ids" in req:
|
1080 |
+
condition["id"] = req["chunk_ids"]
|
1081 |
+
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
1082 |
+
if chunk_number != 0:
|
1083 |
+
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
1084 |
+
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
1085 |
+
return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(req["chunk_ids"])}")
|
1086 |
+
return get_result(message=f"deleted {chunk_number} chunks")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1087 |
|
1088 |
|
1089 |
@manager.route(
|
|
|
1143 |
schema:
|
1144 |
type: object
|
1145 |
"""
|
1146 |
+
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
1147 |
+
if chunk is None:
|
|
|
1148 |
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
1149 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
1150 |
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
|
|
1154 |
message=f"You don't own the document {document_id}."
|
1155 |
)
|
1156 |
doc = doc[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1157 |
req = request.json
|
1158 |
+
if "content" in req:
|
1159 |
+
content = req["content"]
|
1160 |
+
else:
|
1161 |
+
content = chunk.get("content_with_weight", "")
|
1162 |
+
d = {"id": chunk_id, "content_with_weight": content}
|
1163 |
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
|
1164 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
1165 |
if "important_keywords" in req:
|
|
|
1187 |
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
1188 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
1189 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1190 |
+
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
1191 |
return get_result()
|
1192 |
|
1193 |
|
api/apps/system_app.py
CHANGED
@@ -31,7 +31,7 @@ from api.utils.api_utils import (
|
|
31 |
generate_confirmation_token,
|
32 |
)
|
33 |
from api.versions import get_rag_version
|
34 |
-
from
|
35 |
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
36 |
from timeit import default_timer as timer
|
37 |
|
@@ -98,10 +98,11 @@ def status():
|
|
98 |
res = {}
|
99 |
st = timer()
|
100 |
try:
|
101 |
-
res["
|
102 |
-
res["
|
103 |
except Exception as e:
|
104 |
-
res["
|
|
|
105 |
"status": "red",
|
106 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
107 |
"error": str(e),
|
|
|
31 |
generate_confirmation_token,
|
32 |
)
|
33 |
from api.versions import get_rag_version
|
34 |
+
from api.settings import docStoreConn
|
35 |
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
36 |
from timeit import default_timer as timer
|
37 |
|
|
|
98 |
res = {}
|
99 |
st = timer()
|
100 |
try:
|
101 |
+
res["doc_store"] = docStoreConn.health()
|
102 |
+
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
103 |
except Exception as e:
|
104 |
+
res["doc_store"] = {
|
105 |
+
"type": "unknown",
|
106 |
"status": "red",
|
107 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
108 |
"error": str(e),
|
api/db/db_models.py
CHANGED
@@ -470,7 +470,7 @@ class User(DataBaseModel, UserMixin):
|
|
470 |
status = CharField(
|
471 |
max_length=1,
|
472 |
null=True,
|
473 |
-
help_text="is it validate(0: wasted
|
474 |
default="1",
|
475 |
index=True)
|
476 |
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
|
@@ -525,7 +525,7 @@ class Tenant(DataBaseModel):
|
|
525 |
status = CharField(
|
526 |
max_length=1,
|
527 |
null=True,
|
528 |
-
help_text="is it validate(0: wasted
|
529 |
default="1",
|
530 |
index=True)
|
531 |
|
@@ -542,7 +542,7 @@ class UserTenant(DataBaseModel):
|
|
542 |
status = CharField(
|
543 |
max_length=1,
|
544 |
null=True,
|
545 |
-
help_text="is it validate(0: wasted
|
546 |
default="1",
|
547 |
index=True)
|
548 |
|
@@ -559,7 +559,7 @@ class InvitationCode(DataBaseModel):
|
|
559 |
status = CharField(
|
560 |
max_length=1,
|
561 |
null=True,
|
562 |
-
help_text="is it validate(0: wasted
|
563 |
default="1",
|
564 |
index=True)
|
565 |
|
@@ -582,7 +582,7 @@ class LLMFactories(DataBaseModel):
|
|
582 |
status = CharField(
|
583 |
max_length=1,
|
584 |
null=True,
|
585 |
-
help_text="is it validate(0: wasted
|
586 |
default="1",
|
587 |
index=True)
|
588 |
|
@@ -616,7 +616,7 @@ class LLM(DataBaseModel):
|
|
616 |
status = CharField(
|
617 |
max_length=1,
|
618 |
null=True,
|
619 |
-
help_text="is it validate(0: wasted
|
620 |
default="1",
|
621 |
index=True)
|
622 |
|
@@ -703,7 +703,7 @@ class Knowledgebase(DataBaseModel):
|
|
703 |
status = CharField(
|
704 |
max_length=1,
|
705 |
null=True,
|
706 |
-
help_text="is it validate(0: wasted
|
707 |
default="1",
|
708 |
index=True)
|
709 |
|
@@ -767,7 +767,7 @@ class Document(DataBaseModel):
|
|
767 |
status = CharField(
|
768 |
max_length=1,
|
769 |
null=True,
|
770 |
-
help_text="is it validate(0: wasted
|
771 |
default="1",
|
772 |
index=True)
|
773 |
|
@@ -904,7 +904,7 @@ class Dialog(DataBaseModel):
|
|
904 |
status = CharField(
|
905 |
max_length=1,
|
906 |
null=True,
|
907 |
-
help_text="is it validate(0: wasted
|
908 |
default="1",
|
909 |
index=True)
|
910 |
|
@@ -987,7 +987,7 @@ def migrate_db():
|
|
987 |
help_text="where dose this document come from",
|
988 |
index=True))
|
989 |
)
|
990 |
-
except Exception
|
991 |
pass
|
992 |
try:
|
993 |
migrate(
|
@@ -996,7 +996,7 @@ def migrate_db():
|
|
996 |
help_text="default rerank model ID"))
|
997 |
|
998 |
)
|
999 |
-
except Exception
|
1000 |
pass
|
1001 |
try:
|
1002 |
migrate(
|
@@ -1004,59 +1004,59 @@ def migrate_db():
|
|
1004 |
help_text="default rerank model ID"))
|
1005 |
|
1006 |
)
|
1007 |
-
except Exception
|
1008 |
pass
|
1009 |
try:
|
1010 |
migrate(
|
1011 |
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
1012 |
|
1013 |
)
|
1014 |
-
except Exception
|
1015 |
pass
|
1016 |
try:
|
1017 |
migrate(
|
1018 |
migrator.alter_column_type('tenant_llm', 'api_key',
|
1019 |
CharField(max_length=1024, null=True, help_text="API KEY", index=True))
|
1020 |
)
|
1021 |
-
except Exception
|
1022 |
pass
|
1023 |
try:
|
1024 |
migrate(
|
1025 |
migrator.add_column('api_token', 'source',
|
1026 |
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
1027 |
)
|
1028 |
-
except Exception
|
1029 |
pass
|
1030 |
try:
|
1031 |
migrate(
|
1032 |
migrator.add_column("tenant","tts_id",
|
1033 |
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
|
1034 |
)
|
1035 |
-
except Exception
|
1036 |
pass
|
1037 |
try:
|
1038 |
migrate(
|
1039 |
migrator.add_column('api_4_conversation', 'source',
|
1040 |
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
1041 |
)
|
1042 |
-
except Exception
|
1043 |
pass
|
1044 |
try:
|
1045 |
DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
|
1046 |
DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
|
1047 |
-
except Exception
|
1048 |
pass
|
1049 |
try:
|
1050 |
migrate(
|
1051 |
migrator.add_column('task', 'retry_count', IntegerField(default=0))
|
1052 |
)
|
1053 |
-
except Exception
|
1054 |
pass
|
1055 |
try:
|
1056 |
migrate(
|
1057 |
migrator.alter_column_type('api_token', 'dialog_id',
|
1058 |
CharField(max_length=32, null=True, index=True))
|
1059 |
)
|
1060 |
-
except Exception
|
1061 |
pass
|
1062 |
|
|
|
470 |
status = CharField(
|
471 |
max_length=1,
|
472 |
null=True,
|
473 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
474 |
default="1",
|
475 |
index=True)
|
476 |
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
|
|
|
525 |
status = CharField(
|
526 |
max_length=1,
|
527 |
null=True,
|
528 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
529 |
default="1",
|
530 |
index=True)
|
531 |
|
|
|
542 |
status = CharField(
|
543 |
max_length=1,
|
544 |
null=True,
|
545 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
546 |
default="1",
|
547 |
index=True)
|
548 |
|
|
|
559 |
status = CharField(
|
560 |
max_length=1,
|
561 |
null=True,
|
562 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
563 |
default="1",
|
564 |
index=True)
|
565 |
|
|
|
582 |
status = CharField(
|
583 |
max_length=1,
|
584 |
null=True,
|
585 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
586 |
default="1",
|
587 |
index=True)
|
588 |
|
|
|
616 |
status = CharField(
|
617 |
max_length=1,
|
618 |
null=True,
|
619 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
620 |
default="1",
|
621 |
index=True)
|
622 |
|
|
|
703 |
status = CharField(
|
704 |
max_length=1,
|
705 |
null=True,
|
706 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
707 |
default="1",
|
708 |
index=True)
|
709 |
|
|
|
767 |
status = CharField(
|
768 |
max_length=1,
|
769 |
null=True,
|
770 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
771 |
default="1",
|
772 |
index=True)
|
773 |
|
|
|
904 |
status = CharField(
|
905 |
max_length=1,
|
906 |
null=True,
|
907 |
+
help_text="is it validate(0: wasted, 1: validate)",
|
908 |
default="1",
|
909 |
index=True)
|
910 |
|
|
|
987 |
help_text="where dose this document come from",
|
988 |
index=True))
|
989 |
)
|
990 |
+
except Exception:
|
991 |
pass
|
992 |
try:
|
993 |
migrate(
|
|
|
996 |
help_text="default rerank model ID"))
|
997 |
|
998 |
)
|
999 |
+
except Exception:
|
1000 |
pass
|
1001 |
try:
|
1002 |
migrate(
|
|
|
1004 |
help_text="default rerank model ID"))
|
1005 |
|
1006 |
)
|
1007 |
+
except Exception:
|
1008 |
pass
|
1009 |
try:
|
1010 |
migrate(
|
1011 |
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
1012 |
|
1013 |
)
|
1014 |
+
except Exception:
|
1015 |
pass
|
1016 |
try:
|
1017 |
migrate(
|
1018 |
migrator.alter_column_type('tenant_llm', 'api_key',
|
1019 |
CharField(max_length=1024, null=True, help_text="API KEY", index=True))
|
1020 |
)
|
1021 |
+
except Exception:
|
1022 |
pass
|
1023 |
try:
|
1024 |
migrate(
|
1025 |
migrator.add_column('api_token', 'source',
|
1026 |
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
1027 |
)
|
1028 |
+
except Exception:
|
1029 |
pass
|
1030 |
try:
|
1031 |
migrate(
|
1032 |
migrator.add_column("tenant","tts_id",
|
1033 |
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
|
1034 |
)
|
1035 |
+
except Exception:
|
1036 |
pass
|
1037 |
try:
|
1038 |
migrate(
|
1039 |
migrator.add_column('api_4_conversation', 'source',
|
1040 |
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
1041 |
)
|
1042 |
+
except Exception:
|
1043 |
pass
|
1044 |
try:
|
1045 |
DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
|
1046 |
DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
|
1047 |
+
except Exception:
|
1048 |
pass
|
1049 |
try:
|
1050 |
migrate(
|
1051 |
migrator.add_column('task', 'retry_count', IntegerField(default=0))
|
1052 |
)
|
1053 |
+
except Exception:
|
1054 |
pass
|
1055 |
try:
|
1056 |
migrate(
|
1057 |
migrator.alter_column_type('api_token', 'dialog_id',
|
1058 |
CharField(max_length=32, null=True, index=True))
|
1059 |
)
|
1060 |
+
except Exception:
|
1061 |
pass
|
1062 |
|
api/db/services/document_service.py
CHANGED
@@ -15,7 +15,6 @@
|
|
15 |
#
|
16 |
import hashlib
|
17 |
import json
|
18 |
-
import os
|
19 |
import random
|
20 |
import re
|
21 |
import traceback
|
@@ -24,16 +23,13 @@ from copy import deepcopy
|
|
24 |
from datetime import datetime
|
25 |
from io import BytesIO
|
26 |
|
27 |
-
from elasticsearch_dsl import Q
|
28 |
from peewee import fn
|
29 |
|
30 |
from api.db.db_utils import bulk_insert_into_db
|
31 |
-
from api.settings import stat_logger
|
32 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
33 |
-
from api.utils.file_utils import get_project_base_directory
|
34 |
from graphrag.mind_map_extractor import MindMapExtractor
|
35 |
from rag.settings import SVR_QUEUE_NAME
|
36 |
-
from rag.utils.es_conn import ELASTICSEARCH
|
37 |
from rag.utils.storage_factory import STORAGE_IMPL
|
38 |
from rag.nlp import search, rag_tokenizer
|
39 |
|
@@ -112,8 +108,7 @@ class DocumentService(CommonService):
|
|
112 |
@classmethod
|
113 |
@DB.connection_context()
|
114 |
def remove_document(cls, doc, tenant_id):
|
115 |
-
|
116 |
-
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
117 |
cls.clear_chunk_num(doc.id)
|
118 |
return cls.delete_by_id(doc.id)
|
119 |
|
@@ -225,6 +220,15 @@ class DocumentService(CommonService):
|
|
225 |
return
|
226 |
return docs[0]["tenant_id"]
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
@classmethod
|
229 |
@DB.connection_context()
|
230 |
def get_tenant_id_by_name(cls, name):
|
@@ -438,11 +442,6 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
438 |
if not e:
|
439 |
raise LookupError("Can't find this knowledgebase!")
|
440 |
|
441 |
-
idxnm = search.index_name(kb.tenant_id)
|
442 |
-
if not ELASTICSEARCH.indexExist(idxnm):
|
443 |
-
ELASTICSEARCH.createIdx(idxnm, json.load(
|
444 |
-
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
445 |
-
|
446 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
447 |
|
448 |
err, files = FileService.upload_document(kb, file_objs, user_id)
|
@@ -486,7 +485,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
486 |
md5 = hashlib.md5()
|
487 |
md5.update((ck["content_with_weight"] +
|
488 |
str(d["doc_id"])).encode("utf-8"))
|
489 |
-
d["
|
490 |
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
491 |
d["create_timestamp_flt"] = datetime.now().timestamp()
|
492 |
if not d.get("image"):
|
@@ -499,8 +498,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
499 |
else:
|
500 |
d["image"].save(output_buffer, format='JPEG')
|
501 |
|
502 |
-
STORAGE_IMPL.put(kb.id, d["
|
503 |
-
d["img_id"] = "{}-{}".format(kb.id, d["
|
504 |
del d["image"]
|
505 |
docs.append(d)
|
506 |
|
@@ -520,6 +519,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
520 |
token_counts[doc_id] += c
|
521 |
return vects
|
522 |
|
|
|
|
|
|
|
523 |
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
524 |
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
525 |
for doc_id in docids:
|
@@ -550,7 +552,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
550 |
v = vects[i]
|
551 |
d["q_%d_vec" % len(v)] = v
|
552 |
for b in range(0, len(cks), es_bulk_size):
|
553 |
-
|
|
|
|
|
|
|
|
|
554 |
|
555 |
DocumentService.increment_chunk_num(
|
556 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
|
|
15 |
#
|
16 |
import hashlib
|
17 |
import json
|
|
|
18 |
import random
|
19 |
import re
|
20 |
import traceback
|
|
|
23 |
from datetime import datetime
|
24 |
from io import BytesIO
|
25 |
|
|
|
26 |
from peewee import fn
|
27 |
|
28 |
from api.db.db_utils import bulk_insert_into_db
|
29 |
+
from api.settings import stat_logger, docStoreConn
|
30 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
|
|
31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
32 |
from rag.settings import SVR_QUEUE_NAME
|
|
|
33 |
from rag.utils.storage_factory import STORAGE_IMPL
|
34 |
from rag.nlp import search, rag_tokenizer
|
35 |
|
|
|
108 |
@classmethod
|
109 |
@DB.connection_context()
|
110 |
def remove_document(cls, doc, tenant_id):
|
111 |
+
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
|
|
112 |
cls.clear_chunk_num(doc.id)
|
113 |
return cls.delete_by_id(doc.id)
|
114 |
|
|
|
220 |
return
|
221 |
return docs[0]["tenant_id"]
|
222 |
|
223 |
+
@classmethod
|
224 |
+
@DB.connection_context()
|
225 |
+
def get_knowledgebase_id(cls, doc_id):
|
226 |
+
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
227 |
+
docs = docs.dicts()
|
228 |
+
if not docs:
|
229 |
+
return
|
230 |
+
return docs[0]["kb_id"]
|
231 |
+
|
232 |
@classmethod
|
233 |
@DB.connection_context()
|
234 |
def get_tenant_id_by_name(cls, name):
|
|
|
442 |
if not e:
|
443 |
raise LookupError("Can't find this knowledgebase!")
|
444 |
|
|
|
|
|
|
|
|
|
|
|
445 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
446 |
|
447 |
err, files = FileService.upload_document(kb, file_objs, user_id)
|
|
|
485 |
md5 = hashlib.md5()
|
486 |
md5.update((ck["content_with_weight"] +
|
487 |
str(d["doc_id"])).encode("utf-8"))
|
488 |
+
d["id"] = md5.hexdigest()
|
489 |
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
490 |
d["create_timestamp_flt"] = datetime.now().timestamp()
|
491 |
if not d.get("image"):
|
|
|
498 |
else:
|
499 |
d["image"].save(output_buffer, format='JPEG')
|
500 |
|
501 |
+
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
502 |
+
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
503 |
del d["image"]
|
504 |
docs.append(d)
|
505 |
|
|
|
519 |
token_counts[doc_id] += c
|
520 |
return vects
|
521 |
|
522 |
+
idxnm = search.index_name(kb.tenant_id)
|
523 |
+
try_create_idx = True
|
524 |
+
|
525 |
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
526 |
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
527 |
for doc_id in docids:
|
|
|
552 |
v = vects[i]
|
553 |
d["q_%d_vec" % len(v)] = v
|
554 |
for b in range(0, len(cks), es_bulk_size):
|
555 |
+
if try_create_idx:
|
556 |
+
if not docStoreConn.indexExist(idxnm, kb_id):
|
557 |
+
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
558 |
+
try_create_idx = False
|
559 |
+
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
560 |
|
561 |
DocumentService.increment_chunk_num(
|
562 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
api/db/services/knowledgebase_service.py
CHANGED
@@ -66,6 +66,16 @@ class KnowledgebaseService(CommonService):
|
|
66 |
|
67 |
return list(kbs.dicts())
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
@classmethod
|
70 |
@DB.connection_context()
|
71 |
def get_detail(cls, kb_id):
|
|
|
66 |
|
67 |
return list(kbs.dicts())
|
68 |
|
69 |
+
@classmethod
|
70 |
+
@DB.connection_context()
|
71 |
+
def get_kb_ids(cls, tenant_id):
|
72 |
+
fields = [
|
73 |
+
cls.model.id,
|
74 |
+
]
|
75 |
+
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
76 |
+
kb_ids = [kb["id"] for kb in kbs]
|
77 |
+
return kb_ids
|
78 |
+
|
79 |
@classmethod
|
80 |
@DB.connection_context()
|
81 |
def get_detail(cls, kb_id):
|
api/settings.py
CHANGED
@@ -18,6 +18,8 @@ from datetime import date
|
|
18 |
from enum import IntEnum, Enum
|
19 |
from api.utils.file_utils import get_project_base_directory
|
20 |
from api.utils.log_utils import LoggerFactory, getLogger
|
|
|
|
|
21 |
|
22 |
# Logger
|
23 |
LoggerFactory.set_directory(
|
@@ -33,7 +35,7 @@ access_logger = getLogger("access")
|
|
33 |
database_logger = getLogger("database")
|
34 |
chat_logger = getLogger("chat")
|
35 |
|
36 |
-
|
37 |
from rag.nlp import search
|
38 |
from graphrag import search as kg_search
|
39 |
from api.utils import get_base_config, decrypt_database_config
|
@@ -206,8 +208,12 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
|
206 |
PRIVILEGE_COMMAND_WHITELIST = []
|
207 |
CHECK_NODES_IDENTITY = False
|
208 |
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
211 |
|
212 |
|
213 |
class CustomEnum(Enum):
|
|
|
18 |
from enum import IntEnum, Enum
|
19 |
from api.utils.file_utils import get_project_base_directory
|
20 |
from api.utils.log_utils import LoggerFactory, getLogger
|
21 |
+
import rag.utils.es_conn
|
22 |
+
import rag.utils.infinity_conn
|
23 |
|
24 |
# Logger
|
25 |
LoggerFactory.set_directory(
|
|
|
35 |
database_logger = getLogger("database")
|
36 |
chat_logger = getLogger("chat")
|
37 |
|
38 |
+
import rag.utils
|
39 |
from rag.nlp import search
|
40 |
from graphrag import search as kg_search
|
41 |
from api.utils import get_base_config, decrypt_database_config
|
|
|
208 |
PRIVILEGE_COMMAND_WHITELIST = []
|
209 |
CHECK_NODES_IDENTITY = False
|
210 |
|
211 |
+
if 'username' in get_base_config("es", {}):
|
212 |
+
docStoreConn = rag.utils.es_conn.ESConnection()
|
213 |
+
else:
|
214 |
+
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
215 |
+
retrievaler = search.Dealer(docStoreConn)
|
216 |
+
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
217 |
|
218 |
|
219 |
class CustomEnum(Enum):
|
api/utils/api_utils.py
CHANGED
@@ -126,10 +126,6 @@ def server_error_response(e):
|
|
126 |
if len(e.args) > 1:
|
127 |
return get_json_result(
|
128 |
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
129 |
-
if repr(e).find("index_not_found_exception") >= 0:
|
130 |
-
return get_json_result(code=RetCode.EXCEPTION_ERROR,
|
131 |
-
message="No chunk found, please upload file and parse it.")
|
132 |
-
|
133 |
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
134 |
|
135 |
|
@@ -270,10 +266,6 @@ def construct_error_response(e):
|
|
270 |
pass
|
271 |
if len(e.args) > 1:
|
272 |
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
273 |
-
if repr(e).find("index_not_found_exception") >= 0:
|
274 |
-
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
|
275 |
-
message="No chunk found, please upload file and parse it.")
|
276 |
-
|
277 |
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
278 |
|
279 |
|
@@ -295,7 +287,7 @@ def token_required(func):
|
|
295 |
return decorated_function
|
296 |
|
297 |
|
298 |
-
def get_result(code=RetCode.SUCCESS, message=
|
299 |
if code == 0:
|
300 |
if data is not None:
|
301 |
response = {"code": code, "data": data}
|
|
|
126 |
if len(e.args) > 1:
|
127 |
return get_json_result(
|
128 |
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
|
|
|
|
|
|
|
|
129 |
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
130 |
|
131 |
|
|
|
266 |
pass
|
267 |
if len(e.args) > 1:
|
268 |
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
|
|
|
|
|
|
|
|
269 |
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
270 |
|
271 |
|
|
|
287 |
return decorated_function
|
288 |
|
289 |
|
290 |
+
def get_result(code=RetCode.SUCCESS, message="", data=None):
|
291 |
if code == 0:
|
292 |
if data is not None:
|
293 |
response = {"code": code, "data": data}
|
conf/infinity_mapping.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"id": {"type": "varchar", "default": ""},
|
3 |
+
"doc_id": {"type": "varchar", "default": ""},
|
4 |
+
"kb_id": {"type": "varchar", "default": ""},
|
5 |
+
"create_time": {"type": "varchar", "default": ""},
|
6 |
+
"create_timestamp_flt": {"type": "float", "default": 0.0},
|
7 |
+
"img_id": {"type": "varchar", "default": ""},
|
8 |
+
"docnm_kwd": {"type": "varchar", "default": ""},
|
9 |
+
"title_tks": {"type": "varchar", "default": ""},
|
10 |
+
"title_sm_tks": {"type": "varchar", "default": ""},
|
11 |
+
"name_kwd": {"type": "varchar", "default": ""},
|
12 |
+
"important_kwd": {"type": "varchar", "default": ""},
|
13 |
+
"important_tks": {"type": "varchar", "default": ""},
|
14 |
+
"content_with_weight": {"type": "varchar", "default": ""},
|
15 |
+
"content_ltks": {"type": "varchar", "default": ""},
|
16 |
+
"content_sm_ltks": {"type": "varchar", "default": ""},
|
17 |
+
"page_num_list": {"type": "varchar", "default": ""},
|
18 |
+
"top_list": {"type": "varchar", "default": ""},
|
19 |
+
"position_list": {"type": "varchar", "default": ""},
|
20 |
+
"weight_int": {"type": "integer", "default": 0},
|
21 |
+
"weight_flt": {"type": "float", "default": 0.0},
|
22 |
+
"rank_int": {"type": "integer", "default": 0},
|
23 |
+
"available_int": {"type": "integer", "default": 1},
|
24 |
+
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
|
25 |
+
"entities_kwd": {"type": "varchar", "default": ""}
|
26 |
+
}
|
conf/mapping.json
CHANGED
@@ -1,200 +1,203 @@
|
|
1 |
-
|
2 |
"settings": {
|
3 |
"index": {
|
4 |
"number_of_shards": 2,
|
5 |
"number_of_replicas": 0,
|
6 |
-
"refresh_interval"
|
7 |
},
|
8 |
"similarity": {
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
}
|
14 |
}
|
|
|
15 |
}
|
16 |
},
|
17 |
"mappings": {
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
}
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
"settings": {
|
3 |
"index": {
|
4 |
"number_of_shards": 2,
|
5 |
"number_of_replicas": 0,
|
6 |
+
"refresh_interval": "1000ms"
|
7 |
},
|
8 |
"similarity": {
|
9 |
+
"scripted_sim": {
|
10 |
+
"type": "scripted",
|
11 |
+
"script": {
|
12 |
+
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
|
|
|
13 |
}
|
14 |
+
}
|
15 |
}
|
16 |
},
|
17 |
"mappings": {
|
18 |
+
"properties": {
|
19 |
+
"lat_lon": {
|
20 |
+
"type": "geo_point",
|
21 |
+
"store": "true"
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"date_detection": "true",
|
25 |
+
"dynamic_templates": [
|
26 |
+
{
|
27 |
+
"int": {
|
28 |
+
"match": "*_int",
|
29 |
+
"mapping": {
|
30 |
+
"type": "integer",
|
31 |
+
"store": "true"
|
32 |
+
}
|
33 |
+
}
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"ulong": {
|
37 |
+
"match": "*_ulong",
|
38 |
+
"mapping": {
|
39 |
+
"type": "unsigned_long",
|
40 |
+
"store": "true"
|
41 |
+
}
|
42 |
+
}
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"long": {
|
46 |
+
"match": "*_long",
|
47 |
+
"mapping": {
|
48 |
+
"type": "long",
|
49 |
+
"store": "true"
|
50 |
+
}
|
51 |
+
}
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"short": {
|
55 |
+
"match": "*_short",
|
56 |
+
"mapping": {
|
57 |
+
"type": "short",
|
58 |
+
"store": "true"
|
59 |
+
}
|
60 |
+
}
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"numeric": {
|
64 |
+
"match": "*_flt",
|
65 |
+
"mapping": {
|
66 |
+
"type": "float",
|
67 |
+
"store": true
|
68 |
+
}
|
69 |
+
}
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"tks": {
|
73 |
+
"match": "*_tks",
|
74 |
+
"mapping": {
|
75 |
+
"type": "text",
|
76 |
+
"similarity": "scripted_sim",
|
77 |
+
"analyzer": "whitespace",
|
78 |
+
"store": true
|
79 |
+
}
|
80 |
+
}
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"ltks": {
|
84 |
+
"match": "*_ltks",
|
85 |
+
"mapping": {
|
86 |
+
"type": "text",
|
87 |
+
"analyzer": "whitespace",
|
88 |
+
"store": true
|
89 |
+
}
|
90 |
+
}
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"kwd": {
|
94 |
+
"match_pattern": "regex",
|
95 |
+
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
|
96 |
+
"mapping": {
|
97 |
+
"type": "keyword",
|
98 |
+
"similarity": "boolean",
|
99 |
+
"store": true
|
100 |
+
}
|
101 |
+
}
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"dt": {
|
105 |
+
"match_pattern": "regex",
|
106 |
+
"match": "^.*(_dt|_time|_at)$",
|
107 |
+
"mapping": {
|
108 |
+
"type": "date",
|
109 |
+
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
|
110 |
+
"store": true
|
111 |
+
}
|
112 |
+
}
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"nested": {
|
116 |
+
"match": "*_nst",
|
117 |
+
"mapping": {
|
118 |
+
"type": "nested"
|
119 |
+
}
|
120 |
+
}
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"object": {
|
124 |
+
"match": "*_obj",
|
125 |
+
"mapping": {
|
126 |
+
"type": "object",
|
127 |
+
"dynamic": "true"
|
128 |
+
}
|
129 |
+
}
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"string": {
|
133 |
+
"match": "*_(with_weight|list)$",
|
134 |
+
"mapping": {
|
135 |
+
"type": "text",
|
136 |
+
"index": "false",
|
137 |
+
"store": true
|
138 |
+
}
|
139 |
+
}
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"string": {
|
143 |
+
"match": "*_fea",
|
144 |
+
"mapping": {
|
145 |
+
"type": "rank_feature"
|
146 |
+
}
|
147 |
+
}
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"dense_vector": {
|
151 |
+
"match": "*_512_vec",
|
152 |
+
"mapping": {
|
153 |
+
"type": "dense_vector",
|
154 |
+
"index": true,
|
155 |
+
"similarity": "cosine",
|
156 |
+
"dims": 512
|
157 |
+
}
|
158 |
+
}
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"dense_vector": {
|
162 |
+
"match": "*_768_vec",
|
163 |
+
"mapping": {
|
164 |
+
"type": "dense_vector",
|
165 |
+
"index": true,
|
166 |
+
"similarity": "cosine",
|
167 |
+
"dims": 768
|
168 |
+
}
|
169 |
+
}
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"dense_vector": {
|
173 |
+
"match": "*_1024_vec",
|
174 |
+
"mapping": {
|
175 |
+
"type": "dense_vector",
|
176 |
+
"index": true,
|
177 |
+
"similarity": "cosine",
|
178 |
+
"dims": 1024
|
179 |
+
}
|
180 |
+
}
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"dense_vector": {
|
184 |
+
"match": "*_1536_vec",
|
185 |
+
"mapping": {
|
186 |
+
"type": "dense_vector",
|
187 |
+
"index": true,
|
188 |
+
"similarity": "cosine",
|
189 |
+
"dims": 1536
|
190 |
+
}
|
191 |
+
}
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"binary": {
|
195 |
+
"match": "*_bin",
|
196 |
+
"mapping": {
|
197 |
+
"type": "binary"
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
]
|
202 |
+
}
|
203 |
+
}
|
docker/.env
CHANGED
@@ -19,6 +19,11 @@ KIBANA_PASSWORD=infini_rag_flow
|
|
19 |
# Update it according to the available memory in the host machine.
|
20 |
MEM_LIMIT=8073741824
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
# The password for MySQL.
|
23 |
# When updated, you must revise the `mysql.password` entry in service_conf.yaml.
|
24 |
MYSQL_PASSWORD=infini_rag_flow
|
|
|
19 |
# Update it according to the available memory in the host machine.
|
20 |
MEM_LIMIT=8073741824
|
21 |
|
22 |
+
# Port to expose Infinity API to the host
|
23 |
+
INFINITY_THRIFT_PORT=23817
|
24 |
+
INFINITY_HTTP_PORT=23820
|
25 |
+
INFINITY_PSQL_PORT=5432
|
26 |
+
|
27 |
# The password for MySQL.
|
28 |
# When updated, you must revise the `mysql.password` entry in service_conf.yaml.
|
29 |
MYSQL_PASSWORD=infini_rag_flow
|
docker/docker-compose-base.yml
CHANGED
@@ -6,6 +6,7 @@ services:
|
|
6 |
- esdata01:/usr/share/elasticsearch/data
|
7 |
ports:
|
8 |
- ${ES_PORT}:9200
|
|
|
9 |
environment:
|
10 |
- node.name=es01
|
11 |
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
@@ -27,12 +28,40 @@ services:
|
|
27 |
retries: 120
|
28 |
networks:
|
29 |
- ragflow
|
30 |
-
restart:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
mysql:
|
33 |
# mysql:5.7 linux/arm64 image is unavailable.
|
34 |
image: mysql:8.0.39
|
35 |
container_name: ragflow-mysql
|
|
|
36 |
environment:
|
37 |
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
|
38 |
- TZ=${TIMEZONE}
|
@@ -55,7 +84,7 @@ services:
|
|
55 |
interval: 10s
|
56 |
timeout: 10s
|
57 |
retries: 3
|
58 |
-
restart:
|
59 |
|
60 |
minio:
|
61 |
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
@@ -64,6 +93,7 @@ services:
|
|
64 |
ports:
|
65 |
- ${MINIO_PORT}:9000
|
66 |
- ${MINIO_CONSOLE_PORT}:9001
|
|
|
67 |
environment:
|
68 |
- MINIO_ROOT_USER=${MINIO_USER}
|
69 |
- MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
|
@@ -72,25 +102,28 @@ services:
|
|
72 |
- minio_data:/data
|
73 |
networks:
|
74 |
- ragflow
|
75 |
-
restart:
|
76 |
|
77 |
redis:
|
78 |
image: valkey/valkey:8
|
79 |
container_name: ragflow-redis
|
80 |
command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
|
|
|
81 |
ports:
|
82 |
- ${REDIS_PORT}:6379
|
83 |
volumes:
|
84 |
- redis_data:/data
|
85 |
networks:
|
86 |
- ragflow
|
87 |
-
restart:
|
88 |
|
89 |
|
90 |
|
91 |
volumes:
|
92 |
esdata01:
|
93 |
driver: local
|
|
|
|
|
94 |
mysql_data:
|
95 |
driver: local
|
96 |
minio_data:
|
|
|
6 |
- esdata01:/usr/share/elasticsearch/data
|
7 |
ports:
|
8 |
- ${ES_PORT}:9200
|
9 |
+
env_file: .env
|
10 |
environment:
|
11 |
- node.name=es01
|
12 |
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
|
|
28 |
retries: 120
|
29 |
networks:
|
30 |
- ragflow
|
31 |
+
restart: on-failure
|
32 |
+
|
33 |
+
# infinity:
|
34 |
+
# container_name: ragflow-infinity
|
35 |
+
# image: infiniflow/infinity:v0.5.0-dev2
|
36 |
+
# volumes:
|
37 |
+
# - infinity_data:/var/infinity
|
38 |
+
# ports:
|
39 |
+
# - ${INFINITY_THRIFT_PORT}:23817
|
40 |
+
# - ${INFINITY_HTTP_PORT}:23820
|
41 |
+
# - ${INFINITY_PSQL_PORT}:5432
|
42 |
+
# env_file: .env
|
43 |
+
# environment:
|
44 |
+
# - TZ=${TIMEZONE}
|
45 |
+
# mem_limit: ${MEM_LIMIT}
|
46 |
+
# ulimits:
|
47 |
+
# nofile:
|
48 |
+
# soft: 500000
|
49 |
+
# hard: 500000
|
50 |
+
# networks:
|
51 |
+
# - ragflow
|
52 |
+
# healthcheck:
|
53 |
+
# test: ["CMD", "curl", "http://localhost:23820/admin/node/current"]
|
54 |
+
# interval: 10s
|
55 |
+
# timeout: 10s
|
56 |
+
# retries: 120
|
57 |
+
# restart: on-failure
|
58 |
+
|
59 |
|
60 |
mysql:
|
61 |
# mysql:5.7 linux/arm64 image is unavailable.
|
62 |
image: mysql:8.0.39
|
63 |
container_name: ragflow-mysql
|
64 |
+
env_file: .env
|
65 |
environment:
|
66 |
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
|
67 |
- TZ=${TIMEZONE}
|
|
|
84 |
interval: 10s
|
85 |
timeout: 10s
|
86 |
retries: 3
|
87 |
+
restart: on-failure
|
88 |
|
89 |
minio:
|
90 |
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
|
|
93 |
ports:
|
94 |
- ${MINIO_PORT}:9000
|
95 |
- ${MINIO_CONSOLE_PORT}:9001
|
96 |
+
env_file: .env
|
97 |
environment:
|
98 |
- MINIO_ROOT_USER=${MINIO_USER}
|
99 |
- MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
|
|
|
102 |
- minio_data:/data
|
103 |
networks:
|
104 |
- ragflow
|
105 |
+
restart: on-failure
|
106 |
|
107 |
redis:
|
108 |
image: valkey/valkey:8
|
109 |
container_name: ragflow-redis
|
110 |
command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
|
111 |
+
env_file: .env
|
112 |
ports:
|
113 |
- ${REDIS_PORT}:6379
|
114 |
volumes:
|
115 |
- redis_data:/data
|
116 |
networks:
|
117 |
- ragflow
|
118 |
+
restart: on-failure
|
119 |
|
120 |
|
121 |
|
122 |
volumes:
|
123 |
esdata01:
|
124 |
driver: local
|
125 |
+
infinity_data:
|
126 |
+
driver: local
|
127 |
mysql_data:
|
128 |
driver: local
|
129 |
minio_data:
|
docker/docker-compose.yml
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
include:
|
2 |
-
-
|
3 |
-
env_file: ./.env
|
4 |
|
5 |
services:
|
6 |
ragflow:
|
@@ -15,19 +14,21 @@ services:
|
|
15 |
- ${SVR_HTTP_PORT}:9380
|
16 |
- 80:80
|
17 |
- 443:443
|
18 |
-
- 5678:5678
|
19 |
volumes:
|
20 |
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
21 |
- ./ragflow-logs:/ragflow/logs
|
22 |
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
23 |
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
24 |
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
|
|
25 |
environment:
|
26 |
- TZ=${TIMEZONE}
|
27 |
- HF_ENDPOINT=${HF_ENDPOINT}
|
28 |
- MACOS=${MACOS}
|
29 |
networks:
|
30 |
- ragflow
|
31 |
-
restart:
|
|
|
|
|
32 |
extra_hosts:
|
33 |
- "host.docker.internal:host-gateway"
|
|
|
1 |
include:
|
2 |
+
- ./docker-compose-base.yml
|
|
|
3 |
|
4 |
services:
|
5 |
ragflow:
|
|
|
14 |
- ${SVR_HTTP_PORT}:9380
|
15 |
- 80:80
|
16 |
- 443:443
|
|
|
17 |
volumes:
|
18 |
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
19 |
- ./ragflow-logs:/ragflow/logs
|
20 |
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
21 |
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
22 |
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
23 |
+
env_file: .env
|
24 |
environment:
|
25 |
- TZ=${TIMEZONE}
|
26 |
- HF_ENDPOINT=${HF_ENDPOINT}
|
27 |
- MACOS=${MACOS}
|
28 |
networks:
|
29 |
- ragflow
|
30 |
+
restart: on-failure
|
31 |
+
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
32 |
+
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
33 |
extra_hosts:
|
34 |
- "host.docker.internal:host-gateway"
|
docs/guides/develop/launch_ragflow_from_source.md
CHANGED
@@ -67,7 +67,7 @@ docker compose -f docker/docker-compose-base.yml up -d
|
|
67 |
1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
68 |
|
69 |
```
|
70 |
-
127.0.0.1 es01 mysql minio redis
|
71 |
```
|
72 |
|
73 |
2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
|
|
67 |
1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
68 |
|
69 |
```
|
70 |
+
127.0.0.1 es01 infinity mysql minio redis
|
71 |
```
|
72 |
|
73 |
2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
docs/references/http_api_reference.md
CHANGED
@@ -1280,7 +1280,7 @@ Success:
|
|
1280 |
"document_keyword": "1.txt",
|
1281 |
"highlight": "<em>ragflow</em> content",
|
1282 |
"id": "d78435d142bd5cf6704da62c778795c5",
|
1283 |
-
"
|
1284 |
"important_keywords": [
|
1285 |
""
|
1286 |
],
|
|
|
1280 |
"document_keyword": "1.txt",
|
1281 |
"highlight": "<em>ragflow</em> content",
|
1282 |
"id": "d78435d142bd5cf6704da62c778795c5",
|
1283 |
+
"image_id": "",
|
1284 |
"important_keywords": [
|
1285 |
""
|
1286 |
],
|
docs/references/python_api_reference.md
CHANGED
@@ -1351,7 +1351,7 @@ A list of `Chunk` objects representing references to the message, each containin
|
|
1351 |
The chunk ID.
|
1352 |
- `content` `str`
|
1353 |
The content of the chunk.
|
1354 |
-
- `
|
1355 |
The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
|
1356 |
- `document_id` `str`
|
1357 |
The ID of the referenced document.
|
|
|
1351 |
The chunk ID.
|
1352 |
- `content` `str`
|
1353 |
The content of the chunk.
|
1354 |
+
- `img_id` `str`
|
1355 |
The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
|
1356 |
- `document_id` `str`
|
1357 |
The ID of the referenced document.
|
graphrag/claim_extractor.py
CHANGED
@@ -254,9 +254,12 @@ if __name__ == "__main__":
|
|
254 |
from api.db import LLMType
|
255 |
from api.db.services.llm_service import LLMBundle
|
256 |
from api.settings import retrievaler
|
|
|
|
|
|
|
257 |
|
258 |
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
259 |
-
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])]
|
260 |
info = {
|
261 |
"input_text": docs,
|
262 |
"entity_specs": "organization, person",
|
|
|
254 |
from api.db import LLMType
|
255 |
from api.db.services.llm_service import LLMBundle
|
256 |
from api.settings import retrievaler
|
257 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
258 |
+
|
259 |
+
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
260 |
|
261 |
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
262 |
+
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
263 |
info = {
|
264 |
"input_text": docs,
|
265 |
"entity_specs": "organization, person",
|
graphrag/search.py
CHANGED
@@ -15,95 +15,90 @@
|
|
15 |
#
|
16 |
import json
|
17 |
from copy import deepcopy
|
|
|
18 |
|
19 |
import pandas as pd
|
20 |
-
from
|
21 |
|
22 |
from rag.nlp.search import Dealer
|
23 |
|
24 |
|
25 |
class KGSearch(Dealer):
|
26 |
-
def search(self, req, idxnm, emb_mdl
|
27 |
-
def merge_into_first(sres, title=""):
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
30 |
try:
|
31 |
-
df.append(json.loads(d["
|
32 |
-
except Exception
|
33 |
-
texts.append(d["
|
34 |
-
pass
|
35 |
-
if not df and not texts: return False
|
36 |
if df:
|
37 |
-
|
38 |
-
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
|
39 |
-
except Exception as e:
|
40 |
-
pass
|
41 |
else:
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
46 |
-
"
|
47 |
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
|
48 |
"weight_int", "weight_flt", "rank_int"
|
49 |
])
|
50 |
|
51 |
-
|
52 |
-
binary_query, keywords = self.qryr.question(qst, min_match="5%")
|
53 |
-
binary_query = self._add_filters(binary_query, req)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
q_vec = []
|
63 |
-
if req.get("vector"):
|
64 |
-
assert emb_mdl, "No embedding model selected"
|
65 |
-
s["knn"] = self._vector(
|
66 |
-
qst, emb_mdl, req.get(
|
67 |
-
"similarity", 0.1), 1024)
|
68 |
-
s["knn"]["filter"] = bqry.to_dict()
|
69 |
-
q_vec = s["knn"]["query_vector"]
|
70 |
-
|
71 |
-
ent_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
|
72 |
-
entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
|
73 |
-
ent_ids = self.es.getDocIds(ent_res)
|
74 |
-
if merge_into_first(ent_res, "-Entities-"):
|
75 |
-
ent_ids = ent_ids[0:1]
|
76 |
|
77 |
## Community retrieval
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
if merge_into_first(comm_res, "-Community Report-"):
|
87 |
-
comm_ids = comm_ids[0:1]
|
88 |
|
89 |
## Text content retrieval
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
txt_ids = txt_ids[0:1]
|
99 |
|
100 |
return self.SearchResult(
|
101 |
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
|
102 |
ids=[*ent_ids, *comm_ids, *txt_ids],
|
103 |
query_vector=q_vec,
|
104 |
-
aggregation=None,
|
105 |
highlight=None,
|
106 |
-
field={**
|
107 |
keywords=[]
|
108 |
)
|
109 |
-
|
|
|
15 |
#
|
16 |
import json
|
17 |
from copy import deepcopy
|
18 |
+
from typing import Dict
|
19 |
|
20 |
import pandas as pd
|
21 |
+
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
|
22 |
|
23 |
from rag.nlp.search import Dealer
|
24 |
|
25 |
|
26 |
class KGSearch(Dealer):
|
27 |
+
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
|
28 |
+
def merge_into_first(sres, title="") -> Dict[str, str]:
|
29 |
+
if not sres:
|
30 |
+
return {}
|
31 |
+
content_with_weight = ""
|
32 |
+
df, texts = [],[]
|
33 |
+
for d in sres.values():
|
34 |
try:
|
35 |
+
df.append(json.loads(d["content_with_weight"]))
|
36 |
+
except Exception:
|
37 |
+
texts.append(d["content_with_weight"])
|
|
|
|
|
38 |
if df:
|
39 |
+
content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
|
|
|
|
|
|
|
40 |
else:
|
41 |
+
content_with_weight = title + "\n" + "\n".join(texts)
|
42 |
+
first_id = ""
|
43 |
+
first_source = {}
|
44 |
+
for k, v in sres.items():
|
45 |
+
first_id = id
|
46 |
+
first_source = deepcopy(v)
|
47 |
+
break
|
48 |
+
first_source["content_with_weight"] = content_with_weight
|
49 |
+
first_id = next(iter(sres))
|
50 |
+
return {first_id: first_source}
|
51 |
+
|
52 |
+
qst = req.get("question", "")
|
53 |
+
matchText, keywords = self.qryr.question(qst, min_match=0.05)
|
54 |
+
condition = self.get_filters(req)
|
55 |
|
56 |
+
## Entity retrieval
|
57 |
+
condition.update({"knowledge_graph_kwd": ["entity"]})
|
58 |
+
assert emb_mdl, "No embedding model selected"
|
59 |
+
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
|
60 |
+
q_vec = matchDense.embedding_data
|
61 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
62 |
+
"doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
|
63 |
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
|
64 |
"weight_int", "weight_flt", "rank_int"
|
65 |
])
|
66 |
|
67 |
+
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
|
|
|
|
|
68 |
|
69 |
+
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
|
70 |
+
ent_res_fields = self.dataStore.getFields(ent_res, src)
|
71 |
+
entities = [d["name_kwd"] for d in ent_res_fields.values()]
|
72 |
+
ent_ids = self.dataStore.getChunkIds(ent_res)
|
73 |
+
ent_content = merge_into_first(ent_res_fields, "-Entities-")
|
74 |
+
if ent_content:
|
75 |
+
ent_ids = list(ent_content.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
## Community retrieval
|
78 |
+
condition = self.get_filters(req)
|
79 |
+
condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
|
80 |
+
comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
|
81 |
+
comm_res_fields = self.dataStore.getFields(comm_res, src)
|
82 |
+
comm_ids = self.dataStore.getChunkIds(comm_res)
|
83 |
+
comm_content = merge_into_first(comm_res_fields, "-Community Report-")
|
84 |
+
if comm_content:
|
85 |
+
comm_ids = list(comm_content.keys())
|
|
|
|
|
86 |
|
87 |
## Text content retrieval
|
88 |
+
condition = self.get_filters(req)
|
89 |
+
condition.update({"knowledge_graph_kwd": ["text"]})
|
90 |
+
txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
|
91 |
+
txt_res_fields = self.dataStore.getFields(txt_res, src)
|
92 |
+
txt_ids = self.dataStore.getChunkIds(txt_res)
|
93 |
+
txt_content = merge_into_first(txt_res_fields, "-Original Content-")
|
94 |
+
if txt_content:
|
95 |
+
txt_ids = list(txt_content.keys())
|
|
|
96 |
|
97 |
return self.SearchResult(
|
98 |
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
|
99 |
ids=[*ent_ids, *comm_ids, *txt_ids],
|
100 |
query_vector=q_vec,
|
|
|
101 |
highlight=None,
|
102 |
+
field={**ent_content, **comm_content, **txt_content},
|
103 |
keywords=[]
|
104 |
)
|
|
graphrag/smoke.py
CHANGED
@@ -31,10 +31,13 @@ if __name__ == "__main__":
|
|
31 |
from api.db import LLMType
|
32 |
from api.db.services.llm_service import LLMBundle
|
33 |
from api.settings import retrievaler
|
|
|
|
|
|
|
34 |
|
35 |
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
36 |
docs = [d["content_with_weight"] for d in
|
37 |
-
retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])]
|
38 |
graph = ex(docs)
|
39 |
|
40 |
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
|
|
31 |
from api.db import LLMType
|
32 |
from api.db.services.llm_service import LLMBundle
|
33 |
from api.settings import retrievaler
|
34 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
35 |
+
|
36 |
+
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
37 |
|
38 |
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
39 |
docs = [d["content_with_weight"] for d in
|
40 |
+
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
41 |
graph = ex(docs)
|
42 |
|
43 |
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
poetry.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
CHANGED
@@ -46,22 +46,23 @@ hanziconv = "0.3.2"
|
|
46 |
html-text = "0.6.2"
|
47 |
httpx = "0.27.0"
|
48 |
huggingface-hub = "^0.25.0"
|
49 |
-
infinity-
|
|
|
50 |
itsdangerous = "2.1.2"
|
51 |
markdown = "3.6"
|
52 |
markdown-to-json = "2.1.1"
|
53 |
minio = "7.2.4"
|
54 |
mistralai = "0.4.2"
|
55 |
nltk = "3.9.1"
|
56 |
-
numpy = "1.26.
|
57 |
ollama = "0.2.1"
|
58 |
onnxruntime = "1.19.2"
|
59 |
openai = "1.45.0"
|
60 |
opencv-python = "4.10.0.84"
|
61 |
opencv-python-headless = "4.10.0.84"
|
62 |
-
openpyxl = "3.1.
|
63 |
ormsgpack = "1.5.0"
|
64 |
-
pandas = "2.2.
|
65 |
pdfplumber = "0.10.4"
|
66 |
peewee = "3.17.1"
|
67 |
pillow = "10.4.0"
|
@@ -70,7 +71,7 @@ psycopg2-binary = "2.9.9"
|
|
70 |
pyclipper = "1.3.0.post5"
|
71 |
pycryptodomex = "3.20.0"
|
72 |
pypdf = "^5.0.0"
|
73 |
-
pytest = "8.
|
74 |
python-dotenv = "1.0.1"
|
75 |
python-dateutil = "2.8.2"
|
76 |
python-pptx = "^1.0.2"
|
@@ -86,7 +87,7 @@ ruamel-base = "1.0.0"
|
|
86 |
scholarly = "1.7.11"
|
87 |
scikit-learn = "1.5.0"
|
88 |
selenium = "4.22.0"
|
89 |
-
setuptools = "
|
90 |
shapely = "2.0.5"
|
91 |
six = "1.16.0"
|
92 |
strenum = "0.4.15"
|
@@ -115,6 +116,7 @@ pymysql = "^1.1.1"
|
|
115 |
mini-racer = "^0.12.4"
|
116 |
pyicu = "^2.13.1"
|
117 |
flasgger = "^0.9.7.1"
|
|
|
118 |
|
119 |
|
120 |
[tool.poetry.group.full]
|
|
|
46 |
html-text = "0.6.2"
|
47 |
httpx = "0.27.0"
|
48 |
huggingface-hub = "^0.25.0"
|
49 |
+
infinity-sdk = "0.5.0.dev2"
|
50 |
+
infinity-emb = "^0.0.66"
|
51 |
itsdangerous = "2.1.2"
|
52 |
markdown = "3.6"
|
53 |
markdown-to-json = "2.1.1"
|
54 |
minio = "7.2.4"
|
55 |
mistralai = "0.4.2"
|
56 |
nltk = "3.9.1"
|
57 |
+
numpy = "^1.26.0"
|
58 |
ollama = "0.2.1"
|
59 |
onnxruntime = "1.19.2"
|
60 |
openai = "1.45.0"
|
61 |
opencv-python = "4.10.0.84"
|
62 |
opencv-python-headless = "4.10.0.84"
|
63 |
+
openpyxl = "^3.1.0"
|
64 |
ormsgpack = "1.5.0"
|
65 |
+
pandas = "^2.2.0"
|
66 |
pdfplumber = "0.10.4"
|
67 |
peewee = "3.17.1"
|
68 |
pillow = "10.4.0"
|
|
|
71 |
pyclipper = "1.3.0.post5"
|
72 |
pycryptodomex = "3.20.0"
|
73 |
pypdf = "^5.0.0"
|
74 |
+
pytest = "^8.3.0"
|
75 |
python-dotenv = "1.0.1"
|
76 |
python-dateutil = "2.8.2"
|
77 |
python-pptx = "^1.0.2"
|
|
|
87 |
scholarly = "1.7.11"
|
88 |
scikit-learn = "1.5.0"
|
89 |
selenium = "4.22.0"
|
90 |
+
setuptools = "^75.2.0"
|
91 |
shapely = "2.0.5"
|
92 |
six = "1.16.0"
|
93 |
strenum = "0.4.15"
|
|
|
116 |
mini-racer = "^0.12.4"
|
117 |
pyicu = "^2.13.1"
|
118 |
flasgger = "^0.9.7.1"
|
119 |
+
polars = "^1.9.0"
|
120 |
|
121 |
|
122 |
[tool.poetry.group.full]
|
rag/app/presentation.py
CHANGED
@@ -20,6 +20,7 @@ from rag.nlp import tokenize, is_english
|
|
20 |
from rag.nlp import rag_tokenizer
|
21 |
from deepdoc.parser import PdfParser, PptParser, PlainParser
|
22 |
from PyPDF2 import PdfReader as pdf2_read
|
|
|
23 |
|
24 |
|
25 |
class Ppt(PptParser):
|
@@ -107,9 +108,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
107 |
d = copy.deepcopy(doc)
|
108 |
pn += from_page
|
109 |
d["image"] = img
|
110 |
-
d["
|
111 |
-
d["
|
112 |
-
d["
|
113 |
tokenize(d, txt, eng)
|
114 |
res.append(d)
|
115 |
return res
|
@@ -123,10 +124,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
123 |
pn += from_page
|
124 |
if img:
|
125 |
d["image"] = img
|
126 |
-
d["
|
127 |
-
d["
|
128 |
-
d["
|
129 |
-
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
130 |
tokenize(d, txt, eng)
|
131 |
res.append(d)
|
132 |
return res
|
|
|
20 |
from rag.nlp import rag_tokenizer
|
21 |
from deepdoc.parser import PdfParser, PptParser, PlainParser
|
22 |
from PyPDF2 import PdfReader as pdf2_read
|
23 |
+
import json
|
24 |
|
25 |
|
26 |
class Ppt(PptParser):
|
|
|
108 |
d = copy.deepcopy(doc)
|
109 |
pn += from_page
|
110 |
d["image"] = img
|
111 |
+
d["page_num_list"] = json.dumps([pn + 1])
|
112 |
+
d["top_list"] = json.dumps([0])
|
113 |
+
d["position_list"] = json.dumps([(pn + 1, 0, img.size[0], 0, img.size[1])])
|
114 |
tokenize(d, txt, eng)
|
115 |
res.append(d)
|
116 |
return res
|
|
|
124 |
pn += from_page
|
125 |
if img:
|
126 |
d["image"] = img
|
127 |
+
d["page_num_list"] = json.dumps([pn + 1])
|
128 |
+
d["top_list"] = json.dumps([0])
|
129 |
+
d["position_list"] = json.dumps([
|
130 |
+
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)])
|
131 |
tokenize(d, txt, eng)
|
132 |
res.append(d)
|
133 |
return res
|
rag/app/table.py
CHANGED
@@ -74,7 +74,7 @@ class Excel(ExcelParser):
|
|
74 |
def trans_datatime(s):
|
75 |
try:
|
76 |
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
|
77 |
-
except Exception
|
78 |
pass
|
79 |
|
80 |
|
@@ -112,7 +112,7 @@ def column_data_type(arr):
|
|
112 |
continue
|
113 |
try:
|
114 |
arr[i] = trans[ty](str(arr[i]))
|
115 |
-
except Exception
|
116 |
arr[i] = None
|
117 |
# if ty == "text":
|
118 |
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
@@ -182,7 +182,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
|
|
182 |
"datetime": "_dt",
|
183 |
"bool": "_kwd"}
|
184 |
for df in dfs:
|
185 |
-
for n in ["id", "
|
186 |
if n in df.columns:
|
187 |
del df[n]
|
188 |
clmns = df.columns.values
|
|
|
74 |
def trans_datatime(s):
|
75 |
try:
|
76 |
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
|
77 |
+
except Exception:
|
78 |
pass
|
79 |
|
80 |
|
|
|
112 |
continue
|
113 |
try:
|
114 |
arr[i] = trans[ty](str(arr[i]))
|
115 |
+
except Exception:
|
116 |
arr[i] = None
|
117 |
# if ty == "text":
|
118 |
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
|
|
182 |
"datetime": "_dt",
|
183 |
"bool": "_kwd"}
|
184 |
for df in dfs:
|
185 |
+
for n in ["id", "index", "idx"]:
|
186 |
if n in df.columns:
|
187 |
del df[n]
|
188 |
clmns = df.columns.values
|
rag/benchmark.py
CHANGED
@@ -1,280 +1,310 @@
|
|
1 |
-
#
|
2 |
-
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
#
|
16 |
-
import json
|
17 |
-
import os
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
from api.db
|
24 |
-
from api.db.services.
|
25 |
-
from api.
|
26 |
-
from api.
|
27 |
-
from api.utils
|
28 |
-
from rag.nlp import tokenize, search
|
29 |
-
from
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
class Benchmark:
|
37 |
-
def __init__(self, kb_id):
|
38 |
-
|
39 |
-
self.
|
40 |
-
self.
|
41 |
-
self.
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
run = defaultdict(dict)
|
51 |
-
query_list = list(qrels.keys())
|
52 |
-
for query in query_list:
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
for
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
"
|
109 |
-
"
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
docs = []
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
docs
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
for
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
if
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
import time
|
20 |
+
import argparse
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
from api.db import LLMType
|
24 |
+
from api.db.services.llm_service import LLMBundle
|
25 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
26 |
+
from api.settings import retrievaler, docStoreConn
|
27 |
+
from api.utils import get_uuid
|
28 |
+
from rag.nlp import tokenize, search
|
29 |
+
from ranx import evaluate
|
30 |
+
import pandas as pd
|
31 |
+
from tqdm import tqdm
|
32 |
+
|
33 |
+
global max_docs
|
34 |
+
max_docs = sys.maxsize
|
35 |
+
|
36 |
+
class Benchmark:
|
37 |
+
def __init__(self, kb_id):
|
38 |
+
self.kb_id = kb_id
|
39 |
+
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
|
40 |
+
self.similarity_threshold = self.kb.similarity_threshold
|
41 |
+
self.vector_similarity_weight = self.kb.vector_similarity_weight
|
42 |
+
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
|
43 |
+
self.tenant_id = ''
|
44 |
+
self.index_name = ''
|
45 |
+
self.initialized_index = False
|
46 |
+
|
47 |
+
def _get_retrieval(self, qrels):
|
48 |
+
# Need to wait for the ES and Infinity index to be ready
|
49 |
+
time.sleep(20)
|
50 |
+
run = defaultdict(dict)
|
51 |
+
query_list = list(qrels.keys())
|
52 |
+
for query in query_list:
|
53 |
+
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
54 |
+
0.0, self.vector_similarity_weight)
|
55 |
+
if len(ranks["chunks"]) == 0:
|
56 |
+
print(f"deleted query: {query}")
|
57 |
+
del qrels[query]
|
58 |
+
continue
|
59 |
+
for c in ranks["chunks"]:
|
60 |
+
if "vector" in c:
|
61 |
+
del c["vector"]
|
62 |
+
run[query][c["chunk_id"]] = c["similarity"]
|
63 |
+
return run
|
64 |
+
|
65 |
+
def embedding(self, docs, batch_size=16):
|
66 |
+
vects = []
|
67 |
+
cnts = [d["content_with_weight"] for d in docs]
|
68 |
+
for i in range(0, len(cnts), batch_size):
|
69 |
+
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
70 |
+
vects.extend(vts.tolist())
|
71 |
+
assert len(docs) == len(vects)
|
72 |
+
vector_size = 0
|
73 |
+
for i, d in enumerate(docs):
|
74 |
+
v = vects[i]
|
75 |
+
vector_size = len(v)
|
76 |
+
d["q_%d_vec" % len(v)] = v
|
77 |
+
return docs, vector_size
|
78 |
+
|
79 |
+
def init_index(self, vector_size: int):
|
80 |
+
if self.initialized_index:
|
81 |
+
return
|
82 |
+
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
83 |
+
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
84 |
+
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
85 |
+
self.initialized_index = True
|
86 |
+
|
87 |
+
def ms_marco_index(self, file_path, index_name):
|
88 |
+
qrels = defaultdict(dict)
|
89 |
+
texts = defaultdict(dict)
|
90 |
+
docs_count = 0
|
91 |
+
docs = []
|
92 |
+
filelist = sorted(os.listdir(file_path))
|
93 |
+
|
94 |
+
for fn in filelist:
|
95 |
+
if docs_count >= max_docs:
|
96 |
+
break
|
97 |
+
if not fn.endswith(".parquet"):
|
98 |
+
continue
|
99 |
+
data = pd.read_parquet(os.path.join(file_path, fn))
|
100 |
+
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
|
101 |
+
if docs_count >= max_docs:
|
102 |
+
break
|
103 |
+
query = data.iloc[i]['query']
|
104 |
+
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
105 |
+
d = {
|
106 |
+
"id": get_uuid(),
|
107 |
+
"kb_id": self.kb.id,
|
108 |
+
"docnm_kwd": "xxxxx",
|
109 |
+
"doc_id": "ksksks"
|
110 |
+
}
|
111 |
+
tokenize(d, text, "english")
|
112 |
+
docs.append(d)
|
113 |
+
texts[d["id"]] = text
|
114 |
+
qrels[query][d["id"]] = int(rel)
|
115 |
+
if len(docs) >= 32:
|
116 |
+
docs_count += len(docs)
|
117 |
+
docs, vector_size = self.embedding(docs)
|
118 |
+
self.init_index(vector_size)
|
119 |
+
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
120 |
+
docs = []
|
121 |
+
|
122 |
+
if docs:
|
123 |
+
docs, vector_size = self.embedding(docs)
|
124 |
+
self.init_index(vector_size)
|
125 |
+
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
126 |
+
return qrels, texts
|
127 |
+
|
128 |
+
def trivia_qa_index(self, file_path, index_name):
|
129 |
+
qrels = defaultdict(dict)
|
130 |
+
texts = defaultdict(dict)
|
131 |
+
docs_count = 0
|
132 |
+
docs = []
|
133 |
+
filelist = sorted(os.listdir(file_path))
|
134 |
+
for fn in filelist:
|
135 |
+
if docs_count >= max_docs:
|
136 |
+
break
|
137 |
+
if not fn.endswith(".parquet"):
|
138 |
+
continue
|
139 |
+
data = pd.read_parquet(os.path.join(file_path, fn))
|
140 |
+
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
|
141 |
+
if docs_count >= max_docs:
|
142 |
+
break
|
143 |
+
query = data.iloc[i]['question']
|
144 |
+
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
145 |
+
data.iloc[i]["search_results"]['search_context']):
|
146 |
+
d = {
|
147 |
+
"id": get_uuid(),
|
148 |
+
"kb_id": self.kb.id,
|
149 |
+
"docnm_kwd": "xxxxx",
|
150 |
+
"doc_id": "ksksks"
|
151 |
+
}
|
152 |
+
tokenize(d, text, "english")
|
153 |
+
docs.append(d)
|
154 |
+
texts[d["id"]] = text
|
155 |
+
qrels[query][d["id"]] = int(rel)
|
156 |
+
if len(docs) >= 32:
|
157 |
+
docs_count += len(docs)
|
158 |
+
docs, vector_size = self.embedding(docs)
|
159 |
+
self.init_index(vector_size)
|
160 |
+
docStoreConn.insert(docs,self.index_name)
|
161 |
+
docs = []
|
162 |
+
|
163 |
+
docs, vector_size = self.embedding(docs)
|
164 |
+
self.init_index(vector_size)
|
165 |
+
docStoreConn.insert(docs, self.index_name)
|
166 |
+
return qrels, texts
|
167 |
+
|
168 |
+
def miracl_index(self, file_path, corpus_path, index_name):
|
169 |
+
corpus_total = {}
|
170 |
+
for corpus_file in os.listdir(corpus_path):
|
171 |
+
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
|
172 |
+
for index, i in tmp_data.iterrows():
|
173 |
+
corpus_total[i['docid']] = i['text']
|
174 |
+
|
175 |
+
topics_total = {}
|
176 |
+
for topics_file in os.listdir(os.path.join(file_path, 'topics')):
|
177 |
+
if 'test' in topics_file:
|
178 |
+
continue
|
179 |
+
tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
|
180 |
+
for index, i in tmp_data.iterrows():
|
181 |
+
topics_total[i['qid']] = i['query']
|
182 |
+
|
183 |
+
qrels = defaultdict(dict)
|
184 |
+
texts = defaultdict(dict)
|
185 |
+
docs_count = 0
|
186 |
+
docs = []
|
187 |
+
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
|
188 |
+
if 'test' in qrels_file:
|
189 |
+
continue
|
190 |
+
if docs_count >= max_docs:
|
191 |
+
break
|
192 |
+
|
193 |
+
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
|
194 |
+
names=['qid', 'Q0', 'docid', 'relevance'])
|
195 |
+
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
|
196 |
+
if docs_count >= max_docs:
|
197 |
+
break
|
198 |
+
query = topics_total[tmp_data.iloc[i]['qid']]
|
199 |
+
text = corpus_total[tmp_data.iloc[i]['docid']]
|
200 |
+
rel = tmp_data.iloc[i]['relevance']
|
201 |
+
d = {
|
202 |
+
"id": get_uuid(),
|
203 |
+
"kb_id": self.kb.id,
|
204 |
+
"docnm_kwd": "xxxxx",
|
205 |
+
"doc_id": "ksksks"
|
206 |
+
}
|
207 |
+
tokenize(d, text, 'english')
|
208 |
+
docs.append(d)
|
209 |
+
texts[d["id"]] = text
|
210 |
+
qrels[query][d["id"]] = int(rel)
|
211 |
+
if len(docs) >= 32:
|
212 |
+
docs_count += len(docs)
|
213 |
+
docs, vector_size = self.embedding(docs)
|
214 |
+
self.init_index(vector_size)
|
215 |
+
docStoreConn.insert(docs, self.index_name)
|
216 |
+
docs = []
|
217 |
+
|
218 |
+
docs, vector_size = self.embedding(docs)
|
219 |
+
self.init_index(vector_size)
|
220 |
+
docStoreConn.insert(docs, self.index_name)
|
221 |
+
return qrels, texts
|
222 |
+
|
223 |
+
def save_results(self, qrels, run, texts, dataset, file_path):
|
224 |
+
keep_result = []
|
225 |
+
run_keys = list(run.keys())
|
226 |
+
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
227 |
+
key = run_keys[run_i]
|
228 |
+
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
229 |
+
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
|
230 |
+
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
231 |
+
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
232 |
+
f.write('## Score For Every Query\n')
|
233 |
+
for keep_result_i in keep_result:
|
234 |
+
f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
|
235 |
+
scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
|
236 |
+
scores = sorted(scores, key=lambda kk: kk[1])
|
237 |
+
for score in scores[:10]:
|
238 |
+
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
239 |
+
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
|
240 |
+
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
|
241 |
+
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
242 |
+
|
243 |
+
def __call__(self, dataset, file_path, miracl_corpus=''):
|
244 |
+
if dataset == "ms_marco_v1.1":
|
245 |
+
self.tenant_id = "benchmark_ms_marco_v11"
|
246 |
+
self.index_name = search.index_name(self.tenant_id)
|
247 |
+
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
248 |
+
run = self._get_retrieval(qrels)
|
249 |
+
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
250 |
+
self.save_results(qrels, run, texts, dataset, file_path)
|
251 |
+
if dataset == "trivia_qa":
|
252 |
+
self.tenant_id = "benchmark_trivia_qa"
|
253 |
+
self.index_name = search.index_name(self.tenant_id)
|
254 |
+
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
255 |
+
run = self._get_retrieval(qrels)
|
256 |
+
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
257 |
+
self.save_results(qrels, run, texts, dataset, file_path)
|
258 |
+
if dataset == "miracl":
|
259 |
+
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
260 |
+
'yo', 'zh']:
|
261 |
+
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
|
262 |
+
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
|
263 |
+
continue
|
264 |
+
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
|
265 |
+
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
|
266 |
+
continue
|
267 |
+
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
|
268 |
+
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
|
269 |
+
continue
|
270 |
+
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
|
271 |
+
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
|
272 |
+
continue
|
273 |
+
self.tenant_id = "benchmark_miracl_" + lang
|
274 |
+
self.index_name = search.index_name(self.tenant_id)
|
275 |
+
self.initialized_index = False
|
276 |
+
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
|
277 |
+
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
278 |
+
"benchmark_miracl_" + lang)
|
279 |
+
run = self._get_retrieval(qrels)
|
280 |
+
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
281 |
+
self.save_results(qrels, run, texts, dataset, file_path)
|
282 |
+
|
283 |
+
|
284 |
+
if __name__ == '__main__':
|
285 |
+
print('*****************RAGFlow Benchmark*****************')
|
286 |
+
parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
|
287 |
+
parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
|
288 |
+
parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
|
289 |
+
parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
|
290 |
+
parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
|
291 |
+
parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
|
292 |
+
|
293 |
+
args = parser.parse_args()
|
294 |
+
max_docs = args.max_docs
|
295 |
+
kb_id = args.kb_id
|
296 |
+
ex = Benchmark(kb_id)
|
297 |
+
|
298 |
+
dataset = args.dataset
|
299 |
+
dataset_path = args.dataset_path
|
300 |
+
|
301 |
+
if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
|
302 |
+
ex(dataset, dataset_path)
|
303 |
+
elif dataset == "miracl":
|
304 |
+
if len(args) < 5:
|
305 |
+
print('Please input the correct parameters!')
|
306 |
+
exit(1)
|
307 |
+
miracl_corpus_path = args[4]
|
308 |
+
ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
|
309 |
+
else:
|
310 |
+
print("Dataset: ", dataset, "not supported!")
|
rag/nlp/__init__.py
CHANGED
@@ -25,6 +25,7 @@ import roman_numbers as r
|
|
25 |
from word2number import w2n
|
26 |
from cn2an import cn2an
|
27 |
from PIL import Image
|
|
|
28 |
|
29 |
all_codecs = [
|
30 |
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
@@ -51,12 +52,12 @@ def find_codec(blob):
|
|
51 |
try:
|
52 |
blob[:1024].decode(c)
|
53 |
return c
|
54 |
-
except Exception
|
55 |
pass
|
56 |
try:
|
57 |
blob.decode(c)
|
58 |
return c
|
59 |
-
except Exception
|
60 |
pass
|
61 |
|
62 |
return "utf-8"
|
@@ -241,7 +242,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
|
241 |
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
242 |
add_positions(d, poss)
|
243 |
ck = pdf_parser.remove_tag(ck)
|
244 |
-
except NotImplementedError
|
245 |
pass
|
246 |
tokenize(d, ck, eng)
|
247 |
res.append(d)
|
@@ -289,13 +290,16 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
|
|
289 |
def add_positions(d, poss):
|
290 |
if not poss:
|
291 |
return
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
for pn, left, right, top, bottom in poss:
|
296 |
-
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
299 |
|
300 |
|
301 |
def remove_contents_table(sections, eng=False):
|
|
|
25 |
from word2number import w2n
|
26 |
from cn2an import cn2an
|
27 |
from PIL import Image
|
28 |
+
import json
|
29 |
|
30 |
all_codecs = [
|
31 |
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
|
|
52 |
try:
|
53 |
blob[:1024].decode(c)
|
54 |
return c
|
55 |
+
except Exception:
|
56 |
pass
|
57 |
try:
|
58 |
blob.decode(c)
|
59 |
return c
|
60 |
+
except Exception:
|
61 |
pass
|
62 |
|
63 |
return "utf-8"
|
|
|
242 |
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
243 |
add_positions(d, poss)
|
244 |
ck = pdf_parser.remove_tag(ck)
|
245 |
+
except NotImplementedError:
|
246 |
pass
|
247 |
tokenize(d, ck, eng)
|
248 |
res.append(d)
|
|
|
290 |
def add_positions(d, poss):
|
291 |
if not poss:
|
292 |
return
|
293 |
+
page_num_list = []
|
294 |
+
position_list = []
|
295 |
+
top_list = []
|
296 |
for pn, left, right, top, bottom in poss:
|
297 |
+
page_num_list.append(int(pn + 1))
|
298 |
+
top_list.append(int(top))
|
299 |
+
position_list.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
300 |
+
d["page_num_list"] = json.dumps(page_num_list)
|
301 |
+
d["position_list"] = json.dumps(position_list)
|
302 |
+
d["top_list"] = json.dumps(top_list)
|
303 |
|
304 |
|
305 |
def remove_contents_table(sections, eng=False):
|
rag/nlp/query.py
CHANGED
@@ -15,20 +15,25 @@
|
|
15 |
#
|
16 |
|
17 |
import json
|
18 |
-
import math
|
19 |
import re
|
20 |
import logging
|
21 |
-
import
|
22 |
-
from elasticsearch_dsl import Q
|
23 |
|
24 |
from rag.nlp import rag_tokenizer, term_weight, synonym
|
25 |
|
26 |
-
|
27 |
-
|
|
|
28 |
self.tw = term_weight.Dealer()
|
29 |
-
self.es = es
|
30 |
self.syn = synonym.Dealer()
|
31 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
@staticmethod
|
34 |
def subSpecialChar(line):
|
@@ -43,12 +48,15 @@ class EsQueryer:
|
|
43 |
for t in arr:
|
44 |
if not re.match(r"[a-zA-Z]+$", t):
|
45 |
e += 1
|
46 |
-
return e * 1. / len(arr) >= 0.7
|
47 |
|
48 |
@staticmethod
|
49 |
def rmWWW(txt):
|
50 |
patts = [
|
51 |
-
(
|
|
|
|
|
|
|
52 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
53 |
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
|
54 |
]
|
@@ -56,16 +64,16 @@ class EsQueryer:
|
|
56 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
57 |
return txt
|
58 |
|
59 |
-
def question(self, txt, tbl="qa", min_match=
|
60 |
txt = re.sub(
|
61 |
r"[ :\r\n\t,,。??/`!!&\^%%]+",
|
62 |
" ",
|
63 |
-
rag_tokenizer.tradi2simp(
|
64 |
-
|
65 |
-
|
66 |
|
67 |
if not self.isChinese(txt):
|
68 |
-
txt =
|
69 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
70 |
tks_w = self.tw.weights(tks)
|
71 |
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
@@ -73,14 +81,20 @@ class EsQueryer:
|
|
73 |
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
74 |
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
75 |
for i in range(1, len(tks_w)):
|
76 |
-
q.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
if not q:
|
78 |
q.append(txt)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
), list(set([t for t in txt.split(" ") if t]))
|
84 |
|
85 |
def need_fine_grained_tokenize(tk):
|
86 |
if len(tk) < 3:
|
@@ -89,7 +103,7 @@ class EsQueryer:
|
|
89 |
return False
|
90 |
return True
|
91 |
|
92 |
-
txt =
|
93 |
qs, keywords = [], []
|
94 |
for tt in self.tw.split(txt)[:256]: # .split(" "):
|
95 |
if not tt:
|
@@ -101,65 +115,71 @@ class EsQueryer:
|
|
101 |
logging.info(json.dumps(twts, ensure_ascii=False))
|
102 |
tms = []
|
103 |
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
104 |
-
sm =
|
|
|
|
|
|
|
|
|
105 |
sm = [
|
106 |
re.sub(
|
107 |
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
108 |
"",
|
109 |
-
m
|
110 |
-
|
|
|
|
|
|
|
111 |
sm = [m for m in sm if len(m) > 1]
|
112 |
|
113 |
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
114 |
keywords.extend(sm)
|
115 |
-
if len(keywords) >= 12:
|
|
|
116 |
|
117 |
tk_syns = self.syn.lookup(tk)
|
118 |
-
tk =
|
119 |
if tk.find(" ") > 0:
|
120 |
-
tk = "
|
121 |
if tk_syns:
|
122 |
tk = f"({tk} %s)" % " ".join(tk_syns)
|
123 |
if sm:
|
124 |
-
tk = f
|
125 |
-
" ".join(sm), " ".join(sm))
|
126 |
if tk.strip():
|
127 |
tms.append((tk, w))
|
128 |
|
129 |
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
130 |
|
131 |
if len(twts) > 1:
|
132 |
-
tms +=
|
133 |
if re.match(r"[0-9a-z ]+$", tt):
|
134 |
-
tms = f
|
135 |
|
136 |
syns = " OR ".join(
|
137 |
-
[
|
|
|
|
|
|
|
|
|
|
|
138 |
if syns:
|
139 |
tms = f"({tms})^5 OR ({syns})^0.7"
|
140 |
|
141 |
qs.append(tms)
|
142 |
|
143 |
-
flds = copy.deepcopy(self.flds)
|
144 |
-
mst = []
|
145 |
if qs:
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
)
|
150 |
-
|
151 |
-
return Q("bool",
|
152 |
-
must=mst,
|
153 |
-
), list(set(keywords))
|
154 |
|
155 |
-
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
|
156 |
-
vtweight=0.7):
|
157 |
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
158 |
import numpy as np
|
|
|
159 |
sims = CosineSimilarity([avec], bvecs)
|
160 |
tksim = self.token_similarity(atks, btkss)
|
161 |
-
return np.array(sims[0]) * vtweight +
|
162 |
-
np.array(tksim) * tkweight, tksim, sims[0]
|
163 |
|
164 |
def token_similarity(self, atks, btkss):
|
165 |
def toDict(tks):
|
|
|
15 |
#
|
16 |
|
17 |
import json
|
|
|
18 |
import re
|
19 |
import logging
|
20 |
+
from rag.utils.doc_store_conn import MatchTextExpr
|
|
|
21 |
|
22 |
from rag.nlp import rag_tokenizer, term_weight, synonym
|
23 |
|
24 |
+
|
25 |
+
class FulltextQueryer:
|
26 |
+
def __init__(self):
|
27 |
self.tw = term_weight.Dealer()
|
|
|
28 |
self.syn = synonym.Dealer()
|
29 |
+
self.query_fields = [
|
30 |
+
"title_tks^10",
|
31 |
+
"title_sm_tks^5",
|
32 |
+
"important_kwd^30",
|
33 |
+
"important_tks^20",
|
34 |
+
"content_ltks^2",
|
35 |
+
"content_sm_ltks",
|
36 |
+
]
|
37 |
|
38 |
@staticmethod
|
39 |
def subSpecialChar(line):
|
|
|
48 |
for t in arr:
|
49 |
if not re.match(r"[a-zA-Z]+$", t):
|
50 |
e += 1
|
51 |
+
return e * 1.0 / len(arr) >= 0.7
|
52 |
|
53 |
@staticmethod
|
54 |
def rmWWW(txt):
|
55 |
patts = [
|
56 |
+
(
|
57 |
+
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
|
58 |
+
"",
|
59 |
+
),
|
60 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
61 |
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
|
62 |
]
|
|
|
64 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
65 |
return txt
|
66 |
|
67 |
+
def question(self, txt, tbl="qa", min_match:float=0.6):
|
68 |
txt = re.sub(
|
69 |
r"[ :\r\n\t,,。??/`!!&\^%%]+",
|
70 |
" ",
|
71 |
+
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
72 |
+
).strip()
|
73 |
+
txt = FulltextQueryer.rmWWW(txt)
|
74 |
|
75 |
if not self.isChinese(txt):
|
76 |
+
txt = FulltextQueryer.rmWWW(txt)
|
77 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
78 |
tks_w = self.tw.weights(tks)
|
79 |
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
|
|
81 |
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
82 |
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
83 |
for i in range(1, len(tks_w)):
|
84 |
+
q.append(
|
85 |
+
'"%s %s"^%.4f'
|
86 |
+
% (
|
87 |
+
tks_w[i - 1][0],
|
88 |
+
tks_w[i][0],
|
89 |
+
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
90 |
+
)
|
91 |
+
)
|
92 |
if not q:
|
93 |
q.append(txt)
|
94 |
+
query = " ".join(q)
|
95 |
+
return MatchTextExpr(
|
96 |
+
self.query_fields, query, 100
|
97 |
+
), tks
|
|
|
98 |
|
99 |
def need_fine_grained_tokenize(tk):
|
100 |
if len(tk) < 3:
|
|
|
103 |
return False
|
104 |
return True
|
105 |
|
106 |
+
txt = FulltextQueryer.rmWWW(txt)
|
107 |
qs, keywords = [], []
|
108 |
for tt in self.tw.split(txt)[:256]: # .split(" "):
|
109 |
if not tt:
|
|
|
115 |
logging.info(json.dumps(twts, ensure_ascii=False))
|
116 |
tms = []
|
117 |
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
118 |
+
sm = (
|
119 |
+
rag_tokenizer.fine_grained_tokenize(tk).split(" ")
|
120 |
+
if need_fine_grained_tokenize(tk)
|
121 |
+
else []
|
122 |
+
)
|
123 |
sm = [
|
124 |
re.sub(
|
125 |
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
126 |
"",
|
127 |
+
m,
|
128 |
+
)
|
129 |
+
for m in sm
|
130 |
+
]
|
131 |
+
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
132 |
sm = [m for m in sm if len(m) > 1]
|
133 |
|
134 |
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
135 |
keywords.extend(sm)
|
136 |
+
if len(keywords) >= 12:
|
137 |
+
break
|
138 |
|
139 |
tk_syns = self.syn.lookup(tk)
|
140 |
+
tk = FulltextQueryer.subSpecialChar(tk)
|
141 |
if tk.find(" ") > 0:
|
142 |
+
tk = '"%s"' % tk
|
143 |
if tk_syns:
|
144 |
tk = f"({tk} %s)" % " ".join(tk_syns)
|
145 |
if sm:
|
146 |
+
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
|
|
147 |
if tk.strip():
|
148 |
tms.append((tk, w))
|
149 |
|
150 |
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
151 |
|
152 |
if len(twts) > 1:
|
153 |
+
tms += ' ("%s"~4)^1.5' % (" ".join([t for t, _ in twts]))
|
154 |
if re.match(r"[0-9a-z ]+$", tt):
|
155 |
+
tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
|
156 |
|
157 |
syns = " OR ".join(
|
158 |
+
[
|
159 |
+
'"%s"^0.7'
|
160 |
+
% FulltextQueryer.subSpecialChar(rag_tokenizer.tokenize(s))
|
161 |
+
for s in syns
|
162 |
+
]
|
163 |
+
)
|
164 |
if syns:
|
165 |
tms = f"({tms})^5 OR ({syns})^0.7"
|
166 |
|
167 |
qs.append(tms)
|
168 |
|
|
|
|
|
169 |
if qs:
|
170 |
+
query = " OR ".join([f"({t})" for t in qs if t])
|
171 |
+
return MatchTextExpr(
|
172 |
+
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
173 |
+
), keywords
|
174 |
+
return None, keywords
|
|
|
|
|
|
|
175 |
|
176 |
+
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
|
|
177 |
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
178 |
import numpy as np
|
179 |
+
|
180 |
sims = CosineSimilarity([avec], bvecs)
|
181 |
tksim = self.token_similarity(atks, btkss)
|
182 |
+
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
|
|
183 |
|
184 |
def token_similarity(self, atks, btkss):
|
185 |
def toDict(tks):
|
rag/nlp/search.py
CHANGED
@@ -14,34 +14,25 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
|
17 |
-
import json
|
18 |
import re
|
19 |
-
|
20 |
-
|
21 |
-
from elasticsearch_dsl import Q, Search
|
22 |
from typing import List, Optional, Dict, Union
|
23 |
from dataclasses import dataclass
|
24 |
|
25 |
-
from rag.settings import
|
26 |
from rag.utils import rmSpace
|
27 |
-
from rag.nlp import rag_tokenizer, query
|
28 |
import numpy as np
|
|
|
29 |
|
30 |
|
31 |
def index_name(uid): return f"ragflow_{uid}"
|
32 |
|
33 |
|
34 |
class Dealer:
|
35 |
-
def __init__(self,
|
36 |
-
self.qryr = query.
|
37 |
-
self.
|
38 |
-
"title_tks^10",
|
39 |
-
"title_sm_tks^5",
|
40 |
-
"important_kwd^30",
|
41 |
-
"important_tks^20",
|
42 |
-
"content_ltks^2",
|
43 |
-
"content_sm_ltks"]
|
44 |
-
self.es = es
|
45 |
|
46 |
@dataclass
|
47 |
class SearchResult:
|
@@ -54,170 +45,99 @@ class Dealer:
|
|
54 |
keywords: Optional[List[str]] = None
|
55 |
group_docs: List[List] = None
|
56 |
|
57 |
-
def
|
58 |
-
qv,
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
}
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
else:
|
78 |
-
bqry.filter.append(
|
79 |
-
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
80 |
-
return bqry
|
81 |
-
|
82 |
-
def search(self, req, idxnms, emb_mdl=None, highlight=False):
|
83 |
-
qst = req.get("question", "")
|
84 |
-
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
85 |
-
bqry = self._add_filters(bqry, req)
|
86 |
-
bqry.boost = 0.05
|
87 |
|
88 |
-
s = Search()
|
89 |
pg = int(req.get("page", 1)) - 1
|
90 |
topk = int(req.get("topk", 1024))
|
91 |
ps = int(req.get("size", topk))
|
|
|
|
|
92 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
93 |
-
"
|
94 |
-
"
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
s = s.highlight("title_ltks")
|
99 |
if not qst:
|
100 |
-
if
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
else:
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
)
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
bqry = Q("bool", must=[])
|
143 |
-
bqry = self._add_filters(bqry, req)
|
144 |
-
s["query"] = bqry.to_dict()
|
145 |
-
s["knn"]["filter"] = bqry.to_dict()
|
146 |
-
s["knn"]["similarity"] = 0.17
|
147 |
-
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
|
148 |
-
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
149 |
-
|
150 |
-
kwds = set([])
|
151 |
-
for k in keywords:
|
152 |
-
kwds.add(k)
|
153 |
-
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
154 |
-
if len(kk) < 2:
|
155 |
-
continue
|
156 |
-
if kk in kwds:
|
157 |
-
continue
|
158 |
-
kwds.add(kk)
|
159 |
-
|
160 |
-
aggs = self.getAggregation(res, "docnm_kwd")
|
161 |
-
|
162 |
return self.SearchResult(
|
163 |
-
total=
|
164 |
-
ids=
|
165 |
query_vector=q_vec,
|
166 |
aggregation=aggs,
|
167 |
-
highlight=
|
168 |
-
field=self.getFields(res, src),
|
169 |
-
keywords=
|
170 |
)
|
171 |
|
172 |
-
def getAggregation(self, res, g):
|
173 |
-
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
174 |
-
return
|
175 |
-
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
176 |
-
return [(b["key"], b["doc_count"]) for b in bkts]
|
177 |
-
|
178 |
-
def getHighlight(self, res, keywords, fieldnm):
|
179 |
-
ans = {}
|
180 |
-
for d in res["hits"]["hits"]:
|
181 |
-
hlts = d.get("highlight")
|
182 |
-
if not hlts:
|
183 |
-
continue
|
184 |
-
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
185 |
-
if not is_english(txt.split(" ")):
|
186 |
-
ans[d["_id"]] = txt
|
187 |
-
continue
|
188 |
-
|
189 |
-
txt = d["_source"][fieldnm]
|
190 |
-
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
191 |
-
txts = []
|
192 |
-
for t in re.split(r"[.?!;\n]", txt):
|
193 |
-
for w in keywords:
|
194 |
-
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
195 |
-
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
|
196 |
-
txts.append(t)
|
197 |
-
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
198 |
-
|
199 |
-
return ans
|
200 |
-
|
201 |
-
def getFields(self, sres, flds):
|
202 |
-
res = {}
|
203 |
-
if not flds:
|
204 |
-
return {}
|
205 |
-
for d in self.es.getSource(sres):
|
206 |
-
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
207 |
-
for n, v in m.items():
|
208 |
-
if isinstance(v, type([])):
|
209 |
-
m[n] = "\t".join([str(vv) if not isinstance(
|
210 |
-
vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
|
211 |
-
continue
|
212 |
-
if not isinstance(v, type("")):
|
213 |
-
m[n] = str(m[n])
|
214 |
-
#if n.find("tks") > 0:
|
215 |
-
# m[n] = rmSpace(m[n])
|
216 |
-
|
217 |
-
if m:
|
218 |
-
res[d["id"]] = m
|
219 |
-
return res
|
220 |
-
|
221 |
@staticmethod
|
222 |
def trans2floats(txt):
|
223 |
return [float(t) for t in txt.split("\t")]
|
@@ -260,7 +180,7 @@ class Dealer:
|
|
260 |
continue
|
261 |
idx.append(i)
|
262 |
pieces_.append(t)
|
263 |
-
|
264 |
if not pieces_:
|
265 |
return answer, set([])
|
266 |
|
@@ -281,7 +201,7 @@ class Dealer:
|
|
281 |
chunks_tks,
|
282 |
tkweight, vtweight)
|
283 |
mx = np.max(sim) * 0.99
|
284 |
-
|
285 |
if mx < thr:
|
286 |
continue
|
287 |
cites[idx[i]] = list(
|
@@ -309,9 +229,15 @@ class Dealer:
|
|
309 |
def rerank(self, sres, query, tkweight=0.3,
|
310 |
vtweight=0.7, cfield="content_ltks"):
|
311 |
_, keywords = self.qryr.question(query)
|
312 |
-
|
313 |
-
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
if not ins_embd:
|
316 |
return [], [], []
|
317 |
|
@@ -377,7 +303,7 @@ class Dealer:
|
|
377 |
if isinstance(tenant_ids, str):
|
378 |
tenant_ids = tenant_ids.split(",")
|
379 |
|
380 |
-
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
|
381 |
ranks["total"] = sres.total
|
382 |
|
383 |
if page <= RERANK_PAGE_LIMIT:
|
@@ -393,6 +319,8 @@ class Dealer:
|
|
393 |
idx = list(range(len(sres.ids)))
|
394 |
|
395 |
dim = len(sres.query_vector)
|
|
|
|
|
396 |
for i in idx:
|
397 |
if sim[i] < similarity_threshold:
|
398 |
break
|
@@ -401,34 +329,32 @@ class Dealer:
|
|
401 |
continue
|
402 |
break
|
403 |
id = sres.ids[i]
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
406 |
d = {
|
407 |
"chunk_id": id,
|
408 |
-
"content_ltks":
|
409 |
-
"content_with_weight":
|
410 |
-
"doc_id":
|
411 |
"docnm_kwd": dnm,
|
412 |
-
"kb_id":
|
413 |
-
"important_kwd":
|
414 |
-
"
|
415 |
"similarity": sim[i],
|
416 |
"vector_similarity": vsim[i],
|
417 |
"term_similarity": tsim[i],
|
418 |
-
"vector":
|
419 |
-
"positions":
|
420 |
}
|
421 |
if highlight:
|
422 |
if id in sres.highlight:
|
423 |
d["highlight"] = rmSpace(sres.highlight[id])
|
424 |
else:
|
425 |
d["highlight"] = d["content_with_weight"]
|
426 |
-
if len(d["positions"]) % 5 == 0:
|
427 |
-
poss = []
|
428 |
-
for i in range(0, len(d["positions"]), 5):
|
429 |
-
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
430 |
-
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
431 |
-
d["positions"] = poss
|
432 |
ranks["chunks"].append(d)
|
433 |
if dnm not in ranks["doc_aggs"]:
|
434 |
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
@@ -442,39 +368,11 @@ class Dealer:
|
|
442 |
return ranks
|
443 |
|
444 |
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
454 |
-
replaces.append(
|
455 |
-
("{}{}'{}'".format(
|
456 |
-
r.group(1),
|
457 |
-
r.group(2),
|
458 |
-
r.group(3)),
|
459 |
-
match))
|
460 |
-
|
461 |
-
for p, r in replaces:
|
462 |
-
sql = sql.replace(p, r, 1)
|
463 |
-
chat_logger.info(f"To es: {sql}")
|
464 |
-
|
465 |
-
try:
|
466 |
-
tbl = self.es.sql(sql, fetch_size, format)
|
467 |
-
return tbl
|
468 |
-
except Exception as e:
|
469 |
-
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
|
470 |
-
return {"error": str(e)}
|
471 |
-
|
472 |
-
def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
473 |
-
s = Search()
|
474 |
-
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
475 |
-
s = s.to_dict()
|
476 |
-
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
|
477 |
-
res = []
|
478 |
-
for index, chunk in enumerate(es_res['hits']['hits']):
|
479 |
-
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
480 |
-
return res
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
|
|
|
17 |
import re
|
18 |
+
import json
|
|
|
|
|
19 |
from typing import List, Optional, Dict, Union
|
20 |
from dataclasses import dataclass
|
21 |
|
22 |
+
from rag.settings import doc_store_logger
|
23 |
from rag.utils import rmSpace
|
24 |
+
from rag.nlp import rag_tokenizer, query
|
25 |
import numpy as np
|
26 |
+
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
27 |
|
28 |
|
29 |
def index_name(uid): return f"ragflow_{uid}"
|
30 |
|
31 |
|
32 |
class Dealer:
|
33 |
+
def __init__(self, dataStore: DocStoreConnection):
|
34 |
+
self.qryr = query.FulltextQueryer()
|
35 |
+
self.dataStore = dataStore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
@dataclass
|
38 |
class SearchResult:
|
|
|
45 |
keywords: Optional[List[str]] = None
|
46 |
group_docs: List[List] = None
|
47 |
|
48 |
+
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
49 |
+
qv, _ = emb_mdl.encode_queries(txt)
|
50 |
+
embedding_data = [float(v) for v in qv]
|
51 |
+
vector_column_name = f"q_{len(embedding_data)}_vec"
|
52 |
+
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
53 |
+
|
54 |
+
def get_filters(self, req):
|
55 |
+
condition = dict()
|
56 |
+
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
|
57 |
+
if key in req and req[key] is not None:
|
58 |
+
condition[field] = req[key]
|
59 |
+
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
60 |
+
for key in ["knowledge_graph_kwd"]:
|
61 |
+
if key in req and req[key] is not None:
|
62 |
+
condition[key] = req[key]
|
63 |
+
return condition
|
64 |
+
|
65 |
+
def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
|
66 |
+
filters = self.get_filters(req)
|
67 |
+
orderBy = OrderByExpr()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
|
|
69 |
pg = int(req.get("page", 1)) - 1
|
70 |
topk = int(req.get("topk", 1024))
|
71 |
ps = int(req.get("size", topk))
|
72 |
+
offset, limit = pg * ps, (pg + 1) * ps
|
73 |
+
|
74 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
75 |
+
"doc_id", "position_list", "knowledge_graph_kwd",
|
76 |
+
"available_int", "content_with_weight"])
|
77 |
+
kwds = set([])
|
78 |
|
79 |
+
qst = req.get("question", "")
|
80 |
+
q_vec = []
|
|
|
81 |
if not qst:
|
82 |
+
if req.get("sort"):
|
83 |
+
orderBy.desc("create_timestamp_flt")
|
84 |
+
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
85 |
+
total=self.dataStore.getTotal(res)
|
86 |
+
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
87 |
+
else:
|
88 |
+
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
89 |
+
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
90 |
+
if emb_mdl is None:
|
91 |
+
matchExprs = [matchText]
|
92 |
+
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
93 |
+
total=self.dataStore.getTotal(res)
|
94 |
+
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
95 |
else:
|
96 |
+
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
97 |
+
q_vec = matchDense.embedding_data
|
98 |
+
src.append(f"q_{len(q_vec)}_vec")
|
99 |
+
|
100 |
+
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
101 |
+
matchExprs = [matchText, matchDense, fusionExpr]
|
102 |
+
|
103 |
+
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
104 |
+
total=self.dataStore.getTotal(res)
|
105 |
+
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
106 |
+
|
107 |
+
# If result is empty, try again with lower min_match
|
108 |
+
if total == 0:
|
109 |
+
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
110 |
+
if "doc_ids" in filters:
|
111 |
+
del filters["doc_ids"]
|
112 |
+
matchDense.extra_options["similarity"] = 0.17
|
113 |
+
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
|
114 |
+
total=self.dataStore.getTotal(res)
|
115 |
+
doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
|
116 |
+
|
117 |
+
for k in keywords:
|
118 |
+
kwds.add(k)
|
119 |
+
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
120 |
+
if len(kk) < 2:
|
121 |
+
continue
|
122 |
+
if kk in kwds:
|
123 |
+
continue
|
124 |
+
kwds.add(kk)
|
125 |
+
|
126 |
+
doc_store_logger.info(f"TOTAL: {total}")
|
127 |
+
ids=self.dataStore.getChunkIds(res)
|
128 |
+
keywords=list(kwds)
|
129 |
+
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
130 |
+
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
return self.SearchResult(
|
132 |
+
total=total,
|
133 |
+
ids=ids,
|
134 |
query_vector=q_vec,
|
135 |
aggregation=aggs,
|
136 |
+
highlight=highlight,
|
137 |
+
field=self.dataStore.getFields(res, src),
|
138 |
+
keywords=keywords
|
139 |
)
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
@staticmethod
|
142 |
def trans2floats(txt):
|
143 |
return [float(t) for t in txt.split("\t")]
|
|
|
180 |
continue
|
181 |
idx.append(i)
|
182 |
pieces_.append(t)
|
183 |
+
doc_store_logger.info("{} => {}".format(answer, pieces_))
|
184 |
if not pieces_:
|
185 |
return answer, set([])
|
186 |
|
|
|
201 |
chunks_tks,
|
202 |
tkweight, vtweight)
|
203 |
mx = np.max(sim) * 0.99
|
204 |
+
doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
|
205 |
if mx < thr:
|
206 |
continue
|
207 |
cites[idx[i]] = list(
|
|
|
229 |
def rerank(self, sres, query, tkweight=0.3,
|
230 |
vtweight=0.7, cfield="content_ltks"):
|
231 |
_, keywords = self.qryr.question(query)
|
232 |
+
vector_size = len(sres.query_vector)
|
233 |
+
vector_column = f"q_{vector_size}_vec"
|
234 |
+
zero_vector = [0.0] * vector_size
|
235 |
+
ins_embd = []
|
236 |
+
for chunk_id in sres.ids:
|
237 |
+
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
238 |
+
if isinstance(vector, str):
|
239 |
+
vector = [float(v) for v in vector.split("\t")]
|
240 |
+
ins_embd.append(vector)
|
241 |
if not ins_embd:
|
242 |
return [], [], []
|
243 |
|
|
|
303 |
if isinstance(tenant_ids, str):
|
304 |
tenant_ids = tenant_ids.split(",")
|
305 |
|
306 |
+
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
|
307 |
ranks["total"] = sres.total
|
308 |
|
309 |
if page <= RERANK_PAGE_LIMIT:
|
|
|
319 |
idx = list(range(len(sres.ids)))
|
320 |
|
321 |
dim = len(sres.query_vector)
|
322 |
+
vector_column = f"q_{dim}_vec"
|
323 |
+
zero_vector = [0.0] * dim
|
324 |
for i in idx:
|
325 |
if sim[i] < similarity_threshold:
|
326 |
break
|
|
|
329 |
continue
|
330 |
break
|
331 |
id = sres.ids[i]
|
332 |
+
chunk = sres.field[id]
|
333 |
+
dnm = chunk["docnm_kwd"]
|
334 |
+
did = chunk["doc_id"]
|
335 |
+
position_list = chunk.get("position_list", "[]")
|
336 |
+
if not position_list:
|
337 |
+
position_list = "[]"
|
338 |
d = {
|
339 |
"chunk_id": id,
|
340 |
+
"content_ltks": chunk["content_ltks"],
|
341 |
+
"content_with_weight": chunk["content_with_weight"],
|
342 |
+
"doc_id": chunk["doc_id"],
|
343 |
"docnm_kwd": dnm,
|
344 |
+
"kb_id": chunk["kb_id"],
|
345 |
+
"important_kwd": chunk.get("important_kwd", []),
|
346 |
+
"image_id": chunk.get("img_id", ""),
|
347 |
"similarity": sim[i],
|
348 |
"vector_similarity": vsim[i],
|
349 |
"term_similarity": tsim[i],
|
350 |
+
"vector": chunk.get(vector_column, zero_vector),
|
351 |
+
"positions": json.loads(position_list)
|
352 |
}
|
353 |
if highlight:
|
354 |
if id in sres.highlight:
|
355 |
d["highlight"] = rmSpace(sres.highlight[id])
|
356 |
else:
|
357 |
d["highlight"] = d["content_with_weight"]
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
ranks["chunks"].append(d)
|
359 |
if dnm not in ranks["doc_aggs"]:
|
360 |
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
|
|
368 |
return ranks
|
369 |
|
370 |
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
371 |
+
tbl = self.dataStore.sql(sql, fetch_size, format)
|
372 |
+
return tbl
|
373 |
+
|
374 |
+
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
375 |
+
condition = {"doc_id": doc_id}
|
376 |
+
res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
|
377 |
+
dict_chunks = self.dataStore.getFields(res, fields)
|
378 |
+
return dict_chunks.values()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag/settings.py
CHANGED
@@ -25,12 +25,13 @@ RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
|
25 |
SUBPROCESS_STD_LOG_NAME = "std.log"
|
26 |
|
27 |
ES = get_base_config("es", {})
|
|
|
28 |
AZURE = get_base_config("azure", {})
|
29 |
S3 = get_base_config("s3", {})
|
30 |
MINIO = decrypt_database_config(name="minio")
|
31 |
try:
|
32 |
REDIS = decrypt_database_config(name="redis")
|
33 |
-
except Exception
|
34 |
REDIS = {}
|
35 |
pass
|
36 |
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
@@ -44,7 +45,7 @@ LoggerFactory.set_directory(
|
|
44 |
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
45 |
LoggerFactory.LEVEL = 30
|
46 |
|
47 |
-
|
48 |
minio_logger = getLogger("minio")
|
49 |
s3_logger = getLogger("s3")
|
50 |
azure_logger = getLogger("azure")
|
@@ -53,7 +54,7 @@ chunk_logger = getLogger("chunk_logger")
|
|
53 |
database_logger = getLogger("database")
|
54 |
|
55 |
formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
|
56 |
-
for logger in [
|
57 |
logger.setLevel(logging.INFO)
|
58 |
for handler in logger.handlers:
|
59 |
handler.setFormatter(fmt=formatter)
|
|
|
25 |
SUBPROCESS_STD_LOG_NAME = "std.log"
|
26 |
|
27 |
ES = get_base_config("es", {})
|
28 |
+
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
29 |
AZURE = get_base_config("azure", {})
|
30 |
S3 = get_base_config("s3", {})
|
31 |
MINIO = decrypt_database_config(name="minio")
|
32 |
try:
|
33 |
REDIS = decrypt_database_config(name="redis")
|
34 |
+
except Exception:
|
35 |
REDIS = {}
|
36 |
pass
|
37 |
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
|
|
45 |
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
46 |
LoggerFactory.LEVEL = 30
|
47 |
|
48 |
+
doc_store_logger = getLogger("doc_store")
|
49 |
minio_logger = getLogger("minio")
|
50 |
s3_logger = getLogger("s3")
|
51 |
azure_logger = getLogger("azure")
|
|
|
54 |
database_logger = getLogger("database")
|
55 |
|
56 |
formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
|
57 |
+
for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
|
58 |
logger.setLevel(logging.INFO)
|
59 |
for handler in logger.handlers:
|
60 |
handler.setFormatter(fmt=formatter)
|
rag/svr/task_executor.py
CHANGED
@@ -31,7 +31,6 @@ from timeit import default_timer as timer
|
|
31 |
|
32 |
import numpy as np
|
33 |
import pandas as pd
|
34 |
-
from elasticsearch_dsl import Q
|
35 |
|
36 |
from api.db import LLMType, ParserType
|
37 |
from api.db.services.dialog_service import keyword_extraction, question_proposal
|
@@ -39,8 +38,7 @@ from api.db.services.document_service import DocumentService
|
|
39 |
from api.db.services.llm_service import LLMBundle
|
40 |
from api.db.services.task_service import TaskService
|
41 |
from api.db.services.file2document_service import File2DocumentService
|
42 |
-
from api.settings import retrievaler
|
43 |
-
from api.utils.file_utils import get_project_base_directory
|
44 |
from api.db.db_models import close_connection
|
45 |
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
46 |
from rag.nlp import search, rag_tokenizer
|
@@ -48,7 +46,6 @@ from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as
|
|
48 |
from rag.settings import database_logger, SVR_QUEUE_NAME
|
49 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
50 |
from rag.utils import rmSpace, num_tokens_from_string
|
51 |
-
from rag.utils.es_conn import ELASTICSEARCH
|
52 |
from rag.utils.redis_conn import REDIS_CONN, Payload
|
53 |
from rag.utils.storage_factory import STORAGE_IMPL
|
54 |
|
@@ -126,7 +123,7 @@ def collect():
|
|
126 |
return pd.DataFrame()
|
127 |
tasks = TaskService.get_tasks(msg["id"])
|
128 |
if not tasks:
|
129 |
-
cron_logger.
|
130 |
return []
|
131 |
|
132 |
tasks = pd.DataFrame(tasks)
|
@@ -187,7 +184,7 @@ def build(row):
|
|
187 |
docs = []
|
188 |
doc = {
|
189 |
"doc_id": row["doc_id"],
|
190 |
-
"kb_id":
|
191 |
}
|
192 |
el = 0
|
193 |
for ck in cks:
|
@@ -196,10 +193,14 @@ def build(row):
|
|
196 |
md5 = hashlib.md5()
|
197 |
md5.update((ck["content_with_weight"] +
|
198 |
str(d["doc_id"])).encode("utf-8"))
|
199 |
-
d["
|
200 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
201 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
202 |
if not d.get("image"):
|
|
|
|
|
|
|
|
|
203 |
docs.append(d)
|
204 |
continue
|
205 |
|
@@ -211,13 +212,13 @@ def build(row):
|
|
211 |
d["image"].save(output_buffer, format='JPEG')
|
212 |
|
213 |
st = timer()
|
214 |
-
STORAGE_IMPL.put(row["kb_id"], d["
|
215 |
el += timer() - st
|
216 |
except Exception as e:
|
217 |
cron_logger.error(str(e))
|
218 |
traceback.print_exc()
|
219 |
|
220 |
-
d["img_id"] = "{}-{}".format(row["kb_id"], d["
|
221 |
del d["image"]
|
222 |
docs.append(d)
|
223 |
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
|
@@ -245,12 +246,9 @@ def build(row):
|
|
245 |
return docs
|
246 |
|
247 |
|
248 |
-
def init_kb(row):
|
249 |
idxnm = search.index_name(row["tenant_id"])
|
250 |
-
|
251 |
-
return
|
252 |
-
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
253 |
-
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
254 |
|
255 |
|
256 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
@@ -288,17 +286,20 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
|
288 |
cnts) if len(tts) == len(cnts) else cnts
|
289 |
|
290 |
assert len(vects) == len(docs)
|
|
|
291 |
for i, d in enumerate(docs):
|
292 |
v = vects[i].tolist()
|
|
|
293 |
d["q_%d_vec" % len(v)] = v
|
294 |
-
return tk_count
|
295 |
|
296 |
|
297 |
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
298 |
vts, _ = embd_mdl.encode(["ok"])
|
299 |
-
|
|
|
300 |
chunks = []
|
301 |
-
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
302 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
303 |
|
304 |
raptor = Raptor(
|
@@ -323,7 +324,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|
323 |
d = copy.deepcopy(doc)
|
324 |
md5 = hashlib.md5()
|
325 |
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
326 |
-
d["
|
327 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
328 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
329 |
d[vctr_nm] = vctr.tolist()
|
@@ -332,7 +333,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|
332 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
333 |
res.append(d)
|
334 |
tk_count += num_tokens_from_string(content)
|
335 |
-
return res, tk_count
|
336 |
|
337 |
|
338 |
def main():
|
@@ -352,7 +353,7 @@ def main():
|
|
352 |
if r.get("task_type", "") == "raptor":
|
353 |
try:
|
354 |
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
355 |
-
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
|
356 |
except Exception as e:
|
357 |
callback(-1, msg=str(e))
|
358 |
cron_logger.error(str(e))
|
@@ -373,7 +374,7 @@ def main():
|
|
373 |
len(cks))
|
374 |
st = timer()
|
375 |
try:
|
376 |
-
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
377 |
except Exception as e:
|
378 |
callback(-1, "Embedding error:{}".format(str(e)))
|
379 |
cron_logger.error(str(e))
|
@@ -381,26 +382,25 @@ def main():
|
|
381 |
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
382 |
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
383 |
|
384 |
-
init_kb(r)
|
385 |
-
|
|
|
386 |
st = timer()
|
387 |
es_r = ""
|
388 |
es_bulk_size = 4
|
389 |
for b in range(0, len(cks), es_bulk_size):
|
390 |
-
es_r =
|
391 |
if b % 128 == 0:
|
392 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
393 |
|
394 |
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
395 |
if es_r:
|
396 |
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
397 |
-
|
398 |
-
|
399 |
-
cron_logger.error(str(es_r))
|
400 |
else:
|
401 |
if TaskService.do_cancel(r["id"]):
|
402 |
-
|
403 |
-
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
404 |
continue
|
405 |
callback(1., "Done!")
|
406 |
DocumentService.increment_chunk_num(
|
|
|
31 |
|
32 |
import numpy as np
|
33 |
import pandas as pd
|
|
|
34 |
|
35 |
from api.db import LLMType, ParserType
|
36 |
from api.db.services.dialog_service import keyword_extraction, question_proposal
|
|
|
38 |
from api.db.services.llm_service import LLMBundle
|
39 |
from api.db.services.task_service import TaskService
|
40 |
from api.db.services.file2document_service import File2DocumentService
|
41 |
+
from api.settings import retrievaler, docStoreConn
|
|
|
42 |
from api.db.db_models import close_connection
|
43 |
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
44 |
from rag.nlp import search, rag_tokenizer
|
|
|
46 |
from rag.settings import database_logger, SVR_QUEUE_NAME
|
47 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
48 |
from rag.utils import rmSpace, num_tokens_from_string
|
|
|
49 |
from rag.utils.redis_conn import REDIS_CONN, Payload
|
50 |
from rag.utils.storage_factory import STORAGE_IMPL
|
51 |
|
|
|
123 |
return pd.DataFrame()
|
124 |
tasks = TaskService.get_tasks(msg["id"])
|
125 |
if not tasks:
|
126 |
+
cron_logger.warning("{} empty task!".format(msg["id"]))
|
127 |
return []
|
128 |
|
129 |
tasks = pd.DataFrame(tasks)
|
|
|
184 |
docs = []
|
185 |
doc = {
|
186 |
"doc_id": row["doc_id"],
|
187 |
+
"kb_id": str(row["kb_id"])
|
188 |
}
|
189 |
el = 0
|
190 |
for ck in cks:
|
|
|
193 |
md5 = hashlib.md5()
|
194 |
md5.update((ck["content_with_weight"] +
|
195 |
str(d["doc_id"])).encode("utf-8"))
|
196 |
+
d["id"] = md5.hexdigest()
|
197 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
198 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
199 |
if not d.get("image"):
|
200 |
+
d["img_id"] = ""
|
201 |
+
d["page_num_list"] = json.dumps([])
|
202 |
+
d["position_list"] = json.dumps([])
|
203 |
+
d["top_list"] = json.dumps([])
|
204 |
docs.append(d)
|
205 |
continue
|
206 |
|
|
|
212 |
d["image"].save(output_buffer, format='JPEG')
|
213 |
|
214 |
st = timer()
|
215 |
+
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
216 |
el += timer() - st
|
217 |
except Exception as e:
|
218 |
cron_logger.error(str(e))
|
219 |
traceback.print_exc()
|
220 |
|
221 |
+
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
222 |
del d["image"]
|
223 |
docs.append(d)
|
224 |
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
|
|
|
246 |
return docs
|
247 |
|
248 |
|
249 |
+
def init_kb(row, vector_size: int):
|
250 |
idxnm = search.index_name(row["tenant_id"])
|
251 |
+
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
|
|
|
|
|
|
252 |
|
253 |
|
254 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
|
|
286 |
cnts) if len(tts) == len(cnts) else cnts
|
287 |
|
288 |
assert len(vects) == len(docs)
|
289 |
+
vector_size = 0
|
290 |
for i, d in enumerate(docs):
|
291 |
v = vects[i].tolist()
|
292 |
+
vector_size = len(v)
|
293 |
d["q_%d_vec" % len(v)] = v
|
294 |
+
return tk_count, vector_size
|
295 |
|
296 |
|
297 |
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
298 |
vts, _ = embd_mdl.encode(["ok"])
|
299 |
+
vector_size = len(vts[0])
|
300 |
+
vctr_nm = "q_%d_vec" % vector_size
|
301 |
chunks = []
|
302 |
+
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
|
303 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
304 |
|
305 |
raptor = Raptor(
|
|
|
324 |
d = copy.deepcopy(doc)
|
325 |
md5 = hashlib.md5()
|
326 |
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
327 |
+
d["id"] = md5.hexdigest()
|
328 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
329 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
330 |
d[vctr_nm] = vctr.tolist()
|
|
|
333 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
334 |
res.append(d)
|
335 |
tk_count += num_tokens_from_string(content)
|
336 |
+
return res, tk_count, vector_size
|
337 |
|
338 |
|
339 |
def main():
|
|
|
353 |
if r.get("task_type", "") == "raptor":
|
354 |
try:
|
355 |
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
356 |
+
cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
|
357 |
except Exception as e:
|
358 |
callback(-1, msg=str(e))
|
359 |
cron_logger.error(str(e))
|
|
|
374 |
len(cks))
|
375 |
st = timer()
|
376 |
try:
|
377 |
+
tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
|
378 |
except Exception as e:
|
379 |
callback(-1, "Embedding error:{}".format(str(e)))
|
380 |
cron_logger.error(str(e))
|
|
|
382 |
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
383 |
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
384 |
|
385 |
+
# cron_logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
|
386 |
+
init_kb(r, vector_size)
|
387 |
+
chunk_count = len(set([c["id"] for c in cks]))
|
388 |
st = timer()
|
389 |
es_r = ""
|
390 |
es_bulk_size = 4
|
391 |
for b in range(0, len(cks), es_bulk_size):
|
392 |
+
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
393 |
if b % 128 == 0:
|
394 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
395 |
|
396 |
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
397 |
if es_r:
|
398 |
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
399 |
+
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
400 |
+
cron_logger.error('Insert chunk error: ' + str(es_r))
|
|
|
401 |
else:
|
402 |
if TaskService.do_cancel(r["id"]):
|
403 |
+
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
|
|
404 |
continue
|
405 |
callback(1., "Done!")
|
406 |
DocumentService.increment_chunk_num(
|
rag/utils/doc_store_conn.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Optional, Union
|
3 |
+
from dataclasses import dataclass
|
4 |
+
import numpy as np
|
5 |
+
import polars as pl
|
6 |
+
from typing import List, Dict
|
7 |
+
|
8 |
+
DEFAULT_MATCH_VECTOR_TOPN = 10
|
9 |
+
DEFAULT_MATCH_SPARSE_TOPN = 10
|
10 |
+
VEC = Union[list, np.ndarray]
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class SparseVector:
|
15 |
+
indices: list[int]
|
16 |
+
values: Union[list[float], list[int], None] = None
|
17 |
+
|
18 |
+
def __post_init__(self):
|
19 |
+
assert (self.values is None) or (len(self.indices) == len(self.values))
|
20 |
+
|
21 |
+
def to_dict_old(self):
|
22 |
+
d = {"indices": self.indices}
|
23 |
+
if self.values is not None:
|
24 |
+
d["values"] = self.values
|
25 |
+
return d
|
26 |
+
|
27 |
+
def to_dict(self):
|
28 |
+
if self.values is None:
|
29 |
+
raise ValueError("SparseVector.values is None")
|
30 |
+
result = {}
|
31 |
+
for i, v in zip(self.indices, self.values):
|
32 |
+
result[str(i)] = v
|
33 |
+
return result
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def from_dict(d):
|
37 |
+
return SparseVector(d["indices"], d.get("values"))
|
38 |
+
|
39 |
+
def __str__(self):
|
40 |
+
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
41 |
+
|
42 |
+
def __repr__(self):
|
43 |
+
return str(self)
|
44 |
+
|
45 |
+
|
46 |
+
class MatchTextExpr(ABC):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
fields: str,
|
50 |
+
matching_text: str,
|
51 |
+
topn: int,
|
52 |
+
extra_options: dict = dict(),
|
53 |
+
):
|
54 |
+
self.fields = fields
|
55 |
+
self.matching_text = matching_text
|
56 |
+
self.topn = topn
|
57 |
+
self.extra_options = extra_options
|
58 |
+
|
59 |
+
|
60 |
+
class MatchDenseExpr(ABC):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
vector_column_name: str,
|
64 |
+
embedding_data: VEC,
|
65 |
+
embedding_data_type: str,
|
66 |
+
distance_type: str,
|
67 |
+
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
68 |
+
extra_options: dict = dict(),
|
69 |
+
):
|
70 |
+
self.vector_column_name = vector_column_name
|
71 |
+
self.embedding_data = embedding_data
|
72 |
+
self.embedding_data_type = embedding_data_type
|
73 |
+
self.distance_type = distance_type
|
74 |
+
self.topn = topn
|
75 |
+
self.extra_options = extra_options
|
76 |
+
|
77 |
+
|
78 |
+
class MatchSparseExpr(ABC):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
vector_column_name: str,
|
82 |
+
sparse_data: SparseVector | dict,
|
83 |
+
distance_type: str,
|
84 |
+
topn: int,
|
85 |
+
opt_params: Optional[dict] = None,
|
86 |
+
):
|
87 |
+
self.vector_column_name = vector_column_name
|
88 |
+
self.sparse_data = sparse_data
|
89 |
+
self.distance_type = distance_type
|
90 |
+
self.topn = topn
|
91 |
+
self.opt_params = opt_params
|
92 |
+
|
93 |
+
|
94 |
+
class MatchTensorExpr(ABC):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
column_name: str,
|
98 |
+
query_data: VEC,
|
99 |
+
query_data_type: str,
|
100 |
+
topn: int,
|
101 |
+
extra_option: Optional[dict] = None,
|
102 |
+
):
|
103 |
+
self.column_name = column_name
|
104 |
+
self.query_data = query_data
|
105 |
+
self.query_data_type = query_data_type
|
106 |
+
self.topn = topn
|
107 |
+
self.extra_option = extra_option
|
108 |
+
|
109 |
+
|
110 |
+
class FusionExpr(ABC):
|
111 |
+
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
|
112 |
+
self.method = method
|
113 |
+
self.topn = topn
|
114 |
+
self.fusion_params = fusion_params
|
115 |
+
|
116 |
+
|
117 |
+
MatchExpr = Union[
|
118 |
+
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
|
119 |
+
]
|
120 |
+
|
121 |
+
|
122 |
+
class OrderByExpr(ABC):
|
123 |
+
def __init__(self):
|
124 |
+
self.fields = list()
|
125 |
+
def asc(self, field: str):
|
126 |
+
self.fields.append((field, 0))
|
127 |
+
return self
|
128 |
+
def desc(self, field: str):
|
129 |
+
self.fields.append((field, 1))
|
130 |
+
return self
|
131 |
+
def fields(self):
|
132 |
+
return self.fields
|
133 |
+
|
134 |
+
class DocStoreConnection(ABC):
|
135 |
+
"""
|
136 |
+
Database operations
|
137 |
+
"""
|
138 |
+
|
139 |
+
@abstractmethod
|
140 |
+
def dbType(self) -> str:
|
141 |
+
"""
|
142 |
+
Return the type of the database.
|
143 |
+
"""
|
144 |
+
raise NotImplementedError("Not implemented")
|
145 |
+
|
146 |
+
@abstractmethod
|
147 |
+
def health(self) -> dict:
|
148 |
+
"""
|
149 |
+
Return the health status of the database.
|
150 |
+
"""
|
151 |
+
raise NotImplementedError("Not implemented")
|
152 |
+
|
153 |
+
"""
|
154 |
+
Table operations
|
155 |
+
"""
|
156 |
+
|
157 |
+
@abstractmethod
|
158 |
+
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
159 |
+
"""
|
160 |
+
Create an index with given name
|
161 |
+
"""
|
162 |
+
raise NotImplementedError("Not implemented")
|
163 |
+
|
164 |
+
@abstractmethod
|
165 |
+
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
166 |
+
"""
|
167 |
+
Delete an index with given name
|
168 |
+
"""
|
169 |
+
raise NotImplementedError("Not implemented")
|
170 |
+
|
171 |
+
@abstractmethod
|
172 |
+
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
173 |
+
"""
|
174 |
+
Check if an index with given name exists
|
175 |
+
"""
|
176 |
+
raise NotImplementedError("Not implemented")
|
177 |
+
|
178 |
+
"""
|
179 |
+
CRUD operations
|
180 |
+
"""
|
181 |
+
|
182 |
+
@abstractmethod
|
183 |
+
def search(
|
184 |
+
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
|
185 |
+
) -> list[dict] | pl.DataFrame:
|
186 |
+
"""
|
187 |
+
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
188 |
+
"""
|
189 |
+
raise NotImplementedError("Not implemented")
|
190 |
+
|
191 |
+
@abstractmethod
|
192 |
+
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
193 |
+
"""
|
194 |
+
Get single chunk with given id
|
195 |
+
"""
|
196 |
+
raise NotImplementedError("Not implemented")
|
197 |
+
|
198 |
+
@abstractmethod
|
199 |
+
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
200 |
+
"""
|
201 |
+
Update or insert a bulk of rows
|
202 |
+
"""
|
203 |
+
raise NotImplementedError("Not implemented")
|
204 |
+
|
205 |
+
@abstractmethod
|
206 |
+
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
207 |
+
"""
|
208 |
+
Update rows with given conjunctive equivalent filtering condition
|
209 |
+
"""
|
210 |
+
raise NotImplementedError("Not implemented")
|
211 |
+
|
212 |
+
@abstractmethod
|
213 |
+
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
214 |
+
"""
|
215 |
+
Delete rows with given conjunctive equivalent filtering condition
|
216 |
+
"""
|
217 |
+
raise NotImplementedError("Not implemented")
|
218 |
+
|
219 |
+
"""
|
220 |
+
Helper functions for search result
|
221 |
+
"""
|
222 |
+
|
223 |
+
@abstractmethod
|
224 |
+
def getTotal(self, res):
|
225 |
+
raise NotImplementedError("Not implemented")
|
226 |
+
|
227 |
+
@abstractmethod
|
228 |
+
def getChunkIds(self, res):
|
229 |
+
raise NotImplementedError("Not implemented")
|
230 |
+
|
231 |
+
@abstractmethod
|
232 |
+
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
233 |
+
raise NotImplementedError("Not implemented")
|
234 |
+
|
235 |
+
@abstractmethod
|
236 |
+
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
237 |
+
raise NotImplementedError("Not implemented")
|
238 |
+
|
239 |
+
@abstractmethod
|
240 |
+
def getAggregation(self, res, fieldnm: str):
|
241 |
+
raise NotImplementedError("Not implemented")
|
242 |
+
|
243 |
+
"""
|
244 |
+
SQL
|
245 |
+
"""
|
246 |
+
@abstractmethod
|
247 |
+
def sql(sql: str, fetch_size: int, format: str):
|
248 |
+
"""
|
249 |
+
Run the sql generated by text-to-sql
|
250 |
+
"""
|
251 |
+
raise NotImplementedError("Not implemented")
|
rag/utils/es_conn.py
CHANGED
@@ -1,29 +1,29 @@
|
|
1 |
import re
|
2 |
import json
|
3 |
import time
|
4 |
-
import
|
|
|
5 |
|
6 |
import elasticsearch
|
7 |
-
|
8 |
from elasticsearch import Elasticsearch
|
9 |
-
from elasticsearch_dsl import UpdateByQuery, Search, Index
|
10 |
-
from
|
|
|
11 |
from rag import settings
|
12 |
from rag.utils import singleton
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
|
16 |
|
17 |
@singleton
|
18 |
-
class ESConnection:
|
19 |
def __init__(self):
|
20 |
self.info = {}
|
21 |
-
self.conn()
|
22 |
-
self.idxnm = settings.ES.get("index_name", "")
|
23 |
-
if not self.es.ping():
|
24 |
-
raise Exception("Can't connect to ES cluster")
|
25 |
-
|
26 |
-
def conn(self):
|
27 |
for _ in range(10):
|
28 |
try:
|
29 |
self.es = Elasticsearch(
|
@@ -34,390 +34,317 @@ class ESConnection:
|
|
34 |
)
|
35 |
if self.es:
|
36 |
self.info = self.es.info()
|
37 |
-
|
38 |
break
|
39 |
except Exception as e:
|
40 |
-
|
41 |
time.sleep(1)
|
42 |
-
|
43 |
-
|
44 |
v = self.info.get("version", {"number": "5.6"})
|
45 |
v = v["number"].split(".")[0]
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
refresh=True,
|
68 |
-
retry_on_conflict=100)
|
69 |
-
else:
|
70 |
-
r = self.es.update(
|
71 |
-
index=(
|
72 |
-
self.idxnm if not idxnm else idxnm),
|
73 |
-
body=d,
|
74 |
-
id=id,
|
75 |
-
refresh=True,
|
76 |
-
retry_on_conflict=100)
|
77 |
-
es_logger.info("Successfully upsert: %s" % id)
|
78 |
-
T = True
|
79 |
-
break
|
80 |
-
except Exception as e:
|
81 |
-
es_logger.warning("Fail to index: " +
|
82 |
-
json.dumps(d, ensure_ascii=False) + str(e))
|
83 |
-
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
84 |
-
time.sleep(3)
|
85 |
-
continue
|
86 |
-
self.conn()
|
87 |
-
T = False
|
88 |
-
|
89 |
-
if not T:
|
90 |
-
res.append(d)
|
91 |
-
es_logger.error(
|
92 |
-
"Fail to index: " +
|
93 |
-
re.sub(
|
94 |
-
"[\r\n]",
|
95 |
-
"",
|
96 |
-
json.dumps(
|
97 |
-
d,
|
98 |
-
ensure_ascii=False)))
|
99 |
-
d["id"] = id
|
100 |
-
d["_index"] = self.idxnm
|
101 |
-
|
102 |
-
if not res:
|
103 |
return True
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
|
112 |
-
if "id" in d:
|
113 |
-
del d["id"]
|
114 |
-
if "_id" in d:
|
115 |
-
del d["_id"]
|
116 |
-
acts.append(
|
117 |
-
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
|
118 |
-
acts.append({"doc": d, "doc_as_upsert": "true"})
|
119 |
-
|
120 |
-
res = []
|
121 |
-
for _ in range(100):
|
122 |
-
try:
|
123 |
-
if elasticsearch.__version__[0] < 8:
|
124 |
-
r = self.es.bulk(
|
125 |
-
index=(
|
126 |
-
self.idxnm if not idx_nm else idx_nm),
|
127 |
-
body=acts,
|
128 |
-
refresh=False,
|
129 |
-
timeout="600s")
|
130 |
-
else:
|
131 |
-
r = self.es.bulk(index=(self.idxnm if not idx_nm else
|
132 |
-
idx_nm), operations=acts,
|
133 |
-
refresh=False, timeout="600s")
|
134 |
-
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
135 |
-
return res
|
136 |
-
|
137 |
-
for it in r["items"]:
|
138 |
-
if "error" in it["update"]:
|
139 |
-
res.append(str(it["update"]["_id"]) +
|
140 |
-
":" + str(it["update"]["error"]))
|
141 |
-
|
142 |
-
return res
|
143 |
-
except Exception as e:
|
144 |
-
es_logger.warn("Fail to bulk: " + str(e))
|
145 |
-
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
146 |
-
time.sleep(3)
|
147 |
-
continue
|
148 |
-
self.conn()
|
149 |
-
|
150 |
-
return res
|
151 |
-
|
152 |
-
def bulk4script(self, df):
|
153 |
-
ids, acts = {}, []
|
154 |
-
for d in df:
|
155 |
-
id = d["id"]
|
156 |
-
ids[id] = copy.deepcopy(d["raw"])
|
157 |
-
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
158 |
-
acts.append(d["script"])
|
159 |
-
es_logger.info("bulk upsert: %s" % id)
|
160 |
-
|
161 |
-
res = []
|
162 |
-
for _ in range(10):
|
163 |
-
try:
|
164 |
-
if not self.version():
|
165 |
-
r = self.es.bulk(
|
166 |
-
index=self.idxnm,
|
167 |
-
body=acts,
|
168 |
-
refresh=False,
|
169 |
-
timeout="600s",
|
170 |
-
doc_type="doc")
|
171 |
-
else:
|
172 |
-
r = self.es.bulk(
|
173 |
-
index=self.idxnm,
|
174 |
-
body=acts,
|
175 |
-
refresh=False,
|
176 |
-
timeout="600s")
|
177 |
-
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
178 |
-
return res
|
179 |
-
|
180 |
-
for it in r["items"]:
|
181 |
-
if "error" in it["update"]:
|
182 |
-
res.append(str(it["update"]["_id"]))
|
183 |
-
|
184 |
-
return res
|
185 |
-
except Exception as e:
|
186 |
-
es_logger.warning("Fail to bulk: " + str(e))
|
187 |
-
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
188 |
-
time.sleep(3)
|
189 |
-
continue
|
190 |
-
self.conn()
|
191 |
|
192 |
-
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
def
|
195 |
-
|
|
|
196 |
try:
|
197 |
-
|
198 |
-
r = self.es.delete(
|
199 |
-
index=self.idxnm,
|
200 |
-
id=d["id"],
|
201 |
-
doc_type="doc",
|
202 |
-
refresh=True)
|
203 |
-
else:
|
204 |
-
r = self.es.delete(
|
205 |
-
index=self.idxnm,
|
206 |
-
id=d["id"],
|
207 |
-
refresh=True,
|
208 |
-
doc_type="_doc")
|
209 |
-
es_logger.info("Remove %s" % d["id"])
|
210 |
-
return True
|
211 |
except Exception as e:
|
212 |
-
|
213 |
-
if
|
214 |
-
time.sleep(3)
|
215 |
continue
|
216 |
-
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
217 |
-
return True
|
218 |
-
self.conn()
|
219 |
-
|
220 |
-
es_logger.error("Fail to delete: " + str(d))
|
221 |
-
|
222 |
return False
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
for i in range(3):
|
230 |
try:
|
231 |
-
res = self.es.search(index=
|
232 |
body=q,
|
233 |
-
timeout=
|
234 |
# search_type="dfs_query_then_fetch",
|
235 |
track_total_hits=True,
|
236 |
-
_source=
|
237 |
if str(res.get("timed_out", "")).lower() == "true":
|
238 |
raise Exception("Es Timeout.")
|
|
|
239 |
return res
|
240 |
except Exception as e:
|
241 |
-
|
242 |
"ES search exception: " +
|
243 |
str(e) +
|
244 |
-
"
|
245 |
str(q))
|
246 |
if str(e).find("Timeout") > 0:
|
247 |
continue
|
248 |
raise e
|
249 |
-
|
250 |
raise Exception("ES search timeout.")
|
251 |
|
252 |
-
def
|
253 |
-
for i in range(3):
|
254 |
-
try:
|
255 |
-
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
|
256 |
-
return res
|
257 |
-
except ConnectionTimeout as e:
|
258 |
-
es_logger.error("Timeout【Q】:" + sql)
|
259 |
-
continue
|
260 |
-
except Exception as e:
|
261 |
-
raise e
|
262 |
-
es_logger.error("ES search timeout for 3 times!")
|
263 |
-
raise ConnectionTimeout()
|
264 |
-
|
265 |
-
|
266 |
-
def get(self, doc_id, idxnm=None):
|
267 |
for i in range(3):
|
268 |
try:
|
269 |
-
res = self.es.get(index=(
|
270 |
-
|
271 |
if str(res.get("timed_out", "")).lower() == "true":
|
272 |
raise Exception("Es Timeout.")
|
273 |
-
|
|
|
|
|
|
|
|
|
274 |
except Exception as e:
|
275 |
-
|
276 |
"ES get exception: " +
|
277 |
str(e) +
|
278 |
-
"
|
279 |
-
|
280 |
if str(e).find("Timeout") > 0:
|
281 |
continue
|
282 |
raise e
|
283 |
-
|
284 |
raise Exception("ES search timeout.")
|
285 |
|
286 |
-
def
|
287 |
-
|
288 |
-
|
289 |
-
for
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
return True
|
299 |
-
except Exception as e:
|
300 |
-
es_logger.error("ES updateByQuery exception: " +
|
301 |
-
str(e) + "【Q】:" + str(q.to_dict()))
|
302 |
-
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
303 |
-
continue
|
304 |
-
self.conn()
|
305 |
-
|
306 |
-
return False
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
index=self.idxnm if not idxnm else idxnm).using(
|
311 |
-
self.es).query(q)
|
312 |
-
ubq = ubq.script(source=scripts)
|
313 |
-
ubq = ubq.params(refresh=True)
|
314 |
-
ubq = ubq.params(slices=5)
|
315 |
-
ubq = ubq.params(conflicts="proceed")
|
316 |
-
for i in range(3):
|
317 |
try:
|
318 |
-
r =
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
str(e) + "【Q】:" + str(q.to_dict()))
|
323 |
-
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
324 |
-
continue
|
325 |
-
self.conn()
|
326 |
-
|
327 |
-
return False
|
328 |
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
refresh = True,
|
335 |
-
body=Search().query(query).to_dict())
|
336 |
-
return True
|
337 |
except Exception as e:
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
342 |
continue
|
|
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
"ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
366 |
-
json.dumps(script, ensure_ascii=False))
|
367 |
-
if str(e).find("Timeout") > 0:
|
368 |
continue
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
es_logger.error("ES updateByQuery indexExist: " + str(e))
|
379 |
-
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
380 |
continue
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
return False
|
383 |
|
384 |
-
def
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
try:
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
389 |
except Exception as e:
|
390 |
-
|
391 |
-
if
|
|
|
392 |
continue
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
try:
|
397 |
-
if elasticsearch.__version__[0] < 8:
|
398 |
-
return self.es.indices.create(idxnm, body=mapping)
|
399 |
-
from elasticsearch.client import IndicesClient
|
400 |
-
return IndicesClient(self.es).create(index=idxnm,
|
401 |
-
settings=mapping["settings"],
|
402 |
-
mappings=mapping["mappings"])
|
403 |
-
except Exception as e:
|
404 |
-
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
405 |
|
406 |
-
def deleteIdx(self, idxnm):
|
407 |
-
try:
|
408 |
-
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
409 |
-
except Exception as e:
|
410 |
-
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
411 |
|
|
|
|
|
|
|
412 |
def getTotal(self, res):
|
413 |
if isinstance(res["hits"]["total"], type({})):
|
414 |
return res["hits"]["total"]["value"]
|
415 |
return res["hits"]["total"]
|
416 |
|
417 |
-
def
|
418 |
return [d["_id"] for d in res["hits"]["hits"]]
|
419 |
|
420 |
-
def
|
421 |
rr = []
|
422 |
for d in res["hits"]["hits"]:
|
423 |
d["_source"]["id"] = d["_id"]
|
@@ -425,40 +352,89 @@ class ESConnection:
|
|
425 |
rr.append(d["_source"])
|
426 |
return rr
|
427 |
|
428 |
-
def
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
)
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
time.sleep(3)
|
443 |
-
|
444 |
-
sid = page['_scroll_id']
|
445 |
-
scroll_size = page['hits']['total']["value"]
|
446 |
-
es_logger.info("[TOTAL]%d" % scroll_size)
|
447 |
-
# Start scrolling
|
448 |
-
while scroll_size > 0:
|
449 |
-
yield page["hits"]["hits"]
|
450 |
-
for _ in range(100):
|
451 |
-
try:
|
452 |
-
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
453 |
-
break
|
454 |
-
except Exception as e:
|
455 |
-
es_logger.error("ES scrolling fail. " + str(e))
|
456 |
-
time.sleep(3)
|
457 |
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
scroll_size = len(page['hits']['hits'])
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
import json
|
3 |
import time
|
4 |
+
import os
|
5 |
+
from typing import List, Dict
|
6 |
|
7 |
import elasticsearch
|
8 |
+
import copy
|
9 |
from elasticsearch import Elasticsearch
|
10 |
+
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
11 |
+
from elastic_transport import ConnectionTimeout
|
12 |
+
from rag.settings import doc_store_logger
|
13 |
from rag import settings
|
14 |
from rag.utils import singleton
|
15 |
+
from api.utils.file_utils import get_project_base_directory
|
16 |
+
import polars as pl
|
17 |
+
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
18 |
+
from rag.nlp import is_english, rag_tokenizer
|
19 |
|
20 |
+
doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
|
21 |
|
22 |
|
23 |
@singleton
|
24 |
+
class ESConnection(DocStoreConnection):
|
25 |
def __init__(self):
|
26 |
self.info = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
for _ in range(10):
|
28 |
try:
|
29 |
self.es = Elasticsearch(
|
|
|
34 |
)
|
35 |
if self.es:
|
36 |
self.info = self.es.info()
|
37 |
+
doc_store_logger.info("Connect to es.")
|
38 |
break
|
39 |
except Exception as e:
|
40 |
+
doc_store_logger.error("Fail to connect to es: " + str(e))
|
41 |
time.sleep(1)
|
42 |
+
if not self.es.ping():
|
43 |
+
raise Exception("Can't connect to ES cluster")
|
44 |
v = self.info.get("version", {"number": "5.6"})
|
45 |
v = v["number"].split(".")[0]
|
46 |
+
if int(v) < 8:
|
47 |
+
raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
|
48 |
+
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
|
49 |
+
if not os.path.exists(fp_mapping):
|
50 |
+
raise Exception(f"Mapping file not found at {fp_mapping}")
|
51 |
+
self.mapping = json.load(open(fp_mapping, "r"))
|
52 |
+
|
53 |
+
"""
|
54 |
+
Database operations
|
55 |
+
"""
|
56 |
+
def dbType(self) -> str:
|
57 |
+
return "elasticsearch"
|
58 |
+
|
59 |
+
def health(self) -> dict:
|
60 |
+
return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
|
61 |
+
|
62 |
+
"""
|
63 |
+
Table operations
|
64 |
+
"""
|
65 |
+
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
66 |
+
if self.indexExist(indexName, knowledgebaseId):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
return True
|
68 |
+
try:
|
69 |
+
from elasticsearch.client import IndicesClient
|
70 |
+
return IndicesClient(self.es).create(index=indexName,
|
71 |
+
settings=self.mapping["settings"],
|
72 |
+
mappings=self.mapping["mappings"])
|
73 |
+
except Exception as e:
|
74 |
+
doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
77 |
+
try:
|
78 |
+
return self.es.indices.delete(indexName, allow_no_indices=True)
|
79 |
+
except Exception as e:
|
80 |
+
doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
|
81 |
|
82 |
+
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
83 |
+
s = Index(indexName, self.es)
|
84 |
+
for i in range(3):
|
85 |
try:
|
86 |
+
return s.exists()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
except Exception as e:
|
88 |
+
doc_store_logger.error("ES indexExist: " + str(e))
|
89 |
+
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
|
|
90 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
return False
|
92 |
|
93 |
+
"""
|
94 |
+
CRUD operations
|
95 |
+
"""
|
96 |
+
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
97 |
+
"""
|
98 |
+
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
99 |
+
"""
|
100 |
+
if isinstance(indexNames, str):
|
101 |
+
indexNames = indexNames.split(",")
|
102 |
+
assert isinstance(indexNames, list) and len(indexNames) > 0
|
103 |
+
assert "_id" not in condition
|
104 |
+
s = Search()
|
105 |
+
bqry = None
|
106 |
+
vector_similarity_weight = 0.5
|
107 |
+
for m in matchExprs:
|
108 |
+
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
|
109 |
+
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
|
110 |
+
weights = m.fusion_params["weights"]
|
111 |
+
vector_similarity_weight = float(weights.split(",")[1])
|
112 |
+
for m in matchExprs:
|
113 |
+
if isinstance(m, MatchTextExpr):
|
114 |
+
minimum_should_match = "0%"
|
115 |
+
if "minimum_should_match" in m.extra_options:
|
116 |
+
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
|
117 |
+
bqry = Q("bool",
|
118 |
+
must=Q("query_string", fields=m.fields,
|
119 |
+
type="best_fields", query=m.matching_text,
|
120 |
+
minimum_should_match = minimum_should_match,
|
121 |
+
boost=1),
|
122 |
+
boost = 1.0 - vector_similarity_weight,
|
123 |
+
)
|
124 |
+
if condition:
|
125 |
+
for k, v in condition.items():
|
126 |
+
if not isinstance(k, str) or not v:
|
127 |
+
continue
|
128 |
+
if isinstance(v, list):
|
129 |
+
bqry.filter.append(Q("terms", **{k: v}))
|
130 |
+
elif isinstance(v, str) or isinstance(v, int):
|
131 |
+
bqry.filter.append(Q("term", **{k: v}))
|
132 |
+
else:
|
133 |
+
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
134 |
+
elif isinstance(m, MatchDenseExpr):
|
135 |
+
assert(bqry is not None)
|
136 |
+
similarity = 0.0
|
137 |
+
if "similarity" in m.extra_options:
|
138 |
+
similarity = m.extra_options["similarity"]
|
139 |
+
s = s.knn(m.vector_column_name,
|
140 |
+
m.topn,
|
141 |
+
m.topn * 2,
|
142 |
+
query_vector = list(m.embedding_data),
|
143 |
+
filter = bqry.to_dict(),
|
144 |
+
similarity = similarity,
|
145 |
+
)
|
146 |
+
if matchExprs:
|
147 |
+
s.query = bqry
|
148 |
+
for field in highlightFields:
|
149 |
+
s = s.highlight(field)
|
150 |
+
|
151 |
+
if orderBy:
|
152 |
+
orders = list()
|
153 |
+
for field, order in orderBy.fields:
|
154 |
+
order = "asc" if order == 0 else "desc"
|
155 |
+
orders.append({field: {"order": order, "unmapped_type": "float",
|
156 |
+
"mode": "avg", "numeric_type": "double"}})
|
157 |
+
s = s.sort(*orders)
|
158 |
+
|
159 |
+
if limit > 0:
|
160 |
+
s = s[offset:limit]
|
161 |
+
q = s.to_dict()
|
162 |
+
doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
163 |
+
|
164 |
for i in range(3):
|
165 |
try:
|
166 |
+
res = self.es.search(index=indexNames,
|
167 |
body=q,
|
168 |
+
timeout="600s",
|
169 |
# search_type="dfs_query_then_fetch",
|
170 |
track_total_hits=True,
|
171 |
+
_source=True)
|
172 |
if str(res.get("timed_out", "")).lower() == "true":
|
173 |
raise Exception("Es Timeout.")
|
174 |
+
doc_store_logger.info("ESConnection.search res: " + str(res))
|
175 |
return res
|
176 |
except Exception as e:
|
177 |
+
doc_store_logger.error(
|
178 |
"ES search exception: " +
|
179 |
str(e) +
|
180 |
+
"\n[Q]: " +
|
181 |
str(q))
|
182 |
if str(e).find("Timeout") > 0:
|
183 |
continue
|
184 |
raise e
|
185 |
+
doc_store_logger.error("ES search timeout for 3 times!")
|
186 |
raise Exception("ES search timeout.")
|
187 |
|
188 |
+
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
for i in range(3):
|
190 |
try:
|
191 |
+
res = self.es.get(index=(indexName),
|
192 |
+
id=chunkId, source=True,)
|
193 |
if str(res.get("timed_out", "")).lower() == "true":
|
194 |
raise Exception("Es Timeout.")
|
195 |
+
if not res.get("found"):
|
196 |
+
return None
|
197 |
+
chunk = res["_source"]
|
198 |
+
chunk["id"] = chunkId
|
199 |
+
return chunk
|
200 |
except Exception as e:
|
201 |
+
doc_store_logger.error(
|
202 |
"ES get exception: " +
|
203 |
str(e) +
|
204 |
+
"[Q]: " +
|
205 |
+
chunkId)
|
206 |
if str(e).find("Timeout") > 0:
|
207 |
continue
|
208 |
raise e
|
209 |
+
doc_store_logger.error("ES search timeout for 3 times!")
|
210 |
raise Exception("ES search timeout.")
|
211 |
|
212 |
+
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
213 |
+
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
214 |
+
operations = []
|
215 |
+
for d in documents:
|
216 |
+
assert "_id" not in d
|
217 |
+
assert "id" in d
|
218 |
+
d_copy = copy.deepcopy(d)
|
219 |
+
meta_id = d_copy["id"]
|
220 |
+
del d_copy["id"]
|
221 |
+
operations.append(
|
222 |
+
{"index": {"_index": indexName, "_id": meta_id}})
|
223 |
+
operations.append(d_copy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
+
res = []
|
226 |
+
for _ in range(100):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
try:
|
228 |
+
r = self.es.bulk(index=(indexName), operations=operations,
|
229 |
+
refresh=False, timeout="600s")
|
230 |
+
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
231 |
+
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
+
for item in r["items"]:
|
234 |
+
for action in ["create", "delete", "index", "update"]:
|
235 |
+
if action in item and "error" in item[action]:
|
236 |
+
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
237 |
+
return res
|
|
|
|
|
|
|
238 |
except Exception as e:
|
239 |
+
doc_store_logger.warning("Fail to bulk: " + str(e))
|
240 |
+
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
241 |
+
time.sleep(3)
|
|
|
242 |
continue
|
243 |
+
return res
|
244 |
|
245 |
+
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
246 |
+
doc = copy.deepcopy(newValue)
|
247 |
+
del doc['id']
|
248 |
+
if "id" in condition and isinstance(condition["id"], str):
|
249 |
+
# update specific single document
|
250 |
+
chunkId = condition["id"]
|
251 |
+
for i in range(3):
|
252 |
+
try:
|
253 |
+
self.es.update(index=indexName, id=chunkId, doc=doc)
|
254 |
+
return True
|
255 |
+
except Exception as e:
|
256 |
+
doc_store_logger.error(
|
257 |
+
"ES update exception: " + str(e) + " id:" + str(id) +
|
258 |
+
json.dumps(newValue, ensure_ascii=False))
|
259 |
+
if str(e).find("Timeout") > 0:
|
260 |
+
continue
|
261 |
+
else:
|
262 |
+
# update unspecific maybe-multiple documents
|
263 |
+
bqry = Q("bool")
|
264 |
+
for k, v in condition.items():
|
265 |
+
if not isinstance(k, str) or not v:
|
|
|
|
|
|
|
266 |
continue
|
267 |
+
if isinstance(v, list):
|
268 |
+
bqry.filter.append(Q("terms", **{k: v}))
|
269 |
+
elif isinstance(v, str) or isinstance(v, int):
|
270 |
+
bqry.filter.append(Q("term", **{k: v}))
|
271 |
+
else:
|
272 |
+
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
273 |
+
scripts = []
|
274 |
+
for k, v in newValue.items():
|
275 |
+
if not isinstance(k, str) or not v:
|
|
|
|
|
276 |
continue
|
277 |
+
if isinstance(v, str):
|
278 |
+
scripts.append(f"ctx._source.{k} = '{v}'")
|
279 |
+
elif isinstance(v, int):
|
280 |
+
scripts.append(f"ctx._source.{k} = {v}")
|
281 |
+
else:
|
282 |
+
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
283 |
+
ubq = UpdateByQuery(
|
284 |
+
index=indexName).using(
|
285 |
+
self.es).query(bqry)
|
286 |
+
ubq = ubq.script(source="; ".join(scripts))
|
287 |
+
ubq = ubq.params(refresh=True)
|
288 |
+
ubq = ubq.params(slices=5)
|
289 |
+
ubq = ubq.params(conflicts="proceed")
|
290 |
+
for i in range(3):
|
291 |
+
try:
|
292 |
+
_ = ubq.execute()
|
293 |
+
return True
|
294 |
+
except Exception as e:
|
295 |
+
doc_store_logger.error("ES update exception: " +
|
296 |
+
str(e) + "[Q]:" + str(bqry.to_dict()))
|
297 |
+
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
298 |
+
continue
|
299 |
return False
|
300 |
|
301 |
+
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
302 |
+
qry = None
|
303 |
+
assert "_id" not in condition
|
304 |
+
if "id" in condition:
|
305 |
+
chunk_ids = condition["id"]
|
306 |
+
if not isinstance(chunk_ids, list):
|
307 |
+
chunk_ids = [chunk_ids]
|
308 |
+
qry = Q("ids", values=chunk_ids)
|
309 |
+
else:
|
310 |
+
qry = Q("bool")
|
311 |
+
for k, v in condition.items():
|
312 |
+
if isinstance(v, list):
|
313 |
+
qry.must.append(Q("terms", **{k: v}))
|
314 |
+
elif isinstance(v, str) or isinstance(v, int):
|
315 |
+
qry.must.append(Q("term", **{k: v}))
|
316 |
+
else:
|
317 |
+
raise Exception("Condition value must be int, str or list.")
|
318 |
+
doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
|
319 |
+
for _ in range(10):
|
320 |
try:
|
321 |
+
res = self.es.delete_by_query(
|
322 |
+
index=indexName,
|
323 |
+
body = Search().query(qry).to_dict(),
|
324 |
+
refresh=True)
|
325 |
+
return res["deleted"]
|
326 |
except Exception as e:
|
327 |
+
doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
|
328 |
+
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
329 |
+
time.sleep(3)
|
330 |
continue
|
331 |
+
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
332 |
+
return 0
|
333 |
+
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
|
|
|
|
|
|
|
|
|
|
335 |
|
336 |
+
"""
|
337 |
+
Helper functions for search result
|
338 |
+
"""
|
339 |
def getTotal(self, res):
|
340 |
if isinstance(res["hits"]["total"], type({})):
|
341 |
return res["hits"]["total"]["value"]
|
342 |
return res["hits"]["total"]
|
343 |
|
344 |
+
def getChunkIds(self, res):
|
345 |
return [d["_id"] for d in res["hits"]["hits"]]
|
346 |
|
347 |
+
def __getSource(self, res):
|
348 |
rr = []
|
349 |
for d in res["hits"]["hits"]:
|
350 |
d["_source"]["id"] = d["_id"]
|
|
|
352 |
rr.append(d["_source"])
|
353 |
return rr
|
354 |
|
355 |
+
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
356 |
+
res_fields = {}
|
357 |
+
if not fields:
|
358 |
+
return {}
|
359 |
+
for d in self.__getSource(res):
|
360 |
+
m = {n: d.get(n) for n in fields if d.get(n) is not None}
|
361 |
+
for n, v in m.items():
|
362 |
+
if isinstance(v, list):
|
363 |
+
m[n] = v
|
364 |
+
continue
|
365 |
+
if not isinstance(v, str):
|
366 |
+
m[n] = str(m[n])
|
367 |
+
# if n.find("tks") > 0:
|
368 |
+
# m[n] = rmSpace(m[n])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
+
if m:
|
371 |
+
res_fields[d["id"]] = m
|
372 |
+
return res_fields
|
|
|
373 |
|
374 |
+
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
375 |
+
ans = {}
|
376 |
+
for d in res["hits"]["hits"]:
|
377 |
+
hlts = d.get("highlight")
|
378 |
+
if not hlts:
|
379 |
+
continue
|
380 |
+
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
381 |
+
if not is_english(txt.split(" ")):
|
382 |
+
ans[d["_id"]] = txt
|
383 |
+
continue
|
384 |
+
|
385 |
+
txt = d["_source"][fieldnm]
|
386 |
+
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
387 |
+
txts = []
|
388 |
+
for t in re.split(r"[.?!;\n]", txt):
|
389 |
+
for w in keywords:
|
390 |
+
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
391 |
+
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
|
392 |
+
continue
|
393 |
+
txts.append(t)
|
394 |
+
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
395 |
+
|
396 |
+
return ans
|
397 |
+
|
398 |
+
def getAggregation(self, res, fieldnm: str):
|
399 |
+
agg_field = "aggs_" + fieldnm
|
400 |
+
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
401 |
+
return list()
|
402 |
+
bkts = res["aggregations"][agg_field]["buckets"]
|
403 |
+
return [(b["key"], b["doc_count"]) for b in bkts]
|
404 |
+
|
405 |
+
|
406 |
+
"""
|
407 |
+
SQL
|
408 |
+
"""
|
409 |
+
def sql(self, sql: str, fetch_size: int, format: str):
|
410 |
+
doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
|
411 |
+
sql = re.sub(r"[ `]+", " ", sql)
|
412 |
+
sql = sql.replace("%", "")
|
413 |
+
replaces = []
|
414 |
+
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
415 |
+
fld, v = r.group(1), r.group(3)
|
416 |
+
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
417 |
+
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
418 |
+
replaces.append(
|
419 |
+
("{}{}'{}'".format(
|
420 |
+
r.group(1),
|
421 |
+
r.group(2),
|
422 |
+
r.group(3)),
|
423 |
+
match))
|
424 |
+
|
425 |
+
for p, r in replaces:
|
426 |
+
sql = sql.replace(p, r, 1)
|
427 |
+
doc_store_logger.info(f"ESConnection.sql to es: {sql}")
|
428 |
|
429 |
+
for i in range(3):
|
430 |
+
try:
|
431 |
+
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
|
432 |
+
return res
|
433 |
+
except ConnectionTimeout:
|
434 |
+
doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
|
435 |
+
continue
|
436 |
+
except Exception as e:
|
437 |
+
doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
|
438 |
+
return None
|
439 |
+
doc_store_logger.error("ESConnection.sql timeout for 3 times!")
|
440 |
+
return None
|
rag/utils/infinity_conn.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
from typing import List, Dict
|
5 |
+
import infinity
|
6 |
+
from infinity.common import ConflictType, InfinityException
|
7 |
+
from infinity.index import IndexInfo, IndexType
|
8 |
+
from infinity.connection_pool import ConnectionPool
|
9 |
+
from rag import settings
|
10 |
+
from rag.settings import doc_store_logger
|
11 |
+
from rag.utils import singleton
|
12 |
+
import polars as pl
|
13 |
+
from polars.series.series import Series
|
14 |
+
from api.utils.file_utils import get_project_base_directory
|
15 |
+
|
16 |
+
from rag.utils.doc_store_conn import (
|
17 |
+
DocStoreConnection,
|
18 |
+
MatchExpr,
|
19 |
+
MatchTextExpr,
|
20 |
+
MatchDenseExpr,
|
21 |
+
FusionExpr,
|
22 |
+
OrderByExpr,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def equivalent_condition_to_str(condition: dict) -> str:
|
27 |
+
assert "_id" not in condition
|
28 |
+
cond = list()
|
29 |
+
for k, v in condition.items():
|
30 |
+
if not isinstance(k, str) or not v:
|
31 |
+
continue
|
32 |
+
if isinstance(v, list):
|
33 |
+
inCond = list()
|
34 |
+
for item in v:
|
35 |
+
if isinstance(item, str):
|
36 |
+
inCond.append(f"'{item}'")
|
37 |
+
else:
|
38 |
+
inCond.append(str(item))
|
39 |
+
if inCond:
|
40 |
+
strInCond = ", ".join(inCond)
|
41 |
+
strInCond = f"{k} IN ({strInCond})"
|
42 |
+
cond.append(strInCond)
|
43 |
+
elif isinstance(v, str):
|
44 |
+
cond.append(f"{k}='{v}'")
|
45 |
+
else:
|
46 |
+
cond.append(f"{k}={str(v)}")
|
47 |
+
return " AND ".join(cond)
|
48 |
+
|
49 |
+
|
50 |
+
@singleton
|
51 |
+
class InfinityConnection(DocStoreConnection):
|
52 |
+
def __init__(self):
|
53 |
+
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
54 |
+
infinity_uri = settings.INFINITY["uri"]
|
55 |
+
if ":" in infinity_uri:
|
56 |
+
host, port = infinity_uri.split(":")
|
57 |
+
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
58 |
+
self.connPool = ConnectionPool(infinity_uri)
|
59 |
+
doc_store_logger.info(f"Connected to infinity {infinity_uri}.")
|
60 |
+
|
61 |
+
"""
|
62 |
+
Database operations
|
63 |
+
"""
|
64 |
+
|
65 |
+
def dbType(self) -> str:
|
66 |
+
return "infinity"
|
67 |
+
|
68 |
+
def health(self) -> dict:
|
69 |
+
"""
|
70 |
+
Return the health status of the database.
|
71 |
+
TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
|
72 |
+
"""
|
73 |
+
inf_conn = self.connPool.get_conn()
|
74 |
+
res = infinity.show_current_node()
|
75 |
+
self.connPool.release_conn(inf_conn)
|
76 |
+
color = "green" if res.error_code == 0 else "red"
|
77 |
+
res2 = {
|
78 |
+
"type": "infinity",
|
79 |
+
"status": f"{res.role} {color}",
|
80 |
+
"error": res.error_msg,
|
81 |
+
}
|
82 |
+
return res2
|
83 |
+
|
84 |
+
"""
|
85 |
+
Table operations
|
86 |
+
"""
|
87 |
+
|
88 |
+
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
89 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
90 |
+
inf_conn = self.connPool.get_conn()
|
91 |
+
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
92 |
+
|
93 |
+
fp_mapping = os.path.join(
|
94 |
+
get_project_base_directory(), "conf", "infinity_mapping.json"
|
95 |
+
)
|
96 |
+
if not os.path.exists(fp_mapping):
|
97 |
+
raise Exception(f"Mapping file not found at {fp_mapping}")
|
98 |
+
schema = json.load(open(fp_mapping))
|
99 |
+
vector_name = f"q_{vectorSize}_vec"
|
100 |
+
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
|
101 |
+
inf_table = inf_db.create_table(
|
102 |
+
table_name,
|
103 |
+
schema,
|
104 |
+
ConflictType.Ignore,
|
105 |
+
)
|
106 |
+
inf_table.create_index(
|
107 |
+
"q_vec_idx",
|
108 |
+
IndexInfo(
|
109 |
+
vector_name,
|
110 |
+
IndexType.Hnsw,
|
111 |
+
{
|
112 |
+
"M": "16",
|
113 |
+
"ef_construction": "50",
|
114 |
+
"metric": "cosine",
|
115 |
+
"encode": "lvq",
|
116 |
+
},
|
117 |
+
),
|
118 |
+
ConflictType.Ignore,
|
119 |
+
)
|
120 |
+
text_suffix = ["_tks", "_ltks", "_kwd"]
|
121 |
+
for field_name, field_info in schema.items():
|
122 |
+
if field_info["type"] != "varchar":
|
123 |
+
continue
|
124 |
+
for suffix in text_suffix:
|
125 |
+
if field_name.endswith(suffix):
|
126 |
+
inf_table.create_index(
|
127 |
+
f"text_idx_{field_name}",
|
128 |
+
IndexInfo(
|
129 |
+
field_name, IndexType.FullText, {"ANALYZER": "standard"}
|
130 |
+
),
|
131 |
+
ConflictType.Ignore,
|
132 |
+
)
|
133 |
+
break
|
134 |
+
self.connPool.release_conn(inf_conn)
|
135 |
+
doc_store_logger.info(
|
136 |
+
f"INFINITY created table {table_name}, vector size {vectorSize}"
|
137 |
+
)
|
138 |
+
|
139 |
+
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
140 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
141 |
+
inf_conn = self.connPool.get_conn()
|
142 |
+
db_instance = inf_conn.get_database(self.dbName)
|
143 |
+
db_instance.drop_table(table_name, ConflictType.Ignore)
|
144 |
+
self.connPool.release_conn(inf_conn)
|
145 |
+
doc_store_logger.info(f"INFINITY dropped table {table_name}")
|
146 |
+
|
147 |
+
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
148 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
149 |
+
try:
|
150 |
+
inf_conn = self.connPool.get_conn()
|
151 |
+
db_instance = inf_conn.get_database(self.dbName)
|
152 |
+
_ = db_instance.get_table(table_name)
|
153 |
+
self.connPool.release_conn(inf_conn)
|
154 |
+
return True
|
155 |
+
except Exception as e:
|
156 |
+
doc_store_logger.error("INFINITY indexExist: " + str(e))
|
157 |
+
return False
|
158 |
+
|
159 |
+
"""
|
160 |
+
CRUD operations
|
161 |
+
"""
|
162 |
+
|
163 |
+
def search(
|
164 |
+
self,
|
165 |
+
selectFields: list[str],
|
166 |
+
highlightFields: list[str],
|
167 |
+
condition: dict,
|
168 |
+
matchExprs: list[MatchExpr],
|
169 |
+
orderBy: OrderByExpr,
|
170 |
+
offset: int,
|
171 |
+
limit: int,
|
172 |
+
indexNames: str|list[str],
|
173 |
+
knowledgebaseIds: list[str],
|
174 |
+
) -> list[dict] | pl.DataFrame:
|
175 |
+
"""
|
176 |
+
TODO: Infinity doesn't provide highlight
|
177 |
+
"""
|
178 |
+
if isinstance(indexNames, str):
|
179 |
+
indexNames = indexNames.split(",")
|
180 |
+
assert isinstance(indexNames, list) and len(indexNames) > 0
|
181 |
+
inf_conn = self.connPool.get_conn()
|
182 |
+
db_instance = inf_conn.get_database(self.dbName)
|
183 |
+
df_list = list()
|
184 |
+
table_list = list()
|
185 |
+
if "id" not in selectFields:
|
186 |
+
selectFields.append("id")
|
187 |
+
|
188 |
+
# Prepare expressions common to all tables
|
189 |
+
filter_cond = ""
|
190 |
+
filter_fulltext = ""
|
191 |
+
if condition:
|
192 |
+
filter_cond = equivalent_condition_to_str(condition)
|
193 |
+
for matchExpr in matchExprs:
|
194 |
+
if isinstance(matchExpr, MatchTextExpr):
|
195 |
+
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
|
196 |
+
matchExpr.extra_options.update({"filter": filter_cond})
|
197 |
+
fields = ",".join(matchExpr.fields)
|
198 |
+
filter_fulltext = (
|
199 |
+
f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
200 |
+
)
|
201 |
+
if len(filter_cond) != 0:
|
202 |
+
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
|
203 |
+
# doc_store_logger.info(f"filter_fulltext: {filter_fulltext}")
|
204 |
+
minimum_should_match = "0%"
|
205 |
+
if "minimum_should_match" in matchExpr.extra_options:
|
206 |
+
minimum_should_match = (
|
207 |
+
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
|
208 |
+
+ "%"
|
209 |
+
)
|
210 |
+
matchExpr.extra_options.update(
|
211 |
+
{"minimum_should_match": minimum_should_match}
|
212 |
+
)
|
213 |
+
for k, v in matchExpr.extra_options.items():
|
214 |
+
if not isinstance(v, str):
|
215 |
+
matchExpr.extra_options[k] = str(v)
|
216 |
+
elif isinstance(matchExpr, MatchDenseExpr):
|
217 |
+
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
|
218 |
+
matchExpr.extra_options.update({"filter": filter_fulltext})
|
219 |
+
for k, v in matchExpr.extra_options.items():
|
220 |
+
if not isinstance(v, str):
|
221 |
+
matchExpr.extra_options[k] = str(v)
|
222 |
+
if orderBy.fields:
|
223 |
+
order_by_expr_list = list()
|
224 |
+
for order_field in orderBy.fields:
|
225 |
+
order_by_expr_list.append((order_field[0], order_field[1] == 0))
|
226 |
+
|
227 |
+
# Scatter search tables and gather the results
|
228 |
+
for indexName in indexNames:
|
229 |
+
for knowledgebaseId in knowledgebaseIds:
|
230 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
231 |
+
try:
|
232 |
+
table_instance = db_instance.get_table(table_name)
|
233 |
+
except Exception:
|
234 |
+
continue
|
235 |
+
table_list.append(table_name)
|
236 |
+
builder = table_instance.output(selectFields)
|
237 |
+
for matchExpr in matchExprs:
|
238 |
+
if isinstance(matchExpr, MatchTextExpr):
|
239 |
+
fields = ",".join(matchExpr.fields)
|
240 |
+
builder = builder.match_text(
|
241 |
+
fields,
|
242 |
+
matchExpr.matching_text,
|
243 |
+
matchExpr.topn,
|
244 |
+
matchExpr.extra_options,
|
245 |
+
)
|
246 |
+
elif isinstance(matchExpr, MatchDenseExpr):
|
247 |
+
builder = builder.match_dense(
|
248 |
+
matchExpr.vector_column_name,
|
249 |
+
matchExpr.embedding_data,
|
250 |
+
matchExpr.embedding_data_type,
|
251 |
+
matchExpr.distance_type,
|
252 |
+
matchExpr.topn,
|
253 |
+
matchExpr.extra_options,
|
254 |
+
)
|
255 |
+
elif isinstance(matchExpr, FusionExpr):
|
256 |
+
builder = builder.fusion(
|
257 |
+
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
|
258 |
+
)
|
259 |
+
if orderBy.fields:
|
260 |
+
builder.sort(order_by_expr_list)
|
261 |
+
builder.offset(offset).limit(limit)
|
262 |
+
kb_res = builder.to_pl()
|
263 |
+
df_list.append(kb_res)
|
264 |
+
self.connPool.release_conn(inf_conn)
|
265 |
+
res = pl.concat(df_list)
|
266 |
+
doc_store_logger.info("INFINITY search tables: " + str(table_list))
|
267 |
+
return res
|
268 |
+
|
269 |
+
def get(
|
270 |
+
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
|
271 |
+
) -> dict | None:
|
272 |
+
inf_conn = self.connPool.get_conn()
|
273 |
+
db_instance = inf_conn.get_database(self.dbName)
|
274 |
+
df_list = list()
|
275 |
+
assert isinstance(knowledgebaseIds, list)
|
276 |
+
for knowledgebaseId in knowledgebaseIds:
|
277 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
278 |
+
table_instance = db_instance.get_table(table_name)
|
279 |
+
kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
280 |
+
df_list.append(kb_res)
|
281 |
+
self.connPool.release_conn(inf_conn)
|
282 |
+
res = pl.concat(df_list)
|
283 |
+
res_fields = self.getFields(res, res.columns)
|
284 |
+
return res_fields.get(chunkId, None)
|
285 |
+
|
286 |
+
def insert(
|
287 |
+
self, documents: list[dict], indexName: str, knowledgebaseId: str
|
288 |
+
) -> list[str]:
|
289 |
+
inf_conn = self.connPool.get_conn()
|
290 |
+
db_instance = inf_conn.get_database(self.dbName)
|
291 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
292 |
+
try:
|
293 |
+
table_instance = db_instance.get_table(table_name)
|
294 |
+
except InfinityException as e:
|
295 |
+
# src/common/status.cppm, kTableNotExist = 3022
|
296 |
+
if e.error_code != 3022:
|
297 |
+
raise
|
298 |
+
vector_size = 0
|
299 |
+
patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
|
300 |
+
for k in documents[0].keys():
|
301 |
+
m = patt.match(k)
|
302 |
+
if m:
|
303 |
+
vector_size = int(m.group("vector_size"))
|
304 |
+
break
|
305 |
+
if vector_size == 0:
|
306 |
+
raise ValueError("Cannot infer vector size from documents")
|
307 |
+
self.createIdx(indexName, knowledgebaseId, vector_size)
|
308 |
+
table_instance = db_instance.get_table(table_name)
|
309 |
+
|
310 |
+
for d in documents:
|
311 |
+
assert "_id" not in d
|
312 |
+
assert "id" in d
|
313 |
+
for k, v in d.items():
|
314 |
+
if k.endswith("_kwd") and isinstance(v, list):
|
315 |
+
d[k] = " ".join(v)
|
316 |
+
ids = [f"'{d["id"]}'" for d in documents]
|
317 |
+
str_ids = ", ".join(ids)
|
318 |
+
str_filter = f"id IN ({str_ids})"
|
319 |
+
table_instance.delete(str_filter)
|
320 |
+
# for doc in documents:
|
321 |
+
# doc_store_logger.info(f"insert position_list: {doc['position_list']}")
|
322 |
+
# doc_store_logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
|
323 |
+
table_instance.insert(documents)
|
324 |
+
self.connPool.release_conn(inf_conn)
|
325 |
+
doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
|
326 |
+
return []
|
327 |
+
|
328 |
+
def update(
|
329 |
+
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
|
330 |
+
) -> bool:
|
331 |
+
# if 'position_list' in newValue:
|
332 |
+
# doc_store_logger.info(f"update position_list: {newValue['position_list']}")
|
333 |
+
inf_conn = self.connPool.get_conn()
|
334 |
+
db_instance = inf_conn.get_database(self.dbName)
|
335 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
336 |
+
table_instance = db_instance.get_table(table_name)
|
337 |
+
filter = equivalent_condition_to_str(condition)
|
338 |
+
for k, v in newValue.items():
|
339 |
+
if k.endswith("_kwd") and isinstance(v, list):
|
340 |
+
newValue[k] = " ".join(v)
|
341 |
+
table_instance.update(filter, newValue)
|
342 |
+
self.connPool.release_conn(inf_conn)
|
343 |
+
return True
|
344 |
+
|
345 |
+
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
346 |
+
inf_conn = self.connPool.get_conn()
|
347 |
+
db_instance = inf_conn.get_database(self.dbName)
|
348 |
+
table_name = f"{indexName}_{knowledgebaseId}"
|
349 |
+
filter = equivalent_condition_to_str(condition)
|
350 |
+
try:
|
351 |
+
table_instance = db_instance.get_table(table_name)
|
352 |
+
except Exception:
|
353 |
+
doc_store_logger.warning(
|
354 |
+
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
|
355 |
+
)
|
356 |
+
return 0
|
357 |
+
res = table_instance.delete(filter)
|
358 |
+
self.connPool.release_conn(inf_conn)
|
359 |
+
return res.deleted_rows
|
360 |
+
|
361 |
+
"""
|
362 |
+
Helper functions for search result
|
363 |
+
"""
|
364 |
+
|
365 |
+
def getTotal(self, res):
|
366 |
+
return len(res)
|
367 |
+
|
368 |
+
def getChunkIds(self, res):
|
369 |
+
return list(res["id"])
|
370 |
+
|
371 |
+
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
372 |
+
res_fields = {}
|
373 |
+
if not fields:
|
374 |
+
return {}
|
375 |
+
num_rows = len(res)
|
376 |
+
column_id = res["id"]
|
377 |
+
for i in range(num_rows):
|
378 |
+
id = column_id[i]
|
379 |
+
m = {"id": id}
|
380 |
+
for fieldnm in fields:
|
381 |
+
if fieldnm not in res:
|
382 |
+
m[fieldnm] = None
|
383 |
+
continue
|
384 |
+
v = res[fieldnm][i]
|
385 |
+
if isinstance(v, Series):
|
386 |
+
v = list(v)
|
387 |
+
elif fieldnm == "important_kwd":
|
388 |
+
assert isinstance(v, str)
|
389 |
+
v = v.split(" ")
|
390 |
+
else:
|
391 |
+
if not isinstance(v, str):
|
392 |
+
v = str(v)
|
393 |
+
# if fieldnm.endswith("_tks"):
|
394 |
+
# v = rmSpace(v)
|
395 |
+
m[fieldnm] = v
|
396 |
+
res_fields[id] = m
|
397 |
+
return res_fields
|
398 |
+
|
399 |
+
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
400 |
+
ans = {}
|
401 |
+
num_rows = len(res)
|
402 |
+
column_id = res["id"]
|
403 |
+
for i in range(num_rows):
|
404 |
+
id = column_id[i]
|
405 |
+
txt = res[fieldnm][i]
|
406 |
+
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
407 |
+
txts = []
|
408 |
+
for t in re.split(r"[.?!;\n]", txt):
|
409 |
+
for w in keywords:
|
410 |
+
t = re.sub(
|
411 |
+
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
|
412 |
+
% re.escape(w),
|
413 |
+
r"\1<em>\2</em>\3",
|
414 |
+
t,
|
415 |
+
flags=re.IGNORECASE | re.MULTILINE,
|
416 |
+
)
|
417 |
+
if not re.search(
|
418 |
+
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
|
419 |
+
):
|
420 |
+
continue
|
421 |
+
txts.append(t)
|
422 |
+
ans[id] = "...".join(txts)
|
423 |
+
return ans
|
424 |
+
|
425 |
+
def getAggregation(self, res, fieldnm: str):
|
426 |
+
"""
|
427 |
+
TODO: Infinity doesn't provide aggregation
|
428 |
+
"""
|
429 |
+
return list()
|
430 |
+
|
431 |
+
"""
|
432 |
+
SQL
|
433 |
+
"""
|
434 |
+
|
435 |
+
def sql(sql: str, fetch_size: int, format: str):
|
436 |
+
raise NotImplementedError("Not implemented")
|
sdk/python/ragflow_sdk/modules/document.py
CHANGED
@@ -50,8 +50,8 @@ class Document(Base):
|
|
50 |
return res.content
|
51 |
|
52 |
|
53 |
-
def list_chunks(self,page=1, page_size=30, keywords=""
|
54 |
-
data={"keywords": keywords,"page":page,"page_size":page_size
|
55 |
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
|
56 |
res = res.json()
|
57 |
if res.get("code") == 0:
|
|
|
50 |
return res.content
|
51 |
|
52 |
|
53 |
+
def list_chunks(self,page=1, page_size=30, keywords=""):
|
54 |
+
data={"keywords": keywords,"page":page,"page_size":page_size}
|
55 |
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
|
56 |
res = res.json()
|
57 |
if res.get("code") == 0:
|
sdk/python/test/t_chunk.py
CHANGED
@@ -126,6 +126,7 @@ def test_delete_chunk_with_success(get_api_key_fixture):
|
|
126 |
docs = ds.upload_documents(documents)
|
127 |
doc = docs[0]
|
128 |
chunk = doc.add_chunk(content="This is a chunk addition test")
|
|
|
129 |
doc.delete_chunks([chunk.id])
|
130 |
|
131 |
|
@@ -146,6 +147,8 @@ def test_update_chunk_content(get_api_key_fixture):
|
|
146 |
docs = ds.upload_documents(documents)
|
147 |
doc = docs[0]
|
148 |
chunk = doc.add_chunk(content="This is a chunk addition test")
|
|
|
|
|
149 |
chunk.update({"content":"This is a updated content"})
|
150 |
|
151 |
def test_update_chunk_available(get_api_key_fixture):
|
@@ -165,7 +168,9 @@ def test_update_chunk_available(get_api_key_fixture):
|
|
165 |
docs = ds.upload_documents(documents)
|
166 |
doc = docs[0]
|
167 |
chunk = doc.add_chunk(content="This is a chunk addition test")
|
168 |
-
chunk
|
|
|
|
|
169 |
|
170 |
|
171 |
def test_retrieve_chunks(get_api_key_fixture):
|
|
|
126 |
docs = ds.upload_documents(documents)
|
127 |
doc = docs[0]
|
128 |
chunk = doc.add_chunk(content="This is a chunk addition test")
|
129 |
+
sleep(5)
|
130 |
doc.delete_chunks([chunk.id])
|
131 |
|
132 |
|
|
|
147 |
docs = ds.upload_documents(documents)
|
148 |
doc = docs[0]
|
149 |
chunk = doc.add_chunk(content="This is a chunk addition test")
|
150 |
+
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
|
151 |
+
sleep(3)
|
152 |
chunk.update({"content":"This is a updated content"})
|
153 |
|
154 |
def test_update_chunk_available(get_api_key_fixture):
|
|
|
168 |
docs = ds.upload_documents(documents)
|
169 |
doc = docs[0]
|
170 |
chunk = doc.add_chunk(content="This is a chunk addition test")
|
171 |
+
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
|
172 |
+
sleep(3)
|
173 |
+
chunk.update({"available":0})
|
174 |
|
175 |
|
176 |
def test_retrieve_chunks(get_api_key_fixture):
|