mjschock commited on
Commit
2da6a11
·
unverified ·
1 Parent(s): 0305659

Refactor main_v2.py to update task formatting for dual answer requests, enhancing response structure. Implement error handling for JSON parsing in agent results, ensuring robust output. Add unit tests in test_questions.py to validate succinct answer accuracy against expected values. Remove unused extract_final_answer utility from utils.py, streamlining the codebase.

Browse files
Files changed (7) hide show
  1. .vscode/tasks.json +7 -2
  2. app.py +1 -1
  3. main.py +2 -2
  4. main_v2.py +41 -20
  5. test_questions.py +35 -0
  6. tools/smart_search/tool.py +92 -78
  7. utils.py +0 -70
.vscode/tasks.json CHANGED
@@ -4,7 +4,7 @@
4
  "version": "2.0.0",
5
  "tasks": [
6
  {
7
- "label": "serve",
8
  "type": "shell",
9
  "command": "uv run python -m phoenix.server.main serve"
10
  },
@@ -12,6 +12,11 @@
12
  "label": "run",
13
  "type": "shell",
14
  "command": "uv run python -m main_v2"
 
 
 
 
 
15
  }
16
  ]
17
- }
 
4
  "version": "2.0.0",
5
  "tasks": [
6
  {
7
+ "label": "prerun",
8
  "type": "shell",
9
  "command": "uv run python -m phoenix.server.main serve"
10
  },
 
12
  "label": "run",
13
  "type": "shell",
14
  "command": "uv run python -m main_v2"
15
+ },
16
+ {
17
+ "label": "test",
18
+ "type": "shell",
19
+ "command": "uv run python -m unittest test_questions.py"
20
  }
21
  ]
22
+ }
app.py CHANGED
@@ -30,7 +30,7 @@ class BasicAgent:
30
  # question_counter += 1
31
 
32
  # if question_counter > 1:
33
- # return "This is a default answer."
34
 
35
  # fixed_answer = "This is a default answer."
36
  # print(f"Agent returning fixed answer: {fixed_answer}")
 
30
  # question_counter += 1
31
 
32
  # if question_counter > 1:
33
+ # return "This is a default answer."
34
 
35
  # fixed_answer = "This is a default answer."
36
  # print(f"Agent returning fixed answer: {fixed_answer}")
main.py CHANGED
@@ -14,14 +14,14 @@ from langgraph.graph import END, START, StateGraph
14
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
15
  from opentelemetry.sdk.trace.export import BatchSpanProcessor
16
  from phoenix.otel import register
17
- from prompts import MANAGER_SYSTEM_PROMPT
18
  from smolagents import CodeAgent, LiteLLMModel
19
  from smolagents.memory import ActionStep, FinalAnswerStep
20
  from smolagents.monitoring import LogLevel
 
21
 
22
  from agents import create_data_analysis_agent, create_media_agent, create_web_agent
 
23
  from tools import perform_calculation, web_search
24
- from utils import extract_final_answer
25
 
26
  litellm._turn_on_debug()
27
 
 
14
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
15
  from opentelemetry.sdk.trace.export import BatchSpanProcessor
16
  from phoenix.otel import register
 
17
  from smolagents import CodeAgent, LiteLLMModel
18
  from smolagents.memory import ActionStep, FinalAnswerStep
19
  from smolagents.monitoring import LogLevel
20
+ from utils import extract_final_answer
21
 
22
  from agents import create_data_analysis_agent, create_media_agent, create_web_agent
23
+ from prompts import MANAGER_SYSTEM_PROMPT
24
  from tools import perform_calculation, web_search
 
25
 
26
  litellm._turn_on_debug()
27
 
main_v2.py CHANGED
@@ -13,7 +13,6 @@ from smolagents import CodeAgent, LiteLLMModel
13
  from smolagents.monitoring import LogLevel
14
 
15
  from tools.smart_search.tool import SmartSearchTool
16
- from utils import extract_final_answer
17
 
18
  _disable_debugging()
19
 
@@ -78,20 +77,32 @@ agent.visualize()
78
 
79
 
80
  def main(task: str):
81
- # Format the task with GAIA-style instructions
82
- # gaia_task = f"""Instructions:
83
- # 1. Your response must contain ONLY the answer to the question, nothing else
84
- # 2. Do not repeat the question or any part of it
85
- # 3. Do not include any explanations, reasoning, or context
86
- # 4. Do not include source attribution or references
87
- # 5. Do not use phrases like "The answer is" or "I found that"
88
- # 6. Do not include any formatting, bullet points, or line breaks
89
- # 7. If the answer is a number, return only the number
90
- # 8. If the answer requires multiple items, separate them with commas
91
- # 9. If the answer requires ordering, maintain the specified order
92
- # 10. Use the most direct and succinct form possible
93
-
94
- # {task}"""
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  result = agent.run(
97
  additional_args=None,
@@ -99,13 +110,23 @@ def main(task: str):
99
  max_steps=3,
100
  reset=True,
101
  stream=False,
102
- task=task,
103
- # task=gaia_task,
104
  )
105
 
106
- logger.info(f"Result: {result}")
 
 
107
 
108
- return extract_final_answer(result)
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  if __name__ == "__main__":
@@ -119,7 +140,7 @@ if __name__ == "__main__":
119
  response.raise_for_status()
120
  questions_data = response.json()
121
 
122
- for question_data in questions_data[:1]:
123
  file_name = question_data["file_name"]
124
  level = question_data["Level"]
125
  question = question_data["question"]
 
13
  from smolagents.monitoring import LogLevel
14
 
15
  from tools.smart_search.tool import SmartSearchTool
 
16
 
17
  _disable_debugging()
18
 
 
77
 
78
 
79
  def main(task: str):
80
+ # Format the task to request both succinct and verbose answers
81
+ formatted_task = f"""Please provide two answers to the following question:
82
+
83
+ 1. A succinct answer that follows these rules:
84
+ - Contains ONLY the answer, nothing else
85
+ - Does not repeat the question
86
+ - Does not include explanations, reasoning, or context
87
+ - Does not include source attribution or references
88
+ - Does not use phrases like "The answer is" or "I found that"
89
+ - Does not include formatting, bullet points, or line breaks
90
+ - If the answer is a number, return only the number
91
+ - If the answer requires multiple items, separate them with commas
92
+ - If the answer requires ordering, maintain the specified order
93
+ - Uses the most direct and succinct form possible
94
+
95
+ 2. A verbose answer that includes:
96
+ - The complete answer with all relevant details
97
+ - Explanations and reasoning
98
+ - Context and background information
99
+ - Source attribution where appropriate
100
+
101
+ Question: {task}
102
+
103
+ Please format your response as a JSON object with two keys:
104
+ - "succinct_answer": The concise answer following the rules above
105
+ - "verbose_answer": The detailed explanation with context"""
106
 
107
  result = agent.run(
108
  additional_args=None,
 
110
  max_steps=3,
111
  reset=True,
112
  stream=False,
113
+ task=formatted_task,
 
114
  )
115
 
116
+ # Parse the result into a dictionary
117
+ try:
118
+ import json
119
 
120
+ # Find the JSON object in the response
121
+ json_str = result[result.find("{") : result.rfind("}") + 1]
122
+ parsed_result = json.loads(json_str)
123
+ except (ValueError, AttributeError) as e:
124
+ logger.error(f"Error parsing result: {e}")
125
+ # If parsing fails, return the raw result
126
+ return result
127
+
128
+ logger.info(f"Result: {parsed_result}")
129
+ return parsed_result["succinct_answer"]
130
 
131
 
132
  if __name__ == "__main__":
 
140
  response.raise_for_status()
141
  questions_data = response.json()
142
 
143
+ for question_data in questions_data[:2]:
144
  file_name = question_data["file_name"]
145
  level = question_data["Level"]
146
  question = question_data["question"]
test_questions.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import requests
3
+ from main_v2 import main
4
+
5
+ class TestQuestions(unittest.TestCase):
6
+ def setUp(self):
7
+ self.api_url = "https://agents-course-unit4-scoring.hf.space"
8
+ self.questions_url = f"{self.api_url}/questions"
9
+
10
+ # Get questions from the API
11
+ response = requests.get(self.questions_url, timeout=15)
12
+ response.raise_for_status()
13
+ self.questions = response.json()
14
+
15
+ # Expected answers for each question
16
+ self.expected_answers = {
17
+ "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.": "3",
18
+ # Add more expected answers as you verify them
19
+ }
20
+
21
+ def test_questions(self):
22
+ """Test each question and verify the succinct answer matches the expected value."""
23
+ for question_data in self.questions:
24
+ question = question_data["question"]
25
+ if question in self.expected_answers:
26
+ expected_answer = self.expected_answers[question]
27
+ actual_answer = main(question)
28
+ self.assertEqual(
29
+ actual_answer,
30
+ expected_answer,
31
+ f"Question: {question}\nExpected: {expected_answer}\nGot: {actual_answer}"
32
+ )
33
+
34
+ if __name__ == "__main__":
35
+ unittest.main()
tools/smart_search/tool.py CHANGED
@@ -1,10 +1,10 @@
1
  import logging
2
  import re
3
- from typing import List, Dict, Optional
 
 
4
  from smolagents import Tool
5
  from smolagents.default_tools import DuckDuckGoSearchTool
6
- import requests
7
- from bs4 import BeautifulSoup
8
 
9
  logger = logging.getLogger(__name__)
10
 
@@ -12,7 +12,12 @@ logger = logging.getLogger(__name__)
12
  class SmartSearchTool(Tool):
13
  name = "smart_search"
14
  description = """A smart search tool that searches Wikipedia for information."""
15
- inputs = {"query": {"type": "string", "description": "The search query to find information"}}
 
 
 
 
 
16
  output_type = "string"
17
 
18
  def __init__(self):
@@ -20,30 +25,30 @@ class SmartSearchTool(Tool):
20
  self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
21
  self.api_url = "https://en.wikipedia.org/w/api.php"
22
  self.headers = {
23
- 'User-Agent': 'SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; [email protected])'
24
  }
25
 
26
  def get_wikipedia_page(self, title: str) -> Optional[str]:
27
  """Get the raw wiki markup of a Wikipedia page."""
28
  try:
29
  params = {
30
- 'action': 'query',
31
- 'prop': 'revisions',
32
- 'rvprop': 'content',
33
- 'rvslots': 'main',
34
- 'format': 'json',
35
- 'titles': title,
36
- 'redirects': 1
37
  }
38
  response = requests.get(self.api_url, params=params, headers=self.headers)
39
  response.raise_for_status()
40
  data = response.json()
41
-
42
  # Extract page content
43
- pages = data.get('query', {}).get('pages', {})
44
  for page_id, page_data in pages.items():
45
- if 'revisions' in page_data:
46
- return page_data['revisions'][0]['slots']['main']['*']
47
  return None
48
  except Exception as e:
49
  logger.error(f"Error getting Wikipedia page: {e}")
@@ -52,117 +57,126 @@ class SmartSearchTool(Tool):
52
  def clean_wiki_content(self, content: str) -> str:
53
  """Clean Wikipedia content by removing markup and formatting."""
54
  # Remove citations
55
- content = re.sub(r'\[\d+\]', '', content)
56
  # Remove edit links
57
- content = re.sub(r'\[edit\]', '', content)
58
  # Remove file links
59
- content = re.sub(r'\[\[File:.*?\]\]', '', content)
60
  # Convert links to just text
61
- content = re.sub(r'\[\[(?:[^|\]]*\|)?([^\]]+)\]\]', r'\1', content)
62
  # Remove HTML comments
63
- content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
64
  # Remove templates
65
- content = re.sub(r'\{\{.*?\}\}', '', content)
66
  # Remove small tags
67
- content = re.sub(r'<small>.*?</small>', '', content)
68
  # Normalize whitespace
69
- content = re.sub(r'\n\s*\n', '\n\n', content)
70
  return content.strip()
71
 
72
  def format_wiki_table(self, table_content: str) -> str:
73
  """Format a Wikipedia table into readable text."""
74
  # Split into rows
75
- rows = table_content.strip().split('\n')
76
  formatted_rows = []
77
  current_row = []
78
-
79
  for row in rows:
80
  # Skip empty rows and table structure markers
81
- if not row.strip() or row.startswith('|-') or row.startswith('|+'):
82
  if current_row:
83
- formatted_rows.append('\t'.join(current_row))
84
  current_row = []
85
  continue
86
-
87
  # Extract cells
88
  cells = []
89
  # Split the row into cells using | or ! as separators
90
- cell_parts = re.split(r'[|!]', row)
91
  for cell in cell_parts[1:]: # Skip the first empty part
92
  # Clean up the cell content
93
  cell = cell.strip()
94
  # Remove any remaining markup
95
- cell = re.sub(r'<.*?>', '', cell) # Remove HTML tags
96
- cell = re.sub(r'\[\[.*?\|(.*?)\]\]', r'\1', cell) # Convert links
97
- cell = re.sub(r'\[\[(.*?)\]\]', r'\1', cell) # Convert simple links
98
- cell = re.sub(r'\{\{.*?\}\}', '', cell) # Remove templates
99
- cell = re.sub(r'<small>.*?</small>', '', cell) # Remove small tags
100
- cell = re.sub(r'rowspan="\d+"', '', cell) # Remove rowspan
101
- cell = re.sub(r'colspan="\d+"', '', cell) # Remove colspan
102
- cell = re.sub(r'class=".*?"', '', cell) # Remove class attributes
103
- cell = re.sub(r'style=".*?"', '', cell) # Remove style attributes
104
- cell = re.sub(r'align=".*?"', '', cell) # Remove align attributes
105
- cell = re.sub(r'width=".*?"', '', cell) # Remove width attributes
106
- cell = re.sub(r'bgcolor=".*?"', '', cell) # Remove bgcolor attributes
107
- cell = re.sub(r'valign=".*?"', '', cell) # Remove valign attributes
108
- cell = re.sub(r'border=".*?"', '', cell) # Remove border attributes
109
- cell = re.sub(r'cellpadding=".*?"', '', cell) # Remove cellpadding attributes
110
- cell = re.sub(r'cellspacing=".*?"', '', cell) # Remove cellspacing attributes
111
- cell = re.sub(r'<ref.*?</ref>', '', cell) # Remove references
112
- cell = re.sub(r'<ref.*?/>', '', cell) # Remove empty references
113
- cell = re.sub(r'<br\s*/?>', ' ', cell) # Replace line breaks with spaces
114
- cell = re.sub(r'\s+', ' ', cell) # Normalize whitespace
 
 
 
 
 
 
115
  cells.append(cell)
116
-
117
  if cells:
118
  current_row.extend(cells)
119
-
120
  if current_row:
121
- formatted_rows.append('\t'.join(current_row))
122
-
123
  if formatted_rows:
124
- return '\n'.join(formatted_rows)
125
- return ''
126
 
127
  def extract_wikipedia_title(self, search_result: str) -> Optional[str]:
128
  """Extract Wikipedia page title from search result."""
129
  # Look for Wikipedia links in the format [Title - Wikipedia](url)
130
- wiki_match = re.search(r'\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)', search_result)
 
 
 
131
  if wiki_match:
132
  return wiki_match.group(1).strip()
133
  return None
134
 
135
  def forward(self, query: str) -> str:
136
  logger.info(f"Starting smart search for query: {query}")
137
-
138
  # First do a web search to find the Wikipedia page
139
  search_result = self.web_search_tool.forward(query)
140
  logger.info(f"Web search results: {search_result[:100]}...")
141
-
142
  # Extract Wikipedia page title from search results
143
  wiki_title = self.extract_wikipedia_title(search_result)
144
  if not wiki_title:
145
  return f"Could not find Wikipedia page in search results for '{query}'."
146
-
147
  # Get Wikipedia page content
148
  page_content = self.get_wikipedia_page(wiki_title)
149
  if not page_content:
150
  return f"Could not find Wikipedia page for '{wiki_title}'."
151
-
152
  # Format tables and content
153
  formatted_content = []
154
  current_section = []
155
  in_table = False
156
  table_content = []
157
-
158
- for line in page_content.split('\n'):
159
- if line.startswith('{|'):
160
  in_table = True
161
  table_content = [line]
162
- elif line.startswith('|}'):
163
  in_table = False
164
  table_content.append(line)
165
- formatted_table = self.format_wiki_table('\n'.join(table_content))
166
  if formatted_table:
167
  current_section.append(formatted_table)
168
  elif in_table:
@@ -171,49 +185,49 @@ class SmartSearchTool(Tool):
171
  if line.strip():
172
  current_section.append(line)
173
  elif current_section:
174
- formatted_content.append('\n'.join(current_section))
175
  current_section = []
176
-
177
  if current_section:
178
- formatted_content.append('\n'.join(current_section))
179
-
180
  # Clean and return the formatted content
181
- cleaned_content = self.clean_wiki_content('\n\n'.join(formatted_content))
182
  return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}"
183
 
184
 
185
  def main(query: str) -> str:
186
  """
187
  Test function to run the SmartSearchTool directly.
188
-
189
  Args:
190
  query: The search query to test
191
-
192
  Returns:
193
  The search results
194
  """
195
  # Configure logging
196
  logging.basicConfig(
197
  level=logging.INFO,
198
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
199
  )
200
-
201
  # Create and run the tool
202
  tool = SmartSearchTool()
203
  result = tool.forward(query)
204
-
205
  # Print the result
206
  print("\nSearch Results:")
207
  print("-" * 80)
208
  print(result)
209
  print("-" * 80)
210
-
211
  return result
212
 
213
 
214
  if __name__ == "__main__":
215
  import sys
216
-
217
  if len(sys.argv) > 1:
218
  query = " ".join(sys.argv[1:])
219
  main(query)
 
1
  import logging
2
  import re
3
+ from typing import Optional
4
+
5
+ import requests
6
  from smolagents import Tool
7
  from smolagents.default_tools import DuckDuckGoSearchTool
 
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
12
  class SmartSearchTool(Tool):
13
  name = "smart_search"
14
  description = """A smart search tool that searches Wikipedia for information."""
15
+ inputs = {
16
+ "query": {
17
+ "type": "string",
18
+ "description": "The search query to find information",
19
+ }
20
+ }
21
  output_type = "string"
22
 
23
  def __init__(self):
 
25
  self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
26
  self.api_url = "https://en.wikipedia.org/w/api.php"
27
  self.headers = {
28
+ "User-Agent": "SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; [email protected])"
29
  }
30
 
31
  def get_wikipedia_page(self, title: str) -> Optional[str]:
32
  """Get the raw wiki markup of a Wikipedia page."""
33
  try:
34
  params = {
35
+ "action": "query",
36
+ "prop": "revisions",
37
+ "rvprop": "content",
38
+ "rvslots": "main",
39
+ "format": "json",
40
+ "titles": title,
41
+ "redirects": 1,
42
  }
43
  response = requests.get(self.api_url, params=params, headers=self.headers)
44
  response.raise_for_status()
45
  data = response.json()
46
+
47
  # Extract page content
48
+ pages = data.get("query", {}).get("pages", {})
49
  for page_id, page_data in pages.items():
50
+ if "revisions" in page_data:
51
+ return page_data["revisions"][0]["slots"]["main"]["*"]
52
  return None
53
  except Exception as e:
54
  logger.error(f"Error getting Wikipedia page: {e}")
 
57
  def clean_wiki_content(self, content: str) -> str:
58
  """Clean Wikipedia content by removing markup and formatting."""
59
  # Remove citations
60
+ content = re.sub(r"\[\d+\]", "", content)
61
  # Remove edit links
62
+ content = re.sub(r"\[edit\]", "", content)
63
  # Remove file links
64
+ content = re.sub(r"\[\[File:.*?\]\]", "", content)
65
  # Convert links to just text
66
+ content = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", content)
67
  # Remove HTML comments
68
+ content = re.sub(r"<!--.*?-->", "", content, flags=re.DOTALL)
69
  # Remove templates
70
+ content = re.sub(r"\{\{.*?\}\}", "", content)
71
  # Remove small tags
72
+ content = re.sub(r"<small>.*?</small>", "", content)
73
  # Normalize whitespace
74
+ content = re.sub(r"\n\s*\n", "\n\n", content)
75
  return content.strip()
76
 
77
  def format_wiki_table(self, table_content: str) -> str:
78
  """Format a Wikipedia table into readable text."""
79
  # Split into rows
80
+ rows = table_content.strip().split("\n")
81
  formatted_rows = []
82
  current_row = []
83
+
84
  for row in rows:
85
  # Skip empty rows and table structure markers
86
+ if not row.strip() or row.startswith("|-") or row.startswith("|+"):
87
  if current_row:
88
+ formatted_rows.append("\t".join(current_row))
89
  current_row = []
90
  continue
91
+
92
  # Extract cells
93
  cells = []
94
  # Split the row into cells using | or ! as separators
95
+ cell_parts = re.split(r"[|!]", row)
96
  for cell in cell_parts[1:]: # Skip the first empty part
97
  # Clean up the cell content
98
  cell = cell.strip()
99
  # Remove any remaining markup
100
+ cell = re.sub(r"<.*?>", "", cell) # Remove HTML tags
101
+ cell = re.sub(r"\[\[.*?\|(.*?)\]\]", r"\1", cell) # Convert links
102
+ cell = re.sub(r"\[\[(.*?)\]\]", r"\1", cell) # Convert simple links
103
+ cell = re.sub(r"\{\{.*?\}\}", "", cell) # Remove templates
104
+ cell = re.sub(r"<small>.*?</small>", "", cell) # Remove small tags
105
+ cell = re.sub(r'rowspan="\d+"', "", cell) # Remove rowspan
106
+ cell = re.sub(r'colspan="\d+"', "", cell) # Remove colspan
107
+ cell = re.sub(r'class=".*?"', "", cell) # Remove class attributes
108
+ cell = re.sub(r'style=".*?"', "", cell) # Remove style attributes
109
+ cell = re.sub(r'align=".*?"', "", cell) # Remove align attributes
110
+ cell = re.sub(r'width=".*?"', "", cell) # Remove width attributes
111
+ cell = re.sub(r'bgcolor=".*?"', "", cell) # Remove bgcolor attributes
112
+ cell = re.sub(r'valign=".*?"', "", cell) # Remove valign attributes
113
+ cell = re.sub(r'border=".*?"', "", cell) # Remove border attributes
114
+ cell = re.sub(
115
+ r'cellpadding=".*?"', "", cell
116
+ ) # Remove cellpadding attributes
117
+ cell = re.sub(
118
+ r'cellspacing=".*?"', "", cell
119
+ ) # Remove cellspacing attributes
120
+ cell = re.sub(r"<ref.*?</ref>", "", cell) # Remove references
121
+ cell = re.sub(r"<ref.*?/>", "", cell) # Remove empty references
122
+ cell = re.sub(
123
+ r"<br\s*/?>", " ", cell
124
+ ) # Replace line breaks with spaces
125
+ cell = re.sub(r"\s+", " ", cell) # Normalize whitespace
126
  cells.append(cell)
127
+
128
  if cells:
129
  current_row.extend(cells)
130
+
131
  if current_row:
132
+ formatted_rows.append("\t".join(current_row))
133
+
134
  if formatted_rows:
135
+ return "\n".join(formatted_rows)
136
+ return ""
137
 
138
  def extract_wikipedia_title(self, search_result: str) -> Optional[str]:
139
  """Extract Wikipedia page title from search result."""
140
  # Look for Wikipedia links in the format [Title - Wikipedia](url)
141
+ wiki_match = re.search(
142
+ r"\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)",
143
+ search_result,
144
+ )
145
  if wiki_match:
146
  return wiki_match.group(1).strip()
147
  return None
148
 
149
  def forward(self, query: str) -> str:
150
  logger.info(f"Starting smart search for query: {query}")
151
+
152
  # First do a web search to find the Wikipedia page
153
  search_result = self.web_search_tool.forward(query)
154
  logger.info(f"Web search results: {search_result[:100]}...")
155
+
156
  # Extract Wikipedia page title from search results
157
  wiki_title = self.extract_wikipedia_title(search_result)
158
  if not wiki_title:
159
  return f"Could not find Wikipedia page in search results for '{query}'."
160
+
161
  # Get Wikipedia page content
162
  page_content = self.get_wikipedia_page(wiki_title)
163
  if not page_content:
164
  return f"Could not find Wikipedia page for '{wiki_title}'."
165
+
166
  # Format tables and content
167
  formatted_content = []
168
  current_section = []
169
  in_table = False
170
  table_content = []
171
+
172
+ for line in page_content.split("\n"):
173
+ if line.startswith("{|"):
174
  in_table = True
175
  table_content = [line]
176
+ elif line.startswith("|}"):
177
  in_table = False
178
  table_content.append(line)
179
+ formatted_table = self.format_wiki_table("\n".join(table_content))
180
  if formatted_table:
181
  current_section.append(formatted_table)
182
  elif in_table:
 
185
  if line.strip():
186
  current_section.append(line)
187
  elif current_section:
188
+ formatted_content.append("\n".join(current_section))
189
  current_section = []
190
+
191
  if current_section:
192
+ formatted_content.append("\n".join(current_section))
193
+
194
  # Clean and return the formatted content
195
+ cleaned_content = self.clean_wiki_content("\n\n".join(formatted_content))
196
  return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}"
197
 
198
 
199
  def main(query: str) -> str:
200
  """
201
  Test function to run the SmartSearchTool directly.
202
+
203
  Args:
204
  query: The search query to test
205
+
206
  Returns:
207
  The search results
208
  """
209
  # Configure logging
210
  logging.basicConfig(
211
  level=logging.INFO,
212
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
213
  )
214
+
215
  # Create and run the tool
216
  tool = SmartSearchTool()
217
  result = tool.forward(query)
218
+
219
  # Print the result
220
  print("\nSearch Results:")
221
  print("-" * 80)
222
  print(result)
223
  print("-" * 80)
224
+
225
  return result
226
 
227
 
228
  if __name__ == "__main__":
229
  import sys
230
+
231
  if len(sys.argv) > 1:
232
  query = " ".join(sys.argv[1:])
233
  main(query)
utils.py DELETED
@@ -1,70 +0,0 @@
1
- import re
2
- from typing import Union
3
-
4
-
5
- def extract_final_answer(result: Union[str, dict]) -> str:
6
- """
7
- Extract the final answer from the agent's result, removing explanations.
8
- GAIA requires concise, properly formatted answers.
9
-
10
- Args:
11
- result: The full result from the agent, either a string or a dictionary
12
-
13
- Returns:
14
- Extracted final answer
15
- """
16
- # Handle dictionary input
17
- if isinstance(result, dict):
18
- # Check for final_answer_text first (from agent output)
19
- if "final_answer_text" in result:
20
- return str(result["final_answer_text"])
21
- # Fall back to final_answer key
22
- if "final_answer" in result:
23
- return str(result["final_answer"])
24
- return "No final answer found in result"
25
-
26
- # Handle string input (original logic)
27
- # First check if there's a specific final_answer marker
28
- if "final_answer(" in result:
29
- # Try to extract the answer from final_answer call
30
- pattern = r"final_answer\(['\"](.*?)['\"]\)"
31
- matches = re.findall(pattern, result)
32
- if matches:
33
- return matches[-1] # Return the last final_answer if multiple exist
34
-
35
- # If no final_answer marker, look for lines that might contain the answer
36
- lines = result.strip().split("\n")
37
-
38
- # Check for typical patterns indicating a final answer
39
- for line in reversed(lines): # Start from the end
40
- line = line.strip()
41
-
42
- # Skip empty lines
43
- if not line:
44
- continue
45
-
46
- # Look for patterns like "Answer:", "Final answer:", etc.
47
- if re.match(r"^(answer|final answer|result):?\s+", line.lower()):
48
- return line.split(":", 1)[1].strip()
49
-
50
- # Check for answers that are comma-separated lists (common in GAIA)
51
- if (
52
- "," in line
53
- and len(line.split(",")) > 1
54
- and not line.startswith("#")
55
- and not line.startswith("print(")
56
- ):
57
- # It might be a comma-separated list answer
58
- return line
59
-
60
- # If no clear answer pattern is found, return the last non-empty line
61
- # (often the answer is simply the last output)
62
- for line in reversed(lines):
63
- if (
64
- line.strip()
65
- and not line.strip().startswith("#")
66
- and not line.strip().startswith("print(")
67
- ):
68
- return line.strip()
69
-
70
- return "No answer found"