Kevin Hu
commited on
Commit
·
f539fab
1
Parent(s):
1e02591
Add pagerank to KB. (#3809)
Browse files### What problem does this PR solve?
#3794
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/chunk_app.py +6 -0
- api/apps/kb_app.py +8 -0
- api/db/db_models.py +7 -0
- api/db/services/knowledgebase_service.py +2 -1
- api/db/services/llm_service.py +8 -5
- api/db/services/task_service.py +1 -0
- conf/infinity_mapping.json +2 -1
- rag/nlp/search.py +6 -3
- rag/svr/task_executor.py +3 -1
- rag/utils/es_conn.py +23 -15
- sdk/python/ragflow_sdk/modules/dataset.py +1 -0
api/apps/chunk_app.py
CHANGED
|
@@ -227,12 +227,18 @@ def create():
|
|
| 227 |
return get_data_error_result(message="Document not found!")
|
| 228 |
d["kb_id"] = [doc.kb_id]
|
| 229 |
d["docnm_kwd"] = doc.name
|
|
|
|
| 230 |
d["doc_id"] = doc.id
|
| 231 |
|
| 232 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
| 233 |
if not tenant_id:
|
| 234 |
return get_data_error_result(message="Tenant not found!")
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
| 237 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
| 238 |
|
|
|
|
| 227 |
return get_data_error_result(message="Document not found!")
|
| 228 |
d["kb_id"] = [doc.kb_id]
|
| 229 |
d["docnm_kwd"] = doc.name
|
| 230 |
+
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
| 231 |
d["doc_id"] = doc.id
|
| 232 |
|
| 233 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
| 234 |
if not tenant_id:
|
| 235 |
return get_data_error_result(message="Tenant not found!")
|
| 236 |
|
| 237 |
+
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
| 238 |
+
if not e:
|
| 239 |
+
return get_data_error_result(message="Knowledgebase not found!")
|
| 240 |
+
if kb.pagerank: d["pagerank_fea"] = kb.pagerank
|
| 241 |
+
|
| 242 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
| 243 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
| 244 |
|
api/apps/kb_app.py
CHANGED
|
@@ -102,6 +102,14 @@ def update():
|
|
| 102 |
if not KnowledgebaseService.update_by_id(kb.id, req):
|
| 103 |
return get_data_error_result()
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
| 106 |
if not e:
|
| 107 |
return get_data_error_result(
|
|
|
|
| 102 |
if not KnowledgebaseService.update_by_id(kb.id, req):
|
| 103 |
return get_data_error_result()
|
| 104 |
|
| 105 |
+
if kb.pagerank != req.get("pagerank", 0):
|
| 106 |
+
if req.get("pagerank", 0) > 0:
|
| 107 |
+
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
|
| 108 |
+
search.index_name(kb.tenant_id), kb.id)
|
| 109 |
+
else:
|
| 110 |
+
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
|
| 111 |
+
search.index_name(kb.tenant_id), kb.id)
|
| 112 |
+
|
| 113 |
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
| 114 |
if not e:
|
| 115 |
return get_data_error_result(
|
api/db/db_models.py
CHANGED
|
@@ -703,6 +703,7 @@ class Knowledgebase(DataBaseModel):
|
|
| 703 |
default=ParserType.NAIVE.value,
|
| 704 |
index=True)
|
| 705 |
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
|
|
|
| 706 |
status = CharField(
|
| 707 |
max_length=1,
|
| 708 |
null=True,
|
|
@@ -1076,4 +1077,10 @@ def migrate_db():
|
|
| 1076 |
)
|
| 1077 |
except Exception:
|
| 1078 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1079 |
|
|
|
|
| 703 |
default=ParserType.NAIVE.value,
|
| 704 |
index=True)
|
| 705 |
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
| 706 |
+
pagerank = IntegerField(default=0, index=False)
|
| 707 |
status = CharField(
|
| 708 |
max_length=1,
|
| 709 |
null=True,
|
|
|
|
| 1077 |
)
|
| 1078 |
except Exception:
|
| 1079 |
pass
|
| 1080 |
+
try:
|
| 1081 |
+
migrate(
|
| 1082 |
+
migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False))
|
| 1083 |
+
)
|
| 1084 |
+
except Exception:
|
| 1085 |
+
pass
|
| 1086 |
|
api/db/services/knowledgebase_service.py
CHANGED
|
@@ -104,7 +104,8 @@ class KnowledgebaseService(CommonService):
|
|
| 104 |
cls.model.token_num,
|
| 105 |
cls.model.chunk_num,
|
| 106 |
cls.model.parser_id,
|
| 107 |
-
cls.model.parser_config
|
|
|
|
| 108 |
kbs = cls.model.select(*fields).join(Tenant, on=(
|
| 109 |
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
| 110 |
(cls.model.id == kb_id),
|
|
|
|
| 104 |
cls.model.token_num,
|
| 105 |
cls.model.chunk_num,
|
| 106 |
cls.model.parser_id,
|
| 107 |
+
cls.model.parser_config,
|
| 108 |
+
cls.model.pagerank]
|
| 109 |
kbs = cls.model.select(*fields).join(Tenant, on=(
|
| 110 |
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
| 111 |
(cls.model.id == kb_id),
|
api/db/services/llm_service.py
CHANGED
|
@@ -191,15 +191,18 @@ class TenantLLMService(CommonService):
|
|
| 191 |
|
| 192 |
num = 0
|
| 193 |
try:
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
tenant_llm = tenant_llms[0]
|
| 197 |
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
| 198 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
| 199 |
.execute()
|
| 200 |
-
else:
|
| 201 |
-
if not llm_factory: llm_factory = mdlnm
|
| 202 |
-
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
| 203 |
except Exception:
|
| 204 |
logging.exception("TenantLLMService.increase_usage got exception")
|
| 205 |
return num
|
|
|
|
| 191 |
|
| 192 |
num = 0
|
| 193 |
try:
|
| 194 |
+
if llm_factory:
|
| 195 |
+
tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory)
|
| 196 |
+
else:
|
| 197 |
+
tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name)
|
| 198 |
+
if not tenant_llms:
|
| 199 |
+
if not llm_factory: llm_factory = mdlnm
|
| 200 |
+
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
| 201 |
+
else:
|
| 202 |
tenant_llm = tenant_llms[0]
|
| 203 |
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
| 204 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
| 205 |
.execute()
|
|
|
|
|
|
|
|
|
|
| 206 |
except Exception:
|
| 207 |
logging.exception("TenantLLMService.increase_usage got exception")
|
| 208 |
return num
|
api/db/services/task_service.py
CHANGED
|
@@ -53,6 +53,7 @@ class TaskService(CommonService):
|
|
| 53 |
Knowledgebase.tenant_id,
|
| 54 |
Knowledgebase.language,
|
| 55 |
Knowledgebase.embd_id,
|
|
|
|
| 56 |
Tenant.img2txt_id,
|
| 57 |
Tenant.asr_id,
|
| 58 |
Tenant.llm_id,
|
|
|
|
| 53 |
Knowledgebase.tenant_id,
|
| 54 |
Knowledgebase.language,
|
| 55 |
Knowledgebase.embd_id,
|
| 56 |
+
Knowledgebase.pagerank,
|
| 57 |
Tenant.img2txt_id,
|
| 58 |
Tenant.asr_id,
|
| 59 |
Tenant.llm_id,
|
conf/infinity_mapping.json
CHANGED
|
@@ -22,5 +22,6 @@
|
|
| 22 |
"rank_int": {"type": "integer", "default": 0},
|
| 23 |
"available_int": {"type": "integer", "default": 1},
|
| 24 |
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
|
| 25 |
-
"entities_kwd": {"type": "varchar", "default": ""}
|
|
|
|
| 26 |
}
|
|
|
|
| 22 |
"rank_int": {"type": "integer", "default": 0},
|
| 23 |
"available_int": {"type": "integer", "default": 1},
|
| 24 |
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
|
| 25 |
+
"entities_kwd": {"type": "varchar", "default": ""},
|
| 26 |
+
"pagerank_fea": {"type": "integer", "default": 0}
|
| 27 |
}
|
rag/nlp/search.py
CHANGED
|
@@ -75,7 +75,7 @@ class Dealer:
|
|
| 75 |
|
| 76 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
| 77 |
"doc_id", "position_list", "knowledge_graph_kwd",
|
| 78 |
-
"available_int", "content_with_weight"])
|
| 79 |
kwds = set([])
|
| 80 |
|
| 81 |
qst = req.get("question", "")
|
|
@@ -234,11 +234,13 @@ class Dealer:
|
|
| 234 |
vector_column = f"q_{vector_size}_vec"
|
| 235 |
zero_vector = [0.0] * vector_size
|
| 236 |
ins_embd = []
|
|
|
|
| 237 |
for chunk_id in sres.ids:
|
| 238 |
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
| 239 |
if isinstance(vector, str):
|
| 240 |
vector = [float(v) for v in vector.split("\t")]
|
| 241 |
ins_embd.append(vector)
|
|
|
|
| 242 |
if not ins_embd:
|
| 243 |
return [], [], []
|
| 244 |
|
|
@@ -257,7 +259,8 @@ class Dealer:
|
|
| 257 |
ins_embd,
|
| 258 |
keywords,
|
| 259 |
ins_tw, tkweight, vtweight)
|
| 260 |
-
|
|
|
|
| 261 |
|
| 262 |
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
| 263 |
vtweight=0.7, cfield="content_ltks"):
|
|
@@ -351,7 +354,7 @@ class Dealer:
|
|
| 351 |
"vector": chunk.get(vector_column, zero_vector),
|
| 352 |
"positions": json.loads(position_list)
|
| 353 |
}
|
| 354 |
-
if highlight:
|
| 355 |
if id in sres.highlight:
|
| 356 |
d["highlight"] = rmSpace(sres.highlight[id])
|
| 357 |
else:
|
|
|
|
| 75 |
|
| 76 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
| 77 |
"doc_id", "position_list", "knowledge_graph_kwd",
|
| 78 |
+
"available_int", "content_with_weight", "pagerank_fea"])
|
| 79 |
kwds = set([])
|
| 80 |
|
| 81 |
qst = req.get("question", "")
|
|
|
|
| 234 |
vector_column = f"q_{vector_size}_vec"
|
| 235 |
zero_vector = [0.0] * vector_size
|
| 236 |
ins_embd = []
|
| 237 |
+
pageranks = []
|
| 238 |
for chunk_id in sres.ids:
|
| 239 |
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
| 240 |
if isinstance(vector, str):
|
| 241 |
vector = [float(v) for v in vector.split("\t")]
|
| 242 |
ins_embd.append(vector)
|
| 243 |
+
pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
|
| 244 |
if not ins_embd:
|
| 245 |
return [], [], []
|
| 246 |
|
|
|
|
| 259 |
ins_embd,
|
| 260 |
keywords,
|
| 261 |
ins_tw, tkweight, vtweight)
|
| 262 |
+
|
| 263 |
+
return sim+np.array(pageranks, dtype=float), tksim, vtsim
|
| 264 |
|
| 265 |
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
| 266 |
vtweight=0.7, cfield="content_ltks"):
|
|
|
|
| 354 |
"vector": chunk.get(vector_column, zero_vector),
|
| 355 |
"positions": json.loads(position_list)
|
| 356 |
}
|
| 357 |
+
if highlight and sres.highlight:
|
| 358 |
if id in sres.highlight:
|
| 359 |
d["highlight"] = rmSpace(sres.highlight[id])
|
| 360 |
else:
|
rag/svr/task_executor.py
CHANGED
|
@@ -201,6 +201,7 @@ def build_chunks(task, progress_callback):
|
|
| 201 |
"doc_id": task["doc_id"],
|
| 202 |
"kb_id": str(task["kb_id"])
|
| 203 |
}
|
|
|
|
| 204 |
el = 0
|
| 205 |
for ck in cks:
|
| 206 |
d = copy.deepcopy(doc)
|
|
@@ -339,6 +340,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|
| 339 |
"docnm_kwd": row["name"],
|
| 340 |
"title_tks": rag_tokenizer.tokenize(row["name"])
|
| 341 |
}
|
|
|
|
| 342 |
res = []
|
| 343 |
tk_count = 0
|
| 344 |
for content, vctr in chunks[original_length:]:
|
|
@@ -431,7 +433,7 @@ def do_handle_task(task):
|
|
| 431 |
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
| 432 |
logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
|
| 433 |
if doc_store_result:
|
| 434 |
-
error_message = "Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
| 435 |
progress_callback(-1, msg=error_message)
|
| 436 |
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
|
| 437 |
logging.error(error_message)
|
|
|
|
| 201 |
"doc_id": task["doc_id"],
|
| 202 |
"kb_id": str(task["kb_id"])
|
| 203 |
}
|
| 204 |
+
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
|
| 205 |
el = 0
|
| 206 |
for ck in cks:
|
| 207 |
d = copy.deepcopy(doc)
|
|
|
|
| 340 |
"docnm_kwd": row["name"],
|
| 341 |
"title_tks": rag_tokenizer.tokenize(row["name"])
|
| 342 |
}
|
| 343 |
+
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
|
| 344 |
res = []
|
| 345 |
tk_count = 0
|
| 346 |
for content, vctr in chunks[original_length:]:
|
|
|
|
| 433 |
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
| 434 |
logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
|
| 435 |
if doc_store_result:
|
| 436 |
+
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
| 437 |
progress_callback(-1, msg=error_message)
|
| 438 |
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
|
| 439 |
logging.error(error_message)
|
rag/utils/es_conn.py
CHANGED
|
@@ -175,6 +175,7 @@ class ESConnection(DocStoreConnection):
|
|
| 175 |
)
|
| 176 |
|
| 177 |
if bqry:
|
|
|
|
| 178 |
s = s.query(bqry)
|
| 179 |
for field in highlightFields:
|
| 180 |
s = s.highlight(field)
|
|
@@ -283,12 +284,16 @@ class ESConnection(DocStoreConnection):
|
|
| 283 |
f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
| 284 |
if str(e).find("Timeout") > 0:
|
| 285 |
continue
|
|
|
|
| 286 |
else:
|
| 287 |
# update unspecific maybe-multiple documents
|
| 288 |
bqry = Q("bool")
|
| 289 |
for k, v in condition.items():
|
| 290 |
if not isinstance(k, str) or not v:
|
| 291 |
continue
|
|
|
|
|
|
|
|
|
|
| 292 |
if isinstance(v, list):
|
| 293 |
bqry.filter.append(Q("terms", **{k: v}))
|
| 294 |
elif isinstance(v, str) or isinstance(v, int):
|
|
@@ -298,6 +303,9 @@ class ESConnection(DocStoreConnection):
|
|
| 298 |
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
| 299 |
scripts = []
|
| 300 |
for k, v in newValue.items():
|
|
|
|
|
|
|
|
|
|
| 301 |
if (not isinstance(k, str) or not v) and k != "available_int":
|
| 302 |
continue
|
| 303 |
if isinstance(v, str):
|
|
@@ -307,21 +315,21 @@ class ESConnection(DocStoreConnection):
|
|
| 307 |
else:
|
| 308 |
raise Exception(
|
| 309 |
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
return False
|
| 326 |
|
| 327 |
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
|
|
|
| 175 |
)
|
| 176 |
|
| 177 |
if bqry:
|
| 178 |
+
bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
|
| 179 |
s = s.query(bqry)
|
| 180 |
for field in highlightFields:
|
| 181 |
s = s.highlight(field)
|
|
|
|
| 284 |
f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
| 285 |
if str(e).find("Timeout") > 0:
|
| 286 |
continue
|
| 287 |
+
return False
|
| 288 |
else:
|
| 289 |
# update unspecific maybe-multiple documents
|
| 290 |
bqry = Q("bool")
|
| 291 |
for k, v in condition.items():
|
| 292 |
if not isinstance(k, str) or not v:
|
| 293 |
continue
|
| 294 |
+
if k == "exist":
|
| 295 |
+
bqry.filter.append(Q("exists", field=v))
|
| 296 |
+
continue
|
| 297 |
if isinstance(v, list):
|
| 298 |
bqry.filter.append(Q("terms", **{k: v}))
|
| 299 |
elif isinstance(v, str) or isinstance(v, int):
|
|
|
|
| 303 |
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
| 304 |
scripts = []
|
| 305 |
for k, v in newValue.items():
|
| 306 |
+
if k == "remove":
|
| 307 |
+
scripts.append(f"ctx._source.remove('{v}');")
|
| 308 |
+
continue
|
| 309 |
if (not isinstance(k, str) or not v) and k != "available_int":
|
| 310 |
continue
|
| 311 |
if isinstance(v, str):
|
|
|
|
| 315 |
else:
|
| 316 |
raise Exception(
|
| 317 |
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
| 318 |
+
ubq = UpdateByQuery(
|
| 319 |
+
index=indexName).using(
|
| 320 |
+
self.es).query(bqry)
|
| 321 |
+
ubq = ubq.script(source="; ".join(scripts))
|
| 322 |
+
ubq = ubq.params(refresh=True)
|
| 323 |
+
ubq = ubq.params(slices=5)
|
| 324 |
+
ubq = ubq.params(conflicts="proceed")
|
| 325 |
+
for i in range(3):
|
| 326 |
+
try:
|
| 327 |
+
_ = ubq.execute()
|
| 328 |
+
return True
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error("ESConnection.update got exception: " + str(e))
|
| 331 |
+
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
| 332 |
+
continue
|
| 333 |
return False
|
| 334 |
|
| 335 |
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
sdk/python/ragflow_sdk/modules/dataset.py
CHANGED
|
@@ -21,6 +21,7 @@ class DataSet(Base):
|
|
| 21 |
self.chunk_count = 0
|
| 22 |
self.chunk_method = "naive"
|
| 23 |
self.parser_config = None
|
|
|
|
| 24 |
for k in list(res_dict.keys()):
|
| 25 |
if k not in self.__dict__:
|
| 26 |
res_dict.pop(k)
|
|
|
|
| 21 |
self.chunk_count = 0
|
| 22 |
self.chunk_method = "naive"
|
| 23 |
self.parser_config = None
|
| 24 |
+
self.pagerank = 0
|
| 25 |
for k in list(res_dict.keys()):
|
| 26 |
if k not in self.__dict__:
|
| 27 |
res_dict.pop(k)
|