|
|
|
import os |
|
from typing import Dict, List, Any |
|
from pydantic_ai import Agent |
|
from pydantic_ai.models.gemini import GeminiModel |
|
from pydantic_ai.providers.google_gla import GoogleGLAProvider |
|
from pydantic_ai import RunContext |
|
from pydantic import BaseModel |
|
from google.api_core.exceptions import ResourceExhausted |
|
from csv_service import get_csv_basic_info |
|
from orchestrator_functions import csv_chart, csv_chat |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") |
|
|
|
|
|
def initialize_model(api_key: str) -> GeminiModel: |
|
return GeminiModel( |
|
'gemini-2.0-flash', |
|
provider=GoogleGLAProvider(api_key=api_key) |
|
) |
|
|
|
|
|
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any: |
|
""" |
|
This function generates answers for the given user questions using the CSV URL. |
|
It uses the csv_chat function to process each question and return the answers. |
|
|
|
Args: |
|
csv_url (str): The URL of the CSV file. |
|
user_questions (List[str]): A list of user questions. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question. |
|
|
|
Example: |
|
[ |
|
{"question": "What is the average age of the customers?", "answer": "The average age is 35."}, |
|
{"question": "What is the most common gender?", "answer": "The most common gender is Male."} |
|
] |
|
""" |
|
|
|
print("LLM using the csv chat function....") |
|
print("CSV URL:", csv_url) |
|
print("User question:", user_questions) |
|
|
|
|
|
answers = [] |
|
|
|
for question in user_questions: |
|
answer = await csv_chat(csv_url, question) |
|
answers.append(dict(question=question, answer=answer)) |
|
return answers |
|
|
|
async def generate_chart(csv_url: str, user_questions: List[str]) -> Any: |
|
|
|
""" |
|
This function generates charts for the given user questions using the CSV URL. |
|
It uses the csv_chart function to process each question and return the chart URLs. |
|
It returns a list of dictionaries containing the question and chart URL for each question. |
|
Args: |
|
csv_url (str): The URL of the CSV file. |
|
user_questions (List[str]): A list of user questions. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question. |
|
|
|
Example: |
|
[ |
|
{"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"}, |
|
{"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"} |
|
] |
|
""" |
|
|
|
print("LLM using the csv chart function....") |
|
print("CSV URL:", csv_url) |
|
print("User question:", user_questions) |
|
|
|
|
|
charts = [] |
|
|
|
for question in user_questions: |
|
chart = await csv_chart(csv_url, question) |
|
charts.append(dict(question=question, image_url=chart)) |
|
|
|
return charts |
|
|
|
|
|
def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agent: |
|
csv_metadata = get_csv_basic_info(csv_url) |
|
|
|
system_prompt = f""" |
|
# Role: Expert Data Analysis Assistant |
|
|
|
## Capabilities: |
|
- Break complex queries into simpler sub-tasks |
|
|
|
## Instruction Framework: |
|
1. QUERY PROCESSING: |
|
- If request contains multiple questions: |
|
a) Decompose into logical sub-questions |
|
b) Process sequentially |
|
c) Combine results coherently |
|
|
|
2. DATA HANDLING: |
|
- Always verify CSV structure matches the request |
|
- Handle missing/ambiguous data by: |
|
a) Asking clarifying questions OR |
|
b) Making reasonable assumptions (state them clearly) |
|
|
|
3. VISUALIZATION STANDARDS: |
|
- Format images as: `` |
|
- Include axis labels and titles |
|
- Use appropriate chart types |
|
|
|
4. COMMUNICATION PROTOCOL: |
|
- Friendly, professional tone |
|
- Explain technical terms |
|
- Summarize key findings |
|
- Highlight limitations/caveats |
|
|
|
5. TOOL USAGE: |
|
- Can process statistical operations |
|
- Supports visualization libraries |
|
|
|
## Current Context: |
|
- Working with CSV_URL: {csv_url} |
|
- Dataset overview: {csv_metadata} |
|
- Your conversation history: {conversation_history} |
|
- Output format: Markdown compatible |
|
|
|
## Response Template: |
|
1. Confirm understanding of request |
|
2. Outline analysis approach |
|
3. Present results with visualizations (if applicable) |
|
4. Provide interpretation |
|
5. Offer next-step suggestions |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Agent( |
|
model=initialize_model(api_key), |
|
deps_type=str, |
|
tools=[generate_csv_answer, generate_chart], |
|
system_prompt=system_prompt |
|
) |
|
|
|
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str: |
|
print("CSV URL:", csv_url) |
|
print("User questions:", user_question) |
|
|
|
|
|
for api_key in GEMINI_API_KEYS: |
|
try: |
|
print(f"Attempting with API key: {api_key}") |
|
agent = create_agent(csv_url, api_key, conversation_history) |
|
result = agent.run_sync(user_question) |
|
print("Orchestrator Result:", result.data) |
|
return result.data |
|
except ResourceExhausted or Exception as e: |
|
print(f"Quota exhausted for API key: {api_key}. Switching to the next key.") |
|
continue |
|
except Exception as e: |
|
print(f"Error with API key {api_key}: {e}") |
|
continue |
|
|
|
|
|
print("All API keys have been exhausted or failed.") |
|
return None |
|
|