zhichyu commited on
Commit
b691127
·
1 Parent(s): ad5587c

Integration with Infinity (#2894)

Browse files

### What problem does this PR solve?

Integration with Infinity

- Replaced ELASTICSEARCH with dataStoreConn
- Renamed deleteByQuery with delete
- Renamed bulk to upsertBulk
- getHighlight, getAggregation
- Fix KGSearch.search
- Moved Dealer.sql_retrieval to es_conn.py


### Type of change

- [x] Refactoring

.github/workflows/tests.yml CHANGED
@@ -78,7 +78,7 @@ jobs:
78
  echo "Waiting for service to be available..."
79
  sleep 5
80
  done
81
- cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
82
 
83
  - name: Stop ragflow:dev
84
  if: always() # always run this step even if previous steps failed
 
78
  echo "Waiting for service to be available..."
79
  sleep 5
80
  done
81
+ cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest --tb=short t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
82
 
83
  - name: Stop ragflow:dev
84
  if: always() # always run this step even if previous steps failed
README.md CHANGED
@@ -285,7 +285,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
285
  git clone https://github.com/infiniflow/ragflow.git
286
  cd ragflow/
287
  export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
288
- ~/.local/bin/poetry install --sync --no-root # install RAGFlow dependent python modules
289
  ```
290
 
291
  3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
@@ -295,7 +295,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
295
 
296
  Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
297
  ```
298
- 127.0.0.1 es01 mysql minio redis
299
  ```
300
  In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
301
 
 
285
  git clone https://github.com/infiniflow/ragflow.git
286
  cd ragflow/
287
  export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
288
+ ~/.local/bin/poetry install --sync --no-root --with=full # install RAGFlow dependent python modules
289
  ```
290
 
291
  3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
 
295
 
296
  Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
297
  ```
298
+ 127.0.0.1 es01 infinity mysql minio redis
299
  ```
300
  In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
301
 
README_ja.md CHANGED
@@ -250,7 +250,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
250
 
251
  `/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
252
  ```
253
- 127.0.0.1 es01 mysql minio redis
254
  ```
255
  **docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).
256
 
 
250
 
251
  `/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
252
  ```
253
+ 127.0.0.1 es01 infinity mysql minio redis
254
  ```
255
  **docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).
256
 
README_ko.md CHANGED
@@ -254,7 +254,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
254
 
255
  `/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
256
  ```
257
- 127.0.0.1 es01 mysql minio redis
258
  ```
259
  **docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).
260
 
 
254
 
255
  `/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
256
  ```
257
+ 127.0.0.1 es01 infinity mysql minio redis
258
  ```
259
  **docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).
260
 
README_zh.md CHANGED
@@ -252,7 +252,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
252
 
253
  在 `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`:
254
  ```
255
- 127.0.0.1 es01 mysql minio redis
256
  ```
257
  在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`,es 端口更新为 `1200`。
258
 
 
252
 
253
  在 `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`:
254
  ```
255
+ 127.0.0.1 es01 infinity mysql minio redis
256
  ```
257
  在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`,es 端口更新为 `1200`。
258
 
api/apps/api_app.py CHANGED
@@ -529,13 +529,14 @@ def list_chunks():
529
  return get_json_result(
530
  data=False, message="Can't find doc_name or doc_id"
531
  )
 
532
 
533
- res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
534
  res = [
535
  {
536
  "content": res_item["content_with_weight"],
537
  "doc_name": res_item["docnm_kwd"],
538
- "img_id": res_item["img_id"]
539
  } for res_item in res
540
  ]
541
 
 
529
  return get_json_result(
530
  data=False, message="Can't find doc_name or doc_id"
531
  )
532
+ kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
533
 
534
+ res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
535
  res = [
536
  {
537
  "content": res_item["content_with_weight"],
538
  "doc_name": res_item["docnm_kwd"],
539
+ "image_id": res_item["img_id"]
540
  } for res_item in res
541
  ]
542
 
api/apps/chunk_app.py CHANGED
@@ -18,12 +18,10 @@ import json
18
 
19
  from flask import request
20
  from flask_login import login_required, current_user
21
- from elasticsearch_dsl import Q
22
 
23
  from api.db.services.dialog_service import keyword_extraction
24
  from rag.app.qa import rmPrefix, beAdoc
25
  from rag.nlp import search, rag_tokenizer
26
- from rag.utils.es_conn import ELASTICSEARCH
27
  from rag.utils import rmSpace
28
  from api.db import LLMType, ParserType
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -31,12 +29,11 @@ from api.db.services.llm_service import LLMBundle
31
  from api.db.services.user_service import UserTenantService
32
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
33
  from api.db.services.document_service import DocumentService
34
- from api.settings import RetCode, retrievaler, kg_retrievaler
35
  from api.utils.api_utils import get_json_result
36
  import hashlib
37
  import re
38
 
39
-
40
  @manager.route('/list', methods=['POST'])
41
  @login_required
42
  @validate_request("doc_id")
@@ -53,12 +50,13 @@ def list_chunk():
53
  e, doc = DocumentService.get_by_id(doc_id)
54
  if not e:
55
  return get_data_error_result(message="Document not found!")
 
56
  query = {
57
  "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
58
  }
59
  if "available_int" in req:
60
  query["available_int"] = int(req["available_int"])
61
- sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
62
  res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
63
  for id in sres.ids:
64
  d = {
@@ -69,16 +67,12 @@ def list_chunk():
69
  "doc_id": sres.field[id]["doc_id"],
70
  "docnm_kwd": sres.field[id]["docnm_kwd"],
71
  "important_kwd": sres.field[id].get("important_kwd", []),
72
- "img_id": sres.field[id].get("img_id", ""),
73
  "available_int": sres.field[id].get("available_int", 1),
74
- "positions": sres.field[id].get("position_int", "").split("\t")
75
  }
76
- if len(d["positions"]) % 5 == 0:
77
- poss = []
78
- for i in range(0, len(d["positions"]), 5):
79
- poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
80
- float(d["positions"][i + 3]), float(d["positions"][i + 4])])
81
- d["positions"] = poss
82
  res["chunks"].append(d)
83
  return get_json_result(data=res)
84
  except Exception as e:
@@ -96,22 +90,20 @@ def get():
96
  tenants = UserTenantService.query(user_id=current_user.id)
97
  if not tenants:
98
  return get_data_error_result(message="Tenant not found!")
99
- res = ELASTICSEARCH.get(
100
- chunk_id, search.index_name(
101
- tenants[0].tenant_id))
102
- if not res.get("found"):
 
103
  return server_error_response("Chunk not found")
104
- id = res["_id"]
105
- res = res["_source"]
106
- res["chunk_id"] = id
107
  k = []
108
- for n in res.keys():
109
  if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
110
  k.append(n)
111
  for n in k:
112
- del res[n]
113
 
114
- return get_json_result(data=res)
115
  except Exception as e:
116
  if str(e).find("NotFoundError") >= 0:
117
  return get_json_result(data=False, message='Chunk not found!',
@@ -162,7 +154,7 @@ def set():
162
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
163
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
164
  d["q_%d_vec" % len(v)] = v.tolist()
165
- ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
166
  return get_json_result(data=True)
167
  except Exception as e:
168
  return server_error_response(e)
@@ -174,11 +166,11 @@ def set():
174
  def switch():
175
  req = request.json
176
  try:
177
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
178
- if not tenant_id:
179
- return get_data_error_result(message="Tenant not found!")
180
- if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
181
- search.index_name(tenant_id)):
182
  return get_data_error_result(message="Index updating failure")
183
  return get_json_result(data=True)
184
  except Exception as e:
@@ -191,12 +183,11 @@ def switch():
191
  def rm():
192
  req = request.json
193
  try:
194
- if not ELASTICSEARCH.deleteByQuery(
195
- Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
196
- return get_data_error_result(message="Index updating failure")
197
  e, doc = DocumentService.get_by_id(req["doc_id"])
198
  if not e:
199
  return get_data_error_result(message="Document not found!")
 
 
200
  deleted_chunk_ids = req["chunk_ids"]
201
  chunk_number = len(deleted_chunk_ids)
202
  DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
@@ -239,7 +230,7 @@ def create():
239
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
240
  v = 0.1 * v[0] + 0.9 * v[1]
241
  d["q_%d_vec" % len(v)] = v.tolist()
242
- ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
243
 
244
  DocumentService.increment_chunk_num(
245
  doc.id, doc.kb_id, c, 1, 0)
@@ -256,8 +247,9 @@ def retrieval_test():
256
  page = int(req.get("page", 1))
257
  size = int(req.get("size", 30))
258
  question = req["question"]
259
- kb_id = req["kb_id"]
260
- if isinstance(kb_id, str): kb_id = [kb_id]
 
261
  doc_ids = req.get("doc_ids", [])
262
  similarity_threshold = float(req.get("similarity_threshold", 0.0))
263
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
@@ -265,17 +257,17 @@ def retrieval_test():
265
 
266
  try:
267
  tenants = UserTenantService.query(user_id=current_user.id)
268
- for kid in kb_id:
269
  for tenant in tenants:
270
  if KnowledgebaseService.query(
271
- tenant_id=tenant.tenant_id, id=kid):
272
  break
273
  else:
274
  return get_json_result(
275
  data=False, message='Only owner of knowledgebase authorized for this operation.',
276
  code=RetCode.OPERATING_ERROR)
277
 
278
- e, kb = KnowledgebaseService.get_by_id(kb_id[0])
279
  if not e:
280
  return get_data_error_result(message="Knowledgebase not found!")
281
 
@@ -290,7 +282,7 @@ def retrieval_test():
290
  question += keyword_extraction(chat_mdl, question)
291
 
292
  retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
293
- ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
294
  similarity_threshold, vector_similarity_weight, top,
295
  doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
296
  for c in ranks["chunks"]:
@@ -309,12 +301,16 @@ def retrieval_test():
309
  @login_required
310
  def knowledge_graph():
311
  doc_id = request.args["doc_id"]
 
 
 
 
 
312
  req = {
313
  "doc_ids":[doc_id],
314
  "knowledge_graph_kwd": ["graph", "mind_map"]
315
  }
316
- tenant_id = DocumentService.get_tenant_id(doc_id)
317
- sres = retrievaler.search(req, search.index_name(tenant_id))
318
  obj = {"graph": {}, "mind_map": {}}
319
  for id in sres.ids[:2]:
320
  ty = sres.field[id]["knowledge_graph_kwd"]
 
18
 
19
  from flask import request
20
  from flask_login import login_required, current_user
 
21
 
22
  from api.db.services.dialog_service import keyword_extraction
23
  from rag.app.qa import rmPrefix, beAdoc
24
  from rag.nlp import search, rag_tokenizer
 
25
  from rag.utils import rmSpace
26
  from api.db import LLMType, ParserType
27
  from api.db.services.knowledgebase_service import KnowledgebaseService
 
29
  from api.db.services.user_service import UserTenantService
30
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
31
  from api.db.services.document_service import DocumentService
32
+ from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
33
  from api.utils.api_utils import get_json_result
34
  import hashlib
35
  import re
36
 
 
37
  @manager.route('/list', methods=['POST'])
38
  @login_required
39
  @validate_request("doc_id")
 
50
  e, doc = DocumentService.get_by_id(doc_id)
51
  if not e:
52
  return get_data_error_result(message="Document not found!")
53
+ kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
54
  query = {
55
  "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
56
  }
57
  if "available_int" in req:
58
  query["available_int"] = int(req["available_int"])
59
+ sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
60
  res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
61
  for id in sres.ids:
62
  d = {
 
67
  "doc_id": sres.field[id]["doc_id"],
68
  "docnm_kwd": sres.field[id]["docnm_kwd"],
69
  "important_kwd": sres.field[id].get("important_kwd", []),
70
+ "image_id": sres.field[id].get("img_id", ""),
71
  "available_int": sres.field[id].get("available_int", 1),
72
+ "positions": json.loads(sres.field[id].get("position_list", "[]")),
73
  }
74
+ assert isinstance(d["positions"], list)
75
+ assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
 
 
 
 
76
  res["chunks"].append(d)
77
  return get_json_result(data=res)
78
  except Exception as e:
 
90
  tenants = UserTenantService.query(user_id=current_user.id)
91
  if not tenants:
92
  return get_data_error_result(message="Tenant not found!")
93
+ tenant_id = tenants[0].tenant_id
94
+
95
+ kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
96
+ chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
97
+ if chunk is None:
98
  return server_error_response("Chunk not found")
 
 
 
99
  k = []
100
+ for n in chunk.keys():
101
  if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
102
  k.append(n)
103
  for n in k:
104
+ del chunk[n]
105
 
106
+ return get_json_result(data=chunk)
107
  except Exception as e:
108
  if str(e).find("NotFoundError") >= 0:
109
  return get_json_result(data=False, message='Chunk not found!',
 
154
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
155
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
156
  d["q_%d_vec" % len(v)] = v.tolist()
157
+ docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
158
  return get_json_result(data=True)
159
  except Exception as e:
160
  return server_error_response(e)
 
166
  def switch():
167
  req = request.json
168
  try:
169
+ e, doc = DocumentService.get_by_id(req["doc_id"])
170
+ if not e:
171
+ return get_data_error_result(message="Document not found!")
172
+ if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
173
+ search.index_name(doc.tenant_id), doc.kb_id):
174
  return get_data_error_result(message="Index updating failure")
175
  return get_json_result(data=True)
176
  except Exception as e:
 
183
  def rm():
184
  req = request.json
185
  try:
 
 
 
186
  e, doc = DocumentService.get_by_id(req["doc_id"])
187
  if not e:
188
  return get_data_error_result(message="Document not found!")
189
+ if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
190
+ return get_data_error_result(message="Index updating failure")
191
  deleted_chunk_ids = req["chunk_ids"]
192
  chunk_number = len(deleted_chunk_ids)
193
  DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
 
230
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
231
  v = 0.1 * v[0] + 0.9 * v[1]
232
  d["q_%d_vec" % len(v)] = v.tolist()
233
+ docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
234
 
235
  DocumentService.increment_chunk_num(
236
  doc.id, doc.kb_id, c, 1, 0)
 
247
  page = int(req.get("page", 1))
248
  size = int(req.get("size", 30))
249
  question = req["question"]
250
+ kb_ids = req["kb_id"]
251
+ if isinstance(kb_ids, str):
252
+ kb_ids = [kb_ids]
253
  doc_ids = req.get("doc_ids", [])
254
  similarity_threshold = float(req.get("similarity_threshold", 0.0))
255
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
 
257
 
258
  try:
259
  tenants = UserTenantService.query(user_id=current_user.id)
260
+ for kb_id in kb_ids:
261
  for tenant in tenants:
262
  if KnowledgebaseService.query(
263
+ tenant_id=tenant.tenant_id, id=kb_id):
264
  break
265
  else:
266
  return get_json_result(
267
  data=False, message='Only owner of knowledgebase authorized for this operation.',
268
  code=RetCode.OPERATING_ERROR)
269
 
270
+ e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
271
  if not e:
272
  return get_data_error_result(message="Knowledgebase not found!")
273
 
 
282
  question += keyword_extraction(chat_mdl, question)
283
 
284
  retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
285
+ ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
286
  similarity_threshold, vector_similarity_weight, top,
287
  doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
288
  for c in ranks["chunks"]:
 
301
  @login_required
302
  def knowledge_graph():
303
  doc_id = request.args["doc_id"]
304
+ e, doc = DocumentService.get_by_id(doc_id)
305
+ if not e:
306
+ return get_data_error_result(message="Document not found!")
307
+ tenant_id = DocumentService.get_tenant_id(doc_id)
308
+ kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
309
  req = {
310
  "doc_ids":[doc_id],
311
  "knowledge_graph_kwd": ["graph", "mind_map"]
312
  }
313
+ sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
 
314
  obj = {"graph": {}, "mind_map": {}}
315
  for id in sres.ids[:2]:
316
  ty = sres.field[id]["knowledge_graph_kwd"]
api/apps/document_app.py CHANGED
@@ -17,7 +17,6 @@ import pathlib
17
  import re
18
 
19
  import flask
20
- from elasticsearch_dsl import Q
21
  from flask import request
22
  from flask_login import login_required, current_user
23
 
@@ -27,14 +26,13 @@ from api.db.services.file_service import FileService
27
  from api.db.services.task_service import TaskService, queue_tasks
28
  from api.db.services.user_service import UserTenantService
29
  from rag.nlp import search
30
- from rag.utils.es_conn import ELASTICSEARCH
31
  from api.db.services import duplicate_name
32
  from api.db.services.knowledgebase_service import KnowledgebaseService
33
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
34
  from api.utils import get_uuid
35
  from api.db import FileType, TaskStatus, ParserType, FileSource
36
  from api.db.services.document_service import DocumentService, doc_upload_and_parse
37
- from api.settings import RetCode
38
  from api.utils.api_utils import get_json_result
39
  from rag.utils.storage_factory import STORAGE_IMPL
40
  from api.utils.file_utils import filename_type, thumbnail
@@ -275,18 +273,8 @@ def change_status():
275
  return get_data_error_result(
276
  message="Database error (Document update)!")
277
 
278
- if str(req["status"]) == "0":
279
- ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
280
- scripts="ctx._source.available_int=0;",
281
- idxnm=search.index_name(
282
- kb.tenant_id)
283
- )
284
- else:
285
- ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
286
- scripts="ctx._source.available_int=1;",
287
- idxnm=search.index_name(
288
- kb.tenant_id)
289
- )
290
  return get_json_result(data=True)
291
  except Exception as e:
292
  return server_error_response(e)
@@ -365,8 +353,11 @@ def run():
365
  tenant_id = DocumentService.get_tenant_id(id)
366
  if not tenant_id:
367
  return get_data_error_result(message="Tenant not found!")
368
- ELASTICSEARCH.deleteByQuery(
369
- Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
 
 
 
370
 
371
  if str(req["run"]) == TaskStatus.RUNNING.value:
372
  TaskService.filter_delete([Task.doc_id == id])
@@ -490,8 +481,8 @@ def change_parser():
490
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
491
  if not tenant_id:
492
  return get_data_error_result(message="Tenant not found!")
493
- ELASTICSEARCH.deleteByQuery(
494
- Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
495
 
496
  return get_json_result(data=True)
497
  except Exception as e:
 
17
  import re
18
 
19
  import flask
 
20
  from flask import request
21
  from flask_login import login_required, current_user
22
 
 
26
  from api.db.services.task_service import TaskService, queue_tasks
27
  from api.db.services.user_service import UserTenantService
28
  from rag.nlp import search
 
29
  from api.db.services import duplicate_name
30
  from api.db.services.knowledgebase_service import KnowledgebaseService
31
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
32
  from api.utils import get_uuid
33
  from api.db import FileType, TaskStatus, ParserType, FileSource
34
  from api.db.services.document_service import DocumentService, doc_upload_and_parse
35
+ from api.settings import RetCode, docStoreConn
36
  from api.utils.api_utils import get_json_result
37
  from rag.utils.storage_factory import STORAGE_IMPL
38
  from api.utils.file_utils import filename_type, thumbnail
 
273
  return get_data_error_result(
274
  message="Database error (Document update)!")
275
 
276
+ status = int(req["status"])
277
+ docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
 
 
 
 
 
 
 
 
 
 
278
  return get_json_result(data=True)
279
  except Exception as e:
280
  return server_error_response(e)
 
353
  tenant_id = DocumentService.get_tenant_id(id)
354
  if not tenant_id:
355
  return get_data_error_result(message="Tenant not found!")
356
+ e, doc = DocumentService.get_by_id(id)
357
+ if not e:
358
+ return get_data_error_result(message="Document not found!")
359
+ if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
360
+ docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
361
 
362
  if str(req["run"]) == TaskStatus.RUNNING.value:
363
  TaskService.filter_delete([Task.doc_id == id])
 
481
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
482
  if not tenant_id:
483
  return get_data_error_result(message="Tenant not found!")
484
+ if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
485
+ docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
486
 
487
  return get_json_result(data=True)
488
  except Exception as e:
api/apps/kb_app.py CHANGED
@@ -28,6 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
28
  from api.db.db_models import File
29
  from api.settings import RetCode
30
  from api.utils.api_utils import get_json_result
 
 
31
 
32
 
33
  @manager.route('/create', methods=['post'])
@@ -166,6 +168,9 @@ def rm():
166
  if not KnowledgebaseService.delete_by_id(req["kb_id"]):
167
  return get_data_error_result(
168
  message="Database error (Knowledgebase removal)!")
 
 
 
169
  return get_json_result(data=True)
170
  except Exception as e:
171
  return server_error_response(e)
 
28
  from api.db.db_models import File
29
  from api.settings import RetCode
30
  from api.utils.api_utils import get_json_result
31
+ from api.settings import docStoreConn
32
+ from rag.nlp import search
33
 
34
 
35
  @manager.route('/create', methods=['post'])
 
168
  if not KnowledgebaseService.delete_by_id(req["kb_id"]):
169
  return get_data_error_result(
170
  message="Database error (Knowledgebase removal)!")
171
+ tenants = UserTenantService.query(user_id=current_user.id)
172
+ for tenant in tenants:
173
+ docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
174
  return get_json_result(data=True)
175
  except Exception as e:
176
  return server_error_response(e)
api/apps/sdk/doc.py CHANGED
@@ -30,7 +30,6 @@ from api.db.services.task_service import TaskService, queue_tasks
30
  from api.utils.api_utils import server_error_response
31
  from api.utils.api_utils import get_result, get_error_data_result
32
  from io import BytesIO
33
- from elasticsearch_dsl import Q
34
  from flask import request, send_file
35
  from api.db import FileSource, TaskStatus, FileType
36
  from api.db.db_models import File
@@ -42,7 +41,7 @@ from api.settings import RetCode, retrievaler
42
  from api.utils.api_utils import construct_json_result, get_parser_config
43
  from rag.nlp import search
44
  from rag.utils import rmSpace
45
- from rag.utils.es_conn import ELASTICSEARCH
46
  from rag.utils.storage_factory import STORAGE_IMPL
47
  import os
48
 
@@ -293,9 +292,7 @@ def update_doc(tenant_id, dataset_id, document_id):
293
  )
294
  if not e:
295
  return get_error_data_result(message="Document not found!")
296
- ELASTICSEARCH.deleteByQuery(
297
- Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)
298
- )
299
 
300
  return get_result()
301
 
@@ -647,9 +644,7 @@ def parse(tenant_id, dataset_id):
647
  info["chunk_num"] = 0
648
  info["token_num"] = 0
649
  DocumentService.update_by_id(id, info)
650
- ELASTICSEARCH.deleteByQuery(
651
- Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
652
- )
653
  TaskService.filter_delete([Task.doc_id == id])
654
  e, doc = DocumentService.get_by_id(id)
655
  doc = doc.to_dict()
@@ -713,9 +708,7 @@ def stop_parsing(tenant_id, dataset_id):
713
  )
714
  info = {"run": "2", "progress": 0, "chunk_num": 0}
715
  DocumentService.update_by_id(id, info)
716
- ELASTICSEARCH.deleteByQuery(
717
- Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
718
- )
719
  return get_result()
720
 
721
 
@@ -812,7 +805,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
812
  "question": question,
813
  "sort": True,
814
  }
815
- sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
816
  key_mapping = {
817
  "chunk_num": "chunk_count",
818
  "kb_id": "dataset_id",
@@ -833,51 +825,56 @@ def list_chunks(tenant_id, dataset_id, document_id):
833
  renamed_doc[new_key] = value
834
  if key == "run":
835
  renamed_doc["run"] = run_mapping.get(str(value))
836
- res = {"total": sres.total, "chunks": [], "doc": renamed_doc}
 
837
  origin_chunks = []
838
- sign = 0
839
- for id in sres.ids:
840
- d = {
841
- "chunk_id": id,
842
- "content_with_weight": (
843
- rmSpace(sres.highlight[id])
844
- if question and id in sres.highlight
845
- else sres.field[id].get("content_with_weight", "")
846
- ),
847
- "doc_id": sres.field[id]["doc_id"],
848
- "docnm_kwd": sres.field[id]["docnm_kwd"],
849
- "important_kwd": sres.field[id].get("important_kwd", []),
850
- "img_id": sres.field[id].get("img_id", ""),
851
- "available_int": sres.field[id].get("available_int", 1),
852
- "positions": sres.field[id].get("position_int", "").split("\t"),
853
- }
854
- if len(d["positions"]) % 5 == 0:
855
- poss = []
856
- for i in range(0, len(d["positions"]), 5):
857
- poss.append(
858
- [
859
- float(d["positions"][i]),
860
- float(d["positions"][i + 1]),
861
- float(d["positions"][i + 2]),
862
- float(d["positions"][i + 3]),
863
- float(d["positions"][i + 4]),
864
- ]
865
- )
866
- d["positions"] = poss
 
 
 
867
 
868
- origin_chunks.append(d)
 
 
 
 
 
 
869
  if req.get("id"):
870
- if req.get("id") == id:
871
- origin_chunks.clear()
872
- origin_chunks.append(d)
873
- sign = 1
874
- break
875
- if req.get("id"):
876
- if sign == 0:
877
- return get_error_data_result(f"Can't find this chunk {req.get('id')}")
878
  for chunk in origin_chunks:
879
  key_mapping = {
880
- "chunk_id": "id",
881
  "content_with_weight": "content",
882
  "doc_id": "document_id",
883
  "important_kwd": "important_keywords",
@@ -996,9 +993,9 @@ def add_chunk(tenant_id, dataset_id, document_id):
996
  )
997
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
998
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
999
- d["kb_id"] = [doc.kb_id]
1000
  d["docnm_kwd"] = doc.name
1001
- d["doc_id"] = doc.id
1002
  embd_id = DocumentService.get_embd_id(document_id)
1003
  embd_mdl = TenantLLMService.model_instance(
1004
  tenant_id, LLMType.EMBEDDING.value, embd_id
@@ -1006,14 +1003,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
1006
  v, c = embd_mdl.encode([doc.name, req["content"]])
1007
  v = 0.1 * v[0] + 0.9 * v[1]
1008
  d["q_%d_vec" % len(v)] = v.tolist()
1009
- ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
1010
 
1011
  DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
1012
- d["chunk_id"] = chunk_id
1013
- d["kb_id"] = doc.kb_id
1014
  # rename keys
1015
  key_mapping = {
1016
- "chunk_id": "id",
1017
  "content_with_weight": "content",
1018
  "doc_id": "document_id",
1019
  "important_kwd": "important_keywords",
@@ -1079,36 +1074,16 @@ def rm_chunk(tenant_id, dataset_id, document_id):
1079
  """
1080
  if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
1081
  return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
1082
- doc = DocumentService.query(id=document_id, kb_id=dataset_id)
1083
- if not doc:
1084
- return get_error_data_result(
1085
- message=f"You don't own the document {document_id}."
1086
- )
1087
- doc = doc[0]
1088
  req = request.json
1089
- if not req.get("chunk_ids"):
1090
- return get_error_data_result("`chunk_ids` is required")
1091
- query = {"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True}
1092
- sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
1093
- if not req:
1094
- chunk_ids = None
1095
- else:
1096
- chunk_ids = req.get("chunk_ids")
1097
- if not chunk_ids:
1098
- chunk_list = sres.ids
1099
- else:
1100
- chunk_list = chunk_ids
1101
- for chunk_id in chunk_list:
1102
- if chunk_id not in sres.ids:
1103
- return get_error_data_result(f"Chunk {chunk_id} not found")
1104
- if not ELASTICSEARCH.deleteByQuery(
1105
- Q("ids", values=chunk_list), search.index_name(tenant_id)
1106
- ):
1107
- return get_error_data_result(message="Index updating failure")
1108
- deleted_chunk_ids = chunk_list
1109
- chunk_number = len(deleted_chunk_ids)
1110
- DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
1111
- return get_result()
1112
 
1113
 
1114
  @manager.route(
@@ -1168,9 +1143,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
1168
  schema:
1169
  type: object
1170
  """
1171
- try:
1172
- res = ELASTICSEARCH.get(chunk_id, search.index_name(tenant_id))
1173
- except Exception:
1174
  return get_error_data_result(f"Can't find this chunk {chunk_id}")
1175
  if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
1176
  return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
@@ -1180,19 +1154,12 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
1180
  message=f"You don't own the document {document_id}."
1181
  )
1182
  doc = doc[0]
1183
- query = {
1184
- "doc_ids": [document_id],
1185
- "page": 1,
1186
- "size": 1024,
1187
- "question": "",
1188
- "sort": True,
1189
- }
1190
- sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
1191
- if chunk_id not in sres.ids:
1192
- return get_error_data_result(f"You don't own the chunk {chunk_id}")
1193
  req = request.json
1194
- content = res["_source"].get("content_with_weight")
1195
- d = {"id": chunk_id, "content_with_weight": req.get("content", content)}
 
 
 
1196
  d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
1197
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
1198
  if "important_keywords" in req:
@@ -1220,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
1220
  v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
1221
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
1222
  d["q_%d_vec" % len(v)] = v.tolist()
1223
- ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
1224
  return get_result()
1225
 
1226
 
 
30
  from api.utils.api_utils import server_error_response
31
  from api.utils.api_utils import get_result, get_error_data_result
32
  from io import BytesIO
 
33
  from flask import request, send_file
34
  from api.db import FileSource, TaskStatus, FileType
35
  from api.db.db_models import File
 
41
  from api.utils.api_utils import construct_json_result, get_parser_config
42
  from rag.nlp import search
43
  from rag.utils import rmSpace
44
+ from api.settings import docStoreConn
45
  from rag.utils.storage_factory import STORAGE_IMPL
46
  import os
47
 
 
292
  )
293
  if not e:
294
  return get_error_data_result(message="Document not found!")
295
+ docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
 
 
296
 
297
  return get_result()
298
 
 
644
  info["chunk_num"] = 0
645
  info["token_num"] = 0
646
  DocumentService.update_by_id(id, info)
647
+ docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
 
 
648
  TaskService.filter_delete([Task.doc_id == id])
649
  e, doc = DocumentService.get_by_id(id)
650
  doc = doc.to_dict()
 
708
  )
709
  info = {"run": "2", "progress": 0, "chunk_num": 0}
710
  DocumentService.update_by_id(id, info)
711
+ docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
 
 
712
  return get_result()
713
 
714
 
 
805
  "question": question,
806
  "sort": True,
807
  }
 
808
  key_mapping = {
809
  "chunk_num": "chunk_count",
810
  "kb_id": "dataset_id",
 
825
  renamed_doc[new_key] = value
826
  if key == "run":
827
  renamed_doc["run"] = run_mapping.get(str(value))
828
+
829
+ res = {"total": 0, "chunks": [], "doc": renamed_doc}
830
  origin_chunks = []
831
+ if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
832
+ sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
833
+ res["total"] = sres.total
834
+ sign = 0
835
+ for id in sres.ids:
836
+ d = {
837
+ "id": id,
838
+ "content_with_weight": (
839
+ rmSpace(sres.highlight[id])
840
+ if question and id in sres.highlight
841
+ else sres.field[id].get("content_with_weight", "")
842
+ ),
843
+ "doc_id": sres.field[id]["doc_id"],
844
+ "docnm_kwd": sres.field[id]["docnm_kwd"],
845
+ "important_kwd": sres.field[id].get("important_kwd", []),
846
+ "img_id": sres.field[id].get("img_id", ""),
847
+ "available_int": sres.field[id].get("available_int", 1),
848
+ "positions": sres.field[id].get("position_int", "").split("\t"),
849
+ }
850
+ if len(d["positions"]) % 5 == 0:
851
+ poss = []
852
+ for i in range(0, len(d["positions"]), 5):
853
+ poss.append(
854
+ [
855
+ float(d["positions"][i]),
856
+ float(d["positions"][i + 1]),
857
+ float(d["positions"][i + 2]),
858
+ float(d["positions"][i + 3]),
859
+ float(d["positions"][i + 4]),
860
+ ]
861
+ )
862
+ d["positions"] = poss
863
 
864
+ origin_chunks.append(d)
865
+ if req.get("id"):
866
+ if req.get("id") == id:
867
+ origin_chunks.clear()
868
+ origin_chunks.append(d)
869
+ sign = 1
870
+ break
871
  if req.get("id"):
872
+ if sign == 0:
873
+ return get_error_data_result(f"Can't find this chunk {req.get('id')}")
874
+
 
 
 
 
 
875
  for chunk in origin_chunks:
876
  key_mapping = {
877
+ "id": "id",
878
  "content_with_weight": "content",
879
  "doc_id": "document_id",
880
  "important_kwd": "important_keywords",
 
993
  )
994
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
995
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
996
+ d["kb_id"] = dataset_id
997
  d["docnm_kwd"] = doc.name
998
+ d["doc_id"] = document_id
999
  embd_id = DocumentService.get_embd_id(document_id)
1000
  embd_mdl = TenantLLMService.model_instance(
1001
  tenant_id, LLMType.EMBEDDING.value, embd_id
 
1003
  v, c = embd_mdl.encode([doc.name, req["content"]])
1004
  v = 0.1 * v[0] + 0.9 * v[1]
1005
  d["q_%d_vec" % len(v)] = v.tolist()
1006
+ docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
1007
 
1008
  DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
 
 
1009
  # rename keys
1010
  key_mapping = {
1011
+ "id": "id",
1012
  "content_with_weight": "content",
1013
  "doc_id": "document_id",
1014
  "important_kwd": "important_keywords",
 
1074
  """
1075
  if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
1076
  return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
 
 
 
 
 
 
1077
  req = request.json
1078
+ condition = {"doc_id": document_id}
1079
+ if "chunk_ids" in req:
1080
+ condition["id"] = req["chunk_ids"]
1081
+ chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
1082
+ if chunk_number != 0:
1083
+ DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
1084
+ if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
1085
+ return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(req["chunk_ids"])}")
1086
+ return get_result(message=f"deleted {chunk_number} chunks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1087
 
1088
 
1089
  @manager.route(
 
1143
  schema:
1144
  type: object
1145
  """
1146
+ chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
1147
+ if chunk is None:
 
1148
  return get_error_data_result(f"Can't find this chunk {chunk_id}")
1149
  if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
1150
  return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
 
1154
  message=f"You don't own the document {document_id}."
1155
  )
1156
  doc = doc[0]
 
 
 
 
 
 
 
 
 
 
1157
  req = request.json
1158
+ if "content" in req:
1159
+ content = req["content"]
1160
+ else:
1161
+ content = chunk.get("content_with_weight", "")
1162
+ d = {"id": chunk_id, "content_with_weight": content}
1163
  d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
1164
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
1165
  if "important_keywords" in req:
 
1187
  v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
1188
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
1189
  d["q_%d_vec" % len(v)] = v.tolist()
1190
+ docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
1191
  return get_result()
1192
 
1193
 
api/apps/system_app.py CHANGED
@@ -31,7 +31,7 @@ from api.utils.api_utils import (
31
  generate_confirmation_token,
32
  )
33
  from api.versions import get_rag_version
34
- from rag.utils.es_conn import ELASTICSEARCH
35
  from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
36
  from timeit import default_timer as timer
37
 
@@ -98,10 +98,11 @@ def status():
98
  res = {}
99
  st = timer()
100
  try:
101
- res["es"] = ELASTICSEARCH.health()
102
- res["es"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
103
  except Exception as e:
104
- res["es"] = {
 
105
  "status": "red",
106
  "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
107
  "error": str(e),
 
31
  generate_confirmation_token,
32
  )
33
  from api.versions import get_rag_version
34
+ from api.settings import docStoreConn
35
  from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
36
  from timeit import default_timer as timer
37
 
 
98
  res = {}
99
  st = timer()
100
  try:
101
+ res["doc_store"] = docStoreConn.health()
102
+ res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
103
  except Exception as e:
104
+ res["doc_store"] = {
105
+ "type": "unknown",
106
  "status": "red",
107
  "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
108
  "error": str(e),
api/db/db_models.py CHANGED
@@ -470,7 +470,7 @@ class User(DataBaseModel, UserMixin):
470
  status = CharField(
471
  max_length=1,
472
  null=True,
473
- help_text="is it validate(0: wasted1: validate)",
474
  default="1",
475
  index=True)
476
  is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
@@ -525,7 +525,7 @@ class Tenant(DataBaseModel):
525
  status = CharField(
526
  max_length=1,
527
  null=True,
528
- help_text="is it validate(0: wasted1: validate)",
529
  default="1",
530
  index=True)
531
 
@@ -542,7 +542,7 @@ class UserTenant(DataBaseModel):
542
  status = CharField(
543
  max_length=1,
544
  null=True,
545
- help_text="is it validate(0: wasted1: validate)",
546
  default="1",
547
  index=True)
548
 
@@ -559,7 +559,7 @@ class InvitationCode(DataBaseModel):
559
  status = CharField(
560
  max_length=1,
561
  null=True,
562
- help_text="is it validate(0: wasted1: validate)",
563
  default="1",
564
  index=True)
565
 
@@ -582,7 +582,7 @@ class LLMFactories(DataBaseModel):
582
  status = CharField(
583
  max_length=1,
584
  null=True,
585
- help_text="is it validate(0: wasted1: validate)",
586
  default="1",
587
  index=True)
588
 
@@ -616,7 +616,7 @@ class LLM(DataBaseModel):
616
  status = CharField(
617
  max_length=1,
618
  null=True,
619
- help_text="is it validate(0: wasted1: validate)",
620
  default="1",
621
  index=True)
622
 
@@ -703,7 +703,7 @@ class Knowledgebase(DataBaseModel):
703
  status = CharField(
704
  max_length=1,
705
  null=True,
706
- help_text="is it validate(0: wasted1: validate)",
707
  default="1",
708
  index=True)
709
 
@@ -767,7 +767,7 @@ class Document(DataBaseModel):
767
  status = CharField(
768
  max_length=1,
769
  null=True,
770
- help_text="is it validate(0: wasted1: validate)",
771
  default="1",
772
  index=True)
773
 
@@ -904,7 +904,7 @@ class Dialog(DataBaseModel):
904
  status = CharField(
905
  max_length=1,
906
  null=True,
907
- help_text="is it validate(0: wasted1: validate)",
908
  default="1",
909
  index=True)
910
 
@@ -987,7 +987,7 @@ def migrate_db():
987
  help_text="where dose this document come from",
988
  index=True))
989
  )
990
- except Exception as e:
991
  pass
992
  try:
993
  migrate(
@@ -996,7 +996,7 @@ def migrate_db():
996
  help_text="default rerank model ID"))
997
 
998
  )
999
- except Exception as e:
1000
  pass
1001
  try:
1002
  migrate(
@@ -1004,59 +1004,59 @@ def migrate_db():
1004
  help_text="default rerank model ID"))
1005
 
1006
  )
1007
- except Exception as e:
1008
  pass
1009
  try:
1010
  migrate(
1011
  migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
1012
 
1013
  )
1014
- except Exception as e:
1015
  pass
1016
  try:
1017
  migrate(
1018
  migrator.alter_column_type('tenant_llm', 'api_key',
1019
  CharField(max_length=1024, null=True, help_text="API KEY", index=True))
1020
  )
1021
- except Exception as e:
1022
  pass
1023
  try:
1024
  migrate(
1025
  migrator.add_column('api_token', 'source',
1026
  CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
1027
  )
1028
- except Exception as e:
1029
  pass
1030
  try:
1031
  migrate(
1032
  migrator.add_column("tenant","tts_id",
1033
  CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
1034
  )
1035
- except Exception as e:
1036
  pass
1037
  try:
1038
  migrate(
1039
  migrator.add_column('api_4_conversation', 'source',
1040
  CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
1041
  )
1042
- except Exception as e:
1043
  pass
1044
  try:
1045
  DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
1046
  DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
1047
- except Exception as e:
1048
  pass
1049
  try:
1050
  migrate(
1051
  migrator.add_column('task', 'retry_count', IntegerField(default=0))
1052
  )
1053
- except Exception as e:
1054
  pass
1055
  try:
1056
  migrate(
1057
  migrator.alter_column_type('api_token', 'dialog_id',
1058
  CharField(max_length=32, null=True, index=True))
1059
  )
1060
- except Exception as e:
1061
  pass
1062
 
 
470
  status = CharField(
471
  max_length=1,
472
  null=True,
473
+ help_text="is it validate(0: wasted, 1: validate)",
474
  default="1",
475
  index=True)
476
  is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
 
525
  status = CharField(
526
  max_length=1,
527
  null=True,
528
+ help_text="is it validate(0: wasted, 1: validate)",
529
  default="1",
530
  index=True)
531
 
 
542
  status = CharField(
543
  max_length=1,
544
  null=True,
545
+ help_text="is it validate(0: wasted, 1: validate)",
546
  default="1",
547
  index=True)
548
 
 
559
  status = CharField(
560
  max_length=1,
561
  null=True,
562
+ help_text="is it validate(0: wasted, 1: validate)",
563
  default="1",
564
  index=True)
565
 
 
582
  status = CharField(
583
  max_length=1,
584
  null=True,
585
+ help_text="is it validate(0: wasted, 1: validate)",
586
  default="1",
587
  index=True)
588
 
 
616
  status = CharField(
617
  max_length=1,
618
  null=True,
619
+ help_text="is it validate(0: wasted, 1: validate)",
620
  default="1",
621
  index=True)
622
 
 
703
  status = CharField(
704
  max_length=1,
705
  null=True,
706
+ help_text="is it validate(0: wasted, 1: validate)",
707
  default="1",
708
  index=True)
709
 
 
767
  status = CharField(
768
  max_length=1,
769
  null=True,
770
+ help_text="is it validate(0: wasted, 1: validate)",
771
  default="1",
772
  index=True)
773
 
 
904
  status = CharField(
905
  max_length=1,
906
  null=True,
907
+ help_text="is it validate(0: wasted, 1: validate)",
908
  default="1",
909
  index=True)
910
 
 
987
  help_text="where dose this document come from",
988
  index=True))
989
  )
990
+ except Exception:
991
  pass
992
  try:
993
  migrate(
 
996
  help_text="default rerank model ID"))
997
 
998
  )
999
+ except Exception:
1000
  pass
1001
  try:
1002
  migrate(
 
1004
  help_text="default rerank model ID"))
1005
 
1006
  )
1007
+ except Exception:
1008
  pass
1009
  try:
1010
  migrate(
1011
  migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
1012
 
1013
  )
1014
+ except Exception:
1015
  pass
1016
  try:
1017
  migrate(
1018
  migrator.alter_column_type('tenant_llm', 'api_key',
1019
  CharField(max_length=1024, null=True, help_text="API KEY", index=True))
1020
  )
1021
+ except Exception:
1022
  pass
1023
  try:
1024
  migrate(
1025
  migrator.add_column('api_token', 'source',
1026
  CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
1027
  )
1028
+ except Exception:
1029
  pass
1030
  try:
1031
  migrate(
1032
  migrator.add_column("tenant","tts_id",
1033
  CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
1034
  )
1035
+ except Exception:
1036
  pass
1037
  try:
1038
  migrate(
1039
  migrator.add_column('api_4_conversation', 'source',
1040
  CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
1041
  )
1042
+ except Exception:
1043
  pass
1044
  try:
1045
  DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
1046
  DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
1047
+ except Exception:
1048
  pass
1049
  try:
1050
  migrate(
1051
  migrator.add_column('task', 'retry_count', IntegerField(default=0))
1052
  )
1053
+ except Exception:
1054
  pass
1055
  try:
1056
  migrate(
1057
  migrator.alter_column_type('api_token', 'dialog_id',
1058
  CharField(max_length=32, null=True, index=True))
1059
  )
1060
+ except Exception:
1061
  pass
1062
 
api/db/services/document_service.py CHANGED
@@ -15,7 +15,6 @@
15
  #
16
  import hashlib
17
  import json
18
- import os
19
  import random
20
  import re
21
  import traceback
@@ -24,16 +23,13 @@ from copy import deepcopy
24
  from datetime import datetime
25
  from io import BytesIO
26
 
27
- from elasticsearch_dsl import Q
28
  from peewee import fn
29
 
30
  from api.db.db_utils import bulk_insert_into_db
31
- from api.settings import stat_logger
32
  from api.utils import current_timestamp, get_format_time, get_uuid
33
- from api.utils.file_utils import get_project_base_directory
34
  from graphrag.mind_map_extractor import MindMapExtractor
35
  from rag.settings import SVR_QUEUE_NAME
36
- from rag.utils.es_conn import ELASTICSEARCH
37
  from rag.utils.storage_factory import STORAGE_IMPL
38
  from rag.nlp import search, rag_tokenizer
39
 
@@ -112,8 +108,7 @@ class DocumentService(CommonService):
112
  @classmethod
113
  @DB.connection_context()
114
  def remove_document(cls, doc, tenant_id):
115
- ELASTICSEARCH.deleteByQuery(
116
- Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
117
  cls.clear_chunk_num(doc.id)
118
  return cls.delete_by_id(doc.id)
119
 
@@ -225,6 +220,15 @@ class DocumentService(CommonService):
225
  return
226
  return docs[0]["tenant_id"]
227
 
 
 
 
 
 
 
 
 
 
228
  @classmethod
229
  @DB.connection_context()
230
  def get_tenant_id_by_name(cls, name):
@@ -438,11 +442,6 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
438
  if not e:
439
  raise LookupError("Can't find this knowledgebase!")
440
 
441
- idxnm = search.index_name(kb.tenant_id)
442
- if not ELASTICSEARCH.indexExist(idxnm):
443
- ELASTICSEARCH.createIdx(idxnm, json.load(
444
- open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
445
-
446
  embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
447
 
448
  err, files = FileService.upload_document(kb, file_objs, user_id)
@@ -486,7 +485,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
486
  md5 = hashlib.md5()
487
  md5.update((ck["content_with_weight"] +
488
  str(d["doc_id"])).encode("utf-8"))
489
- d["_id"] = md5.hexdigest()
490
  d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
491
  d["create_timestamp_flt"] = datetime.now().timestamp()
492
  if not d.get("image"):
@@ -499,8 +498,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
499
  else:
500
  d["image"].save(output_buffer, format='JPEG')
501
 
502
- STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
503
- d["img_id"] = "{}-{}".format(kb.id, d["_id"])
504
  del d["image"]
505
  docs.append(d)
506
 
@@ -520,6 +519,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
520
  token_counts[doc_id] += c
521
  return vects
522
 
 
 
 
523
  _, tenant = TenantService.get_by_id(kb.tenant_id)
524
  llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
525
  for doc_id in docids:
@@ -550,7 +552,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
550
  v = vects[i]
551
  d["q_%d_vec" % len(v)] = v
552
  for b in range(0, len(cks), es_bulk_size):
553
- ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
 
 
 
 
554
 
555
  DocumentService.increment_chunk_num(
556
  doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
 
15
  #
16
  import hashlib
17
  import json
 
18
  import random
19
  import re
20
  import traceback
 
23
  from datetime import datetime
24
  from io import BytesIO
25
 
 
26
  from peewee import fn
27
 
28
  from api.db.db_utils import bulk_insert_into_db
29
+ from api.settings import stat_logger, docStoreConn
30
  from api.utils import current_timestamp, get_format_time, get_uuid
 
31
  from graphrag.mind_map_extractor import MindMapExtractor
32
  from rag.settings import SVR_QUEUE_NAME
 
33
  from rag.utils.storage_factory import STORAGE_IMPL
34
  from rag.nlp import search, rag_tokenizer
35
 
 
108
  @classmethod
109
  @DB.connection_context()
110
  def remove_document(cls, doc, tenant_id):
111
+ docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
 
112
  cls.clear_chunk_num(doc.id)
113
  return cls.delete_by_id(doc.id)
114
 
 
220
  return
221
  return docs[0]["tenant_id"]
222
 
223
+ @classmethod
224
+ @DB.connection_context()
225
+ def get_knowledgebase_id(cls, doc_id):
226
+ docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
227
+ docs = docs.dicts()
228
+ if not docs:
229
+ return
230
+ return docs[0]["kb_id"]
231
+
232
  @classmethod
233
  @DB.connection_context()
234
  def get_tenant_id_by_name(cls, name):
 
442
  if not e:
443
  raise LookupError("Can't find this knowledgebase!")
444
 
 
 
 
 
 
445
  embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
446
 
447
  err, files = FileService.upload_document(kb, file_objs, user_id)
 
485
  md5 = hashlib.md5()
486
  md5.update((ck["content_with_weight"] +
487
  str(d["doc_id"])).encode("utf-8"))
488
+ d["id"] = md5.hexdigest()
489
  d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
490
  d["create_timestamp_flt"] = datetime.now().timestamp()
491
  if not d.get("image"):
 
498
  else:
499
  d["image"].save(output_buffer, format='JPEG')
500
 
501
+ STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
502
+ d["img_id"] = "{}-{}".format(kb.id, d["id"])
503
  del d["image"]
504
  docs.append(d)
505
 
 
519
  token_counts[doc_id] += c
520
  return vects
521
 
522
+ idxnm = search.index_name(kb.tenant_id)
523
+ try_create_idx = True
524
+
525
  _, tenant = TenantService.get_by_id(kb.tenant_id)
526
  llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
527
  for doc_id in docids:
 
552
  v = vects[i]
553
  d["q_%d_vec" % len(v)] = v
554
  for b in range(0, len(cks), es_bulk_size):
555
+ if try_create_idx:
556
+ if not docStoreConn.indexExist(idxnm, kb_id):
557
+ docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
558
+ try_create_idx = False
559
+ docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
560
 
561
  DocumentService.increment_chunk_num(
562
  doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
api/db/services/knowledgebase_service.py CHANGED
@@ -66,6 +66,16 @@ class KnowledgebaseService(CommonService):
66
 
67
  return list(kbs.dicts())
68
 
 
 
 
 
 
 
 
 
 
 
69
  @classmethod
70
  @DB.connection_context()
71
  def get_detail(cls, kb_id):
 
66
 
67
  return list(kbs.dicts())
68
 
69
+ @classmethod
70
+ @DB.connection_context()
71
+ def get_kb_ids(cls, tenant_id):
72
+ fields = [
73
+ cls.model.id,
74
+ ]
75
+ kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
76
+ kb_ids = [kb["id"] for kb in kbs]
77
+ return kb_ids
78
+
79
  @classmethod
80
  @DB.connection_context()
81
  def get_detail(cls, kb_id):
api/settings.py CHANGED
@@ -18,6 +18,8 @@ from datetime import date
18
  from enum import IntEnum, Enum
19
  from api.utils.file_utils import get_project_base_directory
20
  from api.utils.log_utils import LoggerFactory, getLogger
 
 
21
 
22
  # Logger
23
  LoggerFactory.set_directory(
@@ -33,7 +35,7 @@ access_logger = getLogger("access")
33
  database_logger = getLogger("database")
34
  chat_logger = getLogger("chat")
35
 
36
- from rag.utils.es_conn import ELASTICSEARCH
37
  from rag.nlp import search
38
  from graphrag import search as kg_search
39
  from api.utils import get_base_config, decrypt_database_config
@@ -206,8 +208,12 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
206
  PRIVILEGE_COMMAND_WHITELIST = []
207
  CHECK_NODES_IDENTITY = False
208
 
209
- retrievaler = search.Dealer(ELASTICSEARCH)
210
- kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
 
 
 
 
211
 
212
 
213
  class CustomEnum(Enum):
 
18
  from enum import IntEnum, Enum
19
  from api.utils.file_utils import get_project_base_directory
20
  from api.utils.log_utils import LoggerFactory, getLogger
21
+ import rag.utils.es_conn
22
+ import rag.utils.infinity_conn
23
 
24
  # Logger
25
  LoggerFactory.set_directory(
 
35
  database_logger = getLogger("database")
36
  chat_logger = getLogger("chat")
37
 
38
+ import rag.utils
39
  from rag.nlp import search
40
  from graphrag import search as kg_search
41
  from api.utils import get_base_config, decrypt_database_config
 
208
  PRIVILEGE_COMMAND_WHITELIST = []
209
  CHECK_NODES_IDENTITY = False
210
 
211
+ if 'username' in get_base_config("es", {}):
212
+ docStoreConn = rag.utils.es_conn.ESConnection()
213
+ else:
214
+ docStoreConn = rag.utils.infinity_conn.InfinityConnection()
215
+ retrievaler = search.Dealer(docStoreConn)
216
+ kg_retrievaler = kg_search.KGSearch(docStoreConn)
217
 
218
 
219
  class CustomEnum(Enum):
api/utils/api_utils.py CHANGED
@@ -126,10 +126,6 @@ def server_error_response(e):
126
  if len(e.args) > 1:
127
  return get_json_result(
128
  code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
129
- if repr(e).find("index_not_found_exception") >= 0:
130
- return get_json_result(code=RetCode.EXCEPTION_ERROR,
131
- message="No chunk found, please upload file and parse it.")
132
-
133
  return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
134
 
135
 
@@ -270,10 +266,6 @@ def construct_error_response(e):
270
  pass
271
  if len(e.args) > 1:
272
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
273
- if repr(e).find("index_not_found_exception") >= 0:
274
- return construct_json_result(code=RetCode.EXCEPTION_ERROR,
275
- message="No chunk found, please upload file and parse it.")
276
-
277
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
278
 
279
 
@@ -295,7 +287,7 @@ def token_required(func):
295
  return decorated_function
296
 
297
 
298
- def get_result(code=RetCode.SUCCESS, message='error', data=None):
299
  if code == 0:
300
  if data is not None:
301
  response = {"code": code, "data": data}
 
126
  if len(e.args) > 1:
127
  return get_json_result(
128
  code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
 
 
 
 
129
  return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
130
 
131
 
 
266
  pass
267
  if len(e.args) > 1:
268
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
 
 
 
 
269
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
270
 
271
 
 
287
  return decorated_function
288
 
289
 
290
+ def get_result(code=RetCode.SUCCESS, message="", data=None):
291
  if code == 0:
292
  if data is not None:
293
  response = {"code": code, "data": data}
conf/infinity_mapping.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "id": {"type": "varchar", "default": ""},
3
+ "doc_id": {"type": "varchar", "default": ""},
4
+ "kb_id": {"type": "varchar", "default": ""},
5
+ "create_time": {"type": "varchar", "default": ""},
6
+ "create_timestamp_flt": {"type": "float", "default": 0.0},
7
+ "img_id": {"type": "varchar", "default": ""},
8
+ "docnm_kwd": {"type": "varchar", "default": ""},
9
+ "title_tks": {"type": "varchar", "default": ""},
10
+ "title_sm_tks": {"type": "varchar", "default": ""},
11
+ "name_kwd": {"type": "varchar", "default": ""},
12
+ "important_kwd": {"type": "varchar", "default": ""},
13
+ "important_tks": {"type": "varchar", "default": ""},
14
+ "content_with_weight": {"type": "varchar", "default": ""},
15
+ "content_ltks": {"type": "varchar", "default": ""},
16
+ "content_sm_ltks": {"type": "varchar", "default": ""},
17
+ "page_num_list": {"type": "varchar", "default": ""},
18
+ "top_list": {"type": "varchar", "default": ""},
19
+ "position_list": {"type": "varchar", "default": ""},
20
+ "weight_int": {"type": "integer", "default": 0},
21
+ "weight_flt": {"type": "float", "default": 0.0},
22
+ "rank_int": {"type": "integer", "default": 0},
23
+ "available_int": {"type": "integer", "default": 1},
24
+ "knowledge_graph_kwd": {"type": "varchar", "default": ""},
25
+ "entities_kwd": {"type": "varchar", "default": ""}
26
+ }
conf/mapping.json CHANGED
@@ -1,200 +1,203 @@
1
- {
2
  "settings": {
3
  "index": {
4
  "number_of_shards": 2,
5
  "number_of_replicas": 0,
6
- "refresh_interval" : "1000ms"
7
  },
8
  "similarity": {
9
- "scripted_sim": {
10
- "type": "scripted",
11
- "script": {
12
- "source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
13
- }
14
  }
 
15
  }
16
  },
17
  "mappings": {
18
- "properties": {
19
- "lat_lon": {"type": "geo_point", "store":"true"}
20
- },
21
- "date_detection": "true",
22
- "dynamic_templates": [
23
- {
24
- "int": {
25
- "match": "*_int",
26
- "mapping": {
27
- "type": "integer",
28
- "store": "true"
29
- }
30
- }
31
- },
32
- {
33
- "ulong": {
34
- "match": "*_ulong",
35
- "mapping": {
36
- "type": "unsigned_long",
37
- "store": "true"
38
- }
39
- }
40
- },
41
- {
42
- "long": {
43
- "match": "*_long",
44
- "mapping": {
45
- "type": "long",
46
- "store": "true"
47
- }
48
- }
49
- },
50
- {
51
- "short": {
52
- "match": "*_short",
53
- "mapping": {
54
- "type": "short",
55
- "store": "true"
56
- }
57
- }
58
- },
59
- {
60
- "numeric": {
61
- "match": "*_flt",
62
- "mapping": {
63
- "type": "float",
64
- "store": true
65
- }
66
- }
67
- },
68
- {
69
- "tks": {
70
- "match": "*_tks",
71
- "mapping": {
72
- "type": "text",
73
- "similarity": "scripted_sim",
74
- "analyzer": "whitespace",
75
- "store": true
76
- }
77
- }
78
- },
79
- {
80
- "ltks":{
81
- "match": "*_ltks",
82
- "mapping": {
83
- "type": "text",
84
- "analyzer": "whitespace",
85
- "store": true
86
- }
87
- }
88
- },
89
- {
90
- "kwd": {
91
- "match_pattern": "regex",
92
- "match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
93
- "mapping": {
94
- "type": "keyword",
95
- "similarity": "boolean",
96
- "store": true
97
- }
98
- }
99
- },
100
- {
101
- "dt": {
102
- "match_pattern": "regex",
103
- "match": "^.*(_dt|_time|_at)$",
104
- "mapping": {
105
- "type": "date",
106
- "format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
107
- "store": true
108
- }
109
- }
110
- },
111
- {
112
- "nested": {
113
- "match": "*_nst",
114
- "mapping": {
115
- "type": "nested"
116
- }
117
- }
118
- },
119
- {
120
- "object": {
121
- "match": "*_obj",
122
- "mapping": {
123
- "type": "object",
124
- "dynamic": "true"
125
- }
126
- }
127
- },
128
- {
129
- "string": {
130
- "match": "*_with_weight",
131
- "mapping": {
132
- "type": "text",
133
- "index": "false",
134
- "store": true
135
- }
136
- }
137
- },
138
- {
139
- "string": {
140
- "match": "*_fea",
141
- "mapping": {
142
- "type": "rank_feature"
143
- }
144
- }
145
- },
146
- {
147
- "dense_vector": {
148
- "match": "*_512_vec",
149
- "mapping": {
150
- "type": "dense_vector",
151
- "index": true,
152
- "similarity": "cosine",
153
- "dims": 512
154
- }
155
- }
156
- },
157
- {
158
- "dense_vector": {
159
- "match": "*_768_vec",
160
- "mapping": {
161
- "type": "dense_vector",
162
- "index": true,
163
- "similarity": "cosine",
164
- "dims": 768
165
- }
166
- }
167
- },
168
- {
169
- "dense_vector": {
170
- "match": "*_1024_vec",
171
- "mapping": {
172
- "type": "dense_vector",
173
- "index": true,
174
- "similarity": "cosine",
175
- "dims": 1024
176
- }
177
- }
178
- },
179
- {
180
- "dense_vector": {
181
- "match": "*_1536_vec",
182
- "mapping": {
183
- "type": "dense_vector",
184
- "index": true,
185
- "similarity": "cosine",
186
- "dims": 1536
187
- }
188
- }
189
- },
190
- {
191
- "binary": {
192
- "match": "*_bin",
193
- "mapping": {
194
- "type": "binary"
195
- }
196
- }
197
- }
198
- ]
199
- }
200
- }
 
 
 
 
1
+ {
2
  "settings": {
3
  "index": {
4
  "number_of_shards": 2,
5
  "number_of_replicas": 0,
6
+ "refresh_interval": "1000ms"
7
  },
8
  "similarity": {
9
+ "scripted_sim": {
10
+ "type": "scripted",
11
+ "script": {
12
+ "source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
 
13
  }
14
+ }
15
  }
16
  },
17
  "mappings": {
18
+ "properties": {
19
+ "lat_lon": {
20
+ "type": "geo_point",
21
+ "store": "true"
22
+ }
23
+ },
24
+ "date_detection": "true",
25
+ "dynamic_templates": [
26
+ {
27
+ "int": {
28
+ "match": "*_int",
29
+ "mapping": {
30
+ "type": "integer",
31
+ "store": "true"
32
+ }
33
+ }
34
+ },
35
+ {
36
+ "ulong": {
37
+ "match": "*_ulong",
38
+ "mapping": {
39
+ "type": "unsigned_long",
40
+ "store": "true"
41
+ }
42
+ }
43
+ },
44
+ {
45
+ "long": {
46
+ "match": "*_long",
47
+ "mapping": {
48
+ "type": "long",
49
+ "store": "true"
50
+ }
51
+ }
52
+ },
53
+ {
54
+ "short": {
55
+ "match": "*_short",
56
+ "mapping": {
57
+ "type": "short",
58
+ "store": "true"
59
+ }
60
+ }
61
+ },
62
+ {
63
+ "numeric": {
64
+ "match": "*_flt",
65
+ "mapping": {
66
+ "type": "float",
67
+ "store": true
68
+ }
69
+ }
70
+ },
71
+ {
72
+ "tks": {
73
+ "match": "*_tks",
74
+ "mapping": {
75
+ "type": "text",
76
+ "similarity": "scripted_sim",
77
+ "analyzer": "whitespace",
78
+ "store": true
79
+ }
80
+ }
81
+ },
82
+ {
83
+ "ltks": {
84
+ "match": "*_ltks",
85
+ "mapping": {
86
+ "type": "text",
87
+ "analyzer": "whitespace",
88
+ "store": true
89
+ }
90
+ }
91
+ },
92
+ {
93
+ "kwd": {
94
+ "match_pattern": "regex",
95
+ "match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
96
+ "mapping": {
97
+ "type": "keyword",
98
+ "similarity": "boolean",
99
+ "store": true
100
+ }
101
+ }
102
+ },
103
+ {
104
+ "dt": {
105
+ "match_pattern": "regex",
106
+ "match": "^.*(_dt|_time|_at)$",
107
+ "mapping": {
108
+ "type": "date",
109
+ "format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
110
+ "store": true
111
+ }
112
+ }
113
+ },
114
+ {
115
+ "nested": {
116
+ "match": "*_nst",
117
+ "mapping": {
118
+ "type": "nested"
119
+ }
120
+ }
121
+ },
122
+ {
123
+ "object": {
124
+ "match": "*_obj",
125
+ "mapping": {
126
+ "type": "object",
127
+ "dynamic": "true"
128
+ }
129
+ }
130
+ },
131
+ {
132
+ "string": {
133
+ "match": "*_(with_weight|list)$",
134
+ "mapping": {
135
+ "type": "text",
136
+ "index": "false",
137
+ "store": true
138
+ }
139
+ }
140
+ },
141
+ {
142
+ "string": {
143
+ "match": "*_fea",
144
+ "mapping": {
145
+ "type": "rank_feature"
146
+ }
147
+ }
148
+ },
149
+ {
150
+ "dense_vector": {
151
+ "match": "*_512_vec",
152
+ "mapping": {
153
+ "type": "dense_vector",
154
+ "index": true,
155
+ "similarity": "cosine",
156
+ "dims": 512
157
+ }
158
+ }
159
+ },
160
+ {
161
+ "dense_vector": {
162
+ "match": "*_768_vec",
163
+ "mapping": {
164
+ "type": "dense_vector",
165
+ "index": true,
166
+ "similarity": "cosine",
167
+ "dims": 768
168
+ }
169
+ }
170
+ },
171
+ {
172
+ "dense_vector": {
173
+ "match": "*_1024_vec",
174
+ "mapping": {
175
+ "type": "dense_vector",
176
+ "index": true,
177
+ "similarity": "cosine",
178
+ "dims": 1024
179
+ }
180
+ }
181
+ },
182
+ {
183
+ "dense_vector": {
184
+ "match": "*_1536_vec",
185
+ "mapping": {
186
+ "type": "dense_vector",
187
+ "index": true,
188
+ "similarity": "cosine",
189
+ "dims": 1536
190
+ }
191
+ }
192
+ },
193
+ {
194
+ "binary": {
195
+ "match": "*_bin",
196
+ "mapping": {
197
+ "type": "binary"
198
+ }
199
+ }
200
+ }
201
+ ]
202
+ }
203
+ }
docker/.env CHANGED
@@ -19,6 +19,11 @@ KIBANA_PASSWORD=infini_rag_flow
19
  # Update it according to the available memory in the host machine.
20
  MEM_LIMIT=8073741824
21
 
 
 
 
 
 
22
  # The password for MySQL.
23
  # When updated, you must revise the `mysql.password` entry in service_conf.yaml.
24
  MYSQL_PASSWORD=infini_rag_flow
 
19
  # Update it according to the available memory in the host machine.
20
  MEM_LIMIT=8073741824
21
 
22
+ # Port to expose Infinity API to the host
23
+ INFINITY_THRIFT_PORT=23817
24
+ INFINITY_HTTP_PORT=23820
25
+ INFINITY_PSQL_PORT=5432
26
+
27
  # The password for MySQL.
28
  # When updated, you must revise the `mysql.password` entry in service_conf.yaml.
29
  MYSQL_PASSWORD=infini_rag_flow
docker/docker-compose-base.yml CHANGED
@@ -6,6 +6,7 @@ services:
6
  - esdata01:/usr/share/elasticsearch/data
7
  ports:
8
  - ${ES_PORT}:9200
 
9
  environment:
10
  - node.name=es01
11
  - ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
@@ -27,12 +28,40 @@ services:
27
  retries: 120
28
  networks:
29
  - ragflow
30
- restart: always
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  mysql:
33
  # mysql:5.7 linux/arm64 image is unavailable.
34
  image: mysql:8.0.39
35
  container_name: ragflow-mysql
 
36
  environment:
37
  - MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
38
  - TZ=${TIMEZONE}
@@ -55,7 +84,7 @@ services:
55
  interval: 10s
56
  timeout: 10s
57
  retries: 3
58
- restart: always
59
 
60
  minio:
61
  image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
@@ -64,6 +93,7 @@ services:
64
  ports:
65
  - ${MINIO_PORT}:9000
66
  - ${MINIO_CONSOLE_PORT}:9001
 
67
  environment:
68
  - MINIO_ROOT_USER=${MINIO_USER}
69
  - MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
@@ -72,25 +102,28 @@ services:
72
  - minio_data:/data
73
  networks:
74
  - ragflow
75
- restart: always
76
 
77
  redis:
78
  image: valkey/valkey:8
79
  container_name: ragflow-redis
80
  command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
 
81
  ports:
82
  - ${REDIS_PORT}:6379
83
  volumes:
84
  - redis_data:/data
85
  networks:
86
  - ragflow
87
- restart: always
88
 
89
 
90
 
91
  volumes:
92
  esdata01:
93
  driver: local
 
 
94
  mysql_data:
95
  driver: local
96
  minio_data:
 
6
  - esdata01:/usr/share/elasticsearch/data
7
  ports:
8
  - ${ES_PORT}:9200
9
+ env_file: .env
10
  environment:
11
  - node.name=es01
12
  - ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
 
28
  retries: 120
29
  networks:
30
  - ragflow
31
+ restart: on-failure
32
+
33
+ # infinity:
34
+ # container_name: ragflow-infinity
35
+ # image: infiniflow/infinity:v0.5.0-dev2
36
+ # volumes:
37
+ # - infinity_data:/var/infinity
38
+ # ports:
39
+ # - ${INFINITY_THRIFT_PORT}:23817
40
+ # - ${INFINITY_HTTP_PORT}:23820
41
+ # - ${INFINITY_PSQL_PORT}:5432
42
+ # env_file: .env
43
+ # environment:
44
+ # - TZ=${TIMEZONE}
45
+ # mem_limit: ${MEM_LIMIT}
46
+ # ulimits:
47
+ # nofile:
48
+ # soft: 500000
49
+ # hard: 500000
50
+ # networks:
51
+ # - ragflow
52
+ # healthcheck:
53
+ # test: ["CMD", "curl", "http://localhost:23820/admin/node/current"]
54
+ # interval: 10s
55
+ # timeout: 10s
56
+ # retries: 120
57
+ # restart: on-failure
58
+
59
 
60
  mysql:
61
  # mysql:5.7 linux/arm64 image is unavailable.
62
  image: mysql:8.0.39
63
  container_name: ragflow-mysql
64
+ env_file: .env
65
  environment:
66
  - MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
67
  - TZ=${TIMEZONE}
 
84
  interval: 10s
85
  timeout: 10s
86
  retries: 3
87
+ restart: on-failure
88
 
89
  minio:
90
  image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
 
93
  ports:
94
  - ${MINIO_PORT}:9000
95
  - ${MINIO_CONSOLE_PORT}:9001
96
+ env_file: .env
97
  environment:
98
  - MINIO_ROOT_USER=${MINIO_USER}
99
  - MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
 
102
  - minio_data:/data
103
  networks:
104
  - ragflow
105
+ restart: on-failure
106
 
107
  redis:
108
  image: valkey/valkey:8
109
  container_name: ragflow-redis
110
  command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
111
+ env_file: .env
112
  ports:
113
  - ${REDIS_PORT}:6379
114
  volumes:
115
  - redis_data:/data
116
  networks:
117
  - ragflow
118
+ restart: on-failure
119
 
120
 
121
 
122
  volumes:
123
  esdata01:
124
  driver: local
125
+ infinity_data:
126
+ driver: local
127
  mysql_data:
128
  driver: local
129
  minio_data:
docker/docker-compose.yml CHANGED
@@ -1,6 +1,5 @@
1
  include:
2
- - path: ./docker-compose-base.yml
3
- env_file: ./.env
4
 
5
  services:
6
  ragflow:
@@ -15,19 +14,21 @@ services:
15
  - ${SVR_HTTP_PORT}:9380
16
  - 80:80
17
  - 443:443
18
- - 5678:5678
19
  volumes:
20
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
21
  - ./ragflow-logs:/ragflow/logs
22
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
23
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
24
  - ./nginx/nginx.conf:/etc/nginx/nginx.conf
 
25
  environment:
26
  - TZ=${TIMEZONE}
27
  - HF_ENDPOINT=${HF_ENDPOINT}
28
  - MACOS=${MACOS}
29
  networks:
30
  - ragflow
31
- restart: always
 
 
32
  extra_hosts:
33
  - "host.docker.internal:host-gateway"
 
1
  include:
2
+ - ./docker-compose-base.yml
 
3
 
4
  services:
5
  ragflow:
 
14
  - ${SVR_HTTP_PORT}:9380
15
  - 80:80
16
  - 443:443
 
17
  volumes:
18
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
19
  - ./ragflow-logs:/ragflow/logs
20
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
21
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
22
  - ./nginx/nginx.conf:/etc/nginx/nginx.conf
23
+ env_file: .env
24
  environment:
25
  - TZ=${TIMEZONE}
26
  - HF_ENDPOINT=${HF_ENDPOINT}
27
  - MACOS=${MACOS}
28
  networks:
29
  - ragflow
30
+ restart: on-failure
31
+ # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
32
+ # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
33
  extra_hosts:
34
  - "host.docker.internal:host-gateway"
docs/guides/develop/launch_ragflow_from_source.md CHANGED
@@ -67,7 +67,7 @@ docker compose -f docker/docker-compose-base.yml up -d
67
  1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
68
 
69
  ```
70
- 127.0.0.1 es01 mysql minio redis
71
  ```
72
 
73
  2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
 
67
  1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
68
 
69
  ```
70
+ 127.0.0.1 es01 infinity mysql minio redis
71
  ```
72
 
73
  2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
docs/references/http_api_reference.md CHANGED
@@ -1280,7 +1280,7 @@ Success:
1280
  "document_keyword": "1.txt",
1281
  "highlight": "<em>ragflow</em> content",
1282
  "id": "d78435d142bd5cf6704da62c778795c5",
1283
- "img_id": "",
1284
  "important_keywords": [
1285
  ""
1286
  ],
 
1280
  "document_keyword": "1.txt",
1281
  "highlight": "<em>ragflow</em> content",
1282
  "id": "d78435d142bd5cf6704da62c778795c5",
1283
+ "image_id": "",
1284
  "important_keywords": [
1285
  ""
1286
  ],
docs/references/python_api_reference.md CHANGED
@@ -1351,7 +1351,7 @@ A list of `Chunk` objects representing references to the message, each containin
1351
  The chunk ID.
1352
  - `content` `str`
1353
  The content of the chunk.
1354
- - `image_id` `str`
1355
  The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
1356
  - `document_id` `str`
1357
  The ID of the referenced document.
 
1351
  The chunk ID.
1352
  - `content` `str`
1353
  The content of the chunk.
1354
+ - `img_id` `str`
1355
  The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
1356
  - `document_id` `str`
1357
  The ID of the referenced document.
graphrag/claim_extractor.py CHANGED
@@ -254,9 +254,12 @@ if __name__ == "__main__":
254
  from api.db import LLMType
255
  from api.db.services.llm_service import LLMBundle
256
  from api.settings import retrievaler
 
 
 
257
 
258
  ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
259
- docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])]
260
  info = {
261
  "input_text": docs,
262
  "entity_specs": "organization, person",
 
254
  from api.db import LLMType
255
  from api.db.services.llm_service import LLMBundle
256
  from api.settings import retrievaler
257
+ from api.db.services.knowledgebase_service import KnowledgebaseService
258
+
259
+ kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
260
 
261
  ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
262
+ docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
263
  info = {
264
  "input_text": docs,
265
  "entity_specs": "organization, person",
graphrag/search.py CHANGED
@@ -15,95 +15,90 @@
15
  #
16
  import json
17
  from copy import deepcopy
 
18
 
19
  import pandas as pd
20
- from elasticsearch_dsl import Q, Search
21
 
22
  from rag.nlp.search import Dealer
23
 
24
 
25
  class KGSearch(Dealer):
26
- def search(self, req, idxnm, emb_mdl=None, highlight=False):
27
- def merge_into_first(sres, title=""):
28
- df,texts = [],[]
29
- for d in sres["hits"]["hits"]:
 
 
 
30
  try:
31
- df.append(json.loads(d["_source"]["content_with_weight"]))
32
- except Exception as e:
33
- texts.append(d["_source"]["content_with_weight"])
34
- pass
35
- if not df and not texts: return False
36
  if df:
37
- try:
38
- sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
39
- except Exception as e:
40
- pass
41
  else:
42
- sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
43
- return True
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
45
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
46
- "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
47
  "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
48
  "weight_int", "weight_flt", "rank_int"
49
  ])
50
 
51
- qst = req.get("question", "")
52
- binary_query, keywords = self.qryr.question(qst, min_match="5%")
53
- binary_query = self._add_filters(binary_query, req)
54
 
55
- ## Entity retrieval
56
- bqry = deepcopy(binary_query)
57
- bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
58
- s = Search()
59
- s = s.query(bqry)[0: 32]
60
-
61
- s = s.to_dict()
62
- q_vec = []
63
- if req.get("vector"):
64
- assert emb_mdl, "No embedding model selected"
65
- s["knn"] = self._vector(
66
- qst, emb_mdl, req.get(
67
- "similarity", 0.1), 1024)
68
- s["knn"]["filter"] = bqry.to_dict()
69
- q_vec = s["knn"]["query_vector"]
70
-
71
- ent_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
72
- entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
73
- ent_ids = self.es.getDocIds(ent_res)
74
- if merge_into_first(ent_res, "-Entities-"):
75
- ent_ids = ent_ids[0:1]
76
 
77
  ## Community retrieval
78
- bqry = deepcopy(binary_query)
79
- bqry.filter.append(Q("terms", entities_kwd=entities))
80
- bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
81
- s = Search()
82
- s = s.query(bqry)[0: 32]
83
- s = s.to_dict()
84
- comm_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
85
- comm_ids = self.es.getDocIds(comm_res)
86
- if merge_into_first(comm_res, "-Community Report-"):
87
- comm_ids = comm_ids[0:1]
88
 
89
  ## Text content retrieval
90
- bqry = deepcopy(binary_query)
91
- bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
92
- s = Search()
93
- s = s.query(bqry)[0: 6]
94
- s = s.to_dict()
95
- txt_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
96
- txt_ids = self.es.getDocIds(txt_res)
97
- if merge_into_first(txt_res, "-Original Content-"):
98
- txt_ids = txt_ids[0:1]
99
 
100
  return self.SearchResult(
101
  total=len(ent_ids) + len(comm_ids) + len(txt_ids),
102
  ids=[*ent_ids, *comm_ids, *txt_ids],
103
  query_vector=q_vec,
104
- aggregation=None,
105
  highlight=None,
106
- field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
107
  keywords=[]
108
  )
109
-
 
15
  #
16
  import json
17
  from copy import deepcopy
18
+ from typing import Dict
19
 
20
  import pandas as pd
21
+ from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
22
 
23
  from rag.nlp.search import Dealer
24
 
25
 
26
  class KGSearch(Dealer):
27
+ def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
28
+ def merge_into_first(sres, title="") -> Dict[str, str]:
29
+ if not sres:
30
+ return {}
31
+ content_with_weight = ""
32
+ df, texts = [],[]
33
+ for d in sres.values():
34
  try:
35
+ df.append(json.loads(d["content_with_weight"]))
36
+ except Exception:
37
+ texts.append(d["content_with_weight"])
 
 
38
  if df:
39
+ content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
 
 
 
40
  else:
41
+ content_with_weight = title + "\n" + "\n".join(texts)
42
+ first_id = ""
43
+ first_source = {}
44
+ for k, v in sres.items():
45
+ first_id = id
46
+ first_source = deepcopy(v)
47
+ break
48
+ first_source["content_with_weight"] = content_with_weight
49
+ first_id = next(iter(sres))
50
+ return {first_id: first_source}
51
+
52
+ qst = req.get("question", "")
53
+ matchText, keywords = self.qryr.question(qst, min_match=0.05)
54
+ condition = self.get_filters(req)
55
 
56
+ ## Entity retrieval
57
+ condition.update({"knowledge_graph_kwd": ["entity"]})
58
+ assert emb_mdl, "No embedding model selected"
59
+ matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
60
+ q_vec = matchDense.embedding_data
61
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
62
+ "doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
63
  "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
64
  "weight_int", "weight_flt", "rank_int"
65
  ])
66
 
67
+ fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
 
 
68
 
69
+ ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
70
+ ent_res_fields = self.dataStore.getFields(ent_res, src)
71
+ entities = [d["name_kwd"] for d in ent_res_fields.values()]
72
+ ent_ids = self.dataStore.getChunkIds(ent_res)
73
+ ent_content = merge_into_first(ent_res_fields, "-Entities-")
74
+ if ent_content:
75
+ ent_ids = list(ent_content.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  ## Community retrieval
78
+ condition = self.get_filters(req)
79
+ condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
80
+ comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
81
+ comm_res_fields = self.dataStore.getFields(comm_res, src)
82
+ comm_ids = self.dataStore.getChunkIds(comm_res)
83
+ comm_content = merge_into_first(comm_res_fields, "-Community Report-")
84
+ if comm_content:
85
+ comm_ids = list(comm_content.keys())
 
 
86
 
87
  ## Text content retrieval
88
+ condition = self.get_filters(req)
89
+ condition.update({"knowledge_graph_kwd": ["text"]})
90
+ txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
91
+ txt_res_fields = self.dataStore.getFields(txt_res, src)
92
+ txt_ids = self.dataStore.getChunkIds(txt_res)
93
+ txt_content = merge_into_first(txt_res_fields, "-Original Content-")
94
+ if txt_content:
95
+ txt_ids = list(txt_content.keys())
 
96
 
97
  return self.SearchResult(
98
  total=len(ent_ids) + len(comm_ids) + len(txt_ids),
99
  ids=[*ent_ids, *comm_ids, *txt_ids],
100
  query_vector=q_vec,
 
101
  highlight=None,
102
+ field={**ent_content, **comm_content, **txt_content},
103
  keywords=[]
104
  )
 
graphrag/smoke.py CHANGED
@@ -31,10 +31,13 @@ if __name__ == "__main__":
31
  from api.db import LLMType
32
  from api.db.services.llm_service import LLMBundle
33
  from api.settings import retrievaler
 
 
 
34
 
35
  ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
36
  docs = [d["content_with_weight"] for d in
37
- retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])]
38
  graph = ex(docs)
39
 
40
  er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
 
31
  from api.db import LLMType
32
  from api.db.services.llm_service import LLMBundle
33
  from api.settings import retrievaler
34
+ from api.db.services.knowledgebase_service import KnowledgebaseService
35
+
36
+ kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
37
 
38
  ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
39
  docs = [d["content_with_weight"] for d in
40
+ retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
41
  graph = ex(docs)
42
 
43
  er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -46,22 +46,23 @@ hanziconv = "0.3.2"
46
  html-text = "0.6.2"
47
  httpx = "0.27.0"
48
  huggingface-hub = "^0.25.0"
49
- infinity-emb = "0.0.51"
 
50
  itsdangerous = "2.1.2"
51
  markdown = "3.6"
52
  markdown-to-json = "2.1.1"
53
  minio = "7.2.4"
54
  mistralai = "0.4.2"
55
  nltk = "3.9.1"
56
- numpy = "1.26.4"
57
  ollama = "0.2.1"
58
  onnxruntime = "1.19.2"
59
  openai = "1.45.0"
60
  opencv-python = "4.10.0.84"
61
  opencv-python-headless = "4.10.0.84"
62
- openpyxl = "3.1.2"
63
  ormsgpack = "1.5.0"
64
- pandas = "2.2.2"
65
  pdfplumber = "0.10.4"
66
  peewee = "3.17.1"
67
  pillow = "10.4.0"
@@ -70,7 +71,7 @@ psycopg2-binary = "2.9.9"
70
  pyclipper = "1.3.0.post5"
71
  pycryptodomex = "3.20.0"
72
  pypdf = "^5.0.0"
73
- pytest = "8.2.2"
74
  python-dotenv = "1.0.1"
75
  python-dateutil = "2.8.2"
76
  python-pptx = "^1.0.2"
@@ -86,7 +87,7 @@ ruamel-base = "1.0.0"
86
  scholarly = "1.7.11"
87
  scikit-learn = "1.5.0"
88
  selenium = "4.22.0"
89
- setuptools = "70.0.0"
90
  shapely = "2.0.5"
91
  six = "1.16.0"
92
  strenum = "0.4.15"
@@ -115,6 +116,7 @@ pymysql = "^1.1.1"
115
  mini-racer = "^0.12.4"
116
  pyicu = "^2.13.1"
117
  flasgger = "^0.9.7.1"
 
118
 
119
 
120
  [tool.poetry.group.full]
 
46
  html-text = "0.6.2"
47
  httpx = "0.27.0"
48
  huggingface-hub = "^0.25.0"
49
+ infinity-sdk = "0.5.0.dev2"
50
+ infinity-emb = "^0.0.66"
51
  itsdangerous = "2.1.2"
52
  markdown = "3.6"
53
  markdown-to-json = "2.1.1"
54
  minio = "7.2.4"
55
  mistralai = "0.4.2"
56
  nltk = "3.9.1"
57
+ numpy = "^1.26.0"
58
  ollama = "0.2.1"
59
  onnxruntime = "1.19.2"
60
  openai = "1.45.0"
61
  opencv-python = "4.10.0.84"
62
  opencv-python-headless = "4.10.0.84"
63
+ openpyxl = "^3.1.0"
64
  ormsgpack = "1.5.0"
65
+ pandas = "^2.2.0"
66
  pdfplumber = "0.10.4"
67
  peewee = "3.17.1"
68
  pillow = "10.4.0"
 
71
  pyclipper = "1.3.0.post5"
72
  pycryptodomex = "3.20.0"
73
  pypdf = "^5.0.0"
74
+ pytest = "^8.3.0"
75
  python-dotenv = "1.0.1"
76
  python-dateutil = "2.8.2"
77
  python-pptx = "^1.0.2"
 
87
  scholarly = "1.7.11"
88
  scikit-learn = "1.5.0"
89
  selenium = "4.22.0"
90
+ setuptools = "^75.2.0"
91
  shapely = "2.0.5"
92
  six = "1.16.0"
93
  strenum = "0.4.15"
 
116
  mini-racer = "^0.12.4"
117
  pyicu = "^2.13.1"
118
  flasgger = "^0.9.7.1"
119
+ polars = "^1.9.0"
120
 
121
 
122
  [tool.poetry.group.full]
rag/app/presentation.py CHANGED
@@ -20,6 +20,7 @@ from rag.nlp import tokenize, is_english
20
  from rag.nlp import rag_tokenizer
21
  from deepdoc.parser import PdfParser, PptParser, PlainParser
22
  from PyPDF2 import PdfReader as pdf2_read
 
23
 
24
 
25
  class Ppt(PptParser):
@@ -107,9 +108,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
107
  d = copy.deepcopy(doc)
108
  pn += from_page
109
  d["image"] = img
110
- d["page_num_int"] = [pn + 1]
111
- d["top_int"] = [0]
112
- d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
113
  tokenize(d, txt, eng)
114
  res.append(d)
115
  return res
@@ -123,10 +124,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
123
  pn += from_page
124
  if img:
125
  d["image"] = img
126
- d["page_num_int"] = [pn + 1]
127
- d["top_int"] = [0]
128
- d["position_int"] = [
129
- (pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
130
  tokenize(d, txt, eng)
131
  res.append(d)
132
  return res
 
20
  from rag.nlp import rag_tokenizer
21
  from deepdoc.parser import PdfParser, PptParser, PlainParser
22
  from PyPDF2 import PdfReader as pdf2_read
23
+ import json
24
 
25
 
26
  class Ppt(PptParser):
 
108
  d = copy.deepcopy(doc)
109
  pn += from_page
110
  d["image"] = img
111
+ d["page_num_list"] = json.dumps([pn + 1])
112
+ d["top_list"] = json.dumps([0])
113
+ d["position_list"] = json.dumps([(pn + 1, 0, img.size[0], 0, img.size[1])])
114
  tokenize(d, txt, eng)
115
  res.append(d)
116
  return res
 
124
  pn += from_page
125
  if img:
126
  d["image"] = img
127
+ d["page_num_list"] = json.dumps([pn + 1])
128
+ d["top_list"] = json.dumps([0])
129
+ d["position_list"] = json.dumps([
130
+ (pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)])
131
  tokenize(d, txt, eng)
132
  res.append(d)
133
  return res
rag/app/table.py CHANGED
@@ -74,7 +74,7 @@ class Excel(ExcelParser):
74
  def trans_datatime(s):
75
  try:
76
  return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
77
- except Exception as e:
78
  pass
79
 
80
 
@@ -112,7 +112,7 @@ def column_data_type(arr):
112
  continue
113
  try:
114
  arr[i] = trans[ty](str(arr[i]))
115
- except Exception as e:
116
  arr[i] = None
117
  # if ty == "text":
118
  # if len(arr) > 128 and uni / len(arr) < 0.1:
@@ -182,7 +182,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
182
  "datetime": "_dt",
183
  "bool": "_kwd"}
184
  for df in dfs:
185
- for n in ["id", "_id", "index", "idx"]:
186
  if n in df.columns:
187
  del df[n]
188
  clmns = df.columns.values
 
74
  def trans_datatime(s):
75
  try:
76
  return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
77
+ except Exception:
78
  pass
79
 
80
 
 
112
  continue
113
  try:
114
  arr[i] = trans[ty](str(arr[i]))
115
+ except Exception:
116
  arr[i] = None
117
  # if ty == "text":
118
  # if len(arr) > 128 and uni / len(arr) < 0.1:
 
182
  "datetime": "_dt",
183
  "bool": "_kwd"}
184
  for df in dfs:
185
+ for n in ["id", "index", "idx"]:
186
  if n in df.columns:
187
  del df[n]
188
  clmns = df.columns.values
rag/benchmark.py CHANGED
@@ -1,280 +1,310 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import json
17
- import os
18
- from collections import defaultdict
19
- from concurrent.futures import ThreadPoolExecutor
20
- from copy import deepcopy
21
-
22
- from api.db import LLMType
23
- from api.db.services.llm_service import LLMBundle
24
- from api.db.services.knowledgebase_service import KnowledgebaseService
25
- from api.settings import retrievaler
26
- from api.utils import get_uuid
27
- from api.utils.file_utils import get_project_base_directory
28
- from rag.nlp import tokenize, search
29
- from rag.utils.es_conn import ELASTICSEARCH
30
- from ranx import evaluate
31
- import pandas as pd
32
- from tqdm import tqdm
33
- from ranx import Qrels, Run
34
-
35
-
36
- class Benchmark:
37
- def __init__(self, kb_id):
38
- e, self.kb = KnowledgebaseService.get_by_id(kb_id)
39
- self.similarity_threshold = self.kb.similarity_threshold
40
- self.vector_similarity_weight = self.kb.vector_similarity_weight
41
- self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
42
-
43
- def _get_benchmarks(self, query, dataset_idxnm, count=16):
44
-
45
- req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
46
- sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
47
- return sres
48
-
49
- def _get_retrieval(self, qrels, dataset_idxnm):
50
- run = defaultdict(dict)
51
- query_list = list(qrels.keys())
52
- for query in query_list:
53
-
54
- ranks = retrievaler.retrieval(query, self.embd_mdl,
55
- dataset_idxnm, [self.kb.id], 1, 30,
56
- 0.0, self.vector_similarity_weight)
57
- for c in ranks["chunks"]:
58
- if "vector" in c:
59
- del c["vector"]
60
- run[query][c["chunk_id"]] = c["similarity"]
61
-
62
- return run
63
-
64
- def embedding(self, docs, batch_size=16):
65
- vects = []
66
- cnts = [d["content_with_weight"] for d in docs]
67
- for i in range(0, len(cnts), batch_size):
68
- vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
69
- vects.extend(vts.tolist())
70
- assert len(docs) == len(vects)
71
- for i, d in enumerate(docs):
72
- v = vects[i]
73
- d["q_%d_vec" % len(v)] = v
74
- return docs
75
-
76
- @staticmethod
77
- def init_kb(index_name):
78
- idxnm = search.index_name(index_name)
79
- if ELASTICSEARCH.indexExist(idxnm):
80
- ELASTICSEARCH.deleteIdx(search.index_name(index_name))
81
-
82
- return ELASTICSEARCH.createIdx(idxnm, json.load(
83
- open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
84
-
85
- def ms_marco_index(self, file_path, index_name):
86
- qrels = defaultdict(dict)
87
- texts = defaultdict(dict)
88
- docs = []
89
- filelist = os.listdir(file_path)
90
- self.init_kb(index_name)
91
-
92
- max_workers = int(os.environ.get('MAX_WORKERS', 3))
93
- exe = ThreadPoolExecutor(max_workers=max_workers)
94
- threads = []
95
-
96
- def slow_actions(es_docs, idx_nm):
97
- es_docs = self.embedding(es_docs)
98
- ELASTICSEARCH.bulk(es_docs, idx_nm)
99
- return True
100
-
101
- for dir in filelist:
102
- data = pd.read_parquet(os.path.join(file_path, dir))
103
- for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
104
-
105
- query = data.iloc[i]['query']
106
- for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
107
- d = {
108
- "id": get_uuid(),
109
- "kb_id": self.kb.id,
110
- "docnm_kwd": "xxxxx",
111
- "doc_id": "ksksks"
112
- }
113
- tokenize(d, text, "english")
114
- docs.append(d)
115
- texts[d["id"]] = text
116
- qrels[query][d["id"]] = int(rel)
117
- if len(docs) >= 32:
118
- threads.append(
119
- exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
120
- docs = []
121
-
122
- threads.append(
123
- exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
124
-
125
- for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
126
- if not threads[i].result().output:
127
- print("Indexing error...")
128
-
129
- return qrels, texts
130
-
131
- def trivia_qa_index(self, file_path, index_name):
132
- qrels = defaultdict(dict)
133
- texts = defaultdict(dict)
134
- docs = []
135
- filelist = os.listdir(file_path)
136
- for dir in filelist:
137
- data = pd.read_parquet(os.path.join(file_path, dir))
138
- for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
139
- query = data.iloc[i]['question']
140
- for rel, text in zip(data.iloc[i]["search_results"]['rank'],
141
- data.iloc[i]["search_results"]['search_context']):
142
- d = {
143
- "id": get_uuid(),
144
- "kb_id": self.kb.id,
145
- "docnm_kwd": "xxxxx",
146
- "doc_id": "ksksks"
147
- }
148
- tokenize(d, text, "english")
149
- docs.append(d)
150
- texts[d["id"]] = text
151
- qrels[query][d["id"]] = int(rel)
152
- if len(docs) >= 32:
153
- docs = self.embedding(docs)
154
- ELASTICSEARCH.bulk(docs, search.index_name(index_name))
155
- docs = []
156
-
157
- docs = self.embedding(docs)
158
- ELASTICSEARCH.bulk(docs, search.index_name(index_name))
159
- return qrels, texts
160
-
161
- def miracl_index(self, file_path, corpus_path, index_name):
162
-
163
- corpus_total = {}
164
- for corpus_file in os.listdir(corpus_path):
165
- tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
166
- for index, i in tmp_data.iterrows():
167
- corpus_total[i['docid']] = i['text']
168
-
169
- topics_total = {}
170
- for topics_file in os.listdir(os.path.join(file_path, 'topics')):
171
- if 'test' in topics_file:
172
- continue
173
- tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
174
- for index, i in tmp_data.iterrows():
175
- topics_total[i['qid']] = i['query']
176
-
177
- qrels = defaultdict(dict)
178
- texts = defaultdict(dict)
179
- docs = []
180
- for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
181
- if 'test' in qrels_file:
182
- continue
183
-
184
- tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
185
- names=['qid', 'Q0', 'docid', 'relevance'])
186
- for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
187
- query = topics_total[tmp_data.iloc[i]['qid']]
188
- text = corpus_total[tmp_data.iloc[i]['docid']]
189
- rel = tmp_data.iloc[i]['relevance']
190
- d = {
191
- "id": get_uuid(),
192
- "kb_id": self.kb.id,
193
- "docnm_kwd": "xxxxx",
194
- "doc_id": "ksksks"
195
- }
196
- tokenize(d, text, 'english')
197
- docs.append(d)
198
- texts[d["id"]] = text
199
- qrels[query][d["id"]] = int(rel)
200
- if len(docs) >= 32:
201
- docs = self.embedding(docs)
202
- ELASTICSEARCH.bulk(docs, search.index_name(index_name))
203
- docs = []
204
-
205
- docs = self.embedding(docs)
206
- ELASTICSEARCH.bulk(docs, search.index_name(index_name))
207
-
208
- return qrels, texts
209
-
210
- def save_results(self, qrels, run, texts, dataset, file_path):
211
- keep_result = []
212
- run_keys = list(run.keys())
213
- for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
214
- key = run_keys[run_i]
215
- keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
216
- 'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
217
- keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
218
- with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
219
- f.write('## Score For Every Query\n')
220
- for keep_result_i in keep_result:
221
- f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
222
- scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
223
- scores = sorted(scores, key=lambda kk: kk[1])
224
- for score in scores[:10]:
225
- f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
226
- json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
227
- json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
228
- print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
229
-
230
- def __call__(self, dataset, file_path, miracl_corpus=''):
231
- if dataset == "ms_marco_v1.1":
232
- qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
233
- run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
234
- print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
235
- self.save_results(qrels, run, texts, dataset, file_path)
236
- if dataset == "trivia_qa":
237
- qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
238
- run = self._get_retrieval(qrels, "benchmark_trivia_qa")
239
- print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
240
- self.save_results(qrels, run, texts, dataset, file_path)
241
- if dataset == "miracl":
242
- for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
243
- 'yo', 'zh']:
244
- if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
245
- print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
246
- continue
247
- if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
248
- print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
249
- continue
250
- if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
251
- print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
252
- continue
253
- if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
254
- print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
255
- continue
256
- qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
257
- os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
258
- "benchmark_miracl_" + lang)
259
- run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
260
- print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
261
- self.save_results(qrels, run, texts, dataset, file_path)
262
-
263
-
264
- if __name__ == '__main__':
265
- print('*****************RAGFlow Benchmark*****************')
266
- kb_id = input('Please input kb_id:\n')
267
- ex = Benchmark(kb_id)
268
- dataset = input(
269
- 'RAGFlow Benchmark Support:\n\tms_marco_v1.1:<https://huggingface.co/datasets/microsoft/ms_marco>\n\ttrivia_qa:<https://huggingface.co/datasets/mandarjoshi/trivia_qa>\n\tmiracl:<https://huggingface.co/datasets/miracl/miracl>\nPlease input dataset choice:\n')
270
- if dataset in ['ms_marco_v1.1', 'trivia_qa']:
271
- if dataset == "ms_marco_v1.1":
272
- print("Notice: Please provide the ms_marco_v1.1 dataset only. ms_marco_v2.1 is not supported!")
273
- dataset_path = input('Please input ' + dataset + ' dataset path:\n')
274
- ex(dataset, dataset_path)
275
- elif dataset == 'miracl':
276
- dataset_path = input('Please input ' + dataset + ' dataset path:\n')
277
- corpus_path = input('Please input ' + dataset + '-corpus dataset path:\n')
278
- ex(dataset, dataset_path, miracl_corpus=corpus_path)
279
- else:
280
- print("Dataset: ", dataset, "not supported!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import json
17
+ import os
18
+ import sys
19
+ import time
20
+ import argparse
21
+ from collections import defaultdict
22
+
23
+ from api.db import LLMType
24
+ from api.db.services.llm_service import LLMBundle
25
+ from api.db.services.knowledgebase_service import KnowledgebaseService
26
+ from api.settings import retrievaler, docStoreConn
27
+ from api.utils import get_uuid
28
+ from rag.nlp import tokenize, search
29
+ from ranx import evaluate
30
+ import pandas as pd
31
+ from tqdm import tqdm
32
+
33
+ global max_docs
34
+ max_docs = sys.maxsize
35
+
36
+ class Benchmark:
37
+ def __init__(self, kb_id):
38
+ self.kb_id = kb_id
39
+ e, self.kb = KnowledgebaseService.get_by_id(kb_id)
40
+ self.similarity_threshold = self.kb.similarity_threshold
41
+ self.vector_similarity_weight = self.kb.vector_similarity_weight
42
+ self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
43
+ self.tenant_id = ''
44
+ self.index_name = ''
45
+ self.initialized_index = False
46
+
47
+ def _get_retrieval(self, qrels):
48
+ # Need to wait for the ES and Infinity index to be ready
49
+ time.sleep(20)
50
+ run = defaultdict(dict)
51
+ query_list = list(qrels.keys())
52
+ for query in query_list:
53
+ ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
54
+ 0.0, self.vector_similarity_weight)
55
+ if len(ranks["chunks"]) == 0:
56
+ print(f"deleted query: {query}")
57
+ del qrels[query]
58
+ continue
59
+ for c in ranks["chunks"]:
60
+ if "vector" in c:
61
+ del c["vector"]
62
+ run[query][c["chunk_id"]] = c["similarity"]
63
+ return run
64
+
65
+ def embedding(self, docs, batch_size=16):
66
+ vects = []
67
+ cnts = [d["content_with_weight"] for d in docs]
68
+ for i in range(0, len(cnts), batch_size):
69
+ vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
70
+ vects.extend(vts.tolist())
71
+ assert len(docs) == len(vects)
72
+ vector_size = 0
73
+ for i, d in enumerate(docs):
74
+ v = vects[i]
75
+ vector_size = len(v)
76
+ d["q_%d_vec" % len(v)] = v
77
+ return docs, vector_size
78
+
79
+ def init_index(self, vector_size: int):
80
+ if self.initialized_index:
81
+ return
82
+ if docStoreConn.indexExist(self.index_name, self.kb_id):
83
+ docStoreConn.deleteIdx(self.index_name, self.kb_id)
84
+ docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
85
+ self.initialized_index = True
86
+
87
+ def ms_marco_index(self, file_path, index_name):
88
+ qrels = defaultdict(dict)
89
+ texts = defaultdict(dict)
90
+ docs_count = 0
91
+ docs = []
92
+ filelist = sorted(os.listdir(file_path))
93
+
94
+ for fn in filelist:
95
+ if docs_count >= max_docs:
96
+ break
97
+ if not fn.endswith(".parquet"):
98
+ continue
99
+ data = pd.read_parquet(os.path.join(file_path, fn))
100
+ for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
101
+ if docs_count >= max_docs:
102
+ break
103
+ query = data.iloc[i]['query']
104
+ for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
105
+ d = {
106
+ "id": get_uuid(),
107
+ "kb_id": self.kb.id,
108
+ "docnm_kwd": "xxxxx",
109
+ "doc_id": "ksksks"
110
+ }
111
+ tokenize(d, text, "english")
112
+ docs.append(d)
113
+ texts[d["id"]] = text
114
+ qrels[query][d["id"]] = int(rel)
115
+ if len(docs) >= 32:
116
+ docs_count += len(docs)
117
+ docs, vector_size = self.embedding(docs)
118
+ self.init_index(vector_size)
119
+ docStoreConn.insert(docs, self.index_name, self.kb_id)
120
+ docs = []
121
+
122
+ if docs:
123
+ docs, vector_size = self.embedding(docs)
124
+ self.init_index(vector_size)
125
+ docStoreConn.insert(docs, self.index_name, self.kb_id)
126
+ return qrels, texts
127
+
128
+ def trivia_qa_index(self, file_path, index_name):
129
+ qrels = defaultdict(dict)
130
+ texts = defaultdict(dict)
131
+ docs_count = 0
132
+ docs = []
133
+ filelist = sorted(os.listdir(file_path))
134
+ for fn in filelist:
135
+ if docs_count >= max_docs:
136
+ break
137
+ if not fn.endswith(".parquet"):
138
+ continue
139
+ data = pd.read_parquet(os.path.join(file_path, fn))
140
+ for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
141
+ if docs_count >= max_docs:
142
+ break
143
+ query = data.iloc[i]['question']
144
+ for rel, text in zip(data.iloc[i]["search_results"]['rank'],
145
+ data.iloc[i]["search_results"]['search_context']):
146
+ d = {
147
+ "id": get_uuid(),
148
+ "kb_id": self.kb.id,
149
+ "docnm_kwd": "xxxxx",
150
+ "doc_id": "ksksks"
151
+ }
152
+ tokenize(d, text, "english")
153
+ docs.append(d)
154
+ texts[d["id"]] = text
155
+ qrels[query][d["id"]] = int(rel)
156
+ if len(docs) >= 32:
157
+ docs_count += len(docs)
158
+ docs, vector_size = self.embedding(docs)
159
+ self.init_index(vector_size)
160
+ docStoreConn.insert(docs,self.index_name)
161
+ docs = []
162
+
163
+ docs, vector_size = self.embedding(docs)
164
+ self.init_index(vector_size)
165
+ docStoreConn.insert(docs, self.index_name)
166
+ return qrels, texts
167
+
168
+ def miracl_index(self, file_path, corpus_path, index_name):
169
+ corpus_total = {}
170
+ for corpus_file in os.listdir(corpus_path):
171
+ tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
172
+ for index, i in tmp_data.iterrows():
173
+ corpus_total[i['docid']] = i['text']
174
+
175
+ topics_total = {}
176
+ for topics_file in os.listdir(os.path.join(file_path, 'topics')):
177
+ if 'test' in topics_file:
178
+ continue
179
+ tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
180
+ for index, i in tmp_data.iterrows():
181
+ topics_total[i['qid']] = i['query']
182
+
183
+ qrels = defaultdict(dict)
184
+ texts = defaultdict(dict)
185
+ docs_count = 0
186
+ docs = []
187
+ for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
188
+ if 'test' in qrels_file:
189
+ continue
190
+ if docs_count >= max_docs:
191
+ break
192
+
193
+ tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
194
+ names=['qid', 'Q0', 'docid', 'relevance'])
195
+ for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
196
+ if docs_count >= max_docs:
197
+ break
198
+ query = topics_total[tmp_data.iloc[i]['qid']]
199
+ text = corpus_total[tmp_data.iloc[i]['docid']]
200
+ rel = tmp_data.iloc[i]['relevance']
201
+ d = {
202
+ "id": get_uuid(),
203
+ "kb_id": self.kb.id,
204
+ "docnm_kwd": "xxxxx",
205
+ "doc_id": "ksksks"
206
+ }
207
+ tokenize(d, text, 'english')
208
+ docs.append(d)
209
+ texts[d["id"]] = text
210
+ qrels[query][d["id"]] = int(rel)
211
+ if len(docs) >= 32:
212
+ docs_count += len(docs)
213
+ docs, vector_size = self.embedding(docs)
214
+ self.init_index(vector_size)
215
+ docStoreConn.insert(docs, self.index_name)
216
+ docs = []
217
+
218
+ docs, vector_size = self.embedding(docs)
219
+ self.init_index(vector_size)
220
+ docStoreConn.insert(docs, self.index_name)
221
+ return qrels, texts
222
+
223
+ def save_results(self, qrels, run, texts, dataset, file_path):
224
+ keep_result = []
225
+ run_keys = list(run.keys())
226
+ for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
227
+ key = run_keys[run_i]
228
+ keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
229
+ 'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
230
+ keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
231
+ with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
232
+ f.write('## Score For Every Query\n')
233
+ for keep_result_i in keep_result:
234
+ f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
235
+ scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
236
+ scores = sorted(scores, key=lambda kk: kk[1])
237
+ for score in scores[:10]:
238
+ f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
239
+ json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
240
+ json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
241
+ print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
242
+
243
+ def __call__(self, dataset, file_path, miracl_corpus=''):
244
+ if dataset == "ms_marco_v1.1":
245
+ self.tenant_id = "benchmark_ms_marco_v11"
246
+ self.index_name = search.index_name(self.tenant_id)
247
+ qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
248
+ run = self._get_retrieval(qrels)
249
+ print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
250
+ self.save_results(qrels, run, texts, dataset, file_path)
251
+ if dataset == "trivia_qa":
252
+ self.tenant_id = "benchmark_trivia_qa"
253
+ self.index_name = search.index_name(self.tenant_id)
254
+ qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
255
+ run = self._get_retrieval(qrels)
256
+ print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
257
+ self.save_results(qrels, run, texts, dataset, file_path)
258
+ if dataset == "miracl":
259
+ for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
260
+ 'yo', 'zh']:
261
+ if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
262
+ print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
263
+ continue
264
+ if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
265
+ print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
266
+ continue
267
+ if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
268
+ print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
269
+ continue
270
+ if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
271
+ print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
272
+ continue
273
+ self.tenant_id = "benchmark_miracl_" + lang
274
+ self.index_name = search.index_name(self.tenant_id)
275
+ self.initialized_index = False
276
+ qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
277
+ os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
278
+ "benchmark_miracl_" + lang)
279
+ run = self._get_retrieval(qrels)
280
+ print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
281
+ self.save_results(qrels, run, texts, dataset, file_path)
282
+
283
+
284
+ if __name__ == '__main__':
285
+ print('*****************RAGFlow Benchmark*****************')
286
+ parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
287
+ parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
288
+ parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
289
+ parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
290
+ parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
291
+ parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
292
+
293
+ args = parser.parse_args()
294
+ max_docs = args.max_docs
295
+ kb_id = args.kb_id
296
+ ex = Benchmark(kb_id)
297
+
298
+ dataset = args.dataset
299
+ dataset_path = args.dataset_path
300
+
301
+ if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
302
+ ex(dataset, dataset_path)
303
+ elif dataset == "miracl":
304
+ if len(args) < 5:
305
+ print('Please input the correct parameters!')
306
+ exit(1)
307
+ miracl_corpus_path = args[4]
308
+ ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
309
+ else:
310
+ print("Dataset: ", dataset, "not supported!")
rag/nlp/__init__.py CHANGED
@@ -25,6 +25,7 @@ import roman_numbers as r
25
  from word2number import w2n
26
  from cn2an import cn2an
27
  from PIL import Image
 
28
 
29
  all_codecs = [
30
  'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
@@ -51,12 +52,12 @@ def find_codec(blob):
51
  try:
52
  blob[:1024].decode(c)
53
  return c
54
- except Exception as e:
55
  pass
56
  try:
57
  blob.decode(c)
58
  return c
59
- except Exception as e:
60
  pass
61
 
62
  return "utf-8"
@@ -241,7 +242,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
241
  d["image"], poss = pdf_parser.crop(ck, need_position=True)
242
  add_positions(d, poss)
243
  ck = pdf_parser.remove_tag(ck)
244
- except NotImplementedError as e:
245
  pass
246
  tokenize(d, ck, eng)
247
  res.append(d)
@@ -289,13 +290,16 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
289
  def add_positions(d, poss):
290
  if not poss:
291
  return
292
- d["page_num_int"] = []
293
- d["position_int"] = []
294
- d["top_int"] = []
295
  for pn, left, right, top, bottom in poss:
296
- d["page_num_int"].append(int(pn + 1))
297
- d["top_int"].append(int(top))
298
- d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
 
 
 
299
 
300
 
301
  def remove_contents_table(sections, eng=False):
 
25
  from word2number import w2n
26
  from cn2an import cn2an
27
  from PIL import Image
28
+ import json
29
 
30
  all_codecs = [
31
  'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
 
52
  try:
53
  blob[:1024].decode(c)
54
  return c
55
+ except Exception:
56
  pass
57
  try:
58
  blob.decode(c)
59
  return c
60
+ except Exception:
61
  pass
62
 
63
  return "utf-8"
 
242
  d["image"], poss = pdf_parser.crop(ck, need_position=True)
243
  add_positions(d, poss)
244
  ck = pdf_parser.remove_tag(ck)
245
+ except NotImplementedError:
246
  pass
247
  tokenize(d, ck, eng)
248
  res.append(d)
 
290
  def add_positions(d, poss):
291
  if not poss:
292
  return
293
+ page_num_list = []
294
+ position_list = []
295
+ top_list = []
296
  for pn, left, right, top, bottom in poss:
297
+ page_num_list.append(int(pn + 1))
298
+ top_list.append(int(top))
299
+ position_list.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
300
+ d["page_num_list"] = json.dumps(page_num_list)
301
+ d["position_list"] = json.dumps(position_list)
302
+ d["top_list"] = json.dumps(top_list)
303
 
304
 
305
  def remove_contents_table(sections, eng=False):
rag/nlp/query.py CHANGED
@@ -15,20 +15,25 @@
15
  #
16
 
17
  import json
18
- import math
19
  import re
20
  import logging
21
- import copy
22
- from elasticsearch_dsl import Q
23
 
24
  from rag.nlp import rag_tokenizer, term_weight, synonym
25
 
26
- class EsQueryer:
27
- def __init__(self, es):
 
28
  self.tw = term_weight.Dealer()
29
- self.es = es
30
  self.syn = synonym.Dealer()
31
- self.flds = ["ask_tks^10", "ask_small_tks"]
 
 
 
 
 
 
 
32
 
33
  @staticmethod
34
  def subSpecialChar(line):
@@ -43,12 +48,15 @@ class EsQueryer:
43
  for t in arr:
44
  if not re.match(r"[a-zA-Z]+$", t):
45
  e += 1
46
- return e * 1. / len(arr) >= 0.7
47
 
48
  @staticmethod
49
  def rmWWW(txt):
50
  patts = [
51
- (r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
 
 
 
52
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
53
  (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
54
  ]
@@ -56,16 +64,16 @@ class EsQueryer:
56
  txt = re.sub(r, p, txt, flags=re.IGNORECASE)
57
  return txt
58
 
59
- def question(self, txt, tbl="qa", min_match="60%"):
60
  txt = re.sub(
61
  r"[ :\r\n\t,,。??/`!!&\^%%]+",
62
  " ",
63
- rag_tokenizer.tradi2simp(
64
- rag_tokenizer.strQ2B(
65
- txt.lower()))).strip()
66
 
67
  if not self.isChinese(txt):
68
- txt = EsQueryer.rmWWW(txt)
69
  tks = rag_tokenizer.tokenize(txt).split(" ")
70
  tks_w = self.tw.weights(tks)
71
  tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
@@ -73,14 +81,20 @@ class EsQueryer:
73
  tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
74
  q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
75
  for i in range(1, len(tks_w)):
76
- q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
 
 
 
 
 
 
 
77
  if not q:
78
  q.append(txt)
79
- return Q("bool",
80
- must=Q("query_string", fields=self.flds,
81
- type="best_fields", query=" ".join(q),
82
- boost=1)#, minimum_should_match=min_match)
83
- ), list(set([t for t in txt.split(" ") if t]))
84
 
85
  def need_fine_grained_tokenize(tk):
86
  if len(tk) < 3:
@@ -89,7 +103,7 @@ class EsQueryer:
89
  return False
90
  return True
91
 
92
- txt = EsQueryer.rmWWW(txt)
93
  qs, keywords = [], []
94
  for tt in self.tw.split(txt)[:256]: # .split(" "):
95
  if not tt:
@@ -101,65 +115,71 @@ class EsQueryer:
101
  logging.info(json.dumps(twts, ensure_ascii=False))
102
  tms = []
103
  for tk, w in sorted(twts, key=lambda x: x[1] * -1):
104
- sm = rag_tokenizer.fine_grained_tokenize(tk).split(" ") if need_fine_grained_tokenize(tk) else []
 
 
 
 
105
  sm = [
106
  re.sub(
107
  r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
108
  "",
109
- m) for m in sm]
110
- sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
 
 
 
111
  sm = [m for m in sm if len(m) > 1]
112
 
113
  keywords.append(re.sub(r"[ \\\"']+", "", tk))
114
  keywords.extend(sm)
115
- if len(keywords) >= 12: break
 
116
 
117
  tk_syns = self.syn.lookup(tk)
118
- tk = EsQueryer.subSpecialChar(tk)
119
  if tk.find(" ") > 0:
120
- tk = "\"%s\"" % tk
121
  if tk_syns:
122
  tk = f"({tk} %s)" % " ".join(tk_syns)
123
  if sm:
124
- tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
125
- " ".join(sm), " ".join(sm))
126
  if tk.strip():
127
  tms.append((tk, w))
128
 
129
  tms = " ".join([f"({t})^{w}" for t, w in tms])
130
 
131
  if len(twts) > 1:
132
- tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
133
  if re.match(r"[0-9a-z ]+$", tt):
134
- tms = f"(\"{tt}\" OR \"%s\")" % rag_tokenizer.tokenize(tt)
135
 
136
  syns = " OR ".join(
137
- ["\"%s\"^0.7" % EsQueryer.subSpecialChar(rag_tokenizer.tokenize(s)) for s in syns])
 
 
 
 
 
138
  if syns:
139
  tms = f"({tms})^5 OR ({syns})^0.7"
140
 
141
  qs.append(tms)
142
 
143
- flds = copy.deepcopy(self.flds)
144
- mst = []
145
  if qs:
146
- mst.append(
147
- Q("query_string", fields=flds, type="best_fields",
148
- query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
149
- )
150
-
151
- return Q("bool",
152
- must=mst,
153
- ), list(set(keywords))
154
 
155
- def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
156
- vtweight=0.7):
157
  from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
158
  import numpy as np
 
159
  sims = CosineSimilarity([avec], bvecs)
160
  tksim = self.token_similarity(atks, btkss)
161
- return np.array(sims[0]) * vtweight + \
162
- np.array(tksim) * tkweight, tksim, sims[0]
163
 
164
  def token_similarity(self, atks, btkss):
165
  def toDict(tks):
 
15
  #
16
 
17
  import json
 
18
  import re
19
  import logging
20
+ from rag.utils.doc_store_conn import MatchTextExpr
 
21
 
22
  from rag.nlp import rag_tokenizer, term_weight, synonym
23
 
24
+
25
+ class FulltextQueryer:
26
+ def __init__(self):
27
  self.tw = term_weight.Dealer()
 
28
  self.syn = synonym.Dealer()
29
+ self.query_fields = [
30
+ "title_tks^10",
31
+ "title_sm_tks^5",
32
+ "important_kwd^30",
33
+ "important_tks^20",
34
+ "content_ltks^2",
35
+ "content_sm_ltks",
36
+ ]
37
 
38
  @staticmethod
39
  def subSpecialChar(line):
 
48
  for t in arr:
49
  if not re.match(r"[a-zA-Z]+$", t):
50
  e += 1
51
+ return e * 1.0 / len(arr) >= 0.7
52
 
53
  @staticmethod
54
  def rmWWW(txt):
55
  patts = [
56
+ (
57
+ r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
58
+ "",
59
+ ),
60
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
61
  (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
62
  ]
 
64
  txt = re.sub(r, p, txt, flags=re.IGNORECASE)
65
  return txt
66
 
67
+ def question(self, txt, tbl="qa", min_match:float=0.6):
68
  txt = re.sub(
69
  r"[ :\r\n\t,,。??/`!!&\^%%]+",
70
  " ",
71
+ rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
72
+ ).strip()
73
+ txt = FulltextQueryer.rmWWW(txt)
74
 
75
  if not self.isChinese(txt):
76
+ txt = FulltextQueryer.rmWWW(txt)
77
  tks = rag_tokenizer.tokenize(txt).split(" ")
78
  tks_w = self.tw.weights(tks)
79
  tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
 
81
  tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
82
  q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
83
  for i in range(1, len(tks_w)):
84
+ q.append(
85
+ '"%s %s"^%.4f'
86
+ % (
87
+ tks_w[i - 1][0],
88
+ tks_w[i][0],
89
+ max(tks_w[i - 1][1], tks_w[i][1]) * 2,
90
+ )
91
+ )
92
  if not q:
93
  q.append(txt)
94
+ query = " ".join(q)
95
+ return MatchTextExpr(
96
+ self.query_fields, query, 100
97
+ ), tks
 
98
 
99
  def need_fine_grained_tokenize(tk):
100
  if len(tk) < 3:
 
103
  return False
104
  return True
105
 
106
+ txt = FulltextQueryer.rmWWW(txt)
107
  qs, keywords = [], []
108
  for tt in self.tw.split(txt)[:256]: # .split(" "):
109
  if not tt:
 
115
  logging.info(json.dumps(twts, ensure_ascii=False))
116
  tms = []
117
  for tk, w in sorted(twts, key=lambda x: x[1] * -1):
118
+ sm = (
119
+ rag_tokenizer.fine_grained_tokenize(tk).split(" ")
120
+ if need_fine_grained_tokenize(tk)
121
+ else []
122
+ )
123
  sm = [
124
  re.sub(
125
  r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
126
  "",
127
+ m,
128
+ )
129
+ for m in sm
130
+ ]
131
+ sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
132
  sm = [m for m in sm if len(m) > 1]
133
 
134
  keywords.append(re.sub(r"[ \\\"']+", "", tk))
135
  keywords.extend(sm)
136
+ if len(keywords) >= 12:
137
+ break
138
 
139
  tk_syns = self.syn.lookup(tk)
140
+ tk = FulltextQueryer.subSpecialChar(tk)
141
  if tk.find(" ") > 0:
142
+ tk = '"%s"' % tk
143
  if tk_syns:
144
  tk = f"({tk} %s)" % " ".join(tk_syns)
145
  if sm:
146
+ tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
 
147
  if tk.strip():
148
  tms.append((tk, w))
149
 
150
  tms = " ".join([f"({t})^{w}" for t, w in tms])
151
 
152
  if len(twts) > 1:
153
+ tms += ' ("%s"~4)^1.5' % (" ".join([t for t, _ in twts]))
154
  if re.match(r"[0-9a-z ]+$", tt):
155
+ tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
156
 
157
  syns = " OR ".join(
158
+ [
159
+ '"%s"^0.7'
160
+ % FulltextQueryer.subSpecialChar(rag_tokenizer.tokenize(s))
161
+ for s in syns
162
+ ]
163
+ )
164
  if syns:
165
  tms = f"({tms})^5 OR ({syns})^0.7"
166
 
167
  qs.append(tms)
168
 
 
 
169
  if qs:
170
+ query = " OR ".join([f"({t})" for t in qs if t])
171
+ return MatchTextExpr(
172
+ self.query_fields, query, 100, {"minimum_should_match": min_match}
173
+ ), keywords
174
+ return None, keywords
 
 
 
175
 
176
+ def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
 
177
  from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
178
  import numpy as np
179
+
180
  sims = CosineSimilarity([avec], bvecs)
181
  tksim = self.token_similarity(atks, btkss)
182
+ return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
 
183
 
184
  def token_similarity(self, atks, btkss):
185
  def toDict(tks):
rag/nlp/search.py CHANGED
@@ -14,34 +14,25 @@
14
  # limitations under the License.
15
  #
16
 
17
- import json
18
  import re
19
- from copy import deepcopy
20
-
21
- from elasticsearch_dsl import Q, Search
22
  from typing import List, Optional, Dict, Union
23
  from dataclasses import dataclass
24
 
25
- from rag.settings import es_logger
26
  from rag.utils import rmSpace
27
- from rag.nlp import rag_tokenizer, query, is_english
28
  import numpy as np
 
29
 
30
 
31
  def index_name(uid): return f"ragflow_{uid}"
32
 
33
 
34
  class Dealer:
35
- def __init__(self, es):
36
- self.qryr = query.EsQueryer(es)
37
- self.qryr.flds = [
38
- "title_tks^10",
39
- "title_sm_tks^5",
40
- "important_kwd^30",
41
- "important_tks^20",
42
- "content_ltks^2",
43
- "content_sm_ltks"]
44
- self.es = es
45
 
46
  @dataclass
47
  class SearchResult:
@@ -54,170 +45,99 @@ class Dealer:
54
  keywords: Optional[List[str]] = None
55
  group_docs: List[List] = None
56
 
57
- def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
58
- qv, c = emb_mdl.encode_queries(txt)
59
- return {
60
- "field": "q_%d_vec" % len(qv),
61
- "k": topk,
62
- "similarity": sim,
63
- "num_candidates": topk * 2,
64
- "query_vector": [float(v) for v in qv]
65
- }
66
-
67
- def _add_filters(self, bqry, req):
68
- if req.get("kb_ids"):
69
- bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
70
- if req.get("doc_ids"):
71
- bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
72
- if req.get("knowledge_graph_kwd"):
73
- bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"]))
74
- if "available_int" in req:
75
- if req["available_int"] == 0:
76
- bqry.filter.append(Q("range", available_int={"lt": 1}))
77
- else:
78
- bqry.filter.append(
79
- Q("bool", must_not=Q("range", available_int={"lt": 1})))
80
- return bqry
81
-
82
- def search(self, req, idxnms, emb_mdl=None, highlight=False):
83
- qst = req.get("question", "")
84
- bqry, keywords = self.qryr.question(qst, min_match="30%")
85
- bqry = self._add_filters(bqry, req)
86
- bqry.boost = 0.05
87
 
88
- s = Search()
89
  pg = int(req.get("page", 1)) - 1
90
  topk = int(req.get("topk", 1024))
91
  ps = int(req.get("size", topk))
 
 
92
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
93
- "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd",
94
- "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
 
95
 
96
- s = s.query(bqry)[pg * ps:(pg + 1) * ps]
97
- s = s.highlight("content_ltks")
98
- s = s.highlight("title_ltks")
99
  if not qst:
100
- if not req.get("sort"):
101
- s = s.sort(
102
- #{"create_time": {"order": "desc", "unmapped_type": "date"}},
103
- {"create_timestamp_flt": {
104
- "order": "desc", "unmapped_type": "float"}}
105
- )
 
 
 
 
 
 
 
106
  else:
107
- s = s.sort(
108
- {"page_num_int": {"order": "asc", "unmapped_type": "float",
109
- "mode": "avg", "numeric_type": "double"}},
110
- {"top_int": {"order": "asc", "unmapped_type": "float",
111
- "mode": "avg", "numeric_type": "double"}},
112
- #{"create_time": {"order": "desc", "unmapped_type": "date"}},
113
- {"create_timestamp_flt": {
114
- "order": "desc", "unmapped_type": "float"}}
115
- )
116
-
117
- if qst:
118
- s = s.highlight_options(
119
- fragment_size=120,
120
- number_of_fragments=5,
121
- boundary_scanner_locale="zh-CN",
122
- boundary_scanner="SENTENCE",
123
- boundary_chars=",./;:\\!(),。?:!……()——、"
124
- )
125
- s = s.to_dict()
126
- q_vec = []
127
- if req.get("vector"):
128
- assert emb_mdl, "No embedding model selected"
129
- s["knn"] = self._vector(
130
- qst, emb_mdl, req.get(
131
- "similarity", 0.1), topk)
132
- s["knn"]["filter"] = bqry.to_dict()
133
- if not highlight and "highlight" in s:
134
- del s["highlight"]
135
- q_vec = s["knn"]["query_vector"]
136
- es_logger.info("【Q】: {}".format(json.dumps(s)))
137
- res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
138
- es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
139
- if self.es.getTotal(res) == 0 and "knn" in s:
140
- bqry, _ = self.qryr.question(qst, min_match="10%")
141
- if req.get("doc_ids"):
142
- bqry = Q("bool", must=[])
143
- bqry = self._add_filters(bqry, req)
144
- s["query"] = bqry.to_dict()
145
- s["knn"]["filter"] = bqry.to_dict()
146
- s["knn"]["similarity"] = 0.17
147
- res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
148
- es_logger.info("【Q】: {}".format(json.dumps(s)))
149
-
150
- kwds = set([])
151
- for k in keywords:
152
- kwds.add(k)
153
- for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
154
- if len(kk) < 2:
155
- continue
156
- if kk in kwds:
157
- continue
158
- kwds.add(kk)
159
-
160
- aggs = self.getAggregation(res, "docnm_kwd")
161
-
162
  return self.SearchResult(
163
- total=self.es.getTotal(res),
164
- ids=self.es.getDocIds(res),
165
  query_vector=q_vec,
166
  aggregation=aggs,
167
- highlight=self.getHighlight(res, keywords, "content_with_weight"),
168
- field=self.getFields(res, src),
169
- keywords=list(kwds)
170
  )
171
 
172
- def getAggregation(self, res, g):
173
- if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
174
- return
175
- bkts = res["aggregations"]["aggs_" + g]["buckets"]
176
- return [(b["key"], b["doc_count"]) for b in bkts]
177
-
178
- def getHighlight(self, res, keywords, fieldnm):
179
- ans = {}
180
- for d in res["hits"]["hits"]:
181
- hlts = d.get("highlight")
182
- if not hlts:
183
- continue
184
- txt = "...".join([a for a in list(hlts.items())[0][1]])
185
- if not is_english(txt.split(" ")):
186
- ans[d["_id"]] = txt
187
- continue
188
-
189
- txt = d["_source"][fieldnm]
190
- txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
191
- txts = []
192
- for t in re.split(r"[.?!;\n]", txt):
193
- for w in keywords:
194
- t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
195
- if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
196
- txts.append(t)
197
- ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
198
-
199
- return ans
200
-
201
- def getFields(self, sres, flds):
202
- res = {}
203
- if not flds:
204
- return {}
205
- for d in self.es.getSource(sres):
206
- m = {n: d.get(n) for n in flds if d.get(n) is not None}
207
- for n, v in m.items():
208
- if isinstance(v, type([])):
209
- m[n] = "\t".join([str(vv) if not isinstance(
210
- vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
211
- continue
212
- if not isinstance(v, type("")):
213
- m[n] = str(m[n])
214
- #if n.find("tks") > 0:
215
- # m[n] = rmSpace(m[n])
216
-
217
- if m:
218
- res[d["id"]] = m
219
- return res
220
-
221
  @staticmethod
222
  def trans2floats(txt):
223
  return [float(t) for t in txt.split("\t")]
@@ -260,7 +180,7 @@ class Dealer:
260
  continue
261
  idx.append(i)
262
  pieces_.append(t)
263
- es_logger.info("{} => {}".format(answer, pieces_))
264
  if not pieces_:
265
  return answer, set([])
266
 
@@ -281,7 +201,7 @@ class Dealer:
281
  chunks_tks,
282
  tkweight, vtweight)
283
  mx = np.max(sim) * 0.99
284
- es_logger.info("{} SIM: {}".format(pieces_[i], mx))
285
  if mx < thr:
286
  continue
287
  cites[idx[i]] = list(
@@ -309,9 +229,15 @@ class Dealer:
309
  def rerank(self, sres, query, tkweight=0.3,
310
  vtweight=0.7, cfield="content_ltks"):
311
  _, keywords = self.qryr.question(query)
312
- ins_embd = [
313
- Dealer.trans2floats(
314
- sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
 
 
 
 
 
 
315
  if not ins_embd:
316
  return [], [], []
317
 
@@ -377,7 +303,7 @@ class Dealer:
377
  if isinstance(tenant_ids, str):
378
  tenant_ids = tenant_ids.split(",")
379
 
380
- sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
381
  ranks["total"] = sres.total
382
 
383
  if page <= RERANK_PAGE_LIMIT:
@@ -393,6 +319,8 @@ class Dealer:
393
  idx = list(range(len(sres.ids)))
394
 
395
  dim = len(sres.query_vector)
 
 
396
  for i in idx:
397
  if sim[i] < similarity_threshold:
398
  break
@@ -401,34 +329,32 @@ class Dealer:
401
  continue
402
  break
403
  id = sres.ids[i]
404
- dnm = sres.field[id]["docnm_kwd"]
405
- did = sres.field[id]["doc_id"]
 
 
 
 
406
  d = {
407
  "chunk_id": id,
408
- "content_ltks": sres.field[id]["content_ltks"],
409
- "content_with_weight": sres.field[id]["content_with_weight"],
410
- "doc_id": sres.field[id]["doc_id"],
411
  "docnm_kwd": dnm,
412
- "kb_id": sres.field[id]["kb_id"],
413
- "important_kwd": sres.field[id].get("important_kwd", []),
414
- "img_id": sres.field[id].get("img_id", ""),
415
  "similarity": sim[i],
416
  "vector_similarity": vsim[i],
417
  "term_similarity": tsim[i],
418
- "vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
419
- "positions": sres.field[id].get("position_int", "").split("\t")
420
  }
421
  if highlight:
422
  if id in sres.highlight:
423
  d["highlight"] = rmSpace(sres.highlight[id])
424
  else:
425
  d["highlight"] = d["content_with_weight"]
426
- if len(d["positions"]) % 5 == 0:
427
- poss = []
428
- for i in range(0, len(d["positions"]), 5):
429
- poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
430
- float(d["positions"][i + 3]), float(d["positions"][i + 4])])
431
- d["positions"] = poss
432
  ranks["chunks"].append(d)
433
  if dnm not in ranks["doc_aggs"]:
434
  ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
@@ -442,39 +368,11 @@ class Dealer:
442
  return ranks
443
 
444
  def sql_retrieval(self, sql, fetch_size=128, format="json"):
445
- from api.settings import chat_logger
446
- sql = re.sub(r"[ `]+", " ", sql)
447
- sql = sql.replace("%", "")
448
- es_logger.info(f"Get es sql: {sql}")
449
- replaces = []
450
- for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
451
- fld, v = r.group(1), r.group(3)
452
- match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
453
- fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
454
- replaces.append(
455
- ("{}{}'{}'".format(
456
- r.group(1),
457
- r.group(2),
458
- r.group(3)),
459
- match))
460
-
461
- for p, r in replaces:
462
- sql = sql.replace(p, r, 1)
463
- chat_logger.info(f"To es: {sql}")
464
-
465
- try:
466
- tbl = self.es.sql(sql, fetch_size, format)
467
- return tbl
468
- except Exception as e:
469
- chat_logger.error(f"SQL failure: {sql} =>" + str(e))
470
- return {"error": str(e)}
471
-
472
- def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
473
- s = Search()
474
- s = s.query(Q("match", doc_id=doc_id))[0:max_count]
475
- s = s.to_dict()
476
- es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
477
- res = []
478
- for index, chunk in enumerate(es_res['hits']['hits']):
479
- res.append({fld: chunk['_source'].get(fld) for fld in fields})
480
- return res
 
14
  # limitations under the License.
15
  #
16
 
 
17
  import re
18
+ import json
 
 
19
  from typing import List, Optional, Dict, Union
20
  from dataclasses import dataclass
21
 
22
+ from rag.settings import doc_store_logger
23
  from rag.utils import rmSpace
24
+ from rag.nlp import rag_tokenizer, query
25
  import numpy as np
26
+ from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
27
 
28
 
29
  def index_name(uid): return f"ragflow_{uid}"
30
 
31
 
32
  class Dealer:
33
+ def __init__(self, dataStore: DocStoreConnection):
34
+ self.qryr = query.FulltextQueryer()
35
+ self.dataStore = dataStore
 
 
 
 
 
 
 
36
 
37
  @dataclass
38
  class SearchResult:
 
45
  keywords: Optional[List[str]] = None
46
  group_docs: List[List] = None
47
 
48
+ def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
49
+ qv, _ = emb_mdl.encode_queries(txt)
50
+ embedding_data = [float(v) for v in qv]
51
+ vector_column_name = f"q_{len(embedding_data)}_vec"
52
+ return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
53
+
54
+ def get_filters(self, req):
55
+ condition = dict()
56
+ for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
57
+ if key in req and req[key] is not None:
58
+ condition[field] = req[key]
59
+ # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
60
+ for key in ["knowledge_graph_kwd"]:
61
+ if key in req and req[key] is not None:
62
+ condition[key] = req[key]
63
+ return condition
64
+
65
+ def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
66
+ filters = self.get_filters(req)
67
+ orderBy = OrderByExpr()
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  pg = int(req.get("page", 1)) - 1
70
  topk = int(req.get("topk", 1024))
71
  ps = int(req.get("size", topk))
72
+ offset, limit = pg * ps, (pg + 1) * ps
73
+
74
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
75
+ "doc_id", "position_list", "knowledge_graph_kwd",
76
+ "available_int", "content_with_weight"])
77
+ kwds = set([])
78
 
79
+ qst = req.get("question", "")
80
+ q_vec = []
 
81
  if not qst:
82
+ if req.get("sort"):
83
+ orderBy.desc("create_timestamp_flt")
84
+ res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
85
+ total=self.dataStore.getTotal(res)
86
+ doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
87
+ else:
88
+ highlightFields = ["content_ltks", "title_tks"] if highlight else []
89
+ matchText, keywords = self.qryr.question(qst, min_match=0.3)
90
+ if emb_mdl is None:
91
+ matchExprs = [matchText]
92
+ res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
93
+ total=self.dataStore.getTotal(res)
94
+ doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
95
  else:
96
+ matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
97
+ q_vec = matchDense.embedding_data
98
+ src.append(f"q_{len(q_vec)}_vec")
99
+
100
+ fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
101
+ matchExprs = [matchText, matchDense, fusionExpr]
102
+
103
+ res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
104
+ total=self.dataStore.getTotal(res)
105
+ doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
106
+
107
+ # If result is empty, try again with lower min_match
108
+ if total == 0:
109
+ matchText, _ = self.qryr.question(qst, min_match=0.1)
110
+ if "doc_ids" in filters:
111
+ del filters["doc_ids"]
112
+ matchDense.extra_options["similarity"] = 0.17
113
+ res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
114
+ total=self.dataStore.getTotal(res)
115
+ doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
116
+
117
+ for k in keywords:
118
+ kwds.add(k)
119
+ for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
120
+ if len(kk) < 2:
121
+ continue
122
+ if kk in kwds:
123
+ continue
124
+ kwds.add(kk)
125
+
126
+ doc_store_logger.info(f"TOTAL: {total}")
127
+ ids=self.dataStore.getChunkIds(res)
128
+ keywords=list(kwds)
129
+ highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
130
+ aggs = self.dataStore.getAggregation(res, "docnm_kwd")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  return self.SearchResult(
132
+ total=total,
133
+ ids=ids,
134
  query_vector=q_vec,
135
  aggregation=aggs,
136
+ highlight=highlight,
137
+ field=self.dataStore.getFields(res, src),
138
+ keywords=keywords
139
  )
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  @staticmethod
142
  def trans2floats(txt):
143
  return [float(t) for t in txt.split("\t")]
 
180
  continue
181
  idx.append(i)
182
  pieces_.append(t)
183
+ doc_store_logger.info("{} => {}".format(answer, pieces_))
184
  if not pieces_:
185
  return answer, set([])
186
 
 
201
  chunks_tks,
202
  tkweight, vtweight)
203
  mx = np.max(sim) * 0.99
204
+ doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
205
  if mx < thr:
206
  continue
207
  cites[idx[i]] = list(
 
229
  def rerank(self, sres, query, tkweight=0.3,
230
  vtweight=0.7, cfield="content_ltks"):
231
  _, keywords = self.qryr.question(query)
232
+ vector_size = len(sres.query_vector)
233
+ vector_column = f"q_{vector_size}_vec"
234
+ zero_vector = [0.0] * vector_size
235
+ ins_embd = []
236
+ for chunk_id in sres.ids:
237
+ vector = sres.field[chunk_id].get(vector_column, zero_vector)
238
+ if isinstance(vector, str):
239
+ vector = [float(v) for v in vector.split("\t")]
240
+ ins_embd.append(vector)
241
  if not ins_embd:
242
  return [], [], []
243
 
 
303
  if isinstance(tenant_ids, str):
304
  tenant_ids = tenant_ids.split(",")
305
 
306
+ sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
307
  ranks["total"] = sres.total
308
 
309
  if page <= RERANK_PAGE_LIMIT:
 
319
  idx = list(range(len(sres.ids)))
320
 
321
  dim = len(sres.query_vector)
322
+ vector_column = f"q_{dim}_vec"
323
+ zero_vector = [0.0] * dim
324
  for i in idx:
325
  if sim[i] < similarity_threshold:
326
  break
 
329
  continue
330
  break
331
  id = sres.ids[i]
332
+ chunk = sres.field[id]
333
+ dnm = chunk["docnm_kwd"]
334
+ did = chunk["doc_id"]
335
+ position_list = chunk.get("position_list", "[]")
336
+ if not position_list:
337
+ position_list = "[]"
338
  d = {
339
  "chunk_id": id,
340
+ "content_ltks": chunk["content_ltks"],
341
+ "content_with_weight": chunk["content_with_weight"],
342
+ "doc_id": chunk["doc_id"],
343
  "docnm_kwd": dnm,
344
+ "kb_id": chunk["kb_id"],
345
+ "important_kwd": chunk.get("important_kwd", []),
346
+ "image_id": chunk.get("img_id", ""),
347
  "similarity": sim[i],
348
  "vector_similarity": vsim[i],
349
  "term_similarity": tsim[i],
350
+ "vector": chunk.get(vector_column, zero_vector),
351
+ "positions": json.loads(position_list)
352
  }
353
  if highlight:
354
  if id in sres.highlight:
355
  d["highlight"] = rmSpace(sres.highlight[id])
356
  else:
357
  d["highlight"] = d["content_with_weight"]
 
 
 
 
 
 
358
  ranks["chunks"].append(d)
359
  if dnm not in ranks["doc_aggs"]:
360
  ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
 
368
  return ranks
369
 
370
  def sql_retrieval(self, sql, fetch_size=128, format="json"):
371
+ tbl = self.dataStore.sql(sql, fetch_size, format)
372
+ return tbl
373
+
374
+ def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
375
+ condition = {"doc_id": doc_id}
376
+ res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
377
+ dict_chunks = self.dataStore.getFields(res, fields)
378
+ return dict_chunks.values()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag/settings.py CHANGED
@@ -25,12 +25,13 @@ RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
25
  SUBPROCESS_STD_LOG_NAME = "std.log"
26
 
27
  ES = get_base_config("es", {})
 
28
  AZURE = get_base_config("azure", {})
29
  S3 = get_base_config("s3", {})
30
  MINIO = decrypt_database_config(name="minio")
31
  try:
32
  REDIS = decrypt_database_config(name="redis")
33
- except Exception as e:
34
  REDIS = {}
35
  pass
36
  DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
@@ -44,7 +45,7 @@ LoggerFactory.set_directory(
44
  # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
45
  LoggerFactory.LEVEL = 30
46
 
47
- es_logger = getLogger("es")
48
  minio_logger = getLogger("minio")
49
  s3_logger = getLogger("s3")
50
  azure_logger = getLogger("azure")
@@ -53,7 +54,7 @@ chunk_logger = getLogger("chunk_logger")
53
  database_logger = getLogger("database")
54
 
55
  formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
56
- for logger in [es_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
57
  logger.setLevel(logging.INFO)
58
  for handler in logger.handlers:
59
  handler.setFormatter(fmt=formatter)
 
25
  SUBPROCESS_STD_LOG_NAME = "std.log"
26
 
27
  ES = get_base_config("es", {})
28
+ INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
29
  AZURE = get_base_config("azure", {})
30
  S3 = get_base_config("s3", {})
31
  MINIO = decrypt_database_config(name="minio")
32
  try:
33
  REDIS = decrypt_database_config(name="redis")
34
+ except Exception:
35
  REDIS = {}
36
  pass
37
  DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
 
45
  # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
46
  LoggerFactory.LEVEL = 30
47
 
48
+ doc_store_logger = getLogger("doc_store")
49
  minio_logger = getLogger("minio")
50
  s3_logger = getLogger("s3")
51
  azure_logger = getLogger("azure")
 
54
  database_logger = getLogger("database")
55
 
56
  formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
57
+ for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
58
  logger.setLevel(logging.INFO)
59
  for handler in logger.handlers:
60
  handler.setFormatter(fmt=formatter)
rag/svr/task_executor.py CHANGED
@@ -31,7 +31,6 @@ from timeit import default_timer as timer
31
 
32
  import numpy as np
33
  import pandas as pd
34
- from elasticsearch_dsl import Q
35
 
36
  from api.db import LLMType, ParserType
37
  from api.db.services.dialog_service import keyword_extraction, question_proposal
@@ -39,8 +38,7 @@ from api.db.services.document_service import DocumentService
39
  from api.db.services.llm_service import LLMBundle
40
  from api.db.services.task_service import TaskService
41
  from api.db.services.file2document_service import File2DocumentService
42
- from api.settings import retrievaler
43
- from api.utils.file_utils import get_project_base_directory
44
  from api.db.db_models import close_connection
45
  from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
46
  from rag.nlp import search, rag_tokenizer
@@ -48,7 +46,6 @@ from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as
48
  from rag.settings import database_logger, SVR_QUEUE_NAME
49
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
50
  from rag.utils import rmSpace, num_tokens_from_string
51
- from rag.utils.es_conn import ELASTICSEARCH
52
  from rag.utils.redis_conn import REDIS_CONN, Payload
53
  from rag.utils.storage_factory import STORAGE_IMPL
54
 
@@ -126,7 +123,7 @@ def collect():
126
  return pd.DataFrame()
127
  tasks = TaskService.get_tasks(msg["id"])
128
  if not tasks:
129
- cron_logger.warn("{} empty task!".format(msg["id"]))
130
  return []
131
 
132
  tasks = pd.DataFrame(tasks)
@@ -187,7 +184,7 @@ def build(row):
187
  docs = []
188
  doc = {
189
  "doc_id": row["doc_id"],
190
- "kb_id": [str(row["kb_id"])]
191
  }
192
  el = 0
193
  for ck in cks:
@@ -196,10 +193,14 @@ def build(row):
196
  md5 = hashlib.md5()
197
  md5.update((ck["content_with_weight"] +
198
  str(d["doc_id"])).encode("utf-8"))
199
- d["_id"] = md5.hexdigest()
200
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
201
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
202
  if not d.get("image"):
 
 
 
 
203
  docs.append(d)
204
  continue
205
 
@@ -211,13 +212,13 @@ def build(row):
211
  d["image"].save(output_buffer, format='JPEG')
212
 
213
  st = timer()
214
- STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue())
215
  el += timer() - st
216
  except Exception as e:
217
  cron_logger.error(str(e))
218
  traceback.print_exc()
219
 
220
- d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
221
  del d["image"]
222
  docs.append(d)
223
  cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
@@ -245,12 +246,9 @@ def build(row):
245
  return docs
246
 
247
 
248
- def init_kb(row):
249
  idxnm = search.index_name(row["tenant_id"])
250
- if ELASTICSEARCH.indexExist(idxnm):
251
- return
252
- return ELASTICSEARCH.createIdx(idxnm, json.load(
253
- open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
254
 
255
 
256
  def embedding(docs, mdl, parser_config=None, callback=None):
@@ -288,17 +286,20 @@ def embedding(docs, mdl, parser_config=None, callback=None):
288
  cnts) if len(tts) == len(cnts) else cnts
289
 
290
  assert len(vects) == len(docs)
 
291
  for i, d in enumerate(docs):
292
  v = vects[i].tolist()
 
293
  d["q_%d_vec" % len(v)] = v
294
- return tk_count
295
 
296
 
297
  def run_raptor(row, chat_mdl, embd_mdl, callback=None):
298
  vts, _ = embd_mdl.encode(["ok"])
299
- vctr_nm = "q_%d_vec" % len(vts[0])
 
300
  chunks = []
301
- for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
302
  chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
303
 
304
  raptor = Raptor(
@@ -323,7 +324,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
323
  d = copy.deepcopy(doc)
324
  md5 = hashlib.md5()
325
  md5.update((content + str(d["doc_id"])).encode("utf-8"))
326
- d["_id"] = md5.hexdigest()
327
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
328
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
329
  d[vctr_nm] = vctr.tolist()
@@ -332,7 +333,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
332
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
333
  res.append(d)
334
  tk_count += num_tokens_from_string(content)
335
- return res, tk_count
336
 
337
 
338
  def main():
@@ -352,7 +353,7 @@ def main():
352
  if r.get("task_type", "") == "raptor":
353
  try:
354
  chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
355
- cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
356
  except Exception as e:
357
  callback(-1, msg=str(e))
358
  cron_logger.error(str(e))
@@ -373,7 +374,7 @@ def main():
373
  len(cks))
374
  st = timer()
375
  try:
376
- tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
377
  except Exception as e:
378
  callback(-1, "Embedding error:{}".format(str(e)))
379
  cron_logger.error(str(e))
@@ -381,26 +382,25 @@ def main():
381
  cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
382
  callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
383
 
384
- init_kb(r)
385
- chunk_count = len(set([c["_id"] for c in cks]))
 
386
  st = timer()
387
  es_r = ""
388
  es_bulk_size = 4
389
  for b in range(0, len(cks), es_bulk_size):
390
- es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
391
  if b % 128 == 0:
392
  callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
393
 
394
  cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
395
  if es_r:
396
  callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
397
- ELASTICSEARCH.deleteByQuery(
398
- Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
399
- cron_logger.error(str(es_r))
400
  else:
401
  if TaskService.do_cancel(r["id"]):
402
- ELASTICSEARCH.deleteByQuery(
403
- Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
404
  continue
405
  callback(1., "Done!")
406
  DocumentService.increment_chunk_num(
 
31
 
32
  import numpy as np
33
  import pandas as pd
 
34
 
35
  from api.db import LLMType, ParserType
36
  from api.db.services.dialog_service import keyword_extraction, question_proposal
 
38
  from api.db.services.llm_service import LLMBundle
39
  from api.db.services.task_service import TaskService
40
  from api.db.services.file2document_service import File2DocumentService
41
+ from api.settings import retrievaler, docStoreConn
 
42
  from api.db.db_models import close_connection
43
  from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
44
  from rag.nlp import search, rag_tokenizer
 
46
  from rag.settings import database_logger, SVR_QUEUE_NAME
47
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
48
  from rag.utils import rmSpace, num_tokens_from_string
 
49
  from rag.utils.redis_conn import REDIS_CONN, Payload
50
  from rag.utils.storage_factory import STORAGE_IMPL
51
 
 
123
  return pd.DataFrame()
124
  tasks = TaskService.get_tasks(msg["id"])
125
  if not tasks:
126
+ cron_logger.warning("{} empty task!".format(msg["id"]))
127
  return []
128
 
129
  tasks = pd.DataFrame(tasks)
 
184
  docs = []
185
  doc = {
186
  "doc_id": row["doc_id"],
187
+ "kb_id": str(row["kb_id"])
188
  }
189
  el = 0
190
  for ck in cks:
 
193
  md5 = hashlib.md5()
194
  md5.update((ck["content_with_weight"] +
195
  str(d["doc_id"])).encode("utf-8"))
196
+ d["id"] = md5.hexdigest()
197
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
198
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
199
  if not d.get("image"):
200
+ d["img_id"] = ""
201
+ d["page_num_list"] = json.dumps([])
202
+ d["position_list"] = json.dumps([])
203
+ d["top_list"] = json.dumps([])
204
  docs.append(d)
205
  continue
206
 
 
212
  d["image"].save(output_buffer, format='JPEG')
213
 
214
  st = timer()
215
+ STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
216
  el += timer() - st
217
  except Exception as e:
218
  cron_logger.error(str(e))
219
  traceback.print_exc()
220
 
221
+ d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
222
  del d["image"]
223
  docs.append(d)
224
  cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
 
246
  return docs
247
 
248
 
249
+ def init_kb(row, vector_size: int):
250
  idxnm = search.index_name(row["tenant_id"])
251
+ return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
 
 
 
252
 
253
 
254
  def embedding(docs, mdl, parser_config=None, callback=None):
 
286
  cnts) if len(tts) == len(cnts) else cnts
287
 
288
  assert len(vects) == len(docs)
289
+ vector_size = 0
290
  for i, d in enumerate(docs):
291
  v = vects[i].tolist()
292
+ vector_size = len(v)
293
  d["q_%d_vec" % len(v)] = v
294
+ return tk_count, vector_size
295
 
296
 
297
  def run_raptor(row, chat_mdl, embd_mdl, callback=None):
298
  vts, _ = embd_mdl.encode(["ok"])
299
+ vector_size = len(vts[0])
300
+ vctr_nm = "q_%d_vec" % vector_size
301
  chunks = []
302
+ for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
303
  chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
304
 
305
  raptor = Raptor(
 
324
  d = copy.deepcopy(doc)
325
  md5 = hashlib.md5()
326
  md5.update((content + str(d["doc_id"])).encode("utf-8"))
327
+ d["id"] = md5.hexdigest()
328
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
329
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
330
  d[vctr_nm] = vctr.tolist()
 
333
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
334
  res.append(d)
335
  tk_count += num_tokens_from_string(content)
336
+ return res, tk_count, vector_size
337
 
338
 
339
  def main():
 
353
  if r.get("task_type", "") == "raptor":
354
  try:
355
  chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
356
+ cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
357
  except Exception as e:
358
  callback(-1, msg=str(e))
359
  cron_logger.error(str(e))
 
374
  len(cks))
375
  st = timer()
376
  try:
377
+ tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
378
  except Exception as e:
379
  callback(-1, "Embedding error:{}".format(str(e)))
380
  cron_logger.error(str(e))
 
382
  cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
383
  callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
384
 
385
+ # cron_logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
386
+ init_kb(r, vector_size)
387
+ chunk_count = len(set([c["id"] for c in cks]))
388
  st = timer()
389
  es_r = ""
390
  es_bulk_size = 4
391
  for b in range(0, len(cks), es_bulk_size):
392
+ es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
393
  if b % 128 == 0:
394
  callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
395
 
396
  cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
397
  if es_r:
398
  callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
399
+ docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
400
+ cron_logger.error('Insert chunk error: ' + str(es_r))
 
401
  else:
402
  if TaskService.do_cancel(r["id"]):
403
+ docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
 
404
  continue
405
  callback(1., "Done!")
406
  DocumentService.increment_chunk_num(
rag/utils/doc_store_conn.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional, Union
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import polars as pl
6
+ from typing import List, Dict
7
+
8
+ DEFAULT_MATCH_VECTOR_TOPN = 10
9
+ DEFAULT_MATCH_SPARSE_TOPN = 10
10
+ VEC = Union[list, np.ndarray]
11
+
12
+
13
+ @dataclass
14
+ class SparseVector:
15
+ indices: list[int]
16
+ values: Union[list[float], list[int], None] = None
17
+
18
+ def __post_init__(self):
19
+ assert (self.values is None) or (len(self.indices) == len(self.values))
20
+
21
+ def to_dict_old(self):
22
+ d = {"indices": self.indices}
23
+ if self.values is not None:
24
+ d["values"] = self.values
25
+ return d
26
+
27
+ def to_dict(self):
28
+ if self.values is None:
29
+ raise ValueError("SparseVector.values is None")
30
+ result = {}
31
+ for i, v in zip(self.indices, self.values):
32
+ result[str(i)] = v
33
+ return result
34
+
35
+ @staticmethod
36
+ def from_dict(d):
37
+ return SparseVector(d["indices"], d.get("values"))
38
+
39
+ def __str__(self):
40
+ return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
41
+
42
+ def __repr__(self):
43
+ return str(self)
44
+
45
+
46
+ class MatchTextExpr(ABC):
47
+ def __init__(
48
+ self,
49
+ fields: str,
50
+ matching_text: str,
51
+ topn: int,
52
+ extra_options: dict = dict(),
53
+ ):
54
+ self.fields = fields
55
+ self.matching_text = matching_text
56
+ self.topn = topn
57
+ self.extra_options = extra_options
58
+
59
+
60
+ class MatchDenseExpr(ABC):
61
+ def __init__(
62
+ self,
63
+ vector_column_name: str,
64
+ embedding_data: VEC,
65
+ embedding_data_type: str,
66
+ distance_type: str,
67
+ topn: int = DEFAULT_MATCH_VECTOR_TOPN,
68
+ extra_options: dict = dict(),
69
+ ):
70
+ self.vector_column_name = vector_column_name
71
+ self.embedding_data = embedding_data
72
+ self.embedding_data_type = embedding_data_type
73
+ self.distance_type = distance_type
74
+ self.topn = topn
75
+ self.extra_options = extra_options
76
+
77
+
78
+ class MatchSparseExpr(ABC):
79
+ def __init__(
80
+ self,
81
+ vector_column_name: str,
82
+ sparse_data: SparseVector | dict,
83
+ distance_type: str,
84
+ topn: int,
85
+ opt_params: Optional[dict] = None,
86
+ ):
87
+ self.vector_column_name = vector_column_name
88
+ self.sparse_data = sparse_data
89
+ self.distance_type = distance_type
90
+ self.topn = topn
91
+ self.opt_params = opt_params
92
+
93
+
94
+ class MatchTensorExpr(ABC):
95
+ def __init__(
96
+ self,
97
+ column_name: str,
98
+ query_data: VEC,
99
+ query_data_type: str,
100
+ topn: int,
101
+ extra_option: Optional[dict] = None,
102
+ ):
103
+ self.column_name = column_name
104
+ self.query_data = query_data
105
+ self.query_data_type = query_data_type
106
+ self.topn = topn
107
+ self.extra_option = extra_option
108
+
109
+
110
+ class FusionExpr(ABC):
111
+ def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
112
+ self.method = method
113
+ self.topn = topn
114
+ self.fusion_params = fusion_params
115
+
116
+
117
+ MatchExpr = Union[
118
+ MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
119
+ ]
120
+
121
+
122
+ class OrderByExpr(ABC):
123
+ def __init__(self):
124
+ self.fields = list()
125
+ def asc(self, field: str):
126
+ self.fields.append((field, 0))
127
+ return self
128
+ def desc(self, field: str):
129
+ self.fields.append((field, 1))
130
+ return self
131
+ def fields(self):
132
+ return self.fields
133
+
134
+ class DocStoreConnection(ABC):
135
+ """
136
+ Database operations
137
+ """
138
+
139
+ @abstractmethod
140
+ def dbType(self) -> str:
141
+ """
142
+ Return the type of the database.
143
+ """
144
+ raise NotImplementedError("Not implemented")
145
+
146
+ @abstractmethod
147
+ def health(self) -> dict:
148
+ """
149
+ Return the health status of the database.
150
+ """
151
+ raise NotImplementedError("Not implemented")
152
+
153
+ """
154
+ Table operations
155
+ """
156
+
157
+ @abstractmethod
158
+ def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
159
+ """
160
+ Create an index with given name
161
+ """
162
+ raise NotImplementedError("Not implemented")
163
+
164
+ @abstractmethod
165
+ def deleteIdx(self, indexName: str, knowledgebaseId: str):
166
+ """
167
+ Delete an index with given name
168
+ """
169
+ raise NotImplementedError("Not implemented")
170
+
171
+ @abstractmethod
172
+ def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
173
+ """
174
+ Check if an index with given name exists
175
+ """
176
+ raise NotImplementedError("Not implemented")
177
+
178
+ """
179
+ CRUD operations
180
+ """
181
+
182
+ @abstractmethod
183
+ def search(
184
+ self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
185
+ ) -> list[dict] | pl.DataFrame:
186
+ """
187
+ Search with given conjunctive equivalent filtering condition and return all fields of matched documents
188
+ """
189
+ raise NotImplementedError("Not implemented")
190
+
191
+ @abstractmethod
192
+ def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
193
+ """
194
+ Get single chunk with given id
195
+ """
196
+ raise NotImplementedError("Not implemented")
197
+
198
+ @abstractmethod
199
+ def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
200
+ """
201
+ Update or insert a bulk of rows
202
+ """
203
+ raise NotImplementedError("Not implemented")
204
+
205
+ @abstractmethod
206
+ def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
207
+ """
208
+ Update rows with given conjunctive equivalent filtering condition
209
+ """
210
+ raise NotImplementedError("Not implemented")
211
+
212
+ @abstractmethod
213
+ def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
214
+ """
215
+ Delete rows with given conjunctive equivalent filtering condition
216
+ """
217
+ raise NotImplementedError("Not implemented")
218
+
219
+ """
220
+ Helper functions for search result
221
+ """
222
+
223
+ @abstractmethod
224
+ def getTotal(self, res):
225
+ raise NotImplementedError("Not implemented")
226
+
227
+ @abstractmethod
228
+ def getChunkIds(self, res):
229
+ raise NotImplementedError("Not implemented")
230
+
231
+ @abstractmethod
232
+ def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
233
+ raise NotImplementedError("Not implemented")
234
+
235
+ @abstractmethod
236
+ def getHighlight(self, res, keywords: List[str], fieldnm: str):
237
+ raise NotImplementedError("Not implemented")
238
+
239
+ @abstractmethod
240
+ def getAggregation(self, res, fieldnm: str):
241
+ raise NotImplementedError("Not implemented")
242
+
243
+ """
244
+ SQL
245
+ """
246
+ @abstractmethod
247
+ def sql(sql: str, fetch_size: int, format: str):
248
+ """
249
+ Run the sql generated by text-to-sql
250
+ """
251
+ raise NotImplementedError("Not implemented")
rag/utils/es_conn.py CHANGED
@@ -1,29 +1,29 @@
1
  import re
2
  import json
3
  import time
4
- import copy
 
5
 
6
  import elasticsearch
7
- from elastic_transport import ConnectionTimeout
8
  from elasticsearch import Elasticsearch
9
- from elasticsearch_dsl import UpdateByQuery, Search, Index
10
- from rag.settings import es_logger
 
11
  from rag import settings
12
  from rag.utils import singleton
 
 
 
 
13
 
14
- es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
15
 
16
 
17
  @singleton
18
- class ESConnection:
19
  def __init__(self):
20
  self.info = {}
21
- self.conn()
22
- self.idxnm = settings.ES.get("index_name", "")
23
- if not self.es.ping():
24
- raise Exception("Can't connect to ES cluster")
25
-
26
- def conn(self):
27
  for _ in range(10):
28
  try:
29
  self.es = Elasticsearch(
@@ -34,390 +34,317 @@ class ESConnection:
34
  )
35
  if self.es:
36
  self.info = self.es.info()
37
- es_logger.info("Connect to es.")
38
  break
39
  except Exception as e:
40
- es_logger.error("Fail to connect to es: " + str(e))
41
  time.sleep(1)
42
-
43
- def version(self):
44
  v = self.info.get("version", {"number": "5.6"})
45
  v = v["number"].split(".")[0]
46
- return int(v) >= 7
47
-
48
- def health(self):
49
- return dict(self.es.cluster.health())
50
-
51
- def upsert(self, df, idxnm=""):
52
- res = []
53
- for d in df:
54
- id = d["id"]
55
- del d["id"]
56
- d = {"doc": d, "doc_as_upsert": "true"}
57
- T = False
58
- for _ in range(10):
59
- try:
60
- if not self.version():
61
- r = self.es.update(
62
- index=(
63
- self.idxnm if not idxnm else idxnm),
64
- body=d,
65
- id=id,
66
- doc_type="doc",
67
- refresh=True,
68
- retry_on_conflict=100)
69
- else:
70
- r = self.es.update(
71
- index=(
72
- self.idxnm if not idxnm else idxnm),
73
- body=d,
74
- id=id,
75
- refresh=True,
76
- retry_on_conflict=100)
77
- es_logger.info("Successfully upsert: %s" % id)
78
- T = True
79
- break
80
- except Exception as e:
81
- es_logger.warning("Fail to index: " +
82
- json.dumps(d, ensure_ascii=False) + str(e))
83
- if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
84
- time.sleep(3)
85
- continue
86
- self.conn()
87
- T = False
88
-
89
- if not T:
90
- res.append(d)
91
- es_logger.error(
92
- "Fail to index: " +
93
- re.sub(
94
- "[\r\n]",
95
- "",
96
- json.dumps(
97
- d,
98
- ensure_ascii=False)))
99
- d["id"] = id
100
- d["_index"] = self.idxnm
101
-
102
- if not res:
103
  return True
104
- return False
105
-
106
- def bulk(self, df, idx_nm=None):
107
- ids, acts = {}, []
108
- for d in df:
109
- id = d["id"] if "id" in d else d["_id"]
110
- ids[id] = copy.deepcopy(d)
111
- ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
112
- if "id" in d:
113
- del d["id"]
114
- if "_id" in d:
115
- del d["_id"]
116
- acts.append(
117
- {"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
118
- acts.append({"doc": d, "doc_as_upsert": "true"})
119
-
120
- res = []
121
- for _ in range(100):
122
- try:
123
- if elasticsearch.__version__[0] < 8:
124
- r = self.es.bulk(
125
- index=(
126
- self.idxnm if not idx_nm else idx_nm),
127
- body=acts,
128
- refresh=False,
129
- timeout="600s")
130
- else:
131
- r = self.es.bulk(index=(self.idxnm if not idx_nm else
132
- idx_nm), operations=acts,
133
- refresh=False, timeout="600s")
134
- if re.search(r"False", str(r["errors"]), re.IGNORECASE):
135
- return res
136
-
137
- for it in r["items"]:
138
- if "error" in it["update"]:
139
- res.append(str(it["update"]["_id"]) +
140
- ":" + str(it["update"]["error"]))
141
-
142
- return res
143
- except Exception as e:
144
- es_logger.warn("Fail to bulk: " + str(e))
145
- if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
146
- time.sleep(3)
147
- continue
148
- self.conn()
149
-
150
- return res
151
-
152
- def bulk4script(self, df):
153
- ids, acts = {}, []
154
- for d in df:
155
- id = d["id"]
156
- ids[id] = copy.deepcopy(d["raw"])
157
- acts.append({"update": {"_id": id, "_index": self.idxnm}})
158
- acts.append(d["script"])
159
- es_logger.info("bulk upsert: %s" % id)
160
-
161
- res = []
162
- for _ in range(10):
163
- try:
164
- if not self.version():
165
- r = self.es.bulk(
166
- index=self.idxnm,
167
- body=acts,
168
- refresh=False,
169
- timeout="600s",
170
- doc_type="doc")
171
- else:
172
- r = self.es.bulk(
173
- index=self.idxnm,
174
- body=acts,
175
- refresh=False,
176
- timeout="600s")
177
- if re.search(r"False", str(r["errors"]), re.IGNORECASE):
178
- return res
179
-
180
- for it in r["items"]:
181
- if "error" in it["update"]:
182
- res.append(str(it["update"]["_id"]))
183
-
184
- return res
185
- except Exception as e:
186
- es_logger.warning("Fail to bulk: " + str(e))
187
- if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
188
- time.sleep(3)
189
- continue
190
- self.conn()
191
 
192
- return res
 
 
 
 
193
 
194
- def rm(self, d):
195
- for _ in range(10):
 
196
  try:
197
- if not self.version():
198
- r = self.es.delete(
199
- index=self.idxnm,
200
- id=d["id"],
201
- doc_type="doc",
202
- refresh=True)
203
- else:
204
- r = self.es.delete(
205
- index=self.idxnm,
206
- id=d["id"],
207
- refresh=True,
208
- doc_type="_doc")
209
- es_logger.info("Remove %s" % d["id"])
210
- return True
211
  except Exception as e:
212
- es_logger.warn("Fail to delete: " + str(d) + str(e))
213
- if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
214
- time.sleep(3)
215
  continue
216
- if re.search(r"(not_found)", str(e), re.IGNORECASE):
217
- return True
218
- self.conn()
219
-
220
- es_logger.error("Fail to delete: " + str(d))
221
-
222
  return False
223
 
224
- def search(self, q, idxnms=None, src=False, timeout="2s"):
225
- if not isinstance(q, dict):
226
- q = Search().query(q).to_dict()
227
- if isinstance(idxnms, str):
228
- idxnms = idxnms.split(",")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  for i in range(3):
230
  try:
231
- res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
232
  body=q,
233
- timeout=timeout,
234
  # search_type="dfs_query_then_fetch",
235
  track_total_hits=True,
236
- _source=src)
237
  if str(res.get("timed_out", "")).lower() == "true":
238
  raise Exception("Es Timeout.")
 
239
  return res
240
  except Exception as e:
241
- es_logger.error(
242
  "ES search exception: " +
243
  str(e) +
244
- "Q】:" +
245
  str(q))
246
  if str(e).find("Timeout") > 0:
247
  continue
248
  raise e
249
- es_logger.error("ES search timeout for 3 times!")
250
  raise Exception("ES search timeout.")
251
 
252
- def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
253
- for i in range(3):
254
- try:
255
- res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
256
- return res
257
- except ConnectionTimeout as e:
258
- es_logger.error("Timeout【Q】:" + sql)
259
- continue
260
- except Exception as e:
261
- raise e
262
- es_logger.error("ES search timeout for 3 times!")
263
- raise ConnectionTimeout()
264
-
265
-
266
- def get(self, doc_id, idxnm=None):
267
  for i in range(3):
268
  try:
269
- res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
270
- id=doc_id)
271
  if str(res.get("timed_out", "")).lower() == "true":
272
  raise Exception("Es Timeout.")
273
- return res
 
 
 
 
274
  except Exception as e:
275
- es_logger.error(
276
  "ES get exception: " +
277
  str(e) +
278
- "Q】:" +
279
- doc_id)
280
  if str(e).find("Timeout") > 0:
281
  continue
282
  raise e
283
- es_logger.error("ES search timeout for 3 times!")
284
  raise Exception("ES search timeout.")
285
 
286
- def updateByQuery(self, q, d):
287
- ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
288
- scripts = ""
289
- for k, v in d.items():
290
- scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
291
- ubq = ubq.script(source=scripts, params=d)
292
- ubq = ubq.params(refresh=False)
293
- ubq = ubq.params(slices=5)
294
- ubq = ubq.params(conflicts="proceed")
295
- for i in range(3):
296
- try:
297
- r = ubq.execute()
298
- return True
299
- except Exception as e:
300
- es_logger.error("ES updateByQuery exception: " +
301
- str(e) + "【Q】:" + str(q.to_dict()))
302
- if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
303
- continue
304
- self.conn()
305
-
306
- return False
307
 
308
- def updateScriptByQuery(self, q, scripts, idxnm=None):
309
- ubq = UpdateByQuery(
310
- index=self.idxnm if not idxnm else idxnm).using(
311
- self.es).query(q)
312
- ubq = ubq.script(source=scripts)
313
- ubq = ubq.params(refresh=True)
314
- ubq = ubq.params(slices=5)
315
- ubq = ubq.params(conflicts="proceed")
316
- for i in range(3):
317
  try:
318
- r = ubq.execute()
319
- return True
320
- except Exception as e:
321
- es_logger.error("ES updateByQuery exception: " +
322
- str(e) + "【Q】:" + str(q.to_dict()))
323
- if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
324
- continue
325
- self.conn()
326
-
327
- return False
328
 
329
- def deleteByQuery(self, query, idxnm=""):
330
- for i in range(3):
331
- try:
332
- r = self.es.delete_by_query(
333
- index=idxnm if idxnm else self.idxnm,
334
- refresh = True,
335
- body=Search().query(query).to_dict())
336
- return True
337
  except Exception as e:
338
- es_logger.error("ES updateByQuery deleteByQuery: " +
339
- str(e) + "【Q】:" + str(query.to_dict()))
340
- if str(e).find("NotFoundError") > 0: return True
341
- if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
342
  continue
 
343
 
344
- return False
345
-
346
- def update(self, id, script, routing=None):
347
- for i in range(3):
348
- try:
349
- if not self.version():
350
- r = self.es.update(
351
- index=self.idxnm,
352
- id=id,
353
- body=json.dumps(
354
- script,
355
- ensure_ascii=False),
356
- doc_type="doc",
357
- routing=routing,
358
- refresh=False)
359
- else:
360
- r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
361
- routing=routing, refresh=False) # , doc_type="_doc")
362
- return True
363
- except Exception as e:
364
- es_logger.error(
365
- "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
366
- json.dumps(script, ensure_ascii=False))
367
- if str(e).find("Timeout") > 0:
368
  continue
369
-
370
- return False
371
-
372
- def indexExist(self, idxnm):
373
- s = Index(idxnm if idxnm else self.idxnm, self.es)
374
- for i in range(3):
375
- try:
376
- return s.exists()
377
- except Exception as e:
378
- es_logger.error("ES updateByQuery indexExist: " + str(e))
379
- if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
380
  continue
381
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  return False
383
 
384
- def docExist(self, docid, idxnm=None):
385
- for i in range(3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  try:
387
- return self.es.exists(index=(idxnm if idxnm else self.idxnm),
388
- id=docid)
 
 
 
389
  except Exception as e:
390
- es_logger.error("ES Doc Exist: " + str(e))
391
- if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 
392
  continue
393
- return False
394
-
395
- def createIdx(self, idxnm, mapping):
396
- try:
397
- if elasticsearch.__version__[0] < 8:
398
- return self.es.indices.create(idxnm, body=mapping)
399
- from elasticsearch.client import IndicesClient
400
- return IndicesClient(self.es).create(index=idxnm,
401
- settings=mapping["settings"],
402
- mappings=mapping["mappings"])
403
- except Exception as e:
404
- es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
405
 
406
- def deleteIdx(self, idxnm):
407
- try:
408
- return self.es.indices.delete(idxnm, allow_no_indices=True)
409
- except Exception as e:
410
- es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
411
 
 
 
 
412
  def getTotal(self, res):
413
  if isinstance(res["hits"]["total"], type({})):
414
  return res["hits"]["total"]["value"]
415
  return res["hits"]["total"]
416
 
417
- def getDocIds(self, res):
418
  return [d["_id"] for d in res["hits"]["hits"]]
419
 
420
- def getSource(self, res):
421
  rr = []
422
  for d in res["hits"]["hits"]:
423
  d["_source"]["id"] = d["_id"]
@@ -425,40 +352,89 @@ class ESConnection:
425
  rr.append(d["_source"])
426
  return rr
427
 
428
- def scrollIter(self, pagesize=100, scroll_time='2m', q={
429
- "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
430
- for _ in range(100):
431
- try:
432
- page = self.es.search(
433
- index=self.idxnm,
434
- scroll=scroll_time,
435
- size=pagesize,
436
- body=q,
437
- _source=None
438
- )
439
- break
440
- except Exception as e:
441
- es_logger.error("ES scrolling fail. " + str(e))
442
- time.sleep(3)
443
-
444
- sid = page['_scroll_id']
445
- scroll_size = page['hits']['total']["value"]
446
- es_logger.info("[TOTAL]%d" % scroll_size)
447
- # Start scrolling
448
- while scroll_size > 0:
449
- yield page["hits"]["hits"]
450
- for _ in range(100):
451
- try:
452
- page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
453
- break
454
- except Exception as e:
455
- es_logger.error("ES scrolling fail. " + str(e))
456
- time.sleep(3)
457
 
458
- # Update the scroll ID
459
- sid = page['_scroll_id']
460
- # Get the number of results that we returned in the last scroll
461
- scroll_size = len(page['hits']['hits'])
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
- ELASTICSEARCH = ESConnection()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import json
3
  import time
4
+ import os
5
+ from typing import List, Dict
6
 
7
  import elasticsearch
8
+ import copy
9
  from elasticsearch import Elasticsearch
10
+ from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
11
+ from elastic_transport import ConnectionTimeout
12
+ from rag.settings import doc_store_logger
13
  from rag import settings
14
  from rag.utils import singleton
15
+ from api.utils.file_utils import get_project_base_directory
16
+ import polars as pl
17
+ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
18
+ from rag.nlp import is_english, rag_tokenizer
19
 
20
+ doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
21
 
22
 
23
  @singleton
24
+ class ESConnection(DocStoreConnection):
25
  def __init__(self):
26
  self.info = {}
 
 
 
 
 
 
27
  for _ in range(10):
28
  try:
29
  self.es = Elasticsearch(
 
34
  )
35
  if self.es:
36
  self.info = self.es.info()
37
+ doc_store_logger.info("Connect to es.")
38
  break
39
  except Exception as e:
40
+ doc_store_logger.error("Fail to connect to es: " + str(e))
41
  time.sleep(1)
42
+ if not self.es.ping():
43
+ raise Exception("Can't connect to ES cluster")
44
  v = self.info.get("version", {"number": "5.6"})
45
  v = v["number"].split(".")[0]
46
+ if int(v) < 8:
47
+ raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
48
+ fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
49
+ if not os.path.exists(fp_mapping):
50
+ raise Exception(f"Mapping file not found at {fp_mapping}")
51
+ self.mapping = json.load(open(fp_mapping, "r"))
52
+
53
+ """
54
+ Database operations
55
+ """
56
+ def dbType(self) -> str:
57
+ return "elasticsearch"
58
+
59
+ def health(self) -> dict:
60
+ return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
61
+
62
+ """
63
+ Table operations
64
+ """
65
+ def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
66
+ if self.indexExist(indexName, knowledgebaseId):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return True
68
+ try:
69
+ from elasticsearch.client import IndicesClient
70
+ return IndicesClient(self.es).create(index=indexName,
71
+ settings=self.mapping["settings"],
72
+ mappings=self.mapping["mappings"])
73
+ except Exception as e:
74
+ doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def deleteIdx(self, indexName: str, knowledgebaseId: str):
77
+ try:
78
+ return self.es.indices.delete(indexName, allow_no_indices=True)
79
+ except Exception as e:
80
+ doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
81
 
82
+ def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
83
+ s = Index(indexName, self.es)
84
+ for i in range(3):
85
  try:
86
+ return s.exists()
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
+ doc_store_logger.error("ES indexExist: " + str(e))
89
+ if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 
90
  continue
 
 
 
 
 
 
91
  return False
92
 
93
+ """
94
+ CRUD operations
95
+ """
96
+ def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
97
+ """
98
+ Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
99
+ """
100
+ if isinstance(indexNames, str):
101
+ indexNames = indexNames.split(",")
102
+ assert isinstance(indexNames, list) and len(indexNames) > 0
103
+ assert "_id" not in condition
104
+ s = Search()
105
+ bqry = None
106
+ vector_similarity_weight = 0.5
107
+ for m in matchExprs:
108
+ if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
109
+ assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
110
+ weights = m.fusion_params["weights"]
111
+ vector_similarity_weight = float(weights.split(",")[1])
112
+ for m in matchExprs:
113
+ if isinstance(m, MatchTextExpr):
114
+ minimum_should_match = "0%"
115
+ if "minimum_should_match" in m.extra_options:
116
+ minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
117
+ bqry = Q("bool",
118
+ must=Q("query_string", fields=m.fields,
119
+ type="best_fields", query=m.matching_text,
120
+ minimum_should_match = minimum_should_match,
121
+ boost=1),
122
+ boost = 1.0 - vector_similarity_weight,
123
+ )
124
+ if condition:
125
+ for k, v in condition.items():
126
+ if not isinstance(k, str) or not v:
127
+ continue
128
+ if isinstance(v, list):
129
+ bqry.filter.append(Q("terms", **{k: v}))
130
+ elif isinstance(v, str) or isinstance(v, int):
131
+ bqry.filter.append(Q("term", **{k: v}))
132
+ else:
133
+ raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
134
+ elif isinstance(m, MatchDenseExpr):
135
+ assert(bqry is not None)
136
+ similarity = 0.0
137
+ if "similarity" in m.extra_options:
138
+ similarity = m.extra_options["similarity"]
139
+ s = s.knn(m.vector_column_name,
140
+ m.topn,
141
+ m.topn * 2,
142
+ query_vector = list(m.embedding_data),
143
+ filter = bqry.to_dict(),
144
+ similarity = similarity,
145
+ )
146
+ if matchExprs:
147
+ s.query = bqry
148
+ for field in highlightFields:
149
+ s = s.highlight(field)
150
+
151
+ if orderBy:
152
+ orders = list()
153
+ for field, order in orderBy.fields:
154
+ order = "asc" if order == 0 else "desc"
155
+ orders.append({field: {"order": order, "unmapped_type": "float",
156
+ "mode": "avg", "numeric_type": "double"}})
157
+ s = s.sort(*orders)
158
+
159
+ if limit > 0:
160
+ s = s[offset:limit]
161
+ q = s.to_dict()
162
+ doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
163
+
164
  for i in range(3):
165
  try:
166
+ res = self.es.search(index=indexNames,
167
  body=q,
168
+ timeout="600s",
169
  # search_type="dfs_query_then_fetch",
170
  track_total_hits=True,
171
+ _source=True)
172
  if str(res.get("timed_out", "")).lower() == "true":
173
  raise Exception("Es Timeout.")
174
+ doc_store_logger.info("ESConnection.search res: " + str(res))
175
  return res
176
  except Exception as e:
177
+ doc_store_logger.error(
178
  "ES search exception: " +
179
  str(e) +
180
+ "\n[Q]: " +
181
  str(q))
182
  if str(e).find("Timeout") > 0:
183
  continue
184
  raise e
185
+ doc_store_logger.error("ES search timeout for 3 times!")
186
  raise Exception("ES search timeout.")
187
 
188
+ def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  for i in range(3):
190
  try:
191
+ res = self.es.get(index=(indexName),
192
+ id=chunkId, source=True,)
193
  if str(res.get("timed_out", "")).lower() == "true":
194
  raise Exception("Es Timeout.")
195
+ if not res.get("found"):
196
+ return None
197
+ chunk = res["_source"]
198
+ chunk["id"] = chunkId
199
+ return chunk
200
  except Exception as e:
201
+ doc_store_logger.error(
202
  "ES get exception: " +
203
  str(e) +
204
+ "[Q]: " +
205
+ chunkId)
206
  if str(e).find("Timeout") > 0:
207
  continue
208
  raise e
209
+ doc_store_logger.error("ES search timeout for 3 times!")
210
  raise Exception("ES search timeout.")
211
 
212
+ def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
213
+ # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
214
+ operations = []
215
+ for d in documents:
216
+ assert "_id" not in d
217
+ assert "id" in d
218
+ d_copy = copy.deepcopy(d)
219
+ meta_id = d_copy["id"]
220
+ del d_copy["id"]
221
+ operations.append(
222
+ {"index": {"_index": indexName, "_id": meta_id}})
223
+ operations.append(d_copy)
 
 
 
 
 
 
 
 
 
224
 
225
+ res = []
226
+ for _ in range(100):
 
 
 
 
 
 
 
227
  try:
228
+ r = self.es.bulk(index=(indexName), operations=operations,
229
+ refresh=False, timeout="600s")
230
+ if re.search(r"False", str(r["errors"]), re.IGNORECASE):
231
+ return res
 
 
 
 
 
 
232
 
233
+ for item in r["items"]:
234
+ for action in ["create", "delete", "index", "update"]:
235
+ if action in item and "error" in item[action]:
236
+ res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
237
+ return res
 
 
 
238
  except Exception as e:
239
+ doc_store_logger.warning("Fail to bulk: " + str(e))
240
+ if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
241
+ time.sleep(3)
 
242
  continue
243
+ return res
244
 
245
+ def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
246
+ doc = copy.deepcopy(newValue)
247
+ del doc['id']
248
+ if "id" in condition and isinstance(condition["id"], str):
249
+ # update specific single document
250
+ chunkId = condition["id"]
251
+ for i in range(3):
252
+ try:
253
+ self.es.update(index=indexName, id=chunkId, doc=doc)
254
+ return True
255
+ except Exception as e:
256
+ doc_store_logger.error(
257
+ "ES update exception: " + str(e) + " id:" + str(id) +
258
+ json.dumps(newValue, ensure_ascii=False))
259
+ if str(e).find("Timeout") > 0:
260
+ continue
261
+ else:
262
+ # update unspecific maybe-multiple documents
263
+ bqry = Q("bool")
264
+ for k, v in condition.items():
265
+ if not isinstance(k, str) or not v:
 
 
 
266
  continue
267
+ if isinstance(v, list):
268
+ bqry.filter.append(Q("terms", **{k: v}))
269
+ elif isinstance(v, str) or isinstance(v, int):
270
+ bqry.filter.append(Q("term", **{k: v}))
271
+ else:
272
+ raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
273
+ scripts = []
274
+ for k, v in newValue.items():
275
+ if not isinstance(k, str) or not v:
 
 
276
  continue
277
+ if isinstance(v, str):
278
+ scripts.append(f"ctx._source.{k} = '{v}'")
279
+ elif isinstance(v, int):
280
+ scripts.append(f"ctx._source.{k} = {v}")
281
+ else:
282
+ raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
283
+ ubq = UpdateByQuery(
284
+ index=indexName).using(
285
+ self.es).query(bqry)
286
+ ubq = ubq.script(source="; ".join(scripts))
287
+ ubq = ubq.params(refresh=True)
288
+ ubq = ubq.params(slices=5)
289
+ ubq = ubq.params(conflicts="proceed")
290
+ for i in range(3):
291
+ try:
292
+ _ = ubq.execute()
293
+ return True
294
+ except Exception as e:
295
+ doc_store_logger.error("ES update exception: " +
296
+ str(e) + "[Q]:" + str(bqry.to_dict()))
297
+ if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
298
+ continue
299
  return False
300
 
301
+ def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
302
+ qry = None
303
+ assert "_id" not in condition
304
+ if "id" in condition:
305
+ chunk_ids = condition["id"]
306
+ if not isinstance(chunk_ids, list):
307
+ chunk_ids = [chunk_ids]
308
+ qry = Q("ids", values=chunk_ids)
309
+ else:
310
+ qry = Q("bool")
311
+ for k, v in condition.items():
312
+ if isinstance(v, list):
313
+ qry.must.append(Q("terms", **{k: v}))
314
+ elif isinstance(v, str) or isinstance(v, int):
315
+ qry.must.append(Q("term", **{k: v}))
316
+ else:
317
+ raise Exception("Condition value must be int, str or list.")
318
+ doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
319
+ for _ in range(10):
320
  try:
321
+ res = self.es.delete_by_query(
322
+ index=indexName,
323
+ body = Search().query(qry).to_dict(),
324
+ refresh=True)
325
+ return res["deleted"]
326
  except Exception as e:
327
+ doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
328
+ if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
329
+ time.sleep(3)
330
  continue
331
+ if re.search(r"(not_found)", str(e), re.IGNORECASE):
332
+ return 0
333
+ return 0
 
 
 
 
 
 
 
 
 
334
 
 
 
 
 
 
335
 
336
+ """
337
+ Helper functions for search result
338
+ """
339
  def getTotal(self, res):
340
  if isinstance(res["hits"]["total"], type({})):
341
  return res["hits"]["total"]["value"]
342
  return res["hits"]["total"]
343
 
344
+ def getChunkIds(self, res):
345
  return [d["_id"] for d in res["hits"]["hits"]]
346
 
347
+ def __getSource(self, res):
348
  rr = []
349
  for d in res["hits"]["hits"]:
350
  d["_source"]["id"] = d["_id"]
 
352
  rr.append(d["_source"])
353
  return rr
354
 
355
+ def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
356
+ res_fields = {}
357
+ if not fields:
358
+ return {}
359
+ for d in self.__getSource(res):
360
+ m = {n: d.get(n) for n in fields if d.get(n) is not None}
361
+ for n, v in m.items():
362
+ if isinstance(v, list):
363
+ m[n] = v
364
+ continue
365
+ if not isinstance(v, str):
366
+ m[n] = str(m[n])
367
+ # if n.find("tks") > 0:
368
+ # m[n] = rmSpace(m[n])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
+ if m:
371
+ res_fields[d["id"]] = m
372
+ return res_fields
 
373
 
374
+ def getHighlight(self, res, keywords: List[str], fieldnm: str):
375
+ ans = {}
376
+ for d in res["hits"]["hits"]:
377
+ hlts = d.get("highlight")
378
+ if not hlts:
379
+ continue
380
+ txt = "...".join([a for a in list(hlts.items())[0][1]])
381
+ if not is_english(txt.split(" ")):
382
+ ans[d["_id"]] = txt
383
+ continue
384
+
385
+ txt = d["_source"][fieldnm]
386
+ txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
387
+ txts = []
388
+ for t in re.split(r"[.?!;\n]", txt):
389
+ for w in keywords:
390
+ t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
391
+ if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
392
+ continue
393
+ txts.append(t)
394
+ ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
395
+
396
+ return ans
397
+
398
+ def getAggregation(self, res, fieldnm: str):
399
+ agg_field = "aggs_" + fieldnm
400
+ if "aggregations" not in res or agg_field not in res["aggregations"]:
401
+ return list()
402
+ bkts = res["aggregations"][agg_field]["buckets"]
403
+ return [(b["key"], b["doc_count"]) for b in bkts]
404
+
405
+
406
+ """
407
+ SQL
408
+ """
409
+ def sql(self, sql: str, fetch_size: int, format: str):
410
+ doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
411
+ sql = re.sub(r"[ `]+", " ", sql)
412
+ sql = sql.replace("%", "")
413
+ replaces = []
414
+ for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
415
+ fld, v = r.group(1), r.group(3)
416
+ match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
417
+ fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
418
+ replaces.append(
419
+ ("{}{}'{}'".format(
420
+ r.group(1),
421
+ r.group(2),
422
+ r.group(3)),
423
+ match))
424
+
425
+ for p, r in replaces:
426
+ sql = sql.replace(p, r, 1)
427
+ doc_store_logger.info(f"ESConnection.sql to es: {sql}")
428
 
429
+ for i in range(3):
430
+ try:
431
+ res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
432
+ return res
433
+ except ConnectionTimeout:
434
+ doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
435
+ continue
436
+ except Exception as e:
437
+ doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
438
+ return None
439
+ doc_store_logger.error("ESConnection.sql timeout for 3 times!")
440
+ return None
rag/utils/infinity_conn.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ from typing import List, Dict
5
+ import infinity
6
+ from infinity.common import ConflictType, InfinityException
7
+ from infinity.index import IndexInfo, IndexType
8
+ from infinity.connection_pool import ConnectionPool
9
+ from rag import settings
10
+ from rag.settings import doc_store_logger
11
+ from rag.utils import singleton
12
+ import polars as pl
13
+ from polars.series.series import Series
14
+ from api.utils.file_utils import get_project_base_directory
15
+
16
+ from rag.utils.doc_store_conn import (
17
+ DocStoreConnection,
18
+ MatchExpr,
19
+ MatchTextExpr,
20
+ MatchDenseExpr,
21
+ FusionExpr,
22
+ OrderByExpr,
23
+ )
24
+
25
+
26
+ def equivalent_condition_to_str(condition: dict) -> str:
27
+ assert "_id" not in condition
28
+ cond = list()
29
+ for k, v in condition.items():
30
+ if not isinstance(k, str) or not v:
31
+ continue
32
+ if isinstance(v, list):
33
+ inCond = list()
34
+ for item in v:
35
+ if isinstance(item, str):
36
+ inCond.append(f"'{item}'")
37
+ else:
38
+ inCond.append(str(item))
39
+ if inCond:
40
+ strInCond = ", ".join(inCond)
41
+ strInCond = f"{k} IN ({strInCond})"
42
+ cond.append(strInCond)
43
+ elif isinstance(v, str):
44
+ cond.append(f"{k}='{v}'")
45
+ else:
46
+ cond.append(f"{k}={str(v)}")
47
+ return " AND ".join(cond)
48
+
49
+
50
+ @singleton
51
+ class InfinityConnection(DocStoreConnection):
52
+ def __init__(self):
53
+ self.dbName = settings.INFINITY.get("db_name", "default_db")
54
+ infinity_uri = settings.INFINITY["uri"]
55
+ if ":" in infinity_uri:
56
+ host, port = infinity_uri.split(":")
57
+ infinity_uri = infinity.common.NetworkAddress(host, int(port))
58
+ self.connPool = ConnectionPool(infinity_uri)
59
+ doc_store_logger.info(f"Connected to infinity {infinity_uri}.")
60
+
61
+ """
62
+ Database operations
63
+ """
64
+
65
+ def dbType(self) -> str:
66
+ return "infinity"
67
+
68
+ def health(self) -> dict:
69
+ """
70
+ Return the health status of the database.
71
+ TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
72
+ """
73
+ inf_conn = self.connPool.get_conn()
74
+ res = infinity.show_current_node()
75
+ self.connPool.release_conn(inf_conn)
76
+ color = "green" if res.error_code == 0 else "red"
77
+ res2 = {
78
+ "type": "infinity",
79
+ "status": f"{res.role} {color}",
80
+ "error": res.error_msg,
81
+ }
82
+ return res2
83
+
84
+ """
85
+ Table operations
86
+ """
87
+
88
+ def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
89
+ table_name = f"{indexName}_{knowledgebaseId}"
90
+ inf_conn = self.connPool.get_conn()
91
+ inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
92
+
93
+ fp_mapping = os.path.join(
94
+ get_project_base_directory(), "conf", "infinity_mapping.json"
95
+ )
96
+ if not os.path.exists(fp_mapping):
97
+ raise Exception(f"Mapping file not found at {fp_mapping}")
98
+ schema = json.load(open(fp_mapping))
99
+ vector_name = f"q_{vectorSize}_vec"
100
+ schema[vector_name] = {"type": f"vector,{vectorSize},float"}
101
+ inf_table = inf_db.create_table(
102
+ table_name,
103
+ schema,
104
+ ConflictType.Ignore,
105
+ )
106
+ inf_table.create_index(
107
+ "q_vec_idx",
108
+ IndexInfo(
109
+ vector_name,
110
+ IndexType.Hnsw,
111
+ {
112
+ "M": "16",
113
+ "ef_construction": "50",
114
+ "metric": "cosine",
115
+ "encode": "lvq",
116
+ },
117
+ ),
118
+ ConflictType.Ignore,
119
+ )
120
+ text_suffix = ["_tks", "_ltks", "_kwd"]
121
+ for field_name, field_info in schema.items():
122
+ if field_info["type"] != "varchar":
123
+ continue
124
+ for suffix in text_suffix:
125
+ if field_name.endswith(suffix):
126
+ inf_table.create_index(
127
+ f"text_idx_{field_name}",
128
+ IndexInfo(
129
+ field_name, IndexType.FullText, {"ANALYZER": "standard"}
130
+ ),
131
+ ConflictType.Ignore,
132
+ )
133
+ break
134
+ self.connPool.release_conn(inf_conn)
135
+ doc_store_logger.info(
136
+ f"INFINITY created table {table_name}, vector size {vectorSize}"
137
+ )
138
+
139
+ def deleteIdx(self, indexName: str, knowledgebaseId: str):
140
+ table_name = f"{indexName}_{knowledgebaseId}"
141
+ inf_conn = self.connPool.get_conn()
142
+ db_instance = inf_conn.get_database(self.dbName)
143
+ db_instance.drop_table(table_name, ConflictType.Ignore)
144
+ self.connPool.release_conn(inf_conn)
145
+ doc_store_logger.info(f"INFINITY dropped table {table_name}")
146
+
147
+ def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
148
+ table_name = f"{indexName}_{knowledgebaseId}"
149
+ try:
150
+ inf_conn = self.connPool.get_conn()
151
+ db_instance = inf_conn.get_database(self.dbName)
152
+ _ = db_instance.get_table(table_name)
153
+ self.connPool.release_conn(inf_conn)
154
+ return True
155
+ except Exception as e:
156
+ doc_store_logger.error("INFINITY indexExist: " + str(e))
157
+ return False
158
+
159
+ """
160
+ CRUD operations
161
+ """
162
+
163
+ def search(
164
+ self,
165
+ selectFields: list[str],
166
+ highlightFields: list[str],
167
+ condition: dict,
168
+ matchExprs: list[MatchExpr],
169
+ orderBy: OrderByExpr,
170
+ offset: int,
171
+ limit: int,
172
+ indexNames: str|list[str],
173
+ knowledgebaseIds: list[str],
174
+ ) -> list[dict] | pl.DataFrame:
175
+ """
176
+ TODO: Infinity doesn't provide highlight
177
+ """
178
+ if isinstance(indexNames, str):
179
+ indexNames = indexNames.split(",")
180
+ assert isinstance(indexNames, list) and len(indexNames) > 0
181
+ inf_conn = self.connPool.get_conn()
182
+ db_instance = inf_conn.get_database(self.dbName)
183
+ df_list = list()
184
+ table_list = list()
185
+ if "id" not in selectFields:
186
+ selectFields.append("id")
187
+
188
+ # Prepare expressions common to all tables
189
+ filter_cond = ""
190
+ filter_fulltext = ""
191
+ if condition:
192
+ filter_cond = equivalent_condition_to_str(condition)
193
+ for matchExpr in matchExprs:
194
+ if isinstance(matchExpr, MatchTextExpr):
195
+ if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
196
+ matchExpr.extra_options.update({"filter": filter_cond})
197
+ fields = ",".join(matchExpr.fields)
198
+ filter_fulltext = (
199
+ f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
200
+ )
201
+ if len(filter_cond) != 0:
202
+ filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
203
+ # doc_store_logger.info(f"filter_fulltext: {filter_fulltext}")
204
+ minimum_should_match = "0%"
205
+ if "minimum_should_match" in matchExpr.extra_options:
206
+ minimum_should_match = (
207
+ str(int(matchExpr.extra_options["minimum_should_match"] * 100))
208
+ + "%"
209
+ )
210
+ matchExpr.extra_options.update(
211
+ {"minimum_should_match": minimum_should_match}
212
+ )
213
+ for k, v in matchExpr.extra_options.items():
214
+ if not isinstance(v, str):
215
+ matchExpr.extra_options[k] = str(v)
216
+ elif isinstance(matchExpr, MatchDenseExpr):
217
+ if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
218
+ matchExpr.extra_options.update({"filter": filter_fulltext})
219
+ for k, v in matchExpr.extra_options.items():
220
+ if not isinstance(v, str):
221
+ matchExpr.extra_options[k] = str(v)
222
+ if orderBy.fields:
223
+ order_by_expr_list = list()
224
+ for order_field in orderBy.fields:
225
+ order_by_expr_list.append((order_field[0], order_field[1] == 0))
226
+
227
+ # Scatter search tables and gather the results
228
+ for indexName in indexNames:
229
+ for knowledgebaseId in knowledgebaseIds:
230
+ table_name = f"{indexName}_{knowledgebaseId}"
231
+ try:
232
+ table_instance = db_instance.get_table(table_name)
233
+ except Exception:
234
+ continue
235
+ table_list.append(table_name)
236
+ builder = table_instance.output(selectFields)
237
+ for matchExpr in matchExprs:
238
+ if isinstance(matchExpr, MatchTextExpr):
239
+ fields = ",".join(matchExpr.fields)
240
+ builder = builder.match_text(
241
+ fields,
242
+ matchExpr.matching_text,
243
+ matchExpr.topn,
244
+ matchExpr.extra_options,
245
+ )
246
+ elif isinstance(matchExpr, MatchDenseExpr):
247
+ builder = builder.match_dense(
248
+ matchExpr.vector_column_name,
249
+ matchExpr.embedding_data,
250
+ matchExpr.embedding_data_type,
251
+ matchExpr.distance_type,
252
+ matchExpr.topn,
253
+ matchExpr.extra_options,
254
+ )
255
+ elif isinstance(matchExpr, FusionExpr):
256
+ builder = builder.fusion(
257
+ matchExpr.method, matchExpr.topn, matchExpr.fusion_params
258
+ )
259
+ if orderBy.fields:
260
+ builder.sort(order_by_expr_list)
261
+ builder.offset(offset).limit(limit)
262
+ kb_res = builder.to_pl()
263
+ df_list.append(kb_res)
264
+ self.connPool.release_conn(inf_conn)
265
+ res = pl.concat(df_list)
266
+ doc_store_logger.info("INFINITY search tables: " + str(table_list))
267
+ return res
268
+
269
+ def get(
270
+ self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
271
+ ) -> dict | None:
272
+ inf_conn = self.connPool.get_conn()
273
+ db_instance = inf_conn.get_database(self.dbName)
274
+ df_list = list()
275
+ assert isinstance(knowledgebaseIds, list)
276
+ for knowledgebaseId in knowledgebaseIds:
277
+ table_name = f"{indexName}_{knowledgebaseId}"
278
+ table_instance = db_instance.get_table(table_name)
279
+ kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
280
+ df_list.append(kb_res)
281
+ self.connPool.release_conn(inf_conn)
282
+ res = pl.concat(df_list)
283
+ res_fields = self.getFields(res, res.columns)
284
+ return res_fields.get(chunkId, None)
285
+
286
+ def insert(
287
+ self, documents: list[dict], indexName: str, knowledgebaseId: str
288
+ ) -> list[str]:
289
+ inf_conn = self.connPool.get_conn()
290
+ db_instance = inf_conn.get_database(self.dbName)
291
+ table_name = f"{indexName}_{knowledgebaseId}"
292
+ try:
293
+ table_instance = db_instance.get_table(table_name)
294
+ except InfinityException as e:
295
+ # src/common/status.cppm, kTableNotExist = 3022
296
+ if e.error_code != 3022:
297
+ raise
298
+ vector_size = 0
299
+ patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
300
+ for k in documents[0].keys():
301
+ m = patt.match(k)
302
+ if m:
303
+ vector_size = int(m.group("vector_size"))
304
+ break
305
+ if vector_size == 0:
306
+ raise ValueError("Cannot infer vector size from documents")
307
+ self.createIdx(indexName, knowledgebaseId, vector_size)
308
+ table_instance = db_instance.get_table(table_name)
309
+
310
+ for d in documents:
311
+ assert "_id" not in d
312
+ assert "id" in d
313
+ for k, v in d.items():
314
+ if k.endswith("_kwd") and isinstance(v, list):
315
+ d[k] = " ".join(v)
316
+ ids = [f"'{d["id"]}'" for d in documents]
317
+ str_ids = ", ".join(ids)
318
+ str_filter = f"id IN ({str_ids})"
319
+ table_instance.delete(str_filter)
320
+ # for doc in documents:
321
+ # doc_store_logger.info(f"insert position_list: {doc['position_list']}")
322
+ # doc_store_logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
323
+ table_instance.insert(documents)
324
+ self.connPool.release_conn(inf_conn)
325
+ doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
326
+ return []
327
+
328
+ def update(
329
+ self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
330
+ ) -> bool:
331
+ # if 'position_list' in newValue:
332
+ # doc_store_logger.info(f"update position_list: {newValue['position_list']}")
333
+ inf_conn = self.connPool.get_conn()
334
+ db_instance = inf_conn.get_database(self.dbName)
335
+ table_name = f"{indexName}_{knowledgebaseId}"
336
+ table_instance = db_instance.get_table(table_name)
337
+ filter = equivalent_condition_to_str(condition)
338
+ for k, v in newValue.items():
339
+ if k.endswith("_kwd") and isinstance(v, list):
340
+ newValue[k] = " ".join(v)
341
+ table_instance.update(filter, newValue)
342
+ self.connPool.release_conn(inf_conn)
343
+ return True
344
+
345
+ def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
346
+ inf_conn = self.connPool.get_conn()
347
+ db_instance = inf_conn.get_database(self.dbName)
348
+ table_name = f"{indexName}_{knowledgebaseId}"
349
+ filter = equivalent_condition_to_str(condition)
350
+ try:
351
+ table_instance = db_instance.get_table(table_name)
352
+ except Exception:
353
+ doc_store_logger.warning(
354
+ f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
355
+ )
356
+ return 0
357
+ res = table_instance.delete(filter)
358
+ self.connPool.release_conn(inf_conn)
359
+ return res.deleted_rows
360
+
361
+ """
362
+ Helper functions for search result
363
+ """
364
+
365
+ def getTotal(self, res):
366
+ return len(res)
367
+
368
+ def getChunkIds(self, res):
369
+ return list(res["id"])
370
+
371
+ def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
372
+ res_fields = {}
373
+ if not fields:
374
+ return {}
375
+ num_rows = len(res)
376
+ column_id = res["id"]
377
+ for i in range(num_rows):
378
+ id = column_id[i]
379
+ m = {"id": id}
380
+ for fieldnm in fields:
381
+ if fieldnm not in res:
382
+ m[fieldnm] = None
383
+ continue
384
+ v = res[fieldnm][i]
385
+ if isinstance(v, Series):
386
+ v = list(v)
387
+ elif fieldnm == "important_kwd":
388
+ assert isinstance(v, str)
389
+ v = v.split(" ")
390
+ else:
391
+ if not isinstance(v, str):
392
+ v = str(v)
393
+ # if fieldnm.endswith("_tks"):
394
+ # v = rmSpace(v)
395
+ m[fieldnm] = v
396
+ res_fields[id] = m
397
+ return res_fields
398
+
399
+ def getHighlight(self, res, keywords: List[str], fieldnm: str):
400
+ ans = {}
401
+ num_rows = len(res)
402
+ column_id = res["id"]
403
+ for i in range(num_rows):
404
+ id = column_id[i]
405
+ txt = res[fieldnm][i]
406
+ txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
407
+ txts = []
408
+ for t in re.split(r"[.?!;\n]", txt):
409
+ for w in keywords:
410
+ t = re.sub(
411
+ r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
412
+ % re.escape(w),
413
+ r"\1<em>\2</em>\3",
414
+ t,
415
+ flags=re.IGNORECASE | re.MULTILINE,
416
+ )
417
+ if not re.search(
418
+ r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
419
+ ):
420
+ continue
421
+ txts.append(t)
422
+ ans[id] = "...".join(txts)
423
+ return ans
424
+
425
+ def getAggregation(self, res, fieldnm: str):
426
+ """
427
+ TODO: Infinity doesn't provide aggregation
428
+ """
429
+ return list()
430
+
431
+ """
432
+ SQL
433
+ """
434
+
435
+ def sql(sql: str, fetch_size: int, format: str):
436
+ raise NotImplementedError("Not implemented")
sdk/python/ragflow_sdk/modules/document.py CHANGED
@@ -50,8 +50,8 @@ class Document(Base):
50
  return res.content
51
 
52
 
53
- def list_chunks(self,page=1, page_size=30, keywords="", id:str=None):
54
- data={"keywords": keywords,"page":page,"page_size":page_size,"id":id}
55
  res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
56
  res = res.json()
57
  if res.get("code") == 0:
 
50
  return res.content
51
 
52
 
53
+ def list_chunks(self,page=1, page_size=30, keywords=""):
54
+ data={"keywords": keywords,"page":page,"page_size":page_size}
55
  res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
56
  res = res.json()
57
  if res.get("code") == 0:
sdk/python/test/t_chunk.py CHANGED
@@ -126,6 +126,7 @@ def test_delete_chunk_with_success(get_api_key_fixture):
126
  docs = ds.upload_documents(documents)
127
  doc = docs[0]
128
  chunk = doc.add_chunk(content="This is a chunk addition test")
 
129
  doc.delete_chunks([chunk.id])
130
 
131
 
@@ -146,6 +147,8 @@ def test_update_chunk_content(get_api_key_fixture):
146
  docs = ds.upload_documents(documents)
147
  doc = docs[0]
148
  chunk = doc.add_chunk(content="This is a chunk addition test")
 
 
149
  chunk.update({"content":"This is a updated content"})
150
 
151
  def test_update_chunk_available(get_api_key_fixture):
@@ -165,7 +168,9 @@ def test_update_chunk_available(get_api_key_fixture):
165
  docs = ds.upload_documents(documents)
166
  doc = docs[0]
167
  chunk = doc.add_chunk(content="This is a chunk addition test")
168
- chunk.update({"available":False})
 
 
169
 
170
 
171
  def test_retrieve_chunks(get_api_key_fixture):
 
126
  docs = ds.upload_documents(documents)
127
  doc = docs[0]
128
  chunk = doc.add_chunk(content="This is a chunk addition test")
129
+ sleep(5)
130
  doc.delete_chunks([chunk.id])
131
 
132
 
 
147
  docs = ds.upload_documents(documents)
148
  doc = docs[0]
149
  chunk = doc.add_chunk(content="This is a chunk addition test")
150
+ # For ElasticSearch, the chunk is not searchable in shot time (~2s).
151
+ sleep(3)
152
  chunk.update({"content":"This is a updated content"})
153
 
154
  def test_update_chunk_available(get_api_key_fixture):
 
168
  docs = ds.upload_documents(documents)
169
  doc = docs[0]
170
  chunk = doc.add_chunk(content="This is a chunk addition test")
171
+ # For ElasticSearch, the chunk is not searchable in shot time (~2s).
172
+ sleep(3)
173
+ chunk.update({"available":0})
174
 
175
 
176
  def test_retrieve_chunks(get_api_key_fixture):