|
import json |
|
import numpy as np |
|
import pandas as pd |
|
import re |
|
import os |
|
import uuid |
|
import logging |
|
from io import StringIO |
|
import sys |
|
import traceback |
|
from typing import Optional, Dict, Any, List |
|
from pydantic import BaseModel, Field |
|
from google.generativeai import GenerativeModel, configure |
|
from dotenv import load_dotenv |
|
import seaborn as sns |
|
import datetime as dt |
|
|
|
from supabase_service import upload_file_to_supabase |
|
|
|
pd.set_option('display.max_columns', None) |
|
pd.set_option('display.max_rows', None) |
|
pd.set_option('display.max_colwidth', None) |
|
|
|
load_dotenv() |
|
|
|
|
|
API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")[::-1] |
|
MODEL_NAME = 'gemini-2.0-flash' |
|
|
|
class FileProps(BaseModel): |
|
fileName: str |
|
filePath: str |
|
fileType: str |
|
|
|
class Files(BaseModel): |
|
csv_files: List[FileProps] |
|
image_files: List[FileProps] |
|
|
|
class FileBoxProps(BaseModel): |
|
files: Files |
|
|
|
os.environ['MPLBACKEND'] = 'agg' |
|
import matplotlib.pyplot as plt |
|
plt.show = lambda: None |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class GeminiKeyManager: |
|
"""Manage multiple Gemini API keys with failover""" |
|
|
|
def __init__(self, api_keys: List[str]): |
|
self.original_keys = api_keys.copy() |
|
self.available_keys = api_keys.copy() |
|
self.active_key = None |
|
self.failed_keys = {} |
|
|
|
def configure(self) -> bool: |
|
while self.available_keys: |
|
key = self.available_keys.pop(0) |
|
try: |
|
configure(api_key=key) |
|
self.active_key = key |
|
logger.info(f"Configured with key: {self._mask_key(key)}") |
|
return True |
|
except Exception as e: |
|
self.failed_keys[key] = str(e) |
|
logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}") |
|
logger.critical("All API keys failed") |
|
return False |
|
|
|
def _mask_key(self, key: str) -> str: |
|
return f"{key[:8]}...{key[-4:]}" if key else "" |
|
|
|
class PythonREPL: |
|
"""Secure Python REPL with file generation tracking""" |
|
|
|
def __init__(self, df: pd.DataFrame): |
|
self.df = df |
|
self.output_dir = os.path.abspath(f'generated_outputs/{uuid.uuid4()}') |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
self.local_env = { |
|
"pd": pd, |
|
"df": self.df.copy(), |
|
"plt": plt, |
|
"os": os, |
|
"uuid": uuid, |
|
"sns": sns, |
|
"json": json, |
|
"dt": dt, |
|
"output_dir": self.output_dir |
|
} |
|
|
|
def execute(self, code: str) -> Dict[str, Any]: |
|
print('Executing code...', code) |
|
old_stdout = sys.stdout |
|
sys.stdout = mystdout = StringIO() |
|
file_tracker = { |
|
'csv_files': set(), |
|
'image_files': set() |
|
} |
|
|
|
try: |
|
code = f""" |
|
import matplotlib.pyplot as plt |
|
plt.switch_backend('agg') |
|
{code} |
|
plt.close('all') |
|
""" |
|
exec(code, self.local_env) |
|
self.df = self.local_env.get('df', self.df) |
|
|
|
|
|
for fname in os.listdir(self.output_dir): |
|
if fname.endswith('.csv'): |
|
file_tracker['csv_files'].add(fname) |
|
elif fname.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
file_tracker['image_files'].add(fname) |
|
|
|
error = False |
|
except Exception as e: |
|
error_msg = traceback.format_exc() |
|
error = True |
|
finally: |
|
sys.stdout = old_stdout |
|
|
|
return { |
|
"output": mystdout.getvalue(), |
|
"error": error, |
|
"error_message": error_msg if error else None, |
|
"df": self.local_env.get('df', self.df), |
|
"output_dir": self.output_dir, |
|
"files": { |
|
"csv": [os.path.join(self.output_dir, f) for f in file_tracker['csv_files']], |
|
"images": [os.path.join(self.output_dir, f) for f in file_tracker['image_files']] |
|
} |
|
} |
|
|
|
class RethinkAgent(BaseModel): |
|
df: pd.DataFrame |
|
max_retries: int = Field(default=5, ge=1) |
|
gemini_model: Optional[GenerativeModel] = None |
|
current_retry: int = Field(default=0, ge=0) |
|
repl: Optional[PythonREPL] = None |
|
key_manager: Optional[GeminiKeyManager] = None |
|
conversation: List[Dict[str, Any]] = [] |
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
def _extract_code(self, response: str) -> str: |
|
code_match = re.search(r'```python(.*?)```', response, re.DOTALL) |
|
return code_match.group(1).strip() if code_match else response.strip() |
|
|
|
def _generate_initial_prompt(self, query: str) -> str: |
|
initial_prompt = f"""Generate DIRECT EXECUTION CODE (no functions, no explanations) following STRICT RULES: |
|
|
|
CONVERSATION HISTORY: |
|
{self.conversation} |
|
|
|
MANDATORY REQUIREMENTS: |
|
1. Operate directly on existing 'df' variable |
|
2. Save ALL final DataFrames to CSV using: df.to_csv(f'{{output_dir}}/descriptive_name.csv') |
|
3. For visualizations: plt.savefig(f'{{output_dir}}/chart_name.png') |
|
4. Use EXACTLY this structure: |
|
# Data processing |
|
df_processed = df[...] # filtering/grouping |
|
# Save results |
|
df_processed.to_csv(f'{{output_dir}}/result.csv') |
|
# Visualizations (if needed) |
|
plt.figure() |
|
... plotting code ... |
|
plt.savefig(f'{{output_dir}}/chart.png') |
|
plt.close() |
|
|
|
FORBIDDEN: |
|
- Function definitions |
|
- Dummy data creation |
|
- Any code blocks besides pandas operations and matplotlib |
|
- Print statements showing dataframes |
|
- Using any visualization library other than matplotlib or seaborn |
|
|
|
DATAFRAME COLUMNS: {', '.join(self.df.columns)} |
|
DATAFRAME'S FIRST FIVE ROWS: {self.df.head().to_dict('records')} |
|
USER QUERY: {query} |
|
|
|
EXAMPLE RESPONSE FOR "Sales by region": |
|
# Data processing |
|
sales_by_region = df.groupby('region')['sales'].sum().reset_index() |
|
# Save results |
|
sales_by_region.to_csv(f'{{output_dir}}/sales_by_region.csv') |
|
""" |
|
logger.info('Conversation history:', self.conversation) |
|
return initial_prompt |
|
|
|
def _generate_retry_prompt(self, query: str, error: str, code: str) -> str: |
|
return f"""FIX THIS CODE (failed with: {error}) by STRICTLY FOLLOWING: |
|
|
|
1. REMOVE ALL FUNCTION DEFINITIONS |
|
2. ENSURE DIRECT DF OPERATIONS |
|
3. USE EXPLICIT output_dir PATHS |
|
4. ADD NECESSARY IMPORTS IF MISSING |
|
5. VALIDATE COLUMN NAMES EXIST |
|
|
|
BAD CODE: |
|
{code} |
|
|
|
CORRECTED CODE:""" |
|
|
|
def initialize_model(self, api_keys: List[str]) -> bool: |
|
self.key_manager = GeminiKeyManager(api_keys) |
|
if not self.key_manager.configure(): |
|
raise RuntimeError("API key initialization failed") |
|
try: |
|
self.gemini_model = GenerativeModel(MODEL_NAME) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Model init failed: {str(e)}") |
|
return False |
|
|
|
def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str: |
|
prompt = self._generate_retry_prompt(query, error, previous_code) if error else self._generate_initial_prompt(query) |
|
try: |
|
response = self.gemini_model.generate_content(prompt) |
|
return self._extract_code(response.text) |
|
except Exception as e: |
|
if self.key_manager.available_keys and self.key_manager.configure(): |
|
return self.generate_code(query, error, previous_code) |
|
raise |
|
|
|
def execute_query(self, query: str) -> Dict[str, Any]: |
|
self.repl = PythonREPL(self.df) |
|
result = None |
|
|
|
while self.current_retry < self.max_retries: |
|
try: |
|
code = self.generate_code(query, |
|
result["error_message"] if result else None, |
|
result["code"] if result else None) |
|
execution_result = self.repl.execute(code) |
|
|
|
if execution_result["error"]: |
|
self.current_retry += 1 |
|
result = { |
|
"error_message": execution_result["error_message"], |
|
"code": code |
|
} |
|
else: |
|
return { |
|
"text": execution_result["output"], |
|
"csv_files": execution_result["files"]["csv"], |
|
"image_files": execution_result["files"]["images"] |
|
} |
|
except Exception as e: |
|
return { |
|
"error": f"Critical failure: {str(e)}", |
|
"csv_files": [], |
|
"image_files": [] |
|
} |
|
|
|
return { |
|
"error": f"Failed after {self.max_retries} retries", |
|
"csv_files": [], |
|
"image_files": [] |
|
} |
|
|
|
def gemini_llm_chat(csv_url: str, query: str, conversation_history: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
try: |
|
df = pd.read_csv(csv_url) |
|
agent = RethinkAgent(df=df, conversation=conversation_history) |
|
|
|
if not agent.initialize_model(API_KEYS): |
|
return {"error": "API configuration failed"} |
|
|
|
result = agent.execute_query(query) |
|
|
|
if "error" in result: |
|
return result |
|
|
|
return { |
|
"message": result["text"], |
|
"csv_files": result["csv_files"], |
|
"image_files": result["image_files"] |
|
} |
|
except Exception as e: |
|
logger.error(f"Processing failed: {str(e)}") |
|
return { |
|
"error": f"Processing error: {str(e)}", |
|
"csv_files": [], |
|
"image_files": [] |
|
} |
|
|
|
|
|
async def generate_csv_report(csv_url: str, query: str, chat_id: str, conversation_history: List[Dict[str, Any]]) -> FileBoxProps: |
|
try: |
|
result = gemini_llm_chat(csv_url, query, conversation_history) |
|
logger.info(f"Raw result from gemini_llm_chat: {result}") |
|
|
|
csv_files = [] |
|
image_files = [] |
|
|
|
|
|
if isinstance(result, dict) and 'csv_files' in result and 'image_files' in result: |
|
|
|
for csv_path in result['csv_files']: |
|
if os.path.exists(csv_path): |
|
file_name = os.path.basename(csv_path) |
|
try: |
|
unique_file_name = f"{uuid.uuid4()}_{file_name}" |
|
public_url = await upload_file_to_supabase( |
|
file_path=csv_path, |
|
file_name=unique_file_name, |
|
chat_id=chat_id |
|
) |
|
csv_files.append(FileProps( |
|
fileName=file_name, |
|
filePath=public_url, |
|
fileType="csv" |
|
)) |
|
os.remove(csv_path) |
|
except Exception as upload_error: |
|
logger.error(f"Failed to upload CSV {file_name}: {str(upload_error)}") |
|
continue |
|
|
|
|
|
for img_path in result['image_files']: |
|
if os.path.exists(img_path): |
|
file_name = os.path.basename(img_path) |
|
try: |
|
unique_file_name = f"{uuid.uuid4()}_{file_name}" |
|
public_url = await upload_file_to_supabase( |
|
file_path=img_path, |
|
file_name=unique_file_name, |
|
chat_id=chat_id |
|
) |
|
image_files.append(FileProps( |
|
fileName=file_name, |
|
filePath=public_url, |
|
fileType="image" |
|
)) |
|
os.remove(img_path) |
|
except Exception as upload_error: |
|
logger.error(f"Failed to upload image {file_name}: {str(upload_error)}") |
|
continue |
|
|
|
return FileBoxProps( |
|
files=Files( |
|
csv_files=csv_files, |
|
image_files=image_files |
|
) |
|
) |
|
else: |
|
raise ValueError("Unexpected response format from gemini_llm_chat") |
|
|
|
except Exception as e: |
|
logger.error(f"Report generation failed: {str(e)}") |
|
|
|
if 'csv_files' in locals() and 'image_files' in locals(): |
|
logger.info(f"Files that were generated but not processed: CSV: {result.get('csv_files', [])}, Images: {result.get('image_files', [])}") |
|
return FileBoxProps( |
|
files=Files( |
|
csv_files=[], |
|
image_files=[] |
|
) |
|
) |
|
|
|
|