FastApi / cerebras_csv_agent.py
Soumik555's picture
robust null checking for process_response
611893f
import pandas as pd
import json
from typing import List, Literal, Optional
from pydantic import BaseModel
from dotenv import load_dotenv
from pydantic_ai import Agent
from csv_service import clean_data
from python_code_executor_service import PythonExecutor
from cerebras_instance_provider import InstanceProvider
import logging
load_dotenv()
instance_provider = InstanceProvider()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CodeResponse(BaseModel):
"""Container for code-related responses"""
language: str = "python"
code: str
class ChartSpecification(BaseModel):
"""Details about requested charts"""
image_description: str
code: Optional[str] = None
class AnalysisOperation(BaseModel):
"""Container for a single analysis operation with its code and result"""
code: CodeResponse
result_var: str
class CsvChatResult(BaseModel):
"""Structured response for CSV-related AI interactions"""
# Casual chat response
casual_response: str
# Data analysis components
# analysis_operations: List[AnalysisOperation]
analysis_operations: AnalysisOperation
# Visualization components
# charts: Optional[List[ChartSpecification]] = None
charts: Optional[ChartSpecification] = None
def get_csv_info(df: pd.DataFrame) -> dict:
"""Get metadata/info about the CSV"""
info = {
'num_rows': len(df),
'num_cols': len(df.columns),
'example_rows': df.head(2).to_dict('records'),
'dtypes': {col: str(df[col].dtype) for col in df.columns},
'columns': list(df.columns),
'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])],
'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])]
}
return info
def get_csv_system_prompt(df: pd.DataFrame) -> str:
"""Generate system prompt for CSV analysis"""
csv_info = get_csv_info(df)
prompt = f"""
You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable.
CSV Info:
- Shape: {csv_info['num_rows']} rows × {csv_info['num_cols']} cols
- Columns: {csv_info['columns']}
- Sample: {csv_info['example_rows']}
- Dtypes: {csv_info['dtypes']}
STRICT REQUIREMENTS:
1. NEVER calculate or predict values yourself - ALWAYS return executable code that would produce the result
2. Use existing 'df' - never recreate it
3. For any data structures (Lists, Records, Tables, Dictionaries, etc.), always return them as JSON with correct indentation
4. For charts:
- Use matplotlib/seaborn only
- Professional quality: proper sizing, labels, titles
- Figure size: (14, 8) for complex, (12, 6) for simple
- Clear titles (fontsize=16), labels (fontsize=14)
- Rotate x-labels if needed (45°, fontsize=12)
- Add annotations/gridlines where helpful
- Use colorblind-friendly palettes
- Always include plt.tight_layout()
Example professional chart:
plt.figure(figsize=(14, 8))
sns.barplot(x='category', y='value', data=df, palette='muted')
plt.title('Value by Category', fontsize=16)
plt.xlabel('Category', fontsize=14)
plt.ylabel('Value', fontsize=14)
plt.xticks(rotation=45)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
Example professional response for a dataframe:
num_rows = len(df)
Return complete, executable code.
"""
return prompt
def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent:
"""Create and return a CSV analysis agent with API key rotation"""
csv_system_prompt = get_csv_system_prompt(df)
for attempt in range(max_retries):
try:
model = instance_provider.get_instance()
if model is None:
raise RuntimeError("No available API instances")
csv_agent = Agent(
model=model,
output_type=CsvChatResult,
system_prompt=csv_system_prompt,
)
return csv_agent
except Exception as e:
api_key = instance_provider.get_api_key_for_model(model)
if api_key:
logger.info(f"Error with API key (attempt {attempt + 1}): {str(e)}")
instance_provider.report_error(api_key)
continue
raise RuntimeError(f"Failed to create agent after {max_retries} attempts")
async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str:
"""Query the CSV agent with a DataFrame and question and return formatted output"""
# Get the DataFrame from the CSV URL
df = clean_data(csv_url)
# Create agent and get response
agent = create_csv_agent(df)
result = await agent.run(question)
# Process the response through PythonExecutor
executor = PythonExecutor(df)
# Convert the raw output to CsvChatResult if needed
if not isinstance(result.output, CsvChatResult):
# Handle case where output needs conversion
try:
response_data = result.output if isinstance(result.output, dict) else json.loads(result.output)
chat_result = CsvChatResult(**response_data)
except Exception as e:
raise ValueError(f"Could not parse agent response: {str(e)}")
else:
chat_result = result.output
logger.info("Chat Result Original Object:", chat_result)
# Process and format the response
formatted_output = await executor.process_response(chat_result, chat_id)
return formatted_output