File size: 5,190 Bytes
a7642c6
 
6054f54
 
 
 
 
 
 
 
 
22fe41e
6054f54
 
 
 
 
 
 
6d4f792
 
22fe41e
6054f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4f792
6054f54
6d4f792
6054f54
 
 
6d4f792
 
6054f54
 
 
 
 
 
 
 
 
 
 
 
6d4f792
6054f54
eb6e194
6d4f792
6054f54
 
bf4c34e
 
22fe41e
6054f54
 
 
 
 
 
 
 
 
 
 
22fe41e
6054f54
 
 
 
 
 
6d4f792
 
6054f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb6e194
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
 - [graphrag](https://github.com/microsoft/graphrag)
"""

import json
import re
import traceback
from dataclasses import dataclass
from typing import List, Callable
import networkx as nx
import pandas as pd
from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
from api.utils.log_utils import logger


@dataclass
class CommunityReportsResult:
    """Community reports result class definition."""

    output: List[str]
    structured_output: List[dict]


class CommunityReportsExtractor:
    """Community reports extractor class definition."""

    _llm: CompletionLLM
    _extraction_prompt: str
    _output_formatter_prompt: str
    _on_error: ErrorHandlerFn
    _max_report_length: int

    def __init__(
        self,
        llm_invoker: CompletionLLM,
        extraction_prompt: str | None = None,
        on_error: ErrorHandlerFn | None = None,
        max_report_length: int | None = None,
    ):
        """Init method definition."""
        self._llm = llm_invoker
        self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
        self._on_error = on_error or (lambda _e, _s, _d: None)
        self._max_report_length = max_report_length or 1500

    def __call__(self, graph: nx.Graph, callback: Callable | None = None):
        communities: dict[str, dict[str, List]] = leiden.run(graph, {})
        total = sum([len(comm.items()) for _, comm in communities.items()])
        relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
        res_str = []
        res_dict = []
        over, token_count = 0, 0
        st = timer()
        for level, comm in communities.items():
            for cm_id, ents in comm.items():
                weight = ents["weight"]
                ents = ents["nodes"]
                ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
                rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)

                prompt_variables = {
                    "entity_df": ent_df.to_csv(index_label="id"),
                    "relation_df": rela_df.to_csv(index_label="id")
                }
                text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
                gen_conf = {"temperature": 0.3}
                try:
                    response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
                    token_count += num_tokens_from_string(text + response)
                    response = re.sub(r"^[^\{]*", "", response)
                    response = re.sub(r"[^\}]*$", "", response)
                    response = re.sub(r"\{\{", "{", response)
                    response = re.sub(r"\}\}", "}", response)
                    logger.info(response)
                    response = json.loads(response)
                    if not dict_has_keys_with_types(response, [
                                ("title", str),
                                ("summary", str),
                                ("findings", list),
                                ("rating", float),
                                ("rating_explanation", str),
                            ]): continue
                    response["weight"] = weight
                    response["entities"] = ents
                except Exception as e:
                    logger.exception("CommunityReportsExtractor got exception")
                    self._on_error(e, traceback.format_exc(), None)
                    continue

                add_community_info2graph(graph, ents, response["title"])
                res_str.append(self._get_text_output(response))
                res_dict.append(response)
                over += 1
                if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")

        return CommunityReportsResult(
            structured_output=res_dict,
            output=res_str,
        )

    def _get_text_output(self, parsed_output: dict) -> str:
        title = parsed_output.get("title", "Report")
        summary = parsed_output.get("summary", "")
        findings = parsed_output.get("findings", [])

        def finding_summary(finding: dict):
            if isinstance(finding, str):
                return finding
            return finding.get("summary")

        def finding_explanation(finding: dict):
            if isinstance(finding, str):
                return ""
            return finding.get("explanation")

        report_sections = "\n\n".join(
            f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
        )
        return f"# {title}\n\n{summary}\n\n{report_sections}"