Kevin Hu
commited on
Commit
·
6d4f792
1
Parent(s):
b6ce919
refine loginfo about graprag progress (#1823)
Browse files### What problem does this PR solve?
### Type of change
- [x] Refactoring
- api/db/services/document_service.py +2 -1
- graphrag/community_reports_extractor.py +11 -5
- graphrag/graph_extractor.py +15 -5
- graphrag/index.py +3 -3
- rag/nlp/search.py +1 -1
api/db/services/document_service.py
CHANGED
@@ -317,7 +317,8 @@ class DocumentService(CommonService):
|
|
317 |
if 0 <= t.progress < 1:
|
318 |
finished = False
|
319 |
prg += t.progress if t.progress >= 0 else 0
|
320 |
-
|
|
|
321 |
if t.progress == -1:
|
322 |
bad += 1
|
323 |
prg /= len(tsks)
|
|
|
317 |
if 0 <= t.progress < 1:
|
318 |
finished = False
|
319 |
prg += t.progress if t.progress >= 0 else 0
|
320 |
+
if t.progress_msg not in msg:
|
321 |
+
msg.append(t.progress_msg)
|
322 |
if t.progress == -1:
|
323 |
bad += 1
|
324 |
prg /= len(tsks)
|
graphrag/community_reports_extractor.py
CHANGED
@@ -23,16 +23,16 @@ import logging
|
|
23 |
import re
|
24 |
import traceback
|
25 |
from dataclasses import dataclass
|
26 |
-
from typing import Any, List
|
27 |
-
|
28 |
import networkx as nx
|
29 |
import pandas as pd
|
30 |
-
|
31 |
from graphrag import leiden
|
32 |
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
33 |
from graphrag.leiden import add_community_info2graph
|
34 |
from rag.llm.chat_model import Base as CompletionLLM
|
35 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
|
|
|
|
36 |
|
37 |
log = logging.getLogger(__name__)
|
38 |
|
@@ -67,11 +67,14 @@ class CommunityReportsExtractor:
|
|
67 |
self._on_error = on_error or (lambda _e, _s, _d: None)
|
68 |
self._max_report_length = max_report_length or 1500
|
69 |
|
70 |
-
def __call__(self, graph: nx.Graph):
|
71 |
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
|
|
72 |
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
73 |
res_str = []
|
74 |
res_dict = []
|
|
|
|
|
75 |
for level, comm in communities.items():
|
76 |
for cm_id, ents in comm.items():
|
77 |
weight = ents["weight"]
|
@@ -84,9 +87,10 @@ class CommunityReportsExtractor:
|
|
84 |
"relation_df": rela_df.to_csv(index_label="id")
|
85 |
}
|
86 |
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
87 |
-
gen_conf = {"temperature": 0.
|
88 |
try:
|
89 |
response = self._llm.chat(text, [], gen_conf)
|
|
|
90 |
response = re.sub(r"^[^\{]*", "", response)
|
91 |
response = re.sub(r"[^\}]*$", "", response)
|
92 |
print(response)
|
@@ -108,6 +112,8 @@ class CommunityReportsExtractor:
|
|
108 |
add_community_info2graph(graph, ents, response["title"])
|
109 |
res_str.append(self._get_text_output(response))
|
110 |
res_dict.append(response)
|
|
|
|
|
111 |
|
112 |
return CommunityReportsResult(
|
113 |
structured_output=res_dict,
|
|
|
23 |
import re
|
24 |
import traceback
|
25 |
from dataclasses import dataclass
|
26 |
+
from typing import Any, List, Callable
|
|
|
27 |
import networkx as nx
|
28 |
import pandas as pd
|
|
|
29 |
from graphrag import leiden
|
30 |
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
31 |
from graphrag.leiden import add_community_info2graph
|
32 |
from rag.llm.chat_model import Base as CompletionLLM
|
33 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
34 |
+
from rag.utils import num_tokens_from_string
|
35 |
+
from timeit import default_timer as timer
|
36 |
|
37 |
log = logging.getLogger(__name__)
|
38 |
|
|
|
67 |
self._on_error = on_error or (lambda _e, _s, _d: None)
|
68 |
self._max_report_length = max_report_length or 1500
|
69 |
|
70 |
+
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
71 |
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
72 |
+
total = sum([len(comm.items()) for _, comm in communities.items()])
|
73 |
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
74 |
res_str = []
|
75 |
res_dict = []
|
76 |
+
over, token_count = 0, 0
|
77 |
+
st = timer()
|
78 |
for level, comm in communities.items():
|
79 |
for cm_id, ents in comm.items():
|
80 |
weight = ents["weight"]
|
|
|
87 |
"relation_df": rela_df.to_csv(index_label="id")
|
88 |
}
|
89 |
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
90 |
+
gen_conf = {"temperature": 0.3}
|
91 |
try:
|
92 |
response = self._llm.chat(text, [], gen_conf)
|
93 |
+
token_count += num_tokens_from_string(text + response)
|
94 |
response = re.sub(r"^[^\{]*", "", response)
|
95 |
response = re.sub(r"[^\}]*$", "", response)
|
96 |
print(response)
|
|
|
112 |
add_community_info2graph(graph, ents, response["title"])
|
113 |
res_str.append(self._get_text_output(response))
|
114 |
res_dict.append(response)
|
115 |
+
over += 1
|
116 |
+
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
117 |
|
118 |
return CommunityReportsResult(
|
119 |
structured_output=res_dict,
|
graphrag/graph_extractor.py
CHANGED
@@ -21,13 +21,14 @@ import numbers
|
|
21 |
import re
|
22 |
import traceback
|
23 |
from dataclasses import dataclass
|
24 |
-
from typing import Any, Mapping
|
25 |
import tiktoken
|
26 |
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
27 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
28 |
from rag.llm.chat_model import Base as CompletionLLM
|
29 |
import networkx as nx
|
30 |
from rag.utils import num_tokens_from_string
|
|
|
31 |
|
32 |
DEFAULT_TUPLE_DELIMITER = "<|>"
|
33 |
DEFAULT_RECORD_DELIMITER = "##"
|
@@ -103,7 +104,9 @@ class GraphExtractor:
|
|
103 |
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
104 |
|
105 |
def __call__(
|
106 |
-
self, texts: list[str],
|
|
|
|
|
107 |
) -> GraphExtractionResult:
|
108 |
"""Call method definition."""
|
109 |
if prompt_variables is None:
|
@@ -127,12 +130,17 @@ class GraphExtractor:
|
|
127 |
),
|
128 |
}
|
129 |
|
|
|
|
|
|
|
130 |
for doc_index, text in enumerate(texts):
|
131 |
try:
|
132 |
# Invoke the entity extraction
|
133 |
-
result = self._process_document(text, prompt_variables)
|
134 |
source_doc_map[doc_index] = text
|
135 |
all_records[doc_index] = result
|
|
|
|
|
136 |
except Exception as e:
|
137 |
logging.exception("error extracting graph")
|
138 |
self._on_error(
|
@@ -162,9 +170,11 @@ class GraphExtractor:
|
|
162 |
**prompt_variables,
|
163 |
self._input_text_key: text,
|
164 |
}
|
|
|
165 |
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
166 |
-
gen_conf = {"temperature": 0.
|
167 |
response = self._llm.chat(text, [], gen_conf)
|
|
|
168 |
|
169 |
results = response or ""
|
170 |
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
|
@@ -185,7 +195,7 @@ class GraphExtractor:
|
|
185 |
if continuation != "YES":
|
186 |
break
|
187 |
|
188 |
-
return results
|
189 |
|
190 |
def _process_results(
|
191 |
self,
|
|
|
21 |
import re
|
22 |
import traceback
|
23 |
from dataclasses import dataclass
|
24 |
+
from typing import Any, Mapping, Callable
|
25 |
import tiktoken
|
26 |
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
27 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
28 |
from rag.llm.chat_model import Base as CompletionLLM
|
29 |
import networkx as nx
|
30 |
from rag.utils import num_tokens_from_string
|
31 |
+
from timeit import default_timer as timer
|
32 |
|
33 |
DEFAULT_TUPLE_DELIMITER = "<|>"
|
34 |
DEFAULT_RECORD_DELIMITER = "##"
|
|
|
104 |
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
105 |
|
106 |
def __call__(
|
107 |
+
self, texts: list[str],
|
108 |
+
prompt_variables: dict[str, Any] | None = None,
|
109 |
+
callback: Callable | None = None
|
110 |
) -> GraphExtractionResult:
|
111 |
"""Call method definition."""
|
112 |
if prompt_variables is None:
|
|
|
130 |
),
|
131 |
}
|
132 |
|
133 |
+
st = timer()
|
134 |
+
total = len(texts)
|
135 |
+
total_token_count = 0
|
136 |
for doc_index, text in enumerate(texts):
|
137 |
try:
|
138 |
# Invoke the entity extraction
|
139 |
+
result, token_count = self._process_document(text, prompt_variables)
|
140 |
source_doc_map[doc_index] = text
|
141 |
all_records[doc_index] = result
|
142 |
+
total_token_count += token_count
|
143 |
+
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
144 |
except Exception as e:
|
145 |
logging.exception("error extracting graph")
|
146 |
self._on_error(
|
|
|
170 |
**prompt_variables,
|
171 |
self._input_text_key: text,
|
172 |
}
|
173 |
+
token_count = 0
|
174 |
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
175 |
+
gen_conf = {"temperature": 0.3}
|
176 |
response = self._llm.chat(text, [], gen_conf)
|
177 |
+
token_count = num_tokens_from_string(text + response)
|
178 |
|
179 |
results = response or ""
|
180 |
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
|
|
|
195 |
if continuation != "YES":
|
196 |
break
|
197 |
|
198 |
+
return results, token_count
|
199 |
|
200 |
def _process_results(
|
201 |
self,
|
graphrag/index.py
CHANGED
@@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
|
86 |
for i in range(len(chunks)):
|
87 |
tkn_cnt = num_tokens_from_string(chunks[i])
|
88 |
if cnt+tkn_cnt >= left_token_count and texts:
|
89 |
-
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
|
90 |
texts = []
|
91 |
cnt = 0
|
92 |
texts.append(chunks[i])
|
@@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
|
98 |
graphs = []
|
99 |
for i, _ in enumerate(threads):
|
100 |
graphs.append(_.result().output)
|
101 |
-
callback(0.5 + 0.1*i/len(threads))
|
102 |
|
103 |
graph = reduce(graph_merge, graphs)
|
104 |
er = EntityResolution(llm_bdl)
|
@@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
|
125 |
|
126 |
callback(0.6, "Extracting community reports.")
|
127 |
cr = CommunityReportsExtractor(llm_bdl)
|
128 |
-
cr = cr(graph)
|
129 |
for community, desc in zip(cr.structured_output, cr.output):
|
130 |
chunk = {
|
131 |
"title_tks": rag_tokenizer.tokenize(community["title"]),
|
|
|
86 |
for i in range(len(chunks)):
|
87 |
tkn_cnt = num_tokens_from_string(chunks[i])
|
88 |
if cnt+tkn_cnt >= left_token_count and texts:
|
89 |
+
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback))
|
90 |
texts = []
|
91 |
cnt = 0
|
92 |
texts.append(chunks[i])
|
|
|
98 |
graphs = []
|
99 |
for i, _ in enumerate(threads):
|
100 |
graphs.append(_.result().output)
|
101 |
+
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
|
102 |
|
103 |
graph = reduce(graph_merge, graphs)
|
104 |
er = EntityResolution(llm_bdl)
|
|
|
125 |
|
126 |
callback(0.6, "Extracting community reports.")
|
127 |
cr = CommunityReportsExtractor(llm_bdl)
|
128 |
+
cr = cr(graph, callback=callback)
|
129 |
for community, desc in zip(cr.structured_output, cr.output):
|
130 |
chunk = {
|
131 |
"title_tks": rag_tokenizer.tokenize(community["title"]),
|
rag/nlp/search.py
CHANGED
@@ -138,7 +138,7 @@ class Dealer:
|
|
138 |
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
139 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
140 |
bqry, _ = self.qryr.question(qst, min_match="10%")
|
141 |
-
bqry = self._add_filters(bqry)
|
142 |
s["query"] = bqry.to_dict()
|
143 |
s["knn"]["filter"] = bqry.to_dict()
|
144 |
s["knn"]["similarity"] = 0.17
|
|
|
138 |
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
139 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
140 |
bqry, _ = self.qryr.question(qst, min_match="10%")
|
141 |
+
bqry = self._add_filters(bqry, req)
|
142 |
s["query"] = bqry.to_dict()
|
143 |
s["knn"]["filter"] = bqry.to_dict()
|
144 |
s["knn"]["similarity"] = 0.17
|