|
import os |
|
from dotenv import load_dotenv |
|
import uuid |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from typing import Dict, Any, List, Literal, Optional |
|
import pandas as pd |
|
import numpy as np |
|
import json |
|
import io |
|
import contextlib |
|
import traceback |
|
import time |
|
from datetime import datetime, timedelta |
|
import seaborn as sns |
|
import scipy.stats as stats |
|
from pydantic import BaseModel |
|
from tabulate import tabulate |
|
|
|
from supabase_service import upload_file_to_supabase |
|
|
|
|
|
load_dotenv() |
|
|
|
class CodeResponse(BaseModel): |
|
"""Container for code-related responses""" |
|
language: str = "python" |
|
code: str |
|
|
|
|
|
class ChartSpecification(BaseModel): |
|
"""Details about requested charts""" |
|
image_description: str |
|
code: Optional[str] = None |
|
|
|
|
|
class AnalysisOperation(BaseModel): |
|
"""Container for a single analysis operation with its code and result""" |
|
code: CodeResponse |
|
result_var: str |
|
|
|
|
|
class CsvChatResult(BaseModel): |
|
"""Structured response for CSV-related AI interactions""" |
|
casual_response: str |
|
analysis_operations: Optional[AnalysisOperation] |
|
charts: Optional[ChartSpecification] |
|
|
|
|
|
class PythonExecutor: |
|
"""Handles execution of Python code with comprehensive data analysis libraries""" |
|
|
|
def __init__(self, df: pd.DataFrame, charts_folder: str = "generated_charts"): |
|
""" |
|
Initialize the PythonExecutor with a DataFrame |
|
|
|
Args: |
|
df (pd.DataFrame): The DataFrame to operate on |
|
charts_folder (str): Folder to save charts in |
|
""" |
|
self.df = df |
|
self.charts_folder = Path(charts_folder) |
|
self.charts_folder.mkdir(exist_ok=True) |
|
self.exec_locals = {} |
|
|
|
def execute_code(self, code: str) -> Dict[str, Any]: |
|
""" |
|
Execute Python code with full data analysis context and return results |
|
|
|
Args: |
|
code (str): Python code to execute |
|
|
|
Returns: |
|
dict: Dictionary containing execution results and any generated plots |
|
""" |
|
output = "" |
|
error = None |
|
plots = [] |
|
|
|
|
|
stdout = io.StringIO() |
|
|
|
|
|
original_show = plt.show |
|
|
|
def custom_show(): |
|
"""Custom show function that saves plots instead of displaying them""" |
|
for i, fig in enumerate(plt.get_fignums()): |
|
figure = plt.figure(fig) |
|
|
|
buf = io.BytesIO() |
|
figure.savefig(buf, format='png', bbox_inches='tight') |
|
buf.seek(0) |
|
plots.append(buf.read()) |
|
plt.close('all') |
|
|
|
try: |
|
|
|
exec_globals = { |
|
|
|
'pd': pd, |
|
'np': np, |
|
'df': self.df, |
|
|
|
|
|
'plt': plt, |
|
'sns': sns, |
|
'tabulate': tabulate, |
|
|
|
|
|
'stats': stats, |
|
|
|
|
|
'datetime': datetime, |
|
'timedelta': timedelta, |
|
'time': time, |
|
|
|
|
|
'json': json, |
|
'__builtins__': __builtins__, |
|
} |
|
|
|
|
|
plt.show = custom_show |
|
|
|
|
|
with contextlib.redirect_stdout(stdout): |
|
exec(code, exec_globals, self.exec_locals) |
|
|
|
output = stdout.getvalue() |
|
|
|
except Exception as e: |
|
error = { |
|
"message": str(e), |
|
"traceback": traceback.format_exc() |
|
} |
|
finally: |
|
|
|
plt.show = original_show |
|
|
|
return { |
|
'output': output, |
|
'error': error, |
|
'plots': plots, |
|
'locals': self.exec_locals |
|
} |
|
|
|
async def save_plot_to_supabase(self, plot_data: bytes, description: str, chat_id: str) -> str: |
|
""" |
|
Save plot to Supabase storage and return the public URL |
|
|
|
Args: |
|
plot_data (bytes): Image data in bytes |
|
description (str): Description of the plot |
|
chat_id (str): ID of the chat session |
|
|
|
Returns: |
|
str: Public URL of the uploaded chart |
|
""" |
|
|
|
filename = f"chart_{uuid.uuid4().hex}.png" |
|
filepath = self.charts_folder / filename |
|
|
|
|
|
with open(filepath, 'wb') as f: |
|
f.write(plot_data) |
|
|
|
try: |
|
|
|
public_url = await upload_file_to_supabase( |
|
file_path=str(filepath), |
|
file_name=filename, |
|
chat_id=chat_id |
|
) |
|
|
|
|
|
os.remove(filepath) |
|
|
|
return public_url |
|
except Exception as e: |
|
|
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
raise Exception(f"Failed to upload plot to Supabase: {e}") |
|
|
|
def _format_result(self, result: Any) -> str: |
|
"""Format the result for display""" |
|
if isinstance(result, (pd.DataFrame, pd.Series)): |
|
|
|
json_str = result.to_json(orient='records', date_format='iso') |
|
json.dumps(json.loads(json_str), indent=2) |
|
elif isinstance(result, (dict, list)): |
|
return json.dumps(result, indent=2) |
|
return str(result) |
|
|
|
async def process_response(self, response: CsvChatResult, chat_id: str) -> str: |
|
"""Process the response with proper variable handling and error checking""" |
|
output_parts = [response.casual_response] |
|
|
|
|
|
execution_result = None |
|
operation = None |
|
|
|
|
|
if response.analysis_operations is not None: |
|
try: |
|
operation = response.analysis_operations |
|
if operation and operation.code and operation.code.code: |
|
execution_result = self.execute_code(operation.code.code) |
|
|
|
|
|
result = self.exec_locals.get(operation.result_var) |
|
|
|
if execution_result.get('error'): |
|
output_parts.append(f"\n❌ Error in operation '{operation.result_var}':") |
|
output_parts.append("```python\n" + execution_result['error']['message'] + "\n```") |
|
elif result is not None: |
|
|
|
if result is None or (hasattr(result, '__len__') and len(result) == 0): |
|
output_parts.append(f"\n⚠️ Values are missing - Operation '{operation.result_var}' returned no data") |
|
else: |
|
output_parts.append(f"\n🔹 Result for '{operation.result_var}':") |
|
output_parts.append("```python\n" + self._format_result(result) + "\n```") |
|
else: |
|
output_str = execution_result.get('output', '').strip() |
|
if output_str: |
|
output_parts.append("```\n" + output_str + "\n```") |
|
else: |
|
output_parts.append("\n⚠️ Invalid analysis operation - missing code or result variable") |
|
except Exception as e: |
|
output_parts.append(f"\n❌ Error processing analysis operation: {str(e)}") |
|
if operation: |
|
output_parts.append(f"Operation: {operation.result_var}") |
|
|
|
|
|
if response.charts is not None: |
|
chart = response.charts |
|
try: |
|
if chart and (chart.code or chart.image_description): |
|
if chart.code: |
|
chart_result = self.execute_code(chart.code) |
|
if chart_result.get('plots'): |
|
for plot_data in chart_result['plots']: |
|
try: |
|
public_url = await self.save_plot_to_supabase( |
|
plot_data=plot_data, |
|
description=chart.image_description, |
|
chat_id=chat_id |
|
) |
|
output_parts.append(f"\n🖼️ {chart.image_description}") |
|
output_parts.append(f"") |
|
except Exception as e: |
|
output_parts.append(f"\n⚠️ Error uploading chart: {str(e)}") |
|
elif chart_result.get('error'): |
|
output_parts.append("```python\n" + f"Error generating {chart.image_description}: {chart_result['error']['message']}" + "\n```") |
|
else: |
|
output_parts.append(f"\n⚠️ No chart generated for '{chart.image_description}'") |
|
else: |
|
output_parts.append(f"\n⚠️ No code provided for chart: {chart.image_description}") |
|
else: |
|
output_parts.append("\n⚠️ Invalid chart specification") |
|
except Exception as e: |
|
output_parts.append(f"\n❌ Error processing chart: {str(e)}") |
|
|
|
return "\n".join(output_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|