File size: 5,841 Bytes
4fbcf68
 
 
 
 
 
 
 
 
 
 
4f3a783
 
 
4fbcf68
 
 
 
 
 
 
 
 
 
 
 
 
 
bdba660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fbcf68
 
 
 
 
 
 
 
 
 
 
 
d784ff5
bdba660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fbcf68
 
 
 
 
 
 
 
d784ff5
4fbcf68
 
 
 
 
d784ff5
4fbcf68
 
d3c4ed6
c5658f1
 
 
 
 
 
 
8688d5e
c5658f1
 
 
 
 
 
 
 
8690a80
 
c5658f1
8690a80
c5658f1
d784ff5
c5658f1
 
 
 
 
 
 
 
d3c4ed6
bb534e9
a000f1e
 
4fbcf68
 
 
 
 
 
d784ff5
4fbcf68
 
 
 
 
 
 
d784ff5
4fbcf68
 
 
 
 
 
 
 
 
 
 
 
 
30e7daa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

import os
from typing import Dict, List, Any
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai import RunContext
from pydantic import BaseModel
from google.api_core.exceptions import ResourceExhausted  # Import the exception for quota exhaustion
from csv_service import get_csv_basic_info
from orchestrator_functions import csv_chart, csv_chat
from dotenv import load_dotenv

load_dotenv()


# Load all API keys from the environment variable
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")  # Expecting a comma-separated list of keys

# Function to initialize the model with a specific API key
def initialize_model(api_key: str) -> GeminiModel:
    return GeminiModel(
        'gemini-2.0-flash',
        provider=GoogleGLAProvider(api_key=api_key)
    )

# Define the tools
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
    """
    This function generates answers for the given user questions using the CSV URL.
    It uses the csv_chat function to process each question and return the answers.
    
    Args:
        csv_url (str): The URL of the CSV file.
        user_questions (List[str]): A list of user questions.
        
    Returns:
        List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
        
    Example:
        [
            {"question": "What is the average age of the customers?", "answer": "The average age is 35."},
            {"question": "What is the most common gender?", "answer": "The most common gender is Male."}
        ]
    """
    
    print("LLM using the csv chat function....")
    print("CSV URL:", csv_url)
    print("User question:", user_questions)

    # Create an array to accumulate the answers
    answers = []
    # Loop through the user questions and generate answers for each
    for question in user_questions:
        answer = await csv_chat(csv_url, question)
        answers.append(dict(question=question, answer=answer))
    return answers

async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
    
    """
    This function generates charts for the given user questions using the CSV URL.
    It uses the csv_chart function to process each question and return the chart URLs.
    It returns a list of dictionaries containing the question and chart URL for each question.
    Args:
        csv_url (str): The URL of the CSV file.
        user_questions (List[str]): A list of user questions.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.

    Example:
        [
            {"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
            {"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
        ]
    """
    
    print("LLM using the csv chart function....")
    print("CSV URL:", csv_url)
    print("User question:", user_questions)

    # Create an array to accumulate the charts
    charts = []
    # Loop through the user questions and generate charts for each
    for question in user_questions:
        chart = await csv_chart(csv_url, question, chat_id)
        charts.append(dict(question=question, image_url=chart))
    
    return charts

# Function to create an agent with a specific CSV URL
def create_agent(csv_url: str, api_key: str, conversation_history: List, chat_id: str) -> Agent:
    csv_metadata = get_csv_basic_info(csv_url)
    
    system_prompt = f"""
# Role: Data Analyst Assistant
**Specialization:** CSV Analysis & Visualization

## Key Rules:
1. **Always provide both:** 
   - Complete textual answer with explanations
   - Visualization when applicable
2. **Output Format:** Markdown compatible (visualizations as `![Image Description](url generated by tool)`)
3. **Tool Handling:**
   - Use `generate_csv_answer` for analysis
   - Use `generate_chart` for visuals
   - Never disclose tool names
4. **Visualization Fallback:**
   - If requested library (plotly, bokeh etc.) isn't available:
     - Provide closest alternative
     - Explain the limitation

## Current Context:
- **Dataset:** {csv_url}
- **Metadata:** {csv_metadata}
- **History:** {conversation_history}
- **Chat ID:** {chat_id}

## Required Output:
For every question return:
1. Clear analysis answer
2. Visualization (when possible, in markdown format)
3. Follow-up suggestions

**Critical:** Never return partial responses - always combine both textual answers and visualizations when applicable.
"""

    
    return Agent(
        model=initialize_model(api_key),
        deps_type=str,
        tools=[generate_csv_answer, generate_chart],
        system_prompt=system_prompt
    )

def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
    print("CSV URL:", csv_url)
    print("User questions:", user_question)

    # Iterate through all API keys
    for api_key in GEMINI_API_KEYS:
        try:
            print(f"Attempting with API key: {api_key}")
            agent = create_agent(csv_url, api_key, conversation_history, chat_id)
            result = agent.run_sync(user_question)
            print("Orchestrator Result:", result.data)
            return result.data
        except ResourceExhausted or Exception as e:
            print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
            continue  # Move to the next key
        except Exception as e:
            print(f"Error with API key {api_key}: {e}")
            continue  # Move to the next key

    # If all keys are exhausted or fail
    print("All API keys have been exhausted or failed.")
    return None