Kevin Hu commited on
Commit
6d4f792
·
1 Parent(s): b6ce919

refine loginfo about graprag progress (#1823)

Browse files

### What problem does this PR solve?



### Type of change

- [x] Refactoring

api/db/services/document_service.py CHANGED
@@ -317,7 +317,8 @@ class DocumentService(CommonService):
317
  if 0 <= t.progress < 1:
318
  finished = False
319
  prg += t.progress if t.progress >= 0 else 0
320
- msg.append(t.progress_msg)
 
321
  if t.progress == -1:
322
  bad += 1
323
  prg /= len(tsks)
 
317
  if 0 <= t.progress < 1:
318
  finished = False
319
  prg += t.progress if t.progress >= 0 else 0
320
+ if t.progress_msg not in msg:
321
+ msg.append(t.progress_msg)
322
  if t.progress == -1:
323
  bad += 1
324
  prg /= len(tsks)
graphrag/community_reports_extractor.py CHANGED
@@ -23,16 +23,16 @@ import logging
23
  import re
24
  import traceback
25
  from dataclasses import dataclass
26
- from typing import Any, List
27
-
28
  import networkx as nx
29
  import pandas as pd
30
-
31
  from graphrag import leiden
32
  from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
33
  from graphrag.leiden import add_community_info2graph
34
  from rag.llm.chat_model import Base as CompletionLLM
35
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
 
 
36
 
37
  log = logging.getLogger(__name__)
38
 
@@ -67,11 +67,14 @@ class CommunityReportsExtractor:
67
  self._on_error = on_error or (lambda _e, _s, _d: None)
68
  self._max_report_length = max_report_length or 1500
69
 
70
- def __call__(self, graph: nx.Graph):
71
  communities: dict[str, dict[str, List]] = leiden.run(graph, {})
 
72
  relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
73
  res_str = []
74
  res_dict = []
 
 
75
  for level, comm in communities.items():
76
  for cm_id, ents in comm.items():
77
  weight = ents["weight"]
@@ -84,9 +87,10 @@ class CommunityReportsExtractor:
84
  "relation_df": rela_df.to_csv(index_label="id")
85
  }
86
  text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
87
- gen_conf = {"temperature": 0.5}
88
  try:
89
  response = self._llm.chat(text, [], gen_conf)
 
90
  response = re.sub(r"^[^\{]*", "", response)
91
  response = re.sub(r"[^\}]*$", "", response)
92
  print(response)
@@ -108,6 +112,8 @@ class CommunityReportsExtractor:
108
  add_community_info2graph(graph, ents, response["title"])
109
  res_str.append(self._get_text_output(response))
110
  res_dict.append(response)
 
 
111
 
112
  return CommunityReportsResult(
113
  structured_output=res_dict,
 
23
  import re
24
  import traceback
25
  from dataclasses import dataclass
26
+ from typing import Any, List, Callable
 
27
  import networkx as nx
28
  import pandas as pd
 
29
  from graphrag import leiden
30
  from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
31
  from graphrag.leiden import add_community_info2graph
32
  from rag.llm.chat_model import Base as CompletionLLM
33
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
34
+ from rag.utils import num_tokens_from_string
35
+ from timeit import default_timer as timer
36
 
37
  log = logging.getLogger(__name__)
38
 
 
67
  self._on_error = on_error or (lambda _e, _s, _d: None)
68
  self._max_report_length = max_report_length or 1500
69
 
70
+ def __call__(self, graph: nx.Graph, callback: Callable | None = None):
71
  communities: dict[str, dict[str, List]] = leiden.run(graph, {})
72
+ total = sum([len(comm.items()) for _, comm in communities.items()])
73
  relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
74
  res_str = []
75
  res_dict = []
76
+ over, token_count = 0, 0
77
+ st = timer()
78
  for level, comm in communities.items():
79
  for cm_id, ents in comm.items():
80
  weight = ents["weight"]
 
87
  "relation_df": rela_df.to_csv(index_label="id")
88
  }
89
  text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
90
+ gen_conf = {"temperature": 0.3}
91
  try:
92
  response = self._llm.chat(text, [], gen_conf)
93
+ token_count += num_tokens_from_string(text + response)
94
  response = re.sub(r"^[^\{]*", "", response)
95
  response = re.sub(r"[^\}]*$", "", response)
96
  print(response)
 
112
  add_community_info2graph(graph, ents, response["title"])
113
  res_str.append(self._get_text_output(response))
114
  res_dict.append(response)
115
+ over += 1
116
+ if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
117
 
118
  return CommunityReportsResult(
119
  structured_output=res_dict,
graphrag/graph_extractor.py CHANGED
@@ -21,13 +21,14 @@ import numbers
21
  import re
22
  import traceback
23
  from dataclasses import dataclass
24
- from typing import Any, Mapping
25
  import tiktoken
26
  from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
27
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
28
  from rag.llm.chat_model import Base as CompletionLLM
29
  import networkx as nx
30
  from rag.utils import num_tokens_from_string
 
31
 
32
  DEFAULT_TUPLE_DELIMITER = "<|>"
33
  DEFAULT_RECORD_DELIMITER = "##"
@@ -103,7 +104,9 @@ class GraphExtractor:
103
  self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
104
 
105
  def __call__(
106
- self, texts: list[str], prompt_variables: dict[str, Any] | None = None
 
 
107
  ) -> GraphExtractionResult:
108
  """Call method definition."""
109
  if prompt_variables is None:
@@ -127,12 +130,17 @@ class GraphExtractor:
127
  ),
128
  }
129
 
 
 
 
130
  for doc_index, text in enumerate(texts):
131
  try:
132
  # Invoke the entity extraction
133
- result = self._process_document(text, prompt_variables)
134
  source_doc_map[doc_index] = text
135
  all_records[doc_index] = result
 
 
136
  except Exception as e:
137
  logging.exception("error extracting graph")
138
  self._on_error(
@@ -162,9 +170,11 @@ class GraphExtractor:
162
  **prompt_variables,
163
  self._input_text_key: text,
164
  }
 
165
  text = perform_variable_replacements(self._extraction_prompt, variables=variables)
166
- gen_conf = {"temperature": 0.5}
167
  response = self._llm.chat(text, [], gen_conf)
 
168
 
169
  results = response or ""
170
  history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
@@ -185,7 +195,7 @@ class GraphExtractor:
185
  if continuation != "YES":
186
  break
187
 
188
- return results
189
 
190
  def _process_results(
191
  self,
 
21
  import re
22
  import traceback
23
  from dataclasses import dataclass
24
+ from typing import Any, Mapping, Callable
25
  import tiktoken
26
  from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
27
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
28
  from rag.llm.chat_model import Base as CompletionLLM
29
  import networkx as nx
30
  from rag.utils import num_tokens_from_string
31
+ from timeit import default_timer as timer
32
 
33
  DEFAULT_TUPLE_DELIMITER = "<|>"
34
  DEFAULT_RECORD_DELIMITER = "##"
 
104
  self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
105
 
106
  def __call__(
107
+ self, texts: list[str],
108
+ prompt_variables: dict[str, Any] | None = None,
109
+ callback: Callable | None = None
110
  ) -> GraphExtractionResult:
111
  """Call method definition."""
112
  if prompt_variables is None:
 
130
  ),
131
  }
132
 
133
+ st = timer()
134
+ total = len(texts)
135
+ total_token_count = 0
136
  for doc_index, text in enumerate(texts):
137
  try:
138
  # Invoke the entity extraction
139
+ result, token_count = self._process_document(text, prompt_variables)
140
  source_doc_map[doc_index] = text
141
  all_records[doc_index] = result
142
+ total_token_count += token_count
143
+ if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
144
  except Exception as e:
145
  logging.exception("error extracting graph")
146
  self._on_error(
 
170
  **prompt_variables,
171
  self._input_text_key: text,
172
  }
173
+ token_count = 0
174
  text = perform_variable_replacements(self._extraction_prompt, variables=variables)
175
+ gen_conf = {"temperature": 0.3}
176
  response = self._llm.chat(text, [], gen_conf)
177
+ token_count = num_tokens_from_string(text + response)
178
 
179
  results = response or ""
180
  history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
 
195
  if continuation != "YES":
196
  break
197
 
198
+ return results, token_count
199
 
200
  def _process_results(
201
  self,
graphrag/index.py CHANGED
@@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
86
  for i in range(len(chunks)):
87
  tkn_cnt = num_tokens_from_string(chunks[i])
88
  if cnt+tkn_cnt >= left_token_count and texts:
89
- threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
90
  texts = []
91
  cnt = 0
92
  texts.append(chunks[i])
@@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
98
  graphs = []
99
  for i, _ in enumerate(threads):
100
  graphs.append(_.result().output)
101
- callback(0.5 + 0.1*i/len(threads))
102
 
103
  graph = reduce(graph_merge, graphs)
104
  er = EntityResolution(llm_bdl)
@@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
125
 
126
  callback(0.6, "Extracting community reports.")
127
  cr = CommunityReportsExtractor(llm_bdl)
128
- cr = cr(graph)
129
  for community, desc in zip(cr.structured_output, cr.output):
130
  chunk = {
131
  "title_tks": rag_tokenizer.tokenize(community["title"]),
 
86
  for i in range(len(chunks)):
87
  tkn_cnt = num_tokens_from_string(chunks[i])
88
  if cnt+tkn_cnt >= left_token_count and texts:
89
+ threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback))
90
  texts = []
91
  cnt = 0
92
  texts.append(chunks[i])
 
98
  graphs = []
99
  for i, _ in enumerate(threads):
100
  graphs.append(_.result().output)
101
+ callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
102
 
103
  graph = reduce(graph_merge, graphs)
104
  er = EntityResolution(llm_bdl)
 
125
 
126
  callback(0.6, "Extracting community reports.")
127
  cr = CommunityReportsExtractor(llm_bdl)
128
+ cr = cr(graph, callback=callback)
129
  for community, desc in zip(cr.structured_output, cr.output):
130
  chunk = {
131
  "title_tks": rag_tokenizer.tokenize(community["title"]),
rag/nlp/search.py CHANGED
@@ -138,7 +138,7 @@ class Dealer:
138
  es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
139
  if self.es.getTotal(res) == 0 and "knn" in s:
140
  bqry, _ = self.qryr.question(qst, min_match="10%")
141
- bqry = self._add_filters(bqry)
142
  s["query"] = bqry.to_dict()
143
  s["knn"]["filter"] = bqry.to_dict()
144
  s["knn"]["similarity"] = 0.17
 
138
  es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
139
  if self.es.getTotal(res) == 0 and "knn" in s:
140
  bqry, _ = self.qryr.question(qst, min_match="10%")
141
+ bqry = self._add_filters(bqry, req)
142
  s["query"] = bqry.to_dict()
143
  s["knn"]["filter"] = bqry.to_dict()
144
  s["knn"]["similarity"] = 0.17