|
import os |
|
from dotenv import load_dotenv |
|
import uuid |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from typing import Dict, Any, List, Literal, Optional |
|
import pandas as pd |
|
import numpy as np |
|
import json |
|
import io |
|
import contextlib |
|
import traceback |
|
import time |
|
from datetime import datetime, timedelta |
|
import seaborn as sns |
|
import scipy.stats as stats |
|
from pydantic import BaseModel |
|
from tabulate import tabulate |
|
import re |
|
|
|
from supabase_service import upload_file_to_supabase |
|
|
|
|
|
load_dotenv() |
|
|
|
class CodeResponse(BaseModel): |
|
"""Container for code-related responses""" |
|
language: str = "python" |
|
code: str |
|
|
|
def clean_code(self) -> str: |
|
"""Remove trailing newlines while preserving internal structure""" |
|
return self.code.rstrip('\n') |
|
|
|
|
|
class ChartSpecification(BaseModel): |
|
"""Details about requested charts""" |
|
image_description: str |
|
code: Optional[str] = None |
|
|
|
def clean_description(self) -> str: |
|
"""Replace newlines in description with spaces (preserves readability)""" |
|
return self.image_description.replace('\n', ' ').strip() |
|
|
|
|
|
class AnalysisOperation(BaseModel): |
|
"""Container for a single analysis operation with its code and result""" |
|
code: CodeResponse |
|
result_var: str |
|
|
|
|
|
class CsvChatResult(BaseModel): |
|
"""Structured response for CSV-related AI interactions""" |
|
response_type: Literal["casual", "data_analysis", "visualization", "mixed"] |
|
casual_response: str |
|
analysis_operations: List[AnalysisOperation] |
|
charts: Optional[List[ChartSpecification]] = None |
|
|
|
def clean_casual_response(self) -> str: |
|
"""Clean casual response by replacing newlines with spaces when appropriate""" |
|
|
|
if '\n\n' in self.casual_response: |
|
return self.casual_response |
|
return self.casual_response.replace('\n', ' ') |
|
|
|
|
|
class PythonExecutor: |
|
"""Handles execution of Python code with comprehensive data analysis libraries""" |
|
|
|
def __init__(self, df: pd.DataFrame, charts_folder: str = "generated_charts"): |
|
self.df = df |
|
self.charts_folder = Path(charts_folder) |
|
self.charts_folder.mkdir(exist_ok=True) |
|
self.exec_locals = {} |
|
|
|
def execute_code(self, code: str) -> Dict[str, Any]: |
|
output = "" |
|
error = None |
|
plots = [] |
|
stdout = io.StringIO() |
|
original_show = plt.show |
|
|
|
def custom_show(): |
|
for i, fig in enumerate(plt.get_fignums()): |
|
figure = plt.figure(fig) |
|
buf = io.BytesIO() |
|
figure.savefig(buf, format='png', bbox_inches='tight') |
|
buf.seek(0) |
|
plots.append(buf.read()) |
|
plt.close('all') |
|
|
|
try: |
|
exec_globals = { |
|
'pd': pd, 'np': np, 'df': self.df, |
|
'plt': plt, 'sns': sns, 'tabulate': tabulate, |
|
'stats': stats, 'datetime': datetime, |
|
'timedelta': timedelta, 'time': time, |
|
'json': json, '__builtins__': __builtins__, |
|
} |
|
|
|
plt.show = custom_show |
|
with contextlib.redirect_stdout(stdout): |
|
exec(code, exec_globals, self.exec_locals) |
|
output = stdout.getvalue() |
|
|
|
except Exception as e: |
|
error = {"message": str(e), "traceback": traceback.format_exc()} |
|
finally: |
|
plt.show = original_show |
|
|
|
return { |
|
'output': output, |
|
'error': error, |
|
'plots': plots, |
|
'locals': self.exec_locals |
|
} |
|
|
|
async def save_plot_to_supabase(self, plot_data: bytes, description: str, chat_id: str) -> str: |
|
filename = f"chart_{uuid.uuid4().hex}.png" |
|
filepath = self.charts_folder / filename |
|
|
|
with open(filepath, 'wb') as f: |
|
f.write(plot_data) |
|
|
|
try: |
|
public_url = await upload_file_to_supabase( |
|
file_path=str(filepath), |
|
file_name=filename, |
|
chat_id=chat_id |
|
) |
|
os.remove(filepath) |
|
return public_url |
|
except Exception as e: |
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
raise Exception(f"Failed to upload plot to Supabase: {e}") |
|
|
|
def _format_result(self, result: Any) -> str: |
|
"""Format result with safe newline handling""" |
|
if isinstance(result, (pd.DataFrame, pd.Series)): |
|
return result.to_string() |
|
elif isinstance(result, (dict, list)): |
|
return json.dumps(result, indent=2) |
|
|
|
|
|
str_result = str(result) |
|
if '\n' in str_result and not any(x in str_result for x in ['```', 'def ', 'class ']): |
|
return str_result.replace('\n', ' ') |
|
return str_result |
|
|
|
async def process_response(self, response: CsvChatResult, chat_id: str) -> str: |
|
"""Process response with intelligent newline handling""" |
|
output_parts = [response.clean_casual_response()] |
|
|
|
|
|
for operation in response.analysis_operations: |
|
execution_result = self.execute_code(operation.code.clean_code()) |
|
result = self.exec_locals.get(operation.result_var) |
|
|
|
if execution_result['error']: |
|
output_parts.append(f"\n❌ Error in operation '{operation.result_var}':") |
|
output_parts.append(f"```python\n{execution_result['error']['message']}\n```") |
|
elif result is not None: |
|
if result is None or (hasattr(result, '__len__') and len(result) == 0): |
|
output_parts.append(f"\n⚠️ Values are missing - Operation '{operation.result_var}' returned no data") |
|
else: |
|
output_parts.append(f"\n🔹 Result for '{operation.result_var}':") |
|
output_parts.append(f"```python\n{self._format_result(result)}\n```") |
|
else: |
|
output_str = execution_result['output'].strip() |
|
if output_str: |
|
output_parts.append(f"\nOutput for '{operation.result_var}':") |
|
output_parts.append(f"```\n{output_str}\n```") |
|
|
|
|
|
if response.charts: |
|
output_parts.append("\n📊 Visualizations:") |
|
for chart in response.charts: |
|
if chart.code: |
|
chart_result = self.execute_code(chart.code) |
|
if chart_result['plots']: |
|
for plot_data in chart_result['plots']: |
|
try: |
|
public_url = await self.save_plot_to_supabase( |
|
plot_data=plot_data, |
|
description=chart.clean_description(), |
|
chat_id=chat_id |
|
) |
|
output_parts.append(f"\n🖼️ {chart.clean_description()}") |
|
output_parts.append(f"") |
|
except Exception as e: |
|
output_parts.append(f"\n⚠️ Error uploading chart: {str(e)}") |
|
elif chart_result['error']: |
|
output_parts.append(f"```python\nError generating chart: {chart_result['error']['message']}\n```") |
|
else: |
|
output_parts.append(f"\n⚠️ No chart generated for '{chart.clean_description()}'") |
|
|
|
return "\n".join(output_parts) |