File size: 5,564 Bytes
d7d1d4e
 
837fd40
d7d1d4e
 
 
 
 
0c663c0
7f73a04
d7d1d4e
 
 
 
 
7f73a04
 
 
d7d1d4e
 
 
 
 
 
 
 
 
 
 
 
 
69a0b7f
d7d1d4e
 
 
837fd40
 
d7d1d4e
837fd40
 
484c2da
 
837fd40
 
484c2da
 
837fd40
d7d1d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d161b96
d7d1d4e
9d31b19
d161b96
9d31b19
 
a25d048
d7d1d4e
 
 
 
611893f
 
 
 
 
a25d048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d161b96
611893f
 
 
 
a25d048
d7d1d4e
01d3a80
d7d1d4e
5005f7f
d7d1d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f73a04
d7d1d4e
 
 
 
 
 
f8d95b7
d7d1d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e2747
7f73a04
d7d1d4e
 
1dd7d6b
d7d1d4e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import pandas as pd
import json
from typing import List, Literal, Optional
from pydantic import BaseModel
from dotenv import load_dotenv
from pydantic_ai import Agent
from csv_service import clean_data
from python_code_executor_service import PythonExecutor
from cerebras_instance_provider import InstanceProvider
import logging

load_dotenv()

instance_provider = InstanceProvider()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CodeResponse(BaseModel):
    """Container for code-related responses"""
    language: str = "python"
    code: str

class ChartSpecification(BaseModel):
    """Details about requested charts"""
    image_description: str
    code: Optional[str] = None

class AnalysisOperation(BaseModel):
    """Container for a single analysis operation with its code and result"""
    code: CodeResponse
    result_var: str

class CsvChatResult(BaseModel):
    """Structured response for CSV-related AI interactions"""
    
    # Casual chat response
    casual_response: str
    
    # Data analysis components
    # analysis_operations: List[AnalysisOperation]
    analysis_operations: AnalysisOperation
    
    # Visualization components
    # charts: Optional[List[ChartSpecification]] = None
    charts: Optional[ChartSpecification] = None
    
    
def get_csv_info(df: pd.DataFrame) -> dict:
    """Get metadata/info about the CSV"""
    info = {
        'num_rows': len(df),
        'num_cols': len(df.columns),
        'example_rows': df.head(2).to_dict('records'),
        'dtypes': {col: str(df[col].dtype) for col in df.columns},
        'columns': list(df.columns),
        'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])],
        'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])]
    }
    return info


def get_csv_system_prompt(df: pd.DataFrame) -> str:
    """Generate system prompt for CSV analysis"""
    csv_info = get_csv_info(df)
    prompt = f"""
You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable.

CSV Info:
- Shape: {csv_info['num_rows']} rows × {csv_info['num_cols']} cols
- Columns: {csv_info['columns']}
- Sample: {csv_info['example_rows']}
- Dtypes: {csv_info['dtypes']}

STRICT REQUIREMENTS:
1. NEVER calculate or predict values yourself - ALWAYS return executable code that would produce the result
2. Use existing 'df' - never recreate it
3. For any data structures (Lists, Records, Tables, Dictionaries, etc.), always return them as JSON with correct indentation
4. For charts:
   - Use matplotlib/seaborn only
   - Professional quality: proper sizing, labels, titles
   - Figure size: (14, 8) for complex, (12, 6) for simple
   - Clear titles (fontsize=16), labels (fontsize=14)
   - Rotate x-labels if needed (45°, fontsize=12)
   - Add annotations/gridlines where helpful
   - Use colorblind-friendly palettes
   - Always include plt.tight_layout()

Example professional chart:
plt.figure(figsize=(14, 8))
sns.barplot(x='category', y='value', data=df, palette='muted')
plt.title('Value by Category', fontsize=16)
plt.xlabel('Category', fontsize=14)
plt.ylabel('Value', fontsize=14)
plt.xticks(rotation=45)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

Example professional response for a dataframe:
num_rows = len(df)


Return complete, executable code.
"""
    return prompt


def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent:
    """Create and return a CSV analysis agent with API key rotation"""
    csv_system_prompt = get_csv_system_prompt(df)
    
    for attempt in range(max_retries):
        try:
            model = instance_provider.get_instance()
            if model is None:
                raise RuntimeError("No available API instances")
            
            csv_agent = Agent(
                model=model,
                output_type=CsvChatResult,
                system_prompt=csv_system_prompt,
            )
        
            return csv_agent
            
        except Exception as e:
            api_key = instance_provider.get_api_key_for_model(model)
            if api_key:
                logger.info(f"Error with API key (attempt {attempt + 1}): {str(e)}")
                instance_provider.report_error(api_key)
            continue
    
    raise RuntimeError(f"Failed to create agent after {max_retries} attempts")


async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str:
    """Query the CSV agent with a DataFrame and question and return formatted output"""
    
    # Get the DataFrame from the CSV URL
    df = clean_data(csv_url)
    
    # Create agent and get response
    agent = create_csv_agent(df)
    result = await agent.run(question)
    
    # Process the response through PythonExecutor
    executor = PythonExecutor(df)
    
    # Convert the raw output to CsvChatResult if needed
    if not isinstance(result.output, CsvChatResult):
        # Handle case where output needs conversion
        try:
            response_data = result.output if isinstance(result.output, dict) else json.loads(result.output)
            chat_result = CsvChatResult(**response_data)
        except Exception as e:
            raise ValueError(f"Could not parse agent response: {str(e)}")
    else:
        chat_result = result.output
        
        logger.info("Chat Result Original Object:", chat_result)
    
    # Process and format the response
    formatted_output = await executor.process_response(chat_result, chat_id)
    
    return formatted_output