from typing import TypedDict, List, Dict, Optional, Any, Union from langchain_core.messages import BaseMessage import logging logger = logging.getLogger(__name__) class JARVISState(TypedDict): """ State dictionary for the JARVIS GAIA Agent, used with LangGraph to manage task processing. Attributes: task_id: Unique identifier for the GAIA task. question: The question text to be answered. tools_needed: List of tool names to be used for the task. web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo). file_results: Parsed content from text, CSV, Excel, or audio files. image_results: OCR or description results from image files. calculation_results: Results from mathematical calculations. document_results: Extracted content from PDF or text documents. multi_hop_results: Results from iterative multi-hop searches (supports strings or dicts). messages: List of messages for LLM context (e.g., user prompts, system instructions). answer: Final answer for the task, formatted for GAIA submission. results_table: List of task results for Gradio display (Task ID, Question, Answer). status_output: Status message for Gradio output (e.g., submission result). error: Optional error message if task processing fails. metadata: Optional metadata (e.g., timestamps, tool execution status). """ task_id: str question: str tools_needed: List[str] web_results: List[str] file_results: str image_results: str calculation_results: str document_results: str multi_hop_results: List[Union[str, Dict[str, Any]]] messages: List[BaseMessage] answer: str results_table: List[Dict[str, str]] status_output: str error: Optional[str] metadata: Optional[Dict[str, Any]] def validate_state(state: JARVISState) -> JARVISState: """ Validate and initialize JARVISState fields. Args: state: Input state dictionary. Returns: Validated and initialized state. """ try: if not state.get("task_id"): logger.error("task_id is required") raise ValueError("task_id is required") if not state.get("question"): logger.error("question is required") raise ValueError("question is required") # Initialize default values if missing defaults = { "tools_needed": ["search_tool"], "web_results": [], "file_results": "", "image_results": "", "calculation_results": "", "document_results": "", "multi_hop_results": [], "messages": [], "answer": "", "results_table": [], "status_output": "", "error": None, "metadata": {} } for key, default in defaults.items(): if key not in state or state[key] is None: state[key] = default logger.debug(f"Validated state for task {state['task_id']}") return state except Exception as e: logger.error(f"State validation failed: {e}") raise def reset_state(task_id: str, question: str) -> JARVISState: """ Create a fresh JARVISState for a new task. Args: task_id: Task identifier. question: Question text. Returns: Initialized JARVISState. """ state = JARVISState( task_id=task_id, question=question, tools_needed=["search_tool"], web_results=[], file_results="", image_results="", calculation_results="", document_results="", multi_hop_results=[], messages=[], answer="", results_table=[], status_output="", error=None, metadata={} ) return validate_state(state)