Spaces:
Starting
Starting
feat(tools): add more tool to extend the functionaily of jarvis
Browse files- .gitignore +1 -0
- README.md +4 -0
- app.py +326 -344
- project_structure.txt +22 -0
- requirements.txt +5 -2
- result.txt +0 -0
- retriever.py +165 -15
- state.py +83 -6
- test.py +231 -8
- tools/__init__.py +3 -1
- tools/answer_generator.py +129 -0
- tools/calculator.py +28 -8
- tools/document_retriever.py +39 -22
- tools/duckduckgo_search.py +95 -2
- tools/file_fetcher.py +42 -0
- tools/file_parser.py +93 -17
- tools/guest_info.py +40 -13
- tools/hub_stats.py +43 -6
- tools/image_parser.py +34 -16
- tools/search.py +82 -85
- tools/weather_info.py +33 -6
.gitignore
CHANGED
@@ -41,6 +41,7 @@ coverage.xml
|
|
41 |
*.py,cover
|
42 |
.tox/
|
43 |
.pytest_cache/
|
|
|
44 |
|
45 |
# Logs and temporary files
|
46 |
*.log
|
|
|
41 |
*.py,cover
|
42 |
.tox/
|
43 |
.pytest_cache/
|
44 |
+
cache/
|
45 |
|
46 |
# Logs and temporary files
|
47 |
*.log
|
README.md
CHANGED
@@ -74,6 +74,10 @@ jarvis_gaia_agent/
|
|
74 |
- `SERPAPI_API_KEY`: SERPAPI key for web searches.
|
75 |
- `OPENWEATHERMAP_API_KEY`: OpenWeatherMap key for weather queries.
|
76 |
- `SPACE_ID`: `onisj/jarvis_gaia_agent`.
|
|
|
|
|
|
|
|
|
77 |
|
78 |
## Setup and Local Testing
|
79 |
|
|
|
74 |
- `SERPAPI_API_KEY`: SERPAPI key for web searches.
|
75 |
- `OPENWEATHERMAP_API_KEY`: OpenWeatherMap key for weather queries.
|
76 |
- `SPACE_ID`: `onisj/jarvis_gaia_agent`.
|
77 |
+
- Install dependencies:
|
78 |
+
```bash
|
79 |
+
pip install -r requirements.txt
|
80 |
+
```
|
81 |
|
82 |
## Setup and Local Testing
|
83 |
|
app.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3 |
import logging
|
4 |
import asyncio
|
5 |
import aiohttp
|
|
|
6 |
import nest_asyncio
|
7 |
import requests
|
8 |
import pandas as pd
|
@@ -10,18 +11,25 @@ from typing import Dict, Any, List
|
|
10 |
from langchain_core.prompts import ChatPromptTemplate
|
11 |
from langchain_core.messages import SystemMessage, HumanMessage
|
12 |
from langgraph.graph import StateGraph, END
|
|
|
13 |
from sentence_transformers import SentenceTransformer
|
14 |
import gradio as gr
|
15 |
from dotenv import load_dotenv
|
16 |
from huggingface_hub import InferenceClient
|
17 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
18 |
import together
|
19 |
-
from state import JARVISState
|
20 |
-
from tools import
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# Setup logging
|
27 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
@@ -33,10 +41,10 @@ nest_asyncio.apply()
|
|
33 |
# Load environment variables
|
34 |
load_dotenv()
|
35 |
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
|
36 |
-
GAIA_API_URL = "https://agents-course-unit4-
|
37 |
-
GAIA_FILE_URL = f"{GAIA_API_URL}/files/"
|
38 |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
|
39 |
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
|
|
40 |
|
41 |
# Verify environment variables
|
42 |
if not SPACE_ID:
|
@@ -45,6 +53,8 @@ if not HF_API_TOKEN:
|
|
45 |
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
|
46 |
if not TOGETHER_API_KEY:
|
47 |
raise ValueError("TOGETHER_API_KEY not set")
|
|
|
|
|
48 |
logger.info(f"SPACE_ID: {SPACE_ID}")
|
49 |
|
50 |
# Model configuration
|
@@ -56,23 +66,20 @@ HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
|
56 |
|
57 |
# Initialize LLM clients
|
58 |
def initialize_llm():
|
59 |
-
# Try Together AI models
|
60 |
for model in TOGETHER_MODELS:
|
61 |
try:
|
62 |
together.api_key = TOGETHER_API_KEY
|
63 |
client = together.Together()
|
64 |
-
# Test the model
|
65 |
response = client.chat.completions.create(
|
66 |
model=model,
|
67 |
messages=[{"role": "user", "content": "Test"}],
|
68 |
max_tokens=10
|
69 |
)
|
70 |
logger.info(f"Initialized Together AI model: {model}")
|
71 |
-
return client, "together"
|
72 |
except Exception as e:
|
73 |
logger.warning(f"Failed to initialize Together AI model {model}: {e}")
|
74 |
|
75 |
-
# Fallback to Hugging Face Inference API
|
76 |
try:
|
77 |
client = InferenceClient(
|
78 |
model=HF_MODEL,
|
@@ -80,381 +87,355 @@ def initialize_llm():
|
|
80 |
timeout=30
|
81 |
)
|
82 |
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
|
83 |
-
return client, "hf_api"
|
84 |
except Exception as e:
|
85 |
logger.warning(f"Failed to initialize HF Inference API: {e}")
|
86 |
|
87 |
-
# Fallback to local Hugging Face model
|
88 |
try:
|
89 |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
|
90 |
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
|
91 |
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
|
92 |
-
return (model, tokenizer), "hf_local"
|
93 |
except Exception as e:
|
94 |
logger.error(f"Failed to initialize local HF model: {e}")
|
95 |
raise Exception("No LLM could be initialized")
|
96 |
|
97 |
-
llm_client, llm_type = initialize_llm()
|
98 |
|
99 |
# Initialize embedder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
try:
|
101 |
-
embedder =
|
102 |
-
logger.info("Sentence transformer initialized")
|
103 |
except Exception as e:
|
104 |
logger.error(f"Failed to initialize embedder: {e}")
|
105 |
embedder = None
|
106 |
|
107 |
-
#
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# Parse question to select tools
|
129 |
async def parse_question(state: JARVISState) -> JARVISState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
try:
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
tools_needed = ["search_tool"]
|
134 |
-
|
|
|
135 |
if llm_client:
|
136 |
prompt = ChatPromptTemplate.from_messages([
|
137 |
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
|
138 |
-
|
|
|
|
|
139 |
Rules:
|
140 |
-
-
|
141 |
-
-
|
142 |
-
-
|
143 |
-
-
|
144 |
-
-
|
145 |
-
-
|
146 |
-
-
|
147 |
-
-
|
148 |
-
-
|
149 |
-
-
|
|
|
150 |
- Output ONLY valid JSON."""),
|
151 |
HumanMessage(content=f"Query: {question}")
|
152 |
])
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
if any(
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
if
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
tools_needed.append("file_parser_tool")
|
223 |
-
|
224 |
tools_needed.append("image_parser_tool")
|
225 |
-
|
226 |
tools_needed.append("document_retriever_tool")
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
state["tools_needed"] = list(set(tools_needed))
|
231 |
-
logger.info(f"Task {task_id}: Selected tools: {tools_needed}")
|
232 |
return state
|
233 |
except Exception as e:
|
234 |
-
logger.error(f"
|
235 |
state["error"] = f"Parse question failed: {str(e)}"
|
236 |
state["tools_needed"] = ["search_tool"]
|
237 |
return state
|
238 |
|
239 |
# Tool dispatcher
|
240 |
async def tool_dispatcher(state: JARVISState) -> JARVISState:
|
|
|
241 |
try:
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
file_type = "xlsx"
|
248 |
-
|
249 |
-
for tool in updated_state["tools_needed"]:
|
250 |
-
try:
|
251 |
-
if tool == "search_tool":
|
252 |
-
result = search_tool(updated_state["question"])
|
253 |
-
updated_state["web_results"].extend([str(r) for r in result])
|
254 |
-
elif tool == "multi_hop_search_tool":
|
255 |
-
result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3, "llm_client": llm_client, "llm_type": llm_type})
|
256 |
-
updated_state["multi_hop_results"].extend([r["content"] for r in result])
|
257 |
-
await asyncio.sleep(2)
|
258 |
-
elif tool == "file_parser_tool":
|
259 |
-
for ext in ["txt", "csv", "xlsx"]:
|
260 |
-
file_path = await download_file(updated_state["task_id"], ext)
|
261 |
-
if file_path:
|
262 |
-
result = file_parser_tool(file_path)
|
263 |
-
updated_state["file_results"] = str(result)
|
264 |
-
break
|
265 |
-
elif tool == "image_parser_tool":
|
266 |
-
file_path = await download_file(updated_state["task_id"], "jpg")
|
267 |
-
if file_path:
|
268 |
-
result = image_parser_tool(file_path)
|
269 |
-
updated_state["image_results"] = str(result)
|
270 |
-
elif tool == "calculator_tool":
|
271 |
-
result = calculator_tool(updated_state["question"])
|
272 |
-
updated_state["calculation_results"] = str(result)
|
273 |
-
elif tool == "document_retriever_tool":
|
274 |
-
file_path = await download_file(updated_state["task_id"], "pdf")
|
275 |
-
if file_path:
|
276 |
-
result = document_retriever_tool({"task_id": updated_state["task_id"], "query": updated_state["question"], "file_type": "pdf"})
|
277 |
-
updated_state["document_results"] = str(result)
|
278 |
-
elif tool == "duckduckgo_search_tool":
|
279 |
-
result = duckduckgo_search_tool(updated_state["question"])
|
280 |
-
updated_state["web_results"].append(str(result))
|
281 |
-
elif tool == "weather_info_tool":
|
282 |
-
location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown"
|
283 |
-
result = weather_info_tool({"location": location})
|
284 |
-
updated_state["web_results"].append(str(result))
|
285 |
-
elif tool == "hub_stats_tool":
|
286 |
-
author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown"
|
287 |
-
result = hub_stats_tool({"author": author})
|
288 |
-
updated_state["web_results"].append(str(result))
|
289 |
-
elif tool == "guest_info_retriever_tool":
|
290 |
-
query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"]
|
291 |
-
result = guest_info_retriever_tool({"query": query})
|
292 |
-
updated_state["web_results"].append(str(result))
|
293 |
-
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_executed": True}
|
294 |
-
except Exception as e:
|
295 |
-
logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}")
|
296 |
-
updated_state["error"] = f"Tool {tool} failed: {str(e)}"
|
297 |
-
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_error": str(e)}
|
298 |
-
|
299 |
-
logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}")
|
300 |
-
return updated_state
|
301 |
-
except Exception as e:
|
302 |
-
logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}")
|
303 |
-
updated_state["error"] = f"Tool dispatch failed: {str(e)}"
|
304 |
-
return updated_state
|
305 |
-
|
306 |
-
# Reasoning
|
307 |
-
async def reasoning(state: JARVISState) -> Dict[str, Any]:
|
308 |
-
try:
|
309 |
-
prompt = ChatPromptTemplate.from_messages([
|
310 |
-
SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""),
|
311 |
-
HumanMessage(content="""Task: {task_id}
|
312 |
-
Question: {question}
|
313 |
-
Web results: {web_results}
|
314 |
-
Multi-hop results: {multi_hop_results}
|
315 |
-
File results: {file_results}
|
316 |
-
Image results: {image_results}
|
317 |
-
Calculation results: {calculation_results}
|
318 |
-
Document results: {document_results}""")
|
319 |
-
])
|
320 |
-
messages = [
|
321 |
-
{"role": "system", "content": prompt[0].content},
|
322 |
-
{"role": "user", "content": prompt[1].content.format(
|
323 |
-
task_id=state["task_id"],
|
324 |
-
question=state["question"],
|
325 |
-
web_results="\n".join(state["web_results"]),
|
326 |
-
multi_hop_results="\n".join(state["multi_hop_results"]),
|
327 |
-
file_results=state["file_results"],
|
328 |
-
image_results=state["image_results"],
|
329 |
-
calculation_results=state["calculation_results"],
|
330 |
-
document_results=state["document_results"]
|
331 |
-
)}
|
332 |
-
]
|
333 |
-
for attempt in range(3):
|
334 |
try:
|
335 |
-
if
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
)
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
-
|
369 |
-
|
370 |
except Exception as e:
|
371 |
-
logger.warning(f"
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
except Exception as e:
|
376 |
-
logger.error(f"
|
377 |
-
state["error"] = f"
|
378 |
-
return
|
379 |
-
|
380 |
-
# Router
|
381 |
-
def router(state: JARVISState) -> str:
|
382 |
-
if state["tools_needed"]:
|
383 |
-
return "tool_dispatcher"
|
384 |
-
return "reasoning"
|
385 |
|
386 |
# Define StateGraph
|
387 |
workflow = StateGraph(JARVISState)
|
388 |
-
workflow.add_node("
|
389 |
workflow.add_node("tool_dispatcher", tool_dispatcher)
|
390 |
-
workflow.
|
391 |
-
workflow.
|
392 |
-
workflow.
|
393 |
-
"parse",
|
394 |
-
router,
|
395 |
-
{
|
396 |
-
"tool_dispatcher": "tool_dispatcher",
|
397 |
-
"reasoning": "reasoning"
|
398 |
-
}
|
399 |
-
)
|
400 |
-
workflow.add_edge("tool_dispatcher", "reasoning")
|
401 |
-
workflow.add_edge("reasoning", END)
|
402 |
graph = workflow.compile()
|
403 |
|
404 |
# Agent class
|
405 |
class JARVISAgent:
|
406 |
def __init__(self):
|
407 |
-
self.state =
|
408 |
-
|
409 |
-
question="",
|
410 |
-
tools_needed=[],
|
411 |
-
web_results=[],
|
412 |
-
file_results="",
|
413 |
-
image_results="",
|
414 |
-
calculation_results="",
|
415 |
-
document_results="",
|
416 |
-
multi_hop_results=[],
|
417 |
-
messages=[],
|
418 |
-
answer="",
|
419 |
-
results_table=[],
|
420 |
-
status_output="",
|
421 |
-
error=None,
|
422 |
-
metadata={}
|
423 |
-
)
|
424 |
logger.info("JARVISAgent initialized.")
|
425 |
|
426 |
async def process_question(self, task_id: str, question: str) -> str:
|
427 |
-
state =
|
428 |
-
task_id=task_id,
|
429 |
-
question=question,
|
430 |
-
tools_needed=["search_tool"],
|
431 |
-
web_results=[],
|
432 |
-
file_results="",
|
433 |
-
image_results="",
|
434 |
-
calculation_results="",
|
435 |
-
document_results="",
|
436 |
-
multi_hop_results=[],
|
437 |
-
messages=[HumanMessage(content=question)],
|
438 |
-
answer="",
|
439 |
-
results_table=[],
|
440 |
-
status_output="",
|
441 |
-
error=None,
|
442 |
-
metadata={}
|
443 |
-
)
|
444 |
try:
|
445 |
result = await graph.ainvoke(state)
|
446 |
-
answer = result
|
447 |
-
logger.info(f"Task {task_id}
|
448 |
-
self.state
|
449 |
-
self.state
|
450 |
return answer
|
451 |
except Exception as e:
|
452 |
logger.error(f"Error processing task {task_id}: {e}")
|
453 |
-
self.state
|
454 |
-
self.state
|
455 |
return f"Error: {str(e)}"
|
456 |
finally:
|
457 |
-
for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
|
458 |
file_path = f"temp/{task_id}.{ext}"
|
459 |
if os.path.exists(file_path):
|
460 |
try:
|
@@ -466,25 +447,26 @@ class JARVISAgent:
|
|
466 |
async def process_all_questions(self, profile: gr.OAuthProfile | None):
|
467 |
if not profile:
|
468 |
logger.error("User not logged in.")
|
469 |
-
self.state
|
470 |
-
return pd.DataFrame(self.state
|
471 |
|
472 |
-
username =
|
473 |
logger.info(f"User logged in: {username}")
|
474 |
questions_url = f"{GAIA_API_URL}/questions"
|
475 |
submit_url = f"{GAIA_API_URL}/submit"
|
476 |
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
|
477 |
|
478 |
try:
|
479 |
-
|
480 |
-
|
481 |
-
|
|
|
482 |
logger.info(f"Fetched {len(questions)} questions.")
|
483 |
except Exception as e:
|
484 |
logger.error(f"Error fetching questions: {e}")
|
485 |
-
self.state
|
486 |
-
self.state
|
487 |
-
return pd.DataFrame(self.state
|
488 |
|
489 |
answers_payload = []
|
490 |
for item in questions:
|
@@ -498,33 +480,34 @@ class JARVISAgent:
|
|
498 |
|
499 |
if not answers_payload:
|
500 |
logger.error("No answers generated.")
|
501 |
-
self.state
|
502 |
-
self.state
|
503 |
-
return pd.DataFrame(self.state
|
504 |
|
505 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
506 |
try:
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
|
|
511 |
f"Submission Successful!\n"
|
512 |
f"User: {result_data.get('username')}\n"
|
513 |
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
514 |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
515 |
f"Message: {result_data.get('message', 'No message received.')}"
|
516 |
)
|
517 |
-
self.state
|
518 |
except Exception as e:
|
519 |
logger.error(f"Submission failed: {e}")
|
520 |
-
self.state
|
521 |
-
self.state
|
522 |
|
523 |
-
return pd.DataFrame(self.state.results_table), self.state
|
524 |
|
525 |
# Gradio interface
|
526 |
with gr.Blocks() as demo:
|
527 |
-
gr.Markdown("#
|
528 |
gr.Markdown(
|
529 |
"""
|
530 |
**Instructions:**
|
@@ -539,7 +522,6 @@ with gr.Blocks() as demo:
|
|
539 |
)
|
540 |
with gr.Row():
|
541 |
gr.LoginButton(value="Login to Hugging Face")
|
542 |
-
# Removed gr.LogoutButton due to deprecation
|
543 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
544 |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
545 |
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
|
|
|
3 |
import logging
|
4 |
import asyncio
|
5 |
import aiohttp
|
6 |
+
import ssl
|
7 |
import nest_asyncio
|
8 |
import requests
|
9 |
import pandas as pd
|
|
|
11 |
from langchain_core.prompts import ChatPromptTemplate
|
12 |
from langchain_core.messages import SystemMessage, HumanMessage
|
13 |
from langgraph.graph import StateGraph, END
|
14 |
+
import torch
|
15 |
from sentence_transformers import SentenceTransformer
|
16 |
import gradio as gr
|
17 |
from dotenv import load_dotenv
|
18 |
from huggingface_hub import InferenceClient
|
19 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
20 |
import together
|
21 |
+
from state import JARVISState, validate_state, reset_state
|
22 |
+
from tools.answer_generator import generate_answer, preprocess_question
|
23 |
+
from tools.file_fetcher import fetch_task_file
|
24 |
+
from tools.search import search_tool, multi_hop_search_tool
|
25 |
+
from tools.file_parser import file_parser_tool
|
26 |
+
from tools.image_parser import image_parser_tool
|
27 |
+
from tools.calculator import calculator_tool
|
28 |
+
from tools.document_retriever import document_retriever_tool
|
29 |
+
from tools.duckduckgo_search import duckduckgo_search_tool
|
30 |
+
from tools.weather_info import weather_info_tool
|
31 |
+
from tools.hub_stats import hub_stats_tool
|
32 |
+
from tools.guest_info import guest_info_retriever_tool
|
33 |
|
34 |
# Setup logging
|
35 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
41 |
# Load environment variables
|
42 |
load_dotenv()
|
43 |
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
|
44 |
+
GAIA_API_URL = "https://agents-course-unit4-api-1.hf.space/api"
|
|
|
45 |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
|
46 |
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
47 |
+
OPENWEATHERMAP_API_KEY = os.getenv("OPENWEATHERMAP_API_KEY")
|
48 |
|
49 |
# Verify environment variables
|
50 |
if not SPACE_ID:
|
|
|
53 |
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
|
54 |
if not TOGETHER_API_KEY:
|
55 |
raise ValueError("TOGETHER_API_KEY not set")
|
56 |
+
if not OPENWEATHERMAP_API_KEY:
|
57 |
+
logger.warning("OPENWEATHERMAP_API_KEY not set; weather_info_tool may fail")
|
58 |
logger.info(f"SPACE_ID: {SPACE_ID}")
|
59 |
|
60 |
# Model configuration
|
|
|
66 |
|
67 |
# Initialize LLM clients
|
68 |
def initialize_llm():
|
|
|
69 |
for model in TOGETHER_MODELS:
|
70 |
try:
|
71 |
together.api_key = TOGETHER_API_KEY
|
72 |
client = together.Together()
|
|
|
73 |
response = client.chat.completions.create(
|
74 |
model=model,
|
75 |
messages=[{"role": "user", "content": "Test"}],
|
76 |
max_tokens=10
|
77 |
)
|
78 |
logger.info(f"Initialized Together AI model: {model}")
|
79 |
+
return client, "together", model
|
80 |
except Exception as e:
|
81 |
logger.warning(f"Failed to initialize Together AI model {model}: {e}")
|
82 |
|
|
|
83 |
try:
|
84 |
client = InferenceClient(
|
85 |
model=HF_MODEL,
|
|
|
87 |
timeout=30
|
88 |
)
|
89 |
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
|
90 |
+
return client, "hf_api", HF_MODEL
|
91 |
except Exception as e:
|
92 |
logger.warning(f"Failed to initialize HF Inference API: {e}")
|
93 |
|
|
|
94 |
try:
|
95 |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
|
96 |
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
|
97 |
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
|
98 |
+
return (model, tokenizer), "hf_local", HF_MODEL
|
99 |
except Exception as e:
|
100 |
logger.error(f"Failed to initialize local HF model: {e}")
|
101 |
raise Exception("No LLM could be initialized")
|
102 |
|
103 |
+
llm_client, llm_type, llm_model = initialize_llm()
|
104 |
|
105 |
# Initialize embedder
|
106 |
+
_embedder = None
|
107 |
+
|
108 |
+
def get_embedder():
|
109 |
+
global _embedder
|
110 |
+
if _embedder is None:
|
111 |
+
try:
|
112 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
113 |
+
_embedder = SentenceTransformer(
|
114 |
+
"all-MiniLM-L6-v2",
|
115 |
+
device=device,
|
116 |
+
cache_folder="./cache"
|
117 |
+
)
|
118 |
+
logger.info(f"SentenceTransformer initialized on {device.upper()}")
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"Failed to initialize SentenceTransformer: {e}")
|
121 |
+
raise RuntimeError(f"Embedder initialization failed: {e}")
|
122 |
+
return _embedder
|
123 |
+
|
124 |
try:
|
125 |
+
embedder = get_embedder()
|
|
|
126 |
except Exception as e:
|
127 |
logger.error(f"Failed to initialize embedder: {e}")
|
128 |
embedder = None
|
129 |
|
130 |
+
# Log device
|
131 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
132 |
+
logger.info(f"Using device: {device}")
|
133 |
+
|
134 |
+
# HTTP session with SSL handling
|
135 |
+
async def create_http_session():
|
136 |
+
ssl_context = ssl.create_default_context()
|
137 |
+
ssl_context.check_hostname = False
|
138 |
+
ssl_context.verify_mode = ssl.CERT_NONE
|
139 |
+
return aiohttp.ClientSession(
|
140 |
+
connector=aiohttp.TCPConnector(ssl=ssl_context),
|
141 |
+
timeout=aiohttp.ClientTimeout(total=30)
|
142 |
+
)
|
143 |
+
|
144 |
+
# Tool registration
|
145 |
+
tools = {
|
146 |
+
"search_tool": search_tool,
|
147 |
+
"multi_hop_search_tool": multi_hop_search_tool,
|
148 |
+
"file_parser_tool": file_parser_tool,
|
149 |
+
"image_parser_tool": image_parser_tool,
|
150 |
+
"calculator_tool": calculator_tool,
|
151 |
+
"document_retriever_tool": document_retriever_tool,
|
152 |
+
"duckduckgo_search_tool": duckduckgo_search_tool,
|
153 |
+
"weather_info_tool": weather_info_tool,
|
154 |
+
"hub_stats_tool": hub_stats_tool,
|
155 |
+
"guest_info_retriever_tool": guest_info_retriever_tool,
|
156 |
+
}
|
157 |
|
158 |
# Parse question to select tools
|
159 |
async def parse_question(state: JARVISState) -> JARVISState:
|
160 |
+
"""
|
161 |
+
Parse the question to select appropriate tools using LLM with retries, preprocess the question, and integrate file-based tools.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
state (JARVISState): The input state containing task_id, question.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
JARVISState: Updated state with selected tools_needed and metadata.
|
168 |
+
"""
|
169 |
+
state = validate_state(state)
|
170 |
+
task_id = state["task_id"]
|
171 |
+
question = state["question"]
|
172 |
+
|
173 |
+
logger.info(f"Task {task_id} Parsing question: {question}")
|
174 |
try:
|
175 |
+
# Preprocess question
|
176 |
+
processed_question = await preprocess_question(question)
|
177 |
+
if processed_question != question:
|
178 |
+
logger.info(f"Task {task_id} Preprocessed question: {processed_question}")
|
179 |
+
state["question"] = processed_question
|
180 |
+
question = processed_question
|
181 |
+
|
182 |
+
# Default to search_tool
|
183 |
tools_needed = ["search_tool"]
|
184 |
+
|
185 |
+
# LLM-based tool selection
|
186 |
if llm_client:
|
187 |
prompt = ChatPromptTemplate.from_messages([
|
188 |
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
|
189 |
+
|
190 |
+
Return a JSON list of all relevant tools, e.g., ["search_tool", "duckduckgo_search_tool"].
|
191 |
+
|
192 |
Rules:
|
193 |
+
- Include "search_tool" for web-based questions unless purely computational or file-based.
|
194 |
+
- Include "multi_hop_search_tool" for questions with >20 words or requiring multiple steps.
|
195 |
+
- Include "file_parser_tool" for 'data', 'table', 'excel', 'csv', 'txt', 'mp3', or file extensions.
|
196 |
+
- Include "image_parser_tool" for 'image', 'video', 'picture', or 'painting'.
|
197 |
+
- Include "calculator_tool" for 'calculate', 'math', 'sum', 'average', 'total', or numerical operations.
|
198 |
+
- Include "document_retriever_tool" for 'document', 'pdf', 'report', or 'paper'.
|
199 |
+
- Include "duckduckgo_search_tool" for 'search', 'wikipedia', 'online', or general knowledge.
|
200 |
+
- Include "weather_info_tool" for 'weather', 'temperature', or 'forecast'.
|
201 |
+
- Include "hub_stats_tool" for 'model', 'huggingface', or 'dataset'.
|
202 |
+
- Include "guest_info_retriever_tool" for 'guest', 'name', 'relation', or 'person'.
|
203 |
+
- Select multiple tools if the question spans multiple domains (e.g., web and file).
|
204 |
- Output ONLY valid JSON."""),
|
205 |
HumanMessage(content=f"Query: {question}")
|
206 |
])
|
207 |
+
messages = prompt.format_messages()
|
208 |
+
|
209 |
+
for attempt in range(3): # Retry up to 3 times
|
210 |
+
try:
|
211 |
+
formatted_messages = [
|
212 |
+
{"role": "system" if isinstance(m, SystemMessage) else "user", "content": m.content}
|
213 |
+
for m in messages
|
214 |
+
]
|
215 |
+
if llm_type == "hf_local":
|
216 |
+
model, tokenizer = llm_client
|
217 |
+
inputs = tokenizer.apply_chat_template(
|
218 |
+
formatted_messages,
|
219 |
+
return_tensors="pt"
|
220 |
+
).to(model.device)
|
221 |
+
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.5)
|
222 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
223 |
+
elif llm_type == "together":
|
224 |
+
response = llm_client.chat.completions.create(
|
225 |
+
model=llm_model,
|
226 |
+
messages=formatted_messages,
|
227 |
+
max_tokens=100,
|
228 |
+
temperature=0.5
|
229 |
+
)
|
230 |
+
response = response.choices[0].message.content.strip()
|
231 |
+
else: # hf_api
|
232 |
+
response = llm_client.chat.completions.create(
|
233 |
+
messages=formatted_messages,
|
234 |
+
max_tokens=100,
|
235 |
+
temperature=0.5
|
236 |
+
)
|
237 |
+
response = response.choices[0].message.content.strip()
|
238 |
+
|
239 |
+
logger.info(f"Task {task_id} LLM tool selection response: {response}")
|
240 |
+
try:
|
241 |
+
tools_needed = json.loads(response)
|
242 |
+
if isinstance(tools_needed, list) and all(isinstance(t, str) and t in tools for t in tools_needed):
|
243 |
+
break # Valid response, exit retry loop
|
244 |
+
else:
|
245 |
+
raise ValueError("Invalid tool list format")
|
246 |
+
except json.JSONDecodeError as e:
|
247 |
+
logger.warning(f"Task {task_id}: Invalid JSON (attempt {attempt + 1}): {e}")
|
248 |
+
if attempt == 2:
|
249 |
+
tools_needed = ["search_tool"] # Fallback after retries
|
250 |
+
except Exception as e:
|
251 |
+
logger.warning(f"Task {task_id} Tool selection failed (attempt {attempt + 1}): {e}")
|
252 |
+
if attempt == 2:
|
253 |
+
tools_needed = ["search_tool"] # Fallback after retries
|
254 |
+
|
255 |
+
# Fallback to keyword-based selection if LLM fails
|
256 |
+
if tools_needed == ["search_tool"] and not any(kw in question.lower() for kw in ["calculate", "math", "image", "document", "file", "weather", "guest", "model"]):
|
257 |
+
question_lower = question.lower()
|
258 |
+
if any(kw in question_lower for kw in ["excel", "csv", "mp3", "data", "table", "xlsx"]):
|
259 |
+
tools_needed.append("file_parser_tool")
|
260 |
+
if any(kw in question_lower for kw in ["image", "video", "picture", "painting"]):
|
261 |
+
tools_needed.append("image_parser_tool")
|
262 |
+
if any(kw in question_lower for kw in ["calculate", "math", "sum", "average", "total"]):
|
263 |
+
tools_needed.append("calculator_tool")
|
264 |
+
if any(kw in question_lower for kw in ["document", "pdf", "report", "paper"]):
|
265 |
+
tools_needed.append("document_retriever_tool")
|
266 |
+
if any(kw in question_lower for kw in ["search", "wikipedia", "online"]):
|
267 |
+
tools_needed.append("duckduckgo_search_tool")
|
268 |
+
if any(kw in question_lower for kw in ["weather", "temperature", "forecast"]):
|
269 |
+
tools_needed.append("weather_info_tool")
|
270 |
+
if any(kw in question_lower for kw in ["model", "huggingface", "dataset"]):
|
271 |
+
tools_needed.append("hub_stats_tool")
|
272 |
+
if any(kw in question_lower for kw in ["guest", "name", "relation", "person"]):
|
273 |
+
tools_needed.append("guest_info_retriever_tool")
|
274 |
+
if len(question.split()) > 20 or "multiple" in question_lower:
|
275 |
+
tools_needed.append("multi_hop_search_tool")
|
276 |
+
|
277 |
+
# Integrate file-based tools
|
278 |
+
file_results = await fetch_task_file(task_id, question)
|
279 |
+
for ext, content in file_results.items():
|
280 |
+
if content:
|
281 |
+
os.makedirs("temp", exist_ok=True)
|
282 |
+
file_path = f"temp/{task_id}.{ext}"
|
283 |
+
with open(file_path, "wb") as f:
|
284 |
+
f.write(content)
|
285 |
+
state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path}
|
286 |
+
if ext in ["txt", "csv", "xlsx", "mp3"] and "file_parser_tool" not in tools_needed:
|
287 |
tools_needed.append("file_parser_tool")
|
288 |
+
elif ext in ["jpg", "png"] and "image_parser_tool" not in tools_needed:
|
289 |
tools_needed.append("image_parser_tool")
|
290 |
+
elif ext == "pdf" and "document_retriever_tool" not in tools_needed:
|
291 |
tools_needed.append("document_retriever_tool")
|
292 |
+
|
293 |
+
state["tools_needed"] = list(set(tools_needed)) # Remove duplicates
|
294 |
+
logger.info(f"Task {task_id} Selected tools: {state['tools_needed']}")
|
|
|
|
|
295 |
return state
|
296 |
except Exception as e:
|
297 |
+
logger.error(f"Task {task_id} Tool selection failed: {e}")
|
298 |
state["error"] = f"Parse question failed: {str(e)}"
|
299 |
state["tools_needed"] = ["search_tool"]
|
300 |
return state
|
301 |
|
302 |
# Tool dispatcher
|
303 |
async def tool_dispatcher(state: JARVISState) -> JARVISState:
|
304 |
+
state = validate_state(state)
|
305 |
try:
|
306 |
+
task_id = state["task_id"]
|
307 |
+
question = state["question"]
|
308 |
+
tools_needed = state["tools_needed"]
|
309 |
+
|
310 |
+
for tool_name in tools_needed:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
try:
|
312 |
+
if tool_name == "search_tool":
|
313 |
+
result = await tools["search_tool"].ainvoke({"query": question})
|
314 |
+
state["web_results"].extend([str(r) for r in result] if result else ["No results from search_tool"])
|
315 |
+
elif tool_name == "multi_hop_search_tool":
|
316 |
+
result = await tools["multi_hop_search_tool"].ainvoke({
|
317 |
+
"query": question,
|
318 |
+
"steps": 3,
|
319 |
+
"llm_client": llm_client,
|
320 |
+
"llm_type": llm_type,
|
321 |
+
"llm_model": llm_model
|
322 |
+
})
|
323 |
+
state["multi_hop_results"].extend([r["content"] if isinstance(r, dict) else str(r) for r in result] if result else ["No results from multi_hop_search_tool"])
|
324 |
+
elif tool_name == "file_parser_tool":
|
325 |
+
file_path = state["metadata"].get("file_path")
|
326 |
+
file_ext = state["metadata"].get("file_ext")
|
327 |
+
if file_path and os.path.exists(file_path) and file_ext:
|
328 |
+
result = await tools["file_parser_tool"].ainvoke({
|
329 |
+
"task_id": task_id,
|
330 |
+
"file_type": file_ext,
|
331 |
+
"file_path": file_path,
|
332 |
+
"query": question
|
333 |
+
})
|
334 |
+
state["file_results"] = str(result) if result else "No file results"
|
335 |
+
else:
|
336 |
+
state["file_results"] = "No file available"
|
337 |
+
elif tool_name == "image_parser_tool":
|
338 |
+
file_path = state["metadata"].get("file_path")
|
339 |
+
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] in ["jpg", "png"]:
|
340 |
+
result = await tools["image_parser_tool"].ainvoke({"task_id": task_id, "file_path": file_path})
|
341 |
+
state["image_results"] = str(result) if result else "No image results"
|
342 |
+
else:
|
343 |
+
state["image_results"] = "No image available"
|
344 |
+
elif tool_name == "calculator_tool":
|
345 |
+
result = await tools["calculator_tool"].ainvoke({"expression": question})
|
346 |
+
state["calculation_results"] = str(result) if result else "No calculation results"
|
347 |
+
elif tool_name == "document_retriever_tool":
|
348 |
+
file_path = state["metadata"].get("file_path")
|
349 |
+
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] == "pdf":
|
350 |
+
result = await tools["document_retriever_tool"].ainvoke({
|
351 |
+
"task_id": task_id,
|
352 |
+
"query": question,
|
353 |
+
"file_path": file_path
|
354 |
+
})
|
355 |
+
state["document_results"] = str(result) if result else "No document results"
|
356 |
+
else:
|
357 |
+
state["document_results"] = "No document available"
|
358 |
+
elif tool_name == "duckduckgo_search_tool":
|
359 |
+
result = await tools["duckduckgo_search_tool"].ainvoke({
|
360 |
+
"query": question,
|
361 |
+
"original_query": question,
|
362 |
+
"embedder": embedder
|
363 |
+
})
|
364 |
+
state["web_results"].extend(result if isinstance(result, list) else [str(result)] if result else ["No results from duckduckgo_search_tool"])
|
365 |
+
elif tool_name == "weather_info_tool":
|
366 |
+
location = question.split()[-1] if "weather" in question.lower() else "Unknown"
|
367 |
+
result = await tools["weather_info_tool"].ainvoke({"location": location})
|
368 |
+
state["web_results"].append(str(result) if result else "No weather results")
|
369 |
+
elif tool_name == "hub_stats_tool":
|
370 |
+
author = question.split("by ")[1].split()[0] if "by" in question.lower() else "Unknown"
|
371 |
+
result = await tools["hub_stats_tool"].ainvoke({"author": author})
|
372 |
+
state["web_results"].append(str(result) if result else "No hub stats results")
|
373 |
+
elif tool_name == "guest_info_retriever_tool":
|
374 |
+
result = await tools["guest_info_retriever_tool"].ainvoke({"query": question})
|
375 |
+
state["web_results"].append(str(result) if result else "No guest info results")
|
376 |
|
377 |
+
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_executed": True}
|
378 |
+
logger.info(f"Task {task_id}: Executed {tool_name}")
|
379 |
except Exception as e:
|
380 |
+
logger.warning(f"Tool {tool_name} failed for task {task_id}: {e}")
|
381 |
+
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_error": str(e)}
|
382 |
+
|
383 |
+
# Ensure results are populated
|
384 |
+
state["web_results"] = state.get("web_results", ["No web results found"])
|
385 |
+
state["file_results"] = state.get("file_results", "No file results found")
|
386 |
+
state["image_results"] = state.get("image_results", "No image results found")
|
387 |
+
state["document_results"] = state.get("document_results", "No document results found")
|
388 |
+
state["calculation_results"] = state.get("calculation_results", "No calculation results found")
|
389 |
+
|
390 |
+
state["answer"] = await generate_answer(
|
391 |
+
task_id=task_id,
|
392 |
+
question=question,
|
393 |
+
search_results=state.get("web_results", []) + [
|
394 |
+
r["content"] if isinstance(r, dict) else str(r) for r in state.get("multi_hop_results", [])
|
395 |
+
],
|
396 |
+
file_results=state.get("file_results", "") + state.get("document_results", "") + state.get("image_results", "") + state.get("calculation_results", ""),
|
397 |
+
llm_client=llm_client
|
398 |
+
)
|
399 |
+
|
400 |
+
logger.info(f"Task {task_id}: Generated answer: {state['answer']}")
|
401 |
+
return state
|
402 |
except Exception as e:
|
403 |
+
logger.error(f"Tool dispatch failed: {e}")
|
404 |
+
state["error"] = f"Tool dispatch failed: {e}"
|
405 |
+
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
|
407 |
# Define StateGraph
|
408 |
workflow = StateGraph(JARVISState)
|
409 |
+
workflow.add_node("parse_question", parse_question)
|
410 |
workflow.add_node("tool_dispatcher", tool_dispatcher)
|
411 |
+
workflow.set_entry_point("parse_question")
|
412 |
+
workflow.add_edge("parse_question", "tool_dispatcher")
|
413 |
+
workflow.add_edge("tool_dispatcher", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
graph = workflow.compile()
|
415 |
|
416 |
# Agent class
|
417 |
class JARVISAgent:
|
418 |
def __init__(self):
|
419 |
+
self.state = reset_state(task_id="init", question="Agent initialized")
|
420 |
+
self.state["results_table"] = [] # Initialize as empty list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
logger.info("JARVISAgent initialized.")
|
422 |
|
423 |
async def process_question(self, task_id: str, question: str) -> str:
|
424 |
+
state = reset_state(task_id=task_id, question=question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
try:
|
426 |
result = await graph.ainvoke(state)
|
427 |
+
answer = result.get("answer", "Unknown")
|
428 |
+
logger.info(f"Task {task_id} Final answer: {answer}")
|
429 |
+
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": answer})
|
430 |
+
self.state["metadata"] = {"last_task_id": task_id, "answer": answer}
|
431 |
return answer
|
432 |
except Exception as e:
|
433 |
logger.error(f"Error processing task {task_id}: {e}")
|
434 |
+
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"})
|
435 |
+
self.state["error"] = f"Task {task_id} failed: {str(e)}"
|
436 |
return f"Error: {str(e)}"
|
437 |
finally:
|
438 |
+
for ext in ["txt", "csv", "xlsx", "mp3", "jpg", "png", "pdf"]:
|
439 |
file_path = f"temp/{task_id}.{ext}"
|
440 |
if os.path.exists(file_path):
|
441 |
try:
|
|
|
447 |
async def process_all_questions(self, profile: gr.OAuthProfile | None):
|
448 |
if not profile:
|
449 |
logger.error("User not logged in.")
|
450 |
+
self.state["status_output"] = "Please Login to Hugging Face."
|
451 |
+
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
|
452 |
|
453 |
+
username = profile.username
|
454 |
logger.info(f"User logged in: {username}")
|
455 |
questions_url = f"{GAIA_API_URL}/questions"
|
456 |
submit_url = f"{GAIA_API_URL}/submit"
|
457 |
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
|
458 |
|
459 |
try:
|
460 |
+
async with await create_http_session() as session:
|
461 |
+
async with session.get(questions_url) as response:
|
462 |
+
response.raise_for_status()
|
463 |
+
questions = await response.json()
|
464 |
logger.info(f"Fetched {len(questions)} questions.")
|
465 |
except Exception as e:
|
466 |
logger.error(f"Error fetching questions: {e}")
|
467 |
+
self.state["status_output"] = f"Error fetching questions: {e}"
|
468 |
+
self.state["error"] = f"Fetch questions failed: {str(e)}"
|
469 |
+
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
|
470 |
|
471 |
answers_payload = []
|
472 |
for item in questions:
|
|
|
480 |
|
481 |
if not answers_payload:
|
482 |
logger.error("No answers generated.")
|
483 |
+
self.state["status_output"] = "No answers to submit."
|
484 |
+
self.state["error"] = "No answers generated"
|
485 |
+
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
|
486 |
|
487 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
488 |
try:
|
489 |
+
async with await create_http_session() as session:
|
490 |
+
async with session.post(submit_url, json=submission_data) as response:
|
491 |
+
response.raise_for_status()
|
492 |
+
result_data = await response.json()
|
493 |
+
self.state["status_output"] = (
|
494 |
f"Submission Successful!\n"
|
495 |
f"User: {result_data.get('username')}\n"
|
496 |
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
497 |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
498 |
f"Message: {result_data.get('message', 'No message received.')}"
|
499 |
)
|
500 |
+
self.state["metadata"] = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')}
|
501 |
except Exception as e:
|
502 |
logger.error(f"Submission failed: {e}")
|
503 |
+
self.state["status_output"] = f"Submission Failed: {e}"
|
504 |
+
self.state["error"] = f"Submission failed: {str(e)}"
|
505 |
|
506 |
+
return pd.DataFrame(self.state["results_table"] if self.state["results_table"] else [], columns=["Task ID", "Question", "Answer"]), self.state["status_output"]
|
507 |
|
508 |
# Gradio interface
|
509 |
with gr.Blocks() as demo:
|
510 |
+
gr.Markdown("# JARVIS GAIA Agent")
|
511 |
gr.Markdown(
|
512 |
"""
|
513 |
**Instructions:**
|
|
|
522 |
)
|
523 |
with gr.Row():
|
524 |
gr.LoginButton(value="Login to Hugging Face")
|
|
|
525 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
526 |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
527 |
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
|
project_structure.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.
|
2 |
+
├── app.py
|
3 |
+
├── dockerfile
|
4 |
+
├── README.md
|
5 |
+
├── requirements.txt
|
6 |
+
├── retriever.py
|
7 |
+
├── state.py
|
8 |
+
└── tools
|
9 |
+
├── __init__.py
|
10 |
+
├── answer_generator.py
|
11 |
+
├── calculator.py
|
12 |
+
├── document_retriever.py
|
13 |
+
├── duckduckgo_search.py
|
14 |
+
├── file_fetcher.py
|
15 |
+
├── file_parser.py
|
16 |
+
├── guest_info.py
|
17 |
+
├── hub_stats.py
|
18 |
+
├── image_parser.py
|
19 |
+
├── search.py
|
20 |
+
└── weather_info.py
|
21 |
+
|
22 |
+
3 directories, 18 files
|
requirements.txt
CHANGED
@@ -20,8 +20,11 @@ transformers
|
|
20 |
asyncio
|
21 |
serpapi
|
22 |
duckduckgo-search
|
23 |
-
torch
|
24 |
together
|
25 |
google-search-results
|
26 |
beautifulsoup4
|
27 |
-
gradio[oauth]
|
|
|
|
|
|
|
|
20 |
asyncio
|
21 |
serpapi
|
22 |
duckduckgo-search
|
23 |
+
torch==2.2.2
|
24 |
together
|
25 |
google-search-results
|
26 |
beautifulsoup4
|
27 |
+
gradio[oauth]
|
28 |
+
nlkt
|
29 |
+
speechrecognition
|
30 |
+
rank_bm25
|
result.txt
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
retriever.py
CHANGED
@@ -1,25 +1,109 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
try:
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
docs = [
|
10 |
Document(
|
11 |
page_content="\n".join([
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
]),
|
17 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
18 |
)
|
19 |
-
for guest in guest_dataset
|
20 |
]
|
|
|
|
|
|
|
21 |
except Exception as e:
|
22 |
-
|
|
|
23 |
docs = [
|
24 |
Document(
|
25 |
page_content="\n".join([
|
@@ -28,7 +112,73 @@ def load_guest_dataset():
|
|
28 |
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
29 |
"Email: [email protected]"
|
30 |
]),
|
31 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
32 |
)
|
33 |
]
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
from typing import List
|
6 |
+
from langchain_core.documents import Document
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
try:
|
9 |
+
from datasets import load_dataset
|
10 |
+
except ImportError:
|
11 |
+
load_dataset = None
|
12 |
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
def get_device():
|
16 |
+
"""
|
17 |
+
Determine the appropriate device for PyTorch.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: Device name ('cuda', 'mps', or 'cpu').
|
21 |
+
"""
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
return "cuda"
|
24 |
+
elif torch.backends.mps.is_available():
|
25 |
+
return "mps"
|
26 |
+
return "cpu"
|
27 |
+
|
28 |
+
def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]:
|
29 |
+
"""
|
30 |
+
Load guest dataset from a local JSON file or Hugging Face dataset.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dataset_path (str): Path to local JSON file or Hugging Face dataset name.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
List[Document]: List of Document objects with guest information.
|
37 |
+
"""
|
38 |
try:
|
39 |
+
# Try loading from Hugging Face dataset if datasets library is available
|
40 |
+
if load_dataset and not os.path.exists(dataset_path):
|
41 |
+
logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}")
|
42 |
+
guest_dataset = load_dataset(dataset_path, split="train")
|
43 |
+
docs = [
|
44 |
+
Document(
|
45 |
+
page_content="\n".join([
|
46 |
+
f"Name: {guest['name']}",
|
47 |
+
f"Relation: {guest['relation']}",
|
48 |
+
f"Description: {guest['description']}",
|
49 |
+
f"Email: {guest['email']}"
|
50 |
+
]),
|
51 |
+
metadata={
|
52 |
+
"name": guest["name"],
|
53 |
+
"relation": guest["relation"],
|
54 |
+
"description": guest["description"],
|
55 |
+
"email": guest["email"]
|
56 |
+
}
|
57 |
+
)
|
58 |
+
for guest in guest_dataset
|
59 |
+
]
|
60 |
+
logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset")
|
61 |
+
return docs
|
62 |
+
|
63 |
+
# Try loading from local JSON file
|
64 |
+
if os.path.exists(dataset_path):
|
65 |
+
logger.info(f"Loading guest dataset from local path: {dataset_path}")
|
66 |
+
with open(dataset_path, 'r') as f:
|
67 |
+
guests = json.load(f)
|
68 |
+
docs = [
|
69 |
+
Document(
|
70 |
+
page_content=guest.get('description', ''),
|
71 |
+
metadata={
|
72 |
+
'name': guest.get('name', ''),
|
73 |
+
'relation': guest.get('relation', ''),
|
74 |
+
'description': guest.get('description', ''),
|
75 |
+
'email': guest.get('email', '') # Optional email field
|
76 |
+
}
|
77 |
+
)
|
78 |
+
for guest in guests
|
79 |
+
]
|
80 |
+
logger.info(f"Loaded {len(docs)} guests from local JSON")
|
81 |
+
return docs
|
82 |
+
|
83 |
+
# Fallback to mock dataset if both fail
|
84 |
+
logger.warning(f"Dataset not found at {dataset_path}, using mock dataset")
|
85 |
docs = [
|
86 |
Document(
|
87 |
page_content="\n".join([
|
88 |
+
"Name: Dr. Nikola Tesla",
|
89 |
+
"Relation: old friend from university days",
|
90 |
+
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
91 |
+
"Email: [email protected]"
|
92 |
]),
|
93 |
+
metadata={
|
94 |
+
"name": "Dr. Nikola Tesla",
|
95 |
+
"relation": "old friend from university days",
|
96 |
+
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
97 |
+
"email": "[email protected]"
|
98 |
+
}
|
99 |
)
|
|
|
100 |
]
|
101 |
+
logger.info("Loaded mock dataset with 1 guest")
|
102 |
+
return docs
|
103 |
+
|
104 |
except Exception as e:
|
105 |
+
logger.error(f"Failed to load guest dataset: {e}")
|
106 |
+
# Return mock dataset as final fallback
|
107 |
docs = [
|
108 |
Document(
|
109 |
page_content="\n".join([
|
|
|
112 |
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
113 |
"Email: [email protected]"
|
114 |
]),
|
115 |
+
metadata={
|
116 |
+
"name": "Dr. Nikola Tesla",
|
117 |
+
"relation": "old friend from university days",
|
118 |
+
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
119 |
+
"email": "[email protected]"
|
120 |
+
}
|
121 |
)
|
122 |
]
|
123 |
+
logger.info("Loaded mock dataset with 1 guest due to error")
|
124 |
+
return docs
|
125 |
+
|
126 |
+
class BM25Retriever:
|
127 |
+
"""
|
128 |
+
A retriever class using SentenceTransformer for embedding-based search.
|
129 |
+
"""
|
130 |
+
def __init__(self, dataset_path: str):
|
131 |
+
"""
|
132 |
+
Initialize the retriever with a SentenceTransformer model.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
dataset_path (str): Path to the dataset for retrieval.
|
136 |
+
|
137 |
+
Raises:
|
138 |
+
Exception: If embedder initialization fails.
|
139 |
+
"""
|
140 |
+
try:
|
141 |
+
self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device())
|
142 |
+
self.dataset_path = dataset_path
|
143 |
+
logger.info("Initialized SentenceTransformer")
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Failed to initialize embedder: {e}")
|
146 |
+
raise
|
147 |
+
|
148 |
+
def search(self, query: str) -> List[dict]:
|
149 |
+
"""
|
150 |
+
Search the dataset for relevant guest information.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
query (str): Search query (e.g., guest name or relation).
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
List[dict]: List of matching guest metadata dictionaries.
|
157 |
+
"""
|
158 |
+
try:
|
159 |
+
# Load dataset
|
160 |
+
docs = load_guest_dataset(self.dataset_path)
|
161 |
+
if not docs:
|
162 |
+
logger.warning("No documents available for search")
|
163 |
+
return []
|
164 |
+
|
165 |
+
# Convert documents to text for BM25 (using metadata for consistency)
|
166 |
+
texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs]
|
167 |
+
from langchain_community.retrievers import BM25Retriever
|
168 |
+
retriever = BM25Retriever.from_texts(texts)
|
169 |
+
retriever.k = 3 # Limit to top 3 results
|
170 |
+
|
171 |
+
# Perform search
|
172 |
+
results = retriever.invoke(query)
|
173 |
+
# Map results back to original metadata
|
174 |
+
matches = [
|
175 |
+
docs[i].metadata
|
176 |
+
for i in range(len(docs))
|
177 |
+
if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results)
|
178 |
+
]
|
179 |
+
logger.info(f"Found {len(matches)} matches for query: {query}")
|
180 |
+
return matches[:3] # Return top 3 matches
|
181 |
+
|
182 |
+
except Exception as e:
|
183 |
+
logger.error(f"Search failed for query '{query}': {e}")
|
184 |
+
return []
|
state.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
-
from typing import TypedDict, List, Dict, Optional, Any
|
2 |
from langchain_core.messages import BaseMessage
|
|
|
|
|
|
|
3 |
|
4 |
class JARVISState(TypedDict):
|
5 |
"""
|
@@ -10,11 +13,11 @@ class JARVISState(TypedDict):
|
|
10 |
question: The question text to be answered.
|
11 |
tools_needed: List of tool names to be used for the task.
|
12 |
web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
|
13 |
-
file_results: Parsed content from text, CSV, or
|
14 |
image_results: OCR or description results from image files.
|
15 |
calculation_results: Results from mathematical calculations.
|
16 |
-
document_results: Extracted content from PDF documents.
|
17 |
-
multi_hop_results: Results from iterative multi-hop searches.
|
18 |
messages: List of messages for LLM context (e.g., user prompts, system instructions).
|
19 |
answer: Final answer for the task, formatted for GAIA submission.
|
20 |
results_table: List of task results for Gradio display (Task ID, Question, Answer).
|
@@ -30,10 +33,84 @@ class JARVISState(TypedDict):
|
|
30 |
image_results: str
|
31 |
calculation_results: str
|
32 |
document_results: str
|
33 |
-
multi_hop_results: List[str]
|
34 |
messages: List[BaseMessage]
|
35 |
answer: str
|
36 |
results_table: List[Dict[str, str]]
|
37 |
status_output: str
|
38 |
error: Optional[str]
|
39 |
-
metadata: Optional[Dict[str, Any]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, List, Dict, Optional, Any, Union
|
2 |
from langchain_core.messages import BaseMessage
|
3 |
+
import logging
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
|
7 |
class JARVISState(TypedDict):
|
8 |
"""
|
|
|
13 |
question: The question text to be answered.
|
14 |
tools_needed: List of tool names to be used for the task.
|
15 |
web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
|
16 |
+
file_results: Parsed content from text, CSV, Excel, or audio files.
|
17 |
image_results: OCR or description results from image files.
|
18 |
calculation_results: Results from mathematical calculations.
|
19 |
+
document_results: Extracted content from PDF or text documents.
|
20 |
+
multi_hop_results: Results from iterative multi-hop searches (supports strings or dicts).
|
21 |
messages: List of messages for LLM context (e.g., user prompts, system instructions).
|
22 |
answer: Final answer for the task, formatted for GAIA submission.
|
23 |
results_table: List of task results for Gradio display (Task ID, Question, Answer).
|
|
|
33 |
image_results: str
|
34 |
calculation_results: str
|
35 |
document_results: str
|
36 |
+
multi_hop_results: List[Union[str, Dict[str, Any]]]
|
37 |
messages: List[BaseMessage]
|
38 |
answer: str
|
39 |
results_table: List[Dict[str, str]]
|
40 |
status_output: str
|
41 |
error: Optional[str]
|
42 |
+
metadata: Optional[Dict[str, Any]]
|
43 |
+
|
44 |
+
def validate_state(state: JARVISState) -> JARVISState:
|
45 |
+
"""
|
46 |
+
Validate and initialize JARVISState fields.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
state: Input state dictionary.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Validated and initialized state.
|
53 |
+
"""
|
54 |
+
try:
|
55 |
+
if not state.get("task_id"):
|
56 |
+
logger.error("task_id is required")
|
57 |
+
raise ValueError("task_id is required")
|
58 |
+
if not state.get("question"):
|
59 |
+
logger.error("question is required")
|
60 |
+
raise ValueError("question is required")
|
61 |
+
|
62 |
+
# Initialize default values if missing
|
63 |
+
defaults = {
|
64 |
+
"tools_needed": ["search_tool"],
|
65 |
+
"web_results": [],
|
66 |
+
"file_results": "",
|
67 |
+
"image_results": "",
|
68 |
+
"calculation_results": "",
|
69 |
+
"document_results": "",
|
70 |
+
"multi_hop_results": [],
|
71 |
+
"messages": [],
|
72 |
+
"answer": "",
|
73 |
+
"results_table": [],
|
74 |
+
"status_output": "",
|
75 |
+
"error": None,
|
76 |
+
"metadata": {}
|
77 |
+
}
|
78 |
+
for key, default in defaults.items():
|
79 |
+
if key not in state or state[key] is None:
|
80 |
+
state[key] = default
|
81 |
+
|
82 |
+
logger.debug(f"Validated state for task {state['task_id']}")
|
83 |
+
return state
|
84 |
+
except Exception as e:
|
85 |
+
logger.error(f"State validation failed: {e}")
|
86 |
+
raise
|
87 |
+
|
88 |
+
def reset_state(task_id: str, question: str) -> JARVISState:
|
89 |
+
"""
|
90 |
+
Create a fresh JARVISState for a new task.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
task_id: Task identifier.
|
94 |
+
question: Question text.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Initialized JARVISState.
|
98 |
+
"""
|
99 |
+
state = JARVISState(
|
100 |
+
task_id=task_id,
|
101 |
+
question=question,
|
102 |
+
tools_needed=["search_tool"],
|
103 |
+
web_results=[],
|
104 |
+
file_results="",
|
105 |
+
image_results="",
|
106 |
+
calculation_results="",
|
107 |
+
document_results="",
|
108 |
+
multi_hop_results=[],
|
109 |
+
messages=[],
|
110 |
+
answer="",
|
111 |
+
results_table=[],
|
112 |
+
status_output="",
|
113 |
+
error=None,
|
114 |
+
metadata={}
|
115 |
+
)
|
116 |
+
return validate_state(state)
|
test.py
CHANGED
@@ -1,10 +1,233 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
}
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
from app import JARVISAgent, llm_client, llm_type, llm_model, embedder
|
7 |
+
from tools.search import search_tool, multi_hop_search_tool
|
8 |
+
from tools.file_parser import file_parser_tool
|
9 |
+
from tools.image_parser import image_parser_tool
|
10 |
+
from tools.calculator import calculator_tool
|
11 |
+
from tools.document_retriever import document_retriever_tool
|
12 |
+
from tools.duckduckgo_search import duckduckgo_search_tool
|
13 |
+
from tools.weather_info import weather_info_tool
|
14 |
+
from tools.hub_stats import hub_stats_tool
|
15 |
+
from tools.guest_info import guest_info_retriever_tool
|
16 |
+
from tools.file_fetcher import fetch_task_file
|
17 |
+
from tools.answer_generator import preprocess_question, filter_results
|
18 |
+
from state import validate_state, reset_state, JARVISState
|
19 |
|
20 |
+
# Setup logging
|
21 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
22 |
+
logger = logging.getLogger(__name__)
|
|
|
23 |
|
24 |
+
async def test_tools():
|
25 |
+
"""Test all tools."""
|
26 |
+
logger.info("Testing Search Tool (SerpAPI)...")
|
27 |
+
try:
|
28 |
+
if not os.getenv("SERPAPI_API_KEY"):
|
29 |
+
logger.warning("Search Warning: SERPAPI_API_KEY not set")
|
30 |
+
else:
|
31 |
+
result = await search_tool.ainvoke({"query": "What is the capital of France?"})
|
32 |
+
logger.info(f"Search Result: {result}")
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f"Search Error: {e}")
|
35 |
+
|
36 |
+
logger.info("Testing Multi-Hop Search Tool...")
|
37 |
+
try:
|
38 |
+
result = await multi_hop_search_tool.ainvoke({
|
39 |
+
"query": "What is the population of France's capital?",
|
40 |
+
"steps": 2,
|
41 |
+
"llm_client": llm_client,
|
42 |
+
"llm_type": llm_type,
|
43 |
+
"llm_model": llm_model
|
44 |
+
})
|
45 |
+
logger.info(f"Multi-Hop Search Result: {result}")
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Multi-Hop Search Error: {e}")
|
48 |
+
|
49 |
+
logger.info("Testing DuckDuckGo Search Tool...")
|
50 |
+
try:
|
51 |
+
result = await duckduckgo_search_tool.ainvoke({
|
52 |
+
"query": "What is the capital of France?",
|
53 |
+
"original_query": "What is the capital of France?",
|
54 |
+
"embedder": embedder
|
55 |
+
})
|
56 |
+
logger.info(f"DuckDuckGo Result: {result}")
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(f"DuckDuckGo Error: {e}")
|
59 |
+
|
60 |
+
logger.info("Testing Weather Info Tool...")
|
61 |
+
try:
|
62 |
+
if not os.getenv("OPENWEATHERMAP_API_KEY"):
|
63 |
+
logger.warning("Weather Warning: OPENWEATHERMAP_API_KEY not set")
|
64 |
+
else:
|
65 |
+
result = await weather_info_tool.ainvoke({"location": "London"})
|
66 |
+
logger.info(f"Weather Result: {result}")
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Weather Error: {e}")
|
69 |
+
|
70 |
+
logger.info("Testing Document Retriever Tool...")
|
71 |
+
try:
|
72 |
+
from PyPDF2 import PdfWriter
|
73 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
74 |
+
writer = PdfWriter()
|
75 |
+
from PyPDF2.generic import NameObject, create_string_object
|
76 |
+
page = writer.add_blank_page(width=72, height=72)
|
77 |
+
page[NameObject("/Contents")] = create_string_object("Sample document content for testing.")
|
78 |
+
writer.write(tmp)
|
79 |
+
tmp_path = tmp.name
|
80 |
+
result = await document_retriever_tool.ainvoke({
|
81 |
+
"task_id": "test_task",
|
82 |
+
"query": "Sample question",
|
83 |
+
"file_path": tmp_path
|
84 |
+
})
|
85 |
+
logger.info(f"Document Retriever Result: {result}")
|
86 |
+
os.unlink(tmp_path)
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Document Retriever Error: {e}")
|
89 |
+
|
90 |
+
logger.info("Testing Image Parser Tool...")
|
91 |
+
try:
|
92 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
93 |
+
# Create a minimal PNG (1x1 pixel)
|
94 |
+
from PIL import Image
|
95 |
+
img = Image.new('RGB', (1, 1), color='white')
|
96 |
+
img.save(tmp.name, 'PNG')
|
97 |
+
tmp_path = tmp.name
|
98 |
+
result = await image_parser_tool.ainvoke({"task_id": "test_task", "file_path": tmp_path})
|
99 |
+
logger.info(f"Image Parser Result: {result}")
|
100 |
+
os.unlink(tmp_path)
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Image Parser Error: {e}")
|
103 |
+
|
104 |
+
logger.info("Testing File Parser Tool...")
|
105 |
+
try:
|
106 |
+
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
|
107 |
+
tmp.write(b"Sample text file content")
|
108 |
+
tmp_path = tmp.name
|
109 |
+
result = await file_parser_tool.ainvoke({
|
110 |
+
"task_id": "test_task",
|
111 |
+
"file_type": "txt",
|
112 |
+
"file_path": tmp_path,
|
113 |
+
"query": "What is in the file?"
|
114 |
+
})
|
115 |
+
logger.info(f"File Parser Result: {result}")
|
116 |
+
os.unlink(tmp_path)
|
117 |
+
except Exception as e:
|
118 |
+
logger.error(f"File Parser Error: {e}")
|
119 |
+
|
120 |
+
logger.info("Testing Calculator Tool...")
|
121 |
+
try:
|
122 |
+
result = await calculator_tool.ainvoke({"expression": "2 + 2"})
|
123 |
+
logger.info(f"Calculator Result: {result}")
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Calculator Error: {e}")
|
126 |
+
|
127 |
+
logger.info("Testing Hub Stats Tool...")
|
128 |
+
try:
|
129 |
+
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
|
130 |
+
logger.warning("Hub Stats Warning: HUGGINGFACEHUB_API_TOKEN not set")
|
131 |
+
else:
|
132 |
+
result = await hub_stats_tool.ainvoke({"author": "meta-llama"})
|
133 |
+
logger.info(f"Hub Stats Result: {result}")
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"Hub Stats Error: {e}")
|
136 |
+
|
137 |
+
logger.info("Testing Guest Info Retriever Tool...")
|
138 |
+
try:
|
139 |
+
result = await guest_info_retriever_tool.ainvoke({"query": "Who is the guest named John?"})
|
140 |
+
logger.info(f"Guest Info Result: {result}")
|
141 |
+
except Exception as e:
|
142 |
+
logger.error(f"Guest Info Error: {e}")
|
143 |
+
|
144 |
+
async def test_file_fetcher():
|
145 |
+
"""Test file fetcher."""
|
146 |
+
logger.info("Testing File Fetcher...")
|
147 |
+
try:
|
148 |
+
result = await fetch_task_file("8e867cd7-cff9-4e6c-867a-ff5ddc2550be", "Sample question with data")
|
149 |
+
logger.info(f"File Fetcher Result: {result}")
|
150 |
+
except Exception as e:
|
151 |
+
logger.error(f"File Fetcher Error: {e}")
|
152 |
+
|
153 |
+
async def test_answer_generator():
|
154 |
+
"""Test answer generator functions."""
|
155 |
+
logger.info("Testing Preprocess Question...")
|
156 |
+
try:
|
157 |
+
result = await preprocess_question("What's the weather in Paris?")
|
158 |
+
logger.info(f"Preprocess Question Result: {result}")
|
159 |
+
except Exception as e:
|
160 |
+
logger.error(f"Preprocess Question Error: {e}")
|
161 |
+
|
162 |
+
logger.info("Testing Filter Results...")
|
163 |
+
try:
|
164 |
+
results = ["Paris is the capital of France.", "Florida is a state.", "Paris is in Texas."]
|
165 |
+
filtered = filter_results(results, "What is the capital of France?")
|
166 |
+
logger.info(f"Filter Results: {filtered}")
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Filter Results Error: {e}")
|
169 |
+
|
170 |
+
async def test_state_management():
|
171 |
+
"""Test state management functions."""
|
172 |
+
logger.info("Testing Reset State...")
|
173 |
+
try:
|
174 |
+
state = reset_state("test_task", "What is the capital of France?")
|
175 |
+
logger.info(f"Reset State Result: {state}")
|
176 |
+
except Exception as e:
|
177 |
+
logger.error(f"Reset State Error: {e}")
|
178 |
+
|
179 |
+
logger.info("Testing Validate State...")
|
180 |
+
try:
|
181 |
+
invalid_state = {"task_id": "", "question": ""}
|
182 |
+
validate_state(invalid_state)
|
183 |
+
logger.error("Validate State should have failed")
|
184 |
+
except ValueError as e:
|
185 |
+
logger.info(f"Validate State Error (expected): {e}")
|
186 |
+
|
187 |
+
try:
|
188 |
+
valid_state = reset_state("test_task", "Sample question")
|
189 |
+
validated = validate_state(valid_state)
|
190 |
+
logger.info(f"Validate State Result: {validated}")
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"Validate State Error: {e}")
|
193 |
+
|
194 |
+
async def test_agent():
|
195 |
+
"""Test JARVISAgent with various cases."""
|
196 |
+
logger.info("Testing JARVISAgent (Simple Question)...")
|
197 |
+
try:
|
198 |
+
agent = JARVISAgent()
|
199 |
+
answer = await agent.process_question("test_task", "What is the capital of France?")
|
200 |
+
logger.info(f"JARVISAgent Answer: {answer}")
|
201 |
+
except Exception as e:
|
202 |
+
logger.error(f"JARVISAgent Error: {e}")
|
203 |
+
|
204 |
+
logger.info("Testing JARVISAgent (Edge Case: Empty Question)...")
|
205 |
+
try:
|
206 |
+
agent = JARVISAgent()
|
207 |
+
answer = await agent.process_question("test_task", "")
|
208 |
+
logger.info(f"JARVISAgent Empty Question Answer: {answer}")
|
209 |
+
except Exception as e:
|
210 |
+
logger.info(f"JARVISAgent Empty Question Error (expected): {e}")
|
211 |
+
|
212 |
+
async def main():
|
213 |
+
required_envs = [
|
214 |
+
"HUGGINGFACEHUB_API_TOKEN",
|
215 |
+
"TOGETHER_API_KEY",
|
216 |
+
"OPENWEATHERMAP_API_KEY",
|
217 |
+
"SERPAPI_API_KEY"
|
218 |
+
]
|
219 |
+
for env in required_envs:
|
220 |
+
if not os.getenv(env):
|
221 |
+
logger.warning(f"{env} not set, some tools may fail")
|
222 |
+
|
223 |
+
await test_tools()
|
224 |
+
await test_file_fetcher()
|
225 |
+
await test_answer_generator()
|
226 |
+
await test_state_management()
|
227 |
+
await test_agent()
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
try:
|
231 |
+
asyncio.run(main())
|
232 |
+
except Exception as e:
|
233 |
+
logger.error(f"Test script failed: {e}")
|
tools/__init__.py
CHANGED
@@ -6,4 +6,6 @@ from .document_retriever import document_retriever_tool
|
|
6 |
from .duckduckgo_search import duckduckgo_search_tool
|
7 |
from .weather_info import weather_info_tool
|
8 |
from .hub_stats import hub_stats_tool
|
9 |
-
from .guest_info import guest_info_retriever_tool
|
|
|
|
|
|
6 |
from .duckduckgo_search import duckduckgo_search_tool
|
7 |
from .weather_info import weather_info_tool
|
8 |
from .hub_stats import hub_stats_tool
|
9 |
+
from .guest_info import guest_info_retriever_tool
|
10 |
+
from .file_fetcher import fetch_task_file
|
11 |
+
from .answer_generator import generate_answer, preprocess_question#, filter_results, get_embedder
|
tools/answer_generator.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
from typing import List, Any
|
5 |
+
from langchain_core.prompts import ChatPromptTemplate
|
6 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
|
9 |
+
# Setup logging
|
10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# Download NLTK data
|
14 |
+
try:
|
15 |
+
nltk.download('punkt', quiet=True)
|
16 |
+
nltk.download('stopwords', quiet=True)
|
17 |
+
except Exception as e:
|
18 |
+
logger.warning(f"NLTK data download failed: {e}")
|
19 |
+
|
20 |
+
# Global embedder
|
21 |
+
_embedder = None
|
22 |
+
|
23 |
+
def get_embedder():
|
24 |
+
global _embedder
|
25 |
+
if _embedder is None:
|
26 |
+
try:
|
27 |
+
_embedder = SentenceTransformer(
|
28 |
+
"all-MiniLM-L6-v2",
|
29 |
+
device="cpu",
|
30 |
+
cache_folder="./cache"
|
31 |
+
)
|
32 |
+
logger.info("SentenceTransformer initialized")
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f"Failed to initialize SentenceTransformer: {e}")
|
35 |
+
raise RuntimeError(f"Embedder initialization failed: {e}")
|
36 |
+
return _embedder
|
37 |
+
|
38 |
+
def filter_results(search_results: List[str], question: str) -> List[str]:
|
39 |
+
try:
|
40 |
+
if not search_results or not question:
|
41 |
+
return search_results
|
42 |
+
|
43 |
+
embedder = get_embedder()
|
44 |
+
question_embedding = embedder.encode([question], convert_to_numpy=True)
|
45 |
+
result_embeddings = embedder.encode(search_results, convert_to_numpy=True)
|
46 |
+
|
47 |
+
similarities = np.dot(result_embeddings, question_embedding.T).flatten()
|
48 |
+
filtered_results = [
|
49 |
+
search_results[i] for i in range(len(search_results))
|
50 |
+
if similarities[i] > 0.5 and search_results[i].strip()
|
51 |
+
]
|
52 |
+
|
53 |
+
return filtered_results if filtered_results else search_results[:3]
|
54 |
+
except Exception as e:
|
55 |
+
logger.warning(f"Result filtering failed: {e}")
|
56 |
+
return search_results[:3]
|
57 |
+
|
58 |
+
async def preprocess_question(question: str) -> str:
|
59 |
+
"""Preprocess the question to clean and standardize it."""
|
60 |
+
try:
|
61 |
+
question = question.strip().lower()
|
62 |
+
if not question.endswith("?"):
|
63 |
+
question += "?"
|
64 |
+
logger.debug(f"Preprocessed question: {question}")
|
65 |
+
return question
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Error preprocessing question: {e}")
|
68 |
+
return question
|
69 |
+
|
70 |
+
async def generate_answer(
|
71 |
+
task_id: str,
|
72 |
+
question: str,
|
73 |
+
search_results: List[str],
|
74 |
+
file_results: str,
|
75 |
+
llm_client: Any
|
76 |
+
) -> str:
|
77 |
+
"""Generate an answer using LLM with search and file results."""
|
78 |
+
try:
|
79 |
+
if not search_results:
|
80 |
+
search_results = ["No search results available."]
|
81 |
+
if not file_results:
|
82 |
+
file_results = "No file results available."
|
83 |
+
|
84 |
+
context = "\n".join([str(r) for r in search_results]) + "\n" + file_results
|
85 |
+
prompt = ChatPromptTemplate.from_messages([
|
86 |
+
SystemMessage(content="""You are an assistant answering questions using provided context.
|
87 |
+
- Use ONLY the context to formulate a concise, accurate answer.
|
88 |
+
- If the context is insufficient, state: 'Insufficient information to answer.'
|
89 |
+
- Do NOT generate or assume information beyond the context.
|
90 |
+
- Return a single, clear sentence or phrase as the answer."""),
|
91 |
+
HumanMessage(content=f"Context: {context}\nQuestion: {question}")
|
92 |
+
])
|
93 |
+
|
94 |
+
messages = [
|
95 |
+
{"role": "system", "content": prompt[0].content},
|
96 |
+
{"role": "user", "content": prompt[1].content}
|
97 |
+
]
|
98 |
+
|
99 |
+
if isinstance(llm_client, tuple): # hf_local
|
100 |
+
model, tokenizer = llm_client
|
101 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
102 |
+
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7)
|
103 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
104 |
+
elif hasattr(llm_client, "chat"): # together
|
105 |
+
response = llm_client.chat.completions.create(
|
106 |
+
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
107 |
+
messages=messages,
|
108 |
+
max_tokens=100,
|
109 |
+
temperature=0.7,
|
110 |
+
top_p=0.9,
|
111 |
+
frequency_penalty=0.5
|
112 |
+
)
|
113 |
+
response = response.choices[0].message.content.strip()
|
114 |
+
else: # hf_api
|
115 |
+
response = llm_client.chat.completions.create(
|
116 |
+
messages=messages,
|
117 |
+
max_tokens=100,
|
118 |
+
temperature=0.7
|
119 |
+
)
|
120 |
+
response = response.choices[0].message.content.strip()
|
121 |
+
|
122 |
+
answer = response.strip()
|
123 |
+
if not answer or answer.lower() == "none":
|
124 |
+
answer = "Insufficient information to answer."
|
125 |
+
logger.info(f"Task {task_id}: Generated answer: {answer}")
|
126 |
+
return answer
|
127 |
+
except Exception as e:
|
128 |
+
logger.error(f"Task {task_id}: Answer generation failed: {e}")
|
129 |
+
return "Error generating answer."
|
tools/calculator.py
CHANGED
@@ -1,15 +1,35 @@
|
|
1 |
-
from langchain_core.tools import tool
|
2 |
-
from sympy import sympify
|
3 |
import logging
|
|
|
|
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
try:
|
11 |
-
|
|
|
|
|
|
|
12 |
return str(result)
|
13 |
except Exception as e:
|
14 |
-
logger.error(f"
|
15 |
-
return f"Error: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from langchain_core.tools import StructuredTool
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
7 |
+
class CalculatorInput(BaseModel):
|
8 |
+
expression: str = Field(description="Mathematical expression to evaluate")
|
9 |
+
|
10 |
+
async def calculator_func(expression: str) -> str:
|
11 |
+
"""
|
12 |
+
Evaluate a mathematical expression and return the result as a string.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
expression (str): Mathematical expression (e.g., '2 + 2').
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
str: Result of the expression.
|
19 |
+
"""
|
20 |
try:
|
21 |
+
logger.info(f"Evaluating expression: {expression}")
|
22 |
+
result = eval(expression, {"__builtins__": {}}, {}) # Safe eval
|
23 |
+
if isinstance(result, float):
|
24 |
+
return f"{result:.2f}" if "USD" in expression else str(result)
|
25 |
return str(result)
|
26 |
except Exception as e:
|
27 |
+
logger.error(f"Calculator error: {e}")
|
28 |
+
return f"Error: {e}"
|
29 |
+
|
30 |
+
calculator_tool = StructuredTool.from_function(
|
31 |
+
func=calculator_func,
|
32 |
+
name="calculator_tool",
|
33 |
+
args_schema=CalculatorInput,
|
34 |
+
coroutine=calculator_func
|
35 |
+
)
|
tools/document_retriever.py
CHANGED
@@ -1,30 +1,47 @@
|
|
1 |
-
from langchain_core.tools import tool
|
2 |
-
from langchain_community.document_loaders import TextLoader, CSVLoader, PyPDFLoader
|
3 |
import logging
|
4 |
import os
|
|
|
|
|
|
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
try:
|
12 |
-
file_path
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
else:
|
24 |
-
return f"Unsupported file type: {file_type}"
|
25 |
-
|
26 |
-
docs = loader.load()
|
27 |
-
return "\n".join(doc.page_content for doc in docs)
|
28 |
except Exception as e:
|
29 |
logger.error(f"Error retrieving document for task {task_id}: {e}")
|
30 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
+
from langchain_core.tools import StructuredTool
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
from typing import Optional
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
9 |
+
class DocumentRetrieverInput(BaseModel):
|
10 |
+
task_id: str = Field(description="Task identifier")
|
11 |
+
query: str = Field(description="Search query")
|
12 |
+
file_path: Optional[str] = Field(description="Path to document file", default=None)
|
13 |
+
|
14 |
+
async def document_retriever_func(task_id: str, query: str, file_path: Optional[str] = None) -> str:
|
15 |
+
"""
|
16 |
+
Retrieve content from documents for a given task and query.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
task_id (str): Task identifier.
|
20 |
+
query (str): Search query.
|
21 |
+
file_path (Optional[str]): Path to document file.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
str: Retrieved document content or error message.
|
25 |
+
"""
|
26 |
try:
|
27 |
+
if file_path and os.path.exists(file_path):
|
28 |
+
logger.info(f"Retrieving document from {file_path} for task {task_id}")
|
29 |
+
if file_path.endswith('.pdf'):
|
30 |
+
import PyPDF2
|
31 |
+
with open(file_path, 'rb') as f:
|
32 |
+
reader = PyPDF2.PdfReader(f)
|
33 |
+
text = "".join(page.extract_text() or "" for page in reader.pages)
|
34 |
+
return text[:500] if text else "No text extracted"
|
35 |
+
return "Unsupported file format"
|
36 |
+
logger.warning(f"No valid documents found for task {task_id}")
|
37 |
+
return "Document not found"
|
|
|
|
|
|
|
|
|
|
|
38 |
except Exception as e:
|
39 |
logger.error(f"Error retrieving document for task {task_id}: {e}")
|
40 |
+
return f"Error: {str(e)}"
|
41 |
+
|
42 |
+
document_retriever_tool = StructuredTool.from_function(
|
43 |
+
func=document_retriever_func,
|
44 |
+
name="document_retriever_tool",
|
45 |
+
args_schema=DocumentRetrieverInput,
|
46 |
+
coroutine=document_retriever_func
|
47 |
+
)
|
tools/duckduckgo_search.py
CHANGED
@@ -1,6 +1,99 @@
|
|
1 |
-
from smolagents import Tool, DuckDuckGoSearchTool
|
2 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
+
import asyncio
|
4 |
+
from langchain_core.tools import StructuredTool
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from typing import Optional, List
|
7 |
+
from duckduckgo_search import DDGS
|
8 |
+
from serpapi import GoogleSearch
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
+
class DuckDuckGoSearchInput(BaseModel):
|
13 |
+
query: str = Field(description="Search query")
|
14 |
+
original_query: str = Field(description="Original query for context")
|
15 |
+
embedder: Optional[object] = Field(description="SentenceTransformer embedder", default=None)
|
16 |
+
|
17 |
+
async def duckduckgo_search_func(query: str, original_query: str, embedder: Optional[object] = None) -> List[str]:
|
18 |
+
"""
|
19 |
+
Perform a DuckDuckGo search with retries and fall back to SerpAPI if needed.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
query (str): Search query.
|
23 |
+
original_query (str): Original query for context.
|
24 |
+
embedder (Optional[object]): SentenceTransformer for result filtering.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List[str]: List of search result snippets.
|
28 |
+
"""
|
29 |
+
async def try_duckduckgo(query: str, max_retries: int = 3) -> List[str]:
|
30 |
+
for attempt in range(max_retries):
|
31 |
+
try:
|
32 |
+
logger.info(f"DuckDuckGo search attempt {attempt + 1} for query: {query}")
|
33 |
+
with DDGS() as ddgs:
|
34 |
+
results = [r['body'] for r in ddgs.text(query, max_results=5)]
|
35 |
+
return results
|
36 |
+
except Exception as e:
|
37 |
+
if "Ratelimit" in str(e) and attempt < max_retries - 1:
|
38 |
+
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
39 |
+
logger.warning(f"DuckDuckGo rate limit hit, retrying in {wait_time}s: {e}")
|
40 |
+
await asyncio.sleep(wait_time)
|
41 |
+
else:
|
42 |
+
logger.error(f"DuckDuckGo search failed for query '{query}': {e}")
|
43 |
+
raise e
|
44 |
+
return []
|
45 |
+
|
46 |
+
async def try_serpapi(query: str, max_retries: int = 3) -> List[str]:
|
47 |
+
if not os.getenv("SERPAPI_API_KEY"):
|
48 |
+
logger.warning("SERPAPI_API_KEY not set, cannot use SerpAPI fallback")
|
49 |
+
return []
|
50 |
+
for attempt in range(max_retries):
|
51 |
+
try:
|
52 |
+
logger.info(f"SerpAPI search attempt {attempt + 1} for query: {query}")
|
53 |
+
params = {
|
54 |
+
"q": query,
|
55 |
+
"api_key": os.getenv("SERPAPI_API_KEY"),
|
56 |
+
"num": 5
|
57 |
+
}
|
58 |
+
search = GoogleSearch(params)
|
59 |
+
results = search.get_dict().get("organic_results", [])
|
60 |
+
return [result.get("snippet", "") for result in results if "snippet" in result]
|
61 |
+
except Exception as e:
|
62 |
+
if attempt < max_retries - 1:
|
63 |
+
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
64 |
+
logger.warning(f"SerpAPI search failed, retrying in {wait_time}s: {e}")
|
65 |
+
await asyncio.sleep(wait_time)
|
66 |
+
else:
|
67 |
+
logger.error(f"SerpAPI search failed for query '{query}': {e}")
|
68 |
+
return []
|
69 |
+
|
70 |
+
try:
|
71 |
+
# Try DuckDuckGo with retries
|
72 |
+
logger.info(f"Executing DuckDuckGo search for query: {query}")
|
73 |
+
results = await try_duckduckgo(query)
|
74 |
+
|
75 |
+
# Fall back to SerpAPI if DuckDuckGo fails
|
76 |
+
if not results:
|
77 |
+
logger.info(f"DuckDuckGo returned no results, falling back to SerpAPI for query: {query}")
|
78 |
+
results = await try_serpapi(query)
|
79 |
+
|
80 |
+
# Rank results if embedder is provided
|
81 |
+
if embedder and results:
|
82 |
+
from sentence_transformers import util
|
83 |
+
query_embedding = embedder.encode(original_query, convert_to_tensor=True)
|
84 |
+
result_embeddings = embedder.encode(results, convert_to_tensor=True)
|
85 |
+
scores = util.cos_sim(query_embedding, result_embeddings)[0]
|
86 |
+
ranked_results = [results[i] for i in scores.argsort(descending=True)]
|
87 |
+
return ranked_results[:3]
|
88 |
+
|
89 |
+
return results[:3] if results else []
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Search failed for query '{query}': {e}")
|
92 |
+
return []
|
93 |
+
|
94 |
+
duckduckgo_search_tool = StructuredTool.from_function(
|
95 |
+
func=duckduckgo_search_func,
|
96 |
+
name="duckduckgo_search_tool",
|
97 |
+
args_schema=DuckDuckGoSearchInput,
|
98 |
+
coroutine=duckduckgo_search_func
|
99 |
+
)
|
tools/file_fetcher.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ssl
|
3 |
+
import aiohttp
|
4 |
+
import logging
|
5 |
+
from typing import Dict
|
6 |
+
from urllib.parse import urljoin
|
7 |
+
|
8 |
+
# Setup logging
|
9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
async def fetch_task_file(task_id: str, question: str) -> Dict[str, bytes]:
|
13 |
+
"""
|
14 |
+
Fetch a file associated with a task from the GAIA API.
|
15 |
+
Returns a dictionary of file extensions to content.
|
16 |
+
"""
|
17 |
+
results = {}
|
18 |
+
base_url = "https://gaia-benchmark-api.hf.space/files/" # Updated URL
|
19 |
+
extensions = ["xlsx", "csv", "pdf", "txt", "mp3", "jpg", "png"]
|
20 |
+
|
21 |
+
ssl_context = ssl.create_default_context()
|
22 |
+
ssl_context.check_hostname = False
|
23 |
+
ssl_context.verify_mode = ssl.CERT_NONE
|
24 |
+
|
25 |
+
async with aiohttp.ClientSession(
|
26 |
+
connector=aiohttp.TCPConnector(ssl_context=ssl_context),
|
27 |
+
timeout=aiohttp.ClientTimeout(total=30)
|
28 |
+
) as session:
|
29 |
+
for ext in extensions:
|
30 |
+
file_url = urljoin(base_url, f"{task_id}/{task_id}.{ext}")
|
31 |
+
try:
|
32 |
+
async with session.get(file_url) as response:
|
33 |
+
if response.status == 200:
|
34 |
+
content = await response.read()
|
35 |
+
results[ext] = content
|
36 |
+
logger.info(f"Fetched {ext} for task {task_id}")
|
37 |
+
else:
|
38 |
+
logger.warning(f"No {ext} for task {task_id}: HTTP {response.status}")
|
39 |
+
except Exception as e:
|
40 |
+
logger.warning(f"Error fetching {ext} for task {task_id}: {str(e)}")
|
41 |
+
|
42 |
+
return results
|
tools/file_parser.py
CHANGED
@@ -1,36 +1,112 @@
|
|
1 |
-
from langchain_core.tools import tool
|
2 |
-
import pandas as pd
|
3 |
-
import PyPDF2
|
4 |
import logging
|
5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
try:
|
13 |
-
file_path = f"temp_{task_id}.{file_type}"
|
14 |
if not os.path.exists(file_path):
|
15 |
logger.warning(f"File not found: {file_path}")
|
16 |
return "File not found"
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
df = pd.read_csv(file_path)
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
elif file_type == "pdf":
|
25 |
with open(file_path, "rb") as f:
|
26 |
reader = PyPDF2.PdfReader(f)
|
27 |
-
text = "".join(page.extract_text() for page in reader.pages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return text
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
else:
|
|
|
33 |
return f"Unsupported file type: {file_type}"
|
|
|
34 |
except Exception as e:
|
35 |
logger.error(f"Error parsing file for task {task_id}: {e}")
|
36 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
+
import pandas as pd
|
4 |
+
import PyPDF2
|
5 |
+
import speech_recognition as sr
|
6 |
+
import re
|
7 |
+
from langchain_core.tools import StructuredTool
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
from typing import Optional
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
+
class FileParserInput(BaseModel):
|
14 |
+
task_id: str = Field(description="Task identifier")
|
15 |
+
file_type: str = Field(description="File extension (e.g., pdf, csv)")
|
16 |
+
file_path: str = Field(description="Path to the file")
|
17 |
+
query: Optional[str] = Field(description="Query related to the file", default=None)
|
18 |
+
|
19 |
+
async def file_parser_func(task_id: str, file_type: str, file_path: str, query: Optional[str] = None) -> str:
|
20 |
+
"""
|
21 |
+
Parse a file based on task_id, file_type, file_path, and query context.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
task_id (str): Task identifier.
|
25 |
+
file_type (str): File extension (e.g., 'xlsx', 'mp3', 'pdf').
|
26 |
+
file_path (str): Path to the file.
|
27 |
+
query (Optional[str]): Question context to guide parsing (e.g., for specific data extraction).
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
str: Parsed content or error message.
|
31 |
+
"""
|
32 |
try:
|
|
|
33 |
if not os.path.exists(file_path):
|
34 |
logger.warning(f"File not found: {file_path}")
|
35 |
return "File not found"
|
36 |
|
37 |
+
logger.info(f"Parsing file: {file_path} for task {task_id}")
|
38 |
+
|
39 |
+
if file_type in ["xlsx", "xls"]:
|
40 |
+
df = pd.read_excel(file_path, engine="openpyxl")
|
41 |
+
if query and ("sum" in query.lower() or "total" in query.lower()):
|
42 |
+
numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
|
43 |
+
if numerical_cols.empty:
|
44 |
+
return "No numerical data found"
|
45 |
+
if "food" in query.lower():
|
46 |
+
food_rows = df[df.apply(lambda x: "food" in str(x).lower(), axis=1)]
|
47 |
+
if not food_rows.empty and numerical_cols[0] in food_rows:
|
48 |
+
total = food_rows[numerical_cols[0]].sum()
|
49 |
+
return f"{total:.2f}"
|
50 |
+
total = df[numerical_cols[0]].sum()
|
51 |
+
return f"{total:.2f}"
|
52 |
+
return df.to_string(index=False)
|
53 |
+
|
54 |
+
elif file_type == "csv":
|
55 |
df = pd.read_csv(file_path)
|
56 |
+
if query and ("sum" in query.lower() or "total" in query.lower()):
|
57 |
+
numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
|
58 |
+
if numerical_cols.empty:
|
59 |
+
return "No numerical data found"
|
60 |
+
total = df[numerical_cols[0]].sum()
|
61 |
+
return f"{total:.2f}"
|
62 |
+
return df.to_string(index=False)
|
63 |
+
|
64 |
elif file_type == "pdf":
|
65 |
with open(file_path, "rb") as f:
|
66 |
reader = PyPDF2.PdfReader(f)
|
67 |
+
text = "".join(page.extract_text() or "" for page in reader.pages)
|
68 |
+
if query and "page number" in query.lower():
|
69 |
+
pages = re.findall(r'\b\d+\b', text)
|
70 |
+
return ", ".join(sorted(pages, key=int)) if pages else "No page numbers found"
|
71 |
+
return text.strip() or "No text extracted"
|
72 |
+
|
73 |
+
elif file_type == "txt":
|
74 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
75 |
+
text = f.read()
|
76 |
+
if query and "page number" in query.lower():
|
77 |
+
pages = re.findall(r'\b\d+\b', text)
|
78 |
+
return ", ".join(sorted(pages, key=int)) if pages else "No page numbers found"
|
79 |
+
return text.strip()
|
80 |
+
|
81 |
+
elif file_type == "mp3":
|
82 |
+
recognizer = sr.Recognizer()
|
83 |
+
with sr.AudioFile(file_path) as source:
|
84 |
+
audio = recognizer.record(source)
|
85 |
+
try:
|
86 |
+
text = recognizer.recognize_google(audio)
|
87 |
+
logger.debug(f"Transcribed audio: {text}")
|
88 |
+
if query and "page number" in query.lower():
|
89 |
+
pages = re.findall(r'\b\d+\b', text)
|
90 |
+
return ", ".join(sorted(pages, key=int)) if pages else "No page numbers provided"
|
91 |
return text
|
92 |
+
except sr.UnknownValueError:
|
93 |
+
logger.error("Could not understand audio")
|
94 |
+
return "No text transcribed from audio"
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f"Audio parsing failed: {e}")
|
97 |
+
return "Error transcribing audio"
|
98 |
+
|
99 |
else:
|
100 |
+
logger.warning(f"Unsupported file type: {file_type}")
|
101 |
return f"Unsupported file type: {file_type}"
|
102 |
+
|
103 |
except Exception as e:
|
104 |
logger.error(f"Error parsing file for task {task_id}: {e}")
|
105 |
+
return f"Error: {str(e)}"
|
106 |
+
|
107 |
+
file_parser_tool = StructuredTool.from_function(
|
108 |
+
func=file_parser_func,
|
109 |
+
name="file_parser_tool",
|
110 |
+
args_schema=FileParserInput,
|
111 |
+
coroutine=file_parser_func
|
112 |
+
)
|
tools/guest_info.py
CHANGED
@@ -1,20 +1,47 @@
|
|
1 |
-
from langchain_core.tools import tool
|
2 |
-
from retriever import load_guest_dataset
|
3 |
import logging
|
|
|
|
|
|
|
|
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
try:
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
except Exception as e:
|
19 |
logger.error(f"Error retrieving guest info for query '{query}': {e}")
|
20 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from langchain_core.tools import StructuredTool
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from datasets import load_dataset
|
5 |
+
from rank_bm25 import BM25Okapi
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
9 |
+
class GuestInfoInput(BaseModel):
|
10 |
+
query: str = Field(description="Query about guest information")
|
11 |
+
|
12 |
+
async def guest_info_func(query: str) -> str:
|
13 |
+
"""
|
14 |
+
Retrieve guest information based on a query.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
query (str): Query about guest information.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: Guest information or error message.
|
21 |
+
"""
|
22 |
try:
|
23 |
+
logger.info(f"Retrieving guest info for query: {query}")
|
24 |
+
dataset = load_dataset("agents-course/unit3-invitees", split="train")
|
25 |
+
logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset")
|
26 |
+
|
27 |
+
documents = [f"{row['name']} {row['relation']}" for row in dataset]
|
28 |
+
tokenized_docs = [doc.lower().split() for doc in documents]
|
29 |
+
bm25 = BM25Okapi(tokenized_docs)
|
30 |
+
|
31 |
+
tokenized_query = query.lower().split()
|
32 |
+
scores = bm25.get_scores(tokenized_query)
|
33 |
+
best_idx = scores.argmax()
|
34 |
+
|
35 |
+
if scores[best_idx] > 0:
|
36 |
+
return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}"
|
37 |
+
return "No matching guest found"
|
38 |
except Exception as e:
|
39 |
logger.error(f"Error retrieving guest info for query '{query}': {e}")
|
40 |
+
return f"Error: {str(e)}"
|
41 |
+
|
42 |
+
guest_info_retriever_tool = StructuredTool.from_function(
|
43 |
+
func=guest_info_func,
|
44 |
+
name="guest_info_retriever_tool",
|
45 |
+
args_schema=GuestInfoInput,
|
46 |
+
coroutine=guest_info_func
|
47 |
+
)
|
tools/hub_stats.py
CHANGED
@@ -1,17 +1,54 @@
|
|
1 |
-
|
2 |
-
|
3 |
import logging
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
@tool
|
8 |
async def hub_stats_tool(author: str) -> str:
|
9 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
try:
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
model = models[0]
|
14 |
-
return f"The most downloaded model by {author} is {model
|
15 |
return f"No models found for author {author}."
|
16 |
except Exception as e:
|
17 |
logger.error(f"Error fetching models for {author}: {e}")
|
|
|
1 |
+
import aiohttp
|
2 |
+
import ssl
|
3 |
import logging
|
4 |
+
from langchain_core.tools import tool
|
5 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
6 |
+
from typing import Optional
|
7 |
+
import json
|
8 |
+
import os
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
|
13 |
+
async def fetch_hf_models(author: str) -> Optional[dict]:
|
14 |
+
url = f"https://huggingface.co/api/models?author={author}&sort=downloads&direction=-1&limit=1"
|
15 |
+
ssl_context = ssl.create_default_context()
|
16 |
+
try:
|
17 |
+
async with aiohttp.ClientSession() as session:
|
18 |
+
async with session.get(url, ssl=ssl_context) as response:
|
19 |
+
response.raise_for_status()
|
20 |
+
return await response.json()
|
21 |
+
except aiohttp.ClientError as e:
|
22 |
+
logger.error(f"Failed to fetch models for {author}: {e}")
|
23 |
+
raise
|
24 |
+
|
25 |
@tool
|
26 |
async def hub_stats_tool(author: str) -> str:
|
27 |
+
"""
|
28 |
+
Fetch the most downloaded model from a specific author on Hugging Face Hub.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
author (str): Hugging Face author username.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
str: Model information or error message.
|
35 |
+
"""
|
36 |
try:
|
37 |
+
# Check local cache
|
38 |
+
cache_file = f"temp/hf_cache_{author}.json"
|
39 |
+
if os.path.exists(cache_file):
|
40 |
+
with open(cache_file, "r") as f:
|
41 |
+
models = json.load(f)
|
42 |
+
logger.debug(f"Loaded cached models for {author}")
|
43 |
+
else:
|
44 |
+
models = await fetch_hf_models(author)
|
45 |
+
os.makedirs("temp", exist_ok=True)
|
46 |
+
with open(cache_file, "w") as f:
|
47 |
+
json.dump(models, f)
|
48 |
+
|
49 |
+
if models and isinstance(models, list) and models:
|
50 |
model = models[0]
|
51 |
+
return f"The most downloaded model by {author} is {model['id']} with {model.get('downloads', 0):,} downloads."
|
52 |
return f"No models found for author {author}."
|
53 |
except Exception as e:
|
54 |
logger.error(f"Error fetching models for {author}: {e}")
|
tools/image_parser.py
CHANGED
@@ -1,25 +1,43 @@
|
|
1 |
-
from langchain_core.tools import tool
|
2 |
-
import easyocr
|
3 |
import logging
|
4 |
import os
|
|
|
|
|
|
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
-
reader = easyocr.Reader(['en'])
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
try:
|
13 |
if not os.path.exists(file_path):
|
14 |
-
logger.warning(f"Image not found: {file_path}")
|
15 |
-
return "Image not found"
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
return text
|
23 |
except Exception as e:
|
24 |
-
logger.error(f"Error parsing image {
|
25 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
+
from langchain_core.tools import StructuredTool
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
import easyocr
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
|
|
8 |
|
9 |
+
class ImageParserInput(BaseModel):
|
10 |
+
task_id: str = Field(description="Task identifier")
|
11 |
+
file_path: str = Field(description="Path to the image file")
|
12 |
+
|
13 |
+
async def image_parser_func(task_id: str, file_path: str) -> str:
|
14 |
+
"""
|
15 |
+
Parse text from an image file using OCR.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
task_id (str): Task identifier.
|
19 |
+
file_path (str): Path to the image file.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str: Extracted text or error message.
|
23 |
+
"""
|
24 |
try:
|
25 |
if not os.path.exists(file_path):
|
26 |
+
logger.warning(f"Image file not found: {file_path}")
|
27 |
+
return "Image file not found"
|
28 |
|
29 |
+
logger.info(f"Parsing image: {file_path} for task {task_id}")
|
30 |
+
reader = easyocr.Reader(['en'], model_storage_directory='./cache')
|
31 |
+
result = reader.readtext(file_path, detail=0)
|
32 |
+
text = " ".join(result).strip()
|
33 |
+
return text if text else "No text extracted from image"
|
|
|
34 |
except Exception as e:
|
35 |
+
logger.error(f"Error parsing image for task {task_id}: {e}")
|
36 |
+
return f"Error: {str(e)}"
|
37 |
+
|
38 |
+
image_parser_tool = StructuredTool.from_function(
|
39 |
+
func=image_parser_func,
|
40 |
+
name="image_parser_tool",
|
41 |
+
args_schema=ImageParserInput,
|
42 |
+
coroutine=image_parser_func
|
43 |
+
)
|
tools/search.py
CHANGED
@@ -1,106 +1,103 @@
|
|
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
-
from
|
6 |
-
from langchain.tools import Tool
|
7 |
-
from typing import List, Dict, Any
|
8 |
-
from langchain_core.prompts import ChatPromptTemplate
|
9 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
-
Perform a web search using
|
14 |
|
15 |
Args:
|
16 |
-
query:
|
17 |
|
18 |
Returns:
|
19 |
-
List of search result snippets.
|
20 |
-
|
21 |
-
Raises:
|
22 |
-
Exception: If search fails after retries.
|
23 |
"""
|
24 |
-
|
25 |
-
"
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
print(f"INFO - SERPAPI retry {attempt + 1}/3 due to: {e}")
|
38 |
-
asyncio.sleep(2)
|
39 |
-
|
40 |
-
raise Exception("SERPAPI failed after retries")
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
"""
|
44 |
-
Perform
|
45 |
|
46 |
Args:
|
47 |
-
query:
|
48 |
-
steps: Number of search
|
49 |
-
llm_client: LLM client for query refinement.
|
50 |
-
llm_type: Type of LLM
|
|
|
51 |
|
52 |
Returns:
|
53 |
-
List
|
54 |
"""
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
prompt = ChatPromptTemplate.from_messages([
|
67 |
-
SystemMessage(content="""Refine the following query to dig deeper into the topic, focusing on missing details or related aspects. Return ONLY the refined query as plain text, no explanations."""),
|
68 |
-
HumanMessage(content=f"Original query: {current_query}\nPrevious results: {json.dumps(search_results[:2], indent=2)}")
|
69 |
-
])
|
70 |
messages = [
|
71 |
-
{"role": "system", "content":
|
72 |
-
{"role": "user", "content": prompt
|
73 |
]
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
current_query = refined_query if refined_query else f"more details on {current_query}"
|
91 |
-
except Exception as e:
|
92 |
-
print(f"INFO - Query refinement failed at step {step + 1}: {e}")
|
93 |
-
current_query = f"more details on {current_query}"
|
94 |
-
|
95 |
-
await asyncio.sleep(1) # Rate limit
|
96 |
-
except Exception as e:
|
97 |
-
print(f"INFO - Multi-hop search step {step + 1} failed: {e}")
|
98 |
-
break
|
99 |
-
|
100 |
-
return results
|
101 |
|
102 |
-
multi_hop_search_tool =
|
103 |
-
func=
|
104 |
name="multi_hop_search_tool",
|
105 |
-
|
|
|
106 |
)
|
|
|
1 |
+
import logging
|
2 |
import os
|
3 |
+
from langchain_core.tools import StructuredTool
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
from typing import Optional, List
|
6 |
+
from serpapi import GoogleSearch
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class SearchInput(BaseModel):
|
11 |
+
query: str = Field(description="Search query")
|
12 |
+
|
13 |
+
async def search_func(query: str) -> List[str]:
|
14 |
"""
|
15 |
+
Perform a web search using SerpAPI and return relevant snippets.
|
16 |
|
17 |
Args:
|
18 |
+
query (str): The search query to execute.
|
19 |
|
20 |
Returns:
|
21 |
+
List[str]: A list of search result snippets.
|
|
|
|
|
|
|
22 |
"""
|
23 |
+
try:
|
24 |
+
logger.info(f"Executing SerpAPI search for query: {query}")
|
25 |
+
params = {
|
26 |
+
"q": query,
|
27 |
+
"api_key": os.getenv("SERPAPI_API_KEY"),
|
28 |
+
"num": 10
|
29 |
+
}
|
30 |
+
search = GoogleSearch(params)
|
31 |
+
results = search.get_dict().get("organic_results", [])
|
32 |
+
return [result.get("snippet", "") for result in results if "snippet" in result]
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f"SerpAPI search failed for query '{query}': {e}")
|
35 |
+
return []
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
search_tool = StructuredTool.from_function(
|
38 |
+
func=search_func,
|
39 |
+
name="search_tool",
|
40 |
+
args_schema=SearchInput,
|
41 |
+
coroutine=search_func
|
42 |
+
)
|
43 |
+
|
44 |
+
class MultiHopSearchInput(BaseModel):
|
45 |
+
query: str = Field(description="Multi-hop search query")
|
46 |
+
steps: int = Field(description="Number of search steps", ge=1, le=3)
|
47 |
+
llm_client: Optional[object] = Field(description="LLM client", default=None)
|
48 |
+
llm_type: Optional[str] = Field(description="LLM type", default="together")
|
49 |
+
llm_model: Optional[str] = Field(description="LLM model", default="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free")
|
50 |
+
|
51 |
+
async def multi_hop_search_func(query: str, steps: int, llm_client: Optional[object] = None, llm_type: Optional[str] = "together", llm_model: Optional[str] = "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free") -> List[str]:
|
52 |
"""
|
53 |
+
Perform a multi-hop web search using SerpAPI with iterative query refinement.
|
54 |
|
55 |
Args:
|
56 |
+
query (str): The initial multi-hop search query.
|
57 |
+
steps (int): Number of search steps to perform (1 to 3).
|
58 |
+
llm_client (Optional[object]): LLM client for query refinement.
|
59 |
+
llm_type (Optional[str]): Type of LLM (e.g., 'together').
|
60 |
+
llm_model (Optional[str]): LLM model name.
|
61 |
|
62 |
Returns:
|
63 |
+
List[str]: A list of search result snippets from all steps.
|
64 |
"""
|
65 |
+
try:
|
66 |
+
logger.info(f"Executing multi-hop search for query: {query}, steps: {steps}")
|
67 |
+
results = []
|
68 |
+
current_query = query
|
69 |
+
|
70 |
+
for step in range(steps):
|
71 |
+
logger.info(f"Multi-hop step {step + 1}: {current_query}")
|
72 |
+
step_results = await search_func(current_query)
|
73 |
+
results.extend(step_results)
|
74 |
|
75 |
+
if step < steps - 1 and llm_client:
|
76 |
+
prompt = f"Given the query '{current_query}' and results: {step_results[:3]}, generate a follow-up search query to refine or expand the search."
|
|
|
|
|
|
|
|
|
77 |
messages = [
|
78 |
+
{"role": "system", "content": "Generate a single search query as a string."},
|
79 |
+
{"role": "user", "content": prompt}
|
80 |
]
|
81 |
+
if llm_type == "together":
|
82 |
+
response = llm_client.chat.completions.create(
|
83 |
+
model=llm_model,
|
84 |
+
messages=messages,
|
85 |
+
max_tokens=50,
|
86 |
+
temperature=0.7
|
87 |
+
)
|
88 |
+
current_query = response.choices[0].message.content.strip()
|
89 |
+
else:
|
90 |
+
logger.warning("LLM not configured for multi-hop refinement")
|
91 |
+
break
|
92 |
+
|
93 |
+
return results[:5] if results else ["No results found"]
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Multi-hop search failed for query '{query}': {e}")
|
96 |
+
return [f"Error: {str(e)}"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
multi_hop_search_tool = StructuredTool.from_function(
|
99 |
+
func=multi_hop_search_func,
|
100 |
name="multi_hop_search_tool",
|
101 |
+
args_schema=MultiHopSearchInput,
|
102 |
+
coroutine=multi_hop_search_func
|
103 |
)
|
tools/weather_info.py
CHANGED
@@ -1,23 +1,50 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
import logging
|
4 |
import os
|
|
|
|
|
5 |
from dotenv import load_dotenv
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
load_dotenv()
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
@tool
|
11 |
-
async def weather_info_tool(location: str) -> str:
|
12 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
try:
|
14 |
api_key = os.getenv("OPENWEATHERMAP_API_KEY")
|
15 |
if not api_key:
|
16 |
logger.error("OPENWEATHERMAP_API_KEY not set")
|
17 |
return "Weather unavailable: API key missing"
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
21 |
if response.get("cod") == 200:
|
22 |
condition = response["weather"][0]["description"]
|
23 |
temp = response["main"]["temp"]
|
|
|
1 |
+
import aiohttp
|
2 |
+
import ssl
|
3 |
import logging
|
4 |
import os
|
5 |
+
from langchain_core.tools import tool
|
6 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
7 |
from dotenv import load_dotenv
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
load_dotenv()
|
11 |
|
12 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
|
13 |
+
async def fetch_weather(location: str, api_key: str) -> dict:
|
14 |
+
url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={api_key}&units=metric"
|
15 |
+
ssl_context = ssl.create_default_context()
|
16 |
+
try:
|
17 |
+
async with aiohttp.ClientSession() as session:
|
18 |
+
async with session.get(url, ssl=ssl_context) as response:
|
19 |
+
response.raise_for_status()
|
20 |
+
return await response.json()
|
21 |
+
except aiohttp.ClientError as e:
|
22 |
+
logger.error(f"Failed to fetch weather for {location}: {e}")
|
23 |
+
raise
|
24 |
+
|
25 |
@tool
|
26 |
+
async def weather_info_tool(location: str, query_type: str = "current") -> str:
|
27 |
+
"""
|
28 |
+
Fetch weather information for a given location.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
location (str): City or location name.
|
32 |
+
query_type (str): Type of weather query ('current', 'forecast'; default: 'current').
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
str: Weather information or error message.
|
36 |
+
"""
|
37 |
try:
|
38 |
api_key = os.getenv("OPENWEATHERMAP_API_KEY")
|
39 |
if not api_key:
|
40 |
logger.error("OPENWEATHERMAP_API_KEY not set")
|
41 |
return "Weather unavailable: API key missing"
|
42 |
|
43 |
+
if query_type != "current":
|
44 |
+
logger.warning(f"Query type '{query_type}' not supported; using current weather")
|
45 |
+
query_type = "current"
|
46 |
+
|
47 |
+
response = await fetch_weather(location, api_key)
|
48 |
if response.get("cod") == 200:
|
49 |
condition = response["weather"][0]["description"]
|
50 |
temp = response["main"]["temp"]
|