modifies orchestrator,add code_exec tool, openai_chat (later we add chat)
Browse files- orchestrator_agent.py +50 -40
orchestrator_agent.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
|
2 |
-
from datetime import datetime
|
3 |
import os
|
4 |
from typing import Dict, List, Any
|
5 |
from pydantic_ai import Agent
|
@@ -8,12 +7,10 @@ from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
|
8 |
from pydantic_ai import RunContext
|
9 |
from pydantic import BaseModel
|
10 |
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
|
11 |
-
from code_exec_service import run_analysis
|
12 |
from csv_service import get_csv_basic_info
|
13 |
from orchestrator_functions import csv_chart, csv_chat
|
14 |
from dotenv import load_dotenv
|
15 |
|
16 |
-
|
17 |
load_dotenv()
|
18 |
|
19 |
|
@@ -128,14 +125,14 @@ def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agen
|
|
128 |
- Highlight limitations/caveats
|
129 |
|
130 |
5. TOOL USAGE:
|
131 |
-
-
|
132 |
-
-
|
133 |
-
- Chart Generation Tool
|
134 |
|
135 |
## Current Context:
|
136 |
- Working with CSV_URL: {csv_url}
|
137 |
- Dataset overview: {csv_metadata}
|
138 |
- Your conversation history: {conversation_history}
|
|
|
139 |
|
140 |
## Response Template:
|
141 |
1. Confirm understanding of request
|
@@ -144,44 +141,58 @@ def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agen
|
|
144 |
4. Provide interpretation
|
145 |
5. Offer next-step suggestions
|
146 |
"""
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
model=initialize_model(api_key),
|
149 |
deps_type=str,
|
150 |
tools=[generate_csv_answer, generate_chart],
|
151 |
system_prompt=system_prompt
|
152 |
)
|
153 |
-
|
154 |
-
@gemini_csv_orchestrator_agent.tool_plain
|
155 |
-
def python_code_executor(analysis_code: str) -> dict:
|
156 |
-
"""_summary_
|
157 |
-
|
158 |
-
Args:
|
159 |
-
analysis_code (str): _description_
|
160 |
-
Ex:
|
161 |
-
df = pd.read_csv({csv_url})
|
162 |
-
len(df)
|
163 |
-
|
164 |
-
Returns:
|
165 |
-
dict: _description_
|
166 |
-
"""
|
167 |
-
|
168 |
-
print(f'LLM Passed a code: {analysis_code}')
|
169 |
-
result = run_analysis(analysis_code)
|
170 |
-
if result['success']:
|
171 |
-
print("Execution successful")
|
172 |
-
print("Execution time:", result['execution_time'], "seconds")
|
173 |
-
print("Output:", result['output'].strip())
|
174 |
-
print("Result:", result['result'])
|
175 |
-
print("Variables:", list(result['variables'].keys()))
|
176 |
-
# convert the result to a string
|
177 |
-
result_str = str(result['output'])
|
178 |
-
return result_str
|
179 |
-
else:
|
180 |
-
print("Execution failed")
|
181 |
-
print("Error:", result['error'])
|
182 |
-
error_str = str(result['error'])
|
183 |
-
return error_str
|
184 |
-
return gemini_csv_orchestrator_agent
|
185 |
|
186 |
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
|
187 |
print("CSV URL:", csv_url)
|
@@ -205,4 +216,3 @@ def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history
|
|
205 |
# If all keys are exhausted or fail
|
206 |
print("All API keys have been exhausted or failed.")
|
207 |
return None
|
208 |
-
|
|
|
1 |
|
|
|
2 |
import os
|
3 |
from typing import Dict, List, Any
|
4 |
from pydantic_ai import Agent
|
|
|
7 |
from pydantic_ai import RunContext
|
8 |
from pydantic import BaseModel
|
9 |
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
|
|
|
10 |
from csv_service import get_csv_basic_info
|
11 |
from orchestrator_functions import csv_chart, csv_chat
|
12 |
from dotenv import load_dotenv
|
13 |
|
|
|
14 |
load_dotenv()
|
15 |
|
16 |
|
|
|
125 |
- Highlight limitations/caveats
|
126 |
|
127 |
5. TOOL USAGE:
|
128 |
+
- Can process statistical operations
|
129 |
+
- Supports visualization libraries
|
|
|
130 |
|
131 |
## Current Context:
|
132 |
- Working with CSV_URL: {csv_url}
|
133 |
- Dataset overview: {csv_metadata}
|
134 |
- Your conversation history: {conversation_history}
|
135 |
+
- Output format: Markdown compatible
|
136 |
|
137 |
## Response Template:
|
138 |
1. Confirm understanding of request
|
|
|
141 |
4. Provide interpretation
|
142 |
5. Offer next-step suggestions
|
143 |
"""
|
144 |
+
|
145 |
+
# system_prompt = (
|
146 |
+
# "You are a data analyst. "
|
147 |
+
# "You have all the tools you need to answer any question. "
|
148 |
+
# "If the user asks for multiple answers or charts, break the question into several well-defined questions. "
|
149 |
+
# "Pass the CSV URL or file path along with the questions to the tools to generate the answer. "
|
150 |
+
# "The tools are actually LLMs with Python code execution capabilities. "
|
151 |
+
# "Modify the query if needed to make it simpler for the LLM to understand. "
|
152 |
+
# "Answer in a friendly and helpful manner. "
|
153 |
+
# "**Format images** in Markdown: ``. "
|
154 |
+
# f"Your CSV URL is {csv_url}. "
|
155 |
+
# f"Your CSV metadata is {csv_metadata}."
|
156 |
+
# )
|
157 |
+
|
158 |
+
|
159 |
+
# system_prompt = (
|
160 |
+
# "You are a data analyst assistant with limited tool capabilities. "
|
161 |
+
# "Available tools can only handle simple data queries: "
|
162 |
+
# "- Count rows/columns\n- Calculate basic stats (avg, sum, min/max)\n"
|
163 |
+
# "- Create simple visualizations (pie charts, bar graphs)\n"
|
164 |
+
# "- Show column names/types\n\n"
|
165 |
+
|
166 |
+
# "Query Handling Rules:\n"
|
167 |
+
# "1. If query is complex, ambiguous, or exceeds tool capabilities:\n"
|
168 |
+
# " - Break into simpler sub-questions\n"
|
169 |
+
# " - Ask for clarification\n"
|
170 |
+
# " - Rephrase to nearest simple query\n"
|
171 |
+
# "2. For 'full report' requests:\n"
|
172 |
+
# " - Outline possible analysis steps\n"
|
173 |
+
# " - Ask user to select one component at a time\n\n"
|
174 |
+
|
175 |
+
# "Examples:\n"
|
176 |
+
# "- Bad query: 'Show me everything'\n"
|
177 |
+
# " Response: 'I can show row count (10), columns (5: Name, Age...), "
|
178 |
+
# "or a pie chart of categories. Which would you like?'\n"
|
179 |
+
# "- Bad query: 'Analyze trends'\n"
|
180 |
+
# " Response: 'For trend analysis, I can show monthly averages or "
|
181 |
+
# "year-over-year comparisons. Please specify time period and metric.'\n\n"
|
182 |
+
|
183 |
+
# "Current CSV Context:\n"
|
184 |
+
# f"- URL: {csv_url}\n"
|
185 |
+
# f"- Metadata: {csv_metadata}\n\n"
|
186 |
+
|
187 |
+
# "Always format images as: "
|
188 |
+
# )
|
189 |
+
|
190 |
+
return Agent(
|
191 |
model=initialize_model(api_key),
|
192 |
deps_type=str,
|
193 |
tools=[generate_csv_answer, generate_chart],
|
194 |
system_prompt=system_prompt
|
195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
|
198 |
print("CSV URL:", csv_url)
|
|
|
216 |
# If all keys are exhausted or fail
|
217 |
print("All API keys have been exhausted or failed.")
|
218 |
return None
|
|