actualbrain's picture
score-45, gpt-4.1
3c0a133
raw
history blame contribute delete
10.4 kB
"""Define a data enrichment agent.
Works with a chat model with tool calling support.
"""
import json
from typing import Any, Dict, List, Literal, Optional, cast
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel, Field
from researchgraph import prompts
from researchgraph.configuration import Configuration
from researchgraph.state import InputState, OutputState, State
from researchgraph.tools import scrape_website, search, get_file_content
from researchgraph.utils import init_model
async def call_agent_model(
state: State, *, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Call the primary Language Model (LLM) to decide on the next research action.
This asynchronous function performs the following steps:
1. Initializes configuration and sets up the 'Info' tool, which is the user-defined extraction schema.
2. Prepares the prompt and message history for the LLM.
3. Initializes and configures the LLM with available tools.
4. Invokes the LLM and processes its response.
5. Handles the LLM's decision to either continue research or submit final info.
"""
# Load configuration from the provided RunnableConfig
configuration = Configuration.from_runnable_config(config)
# Define the 'Info' tool, which is the user-defined extraction schema
info_tool = {
"name": "Info",
"description": "Call this when you have gathered all the relevant info",
"parameters": configuration.extraction_schema,
}
# Define the GetFile tool
get_file_tool = {
"name": "GetFile",
"description": "Fetch content from the scoring system for a given task ID",
"parameters": {
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "The ID of the task/file to fetch",
}
},
"required": ["task_id"],
},
}
# Format the prompt defined in prompts.py with the extraction schema, question and task_id
p = configuration.prompt.format(
info=json.dumps(configuration.extraction_schema, indent=2),
question=state.question,
task_id=state.task_id,
)
# Create the messages list with the formatted prompt and the previous messages
messages = [HumanMessage(content=p)] + state.messages
# Initialize the raw model with the provided configuration and bind the tools
raw_model = init_model(config)
model = raw_model.bind_tools(
[scrape_website, search, get_file_content, info_tool, get_file_tool],
tool_choice="any",
)
response = cast(AIMessage, await model.ainvoke(messages))
# Initialize info to None
info = None
# Check if the response has tool calls
if response.tool_calls:
for tool_call in response.tool_calls:
if tool_call["name"] == "Info":
info = tool_call["args"]
break
if info is not None:
# The agent is submitting their answer;
# ensure it isn't erroneously attempting to simultaneously perform research
response.tool_calls = [
next(tc for tc in response.tool_calls if tc["name"] == "Info")
]
response_messages: List[BaseMessage] = [response]
if not response.tool_calls: # If LLM didn't respect the tool_choice
response_messages.append(
HumanMessage(content="Please respond by calling one of the provided tools.")
)
return {
"messages": response_messages,
"info": info,
# Add 1 to the step count
"loop_step": 1,
}
class InfoIsSatisfactory(BaseModel):
"""Validate whether the current extracted info is satisfactory and complete."""
reason: List[str] = Field(
description="First, provide reasoning for why this is either good or bad as a final result. Must include at least 3 reasons."
)
is_satisfactory: bool = Field(
description="After providing your reasoning, provide a value indicating whether the result is satisfactory. If not, you will continue researching."
)
improvement_instructions: Optional[str] = Field(
description="If the result is not satisfactory, provide clear and specific instructions on what needs to be improved or added to make the information satisfactory."
" This should include details on missing information, areas that need more depth, or specific aspects to focus on in further research.",
default=None,
)
async def reflect(
state: State, *, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Validate the quality of the data enrichment agent's output.
This asynchronous function performs the following steps:
1. Prepares the initial prompt using the main prompt template.
2. Constructs a message history for the model.
3. Prepares a checker prompt to evaluate the presumed info.
4. Initializes and configures a language model with structured output.
5. Invokes the model to assess the quality of the gathered information.
6. Processes the model's response and determines if the info is satisfactory.
"""
configuration = Configuration.from_runnable_config(config)
p = prompts.MAIN_PROMPT.format(
info=json.dumps(configuration.extraction_schema, indent=2),
question=state.question,
task_id=state.task_id,
)
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"{reflect.__name__} expects the last message in the state to be an AI message with tool calls."
f" Got: {type(last_message)}"
)
messages = [HumanMessage(content=p)] + state.messages[:-1]
presumed_info = state.info
checker_prompt = """I am thinking of calling the info tool with the info below. \
Is this good? Give your reasoning as well. \
You can encourage the Assistant to look at specific URLs if that seems relevant, or do more searches.
If you don't think it is good, you should be very specific about what could be improved.
{presumed_info}"""
p1 = checker_prompt.format(presumed_info=json.dumps(presumed_info or {}, indent=2))
messages.append(HumanMessage(content=p1))
raw_model = init_model(config)
bound_model = raw_model.with_structured_output(InfoIsSatisfactory)
response = cast(InfoIsSatisfactory, await bound_model.ainvoke(messages))
if response.is_satisfactory and presumed_info:
return {
"info": presumed_info,
"messages": [
ToolMessage(
tool_call_id=last_message.tool_calls[0]["id"],
content="\n".join(response.reason),
name="Info",
additional_kwargs={"artifact": response.model_dump()},
status="success",
)
],
}
else:
return {
"messages": [
ToolMessage(
tool_call_id=last_message.tool_calls[0]["id"],
content=f"Unsatisfactory response:\n{response.improvement_instructions}",
name="Info",
additional_kwargs={"artifact": response.model_dump()},
status="error",
)
]
}
def route_after_agent(
state: State,
) -> Literal["reflect", "tools", "call_agent_model", "__end__"]:
"""Schedule the next node after the agent's action.
This function determines the next step in the research process based on the
last message in the state. It handles three main scenarios:
1. Error recovery: If the last message is unexpectedly not an AIMessage.
2. Info submission: If the agent has called the "Info" tool to submit findings.
3. Continued research: If the agent has called any other tool.
"""
last_message = state.messages[-1]
# "If for some reason the last message is not an AIMessage (due to a bug or unexpected behavior elsewhere in the code),
# it ensures the system doesn't crash but instead tries to recover by calling the agent model again.
if not isinstance(last_message, AIMessage):
return "call_agent_model"
# If the "Info" tool was called, then the model provided its extraction output. Reflect on the result
if last_message.tool_calls and last_message.tool_calls[0]["name"] == "Info":
return "reflect"
# The last message is a tool call that is not "Info" (extraction output)
else:
return "tools"
def route_after_checker(
state: State, config: RunnableConfig
) -> Literal["__end__", "call_agent_model"]:
"""Schedule the next node after the checker's evaluation.
This function determines whether to continue the research process or end it
based on the checker's evaluation and the current state of the research.
"""
configurable = Configuration.from_runnable_config(config)
last_message = state.messages[-1]
if state.loop_step < configurable.max_loops:
if not state.info:
return "call_agent_model"
if not isinstance(last_message, ToolMessage):
raise ValueError(
f"{route_after_checker.__name__} expected a tool messages. Received: {type(last_message)}."
)
if last_message.status == "error":
# Research deemed unsatisfactory
return "call_agent_model"
# It's great!
return "__end__"
else:
return "__end__"
# Create the researcher graph
researcher_workflow = StateGraph(
State, input=InputState, output=OutputState, config_schema=Configuration
)
researcher_workflow.add_node(call_agent_model)
researcher_workflow.add_node(reflect)
researcher_workflow.add_node(
"tools", ToolNode([search, scrape_website, get_file_content])
)
researcher_workflow.add_edge("__start__", "call_agent_model")
researcher_workflow.add_conditional_edges("call_agent_model", route_after_agent)
researcher_workflow.add_edge("tools", "call_agent_model")
researcher_workflow.add_conditional_edges("reflect", route_after_checker)
researchgraph = researcher_workflow.compile()
researchgraph.name = "Agent"