|
import os |
|
from supabase import create_client, Client |
|
from dotenv import load_dotenv |
|
import uuid |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from typing import Dict, Any, List, Optional |
|
import pandas as pd |
|
import numpy as np |
|
import json |
|
import io |
|
import contextlib |
|
import traceback |
|
import time |
|
from datetime import datetime, timedelta |
|
import seaborn as sns |
|
import scipy.stats as stats |
|
from pydantic import BaseModel |
|
|
|
from supabase_service import upload_file_to_supabase |
|
|
|
|
|
load_dotenv() |
|
|
|
class CodeResponse(BaseModel): |
|
"""Container for code-related responses""" |
|
language: str = "python" |
|
code: str |
|
|
|
|
|
class ChartSpecification(BaseModel): |
|
"""Details about requested charts""" |
|
image_description: str |
|
code: Optional[str] = None |
|
|
|
|
|
class AnalysisOperation(BaseModel): |
|
"""Container for a single analysis operation with its code and result""" |
|
code: CodeResponse |
|
description: str |
|
|
|
|
|
class CsvChatResult(BaseModel): |
|
"""Structured response for CSV-related AI interactions""" |
|
response_type: str |
|
casual_response: str |
|
analysis_operations: List[AnalysisOperation] |
|
charts: Optional[List[ChartSpecification]] = None |
|
|
|
|
|
class PythonExecutor: |
|
"""Handles execution of Python code with comprehensive data analysis libraries""" |
|
|
|
def __init__(self, df: pd.DataFrame, charts_folder: str = "generated_charts"): |
|
""" |
|
Initialize the PythonExecutor with a DataFrame |
|
|
|
Args: |
|
df (pd.DataFrame): The DataFrame to operate on |
|
charts_folder (str): Folder to save charts in |
|
""" |
|
self.df = df |
|
self.charts_folder = Path(charts_folder) |
|
self.charts_folder.mkdir(exist_ok=True) |
|
|
|
def execute_code(self, code: str) -> Dict[str, Any]: |
|
""" |
|
Execute Python code with full data analysis context and return results |
|
|
|
Args: |
|
code (str): Python code to execute |
|
|
|
Returns: |
|
dict: Dictionary containing execution results and any generated plots |
|
""" |
|
output = "" |
|
error = None |
|
plots = [] |
|
|
|
|
|
stdout = io.StringIO() |
|
|
|
|
|
original_show = plt.show |
|
|
|
def custom_show(): |
|
"""Custom show function that saves plots instead of displaying them""" |
|
for i, fig in enumerate(plt.get_fignums()): |
|
figure = plt.figure(fig) |
|
|
|
buf = io.BytesIO() |
|
figure.savefig(buf, format='png', bbox_inches='tight') |
|
buf.seek(0) |
|
plots.append(buf.read()) |
|
plt.close('all') |
|
|
|
try: |
|
|
|
exec_globals = { |
|
|
|
'pd': pd, |
|
'np': np, |
|
'df': self.df, |
|
|
|
|
|
'plt': plt, |
|
'sns': sns, |
|
|
|
|
|
'stats': stats, |
|
|
|
|
|
'datetime': datetime, |
|
'timedelta': timedelta, |
|
'time': time, |
|
|
|
|
|
'json': json, |
|
'__builtins__': __builtins__, |
|
} |
|
|
|
|
|
plt.show = custom_show |
|
|
|
|
|
with contextlib.redirect_stdout(stdout): |
|
exec(code, exec_globals) |
|
|
|
output = stdout.getvalue() |
|
|
|
except Exception as e: |
|
error = { |
|
"message": str(e), |
|
"traceback": traceback.format_exc() |
|
} |
|
finally: |
|
|
|
plt.show = original_show |
|
|
|
return { |
|
'output': output, |
|
'error': error, |
|
'plots': plots |
|
} |
|
|
|
async def save_plot_to_supabase(self, plot_data: bytes, description: str, chat_id: str) -> str: |
|
""" |
|
Save plot to Supabase storage and return the public URL |
|
|
|
Args: |
|
plot_data (bytes): Image data in bytes |
|
description (str): Description of the plot |
|
chat_id (str): ID of the chat session |
|
|
|
Returns: |
|
str: Public URL of the uploaded chart |
|
""" |
|
|
|
filename = f"chart_{uuid.uuid4().hex}.png" |
|
filepath = self.charts_folder / filename |
|
|
|
|
|
with open(filepath, 'wb') as f: |
|
f.write(plot_data) |
|
|
|
try: |
|
|
|
public_url = await upload_file_to_supabase( |
|
file_path=str(filepath), |
|
file_name=filename, |
|
chat_id=chat_id |
|
) |
|
|
|
|
|
os.remove(filepath) |
|
|
|
return public_url |
|
except Exception as e: |
|
|
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
raise Exception(f"Failed to upload plot to Supabase: {e}") |
|
|
|
def _looks_like_structured_data(self, output: str) -> bool: |
|
"""Helper to detect JSON-like or array-like output""" |
|
output = output.strip() |
|
return ( |
|
output.startswith('{') and output.endswith('}') or |
|
output.startswith('[') and output.endswith(']') or |
|
'\n' in output and '=' in output |
|
) |
|
|
|
async def process_response(self, response: CsvChatResult, chat_id: str) -> str: |
|
""" |
|
Process the CsvChatResult response and generate formatted output |
|
with markdown code blocks for structured data. |
|
|
|
Args: |
|
response (CsvChatResult): Response from CSV analysis |
|
chat_id (str): ID of the chat session |
|
|
|
Returns: |
|
str: Formatted output with results and image URLs |
|
""" |
|
output_parts = [] |
|
|
|
|
|
output_parts.append(response.casual_response) |
|
|
|
|
|
for operation in response.analysis_operations: |
|
|
|
result = self.execute_code(operation.code.code) |
|
|
|
|
|
output_parts.append(f"\n{operation.description}:") |
|
|
|
|
|
if result['error']: |
|
output_parts.append("```python\n" + f"Error: {result['error']['message']}" + "\n```") |
|
else: |
|
output = result['output'].strip() |
|
if self._looks_like_structured_data(output): |
|
output_parts.append("```python\n" + output + "\n```") |
|
else: |
|
output_parts.append(output) |
|
|
|
|
|
if response.charts: |
|
output_parts.append("\nVisualizations:") |
|
for chart in response.charts: |
|
if chart.code: |
|
result = self.execute_code(chart.code) |
|
if result['plots']: |
|
for plot_data in result['plots']: |
|
try: |
|
public_url = await self.save_plot_to_supabase( |
|
plot_data=plot_data, |
|
description=chart.image_description, |
|
chat_id=chat_id |
|
) |
|
output_parts.append(f"\n{chart.image_description}") |
|
output_parts.append(f"") |
|
except Exception as e: |
|
output_parts.append(f"\nError uploading chart: {str(e)}") |
|
elif result['error']: |
|
output_parts.append("```python\n" + f"Error generating {chart.image_description}: {result['error']['message']}" + "\n```") |
|
|
|
return "\n".join(output_parts) |