import pandas as pd import json from typing import List, Literal, Optional from pydantic import BaseModel from dotenv import load_dotenv from pydantic_ai import Agent from csv_service import clean_data from python_code_executor_service import PythonExecutor from together_ai_instance_provider import InstanceProvider load_dotenv() instance_provider = InstanceProvider() class CodeResponse(BaseModel): """Container for code-related responses""" language: str = "python" code: str class ChartSpecification(BaseModel): """Details about requested charts""" image_description: str code: Optional[str] = None class AnalysisOperation(BaseModel): """Container for a single analysis operation with its code and result""" code: CodeResponse result_var: str class CsvChatResult(BaseModel): """Structured response for CSV-related AI interactions""" response_type: Literal["casual", "data_analysis", "visualization", "mixed"] # Casual chat response casual_response: str # Data analysis components analysis_operations: List[AnalysisOperation] # Visualization components charts: Optional[List[ChartSpecification]] = None def get_csv_info(df: pd.DataFrame) -> dict: """Get metadata/info about the CSV""" info = { 'num_rows': len(df), 'num_cols': len(df.columns), 'example_rows': df.head(2).to_dict('records'), 'dtypes': {col: str(df[col].dtype) for col in df.columns}, 'columns': list(df.columns), 'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])], 'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])] } return info def get_csv_system_prompt(df: pd.DataFrame) -> str: """Generate system prompt for CSV analysis""" csv_info = get_csv_info(df) prompt = f""" You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable. CSV Info: - Rows: {csv_info['num_rows']}, Cols: {csv_info['num_cols']} - Columns: {csv_info['columns']} - Sample: {csv_info['example_rows']} - Dtypes: {csv_info['dtypes']} Strict Rules: 1. Never recreate 'df' - use the existing variable 2. For analysis: - Complete code without imports - Use df directly (e.g., print(df[...].mean())) 3. For visualizations: - Create the most professional, publication-quality charts possible - Maximize descriptive elements and detail while maintaining clarity - Figure size: (14, 8) for complex charts, (12, 6) for simpler ones - Use comprehensive titles (fontsize=16) and axis labels (fontsize=14) - Include informative legends (fontsize=12) when appropriate - Add annotations for important data points where valuable - Rotate x-labels (45° if needed) with fontsize=12 for readability - Use colorblind-friendly palettes (seaborn 'deep', 'muted', or 'colorblind') - Add gridlines (alpha=0.3) when they improve readability - Include proper margins and padding to prevent label cutoff - For distributions, include kernel density estimates when appropriate - For time series, use appropriate date formatting and markers - Do not use any visualization library other than matplotlib or seaborn - Complete code with plt.tight_layout() before plt.show() - Example professional chart: plt.figure(figsize=(14, 8)) ax = sns.barplot(x='category', y='value', data=df, palette='muted', ci=None) plt.title('Detailed Analysis of Values by Category', fontsize=16, pad=20) plt.xlabel('Category', fontsize=14) plt.ylabel('Average Value', fontsize=14) plt.xticks(rotation=45, ha='right', fontsize=12) plt.yticks(fontsize=12) ax.grid(True, linestyle='--', alpha=0.3) for p in ax.patches: ax.annotate(f'{{p.get_height():.1f}}', (p.get_x() + p.get_width() / 2., p.get_height()), ha='center', va='center', xytext=(0, 10), textcoords='offset points', fontsize=12) plt.tight_layout() plt.show() 4. For Lists, Records, Tables, Dictionaries...etc for any data structure, always return them as JSON with correct indentation. IMPORTANT: Code must be syntactically perfect and executable as-is. """ return prompt def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent: """Create and return a CSV analysis agent with API key rotation""" csv_system_prompt = get_csv_system_prompt(df) for attempt in range(max_retries): try: model = instance_provider.get_instance() if model is None: raise RuntimeError("No available API instances") csv_agent = Agent( model=model, output_type=CsvChatResult, system_prompt=csv_system_prompt, ) return csv_agent except Exception as e: api_key = instance_provider.get_api_key_for_model(model) if api_key: print(f"Error with API key (attempt {attempt + 1}): {str(e)}") instance_provider.report_error(api_key) continue raise RuntimeError(f"Failed to create agent after {max_retries} attempts") async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str: """Query the CSV agent with a DataFrame and question and return formatted output""" # Get the DataFrame from the CSV URL df = clean_data(csv_url) # Create agent and get response agent = create_csv_agent(df) result = await agent.run(question) # Process the response through PythonExecutor executor = PythonExecutor(df) # Convert the raw output to CsvChatResult if needed if not isinstance(result.output, CsvChatResult): # Handle case where output needs conversion try: response_data = result.output if isinstance(result.output, dict) else json.loads(result.output) chat_result = CsvChatResult(**response_data) except Exception as e: raise ValueError(f"Could not parse agent response: {str(e)}") else: chat_result = result.output print("Chat Result Original Object:", chat_result) # Process and format the response formatted_output = await executor.process_response(chat_result, chat_id) return formatted_output