FastApi / gemini_report_generator.py
Soumik555's picture
added openai in orchestrator
a172af8
raw
history blame
45.7 kB
# import json
# import numpy as np
# import pandas as pd
# import re
# import os
# import uuid
# import logging
# from io import StringIO
# import sys
# import traceback
# from typing import Optional, Dict, Any, List
# from pydantic import BaseModel, Field
# from google.generativeai import GenerativeModel, configure
# from dotenv import load_dotenv
# import seaborn as sns
# import datetime as dt
# from supabase_service import upload_file_to_supabase
# pd.set_option('display.max_columns', None)
# pd.set_option('display.max_rows', None)
# pd.set_option('display.max_colwidth', None)
# load_dotenv()
# API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")[::-1]
# MODEL_NAME = 'gemini-2.0-flash'
# class FileProps(BaseModel):
# fileName: str
# filePath: str
# fileType: str # 'csv' | 'image'
# class Files(BaseModel):
# csv_files: List[FileProps]
# image_files: List[FileProps]
# class FileBoxProps(BaseModel):
# files: Files
# os.environ['MPLBACKEND'] = 'agg'
# import matplotlib.pyplot as plt
# plt.show = lambda: None
# logging.basicConfig(
# level=logging.INFO,
# format='%(asctime)s - %(levelname)s - %(message)s'
# )
# logger = logging.getLogger(__name__)
# class GeminiKeyManager:
# """Manage multiple Gemini API keys with failover"""
# def __init__(self, api_keys: List[str]):
# self.original_keys = api_keys.copy()
# self.available_keys = api_keys.copy()
# self.active_key = None
# self.failed_keys = {}
# def configure(self) -> bool:
# while self.available_keys:
# key = self.available_keys.pop(0)
# try:
# configure(api_key=key)
# self.active_key = key
# logger.info(f"Configured with key: {self._mask_key(key)}")
# return True
# except Exception as e:
# self.failed_keys[key] = str(e)
# logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}")
# logger.critical("All API keys failed")
# return False
# def _mask_key(self, key: str) -> str:
# return f"{key[:8]}...{key[-4:]}" if key else ""
# class PythonREPL:
# """Secure Python REPL with file generation tracking"""
# def __init__(self, df: pd.DataFrame):
# self.df = df
# self.output_dir = os.path.abspath(f'generated_outputs/{uuid.uuid4()}')
# os.makedirs(self.output_dir, exist_ok=True)
# self.local_env = {
# "pd": pd,
# "df": self.df.copy(),
# "plt": plt,
# "os": os,
# "uuid": uuid,
# "sns": sns,
# "json": json,
# "dt": dt,
# "output_dir": self.output_dir
# }
# def execute(self, code: str) -> Dict[str, Any]:
# print('Executing code...', code)
# old_stdout = sys.stdout
# sys.stdout = mystdout = StringIO()
# file_tracker = {
# 'csv_files': set(),
# 'image_files': set()
# }
# try:
# code = f"""
# import matplotlib.pyplot as plt
# plt.switch_backend('agg')
# {code}
# plt.close('all')
# """
# exec(code, self.local_env)
# self.df = self.local_env.get('df', self.df)
# # Track generated files
# for fname in os.listdir(self.output_dir):
# if fname.endswith('.csv'):
# file_tracker['csv_files'].add(fname)
# elif fname.lower().endswith(('.png', '.jpg', '.jpeg')):
# file_tracker['image_files'].add(fname)
# error = False
# except Exception as e:
# error_msg = traceback.format_exc()
# error = True
# finally:
# sys.stdout = old_stdout
# return {
# "output": mystdout.getvalue(),
# "error": error,
# "error_message": error_msg if error else None,
# "df": self.local_env.get('df', self.df),
# "output_dir": self.output_dir,
# "files": {
# "csv": [os.path.join(self.output_dir, f) for f in file_tracker['csv_files']],
# "images": [os.path.join(self.output_dir, f) for f in file_tracker['image_files']]
# }
# }
# class RethinkAgent(BaseModel):
# df: pd.DataFrame
# max_retries: int = Field(default=5, ge=1)
# gemini_model: Optional[GenerativeModel] = None
# current_retry: int = Field(default=0, ge=0)
# repl: Optional[PythonREPL] = None
# key_manager: Optional[GeminiKeyManager] = None
# class Config:
# arbitrary_types_allowed = True
# def _extract_code(self, response: str) -> str:
# code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
# return code_match.group(1).strip() if code_match else response.strip()
# def _generate_initial_prompt(self, query: str) -> str:
# return f"""Generate DIRECT EXECUTION CODE (no functions, no explanations) following STRICT RULES:
# MANDATORY REQUIREMENTS:
# 1. Operate directly on existing 'df' variable
# 2. Save ALL final DataFrames to CSV using: df.to_csv(f'{{output_dir}}/descriptive_name.csv')
# 3. For visualizations: plt.savefig(f'{{output_dir}}/chart_name.png')
# 4. Use EXACTLY this structure:
# # Data processing
# df_processed = df[...] # filtering/grouping
# # Save results
# df_processed.to_csv(f'{{output_dir}}/result.csv')
# # Visualizations (if needed)
# plt.figure()
# ... plotting code ...
# plt.savefig(f'{{output_dir}}/chart.png')
# plt.close()
# FORBIDDEN:
# - Function definitions
# - Dummy data creation
# - Any code blocks besides pandas operations and matplotlib
# - Print statements showing dataframes
# DATAFRAME COLUMNS: {', '.join(self.df.columns)}
# DATAFRAME'S FIRST FIVE ROWS: {self.df.head().to_dict('records')}
# USER QUERY: {query}
# EXAMPLE RESPONSE FOR "Sales by region":
# # Data processing
# sales_by_region = df.groupby('region')['sales'].sum().reset_index()
# # Save results
# sales_by_region.to_csv(f'{{output_dir}}/sales_by_region.csv')
# """
# def _generate_retry_prompt(self, query: str, error: str, code: str) -> str:
# return f"""FIX THIS CODE (failed with: {error}) by STRICTLY FOLLOWING:
# 1. REMOVE ALL FUNCTION DEFINITIONS
# 2. ENSURE DIRECT DF OPERATIONS
# 3. USE EXPLICIT output_dir PATHS
# 4. ADD NECESSARY IMPORTS IF MISSING
# 5. VALIDATE COLUMN NAMES EXIST
# BAD CODE:
# {code}
# CORRECTED CODE:"""
# def initialize_model(self, api_keys: List[str]) -> bool:
# self.key_manager = GeminiKeyManager(api_keys)
# if not self.key_manager.configure():
# raise RuntimeError("API key initialization failed")
# try:
# self.gemini_model = GenerativeModel(MODEL_NAME)
# return True
# except Exception as e:
# logger.error(f"Model init failed: {str(e)}")
# return False
# def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str:
# prompt = self._generate_retry_prompt(query, error, previous_code) if error else self._generate_initial_prompt(query)
# try:
# response = self.gemini_model.generate_content(prompt)
# return self._extract_code(response.text)
# except Exception as e:
# if self.key_manager.available_keys and self.key_manager.configure():
# return self.generate_code(query, error, previous_code)
# raise
# def execute_query(self, query: str) -> Dict[str, Any]:
# self.repl = PythonREPL(self.df)
# result = None
# while self.current_retry < self.max_retries:
# try:
# code = self.generate_code(query,
# result["error_message"] if result else None,
# result["code"] if result else None)
# execution_result = self.repl.execute(code)
# if execution_result["error"]:
# self.current_retry += 1
# result = {
# "error_message": execution_result["error_message"],
# "code": code
# }
# else:
# return {
# "text": execution_result["output"],
# "csv_files": execution_result["files"]["csv"],
# "image_files": execution_result["files"]["images"]
# }
# except Exception as e:
# return {
# "error": f"Critical failure: {str(e)}",
# "csv_files": [],
# "image_files": []
# }
# return {
# "error": f"Failed after {self.max_retries} retries",
# "csv_files": [],
# "image_files": []
# }
# def gemini_llm_chat(csv_url: str, query: str) -> Dict[str, Any]:
# try:
# df = pd.read_csv(csv_url)
# agent = RethinkAgent(df=df)
# if not agent.initialize_model(API_KEYS):
# return {"error": "API configuration failed"}
# result = agent.execute_query(query)
# if "error" in result:
# return result
# return {
# "message": result["text"],
# "csv_files": result["csv_files"],
# "image_files": result["image_files"]
# }
# except Exception as e:
# logger.error(f"Processing failed: {str(e)}")
# return {
# "error": f"Processing error: {str(e)}",
# "csv_files": [],
# "image_files": []
# }
# async def generate_csv_report(csv_url: str, query: str) -> FileBoxProps:
# try:
# result = gemini_llm_chat(csv_url, query)
# logger.info(f"Raw result from gemini_llm_chat: {result}")
# csv_files = []
# image_files = []
# # Check if we got the expected response structure
# if isinstance(result, dict) and 'csv_files' in result and 'image_files' in result:
# # Process CSV files
# for csv_path in result['csv_files']:
# if os.path.exists(csv_path):
# file_name = os.path.basename(csv_path)
# try:
# unique_file_name = f"{uuid.uuid4()}_{file_name}"
# public_url = await upload_file_to_supabase(
# file_path=csv_path,
# file_name=unique_file_name
# )
# csv_files.append(FileProps(
# fileName=file_name,
# filePath=public_url,
# fileType="csv"
# ))
# os.remove(csv_path) # Clean up
# except Exception as upload_error:
# logger.error(f"Failed to upload CSV {file_name}: {str(upload_error)}")
# continue
# # Process image files
# for img_path in result['image_files']:
# if os.path.exists(img_path):
# file_name = os.path.basename(img_path)
# try:
# unique_file_name = f"{uuid.uuid4()}_{file_name}"
# public_url = await upload_file_to_supabase(
# file_path=img_path,
# file_name=unique_file_name
# )
# image_files.append(FileProps(
# fileName=file_name,
# filePath=public_url,
# fileType="image"
# ))
# os.remove(img_path) # Clean up
# except Exception as upload_error:
# logger.error(f"Failed to upload image {file_name}: {str(upload_error)}")
# continue
# return FileBoxProps(
# files=Files(
# csv_files=csv_files,
# image_files=image_files
# )
# )
# else:
# raise ValueError("Unexpected response format from gemini_llm_chat")
# except Exception as e:
# logger.error(f"Report generation failed: {str(e)}")
# # Return empty response but log the files we found
# if 'csv_files' in locals() and 'image_files' in locals():
# logger.info(f"Files that were generated but not processed: CSV: {result.get('csv_files', [])}, Images: {result.get('image_files', [])}")
# return FileBoxProps(
# files=Files(
# csv_files=[],
# image_files=[]
# )
# )
# Newly Modified code with openai
# Import necessary modules
import asyncio
import os
import threading
from typing import Any, Dict, Union
import uuid
from fastapi.encoders import jsonable_encoder
from langchain_openai import ChatOpenAI
import numpy as np
import pandas as pd
from pandasai import SmartDataframe
from langchain_groq.chat_models import ChatGroq
from dotenv import load_dotenv
from pydantic import BaseModel
from csv_service import clean_data, extract_chart_filenames
from langchain_groq import ChatGroq
import pandas as pd
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from gemini_langchain_agent import langchain_gemini_csv_handler
from supabase_service import upload_file_to_supabase
from util_service import _prompt_generator, process_answer
import matplotlib
import logging
matplotlib.use('Agg')
load_dotenv()
image_file_path = os.getenv("IMAGE_FILE_PATH")
image_not_found = os.getenv("IMAGE_NOT_FOUND")
allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",")
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
openai_api_keys = os.getenv("OPENAI_API_KEYS").split(",")
openai_base_url = os.getenv("OPENAI_BASE_URL")
openai_api_base = os.getenv("OPENAI_BASE_URL")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CsvUrlRequest(BaseModel):
csv_url: str
class ImageRequest(BaseModel):
image_path: str
class CsvCommonHeadersRequest(BaseModel):
file_urls: list[str]
class CsvsMergeRequest(BaseModel):
file_urls: list[str]
merge_type: str
common_columns_name: list[str]
# Thread-safe key management for openai_chat
current_openai_key_index = 0
current_openai_key_lock = threading.Lock()
# Thread-safe key management for groq_chat
current_groq_key_index = 0
current_groq_key_lock = threading.Lock()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# CHAT CODING STARTS FROM HERE
def handle_out_of_range_float(value):
if isinstance(value, float):
if np.isnan(value):
return None
elif np.isinf(value):
return "Infinity"
return value
# Modified groq_chat function with thread-safe key rotation
def groq_chat(csv_url: str, question: str):
global current_groq_key_index, current_groq_key_lock
while True:
with current_groq_key_lock:
if current_groq_key_index >= len(groq_api_keys):
return {"error": "All API keys exhausted."}
current_api_key = groq_api_keys[current_groq_key_index]
try:
# Delete cache file if exists
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
print(f"Error deleting cache DB file: {e}")
data = clean_data(csv_url)
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question)
# Process different response types
if isinstance(answer, pd.DataFrame):
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
elif isinstance(answer, pd.Series):
processed = answer.apply(handle_out_of_range_float).to_dict()
elif isinstance(answer, list):
processed = [handle_out_of_range_float(item) for item in answer]
elif isinstance(answer, dict):
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
else:
processed = {"answer": str(handle_out_of_range_float(answer))}
return processed
except Exception as e:
error_message = str(e)
if error_message:
with current_groq_key_lock:
current_groq_key_index += 1
if current_groq_key_index >= len(groq_api_keys):
print("All API keys exhausted.")
return None
else:
print(f"Error with API key index {current_groq_key_index}: {error_message}")
return None
# Modified langchain_csv_chat with thread-safe key rotation
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
global current_langchain_key_index, current_langchain_key_lock
data = clean_data(csv_url)
attempts = 0
while attempts < len(groq_api_keys):
with current_langchain_key_lock:
if current_langchain_key_index >= len(groq_api_keys):
current_langchain_key_index = 0
api_key = groq_api_keys[current_langchain_key_index]
current_key = current_langchain_key_index
current_langchain_key_index += 1
attempts += 1
try:
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
return result.get("output")
except Exception as e:
print(f"Error with key index {current_key}: {str(e)}")
# If all keys are exhausted, return None
print("All API keys have been exhausted.")
return None
def handle_out_of_range_float(value):
if isinstance(value, float):
if np.isnan(value):
return None
elif np.isinf(value):
return "Infinity"
return value
# CHART CODING STARTS FROM HERE
instructions = """
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
- Use colorblind-friendly palette
- Read above instructions and follow them.
"""
# Thread-safe configuration for chart endpoints
current_groq_chart_key_index = 0
current_groq_chart_lock = threading.Lock()
current_langchain_chart_key_index = 0
current_langchain_chart_lock = threading.Lock()
def model():
global current_groq_chart_key_index, current_groq_chart_lock
with current_groq_chart_lock:
if current_groq_chart_key_index >= len(groq_api_keys):
raise Exception("All API keys exhausted for chart generation")
api_key = groq_api_keys[current_groq_chart_key_index]
return ChatGroq(model=model_name, api_key=api_key)
def groq_chart(csv_url: str, question: str):
global current_groq_chart_key_index, current_groq_chart_lock
for attempt in range(len(groq_api_keys)):
try:
# Clean cache before processing
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
print(f"Cache cleanup error: {e}")
data = clean_data(csv_url)
with current_groq_chart_lock:
current_api_key = groq_api_keys[current_groq_chart_key_index]
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question + instructions)
if process_answer(answer):
return "Chart not generated"
return answer
except Exception as e:
error = str(e)
if "429" in error or error is not None:
with current_groq_chart_lock:
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
else:
print(f"Chart generation error: {error}")
return {"error": error}
print("All API keys exhausted for chart generation")
return None
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
global current_langchain_chart_key_index, current_langchain_chart_lock
data = clean_data(csv_url)
for attempt in range(len(groq_api_keys)):
try:
with current_langchain_chart_lock:
api_key = groq_api_keys[current_langchain_chart_key_index]
current_key = current_langchain_chart_key_index
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
result = agent.invoke({"input": _prompt_generator(f"{question} and use this csv_url: {csv_url} to read the csv file", True)})
output = result.get("output", "")
# Verify chart file creation
chart_files = extract_chart_filenames(output)
if len(chart_files) > 0:
return chart_files
if attempt < len(groq_api_keys) - 1:
print(f"Langchain chart error (key {current_key}): {output}")
except Exception as e:
print(f"Langchain chart error (key {current_key}): {str(e)}")
print("All API keys exhausted for chart generation")
return None
####################################### OpenAI + PandasAI #######################################
# Modified openai_chat function with thread-safe key rotation
openai_model_name = 'gpt-4o'
def openai_chat(csv_url: str, question: str):
global current_openai_key_index, current_openai_key_lock
while True:
with current_openai_key_lock:
if current_openai_key_index >= len(openai_api_keys):
return {"error": "All API keys exhausted."}
current_api_key = openai_api_keys[current_openai_key_index]
try:
# Delete cache file if exists
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
print(f"Error deleting cache DB file: {e}")
data = clean_data(csv_url)
llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question)
# Process different response types
if isinstance(answer, pd.DataFrame):
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
elif isinstance(answer, pd.Series):
processed = answer.apply(handle_out_of_range_float).to_dict()
elif isinstance(answer, list):
processed = [handle_out_of_range_float(item) for item in answer]
elif isinstance(answer, dict):
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
else:
processed = {"answer": str(handle_out_of_range_float(answer))}
return processed
except Exception as e:
error_message = str(e)
if error_message:
with current_openai_key_lock:
current_openai_key_index += 1
if current_openai_key_index >= len(openai_api_keys):
print("All API keys exhausted.")
return None
else:
print(f"Error with API key index {current_openai_key_index}: {error_message}")
return None
def openai_chart(csv_url: str, question: str):
global current_openai_key_index, current_openai_key_lock
while True:
with current_openai_key_lock:
if current_openai_key_index >= len(openai_api_keys):
return {"error": "All API keys exhausted."}
current_api_key = openai_api_keys[current_openai_key_index]
try:
# Delete cache file if exists
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
print(f"Error deleting cache DB file: {e}")
data = clean_data(csv_url)
llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question + instructions)
if process_answer(answer):
return "Chart not generated"
return answer
except Exception as e:
error = str(e)
print(f"Error with API key index {current_openai_key_index}: {error}")
if "429" in error or error is not None:
with current_openai_key_lock:
current_openai_key_index = (current_openai_key_index + 1) % len(openai_api_keys)
else:
print(f"Chart generation error: {error}")
return {"error": error}
print("All API keys exhausted for chart generation")
return None
####################################### Start with lc_gemini #######################################
# async def csv_chat(csv_url: str, query: str):
# """
# Generate a response based on the provided CSV URL and query.
# Prioritizes LangChain-Gemini, then LangChain-Groq, then raw OpenAI and finally raw Groq as fallback.
# Parameters:
# - csv_url (str): The URL of the CSV file.
# - query (str): The query for generating the response.
# Returns:
# - dict: A dictionary containing the generated response.
# Example:
# - csv_url: "https://example.com/data.csv"
# - query: "What is the total sales for the year 2022?"
# Returns:
# - dict: {"answer": "The total sales for 2022 is $100,000."}
# """
# try:
# updated_query = f"{query} and Do not show any charts or graphs."
# # --- 1. First Attempt: LangChain Gemini ---
# try:
# gemini_answer = await asyncio.to_thread(
# langchain_gemini_csv_handler, csv_url, updated_query, False
# )
# print("LangChain-Gemini answer:", gemini_answer)
# if not process_answer(gemini_answer) or gemini_answer is None:
# return {"answer": jsonable_encoder(gemini_answer)}
# raise Exception("LangChain-Gemini response not usable, falling back to LangChain-Groq")
# except Exception as gemini_error:
# print(f"LangChain-Gemini error: {str(gemini_error)}")
# # --- 2. Second Attempt: LangChain Groq ---
# try:
# lang_groq_answer = await asyncio.to_thread(
# langchain_csv_chat, csv_url, updated_query, False
# )
# print("LangChain-Groq answer:", lang_groq_answer)
# if not process_answer(lang_groq_answer):
# return {"answer": jsonable_encoder(lang_groq_answer)}
# raise Exception("LangChain-Groq response not usable, falling back to raw Groq")
# except Exception as lang_groq_error:
# print(f"LangChain-Groq error: {str(lang_groq_error)}")
# # --- 3. Final Attempt: Raw OpenAI Chat ---
# try:
# raw_openai_answer = await asyncio.to_thread(openai_chat, csv_url, updated_query)
# print("Raw OpenAI answer:", raw_openai_answer)
# if process_answer(raw_openai_answer) == "Empty response received." or raw_openai_answer is None:
# return {"answer": "Sorry, I couldn't find relevant data..."}
# if process_answer(raw_openai_answer):
# except Exception as openai_exception:
# print(f"Raw OpenAI error: {str(openai_exception)}")
# # --- 4. Final Attempt: Raw Groq Chat ---
# try:
# raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
# print("Raw Groq answer:", raw_groq_answer)
# if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None:
# return {"answer": "Sorry, I couldn't find relevant data..."}
# if process_answer(raw_groq_answer):
# raise Exception("All fallbacks exhausted")
# return {"answer": jsonable_encoder(raw_groq_answer)}
# except Exception as raw_groq_error:
# print(f"Raw Groq error: {str(raw_groq_error)}")
# return {"answer": "error"}
# except Exception as e:
# print(f"Unexpected error: {str(e)}")
# return {"answer": "error"}
# async def csv_chart(csv_url: str, query: str):
# """
# Generate a chart based on the provided CSV URL and query.
# Prioritizes raw OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
# Parameters:
# - csv_url (str): The URL of the CSV file.
# - query (str): The query for generating the chart.
# Returns:
# - dict: A dictionary containing either:
# - {"image_url": "https://example.com/chart.png"} on success, or
# - {"error": "error message"} on failure
# Example:
# - csv_url: "https://example.com/data.csv"
# - query: "Show sales trends as a line chart"
# Returns:
# - dict: {"image_url": "https://storage.example.com/chart_uuid.png"}
# """
# async def upload_and_return(image_path: str) -> dict:
# """Helper function to handle image uploads"""
# unique_name = f'{uuid.uuid4()}.png'
# public_url = await upload_file_to_supabase(image_path, unique_name)
# print(f"Uploaded chart: {public_url}")
# os.remove(image_path) # Remove the local image file after upload
# return {"image_url": public_url}
# try:
# # --- 1. First Attempt: Raw OpenAI ---
# try:
# openai_result = await asyncio.to_thread(openai_chart, csv_url, query)
# print(f"OpenAI chart result:", openai_result)
# if openai_result and openai_result != 'Chart not generated':
# return await upload_and_return(openai_result)
# raise Exception("OpenAI failed to generate chart")
# except Exception as openai_error:
# print(f"OpenAI failed ({str(openai_error)}), trying LangChain Gemini...")
# # --- 2.. First Attempt: Raw Groq ---
# try:
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
# print(f"Raw Groq chart result:", groq_result)
# if groq_result and groq_result != 'Chart not generated':
# return await upload_and_return(groq_result)
# raise Exception("Raw Groq failed to generate chart")
# except Exception as groq_error:
# print(f"Raw Groq failed ({str(groq_error)}), trying LangChain Gemini...")
# # --- 3. Second Attempt: LangChain Gemini ---
# try:
# gemini_result = await asyncio.to_thread(
# langchain_gemini_csv_handler, csv_url, query, True
# )
# print("LangChain Gemini chart result:", gemini_result)
# # --- i) If Gemini result is a string, return it ---
# if gemini_result and isinstance(gemini_result, str):
# clean_path = gemini_result.strip()
# return await upload_and_return(clean_path)
# # --- ii) If Gemini result is a list, return the first element ---
# if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0:
# return await upload_and_return(gemini_result[0])
# raise Exception("LangChain Gemini returned empty result")
# except Exception as gemini_error:
# print(f"LangChain Gemini failed ({str(gemini_error)}), trying LangChain Groq...")
# # --- 4. Final Attempt: LangChain Groq ---
# try:
# lc_groq_paths = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# print("LangChain Groq chart result:", lc_groq_paths)
# if isinstance(lc_groq_paths, list) and lc_groq_paths:
# return await upload_and_return(lc_groq_paths[0])
# return {"error": "All chart generation methods failed"}
# except Exception as lc_groq_error:
# print(f"LangChain Groq failed: {str(lc_groq_error)}")
# return {"error": "Could not generate chart"}
# except Exception as e:
# print(f"Critical error: {str(e)}")
# return {"error": "Internal system error"}
####################################### Optimized Version #######################################
async def csv_chat(csv_url: str, query: str) -> Dict[str, Any]:
"""
Generate a response based on the provided CSV URL and query.
Prioritizes LangChain-Gemini, then LangChain-Groq, then raw OpenAI and finally raw Groq as fallback.
Parameters:
- csv_url (str): The URL of the CSV file.
- query (str): The query for generating the response.
Returns:
- dict: A dictionary containing the generated response or error message.
Example:
- csv_url: "https://example.com/data.csv"
- query: "What is the total sales for the year 2022?"
Returns:
- dict: {"answer": "The total sales for 2022 is $100,000."}
"""
updated_query = f"{query} and Do not show any charts or graphs."
fallback_answer = "Sorry, I couldn't find relevant data..."
error_answer = "An error occurred while processing your request."
async def try_chat_method(method_name: str, method, *args) -> Dict[str, Any]:
"""Attempt to get answer from a specific chat method"""
try:
logger.info(f"Attempting {method_name}")
answer = await asyncio.to_thread(method, *args)
if answer is None:
logger.warning(f"{method_name} returned None")
return {"status": "empty", "answer": None}
processed = process_answer(answer)
if processed == "Empty response received.":
logger.warning(f"{method_name} returned empty response")
return {"status": "empty", "answer": answer}
elif processed:
logger.warning(f"{method_name} response not usable")
return {"status": "invalid", "answer": answer}
else:
logger.info(f"{method_name} succeeded")
return {"status": "success", "answer": answer}
except Exception as e:
logger.error(f"{method_name} failed: {str(e)}")
return {"status": "error", "error": str(e)}
# Define the methods to try in priority order
chat_methods = [
("LangChain-Gemini", langchain_gemini_csv_handler, csv_url, updated_query, False),
("LangChain-Groq", langchain_csv_chat, csv_url, updated_query, False),
("Raw OpenAI", openai_chat, csv_url, updated_query),
("Raw Groq", groq_chat, csv_url, updated_query),
]
for method_name, method, *args in chat_methods:
result = await try_chat_method(method_name, method, *args)
if result["status"] == "success":
return {"answer": jsonable_encoder(result["answer"])}
elif result["status"] == "empty":
return {"answer": fallback_answer}
# If all methods failed or returned invalid responses
logger.error("All chat methods failed to produce a valid response")
return {"answer": error_answer}
async def csv_chart(csv_url: str, query: str) -> Dict[str, str]:
"""
Generate a chart based on the provided CSV URL and query.
Prioritizes raw OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
Parameters:
- csv_url (str): The URL of the CSV file.
- query (str): The query for generating the chart.
Returns:
- dict: A dictionary containing either:
- {"image_url": "https://example.com/chart.png"} on success, or
- {"error": "error message"} on failure
Example:
- csv_url: "https://example.com/data.csv"
- query: "Show sales trends as a line chart"
Returns:
- dict: {"image_url": "https://storage.example.com/chart_uuid.png"}
"""
async def upload_and_return(image_path: str) -> Dict[str, str]:
"""Helper function to handle image uploads and cleanup"""
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found at {image_path}")
unique_name = f'{uuid.uuid4()}.png'
public_url = await upload_file_to_supabase(image_path, unique_name)
logger.info(f"Uploaded chart: {public_url}")
try:
os.remove(image_path)
except OSError as e:
logger.warning(f"Failed to remove local image file: {e}")
return {"image_url": public_url}
except Exception as e:
logger.error(f"Error in upload_and_return: {e}")
raise e
async def try_generation(method_name: str, method, *args) -> Union[str, None]:
"""Attempt chart generation with a specific method"""
try:
logger.info(f"Attempting chart generation with {method_name}")
result = await asyncio.to_thread(method, *args)
if not result or result == 'Chart not generated':
raise ValueError(f"{method_name} returned empty or invalid result")
if isinstance(result, str):
return result.strip()
elif isinstance(result, list) and result:
return result[0]
raise ValueError(f"{method_name} returned unexpected result type")
except Exception as e:
logger.warning(f"{method_name} failed: {str(e)}")
return None
generation_methods = [
("Raw OpenAI", openai_chart, csv_url, query),
("Raw Groq", groq_chart, csv_url, query),
("LangChain Gemini", lambda u, q: langchain_gemini_csv_handler(u, q, True), csv_url, query),
("LangChain Groq", lambda u, q: langchain_csv_chart(u, q, True), csv_url, query),
]
for attempt, (method_name, method, *args) in enumerate(generation_methods, 1):
try:
result = await try_generation(method_name, method, *args)
if result:
return await upload_and_return(result)
except Exception as e:
logger.error(f"Error processing {method_name}: {e}")
if attempt == len(generation_methods):
logger.error("All chart generation methods failed")
return {"error": "Could not generate chart using any available method"}
return {"error": "All chart generation methods failed"}
# Example usage:
# csv_url = './documents/titanic.csv'
# query = "Create a pie chart of male vs female passengers?"
# result = openai_chart(csv_url, query)
# print(result)