FastApi / orchestrator_agent.py
Soumik555's picture
browse url infinite loading fixed with backend check
8690a80
raw
history blame
6.63 kB
import os
from typing import Dict, List, Any
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai import RunContext
from pydantic import BaseModel
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
from csv_service import get_csv_basic_info
from orchestrator_functions import csv_chart, csv_chat
from dotenv import load_dotenv
load_dotenv()
# Load all API keys from the environment variable
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") # Expecting a comma-separated list of keys
# Function to initialize the model with a specific API key
def initialize_model(api_key: str) -> GeminiModel:
return GeminiModel(
'gemini-2.0-flash',
provider=GoogleGLAProvider(api_key=api_key)
)
# Define the tools
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
"""
This function generates answers for the given user questions using the CSV URL.
It uses the csv_chat function to process each question and return the answers.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
Example:
[
{"question": "What is the average age of the customers?", "answer": "The average age is 35."},
{"question": "What is the most common gender?", "answer": "The most common gender is Male."}
]
"""
print("LLM using the csv chat function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the answers
answers = []
# Loop through the user questions and generate answers for each
for question in user_questions:
answer = await csv_chat(csv_url, question)
answers.append(dict(question=question, answer=answer))
return answers
async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
"""
This function generates charts for the given user questions using the CSV URL.
It uses the csv_chart function to process each question and return the chart URLs.
It returns a list of dictionaries containing the question and chart URL for each question.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.
Example:
[
{"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
{"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
]
"""
print("LLM using the csv chart function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the charts
charts = []
# Loop through the user questions and generate charts for each
for question in user_questions:
chart = await csv_chart(csv_url, question)
charts.append(dict(question=question, image_url=chart))
return charts
# Function to create an agent with a specific CSV URL
def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agent:
csv_metadata = get_csv_basic_info(csv_url)
system_prompt = f"""
# Role: Expert Data Analyst Assistant
**Specialization:** CSV Data Analysis & Visualization
## Core Responsibilities:
1. **Data Analysis:** Perform thorough analysis of CSV data to extract insights
2. **Visualization:** Create clear, informative visualizations using available libraries
3. **Guidance:** Help users formulate better data questions and understand results
## Technical Specifications:
- **Available Libraries:** matplotlib, seaborn
- **Output Format:** Markdown compatible (including visualizations as `![Description](direct-url)`)
- **Data Handling:**
- Auto-verify CSV structure before analysis
- Handle missing data by either:
- Making clear assumptions (and stating them)
- Requesting user clarification when critical
## Workflow Rules:
1. **Query Processing:**
- Break complex questions into logical steps
- Optimize questions before tool execution
- Process multi-part queries sequentially and combine results
2. **Tool Usage:**
- Primary tools:
- `generate_csv_answer` for data analysis
- `generate_chart` for visualizations
- Never disclose tool names or internal processes
- If requested visualization isn't available (plotly, bokeh, etc.):
- Suggest closest alternative
- Provide clear explanation
3. **User Interaction:**
- When question relates to dataset:
- First use tools to generate potential answers
- Then cross-check with user if needed
- Maintain friendly yet professional tone
- Read questions carefully before responding
## Current Context:
- **Dataset URL:** {csv_url}
- **Metadata:** {csv_metadata}
- **Conversation History:** {conversation_history}
## Style Guidelines:
- Prioritize clarity over technical jargon
- Present one logical thought per paragraph
- Use bullet points for complex information
- Always verify critical assumptions with users
"""
return Agent(
model=initialize_model(api_key),
deps_type=str,
tools=[generate_csv_answer, generate_chart],
system_prompt=system_prompt
)
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
print("CSV URL:", csv_url)
print("User questions:", user_question)
# Iterate through all API keys
for api_key in GEMINI_API_KEYS:
try:
print(f"Attempting with API key: {api_key}")
agent = create_agent(csv_url, api_key, conversation_history)
result = agent.run_sync(user_question)
print("Orchestrator Result:", result.data)
return result.data
except ResourceExhausted or Exception as e:
print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
continue # Move to the next key
except Exception as e:
print(f"Error with API key {api_key}: {e}")
continue # Move to the next key
# If all keys are exhausted or fail
print("All API keys have been exhausted or failed.")
return None