File size: 8,022 Bytes
d7d1d4e 837fd40 d7d1d4e 0c663c0 d7d1d4e 69a0b7f d7d1d4e 837fd40 d7d1d4e 837fd40 d7d1d4e 837fd40 d7d1d4e 837fd40 d7d1d4e 7a58a46 a25d048 d7d1d4e d161b96 d7d1d4e 9d31b19 d161b96 9d31b19 a25d048 d7d1d4e a25d048 d161b96 a25d048 d7d1d4e 01d3a80 d7d1d4e 5005f7f d7d1d4e f8d95b7 d7d1d4e 41e2747 d7d1d4e 1dd7d6b d7d1d4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import pandas as pd
import json
from typing import List, Literal, Optional
from pydantic import BaseModel
from dotenv import load_dotenv
from pydantic_ai import Agent
from csv_service import clean_data
from python_code_executor_service import PythonExecutor
from cerebras_instance_provider import InstanceProvider
load_dotenv()
instance_provider = InstanceProvider()
class CodeResponse(BaseModel):
"""Container for code-related responses"""
language: str = "python"
code: str
class ChartSpecification(BaseModel):
"""Details about requested charts"""
image_description: str
code: Optional[str] = None
class AnalysisOperation(BaseModel):
"""Container for a single analysis operation with its code and result"""
code: CodeResponse
result_var: str
class CsvChatResult(BaseModel):
"""Structured response for CSV-related AI interactions"""
response_type: Literal["casual", "data_analysis", "visualization", "mixed"]
# Casual chat response
casual_response: str
# Data analysis components
analysis_operations: List[AnalysisOperation]
# Visualization components
charts: Optional[List[ChartSpecification]] = None
def get_csv_info(df: pd.DataFrame) -> dict:
"""Get metadata/info about the CSV"""
info = {
'num_rows': len(df),
'num_cols': len(df.columns),
'example_rows': df.head(2).to_dict('records'),
'dtypes': {col: str(df[col].dtype) for col in df.columns},
'columns': list(df.columns),
'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])],
'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])]
}
return info
# def get_csv_system_prompt(df: pd.DataFrame) -> str:
# """Generate system prompt for CSV analysis"""
# csv_info = get_csv_info(df)
# prompt = f"""
# You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable.
# CSV Info:
# - Rows: {csv_info['num_rows']}, Cols: {csv_info['num_cols']}
# - Columns: {csv_info['columns']}
# - Sample: {csv_info['example_rows']}
# - Dtypes: {csv_info['dtypes']}
# Strict Rules:
# 1. Never recreate 'df' - use the existing variable
# 2. For analysis:
# - Complete code without imports
# - Use df directly (e.g., print(df[...].mean()))
# 3. For visualizations:
# - Create the most professional, publication-quality charts possible
# - Maximize descriptive elements and detail while maintaining clarity
# - Figure size: (14, 8) for complex charts, (12, 6) for simpler ones
# - Use comprehensive titles (fontsize=16) and axis labels (fontsize=14)
# - Include informative legends (fontsize=12) when appropriate
# - Add annotations for important data points where valuable
# - Rotate x-labels (45° if needed) with fontsize=12 for readability
# - Use colorblind-friendly palettes (seaborn 'deep', 'muted', or 'colorblind')
# - Add gridlines (alpha=0.3) when they improve readability
# - Include proper margins and padding to prevent label cutoff
# - For distributions, include kernel density estimates when appropriate
# - For time series, use appropriate date formatting and markers
# - Do not use any visualization library other than matplotlib or seaborn
# - Complete code with plt.tight_layout() before plt.show()
# - Example professional chart:
# plt.figure(figsize=(14, 8))
# ax = sns.barplot(x='category', y='value', data=df, palette='muted', ci=None)
# plt.title('Detailed Analysis of Values by Category', fontsize=16, pad=20)
# plt.xlabel('Category', fontsize=14)
# plt.ylabel('Average Value', fontsize=14)
# plt.xticks(rotation=45, ha='right', fontsize=12)
# plt.yticks(fontsize=12)
# ax.grid(True, linestyle='--', alpha=0.3)
# for p in ax.patches:
# ax.annotate(f'{{p.get_height():.1f}}',
# (p.get_x() + p.get_width() / 2., p.get_height()),
# ha='center', va='center',
# xytext=(0, 10),
# textcoords='offset points',
# fontsize=12)
# plt.tight_layout()
# plt.show()
# 4. For Lists, Records, Tables, Dictionaries...etc for any data structure, always return them as JSON with correct indentation.
# IMPORTANT: Code must be syntactically perfect and executable as-is.
# """
# return prompt
def get_csv_system_prompt(df: pd.DataFrame) -> str:
"""Generate system prompt for CSV analysis"""
csv_info = get_csv_info(df)
prompt = f"""
You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable.
CSV Info:
- Shape: {csv_info['num_rows']} rows × {csv_info['num_cols']} cols
- Columns: {csv_info['columns']}
- Sample: {csv_info['example_rows']}
- Dtypes: {csv_info['dtypes']}
Requirements:
1. Use existing 'df' - never recreate it
2. For Lists, Records, Tables, Dictionaries...etc for any data structure, always return them as JSON with correct indentation.
3. For charts:
- Use matplotlib/seaborn only
- Professional quality: proper sizing, labels, titles
- Figure size: (14, 8) for complex, (12, 6) for simple
- Clear titles (fontsize=16), labels (fontsize=14)
- Rotate x-labels if needed (45°, fontsize=12)
- Add annotations/gridlines where helpful
- Use colorblind-friendly palettes
- Always include plt.tight_layout()
3. For data structures: return as properly formatted JSON
Example professional chart:
plt.figure(figsize=(14, 8))
sns.barplot(x='category', y='value', data=df, palette='muted')
plt.title('Value by Category', fontsize=16)
plt.xlabel('Category', fontsize=14)
plt.ylabel('Value', fontsize=14)
plt.xticks(rotation=45)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
Return complete, executable code.
"""
return prompt
def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent:
"""Create and return a CSV analysis agent with API key rotation"""
csv_system_prompt = get_csv_system_prompt(df)
for attempt in range(max_retries):
try:
model = instance_provider.get_instance()
if model is None:
raise RuntimeError("No available API instances")
csv_agent = Agent(
model=model,
output_type=CsvChatResult,
system_prompt=csv_system_prompt,
)
return csv_agent
except Exception as e:
api_key = instance_provider.get_api_key_for_model(model)
if api_key:
print(f"Error with API key (attempt {attempt + 1}): {str(e)}")
instance_provider.report_error(api_key)
continue
raise RuntimeError(f"Failed to create agent after {max_retries} attempts")
async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str:
"""Query the CSV agent with a DataFrame and question and return formatted output"""
# Get the DataFrame from the CSV URL
df = clean_data(csv_url)
# Create agent and get response
agent = create_csv_agent(df)
result = await agent.run(question)
# Process the response through PythonExecutor
executor = PythonExecutor(df)
# Convert the raw output to CsvChatResult if needed
if not isinstance(result.output, CsvChatResult):
# Handle case where output needs conversion
try:
response_data = result.output if isinstance(result.output, dict) else json.loads(result.output)
chat_result = CsvChatResult(**response_data)
except Exception as e:
raise ValueError(f"Could not parse agent response: {str(e)}")
else:
chat_result = result.output
print("Chat Result Original Object:", chat_result)
# Process and format the response
formatted_output = await executor.process_response(chat_result, chat_id)
return formatted_output
|