Spaces:
Build error
Build error
Refactor main_v2.py to update task formatting for dual answer requests, enhancing response structure. Implement error handling for JSON parsing in agent results, ensuring robust output. Add unit tests in test_questions.py to validate succinct answer accuracy against expected values. Remove unused extract_final_answer utility from utils.py, streamlining the codebase.
2da6a11
unverified
import asyncio | |
import importlib | |
import logging | |
import os | |
import time | |
import uuid # for generating thread IDs for checkpointer | |
from typing import AsyncIterator, Optional, TypedDict | |
import litellm | |
import yaml | |
from dotenv import find_dotenv, load_dotenv | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.graph import END, START, StateGraph | |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor | |
from opentelemetry.sdk.trace.export import BatchSpanProcessor | |
from phoenix.otel import register | |
from smolagents import CodeAgent, LiteLLMModel | |
from smolagents.memory import ActionStep, FinalAnswerStep | |
from smolagents.monitoring import LogLevel | |
from utils import extract_final_answer | |
from agents import create_data_analysis_agent, create_media_agent, create_web_agent | |
from prompts import MANAGER_SYSTEM_PROMPT | |
from tools import perform_calculation, web_search | |
litellm._turn_on_debug() | |
# Configure OpenTelemetry with BatchSpanProcessor | |
register() | |
tracer_provider = register() | |
tracer_provider.add_span_processor(BatchSpanProcessor()) | |
SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider) | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv(find_dotenv()) | |
# Get required environment variables with validation | |
API_BASE = os.getenv("API_BASE") | |
API_KEY = os.getenv("API_KEY") | |
MODEL_ID = os.getenv("MODEL_ID") | |
if not all([API_BASE, API_KEY, MODEL_ID]): | |
raise ValueError( | |
"Missing required environment variables: API_BASE, API_KEY, MODEL_ID" | |
) | |
# Define the state types for our graph | |
class AgentState(TypedDict): | |
task: str | |
current_step: Optional[dict] # Store serializable dict instead of ActionStep | |
error: Optional[str] | |
answer_text: Optional[str] | |
# Initialize model with error handling | |
try: | |
model = LiteLLMModel( | |
api_base=API_BASE, | |
api_key=API_KEY, | |
model_id=MODEL_ID, | |
) | |
except Exception as e: | |
logger.error(f"Failed to initialize model: {str(e)}") | |
raise | |
web_agent = create_web_agent(model) | |
data_agent = create_data_analysis_agent(model) | |
media_agent = create_media_agent(model) | |
tools = [ | |
# DuckDuckGoSearchTool(max_results=3), | |
# VisitWebpageTool(max_output_length=1000), | |
web_search, | |
perform_calculation, | |
] | |
# Initialize agent with error handling | |
try: | |
prompt_templates = yaml.safe_load( | |
importlib.resources.files("smolagents.prompts") | |
.joinpath("code_agent.yaml") | |
.read_text() | |
) | |
# prompt_templates["system_prompt"] = MANAGER_SYSTEM_PROMPT | |
agent = CodeAgent( | |
add_base_tools=True, | |
additional_authorized_imports=[ | |
"json", | |
"pandas", | |
"numpy", | |
"re", | |
], | |
# max_steps=10, | |
managed_agents=[web_agent, data_agent, media_agent], | |
model=model, | |
prompt_templates=prompt_templates, | |
tools=tools, | |
step_callbacks=None, | |
verbosity_level=LogLevel.ERROR, | |
) | |
agent.logger.console.width = 66 | |
agent.visualize() | |
tools = agent.tools | |
print(f"Tools: {tools}") | |
except Exception as e: | |
logger.error(f"Failed to initialize agent: {str(e)}") | |
raise | |
async def process_step(state: AgentState) -> AgentState: | |
"""Process a single step of the agent's execution.""" | |
try: | |
# Clear previous step results before running agent.run | |
state["current_step"] = None | |
state["answer_text"] = None | |
state["error"] = None | |
steps = agent.run( | |
task=state["task"], | |
additional_args=None, | |
images=None, | |
# max_steps=1, # Process one step at a time | |
stream=True, | |
reset=False, # Maintain agent's internal state across process_step calls | |
) | |
for step in steps: | |
if isinstance(step, ActionStep): | |
# Convert ActionStep to serializable dict using the correct attributes | |
state["current_step"] = { | |
"step_number": step.step_number, | |
"model_output": step.model_output, | |
"observations": step.observations, | |
"tool_calls": [ | |
{"name": tc.name, "arguments": tc.arguments} | |
for tc in (step.tool_calls or []) | |
], | |
"action_output": step.action_output, | |
} | |
logger.info(f"Processed action step {step.step_number}") | |
logger.info(f"Step {step.step_number} details: {step}") | |
logger.info(f"Sleeping for 60 seconds...") | |
time.sleep(60) | |
elif isinstance(step, FinalAnswerStep): | |
state["answer_text"] = step.final_answer | |
logger.info("Processed final answer") | |
logger.debug(f"Final answer details: {step}") | |
logger.info(f"Extracted answer text: {state['answer_text']}") | |
# Return immediately when we get a final answer | |
return state | |
# If loop finishes without FinalAnswerStep, return current state | |
return state | |
except Exception as e: | |
state["error"] = str(e) | |
logger.error(f"Error during agent execution step: {str(e)}") | |
return state | |
def should_continue(state: AgentState) -> bool: | |
"""Determine if the agent should continue processing steps.""" | |
# Continue if we don't have an answer_text and no error | |
continue_execution = state.get("answer_text") is None and state.get("error") is None | |
logger.debug( | |
f"Checking should_continue: answer_text={state.get('answer_text') is not None}, error={state.get('error') is not None} -> Continue={continue_execution}" | |
) | |
return continue_execution | |
# Build the LangGraph graph once with persistence | |
memory = MemorySaver() | |
builder = StateGraph(AgentState) | |
builder.add_node("process_step", process_step) | |
builder.add_edge(START, "process_step") | |
builder.add_conditional_edges( | |
"process_step", should_continue, {True: "process_step", False: END} | |
) | |
graph = builder.compile(checkpointer=memory) | |
async def stream_execution(task: str, thread_id: str) -> AsyncIterator[AgentState]: | |
"""Stream the execution of the agent.""" | |
if not task: | |
raise ValueError("Task cannot be empty") | |
logger.info(f"Initializing agent execution for task: {task}") | |
# Initialize the state | |
initial_state: AgentState = { | |
"task": task, | |
"current_step": None, | |
"error": None, | |
"answer_text": None, | |
} | |
# Pass thread_id via the config dict so the checkpointer can persist state | |
async for state in graph.astream( | |
initial_state, {"configurable": {"thread_id": thread_id}} | |
): | |
yield state | |
# Propagate error immediately if it occurs without an answer | |
if state.get("error") and not state.get("answer_text"): | |
logger.error(f"Propagating error from stream: {state['error']}") | |
raise Exception(state["error"]) | |
async def run_with_streaming(task: str, thread_id: str) -> dict: | |
"""Run the agent with streaming output and return the results.""" | |
last_state = None | |
steps = [] | |
error = None | |
final_answer_text = None | |
try: | |
logger.info(f"Starting execution run for task: {task}") | |
async for state in stream_execution(task, thread_id): | |
last_state = state | |
if current_step := state.get("current_step"): | |
if not steps or steps[-1]["step_number"] != current_step["step_number"]: | |
steps.append(current_step) | |
# Keep print here for direct user feedback during streaming | |
print(f"\nStep {current_step['step_number']}:") | |
print(f"Model Output: {current_step['model_output']}") | |
print(f"Observations: {current_step['observations']}") | |
if current_step.get("tool_calls"): | |
print("Tool Calls:") | |
for tc in current_step["tool_calls"]: | |
print(f" - {tc['name']}: {tc['arguments']}") | |
if current_step.get("action_output"): | |
print(f"Action Output: {current_step['action_output']}") | |
# After the stream is finished, process the last state | |
logger.info("Stream finished.") | |
if last_state: | |
# LangGraph streams dicts where keys are node names, values are state dicts | |
node_name = list(last_state.keys())[0] | |
actual_state = last_state.get(node_name) | |
if actual_state: | |
final_answer_text = actual_state.get("answer_text") | |
error = actual_state.get("error") | |
logger.info( | |
f"Final answer text extracted from last state: {final_answer_text}" | |
) | |
logger.info(f"Error extracted from last state: {error}") | |
# Ensure steps list is consistent with the final state if needed | |
last_step_in_state = actual_state.get("current_step") | |
if last_step_in_state and ( | |
not steps | |
or steps[-1]["step_number"] != last_step_in_state["step_number"] | |
): | |
logger.debug("Adding last step from final state to steps list.") | |
steps.append(last_step_in_state) | |
else: | |
logger.warning( | |
"Could not find actual state dictionary within last_state." | |
) | |
return {"steps": steps, "final_answer": final_answer_text, "error": error} | |
except Exception as e: | |
import traceback | |
logger.error( | |
f"Exception during run_with_streaming: {str(e)}\n{traceback.format_exc()}" | |
) | |
# Attempt to return based on the last known state even if exception occurred outside stream | |
final_answer_text = None | |
error_msg = str(e) | |
if last_state: | |
node_name = list(last_state.keys())[0] | |
actual_state = last_state.get(node_name) | |
if actual_state: | |
final_answer_text = actual_state.get("answer_text") | |
return {"steps": steps, "final_answer": final_answer_text, "error": error_msg} | |
def main(task: str, thread_id: str = str(uuid.uuid4())): | |
# Enhance the question with minimal instructions | |
enhanced_question = f""" | |
GAIA Question: {task} | |
Please solve this multi-step reasoning problem by: | |
1. Breaking it down into logical steps | |
2. Using specialized agents when needed | |
3. Providing the final answer in the exact format requested | |
""" | |
logger.info( | |
f"Starting agent run from __main__ for task: '{task}' with thread_id: {thread_id}" | |
) | |
result = asyncio.run(run_with_streaming(enhanced_question, thread_id)) | |
logger.info("Agent run finished.") | |
logger.info(f"Result: {result}") | |
return extract_final_answer(result) | |
if __name__ == "__main__": | |
# Example Usage | |
task_to_run = "What is the capital of France?" | |
thread_id = str(uuid.uuid4()) # Generate a unique thread ID for this run | |
final_answer = main(task_to_run, thread_id) | |
print(f"Final Answer: {final_answer}") | |