File size: 7,816 Bytes
f8d95b7
 
d7d1d4e
 
 
7797795
d7d1d4e
b707dc6
d7d1d4e
 
 
 
b707dc6
 
 
 
d7d1d4e
cdffd76
938b9b4
e35735a
f8d95b7
 
e8e8da2
f8d95b7
b707dc6
d7d1d4e
e8e8da2
d7d1d4e
 
837fd40
938b9b4
 
 
 
b707dc6
d7d1d4e
e8e8da2
d7d1d4e
 
837fd40
938b9b4
 
 
 
b707dc6
d7d1d4e
e8e8da2
d7d1d4e
69a0b7f
837fd40
b707dc6
d7d1d4e
e8e8da2
7797795
d7d1d4e
 
 
837fd40
938b9b4
 
 
 
 
 
 
b707dc6
d7d1d4e
e8e8da2
 
1f6b1ac
d7d1d4e
 
 
e326328
e8e8da2
d7d1d4e
 
 
 
 
 
e8e8da2
d7d1d4e
 
 
 
 
 
 
 
e8e8da2
d7d1d4e
af41fa4
938b9b4
 
 
 
 
d7d1d4e
e8e8da2
d7d1d4e
 
e326328
d7d1d4e
e8e8da2
d7d1d4e
938b9b4
d7d1d4e
 
e8e8da2
d7d1d4e
 
 
e326328
 
d7d1d4e
e8e8da2
7abd4e3
d7d1d4e
 
e8e8da2
d7d1d4e
 
e8e8da2
f8d95b7
 
 
 
 
 
 
 
 
 
 
837fd40
e8e8da2
e326328
938b9b4
e326328
938b9b4
e326328
 
938b9b4
 
 
 
 
 
e8e8da2
3ab74be
938b9b4
 
e8e8da2
938b9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e8da2
938b9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
from dotenv import load_dotenv
import uuid
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, Any, List, Literal, Optional
import pandas as pd
import numpy as np
import json
import io
import contextlib
import traceback
import time
from datetime import datetime, timedelta
import seaborn as sns
import scipy.stats as stats
from pydantic import BaseModel
from tabulate import tabulate
import re

from supabase_service import upload_file_to_supabase

# Load environment variables from .env file
load_dotenv()

class CodeResponse(BaseModel):
    """Container for code-related responses"""
    language: str = "python"
    code: str

    def clean_code(self) -> str:
        """Remove trailing newlines while preserving internal structure"""
        return self.code.rstrip('\n')


class ChartSpecification(BaseModel):
    """Details about requested charts"""
    image_description: str
    code: Optional[str] = None

    def clean_description(self) -> str:
        """Replace newlines in description with spaces (preserves readability)"""
        return self.image_description.replace('\n', ' ').strip()


class AnalysisOperation(BaseModel):
    """Container for a single analysis operation with its code and result"""
    code: CodeResponse
    result_var: str


class CsvChatResult(BaseModel):
    """Structured response for CSV-related AI interactions"""
    response_type: Literal["casual", "data_analysis", "visualization", "mixed"]
    casual_response: str
    analysis_operations: List[AnalysisOperation]
    charts: Optional[List[ChartSpecification]] = None

    def clean_casual_response(self) -> str:
        """Clean casual response by replacing newlines with spaces when appropriate"""
        # Preserve intentional line breaks (markdown-style)
        if '\n\n' in self.casual_response:
            return self.casual_response
        return self.casual_response.replace('\n', ' ')


class PythonExecutor:
    """Handles execution of Python code with comprehensive data analysis libraries"""
    
    def __init__(self, df: pd.DataFrame, charts_folder: str = "generated_charts"):
        self.df = df
        self.charts_folder = Path(charts_folder)
        self.charts_folder.mkdir(exist_ok=True)
        self.exec_locals = {}
        
    def execute_code(self, code: str) -> Dict[str, Any]:
        output = ""
        error = None
        plots = []
        stdout = io.StringIO()
        original_show = plt.show
        
        def custom_show():
            for i, fig in enumerate(plt.get_fignums()):
                figure = plt.figure(fig)
                buf = io.BytesIO()
                figure.savefig(buf, format='png', bbox_inches='tight')
                buf.seek(0)
                plots.append(buf.read())
            plt.close('all')
        
        try:
            exec_globals = {
                'pd': pd, 'np': np, 'df': self.df,
                'plt': plt, 'sns': sns, 'tabulate': tabulate,
                'stats': stats, 'datetime': datetime,
                'timedelta': timedelta, 'time': time,
                'json': json, '__builtins__': __builtins__,
            }
            
            plt.show = custom_show
            with contextlib.redirect_stdout(stdout):
                exec(code, exec_globals, self.exec_locals)
            output = stdout.getvalue()
            
        except Exception as e:
            error = {"message": str(e), "traceback": traceback.format_exc()}
        finally:
            plt.show = original_show
            
        return {
            'output': output,
            'error': error,
            'plots': plots,
            'locals': self.exec_locals
        }
    
    async def save_plot_to_supabase(self, plot_data: bytes, description: str, chat_id: str) -> str:
        filename = f"chart_{uuid.uuid4().hex}.png"
        filepath = self.charts_folder / filename
        
        with open(filepath, 'wb') as f:
            f.write(plot_data)
            
        try:
            public_url = await upload_file_to_supabase(
                file_path=str(filepath),
                file_name=filename,
                chat_id=chat_id
            )
            os.remove(filepath)
            return public_url
        except Exception as e:
            if os.path.exists(filepath):
                os.remove(filepath)
            raise Exception(f"Failed to upload plot to Supabase: {e}")
    
    def _format_result(self, result: Any) -> str:
        """Format result with safe newline handling"""
        if isinstance(result, (pd.DataFrame, pd.Series)):
            return result.to_string()
        elif isinstance(result, (dict, list)):
            return json.dumps(result, indent=2)
        
        # Clean string representation while preserving essential newlines
        str_result = str(result)
        if '\n' in str_result and not any(x in str_result for x in ['```', 'def ', 'class ']):
            return str_result.replace('\n', ' ')
        return str_result
    
    async def process_response(self, response: CsvChatResult, chat_id: str) -> str:
        """Process response with intelligent newline handling"""
        output_parts = [response.clean_casual_response()]
        
        # Process analysis operations
        for operation in response.analysis_operations:
            execution_result = self.execute_code(operation.code.clean_code())
            result = self.exec_locals.get(operation.result_var)
            
            if execution_result['error']:
                output_parts.append(f"\n❌ Error in operation '{operation.result_var}':")
                output_parts.append(f"```python\n{execution_result['error']['message']}\n```")
            elif result is not None:
                if result is None or (hasattr(result, '__len__') and len(result) == 0):
                    output_parts.append(f"\n⚠️ Values are missing - Operation '{operation.result_var}' returned no data")
                else:
                    output_parts.append(f"\n🔹 Result for '{operation.result_var}':")
                    output_parts.append(f"```python\n{self._format_result(result)}\n```")
            else:
                output_str = execution_result['output'].strip()
                if output_str:
                    output_parts.append(f"\nOutput for '{operation.result_var}':")
                    output_parts.append(f"```\n{output_str}\n```")
        
        # Process charts
        if response.charts:
            output_parts.append("\n📊 Visualizations:")
            for chart in response.charts:
                if chart.code:
                    chart_result = self.execute_code(chart.code)
                    if chart_result['plots']:
                        for plot_data in chart_result['plots']:
                            try:
                                public_url = await self.save_plot_to_supabase(
                                    plot_data=plot_data,
                                    description=chart.clean_description(),
                                    chat_id=chat_id
                                )
                                output_parts.append(f"\n🖼️ {chart.clean_description()}")
                                output_parts.append(f"![{chart.clean_description()}]({public_url})")
                            except Exception as e:
                                output_parts.append(f"\n⚠️ Error uploading chart: {str(e)}")
                    elif chart_result['error']:
                        output_parts.append(f"```python\nError generating chart: {chart_result['error']['message']}\n```")
                    else:
                        output_parts.append(f"\n⚠️ No chart generated for '{chart.clean_description()}'")
        
        return "\n".join(output_parts)