Kevin Hu commited on
Commit
f539fab
·
1 Parent(s): 1e02591

Add pagerank to KB. (#3809)

Browse files

### What problem does this PR solve?

#3794

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/chunk_app.py CHANGED
@@ -227,12 +227,18 @@ def create():
227
  return get_data_error_result(message="Document not found!")
228
  d["kb_id"] = [doc.kb_id]
229
  d["docnm_kwd"] = doc.name
 
230
  d["doc_id"] = doc.id
231
 
232
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
233
  if not tenant_id:
234
  return get_data_error_result(message="Tenant not found!")
235
 
 
 
 
 
 
236
  embd_id = DocumentService.get_embd_id(req["doc_id"])
237
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
238
 
 
227
  return get_data_error_result(message="Document not found!")
228
  d["kb_id"] = [doc.kb_id]
229
  d["docnm_kwd"] = doc.name
230
+ d["title_tks"] = rag_tokenizer.tokenize(doc.name)
231
  d["doc_id"] = doc.id
232
 
233
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
234
  if not tenant_id:
235
  return get_data_error_result(message="Tenant not found!")
236
 
237
+ e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
238
+ if not e:
239
+ return get_data_error_result(message="Knowledgebase not found!")
240
+ if kb.pagerank: d["pagerank_fea"] = kb.pagerank
241
+
242
  embd_id = DocumentService.get_embd_id(req["doc_id"])
243
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
244
 
api/apps/kb_app.py CHANGED
@@ -102,6 +102,14 @@ def update():
102
  if not KnowledgebaseService.update_by_id(kb.id, req):
103
  return get_data_error_result()
104
 
 
 
 
 
 
 
 
 
105
  e, kb = KnowledgebaseService.get_by_id(kb.id)
106
  if not e:
107
  return get_data_error_result(
 
102
  if not KnowledgebaseService.update_by_id(kb.id, req):
103
  return get_data_error_result()
104
 
105
+ if kb.pagerank != req.get("pagerank", 0):
106
+ if req.get("pagerank", 0) > 0:
107
+ settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
108
+ search.index_name(kb.tenant_id), kb.id)
109
+ else:
110
+ settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
111
+ search.index_name(kb.tenant_id), kb.id)
112
+
113
  e, kb = KnowledgebaseService.get_by_id(kb.id)
114
  if not e:
115
  return get_data_error_result(
api/db/db_models.py CHANGED
@@ -703,6 +703,7 @@ class Knowledgebase(DataBaseModel):
703
  default=ParserType.NAIVE.value,
704
  index=True)
705
  parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
 
706
  status = CharField(
707
  max_length=1,
708
  null=True,
@@ -1076,4 +1077,10 @@ def migrate_db():
1076
  )
1077
  except Exception:
1078
  pass
 
 
 
 
 
 
1079
 
 
703
  default=ParserType.NAIVE.value,
704
  index=True)
705
  parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
706
+ pagerank = IntegerField(default=0, index=False)
707
  status = CharField(
708
  max_length=1,
709
  null=True,
 
1077
  )
1078
  except Exception:
1079
  pass
1080
+ try:
1081
+ migrate(
1082
+ migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False))
1083
+ )
1084
+ except Exception:
1085
+ pass
1086
 
api/db/services/knowledgebase_service.py CHANGED
@@ -104,7 +104,8 @@ class KnowledgebaseService(CommonService):
104
  cls.model.token_num,
105
  cls.model.chunk_num,
106
  cls.model.parser_id,
107
- cls.model.parser_config]
 
108
  kbs = cls.model.select(*fields).join(Tenant, on=(
109
  (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
110
  (cls.model.id == kb_id),
 
104
  cls.model.token_num,
105
  cls.model.chunk_num,
106
  cls.model.parser_id,
107
+ cls.model.parser_config,
108
+ cls.model.pagerank]
109
  kbs = cls.model.select(*fields).join(Tenant, on=(
110
  (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
111
  (cls.model.id == kb_id),
api/db/services/llm_service.py CHANGED
@@ -191,15 +191,18 @@ class TenantLLMService(CommonService):
191
 
192
  num = 0
193
  try:
194
- tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name)
195
- if tenant_llms:
 
 
 
 
 
 
196
  tenant_llm = tenant_llms[0]
197
  num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
198
  .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
199
  .execute()
200
- else:
201
- if not llm_factory: llm_factory = mdlnm
202
- num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
203
  except Exception:
204
  logging.exception("TenantLLMService.increase_usage got exception")
205
  return num
 
191
 
192
  num = 0
193
  try:
194
+ if llm_factory:
195
+ tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory)
196
+ else:
197
+ tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name)
198
+ if not tenant_llms:
199
+ if not llm_factory: llm_factory = mdlnm
200
+ num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
201
+ else:
202
  tenant_llm = tenant_llms[0]
203
  num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
204
  .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
205
  .execute()
 
 
 
206
  except Exception:
207
  logging.exception("TenantLLMService.increase_usage got exception")
208
  return num
api/db/services/task_service.py CHANGED
@@ -53,6 +53,7 @@ class TaskService(CommonService):
53
  Knowledgebase.tenant_id,
54
  Knowledgebase.language,
55
  Knowledgebase.embd_id,
 
56
  Tenant.img2txt_id,
57
  Tenant.asr_id,
58
  Tenant.llm_id,
 
53
  Knowledgebase.tenant_id,
54
  Knowledgebase.language,
55
  Knowledgebase.embd_id,
56
+ Knowledgebase.pagerank,
57
  Tenant.img2txt_id,
58
  Tenant.asr_id,
59
  Tenant.llm_id,
conf/infinity_mapping.json CHANGED
@@ -22,5 +22,6 @@
22
  "rank_int": {"type": "integer", "default": 0},
23
  "available_int": {"type": "integer", "default": 1},
24
  "knowledge_graph_kwd": {"type": "varchar", "default": ""},
25
- "entities_kwd": {"type": "varchar", "default": ""}
 
26
  }
 
22
  "rank_int": {"type": "integer", "default": 0},
23
  "available_int": {"type": "integer", "default": 1},
24
  "knowledge_graph_kwd": {"type": "varchar", "default": ""},
25
+ "entities_kwd": {"type": "varchar", "default": ""},
26
+ "pagerank_fea": {"type": "integer", "default": 0}
27
  }
rag/nlp/search.py CHANGED
@@ -75,7 +75,7 @@ class Dealer:
75
 
76
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
77
  "doc_id", "position_list", "knowledge_graph_kwd",
78
- "available_int", "content_with_weight"])
79
  kwds = set([])
80
 
81
  qst = req.get("question", "")
@@ -234,11 +234,13 @@ class Dealer:
234
  vector_column = f"q_{vector_size}_vec"
235
  zero_vector = [0.0] * vector_size
236
  ins_embd = []
 
237
  for chunk_id in sres.ids:
238
  vector = sres.field[chunk_id].get(vector_column, zero_vector)
239
  if isinstance(vector, str):
240
  vector = [float(v) for v in vector.split("\t")]
241
  ins_embd.append(vector)
 
242
  if not ins_embd:
243
  return [], [], []
244
 
@@ -257,7 +259,8 @@ class Dealer:
257
  ins_embd,
258
  keywords,
259
  ins_tw, tkweight, vtweight)
260
- return sim, tksim, vtsim
 
261
 
262
  def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
263
  vtweight=0.7, cfield="content_ltks"):
@@ -351,7 +354,7 @@ class Dealer:
351
  "vector": chunk.get(vector_column, zero_vector),
352
  "positions": json.loads(position_list)
353
  }
354
- if highlight:
355
  if id in sres.highlight:
356
  d["highlight"] = rmSpace(sres.highlight[id])
357
  else:
 
75
 
76
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
77
  "doc_id", "position_list", "knowledge_graph_kwd",
78
+ "available_int", "content_with_weight", "pagerank_fea"])
79
  kwds = set([])
80
 
81
  qst = req.get("question", "")
 
234
  vector_column = f"q_{vector_size}_vec"
235
  zero_vector = [0.0] * vector_size
236
  ins_embd = []
237
+ pageranks = []
238
  for chunk_id in sres.ids:
239
  vector = sres.field[chunk_id].get(vector_column, zero_vector)
240
  if isinstance(vector, str):
241
  vector = [float(v) for v in vector.split("\t")]
242
  ins_embd.append(vector)
243
+ pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
244
  if not ins_embd:
245
  return [], [], []
246
 
 
259
  ins_embd,
260
  keywords,
261
  ins_tw, tkweight, vtweight)
262
+
263
+ return sim+np.array(pageranks, dtype=float), tksim, vtsim
264
 
265
  def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
266
  vtweight=0.7, cfield="content_ltks"):
 
354
  "vector": chunk.get(vector_column, zero_vector),
355
  "positions": json.loads(position_list)
356
  }
357
+ if highlight and sres.highlight:
358
  if id in sres.highlight:
359
  d["highlight"] = rmSpace(sres.highlight[id])
360
  else:
rag/svr/task_executor.py CHANGED
@@ -201,6 +201,7 @@ def build_chunks(task, progress_callback):
201
  "doc_id": task["doc_id"],
202
  "kb_id": str(task["kb_id"])
203
  }
 
204
  el = 0
205
  for ck in cks:
206
  d = copy.deepcopy(doc)
@@ -339,6 +340,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
339
  "docnm_kwd": row["name"],
340
  "title_tks": rag_tokenizer.tokenize(row["name"])
341
  }
 
342
  res = []
343
  tk_count = 0
344
  for content, vctr in chunks[original_length:]:
@@ -431,7 +433,7 @@ def do_handle_task(task):
431
  progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
432
  logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
433
  if doc_store_result:
434
- error_message = "Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
435
  progress_callback(-1, msg=error_message)
436
  settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
437
  logging.error(error_message)
 
201
  "doc_id": task["doc_id"],
202
  "kb_id": str(task["kb_id"])
203
  }
204
+ if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
205
  el = 0
206
  for ck in cks:
207
  d = copy.deepcopy(doc)
 
340
  "docnm_kwd": row["name"],
341
  "title_tks": rag_tokenizer.tokenize(row["name"])
342
  }
343
+ if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
344
  res = []
345
  tk_count = 0
346
  for content, vctr in chunks[original_length:]:
 
433
  progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
434
  logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
435
  if doc_store_result:
436
+ error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
437
  progress_callback(-1, msg=error_message)
438
  settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
439
  logging.error(error_message)
rag/utils/es_conn.py CHANGED
@@ -175,6 +175,7 @@ class ESConnection(DocStoreConnection):
175
  )
176
 
177
  if bqry:
 
178
  s = s.query(bqry)
179
  for field in highlightFields:
180
  s = s.highlight(field)
@@ -283,12 +284,16 @@ class ESConnection(DocStoreConnection):
283
  f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
284
  if str(e).find("Timeout") > 0:
285
  continue
 
286
  else:
287
  # update unspecific maybe-multiple documents
288
  bqry = Q("bool")
289
  for k, v in condition.items():
290
  if not isinstance(k, str) or not v:
291
  continue
 
 
 
292
  if isinstance(v, list):
293
  bqry.filter.append(Q("terms", **{k: v}))
294
  elif isinstance(v, str) or isinstance(v, int):
@@ -298,6 +303,9 @@ class ESConnection(DocStoreConnection):
298
  f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
299
  scripts = []
300
  for k, v in newValue.items():
 
 
 
301
  if (not isinstance(k, str) or not v) and k != "available_int":
302
  continue
303
  if isinstance(v, str):
@@ -307,21 +315,21 @@ class ESConnection(DocStoreConnection):
307
  else:
308
  raise Exception(
309
  f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
310
- ubq = UpdateByQuery(
311
- index=indexName).using(
312
- self.es).query(bqry)
313
- ubq = ubq.script(source="; ".join(scripts))
314
- ubq = ubq.params(refresh=True)
315
- ubq = ubq.params(slices=5)
316
- ubq = ubq.params(conflicts="proceed")
317
- for i in range(3):
318
- try:
319
- _ = ubq.execute()
320
- return True
321
- except Exception as e:
322
- logger.error("ESConnection.update got exception: " + str(e))
323
- if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
324
- continue
325
  return False
326
 
327
  def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
 
175
  )
176
 
177
  if bqry:
178
+ bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
179
  s = s.query(bqry)
180
  for field in highlightFields:
181
  s = s.highlight(field)
 
284
  f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
285
  if str(e).find("Timeout") > 0:
286
  continue
287
+ return False
288
  else:
289
  # update unspecific maybe-multiple documents
290
  bqry = Q("bool")
291
  for k, v in condition.items():
292
  if not isinstance(k, str) or not v:
293
  continue
294
+ if k == "exist":
295
+ bqry.filter.append(Q("exists", field=v))
296
+ continue
297
  if isinstance(v, list):
298
  bqry.filter.append(Q("terms", **{k: v}))
299
  elif isinstance(v, str) or isinstance(v, int):
 
303
  f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
304
  scripts = []
305
  for k, v in newValue.items():
306
+ if k == "remove":
307
+ scripts.append(f"ctx._source.remove('{v}');")
308
+ continue
309
  if (not isinstance(k, str) or not v) and k != "available_int":
310
  continue
311
  if isinstance(v, str):
 
315
  else:
316
  raise Exception(
317
  f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
318
+ ubq = UpdateByQuery(
319
+ index=indexName).using(
320
+ self.es).query(bqry)
321
+ ubq = ubq.script(source="; ".join(scripts))
322
+ ubq = ubq.params(refresh=True)
323
+ ubq = ubq.params(slices=5)
324
+ ubq = ubq.params(conflicts="proceed")
325
+ for i in range(3):
326
+ try:
327
+ _ = ubq.execute()
328
+ return True
329
+ except Exception as e:
330
+ logger.error("ESConnection.update got exception: " + str(e))
331
+ if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
332
+ continue
333
  return False
334
 
335
  def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
sdk/python/ragflow_sdk/modules/dataset.py CHANGED
@@ -21,6 +21,7 @@ class DataSet(Base):
21
  self.chunk_count = 0
22
  self.chunk_method = "naive"
23
  self.parser_config = None
 
24
  for k in list(res_dict.keys()):
25
  if k not in self.__dict__:
26
  res_dict.pop(k)
 
21
  self.chunk_count = 0
22
  self.chunk_method = "naive"
23
  self.parser_config = None
24
+ self.pagerank = 0
25
  for k in list(res_dict.keys()):
26
  if k not in self.__dict__:
27
  res_dict.pop(k)