onisj commited on
Commit
751d628
·
1 Parent(s): 853221a

feat(tools): add more tool to extend the functionaily of jarvis

Browse files
.gitignore CHANGED
@@ -41,6 +41,7 @@ coverage.xml
41
  *.py,cover
42
  .tox/
43
  .pytest_cache/
 
44
 
45
  # Logs and temporary files
46
  *.log
 
41
  *.py,cover
42
  .tox/
43
  .pytest_cache/
44
+ cache/
45
 
46
  # Logs and temporary files
47
  *.log
README.md CHANGED
@@ -74,6 +74,10 @@ jarvis_gaia_agent/
74
  - `SERPAPI_API_KEY`: SERPAPI key for web searches.
75
  - `OPENWEATHERMAP_API_KEY`: OpenWeatherMap key for weather queries.
76
  - `SPACE_ID`: `onisj/jarvis_gaia_agent`.
 
 
 
 
77
 
78
  ## Setup and Local Testing
79
 
 
74
  - `SERPAPI_API_KEY`: SERPAPI key for web searches.
75
  - `OPENWEATHERMAP_API_KEY`: OpenWeatherMap key for weather queries.
76
  - `SPACE_ID`: `onisj/jarvis_gaia_agent`.
77
+ - Install dependencies:
78
+ ```bash
79
+ pip install -r requirements.txt
80
+ ```
81
 
82
  ## Setup and Local Testing
83
 
app.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import logging
4
  import asyncio
5
  import aiohttp
 
6
  import nest_asyncio
7
  import requests
8
  import pandas as pd
@@ -10,18 +11,25 @@ from typing import Dict, Any, List
10
  from langchain_core.prompts import ChatPromptTemplate
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
  from langgraph.graph import StateGraph, END
 
13
  from sentence_transformers import SentenceTransformer
14
  import gradio as gr
15
  from dotenv import load_dotenv
16
  from huggingface_hub import InferenceClient
17
  from transformers import AutoTokenizer, AutoModelForCausalLM
18
  import together
19
- from state import JARVISState
20
- from tools import (
21
- search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool,
22
- calculator_tool, document_retriever_tool, duckduckgo_search_tool,
23
- weather_info_tool, hub_stats_tool, guest_info_retriever_tool
24
- )
 
 
 
 
 
 
25
 
26
  # Setup logging
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -33,10 +41,10 @@ nest_asyncio.apply()
33
  # Load environment variables
34
  load_dotenv()
35
  SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
36
- GAIA_API_URL = "https://agents-course-unit4-scoring.hf.space"
37
- GAIA_FILE_URL = f"{GAIA_API_URL}/files/"
38
  TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
39
  HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
40
 
41
  # Verify environment variables
42
  if not SPACE_ID:
@@ -45,6 +53,8 @@ if not HF_API_TOKEN:
45
  raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
46
  if not TOGETHER_API_KEY:
47
  raise ValueError("TOGETHER_API_KEY not set")
 
 
48
  logger.info(f"SPACE_ID: {SPACE_ID}")
49
 
50
  # Model configuration
@@ -56,23 +66,20 @@ HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
56
 
57
  # Initialize LLM clients
58
  def initialize_llm():
59
- # Try Together AI models
60
  for model in TOGETHER_MODELS:
61
  try:
62
  together.api_key = TOGETHER_API_KEY
63
  client = together.Together()
64
- # Test the model
65
  response = client.chat.completions.create(
66
  model=model,
67
  messages=[{"role": "user", "content": "Test"}],
68
  max_tokens=10
69
  )
70
  logger.info(f"Initialized Together AI model: {model}")
71
- return client, "together"
72
  except Exception as e:
73
  logger.warning(f"Failed to initialize Together AI model {model}: {e}")
74
 
75
- # Fallback to Hugging Face Inference API
76
  try:
77
  client = InferenceClient(
78
  model=HF_MODEL,
@@ -80,381 +87,355 @@ def initialize_llm():
80
  timeout=30
81
  )
82
  logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
83
- return client, "hf_api"
84
  except Exception as e:
85
  logger.warning(f"Failed to initialize HF Inference API: {e}")
86
 
87
- # Fallback to local Hugging Face model
88
  try:
89
  tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
90
  model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
91
  logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
92
- return (model, tokenizer), "hf_local"
93
  except Exception as e:
94
  logger.error(f"Failed to initialize local HF model: {e}")
95
  raise Exception("No LLM could be initialized")
96
 
97
- llm_client, llm_type = initialize_llm()
98
 
99
  # Initialize embedder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  try:
101
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
102
- logger.info("Sentence transformer initialized")
103
  except Exception as e:
104
  logger.error(f"Failed to initialize embedder: {e}")
105
  embedder = None
106
 
107
- # Download file with local fallback
108
- async def download_file(task_id: str, ext: str) -> str | None:
109
- try:
110
- url = f"{GAIA_FILE_URL}{task_id}.{ext}"
111
- async with aiohttp.ClientSession() as session:
112
- async with session.get(url, timeout=10) as resp:
113
- logger.info(f"GAIA API test for task {task_id} with .{ext}: HTTP {resp.status}")
114
- if resp.status == 200:
115
- os.makedirs("temp", exist_ok=True)
116
- file_path = f"temp/{task_id}.{ext}"
117
- with open(file_path, "wb") as f:
118
- f.write(await resp.read())
119
- return file_path
120
- except Exception as e:
121
- logger.warning(f"File download failed for {task_id}.{ext}: {e}")
122
- local_path = f"temp/{task_id}.{ext}"
123
- if os.path.exists(local_path):
124
- logger.info(f"Using local file: {local_path}")
125
- return local_path
126
- return None
 
 
 
 
 
 
 
127
 
128
  # Parse question to select tools
129
  async def parse_question(state: JARVISState) -> JARVISState:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  try:
131
- question = state["question"]
132
- task_id = state["task_id"]
 
 
 
 
 
 
133
  tools_needed = ["search_tool"]
134
-
 
135
  if llm_client:
136
  prompt = ChatPromptTemplate.from_messages([
137
  SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
138
- Return JSON list, e.g., ["search_tool", "file_parser_tool"].
 
 
139
  Rules:
140
- - Always include "search_tool" unless purely computational.
141
- - Use "multi_hop_search_tool" for complex queries (over 20 words or requiring multiple steps).
142
- - Use "file_parser_tool" for data, tables, or Excel.
143
- - Use "image_parser_tool" for images/videos.
144
- - Use "calculator_tool" for math calculations.
145
- - Use "document_retriever_tool" for documents/PDFs.
146
- - Use "duckduckgo_search_tool" for additional search capability.
147
- - Use "weather_info_tool" for weather-related queries.
148
- - Use "hub_stats_tool" for Hugging Face Hub queries.
149
- - Use "guest_info_retriever_tool" for guest-related queries.
 
150
  - Output ONLY valid JSON."""),
151
  HumanMessage(content=f"Query: {question}")
152
  ])
153
- try:
154
- if llm_type == "hf_local":
155
- model, tokenizer = llm_client
156
- inputs = tokenizer.apply_chat_template(
157
- [{"role": "system", "content": prompt[0].content}, {"role": "user", "content": prompt[1].content}],
158
- return_tensors="pt"
159
- ).to(model.device)
160
- outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
161
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
- tools_needed = json.loads(response.strip())
163
- elif llm_type == "together":
164
- response = llm_client.chat.completions.create(
165
- model=llm_client.model,
166
- messages=[
167
- {"role": "system", "content": prompt[0].content},
168
- {"role": "user", "content": prompt[1].content}
169
- ],
170
- max_tokens=512,
171
- temperature=0.7
172
- )
173
- tools_needed = json.loads(response.choices[0].message.content.strip())
174
- else: # hf_api
175
- response = llm_client.chat.completions.create(
176
- model=HF_MODEL,
177
- messages=[
178
- {"role": "system", "content": prompt[0].content},
179
- {"role": "user", "content": prompt[1].content}
180
- ],
181
- max_tokens=512,
182
- temperature=0.7
183
- )
184
- tools_needed = json.loads(response.choices[0].message.content.strip())
185
-
186
- valid_tools = {
187
- "search_tool", "multi_hop_search_tool", "file_parser_tool", "image_parser_tool",
188
- "calculator_tool", "document_retriever_tool", "duckduckgo_search_tool",
189
- "weather_info_tool", "hub_stats_tool", "guest_info_retriever_tool"
190
- }
191
- tools_needed = [tool for tool in tools_needed if tool in valid_tools]
192
- except Exception as e:
193
- logger.warning(f"Task {task_id} tool selection failed: {e}")
194
- state["error"] = f"Tool selection failed: {str(e)}"
195
-
196
- # Keyword-based fallback
197
- question_lower = question.lower()
198
- if any(word in question_lower for word in ["image", "video", "picture"]):
199
- tools_needed.append("image_parser_tool")
200
- if any(word in question_lower for word in ["data", "table", "excel", ".txt", ".csv", ".xlsx"]):
201
- tools_needed.append("file_parser_tool")
202
- if any(word in question_lower for word in ["calculate", "math", "sum", "average", "total"]):
203
- tools_needed.append("calculator_tool")
204
- if any(word in question_lower for word in ["document", "pdf", "report", "menu"]):
205
- tools_needed.append("document_retriever_tool")
206
- if any(word in question_lower for word in ["weather", "temperature"]):
207
- tools_needed.append("weather_info_tool")
208
- if any(word in question_lower for word in ["model", "huggingface", "dataset"]):
209
- tools_needed.append("hub_stats_tool")
210
- if any(word in question_lower for word in ["guest", "name", "relation", "person"]):
211
- tools_needed.append("guest_info_retriever_tool")
212
- if len(question.split()) > 20 or "multiple" in question_lower:
213
- tools_needed.append("multi_hop_search_tool")
214
- if any(word in question_lower for word in ["search", "wikipedia", "online"]):
215
- tools_needed.append("duckduckgo_search_tool")
216
-
217
- # Check file availability
218
- for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
219
- file_path = await download_file(task_id, ext)
220
- if file_path:
221
- if ext in ["txt", "csv", "xlsx"] and "file_parser_tool" not in tools_needed:
 
 
 
 
 
 
 
 
 
 
 
222
  tools_needed.append("file_parser_tool")
223
- if ext == "jpg" and "image_parser_tool" not in tools_needed:
224
  tools_needed.append("image_parser_tool")
225
- if ext == "pdf" and "document_retriever_tool" not in tools_needed:
226
  tools_needed.append("document_retriever_tool")
227
- state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path}
228
- break
229
-
230
- state["tools_needed"] = list(set(tools_needed))
231
- logger.info(f"Task {task_id}: Selected tools: {tools_needed}")
232
  return state
233
  except Exception as e:
234
- logger.error(f"Error parsing task {task_id}: {e}")
235
  state["error"] = f"Parse question failed: {str(e)}"
236
  state["tools_needed"] = ["search_tool"]
237
  return state
238
 
239
  # Tool dispatcher
240
  async def tool_dispatcher(state: JARVISState) -> JARVISState:
 
241
  try:
242
- updated_state = state.copy()
243
- file_type = "jpg" if "image" in state["question"].lower() else "txt"
244
- if any(word in state["question"].lower() for word in ["menu", "report"]):
245
- file_type = "pdf"
246
- elif "data" in state["question"].lower():
247
- file_type = "xlsx"
248
-
249
- for tool in updated_state["tools_needed"]:
250
- try:
251
- if tool == "search_tool":
252
- result = search_tool(updated_state["question"])
253
- updated_state["web_results"].extend([str(r) for r in result])
254
- elif tool == "multi_hop_search_tool":
255
- result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3, "llm_client": llm_client, "llm_type": llm_type})
256
- updated_state["multi_hop_results"].extend([r["content"] for r in result])
257
- await asyncio.sleep(2)
258
- elif tool == "file_parser_tool":
259
- for ext in ["txt", "csv", "xlsx"]:
260
- file_path = await download_file(updated_state["task_id"], ext)
261
- if file_path:
262
- result = file_parser_tool(file_path)
263
- updated_state["file_results"] = str(result)
264
- break
265
- elif tool == "image_parser_tool":
266
- file_path = await download_file(updated_state["task_id"], "jpg")
267
- if file_path:
268
- result = image_parser_tool(file_path)
269
- updated_state["image_results"] = str(result)
270
- elif tool == "calculator_tool":
271
- result = calculator_tool(updated_state["question"])
272
- updated_state["calculation_results"] = str(result)
273
- elif tool == "document_retriever_tool":
274
- file_path = await download_file(updated_state["task_id"], "pdf")
275
- if file_path:
276
- result = document_retriever_tool({"task_id": updated_state["task_id"], "query": updated_state["question"], "file_type": "pdf"})
277
- updated_state["document_results"] = str(result)
278
- elif tool == "duckduckgo_search_tool":
279
- result = duckduckgo_search_tool(updated_state["question"])
280
- updated_state["web_results"].append(str(result))
281
- elif tool == "weather_info_tool":
282
- location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown"
283
- result = weather_info_tool({"location": location})
284
- updated_state["web_results"].append(str(result))
285
- elif tool == "hub_stats_tool":
286
- author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown"
287
- result = hub_stats_tool({"author": author})
288
- updated_state["web_results"].append(str(result))
289
- elif tool == "guest_info_retriever_tool":
290
- query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"]
291
- result = guest_info_retriever_tool({"query": query})
292
- updated_state["web_results"].append(str(result))
293
- updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_executed": True}
294
- except Exception as e:
295
- logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}")
296
- updated_state["error"] = f"Tool {tool} failed: {str(e)}"
297
- updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_error": str(e)}
298
-
299
- logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}")
300
- return updated_state
301
- except Exception as e:
302
- logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}")
303
- updated_state["error"] = f"Tool dispatch failed: {str(e)}"
304
- return updated_state
305
-
306
- # Reasoning
307
- async def reasoning(state: JARVISState) -> Dict[str, Any]:
308
- try:
309
- prompt = ChatPromptTemplate.from_messages([
310
- SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""),
311
- HumanMessage(content="""Task: {task_id}
312
- Question: {question}
313
- Web results: {web_results}
314
- Multi-hop results: {multi_hop_results}
315
- File results: {file_results}
316
- Image results: {image_results}
317
- Calculation results: {calculation_results}
318
- Document results: {document_results}""")
319
- ])
320
- messages = [
321
- {"role": "system", "content": prompt[0].content},
322
- {"role": "user", "content": prompt[1].content.format(
323
- task_id=state["task_id"],
324
- question=state["question"],
325
- web_results="\n".join(state["web_results"]),
326
- multi_hop_results="\n".join(state["multi_hop_results"]),
327
- file_results=state["file_results"],
328
- image_results=state["image_results"],
329
- calculation_results=state["calculation_results"],
330
- document_results=state["document_results"]
331
- )}
332
- ]
333
- for attempt in range(3):
334
  try:
335
- if llm_type == "hf_local":
336
- model, tokenizer = llm_client
337
- inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
338
- outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
339
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
340
- elif llm_type == "together":
341
- response = llm_client.chat.completions.create(
342
- model=llm_client.model,
343
- messages=messages,
344
- max_tokens=512,
345
- temperature=0.7
346
- )
347
- answer = response.choices[0].message.content.strip()
348
- else: # hf_api
349
- response = llm_client.chat.completions.create(
350
- model=HF_MODEL,
351
- messages=messages,
352
- max_tokens=512,
353
- temperature=0.7
354
- )
355
- answer = response.choices[0].message.content.strip()
356
-
357
- # Format answer
358
- if "USD" in state["question"].lower():
359
- try:
360
- answer = f"{float(answer):.2f}"
361
- except ValueError:
362
- pass
363
- if "before and after" in state["question"].lower():
364
- answer = answer.replace(" and ", ", ")
365
- if "IOC code" in state["question"].lower():
366
- answer = answer.upper()[:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- logger.info(f"Task {state['task_id']}: Answer: {answer}")
369
- return {"answer": answer}
370
  except Exception as e:
371
- logger.warning(f"LLM retry {attempt + 1}/3 for task {state['task_id']}: {e}")
372
- await asyncio.sleep(2)
373
- state["error"] = "LLM failed after retries"
374
- return {"answer": "Error: LLM failed after retries"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  except Exception as e:
376
- logger.error(f"Reasoning failed for task {state['task_id']}: {e}")
377
- state["error"] = f"Reasoning failed: {str(e)}"
378
- return {"answer": f"Error: {str(e)}"}
379
-
380
- # Router
381
- def router(state: JARVISState) -> str:
382
- if state["tools_needed"]:
383
- return "tool_dispatcher"
384
- return "reasoning"
385
 
386
  # Define StateGraph
387
  workflow = StateGraph(JARVISState)
388
- workflow.add_node("parse", parse_question)
389
  workflow.add_node("tool_dispatcher", tool_dispatcher)
390
- workflow.add_node("reasoning", reasoning)
391
- workflow.set_entry_point("parse")
392
- workflow.add_conditional_edges(
393
- "parse",
394
- router,
395
- {
396
- "tool_dispatcher": "tool_dispatcher",
397
- "reasoning": "reasoning"
398
- }
399
- )
400
- workflow.add_edge("tool_dispatcher", "reasoning")
401
- workflow.add_edge("reasoning", END)
402
  graph = workflow.compile()
403
 
404
  # Agent class
405
  class JARVISAgent:
406
  def __init__(self):
407
- self.state = JARVISState(
408
- task_id="",
409
- question="",
410
- tools_needed=[],
411
- web_results=[],
412
- file_results="",
413
- image_results="",
414
- calculation_results="",
415
- document_results="",
416
- multi_hop_results=[],
417
- messages=[],
418
- answer="",
419
- results_table=[],
420
- status_output="",
421
- error=None,
422
- metadata={}
423
- )
424
  logger.info("JARVISAgent initialized.")
425
 
426
  async def process_question(self, task_id: str, question: str) -> str:
427
- state = JARVISState(
428
- task_id=task_id,
429
- question=question,
430
- tools_needed=["search_tool"],
431
- web_results=[],
432
- file_results="",
433
- image_results="",
434
- calculation_results="",
435
- document_results="",
436
- multi_hop_results=[],
437
- messages=[HumanMessage(content=question)],
438
- answer="",
439
- results_table=[],
440
- status_output="",
441
- error=None,
442
- metadata={}
443
- )
444
  try:
445
  result = await graph.ainvoke(state)
446
- answer = result["answer"] or "Unknown"
447
- logger.info(f"Task {task_id}: Final answer: {answer}")
448
- self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": answer})
449
- self.state.metadata = self.state.get("metadata", {}) | {"last_task": task_id, "answer": answer}
450
  return answer
451
  except Exception as e:
452
  logger.error(f"Error processing task {task_id}: {e}")
453
- self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"})
454
- self.state.error = f"Task {task_id} failed: {str(e)}"
455
  return f"Error: {str(e)}"
456
  finally:
457
- for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
458
  file_path = f"temp/{task_id}.{ext}"
459
  if os.path.exists(file_path):
460
  try:
@@ -466,25 +447,26 @@ class JARVISAgent:
466
  async def process_all_questions(self, profile: gr.OAuthProfile | None):
467
  if not profile:
468
  logger.error("User not logged in.")
469
- self.state.status_output = "Please Login to Hugging Face."
470
- return pd.DataFrame(self.state.results_table), self.state.status_output
471
 
472
- username = f"{profile.username}"
473
  logger.info(f"User logged in: {username}")
474
  questions_url = f"{GAIA_API_URL}/questions"
475
  submit_url = f"{GAIA_API_URL}/submit"
476
  agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
477
 
478
  try:
479
- response = requests.get(questions_url, timeout=15)
480
- response.raise_for_status()
481
- questions = response.json()
 
482
  logger.info(f"Fetched {len(questions)} questions.")
483
  except Exception as e:
484
  logger.error(f"Error fetching questions: {e}")
485
- self.state.status_output = f"Error fetching questions: {e}"
486
- self.state.error = f"Fetch questions failed: {str(e)}"
487
- return pd.DataFrame(self.state.results_table), self.state.status_output
488
 
489
  answers_payload = []
490
  for item in questions:
@@ -498,33 +480,34 @@ class JARVISAgent:
498
 
499
  if not answers_payload:
500
  logger.error("No answers generated.")
501
- self.state.status_output = "No answers to submit."
502
- self.state.error = "No answers generated"
503
- return pd.DataFrame(self.state.results_table), self.state.status_output
504
 
505
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
506
  try:
507
- response = requests.post(submit_url, json=submission_data, timeout=120)
508
- response.raise_for_status()
509
- result_data = response.json()
510
- self.state.status_output = (
 
511
  f"Submission Successful!\n"
512
  f"User: {result_data.get('username')}\n"
513
  f"Overall Score: {result_data.get('score', 'N/A')}% "
514
  f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
515
  f"Message: {result_data.get('message', 'No message received.')}"
516
  )
517
- self.state.metadata = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')}
518
  except Exception as e:
519
  logger.error(f"Submission failed: {e}")
520
- self.state.status_output = f"Submission Failed: {e}"
521
- self.state.error = f"Submission failed: {str(e)}"
522
 
523
- return pd.DataFrame(self.state.results_table), self.state.status_output
524
 
525
  # Gradio interface
526
  with gr.Blocks() as demo:
527
- gr.Markdown("# Evolved JARVIS GAIA Agent")
528
  gr.Markdown(
529
  """
530
  **Instructions:**
@@ -539,7 +522,6 @@ with gr.Blocks() as demo:
539
  )
540
  with gr.Row():
541
  gr.LoginButton(value="Login to Hugging Face")
542
- # Removed gr.LogoutButton due to deprecation
543
  run_button = gr.Button("Run Evaluation & Submit All Answers")
544
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
545
  results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
 
3
  import logging
4
  import asyncio
5
  import aiohttp
6
+ import ssl
7
  import nest_asyncio
8
  import requests
9
  import pandas as pd
 
11
  from langchain_core.prompts import ChatPromptTemplate
12
  from langchain_core.messages import SystemMessage, HumanMessage
13
  from langgraph.graph import StateGraph, END
14
+ import torch
15
  from sentence_transformers import SentenceTransformer
16
  import gradio as gr
17
  from dotenv import load_dotenv
18
  from huggingface_hub import InferenceClient
19
  from transformers import AutoTokenizer, AutoModelForCausalLM
20
  import together
21
+ from state import JARVISState, validate_state, reset_state
22
+ from tools.answer_generator import generate_answer, preprocess_question
23
+ from tools.file_fetcher import fetch_task_file
24
+ from tools.search import search_tool, multi_hop_search_tool
25
+ from tools.file_parser import file_parser_tool
26
+ from tools.image_parser import image_parser_tool
27
+ from tools.calculator import calculator_tool
28
+ from tools.document_retriever import document_retriever_tool
29
+ from tools.duckduckgo_search import duckduckgo_search_tool
30
+ from tools.weather_info import weather_info_tool
31
+ from tools.hub_stats import hub_stats_tool
32
+ from tools.guest_info import guest_info_retriever_tool
33
 
34
  # Setup logging
35
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
41
  # Load environment variables
42
  load_dotenv()
43
  SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
44
+ GAIA_API_URL = "https://agents-course-unit4-api-1.hf.space/api"
 
45
  TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
46
  HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
47
+ OPENWEATHERMAP_API_KEY = os.getenv("OPENWEATHERMAP_API_KEY")
48
 
49
  # Verify environment variables
50
  if not SPACE_ID:
 
53
  raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
54
  if not TOGETHER_API_KEY:
55
  raise ValueError("TOGETHER_API_KEY not set")
56
+ if not OPENWEATHERMAP_API_KEY:
57
+ logger.warning("OPENWEATHERMAP_API_KEY not set; weather_info_tool may fail")
58
  logger.info(f"SPACE_ID: {SPACE_ID}")
59
 
60
  # Model configuration
 
66
 
67
  # Initialize LLM clients
68
  def initialize_llm():
 
69
  for model in TOGETHER_MODELS:
70
  try:
71
  together.api_key = TOGETHER_API_KEY
72
  client = together.Together()
 
73
  response = client.chat.completions.create(
74
  model=model,
75
  messages=[{"role": "user", "content": "Test"}],
76
  max_tokens=10
77
  )
78
  logger.info(f"Initialized Together AI model: {model}")
79
+ return client, "together", model
80
  except Exception as e:
81
  logger.warning(f"Failed to initialize Together AI model {model}: {e}")
82
 
 
83
  try:
84
  client = InferenceClient(
85
  model=HF_MODEL,
 
87
  timeout=30
88
  )
89
  logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
90
+ return client, "hf_api", HF_MODEL
91
  except Exception as e:
92
  logger.warning(f"Failed to initialize HF Inference API: {e}")
93
 
 
94
  try:
95
  tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
96
  model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
97
  logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
98
+ return (model, tokenizer), "hf_local", HF_MODEL
99
  except Exception as e:
100
  logger.error(f"Failed to initialize local HF model: {e}")
101
  raise Exception("No LLM could be initialized")
102
 
103
+ llm_client, llm_type, llm_model = initialize_llm()
104
 
105
  # Initialize embedder
106
+ _embedder = None
107
+
108
+ def get_embedder():
109
+ global _embedder
110
+ if _embedder is None:
111
+ try:
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ _embedder = SentenceTransformer(
114
+ "all-MiniLM-L6-v2",
115
+ device=device,
116
+ cache_folder="./cache"
117
+ )
118
+ logger.info(f"SentenceTransformer initialized on {device.upper()}")
119
+ except Exception as e:
120
+ logger.error(f"Failed to initialize SentenceTransformer: {e}")
121
+ raise RuntimeError(f"Embedder initialization failed: {e}")
122
+ return _embedder
123
+
124
  try:
125
+ embedder = get_embedder()
 
126
  except Exception as e:
127
  logger.error(f"Failed to initialize embedder: {e}")
128
  embedder = None
129
 
130
+ # Log device
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ logger.info(f"Using device: {device}")
133
+
134
+ # HTTP session with SSL handling
135
+ async def create_http_session():
136
+ ssl_context = ssl.create_default_context()
137
+ ssl_context.check_hostname = False
138
+ ssl_context.verify_mode = ssl.CERT_NONE
139
+ return aiohttp.ClientSession(
140
+ connector=aiohttp.TCPConnector(ssl=ssl_context),
141
+ timeout=aiohttp.ClientTimeout(total=30)
142
+ )
143
+
144
+ # Tool registration
145
+ tools = {
146
+ "search_tool": search_tool,
147
+ "multi_hop_search_tool": multi_hop_search_tool,
148
+ "file_parser_tool": file_parser_tool,
149
+ "image_parser_tool": image_parser_tool,
150
+ "calculator_tool": calculator_tool,
151
+ "document_retriever_tool": document_retriever_tool,
152
+ "duckduckgo_search_tool": duckduckgo_search_tool,
153
+ "weather_info_tool": weather_info_tool,
154
+ "hub_stats_tool": hub_stats_tool,
155
+ "guest_info_retriever_tool": guest_info_retriever_tool,
156
+ }
157
 
158
  # Parse question to select tools
159
  async def parse_question(state: JARVISState) -> JARVISState:
160
+ """
161
+ Parse the question to select appropriate tools using LLM with retries, preprocess the question, and integrate file-based tools.
162
+
163
+ Args:
164
+ state (JARVISState): The input state containing task_id, question.
165
+
166
+ Returns:
167
+ JARVISState: Updated state with selected tools_needed and metadata.
168
+ """
169
+ state = validate_state(state)
170
+ task_id = state["task_id"]
171
+ question = state["question"]
172
+
173
+ logger.info(f"Task {task_id} Parsing question: {question}")
174
  try:
175
+ # Preprocess question
176
+ processed_question = await preprocess_question(question)
177
+ if processed_question != question:
178
+ logger.info(f"Task {task_id} Preprocessed question: {processed_question}")
179
+ state["question"] = processed_question
180
+ question = processed_question
181
+
182
+ # Default to search_tool
183
  tools_needed = ["search_tool"]
184
+
185
+ # LLM-based tool selection
186
  if llm_client:
187
  prompt = ChatPromptTemplate.from_messages([
188
  SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
189
+
190
+ Return a JSON list of all relevant tools, e.g., ["search_tool", "duckduckgo_search_tool"].
191
+
192
  Rules:
193
+ - Include "search_tool" for web-based questions unless purely computational or file-based.
194
+ - Include "multi_hop_search_tool" for questions with >20 words or requiring multiple steps.
195
+ - Include "file_parser_tool" for 'data', 'table', 'excel', 'csv', 'txt', 'mp3', or file extensions.
196
+ - Include "image_parser_tool" for 'image', 'video', 'picture', or 'painting'.
197
+ - Include "calculator_tool" for 'calculate', 'math', 'sum', 'average', 'total', or numerical operations.
198
+ - Include "document_retriever_tool" for 'document', 'pdf', 'report', or 'paper'.
199
+ - Include "duckduckgo_search_tool" for 'search', 'wikipedia', 'online', or general knowledge.
200
+ - Include "weather_info_tool" for 'weather', 'temperature', or 'forecast'.
201
+ - Include "hub_stats_tool" for 'model', 'huggingface', or 'dataset'.
202
+ - Include "guest_info_retriever_tool" for 'guest', 'name', 'relation', or 'person'.
203
+ - Select multiple tools if the question spans multiple domains (e.g., web and file).
204
  - Output ONLY valid JSON."""),
205
  HumanMessage(content=f"Query: {question}")
206
  ])
207
+ messages = prompt.format_messages()
208
+
209
+ for attempt in range(3): # Retry up to 3 times
210
+ try:
211
+ formatted_messages = [
212
+ {"role": "system" if isinstance(m, SystemMessage) else "user", "content": m.content}
213
+ for m in messages
214
+ ]
215
+ if llm_type == "hf_local":
216
+ model, tokenizer = llm_client
217
+ inputs = tokenizer.apply_chat_template(
218
+ formatted_messages,
219
+ return_tensors="pt"
220
+ ).to(model.device)
221
+ outputs = model.generate(inputs, max_new_tokens=100, temperature=0.5)
222
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
223
+ elif llm_type == "together":
224
+ response = llm_client.chat.completions.create(
225
+ model=llm_model,
226
+ messages=formatted_messages,
227
+ max_tokens=100,
228
+ temperature=0.5
229
+ )
230
+ response = response.choices[0].message.content.strip()
231
+ else: # hf_api
232
+ response = llm_client.chat.completions.create(
233
+ messages=formatted_messages,
234
+ max_tokens=100,
235
+ temperature=0.5
236
+ )
237
+ response = response.choices[0].message.content.strip()
238
+
239
+ logger.info(f"Task {task_id} LLM tool selection response: {response}")
240
+ try:
241
+ tools_needed = json.loads(response)
242
+ if isinstance(tools_needed, list) and all(isinstance(t, str) and t in tools for t in tools_needed):
243
+ break # Valid response, exit retry loop
244
+ else:
245
+ raise ValueError("Invalid tool list format")
246
+ except json.JSONDecodeError as e:
247
+ logger.warning(f"Task {task_id}: Invalid JSON (attempt {attempt + 1}): {e}")
248
+ if attempt == 2:
249
+ tools_needed = ["search_tool"] # Fallback after retries
250
+ except Exception as e:
251
+ logger.warning(f"Task {task_id} Tool selection failed (attempt {attempt + 1}): {e}")
252
+ if attempt == 2:
253
+ tools_needed = ["search_tool"] # Fallback after retries
254
+
255
+ # Fallback to keyword-based selection if LLM fails
256
+ if tools_needed == ["search_tool"] and not any(kw in question.lower() for kw in ["calculate", "math", "image", "document", "file", "weather", "guest", "model"]):
257
+ question_lower = question.lower()
258
+ if any(kw in question_lower for kw in ["excel", "csv", "mp3", "data", "table", "xlsx"]):
259
+ tools_needed.append("file_parser_tool")
260
+ if any(kw in question_lower for kw in ["image", "video", "picture", "painting"]):
261
+ tools_needed.append("image_parser_tool")
262
+ if any(kw in question_lower for kw in ["calculate", "math", "sum", "average", "total"]):
263
+ tools_needed.append("calculator_tool")
264
+ if any(kw in question_lower for kw in ["document", "pdf", "report", "paper"]):
265
+ tools_needed.append("document_retriever_tool")
266
+ if any(kw in question_lower for kw in ["search", "wikipedia", "online"]):
267
+ tools_needed.append("duckduckgo_search_tool")
268
+ if any(kw in question_lower for kw in ["weather", "temperature", "forecast"]):
269
+ tools_needed.append("weather_info_tool")
270
+ if any(kw in question_lower for kw in ["model", "huggingface", "dataset"]):
271
+ tools_needed.append("hub_stats_tool")
272
+ if any(kw in question_lower for kw in ["guest", "name", "relation", "person"]):
273
+ tools_needed.append("guest_info_retriever_tool")
274
+ if len(question.split()) > 20 or "multiple" in question_lower:
275
+ tools_needed.append("multi_hop_search_tool")
276
+
277
+ # Integrate file-based tools
278
+ file_results = await fetch_task_file(task_id, question)
279
+ for ext, content in file_results.items():
280
+ if content:
281
+ os.makedirs("temp", exist_ok=True)
282
+ file_path = f"temp/{task_id}.{ext}"
283
+ with open(file_path, "wb") as f:
284
+ f.write(content)
285
+ state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path}
286
+ if ext in ["txt", "csv", "xlsx", "mp3"] and "file_parser_tool" not in tools_needed:
287
  tools_needed.append("file_parser_tool")
288
+ elif ext in ["jpg", "png"] and "image_parser_tool" not in tools_needed:
289
  tools_needed.append("image_parser_tool")
290
+ elif ext == "pdf" and "document_retriever_tool" not in tools_needed:
291
  tools_needed.append("document_retriever_tool")
292
+
293
+ state["tools_needed"] = list(set(tools_needed)) # Remove duplicates
294
+ logger.info(f"Task {task_id} Selected tools: {state['tools_needed']}")
 
 
295
  return state
296
  except Exception as e:
297
+ logger.error(f"Task {task_id} Tool selection failed: {e}")
298
  state["error"] = f"Parse question failed: {str(e)}"
299
  state["tools_needed"] = ["search_tool"]
300
  return state
301
 
302
  # Tool dispatcher
303
  async def tool_dispatcher(state: JARVISState) -> JARVISState:
304
+ state = validate_state(state)
305
  try:
306
+ task_id = state["task_id"]
307
+ question = state["question"]
308
+ tools_needed = state["tools_needed"]
309
+
310
+ for tool_name in tools_needed:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  try:
312
+ if tool_name == "search_tool":
313
+ result = await tools["search_tool"].ainvoke({"query": question})
314
+ state["web_results"].extend([str(r) for r in result] if result else ["No results from search_tool"])
315
+ elif tool_name == "multi_hop_search_tool":
316
+ result = await tools["multi_hop_search_tool"].ainvoke({
317
+ "query": question,
318
+ "steps": 3,
319
+ "llm_client": llm_client,
320
+ "llm_type": llm_type,
321
+ "llm_model": llm_model
322
+ })
323
+ state["multi_hop_results"].extend([r["content"] if isinstance(r, dict) else str(r) for r in result] if result else ["No results from multi_hop_search_tool"])
324
+ elif tool_name == "file_parser_tool":
325
+ file_path = state["metadata"].get("file_path")
326
+ file_ext = state["metadata"].get("file_ext")
327
+ if file_path and os.path.exists(file_path) and file_ext:
328
+ result = await tools["file_parser_tool"].ainvoke({
329
+ "task_id": task_id,
330
+ "file_type": file_ext,
331
+ "file_path": file_path,
332
+ "query": question
333
+ })
334
+ state["file_results"] = str(result) if result else "No file results"
335
+ else:
336
+ state["file_results"] = "No file available"
337
+ elif tool_name == "image_parser_tool":
338
+ file_path = state["metadata"].get("file_path")
339
+ if file_path and os.path.exists(file_path) and file_path.split('.')[-1] in ["jpg", "png"]:
340
+ result = await tools["image_parser_tool"].ainvoke({"task_id": task_id, "file_path": file_path})
341
+ state["image_results"] = str(result) if result else "No image results"
342
+ else:
343
+ state["image_results"] = "No image available"
344
+ elif tool_name == "calculator_tool":
345
+ result = await tools["calculator_tool"].ainvoke({"expression": question})
346
+ state["calculation_results"] = str(result) if result else "No calculation results"
347
+ elif tool_name == "document_retriever_tool":
348
+ file_path = state["metadata"].get("file_path")
349
+ if file_path and os.path.exists(file_path) and file_path.split('.')[-1] == "pdf":
350
+ result = await tools["document_retriever_tool"].ainvoke({
351
+ "task_id": task_id,
352
+ "query": question,
353
+ "file_path": file_path
354
+ })
355
+ state["document_results"] = str(result) if result else "No document results"
356
+ else:
357
+ state["document_results"] = "No document available"
358
+ elif tool_name == "duckduckgo_search_tool":
359
+ result = await tools["duckduckgo_search_tool"].ainvoke({
360
+ "query": question,
361
+ "original_query": question,
362
+ "embedder": embedder
363
+ })
364
+ state["web_results"].extend(result if isinstance(result, list) else [str(result)] if result else ["No results from duckduckgo_search_tool"])
365
+ elif tool_name == "weather_info_tool":
366
+ location = question.split()[-1] if "weather" in question.lower() else "Unknown"
367
+ result = await tools["weather_info_tool"].ainvoke({"location": location})
368
+ state["web_results"].append(str(result) if result else "No weather results")
369
+ elif tool_name == "hub_stats_tool":
370
+ author = question.split("by ")[1].split()[0] if "by" in question.lower() else "Unknown"
371
+ result = await tools["hub_stats_tool"].ainvoke({"author": author})
372
+ state["web_results"].append(str(result) if result else "No hub stats results")
373
+ elif tool_name == "guest_info_retriever_tool":
374
+ result = await tools["guest_info_retriever_tool"].ainvoke({"query": question})
375
+ state["web_results"].append(str(result) if result else "No guest info results")
376
 
377
+ state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_executed": True}
378
+ logger.info(f"Task {task_id}: Executed {tool_name}")
379
  except Exception as e:
380
+ logger.warning(f"Tool {tool_name} failed for task {task_id}: {e}")
381
+ state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_error": str(e)}
382
+
383
+ # Ensure results are populated
384
+ state["web_results"] = state.get("web_results", ["No web results found"])
385
+ state["file_results"] = state.get("file_results", "No file results found")
386
+ state["image_results"] = state.get("image_results", "No image results found")
387
+ state["document_results"] = state.get("document_results", "No document results found")
388
+ state["calculation_results"] = state.get("calculation_results", "No calculation results found")
389
+
390
+ state["answer"] = await generate_answer(
391
+ task_id=task_id,
392
+ question=question,
393
+ search_results=state.get("web_results", []) + [
394
+ r["content"] if isinstance(r, dict) else str(r) for r in state.get("multi_hop_results", [])
395
+ ],
396
+ file_results=state.get("file_results", "") + state.get("document_results", "") + state.get("image_results", "") + state.get("calculation_results", ""),
397
+ llm_client=llm_client
398
+ )
399
+
400
+ logger.info(f"Task {task_id}: Generated answer: {state['answer']}")
401
+ return state
402
  except Exception as e:
403
+ logger.error(f"Tool dispatch failed: {e}")
404
+ state["error"] = f"Tool dispatch failed: {e}"
405
+ return state
 
 
 
 
 
 
406
 
407
  # Define StateGraph
408
  workflow = StateGraph(JARVISState)
409
+ workflow.add_node("parse_question", parse_question)
410
  workflow.add_node("tool_dispatcher", tool_dispatcher)
411
+ workflow.set_entry_point("parse_question")
412
+ workflow.add_edge("parse_question", "tool_dispatcher")
413
+ workflow.add_edge("tool_dispatcher", END)
 
 
 
 
 
 
 
 
 
414
  graph = workflow.compile()
415
 
416
  # Agent class
417
  class JARVISAgent:
418
  def __init__(self):
419
+ self.state = reset_state(task_id="init", question="Agent initialized")
420
+ self.state["results_table"] = [] # Initialize as empty list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  logger.info("JARVISAgent initialized.")
422
 
423
  async def process_question(self, task_id: str, question: str) -> str:
424
+ state = reset_state(task_id=task_id, question=question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  try:
426
  result = await graph.ainvoke(state)
427
+ answer = result.get("answer", "Unknown")
428
+ logger.info(f"Task {task_id} Final answer: {answer}")
429
+ self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": answer})
430
+ self.state["metadata"] = {"last_task_id": task_id, "answer": answer}
431
  return answer
432
  except Exception as e:
433
  logger.error(f"Error processing task {task_id}: {e}")
434
+ self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"})
435
+ self.state["error"] = f"Task {task_id} failed: {str(e)}"
436
  return f"Error: {str(e)}"
437
  finally:
438
+ for ext in ["txt", "csv", "xlsx", "mp3", "jpg", "png", "pdf"]:
439
  file_path = f"temp/{task_id}.{ext}"
440
  if os.path.exists(file_path):
441
  try:
 
447
  async def process_all_questions(self, profile: gr.OAuthProfile | None):
448
  if not profile:
449
  logger.error("User not logged in.")
450
+ self.state["status_output"] = "Please Login to Hugging Face."
451
+ return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
452
 
453
+ username = profile.username
454
  logger.info(f"User logged in: {username}")
455
  questions_url = f"{GAIA_API_URL}/questions"
456
  submit_url = f"{GAIA_API_URL}/submit"
457
  agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
458
 
459
  try:
460
+ async with await create_http_session() as session:
461
+ async with session.get(questions_url) as response:
462
+ response.raise_for_status()
463
+ questions = await response.json()
464
  logger.info(f"Fetched {len(questions)} questions.")
465
  except Exception as e:
466
  logger.error(f"Error fetching questions: {e}")
467
+ self.state["status_output"] = f"Error fetching questions: {e}"
468
+ self.state["error"] = f"Fetch questions failed: {str(e)}"
469
+ return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
470
 
471
  answers_payload = []
472
  for item in questions:
 
480
 
481
  if not answers_payload:
482
  logger.error("No answers generated.")
483
+ self.state["status_output"] = "No answers to submit."
484
+ self.state["error"] = "No answers generated"
485
+ return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
486
 
487
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
488
  try:
489
+ async with await create_http_session() as session:
490
+ async with session.post(submit_url, json=submission_data) as response:
491
+ response.raise_for_status()
492
+ result_data = await response.json()
493
+ self.state["status_output"] = (
494
  f"Submission Successful!\n"
495
  f"User: {result_data.get('username')}\n"
496
  f"Overall Score: {result_data.get('score', 'N/A')}% "
497
  f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
498
  f"Message: {result_data.get('message', 'No message received.')}"
499
  )
500
+ self.state["metadata"] = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')}
501
  except Exception as e:
502
  logger.error(f"Submission failed: {e}")
503
+ self.state["status_output"] = f"Submission Failed: {e}"
504
+ self.state["error"] = f"Submission failed: {str(e)}"
505
 
506
+ return pd.DataFrame(self.state["results_table"] if self.state["results_table"] else [], columns=["Task ID", "Question", "Answer"]), self.state["status_output"]
507
 
508
  # Gradio interface
509
  with gr.Blocks() as demo:
510
+ gr.Markdown("# JARVIS GAIA Agent")
511
  gr.Markdown(
512
  """
513
  **Instructions:**
 
522
  )
523
  with gr.Row():
524
  gr.LoginButton(value="Login to Hugging Face")
 
525
  run_button = gr.Button("Run Evaluation & Submit All Answers")
526
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
527
  results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
project_structure.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .
2
+ ├── app.py
3
+ ├── dockerfile
4
+ ├── README.md
5
+ ├── requirements.txt
6
+ ├── retriever.py
7
+ ├── state.py
8
+ └── tools
9
+ ├── __init__.py
10
+ ├── answer_generator.py
11
+ ├── calculator.py
12
+ ├── document_retriever.py
13
+ ├── duckduckgo_search.py
14
+ ├── file_fetcher.py
15
+ ├── file_parser.py
16
+ ├── guest_info.py
17
+ ├── hub_stats.py
18
+ ├── image_parser.py
19
+ ├── search.py
20
+ └── weather_info.py
21
+
22
+ 3 directories, 18 files
requirements.txt CHANGED
@@ -20,8 +20,11 @@ transformers
20
  asyncio
21
  serpapi
22
  duckduckgo-search
23
- torch
24
  together
25
  google-search-results
26
  beautifulsoup4
27
- gradio[oauth]
 
 
 
 
20
  asyncio
21
  serpapi
22
  duckduckgo-search
23
+ torch==2.2.2
24
  together
25
  google-search-results
26
  beautifulsoup4
27
+ gradio[oauth]
28
+ nlkt
29
+ speechrecognition
30
+ rank_bm25
result.txt CHANGED
The diff for this file is too large to render. See raw diff
 
retriever.py CHANGED
@@ -1,25 +1,109 @@
1
- import datasets
2
- from langchain.docstore.document import Document
3
- from langchain_community.retrievers import BM25Retriever
4
- from smolagents import Tool
 
 
 
 
 
 
 
5
 
6
- def load_guest_dataset():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  try:
8
- guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  docs = [
10
  Document(
11
  page_content="\n".join([
12
- f"Name: {guest['name']}",
13
- f"Relation: {guest['relation']}",
14
- f"Description: {guest['description']}",
15
- f"Email: {guest['email']}"
16
  ]),
17
- metadata={"name": guest["name"]}
 
 
 
 
 
18
  )
19
- for guest in guest_dataset
20
  ]
 
 
 
21
  except Exception as e:
22
- # Fallback mock dataset
 
23
  docs = [
24
  Document(
25
  page_content="\n".join([
@@ -28,7 +112,73 @@ def load_guest_dataset():
28
  "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
29
  "Email: [email protected]"
30
  ]),
31
- metadata={"name": "Dr. Nikola Tesla"}
 
 
 
 
 
32
  )
33
  ]
34
- return docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import logging
4
+ import torch
5
+ from typing import List
6
+ from langchain_core.documents import Document
7
+ from sentence_transformers import SentenceTransformer
8
+ try:
9
+ from datasets import load_dataset
10
+ except ImportError:
11
+ load_dataset = None
12
 
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def get_device():
16
+ """
17
+ Determine the appropriate device for PyTorch.
18
+
19
+ Returns:
20
+ str: Device name ('cuda', 'mps', or 'cpu').
21
+ """
22
+ if torch.cuda.is_available():
23
+ return "cuda"
24
+ elif torch.backends.mps.is_available():
25
+ return "mps"
26
+ return "cpu"
27
+
28
+ def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]:
29
+ """
30
+ Load guest dataset from a local JSON file or Hugging Face dataset.
31
+
32
+ Args:
33
+ dataset_path (str): Path to local JSON file or Hugging Face dataset name.
34
+
35
+ Returns:
36
+ List[Document]: List of Document objects with guest information.
37
+ """
38
  try:
39
+ # Try loading from Hugging Face dataset if datasets library is available
40
+ if load_dataset and not os.path.exists(dataset_path):
41
+ logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}")
42
+ guest_dataset = load_dataset(dataset_path, split="train")
43
+ docs = [
44
+ Document(
45
+ page_content="\n".join([
46
+ f"Name: {guest['name']}",
47
+ f"Relation: {guest['relation']}",
48
+ f"Description: {guest['description']}",
49
+ f"Email: {guest['email']}"
50
+ ]),
51
+ metadata={
52
+ "name": guest["name"],
53
+ "relation": guest["relation"],
54
+ "description": guest["description"],
55
+ "email": guest["email"]
56
+ }
57
+ )
58
+ for guest in guest_dataset
59
+ ]
60
+ logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset")
61
+ return docs
62
+
63
+ # Try loading from local JSON file
64
+ if os.path.exists(dataset_path):
65
+ logger.info(f"Loading guest dataset from local path: {dataset_path}")
66
+ with open(dataset_path, 'r') as f:
67
+ guests = json.load(f)
68
+ docs = [
69
+ Document(
70
+ page_content=guest.get('description', ''),
71
+ metadata={
72
+ 'name': guest.get('name', ''),
73
+ 'relation': guest.get('relation', ''),
74
+ 'description': guest.get('description', ''),
75
+ 'email': guest.get('email', '') # Optional email field
76
+ }
77
+ )
78
+ for guest in guests
79
+ ]
80
+ logger.info(f"Loaded {len(docs)} guests from local JSON")
81
+ return docs
82
+
83
+ # Fallback to mock dataset if both fail
84
+ logger.warning(f"Dataset not found at {dataset_path}, using mock dataset")
85
  docs = [
86
  Document(
87
  page_content="\n".join([
88
+ "Name: Dr. Nikola Tesla",
89
+ "Relation: old friend from university days",
90
+ "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
91
+ "Email: [email protected]"
92
  ]),
93
+ metadata={
94
+ "name": "Dr. Nikola Tesla",
95
+ "relation": "old friend from university days",
96
+ "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
97
+ "email": "[email protected]"
98
+ }
99
  )
 
100
  ]
101
+ logger.info("Loaded mock dataset with 1 guest")
102
+ return docs
103
+
104
  except Exception as e:
105
+ logger.error(f"Failed to load guest dataset: {e}")
106
+ # Return mock dataset as final fallback
107
  docs = [
108
  Document(
109
  page_content="\n".join([
 
112
  "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
113
  "Email: [email protected]"
114
  ]),
115
+ metadata={
116
+ "name": "Dr. Nikola Tesla",
117
+ "relation": "old friend from university days",
118
+ "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
119
+ "email": "[email protected]"
120
+ }
121
  )
122
  ]
123
+ logger.info("Loaded mock dataset with 1 guest due to error")
124
+ return docs
125
+
126
+ class BM25Retriever:
127
+ """
128
+ A retriever class using SentenceTransformer for embedding-based search.
129
+ """
130
+ def __init__(self, dataset_path: str):
131
+ """
132
+ Initialize the retriever with a SentenceTransformer model.
133
+
134
+ Args:
135
+ dataset_path (str): Path to the dataset for retrieval.
136
+
137
+ Raises:
138
+ Exception: If embedder initialization fails.
139
+ """
140
+ try:
141
+ self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device())
142
+ self.dataset_path = dataset_path
143
+ logger.info("Initialized SentenceTransformer")
144
+ except Exception as e:
145
+ logger.error(f"Failed to initialize embedder: {e}")
146
+ raise
147
+
148
+ def search(self, query: str) -> List[dict]:
149
+ """
150
+ Search the dataset for relevant guest information.
151
+
152
+ Args:
153
+ query (str): Search query (e.g., guest name or relation).
154
+
155
+ Returns:
156
+ List[dict]: List of matching guest metadata dictionaries.
157
+ """
158
+ try:
159
+ # Load dataset
160
+ docs = load_guest_dataset(self.dataset_path)
161
+ if not docs:
162
+ logger.warning("No documents available for search")
163
+ return []
164
+
165
+ # Convert documents to text for BM25 (using metadata for consistency)
166
+ texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs]
167
+ from langchain_community.retrievers import BM25Retriever
168
+ retriever = BM25Retriever.from_texts(texts)
169
+ retriever.k = 3 # Limit to top 3 results
170
+
171
+ # Perform search
172
+ results = retriever.invoke(query)
173
+ # Map results back to original metadata
174
+ matches = [
175
+ docs[i].metadata
176
+ for i in range(len(docs))
177
+ if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results)
178
+ ]
179
+ logger.info(f"Found {len(matches)} matches for query: {query}")
180
+ return matches[:3] # Return top 3 matches
181
+
182
+ except Exception as e:
183
+ logger.error(f"Search failed for query '{query}': {e}")
184
+ return []
state.py CHANGED
@@ -1,5 +1,8 @@
1
- from typing import TypedDict, List, Dict, Optional, Any
2
  from langchain_core.messages import BaseMessage
 
 
 
3
 
4
  class JARVISState(TypedDict):
5
  """
@@ -10,11 +13,11 @@ class JARVISState(TypedDict):
10
  question: The question text to be answered.
11
  tools_needed: List of tool names to be used for the task.
12
  web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
13
- file_results: Parsed content from text, CSV, or Excel files.
14
  image_results: OCR or description results from image files.
15
  calculation_results: Results from mathematical calculations.
16
- document_results: Extracted content from PDF documents.
17
- multi_hop_results: Results from iterative multi-hop searches.
18
  messages: List of messages for LLM context (e.g., user prompts, system instructions).
19
  answer: Final answer for the task, formatted for GAIA submission.
20
  results_table: List of task results for Gradio display (Task ID, Question, Answer).
@@ -30,10 +33,84 @@ class JARVISState(TypedDict):
30
  image_results: str
31
  calculation_results: str
32
  document_results: str
33
- multi_hop_results: List[str]
34
  messages: List[BaseMessage]
35
  answer: str
36
  results_table: List[Dict[str, str]]
37
  status_output: str
38
  error: Optional[str]
39
- metadata: Optional[Dict[str, Any]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, List, Dict, Optional, Any, Union
2
  from langchain_core.messages import BaseMessage
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
 
7
  class JARVISState(TypedDict):
8
  """
 
13
  question: The question text to be answered.
14
  tools_needed: List of tool names to be used for the task.
15
  web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
16
+ file_results: Parsed content from text, CSV, Excel, or audio files.
17
  image_results: OCR or description results from image files.
18
  calculation_results: Results from mathematical calculations.
19
+ document_results: Extracted content from PDF or text documents.
20
+ multi_hop_results: Results from iterative multi-hop searches (supports strings or dicts).
21
  messages: List of messages for LLM context (e.g., user prompts, system instructions).
22
  answer: Final answer for the task, formatted for GAIA submission.
23
  results_table: List of task results for Gradio display (Task ID, Question, Answer).
 
33
  image_results: str
34
  calculation_results: str
35
  document_results: str
36
+ multi_hop_results: List[Union[str, Dict[str, Any]]]
37
  messages: List[BaseMessage]
38
  answer: str
39
  results_table: List[Dict[str, str]]
40
  status_output: str
41
  error: Optional[str]
42
+ metadata: Optional[Dict[str, Any]]
43
+
44
+ def validate_state(state: JARVISState) -> JARVISState:
45
+ """
46
+ Validate and initialize JARVISState fields.
47
+
48
+ Args:
49
+ state: Input state dictionary.
50
+
51
+ Returns:
52
+ Validated and initialized state.
53
+ """
54
+ try:
55
+ if not state.get("task_id"):
56
+ logger.error("task_id is required")
57
+ raise ValueError("task_id is required")
58
+ if not state.get("question"):
59
+ logger.error("question is required")
60
+ raise ValueError("question is required")
61
+
62
+ # Initialize default values if missing
63
+ defaults = {
64
+ "tools_needed": ["search_tool"],
65
+ "web_results": [],
66
+ "file_results": "",
67
+ "image_results": "",
68
+ "calculation_results": "",
69
+ "document_results": "",
70
+ "multi_hop_results": [],
71
+ "messages": [],
72
+ "answer": "",
73
+ "results_table": [],
74
+ "status_output": "",
75
+ "error": None,
76
+ "metadata": {}
77
+ }
78
+ for key, default in defaults.items():
79
+ if key not in state or state[key] is None:
80
+ state[key] = default
81
+
82
+ logger.debug(f"Validated state for task {state['task_id']}")
83
+ return state
84
+ except Exception as e:
85
+ logger.error(f"State validation failed: {e}")
86
+ raise
87
+
88
+ def reset_state(task_id: str, question: str) -> JARVISState:
89
+ """
90
+ Create a fresh JARVISState for a new task.
91
+
92
+ Args:
93
+ task_id: Task identifier.
94
+ question: Question text.
95
+
96
+ Returns:
97
+ Initialized JARVISState.
98
+ """
99
+ state = JARVISState(
100
+ task_id=task_id,
101
+ question=question,
102
+ tools_needed=["search_tool"],
103
+ web_results=[],
104
+ file_results="",
105
+ image_results="",
106
+ calculation_results="",
107
+ document_results="",
108
+ multi_hop_results=[],
109
+ messages=[],
110
+ answer="",
111
+ results_table=[],
112
+ status_output="",
113
+ error=None,
114
+ metadata={}
115
+ )
116
+ return validate_state(state)
test.py CHANGED
@@ -1,10 +1,233 @@
1
- from serpapi import GoogleSearch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- params = {
4
- "q": "drop shipping",
5
- "api_key": "e44c79583cac0e507fee32d564f190b7290a313d886edd5ba5fccc93df932733"
6
- }
7
 
8
- search = GoogleSearch(params)
9
- results = search.get_dict()
10
- ai_overview = results["ai_overview"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import logging
4
+ import tempfile
5
+ from pathlib import Path
6
+ from app import JARVISAgent, llm_client, llm_type, llm_model, embedder
7
+ from tools.search import search_tool, multi_hop_search_tool
8
+ from tools.file_parser import file_parser_tool
9
+ from tools.image_parser import image_parser_tool
10
+ from tools.calculator import calculator_tool
11
+ from tools.document_retriever import document_retriever_tool
12
+ from tools.duckduckgo_search import duckduckgo_search_tool
13
+ from tools.weather_info import weather_info_tool
14
+ from tools.hub_stats import hub_stats_tool
15
+ from tools.guest_info import guest_info_retriever_tool
16
+ from tools.file_fetcher import fetch_task_file
17
+ from tools.answer_generator import preprocess_question, filter_results
18
+ from state import validate_state, reset_state, JARVISState
19
 
20
+ # Setup logging
21
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
22
+ logger = logging.getLogger(__name__)
 
23
 
24
+ async def test_tools():
25
+ """Test all tools."""
26
+ logger.info("Testing Search Tool (SerpAPI)...")
27
+ try:
28
+ if not os.getenv("SERPAPI_API_KEY"):
29
+ logger.warning("Search Warning: SERPAPI_API_KEY not set")
30
+ else:
31
+ result = await search_tool.ainvoke({"query": "What is the capital of France?"})
32
+ logger.info(f"Search Result: {result}")
33
+ except Exception as e:
34
+ logger.error(f"Search Error: {e}")
35
+
36
+ logger.info("Testing Multi-Hop Search Tool...")
37
+ try:
38
+ result = await multi_hop_search_tool.ainvoke({
39
+ "query": "What is the population of France's capital?",
40
+ "steps": 2,
41
+ "llm_client": llm_client,
42
+ "llm_type": llm_type,
43
+ "llm_model": llm_model
44
+ })
45
+ logger.info(f"Multi-Hop Search Result: {result}")
46
+ except Exception as e:
47
+ logger.error(f"Multi-Hop Search Error: {e}")
48
+
49
+ logger.info("Testing DuckDuckGo Search Tool...")
50
+ try:
51
+ result = await duckduckgo_search_tool.ainvoke({
52
+ "query": "What is the capital of France?",
53
+ "original_query": "What is the capital of France?",
54
+ "embedder": embedder
55
+ })
56
+ logger.info(f"DuckDuckGo Result: {result}")
57
+ except Exception as e:
58
+ logger.error(f"DuckDuckGo Error: {e}")
59
+
60
+ logger.info("Testing Weather Info Tool...")
61
+ try:
62
+ if not os.getenv("OPENWEATHERMAP_API_KEY"):
63
+ logger.warning("Weather Warning: OPENWEATHERMAP_API_KEY not set")
64
+ else:
65
+ result = await weather_info_tool.ainvoke({"location": "London"})
66
+ logger.info(f"Weather Result: {result}")
67
+ except Exception as e:
68
+ logger.error(f"Weather Error: {e}")
69
+
70
+ logger.info("Testing Document Retriever Tool...")
71
+ try:
72
+ from PyPDF2 import PdfWriter
73
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
74
+ writer = PdfWriter()
75
+ from PyPDF2.generic import NameObject, create_string_object
76
+ page = writer.add_blank_page(width=72, height=72)
77
+ page[NameObject("/Contents")] = create_string_object("Sample document content for testing.")
78
+ writer.write(tmp)
79
+ tmp_path = tmp.name
80
+ result = await document_retriever_tool.ainvoke({
81
+ "task_id": "test_task",
82
+ "query": "Sample question",
83
+ "file_path": tmp_path
84
+ })
85
+ logger.info(f"Document Retriever Result: {result}")
86
+ os.unlink(tmp_path)
87
+ except Exception as e:
88
+ logger.error(f"Document Retriever Error: {e}")
89
+
90
+ logger.info("Testing Image Parser Tool...")
91
+ try:
92
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
93
+ # Create a minimal PNG (1x1 pixel)
94
+ from PIL import Image
95
+ img = Image.new('RGB', (1, 1), color='white')
96
+ img.save(tmp.name, 'PNG')
97
+ tmp_path = tmp.name
98
+ result = await image_parser_tool.ainvoke({"task_id": "test_task", "file_path": tmp_path})
99
+ logger.info(f"Image Parser Result: {result}")
100
+ os.unlink(tmp_path)
101
+ except Exception as e:
102
+ logger.error(f"Image Parser Error: {e}")
103
+
104
+ logger.info("Testing File Parser Tool...")
105
+ try:
106
+ with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
107
+ tmp.write(b"Sample text file content")
108
+ tmp_path = tmp.name
109
+ result = await file_parser_tool.ainvoke({
110
+ "task_id": "test_task",
111
+ "file_type": "txt",
112
+ "file_path": tmp_path,
113
+ "query": "What is in the file?"
114
+ })
115
+ logger.info(f"File Parser Result: {result}")
116
+ os.unlink(tmp_path)
117
+ except Exception as e:
118
+ logger.error(f"File Parser Error: {e}")
119
+
120
+ logger.info("Testing Calculator Tool...")
121
+ try:
122
+ result = await calculator_tool.ainvoke({"expression": "2 + 2"})
123
+ logger.info(f"Calculator Result: {result}")
124
+ except Exception as e:
125
+ logger.error(f"Calculator Error: {e}")
126
+
127
+ logger.info("Testing Hub Stats Tool...")
128
+ try:
129
+ if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
130
+ logger.warning("Hub Stats Warning: HUGGINGFACEHUB_API_TOKEN not set")
131
+ else:
132
+ result = await hub_stats_tool.ainvoke({"author": "meta-llama"})
133
+ logger.info(f"Hub Stats Result: {result}")
134
+ except Exception as e:
135
+ logger.error(f"Hub Stats Error: {e}")
136
+
137
+ logger.info("Testing Guest Info Retriever Tool...")
138
+ try:
139
+ result = await guest_info_retriever_tool.ainvoke({"query": "Who is the guest named John?"})
140
+ logger.info(f"Guest Info Result: {result}")
141
+ except Exception as e:
142
+ logger.error(f"Guest Info Error: {e}")
143
+
144
+ async def test_file_fetcher():
145
+ """Test file fetcher."""
146
+ logger.info("Testing File Fetcher...")
147
+ try:
148
+ result = await fetch_task_file("8e867cd7-cff9-4e6c-867a-ff5ddc2550be", "Sample question with data")
149
+ logger.info(f"File Fetcher Result: {result}")
150
+ except Exception as e:
151
+ logger.error(f"File Fetcher Error: {e}")
152
+
153
+ async def test_answer_generator():
154
+ """Test answer generator functions."""
155
+ logger.info("Testing Preprocess Question...")
156
+ try:
157
+ result = await preprocess_question("What's the weather in Paris?")
158
+ logger.info(f"Preprocess Question Result: {result}")
159
+ except Exception as e:
160
+ logger.error(f"Preprocess Question Error: {e}")
161
+
162
+ logger.info("Testing Filter Results...")
163
+ try:
164
+ results = ["Paris is the capital of France.", "Florida is a state.", "Paris is in Texas."]
165
+ filtered = filter_results(results, "What is the capital of France?")
166
+ logger.info(f"Filter Results: {filtered}")
167
+ except Exception as e:
168
+ logger.error(f"Filter Results Error: {e}")
169
+
170
+ async def test_state_management():
171
+ """Test state management functions."""
172
+ logger.info("Testing Reset State...")
173
+ try:
174
+ state = reset_state("test_task", "What is the capital of France?")
175
+ logger.info(f"Reset State Result: {state}")
176
+ except Exception as e:
177
+ logger.error(f"Reset State Error: {e}")
178
+
179
+ logger.info("Testing Validate State...")
180
+ try:
181
+ invalid_state = {"task_id": "", "question": ""}
182
+ validate_state(invalid_state)
183
+ logger.error("Validate State should have failed")
184
+ except ValueError as e:
185
+ logger.info(f"Validate State Error (expected): {e}")
186
+
187
+ try:
188
+ valid_state = reset_state("test_task", "Sample question")
189
+ validated = validate_state(valid_state)
190
+ logger.info(f"Validate State Result: {validated}")
191
+ except Exception as e:
192
+ logger.error(f"Validate State Error: {e}")
193
+
194
+ async def test_agent():
195
+ """Test JARVISAgent with various cases."""
196
+ logger.info("Testing JARVISAgent (Simple Question)...")
197
+ try:
198
+ agent = JARVISAgent()
199
+ answer = await agent.process_question("test_task", "What is the capital of France?")
200
+ logger.info(f"JARVISAgent Answer: {answer}")
201
+ except Exception as e:
202
+ logger.error(f"JARVISAgent Error: {e}")
203
+
204
+ logger.info("Testing JARVISAgent (Edge Case: Empty Question)...")
205
+ try:
206
+ agent = JARVISAgent()
207
+ answer = await agent.process_question("test_task", "")
208
+ logger.info(f"JARVISAgent Empty Question Answer: {answer}")
209
+ except Exception as e:
210
+ logger.info(f"JARVISAgent Empty Question Error (expected): {e}")
211
+
212
+ async def main():
213
+ required_envs = [
214
+ "HUGGINGFACEHUB_API_TOKEN",
215
+ "TOGETHER_API_KEY",
216
+ "OPENWEATHERMAP_API_KEY",
217
+ "SERPAPI_API_KEY"
218
+ ]
219
+ for env in required_envs:
220
+ if not os.getenv(env):
221
+ logger.warning(f"{env} not set, some tools may fail")
222
+
223
+ await test_tools()
224
+ await test_file_fetcher()
225
+ await test_answer_generator()
226
+ await test_state_management()
227
+ await test_agent()
228
+
229
+ if __name__ == "__main__":
230
+ try:
231
+ asyncio.run(main())
232
+ except Exception as e:
233
+ logger.error(f"Test script failed: {e}")
tools/__init__.py CHANGED
@@ -6,4 +6,6 @@ from .document_retriever import document_retriever_tool
6
  from .duckduckgo_search import duckduckgo_search_tool
7
  from .weather_info import weather_info_tool
8
  from .hub_stats import hub_stats_tool
9
- from .guest_info import guest_info_retriever_tool
 
 
 
6
  from .duckduckgo_search import duckduckgo_search_tool
7
  from .weather_info import weather_info_tool
8
  from .hub_stats import hub_stats_tool
9
+ from .guest_info import guest_info_retriever_tool
10
+ from .file_fetcher import fetch_task_file
11
+ from .answer_generator import generate_answer, preprocess_question#, filter_results, get_embedder
tools/answer_generator.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import logging
3
+ import numpy as np
4
+ from typing import List, Any
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.messages import SystemMessage, HumanMessage
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Download NLTK data
14
+ try:
15
+ nltk.download('punkt', quiet=True)
16
+ nltk.download('stopwords', quiet=True)
17
+ except Exception as e:
18
+ logger.warning(f"NLTK data download failed: {e}")
19
+
20
+ # Global embedder
21
+ _embedder = None
22
+
23
+ def get_embedder():
24
+ global _embedder
25
+ if _embedder is None:
26
+ try:
27
+ _embedder = SentenceTransformer(
28
+ "all-MiniLM-L6-v2",
29
+ device="cpu",
30
+ cache_folder="./cache"
31
+ )
32
+ logger.info("SentenceTransformer initialized")
33
+ except Exception as e:
34
+ logger.error(f"Failed to initialize SentenceTransformer: {e}")
35
+ raise RuntimeError(f"Embedder initialization failed: {e}")
36
+ return _embedder
37
+
38
+ def filter_results(search_results: List[str], question: str) -> List[str]:
39
+ try:
40
+ if not search_results or not question:
41
+ return search_results
42
+
43
+ embedder = get_embedder()
44
+ question_embedding = embedder.encode([question], convert_to_numpy=True)
45
+ result_embeddings = embedder.encode(search_results, convert_to_numpy=True)
46
+
47
+ similarities = np.dot(result_embeddings, question_embedding.T).flatten()
48
+ filtered_results = [
49
+ search_results[i] for i in range(len(search_results))
50
+ if similarities[i] > 0.5 and search_results[i].strip()
51
+ ]
52
+
53
+ return filtered_results if filtered_results else search_results[:3]
54
+ except Exception as e:
55
+ logger.warning(f"Result filtering failed: {e}")
56
+ return search_results[:3]
57
+
58
+ async def preprocess_question(question: str) -> str:
59
+ """Preprocess the question to clean and standardize it."""
60
+ try:
61
+ question = question.strip().lower()
62
+ if not question.endswith("?"):
63
+ question += "?"
64
+ logger.debug(f"Preprocessed question: {question}")
65
+ return question
66
+ except Exception as e:
67
+ logger.error(f"Error preprocessing question: {e}")
68
+ return question
69
+
70
+ async def generate_answer(
71
+ task_id: str,
72
+ question: str,
73
+ search_results: List[str],
74
+ file_results: str,
75
+ llm_client: Any
76
+ ) -> str:
77
+ """Generate an answer using LLM with search and file results."""
78
+ try:
79
+ if not search_results:
80
+ search_results = ["No search results available."]
81
+ if not file_results:
82
+ file_results = "No file results available."
83
+
84
+ context = "\n".join([str(r) for r in search_results]) + "\n" + file_results
85
+ prompt = ChatPromptTemplate.from_messages([
86
+ SystemMessage(content="""You are an assistant answering questions using provided context.
87
+ - Use ONLY the context to formulate a concise, accurate answer.
88
+ - If the context is insufficient, state: 'Insufficient information to answer.'
89
+ - Do NOT generate or assume information beyond the context.
90
+ - Return a single, clear sentence or phrase as the answer."""),
91
+ HumanMessage(content=f"Context: {context}\nQuestion: {question}")
92
+ ])
93
+
94
+ messages = [
95
+ {"role": "system", "content": prompt[0].content},
96
+ {"role": "user", "content": prompt[1].content}
97
+ ]
98
+
99
+ if isinstance(llm_client, tuple): # hf_local
100
+ model, tokenizer = llm_client
101
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
102
+ outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7)
103
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
104
+ elif hasattr(llm_client, "chat"): # together
105
+ response = llm_client.chat.completions.create(
106
+ model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
107
+ messages=messages,
108
+ max_tokens=100,
109
+ temperature=0.7,
110
+ top_p=0.9,
111
+ frequency_penalty=0.5
112
+ )
113
+ response = response.choices[0].message.content.strip()
114
+ else: # hf_api
115
+ response = llm_client.chat.completions.create(
116
+ messages=messages,
117
+ max_tokens=100,
118
+ temperature=0.7
119
+ )
120
+ response = response.choices[0].message.content.strip()
121
+
122
+ answer = response.strip()
123
+ if not answer or answer.lower() == "none":
124
+ answer = "Insufficient information to answer."
125
+ logger.info(f"Task {task_id}: Generated answer: {answer}")
126
+ return answer
127
+ except Exception as e:
128
+ logger.error(f"Task {task_id}: Answer generation failed: {e}")
129
+ return "Error generating answer."
tools/calculator.py CHANGED
@@ -1,15 +1,35 @@
1
- from langchain_core.tools import tool
2
- from sympy import sympify
3
  import logging
 
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
- @tool
8
- async def calculator_tool(expression: str) -> str:
9
- """Evaluate a mathematical expression."""
 
 
 
 
 
 
 
 
 
 
10
  try:
11
- result = sympify(expression)
 
 
 
12
  return str(result)
13
  except Exception as e:
14
- logger.error(f"Error evaluating expression '{expression}': {e}")
15
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from langchain_core.tools import StructuredTool
3
+ from pydantic import BaseModel, Field
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
+ class CalculatorInput(BaseModel):
8
+ expression: str = Field(description="Mathematical expression to evaluate")
9
+
10
+ async def calculator_func(expression: str) -> str:
11
+ """
12
+ Evaluate a mathematical expression and return the result as a string.
13
+
14
+ Args:
15
+ expression (str): Mathematical expression (e.g., '2 + 2').
16
+
17
+ Returns:
18
+ str: Result of the expression.
19
+ """
20
  try:
21
+ logger.info(f"Evaluating expression: {expression}")
22
+ result = eval(expression, {"__builtins__": {}}, {}) # Safe eval
23
+ if isinstance(result, float):
24
+ return f"{result:.2f}" if "USD" in expression else str(result)
25
  return str(result)
26
  except Exception as e:
27
+ logger.error(f"Calculator error: {e}")
28
+ return f"Error: {e}"
29
+
30
+ calculator_tool = StructuredTool.from_function(
31
+ func=calculator_func,
32
+ name="calculator_tool",
33
+ args_schema=CalculatorInput,
34
+ coroutine=calculator_func
35
+ )
tools/document_retriever.py CHANGED
@@ -1,30 +1,47 @@
1
- from langchain_core.tools import tool
2
- from langchain_community.document_loaders import TextLoader, CSVLoader, PyPDFLoader
3
  import logging
4
  import os
 
 
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
- @tool
9
- async def document_retriever_tool(task_id: str, query: str, file_type: str) -> str:
10
- """Retrieve content from a document."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  try:
12
- file_path = f"temp_{task_id}.{file_type}"
13
- if not os.path.exists(file_path):
14
- logger.warning(f"Document not found: {file_path}")
15
- return "Document not found"
16
-
17
- if file_type == "txt":
18
- loader = TextLoader(file_path)
19
- elif file_type == "csv":
20
- loader = CSVLoader(file_path)
21
- elif file_type == "pdf":
22
- loader = PyPDFLoader(file_path)
23
- else:
24
- return f"Unsupported file type: {file_type}"
25
-
26
- docs = loader.load()
27
- return "\n".join(doc.page_content for doc in docs)
28
  except Exception as e:
29
  logger.error(f"Error retrieving document for task {task_id}: {e}")
30
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import os
3
+ from langchain_core.tools import StructuredTool
4
+ from pydantic import BaseModel, Field
5
+ from typing import Optional
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
+ class DocumentRetrieverInput(BaseModel):
10
+ task_id: str = Field(description="Task identifier")
11
+ query: str = Field(description="Search query")
12
+ file_path: Optional[str] = Field(description="Path to document file", default=None)
13
+
14
+ async def document_retriever_func(task_id: str, query: str, file_path: Optional[str] = None) -> str:
15
+ """
16
+ Retrieve content from documents for a given task and query.
17
+
18
+ Args:
19
+ task_id (str): Task identifier.
20
+ query (str): Search query.
21
+ file_path (Optional[str]): Path to document file.
22
+
23
+ Returns:
24
+ str: Retrieved document content or error message.
25
+ """
26
  try:
27
+ if file_path and os.path.exists(file_path):
28
+ logger.info(f"Retrieving document from {file_path} for task {task_id}")
29
+ if file_path.endswith('.pdf'):
30
+ import PyPDF2
31
+ with open(file_path, 'rb') as f:
32
+ reader = PyPDF2.PdfReader(f)
33
+ text = "".join(page.extract_text() or "" for page in reader.pages)
34
+ return text[:500] if text else "No text extracted"
35
+ return "Unsupported file format"
36
+ logger.warning(f"No valid documents found for task {task_id}")
37
+ return "Document not found"
 
 
 
 
 
38
  except Exception as e:
39
  logger.error(f"Error retrieving document for task {task_id}: {e}")
40
+ return f"Error: {str(e)}"
41
+
42
+ document_retriever_tool = StructuredTool.from_function(
43
+ func=document_retriever_func,
44
+ name="document_retriever_tool",
45
+ args_schema=DocumentRetrieverInput,
46
+ coroutine=document_retriever_func
47
+ )
tools/duckduckgo_search.py CHANGED
@@ -1,6 +1,99 @@
1
- from smolagents import Tool, DuckDuckGoSearchTool
2
  import logging
 
 
 
 
 
 
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
6
- duckduckgo_search_tool = DuckDuckGoSearchTool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ import os
3
+ import asyncio
4
+ from langchain_core.tools import StructuredTool
5
+ from pydantic import BaseModel, Field
6
+ from typing import Optional, List
7
+ from duckduckgo_search import DDGS
8
+ from serpapi import GoogleSearch
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
+ class DuckDuckGoSearchInput(BaseModel):
13
+ query: str = Field(description="Search query")
14
+ original_query: str = Field(description="Original query for context")
15
+ embedder: Optional[object] = Field(description="SentenceTransformer embedder", default=None)
16
+
17
+ async def duckduckgo_search_func(query: str, original_query: str, embedder: Optional[object] = None) -> List[str]:
18
+ """
19
+ Perform a DuckDuckGo search with retries and fall back to SerpAPI if needed.
20
+
21
+ Args:
22
+ query (str): Search query.
23
+ original_query (str): Original query for context.
24
+ embedder (Optional[object]): SentenceTransformer for result filtering.
25
+
26
+ Returns:
27
+ List[str]: List of search result snippets.
28
+ """
29
+ async def try_duckduckgo(query: str, max_retries: int = 3) -> List[str]:
30
+ for attempt in range(max_retries):
31
+ try:
32
+ logger.info(f"DuckDuckGo search attempt {attempt + 1} for query: {query}")
33
+ with DDGS() as ddgs:
34
+ results = [r['body'] for r in ddgs.text(query, max_results=5)]
35
+ return results
36
+ except Exception as e:
37
+ if "Ratelimit" in str(e) and attempt < max_retries - 1:
38
+ wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
39
+ logger.warning(f"DuckDuckGo rate limit hit, retrying in {wait_time}s: {e}")
40
+ await asyncio.sleep(wait_time)
41
+ else:
42
+ logger.error(f"DuckDuckGo search failed for query '{query}': {e}")
43
+ raise e
44
+ return []
45
+
46
+ async def try_serpapi(query: str, max_retries: int = 3) -> List[str]:
47
+ if not os.getenv("SERPAPI_API_KEY"):
48
+ logger.warning("SERPAPI_API_KEY not set, cannot use SerpAPI fallback")
49
+ return []
50
+ for attempt in range(max_retries):
51
+ try:
52
+ logger.info(f"SerpAPI search attempt {attempt + 1} for query: {query}")
53
+ params = {
54
+ "q": query,
55
+ "api_key": os.getenv("SERPAPI_API_KEY"),
56
+ "num": 5
57
+ }
58
+ search = GoogleSearch(params)
59
+ results = search.get_dict().get("organic_results", [])
60
+ return [result.get("snippet", "") for result in results if "snippet" in result]
61
+ except Exception as e:
62
+ if attempt < max_retries - 1:
63
+ wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
64
+ logger.warning(f"SerpAPI search failed, retrying in {wait_time}s: {e}")
65
+ await asyncio.sleep(wait_time)
66
+ else:
67
+ logger.error(f"SerpAPI search failed for query '{query}': {e}")
68
+ return []
69
+
70
+ try:
71
+ # Try DuckDuckGo with retries
72
+ logger.info(f"Executing DuckDuckGo search for query: {query}")
73
+ results = await try_duckduckgo(query)
74
+
75
+ # Fall back to SerpAPI if DuckDuckGo fails
76
+ if not results:
77
+ logger.info(f"DuckDuckGo returned no results, falling back to SerpAPI for query: {query}")
78
+ results = await try_serpapi(query)
79
+
80
+ # Rank results if embedder is provided
81
+ if embedder and results:
82
+ from sentence_transformers import util
83
+ query_embedding = embedder.encode(original_query, convert_to_tensor=True)
84
+ result_embeddings = embedder.encode(results, convert_to_tensor=True)
85
+ scores = util.cos_sim(query_embedding, result_embeddings)[0]
86
+ ranked_results = [results[i] for i in scores.argsort(descending=True)]
87
+ return ranked_results[:3]
88
+
89
+ return results[:3] if results else []
90
+ except Exception as e:
91
+ logger.error(f"Search failed for query '{query}': {e}")
92
+ return []
93
+
94
+ duckduckgo_search_tool = StructuredTool.from_function(
95
+ func=duckduckgo_search_func,
96
+ name="duckduckgo_search_tool",
97
+ args_schema=DuckDuckGoSearchInput,
98
+ coroutine=duckduckgo_search_func
99
+ )
tools/file_fetcher.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ssl
3
+ import aiohttp
4
+ import logging
5
+ from typing import Dict
6
+ from urllib.parse import urljoin
7
+
8
+ # Setup logging
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
10
+ logger = logging.getLogger(__name__)
11
+
12
+ async def fetch_task_file(task_id: str, question: str) -> Dict[str, bytes]:
13
+ """
14
+ Fetch a file associated with a task from the GAIA API.
15
+ Returns a dictionary of file extensions to content.
16
+ """
17
+ results = {}
18
+ base_url = "https://gaia-benchmark-api.hf.space/files/" # Updated URL
19
+ extensions = ["xlsx", "csv", "pdf", "txt", "mp3", "jpg", "png"]
20
+
21
+ ssl_context = ssl.create_default_context()
22
+ ssl_context.check_hostname = False
23
+ ssl_context.verify_mode = ssl.CERT_NONE
24
+
25
+ async with aiohttp.ClientSession(
26
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context),
27
+ timeout=aiohttp.ClientTimeout(total=30)
28
+ ) as session:
29
+ for ext in extensions:
30
+ file_url = urljoin(base_url, f"{task_id}/{task_id}.{ext}")
31
+ try:
32
+ async with session.get(file_url) as response:
33
+ if response.status == 200:
34
+ content = await response.read()
35
+ results[ext] = content
36
+ logger.info(f"Fetched {ext} for task {task_id}")
37
+ else:
38
+ logger.warning(f"No {ext} for task {task_id}: HTTP {response.status}")
39
+ except Exception as e:
40
+ logger.warning(f"Error fetching {ext} for task {task_id}: {str(e)}")
41
+
42
+ return results
tools/file_parser.py CHANGED
@@ -1,36 +1,112 @@
1
- from langchain_core.tools import tool
2
- import pandas as pd
3
- import PyPDF2
4
  import logging
5
  import os
 
 
 
 
 
 
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
- @tool
10
- async def file_parser_tool(task_id: str, file_type: str) -> str:
11
- """Parse a file based on task_id and file_type."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  try:
13
- file_path = f"temp_{task_id}.{file_type}"
14
  if not os.path.exists(file_path):
15
  logger.warning(f"File not found: {file_path}")
16
  return "File not found"
17
 
18
- if file_type == "csv":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  df = pd.read_csv(file_path)
20
- return df.to_string()
21
- elif file_type == "txt":
22
- with open(file_path, "r", encoding="utf-8") as f:
23
- return f.read()
 
 
 
 
24
  elif file_type == "pdf":
25
  with open(file_path, "rb") as f:
26
  reader = PyPDF2.PdfReader(f)
27
- text = "".join(page.extract_text() for page in reader.pages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return text
29
- elif file_type in ["xlsx", "xls"]:
30
- df = pd.read_excel(file_path, engine="openpyxl")
31
- return df.to_string()
 
 
 
 
32
  else:
 
33
  return f"Unsupported file type: {file_type}"
 
34
  except Exception as e:
35
  logger.error(f"Error parsing file for task {task_id}: {e}")
36
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import os
3
+ import pandas as pd
4
+ import PyPDF2
5
+ import speech_recognition as sr
6
+ import re
7
+ from langchain_core.tools import StructuredTool
8
+ from pydantic import BaseModel, Field
9
+ from typing import Optional
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ class FileParserInput(BaseModel):
14
+ task_id: str = Field(description="Task identifier")
15
+ file_type: str = Field(description="File extension (e.g., pdf, csv)")
16
+ file_path: str = Field(description="Path to the file")
17
+ query: Optional[str] = Field(description="Query related to the file", default=None)
18
+
19
+ async def file_parser_func(task_id: str, file_type: str, file_path: str, query: Optional[str] = None) -> str:
20
+ """
21
+ Parse a file based on task_id, file_type, file_path, and query context.
22
+
23
+ Args:
24
+ task_id (str): Task identifier.
25
+ file_type (str): File extension (e.g., 'xlsx', 'mp3', 'pdf').
26
+ file_path (str): Path to the file.
27
+ query (Optional[str]): Question context to guide parsing (e.g., for specific data extraction).
28
+
29
+ Returns:
30
+ str: Parsed content or error message.
31
+ """
32
  try:
 
33
  if not os.path.exists(file_path):
34
  logger.warning(f"File not found: {file_path}")
35
  return "File not found"
36
 
37
+ logger.info(f"Parsing file: {file_path} for task {task_id}")
38
+
39
+ if file_type in ["xlsx", "xls"]:
40
+ df = pd.read_excel(file_path, engine="openpyxl")
41
+ if query and ("sum" in query.lower() or "total" in query.lower()):
42
+ numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
43
+ if numerical_cols.empty:
44
+ return "No numerical data found"
45
+ if "food" in query.lower():
46
+ food_rows = df[df.apply(lambda x: "food" in str(x).lower(), axis=1)]
47
+ if not food_rows.empty and numerical_cols[0] in food_rows:
48
+ total = food_rows[numerical_cols[0]].sum()
49
+ return f"{total:.2f}"
50
+ total = df[numerical_cols[0]].sum()
51
+ return f"{total:.2f}"
52
+ return df.to_string(index=False)
53
+
54
+ elif file_type == "csv":
55
  df = pd.read_csv(file_path)
56
+ if query and ("sum" in query.lower() or "total" in query.lower()):
57
+ numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
58
+ if numerical_cols.empty:
59
+ return "No numerical data found"
60
+ total = df[numerical_cols[0]].sum()
61
+ return f"{total:.2f}"
62
+ return df.to_string(index=False)
63
+
64
  elif file_type == "pdf":
65
  with open(file_path, "rb") as f:
66
  reader = PyPDF2.PdfReader(f)
67
+ text = "".join(page.extract_text() or "" for page in reader.pages)
68
+ if query and "page number" in query.lower():
69
+ pages = re.findall(r'\b\d+\b', text)
70
+ return ", ".join(sorted(pages, key=int)) if pages else "No page numbers found"
71
+ return text.strip() or "No text extracted"
72
+
73
+ elif file_type == "txt":
74
+ with open(file_path, "r", encoding="utf-8") as f:
75
+ text = f.read()
76
+ if query and "page number" in query.lower():
77
+ pages = re.findall(r'\b\d+\b', text)
78
+ return ", ".join(sorted(pages, key=int)) if pages else "No page numbers found"
79
+ return text.strip()
80
+
81
+ elif file_type == "mp3":
82
+ recognizer = sr.Recognizer()
83
+ with sr.AudioFile(file_path) as source:
84
+ audio = recognizer.record(source)
85
+ try:
86
+ text = recognizer.recognize_google(audio)
87
+ logger.debug(f"Transcribed audio: {text}")
88
+ if query and "page number" in query.lower():
89
+ pages = re.findall(r'\b\d+\b', text)
90
+ return ", ".join(sorted(pages, key=int)) if pages else "No page numbers provided"
91
  return text
92
+ except sr.UnknownValueError:
93
+ logger.error("Could not understand audio")
94
+ return "No text transcribed from audio"
95
+ except Exception as e:
96
+ logger.error(f"Audio parsing failed: {e}")
97
+ return "Error transcribing audio"
98
+
99
  else:
100
+ logger.warning(f"Unsupported file type: {file_type}")
101
  return f"Unsupported file type: {file_type}"
102
+
103
  except Exception as e:
104
  logger.error(f"Error parsing file for task {task_id}: {e}")
105
+ return f"Error: {str(e)}"
106
+
107
+ file_parser_tool = StructuredTool.from_function(
108
+ func=file_parser_func,
109
+ name="file_parser_tool",
110
+ args_schema=FileParserInput,
111
+ coroutine=file_parser_func
112
+ )
tools/guest_info.py CHANGED
@@ -1,20 +1,47 @@
1
- from langchain_core.tools import tool
2
- from retriever import load_guest_dataset
3
  import logging
 
 
 
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
- @tool
8
- async def guest_info_retriever_tool(query: str) -> str:
9
- """Retrieve detailed information about gala guests based on their name or relation."""
 
 
 
 
 
 
 
 
 
 
10
  try:
11
- docs = load_guest_dataset()
12
- from langchain_community.retrievers import BM25Retriever
13
- retriever = BM25Retriever.from_documents(docs)
14
- results = retriever.get_relevant_documents(query)
15
- if results:
16
- return "\n\n".join([doc.page_content for doc in results[:3]])
17
- return "No matching guest information found."
 
 
 
 
 
 
 
 
18
  except Exception as e:
19
  logger.error(f"Error retrieving guest info for query '{query}': {e}")
20
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from langchain_core.tools import StructuredTool
3
+ from pydantic import BaseModel, Field
4
+ from datasets import load_dataset
5
+ from rank_bm25 import BM25Okapi
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
+ class GuestInfoInput(BaseModel):
10
+ query: str = Field(description="Query about guest information")
11
+
12
+ async def guest_info_func(query: str) -> str:
13
+ """
14
+ Retrieve guest information based on a query.
15
+
16
+ Args:
17
+ query (str): Query about guest information.
18
+
19
+ Returns:
20
+ str: Guest information or error message.
21
+ """
22
  try:
23
+ logger.info(f"Retrieving guest info for query: {query}")
24
+ dataset = load_dataset("agents-course/unit3-invitees", split="train")
25
+ logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset")
26
+
27
+ documents = [f"{row['name']} {row['relation']}" for row in dataset]
28
+ tokenized_docs = [doc.lower().split() for doc in documents]
29
+ bm25 = BM25Okapi(tokenized_docs)
30
+
31
+ tokenized_query = query.lower().split()
32
+ scores = bm25.get_scores(tokenized_query)
33
+ best_idx = scores.argmax()
34
+
35
+ if scores[best_idx] > 0:
36
+ return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}"
37
+ return "No matching guest found"
38
  except Exception as e:
39
  logger.error(f"Error retrieving guest info for query '{query}': {e}")
40
+ return f"Error: {str(e)}"
41
+
42
+ guest_info_retriever_tool = StructuredTool.from_function(
43
+ func=guest_info_func,
44
+ name="guest_info_retriever_tool",
45
+ args_schema=GuestInfoInput,
46
+ coroutine=guest_info_func
47
+ )
tools/hub_stats.py CHANGED
@@ -1,17 +1,54 @@
1
- from langchain_core.tools import tool
2
- from huggingface_hub import list_models
3
  import logging
 
 
 
 
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @tool
8
  async def hub_stats_tool(author: str) -> str:
9
- """Fetch the most downloaded model from a specific author on Hugging Face Hub."""
 
 
 
 
 
 
 
 
10
  try:
11
- models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
12
- if models:
 
 
 
 
 
 
 
 
 
 
 
13
  model = models[0]
14
- return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
15
  return f"No models found for author {author}."
16
  except Exception as e:
17
  logger.error(f"Error fetching models for {author}: {e}")
 
1
+ import aiohttp
2
+ import ssl
3
  import logging
4
+ from langchain_core.tools import tool
5
+ from tenacity import retry, stop_after_attempt, wait_exponential
6
+ from typing import Optional
7
+ import json
8
+ import os
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
13
+ async def fetch_hf_models(author: str) -> Optional[dict]:
14
+ url = f"https://huggingface.co/api/models?author={author}&sort=downloads&direction=-1&limit=1"
15
+ ssl_context = ssl.create_default_context()
16
+ try:
17
+ async with aiohttp.ClientSession() as session:
18
+ async with session.get(url, ssl=ssl_context) as response:
19
+ response.raise_for_status()
20
+ return await response.json()
21
+ except aiohttp.ClientError as e:
22
+ logger.error(f"Failed to fetch models for {author}: {e}")
23
+ raise
24
+
25
  @tool
26
  async def hub_stats_tool(author: str) -> str:
27
+ """
28
+ Fetch the most downloaded model from a specific author on Hugging Face Hub.
29
+
30
+ Args:
31
+ author (str): Hugging Face author username.
32
+
33
+ Returns:
34
+ str: Model information or error message.
35
+ """
36
  try:
37
+ # Check local cache
38
+ cache_file = f"temp/hf_cache_{author}.json"
39
+ if os.path.exists(cache_file):
40
+ with open(cache_file, "r") as f:
41
+ models = json.load(f)
42
+ logger.debug(f"Loaded cached models for {author}")
43
+ else:
44
+ models = await fetch_hf_models(author)
45
+ os.makedirs("temp", exist_ok=True)
46
+ with open(cache_file, "w") as f:
47
+ json.dump(models, f)
48
+
49
+ if models and isinstance(models, list) and models:
50
  model = models[0]
51
+ return f"The most downloaded model by {author} is {model['id']} with {model.get('downloads', 0):,} downloads."
52
  return f"No models found for author {author}."
53
  except Exception as e:
54
  logger.error(f"Error fetching models for {author}: {e}")
tools/image_parser.py CHANGED
@@ -1,25 +1,43 @@
1
- from langchain_core.tools import tool
2
- import easyocr
3
  import logging
4
  import os
 
 
 
5
 
6
  logger = logging.getLogger(__name__)
7
- reader = easyocr.Reader(['en'])
8
 
9
- @tool
10
- async def image_parser_tool(file_path: str, task: str = "describe", match_query: str = "") -> str:
11
- """Parse text from an image."""
 
 
 
 
 
 
 
 
 
 
 
 
12
  try:
13
  if not os.path.exists(file_path):
14
- logger.warning(f"Image not found: {file_path}")
15
- return "Image not found"
16
 
17
- results = reader.readtext(file_path)
18
- text = " ".join(result[1] for result in results)
19
-
20
- if task == "match" and match_query:
21
- return str(match_query.lower() in text.lower())
22
- return text
23
  except Exception as e:
24
- logger.error(f"Error parsing image {file_path}: {e}")
25
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import os
3
+ from langchain_core.tools import StructuredTool
4
+ from pydantic import BaseModel, Field
5
+ import easyocr
6
 
7
  logger = logging.getLogger(__name__)
 
8
 
9
+ class ImageParserInput(BaseModel):
10
+ task_id: str = Field(description="Task identifier")
11
+ file_path: str = Field(description="Path to the image file")
12
+
13
+ async def image_parser_func(task_id: str, file_path: str) -> str:
14
+ """
15
+ Parse text from an image file using OCR.
16
+
17
+ Args:
18
+ task_id (str): Task identifier.
19
+ file_path (str): Path to the image file.
20
+
21
+ Returns:
22
+ str: Extracted text or error message.
23
+ """
24
  try:
25
  if not os.path.exists(file_path):
26
+ logger.warning(f"Image file not found: {file_path}")
27
+ return "Image file not found"
28
 
29
+ logger.info(f"Parsing image: {file_path} for task {task_id}")
30
+ reader = easyocr.Reader(['en'], model_storage_directory='./cache')
31
+ result = reader.readtext(file_path, detail=0)
32
+ text = " ".join(result).strip()
33
+ return text if text else "No text extracted from image"
 
34
  except Exception as e:
35
+ logger.error(f"Error parsing image for task {task_id}: {e}")
36
+ return f"Error: {str(e)}"
37
+
38
+ image_parser_tool = StructuredTool.from_function(
39
+ func=image_parser_func,
40
+ name="image_parser_tool",
41
+ args_schema=ImageParserInput,
42
+ coroutine=image_parser_func
43
+ )
tools/search.py CHANGED
@@ -1,106 +1,103 @@
 
1
  import os
2
- import json
3
- import asyncio
4
- # from serpapi import GoogleSearch
5
- from google_search_results import GoogleSearch
6
- from langchain.tools import Tool
7
- from typing import List, Dict, Any
8
- from langchain_core.prompts import ChatPromptTemplate
9
- from langchain_core.messages import SystemMessage, HumanMessage
10
 
11
- def search_tool(query: str) -> List[str]:
 
 
 
 
 
12
  """
13
- Perform a web search using SERPAPI with retries.
14
 
15
  Args:
16
- query: Search query string.
17
 
18
  Returns:
19
- List of search result snippets.
20
-
21
- Raises:
22
- Exception: If search fails after retries.
23
  """
24
- params = {
25
- "q": query,
26
- "api_key": os.getenv("SERPAPI_API_KEY"),
27
- "num": 5,
28
- }
29
-
30
- for attempt in range(3):
31
- try:
32
- search = GoogleSearch(params, timeout=30)
33
- results = search.get_dict()
34
- organic_results = results.get("organic_results", [])
35
- return [r.get("snippet", "") for r in organic_results]
36
- except Exception as e:
37
- print(f"INFO - SERPAPI retry {attempt + 1}/3 due to: {e}")
38
- asyncio.sleep(2)
39
-
40
- raise Exception("SERPAPI failed after retries")
41
 
42
- async def multi_hop_search_tool(query: str, steps: int = 3, llm_client: Any = None, llm_type: str = None) -> List[Dict[str, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- Perform iterative web searches for complex queries, refining the query using an LLM.
45
 
46
  Args:
47
- query: Initial search query.
48
- steps: Number of search iterations.
49
- llm_client: LLM client for query refinement.
50
- llm_type: Type of LLM client ("together", "hf_api", or "hf_local").
 
51
 
52
  Returns:
53
- List of dictionaries containing search result content.
54
  """
55
- results = []
56
- current_query = query
57
-
58
- for step in range(steps):
59
- try:
60
- # Perform search
61
- search_results = search_tool(current_query)
62
- results.extend([{"content": str(r)} for r in search_results])
 
63
 
64
- # Refine query using LLM if available
65
- if llm_client and step < steps - 1:
66
- prompt = ChatPromptTemplate.from_messages([
67
- SystemMessage(content="""Refine the following query to dig deeper into the topic, focusing on missing details or related aspects. Return ONLY the refined query as plain text, no explanations."""),
68
- HumanMessage(content=f"Original query: {current_query}\nPrevious results: {json.dumps(search_results[:2], indent=2)}")
69
- ])
70
  messages = [
71
- {"role": "system", "content": prompt[0].content},
72
- {"role": "user", "content": prompt[1].content}
73
  ]
74
-
75
- try:
76
- if llm_type == "hf_local":
77
- model, tokenizer = llm_client
78
- inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("mps")
79
- outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7)
80
- refined_query = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
81
- else:
82
- response = llm_client.chat.completions.create(
83
- model=llm_client.model if llm_type == "together" else "meta-llama/Llama-3.2-1B-Instruct",
84
- messages=messages,
85
- max_tokens=100,
86
- temperature=0.7
87
- )
88
- refined_query = response.choices[0].message.content.strip()
89
-
90
- current_query = refined_query if refined_query else f"more details on {current_query}"
91
- except Exception as e:
92
- print(f"INFO - Query refinement failed at step {step + 1}: {e}")
93
- current_query = f"more details on {current_query}"
94
-
95
- await asyncio.sleep(1) # Rate limit
96
- except Exception as e:
97
- print(f"INFO - Multi-hop search step {step + 1} failed: {e}")
98
- break
99
-
100
- return results
101
 
102
- multi_hop_search_tool = Tool.from_function(
103
- func=multi_hop_search_tool,
104
  name="multi_hop_search_tool",
105
- description="Performs iterative web searches for complex queries, refining the query with an LLM."
 
106
  )
 
1
+ import logging
2
  import os
3
+ from langchain_core.tools import StructuredTool
4
+ from pydantic import BaseModel, Field
5
+ from typing import Optional, List
6
+ from serpapi import GoogleSearch
 
 
 
 
7
 
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class SearchInput(BaseModel):
11
+ query: str = Field(description="Search query")
12
+
13
+ async def search_func(query: str) -> List[str]:
14
  """
15
+ Perform a web search using SerpAPI and return relevant snippets.
16
 
17
  Args:
18
+ query (str): The search query to execute.
19
 
20
  Returns:
21
+ List[str]: A list of search result snippets.
 
 
 
22
  """
23
+ try:
24
+ logger.info(f"Executing SerpAPI search for query: {query}")
25
+ params = {
26
+ "q": query,
27
+ "api_key": os.getenv("SERPAPI_API_KEY"),
28
+ "num": 10
29
+ }
30
+ search = GoogleSearch(params)
31
+ results = search.get_dict().get("organic_results", [])
32
+ return [result.get("snippet", "") for result in results if "snippet" in result]
33
+ except Exception as e:
34
+ logger.error(f"SerpAPI search failed for query '{query}': {e}")
35
+ return []
 
 
 
 
36
 
37
+ search_tool = StructuredTool.from_function(
38
+ func=search_func,
39
+ name="search_tool",
40
+ args_schema=SearchInput,
41
+ coroutine=search_func
42
+ )
43
+
44
+ class MultiHopSearchInput(BaseModel):
45
+ query: str = Field(description="Multi-hop search query")
46
+ steps: int = Field(description="Number of search steps", ge=1, le=3)
47
+ llm_client: Optional[object] = Field(description="LLM client", default=None)
48
+ llm_type: Optional[str] = Field(description="LLM type", default="together")
49
+ llm_model: Optional[str] = Field(description="LLM model", default="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free")
50
+
51
+ async def multi_hop_search_func(query: str, steps: int, llm_client: Optional[object] = None, llm_type: Optional[str] = "together", llm_model: Optional[str] = "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free") -> List[str]:
52
  """
53
+ Perform a multi-hop web search using SerpAPI with iterative query refinement.
54
 
55
  Args:
56
+ query (str): The initial multi-hop search query.
57
+ steps (int): Number of search steps to perform (1 to 3).
58
+ llm_client (Optional[object]): LLM client for query refinement.
59
+ llm_type (Optional[str]): Type of LLM (e.g., 'together').
60
+ llm_model (Optional[str]): LLM model name.
61
 
62
  Returns:
63
+ List[str]: A list of search result snippets from all steps.
64
  """
65
+ try:
66
+ logger.info(f"Executing multi-hop search for query: {query}, steps: {steps}")
67
+ results = []
68
+ current_query = query
69
+
70
+ for step in range(steps):
71
+ logger.info(f"Multi-hop step {step + 1}: {current_query}")
72
+ step_results = await search_func(current_query)
73
+ results.extend(step_results)
74
 
75
+ if step < steps - 1 and llm_client:
76
+ prompt = f"Given the query '{current_query}' and results: {step_results[:3]}, generate a follow-up search query to refine or expand the search."
 
 
 
 
77
  messages = [
78
+ {"role": "system", "content": "Generate a single search query as a string."},
79
+ {"role": "user", "content": prompt}
80
  ]
81
+ if llm_type == "together":
82
+ response = llm_client.chat.completions.create(
83
+ model=llm_model,
84
+ messages=messages,
85
+ max_tokens=50,
86
+ temperature=0.7
87
+ )
88
+ current_query = response.choices[0].message.content.strip()
89
+ else:
90
+ logger.warning("LLM not configured for multi-hop refinement")
91
+ break
92
+
93
+ return results[:5] if results else ["No results found"]
94
+ except Exception as e:
95
+ logger.error(f"Multi-hop search failed for query '{query}': {e}")
96
+ return [f"Error: {str(e)}"]
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ multi_hop_search_tool = StructuredTool.from_function(
99
+ func=multi_hop_search_func,
100
  name="multi_hop_search_tool",
101
+ args_schema=MultiHopSearchInput,
102
+ coroutine=multi_hop_search_func
103
  )
tools/weather_info.py CHANGED
@@ -1,23 +1,50 @@
1
- from langchain_core.tools import tool
2
- import requests
3
  import logging
4
  import os
 
 
5
  from dotenv import load_dotenv
6
 
7
  logger = logging.getLogger(__name__)
8
  load_dotenv()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @tool
11
- async def weather_info_tool(location: str) -> str:
12
- """Fetch real weather information for a given location."""
 
 
 
 
 
 
 
 
 
13
  try:
14
  api_key = os.getenv("OPENWEATHERMAP_API_KEY")
15
  if not api_key:
16
  logger.error("OPENWEATHERMAP_API_KEY not set")
17
  return "Weather unavailable: API key missing"
18
 
19
- url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={api_key}&units=metric"
20
- response = requests.get(url).json()
 
 
 
21
  if response.get("cod") == 200:
22
  condition = response["weather"][0]["description"]
23
  temp = response["main"]["temp"]
 
1
+ import aiohttp
2
+ import ssl
3
  import logging
4
  import os
5
+ from langchain_core.tools import tool
6
+ from tenacity import retry, stop_after_attempt, wait_exponential
7
  from dotenv import load_dotenv
8
 
9
  logger = logging.getLogger(__name__)
10
  load_dotenv()
11
 
12
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
13
+ async def fetch_weather(location: str, api_key: str) -> dict:
14
+ url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={api_key}&units=metric"
15
+ ssl_context = ssl.create_default_context()
16
+ try:
17
+ async with aiohttp.ClientSession() as session:
18
+ async with session.get(url, ssl=ssl_context) as response:
19
+ response.raise_for_status()
20
+ return await response.json()
21
+ except aiohttp.ClientError as e:
22
+ logger.error(f"Failed to fetch weather for {location}: {e}")
23
+ raise
24
+
25
  @tool
26
+ async def weather_info_tool(location: str, query_type: str = "current") -> str:
27
+ """
28
+ Fetch weather information for a given location.
29
+
30
+ Args:
31
+ location (str): City or location name.
32
+ query_type (str): Type of weather query ('current', 'forecast'; default: 'current').
33
+
34
+ Returns:
35
+ str: Weather information or error message.
36
+ """
37
  try:
38
  api_key = os.getenv("OPENWEATHERMAP_API_KEY")
39
  if not api_key:
40
  logger.error("OPENWEATHERMAP_API_KEY not set")
41
  return "Weather unavailable: API key missing"
42
 
43
+ if query_type != "current":
44
+ logger.warning(f"Query type '{query_type}' not supported; using current weather")
45
+ query_type = "current"
46
+
47
+ response = await fetch_weather(location, api_key)
48
  if response.get("cod") == 200:
49
  condition = response["weather"][0]["description"]
50
  temp = response["main"]["temp"]