|
import json |
|
import numpy as np |
|
import pandas as pd |
|
import re |
|
import os |
|
import uuid |
|
import logging |
|
import time |
|
import threading |
|
from io import StringIO |
|
import sys |
|
import traceback |
|
from typing import Optional, Dict, Any, List, Set |
|
from pydantic import BaseModel, Field |
|
from dotenv import load_dotenv |
|
import seaborn as sns |
|
import datetime as dt |
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
pd.set_option('display.max_columns', None) |
|
pd.set_option('display.max_rows', None) |
|
pd.set_option('display.max_colwidth', None) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
API_KEYS = os.getenv("OPENAI_API_KEYS", "").split(",") |
|
MODEL_NAME = 'gpt-4o' |
|
KEY_RETRY_DELAY = 40 |
|
|
|
|
|
os.environ['MPLBACKEND'] = 'agg' |
|
import matplotlib.pyplot as plt |
|
plt.show = lambda: None |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def handle_out_of_range_float(value): |
|
"""Handle NaN and Inf values in numeric data""" |
|
if isinstance(value, float): |
|
if np.isnan(value): |
|
return None |
|
elif np.isinf(value): |
|
return "Infinity" |
|
return value |
|
|
|
class OpenAIKeyManager: |
|
"""Manage multiple OpenAI API keys with validation, failover, and delayed retries""" |
|
|
|
def __init__(self, api_keys: List[str]): |
|
self.original_keys = api_keys.copy() |
|
self.available_keys = api_keys.copy() |
|
self.active_key = None |
|
self.failed_keys: Dict[str, float] = {} |
|
self.llm_instance = None |
|
self.lock = threading.Lock() |
|
|
|
def configure(self) -> bool: |
|
"""Validate and activate an OpenAI API key with retry logic""" |
|
with self.lock: |
|
|
|
while self.available_keys: |
|
key = self.available_keys.pop(0) |
|
if self._try_key(key): |
|
return True |
|
|
|
|
|
now = time.time() |
|
retry_keys = [ |
|
k for k, ts in self.failed_keys.items() |
|
if (now - ts) >= KEY_RETRY_DELAY |
|
] |
|
|
|
for key in retry_keys: |
|
if self._try_key(key): |
|
del self.failed_keys[key] |
|
return True |
|
|
|
logger.critical("All API keys failed (including retries)") |
|
return False |
|
|
|
def _try_key(self, key: str) -> bool: |
|
"""Attempt to use a specific key, return True if successful""" |
|
try: |
|
self.llm_instance = ChatOpenAI( |
|
model=MODEL_NAME, |
|
api_key=key, |
|
temperature=0, |
|
max_retries=0 |
|
) |
|
self.llm_instance.invoke("test") |
|
self.active_key = key |
|
logger.info(f"Active_Key: {self._mask_key(key)}") |
|
return True |
|
except Exception as e: |
|
self.failed_keys[key] = time.time() |
|
logger.error(f"Key failed: {self._mask_key(key)} - {str(e)}") |
|
return False |
|
|
|
def rotate_key(self) -> bool: |
|
"""Rotate to the next available API key (including retries)""" |
|
return self.configure() |
|
|
|
def get_llm_instance(self) -> ChatOpenAI: |
|
"""Get the configured LLM instance""" |
|
return self.llm_instance |
|
|
|
def _mask_key(self, key: str) -> str: |
|
"""Mask API key for secure logging""" |
|
return f"{key[:8]}...{key[-4:]}" if key else "" |
|
|
|
class PythonREPL: |
|
"""Secure Python REPL environment for code execution""" |
|
|
|
def __init__(self, df: pd.DataFrame): |
|
self.df = df |
|
self.local_env = { |
|
"pd": pd, |
|
"df": self.df.copy(), |
|
"plt": plt, |
|
"os": os, |
|
"uuid": uuid, |
|
"sns": sns, |
|
"json": json, |
|
"dt": dt, |
|
"np": np, |
|
} |
|
os.makedirs('generated_charts', exist_ok=True) |
|
|
|
def execute(self, code: str) -> Dict[str, Any]: |
|
"""Execute Python code in a secure environment""" |
|
old_stdout = sys.stdout |
|
sys.stdout = mystdout = StringIO() |
|
error_msg = None |
|
|
|
try: |
|
|
|
code = f""" |
|
import matplotlib.pyplot as plt |
|
plt.switch_backend('agg') |
|
{code} |
|
plt.close('all') |
|
""" |
|
exec(code, self.local_env) |
|
self.df = self.local_env.get('df', self.df) |
|
error = False |
|
except Exception as e: |
|
error_msg = traceback.format_exc() |
|
error = True |
|
finally: |
|
sys.stdout = old_stdout |
|
|
|
return { |
|
"output": mystdout.getvalue(), |
|
"error": error, |
|
"error_message": error_msg if error else None, |
|
"df": self.local_env.get('df', self.df) |
|
} |
|
|
|
class RethinkAgent(BaseModel): |
|
"""AI agent for data analysis with automatic error correction""" |
|
|
|
df: pd.DataFrame |
|
max_retries: int = Field(default=5, ge=1) |
|
current_retry: int = Field(default=0, ge=0) |
|
repl: Optional[PythonREPL] = None |
|
key_manager: Optional[OpenAIKeyManager] = None |
|
llm: Optional[ChatOpenAI] = None |
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
def _extract_code(self, response: str) -> str: |
|
"""Extract Python code from markdown response""" |
|
code_match = re.search(r'```python(.*?)```', response, re.DOTALL) |
|
if code_match: |
|
return code_match.group(1).strip() |
|
code_match = re.search(r'```(.*?)```', response, re.DOTALL) |
|
return code_match.group(1).strip() if code_match else response.strip() |
|
|
|
def _generate_initial_prompt(self, query: str, chart: bool = False) -> str: |
|
"""Generate the initial prompt for the LLM""" |
|
columns = "\n".join([f"{col} ({self.df[col].dtype})" for col in self.df.columns]) |
|
|
|
if chart: |
|
return f""" |
|
Generate Python code to create visualization(s) for this DataFrame with columns: |
|
{columns} |
|
|
|
First 5 rows: |
|
{self.df.head().to_string()} |
|
|
|
Query: {query} |
|
|
|
Requirements: |
|
1. Save visualizations to 'generated_charts/' with UUID filename (use uuid.uuid4()) |
|
2. Use plt.savefig() with format='png' |
|
3. No plt.show() calls allowed |
|
4. After saving each chart, logger.info exactly: CHART_SAVED: generated_charts/<uuid>.png |
|
5. Start with 'import pandas as pd', 'import matplotlib.pyplot as plt', etc. |
|
6. The DataFrame is available as 'df' |
|
7. Wrap code in ```python``` blocks |
|
8. If Question is illogical and cannot be answered, explain using logger.info() |
|
""" |
|
else: |
|
return f""" |
|
Generate Python code to analyze this DataFrame with columns: |
|
{columns} |
|
|
|
First 5 rows: |
|
{self.df.head().to_string()} |
|
|
|
Query: {query} |
|
|
|
Requirements: |
|
1. Use logger.info() to show results with clear explanations |
|
2. If Question is illogical and cannot be answered, explain using logger.info() |
|
3. Start with necessary imports ('import pandas as pd', etc.) |
|
4. The DataFrame is available as 'df' |
|
5. For tabular results, use markdown formatting |
|
6. Wrap code in ```python``` blocks |
|
""" |
|
|
|
def _generate_retry_prompt(self, query: str, error: str, code: str, chart: bool = False) -> str: |
|
"""Generate a retry prompt when code execution fails""" |
|
if chart: |
|
return f""" |
|
The previous code failed with this error: |
|
{error} |
|
|
|
Here was the code that failed: |
|
{code} |
|
|
|
Please fix the code to: |
|
1. Create the requested visualization(s) |
|
2. Save to 'generated_charts/' with UUID filename |
|
3. logger.info CHART_SAVED messages |
|
4. Handle the error: {error} |
|
|
|
Original query: {query} |
|
|
|
Show the corrected code in ```python``` blocks |
|
""" |
|
else: |
|
return f""" |
|
The previous code failed with this error: |
|
{error} |
|
|
|
Here was the code that failed: |
|
{code} |
|
|
|
Please fix the code to: |
|
1. Complete the analysis requested |
|
2. Handle the error: {error} |
|
3. Include clear output formatting |
|
|
|
Original query: {query} |
|
|
|
Show the corrected code in ```python``` blocks |
|
""" |
|
|
|
def initialize_model(self, api_keys: List[str]) -> bool: |
|
"""Initialize OpenAI client with key rotation""" |
|
self.key_manager = OpenAIKeyManager(api_keys) |
|
if not self.key_manager.configure(): |
|
raise RuntimeError("All API keys failed") |
|
self.llm = self.key_manager.get_llm_instance() |
|
return True |
|
|
|
def generate_code(self, query: str, error: Optional[str] = None, |
|
previous_code: Optional[str] = None, chart: bool = False) -> str: |
|
"""Generate Python code to answer the query""" |
|
prompt = self._generate_retry_prompt(query, error, previous_code, chart) if error else self._generate_initial_prompt(query, chart) |
|
|
|
try: |
|
response = self.llm.invoke(prompt) |
|
return self._extract_code(response.content) |
|
except Exception as e: |
|
logger.error(f"API error: {str(e)}") |
|
if self.key_manager.rotate_key(): |
|
self.llm = self.key_manager.get_llm_instance() |
|
return self.generate_code(query, error, previous_code, chart) |
|
raise |
|
|
|
def execute_query(self, query: str, chart: bool = False) -> str: |
|
"""Execute the query with automatic error correction""" |
|
self.repl = PythonREPL(self.df) |
|
error = None |
|
previous_code = None |
|
|
|
while self.current_retry < self.max_retries: |
|
try: |
|
code = self.generate_code(query, error, previous_code, chart) |
|
result = self.repl.execute(code) |
|
|
|
if result["error"]: |
|
self.current_retry += 1 |
|
error = result["error_message"] |
|
previous_code = code |
|
logger.warning(f"Retry {self.current_retry}/{self.max_retries}") |
|
else: |
|
self.df = result["df"] |
|
return result["output"] |
|
except Exception as e: |
|
logger.error(f"Critical error: {str(e)}") |
|
return f"System error: {str(e)}" |
|
|
|
return f"Failed after {self.max_retries} retries. Last error: {error}" |
|
|
|
def openai_react_chat(csv_url: str, query: str, chart: bool = False) -> Optional[Dict]: |
|
"""Main function to execute data analysis queries""" |
|
try: |
|
|
|
df = pd.read_csv(csv_url) |
|
if df.empty: |
|
raise ValueError("Empty DataFrame loaded from CSV") |
|
|
|
agent = RethinkAgent(df=df) |
|
|
|
if not agent.initialize_model(API_KEYS): |
|
logger.error("Failed to initialize model") |
|
return None |
|
|
|
result = agent.execute_query(query, chart) |
|
|
|
|
|
if isinstance(result, pd.DataFrame): |
|
processed = result.apply(handle_out_of_range_float).to_dict(orient="records") |
|
elif isinstance(result, pd.Series): |
|
processed = result.apply(handle_out_of_range_float).to_dict() |
|
elif isinstance(result, list): |
|
processed = [handle_out_of_range_float(item) for item in result] |
|
elif isinstance(result, dict): |
|
processed = {k: handle_out_of_range_float(v) for k, v in result.items()} |
|
else: |
|
processed = {"answer": str(handle_out_of_range_float(result))} |
|
|
|
logger.info("Analysis completed successfully") |
|
|
|
if chart and isinstance(result, str) and result.startswith("CHART_SAVED:"): |
|
result = result.strip() |
|
match = re.search(r'CHART_SAVED:\s*(\S+)', result) |
|
if match: |
|
chart_path = match.group(1) |
|
logger.info("Chart Path:", chart_path) |
|
return chart_path |
|
else: |
|
logger.info("Could not extract chart path from response") |
|
return None |
|
|
|
return processed |
|
except Exception as e: |
|
logger.error(f"Error in openai_llm_chat: {str(e)}") |
|
return None |
|
|
|
|