# import json | |
# import numpy as np | |
# import pandas as pd | |
# import re | |
# import os | |
# import uuid | |
# import logging | |
# from io import StringIO | |
# import sys | |
# import traceback | |
# from typing import Optional, Dict, Any, List | |
# from pydantic import BaseModel, Field | |
# from google.generativeai import GenerativeModel, configure | |
# from dotenv import load_dotenv | |
# import seaborn as sns | |
# import datetime as dt | |
# from supabase_service import upload_file_to_supabase | |
# pd.set_option('display.max_columns', None) | |
# pd.set_option('display.max_rows', None) | |
# pd.set_option('display.max_colwidth', None) | |
# load_dotenv() | |
# API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")[::-1] | |
# MODEL_NAME = 'gemini-2.0-flash' | |
# class FileProps(BaseModel): | |
# fileName: str | |
# filePath: str | |
# fileType: str # 'csv' | 'image' | |
# class Files(BaseModel): | |
# csv_files: List[FileProps] | |
# image_files: List[FileProps] | |
# class FileBoxProps(BaseModel): | |
# files: Files | |
# os.environ['MPLBACKEND'] = 'agg' | |
# import matplotlib.pyplot as plt | |
# plt.show = lambda: None | |
# logging.basicConfig( | |
# level=logging.INFO, | |
# format='%(asctime)s - %(levelname)s - %(message)s' | |
# ) | |
# logger = logging.getLogger(__name__) | |
# class GeminiKeyManager: | |
# """Manage multiple Gemini API keys with failover""" | |
# def __init__(self, api_keys: List[str]): | |
# self.original_keys = api_keys.copy() | |
# self.available_keys = api_keys.copy() | |
# self.active_key = None | |
# self.failed_keys = {} | |
# def configure(self) -> bool: | |
# while self.available_keys: | |
# key = self.available_keys.pop(0) | |
# try: | |
# configure(api_key=key) | |
# self.active_key = key | |
# logger.info(f"Configured with key: {self._mask_key(key)}") | |
# return True | |
# except Exception as e: | |
# self.failed_keys[key] = str(e) | |
# logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}") | |
# logger.critical("All API keys failed") | |
# return False | |
# def _mask_key(self, key: str) -> str: | |
# return f"{key[:8]}...{key[-4:]}" if key else "" | |
# class PythonREPL: | |
# """Secure Python REPL with file generation tracking""" | |
# def __init__(self, df: pd.DataFrame): | |
# self.df = df | |
# self.output_dir = os.path.abspath(f'generated_outputs/{uuid.uuid4()}') | |
# os.makedirs(self.output_dir, exist_ok=True) | |
# self.local_env = { | |
# "pd": pd, | |
# "df": self.df.copy(), | |
# "plt": plt, | |
# "os": os, | |
# "uuid": uuid, | |
# "sns": sns, | |
# "json": json, | |
# "dt": dt, | |
# "output_dir": self.output_dir | |
# } | |
# def execute(self, code: str) -> Dict[str, Any]: | |
# print('Executing code...', code) | |
# old_stdout = sys.stdout | |
# sys.stdout = mystdout = StringIO() | |
# file_tracker = { | |
# 'csv_files': set(), | |
# 'image_files': set() | |
# } | |
# try: | |
# code = f""" | |
# import matplotlib.pyplot as plt | |
# plt.switch_backend('agg') | |
# {code} | |
# plt.close('all') | |
# """ | |
# exec(code, self.local_env) | |
# self.df = self.local_env.get('df', self.df) | |
# # Track generated files | |
# for fname in os.listdir(self.output_dir): | |
# if fname.endswith('.csv'): | |
# file_tracker['csv_files'].add(fname) | |
# elif fname.lower().endswith(('.png', '.jpg', '.jpeg')): | |
# file_tracker['image_files'].add(fname) | |
# error = False | |
# except Exception as e: | |
# error_msg = traceback.format_exc() | |
# error = True | |
# finally: | |
# sys.stdout = old_stdout | |
# return { | |
# "output": mystdout.getvalue(), | |
# "error": error, | |
# "error_message": error_msg if error else None, | |
# "df": self.local_env.get('df', self.df), | |
# "output_dir": self.output_dir, | |
# "files": { | |
# "csv": [os.path.join(self.output_dir, f) for f in file_tracker['csv_files']], | |
# "images": [os.path.join(self.output_dir, f) for f in file_tracker['image_files']] | |
# } | |
# } | |
# class RethinkAgent(BaseModel): | |
# df: pd.DataFrame | |
# max_retries: int = Field(default=5, ge=1) | |
# gemini_model: Optional[GenerativeModel] = None | |
# current_retry: int = Field(default=0, ge=0) | |
# repl: Optional[PythonREPL] = None | |
# key_manager: Optional[GeminiKeyManager] = None | |
# class Config: | |
# arbitrary_types_allowed = True | |
# def _extract_code(self, response: str) -> str: | |
# code_match = re.search(r'```python(.*?)```', response, re.DOTALL) | |
# return code_match.group(1).strip() if code_match else response.strip() | |
# def _generate_initial_prompt(self, query: str) -> str: | |
# return f"""Generate DIRECT EXECUTION CODE (no functions, no explanations) following STRICT RULES: | |
# MANDATORY REQUIREMENTS: | |
# 1. Operate directly on existing 'df' variable | |
# 2. Save ALL final DataFrames to CSV using: df.to_csv(f'{{output_dir}}/descriptive_name.csv') | |
# 3. For visualizations: plt.savefig(f'{{output_dir}}/chart_name.png') | |
# 4. Use EXACTLY this structure: | |
# # Data processing | |
# df_processed = df[...] # filtering/grouping | |
# # Save results | |
# df_processed.to_csv(f'{{output_dir}}/result.csv') | |
# # Visualizations (if needed) | |
# plt.figure() | |
# ... plotting code ... | |
# plt.savefig(f'{{output_dir}}/chart.png') | |
# plt.close() | |
# FORBIDDEN: | |
# - Function definitions | |
# - Dummy data creation | |
# - Any code blocks besides pandas operations and matplotlib | |
# - Print statements showing dataframes | |
# DATAFRAME COLUMNS: {', '.join(self.df.columns)} | |
# DATAFRAME'S FIRST FIVE ROWS: {self.df.head().to_dict('records')} | |
# USER QUERY: {query} | |
# EXAMPLE RESPONSE FOR "Sales by region": | |
# # Data processing | |
# sales_by_region = df.groupby('region')['sales'].sum().reset_index() | |
# # Save results | |
# sales_by_region.to_csv(f'{{output_dir}}/sales_by_region.csv') | |
# """ | |
# def _generate_retry_prompt(self, query: str, error: str, code: str) -> str: | |
# return f"""FIX THIS CODE (failed with: {error}) by STRICTLY FOLLOWING: | |
# 1. REMOVE ALL FUNCTION DEFINITIONS | |
# 2. ENSURE DIRECT DF OPERATIONS | |
# 3. USE EXPLICIT output_dir PATHS | |
# 4. ADD NECESSARY IMPORTS IF MISSING | |
# 5. VALIDATE COLUMN NAMES EXIST | |
# BAD CODE: | |
# {code} | |
# CORRECTED CODE:""" | |
# def initialize_model(self, api_keys: List[str]) -> bool: | |
# self.key_manager = GeminiKeyManager(api_keys) | |
# if not self.key_manager.configure(): | |
# raise RuntimeError("API key initialization failed") | |
# try: | |
# self.gemini_model = GenerativeModel(MODEL_NAME) | |
# return True | |
# except Exception as e: | |
# logger.error(f"Model init failed: {str(e)}") | |
# return False | |
# def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str: | |
# prompt = self._generate_retry_prompt(query, error, previous_code) if error else self._generate_initial_prompt(query) | |
# try: | |
# response = self.gemini_model.generate_content(prompt) | |
# return self._extract_code(response.text) | |
# except Exception as e: | |
# if self.key_manager.available_keys and self.key_manager.configure(): | |
# return self.generate_code(query, error, previous_code) | |
# raise | |
# def execute_query(self, query: str) -> Dict[str, Any]: | |
# self.repl = PythonREPL(self.df) | |
# result = None | |
# while self.current_retry < self.max_retries: | |
# try: | |
# code = self.generate_code(query, | |
# result["error_message"] if result else None, | |
# result["code"] if result else None) | |
# execution_result = self.repl.execute(code) | |
# if execution_result["error"]: | |
# self.current_retry += 1 | |
# result = { | |
# "error_message": execution_result["error_message"], | |
# "code": code | |
# } | |
# else: | |
# return { | |
# "text": execution_result["output"], | |
# "csv_files": execution_result["files"]["csv"], | |
# "image_files": execution_result["files"]["images"] | |
# } | |
# except Exception as e: | |
# return { | |
# "error": f"Critical failure: {str(e)}", | |
# "csv_files": [], | |
# "image_files": [] | |
# } | |
# return { | |
# "error": f"Failed after {self.max_retries} retries", | |
# "csv_files": [], | |
# "image_files": [] | |
# } | |
# def gemini_llm_chat(csv_url: str, query: str) -> Dict[str, Any]: | |
# try: | |
# df = pd.read_csv(csv_url) | |
# agent = RethinkAgent(df=df) | |
# if not agent.initialize_model(API_KEYS): | |
# return {"error": "API configuration failed"} | |
# result = agent.execute_query(query) | |
# if "error" in result: | |
# return result | |
# return { | |
# "message": result["text"], | |
# "csv_files": result["csv_files"], | |
# "image_files": result["image_files"] | |
# } | |
# except Exception as e: | |
# logger.error(f"Processing failed: {str(e)}") | |
# return { | |
# "error": f"Processing error: {str(e)}", | |
# "csv_files": [], | |
# "image_files": [] | |
# } | |
# async def generate_csv_report(csv_url: str, query: str) -> FileBoxProps: | |
# try: | |
# result = gemini_llm_chat(csv_url, query) | |
# logger.info(f"Raw result from gemini_llm_chat: {result}") | |
# csv_files = [] | |
# image_files = [] | |
# # Check if we got the expected response structure | |
# if isinstance(result, dict) and 'csv_files' in result and 'image_files' in result: | |
# # Process CSV files | |
# for csv_path in result['csv_files']: | |
# if os.path.exists(csv_path): | |
# file_name = os.path.basename(csv_path) | |
# try: | |
# unique_file_name = f"{uuid.uuid4()}_{file_name}" | |
# public_url = await upload_file_to_supabase( | |
# file_path=csv_path, | |
# file_name=unique_file_name | |
# ) | |
# csv_files.append(FileProps( | |
# fileName=file_name, | |
# filePath=public_url, | |
# fileType="csv" | |
# )) | |
# os.remove(csv_path) # Clean up | |
# except Exception as upload_error: | |
# logger.error(f"Failed to upload CSV {file_name}: {str(upload_error)}") | |
# continue | |
# # Process image files | |
# for img_path in result['image_files']: | |
# if os.path.exists(img_path): | |
# file_name = os.path.basename(img_path) | |
# try: | |
# unique_file_name = f"{uuid.uuid4()}_{file_name}" | |
# public_url = await upload_file_to_supabase( | |
# file_path=img_path, | |
# file_name=unique_file_name | |
# ) | |
# image_files.append(FileProps( | |
# fileName=file_name, | |
# filePath=public_url, | |
# fileType="image" | |
# )) | |
# os.remove(img_path) # Clean up | |
# except Exception as upload_error: | |
# logger.error(f"Failed to upload image {file_name}: {str(upload_error)}") | |
# continue | |
# return FileBoxProps( | |
# files=Files( | |
# csv_files=csv_files, | |
# image_files=image_files | |
# ) | |
# ) | |
# else: | |
# raise ValueError("Unexpected response format from gemini_llm_chat") | |
# except Exception as e: | |
# logger.error(f"Report generation failed: {str(e)}") | |
# # Return empty response but log the files we found | |
# if 'csv_files' in locals() and 'image_files' in locals(): | |
# logger.info(f"Files that were generated but not processed: CSV: {result.get('csv_files', [])}, Images: {result.get('image_files', [])}") | |
# return FileBoxProps( | |
# files=Files( | |
# csv_files=[], | |
# image_files=[] | |
# ) | |
# ) | |
# Newly Modified code with openai | |
# Import necessary modules | |
import asyncio | |
import os | |
import threading | |
from typing import Any, Dict, Union | |
import uuid | |
from fastapi.encoders import jsonable_encoder | |
from langchain_openai import ChatOpenAI | |
import numpy as np | |
import pandas as pd | |
from pandasai import SmartDataframe | |
from langchain_groq.chat_models import ChatGroq | |
from dotenv import load_dotenv | |
from pydantic import BaseModel | |
from csv_service import clean_data, extract_chart_filenames | |
from langchain_groq import ChatGroq | |
import pandas as pd | |
from langchain_experimental.tools import PythonAstREPLTool | |
from langchain_experimental.agents import create_pandas_dataframe_agent | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import seaborn as sns | |
from gemini_langchain_agent import langchain_gemini_csv_handler | |
from supabase_service import upload_file_to_supabase | |
from util_service import _prompt_generator, process_answer | |
import matplotlib | |
import logging | |
matplotlib.use('Agg') | |
load_dotenv() | |
image_file_path = os.getenv("IMAGE_FILE_PATH") | |
image_not_found = os.getenv("IMAGE_NOT_FOUND") | |
allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") | |
# Load environment variables | |
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") | |
model_name = os.getenv("GROQ_LLM_MODEL") | |
openai_api_keys = os.getenv("OPENAI_API_KEYS").split(",") | |
openai_base_url = os.getenv("OPENAI_BASE_URL") | |
openai_api_base = os.getenv("OPENAI_BASE_URL") | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class CsvUrlRequest(BaseModel): | |
csv_url: str | |
class ImageRequest(BaseModel): | |
image_path: str | |
class CsvCommonHeadersRequest(BaseModel): | |
file_urls: list[str] | |
class CsvsMergeRequest(BaseModel): | |
file_urls: list[str] | |
merge_type: str | |
common_columns_name: list[str] | |
# Thread-safe key management for openai_chat | |
current_openai_key_index = 0 | |
current_openai_key_lock = threading.Lock() | |
# Thread-safe key management for groq_chat | |
current_groq_key_index = 0 | |
current_groq_key_lock = threading.Lock() | |
# Thread-safe key management for langchain_csv_chat | |
current_langchain_key_index = 0 | |
current_langchain_key_lock = threading.Lock() | |
# CHAT CODING STARTS FROM HERE | |
def handle_out_of_range_float(value): | |
if isinstance(value, float): | |
if np.isnan(value): | |
return None | |
elif np.isinf(value): | |
return "Infinity" | |
return value | |
# Modified groq_chat function with thread-safe key rotation | |
def groq_chat(csv_url: str, question: str): | |
global current_groq_key_index, current_groq_key_lock | |
while True: | |
with current_groq_key_lock: | |
if current_groq_key_index >= len(groq_api_keys): | |
return {"error": "All API keys exhausted."} | |
current_api_key = groq_api_keys[current_groq_key_index] | |
try: | |
# Delete cache file if exists | |
cache_db_path = "/workspace/cache/cache_db_0.11.db" | |
if os.path.exists(cache_db_path): | |
try: | |
os.remove(cache_db_path) | |
except Exception as e: | |
print(f"Error deleting cache DB file: {e}") | |
data = clean_data(csv_url) | |
llm = ChatGroq(model=model_name, api_key=current_api_key) | |
# Generate unique filename using UUID | |
chart_filename = f"chart_{uuid.uuid4()}.png" | |
chart_path = os.path.join("generated_charts", chart_filename) | |
# Configure SmartDataframe with chart settings | |
df = SmartDataframe( | |
data, | |
config={ | |
'llm': llm, | |
'save_charts': True, # Enable chart saving | |
'open_charts': False, | |
'save_charts_path': os.path.dirname(chart_path), # Directory to save | |
'custom_chart_filename': chart_filename # Unique filename | |
} | |
) | |
answer = df.chat(question) | |
# Process different response types | |
if isinstance(answer, pd.DataFrame): | |
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") | |
elif isinstance(answer, pd.Series): | |
processed = answer.apply(handle_out_of_range_float).to_dict() | |
elif isinstance(answer, list): | |
processed = [handle_out_of_range_float(item) for item in answer] | |
elif isinstance(answer, dict): | |
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} | |
else: | |
processed = {"answer": str(handle_out_of_range_float(answer))} | |
return processed | |
except Exception as e: | |
error_message = str(e) | |
if error_message: | |
with current_groq_key_lock: | |
current_groq_key_index += 1 | |
if current_groq_key_index >= len(groq_api_keys): | |
print("All API keys exhausted.") | |
return None | |
else: | |
print(f"Error with API key index {current_groq_key_index}: {error_message}") | |
return None | |
# Modified langchain_csv_chat with thread-safe key rotation | |
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): | |
global current_langchain_key_index, current_langchain_key_lock | |
data = clean_data(csv_url) | |
attempts = 0 | |
while attempts < len(groq_api_keys): | |
with current_langchain_key_lock: | |
if current_langchain_key_index >= len(groq_api_keys): | |
current_langchain_key_index = 0 | |
api_key = groq_api_keys[current_langchain_key_index] | |
current_key = current_langchain_key_index | |
current_langchain_key_index += 1 | |
attempts += 1 | |
try: | |
llm = ChatGroq(model=model_name, api_key=api_key) | |
tool = PythonAstREPLTool(locals={ | |
"df": data, | |
"pd": pd, | |
"np": np, | |
"plt": plt, | |
"sns": sns, | |
"matplotlib": matplotlib | |
}) | |
agent = create_pandas_dataframe_agent( | |
llm, | |
data, | |
agent_type="tool-calling", | |
verbose=True, | |
allow_dangerous_code=True, | |
extra_tools=[tool], | |
return_intermediate_steps=True | |
) | |
prompt = _prompt_generator(question, chart_required) | |
result = agent.invoke({"input": prompt}) | |
return result.get("output") | |
except Exception as e: | |
print(f"Error with key index {current_key}: {str(e)}") | |
# If all keys are exhausted, return None | |
print("All API keys have been exhausted.") | |
return None | |
def handle_out_of_range_float(value): | |
if isinstance(value, float): | |
if np.isnan(value): | |
return None | |
elif np.isinf(value): | |
return "Infinity" | |
return value | |
# CHART CODING STARTS FROM HERE | |
instructions = """ | |
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). | |
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) | |
- Use colorblind-friendly palette | |
- Read above instructions and follow them. | |
""" | |
# Thread-safe configuration for chart endpoints | |
current_groq_chart_key_index = 0 | |
current_groq_chart_lock = threading.Lock() | |
current_langchain_chart_key_index = 0 | |
current_langchain_chart_lock = threading.Lock() | |
def model(): | |
global current_groq_chart_key_index, current_groq_chart_lock | |
with current_groq_chart_lock: | |
if current_groq_chart_key_index >= len(groq_api_keys): | |
raise Exception("All API keys exhausted for chart generation") | |
api_key = groq_api_keys[current_groq_chart_key_index] | |
return ChatGroq(model=model_name, api_key=api_key) | |
def groq_chart(csv_url: str, question: str): | |
global current_groq_chart_key_index, current_groq_chart_lock | |
for attempt in range(len(groq_api_keys)): | |
try: | |
# Clean cache before processing | |
cache_db_path = "/workspace/cache/cache_db_0.11.db" | |
if os.path.exists(cache_db_path): | |
try: | |
os.remove(cache_db_path) | |
except Exception as e: | |
print(f"Cache cleanup error: {e}") | |
data = clean_data(csv_url) | |
with current_groq_chart_lock: | |
current_api_key = groq_api_keys[current_groq_chart_key_index] | |
llm = ChatGroq(model=model_name, api_key=current_api_key) | |
# Generate unique filename using UUID | |
chart_filename = f"chart_{uuid.uuid4()}.png" | |
chart_path = os.path.join("generated_charts", chart_filename) | |
# Configure SmartDataframe with chart settings | |
df = SmartDataframe( | |
data, | |
config={ | |
'llm': llm, | |
'save_charts': True, # Enable chart saving | |
'open_charts': False, | |
'save_charts_path': os.path.dirname(chart_path), # Directory to save | |
'custom_chart_filename': chart_filename # Unique filename | |
} | |
) | |
answer = df.chat(question + instructions) | |
if process_answer(answer): | |
return "Chart not generated" | |
return answer | |
except Exception as e: | |
error = str(e) | |
if "429" in error or error is not None: | |
with current_groq_chart_lock: | |
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) | |
else: | |
print(f"Chart generation error: {error}") | |
return {"error": error} | |
print("All API keys exhausted for chart generation") | |
return None | |
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): | |
global current_langchain_chart_key_index, current_langchain_chart_lock | |
data = clean_data(csv_url) | |
for attempt in range(len(groq_api_keys)): | |
try: | |
with current_langchain_chart_lock: | |
api_key = groq_api_keys[current_langchain_chart_key_index] | |
current_key = current_langchain_chart_key_index | |
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) | |
llm = ChatGroq(model=model_name, api_key=api_key) | |
tool = PythonAstREPLTool(locals={ | |
"df": data, | |
"pd": pd, | |
"np": np, | |
"plt": plt, | |
"sns": sns, | |
"matplotlib": matplotlib, | |
"uuid": uuid | |
}) | |
agent = create_pandas_dataframe_agent( | |
llm, | |
data, | |
agent_type="tool-calling", | |
verbose=True, | |
allow_dangerous_code=True, | |
extra_tools=[tool], | |
return_intermediate_steps=True | |
) | |
result = agent.invoke({"input": _prompt_generator(f"{question} and use this csv_url: {csv_url} to read the csv file", True)}) | |
output = result.get("output", "") | |
# Verify chart file creation | |
chart_files = extract_chart_filenames(output) | |
if len(chart_files) > 0: | |
return chart_files | |
if attempt < len(groq_api_keys) - 1: | |
print(f"Langchain chart error (key {current_key}): {output}") | |
except Exception as e: | |
print(f"Langchain chart error (key {current_key}): {str(e)}") | |
print("All API keys exhausted for chart generation") | |
return None | |
####################################### OpenAI + PandasAI ####################################### | |
# Modified openai_chat function with thread-safe key rotation | |
openai_model_name = 'gpt-4o' | |
def openai_chat(csv_url: str, question: str): | |
global current_openai_key_index, current_openai_key_lock | |
while True: | |
with current_openai_key_lock: | |
if current_openai_key_index >= len(openai_api_keys): | |
return {"error": "All API keys exhausted."} | |
current_api_key = openai_api_keys[current_openai_key_index] | |
try: | |
# Delete cache file if exists | |
cache_db_path = "/workspace/cache/cache_db_0.11.db" | |
if os.path.exists(cache_db_path): | |
try: | |
os.remove(cache_db_path) | |
except Exception as e: | |
print(f"Error deleting cache DB file: {e}") | |
data = clean_data(csv_url) | |
llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base) | |
# Generate unique filename using UUID | |
chart_filename = f"chart_{uuid.uuid4()}.png" | |
chart_path = os.path.join("generated_charts", chart_filename) | |
# Configure SmartDataframe with chart settings | |
df = SmartDataframe( | |
data, | |
config={ | |
'llm': llm, | |
'save_charts': True, # Enable chart saving | |
'open_charts': False, | |
'save_charts_path': os.path.dirname(chart_path), # Directory to save | |
'custom_chart_filename': chart_filename # Unique filename | |
} | |
) | |
answer = df.chat(question) | |
# Process different response types | |
if isinstance(answer, pd.DataFrame): | |
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") | |
elif isinstance(answer, pd.Series): | |
processed = answer.apply(handle_out_of_range_float).to_dict() | |
elif isinstance(answer, list): | |
processed = [handle_out_of_range_float(item) for item in answer] | |
elif isinstance(answer, dict): | |
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} | |
else: | |
processed = {"answer": str(handle_out_of_range_float(answer))} | |
return processed | |
except Exception as e: | |
error_message = str(e) | |
if error_message: | |
with current_openai_key_lock: | |
current_openai_key_index += 1 | |
if current_openai_key_index >= len(openai_api_keys): | |
print("All API keys exhausted.") | |
return None | |
else: | |
print(f"Error with API key index {current_openai_key_index}: {error_message}") | |
return None | |
def openai_chart(csv_url: str, question: str): | |
global current_openai_key_index, current_openai_key_lock | |
while True: | |
with current_openai_key_lock: | |
if current_openai_key_index >= len(openai_api_keys): | |
return {"error": "All API keys exhausted."} | |
current_api_key = openai_api_keys[current_openai_key_index] | |
try: | |
# Delete cache file if exists | |
cache_db_path = "/workspace/cache/cache_db_0.11.db" | |
if os.path.exists(cache_db_path): | |
try: | |
os.remove(cache_db_path) | |
except Exception as e: | |
print(f"Error deleting cache DB file: {e}") | |
data = clean_data(csv_url) | |
llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base) | |
# Generate unique filename using UUID | |
chart_filename = f"chart_{uuid.uuid4()}.png" | |
chart_path = os.path.join("generated_charts", chart_filename) | |
# Configure SmartDataframe with chart settings | |
df = SmartDataframe( | |
data, | |
config={ | |
'llm': llm, | |
'save_charts': True, # Enable chart saving | |
'open_charts': False, | |
'save_charts_path': os.path.dirname(chart_path), # Directory to save | |
'custom_chart_filename': chart_filename # Unique filename | |
} | |
) | |
answer = df.chat(question + instructions) | |
if process_answer(answer): | |
return "Chart not generated" | |
return answer | |
except Exception as e: | |
error = str(e) | |
print(f"Error with API key index {current_openai_key_index}: {error}") | |
if "429" in error or error is not None: | |
with current_openai_key_lock: | |
current_openai_key_index = (current_openai_key_index + 1) % len(openai_api_keys) | |
else: | |
print(f"Chart generation error: {error}") | |
return {"error": error} | |
print("All API keys exhausted for chart generation") | |
return None | |
####################################### Start with lc_gemini ####################################### | |
# async def csv_chat(csv_url: str, query: str): | |
# """ | |
# Generate a response based on the provided CSV URL and query. | |
# Prioritizes LangChain-Gemini, then LangChain-Groq, then raw OpenAI and finally raw Groq as fallback. | |
# Parameters: | |
# - csv_url (str): The URL of the CSV file. | |
# - query (str): The query for generating the response. | |
# Returns: | |
# - dict: A dictionary containing the generated response. | |
# Example: | |
# - csv_url: "https://example.com/data.csv" | |
# - query: "What is the total sales for the year 2022?" | |
# Returns: | |
# - dict: {"answer": "The total sales for 2022 is $100,000."} | |
# """ | |
# try: | |
# updated_query = f"{query} and Do not show any charts or graphs." | |
# # --- 1. First Attempt: LangChain Gemini --- | |
# try: | |
# gemini_answer = await asyncio.to_thread( | |
# langchain_gemini_csv_handler, csv_url, updated_query, False | |
# ) | |
# print("LangChain-Gemini answer:", gemini_answer) | |
# if not process_answer(gemini_answer) or gemini_answer is None: | |
# return {"answer": jsonable_encoder(gemini_answer)} | |
# raise Exception("LangChain-Gemini response not usable, falling back to LangChain-Groq") | |
# except Exception as gemini_error: | |
# print(f"LangChain-Gemini error: {str(gemini_error)}") | |
# # --- 2. Second Attempt: LangChain Groq --- | |
# try: | |
# lang_groq_answer = await asyncio.to_thread( | |
# langchain_csv_chat, csv_url, updated_query, False | |
# ) | |
# print("LangChain-Groq answer:", lang_groq_answer) | |
# if not process_answer(lang_groq_answer): | |
# return {"answer": jsonable_encoder(lang_groq_answer)} | |
# raise Exception("LangChain-Groq response not usable, falling back to raw Groq") | |
# except Exception as lang_groq_error: | |
# print(f"LangChain-Groq error: {str(lang_groq_error)}") | |
# # --- 3. Final Attempt: Raw OpenAI Chat --- | |
# try: | |
# raw_openai_answer = await asyncio.to_thread(openai_chat, csv_url, updated_query) | |
# print("Raw OpenAI answer:", raw_openai_answer) | |
# if process_answer(raw_openai_answer) == "Empty response received." or raw_openai_answer is None: | |
# return {"answer": "Sorry, I couldn't find relevant data..."} | |
# if process_answer(raw_openai_answer): | |
# except Exception as openai_exception: | |
# print(f"Raw OpenAI error: {str(openai_exception)}") | |
# # --- 4. Final Attempt: Raw Groq Chat --- | |
# try: | |
# raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query) | |
# print("Raw Groq answer:", raw_groq_answer) | |
# if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None: | |
# return {"answer": "Sorry, I couldn't find relevant data..."} | |
# if process_answer(raw_groq_answer): | |
# raise Exception("All fallbacks exhausted") | |
# return {"answer": jsonable_encoder(raw_groq_answer)} | |
# except Exception as raw_groq_error: | |
# print(f"Raw Groq error: {str(raw_groq_error)}") | |
# return {"answer": "error"} | |
# except Exception as e: | |
# print(f"Unexpected error: {str(e)}") | |
# return {"answer": "error"} | |
# async def csv_chart(csv_url: str, query: str): | |
# """ | |
# Generate a chart based on the provided CSV URL and query. | |
# Prioritizes raw OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback. | |
# Parameters: | |
# - csv_url (str): The URL of the CSV file. | |
# - query (str): The query for generating the chart. | |
# Returns: | |
# - dict: A dictionary containing either: | |
# - {"image_url": "https://example.com/chart.png"} on success, or | |
# - {"error": "error message"} on failure | |
# Example: | |
# - csv_url: "https://example.com/data.csv" | |
# - query: "Show sales trends as a line chart" | |
# Returns: | |
# - dict: {"image_url": "https://storage.example.com/chart_uuid.png"} | |
# """ | |
# async def upload_and_return(image_path: str) -> dict: | |
# """Helper function to handle image uploads""" | |
# unique_name = f'{uuid.uuid4()}.png' | |
# public_url = await upload_file_to_supabase(image_path, unique_name) | |
# print(f"Uploaded chart: {public_url}") | |
# os.remove(image_path) # Remove the local image file after upload | |
# return {"image_url": public_url} | |
# try: | |
# # --- 1. First Attempt: Raw OpenAI --- | |
# try: | |
# openai_result = await asyncio.to_thread(openai_chart, csv_url, query) | |
# print(f"OpenAI chart result:", openai_result) | |
# if openai_result and openai_result != 'Chart not generated': | |
# return await upload_and_return(openai_result) | |
# raise Exception("OpenAI failed to generate chart") | |
# except Exception as openai_error: | |
# print(f"OpenAI failed ({str(openai_error)}), trying LangChain Gemini...") | |
# # --- 2.. First Attempt: Raw Groq --- | |
# try: | |
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query) | |
# print(f"Raw Groq chart result:", groq_result) | |
# if groq_result and groq_result != 'Chart not generated': | |
# return await upload_and_return(groq_result) | |
# raise Exception("Raw Groq failed to generate chart") | |
# except Exception as groq_error: | |
# print(f"Raw Groq failed ({str(groq_error)}), trying LangChain Gemini...") | |
# # --- 3. Second Attempt: LangChain Gemini --- | |
# try: | |
# gemini_result = await asyncio.to_thread( | |
# langchain_gemini_csv_handler, csv_url, query, True | |
# ) | |
# print("LangChain Gemini chart result:", gemini_result) | |
# # --- i) If Gemini result is a string, return it --- | |
# if gemini_result and isinstance(gemini_result, str): | |
# clean_path = gemini_result.strip() | |
# return await upload_and_return(clean_path) | |
# # --- ii) If Gemini result is a list, return the first element --- | |
# if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0: | |
# return await upload_and_return(gemini_result[0]) | |
# raise Exception("LangChain Gemini returned empty result") | |
# except Exception as gemini_error: | |
# print(f"LangChain Gemini failed ({str(gemini_error)}), trying LangChain Groq...") | |
# # --- 4. Final Attempt: LangChain Groq --- | |
# try: | |
# lc_groq_paths = await asyncio.to_thread( | |
# langchain_csv_chart, csv_url, query, True | |
# ) | |
# print("LangChain Groq chart result:", lc_groq_paths) | |
# if isinstance(lc_groq_paths, list) and lc_groq_paths: | |
# return await upload_and_return(lc_groq_paths[0]) | |
# return {"error": "All chart generation methods failed"} | |
# except Exception as lc_groq_error: | |
# print(f"LangChain Groq failed: {str(lc_groq_error)}") | |
# return {"error": "Could not generate chart"} | |
# except Exception as e: | |
# print(f"Critical error: {str(e)}") | |
# return {"error": "Internal system error"} | |
####################################### Optimized Version ####################################### | |
async def csv_chat(csv_url: str, query: str) -> Dict[str, Any]: | |
""" | |
Generate a response based on the provided CSV URL and query. | |
Prioritizes LangChain-Gemini, then LangChain-Groq, then raw OpenAI and finally raw Groq as fallback. | |
Parameters: | |
- csv_url (str): The URL of the CSV file. | |
- query (str): The query for generating the response. | |
Returns: | |
- dict: A dictionary containing the generated response or error message. | |
Example: | |
- csv_url: "https://example.com/data.csv" | |
- query: "What is the total sales for the year 2022?" | |
Returns: | |
- dict: {"answer": "The total sales for 2022 is $100,000."} | |
""" | |
updated_query = f"{query} and Do not show any charts or graphs." | |
fallback_answer = "Sorry, I couldn't find relevant data..." | |
error_answer = "An error occurred while processing your request." | |
async def try_chat_method(method_name: str, method, *args) -> Dict[str, Any]: | |
"""Attempt to get answer from a specific chat method""" | |
try: | |
logger.info(f"Attempting {method_name}") | |
answer = await asyncio.to_thread(method, *args) | |
if answer is None: | |
logger.warning(f"{method_name} returned None") | |
return {"status": "empty", "answer": None} | |
processed = process_answer(answer) | |
if processed == "Empty response received.": | |
logger.warning(f"{method_name} returned empty response") | |
return {"status": "empty", "answer": answer} | |
elif processed: | |
logger.warning(f"{method_name} response not usable") | |
return {"status": "invalid", "answer": answer} | |
else: | |
logger.info(f"{method_name} succeeded") | |
return {"status": "success", "answer": answer} | |
except Exception as e: | |
logger.error(f"{method_name} failed: {str(e)}") | |
return {"status": "error", "error": str(e)} | |
# Define the methods to try in priority order | |
chat_methods = [ | |
("LangChain-Gemini", langchain_gemini_csv_handler, csv_url, updated_query, False), | |
("LangChain-Groq", langchain_csv_chat, csv_url, updated_query, False), | |
("Raw OpenAI", openai_chat, csv_url, updated_query), | |
("Raw Groq", groq_chat, csv_url, updated_query), | |
] | |
for method_name, method, *args in chat_methods: | |
result = await try_chat_method(method_name, method, *args) | |
if result["status"] == "success": | |
return {"answer": jsonable_encoder(result["answer"])} | |
elif result["status"] == "empty": | |
return {"answer": fallback_answer} | |
# If all methods failed or returned invalid responses | |
logger.error("All chat methods failed to produce a valid response") | |
return {"answer": error_answer} | |
async def csv_chart(csv_url: str, query: str) -> Dict[str, str]: | |
""" | |
Generate a chart based on the provided CSV URL and query. | |
Prioritizes raw OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback. | |
Parameters: | |
- csv_url (str): The URL of the CSV file. | |
- query (str): The query for generating the chart. | |
Returns: | |
- dict: A dictionary containing either: | |
- {"image_url": "https://example.com/chart.png"} on success, or | |
- {"error": "error message"} on failure | |
Example: | |
- csv_url: "https://example.com/data.csv" | |
- query: "Show sales trends as a line chart" | |
Returns: | |
- dict: {"image_url": "https://storage.example.com/chart_uuid.png"} | |
""" | |
async def upload_and_return(image_path: str) -> Dict[str, str]: | |
"""Helper function to handle image uploads and cleanup""" | |
try: | |
if not os.path.exists(image_path): | |
raise FileNotFoundError(f"Image file not found at {image_path}") | |
unique_name = f'{uuid.uuid4()}.png' | |
public_url = await upload_file_to_supabase(image_path, unique_name) | |
logger.info(f"Uploaded chart: {public_url}") | |
try: | |
os.remove(image_path) | |
except OSError as e: | |
logger.warning(f"Failed to remove local image file: {e}") | |
return {"image_url": public_url} | |
except Exception as e: | |
logger.error(f"Error in upload_and_return: {e}") | |
raise e | |
async def try_generation(method_name: str, method, *args) -> Union[str, None]: | |
"""Attempt chart generation with a specific method""" | |
try: | |
logger.info(f"Attempting chart generation with {method_name}") | |
result = await asyncio.to_thread(method, *args) | |
if not result or result == 'Chart not generated': | |
raise ValueError(f"{method_name} returned empty or invalid result") | |
if isinstance(result, str): | |
return result.strip() | |
elif isinstance(result, list) and result: | |
return result[0] | |
raise ValueError(f"{method_name} returned unexpected result type") | |
except Exception as e: | |
logger.warning(f"{method_name} failed: {str(e)}") | |
return None | |
generation_methods = [ | |
("Raw OpenAI", openai_chart, csv_url, query), | |
("Raw Groq", groq_chart, csv_url, query), | |
("LangChain Gemini", lambda u, q: langchain_gemini_csv_handler(u, q, True), csv_url, query), | |
("LangChain Groq", lambda u, q: langchain_csv_chart(u, q, True), csv_url, query), | |
] | |
for attempt, (method_name, method, *args) in enumerate(generation_methods, 1): | |
try: | |
result = await try_generation(method_name, method, *args) | |
if result: | |
return await upload_and_return(result) | |
except Exception as e: | |
logger.error(f"Error processing {method_name}: {e}") | |
if attempt == len(generation_methods): | |
logger.error("All chart generation methods failed") | |
return {"error": "Could not generate chart using any available method"} | |
return {"error": "All chart generation methods failed"} | |
# Example usage: | |
# csv_url = './documents/titanic.csv' | |
# query = "Create a pie chart of male vs female passengers?" | |
# result = openai_chart(csv_url, query) | |
# print(result) |