diff --git a/graph/README.md b/agent/README.md similarity index 100% rename from graph/README.md rename to agent/README.md diff --git a/graph/README_zh.md b/agent/README_zh.md similarity index 100% rename from graph/README_zh.md rename to agent/README_zh.md diff --git a/graph/__init__.py b/agent/__init__.py similarity index 100% rename from graph/__init__.py rename to agent/__init__.py diff --git a/graph/canvas.py b/agent/canvas.py similarity index 98% rename from graph/canvas.py rename to agent/canvas.py index 2d61915b3395d057c1cee0a6fcd6a44a94f82fdb..566a2bf311c44db5222daef759b37136f6a1019c 100644 --- a/graph/canvas.py +++ b/agent/canvas.py @@ -22,9 +22,9 @@ from functools import partial import pandas as pd -from graph.component import component_class -from graph.component.base import ComponentBase -from graph.settings import flow_logger, DEBUG +from agent.component import component_class +from agent.component.base import ComponentBase +from agent.settings import flow_logger, DEBUG class Canvas(ABC): diff --git a/graph/component/__init__.py b/agent/component/__init__.py similarity index 100% rename from graph/component/__init__.py rename to agent/component/__init__.py diff --git a/graph/component/answer.py b/agent/component/answer.py similarity index 97% rename from graph/component/answer.py rename to agent/component/answer.py index 3973111d780d4b9f7460c7c2e4c838db87b6c073..aedc0dd979332009697f109164517be88206353b 100644 --- a/graph/component/answer.py +++ b/agent/component/answer.py @@ -19,7 +19,7 @@ from functools import partial import pandas as pd -from graph.component.base import ComponentBase, ComponentParamBase +from agent.component.base import ComponentBase, ComponentParamBase class AnswerParam(ComponentParamBase): diff --git a/graph/component/arxiv.py b/agent/component/arxiv.py similarity index 94% rename from graph/component/arxiv.py rename to agent/component/arxiv.py index 7d485a43bde36e8c1d3e0f463f848464c1edc4fb..6b47ded9d15e872a49d2d524e4f343a7a9e7e000 100644 --- a/graph/component/arxiv.py +++ b/agent/component/arxiv.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import random from abc import ABC -from functools import partial import arxiv import pandas as pd -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase class ArXivParam(ComponentParamBase): diff --git a/graph/component/baidu.py b/agent/component/baidu.py similarity index 93% rename from graph/component/baidu.py rename to agent/component/baidu.py index 024917cc08fa29c3033a36f4a5525fc6c6d50bd7..0a866aab0c4faa2006ec0568fdcc5fefa54f9583 100644 --- a/graph/component/baidu.py +++ b/agent/component/baidu.py @@ -19,8 +19,8 @@ from functools import partial import pandas as pd import requests import re -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase class BaiduParam(ComponentParamBase): diff --git a/graph/component/base.py b/agent/component/base.py similarity index 99% rename from graph/component/base.py rename to agent/component/base.py index 06574e32d7f896613939666908ce3ff234d9dfd2..ec26dc5b5728db0c2bb6c15b8e162e09f239d02d 100644 --- a/graph/component/base.py +++ b/agent/component/base.py @@ -23,8 +23,8 @@ from typing import List, Dict, Tuple, Union import pandas as pd -from graph import settings -from graph.settings import flow_logger, DEBUG +from agent import settings +from agent.settings import flow_logger, DEBUG _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" diff --git a/graph/component/begin.py b/agent/component/begin.py similarity index 94% rename from graph/component/begin.py rename to agent/component/begin.py index 024e4fed3e8bd337f7dfdf35ff7ac6b0ab77dfbf..037a8a057c202f254de748e1399c397d41ad1961 100644 --- a/graph/component/begin.py +++ b/agent/component/begin.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json from functools import partial - import pandas as pd -from graph.component.base import ComponentBase, ComponentParamBase +from agent.component.base import ComponentBase, ComponentParamBase + class BeginParam(ComponentParamBase): diff --git a/graph/component/bing.py b/agent/component/bing.py similarity index 95% rename from graph/component/bing.py rename to agent/component/bing.py index 128358816a25dc2215c1e073e57c872501c29d2c..14fce05559e63db60d7cef3be496fd3b35b10670 100644 --- a/graph/component/bing.py +++ b/agent/component/bing.py @@ -16,8 +16,8 @@ from abc import ABC import requests import pandas as pd -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase class BingParam(ComponentParamBase): diff --git a/graph/component/categorize.py b/agent/component/categorize.py similarity index 96% rename from graph/component/categorize.py rename to agent/component/categorize.py index c9ef19164e187b6c9e3f3ee8d4d7d17c99bf127d..ee398e493915a2b50c1d63d45f25a60f8d5e5af3 100644 --- a/graph/component/categorize.py +++ b/agent/component/categorize.py @@ -14,13 +14,10 @@ # limitations under the License. # from abc import ABC - -import pandas as pd - from api.db import LLMType from api.db.services.llm_service import LLMBundle -from graph.component import GenerateParam, Generate -from graph.settings import DEBUG +from agent.component import GenerateParam, Generate +from agent.settings import DEBUG class CategorizeParam(GenerateParam): diff --git a/graph/component/cite.py b/agent/component/cite.py similarity index 97% rename from graph/component/cite.py rename to agent/component/cite.py index 4d38b4c4737b7386d5cb308549b16cf611cdb257..f50bc4e81394558bb4cb84884379b8210e886a5f 100644 --- a/graph/component/cite.py +++ b/agent/component/cite.py @@ -21,7 +21,7 @@ from api.db import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.settings import retrievaler -from graph.component.base import ComponentBase, ComponentParamBase +from agent.component.base import ComponentBase, ComponentParamBase class CiteParam(ComponentParamBase): diff --git a/graph/component/duckduckgo.py b/agent/component/duckduckgo.py similarity index 94% rename from graph/component/duckduckgo.py rename to agent/component/duckduckgo.py index e796b4af6d28ff7ae5b74256f30db25c52c5c0de..2ee011369dbb73c21d73054ae2bd8a075aeb8f7a 100644 --- a/graph/component/duckduckgo.py +++ b/agent/component/duckduckgo.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import random from abc import ABC -from functools import partial from duckduckgo_search import DDGS import pandas as pd -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase class DuckDuckGoParam(ComponentParamBase): diff --git a/graph/component/generate.py b/agent/component/generate.py similarity index 98% rename from graph/component/generate.py rename to agent/component/generate.py index 924af903f7c225eb5bafa4721735ff9baff4e60f..c3f50f4de28a17a5284fae5314f07cd5d01b394d 100644 --- a/graph/component/generate.py +++ b/agent/component/generate.py @@ -15,13 +15,11 @@ # import re from functools import partial - import pandas as pd - from api.db import LLMType from api.db.services.llm_service import LLMBundle from api.settings import retrievaler -from graph.component.base import ComponentBase, ComponentParamBase +from agent.component.base import ComponentBase, ComponentParamBase class GenerateParam(ComponentParamBase): diff --git a/graph/component/google.py b/agent/component/google.py similarity index 96% rename from graph/component/google.py rename to agent/component/google.py index eb7cd50de7736f128632efe9d957f4f1908e6241..3ac477040168053675ce7cd44cbbf9c3e6393622 100644 --- a/graph/component/google.py +++ b/agent/component/google.py @@ -16,8 +16,8 @@ from abc import ABC from serpapi import GoogleSearch import pandas as pd -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase class GoogleParam(ComponentParamBase): diff --git a/graph/component/googlescholar.py b/agent/component/googlescholar.py similarity index 93% rename from graph/component/googlescholar.py rename to agent/component/googlescholar.py index 8da7ba55fdeb1edaed84cbeef47285215190e25e..f895c6cf1fe0f271369141882cdb429c39232fdb 100644 --- a/graph/component/googlescholar.py +++ b/agent/component/googlescholar.py @@ -15,8 +15,8 @@ # from abc import ABC import pandas as pd -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase from scholarly import scholarly diff --git a/graph/component/keyword.py b/agent/component/keyword.py similarity index 95% rename from graph/component/keyword.py rename to agent/component/keyword.py index 7a326cbfe4b1e734e57a2e48b6cd64aeb36ca772..c5083efae2d510ba0335763156e598b91af82d99 100644 --- a/graph/component/keyword.py +++ b/agent/component/keyword.py @@ -17,8 +17,8 @@ import re from abc import ABC from api.db import LLMType from api.db.services.llm_service import LLMBundle -from graph.component import GenerateParam, Generate -from graph.settings import DEBUG +from agent.component import GenerateParam, Generate +from agent.settings import DEBUG class KeywordExtractParam(GenerateParam): diff --git a/graph/component/message.py b/agent/component/message.py similarity index 94% rename from graph/component/message.py rename to agent/component/message.py index 87d290fef35a835971a34a8a272874d77a817bcc..a193dd122ba259b43ec22a43444c8da18a396d15 100644 --- a/graph/component/message.py +++ b/agent/component/message.py @@ -16,10 +16,7 @@ import random from abc import ABC from functools import partial - -import pandas as pd - -from graph.component.base import ComponentBase, ComponentParamBase +from agent.component.base import ComponentBase, ComponentParamBase class MessageParam(ComponentParamBase): diff --git a/graph/component/pubmed.py b/agent/component/pubmed.py similarity index 94% rename from graph/component/pubmed.py rename to agent/component/pubmed.py index 1c76cf93c87ec078e9ad8cf1d1783f934259121f..b586d67e70639a096222e1006abb2e715bebcb0b 100644 --- a/graph/component/pubmed.py +++ b/agent/component/pubmed.py @@ -13,14 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import random from abc import ABC -from functools import partial from Bio import Entrez import pandas as pd import xml.etree.ElementTree as ET -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase class PubMedParam(ComponentParamBase): diff --git a/graph/component/relevant.py b/agent/component/relevant.py similarity index 98% rename from graph/component/relevant.py rename to agent/component/relevant.py index ab2fa318f1fecca50308ad062317b9c8c58df853..8f246f3d27f33e6ead6094aad9a853df7f92c338 100644 --- a/graph/component/relevant.py +++ b/agent/component/relevant.py @@ -16,7 +16,7 @@ from abc import ABC from api.db import LLMType from api.db.services.llm_service import LLMBundle -from graph.component import GenerateParam, Generate +from agent.component import GenerateParam, Generate from rag.utils import num_tokens_from_string, encoder diff --git a/graph/component/retrieval.py b/agent/component/retrieval.py similarity index 98% rename from graph/component/retrieval.py rename to agent/component/retrieval.py index a765556cefd67148f25a161d71a1701c353815d9..84893d185270fe38694cf6b0f2041c5d74f18d5c 100644 --- a/graph/component/retrieval.py +++ b/agent/component/retrieval.py @@ -21,7 +21,7 @@ from api.db import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.settings import retrievaler -from graph.component.base import ComponentBase, ComponentParamBase +from agent.component.base import ComponentBase, ComponentParamBase class RetrievalParam(ComponentParamBase): diff --git a/graph/component/rewrite.py b/agent/component/rewrite.py similarity index 98% rename from graph/component/rewrite.py rename to agent/component/rewrite.py index 11a374660b411e3e5432e04d632f94b7408ca8c6..614168291d95280c2e2f2a5a7f47b07b93e0a3c5 100644 --- a/graph/component/rewrite.py +++ b/agent/component/rewrite.py @@ -16,7 +16,7 @@ from abc import ABC from api.db import LLMType from api.db.services.llm_service import LLMBundle -from graph.component import GenerateParam, Generate +from agent.component import GenerateParam, Generate class RewriteQuestionParam(GenerateParam): diff --git a/graph/component/switch.py b/agent/component/switch.py similarity index 90% rename from graph/component/switch.py rename to agent/component/switch.py index a431143d3140498acfa6f63f303088654f3d4924..21b78efd0e98b588cbb7f66c9f5a3101c1db4911 100644 --- a/graph/component/switch.py +++ b/agent/component/switch.py @@ -16,12 +16,7 @@ from abc import ABC import pandas as pd - -from api.db import LLMType -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMBundle -from api.settings import retrievaler -from graph.component.base import ComponentBase, ComponentParamBase +from agent.component.base import ComponentBase, ComponentParamBase class SwitchParam(ComponentParamBase): diff --git a/graph/component/wikipedia.py b/agent/component/wikipedia.py similarity index 96% rename from graph/component/wikipedia.py rename to agent/component/wikipedia.py index 6a4b015d7aa1311cb8e4cfd7af46dc043936887f..811ac8443396b53221072414ae7ed500a4d75a70 100644 --- a/graph/component/wikipedia.py +++ b/agent/component/wikipedia.py @@ -18,8 +18,8 @@ from abc import ABC from functools import partial import wikipedia import pandas as pd -from graph.settings import DEBUG -from graph.component.base import ComponentBase, ComponentParamBase +from agent.settings import DEBUG +from agent.component.base import ComponentBase, ComponentParamBase class WikipediaParam(ComponentParamBase): diff --git a/graph/settings.py b/agent/settings.py similarity index 100% rename from graph/settings.py rename to agent/settings.py diff --git a/graph/templates/HR_callout_zh.json b/agent/templates/HR_callout_zh.json similarity index 100% rename from graph/templates/HR_callout_zh.json rename to agent/templates/HR_callout_zh.json diff --git a/graph/templates/customer_service.json b/agent/templates/customer_service.json similarity index 100% rename from graph/templates/customer_service.json rename to agent/templates/customer_service.json diff --git a/graph/templates/general_chat_bot.json b/agent/templates/general_chat_bot.json similarity index 100% rename from graph/templates/general_chat_bot.json rename to agent/templates/general_chat_bot.json diff --git a/graph/templates/interpreter.json b/agent/templates/interpreter.json similarity index 100% rename from graph/templates/interpreter.json rename to agent/templates/interpreter.json diff --git a/graph/templates/websearch_assistant.json b/agent/templates/websearch_assistant.json similarity index 100% rename from graph/templates/websearch_assistant.json rename to agent/templates/websearch_assistant.json diff --git a/graph/test/client.py b/agent/test/client.py similarity index 95% rename from graph/test/client.py rename to agent/test/client.py index 3682a51721b082fd0953ab67d682ce775f5c97d4..be9115290cf45c8d2d73152b745fb978c92793d3 100644 --- a/graph/test/client.py +++ b/agent/test/client.py @@ -16,9 +16,8 @@ import argparse import os from functools import partial -import readline -from graph.canvas import Canvas -from graph.settings import DEBUG +from agent.canvas import Canvas +from agent.settings import DEBUG if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/graph/test/dsl_examples/categorize.json b/agent/test/dsl_examples/categorize.json similarity index 100% rename from graph/test/dsl_examples/categorize.json rename to agent/test/dsl_examples/categorize.json diff --git a/graph/test/dsl_examples/customer_service.json b/agent/test/dsl_examples/customer_service.json similarity index 100% rename from graph/test/dsl_examples/customer_service.json rename to agent/test/dsl_examples/customer_service.json diff --git a/graph/test/dsl_examples/headhunter_zh.json b/agent/test/dsl_examples/headhunter_zh.json similarity index 100% rename from graph/test/dsl_examples/headhunter_zh.json rename to agent/test/dsl_examples/headhunter_zh.json diff --git a/graph/test/dsl_examples/intergreper.json b/agent/test/dsl_examples/intergreper.json similarity index 100% rename from graph/test/dsl_examples/intergreper.json rename to agent/test/dsl_examples/intergreper.json diff --git a/graph/test/dsl_examples/interpreter.json b/agent/test/dsl_examples/interpreter.json similarity index 100% rename from graph/test/dsl_examples/interpreter.json rename to agent/test/dsl_examples/interpreter.json diff --git a/graph/test/dsl_examples/keyword_wikipedia_and_generate.json b/agent/test/dsl_examples/keyword_wikipedia_and_generate.json similarity index 100% rename from graph/test/dsl_examples/keyword_wikipedia_and_generate.json rename to agent/test/dsl_examples/keyword_wikipedia_and_generate.json diff --git a/graph/test/dsl_examples/retrieval_and_generate.json b/agent/test/dsl_examples/retrieval_and_generate.json similarity index 100% rename from graph/test/dsl_examples/retrieval_and_generate.json rename to agent/test/dsl_examples/retrieval_and_generate.json diff --git a/graph/test/dsl_examples/retrieval_categorize_and_generate.json b/agent/test/dsl_examples/retrieval_categorize_and_generate.json similarity index 100% rename from graph/test/dsl_examples/retrieval_categorize_and_generate.json rename to agent/test/dsl_examples/retrieval_categorize_and_generate.json diff --git a/graph/test/dsl_examples/retrieval_relevant_and_generate.json b/agent/test/dsl_examples/retrieval_relevant_and_generate.json similarity index 100% rename from graph/test/dsl_examples/retrieval_relevant_and_generate.json rename to agent/test/dsl_examples/retrieval_relevant_and_generate.json diff --git a/graph/test/dsl_examples/retrieval_relevant_keyword_baidu_and_generate.json b/agent/test/dsl_examples/retrieval_relevant_keyword_baidu_and_generate.json similarity index 100% rename from graph/test/dsl_examples/retrieval_relevant_keyword_baidu_and_generate.json rename to agent/test/dsl_examples/retrieval_relevant_keyword_baidu_and_generate.json diff --git a/graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json b/agent/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json similarity index 100% rename from graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json rename to agent/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json diff --git a/api/apps/api_app.py b/api/apps/api_app.py index bae0527d537a31ed9e6582a9c46b24d35f145154..8d7f00e8bc19b11ac5de53fa3132b3d9006292d9 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -20,7 +20,7 @@ from datetime import datetime, timedelta from flask import request, Response from flask_login import login_required, current_user -from api.db import FileType, ParserType, FileSource, LLMType +from api.db import FileType, ParserType, FileSource from api.db.db_models import APIToken, API4Conversation, Task, File from api.db.services import duplicate_name from api.db.services.api_service import APITokenService, API4ConversationService @@ -29,7 +29,6 @@ from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import TenantLLMService from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService from api.settings import RetCode, retrievaler @@ -38,7 +37,6 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge from itsdangerous import URLSafeTimedSerializer from api.utils.file_utils import filename_type, thumbnail -from rag.nlp import keyword_extraction from rag.utils.minio_conn import MINIO diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 40562c3fe32640ed81990fbc732dede187d71d05..0f17b8ebf2e9364b299a68056b69439e46162604 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -15,15 +15,12 @@ # import json from functools import partial - from flask import request, Response from flask_login import login_required, current_user - -from api.db.db_models import UserCanvas from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.utils import get_uuid from api.utils.api_utils import get_json_result, server_error_response, validate_request -from graph.canvas import Canvas +from agent.canvas import Canvas @manager.route('/templates', methods=['GET']) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index c42ec4a980f1ef6df762a8e2a9b46ed2a70a3384..f65c53b396c9613ad82d36c7d62c1696091fe269 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -14,6 +14,8 @@ # limitations under the License. # import datetime +import json +import traceback from flask import request from flask_login import login_required, current_user @@ -29,7 +31,7 @@ from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.db.services.document_service import DocumentService -from api.settings import RetCode, retrievaler +from api.settings import RetCode, retrievaler, kg_retrievaler from api.utils.api_utils import get_json_result import hashlib import re @@ -61,7 +63,8 @@ def list_chunk(): for id in sres.ids: d = { "chunk_id": id, - "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get( + "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[ + id].get( "content_with_weight", ""), "doc_id": sres.field[id]["doc_id"], "docnm_kwd": sres.field[id]["docnm_kwd"], @@ -136,11 +139,11 @@ def set(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - + embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_mdl = TenantLLMService.model_instance( tenant_id, LLMType.EMBEDDING.value, embd_id) - + e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(retmsg="Document not found!") @@ -185,7 +188,7 @@ def switch(): @manager.route('/rm', methods=['POST']) @login_required -@validate_request("chunk_ids","doc_id") +@validate_request("chunk_ids", "doc_id") def rm(): req = request.json try: @@ -230,11 +233,11 @@ def create(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - + embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_mdl = TenantLLMService.model_instance( tenant_id, LLMType.EMBEDDING.value, embd_id) - + v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() @@ -277,9 +280,10 @@ def retrieval_test(): chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, - similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl) + retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler + ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, + similarity_threshold, vector_similarity_weight, top, + doc_ids, rerank_mdl=rerank_mdl) for c in ranks["chunks"]: if "vector" in c: del c["vector"] @@ -290,3 +294,25 @@ def retrieval_test(): return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', retcode=RetCode.DATA_ERROR) return server_error_response(e) + + +@manager.route('/knowledge_graph', methods=['GET']) +@login_required +def knowledge_graph(): + doc_id = request.args["doc_id"] + req = { + "doc_ids":[doc_id], + "knowledge_graph_kwd": ["graph", "mind_map"] + } + tenant_id = DocumentService.get_tenant_id(doc_id) + sres = retrievaler.search(req, search.index_name(tenant_id)) + obj = {"graph": {}, "mind_map": {}} + for id in sres.ids[:2]: + ty = sres.field[id]["knowledge_graph_kwd"] + try: + obj[ty] = json.loads(sres.field[id]["content_with_weight"]) + except Exception as e: + print(traceback.format_exc(), flush=True) + + return get_json_result(data=obj) + diff --git a/api/apps/dataset_api.py b/api/apps/dataset_api.py index 111cf2ec2caf318c1d9550a52826ca3fefa7bb35..c0d1c86d3a0fe6f3872718e7d8694fb36861fe71 100644 --- a/api/apps/dataset_api.py +++ b/api/apps/dataset_api.py @@ -623,7 +623,7 @@ def doc_parse_callback(doc_id, prog=None, msg=""): if cancel: raise Exception("The parsing process has been cancelled!") - +""" def doc_parse(binary, doc_name, parser_name, tenant_id, doc_id): match parser_name: case "book": @@ -656,6 +656,7 @@ def doc_parse(binary, doc_name, parser_name, tenant_id, doc_id): return False return True + """ @manager.route("//documents//status", methods=["POST"]) diff --git a/api/db/__init__.py b/api/db/__init__.py index 81d79c1a7956ef4b98e005e15c4b088733f37962..505d603b560762a001ed978f4324b90e521455b4 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -85,6 +85,7 @@ class ParserType(StrEnum): PICTURE = "picture" ONE = "one" AUDIO = "audio" + KG = "knowledge_graph" class FileSource(StrEnum): diff --git a/api/db/init_data.py b/api/db/init_data.py index a042f4500eb86f6cce17f86d0c1ec21124d4d1ea..7ced350c73c3570c0301c621dc04b87f5a4757be 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -122,7 +122,7 @@ def init_llm_factory(): LLMService.filter_delete([LLMService.model.fid == "QAnything"]) TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) TenantService.filter_update([1 == 1], { - "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio"}) + "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph"}) ## insert openai two embedding models to the current openai user. print("Start to insert 2 OpenAI embedding models...") tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()]) @@ -145,7 +145,7 @@ def init_llm_factory(): """ drop table llm; drop table llm_factories; - update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio'; + update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph'; alter table knowledgebase modify avatar longtext; alter table user modify avatar longtext; alter table dialog modify icon longtext; @@ -153,7 +153,7 @@ def init_llm_factory(): def add_graph_templates(): - dir = os.path.join(get_project_base_directory(), "graph", "templates") + dir = os.path.join(get_project_base_directory(), "agent", "templates") for fnm in os.listdir(dir): try: cnvs = json.load(open(os.path.join(dir, fnm), "r")) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 20cb27cfdc4d209b461abb6006af6ea5e4d726d2..ab43b944cbb2361bfec5fcaebc75657199faf553 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -18,12 +18,12 @@ import json import re from copy import deepcopy -from api.db import LLMType +from api.db import LLMType, ParserType from api.db.db_models import Dialog, Conversation from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle -from api.settings import chat_logger, retrievaler +from api.settings import chat_logger, retrievaler, kg_retrievaler from rag.app.resume import forbidden_select_fields4resume from rag.nlp import keyword_extraction from rag.nlp.search import index_name @@ -101,6 +101,9 @@ def chat(dialog, messages, stream=True, **kwargs): yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} + is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) + retr = retrievaler if not is_kg else kg_retrievaler + questions = [m["content"] for m in messages if m["role"] == "user"] embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) if llm_id2llm_type(dialog.llm_id) == "image2text": @@ -138,7 +141,7 @@ def chat(dialog, messages, stream=True, **kwargs): else: if prompt_config.get("keyword", False): questions[-1] += keyword_extraction(chat_mdl, questions[-1]) - kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, + kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, @@ -147,7 +150,7 @@ def chat(dialog, messages, stream=True, **kwargs): #self-rag if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges): questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1]) - kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, + kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, @@ -179,7 +182,7 @@ def chat(dialog, messages, stream=True, **kwargs): nonlocal prompt_config, knowledges, kwargs, kbinfos refs = [] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): - answer, idx = retrievaler.insert_citations(answer, + answer, idx = retr.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 942207c067d34d5984237feed470b7f74a18a7ad..33565bcb87a61e37eef8280adfc407fe61246a7e 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -139,6 +139,8 @@ def queue_tasks(doc, bucket, name): page_size = doc["parser_config"].get("task_page_size", 22) if doc["parser_id"] == "one": page_size = 1000000000 + if doc["parser_id"] == "knowledge_graph": + page_size = 1000000000 if not do_layout: page_size = 1000000000 page_ranges = doc["parser_config"].get("pages") diff --git a/api/settings.py b/api/settings.py index 1adbd5014a3703cc344316c23819996cfa2d80b3..90efd13a0e8580f8d1b35b01b1ccc0cd019791c8 100644 --- a/api/settings.py +++ b/api/settings.py @@ -34,6 +34,7 @@ chat_logger = getLogger("chat") from rag.utils.es_conn import ELASTICSEARCH from rag.nlp import search +from graphrag import search as kg_search from api.utils import get_base_config, decrypt_database_config API_VERSION = "v1" @@ -131,7 +132,7 @@ IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] API_KEY = LLM.get("api_key", "") PARSERS = LLM.get( "parsers", - "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio") + "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph") # distribution DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) @@ -204,6 +205,7 @@ PRIVILEGE_COMMAND_WHITELIST = [] CHECK_NODES_IDENTITY = False retrievaler = search.Dealer(ELASTICSEARCH) +kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH) class CustomEnum(Enum): diff --git a/graphrag/claim_extractor.py b/graphrag/claim_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9e7f323700794d1e3aa64f716a7f4229b8b4ab --- /dev/null +++ b/graphrag/claim_extractor.py @@ -0,0 +1,278 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" +import argparse +import json +import logging +import re +import traceback +from dataclasses import dataclass +from typing import Any + +import tiktoken + +from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT +from rag.llm.chat_model import Base as CompletionLLM +from graphrag.utils import ErrorHandlerFn, perform_variable_replacements + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +CLAIM_MAX_GLEANINGS = 1 +log = logging.getLogger(__name__) + + +@dataclass +class ClaimExtractorResult: + """Claim extractor result class definition.""" + + output: list[dict] + source_docs: dict[str, Any] + + +class ClaimExtractor: + """Claim extractor class definition.""" + + _llm: CompletionLLM + _extraction_prompt: str + _summary_prompt: str + _output_formatter_prompt: str + _input_text_key: str + _input_entity_spec_key: str + _input_claim_description_key: str + _tuple_delimiter_key: str + _record_delimiter_key: str + _completion_delimiter_key: str + _max_gleanings: int + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: CompletionLLM, + extraction_prompt: str | None = None, + input_text_key: str | None = None, + input_entity_spec_key: str | None = None, + input_claim_description_key: str | None = None, + input_resolved_entities_key: str | None = None, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + completion_delimiter_key: str | None = None, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT + self._input_text_key = input_text_key or "input_text" + self._input_entity_spec_key = input_entity_spec_key or "entity_specs" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._input_claim_description_key = ( + input_claim_description_key or "claim_description" + ) + self._input_resolved_entities_key = ( + input_resolved_entities_key or "resolved_entities" + ) + self._max_gleanings = ( + max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = encoding.encode("YES") + no = encoding.encode("NO") + self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + + def __call__( + self, inputs: dict[str, Any], prompt_variables: dict | None = None + ) -> ClaimExtractorResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + texts = inputs[self._input_text_key] + entity_spec = str(inputs[self._input_entity_spec_key]) + claim_description = inputs[self._input_claim_description_key] + resolved_entities = inputs.get(self._input_resolved_entities_key, {}) + source_doc_map = {} + + prompt_args = { + self._input_entity_spec_key: entity_spec, + self._input_claim_description_key: claim_description, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + } + + all_claims: list[dict] = [] + for doc_index, text in enumerate(texts): + document_id = f"d{doc_index}" + try: + claims = self._process_document(prompt_args, text, doc_index) + all_claims += [ + self._clean_claim(c, document_id, resolved_entities) for c in claims + ] + source_doc_map[document_id] = text + except Exception as e: + log.exception("error extracting claim") + self._on_error( + e, + traceback.format_exc(), + {"doc_index": doc_index, "text": text}, + ) + continue + + return ClaimExtractorResult( + output=all_claims, + source_docs=source_doc_map, + ) + + def _clean_claim( + self, claim: dict, document_id: str, resolved_entities: dict + ) -> dict: + # clean the parsed claims to remove any claims with status = False + obj = claim.get("object_id", claim.get("object")) + subject = claim.get("subject_id", claim.get("subject")) + + # If subject or object in resolved entities, then replace with resolved entity + obj = resolved_entities.get(obj, obj) + subject = resolved_entities.get(subject, subject) + claim["object_id"] = obj + claim["subject_id"] = subject + claim["doc_id"] = document_id + return claim + + def _process_document( + self, prompt_args: dict, doc, doc_index: int + ) -> list[dict]: + record_delimiter = prompt_args.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_args.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + variables = { + self._input_text_key: doc, + **prompt_args, + } + text = perform_variable_replacements(self._extraction_prompt, variables=variables) + gen_conf = {"temperature": 0.5} + results = self._llm.chat(text, [], gen_conf) + claims = results.strip().removesuffix(completion_delimiter) + history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}] + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) + history.append({"role": "user", "content": text}) + extension = self._llm.chat("", history, gen_conf) + claims += record_delimiter + extension.strip().removesuffix( + completion_delimiter + ) + + # If this isn't the last loop, check to see if we should continue + if i >= self._max_gleanings - 1: + break + + history.append({"role": "assistant", "content": extension}) + history.append({"role": "user", "content": LOOP_PROMPT}) + continuation = self._llm.chat("", history, self._loop_args) + if continuation != "YES": + break + + result = self._parse_claim_tuples(claims, prompt_args) + for r in result: + r["doc_id"] = f"{doc_index}" + return result + + def _parse_claim_tuples( + self, claims: str, prompt_variables: dict + ) -> list[dict[str, Any]]: + """Parse claim tuples.""" + record_delimiter = prompt_variables.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_variables.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + tuple_delimiter = prompt_variables.get( + self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER + ) + + def pull_field(index: int, fields: list[str]) -> str | None: + return fields[index].strip() if len(fields) > index else None + + result: list[dict[str, Any]] = [] + claims_values = ( + claims.strip().removesuffix(completion_delimiter).split(record_delimiter) + ) + for claim in claims_values: + claim = claim.strip().removeprefix("(").removesuffix(")") + claim = re.sub(r".*Output:", "", claim) + + # Ignore the completion delimiter + if claim == completion_delimiter: + continue + + claim_fields = claim.split(tuple_delimiter) + o = { + "subject_id": pull_field(0, claim_fields), + "object_id": pull_field(1, claim_fields), + "type": pull_field(2, claim_fields), + "status": pull_field(3, claim_fields), + "start_date": pull_field(4, claim_fields), + "end_date": pull_field(5, claim_fields), + "description": pull_field(6, claim_fields), + "source_text": pull_field(7, claim_fields), + "doc_id": pull_field(8, claim_fields), + } + if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]): + continue + result.append(o) + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) + parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) + args = parser.parse_args() + + from api.db import LLMType + from api.db.services.llm_service import LLMBundle + from api.settings import retrievaler + + ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) + docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])] + info = { + "input_text": docs, + "entity_specs": "organization, person", + "claim_description": "" + } + claim = ex(info) + print(json.dumps(claim.output, ensure_ascii=False, indent=2)) diff --git a/graphrag/claim_prompt.py b/graphrag/claim_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..5678ffd1348ddf96686fe08a5a121214f09aa81e --- /dev/null +++ b/graphrag/claim_prompt.py @@ -0,0 +1,84 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" + +CLAIM_EXTRACTION_PROMPT = """ +################ +-Target activity- +################ +You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document. + +################ +-Goal- +################ +Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities. + +################ +-Steps- +################ + - 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types. + - 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim. + For each claim, extract the following information: + - Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1. + - Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**. + - Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type + - Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified. + - Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references. + - Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. + - Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. + + - 3. Format each claim as ({tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + - 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + - 5. If there's nothing satisfy the above requirements, just keep output empty. + - 6. When finished, output {completion_delimiter} + +################ +-Examples- +################ +Example 1: +Entity specification: organization +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{completion_delimiter} + +########################### +Example 2: +Entity specification: Company A, Person C +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{record_delimiter} +(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) +{completion_delimiter} + +################ +-Real Data- +################ +Use the following input for your answer. +Entity specification: {entity_specs} +Claim description: {claim_description} +Text: {input_text} +Output:""" + + +CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: " +LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n" \ No newline at end of file diff --git a/graphrag/community_report_prompt.py b/graphrag/community_report_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..3b042d08b4fab6367e9e7e74b2bb9cde2d8f4292 --- /dev/null +++ b/graphrag/community_report_prompt.py @@ -0,0 +1,171 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" + +COMMUNITY_REPORT_PROMPT = """ +You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. + +# Goal +Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. + +# Report Structure + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content): + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +# Example Input +----------- +Text: + +-Entities- + +id,entity,description +5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March +6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza + +-Relationships- + +id,source,target,description +37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March +38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza +39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza +40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza +41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march +43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March + +Output: +{{ + "title": "Verdant Oasis Plaza and Unity March", + "summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", + "rating": 5.0, + "rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.", + "findings": [ + {{ + "summary": "Verdant Oasis Plaza as the central location", + "explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]" + }}, + {{ + "summary": "Harmony Assembly's role in the community", + "explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]" + }}, + {{ + "summary": "Unity March as a significant event", + "explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]" + }}, + {{ + "summary": "Role of Tribune Spotlight", + "explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]" + }} + ] +}} + + +# Real Data + +Use the following text for your answer. Do not make anything up in your answer. + +Text: + +-Entities- +{entity_df} + +-Relationships- +{relation_df} + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content): + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + +Output:""" \ No newline at end of file diff --git a/graphrag/community_reports_extractor.py b/graphrag/community_reports_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc0c2e5d52cc11bad19b3dff8e6c4785a6b3c0f --- /dev/null +++ b/graphrag/community_reports_extractor.py @@ -0,0 +1,135 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" + +import json +import logging +import re +import traceback +from dataclasses import dataclass +from typing import Any, List + +import networkx as nx +import pandas as pd + +from graphrag import leiden +from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT +from graphrag.leiden import add_community_info2graph +from rag.llm.chat_model import Base as CompletionLLM +from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types + +log = logging.getLogger(__name__) + + +@dataclass +class CommunityReportsResult: + """Community reports result class definition.""" + + output: List[str] + structured_output: List[dict] + + +class CommunityReportsExtractor: + """Community reports extractor class definition.""" + + _llm: CompletionLLM + _extraction_prompt: str + _output_formatter_prompt: str + _on_error: ErrorHandlerFn + _max_report_length: int + + def __init__( + self, + llm_invoker: CompletionLLM, + extraction_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_report_length: int | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_report_length = max_report_length or 1500 + + def __call__(self, graph: nx.Graph): + communities: dict[str, dict[str, List]] = leiden.run(graph, {}) + relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) + res_str = [] + res_dict = [] + for level, comm in communities.items(): + for cm_id, ents in comm.items(): + weight = ents["weight"] + ents = ents["nodes"] + ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents]) + rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True) + + prompt_variables = { + "entity_df": ent_df.to_csv(index_label="id"), + "relation_df": rela_df.to_csv(index_label="id") + } + text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) + gen_conf = {"temperature": 0.5} + try: + response = self._llm.chat(text, [], gen_conf) + response = re.sub(r"^[^\{]*", "", response) + response = re.sub(r"[^\}]*$", "", response) + print(response) + response = json.loads(response) + if not dict_has_keys_with_types(response, [ + ("title", str), + ("summary", str), + ("findings", list), + ("rating", float), + ("rating_explanation", str), + ]): continue + response["weight"] = weight + response["entities"] = ents + except Exception as e: + print("ERROR: ", traceback.format_exc()) + self._on_error(e, traceback.format_exc(), None) + continue + + add_community_info2graph(graph, ents, response["title"]) + res_str.append(self._get_text_output(response)) + res_dict.append(response) + + return CommunityReportsResult( + structured_output=res_dict, + output=res_str, + ) + + def _get_text_output(self, parsed_output: dict) -> str: + title = parsed_output.get("title", "Report") + summary = parsed_output.get("summary", "") + findings = parsed_output.get("findings", []) + + def finding_summary(finding: dict): + if isinstance(finding, str): + return finding + return finding.get("summary") + + def finding_explanation(finding: dict): + if isinstance(finding, str): + return "" + return finding.get("explanation") + + report_sections = "\n\n".join( + f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings + ) + return f"# {title}\n\n{summary}\n\n{report_sections}" \ No newline at end of file diff --git a/graphrag/description_summary.py b/graphrag/description_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..b49ecb022bb5ca3c1d763215e03f6f444d8bb95a --- /dev/null +++ b/graphrag/description_summary.py @@ -0,0 +1,167 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" + +import argparse +import html +import json +import logging +import numbers +import re +import traceback +from collections.abc import Callable +from dataclasses import dataclass + +from graphrag.utils import ErrorHandlerFn, perform_variable_replacements +from rag.llm.chat_model import Base as CompletionLLM +import networkx as nx + +from rag.utils import num_tokens_from_string + +SUMMARIZE_PROMPT = """ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Entities: {entity_name} +Description List: {description_list} +####### +Output: +""" + +# Max token size for input prompts +DEFAULT_MAX_INPUT_TOKENS = 4_000 +# Max token count for LLM answers +DEFAULT_MAX_SUMMARY_LENGTH = 128 + + +@dataclass +class SummarizationResult: + """Unipartite graph extraction result class definition.""" + + items: str | tuple[str, str] + description: str + + +class SummarizeExtractor: + """Unipartite graph extractor class definition.""" + + _llm: CompletionLLM + _entity_name_key: str + _input_descriptions_key: str + _summarization_prompt: str + _on_error: ErrorHandlerFn + _max_summary_length: int + _max_input_tokens: int + + def __init__( + self, + llm_invoker: CompletionLLM, + entity_name_key: str | None = None, + input_descriptions_key: str | None = None, + summarization_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_summary_length: int | None = None, + max_input_tokens: int | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._entity_name_key = entity_name_key or "entity_name" + self._input_descriptions_key = input_descriptions_key or "description_list" + + self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH + self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS + + def __call__( + self, + items: str | tuple[str, str], + descriptions: list[str], + ) -> SummarizationResult: + """Call method definition.""" + result = "" + if len(descriptions) == 0: + result = "" + if len(descriptions) == 1: + result = descriptions[0] + else: + result = self._summarize_descriptions(items, descriptions) + + return SummarizationResult( + items=items, + description=result or "", + ) + + def _summarize_descriptions( + self, items: str | tuple[str, str], descriptions: list[str] + ) -> str: + """Summarize descriptions into a single description.""" + sorted_items = sorted(items) if isinstance(items, list) else items + + # Safety check, should always be a list + if not isinstance(descriptions, list): + descriptions = [descriptions] + + # Iterate over descriptions, adding all until the max input tokens is reached + usable_tokens = self._max_input_tokens - num_tokens_from_string( + self._summarization_prompt + ) + descriptions_collected = [] + result = "" + + for i, description in enumerate(descriptions): + usable_tokens -= num_tokens_from_string(description) + descriptions_collected.append(description) + + # If buffer is full, or all descriptions have been added, summarize + if (usable_tokens < 0 and len(descriptions_collected) > 1) or ( + i == len(descriptions) - 1 + ): + # Calculate result (final or partial) + result = await self._summarize_descriptions_with_llm( + sorted_items, descriptions_collected + ) + + # If we go for another loop, reset values to new + if i != len(descriptions) - 1: + descriptions_collected = [result] + usable_tokens = ( + self._max_input_tokens + - num_tokens_from_string(self._summarization_prompt) + - num_tokens_from_string(result) + ) + + return result + + def _summarize_descriptions_with_llm( + self, items: str | tuple[str, str] | list[str], descriptions: list[str] + ): + """Summarize descriptions using the LLM.""" + variables = { + self._entity_name_key: json.dumps(items), + self._input_descriptions_key: json.dumps(sorted(descriptions)), + } + text = perform_variable_replacements(self._summarization_prompt, variables=variables) + return self._llm.chat("", [{"role": "user", "content": text}]) diff --git a/graphrag/entity_embedding.py b/graphrag/entity_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7e247a8d186682b17275775f1ea71d3da6c0b278 --- /dev/null +++ b/graphrag/entity_embedding.py @@ -0,0 +1,78 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" + +from typing import Any + +import numpy as np +import networkx as nx +from graphrag.leiden import stable_largest_connected_component + + +@dataclass +class NodeEmbeddings: + """Node embeddings class definition.""" + + nodes: list[str] + embeddings: np.ndarray + + +def embed_nod2vec( + graph: nx.Graph | nx.DiGraph, + dimensions: int = 1536, + num_walks: int = 10, + walk_length: int = 40, + window_size: int = 2, + iterations: int = 3, + random_seed: int = 86, +) -> NodeEmbeddings: + """Generate node embeddings using Node2Vec.""" + # generate embedding + lcc_tensors = gc.embed.node2vec_embed( # type: ignore + graph=graph, + dimensions=dimensions, + window_size=window_size, + iterations=iterations, + num_walks=num_walks, + walk_length=walk_length, + random_seed=random_seed, + ) + return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1]) + + +def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: + """Run method definition.""" + if args.get("use_lcc", True): + graph = stable_largest_connected_component(graph) + + # create graph embedding using node2vec + embeddings = embed_nod2vec( + graph=graph, + dimensions=args.get("dimensions", 1536), + num_walks=args.get("num_walks", 10), + walk_length=args.get("walk_length", 40), + window_size=args.get("window_size", 2), + iterations=args.get("iterations", 3), + random_seed=args.get("random_seed", 86), + ) + + pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True) + sorted_pairs = sorted(pairs, key=lambda x: x[0]) + + return dict(sorted_pairs) \ No newline at end of file diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..909207682dfd0baba9b5ac94175bb177360f34a3 --- /dev/null +++ b/graphrag/entity_resolution.py @@ -0,0 +1,212 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import re +import traceback +from dataclasses import dataclass +from typing import Any + +import networkx as nx +from rag.nlp import is_english +import editdistance +from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT +from rag.llm.chat_model import Base as CompletionLLM +from graphrag.utils import ErrorHandlerFn, perform_variable_replacements + +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" +DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&" + + +@dataclass +class EntityResolutionResult: + """Entity resolution result class definition.""" + + output: nx.Graph + + +class EntityResolution: + """Entity resolution class definition.""" + + _llm: CompletionLLM + _resolution_prompt: str + _output_formatter_prompt: str + _on_error: ErrorHandlerFn + _record_delimiter_key: str + _entity_index_delimiter_key: str + _resolution_result_delimiter_key: str + + def __init__( + self, + llm_invoker: CompletionLLM, + resolution_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + record_delimiter_key: str | None = None, + entity_index_delimiter_key: str | None = None, + resolution_result_delimiter_key: str | None = None, + input_text_key: str | None = None + ): + """Init method definition.""" + self._llm = llm_invoker + self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter" + self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter" + self._input_text_key = input_text_key or "input_text" + + def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + + # Wire defaults into the prompt variables + prompt_variables = { + **prompt_variables, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key) + or DEFAULT_ENTITY_INDEX_DELIMITER, + self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key) + or DEFAULT_RESOLUTION_RESULT_DELIMITER, + } + + nodes = graph.nodes + entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes)) + node_clusters = {entity_type: [] for entity_type in entity_types} + + for node in nodes: + node_clusters[graph.nodes[node]['entity_type']].append(node) + + candidate_resolution = {entity_type: [] for entity_type in entity_types} + for node_cluster in node_clusters.items(): + candidate_resolution_tmp = [] + for a in node_cluster[1]: + for b in node_cluster[1]: + if a == b: + continue + if self.is_similarity(a, b) and (b, a) not in candidate_resolution_tmp: + candidate_resolution_tmp.append((a, b)) + if candidate_resolution_tmp: + candidate_resolution[node_cluster[0]] = candidate_resolution_tmp + + gen_conf = {"temperature": 0.5} + resolution_result = set() + for candidate_resolution_i in candidate_resolution.items(): + if candidate_resolution_i[1]: + try: + pair_txt = [ + f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] + for index, candidate in enumerate(candidate_resolution_i[1]): + pair_txt.append( + f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}') + sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions' + pair_txt.append( + f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)') + pair_prompt = '\n'.join(pair_txt) + + variables = { + **prompt_variables, + self._input_text_key: pair_prompt + } + text = perform_variable_replacements(self._resolution_prompt, variables=variables) + + response = self._llm.chat(text, [], gen_conf) + result = self._process_results(len(candidate_resolution_i[1]), response, + prompt_variables.get(self._record_delimiter_key, + DEFAULT_RECORD_DELIMITER), + prompt_variables.get(self._entity_index_dilimiter_key, + DEFAULT_ENTITY_INDEX_DELIMITER), + prompt_variables.get(self._resolution_result_delimiter_key, + DEFAULT_RESOLUTION_RESULT_DELIMITER)) + for result_i in result: + resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) + except Exception as e: + logging.exception("error entity resolution") + self._on_error(e, traceback.format_exc(), None) + + connect_graph = nx.Graph() + connect_graph.add_edges_from(resolution_result) + for sub_connect_graph in nx.connected_components(connect_graph): + sub_connect_graph = connect_graph.subgraph(sub_connect_graph) + remove_nodes = list(sub_connect_graph.nodes) + keep_node = remove_nodes.pop() + for remove_node in remove_nodes: + remove_node_neighbors = graph[remove_node] + graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description'] + graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight'] + remove_node_neighbors = list(remove_node_neighbors) + for remove_node_neighbor in remove_node_neighbors: + if remove_node_neighbor == keep_node: + graph.remove_edge(keep_node, remove_node) + continue + if graph.has_edge(keep_node, remove_node_neighbor): + graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][ + 'weight'] + graph[keep_node][remove_node_neighbor]['description'] += \ + graph[remove_node][remove_node_neighbor]['description'] + graph.remove_edge(remove_node, remove_node_neighbor) + else: + graph.add_edge(keep_node, remove_node_neighbor, + weight=graph[remove_node][remove_node_neighbor]['weight'], + description=graph[remove_node][remove_node_neighbor]['description'], + source_id="") + graph.remove_edge(remove_node, remove_node_neighbor) + graph.remove_node(remove_node) + + for node_degree in graph.degree: + graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) + + return EntityResolutionResult( + output=graph, + ) + + def _process_results( + self, + records_length: int, + results: str, + record_delimiter: str, + entity_index_delimiter: str, + resolution_result_delimiter: str + ) -> list: + ans_list = [] + records = [r.strip() for r in results.split(record_delimiter)] + for record in records: + pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}" + match_int = re.search(pattern_int, record) + res_int = int(str(match_int.group(1) if match_int else '0')) + if res_int > records_length: + continue + + pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}" + match_bool = re.search(pattern_bool, record) + res_bool = str(match_bool.group(1) if match_bool else '') + + if res_int and res_bool: + if res_bool.lower() == 'yes': + ans_list.append((res_int, "yes")) + + return ans_list + + def is_similarity(self, a, b): + if is_english(a) and is_english(b): + if editdistance.eval(a, b) <= min(len(a), len(b)) // 2: + return True + + if len(set(a) & set(b)) > 0: + return True + + return False diff --git a/graphrag/entity_resolution_prompt.py b/graphrag/entity_resolution_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a360dd5da1b1036a74541469fdbbe9c615670b --- /dev/null +++ b/graphrag/entity_resolution_prompt.py @@ -0,0 +1,74 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ENTITY_RESOLUTION_PROMPT = """ +-Goal- +Please answer the following Question as required + +-Steps- +1. Identify each line of questioning as required + +2. Return output in English as a single list of each line answer in steps 1. Use **{record_delimiter}** as the list delimiter. + +###################### +-Examples- +###################### +Example 1: + +Question: +When determining whether two Products are the same, you should only focus on critical properties and overlook noisy factors. + +Demonstration 1: name of Product A is : "computer", name of Product B is :"phone" No, Product A and Product B are different products. +Question 1: name of Product A is : "television", name of Product B is :"TV" +Question 2: name of Product A is : "cup", name of Product B is :"mug" +Question 3: name of Product A is : "soccer", name of Product B is :"football" +Question 4: name of Product A is : "pen", name of Product B is :"eraser" + +Use domain knowledge of Products to help understand the text and answer the above 4 questions in the format: For Question i, Yes, Product A and Product B are the same product. or No, Product A and Product B are different products. For Question i+1, (repeat the above procedures) +################ +Output: +(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter} +(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter} +(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, Product A and Product B are the same product.){record_delimiter} +(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter} +############################# + +Example 2: + +Question: +When determining whether two toponym are the same, you should only focus on critical properties and overlook noisy factors. + +Demonstration 1: name of toponym A is : "nanjing", name of toponym B is :"nanjing city" No, toponym A and toponym B are same toponym. +Question 1: name of toponym A is : "Chicago", name of toponym B is :"ChiTown" +Question 2: name of toponym A is : "Shanghai", name of toponym B is :"Zhengzhou" +Question 3: name of toponym A is : "Beijing", name of toponym B is :"Peking" +Question 4: name of toponym A is : "Los Angeles", name of toponym B is :"Cleveland" + +Use domain knowledge of toponym to help understand the text and answer the above 4 questions in the format: For Question i, Yes, toponym A and toponym B are the same toponym. or No, toponym A and toponym B are different toponym. For Question i+1, (repeat the above procedures) +################ +Output: +(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are same toponym.){record_delimiter} +(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter} +(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are the same toponym.){record_delimiter} +(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter} +############################# + +-Real Data- +###################### +Question:{input_text} +###################### +Output: +""" diff --git a/graphrag/graph_extractor.py b/graphrag/graph_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ffaf2cc02b9830e23d2851b6ae1414b033181e --- /dev/null +++ b/graphrag/graph_extractor.py @@ -0,0 +1,319 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" +import logging +import numbers +import re +import traceback +from dataclasses import dataclass +from typing import Any, Mapping +import tiktoken +from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT +from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str +from rag.llm.chat_model import Base as CompletionLLM +import networkx as nx +from rag.utils import num_tokens_from_string + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"] +ENTITY_EXTRACTION_MAX_GLEANINGS = 1 + + +@dataclass +class GraphExtractionResult: + """Unipartite graph extraction result class definition.""" + + output: nx.Graph + source_docs: dict[Any, Any] + + +class GraphExtractor: + """Unipartite graph extractor class definition.""" + + _llm: CompletionLLM + _join_descriptions: bool + _tuple_delimiter_key: str + _record_delimiter_key: str + _entity_types_key: str + _input_text_key: str + _completion_delimiter_key: str + _entity_name_key: str + _input_descriptions_key: str + _extraction_prompt: str + _summarization_prompt: str + _loop_args: dict[str, Any] + _max_gleanings: int + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: CompletionLLM, + prompt: str | None = None, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + input_text_key: str | None = None, + entity_types_key: str | None = None, + completion_delimiter_key: str | None = None, + join_descriptions=True, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._join_descriptions = join_descriptions + self._input_text_key = input_text_key or "input_text" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._entity_types_key = entity_types_key or "entity_types" + self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT + self._max_gleanings = ( + max_gleanings + if max_gleanings is not None + else ENTITY_EXTRACTION_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + self.prompt_token_count = num_tokens_from_string(self._extraction_prompt) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = encoding.encode("YES") + no = encoding.encode("NO") + self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + + def __call__( + self, texts: list[str], prompt_variables: dict[str, Any] | None = None + ) -> GraphExtractionResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + all_records: dict[int, str] = {} + source_doc_map: dict[int, str] = {} + + # Wire defaults into the prompt variables + prompt_variables = { + **prompt_variables, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + self._entity_types_key: ",".join( + prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES + ), + } + + for doc_index, text in enumerate(texts): + try: + # Invoke the entity extraction + result = self._process_document(text, prompt_variables) + source_doc_map[doc_index] = text + all_records[doc_index] = result + except Exception as e: + logging.exception("error extracting graph") + self._on_error( + e, + traceback.format_exc(), + { + "doc_index": doc_index, + "text": text, + }, + ) + + output = self._process_results( + all_records, + prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), + prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), + ) + + return GraphExtractionResult( + output=output, + source_docs=source_doc_map, + ) + + def _process_document( + self, text: str, prompt_variables: dict[str, str] + ) -> str: + variables = { + **prompt_variables, + self._input_text_key: text, + } + text = perform_variable_replacements(self._extraction_prompt, variables=variables) + gen_conf = {"temperature": 0.5} + response = self._llm.chat(text, [], gen_conf) + + results = response or "" + history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) + history.append({"role": "user", "content": text}) + response = self._llm.chat("", history, gen_conf) + results += response or "" + + # if this is the final glean, don't bother updating the continuation flag + if i >= self._max_gleanings - 1: + break + history.append({"role": "assistant", "content": response}) + history.append({"role": "user", "content": LOOP_PROMPT}) + continuation = self._llm.chat("", history, self._loop_args) + if continuation != "YES": + break + + return results + + def _process_results( + self, + results: dict[int, str], + tuple_delimiter: str, + record_delimiter: str, + ) -> nx.Graph: + """Parse the result string to create an undirected unipartite graph. + + Args: + - results - dict of results from the extraction chain + - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' + - record_delimiter - delimiter between records, default is '##' + Returns: + - output - unipartite graph in graphML format + """ + graph = nx.Graph() + for source_doc_id, extracted_data in results.items(): + records = [r.strip() for r in extracted_data.split(record_delimiter)] + + for record in records: + record = re.sub(r"^\(|\)$", "", record.strip()) + record_attributes = record.split(tuple_delimiter) + + if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: + # add this record as a node in the G + entity_name = clean_str(record_attributes[1].upper()) + entity_type = clean_str(record_attributes[2].upper()) + entity_description = clean_str(record_attributes[3]) + + if entity_name in graph.nodes(): + node = graph.nodes[entity_name] + if self._join_descriptions: + node["description"] = "\n".join( + list({ + *_unpack_descriptions(node), + entity_description, + }) + ) + else: + if len(entity_description) > len(node["description"]): + node["description"] = entity_description + node["source_id"] = ", ".join( + list({ + *_unpack_source_ids(node), + str(source_doc_id), + }) + ) + node["entity_type"] = ( + entity_type if entity_type != "" else node["entity_type"] + ) + else: + graph.add_node( + entity_name, + entity_type=entity_type, + description=entity_description, + source_id=str(source_doc_id), + weight=1 + ) + + if ( + record_attributes[0] == '"relationship"' + and len(record_attributes) >= 5 + ): + # add this record as edge + source = clean_str(record_attributes[1].upper()) + target = clean_str(record_attributes[2].upper()) + edge_description = clean_str(record_attributes[3]) + edge_source_id = clean_str(str(source_doc_id)) + weight = ( + float(record_attributes[-1]) + if isinstance(record_attributes[-1], numbers.Number) + else 1.0 + ) + if source not in graph.nodes(): + graph.add_node( + source, + entity_type="", + description="", + source_id=edge_source_id, + weight=1 + ) + if target not in graph.nodes(): + graph.add_node( + target, + entity_type="", + description="", + source_id=edge_source_id, + weight=1 + ) + if graph.has_edge(source, target): + edge_data = graph.get_edge_data(source, target) + if edge_data is not None: + weight += edge_data["weight"] + if self._join_descriptions: + edge_description = "\n".join( + list({ + *_unpack_descriptions(edge_data), + edge_description, + }) + ) + edge_source_id = ", ".join( + list({ + *_unpack_source_ids(edge_data), + str(source_doc_id), + }) + ) + graph.add_edge( + source, + target, + weight=weight, + description=edge_description, + source_id=edge_source_id, + ) + + for node_degree in graph.degree: + graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) + return graph + + +def _unpack_descriptions(data: Mapping) -> list[str]: + value = data.get("description", None) + return [] if value is None else value.split("\n") + + +def _unpack_source_ids(data: Mapping) -> list[str]: + value = data.get("source_id", None) + return [] if value is None else value.split(", ") + + + diff --git a/graphrag/graph_prompt.py b/graphrag/graph_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..11f3a6aee2625a12fe470eab9d5bc80419f73551 --- /dev/null +++ b/graphrag/graph_prompt.py @@ -0,0 +1,121 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" +GRAPH_EXTRACTION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity + Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: + +Entity_types: [person, technology, mission, organization, location] +Text: +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. + +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths +################ +Output: +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} +("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} +("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter} +############################# +Example 2: + +Entity_types: [person, technology, mission, organization, location] +Text: +They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve. + +Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril. + +Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly +############# +Output: +("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter} +("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter} +("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter} +("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter} +############################# +Example 3: + +Entity_types: [person, role, technology, organization, event, location, concept] +Text: +their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. + +"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." + +Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." + +Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. + +The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation +############# +Output: +("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter} +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter} +("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter} +("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter} +("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter} +("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter} +("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter} +############################# +-Real Data- +###################### +Entity_types: {entity_types} +Text: {input_text} +###################### +Output:""" + +CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n" +LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n" \ No newline at end of file diff --git a/graphrag/index.py b/graphrag/index.py new file mode 100644 index 0000000000000000000000000000000000000000..a608347198f11de133f1f9c9cfbe6f37da9b41ff --- /dev/null +++ b/graphrag/index.py @@ -0,0 +1,160 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +from concurrent.futures import ThreadPoolExecutor +import json +from functools import reduce +from typing import List +import networkx as nx +from api.db import LLMType +from api.db.services.llm_service import LLMBundle +from graphrag.community_reports_extractor import CommunityReportsExtractor +from graphrag.entity_resolution import EntityResolution +from graphrag.graph_extractor import GraphExtractor +from graphrag.mind_map_extractor import MindMapExtractor +from rag.nlp import rag_tokenizer +from rag.utils import num_tokens_from_string + + +def be_children(obj: dict): + arr = [] + for k,v in obj.items(): + k = re.sub(r"\*+", "", k) + if not k :continue + arr.append({ + "id": k, + "children": be_children(v) if isinstance(v, dict) else [] + }) + return arr + + +def graph_merge(g1, g2): + g = g2.copy() + for n, attr in g1.nodes(data=True): + if n not in g2.nodes(): + g2.add_node(n, **attr) + continue + + g.nodes[n]["weight"] += 1 + if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0: + g.nodes[n]["description"] += "\n" + attr["description"] + + for source, target, attr in g1.edges(data=True): + if g.has_edge(source, target): + g[source][target].update({"weight": attr["weight"]+1}) + continue + g.add_edge(source, target, **attr) + + for node_degree in g.degree: + g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) + return g + + +def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=["organization", "person", "location", "event", "time"]): + llm_bdl = LLMBundle(tenant_id, LLMType.CHAT) + ext = GraphExtractor(llm_bdl) + left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024 + left_token_count = llm_bdl.max_length * 0.4 + + assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})" + + texts, graphs = [], [] + cnt = 0 + threads = [] + exe = ThreadPoolExecutor(max_workers=12) + for i in range(len(chunks[:512])): + tkn_cnt = num_tokens_from_string(chunks[i]) + if cnt+tkn_cnt >= left_token_count and texts: + threads.append(exe.submit(ext, texts, {"entity_types": entity_types})) + texts = [] + cnt = 0 + texts.append(chunks[i]) + cnt += tkn_cnt + if texts: + threads.append(exe.submit(ext, texts)) + + callback(0.5, "Extracting entities.") + graphs = [] + for i, _ in enumerate(threads): + graphs.append(_.result().output) + callback(0.5 + 0.1*i/len(threads)) + + graph = reduce(graph_merge, graphs) + er = EntityResolution(llm_bdl) + graph = er(graph).output + + _chunks = chunks + chunks = [] + for n, attr in graph.nodes(data=True): + if attr.get("rank", 0) == 0: + print(f"Ignore entity: {n}") + continue + chunk = { + "name_kwd": n, + "important_kwd": [n], + "title_tks": rag_tokenizer.tokenize(n), + "content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False), + "content_ltks": rag_tokenizer.tokenize(attr["description"]), + "knowledge_graph_kwd": "entity", + "rank_int": attr["rank"], + "weight_int": attr["weight"] + } + chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) + chunks.append(chunk) + + callback(0.6, "Extracting community reports.") + cr = CommunityReportsExtractor(llm_bdl) + cr = cr(graph) + for community, desc in zip(cr.structured_output, cr.output): + chunk = { + "title_tks": rag_tokenizer.tokenize(community["title"]), + "content_with_weight": desc, + "content_ltks": rag_tokenizer.tokenize(desc), + "knowledge_graph_kwd": "community_report", + "weight_flt": community["weight"], + "entities_kwd": community["entities"], + "important_kwd": community["entities"] + } + chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) + chunks.append(chunk) + + chunks.append( + { + "content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2), + "knowledge_graph_kwd": "graph" + }) + + callback(0.75, "Extracting mind graph.") + mindmap = MindMapExtractor(llm_bdl) + mg = mindmap(_chunks).output + if not len(mg.keys()): return chunks + + if len(mg.keys()) > 1: md_map = {"id": "root", "children": [{"id": re.sub(r"\*+", "", k), "children": be_children(v)} for k,v in mg.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)]} + else: md_map = {"id": re.sub(r"\*+", "", list(mg.keys())[0]), "children": be_children(list(mg.items())[1])} + print(json.dumps(md_map, ensure_ascii=False, indent=2)) + chunks.append( + { + "content_with_weight": json.dumps(md_map, ensure_ascii=False, indent=2), + "knowledge_graph_kwd": "mind_map" + }) + + return chunks + + + + + + diff --git a/graphrag/leiden.py b/graphrag/leiden.py new file mode 100644 index 0000000000000000000000000000000000000000..945ef585e15522bb18100c1adf9766ca0b7ca839 --- /dev/null +++ b/graphrag/leiden.py @@ -0,0 +1,160 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" + +import logging +from typing import Any, cast, List +import html +from graspologic.partition import hierarchical_leiden +from graspologic.utils import largest_connected_component + +import networkx as nx + +log = logging.getLogger(__name__) + + +def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + """Ensure an undirected graph with the same relationships will always be read the same way.""" + fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() + + sorted_nodes = graph.nodes(data=True) + sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) + + fixed_graph.add_nodes_from(sorted_nodes) + edges = list(graph.edges(data=True)) + + # If the graph is undirected, we create the edges in a stable way, so we get the same results + # for example: + # A -> B + # in graph theory is the same as + # B -> A + # in an undirected graph + # however, this can lead to downstream issues because sometimes + # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A] + # but they base some of their logic on the order of the nodes, so the order ends up being important + # so we sort the nodes in the edge in a stable way, so that we always get the same order + if not graph.is_directed(): + + def _sort_source_target(edge): + source, target, edge_data = edge + if source > target: + temp = source + source = target + target = temp + return source, target, edge_data + + edges = [_sort_source_target(edge) for edge in edges] + + def _get_edge_key(source: Any, target: Any) -> str: + return f"{source} -> {target}" + + edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) + + fixed_graph.add_edges_from(edges) + return fixed_graph + + +def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: + """Normalize node names.""" + node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore + return nx.relabel_nodes(graph, node_mapping) + + +def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + """Return the largest connected component of the graph, with nodes and edges sorted in a stable way.""" + graph = graph.copy() + graph = cast(nx.Graph, largest_connected_component(graph)) + graph = normalize_node_names(graph) + return _stabilize_graph(graph) + + +def _compute_leiden_communities( + graph: nx.Graph | nx.DiGraph, + max_cluster_size: int, + use_lcc: bool, + seed=0xDEADBEEF, +) -> dict[int, dict[str, int]]: + """Return Leiden root communities.""" + if use_lcc: + graph = stable_largest_connected_component(graph) + + community_mapping = hierarchical_leiden( + graph, max_cluster_size=max_cluster_size, random_seed=seed + ) + results: dict[int, dict[str, int]] = {} + for partition in community_mapping: + results[partition.level] = results.get(partition.level, {}) + results[partition.level][partition.node] = partition.cluster + + return results + + +def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: + """Run method definition.""" + max_cluster_size = args.get("max_cluster_size", 12) + use_lcc = args.get("use_lcc", True) + if args.get("verbose", False): + log.info( + "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc + ) + if not graph.nodes(): return {} + + node_id_to_community_map = _compute_leiden_communities( + graph=graph, + max_cluster_size=max_cluster_size, + use_lcc=use_lcc, + seed=args.get("seed", 0xDEADBEEF), + ) + levels = args.get("levels") + + # If they don't pass in levels, use them all + if levels is None: + levels = sorted(node_id_to_community_map.keys()) + + results_by_level: dict[int, dict[str, list[str]]] = {} + for level in levels: + result = {} + results_by_level[level] = result + for node_id, raw_community_id in node_id_to_community_map[level].items(): + community_id = str(raw_community_id) + if community_id not in result: + result[community_id] = {"weight": 0, "nodes": []} + result[community_id]["nodes"].append(node_id) + result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1) + weights = [comm["weight"] for _, comm in result.items()] + if not weights:continue + max_weight = max(weights) + for _, comm in result.items(): comm["weight"] /= max_weight + + return results_by_level + + +def add_community_info2graph(graph: nx.Graph, commu_info: dict[str, dict[str, dict]]): + for lev, cluster_info in commu_info.items(): + for cid, nodes in cluster_info.items(): + for n in nodes["nodes"]: + if "community" not in graph.nodes[n]: graph.nodes[n]["community"] = {} + graph.nodes[n]["community"].update({lev: cid}) + + +def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title): + for n in nodes: + if "communities" not in graph.nodes[n]: + graph.nodes[n]["communities"] = [] + graph.nodes[n]["communities"].append(community_title) diff --git a/graphrag/mind_map_extractor.py b/graphrag/mind_map_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..e4daae5ebd6b4cfece7d46bac11a05bd97147cd6 --- /dev/null +++ b/graphrag/mind_map_extractor.py @@ -0,0 +1,137 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import traceback +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any + +from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT +from graphrag.utils import ErrorHandlerFn, perform_variable_replacements +from rag.llm.chat_model import Base as CompletionLLM +import markdown_to_json +from functools import reduce +from rag.utils import num_tokens_from_string + + +@dataclass +class MindMapResult: + """Unipartite Mind Graph result class definition.""" + output: dict + + +class MindMapExtractor: + + _llm: CompletionLLM + _input_text_key: str + _mind_map_prompt: str + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: CompletionLLM, + prompt: str | None = None, + input_text_key: str | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._input_text_key = input_text_key or "input_text" + self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + + def __call__( + self, sections: list[str], prompt_variables: dict[str, Any] | None = None + ) -> MindMapResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + + try: + exe = ThreadPoolExecutor(max_workers=12) + threads = [] + token_count = self._llm.max_length * 0.7 + texts = [] + res = [] + cnt = 0 + for i in range(len(sections)): + section_cnt = num_tokens_from_string(sections[i]) + if cnt + section_cnt >= token_count and texts: + threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) + texts = [] + cnt = 0 + texts.append(sections[i]) + cnt += section_cnt + if texts: + threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) + + for i, _ in enumerate(threads): + res.append(_.result()) + + merge_json = reduce(self._merge, res) + merge_json = self._list_to_kv(merge_json) + except Exception as e: + logging.exception("error mind graph") + self._on_error( + e, + traceback.format_exc(), None + ) + + return MindMapResult(output=merge_json) + + def _merge(self, d1, d2): + for k in d1: + if k in d2: + if isinstance(d1[k], dict) and isinstance(d2[k], dict): + self._merge(d1[k], d2[k]) + elif isinstance(d1[k], list) and isinstance(d2[k], list): + d2[k].extend(d1[k]) + else: + d2[k] = d1[k] + else: + d2[k] = d1[k] + + return d2 + + def _list_to_kv(self, data): + for key, value in data.items(): + if isinstance(value, dict): + self._list_to_kv(value) + elif isinstance(value, list): + new_value = {} + for i in range(len(value)): + if isinstance(value[i], list): + new_value[value[i - 1]] = value[i][0] + data[key] = new_value + else: + continue + return data + + def _process_document( + self, text: str, prompt_variables: dict[str, str] + ) -> str: + variables = { + **prompt_variables, + self._input_text_key: text, + } + text = perform_variable_replacements(self._mind_map_prompt, variables=variables) + gen_conf = {"temperature": 0.5} + response = self._llm.chat(text, [], gen_conf) + print(response) + print("---------------------------------------------------\n", markdown_to_json.dictify(response)) + return dict(markdown_to_json.dictify(response)) diff --git a/graphrag/mind_map_prompt.py b/graphrag/mind_map_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..fac8f8beac6a79e4854c823495259bf1af5acf1c --- /dev/null +++ b/graphrag/mind_map_prompt.py @@ -0,0 +1,42 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +MIND_MAP_EXTRACTION_PROMPT = """ + - Role: You're a talent text processor. + + - Step of task: + 1. Generate a title for user's 'TEXT'。 + 2. Classify the 'TEXT' into sections as you see fit. + 3. If the subject matter is really complex, split them into sub-sections. + + - Output requirement: + - Always try to maximize the number of sub-sections. + - In language of + - MUST IN FORMAT OF MARKDOWN + +Output: +## + <Section Name> + <Section Name> + <Subsection Name> + <Subsection Name> + <Section Name> + <Subsection Name> + +-TEXT- +{input_text} + +Output: +""" \ No newline at end of file diff --git a/graphrag/search.py b/graphrag/search.py new file mode 100644 index 0000000000000000000000000000000000000000..9a95a20a49a475eb713e152cb55a08aa9ac4a043 --- /dev/null +++ b/graphrag/search.py @@ -0,0 +1,109 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +from copy import deepcopy + +import pandas as pd +from elasticsearch_dsl import Q, Search + +from rag.nlp.search import Dealer + + +class KGSearch(Dealer): + def search(self, req, idxnm, emb_mdl=None): + def merge_into_first(sres, title=""): + df,texts = [],[] + for d in sres["hits"]["hits"]: + try: + df.append(json.loads(d["_source"]["content_with_weight"])) + except Exception as e: + texts.append(d["_source"]["content_with_weight"]) + pass + if not df and not texts: return False + if df: + try: + sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv() + except Exception as e: + pass + else: + sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts) + return True + + src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", + "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd", + "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight", + "weight_int", "weight_flt", "rank_int" + ]) + + qst = req.get("question", "") + binary_query, keywords = self.qryr.question(qst, min_match="5%") + binary_query = self._add_filters(binary_query, req) + + ## Entity retrieval + bqry = deepcopy(binary_query) + bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"])) + s = Search() + s = s.query(bqry)[0: 32] + + s = s.to_dict() + q_vec = [] + if req.get("vector"): + assert emb_mdl, "No embedding model selected" + s["knn"] = self._vector( + qst, emb_mdl, req.get( + "similarity", 0.1), 1024) + s["knn"]["filter"] = bqry.to_dict() + q_vec = s["knn"]["query_vector"] + + ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) + entities = [d["name_kwd"] for d in self.es.getSource(ent_res)] + ent_ids = self.es.getDocIds(ent_res) + if merge_into_first(ent_res, "-Entities-"): + ent_ids = ent_ids[0:1] + + ## Community retrieval + bqry = deepcopy(binary_query) + bqry.filter.append(Q("terms", entities_kwd=entities)) + bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"])) + s = Search() + s = s.query(bqry)[0: 32] + s = s.to_dict() + comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) + comm_ids = self.es.getDocIds(comm_res) + if merge_into_first(comm_res, "-Community Report-"): + comm_ids = comm_ids[0:1] + + ## Text content retrieval + bqry = deepcopy(binary_query) + bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"])) + s = Search() + s = s.query(bqry)[0: 6] + s = s.to_dict() + txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) + txt_ids = self.es.getDocIds(comm_res) + if merge_into_first(txt_res, "-Original Content-"): + txt_ids = comm_ids[0:1] + + return self.SearchResult( + total=len(ent_ids) + len(comm_ids) + len(txt_ids), + ids=[*ent_ids, *comm_ids, *txt_ids], + query_vector=q_vec, + aggregation=None, + highlight=None, + field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)}, + keywords=[] + ) + diff --git a/graphrag/smoke.py b/graphrag/smoke.py new file mode 100644 index 0000000000000000000000000000000000000000..b2efbc91fd22a8b80ade600fcba6e4ffbb21b422 --- /dev/null +++ b/graphrag/smoke.py @@ -0,0 +1,52 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import json +from graphrag import leiden +from graphrag.community_reports_extractor import CommunityReportsExtractor +from graphrag.entity_resolution import EntityResolution +from graphrag.graph_extractor import GraphExtractor +from graphrag.leiden import add_community_info2graph + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) + parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) + args = parser.parse_args() + + from api.db import LLMType + from api.db.services.llm_service import LLMBundle + from api.settings import retrievaler + + ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) + docs = [d["content_with_weight"] for d in + retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])] + graph = ex(docs) + + er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) + graph = er(graph.output) + + comm = leiden.run(graph.output, {}) + add_community_info2graph(graph.output, comm) + + # print(json.dumps(nx.node_link_data(graph.output), ensure_ascii=False,indent=2)) + print(json.dumps(comm, ensure_ascii=False, indent=2)) + + cr = CommunityReportsExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) + cr = cr(graph.output) + print("------------------ COMMUNITY REPORT ----------------------\n", cr.output) + print(json.dumps(cr.structured_output, ensure_ascii=False, indent=2)) \ No newline at end of file diff --git a/graphrag/utils.py b/graphrag/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4524362cd0529385a86838ca0af571d9b3b88c2 --- /dev/null +++ b/graphrag/utils.py @@ -0,0 +1,74 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Reference: + - [graphrag](https://github.com/microsoft/graphrag) +""" + +import html +import re +from collections.abc import Callable +from typing import Any + +ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] + + +def perform_variable_replacements( + input: str, history: list[dict]=[], variables: dict | None ={} +) -> str: + """Perform variable replacements on the input string and in a chat log.""" + result = input + + def replace_all(input: str) -> str: + result = input + if variables: + for entry in variables: + result = result.replace(f"{{{entry}}}", variables[entry]) + return result + + result = replace_all(result) + for i in range(len(history)): + entry = history[i] + if entry.get("role") == "system": + history[i]["content"] = replace_all(entry.get("content") or "") + + return result + + +def clean_str(input: Any) -> str: + """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" + # If we get non-string input, just give it back + if not isinstance(input, str): + return input + + result = html.unescape(input.strip()) + # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python + return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result) + + +def dict_has_keys_with_types( + data: dict, expected_fields: list[tuple[str, type]] +) -> bool: + """Return True if the given dictionary has the given keys with the given types.""" + for field, field_type in expected_fields: + if field not in data: + return False + + value = data[field] + if not isinstance(value, field_type): + return False + return True + diff --git a/rag/app/knowledge_graph.py b/rag/app/knowledge_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..a8775f9cc89435ddcf39fd2f2762800ed71c4992 --- /dev/null +++ b/rag/app/knowledge_graph.py @@ -0,0 +1,30 @@ +import re + +from graphrag.index import build_knowlege_graph_chunks +from rag.app import naive +from rag.nlp import rag_tokenizer, tokenize_chunks + + +def chunk(filename, binary, tenant_id, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): + parser_config = kwargs.get( + "parser_config", { + "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": False}) + eng = lang.lower() == "english" + + parser_config["layout_recognize"] = False + sections = naive.chunk(filename, binary, from_page=from_page, to_page=to_page, section_only=True, parser_config=parser_config) + chunks = build_knowlege_graph_chunks(tenant_id, sections, callback, + parser_config.get("entity_types", ["organization", "person", "location", "event", "time"]) + ) + for c in chunks: c["docnm_kwd"] = filename + + doc = { + "docnm_kwd": filename, + "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)), + "knowledge_graph_kwd": "text" + } + doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) + chunks.extend(tokenize_chunks(sections, doc, eng)) + + return chunks \ No newline at end of file diff --git a/rag/app/naive.py b/rag/app/naive.py index a2ba6993c8e8b7a3a629a143f093aff4ad073d38..6c39954c594fd8030255d39e2e1e8ea556491c68 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -273,6 +273,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, raise NotImplementedError( "file type not supported yet(pdf, xlsx, doc, docx, txt supported)") + if kwargs.get("section_only", False): + return [t for t, _ in sections] + st = timer() chunks = naive_merge( sections, int(parser_config.get( diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index cd2a8e86ce847a7c7d6267f748043cd02c1f6e13..af954f9627e791d0781b115b9e5115d5d1a5b8b0 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -228,7 +228,7 @@ def tokenize(d, t, eng): d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) -def tokenize_chunks(chunks, doc, eng, pdf_parser): +def tokenize_chunks(chunks, doc, eng, pdf_parser=None): res = [] # wrap up as es documents for ck in chunks: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 7720330bef2ce2ce8d80e34bb928167ff1aaf58c..981e81f9e206d646dad0c96ce3418a02f646aa2f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -64,24 +64,25 @@ class Dealer: "query_vector": [float(v) for v in qv] } + def _add_filters(self, bqry, req): + if req.get("kb_ids"): + bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) + if req.get("doc_ids"): + bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) + if req.get("knowledge_graph_kwd"): + bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"])) + if "available_int" in req: + if req["available_int"] == 0: + bqry.filter.append(Q("range", available_int={"lt": 1})) + else: + bqry.filter.append( + Q("bool", must_not=Q("range", available_int={"lt": 1}))) + return bqry + def search(self, req, idxnm, emb_mdl=None): qst = req.get("question", "") bqry, keywords = self.qryr.question(qst) - def add_filters(bqry): - nonlocal req - if req.get("kb_ids"): - bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) - if req.get("doc_ids"): - bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) - if "available_int" in req: - if req["available_int"] == 0: - bqry.filter.append(Q("range", available_int={"lt": 1})) - else: - bqry.filter.append( - Q("bool", must_not=Q("range", available_int={"lt": 1}))) - return bqry - - bqry = add_filters(bqry) + bqry = self._add_filters(bqry, req) bqry.boost = 0.05 s = Search() @@ -89,7 +90,7 @@ class Dealer: topk = int(req.get("topk", 1024)) ps = int(req.get("size", topk)) src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", - "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", + "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd", "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) s = s.query(bqry)[pg * ps:(pg + 1) * ps] @@ -137,7 +138,7 @@ class Dealer: es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) if self.es.getTotal(res) == 0 and "knn" in s: bqry, _ = self.qryr.question(qst, min_match="10%") - bqry = add_filters(bqry) + bqry = self._add_filters(bqry) s["query"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict() s["knn"]["similarity"] = 0.17 diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 323b689a1e5be346b141f828d853a16b640e8fa2..e391d3a5837de6c42810c587e9cab88067ea2337 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -45,7 +45,7 @@ from rag.nlp import search, rag_tokenizer from io import BytesIO import pandas as pd -from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio +from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService @@ -68,7 +68,8 @@ FACTORY = { ParserType.RESUME.value: resume, ParserType.PICTURE.value: picture, ParserType.ONE.value: one, - ParserType.AUDIO.value: audio + ParserType.AUDIO.value: audio, + ParserType.KG.value: knowledge_graph }