Spaces:
Build error
Build error
Enhance agent functionality in main_v2.py by adding WikipediaSearchTool and updating DuckDuckGoSearchTool and VisitWebpageTool parameters. Modify agent initialization to accommodate new tools and increase max results and output length. Update requirements.txt to include Wikipedia-API dependency. Refactor imports for better organization across agent modules.
Browse files- agents/__init__.py +2 -6
- agents/data_agent/__init__.py +1 -1
- agents/data_agent/agent.py +4 -1
- agents/media_agent/__init__.py +1 -1
- agents/media_agent/agent.py +4 -1
- agents/web_agent/__init__.py +1 -1
- agents/web_agent/agent.py +5 -2
- main.py +1 -1
- main_v2.py +9 -4
- requirements.txt +1 -0
- tools/__init__.py +18 -16
- tools/analyze_image/__init__.py +1 -1
- tools/analyze_image/tool.py +4 -2
- tools/browse_webpage/__init__.py +1 -1
- tools/browse_webpage/tool.py +4 -2
- tools/extract_dates/__init__.py +1 -1
- tools/extract_dates/tool.py +4 -2
- tools/find_in_page/__init__.py +1 -1
- tools/find_in_page/tool.py +4 -2
- tools/parse_csv/__init__.py +1 -1
- tools/parse_csv/tool.py +5 -3
- tools/perform_calculation/__init__.py +1 -1
- tools/perform_calculation/tool.py +4 -2
- tools/read_pdf/__init__.py +1 -1
- tools/read_pdf/tool.py +2 -1
- tools/web_search/__init__.py +1 -1
- tools/web_search/tool.py +1 -0
- tools/wiki_search/tool.py +40 -0
- tools/wikipedia_rag/__init__.py +1 -1
- tools/wikipedia_rag/run.py +18 -9
- tools/wikipedia_rag/tool.py +18 -16
agents/__init__.py
CHANGED
@@ -1,9 +1,5 @@
|
|
1 |
-
from .web_agent import create_web_agent
|
2 |
from .data_agent import create_data_agent
|
3 |
from .media_agent import create_media_agent
|
|
|
4 |
|
5 |
-
__all__ = [
|
6 |
-
'create_web_agent',
|
7 |
-
'create_data_agent',
|
8 |
-
'create_media_agent'
|
9 |
-
]
|
|
|
|
|
1 |
from .data_agent import create_data_agent
|
2 |
from .media_agent import create_media_agent
|
3 |
+
from .web_agent import create_web_agent
|
4 |
|
5 |
+
__all__ = ["create_web_agent", "create_data_agent", "create_media_agent"]
|
|
|
|
|
|
|
|
agents/data_agent/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .agent import create_data_agent
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .agent import create_data_agent
|
2 |
|
3 |
+
__all__ = ["create_data_agent"]
|
agents/data_agent/agent.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import importlib
|
|
|
2 |
import yaml
|
3 |
from smolagents import CodeAgent
|
|
|
4 |
from tools import parse_csv, perform_calculation
|
5 |
|
|
|
6 |
def create_data_agent(model):
|
7 |
"""
|
8 |
Create a specialized agent for data analysis tasks.
|
@@ -30,4 +33,4 @@ def create_data_agent(model):
|
|
30 |
prompt_templates=prompt_templates,
|
31 |
)
|
32 |
|
33 |
-
return data_agent
|
|
|
1 |
import importlib
|
2 |
+
|
3 |
import yaml
|
4 |
from smolagents import CodeAgent
|
5 |
+
|
6 |
from tools import parse_csv, perform_calculation
|
7 |
|
8 |
+
|
9 |
def create_data_agent(model):
|
10 |
"""
|
11 |
Create a specialized agent for data analysis tasks.
|
|
|
33 |
prompt_templates=prompt_templates,
|
34 |
)
|
35 |
|
36 |
+
return data_agent
|
agents/media_agent/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .agent import create_media_agent
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .agent import create_media_agent
|
2 |
|
3 |
+
__all__ = ["create_media_agent"]
|
agents/media_agent/agent.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import importlib
|
|
|
2 |
import yaml
|
3 |
from smolagents import CodeAgent
|
|
|
4 |
from tools import analyze_image, read_pdf
|
5 |
|
|
|
6 |
def create_media_agent(model):
|
7 |
"""
|
8 |
Create a specialized agent for handling media (images, PDFs).
|
@@ -30,4 +33,4 @@ def create_media_agent(model):
|
|
30 |
prompt_templates=prompt_templates,
|
31 |
)
|
32 |
|
33 |
-
return media_agent
|
|
|
1 |
import importlib
|
2 |
+
|
3 |
import yaml
|
4 |
from smolagents import CodeAgent
|
5 |
+
|
6 |
from tools import analyze_image, read_pdf
|
7 |
|
8 |
+
|
9 |
def create_media_agent(model):
|
10 |
"""
|
11 |
Create a specialized agent for handling media (images, PDFs).
|
|
|
33 |
prompt_templates=prompt_templates,
|
34 |
)
|
35 |
|
36 |
+
return media_agent
|
agents/web_agent/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .agent import create_web_agent
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .agent import create_web_agent
|
2 |
|
3 |
+
__all__ = ["create_web_agent"]
|
agents/web_agent/agent.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
import importlib
|
|
|
2 |
import yaml
|
3 |
from smolagents import CodeAgent
|
4 |
-
|
|
|
|
|
5 |
|
6 |
def create_web_agent(model):
|
7 |
"""
|
@@ -30,4 +33,4 @@ def create_web_agent(model):
|
|
30 |
prompt_templates=prompt_templates,
|
31 |
)
|
32 |
|
33 |
-
return web_agent
|
|
|
1 |
import importlib
|
2 |
+
|
3 |
import yaml
|
4 |
from smolagents import CodeAgent
|
5 |
+
|
6 |
+
from tools import browse_webpage, extract_dates, find_in_page, web_search
|
7 |
+
|
8 |
|
9 |
def create_web_agent(model):
|
10 |
"""
|
|
|
33 |
prompt_templates=prompt_templates,
|
34 |
)
|
35 |
|
36 |
+
return web_agent
|
main.py
CHANGED
@@ -14,12 +14,12 @@ from langgraph.graph import END, START, StateGraph
|
|
14 |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
15 |
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
16 |
from phoenix.otel import register
|
|
|
17 |
from smolagents import CodeAgent, LiteLLMModel
|
18 |
from smolagents.memory import ActionStep, FinalAnswerStep
|
19 |
from smolagents.monitoring import LogLevel
|
20 |
|
21 |
from agents import create_data_analysis_agent, create_media_agent, create_web_agent
|
22 |
-
from prompts import MANAGER_SYSTEM_PROMPT
|
23 |
from tools import perform_calculation, web_search
|
24 |
from utils import extract_final_answer
|
25 |
|
|
|
14 |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
15 |
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
16 |
from phoenix.otel import register
|
17 |
+
from prompts import MANAGER_SYSTEM_PROMPT
|
18 |
from smolagents import CodeAgent, LiteLLMModel
|
19 |
from smolagents.memory import ActionStep, FinalAnswerStep
|
20 |
from smolagents.monitoring import LogLevel
|
21 |
|
22 |
from agents import create_data_analysis_agent, create_media_agent, create_web_agent
|
|
|
23 |
from tools import perform_calculation, web_search
|
24 |
from utils import extract_final_answer
|
25 |
|
main_v2.py
CHANGED
@@ -11,7 +11,11 @@ from phoenix.otel import register
|
|
11 |
|
12 |
# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
|
13 |
from smolagents import CodeAgent, LiteLLMModel
|
14 |
-
from smolagents.default_tools import
|
|
|
|
|
|
|
|
|
15 |
from smolagents.monitoring import LogLevel
|
16 |
|
17 |
from agents.data_agent.agent import create_data_agent
|
@@ -67,8 +71,9 @@ agent = CodeAgent(
|
|
67 |
model=model,
|
68 |
prompt_templates=prompt_templates,
|
69 |
tools=[
|
70 |
-
DuckDuckGoSearchTool(max_results=
|
71 |
-
VisitWebpageTool(max_output_length=
|
|
|
72 |
],
|
73 |
step_callbacks=None,
|
74 |
verbosity_level=LogLevel.ERROR,
|
@@ -81,7 +86,7 @@ def main(task: str):
|
|
81 |
result = agent.run(
|
82 |
additional_args=None,
|
83 |
images=None,
|
84 |
-
max_steps=
|
85 |
reset=True,
|
86 |
stream=False,
|
87 |
task=task,
|
|
|
11 |
|
12 |
# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
|
13 |
from smolagents import CodeAgent, LiteLLMModel
|
14 |
+
from smolagents.default_tools import (
|
15 |
+
DuckDuckGoSearchTool,
|
16 |
+
VisitWebpageTool,
|
17 |
+
WikipediaSearchTool,
|
18 |
+
)
|
19 |
from smolagents.monitoring import LogLevel
|
20 |
|
21 |
from agents.data_agent.agent import create_data_agent
|
|
|
71 |
model=model,
|
72 |
prompt_templates=prompt_templates,
|
73 |
tools=[
|
74 |
+
DuckDuckGoSearchTool(max_results=3),
|
75 |
+
VisitWebpageTool(max_output_length=1024),
|
76 |
+
WikipediaSearchTool(),
|
77 |
],
|
78 |
step_callbacks=None,
|
79 |
verbosity_level=LogLevel.ERROR,
|
|
|
86 |
result = agent.run(
|
87 |
additional_args=None,
|
88 |
images=None,
|
89 |
+
max_steps=5,
|
90 |
reset=True,
|
91 |
stream=False,
|
92 |
task=task,
|
requirements.txt
CHANGED
@@ -14,3 +14,4 @@ wikipedia-api>=0.8.1
|
|
14 |
langchain>=0.1.0
|
15 |
langchain-community>=0.0.10
|
16 |
pandas>=2.0.0
|
|
|
|
14 |
langchain>=0.1.0
|
15 |
langchain-community>=0.0.10
|
16 |
pandas>=2.0.0
|
17 |
+
Wikipedia-API>=0.8.1
|
tools/__init__.py
CHANGED
@@ -1,21 +1,23 @@
|
|
1 |
-
from .wikipedia_rag import WikipediaRAGTool
|
2 |
-
from .web_search import web_search
|
3 |
-
from .browse_webpage import browse_webpage
|
4 |
from .analyze_image import analyze_image
|
5 |
-
from .
|
6 |
-
from .parse_csv import parse_csv
|
7 |
-
from .find_in_page import find_in_page
|
8 |
from .extract_dates import extract_dates
|
|
|
|
|
9 |
from .perform_calculation import perform_calculation
|
|
|
|
|
|
|
|
|
10 |
|
11 |
__all__ = [
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from .analyze_image import analyze_image
|
2 |
+
from .browse_webpage import browse_webpage
|
|
|
|
|
3 |
from .extract_dates import extract_dates
|
4 |
+
from .find_in_page import find_in_page
|
5 |
+
from .parse_csv import parse_csv
|
6 |
from .perform_calculation import perform_calculation
|
7 |
+
from .read_pdf import read_pdf
|
8 |
+
from .web_search import web_search
|
9 |
+
from .wiki_search.tool import wiki
|
10 |
+
from .wikipedia_rag import WikipediaRAGTool
|
11 |
|
12 |
__all__ = [
|
13 |
+
"WikipediaRAGTool",
|
14 |
+
"web_search",
|
15 |
+
"browse_webpage",
|
16 |
+
"analyze_image",
|
17 |
+
"read_pdf",
|
18 |
+
"parse_csv",
|
19 |
+
"find_in_page",
|
20 |
+
"extract_dates",
|
21 |
+
"perform_calculation",
|
22 |
+
"wiki",
|
23 |
+
]
|
tools/analyze_image/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import analyze_image
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import analyze_image
|
2 |
|
3 |
+
__all__ = ["analyze_image"]
|
tools/analyze_image/tool.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import io
|
2 |
-
from typing import
|
|
|
3 |
import requests
|
4 |
from PIL import Image
|
5 |
from smolagents import tool
|
6 |
|
|
|
7 |
@tool
|
8 |
def analyze_image(image_url: str) -> Dict[str, Any]:
|
9 |
"""
|
@@ -36,4 +38,4 @@ def analyze_image(image_url: str) -> Dict[str, Any]:
|
|
36 |
"aspect_ratio": width / height,
|
37 |
}
|
38 |
except Exception as e:
|
39 |
-
return {"error": str(e)}
|
|
|
1 |
import io
|
2 |
+
from typing import Any, Dict
|
3 |
+
|
4 |
import requests
|
5 |
from PIL import Image
|
6 |
from smolagents import tool
|
7 |
|
8 |
+
|
9 |
@tool
|
10 |
def analyze_image(image_url: str) -> Dict[str, Any]:
|
11 |
"""
|
|
|
38 |
"aspect_ratio": width / height,
|
39 |
}
|
40 |
except Exception as e:
|
41 |
+
return {"error": str(e)}
|
tools/browse_webpage/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import browse_webpage
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import browse_webpage
|
2 |
|
3 |
+
__all__ = ["browse_webpage"]
|
tools/browse_webpage/tool.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
-
from typing import
|
|
|
2 |
import requests
|
3 |
from bs4 import BeautifulSoup
|
4 |
from smolagents import tool
|
5 |
|
|
|
6 |
@tool
|
7 |
def browse_webpage(url: str) -> Dict[str, Any]:
|
8 |
"""
|
@@ -40,4 +42,4 @@ def browse_webpage(url: str) -> Dict[str, Any]:
|
|
40 |
|
41 |
return {"title": title, "content": text_content, "links": links}
|
42 |
except Exception as e:
|
43 |
-
return {"error": str(e)}
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
import requests
|
4 |
from bs4 import BeautifulSoup
|
5 |
from smolagents import tool
|
6 |
|
7 |
+
|
8 |
@tool
|
9 |
def browse_webpage(url: str) -> Dict[str, Any]:
|
10 |
"""
|
|
|
42 |
|
43 |
return {"title": title, "content": text_content, "links": links}
|
44 |
except Exception as e:
|
45 |
+
return {"error": str(e)}
|
tools/extract_dates/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import extract_dates
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import extract_dates
|
2 |
|
3 |
+
__all__ = ["extract_dates"]
|
tools/extract_dates/tool.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
from typing import List
|
2 |
import re
|
|
|
|
|
3 |
from smolagents import tool
|
4 |
|
|
|
5 |
@tool
|
6 |
def extract_dates(text: str) -> List[str]:
|
7 |
"""
|
@@ -27,4 +29,4 @@ def extract_dates(text: str) -> List[str]:
|
|
27 |
matches = re.findall(pattern, text, re.IGNORECASE)
|
28 |
results.extend(matches)
|
29 |
|
30 |
-
return results
|
|
|
|
|
1 |
import re
|
2 |
+
from typing import List
|
3 |
+
|
4 |
from smolagents import tool
|
5 |
|
6 |
+
|
7 |
@tool
|
8 |
def extract_dates(text: str) -> List[str]:
|
9 |
"""
|
|
|
29 |
matches = re.findall(pattern, text, re.IGNORECASE)
|
30 |
results.extend(matches)
|
31 |
|
32 |
+
return results
|
tools/find_in_page/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import find_in_page
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import find_in_page
|
2 |
|
3 |
+
__all__ = ["find_in_page"]
|
tools/find_in_page/tool.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
from typing import List, Dict, Any
|
2 |
import re
|
|
|
|
|
3 |
from smolagents import tool
|
4 |
|
|
|
5 |
@tool
|
6 |
def find_in_page(page_content: Dict[str, Any], query: str) -> List[str]:
|
7 |
"""
|
@@ -25,4 +27,4 @@ def find_in_page(page_content: Dict[str, Any], query: str) -> List[str]:
|
|
25 |
if query.lower() in sentence.lower():
|
26 |
results.append(sentence)
|
27 |
|
28 |
-
return results
|
|
|
|
|
1 |
import re
|
2 |
+
from typing import Any, Dict, List
|
3 |
+
|
4 |
from smolagents import tool
|
5 |
|
6 |
+
|
7 |
@tool
|
8 |
def find_in_page(page_content: Dict[str, Any], query: str) -> List[str]:
|
9 |
"""
|
|
|
27 |
if query.lower() in sentence.lower():
|
28 |
results.append(sentence)
|
29 |
|
30 |
+
return results
|
tools/parse_csv/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import parse_csv
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import parse_csv
|
2 |
|
3 |
+
__all__ = ["parse_csv"]
|
tools/parse_csv/tool.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import io
|
2 |
-
from typing import
|
3 |
-
|
4 |
import pandas as pd
|
|
|
5 |
from smolagents import tool
|
6 |
|
|
|
7 |
@tool
|
8 |
def parse_csv(csv_url: str) -> Dict[str, Any]:
|
9 |
"""
|
@@ -35,4 +37,4 @@ def parse_csv(csv_url: str) -> Dict[str, Any]:
|
|
35 |
"column_dtypes": {col: str(df[col].dtype) for col in columns},
|
36 |
}
|
37 |
except Exception as e:
|
38 |
-
return {"error": str(e)}
|
|
|
1 |
import io
|
2 |
+
from typing import Any, Dict
|
3 |
+
|
4 |
import pandas as pd
|
5 |
+
import requests
|
6 |
from smolagents import tool
|
7 |
|
8 |
+
|
9 |
@tool
|
10 |
def parse_csv(csv_url: str) -> Dict[str, Any]:
|
11 |
"""
|
|
|
37 |
"column_dtypes": {col: str(df[col].dtype) for col in columns},
|
38 |
}
|
39 |
except Exception as e:
|
40 |
+
return {"error": str(e)}
|
tools/perform_calculation/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import perform_calculation
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import perform_calculation
|
2 |
|
3 |
+
__all__ = ["perform_calculation"]
|
tools/perform_calculation/tool.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
from typing import Dict, Any
|
2 |
import math
|
|
|
|
|
3 |
from smolagents import tool
|
4 |
|
|
|
5 |
@tool
|
6 |
def perform_calculation(expression: str) -> Dict[str, Any]:
|
7 |
"""
|
@@ -35,4 +37,4 @@ def perform_calculation(expression: str) -> Dict[str, Any]:
|
|
35 |
|
36 |
return {"result": result}
|
37 |
except Exception as e:
|
38 |
-
return {"error": str(e)}
|
|
|
|
|
1 |
import math
|
2 |
+
from typing import Any, Dict
|
3 |
+
|
4 |
from smolagents import tool
|
5 |
|
6 |
+
|
7 |
@tool
|
8 |
def perform_calculation(expression: str) -> Dict[str, Any]:
|
9 |
"""
|
|
|
37 |
|
38 |
return {"result": result}
|
39 |
except Exception as e:
|
40 |
+
return {"error": str(e)}
|
tools/read_pdf/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import read_pdf
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import read_pdf
|
2 |
|
3 |
+
__all__ = ["read_pdf"]
|
tools/read_pdf/tool.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import requests
|
2 |
from smolagents import tool
|
3 |
|
|
|
4 |
@tool
|
5 |
def read_pdf(pdf_url: str) -> str:
|
6 |
"""
|
@@ -21,4 +22,4 @@ def read_pdf(pdf_url: str) -> str:
|
|
21 |
# such as PyPDF2, pdfplumber, or pdf2text
|
22 |
return "PDF content extraction would happen here in a real implementation"
|
23 |
except Exception as e:
|
24 |
-
return f"Error: {str(e)}"
|
|
|
1 |
import requests
|
2 |
from smolagents import tool
|
3 |
|
4 |
+
|
5 |
@tool
|
6 |
def read_pdf(pdf_url: str) -> str:
|
7 |
"""
|
|
|
22 |
# such as PyPDF2, pdfplumber, or pdf2text
|
23 |
return "PDF content extraction would happen here in a real implementation"
|
24 |
except Exception as e:
|
25 |
+
return f"Error: {str(e)}"
|
tools/web_search/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import web_search
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import web_search
|
2 |
|
3 |
+
__all__ = ["web_search"]
|
tools/web_search/tool.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from smolagents import tool
|
2 |
from smolagents.default_tools import DuckDuckGoSearchTool
|
3 |
|
|
|
4 |
@tool
|
5 |
def web_search(query: str) -> str:
|
6 |
"""
|
|
|
1 |
from smolagents import tool
|
2 |
from smolagents.default_tools import DuckDuckGoSearchTool
|
3 |
|
4 |
+
|
5 |
@tool
|
6 |
def web_search(query: str) -> str:
|
7 |
"""
|
tools/wiki_search/tool.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wikipediaapi
|
2 |
+
from smolagents import tool
|
3 |
+
|
4 |
+
|
5 |
+
@tool
|
6 |
+
def wiki(query: str) -> str:
|
7 |
+
"""
|
8 |
+
Search and retrieve information from Wikipedia using the Wikipedia-API library.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
query: The search query to look up on Wikipedia
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
A string containing the Wikipedia page summary and relevant sections
|
15 |
+
"""
|
16 |
+
# Initialize Wikipedia API with a user agent
|
17 |
+
wiki_wiki = wikipediaapi.Wikipedia(
|
18 |
+
user_agent="HF-Agents-Course (https://huggingface.co/courses/agents-course)",
|
19 |
+
language="en",
|
20 |
+
)
|
21 |
+
|
22 |
+
# Search for the page
|
23 |
+
page = wiki_wiki.page(query)
|
24 |
+
|
25 |
+
if not page.exists():
|
26 |
+
return f"No Wikipedia page found for query: {query}"
|
27 |
+
|
28 |
+
# Get the page summary
|
29 |
+
result = f"Title: {page.title}\n\n"
|
30 |
+
result += f"Summary: {page.summary}\n\n"
|
31 |
+
|
32 |
+
# Add the first few sections if they exist
|
33 |
+
if page.sections:
|
34 |
+
result += "Sections:\n"
|
35 |
+
for section in page.sections[
|
36 |
+
:3
|
37 |
+
]: # Limit to first 3 sections to avoid too much text
|
38 |
+
result += f"\n{section.title}:\n{section.text[:500]}...\n"
|
39 |
+
|
40 |
+
return result
|
tools/wikipedia_rag/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .tool import WikipediaRAGTool
|
2 |
|
3 |
-
__all__ = [
|
|
|
1 |
from .tool import WikipediaRAGTool
|
2 |
|
3 |
+
__all__ = ["WikipediaRAGTool"]
|
tools/wikipedia_rag/run.py
CHANGED
@@ -1,22 +1,30 @@
|
|
1 |
-
import os
|
2 |
import argparse
|
|
|
|
|
3 |
from dotenv import load_dotenv
|
4 |
from tool import WikipediaRAGTool
|
5 |
|
|
|
6 |
def main():
|
7 |
# Load environment variables
|
8 |
load_dotenv()
|
9 |
-
|
10 |
# Set up argument parser
|
11 |
-
parser = argparse.ArgumentParser(description=
|
12 |
-
parser.add_argument(
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
args = parser.parse_args()
|
16 |
-
|
17 |
# Initialize the tool
|
18 |
tool = WikipediaRAGTool(dataset_path=args.dataset_path)
|
19 |
-
|
20 |
# Run the query
|
21 |
print(f"\nQuery: {args.query}")
|
22 |
print("-" * 50)
|
@@ -24,5 +32,6 @@ def main():
|
|
24 |
print(f"Result: {result}")
|
25 |
print("-" * 50)
|
26 |
|
|
|
27 |
if __name__ == "__main__":
|
28 |
-
main()
|
|
|
|
|
1 |
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
from dotenv import load_dotenv
|
5 |
from tool import WikipediaRAGTool
|
6 |
|
7 |
+
|
8 |
def main():
|
9 |
# Load environment variables
|
10 |
load_dotenv()
|
11 |
+
|
12 |
# Set up argument parser
|
13 |
+
parser = argparse.ArgumentParser(description="Run Wikipedia RAG Tool")
|
14 |
+
parser.add_argument(
|
15 |
+
"--query", type=str, required=True, help="Search query for Wikipedia articles"
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"--dataset-path",
|
19 |
+
type=str,
|
20 |
+
default="wikipedia-structured-contents",
|
21 |
+
help="Path to the Wikipedia dataset",
|
22 |
+
)
|
23 |
args = parser.parse_args()
|
24 |
+
|
25 |
# Initialize the tool
|
26 |
tool = WikipediaRAGTool(dataset_path=args.dataset_path)
|
27 |
+
|
28 |
# Run the query
|
29 |
print(f"\nQuery: {args.query}")
|
30 |
print("-" * 50)
|
|
|
32 |
print(f"Result: {result}")
|
33 |
print("-" * 50)
|
34 |
|
35 |
+
|
36 |
if __name__ == "__main__":
|
37 |
+
main()
|
tools/wikipedia_rag/tool.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1 |
import os
|
2 |
-
import pandas as pd
|
3 |
from typing import List, Optional
|
|
|
|
|
4 |
from langchain.docstore.document import Document
|
5 |
from langchain_community.retrievers import BM25Retriever
|
6 |
from smolagents import Tool
|
7 |
|
|
|
8 |
class WikipediaRAGTool(Tool):
|
9 |
name = "wikipedia_rag"
|
10 |
description = "Retrieves relevant information from Wikipedia articles using RAG."
|
11 |
inputs = {
|
12 |
"query": {
|
13 |
"type": "string",
|
14 |
-
"description": "The search query to find relevant Wikipedia content."
|
15 |
}
|
16 |
}
|
17 |
output_type = "string"
|
@@ -27,24 +29,24 @@ class WikipediaRAGTool(Tool):
|
|
27 |
try:
|
28 |
# Load the dataset
|
29 |
df = pd.read_csv(os.path.join(self.dataset_path, "wikipedia_articles.csv"))
|
30 |
-
|
31 |
# Convert each article into a Document
|
32 |
self.docs = [
|
33 |
Document(
|
34 |
page_content=f"Title: {row['title']}\n\nContent: {row['content']}",
|
35 |
metadata={
|
36 |
-
"title": row[
|
37 |
-
"url": row[
|
38 |
-
"category": row.get(
|
39 |
-
}
|
40 |
)
|
41 |
for _, row in df.iterrows()
|
42 |
]
|
43 |
-
|
44 |
# Initialize the retriever
|
45 |
self.retriever = BM25Retriever.from_documents(self.docs)
|
46 |
self.is_initialized = True
|
47 |
-
|
48 |
except Exception as e:
|
49 |
print(f"Error loading documents: {e}")
|
50 |
raise
|
@@ -53,17 +55,17 @@ class WikipediaRAGTool(Tool):
|
|
53 |
"""Process the query and return relevant Wikipedia content."""
|
54 |
if not self.is_initialized:
|
55 |
self._load_documents()
|
56 |
-
|
57 |
if not self.retriever:
|
58 |
return "Error: Retriever not initialized properly."
|
59 |
-
|
60 |
try:
|
61 |
# Get relevant documents
|
62 |
results = self.retriever.get_relevant_documents(query)
|
63 |
-
|
64 |
if not results:
|
65 |
return "No relevant Wikipedia articles found."
|
66 |
-
|
67 |
# Format the results
|
68 |
formatted_results = []
|
69 |
for doc in results[:3]: # Return top 3 most relevant results
|
@@ -74,8 +76,8 @@ class WikipediaRAGTool(Tool):
|
|
74 |
f"Category: {metadata['category']}\n"
|
75 |
f"Content: {doc.page_content[:500]}...\n"
|
76 |
)
|
77 |
-
|
78 |
return "\n\n".join(formatted_results)
|
79 |
-
|
80 |
except Exception as e:
|
81 |
-
return f"Error retrieving information: {str(e)}"
|
|
|
1 |
import os
|
|
|
2 |
from typing import List, Optional
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
from langchain.docstore.document import Document
|
6 |
from langchain_community.retrievers import BM25Retriever
|
7 |
from smolagents import Tool
|
8 |
|
9 |
+
|
10 |
class WikipediaRAGTool(Tool):
|
11 |
name = "wikipedia_rag"
|
12 |
description = "Retrieves relevant information from Wikipedia articles using RAG."
|
13 |
inputs = {
|
14 |
"query": {
|
15 |
"type": "string",
|
16 |
+
"description": "The search query to find relevant Wikipedia content.",
|
17 |
}
|
18 |
}
|
19 |
output_type = "string"
|
|
|
29 |
try:
|
30 |
# Load the dataset
|
31 |
df = pd.read_csv(os.path.join(self.dataset_path, "wikipedia_articles.csv"))
|
32 |
+
|
33 |
# Convert each article into a Document
|
34 |
self.docs = [
|
35 |
Document(
|
36 |
page_content=f"Title: {row['title']}\n\nContent: {row['content']}",
|
37 |
metadata={
|
38 |
+
"title": row["title"],
|
39 |
+
"url": row["url"],
|
40 |
+
"category": row.get("category", ""),
|
41 |
+
},
|
42 |
)
|
43 |
for _, row in df.iterrows()
|
44 |
]
|
45 |
+
|
46 |
# Initialize the retriever
|
47 |
self.retriever = BM25Retriever.from_documents(self.docs)
|
48 |
self.is_initialized = True
|
49 |
+
|
50 |
except Exception as e:
|
51 |
print(f"Error loading documents: {e}")
|
52 |
raise
|
|
|
55 |
"""Process the query and return relevant Wikipedia content."""
|
56 |
if not self.is_initialized:
|
57 |
self._load_documents()
|
58 |
+
|
59 |
if not self.retriever:
|
60 |
return "Error: Retriever not initialized properly."
|
61 |
+
|
62 |
try:
|
63 |
# Get relevant documents
|
64 |
results = self.retriever.get_relevant_documents(query)
|
65 |
+
|
66 |
if not results:
|
67 |
return "No relevant Wikipedia articles found."
|
68 |
+
|
69 |
# Format the results
|
70 |
formatted_results = []
|
71 |
for doc in results[:3]: # Return top 3 most relevant results
|
|
|
76 |
f"Category: {metadata['category']}\n"
|
77 |
f"Content: {doc.page_content[:500]}...\n"
|
78 |
)
|
79 |
+
|
80 |
return "\n\n".join(formatted_results)
|
81 |
+
|
82 |
except Exception as e:
|
83 |
+
return f"Error retrieving information: {str(e)}"
|