mjschock commited on
Commit
e4c7240
·
unverified ·
1 Parent(s): 837e221

Enhance agent functionality in main_v2.py by adding WikipediaSearchTool and updating DuckDuckGoSearchTool and VisitWebpageTool parameters. Modify agent initialization to accommodate new tools and increase max results and output length. Update requirements.txt to include Wikipedia-API dependency. Refactor imports for better organization across agent modules.

Browse files
agents/__init__.py CHANGED
@@ -1,9 +1,5 @@
1
- from .web_agent import create_web_agent
2
  from .data_agent import create_data_agent
3
  from .media_agent import create_media_agent
 
4
 
5
- __all__ = [
6
- 'create_web_agent',
7
- 'create_data_agent',
8
- 'create_media_agent'
9
- ]
 
 
1
  from .data_agent import create_data_agent
2
  from .media_agent import create_media_agent
3
+ from .web_agent import create_web_agent
4
 
5
+ __all__ = ["create_web_agent", "create_data_agent", "create_media_agent"]
 
 
 
 
agents/data_agent/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .agent import create_data_agent
2
 
3
- __all__ = ['create_data_agent']
 
1
  from .agent import create_data_agent
2
 
3
+ __all__ = ["create_data_agent"]
agents/data_agent/agent.py CHANGED
@@ -1,8 +1,11 @@
1
  import importlib
 
2
  import yaml
3
  from smolagents import CodeAgent
 
4
  from tools import parse_csv, perform_calculation
5
 
 
6
  def create_data_agent(model):
7
  """
8
  Create a specialized agent for data analysis tasks.
@@ -30,4 +33,4 @@ def create_data_agent(model):
30
  prompt_templates=prompt_templates,
31
  )
32
 
33
- return data_agent
 
1
  import importlib
2
+
3
  import yaml
4
  from smolagents import CodeAgent
5
+
6
  from tools import parse_csv, perform_calculation
7
 
8
+
9
  def create_data_agent(model):
10
  """
11
  Create a specialized agent for data analysis tasks.
 
33
  prompt_templates=prompt_templates,
34
  )
35
 
36
+ return data_agent
agents/media_agent/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .agent import create_media_agent
2
 
3
- __all__ = ['create_media_agent']
 
1
  from .agent import create_media_agent
2
 
3
+ __all__ = ["create_media_agent"]
agents/media_agent/agent.py CHANGED
@@ -1,8 +1,11 @@
1
  import importlib
 
2
  import yaml
3
  from smolagents import CodeAgent
 
4
  from tools import analyze_image, read_pdf
5
 
 
6
  def create_media_agent(model):
7
  """
8
  Create a specialized agent for handling media (images, PDFs).
@@ -30,4 +33,4 @@ def create_media_agent(model):
30
  prompt_templates=prompt_templates,
31
  )
32
 
33
- return media_agent
 
1
  import importlib
2
+
3
  import yaml
4
  from smolagents import CodeAgent
5
+
6
  from tools import analyze_image, read_pdf
7
 
8
+
9
  def create_media_agent(model):
10
  """
11
  Create a specialized agent for handling media (images, PDFs).
 
33
  prompt_templates=prompt_templates,
34
  )
35
 
36
+ return media_agent
agents/web_agent/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .agent import create_web_agent
2
 
3
- __all__ = ['create_web_agent']
 
1
  from .agent import create_web_agent
2
 
3
+ __all__ = ["create_web_agent"]
agents/web_agent/agent.py CHANGED
@@ -1,7 +1,10 @@
1
  import importlib
 
2
  import yaml
3
  from smolagents import CodeAgent
4
- from tools import web_search, browse_webpage, find_in_page, extract_dates
 
 
5
 
6
  def create_web_agent(model):
7
  """
@@ -30,4 +33,4 @@ def create_web_agent(model):
30
  prompt_templates=prompt_templates,
31
  )
32
 
33
- return web_agent
 
1
  import importlib
2
+
3
  import yaml
4
  from smolagents import CodeAgent
5
+
6
+ from tools import browse_webpage, extract_dates, find_in_page, web_search
7
+
8
 
9
  def create_web_agent(model):
10
  """
 
33
  prompt_templates=prompt_templates,
34
  )
35
 
36
+ return web_agent
main.py CHANGED
@@ -14,12 +14,12 @@ from langgraph.graph import END, START, StateGraph
14
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
15
  from opentelemetry.sdk.trace.export import BatchSpanProcessor
16
  from phoenix.otel import register
 
17
  from smolagents import CodeAgent, LiteLLMModel
18
  from smolagents.memory import ActionStep, FinalAnswerStep
19
  from smolagents.monitoring import LogLevel
20
 
21
  from agents import create_data_analysis_agent, create_media_agent, create_web_agent
22
- from prompts import MANAGER_SYSTEM_PROMPT
23
  from tools import perform_calculation, web_search
24
  from utils import extract_final_answer
25
 
 
14
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
15
  from opentelemetry.sdk.trace.export import BatchSpanProcessor
16
  from phoenix.otel import register
17
+ from prompts import MANAGER_SYSTEM_PROMPT
18
  from smolagents import CodeAgent, LiteLLMModel
19
  from smolagents.memory import ActionStep, FinalAnswerStep
20
  from smolagents.monitoring import LogLevel
21
 
22
  from agents import create_data_analysis_agent, create_media_agent, create_web_agent
 
23
  from tools import perform_calculation, web_search
24
  from utils import extract_final_answer
25
 
main_v2.py CHANGED
@@ -11,7 +11,11 @@ from phoenix.otel import register
11
 
12
  # from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
13
  from smolagents import CodeAgent, LiteLLMModel
14
- from smolagents.default_tools import DuckDuckGoSearchTool, VisitWebpageTool
 
 
 
 
15
  from smolagents.monitoring import LogLevel
16
 
17
  from agents.data_agent.agent import create_data_agent
@@ -67,8 +71,9 @@ agent = CodeAgent(
67
  model=model,
68
  prompt_templates=prompt_templates,
69
  tools=[
70
- DuckDuckGoSearchTool(max_results=1),
71
- VisitWebpageTool(max_output_length=256),
 
72
  ],
73
  step_callbacks=None,
74
  verbosity_level=LogLevel.ERROR,
@@ -81,7 +86,7 @@ def main(task: str):
81
  result = agent.run(
82
  additional_args=None,
83
  images=None,
84
- max_steps=3,
85
  reset=True,
86
  stream=False,
87
  task=task,
 
11
 
12
  # from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
13
  from smolagents import CodeAgent, LiteLLMModel
14
+ from smolagents.default_tools import (
15
+ DuckDuckGoSearchTool,
16
+ VisitWebpageTool,
17
+ WikipediaSearchTool,
18
+ )
19
  from smolagents.monitoring import LogLevel
20
 
21
  from agents.data_agent.agent import create_data_agent
 
71
  model=model,
72
  prompt_templates=prompt_templates,
73
  tools=[
74
+ DuckDuckGoSearchTool(max_results=3),
75
+ VisitWebpageTool(max_output_length=1024),
76
+ WikipediaSearchTool(),
77
  ],
78
  step_callbacks=None,
79
  verbosity_level=LogLevel.ERROR,
 
86
  result = agent.run(
87
  additional_args=None,
88
  images=None,
89
+ max_steps=5,
90
  reset=True,
91
  stream=False,
92
  task=task,
requirements.txt CHANGED
@@ -14,3 +14,4 @@ wikipedia-api>=0.8.1
14
  langchain>=0.1.0
15
  langchain-community>=0.0.10
16
  pandas>=2.0.0
 
 
14
  langchain>=0.1.0
15
  langchain-community>=0.0.10
16
  pandas>=2.0.0
17
+ Wikipedia-API>=0.8.1
tools/__init__.py CHANGED
@@ -1,21 +1,23 @@
1
- from .wikipedia_rag import WikipediaRAGTool
2
- from .web_search import web_search
3
- from .browse_webpage import browse_webpage
4
  from .analyze_image import analyze_image
5
- from .read_pdf import read_pdf
6
- from .parse_csv import parse_csv
7
- from .find_in_page import find_in_page
8
  from .extract_dates import extract_dates
 
 
9
  from .perform_calculation import perform_calculation
 
 
 
 
10
 
11
  __all__ = [
12
- 'WikipediaRAGTool',
13
- 'web_search',
14
- 'browse_webpage',
15
- 'analyze_image',
16
- 'read_pdf',
17
- 'parse_csv',
18
- 'find_in_page',
19
- 'extract_dates',
20
- 'perform_calculation'
21
- ]
 
 
 
 
 
1
  from .analyze_image import analyze_image
2
+ from .browse_webpage import browse_webpage
 
 
3
  from .extract_dates import extract_dates
4
+ from .find_in_page import find_in_page
5
+ from .parse_csv import parse_csv
6
  from .perform_calculation import perform_calculation
7
+ from .read_pdf import read_pdf
8
+ from .web_search import web_search
9
+ from .wiki_search.tool import wiki
10
+ from .wikipedia_rag import WikipediaRAGTool
11
 
12
  __all__ = [
13
+ "WikipediaRAGTool",
14
+ "web_search",
15
+ "browse_webpage",
16
+ "analyze_image",
17
+ "read_pdf",
18
+ "parse_csv",
19
+ "find_in_page",
20
+ "extract_dates",
21
+ "perform_calculation",
22
+ "wiki",
23
+ ]
tools/analyze_image/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import analyze_image
2
 
3
- __all__ = ['analyze_image']
 
1
  from .tool import analyze_image
2
 
3
+ __all__ = ["analyze_image"]
tools/analyze_image/tool.py CHANGED
@@ -1,9 +1,11 @@
1
  import io
2
- from typing import Dict, Any
 
3
  import requests
4
  from PIL import Image
5
  from smolagents import tool
6
 
 
7
  @tool
8
  def analyze_image(image_url: str) -> Dict[str, Any]:
9
  """
@@ -36,4 +38,4 @@ def analyze_image(image_url: str) -> Dict[str, Any]:
36
  "aspect_ratio": width / height,
37
  }
38
  except Exception as e:
39
- return {"error": str(e)}
 
1
  import io
2
+ from typing import Any, Dict
3
+
4
  import requests
5
  from PIL import Image
6
  from smolagents import tool
7
 
8
+
9
  @tool
10
  def analyze_image(image_url: str) -> Dict[str, Any]:
11
  """
 
38
  "aspect_ratio": width / height,
39
  }
40
  except Exception as e:
41
+ return {"error": str(e)}
tools/browse_webpage/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import browse_webpage
2
 
3
- __all__ = ['browse_webpage']
 
1
  from .tool import browse_webpage
2
 
3
+ __all__ = ["browse_webpage"]
tools/browse_webpage/tool.py CHANGED
@@ -1,8 +1,10 @@
1
- from typing import Dict, Any
 
2
  import requests
3
  from bs4 import BeautifulSoup
4
  from smolagents import tool
5
 
 
6
  @tool
7
  def browse_webpage(url: str) -> Dict[str, Any]:
8
  """
@@ -40,4 +42,4 @@ def browse_webpage(url: str) -> Dict[str, Any]:
40
 
41
  return {"title": title, "content": text_content, "links": links}
42
  except Exception as e:
43
- return {"error": str(e)}
 
1
+ from typing import Any, Dict
2
+
3
  import requests
4
  from bs4 import BeautifulSoup
5
  from smolagents import tool
6
 
7
+
8
  @tool
9
  def browse_webpage(url: str) -> Dict[str, Any]:
10
  """
 
42
 
43
  return {"title": title, "content": text_content, "links": links}
44
  except Exception as e:
45
+ return {"error": str(e)}
tools/extract_dates/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import extract_dates
2
 
3
- __all__ = ['extract_dates']
 
1
  from .tool import extract_dates
2
 
3
+ __all__ = ["extract_dates"]
tools/extract_dates/tool.py CHANGED
@@ -1,7 +1,9 @@
1
- from typing import List
2
  import re
 
 
3
  from smolagents import tool
4
 
 
5
  @tool
6
  def extract_dates(text: str) -> List[str]:
7
  """
@@ -27,4 +29,4 @@ def extract_dates(text: str) -> List[str]:
27
  matches = re.findall(pattern, text, re.IGNORECASE)
28
  results.extend(matches)
29
 
30
- return results
 
 
1
  import re
2
+ from typing import List
3
+
4
  from smolagents import tool
5
 
6
+
7
  @tool
8
  def extract_dates(text: str) -> List[str]:
9
  """
 
29
  matches = re.findall(pattern, text, re.IGNORECASE)
30
  results.extend(matches)
31
 
32
+ return results
tools/find_in_page/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import find_in_page
2
 
3
- __all__ = ['find_in_page']
 
1
  from .tool import find_in_page
2
 
3
+ __all__ = ["find_in_page"]
tools/find_in_page/tool.py CHANGED
@@ -1,7 +1,9 @@
1
- from typing import List, Dict, Any
2
  import re
 
 
3
  from smolagents import tool
4
 
 
5
  @tool
6
  def find_in_page(page_content: Dict[str, Any], query: str) -> List[str]:
7
  """
@@ -25,4 +27,4 @@ def find_in_page(page_content: Dict[str, Any], query: str) -> List[str]:
25
  if query.lower() in sentence.lower():
26
  results.append(sentence)
27
 
28
- return results
 
 
1
  import re
2
+ from typing import Any, Dict, List
3
+
4
  from smolagents import tool
5
 
6
+
7
  @tool
8
  def find_in_page(page_content: Dict[str, Any], query: str) -> List[str]:
9
  """
 
27
  if query.lower() in sentence.lower():
28
  results.append(sentence)
29
 
30
+ return results
tools/parse_csv/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import parse_csv
2
 
3
- __all__ = ['parse_csv']
 
1
  from .tool import parse_csv
2
 
3
+ __all__ = ["parse_csv"]
tools/parse_csv/tool.py CHANGED
@@ -1,9 +1,11 @@
1
  import io
2
- from typing import Dict, Any
3
- import requests
4
  import pandas as pd
 
5
  from smolagents import tool
6
 
 
7
  @tool
8
  def parse_csv(csv_url: str) -> Dict[str, Any]:
9
  """
@@ -35,4 +37,4 @@ def parse_csv(csv_url: str) -> Dict[str, Any]:
35
  "column_dtypes": {col: str(df[col].dtype) for col in columns},
36
  }
37
  except Exception as e:
38
- return {"error": str(e)}
 
1
  import io
2
+ from typing import Any, Dict
3
+
4
  import pandas as pd
5
+ import requests
6
  from smolagents import tool
7
 
8
+
9
  @tool
10
  def parse_csv(csv_url: str) -> Dict[str, Any]:
11
  """
 
37
  "column_dtypes": {col: str(df[col].dtype) for col in columns},
38
  }
39
  except Exception as e:
40
+ return {"error": str(e)}
tools/perform_calculation/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import perform_calculation
2
 
3
- __all__ = ['perform_calculation']
 
1
  from .tool import perform_calculation
2
 
3
+ __all__ = ["perform_calculation"]
tools/perform_calculation/tool.py CHANGED
@@ -1,7 +1,9 @@
1
- from typing import Dict, Any
2
  import math
 
 
3
  from smolagents import tool
4
 
 
5
  @tool
6
  def perform_calculation(expression: str) -> Dict[str, Any]:
7
  """
@@ -35,4 +37,4 @@ def perform_calculation(expression: str) -> Dict[str, Any]:
35
 
36
  return {"result": result}
37
  except Exception as e:
38
- return {"error": str(e)}
 
 
1
  import math
2
+ from typing import Any, Dict
3
+
4
  from smolagents import tool
5
 
6
+
7
  @tool
8
  def perform_calculation(expression: str) -> Dict[str, Any]:
9
  """
 
37
 
38
  return {"result": result}
39
  except Exception as e:
40
+ return {"error": str(e)}
tools/read_pdf/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import read_pdf
2
 
3
- __all__ = ['read_pdf']
 
1
  from .tool import read_pdf
2
 
3
+ __all__ = ["read_pdf"]
tools/read_pdf/tool.py CHANGED
@@ -1,6 +1,7 @@
1
  import requests
2
  from smolagents import tool
3
 
 
4
  @tool
5
  def read_pdf(pdf_url: str) -> str:
6
  """
@@ -21,4 +22,4 @@ def read_pdf(pdf_url: str) -> str:
21
  # such as PyPDF2, pdfplumber, or pdf2text
22
  return "PDF content extraction would happen here in a real implementation"
23
  except Exception as e:
24
- return f"Error: {str(e)}"
 
1
  import requests
2
  from smolagents import tool
3
 
4
+
5
  @tool
6
  def read_pdf(pdf_url: str) -> str:
7
  """
 
22
  # such as PyPDF2, pdfplumber, or pdf2text
23
  return "PDF content extraction would happen here in a real implementation"
24
  except Exception as e:
25
+ return f"Error: {str(e)}"
tools/web_search/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import web_search
2
 
3
- __all__ = ['web_search']
 
1
  from .tool import web_search
2
 
3
+ __all__ = ["web_search"]
tools/web_search/tool.py CHANGED
@@ -1,6 +1,7 @@
1
  from smolagents import tool
2
  from smolagents.default_tools import DuckDuckGoSearchTool
3
 
 
4
  @tool
5
  def web_search(query: str) -> str:
6
  """
 
1
  from smolagents import tool
2
  from smolagents.default_tools import DuckDuckGoSearchTool
3
 
4
+
5
  @tool
6
  def web_search(query: str) -> str:
7
  """
tools/wiki_search/tool.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipediaapi
2
+ from smolagents import tool
3
+
4
+
5
+ @tool
6
+ def wiki(query: str) -> str:
7
+ """
8
+ Search and retrieve information from Wikipedia using the Wikipedia-API library.
9
+
10
+ Args:
11
+ query: The search query to look up on Wikipedia
12
+
13
+ Returns:
14
+ A string containing the Wikipedia page summary and relevant sections
15
+ """
16
+ # Initialize Wikipedia API with a user agent
17
+ wiki_wiki = wikipediaapi.Wikipedia(
18
+ user_agent="HF-Agents-Course (https://huggingface.co/courses/agents-course)",
19
+ language="en",
20
+ )
21
+
22
+ # Search for the page
23
+ page = wiki_wiki.page(query)
24
+
25
+ if not page.exists():
26
+ return f"No Wikipedia page found for query: {query}"
27
+
28
+ # Get the page summary
29
+ result = f"Title: {page.title}\n\n"
30
+ result += f"Summary: {page.summary}\n\n"
31
+
32
+ # Add the first few sections if they exist
33
+ if page.sections:
34
+ result += "Sections:\n"
35
+ for section in page.sections[
36
+ :3
37
+ ]: # Limit to first 3 sections to avoid too much text
38
+ result += f"\n{section.title}:\n{section.text[:500]}...\n"
39
+
40
+ return result
tools/wikipedia_rag/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .tool import WikipediaRAGTool
2
 
3
- __all__ = ['WikipediaRAGTool']
 
1
  from .tool import WikipediaRAGTool
2
 
3
+ __all__ = ["WikipediaRAGTool"]
tools/wikipedia_rag/run.py CHANGED
@@ -1,22 +1,30 @@
1
- import os
2
  import argparse
 
 
3
  from dotenv import load_dotenv
4
  from tool import WikipediaRAGTool
5
 
 
6
  def main():
7
  # Load environment variables
8
  load_dotenv()
9
-
10
  # Set up argument parser
11
- parser = argparse.ArgumentParser(description='Run Wikipedia RAG Tool')
12
- parser.add_argument('--query', type=str, required=True, help='Search query for Wikipedia articles')
13
- parser.add_argument('--dataset-path', type=str, default='wikipedia-structured-contents',
14
- help='Path to the Wikipedia dataset')
 
 
 
 
 
 
15
  args = parser.parse_args()
16
-
17
  # Initialize the tool
18
  tool = WikipediaRAGTool(dataset_path=args.dataset_path)
19
-
20
  # Run the query
21
  print(f"\nQuery: {args.query}")
22
  print("-" * 50)
@@ -24,5 +32,6 @@ def main():
24
  print(f"Result: {result}")
25
  print("-" * 50)
26
 
 
27
  if __name__ == "__main__":
28
- main()
 
 
1
  import argparse
2
+ import os
3
+
4
  from dotenv import load_dotenv
5
  from tool import WikipediaRAGTool
6
 
7
+
8
  def main():
9
  # Load environment variables
10
  load_dotenv()
11
+
12
  # Set up argument parser
13
+ parser = argparse.ArgumentParser(description="Run Wikipedia RAG Tool")
14
+ parser.add_argument(
15
+ "--query", type=str, required=True, help="Search query for Wikipedia articles"
16
+ )
17
+ parser.add_argument(
18
+ "--dataset-path",
19
+ type=str,
20
+ default="wikipedia-structured-contents",
21
+ help="Path to the Wikipedia dataset",
22
+ )
23
  args = parser.parse_args()
24
+
25
  # Initialize the tool
26
  tool = WikipediaRAGTool(dataset_path=args.dataset_path)
27
+
28
  # Run the query
29
  print(f"\nQuery: {args.query}")
30
  print("-" * 50)
 
32
  print(f"Result: {result}")
33
  print("-" * 50)
34
 
35
+
36
  if __name__ == "__main__":
37
+ main()
tools/wikipedia_rag/tool.py CHANGED
@@ -1,17 +1,19 @@
1
  import os
2
- import pandas as pd
3
  from typing import List, Optional
 
 
4
  from langchain.docstore.document import Document
5
  from langchain_community.retrievers import BM25Retriever
6
  from smolagents import Tool
7
 
 
8
  class WikipediaRAGTool(Tool):
9
  name = "wikipedia_rag"
10
  description = "Retrieves relevant information from Wikipedia articles using RAG."
11
  inputs = {
12
  "query": {
13
  "type": "string",
14
- "description": "The search query to find relevant Wikipedia content."
15
  }
16
  }
17
  output_type = "string"
@@ -27,24 +29,24 @@ class WikipediaRAGTool(Tool):
27
  try:
28
  # Load the dataset
29
  df = pd.read_csv(os.path.join(self.dataset_path, "wikipedia_articles.csv"))
30
-
31
  # Convert each article into a Document
32
  self.docs = [
33
  Document(
34
  page_content=f"Title: {row['title']}\n\nContent: {row['content']}",
35
  metadata={
36
- "title": row['title'],
37
- "url": row['url'],
38
- "category": row.get('category', '')
39
- }
40
  )
41
  for _, row in df.iterrows()
42
  ]
43
-
44
  # Initialize the retriever
45
  self.retriever = BM25Retriever.from_documents(self.docs)
46
  self.is_initialized = True
47
-
48
  except Exception as e:
49
  print(f"Error loading documents: {e}")
50
  raise
@@ -53,17 +55,17 @@ class WikipediaRAGTool(Tool):
53
  """Process the query and return relevant Wikipedia content."""
54
  if not self.is_initialized:
55
  self._load_documents()
56
-
57
  if not self.retriever:
58
  return "Error: Retriever not initialized properly."
59
-
60
  try:
61
  # Get relevant documents
62
  results = self.retriever.get_relevant_documents(query)
63
-
64
  if not results:
65
  return "No relevant Wikipedia articles found."
66
-
67
  # Format the results
68
  formatted_results = []
69
  for doc in results[:3]: # Return top 3 most relevant results
@@ -74,8 +76,8 @@ class WikipediaRAGTool(Tool):
74
  f"Category: {metadata['category']}\n"
75
  f"Content: {doc.page_content[:500]}...\n"
76
  )
77
-
78
  return "\n\n".join(formatted_results)
79
-
80
  except Exception as e:
81
- return f"Error retrieving information: {str(e)}"
 
1
  import os
 
2
  from typing import List, Optional
3
+
4
+ import pandas as pd
5
  from langchain.docstore.document import Document
6
  from langchain_community.retrievers import BM25Retriever
7
  from smolagents import Tool
8
 
9
+
10
  class WikipediaRAGTool(Tool):
11
  name = "wikipedia_rag"
12
  description = "Retrieves relevant information from Wikipedia articles using RAG."
13
  inputs = {
14
  "query": {
15
  "type": "string",
16
+ "description": "The search query to find relevant Wikipedia content.",
17
  }
18
  }
19
  output_type = "string"
 
29
  try:
30
  # Load the dataset
31
  df = pd.read_csv(os.path.join(self.dataset_path, "wikipedia_articles.csv"))
32
+
33
  # Convert each article into a Document
34
  self.docs = [
35
  Document(
36
  page_content=f"Title: {row['title']}\n\nContent: {row['content']}",
37
  metadata={
38
+ "title": row["title"],
39
+ "url": row["url"],
40
+ "category": row.get("category", ""),
41
+ },
42
  )
43
  for _, row in df.iterrows()
44
  ]
45
+
46
  # Initialize the retriever
47
  self.retriever = BM25Retriever.from_documents(self.docs)
48
  self.is_initialized = True
49
+
50
  except Exception as e:
51
  print(f"Error loading documents: {e}")
52
  raise
 
55
  """Process the query and return relevant Wikipedia content."""
56
  if not self.is_initialized:
57
  self._load_documents()
58
+
59
  if not self.retriever:
60
  return "Error: Retriever not initialized properly."
61
+
62
  try:
63
  # Get relevant documents
64
  results = self.retriever.get_relevant_documents(query)
65
+
66
  if not results:
67
  return "No relevant Wikipedia articles found."
68
+
69
  # Format the results
70
  formatted_results = []
71
  for doc in results[:3]: # Return top 3 most relevant results
 
76
  f"Category: {metadata['category']}\n"
77
  f"Content: {doc.page_content[:500]}...\n"
78
  )
79
+
80
  return "\n\n".join(formatted_results)
81
+
82
  except Exception as e:
83
+ return f"Error retrieving information: {str(e)}"