Spaces:
Sleeping
Sleeping
"""Define a data enrichment agent. | |
Works with a chat model with tool calling support. | |
""" | |
import json | |
from typing import Any, Dict, List, Literal, Optional, cast | |
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage | |
from langchain_core.runnables import RunnableConfig | |
from langgraph.graph import StateGraph | |
from langgraph.prebuilt import ToolNode | |
from pydantic import BaseModel, Field | |
from researchgraph import prompts | |
from researchgraph.configuration import Configuration | |
from researchgraph.state import InputState, OutputState, State | |
from researchgraph.tools import scrape_website, search, get_file_content | |
from researchgraph.utils import init_model | |
async def call_agent_model( | |
state: State, *, config: Optional[RunnableConfig] = None | |
) -> Dict[str, Any]: | |
"""Call the primary Language Model (LLM) to decide on the next research action. | |
This asynchronous function performs the following steps: | |
1. Initializes configuration and sets up the 'Info' tool, which is the user-defined extraction schema. | |
2. Prepares the prompt and message history for the LLM. | |
3. Initializes and configures the LLM with available tools. | |
4. Invokes the LLM and processes its response. | |
5. Handles the LLM's decision to either continue research or submit final info. | |
""" | |
# Load configuration from the provided RunnableConfig | |
configuration = Configuration.from_runnable_config(config) | |
# Define the 'Info' tool, which is the user-defined extraction schema | |
info_tool = { | |
"name": "Info", | |
"description": "Call this when you have gathered all the relevant info", | |
"parameters": configuration.extraction_schema, | |
} | |
# Define the GetFile tool | |
get_file_tool = { | |
"name": "GetFile", | |
"description": "Fetch content from the scoring system for a given task ID", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"task_id": { | |
"type": "string", | |
"description": "The ID of the task/file to fetch", | |
} | |
}, | |
"required": ["task_id"], | |
}, | |
} | |
# Format the prompt defined in prompts.py with the extraction schema, question and task_id | |
p = configuration.prompt.format( | |
info=json.dumps(configuration.extraction_schema, indent=2), | |
question=state.question, | |
task_id=state.task_id, | |
) | |
# Create the messages list with the formatted prompt and the previous messages | |
messages = [HumanMessage(content=p)] + state.messages | |
# Initialize the raw model with the provided configuration and bind the tools | |
raw_model = init_model(config) | |
model = raw_model.bind_tools( | |
[scrape_website, search, get_file_content, info_tool, get_file_tool], | |
tool_choice="any", | |
) | |
response = cast(AIMessage, await model.ainvoke(messages)) | |
# Initialize info to None | |
info = None | |
# Check if the response has tool calls | |
if response.tool_calls: | |
for tool_call in response.tool_calls: | |
if tool_call["name"] == "Info": | |
info = tool_call["args"] | |
break | |
if info is not None: | |
# The agent is submitting their answer; | |
# ensure it isn't erroneously attempting to simultaneously perform research | |
response.tool_calls = [ | |
next(tc for tc in response.tool_calls if tc["name"] == "Info") | |
] | |
response_messages: List[BaseMessage] = [response] | |
if not response.tool_calls: # If LLM didn't respect the tool_choice | |
response_messages.append( | |
HumanMessage(content="Please respond by calling one of the provided tools.") | |
) | |
return { | |
"messages": response_messages, | |
"info": info, | |
# Add 1 to the step count | |
"loop_step": 1, | |
} | |
class InfoIsSatisfactory(BaseModel): | |
"""Validate whether the current extracted info is satisfactory and complete.""" | |
reason: List[str] = Field( | |
description="First, provide reasoning for why this is either good or bad as a final result. Must include at least 3 reasons." | |
) | |
is_satisfactory: bool = Field( | |
description="After providing your reasoning, provide a value indicating whether the result is satisfactory. If not, you will continue researching." | |
) | |
improvement_instructions: Optional[str] = Field( | |
description="If the result is not satisfactory, provide clear and specific instructions on what needs to be improved or added to make the information satisfactory." | |
" This should include details on missing information, areas that need more depth, or specific aspects to focus on in further research.", | |
default=None, | |
) | |
async def reflect( | |
state: State, *, config: Optional[RunnableConfig] = None | |
) -> Dict[str, Any]: | |
"""Validate the quality of the data enrichment agent's output. | |
This asynchronous function performs the following steps: | |
1. Prepares the initial prompt using the main prompt template. | |
2. Constructs a message history for the model. | |
3. Prepares a checker prompt to evaluate the presumed info. | |
4. Initializes and configures a language model with structured output. | |
5. Invokes the model to assess the quality of the gathered information. | |
6. Processes the model's response and determines if the info is satisfactory. | |
""" | |
configuration = Configuration.from_runnable_config(config) | |
p = prompts.MAIN_PROMPT.format( | |
info=json.dumps(configuration.extraction_schema, indent=2), | |
question=state.question, | |
task_id=state.task_id, | |
) | |
last_message = state.messages[-1] | |
if not isinstance(last_message, AIMessage): | |
raise ValueError( | |
f"{reflect.__name__} expects the last message in the state to be an AI message with tool calls." | |
f" Got: {type(last_message)}" | |
) | |
messages = [HumanMessage(content=p)] + state.messages[:-1] | |
presumed_info = state.info | |
checker_prompt = """I am thinking of calling the info tool with the info below. \ | |
Is this good? Give your reasoning as well. \ | |
You can encourage the Assistant to look at specific URLs if that seems relevant, or do more searches. | |
If you don't think it is good, you should be very specific about what could be improved. | |
{presumed_info}""" | |
p1 = checker_prompt.format(presumed_info=json.dumps(presumed_info or {}, indent=2)) | |
messages.append(HumanMessage(content=p1)) | |
raw_model = init_model(config) | |
bound_model = raw_model.with_structured_output(InfoIsSatisfactory) | |
response = cast(InfoIsSatisfactory, await bound_model.ainvoke(messages)) | |
if response.is_satisfactory and presumed_info: | |
return { | |
"info": presumed_info, | |
"messages": [ | |
ToolMessage( | |
tool_call_id=last_message.tool_calls[0]["id"], | |
content="\n".join(response.reason), | |
name="Info", | |
additional_kwargs={"artifact": response.model_dump()}, | |
status="success", | |
) | |
], | |
} | |
else: | |
return { | |
"messages": [ | |
ToolMessage( | |
tool_call_id=last_message.tool_calls[0]["id"], | |
content=f"Unsatisfactory response:\n{response.improvement_instructions}", | |
name="Info", | |
additional_kwargs={"artifact": response.model_dump()}, | |
status="error", | |
) | |
] | |
} | |
def route_after_agent( | |
state: State, | |
) -> Literal["reflect", "tools", "call_agent_model", "__end__"]: | |
"""Schedule the next node after the agent's action. | |
This function determines the next step in the research process based on the | |
last message in the state. It handles three main scenarios: | |
1. Error recovery: If the last message is unexpectedly not an AIMessage. | |
2. Info submission: If the agent has called the "Info" tool to submit findings. | |
3. Continued research: If the agent has called any other tool. | |
""" | |
last_message = state.messages[-1] | |
# "If for some reason the last message is not an AIMessage (due to a bug or unexpected behavior elsewhere in the code), | |
# it ensures the system doesn't crash but instead tries to recover by calling the agent model again. | |
if not isinstance(last_message, AIMessage): | |
return "call_agent_model" | |
# If the "Info" tool was called, then the model provided its extraction output. Reflect on the result | |
if last_message.tool_calls and last_message.tool_calls[0]["name"] == "Info": | |
return "reflect" | |
# The last message is a tool call that is not "Info" (extraction output) | |
else: | |
return "tools" | |
def route_after_checker( | |
state: State, config: RunnableConfig | |
) -> Literal["__end__", "call_agent_model"]: | |
"""Schedule the next node after the checker's evaluation. | |
This function determines whether to continue the research process or end it | |
based on the checker's evaluation and the current state of the research. | |
""" | |
configurable = Configuration.from_runnable_config(config) | |
last_message = state.messages[-1] | |
if state.loop_step < configurable.max_loops: | |
if not state.info: | |
return "call_agent_model" | |
if not isinstance(last_message, ToolMessage): | |
raise ValueError( | |
f"{route_after_checker.__name__} expected a tool messages. Received: {type(last_message)}." | |
) | |
if last_message.status == "error": | |
# Research deemed unsatisfactory | |
return "call_agent_model" | |
# It's great! | |
return "__end__" | |
else: | |
return "__end__" | |
# Create the researcher graph | |
researcher_workflow = StateGraph( | |
State, input=InputState, output=OutputState, config_schema=Configuration | |
) | |
researcher_workflow.add_node(call_agent_model) | |
researcher_workflow.add_node(reflect) | |
researcher_workflow.add_node( | |
"tools", ToolNode([search, scrape_website, get_file_content]) | |
) | |
researcher_workflow.add_edge("__start__", "call_agent_model") | |
researcher_workflow.add_conditional_edges("call_agent_model", route_after_agent) | |
researcher_workflow.add_edge("tools", "call_agent_model") | |
researcher_workflow.add_conditional_edges("reflect", route_after_checker) | |
researchgraph = researcher_workflow.compile() | |
researchgraph.name = "Agent" | |