FastApi / orchestrator_agent.py
Soumik555's picture
Report generator
8f05635
raw
history blame
4.54 kB
import os
from typing import Dict, List, Any
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai import RunContext
from pydantic import BaseModel
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
from csv_service import get_csv_basic_info
from orchestrator_functions import csv_chart, csv_chat
from dotenv import load_dotenv
load_dotenv()
# Load all API keys from the environment variable
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") # Expecting a comma-separated list of keys
# Function to initialize the model with a specific API key
def initialize_model(api_key: str) -> GeminiModel:
return GeminiModel(
'gemini-2.0-flash',
provider=GoogleGLAProvider(api_key=api_key)
)
# Define the tools
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
print("LLM using the csv chat function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the answers
answers = []
# Loop through the user questions and generate answers for each
for question in user_questions:
answer = await csv_chat(csv_url, question)
answers.append(dict(question=question, answer=answer))
return answers
async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
print("LLM using the csv chart function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the charts
charts = []
# Loop through the user questions and generate charts for each
for question in user_questions:
chart = await csv_chart(csv_url, question)
charts.append(dict(question=question, image_url=chart))
return charts
# Function to create an agent with a specific CSV URL
def create_agent(csv_url: str, api_key: str) -> Agent:
csv_metadata = get_csv_basic_info(csv_url)
system_prompt = (
"You are a data analyst assistant with limited tool capabilities. "
"Available tools can only handle simple data queries: "
"- Count rows/columns\n- Calculate basic stats (avg, sum, min/max)\n"
"- Create simple visualizations (pie charts, bar graphs)\n"
"- Show column names/types\n\n"
"Query Handling Rules:\n"
"1. If query is complex, ambiguous, or exceeds tool capabilities:\n"
" - Break into simpler sub-questions\n"
" - Ask for clarification\n"
" - Rephrase to nearest simple query\n"
"2. For 'full report' requests:\n"
" - Outline possible analysis steps\n"
" - Ask user to select one component at a time\n\n"
"Examples:\n"
"- Bad query: 'Show me everything'\n"
" Response: 'I can show row count (10), columns (5: Name, Age...), "
"or a pie chart of categories. Which would you like?'\n"
"- Bad query: 'Analyze trends'\n"
" Response: 'For trend analysis, I can show monthly averages or "
"year-over-year comparisons. Please specify time period and metric.'\n\n"
"Current CSV Context:\n"
f"- URL: {csv_url}\n"
f"- Metadata: {csv_metadata}\n\n"
"Always format images as: ![Chart Description](direct-image-url)"
)
return Agent(
model=initialize_model(api_key),
deps_type=str,
tools=[generate_csv_answer, generate_chart],
system_prompt=system_prompt
)
def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
print("CSV URL:", csv_url)
print("User questions:", user_question)
# Iterate through all API keys
for api_key in GEMINI_API_KEYS:
try:
print(f"Attempting with API key: {api_key}")
agent = create_agent(csv_url, api_key)
result = agent.run_sync(user_question)
print("Orchestrator Result:", result.data)
return result.data
except ResourceExhausted or Exception as e:
print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
continue # Move to the next key
except Exception as e:
print(f"Error with API key {api_key}: {e}")
continue # Move to the next key
# If all keys are exhausted or fail
print("All API keys have been exhausted or failed.")
return None