Spaces:
Starting
Starting
File size: 3,941 Bytes
751d628 c6951f4 751d628 1bbca12 c6951f4 751d628 c6951f4 751d628 c6951f4 1bbca12 751d628 c6951f4 751d628 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from typing import TypedDict, List, Dict, Optional, Any, Union
from langchain_core.messages import BaseMessage
import logging
logger = logging.getLogger(__name__)
class JARVISState(TypedDict):
"""
State dictionary for the JARVIS GAIA Agent, used with LangGraph to manage task processing.
Attributes:
task_id: Unique identifier for the GAIA task.
question: The question text to be answered.
tools_needed: List of tool names to be used for the task.
web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
file_results: Parsed content from text, CSV, Excel, or audio files.
image_results: OCR or description results from image files.
calculation_results: Results from mathematical calculations.
document_results: Extracted content from PDF or text documents.
multi_hop_results: Results from iterative multi-hop searches (supports strings or dicts).
messages: List of messages for LLM context (e.g., user prompts, system instructions).
answer: Final answer for the task, formatted for GAIA submission.
results_table: List of task results for Gradio display (Task ID, Question, Answer).
status_output: Status message for Gradio output (e.g., submission result).
error: Optional error message if task processing fails.
metadata: Optional metadata (e.g., timestamps, tool execution status).
"""
task_id: str
question: str
tools_needed: List[str]
web_results: List[str]
file_results: str
image_results: str
calculation_results: str
document_results: str
multi_hop_results: List[Union[str, Dict[str, Any]]]
messages: List[BaseMessage]
answer: str
results_table: List[Dict[str, str]]
status_output: str
error: Optional[str]
metadata: Optional[Dict[str, Any]]
def validate_state(state: JARVISState) -> JARVISState:
"""
Validate and initialize JARVISState fields.
Args:
state: Input state dictionary.
Returns:
Validated and initialized state.
"""
try:
if not state.get("task_id"):
logger.error("task_id is required")
raise ValueError("task_id is required")
if not state.get("question"):
logger.error("question is required")
raise ValueError("question is required")
# Initialize default values if missing
defaults = {
"tools_needed": ["search_tool"],
"web_results": [],
"file_results": "",
"image_results": "",
"calculation_results": "",
"document_results": "",
"multi_hop_results": [],
"messages": [],
"answer": "",
"results_table": [],
"status_output": "",
"error": None,
"metadata": {}
}
for key, default in defaults.items():
if key not in state or state[key] is None:
state[key] = default
logger.debug(f"Validated state for task {state['task_id']}")
return state
except Exception as e:
logger.error(f"State validation failed: {e}")
raise
def reset_state(task_id: str, question: str) -> JARVISState:
"""
Create a fresh JARVISState for a new task.
Args:
task_id: Task identifier.
question: Question text.
Returns:
Initialized JARVISState.
"""
state = JARVISState(
task_id=task_id,
question=question,
tools_needed=["search_tool"],
web_results=[],
file_results="",
image_results="",
calculation_results="",
document_results="",
multi_hop_results=[],
messages=[],
answer="",
results_table=[],
status_output="",
error=None,
metadata={}
)
return validate_state(state) |