Spaces:
Sleeping
Sleeping
Refactor main_v2.py to update task formatting for dual answer requests, enhancing response structure. Implement error handling for JSON parsing in agent results, ensuring robust output. Add unit tests in test_questions.py to validate succinct answer accuracy against expected values. Remove unused extract_final_answer utility from utils.py, streamlining the codebase.
2da6a11
unverified
| import logging | |
| import re | |
| from typing import Optional | |
| import requests | |
| from smolagents import Tool | |
| from smolagents.default_tools import DuckDuckGoSearchTool | |
| logger = logging.getLogger(__name__) | |
| class SmartSearchTool(Tool): | |
| name = "smart_search" | |
| description = """A smart search tool that searches Wikipedia for information.""" | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query to find information", | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self): | |
| super().__init__() | |
| self.web_search_tool = DuckDuckGoSearchTool(max_results=1) | |
| self.api_url = "https://en.wikipedia.org/w/api.php" | |
| self.headers = { | |
| "User-Agent": "SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; [email protected])" | |
| } | |
| def get_wikipedia_page(self, title: str) -> Optional[str]: | |
| """Get the raw wiki markup of a Wikipedia page.""" | |
| try: | |
| params = { | |
| "action": "query", | |
| "prop": "revisions", | |
| "rvprop": "content", | |
| "rvslots": "main", | |
| "format": "json", | |
| "titles": title, | |
| "redirects": 1, | |
| } | |
| response = requests.get(self.api_url, params=params, headers=self.headers) | |
| response.raise_for_status() | |
| data = response.json() | |
| # Extract page content | |
| pages = data.get("query", {}).get("pages", {}) | |
| for page_id, page_data in pages.items(): | |
| if "revisions" in page_data: | |
| return page_data["revisions"][0]["slots"]["main"]["*"] | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error getting Wikipedia page: {e}") | |
| return None | |
| def clean_wiki_content(self, content: str) -> str: | |
| """Clean Wikipedia content by removing markup and formatting.""" | |
| # Remove citations | |
| content = re.sub(r"\[\d+\]", "", content) | |
| # Remove edit links | |
| content = re.sub(r"\[edit\]", "", content) | |
| # Remove file links | |
| content = re.sub(r"\[\[File:.*?\]\]", "", content) | |
| # Convert links to just text | |
| content = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", content) | |
| # Remove HTML comments | |
| content = re.sub(r"<!--.*?-->", "", content, flags=re.DOTALL) | |
| # Remove templates | |
| content = re.sub(r"\{\{.*?\}\}", "", content) | |
| # Remove small tags | |
| content = re.sub(r"<small>.*?</small>", "", content) | |
| # Normalize whitespace | |
| content = re.sub(r"\n\s*\n", "\n\n", content) | |
| return content.strip() | |
| def format_wiki_table(self, table_content: str) -> str: | |
| """Format a Wikipedia table into readable text.""" | |
| # Split into rows | |
| rows = table_content.strip().split("\n") | |
| formatted_rows = [] | |
| current_row = [] | |
| for row in rows: | |
| # Skip empty rows and table structure markers | |
| if not row.strip() or row.startswith("|-") or row.startswith("|+"): | |
| if current_row: | |
| formatted_rows.append("\t".join(current_row)) | |
| current_row = [] | |
| continue | |
| # Extract cells | |
| cells = [] | |
| # Split the row into cells using | or ! as separators | |
| cell_parts = re.split(r"[|!]", row) | |
| for cell in cell_parts[1:]: # Skip the first empty part | |
| # Clean up the cell content | |
| cell = cell.strip() | |
| # Remove any remaining markup | |
| cell = re.sub(r"<.*?>", "", cell) # Remove HTML tags | |
| cell = re.sub(r"\[\[.*?\|(.*?)\]\]", r"\1", cell) # Convert links | |
| cell = re.sub(r"\[\[(.*?)\]\]", r"\1", cell) # Convert simple links | |
| cell = re.sub(r"\{\{.*?\}\}", "", cell) # Remove templates | |
| cell = re.sub(r"<small>.*?</small>", "", cell) # Remove small tags | |
| cell = re.sub(r'rowspan="\d+"', "", cell) # Remove rowspan | |
| cell = re.sub(r'colspan="\d+"', "", cell) # Remove colspan | |
| cell = re.sub(r'class=".*?"', "", cell) # Remove class attributes | |
| cell = re.sub(r'style=".*?"', "", cell) # Remove style attributes | |
| cell = re.sub(r'align=".*?"', "", cell) # Remove align attributes | |
| cell = re.sub(r'width=".*?"', "", cell) # Remove width attributes | |
| cell = re.sub(r'bgcolor=".*?"', "", cell) # Remove bgcolor attributes | |
| cell = re.sub(r'valign=".*?"', "", cell) # Remove valign attributes | |
| cell = re.sub(r'border=".*?"', "", cell) # Remove border attributes | |
| cell = re.sub( | |
| r'cellpadding=".*?"', "", cell | |
| ) # Remove cellpadding attributes | |
| cell = re.sub( | |
| r'cellspacing=".*?"', "", cell | |
| ) # Remove cellspacing attributes | |
| cell = re.sub(r"<ref.*?</ref>", "", cell) # Remove references | |
| cell = re.sub(r"<ref.*?/>", "", cell) # Remove empty references | |
| cell = re.sub( | |
| r"<br\s*/?>", " ", cell | |
| ) # Replace line breaks with spaces | |
| cell = re.sub(r"\s+", " ", cell) # Normalize whitespace | |
| cells.append(cell) | |
| if cells: | |
| current_row.extend(cells) | |
| if current_row: | |
| formatted_rows.append("\t".join(current_row)) | |
| if formatted_rows: | |
| return "\n".join(formatted_rows) | |
| return "" | |
| def extract_wikipedia_title(self, search_result: str) -> Optional[str]: | |
| """Extract Wikipedia page title from search result.""" | |
| # Look for Wikipedia links in the format [Title - Wikipedia](url) | |
| wiki_match = re.search( | |
| r"\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)", | |
| search_result, | |
| ) | |
| if wiki_match: | |
| return wiki_match.group(1).strip() | |
| return None | |
| def forward(self, query: str) -> str: | |
| logger.info(f"Starting smart search for query: {query}") | |
| # First do a web search to find the Wikipedia page | |
| search_result = self.web_search_tool.forward(query) | |
| logger.info(f"Web search results: {search_result[:100]}...") | |
| # Extract Wikipedia page title from search results | |
| wiki_title = self.extract_wikipedia_title(search_result) | |
| if not wiki_title: | |
| return f"Could not find Wikipedia page in search results for '{query}'." | |
| # Get Wikipedia page content | |
| page_content = self.get_wikipedia_page(wiki_title) | |
| if not page_content: | |
| return f"Could not find Wikipedia page for '{wiki_title}'." | |
| # Format tables and content | |
| formatted_content = [] | |
| current_section = [] | |
| in_table = False | |
| table_content = [] | |
| for line in page_content.split("\n"): | |
| if line.startswith("{|"): | |
| in_table = True | |
| table_content = [line] | |
| elif line.startswith("|}"): | |
| in_table = False | |
| table_content.append(line) | |
| formatted_table = self.format_wiki_table("\n".join(table_content)) | |
| if formatted_table: | |
| current_section.append(formatted_table) | |
| elif in_table: | |
| table_content.append(line) | |
| else: | |
| if line.strip(): | |
| current_section.append(line) | |
| elif current_section: | |
| formatted_content.append("\n".join(current_section)) | |
| current_section = [] | |
| if current_section: | |
| formatted_content.append("\n".join(current_section)) | |
| # Clean and return the formatted content | |
| cleaned_content = self.clean_wiki_content("\n\n".join(formatted_content)) | |
| return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}" | |
| def main(query: str) -> str: | |
| """ | |
| Test function to run the SmartSearchTool directly. | |
| Args: | |
| query: The search query to test | |
| Returns: | |
| The search results | |
| """ | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| # Create and run the tool | |
| tool = SmartSearchTool() | |
| result = tool.forward(query) | |
| # Print the result | |
| print("\nSearch Results:") | |
| print("-" * 80) | |
| print(result) | |
| print("-" * 80) | |
| return result | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) > 1: | |
| query = " ".join(sys.argv[1:]) | |
| main(query) | |
| else: | |
| print("Usage: python tool.py <search query>") | |
| print("Example: python tool.py 'Mercedes Sosa discography'") | |