File size: 4,539 Bytes
4fbcf68
 
 
 
 
 
 
 
 
 
 
4f3a783
 
 
4fbcf68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f05635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fbcf68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

import os
from typing import Dict, List, Any
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai import RunContext
from pydantic import BaseModel
from google.api_core.exceptions import ResourceExhausted  # Import the exception for quota exhaustion
from csv_service import get_csv_basic_info
from orchestrator_functions import csv_chart, csv_chat
from dotenv import load_dotenv

load_dotenv()


# Load all API keys from the environment variable
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")  # Expecting a comma-separated list of keys

# Function to initialize the model with a specific API key
def initialize_model(api_key: str) -> GeminiModel:
    return GeminiModel(
        'gemini-2.0-flash',
        provider=GoogleGLAProvider(api_key=api_key)
    )

# Define the tools
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
    print("LLM using the csv chat function....")
    print("CSV URL:", csv_url)
    print("User question:", user_questions)

    # Create an array to accumulate the answers
    answers = []
    # Loop through the user questions and generate answers for each
    for question in user_questions:
        answer = await csv_chat(csv_url, question)
        answers.append(dict(question=question, answer=answer))
    return answers

async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
    print("LLM using the csv chart function....")
    print("CSV URL:", csv_url)
    print("User question:", user_questions)

    # Create an array to accumulate the charts
    charts = []
    # Loop through the user questions and generate charts for each
    for question in user_questions:
        chart = await csv_chart(csv_url, question)
        charts.append(dict(question=question, image_url=chart))
    
    return charts

# Function to create an agent with a specific CSV URL
def create_agent(csv_url: str, api_key: str) -> Agent:
    csv_metadata = get_csv_basic_info(csv_url)
    
    system_prompt = (
        "You are a data analyst assistant with limited tool capabilities. "
        "Available tools can only handle simple data queries: "
        "- Count rows/columns\n- Calculate basic stats (avg, sum, min/max)\n"
        "- Create simple visualizations (pie charts, bar graphs)\n"
        "- Show column names/types\n\n"
        
        "Query Handling Rules:\n"
        "1. If query is complex, ambiguous, or exceeds tool capabilities:\n"
        "   - Break into simpler sub-questions\n"
        "   - Ask for clarification\n"
        "   - Rephrase to nearest simple query\n"
        "2. For 'full report' requests:\n"
        "   - Outline possible analysis steps\n"
        "   - Ask user to select one component at a time\n\n"
        
        "Examples:\n"
        "- Bad query: 'Show me everything'\n"
        "  Response: 'I can show row count (10), columns (5: Name, Age...), "
        "or a pie chart of categories. Which would you like?'\n"
        "- Bad query: 'Analyze trends'\n"
        "  Response: 'For trend analysis, I can show monthly averages or "
        "year-over-year comparisons. Please specify time period and metric.'\n\n"
        
        "Current CSV Context:\n"
        f"- URL: {csv_url}\n"
        f"- Metadata: {csv_metadata}\n\n"
        
        "Always format images as: ![Chart Description](direct-image-url)"
    )
    return Agent(
        model=initialize_model(api_key),
        deps_type=str,
        tools=[generate_csv_answer, generate_chart],
        system_prompt=system_prompt
    )

def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
    print("CSV URL:", csv_url)
    print("User questions:", user_question)

    # Iterate through all API keys
    for api_key in GEMINI_API_KEYS:
        try:
            print(f"Attempting with API key: {api_key}")
            agent = create_agent(csv_url, api_key)
            result = agent.run_sync(user_question)
            print("Orchestrator Result:", result.data)
            return result.data
        except ResourceExhausted or Exception as e:
            print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
            continue  # Move to the next key
        except Exception as e:
            print(f"Error with API key {api_key}: {e}")
            continue  # Move to the next key

    # If all keys are exhausted or fail
    print("All API keys have been exhausted or failed.")
    return None