import pandas as pd import json from typing import List, Literal, Optional from pydantic import BaseModel from dotenv import load_dotenv from pydantic_ai import Agent from csv_service import clean_data from python_code_executor_service import PythonExecutor from together_ai_instance_provider import InstanceProvider load_dotenv() instance_provider = InstanceProvider() class CodeResponse(BaseModel): """Container for code-related responses""" language: str = "python" code: str class ChartSpecification(BaseModel): """Details about requested charts""" image_description: str code: Optional[str] = None class AnalysisOperation(BaseModel): """Container for a single analysis operation with its code and result""" code: CodeResponse description: str class CsvChatResult(BaseModel): """Structured response for CSV-related AI interactions""" response_type: Literal["casual", "data_analysis", "visualization", "mixed"] # Casual chat response casual_response: str # Data analysis components analysis_operations: List[AnalysisOperation] # Visualization components charts: Optional[List[ChartSpecification]] = None def get_csv_info(df: pd.DataFrame) -> dict: """Get metadata/info about the CSV""" info = { 'num_rows': len(df), 'num_cols': len(df.columns), 'example_rows': df.head(2).to_dict('records'), 'dtypes': {col: str(df[col].dtype) for col in df.columns}, 'columns': list(df.columns), 'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])], 'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])] } return info def get_csv_system_prompt(df: pd.DataFrame) -> str: """Generate system prompt for CSV analysis""" csv_info = get_csv_info(df) prompt = f""" You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable. CSV Info: - Rows: {csv_info['num_rows']}, Cols: {csv_info['num_cols']} - Columns: {csv_info['columns']} - Sample: {csv_info['example_rows']} - Dtypes: {csv_info['dtypes']} Strict Rules: 1. Never recreate 'df' - use the existing variable 2. For analysis: - Include necessary imports (except pandas) and include complete code - Use df directly (e.g., print(df[...].mean())) 3. For visualizations: - Complete code with plt.show() - Example: plt.bar(df['x'], df['y']) \n plt.show() 4. For Lists and Dictionaries, return them as JSON Example: import json print(json.dumps(df[df['col'] == 'val'].to_dict('records'), indent=2)) """ return prompt def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent: """Create and return a CSV analysis agent with API key rotation""" csv_system_prompt = get_csv_system_prompt(df) for attempt in range(max_retries): try: model = instance_provider.get_instance() if model is None: raise RuntimeError("No available API instances") csv_agent = Agent( model=model, output_type=CsvChatResult, system_prompt=csv_system_prompt, ) return csv_agent except Exception as e: api_key = instance_provider.get_api_key_for_model(model) if api_key: print(f"Error with API key (attempt {attempt + 1}): {str(e)}") instance_provider.report_error(api_key) continue raise RuntimeError(f"Failed to create agent after {max_retries} attempts") async def query_csv_agent(csv_url: str, question: str) -> str: """Query the CSV agent with a DataFrame and question and return formatted output""" # Get the DataFrame from the CSV URL df = clean_data(csv_url) # Create agent and get response agent = create_csv_agent(df) result = await agent.run(question) # Process the response through PythonExecutor executor = PythonExecutor(df) # Convert the raw output to CsvChatResult if needed if not isinstance(result.output, CsvChatResult): # Handle case where output needs conversion try: response_data = result.output if isinstance(result.output, dict) else json.loads(result.output) chat_result = CsvChatResult(**response_data) except Exception as e: raise ValueError(f"Could not parse agent response: {str(e)}") else: chat_result = result.output print("Chat Result Original Object:", chat_result) # Process and format the response formatted_output = executor.process_response(chat_result) return formatted_output