import os import json import logging import asyncio import aiohttp import ssl import nest_asyncio import requests import pandas as pd from typing import Dict, Any, List from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import SystemMessage, HumanMessage from langgraph.graph import StateGraph, END import torch from sentence_transformers import SentenceTransformer import gradio as gr from dotenv import load_dotenv from huggingface_hub import InferenceClient from transformers import AutoTokenizer, AutoModelForCausalLM import together from state import JARVISState, validate_state, reset_state from tools.answer_generator import generate_answer, preprocess_question from tools.file_fetcher import fetch_task_file from tools.search import search_tool, multi_hop_search_tool from tools.file_parser import file_parser_tool from tools.image_parser import image_parser_tool from tools.calculator import calculator_tool from tools.document_retriever import document_retriever_tool from tools.duckduckgo_search import duckduckgo_search_tool from tools.weather_info import weather_info_tool from tools.hub_stats import hub_stats_tool from tools.guest_info import guest_info_retriever_tool # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Apply nest_asyncio nest_asyncio.apply() # Load environment variables load_dotenv() SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent") GAIA_API_URL = "https://agents-course-unit4-api-1.hf.space/api" TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") OPENWEATHERMAP_API_KEY = os.getenv("OPENWEATHERMAP_API_KEY") # Verify environment variables if not SPACE_ID: raise ValueError("SPACE_ID not set") if not HF_API_TOKEN: raise ValueError("HUGGINGFACEHUB_API_TOKEN not set") if not TOGETHER_API_KEY: raise ValueError("TOGETHER_API_KEY not set") if not OPENWEATHERMAP_API_KEY: logger.warning("OPENWEATHERMAP_API_KEY not set; weather_info_tool may fail") logger.info(f"SPACE_ID: {SPACE_ID}") # Model configuration TOGETHER_MODELS = [ "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free", ] HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct" # Initialize LLM clients def initialize_llm(): for model in TOGETHER_MODELS: try: together.api_key = TOGETHER_API_KEY client = together.Together() response = client.chat.completions.create( model=model, messages=[{"role": "user", "content": "Test"}], max_tokens=10 ) logger.info(f"Initialized Together AI model: {model}") return client, "together", model except Exception as e: logger.warning(f"Failed to initialize Together AI model {model}: {e}") try: client = InferenceClient( model=HF_MODEL, token=HF_API_TOKEN, timeout=30 ) logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}") return client, "hf_api", HF_MODEL except Exception as e: logger.warning(f"Failed to initialize HF Inference API: {e}") try: tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN) model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto") logger.info(f"Initialized local Hugging Face model: {HF_MODEL}") return (model, tokenizer), "hf_local", HF_MODEL except Exception as e: logger.error(f"Failed to initialize local HF model: {e}") raise Exception("No LLM could be initialized") llm_client, llm_type, llm_model = initialize_llm() # Initialize embedder _embedder = None def get_embedder(): global _embedder if _embedder is None: try: device = "cuda" if torch.cuda.is_available() else "cpu" _embedder = SentenceTransformer( "all-MiniLM-L6-v2", device=device, cache_folder="./cache" ) logger.info(f"SentenceTransformer initialized on {device.upper()}") except Exception as e: logger.error(f"Failed to initialize SentenceTransformer: {e}") raise RuntimeError(f"Embedder initialization failed: {e}") return _embedder try: embedder = get_embedder() except Exception as e: logger.error(f"Failed to initialize embedder: {e}") embedder = None # Log device device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # HTTP session with SSL handling async def create_http_session(): ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE return aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl=ssl_context), timeout=aiohttp.ClientTimeout(total=30) ) # Tool registration tools = { "search_tool": search_tool, "multi_hop_search_tool": multi_hop_search_tool, "file_parser_tool": file_parser_tool, "image_parser_tool": image_parser_tool, "calculator_tool": calculator_tool, "document_retriever_tool": document_retriever_tool, "duckduckgo_search_tool": duckduckgo_search_tool, "weather_info_tool": weather_info_tool, "hub_stats_tool": hub_stats_tool, "guest_info_retriever_tool": guest_info_retriever_tool, } # Parse question to select tools async def parse_question(state: JARVISState) -> JARVISState: """ Parse the question to select appropriate tools using LLM with retries, preprocess the question, and integrate file-based tools. Args: state (JARVISState): The input state containing task_id, question. Returns: JARVISState: Updated state with selected tools_needed and metadata. """ state = validate_state(state) task_id = state["task_id"] question = state["question"] logger.info(f"Task {task_id} Parsing question: {question}") try: # Preprocess question processed_question = await preprocess_question(question) if processed_question != question: logger.info(f"Task {task_id} Preprocessed question: {processed_question}") state["question"] = processed_question question = processed_question # Default to search_tool tools_needed = ["search_tool"] # LLM-based tool selection if llm_client: prompt = ChatPromptTemplate.from_messages([ SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool']. Return a JSON list of all relevant tools, e.g., ["search_tool", "duckduckgo_search_tool"]. Rules: - Include "search_tool" for web-based questions unless purely computational or file-based. - Include "multi_hop_search_tool" for questions with >20 words or requiring multiple steps. - Include "file_parser_tool" for 'data', 'table', 'excel', 'csv', 'txt', 'mp3', or file extensions. - Include "image_parser_tool" for 'image', 'video', 'picture', or 'painting'. - Include "calculator_tool" for 'calculate', 'math', 'sum', 'average', 'total', or numerical operations. - Include "document_retriever_tool" for 'document', 'pdf', 'report', or 'paper'. - Include "duckduckgo_search_tool" for 'search', 'wikipedia', 'online', or general knowledge. - Include "weather_info_tool" for 'weather', 'temperature', or 'forecast'. - Include "hub_stats_tool" for 'model', 'huggingface', or 'dataset'. - Include "guest_info_retriever_tool" for 'guest', 'name', 'relation', or 'person'. - Select multiple tools if the question spans multiple domains (e.g., web and file). - Output ONLY valid JSON."""), HumanMessage(content=f"Query: {question}") ]) messages = prompt.format_messages() for attempt in range(3): # Retry up to 3 times try: formatted_messages = [ {"role": "system" if isinstance(m, SystemMessage) else "user", "content": m.content} for m in messages ] if llm_type == "hf_local": model, tokenizer = llm_client inputs = tokenizer.apply_chat_template( formatted_messages, return_tensors="pt" ).to(model.device) outputs = model.generate(inputs, max_new_tokens=100, temperature=0.5) response = tokenizer.decode(outputs[0], skip_special_tokens=True) elif llm_type == "together": response = llm_client.chat.completions.create( model=llm_model, messages=formatted_messages, max_tokens=100, temperature=0.5 ) response = response.choices[0].message.content.strip() else: # hf_api response = llm_client.chat.completions.create( messages=formatted_messages, max_tokens=100, temperature=0.5 ) response = response.choices[0].message.content.strip() logger.info(f"Task {task_id} LLM tool selection response: {response}") try: tools_needed = json.loads(response) if isinstance(tools_needed, list) and all(isinstance(t, str) and t in tools for t in tools_needed): break # Valid response, exit retry loop else: raise ValueError("Invalid tool list format") except json.JSONDecodeError as e: logger.warning(f"Task {task_id}: Invalid JSON (attempt {attempt + 1}): {e}") if attempt == 2: tools_needed = ["search_tool"] # Fallback after retries except Exception as e: logger.warning(f"Task {task_id} Tool selection failed (attempt {attempt + 1}): {e}") if attempt == 2: tools_needed = ["search_tool"] # Fallback after retries # Fallback to keyword-based selection if LLM fails if tools_needed == ["search_tool"] and not any(kw in question.lower() for kw in ["calculate", "math", "image", "document", "file", "weather", "guest", "model"]): question_lower = question.lower() if any(kw in question_lower for kw in ["excel", "csv", "mp3", "data", "table", "xlsx"]): tools_needed.append("file_parser_tool") if any(kw in question_lower for kw in ["image", "video", "picture", "painting"]): tools_needed.append("image_parser_tool") if any(kw in question_lower for kw in ["calculate", "math", "sum", "average", "total"]): tools_needed.append("calculator_tool") if any(kw in question_lower for kw in ["document", "pdf", "report", "paper"]): tools_needed.append("document_retriever_tool") if any(kw in question_lower for kw in ["search", "wikipedia", "online"]): tools_needed.append("duckduckgo_search_tool") if any(kw in question_lower for kw in ["weather", "temperature", "forecast"]): tools_needed.append("weather_info_tool") if any(kw in question_lower for kw in ["model", "huggingface", "dataset"]): tools_needed.append("hub_stats_tool") if any(kw in question_lower for kw in ["guest", "name", "relation", "person"]): tools_needed.append("guest_info_retriever_tool") if len(question.split()) > 20 or "multiple" in question_lower: tools_needed.append("multi_hop_search_tool") # Integrate file-based tools file_results = await fetch_task_file(task_id, question) for ext, content in file_results.items(): if content: os.makedirs("temp", exist_ok=True) file_path = f"temp/{task_id}.{ext}" with open(file_path, "wb") as f: f.write(content) state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path} if ext in ["txt", "csv", "xlsx", "mp3"] and "file_parser_tool" not in tools_needed: tools_needed.append("file_parser_tool") elif ext in ["jpg", "png"] and "image_parser_tool" not in tools_needed: tools_needed.append("image_parser_tool") elif ext == "pdf" and "document_retriever_tool" not in tools_needed: tools_needed.append("document_retriever_tool") state["tools_needed"] = list(set(tools_needed)) # Remove duplicates logger.info(f"Task {task_id} Selected tools: {state['tools_needed']}") return state except Exception as e: logger.error(f"Task {task_id} Tool selection failed: {e}") state["error"] = f"Parse question failed: {str(e)}" state["tools_needed"] = ["search_tool"] return state # Tool dispatcher async def tool_dispatcher(state: JARVISState) -> JARVISState: state = validate_state(state) try: task_id = state["task_id"] question = state["question"] tools_needed = state["tools_needed"] for tool_name in tools_needed: try: if tool_name == "search_tool": result = await tools["search_tool"].ainvoke({"query": question}) state["web_results"].extend([str(r) for r in result] if result else ["No results from search_tool"]) elif tool_name == "multi_hop_search_tool": result = await tools["multi_hop_search_tool"].ainvoke({ "query": question, "steps": 3, "llm_client": llm_client, "llm_type": llm_type, "llm_model": llm_model }) state["multi_hop_results"].extend([r["content"] if isinstance(r, dict) else str(r) for r in result] if result else ["No results from multi_hop_search_tool"]) elif tool_name == "file_parser_tool": file_path = state["metadata"].get("file_path") file_ext = state["metadata"].get("file_ext") if file_path and os.path.exists(file_path) and file_ext: result = await tools["file_parser_tool"].ainvoke({ "task_id": task_id, "file_type": file_ext, "file_path": file_path, "query": question }) state["file_results"] = str(result) if result else "No file results" else: state["file_results"] = "No file available" elif tool_name == "image_parser_tool": file_path = state["metadata"].get("file_path") if file_path and os.path.exists(file_path) and file_path.split('.')[-1] in ["jpg", "png"]: result = await tools["image_parser_tool"].ainvoke({"task_id": task_id, "file_path": file_path}) state["image_results"] = str(result) if result else "No image results" else: state["image_results"] = "No image available" elif tool_name == "calculator_tool": result = await tools["calculator_tool"].ainvoke({"expression": question}) state["calculation_results"] = str(result) if result else "No calculation results" elif tool_name == "document_retriever_tool": file_path = state["metadata"].get("file_path") if file_path and os.path.exists(file_path) and file_path.split('.')[-1] == "pdf": result = await tools["document_retriever_tool"].ainvoke({ "task_id": task_id, "query": question, "file_path": file_path }) state["document_results"] = str(result) if result else "No document results" else: state["document_results"] = "No document available" elif tool_name == "duckduckgo_search_tool": result = await tools["duckduckgo_search_tool"].ainvoke({ "query": question, "original_query": question, "embedder": embedder }) state["web_results"].extend(result if isinstance(result, list) else [str(result)] if result else ["No results from duckduckgo_search_tool"]) elif tool_name == "weather_info_tool": location = question.split()[-1] if "weather" in question.lower() else "Unknown" result = await tools["weather_info_tool"].ainvoke({"location": location}) state["web_results"].append(str(result) if result else "No weather results") elif tool_name == "hub_stats_tool": author = question.split("by ")[1].split()[0] if "by" in question.lower() else "Unknown" result = await tools["hub_stats_tool"].ainvoke({"author": author}) state["web_results"].append(str(result) if result else "No hub stats results") elif tool_name == "guest_info_retriever_tool": result = await tools["guest_info_retriever_tool"].ainvoke({"query": question}) state["web_results"].append(str(result) if result else "No guest info results") state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_executed": True} logger.info(f"Task {task_id}: Executed {tool_name}") except Exception as e: logger.warning(f"Tool {tool_name} failed for task {task_id}: {e}") state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_error": str(e)} # Ensure results are populated state["web_results"] = state.get("web_results", ["No web results found"]) state["file_results"] = state.get("file_results", "No file results found") state["image_results"] = state.get("image_results", "No image results found") state["document_results"] = state.get("document_results", "No document results found") state["calculation_results"] = state.get("calculation_results", "No calculation results found") state["answer"] = await generate_answer( task_id=task_id, question=question, search_results=state.get("web_results", []) + [ r["content"] if isinstance(r, dict) else str(r) for r in state.get("multi_hop_results", []) ], file_results=state.get("file_results", "") + state.get("document_results", "") + state.get("image_results", "") + state.get("calculation_results", ""), llm_client=llm_client ) logger.info(f"Task {task_id}: Generated answer: {state['answer']}") return state except Exception as e: logger.error(f"Tool dispatch failed: {e}") state["error"] = f"Tool dispatch failed: {e}" return state # Define StateGraph workflow = StateGraph(JARVISState) workflow.add_node("parse_question", parse_question) workflow.add_node("tool_dispatcher", tool_dispatcher) workflow.set_entry_point("parse_question") workflow.add_edge("parse_question", "tool_dispatcher") workflow.add_edge("tool_dispatcher", END) graph = workflow.compile() # Agent class class JARVISAgent: def __init__(self): self.state = reset_state(task_id="init", question="Agent initialized") self.state["results_table"] = [] # Initialize as empty list logger.info("JARVISAgent initialized.") async def process_question(self, task_id: str, question: str) -> str: state = reset_state(task_id=task_id, question=question) try: result = await graph.ainvoke(state) answer = result.get("answer", "Unknown") logger.info(f"Task {task_id} Final answer: {answer}") self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": answer}) self.state["metadata"] = {"last_task_id": task_id, "answer": answer} return answer except Exception as e: logger.error(f"Error processing task {task_id}: {e}") self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"}) self.state["error"] = f"Task {task_id} failed: {str(e)}" return f"Error: {str(e)}" finally: for ext in ["txt", "csv", "xlsx", "mp3", "jpg", "png", "pdf"]: file_path = f"temp/{task_id}.{ext}" if os.path.exists(file_path): try: os.remove(file_path) logger.info(f"Removed temp file: {file_path}") except Exception as e: logger.error(f"Error removing file {file_path}: {e}") async def process_all_questions(self, profile: gr.OAuthProfile | None): if not profile: logger.error("User not logged in.") self.state["status_output"] = "Please Login to Hugging Face." return pd.DataFrame(self.state["results_table"]), self.state["status_output"] username = profile.username logger.info(f"User logged in: {username}") questions_url = f"{GAIA_API_URL}/questions" submit_url = f"{GAIA_API_URL}/submit" agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main" try: async with await create_http_session() as session: async with session.get(questions_url) as response: response.raise_for_status() questions = await response.json() logger.info(f"Fetched {len(questions)} questions.") except Exception as e: logger.error(f"Error fetching questions: {e}") self.state["status_output"] = f"Error fetching questions: {e}" self.state["error"] = f"Fetch questions failed: {str(e)}" return pd.DataFrame(self.state["results_table"]), self.state["status_output"] answers_payload = [] for item in questions: task_id = item.get("task_id") question = item.get("question") if not task_id or not question: logger.warning(f"Skipping invalid item: {item}") continue answer = await self.process_question(task_id, question) answers_payload.append({"task_id": task_id, "submitted_answer": answer}) if not answers_payload: logger.error("No answers generated.") self.state["status_output"] = "No answers to submit." self.state["error"] = "No answers generated" return pd.DataFrame(self.state["results_table"]), self.state["status_output"] submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} try: async with await create_http_session() as session: async with session.post(submit_url, json=submission_data) as response: response.raise_for_status() result_data = await response.json() self.state["status_output"] = ( f"Submission Successful!\n" f"User: {result_data.get('username')}\n" f"Overall Score: {result_data.get('score', 'N/A')}% " f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" f"Message: {result_data.get('message', 'No message received.')}" ) self.state["metadata"] = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')} except Exception as e: logger.error(f"Submission failed: {e}") self.state["status_output"] = f"Submission Failed: {e}" self.state["error"] = f"Submission failed: {str(e)}" return pd.DataFrame(self.state["results_table"] if self.state["results_table"] else [], columns=["Task ID", "Question", "Answer"]), self.state["status_output"] # Gradio interface with gr.Blocks() as demo: gr.Markdown("# JARVIS GAIA Agent") gr.Markdown( """ **Instructions:** 1. Log in to Hugging Face using the button below. 2. Click 'Run Evaluation & Submit All Answers' to process GAIA questions and submit. --- **Disclaimers:** Uses Hugging Face Inference, Together AI, SERPAPI, and OpenWeatherMap for GAIA benchmark. """ ) with gr.Row(): gr.LoginButton(value="Login to Hugging Face") run_button = gr.Button("Run Evaluation & Submit All Answers") status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"]) agent = JARVISAgent() run_button.click( fn=agent.process_all_questions, outputs=[results_table, status_output] ) if __name__ == "__main__": logger.info("\n" + "-"*30 + " App Starting " + "-"*30) logger.info(f"SPACE_ID: {SPACE_ID}") logger.info("Launching Gradio Interface...") demo.launch(debug=True, share=False)