|
import uuid |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from typing import Dict, Any, List, Optional |
|
import pandas as pd |
|
import json |
|
import io |
|
import contextlib |
|
import traceback |
|
from pydantic import BaseModel |
|
|
|
class CodeResponse(BaseModel): |
|
"""Container for code-related responses""" |
|
language: str = "python" |
|
code: str |
|
|
|
class ChartSpecification(BaseModel): |
|
"""Details about requested charts""" |
|
image_description: str |
|
code: Optional[str] = None |
|
|
|
class AnalysisOperation(BaseModel): |
|
"""Container for a single analysis operation with its code and result""" |
|
code: CodeResponse |
|
description: str |
|
|
|
class CsvChatResult(BaseModel): |
|
"""Structured response for CSV-related AI interactions""" |
|
response_type: str |
|
casual_response: str |
|
analysis_operations: List[AnalysisOperation] |
|
charts: Optional[List[ChartSpecification]] = None |
|
|
|
class PythonExecutor: |
|
"""Handles execution of Python code and dummy image generation for CSV analysis""" |
|
|
|
def __init__(self, df: pd.DataFrame, charts_folder: str = "generated_charts"): |
|
""" |
|
Initialize the PythonExecutor with a DataFrame |
|
|
|
Args: |
|
df (pd.DataFrame): The DataFrame to operate on |
|
charts_folder (str): Folder to save charts in |
|
""" |
|
self.df = df |
|
self.charts_folder = Path(charts_folder) |
|
self.charts_folder.mkdir(exist_ok=True) |
|
|
|
def execute_code(self, code: str) -> Dict[str, Any]: |
|
""" |
|
Execute Python code and return the output and any generated plots |
|
|
|
Args: |
|
code (str): Python code to execute |
|
|
|
Returns: |
|
dict: Dictionary containing execution results and any generated plots |
|
""" |
|
output = "" |
|
error = None |
|
plots = [] |
|
|
|
|
|
stdout = io.StringIO() |
|
|
|
|
|
original_show = plt.show |
|
|
|
def custom_show(): |
|
"""Custom show function that saves plots instead of displaying them""" |
|
for i, fig in enumerate(plt.get_fignums()): |
|
figure = plt.figure(fig) |
|
|
|
buf = io.BytesIO() |
|
figure.savefig(buf, format='png', bbox_inches='tight') |
|
buf.seek(0) |
|
plots.append(buf.read()) |
|
plt.close('all') |
|
|
|
try: |
|
|
|
exec_globals = { |
|
'pd': pd, |
|
'plt': plt, |
|
'json': json, |
|
'df': self.df, |
|
'__builtins__': __builtins__, |
|
} |
|
|
|
|
|
plt.show = custom_show |
|
|
|
|
|
with contextlib.redirect_stdout(stdout): |
|
exec(code, exec_globals) |
|
|
|
output = stdout.getvalue() |
|
|
|
except Exception as e: |
|
error = { |
|
"message": str(e), |
|
"traceback": traceback.format_exc() |
|
} |
|
finally: |
|
|
|
plt.show = original_show |
|
|
|
return { |
|
'output': output, |
|
'error': error, |
|
'plots': plots |
|
} |
|
|
|
def save_plot_dummy(self, plot_data: bytes, description: str) -> str: |
|
""" |
|
Save plot to charts folder and return a dummy URL |
|
|
|
Args: |
|
plot_data (bytes): Image data in bytes |
|
description (str): Description of the plot |
|
|
|
Returns: |
|
str: Dummy URL for the chart |
|
""" |
|
|
|
filename = f"chart_{uuid.uuid4().hex}.png" |
|
filepath = self.charts_folder / filename |
|
|
|
|
|
with open(filepath, 'wb') as f: |
|
f.write(plot_data) |
|
|
|
|
|
return f"https://example.com/charts/{filename}" |
|
|
|
def process_response(self, response: CsvChatResult) -> str: |
|
""" |
|
Process the CsvChatResult response and generate formatted output |
|
|
|
Args: |
|
response (CsvChatResult): Response from CSV analysis |
|
|
|
Returns: |
|
str: Formatted output with results and dummy image URLs |
|
""" |
|
output_parts = [] |
|
|
|
|
|
output_parts.append(response.casual_response) |
|
|
|
|
|
for operation in response.analysis_operations: |
|
|
|
result = self.execute_code(operation.code.code) |
|
|
|
|
|
output_parts.append(f"\n{operation.description}:") |
|
|
|
|
|
if result['error']: |
|
output_parts.append(f"Error: {result['error']['message']}") |
|
else: |
|
output_parts.append(result['output'].strip()) |
|
|
|
|
|
if response.charts: |
|
output_parts.append("\nVisualizations:") |
|
|
|
for chart in response.charts: |
|
if chart.code: |
|
|
|
result = self.execute_code(chart.code) |
|
|
|
if result['plots']: |
|
|
|
for plot_data in result['plots']: |
|
dummy_url = self.save_plot_dummy(plot_data, chart.image_description) |
|
output_parts.append(f"\n{chart.image_description}") |
|
output_parts.append(f"") |
|
elif result['error']: |
|
output_parts.append(f"\nError generating {chart.image_description}: {result['error']['message']}") |
|
|
|
return "\n".join(output_parts) |