Kevin Hu
		
	commited on
		
		
					Commit 
							
							·
						
						758538f
	
1
								Parent(s):
							
							7056954
								
Cache the result from llm for graphrag and raptor (#4051)
Browse files### What problem does this PR solve?
#4045
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/db/services/task_service.py +6 -2
- graphrag/__init__.py +0 -0
- graphrag/claim_extractor.py +5 -5
- graphrag/community_reports_extractor.py +3 -3
- graphrag/description_summary.py +3 -3
- graphrag/entity_resolution.py +4 -3
- graphrag/extractor.py +34 -0
- graphrag/graph_extractor.py +6 -9
- graphrag/mind_map_extractor.py +3 -3
- graphrag/utils.py +52 -0
- rag/raptor.py +24 -4
- rag/svr/task_executor.py +21 -6
    	
        api/db/services/task_service.py
    CHANGED
    
    | @@ -271,7 +271,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): | |
| 271 |  | 
| 272 |  | 
| 273 | 
             
            def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
         | 
| 274 | 
            -
                idx = bisect.bisect_left(prev_tasks, task | 
| 275 | 
             
                if idx >= len(prev_tasks):
         | 
| 276 | 
             
                    return 0
         | 
| 277 | 
             
                prev_task = prev_tasks[idx]
         | 
| @@ -279,7 +279,11 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: | |
| 279 | 
             
                    return 0
         | 
| 280 | 
             
                task["chunk_ids"] = prev_task["chunk_ids"]
         | 
| 281 | 
             
                task["progress"] = 1.0
         | 
| 282 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
| 283 | 
             
                prev_task["chunk_ids"] = ""
         | 
| 284 |  | 
| 285 | 
             
                return len(task["chunk_ids"].split())
         | 
|  | |
| 271 |  | 
| 272 |  | 
| 273 | 
             
            def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
         | 
| 274 | 
            +
                idx = bisect.bisect_left(prev_tasks, task.get("from_page", 0), key=lambda x: x.get("from_page",0))
         | 
| 275 | 
             
                if idx >= len(prev_tasks):
         | 
| 276 | 
             
                    return 0
         | 
| 277 | 
             
                prev_task = prev_tasks[idx]
         | 
|  | |
| 279 | 
             
                    return 0
         | 
| 280 | 
             
                task["chunk_ids"] = prev_task["chunk_ids"]
         | 
| 281 | 
             
                task["progress"] = 1.0
         | 
| 282 | 
            +
                if "from_page" in task and "to_page" in task:
         | 
| 283 | 
            +
                    task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
         | 
| 284 | 
            +
                else:
         | 
| 285 | 
            +
                    task["progress_msg"] = ""
         | 
| 286 | 
            +
                task["progress_msg"] += "reused previous task's chunks."
         | 
| 287 | 
             
                prev_task["chunk_ids"] = ""
         | 
| 288 |  | 
| 289 | 
             
                return len(task["chunk_ids"].split())
         | 
    	
        graphrag/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        graphrag/claim_extractor.py
    CHANGED
    
    | @@ -16,6 +16,7 @@ from typing import Any | |
| 16 | 
             
            import tiktoken
         | 
| 17 |  | 
| 18 | 
             
            from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
         | 
|  | |
| 19 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| 20 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
         | 
| 21 |  | 
| @@ -33,10 +34,9 @@ class ClaimExtractorResult: | |
| 33 | 
             
                source_docs: dict[str, Any]
         | 
| 34 |  | 
| 35 |  | 
| 36 | 
            -
            class ClaimExtractor:
         | 
| 37 | 
             
                """Claim extractor class definition."""
         | 
| 38 |  | 
| 39 | 
            -
                _llm: CompletionLLM
         | 
| 40 | 
             
                _extraction_prompt: str
         | 
| 41 | 
             
                _summary_prompt: str
         | 
| 42 | 
             
                _output_formatter_prompt: str
         | 
| @@ -169,7 +169,7 @@ class ClaimExtractor: | |
| 169 | 
             
                                }
         | 
| 170 | 
             
                    text = perform_variable_replacements(self._extraction_prompt, variables=variables)
         | 
| 171 | 
             
                    gen_conf = {"temperature": 0.5}
         | 
| 172 | 
            -
                    results = self. | 
| 173 | 
             
                    claims = results.strip().removesuffix(completion_delimiter)
         | 
| 174 | 
             
                    history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
         | 
| 175 |  | 
| @@ -177,7 +177,7 @@ class ClaimExtractor: | |
| 177 | 
             
                    for i in range(self._max_gleanings):
         | 
| 178 | 
             
                        text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
         | 
| 179 | 
             
                        history.append({"role": "user", "content": text})
         | 
| 180 | 
            -
                        extension = self. | 
| 181 | 
             
                        claims += record_delimiter + extension.strip().removesuffix(
         | 
| 182 | 
             
                            completion_delimiter
         | 
| 183 | 
             
                        )
         | 
| @@ -188,7 +188,7 @@ class ClaimExtractor: | |
| 188 |  | 
| 189 | 
             
                        history.append({"role": "assistant", "content": extension})
         | 
| 190 | 
             
                        history.append({"role": "user", "content": LOOP_PROMPT})
         | 
| 191 | 
            -
                        continuation = self. | 
| 192 | 
             
                        if continuation != "YES":
         | 
| 193 | 
             
                            break
         | 
| 194 |  | 
|  | |
| 16 | 
             
            import tiktoken
         | 
| 17 |  | 
| 18 | 
             
            from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
         | 
| 19 | 
            +
            from graphrag.extractor import Extractor
         | 
| 20 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| 21 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
         | 
| 22 |  | 
|  | |
| 34 | 
             
                source_docs: dict[str, Any]
         | 
| 35 |  | 
| 36 |  | 
| 37 | 
            +
            class ClaimExtractor(Extractor):
         | 
| 38 | 
             
                """Claim extractor class definition."""
         | 
| 39 |  | 
|  | |
| 40 | 
             
                _extraction_prompt: str
         | 
| 41 | 
             
                _summary_prompt: str
         | 
| 42 | 
             
                _output_formatter_prompt: str
         | 
|  | |
| 169 | 
             
                                }
         | 
| 170 | 
             
                    text = perform_variable_replacements(self._extraction_prompt, variables=variables)
         | 
| 171 | 
             
                    gen_conf = {"temperature": 0.5}
         | 
| 172 | 
            +
                    results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
         | 
| 173 | 
             
                    claims = results.strip().removesuffix(completion_delimiter)
         | 
| 174 | 
             
                    history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
         | 
| 175 |  | 
|  | |
| 177 | 
             
                    for i in range(self._max_gleanings):
         | 
| 178 | 
             
                        text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
         | 
| 179 | 
             
                        history.append({"role": "user", "content": text})
         | 
| 180 | 
            +
                        extension = self._chat("", history, gen_conf)
         | 
| 181 | 
             
                        claims += record_delimiter + extension.strip().removesuffix(
         | 
| 182 | 
             
                            completion_delimiter
         | 
| 183 | 
             
                        )
         | 
|  | |
| 188 |  | 
| 189 | 
             
                        history.append({"role": "assistant", "content": extension})
         | 
| 190 | 
             
                        history.append({"role": "user", "content": LOOP_PROMPT})
         | 
| 191 | 
            +
                        continuation = self._chat("", history, self._loop_args)
         | 
| 192 | 
             
                        if continuation != "YES":
         | 
| 193 | 
             
                            break
         | 
| 194 |  | 
    	
        graphrag/community_reports_extractor.py
    CHANGED
    
    | @@ -15,6 +15,7 @@ import networkx as nx | |
| 15 | 
             
            import pandas as pd
         | 
| 16 | 
             
            from graphrag import leiden
         | 
| 17 | 
             
            from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
         | 
|  | |
| 18 | 
             
            from graphrag.leiden import add_community_info2graph
         | 
| 19 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| 20 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
         | 
| @@ -30,10 +31,9 @@ class CommunityReportsResult: | |
| 30 | 
             
                structured_output: list[dict]
         | 
| 31 |  | 
| 32 |  | 
| 33 | 
            -
            class CommunityReportsExtractor:
         | 
| 34 | 
             
                """Community reports extractor class definition."""
         | 
| 35 |  | 
| 36 | 
            -
                _llm: CompletionLLM
         | 
| 37 | 
             
                _extraction_prompt: str
         | 
| 38 | 
             
                _output_formatter_prompt: str
         | 
| 39 | 
             
                _on_error: ErrorHandlerFn
         | 
| @@ -74,7 +74,7 @@ class CommunityReportsExtractor: | |
| 74 | 
             
                            text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
         | 
| 75 | 
             
                            gen_conf = {"temperature": 0.3}
         | 
| 76 | 
             
                            try:
         | 
| 77 | 
            -
                                response = self. | 
| 78 | 
             
                                token_count += num_tokens_from_string(text + response)
         | 
| 79 | 
             
                                response = re.sub(r"^[^\{]*", "", response)
         | 
| 80 | 
             
                                response = re.sub(r"[^\}]*$", "", response)
         | 
|  | |
| 15 | 
             
            import pandas as pd
         | 
| 16 | 
             
            from graphrag import leiden
         | 
| 17 | 
             
            from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
         | 
| 18 | 
            +
            from graphrag.extractor import Extractor
         | 
| 19 | 
             
            from graphrag.leiden import add_community_info2graph
         | 
| 20 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| 21 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
         | 
|  | |
| 31 | 
             
                structured_output: list[dict]
         | 
| 32 |  | 
| 33 |  | 
| 34 | 
            +
            class CommunityReportsExtractor(Extractor):
         | 
| 35 | 
             
                """Community reports extractor class definition."""
         | 
| 36 |  | 
|  | |
| 37 | 
             
                _extraction_prompt: str
         | 
| 38 | 
             
                _output_formatter_prompt: str
         | 
| 39 | 
             
                _on_error: ErrorHandlerFn
         | 
|  | |
| 74 | 
             
                            text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
         | 
| 75 | 
             
                            gen_conf = {"temperature": 0.3}
         | 
| 76 | 
             
                            try:
         | 
| 77 | 
            +
                                response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
         | 
| 78 | 
             
                                token_count += num_tokens_from_string(text + response)
         | 
| 79 | 
             
                                response = re.sub(r"^[^\{]*", "", response)
         | 
| 80 | 
             
                                response = re.sub(r"[^\}]*$", "", response)
         | 
    	
        graphrag/description_summary.py
    CHANGED
    
    | @@ -8,6 +8,7 @@ Reference: | |
| 8 | 
             
            import json
         | 
| 9 | 
             
            from dataclasses import dataclass
         | 
| 10 |  | 
|  | |
| 11 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
         | 
| 12 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| 13 |  | 
| @@ -42,10 +43,9 @@ class SummarizationResult: | |
| 42 | 
             
                description: str
         | 
| 43 |  | 
| 44 |  | 
| 45 | 
            -
            class SummarizeExtractor:
         | 
| 46 | 
             
                """Unipartite graph extractor class definition."""
         | 
| 47 |  | 
| 48 | 
            -
                _llm: CompletionLLM
         | 
| 49 | 
             
                _entity_name_key: str
         | 
| 50 | 
             
                _input_descriptions_key: str
         | 
| 51 | 
             
                _summarization_prompt: str
         | 
| @@ -143,4 +143,4 @@ class SummarizeExtractor: | |
| 143 | 
             
                                    self._input_descriptions_key: json.dumps(sorted(descriptions)),
         | 
| 144 | 
             
                                }
         | 
| 145 | 
             
                    text = perform_variable_replacements(self._summarization_prompt, variables=variables)
         | 
| 146 | 
            -
                    return self. | 
|  | |
| 8 | 
             
            import json
         | 
| 9 | 
             
            from dataclasses import dataclass
         | 
| 10 |  | 
| 11 | 
            +
            from graphrag.extractor import Extractor
         | 
| 12 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
         | 
| 13 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| 14 |  | 
|  | |
| 43 | 
             
                description: str
         | 
| 44 |  | 
| 45 |  | 
| 46 | 
            +
            class SummarizeExtractor(Extractor):
         | 
| 47 | 
             
                """Unipartite graph extractor class definition."""
         | 
| 48 |  | 
|  | |
| 49 | 
             
                _entity_name_key: str
         | 
| 50 | 
             
                _input_descriptions_key: str
         | 
| 51 | 
             
                _summarization_prompt: str
         | 
|  | |
| 143 | 
             
                                    self._input_descriptions_key: json.dumps(sorted(descriptions)),
         | 
| 144 | 
             
                                }
         | 
| 145 | 
             
                    text = perform_variable_replacements(self._summarization_prompt, variables=variables)
         | 
| 146 | 
            +
                    return self._chat("", [{"role": "user", "content": text}])
         | 
    	
        graphrag/entity_resolution.py
    CHANGED
    
    | @@ -21,6 +21,8 @@ from dataclasses import dataclass | |
| 21 | 
             
            from typing import Any
         | 
| 22 |  | 
| 23 | 
             
            import networkx as nx
         | 
|  | |
|  | |
| 24 | 
             
            from rag.nlp import is_english
         | 
| 25 | 
             
            import editdistance
         | 
| 26 | 
             
            from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
         | 
| @@ -39,10 +41,9 @@ class EntityResolutionResult: | |
| 39 | 
             
                output: nx.Graph
         | 
| 40 |  | 
| 41 |  | 
| 42 | 
            -
            class EntityResolution:
         | 
| 43 | 
             
                """Entity resolution class definition."""
         | 
| 44 |  | 
| 45 | 
            -
                _llm: CompletionLLM
         | 
| 46 | 
             
                _resolution_prompt: str
         | 
| 47 | 
             
                _output_formatter_prompt: str
         | 
| 48 | 
             
                _on_error: ErrorHandlerFn
         | 
| @@ -117,7 +118,7 @@ class EntityResolution: | |
| 117 | 
             
                                }
         | 
| 118 | 
             
                                text = perform_variable_replacements(self._resolution_prompt, variables=variables)
         | 
| 119 |  | 
| 120 | 
            -
                                response = self. | 
| 121 | 
             
                                result = self._process_results(len(candidate_resolution_i[1]), response,
         | 
| 122 | 
             
                                                               prompt_variables.get(self._record_delimiter_key,
         | 
| 123 | 
             
                                                                                    DEFAULT_RECORD_DELIMITER),
         | 
|  | |
| 21 | 
             
            from typing import Any
         | 
| 22 |  | 
| 23 | 
             
            import networkx as nx
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from graphrag.extractor import Extractor
         | 
| 26 | 
             
            from rag.nlp import is_english
         | 
| 27 | 
             
            import editdistance
         | 
| 28 | 
             
            from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
         | 
|  | |
| 41 | 
             
                output: nx.Graph
         | 
| 42 |  | 
| 43 |  | 
| 44 | 
            +
            class EntityResolution(Extractor):
         | 
| 45 | 
             
                """Entity resolution class definition."""
         | 
| 46 |  | 
|  | |
| 47 | 
             
                _resolution_prompt: str
         | 
| 48 | 
             
                _output_formatter_prompt: str
         | 
| 49 | 
             
                _on_error: ErrorHandlerFn
         | 
|  | |
| 118 | 
             
                                }
         | 
| 119 | 
             
                                text = perform_variable_replacements(self._resolution_prompt, variables=variables)
         | 
| 120 |  | 
| 121 | 
            +
                                response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
         | 
| 122 | 
             
                                result = self._process_results(len(candidate_resolution_i[1]), response,
         | 
| 123 | 
             
                                                               prompt_variables.get(self._record_delimiter_key,
         | 
| 124 | 
             
                                                                                    DEFAULT_RECORD_DELIMITER),
         | 
    	
        graphrag/extractor.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            #  Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            #  you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            #  You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #      http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            #  Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            #  distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            #  See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            #  limitations under the License.
         | 
| 15 | 
            +
            #
         | 
| 16 | 
            +
            from graphrag.utils import get_llm_cache, set_llm_cache
         | 
| 17 | 
            +
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class Extractor:
         | 
| 21 | 
            +
                _llm: CompletionLLM
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(self, llm_invoker: CompletionLLM):
         | 
| 24 | 
            +
                    self._llm = llm_invoker
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def _chat(self, system, history, gen_conf):
         | 
| 27 | 
            +
                    response = get_llm_cache(self._llm.llm_name, system, history, gen_conf)
         | 
| 28 | 
            +
                    if response:
         | 
| 29 | 
            +
                        return response
         | 
| 30 | 
            +
                    response = self._llm.chat(system, history, gen_conf)
         | 
| 31 | 
            +
                    if response.find("**ERROR**") >= 0:
         | 
| 32 | 
            +
                        raise Exception(response)
         | 
| 33 | 
            +
                    set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
         | 
| 34 | 
            +
                    return response
         | 
    	
        graphrag/graph_extractor.py
    CHANGED
    
    | @@ -12,6 +12,8 @@ import traceback | |
| 12 | 
             
            from typing import Any, Callable, Mapping
         | 
| 13 | 
             
            from dataclasses import dataclass
         | 
| 14 | 
             
            import tiktoken
         | 
|  | |
|  | |
| 15 | 
             
            from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
         | 
| 16 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
         | 
| 17 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| @@ -34,10 +36,9 @@ class GraphExtractionResult: | |
| 34 | 
             
                source_docs: dict[Any, Any]
         | 
| 35 |  | 
| 36 |  | 
| 37 | 
            -
            class GraphExtractor:
         | 
| 38 | 
             
                """Unipartite graph extractor class definition."""
         | 
| 39 |  | 
| 40 | 
            -
                _llm: CompletionLLM
         | 
| 41 | 
             
                _join_descriptions: bool
         | 
| 42 | 
             
                _tuple_delimiter_key: str
         | 
| 43 | 
             
                _record_delimiter_key: str
         | 
| @@ -165,9 +166,7 @@ class GraphExtractor: | |
| 165 | 
             
                    token_count = 0
         | 
| 166 | 
             
                    text = perform_variable_replacements(self._extraction_prompt, variables=variables)
         | 
| 167 | 
             
                    gen_conf = {"temperature": 0.3}
         | 
| 168 | 
            -
                    response = self. | 
| 169 | 
            -
                    if response.find("**ERROR**") >= 0:
         | 
| 170 | 
            -
                        raise Exception(response)
         | 
| 171 | 
             
                    token_count = num_tokens_from_string(text + response)
         | 
| 172 |  | 
| 173 | 
             
                    results = response or ""
         | 
| @@ -177,9 +176,7 @@ class GraphExtractor: | |
| 177 | 
             
                    for i in range(self._max_gleanings):
         | 
| 178 | 
             
                        text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
         | 
| 179 | 
             
                        history.append({"role": "user", "content": text})
         | 
| 180 | 
            -
                        response = self. | 
| 181 | 
            -
                        if response.find("**ERROR**") >=0:
         | 
| 182 | 
            -
                            raise Exception(response)
         | 
| 183 | 
             
                        results += response or ""
         | 
| 184 |  | 
| 185 | 
             
                        # if this is the final glean, don't bother updating the continuation flag
         | 
| @@ -187,7 +184,7 @@ class GraphExtractor: | |
| 187 | 
             
                            break
         | 
| 188 | 
             
                        history.append({"role": "assistant", "content": response})
         | 
| 189 | 
             
                        history.append({"role": "user", "content": LOOP_PROMPT})
         | 
| 190 | 
            -
                        continuation = self. | 
| 191 | 
             
                        if continuation != "YES":
         | 
| 192 | 
             
                            break
         | 
| 193 |  | 
|  | |
| 12 | 
             
            from typing import Any, Callable, Mapping
         | 
| 13 | 
             
            from dataclasses import dataclass
         | 
| 14 | 
             
            import tiktoken
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from graphrag.extractor import Extractor
         | 
| 17 | 
             
            from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
         | 
| 18 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
         | 
| 19 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
|  | |
| 36 | 
             
                source_docs: dict[Any, Any]
         | 
| 37 |  | 
| 38 |  | 
| 39 | 
            +
            class GraphExtractor(Extractor):
         | 
| 40 | 
             
                """Unipartite graph extractor class definition."""
         | 
| 41 |  | 
|  | |
| 42 | 
             
                _join_descriptions: bool
         | 
| 43 | 
             
                _tuple_delimiter_key: str
         | 
| 44 | 
             
                _record_delimiter_key: str
         | 
|  | |
| 166 | 
             
                    token_count = 0
         | 
| 167 | 
             
                    text = perform_variable_replacements(self._extraction_prompt, variables=variables)
         | 
| 168 | 
             
                    gen_conf = {"temperature": 0.3}
         | 
| 169 | 
            +
                    response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
         | 
|  | |
|  | |
| 170 | 
             
                    token_count = num_tokens_from_string(text + response)
         | 
| 171 |  | 
| 172 | 
             
                    results = response or ""
         | 
|  | |
| 176 | 
             
                    for i in range(self._max_gleanings):
         | 
| 177 | 
             
                        text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
         | 
| 178 | 
             
                        history.append({"role": "user", "content": text})
         | 
| 179 | 
            +
                        response = self._chat("", history, gen_conf)
         | 
|  | |
|  | |
| 180 | 
             
                        results += response or ""
         | 
| 181 |  | 
| 182 | 
             
                        # if this is the final glean, don't bother updating the continuation flag
         | 
|  | |
| 184 | 
             
                            break
         | 
| 185 | 
             
                        history.append({"role": "assistant", "content": response})
         | 
| 186 | 
             
                        history.append({"role": "user", "content": LOOP_PROMPT})
         | 
| 187 | 
            +
                        continuation = self._chat("", history, self._loop_args)
         | 
| 188 | 
             
                        if continuation != "YES":
         | 
| 189 | 
             
                            break
         | 
| 190 |  | 
    	
        graphrag/mind_map_extractor.py
    CHANGED
    
    | @@ -23,6 +23,7 @@ from typing import Any | |
| 23 | 
             
            from concurrent.futures import ThreadPoolExecutor
         | 
| 24 | 
             
            from dataclasses import dataclass
         | 
| 25 |  | 
|  | |
| 26 | 
             
            from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
         | 
| 27 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
         | 
| 28 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
| @@ -37,8 +38,7 @@ class MindMapResult: | |
| 37 | 
             
                output: dict
         | 
| 38 |  | 
| 39 |  | 
| 40 | 
            -
            class MindMapExtractor:
         | 
| 41 | 
            -
                _llm: CompletionLLM
         | 
| 42 | 
             
                _input_text_key: str
         | 
| 43 | 
             
                _mind_map_prompt: str
         | 
| 44 | 
             
                _on_error: ErrorHandlerFn
         | 
| @@ -190,7 +190,7 @@ class MindMapExtractor: | |
| 190 | 
             
                    }
         | 
| 191 | 
             
                    text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
         | 
| 192 | 
             
                    gen_conf = {"temperature": 0.5}
         | 
| 193 | 
            -
                    response = self. | 
| 194 | 
             
                    response = re.sub(r"```[^\n]*", "", response)
         | 
| 195 | 
             
                    logging.debug(response)
         | 
| 196 | 
             
                    logging.debug(self._todict(markdown_to_json.dictify(response)))
         | 
|  | |
| 23 | 
             
            from concurrent.futures import ThreadPoolExecutor
         | 
| 24 | 
             
            from dataclasses import dataclass
         | 
| 25 |  | 
| 26 | 
            +
            from graphrag.extractor import Extractor
         | 
| 27 | 
             
            from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
         | 
| 28 | 
             
            from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
         | 
| 29 | 
             
            from rag.llm.chat_model import Base as CompletionLLM
         | 
|  | |
| 38 | 
             
                output: dict
         | 
| 39 |  | 
| 40 |  | 
| 41 | 
            +
            class MindMapExtractor(Extractor):
         | 
|  | |
| 42 | 
             
                _input_text_key: str
         | 
| 43 | 
             
                _mind_map_prompt: str
         | 
| 44 | 
             
                _on_error: ErrorHandlerFn
         | 
|  | |
| 190 | 
             
                    }
         | 
| 191 | 
             
                    text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
         | 
| 192 | 
             
                    gen_conf = {"temperature": 0.5}
         | 
| 193 | 
            +
                    response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
         | 
| 194 | 
             
                    response = re.sub(r"```[^\n]*", "", response)
         | 
| 195 | 
             
                    logging.debug(response)
         | 
| 196 | 
             
                    logging.debug(self._todict(markdown_to_json.dictify(response)))
         | 
    	
        graphrag/utils.py
    CHANGED
    
    | @@ -6,9 +6,15 @@ Reference: | |
| 6 | 
             
            """
         | 
| 7 |  | 
| 8 | 
             
            import html
         | 
|  | |
| 9 | 
             
            import re
         | 
| 10 | 
             
            from typing import Any, Callable
         | 
| 11 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 12 | 
             
            ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
         | 
| 13 |  | 
| 14 |  | 
| @@ -60,3 +66,49 @@ def dict_has_keys_with_types( | |
| 60 | 
             
                        return False
         | 
| 61 | 
             
                return True
         | 
| 62 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 6 | 
             
            """
         | 
| 7 |  | 
| 8 | 
             
            import html
         | 
| 9 | 
            +
            import json
         | 
| 10 | 
             
            import re
         | 
| 11 | 
             
            from typing import Any, Callable
         | 
| 12 |  | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import xxhash
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from rag.utils.redis_conn import REDIS_CONN
         | 
| 17 | 
            +
             | 
| 18 | 
             
            ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
         | 
| 19 |  | 
| 20 |  | 
|  | |
| 66 | 
             
                        return False
         | 
| 67 | 
             
                return True
         | 
| 68 |  | 
| 69 | 
            +
             | 
| 70 | 
            +
            def get_llm_cache(llmnm, txt, history, genconf):
         | 
| 71 | 
            +
                hasher = xxhash.xxh64()
         | 
| 72 | 
            +
                hasher.update(str(llmnm).encode("utf-8"))
         | 
| 73 | 
            +
                hasher.update(str(txt).encode("utf-8"))
         | 
| 74 | 
            +
                hasher.update(str(history).encode("utf-8"))
         | 
| 75 | 
            +
                hasher.update(str(genconf).encode("utf-8"))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                k = hasher.hexdigest()
         | 
| 78 | 
            +
                bin = REDIS_CONN.get(k)
         | 
| 79 | 
            +
                if not bin:
         | 
| 80 | 
            +
                    return
         | 
| 81 | 
            +
                return bin.decode("utf-8")
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            def set_llm_cache(llmnm, txt, v: str, history, genconf):
         | 
| 85 | 
            +
                hasher = xxhash.xxh64()
         | 
| 86 | 
            +
                hasher.update(str(llmnm).encode("utf-8"))
         | 
| 87 | 
            +
                hasher.update(str(txt).encode("utf-8"))
         | 
| 88 | 
            +
                hasher.update(str(history).encode("utf-8"))
         | 
| 89 | 
            +
                hasher.update(str(genconf).encode("utf-8"))
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                k = hasher.hexdigest()
         | 
| 92 | 
            +
                REDIS_CONN.set(k, v.encode("utf-8"), 24*3600)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def get_embed_cache(llmnm, txt):
         | 
| 96 | 
            +
                hasher = xxhash.xxh64()
         | 
| 97 | 
            +
                hasher.update(str(llmnm).encode("utf-8"))
         | 
| 98 | 
            +
                hasher.update(str(txt).encode("utf-8"))
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                k = hasher.hexdigest()
         | 
| 101 | 
            +
                bin = REDIS_CONN.get(k)
         | 
| 102 | 
            +
                if not bin:
         | 
| 103 | 
            +
                    return
         | 
| 104 | 
            +
                return np.array(json.loads(bin.decode("utf-8")))
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def set_embed_cache(llmnm, txt, arr):
         | 
| 108 | 
            +
                hasher = xxhash.xxh64()
         | 
| 109 | 
            +
                hasher.update(str(llmnm).encode("utf-8"))
         | 
| 110 | 
            +
                hasher.update(str(txt).encode("utf-8"))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                k = hasher.hexdigest()
         | 
| 113 | 
            +
                arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
         | 
| 114 | 
            +
                REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
         | 
    	
        rag/raptor.py
    CHANGED
    
    | @@ -21,6 +21,7 @@ import umap | |
| 21 | 
             
            import numpy as np
         | 
| 22 | 
             
            from sklearn.mixture import GaussianMixture
         | 
| 23 |  | 
|  | |
| 24 | 
             
            from rag.utils import truncate
         | 
| 25 |  | 
| 26 |  | 
| @@ -33,6 +34,27 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |
| 33 | 
             
                    self._prompt = prompt
         | 
| 34 | 
             
                    self._max_token = max_token
         | 
| 35 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 36 | 
             
                def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
         | 
| 37 | 
             
                    max_clusters = min(self._max_cluster, len(embeddings))
         | 
| 38 | 
             
                    n_clusters = np.arange(1, max_clusters)
         | 
| @@ -57,7 +79,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |
| 57 | 
             
                            texts = [chunks[i][0] for i in ck_idx]
         | 
| 58 | 
             
                            len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
         | 
| 59 | 
             
                            cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
         | 
| 60 | 
            -
                            cnt = self. | 
| 61 | 
             
                                                       [{"role": "user",
         | 
| 62 | 
             
                                                         "content": self._prompt.format(cluster_content=cluster_content)}],
         | 
| 63 | 
             
                                                       {"temperature": 0.3, "max_tokens": self._max_token}
         | 
| @@ -67,9 +89,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |
| 67 | 
             
                            logging.debug(f"SUM: {cnt}")
         | 
| 68 | 
             
                            embds, _ = self._embd_model.encode([cnt])
         | 
| 69 | 
             
                            with lock:
         | 
| 70 | 
            -
                                 | 
| 71 | 
            -
                                    return
         | 
| 72 | 
            -
                                chunks.append((cnt, embds[0]))
         | 
| 73 | 
             
                        except Exception as e:
         | 
| 74 | 
             
                            logging.exception("summarize got exception")
         | 
| 75 | 
             
                            return e
         | 
|  | |
| 21 | 
             
            import numpy as np
         | 
| 22 | 
             
            from sklearn.mixture import GaussianMixture
         | 
| 23 |  | 
| 24 | 
            +
            from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
         | 
| 25 | 
             
            from rag.utils import truncate
         | 
| 26 |  | 
| 27 |  | 
|  | |
| 34 | 
             
                    self._prompt = prompt
         | 
| 35 | 
             
                    self._max_token = max_token
         | 
| 36 |  | 
| 37 | 
            +
                def _chat(self, system, history, gen_conf):
         | 
| 38 | 
            +
                    response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
         | 
| 39 | 
            +
                    if response:
         | 
| 40 | 
            +
                        return response
         | 
| 41 | 
            +
                    response = self._llm_model.chat(system, history, gen_conf)
         | 
| 42 | 
            +
                    if response.find("**ERROR**") >= 0:
         | 
| 43 | 
            +
                        raise Exception(response)
         | 
| 44 | 
            +
                    set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
         | 
| 45 | 
            +
                    return response
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def _embedding_encode(self, txt):
         | 
| 48 | 
            +
                    response = get_embed_cache(self._embd_model.llm_name, txt)
         | 
| 49 | 
            +
                    if response:
         | 
| 50 | 
            +
                        return response
         | 
| 51 | 
            +
                    embds, _ = self._embd_model.encode([txt])
         | 
| 52 | 
            +
                    if len(embds) < 1 or len(embds[0]) < 1:
         | 
| 53 | 
            +
                        raise Exception("Embedding error: ")
         | 
| 54 | 
            +
                    embds = embds[0]
         | 
| 55 | 
            +
                    set_embed_cache(self._embd_model.llm_name, txt, embds)
         | 
| 56 | 
            +
                    return embds
         | 
| 57 | 
            +
             | 
| 58 | 
             
                def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
         | 
| 59 | 
             
                    max_clusters = min(self._max_cluster, len(embeddings))
         | 
| 60 | 
             
                    n_clusters = np.arange(1, max_clusters)
         | 
|  | |
| 79 | 
             
                            texts = [chunks[i][0] for i in ck_idx]
         | 
| 80 | 
             
                            len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
         | 
| 81 | 
             
                            cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
         | 
| 82 | 
            +
                            cnt = self._chat("You're a helpful assistant.",
         | 
| 83 | 
             
                                                       [{"role": "user",
         | 
| 84 | 
             
                                                         "content": self._prompt.format(cluster_content=cluster_content)}],
         | 
| 85 | 
             
                                                       {"temperature": 0.3, "max_tokens": self._max_token}
         | 
|  | |
| 89 | 
             
                            logging.debug(f"SUM: {cnt}")
         | 
| 90 | 
             
                            embds, _ = self._embd_model.encode([cnt])
         | 
| 91 | 
             
                            with lock:
         | 
| 92 | 
            +
                                chunks.append((cnt, self._embedding_encode(cnt)))
         | 
|  | |
|  | |
| 93 | 
             
                        except Exception as e:
         | 
| 94 | 
             
                            logging.exception("summarize got exception")
         | 
| 95 | 
             
                            return e
         | 
    	
        rag/svr/task_executor.py
    CHANGED
    
    | @@ -19,6 +19,8 @@ | |
| 19 |  | 
| 20 | 
             
            import sys
         | 
| 21 | 
             
            from api.utils.log_utils import initRootLogger
         | 
|  | |
|  | |
| 22 | 
             
            CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
         | 
| 23 | 
             
            CONSUMER_NAME = "task_executor_" + CONSUMER_NO
         | 
| 24 | 
             
            initRootLogger(CONSUMER_NAME)
         | 
| @@ -232,9 +234,6 @@ def build_chunks(task, progress_callback): | |
| 232 | 
             
                    if not d.get("image"):
         | 
| 233 | 
             
                        _ = d.pop("image", None)
         | 
| 234 | 
             
                        d["img_id"] = ""
         | 
| 235 | 
            -
                        d["page_num_int"] = []
         | 
| 236 | 
            -
                        d["position_int"] = []
         | 
| 237 | 
            -
                        d["top_int"] = []
         | 
| 238 | 
             
                        docs.append(d)
         | 
| 239 | 
             
                        continue
         | 
| 240 |  | 
| @@ -262,8 +261,16 @@ def build_chunks(task, progress_callback): | |
| 262 | 
             
                    progress_callback(msg="Start to generate keywords for every chunk ...")
         | 
| 263 | 
             
                    chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
         | 
| 264 | 
             
                    for d in docs:
         | 
| 265 | 
            -
                         | 
| 266 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 267 | 
             
                        d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
         | 
| 268 | 
             
                    progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
         | 
| 269 |  | 
| @@ -272,7 +279,15 @@ def build_chunks(task, progress_callback): | |
| 272 | 
             
                    progress_callback(msg="Start to generate questions for every chunk ...")
         | 
| 273 | 
             
                    chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
         | 
| 274 | 
             
                    for d in docs:
         | 
| 275 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 276 | 
             
                        d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
         | 
| 277 | 
             
                    progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
         | 
| 278 |  | 
|  | |
| 19 |  | 
| 20 | 
             
            import sys
         | 
| 21 | 
             
            from api.utils.log_utils import initRootLogger
         | 
| 22 | 
            +
            from graphrag.utils import get_llm_cache, set_llm_cache
         | 
| 23 | 
            +
             | 
| 24 | 
             
            CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
         | 
| 25 | 
             
            CONSUMER_NAME = "task_executor_" + CONSUMER_NO
         | 
| 26 | 
             
            initRootLogger(CONSUMER_NAME)
         | 
|  | |
| 234 | 
             
                    if not d.get("image"):
         | 
| 235 | 
             
                        _ = d.pop("image", None)
         | 
| 236 | 
             
                        d["img_id"] = ""
         | 
|  | |
|  | |
|  | |
| 237 | 
             
                        docs.append(d)
         | 
| 238 | 
             
                        continue
         | 
| 239 |  | 
|  | |
| 261 | 
             
                    progress_callback(msg="Start to generate keywords for every chunk ...")
         | 
| 262 | 
             
                    chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
         | 
| 263 | 
             
                    for d in docs:
         | 
| 264 | 
            +
                        cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords",
         | 
| 265 | 
            +
                                               {"topn": task["parser_config"]["auto_keywords"]})
         | 
| 266 | 
            +
                        if not cached:
         | 
| 267 | 
            +
                            cached = keyword_extraction(chat_mdl, d["content_with_weight"],
         | 
| 268 | 
            +
                                                        task["parser_config"]["auto_keywords"])
         | 
| 269 | 
            +
                            if cached:
         | 
| 270 | 
            +
                                set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords",
         | 
| 271 | 
            +
                                              {"topn": task["parser_config"]["auto_keywords"]})
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                        d["important_kwd"] = cached.split(",")
         | 
| 274 | 
             
                        d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
         | 
| 275 | 
             
                    progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
         | 
| 276 |  | 
|  | |
| 279 | 
             
                    progress_callback(msg="Start to generate questions for every chunk ...")
         | 
| 280 | 
             
                    chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
         | 
| 281 | 
             
                    for d in docs:
         | 
| 282 | 
            +
                        cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question",
         | 
| 283 | 
            +
                                               {"topn": task["parser_config"]["auto_questions"]})
         | 
| 284 | 
            +
                        if not cached:
         | 
| 285 | 
            +
                            cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"])
         | 
| 286 | 
            +
                            if cached:
         | 
| 287 | 
            +
                                set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question",
         | 
| 288 | 
            +
                                              {"topn": task["parser_config"]["auto_questions"]})
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                        d["question_kwd"] = cached.split("\n")
         | 
| 291 | 
             
                        d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
         | 
| 292 | 
             
                    progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
         | 
| 293 |  |