File size: 5,841 Bytes
4fbcf68 4f3a783 4fbcf68 bdba660 4fbcf68 d784ff5 bdba660 4fbcf68 d784ff5 4fbcf68 d784ff5 4fbcf68 d3c4ed6 c5658f1 8688d5e c5658f1 8690a80 c5658f1 8690a80 c5658f1 d784ff5 c5658f1 d3c4ed6 bb534e9 a000f1e 4fbcf68 d784ff5 4fbcf68 d784ff5 4fbcf68 30e7daa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
from typing import Dict, List, Any
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai import RunContext
from pydantic import BaseModel
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
from csv_service import get_csv_basic_info
from orchestrator_functions import csv_chart, csv_chat
from dotenv import load_dotenv
load_dotenv()
# Load all API keys from the environment variable
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") # Expecting a comma-separated list of keys
# Function to initialize the model with a specific API key
def initialize_model(api_key: str) -> GeminiModel:
return GeminiModel(
'gemini-2.0-flash',
provider=GoogleGLAProvider(api_key=api_key)
)
# Define the tools
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
"""
This function generates answers for the given user questions using the CSV URL.
It uses the csv_chat function to process each question and return the answers.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
Example:
[
{"question": "What is the average age of the customers?", "answer": "The average age is 35."},
{"question": "What is the most common gender?", "answer": "The most common gender is Male."}
]
"""
print("LLM using the csv chat function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the answers
answers = []
# Loop through the user questions and generate answers for each
for question in user_questions:
answer = await csv_chat(csv_url, question)
answers.append(dict(question=question, answer=answer))
return answers
async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
"""
This function generates charts for the given user questions using the CSV URL.
It uses the csv_chart function to process each question and return the chart URLs.
It returns a list of dictionaries containing the question and chart URL for each question.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.
Example:
[
{"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
{"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
]
"""
print("LLM using the csv chart function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the charts
charts = []
# Loop through the user questions and generate charts for each
for question in user_questions:
chart = await csv_chart(csv_url, question, chat_id)
charts.append(dict(question=question, image_url=chart))
return charts
# Function to create an agent with a specific CSV URL
def create_agent(csv_url: str, api_key: str, conversation_history: List, chat_id: str) -> Agent:
csv_metadata = get_csv_basic_info(csv_url)
system_prompt = f"""
# Role: Data Analyst Assistant
**Specialization:** CSV Analysis & Visualization
## Key Rules:
1. **Always provide both:**
- Complete textual answer with explanations
- Visualization when applicable
2. **Output Format:** Markdown compatible (visualizations as ``)
3. **Tool Handling:**
- Use `generate_csv_answer` for analysis
- Use `generate_chart` for visuals
- Never disclose tool names
4. **Visualization Fallback:**
- If requested library (plotly, bokeh etc.) isn't available:
- Provide closest alternative
- Explain the limitation
## Current Context:
- **Dataset:** {csv_url}
- **Metadata:** {csv_metadata}
- **History:** {conversation_history}
- **Chat ID:** {chat_id}
## Required Output:
For every question return:
1. Clear analysis answer
2. Visualization (when possible, in markdown format)
3. Follow-up suggestions
**Critical:** Never return partial responses - always combine both textual answers and visualizations when applicable.
"""
return Agent(
model=initialize_model(api_key),
deps_type=str,
tools=[generate_csv_answer, generate_chart],
system_prompt=system_prompt
)
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
print("CSV URL:", csv_url)
print("User questions:", user_question)
# Iterate through all API keys
for api_key in GEMINI_API_KEYS:
try:
print(f"Attempting with API key: {api_key}")
agent = create_agent(csv_url, api_key, conversation_history, chat_id)
result = agent.run_sync(user_question)
print("Orchestrator Result:", result.data)
return result.data
except ResourceExhausted or Exception as e:
print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
continue # Move to the next key
except Exception as e:
print(f"Error with API key {api_key}: {e}")
continue # Move to the next key
# If all keys are exhausted or fail
print("All API keys have been exhausted or failed.")
return None
|