File size: 3,941 Bytes
751d628
c6951f4
751d628
 
 
1bbca12
 
c6951f4
 
 
 
 
 
 
 
751d628
c6951f4
 
751d628
 
c6951f4
 
 
 
 
 
 
1bbca12
 
 
 
 
 
 
 
751d628
c6951f4
 
 
 
 
751d628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import TypedDict, List, Dict, Optional, Any, Union
from langchain_core.messages import BaseMessage
import logging

logger = logging.getLogger(__name__)

class JARVISState(TypedDict):
    """
    State dictionary for the JARVIS GAIA Agent, used with LangGraph to manage task processing.

    Attributes:
        task_id: Unique identifier for the GAIA task.
        question: The question text to be answered.
        tools_needed: List of tool names to be used for the task.
        web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
        file_results: Parsed content from text, CSV, Excel, or audio files.
        image_results: OCR or description results from image files.
        calculation_results: Results from mathematical calculations.
        document_results: Extracted content from PDF or text documents.
        multi_hop_results: Results from iterative multi-hop searches (supports strings or dicts).
        messages: List of messages for LLM context (e.g., user prompts, system instructions).
        answer: Final answer for the task, formatted for GAIA submission.
        results_table: List of task results for Gradio display (Task ID, Question, Answer).
        status_output: Status message for Gradio output (e.g., submission result).
        error: Optional error message if task processing fails.
        metadata: Optional metadata (e.g., timestamps, tool execution status).
    """
    task_id: str
    question: str
    tools_needed: List[str]
    web_results: List[str]
    file_results: str
    image_results: str
    calculation_results: str
    document_results: str
    multi_hop_results: List[Union[str, Dict[str, Any]]]
    messages: List[BaseMessage]
    answer: str
    results_table: List[Dict[str, str]]
    status_output: str
    error: Optional[str]
    metadata: Optional[Dict[str, Any]]

def validate_state(state: JARVISState) -> JARVISState:
    """
    Validate and initialize JARVISState fields.
    
    Args:
        state: Input state dictionary.
    
    Returns:
        Validated and initialized state.
    """
    try:
        if not state.get("task_id"):
            logger.error("task_id is required")
            raise ValueError("task_id is required")
        if not state.get("question"):
            logger.error("question is required")
            raise ValueError("question is required")
        
        # Initialize default values if missing
        defaults = {
            "tools_needed": ["search_tool"],
            "web_results": [],
            "file_results": "",
            "image_results": "",
            "calculation_results": "",
            "document_results": "",
            "multi_hop_results": [],
            "messages": [],
            "answer": "",
            "results_table": [],
            "status_output": "",
            "error": None,
            "metadata": {}
        }
        for key, default in defaults.items():
            if key not in state or state[key] is None:
                state[key] = default
        
        logger.debug(f"Validated state for task {state['task_id']}")
        return state
    except Exception as e:
        logger.error(f"State validation failed: {e}")
        raise

def reset_state(task_id: str, question: str) -> JARVISState:
    """
    Create a fresh JARVISState for a new task.
    
    Args:
        task_id: Task identifier.
        question: Question text.
    
    Returns:
        Initialized JARVISState.
    """
    state = JARVISState(
        task_id=task_id,
        question=question,
        tools_needed=["search_tool"],
        web_results=[],
        file_results="",
        image_results="",
        calculation_results="",
        document_results="",
        multi_hop_results=[],
        messages=[],
        answer="",
        results_table=[],
        status_output="",
        error=None,
        metadata={}
    )
    return validate_state(state)