added openai react
Browse files- controller.py +70 -44
- openai_react_agent_service.py +362 -0
controller.py
CHANGED
@@ -27,6 +27,7 @@ import matplotlib
|
|
27 |
import seaborn as sns
|
28 |
from gemini_report_generator import generate_csv_report
|
29 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
|
|
30 |
from orchestrator_agent import csv_orchestrator_chat
|
31 |
from supabase_service import upload_file_to_supabase
|
32 |
from util_service import _prompt_generator, process_answer
|
@@ -344,9 +345,10 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
344 |
query = request.get("query")
|
345 |
csv_url = request.get("csv_url")
|
346 |
decoded_url = unquote(csv_url)
|
347 |
-
detailed_answer = request.get("detailed_answer")
|
348 |
conversation_history = request.get("conversation_history", [])
|
349 |
-
generate_report = request.get("generate_report")
|
|
|
350 |
|
351 |
if generate_report is True:
|
352 |
report_files = await generate_csv_report(csv_url, query)
|
@@ -368,22 +370,31 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
368 |
if orchestrator_answer is not None:
|
369 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
# Process with groq_chat first
|
372 |
-
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
373 |
-
logger.info("groq_answer:", groq_answer)
|
374 |
|
375 |
-
if process_answer(groq_answer) == "Empty response received.":
|
376 |
-
|
377 |
|
378 |
-
if process_answer(groq_answer):
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
|
386 |
-
return {"answer": jsonable_encoder(groq_answer)}
|
387 |
|
388 |
except Exception as e:
|
389 |
logger.error(f"Error processing request: {str(e)}")
|
@@ -851,6 +862,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
851 |
detailed_answer = request.get("detailed_answer", False)
|
852 |
conversation_history = request.get("conversation_history", [])
|
853 |
generate_report = request.get("generate_report", False)
|
|
|
854 |
|
855 |
if generate_report is True:
|
856 |
report_files = await generate_csv_report(csv_url, query)
|
@@ -881,38 +893,52 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
881 |
|
882 |
if orchestrator_answer is not None:
|
883 |
return {"orchestrator_response": jsonable_encoder(orchestrator_answer)}
|
884 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
885 |
# Next, try the groq-based method
|
886 |
-
groq_result = await loop.run_in_executor(
|
887 |
-
|
888 |
-
)
|
889 |
-
logger.info(f"Groq chart result: {groq_result}")
|
890 |
-
if isinstance(groq_result, str) and groq_result != "Chart not generated":
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
|
899 |
-
# Fallback: try langchain-based again
|
900 |
-
logger.error("Groq chart generation failed, trying langchain....")
|
901 |
-
langchain_paths = await loop.run_in_executor(
|
902 |
-
|
903 |
-
)
|
904 |
-
logger.info("Fallback langchain chart result:", langchain_paths)
|
905 |
-
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
else:
|
914 |
-
|
915 |
-
|
916 |
|
917 |
except Exception as e:
|
918 |
logger.error(f"Critical chart error: {str(e)}")
|
|
|
27 |
import seaborn as sns
|
28 |
from gemini_report_generator import generate_csv_report
|
29 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
30 |
+
from openai_react_agent_service import openai_react_chat
|
31 |
from orchestrator_agent import csv_orchestrator_chat
|
32 |
from supabase_service import upload_file_to_supabase
|
33 |
from util_service import _prompt_generator, process_answer
|
|
|
345 |
query = request.get("query")
|
346 |
csv_url = request.get("csv_url")
|
347 |
decoded_url = unquote(csv_url)
|
348 |
+
detailed_answer = request.get("detailed_answer", False)
|
349 |
conversation_history = request.get("conversation_history", [])
|
350 |
+
generate_report = request.get("generate_report", False)
|
351 |
+
is_pro = request.get("is_pro", False)
|
352 |
|
353 |
if generate_report is True:
|
354 |
report_files = await generate_csv_report(csv_url, query)
|
|
|
370 |
if orchestrator_answer is not None:
|
371 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
372 |
|
373 |
+
# if the user is pro, then we use the openai_react_agent first
|
374 |
+
if is_pro is True:
|
375 |
+
openai_answer = await asyncio.to_thread(
|
376 |
+
openai_react_chat, decoded_url, query, False
|
377 |
+
)
|
378 |
+
logger.info("openai_answer:", openai_answer)
|
379 |
+
if openai_answer is not None:
|
380 |
+
return {"answer": jsonable_encoder(openai_answer)}
|
381 |
+
|
382 |
# Process with groq_chat first
|
383 |
+
# groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
384 |
+
# logger.info("groq_answer:", groq_answer)
|
385 |
|
386 |
+
# if process_answer(groq_answer) == "Empty response received.":
|
387 |
+
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
388 |
|
389 |
+
# if process_answer(groq_answer):
|
390 |
+
# lang_answer = await asyncio.to_thread(
|
391 |
+
# langchain_csv_chat, decoded_url, query, False
|
392 |
+
# )
|
393 |
+
# if process_answer(lang_answer):
|
394 |
+
# return {"answer": "error"}
|
395 |
+
# return {"answer": jsonable_encoder(lang_answer)}
|
396 |
|
397 |
+
# return {"answer": jsonable_encoder(groq_answer)}
|
398 |
|
399 |
except Exception as e:
|
400 |
logger.error(f"Error processing request: {str(e)}")
|
|
|
862 |
detailed_answer = request.get("detailed_answer", False)
|
863 |
conversation_history = request.get("conversation_history", [])
|
864 |
generate_report = request.get("generate_report", False)
|
865 |
+
is_pro = request.get("is_pro", False)
|
866 |
|
867 |
if generate_report is True:
|
868 |
report_files = await generate_csv_report(csv_url, query)
|
|
|
893 |
|
894 |
if orchestrator_answer is not None:
|
895 |
return {"orchestrator_response": jsonable_encoder(orchestrator_answer)}
|
896 |
+
|
897 |
+
# If user have a pro subscription start with openai-reAct agent
|
898 |
+
if is_pro is True:
|
899 |
+
openai_react_answer = await asyncio.to_thread(
|
900 |
+
process_executor, openai_react_chat, csv_url, query, True
|
901 |
+
)
|
902 |
+
if openai_react_answer is not None:
|
903 |
+
chart_path = openai_react_answer
|
904 |
+
logger.info("Uploading the chart to supabase...")
|
905 |
+
unique_file_name =f'{str(uuid.uuid4())}.png'
|
906 |
+
image_public_url = await upload_file_to_supabase(f"{chart_path}", unique_file_name)
|
907 |
+
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
908 |
+
os.remove(chart_path)
|
909 |
+
return {"image_url": image_public_url}
|
910 |
+
|
911 |
# Next, try the groq-based method
|
912 |
+
# groq_result = await loop.run_in_executor(
|
913 |
+
# process_executor, groq_chart, csv_url, query
|
914 |
+
# )
|
915 |
+
# logger.info(f"Groq chart result: {groq_result}")
|
916 |
+
# if isinstance(groq_result, str) and groq_result != "Chart not generated":
|
917 |
+
# unique_file_name =f'{str(uuid.uuid4())}.png'
|
918 |
+
# logger.info("Uploading the chart to supabase...")
|
919 |
+
# image_public_url = await upload_file_to_supabase(f"{groq_result}", unique_file_name)
|
920 |
+
# logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
921 |
+
# os.remove(groq_result)
|
922 |
+
# return {"image_url": image_public_url}
|
923 |
+
# # return FileResponse(groq_result, media_type="image/png")
|
924 |
|
925 |
+
# # Fallback: try langchain-based again
|
926 |
+
# logger.error("Groq chart generation failed, trying langchain....")
|
927 |
+
# langchain_paths = await loop.run_in_executor(
|
928 |
+
# process_executor, langchain_csv_chart, csv_url, query, True
|
929 |
+
# )
|
930 |
+
# logger.info("Fallback langchain chart result:", langchain_paths)
|
931 |
+
# if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
932 |
+
# unique_file_name =f'{str(uuid.uuid4())}.png'
|
933 |
+
# logger.info("Uploading the chart to supabase...")
|
934 |
+
# image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name)
|
935 |
+
# logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
936 |
+
# os.remove(langchain_paths[0])
|
937 |
+
# return {"image_url": image_public_url}
|
938 |
+
# # return FileResponse(langchain_paths[0], media_type="image/png")
|
939 |
+
# else:
|
940 |
+
# logger.error("All chart generation methods failed")
|
941 |
+
# return {"answer": "error"}
|
942 |
|
943 |
except Exception as e:
|
944 |
logger.error(f"Critical chart error: {str(e)}")
|
openai_react_agent_service.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import re
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
import logging
|
8 |
+
import time
|
9 |
+
import threading
|
10 |
+
from io import StringIO
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
from typing import Optional, Dict, Any, List, Set
|
14 |
+
from pydantic import BaseModel, Field
|
15 |
+
from dotenv import load_dotenv
|
16 |
+
import seaborn as sns
|
17 |
+
import datetime as dt
|
18 |
+
from langchain_openai import ChatOpenAI
|
19 |
+
|
20 |
+
# Configure pandas display options
|
21 |
+
pd.set_option('display.max_columns', None)
|
22 |
+
pd.set_option('display.max_rows', None)
|
23 |
+
pd.set_option('display.max_colwidth', None)
|
24 |
+
|
25 |
+
# Load environment variables
|
26 |
+
load_dotenv()
|
27 |
+
|
28 |
+
# Configuration constants
|
29 |
+
API_KEYS = os.getenv("OPENAI_API_KEYS", "").split(",")
|
30 |
+
MODEL_NAME = 'gpt-4o'
|
31 |
+
KEY_RETRY_DELAY = 40 # seconds
|
32 |
+
|
33 |
+
# Configure non-interactive matplotlib backend
|
34 |
+
os.environ['MPLBACKEND'] = 'agg'
|
35 |
+
import matplotlib.pyplot as plt
|
36 |
+
plt.show = lambda: None # Disable display
|
37 |
+
|
38 |
+
# Configure logging
|
39 |
+
logging.basicConfig(
|
40 |
+
level=logging.INFO,
|
41 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
42 |
+
)
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
|
45 |
+
def handle_out_of_range_float(value):
|
46 |
+
"""Handle NaN and Inf values in numeric data"""
|
47 |
+
if isinstance(value, float):
|
48 |
+
if np.isnan(value):
|
49 |
+
return None
|
50 |
+
elif np.isinf(value):
|
51 |
+
return "Infinity"
|
52 |
+
return value
|
53 |
+
|
54 |
+
class OpenAIKeyManager:
|
55 |
+
"""Manage multiple OpenAI API keys with validation, failover, and delayed retries"""
|
56 |
+
|
57 |
+
def __init__(self, api_keys: List[str]):
|
58 |
+
self.original_keys = api_keys.copy()
|
59 |
+
self.available_keys = api_keys.copy()
|
60 |
+
self.active_key = None
|
61 |
+
self.failed_keys: Dict[str, float] = {} # key: timestamp when failed
|
62 |
+
self.llm_instance = None
|
63 |
+
self.lock = threading.Lock()
|
64 |
+
|
65 |
+
def configure(self) -> bool:
|
66 |
+
"""Validate and activate an OpenAI API key with retry logic"""
|
67 |
+
with self.lock:
|
68 |
+
# First try available keys
|
69 |
+
while self.available_keys:
|
70 |
+
key = self.available_keys.pop(0)
|
71 |
+
if self._try_key(key):
|
72 |
+
return True
|
73 |
+
|
74 |
+
# Then check if any failed keys are ready for retry
|
75 |
+
now = time.time()
|
76 |
+
retry_keys = [
|
77 |
+
k for k, ts in self.failed_keys.items()
|
78 |
+
if (now - ts) >= KEY_RETRY_DELAY
|
79 |
+
]
|
80 |
+
|
81 |
+
for key in retry_keys:
|
82 |
+
if self._try_key(key):
|
83 |
+
del self.failed_keys[key]
|
84 |
+
return True
|
85 |
+
|
86 |
+
logger.critical("All API keys failed (including retries)")
|
87 |
+
return False
|
88 |
+
|
89 |
+
def _try_key(self, key: str) -> bool:
|
90 |
+
"""Attempt to use a specific key, return True if successful"""
|
91 |
+
try:
|
92 |
+
self.llm_instance = ChatOpenAI(
|
93 |
+
model=MODEL_NAME,
|
94 |
+
api_key=key,
|
95 |
+
temperature=0,
|
96 |
+
max_retries=0
|
97 |
+
)
|
98 |
+
self.llm_instance.invoke("test") # Simple test call
|
99 |
+
self.active_key = key
|
100 |
+
logger.info(f"Active_Key: {self._mask_key(key)}")
|
101 |
+
return True
|
102 |
+
except Exception as e:
|
103 |
+
self.failed_keys[key] = time.time()
|
104 |
+
logger.error(f"Key failed: {self._mask_key(key)} - {str(e)}")
|
105 |
+
return False
|
106 |
+
|
107 |
+
def rotate_key(self) -> bool:
|
108 |
+
"""Rotate to the next available API key (including retries)"""
|
109 |
+
return self.configure()
|
110 |
+
|
111 |
+
def get_llm_instance(self) -> ChatOpenAI:
|
112 |
+
"""Get the configured LLM instance"""
|
113 |
+
return self.llm_instance
|
114 |
+
|
115 |
+
def _mask_key(self, key: str) -> str:
|
116 |
+
"""Mask API key for secure logging"""
|
117 |
+
return f"{key[:8]}...{key[-4:]}" if key else ""
|
118 |
+
|
119 |
+
class PythonREPL:
|
120 |
+
"""Secure Python REPL environment for code execution"""
|
121 |
+
|
122 |
+
def __init__(self, df: pd.DataFrame):
|
123 |
+
self.df = df
|
124 |
+
self.local_env = {
|
125 |
+
"pd": pd,
|
126 |
+
"df": self.df.copy(),
|
127 |
+
"plt": plt,
|
128 |
+
"os": os,
|
129 |
+
"uuid": uuid,
|
130 |
+
"sns": sns,
|
131 |
+
"json": json,
|
132 |
+
"dt": dt,
|
133 |
+
"np": np,
|
134 |
+
}
|
135 |
+
os.makedirs('generated_charts', exist_ok=True)
|
136 |
+
|
137 |
+
def execute(self, code: str) -> Dict[str, Any]:
|
138 |
+
"""Execute Python code in a secure environment"""
|
139 |
+
old_stdout = sys.stdout
|
140 |
+
sys.stdout = mystdout = StringIO()
|
141 |
+
error_msg = None
|
142 |
+
|
143 |
+
try:
|
144 |
+
# Ensure proper matplotlib configuration
|
145 |
+
code = f"""
|
146 |
+
import matplotlib.pyplot as plt
|
147 |
+
plt.switch_backend('agg')
|
148 |
+
{code}
|
149 |
+
plt.close('all')
|
150 |
+
"""
|
151 |
+
exec(code, self.local_env)
|
152 |
+
self.df = self.local_env.get('df', self.df)
|
153 |
+
error = False
|
154 |
+
except Exception as e:
|
155 |
+
error_msg = traceback.format_exc()
|
156 |
+
error = True
|
157 |
+
finally:
|
158 |
+
sys.stdout = old_stdout
|
159 |
+
|
160 |
+
return {
|
161 |
+
"output": mystdout.getvalue(),
|
162 |
+
"error": error,
|
163 |
+
"error_message": error_msg if error else None,
|
164 |
+
"df": self.local_env.get('df', self.df)
|
165 |
+
}
|
166 |
+
|
167 |
+
class RethinkAgent(BaseModel):
|
168 |
+
"""AI agent for data analysis with automatic error correction"""
|
169 |
+
|
170 |
+
df: pd.DataFrame
|
171 |
+
max_retries: int = Field(default=5, ge=1)
|
172 |
+
current_retry: int = Field(default=0, ge=0)
|
173 |
+
repl: Optional[PythonREPL] = None
|
174 |
+
key_manager: Optional[OpenAIKeyManager] = None
|
175 |
+
llm: Optional[ChatOpenAI] = None
|
176 |
+
|
177 |
+
class Config:
|
178 |
+
arbitrary_types_allowed = True
|
179 |
+
|
180 |
+
def _extract_code(self, response: str) -> str:
|
181 |
+
"""Extract Python code from markdown response"""
|
182 |
+
code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
|
183 |
+
if code_match:
|
184 |
+
return code_match.group(1).strip()
|
185 |
+
code_match = re.search(r'```(.*?)```', response, re.DOTALL)
|
186 |
+
return code_match.group(1).strip() if code_match else response.strip()
|
187 |
+
|
188 |
+
def _generate_initial_prompt(self, query: str, chart: bool = False) -> str:
|
189 |
+
"""Generate the initial prompt for the LLM"""
|
190 |
+
columns = "\n".join([f"{col} ({self.df[col].dtype})" for col in self.df.columns])
|
191 |
+
|
192 |
+
if chart:
|
193 |
+
return f"""
|
194 |
+
Generate Python code to create visualization(s) for this DataFrame with columns:
|
195 |
+
{columns}
|
196 |
+
|
197 |
+
First 5 rows:
|
198 |
+
{self.df.head().to_string()}
|
199 |
+
|
200 |
+
Query: {query}
|
201 |
+
|
202 |
+
Requirements:
|
203 |
+
1. Save visualizations to 'generated_charts/' with UUID filename (use uuid.uuid4())
|
204 |
+
2. Use plt.savefig() with format='png'
|
205 |
+
3. No plt.show() calls allowed
|
206 |
+
4. After saving each chart, logger.info exactly: CHART_SAVED: generated_charts/<uuid>.png
|
207 |
+
5. Start with 'import pandas as pd', 'import matplotlib.pyplot as plt', etc.
|
208 |
+
6. The DataFrame is available as 'df'
|
209 |
+
7. Wrap code in ```python``` blocks
|
210 |
+
8. If Question is illogical and cannot be answered, explain using logger.info()
|
211 |
+
"""
|
212 |
+
else:
|
213 |
+
return f"""
|
214 |
+
Generate Python code to analyze this DataFrame with columns:
|
215 |
+
{columns}
|
216 |
+
|
217 |
+
First 5 rows:
|
218 |
+
{self.df.head().to_string()}
|
219 |
+
|
220 |
+
Query: {query}
|
221 |
+
|
222 |
+
Requirements:
|
223 |
+
1. Use logger.info() to show results with clear explanations
|
224 |
+
2. If Question is illogical and cannot be answered, explain using logger.info()
|
225 |
+
3. Start with necessary imports ('import pandas as pd', etc.)
|
226 |
+
4. The DataFrame is available as 'df'
|
227 |
+
5. For tabular results, use markdown formatting
|
228 |
+
6. Wrap code in ```python``` blocks
|
229 |
+
"""
|
230 |
+
|
231 |
+
def _generate_retry_prompt(self, query: str, error: str, code: str, chart: bool = False) -> str:
|
232 |
+
"""Generate a retry prompt when code execution fails"""
|
233 |
+
if chart:
|
234 |
+
return f"""
|
235 |
+
The previous code failed with this error:
|
236 |
+
{error}
|
237 |
+
|
238 |
+
Here was the code that failed:
|
239 |
+
{code}
|
240 |
+
|
241 |
+
Please fix the code to:
|
242 |
+
1. Create the requested visualization(s)
|
243 |
+
2. Save to 'generated_charts/' with UUID filename
|
244 |
+
3. logger.info CHART_SAVED messages
|
245 |
+
4. Handle the error: {error}
|
246 |
+
|
247 |
+
Original query: {query}
|
248 |
+
|
249 |
+
Show the corrected code in ```python``` blocks
|
250 |
+
"""
|
251 |
+
else:
|
252 |
+
return f"""
|
253 |
+
The previous code failed with this error:
|
254 |
+
{error}
|
255 |
+
|
256 |
+
Here was the code that failed:
|
257 |
+
{code}
|
258 |
+
|
259 |
+
Please fix the code to:
|
260 |
+
1. Complete the analysis requested
|
261 |
+
2. Handle the error: {error}
|
262 |
+
3. Include clear output formatting
|
263 |
+
|
264 |
+
Original query: {query}
|
265 |
+
|
266 |
+
Show the corrected code in ```python``` blocks
|
267 |
+
"""
|
268 |
+
|
269 |
+
def initialize_model(self, api_keys: List[str]) -> bool:
|
270 |
+
"""Initialize OpenAI client with key rotation"""
|
271 |
+
self.key_manager = OpenAIKeyManager(api_keys)
|
272 |
+
if not self.key_manager.configure():
|
273 |
+
raise RuntimeError("All API keys failed")
|
274 |
+
self.llm = self.key_manager.get_llm_instance()
|
275 |
+
return True
|
276 |
+
|
277 |
+
def generate_code(self, query: str, error: Optional[str] = None,
|
278 |
+
previous_code: Optional[str] = None, chart: bool = False) -> str:
|
279 |
+
"""Generate Python code to answer the query"""
|
280 |
+
prompt = self._generate_retry_prompt(query, error, previous_code, chart) if error else self._generate_initial_prompt(query, chart)
|
281 |
+
|
282 |
+
try:
|
283 |
+
response = self.llm.invoke(prompt)
|
284 |
+
return self._extract_code(response.content)
|
285 |
+
except Exception as e:
|
286 |
+
logger.error(f"API error: {str(e)}")
|
287 |
+
if self.key_manager.rotate_key():
|
288 |
+
self.llm = self.key_manager.get_llm_instance()
|
289 |
+
return self.generate_code(query, error, previous_code, chart)
|
290 |
+
raise
|
291 |
+
|
292 |
+
def execute_query(self, query: str, chart: bool = False) -> str:
|
293 |
+
"""Execute the query with automatic error correction"""
|
294 |
+
self.repl = PythonREPL(self.df)
|
295 |
+
error = None
|
296 |
+
previous_code = None
|
297 |
+
|
298 |
+
while self.current_retry < self.max_retries:
|
299 |
+
try:
|
300 |
+
code = self.generate_code(query, error, previous_code, chart)
|
301 |
+
result = self.repl.execute(code)
|
302 |
+
|
303 |
+
if result["error"]:
|
304 |
+
self.current_retry += 1
|
305 |
+
error = result["error_message"]
|
306 |
+
previous_code = code
|
307 |
+
logger.warning(f"Retry {self.current_retry}/{self.max_retries}")
|
308 |
+
else:
|
309 |
+
self.df = result["df"]
|
310 |
+
return result["output"]
|
311 |
+
except Exception as e:
|
312 |
+
logger.error(f"Critical error: {str(e)}")
|
313 |
+
return f"System error: {str(e)}"
|
314 |
+
|
315 |
+
return f"Failed after {self.max_retries} retries. Last error: {error}"
|
316 |
+
|
317 |
+
def openai_react_chat(csv_url: str, query: str, chart: bool = False) -> Optional[Dict]:
|
318 |
+
"""Main function to execute data analysis queries"""
|
319 |
+
try:
|
320 |
+
# Read and validate input data
|
321 |
+
df = pd.read_csv(csv_url)
|
322 |
+
if df.empty:
|
323 |
+
raise ValueError("Empty DataFrame loaded from CSV")
|
324 |
+
|
325 |
+
agent = RethinkAgent(df=df)
|
326 |
+
|
327 |
+
if not agent.initialize_model(API_KEYS):
|
328 |
+
logger.error("Failed to initialize model")
|
329 |
+
return None
|
330 |
+
|
331 |
+
result = agent.execute_query(query, chart)
|
332 |
+
|
333 |
+
# Process different response types
|
334 |
+
if isinstance(result, pd.DataFrame):
|
335 |
+
processed = result.apply(handle_out_of_range_float).to_dict(orient="records")
|
336 |
+
elif isinstance(result, pd.Series):
|
337 |
+
processed = result.apply(handle_out_of_range_float).to_dict()
|
338 |
+
elif isinstance(result, list):
|
339 |
+
processed = [handle_out_of_range_float(item) for item in result]
|
340 |
+
elif isinstance(result, dict):
|
341 |
+
processed = {k: handle_out_of_range_float(v) for k, v in result.items()}
|
342 |
+
else:
|
343 |
+
processed = {"answer": str(handle_out_of_range_float(result))}
|
344 |
+
|
345 |
+
logger.info("Analysis completed successfully")
|
346 |
+
|
347 |
+
if chart and isinstance(result, str) and result.startswith("CHART_SAVED:"):
|
348 |
+
result = result.strip() # Remove any leading/trailing spaces or newlines
|
349 |
+
match = re.search(r'CHART_SAVED:\s*(\S+)', result)
|
350 |
+
if match:
|
351 |
+
chart_path = match.group(1)
|
352 |
+
logger.info("Chart Path:", chart_path)
|
353 |
+
return chart_path
|
354 |
+
else:
|
355 |
+
logger.info("Could not extract chart path from response")
|
356 |
+
return None
|
357 |
+
|
358 |
+
return processed
|
359 |
+
except Exception as e:
|
360 |
+
logger.error(f"Error in openai_llm_chat: {str(e)}")
|
361 |
+
return None
|
362 |
+
|