|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
from functools import reduce, partial |
|
import networkx as nx |
|
|
|
from api import settings |
|
from graphrag.general.community_reports_extractor import CommunityReportsExtractor |
|
from graphrag.entity_resolution import EntityResolution |
|
from graphrag.general.extractor import Extractor |
|
from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES |
|
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \ |
|
chunk_id, update_nodes_pagerank_nhop_neighbour |
|
from rag.nlp import rag_tokenizer, search |
|
from rag.utils.redis_conn import RedisDistributedLock |
|
|
|
|
|
class Dealer: |
|
def __init__(self, |
|
extractor: Extractor, |
|
tenant_id: str, |
|
kb_id: str, |
|
llm_bdl, |
|
chunks: list[tuple[str, str]], |
|
language, |
|
entity_types=DEFAULT_ENTITY_TYPES, |
|
embed_bdl=None, |
|
callback=None |
|
): |
|
docids = list(set([docid for docid,_ in chunks])) |
|
self.llm_bdl = llm_bdl |
|
self.embed_bdl = embed_bdl |
|
ext = extractor(self.llm_bdl, language=language, |
|
entity_types=entity_types, |
|
get_entity=partial(get_entity, tenant_id, kb_id), |
|
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), |
|
get_relation=partial(get_relation, tenant_id, kb_id), |
|
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) |
|
) |
|
ents, rels = ext(chunks, callback) |
|
self.graph = nx.Graph() |
|
for en in ents: |
|
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"]) |
|
|
|
for rel in rels: |
|
self.graph.add_edge( |
|
rel["src_id"], |
|
rel["tgt_id"], |
|
weight=rel["weight"], |
|
|
|
) |
|
|
|
with RedisDistributedLock(kb_id, 60*60): |
|
old_graph, old_doc_ids = get_graph(tenant_id, kb_id) |
|
if old_graph is not None: |
|
logging.info("Merge with an exiting graph...................") |
|
self.graph = reduce(graph_merge, [old_graph, self.graph]) |
|
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2) |
|
if old_doc_ids: |
|
docids.extend(old_doc_ids) |
|
docids = list(set(docids)) |
|
set_graph(tenant_id, kb_id, self.graph, docids) |
|
|
|
|
|
class WithResolution(Dealer): |
|
def __init__(self, |
|
tenant_id: str, |
|
kb_id: str, |
|
llm_bdl, |
|
embed_bdl=None, |
|
callback=None |
|
): |
|
self.llm_bdl = llm_bdl |
|
self.embed_bdl = embed_bdl |
|
|
|
with RedisDistributedLock(kb_id, 60*60): |
|
self.graph, doc_ids = get_graph(tenant_id, kb_id) |
|
if not self.graph: |
|
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}") |
|
if callback: |
|
callback(-1, msg="Faild to fetch the graph.") |
|
return |
|
|
|
if callback: |
|
callback(msg="Fetch the existing graph.") |
|
er = EntityResolution(self.llm_bdl, |
|
get_entity=partial(get_entity, tenant_id, kb_id), |
|
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), |
|
get_relation=partial(get_relation, tenant_id, kb_id), |
|
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)) |
|
reso = er(self.graph) |
|
self.graph = reso.graph |
|
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) |
|
if callback: |
|
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) |
|
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2) |
|
set_graph(tenant_id, kb_id, self.graph, doc_ids) |
|
|
|
settings.docStoreConn.delete({ |
|
"knowledge_graph_kwd": "relation", |
|
"kb_id": kb_id, |
|
"from_entity_kwd": reso.removed_entities |
|
}, search.index_name(tenant_id), kb_id) |
|
settings.docStoreConn.delete({ |
|
"knowledge_graph_kwd": "relation", |
|
"kb_id": kb_id, |
|
"to_entity_kwd": reso.removed_entities |
|
}, search.index_name(tenant_id), kb_id) |
|
settings.docStoreConn.delete({ |
|
"knowledge_graph_kwd": "entity", |
|
"kb_id": kb_id, |
|
"entity_kwd": reso.removed_entities |
|
}, search.index_name(tenant_id), kb_id) |
|
|
|
|
|
class WithCommunity(Dealer): |
|
def __init__(self, |
|
tenant_id: str, |
|
kb_id: str, |
|
llm_bdl, |
|
embed_bdl=None, |
|
callback=None |
|
): |
|
|
|
self.community_structure = None |
|
self.community_reports = None |
|
self.llm_bdl = llm_bdl |
|
self.embed_bdl = embed_bdl |
|
|
|
with RedisDistributedLock(kb_id, 60*60): |
|
self.graph, doc_ids = get_graph(tenant_id, kb_id) |
|
if not self.graph: |
|
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}") |
|
if callback: |
|
callback(-1, msg="Faild to fetch the graph.") |
|
return |
|
if callback: |
|
callback(msg="Fetch the existing graph.") |
|
|
|
cr = CommunityReportsExtractor(self.llm_bdl, |
|
get_entity=partial(get_entity, tenant_id, kb_id), |
|
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), |
|
get_relation=partial(get_relation, tenant_id, kb_id), |
|
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)) |
|
cr = cr(self.graph, callback=callback) |
|
self.community_structure = cr.structured_output |
|
self.community_reports = cr.output |
|
set_graph(tenant_id, kb_id, self.graph, doc_ids) |
|
|
|
if callback: |
|
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) |
|
|
|
settings.docStoreConn.delete({ |
|
"knowledge_graph_kwd": "community_report", |
|
"kb_id": kb_id |
|
}, search.index_name(tenant_id), kb_id) |
|
|
|
for stru, rep in zip(self.community_structure, self.community_reports): |
|
obj = { |
|
"report": rep, |
|
"evidences": "\n".join([f["explanation"] for f in stru["findings"]]) |
|
} |
|
chunk = { |
|
"docnm_kwd": stru["title"], |
|
"title_tks": rag_tokenizer.tokenize(stru["title"]), |
|
"content_with_weight": json.dumps(obj, ensure_ascii=False), |
|
"content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]), |
|
"knowledge_graph_kwd": "community_report", |
|
"weight_flt": stru["weight"], |
|
"entities_kwd": stru["entities"], |
|
"important_kwd": stru["entities"], |
|
"kb_id": kb_id, |
|
"source_id": doc_ids, |
|
"available_int": 0 |
|
} |
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) |
|
|
|
|
|
|
|
|
|
|
|
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)) |
|
|
|
|