FastApi / orchestrator_agent.py
Soumik555's picture
added gemini too
48e6960
raw
history blame
8.21 kB
import os
from typing import Dict, List, Any
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai import RunContext
from pydantic import BaseModel
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
from csv_service import get_csv_basic_info
from orchestrator_functions import csv_chart, csv_chat
from dotenv import load_dotenv
load_dotenv()
# Load all API keys from the environment variable
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") # Expecting a comma-separated list of keys
# Function to initialize the model with a specific API key
def initialize_model(api_key: str) -> GeminiModel:
return GeminiModel(
'gemini-2.0-flash',
provider=GoogleGLAProvider(api_key=api_key)
)
# Define the tools
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
"""
This function generates answers for the given user questions using the CSV URL.
It uses the csv_chat function to process each question and return the answers.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
Example:
[
{"question": "What is the average age of the customers?", "answer": "The average age is 35."},
{"question": "What is the most common gender?", "answer": "The most common gender is Male."}
]
"""
print("LLM using the csv chat function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the answers
answers = []
# Loop through the user questions and generate answers for each
for question in user_questions:
answer = await csv_chat(csv_url, question)
answers.append(dict(question=question, answer=answer))
return answers
async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
"""
This function generates charts for the given user questions using the CSV URL.
It uses the csv_chart function to process each question and return the chart URLs.
It returns a list of dictionaries containing the question and chart URL for each question.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.
Example:
[
{"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
{"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
]
"""
print("LLM using the csv chart function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the charts
charts = []
# Loop through the user questions and generate charts for each
for question in user_questions:
chart = await csv_chart(csv_url, question)
charts.append(dict(question=question, image_url=chart))
return charts
# Function to create an agent with a specific CSV URL
def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agent:
csv_metadata = get_csv_basic_info(csv_url)
system_prompt = f"""
# Role: Expert Data Analysis Assistant
## Capabilities:
- Break complex queries into simpler sub-tasks
## Instruction Framework:
1. QUERY PROCESSING:
- If request contains multiple questions:
a) Decompose into logical sub-questions
b) Process sequentially
c) Combine results coherently
2. DATA HANDLING:
- Always verify CSV structure matches the request
- Handle missing/ambiguous data by:
a) Asking clarifying questions OR
b) Making reasonable assumptions (state them clearly)
3. VISUALIZATION STANDARDS:
- Format images as: `![Description](direct-url)`
- Include axis labels and titles
- Use appropriate chart types
4. COMMUNICATION PROTOCOL:
- Friendly, professional tone
- Explain technical terms
- Summarize key findings
- Highlight limitations/caveats
5. TOOL USAGE:
- Can process statistical operations
- Supports visualization libraries
## Current Context:
- Working with CSV_URL: {csv_url}
- Dataset overview: {csv_metadata}
- Your conversation history: {conversation_history}
- Output format: Markdown compatible
## Response Template:
1. Confirm understanding of request
2. Outline analysis approach
3. Present results with visualizations (if applicable)
4. Provide interpretation
5. Offer next-step suggestions
"""
# system_prompt = (
# "You are a data analyst. "
# "You have all the tools you need to answer any question. "
# "If the user asks for multiple answers or charts, break the question into several well-defined questions. "
# "Pass the CSV URL or file path along with the questions to the tools to generate the answer. "
# "The tools are actually LLMs with Python code execution capabilities. "
# "Modify the query if needed to make it simpler for the LLM to understand. "
# "Answer in a friendly and helpful manner. "
# "**Format images** in Markdown: `![alt_text](direct-image-url)`. "
# f"Your CSV URL is {csv_url}. "
# f"Your CSV metadata is {csv_metadata}."
# )
# system_prompt = (
# "You are a data analyst assistant with limited tool capabilities. "
# "Available tools can only handle simple data queries: "
# "- Count rows/columns\n- Calculate basic stats (avg, sum, min/max)\n"
# "- Create simple visualizations (pie charts, bar graphs)\n"
# "- Show column names/types\n\n"
# "Query Handling Rules:\n"
# "1. If query is complex, ambiguous, or exceeds tool capabilities:\n"
# " - Break into simpler sub-questions\n"
# " - Ask for clarification\n"
# " - Rephrase to nearest simple query\n"
# "2. For 'full report' requests:\n"
# " - Outline possible analysis steps\n"
# " - Ask user to select one component at a time\n\n"
# "Examples:\n"
# "- Bad query: 'Show me everything'\n"
# " Response: 'I can show row count (10), columns (5: Name, Age...), "
# "or a pie chart of categories. Which would you like?'\n"
# "- Bad query: 'Analyze trends'\n"
# " Response: 'For trend analysis, I can show monthly averages or "
# "year-over-year comparisons. Please specify time period and metric.'\n\n"
# "Current CSV Context:\n"
# f"- URL: {csv_url}\n"
# f"- Metadata: {csv_metadata}\n\n"
# "Always format images as: ![Chart Description](direct-image-url)"
# )
return Agent(
model=initialize_model(api_key),
deps_type=str,
tools=[generate_csv_answer, generate_chart],
system_prompt=system_prompt
)
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
print("CSV URL:", csv_url)
print("User questions:", user_question)
# Iterate through all API keys
for api_key in GEMINI_API_KEYS:
try:
print(f"Attempting with API key: {api_key}")
agent = create_agent(csv_url, api_key, conversation_history)
result = agent.run_sync(user_question)
print("Orchestrator Result:", result.data)
return result.data
except ResourceExhausted or Exception as e:
print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
continue # Move to the next key
except Exception as e:
print(f"Error with API key {api_key}: {e}")
continue # Move to the next key
# If all keys are exhausted or fail
print("All API keys have been exhausted or failed.")
return None