FastApi / orchestrator_functions.py
Soumik555's picture
Changed supabase query
d784ff5
# Import necessary modules
import asyncio
import logging
import os
import threading
import uuid
from fastapi.encoders import jsonable_encoder
import numpy as np
import pandas as pd
from pandasai import SmartDataframe
from langchain_groq.chat_models import ChatGroq
from dotenv import load_dotenv
from pydantic import BaseModel
from csv_service import clean_data, extract_chart_filenames
from langchain_groq import ChatGroq
import pandas as pd
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from gemini_langchain_agent import langchain_gemini_csv_handler
from openai_pandasai_service import openai_chart
from supabase_service import upload_file_to_supabase
from util_service import _prompt_generator, process_answer
import matplotlib
matplotlib.use('Agg')
load_dotenv()
image_file_path = os.getenv("IMAGE_FILE_PATH")
image_not_found = os.getenv("IMAGE_NOT_FOUND")
allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",")
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
class CsvUrlRequest(BaseModel):
csv_url: str
class ImageRequest(BaseModel):
image_path: str
class CsvCommonHeadersRequest(BaseModel):
file_urls: list[str]
class CsvsMergeRequest(BaseModel):
file_urls: list[str]
merge_type: str
common_columns_name: list[str]
# Thread-safe key management for groq_chat
current_groq_key_index = 0
current_groq_key_lock = threading.Lock()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# CHAT CODING STARTS FROM HERE
def handle_out_of_range_float(value):
if isinstance(value, float):
if np.isnan(value):
return None
elif np.isinf(value):
return "Infinity"
return value
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Modified groq_chat function with thread-safe key rotation
def groq_chat(csv_url: str, question: str):
global current_groq_key_index, current_groq_key_lock
while True:
with current_groq_key_lock:
if current_groq_key_index >= len(groq_api_keys):
return {"error": "All API keys exhausted."}
current_api_key = groq_api_keys[current_groq_key_index]
try:
# Delete cache file if exists
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
logger.info(f"Error deleting cache DB file: {e}")
data = clean_data(csv_url)
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question)
# Process different response types
if isinstance(answer, pd.DataFrame):
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
elif isinstance(answer, pd.Series):
processed = answer.apply(handle_out_of_range_float).to_dict()
elif isinstance(answer, list):
processed = [handle_out_of_range_float(item) for item in answer]
elif isinstance(answer, dict):
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
else:
processed = {"answer": str(handle_out_of_range_float(answer))}
return processed
except Exception as e:
error_message = str(e)
if "429" in error_message:
with current_groq_key_lock:
current_groq_key_index += 1
if current_groq_key_index >= len(groq_api_keys):
logger.info("All API keys exhausted.")
return None
else:
logger.info(f"Error with API key index {current_groq_key_index}: {error_message}")
return None
# Modified langchain_csv_chat with thread-safe key rotation
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
global current_langchain_key_index, current_langchain_key_lock
data = clean_data(csv_url)
attempts = 0
while attempts < len(groq_api_keys):
with current_langchain_key_lock:
if current_langchain_key_index >= len(groq_api_keys):
current_langchain_key_index = 0
api_key = groq_api_keys[current_langchain_key_index]
current_key = current_langchain_key_index
current_langchain_key_index += 1
attempts += 1
try:
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
return result.get("output")
except Exception as e:
logger.info(f"Error with key index {current_key}: {str(e)}")
# If all keys are exhausted, return None
logger.info("All API keys have been exhausted.")
return None
def handle_out_of_range_float(value):
if isinstance(value, float):
if np.isnan(value):
return None
elif np.isinf(value):
return "Infinity"
return value
# CHART CODING STARTS FROM HERE
instructions = """
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
- Use colorblind-friendly palette
- Read above instructions and follow them.
"""
# Thread-safe configuration for chart endpoints
current_groq_chart_key_index = 0
current_groq_chart_lock = threading.Lock()
current_langchain_chart_key_index = 0
current_langchain_chart_lock = threading.Lock()
def model():
global current_groq_chart_key_index, current_groq_chart_lock
with current_groq_chart_lock:
if current_groq_chart_key_index >= len(groq_api_keys):
raise Exception("All API keys exhausted for chart generation")
api_key = groq_api_keys[current_groq_chart_key_index]
return ChatGroq(model=model_name, api_key=api_key)
def groq_chart(csv_url: str, question: str):
global current_groq_chart_key_index, current_groq_chart_lock
for attempt in range(len(groq_api_keys)):
try:
# Clean cache before processing
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
logger.info(f"Cache cleanup error: {e}")
data = clean_data(csv_url)
with current_groq_chart_lock:
current_api_key = groq_api_keys[current_groq_chart_key_index]
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question + instructions)
if process_answer(answer):
return "Chart not generated"
return answer
except Exception as e:
error = str(e)
if "429" in error:
with current_groq_chart_lock:
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
else:
logger.info(f"Chart generation error: {error}")
return {"error": error}
logger.info("All API keys exhausted for chart generation")
return None
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
global current_langchain_chart_key_index, current_langchain_chart_lock
data = clean_data(csv_url)
for attempt in range(len(groq_api_keys)):
try:
with current_langchain_chart_lock:
api_key = groq_api_keys[current_langchain_chart_key_index]
current_key = current_langchain_chart_key_index
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
result = agent.invoke({"input": _prompt_generator(f"{question} and use this csv_url: {csv_url} to read the csv file", True)})
output = result.get("output", "")
# Verify chart file creation
chart_files = extract_chart_filenames(output)
if len(chart_files) > 0:
return chart_files
if attempt < len(groq_api_keys) - 1:
logger.info(f"Langchain chart error (key {current_key}): {output}")
except Exception as e:
logger.info(f"Langchain chart error (key {current_key}): {str(e)}")
logger.info("All API keys exhausted for chart generation")
return None
###########################################################################################################################
# async def csv_chart(csv_url: str, query: str):
# """
# Generate a chart based on the provided CSV URL and query.
# Parameters:
# - csv_url (str): The URL of the CSV file.
# - query (str): The query for generating the chart.
# Returns:
# - dict: A dictionary containing the generated chart image URL.
# Example:
# - csv_url: "https://example.com/data.csv"
# - query: "Generate a bar chart showing sales by region."
# Returns:
# - dict: {"image_url": "https://example.com/chart.png"}.
# """
# try:
# # First try Groq-based chart generation
# try:
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
# logger.info(f"Generated Chart (Groq): {groq_result}")
# if groq_result != 'Chart not generated':
# unique_file_name = f'{str(uuid.uuid4())}.png'
# image_public_url = await upload_file_to_supabase(groq_result, unique_file_name)
# logger.info(f"Image uploaded to Supabase: {image_public_url}")
# return {"image_url": image_public_url}
# except Exception as groq_error:
# logger.info(f"Groq chart generation failed, falling back to Langchain: {str(groq_error)}")
# # Fallback to Langchain if Groq fails
# try:
# langchain_paths = await asyncio.to_thread(langchain_csv_chart, csv_url, query, True)
# logger.info("Fallback langchain chart result:", langchain_paths)
# if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
# unique_file_name = f'{str(uuid.uuid4())}.png'
# logger.info("Uploading the chart to supabase...")
# image_public_url = await upload_file_to_supabase(langchain_paths[0], unique_file_name)
# logger.info("Image uploaded to Supabase and Image URL is... ", image_public_url)
# return {"image_url": image_public_url}
# except Exception as langchain_error:
# logger.info(f"Langchain chart generation also failed: {str(langchain_error)}")
# try:
# # Last resort: Try with the gemini langchain agent
# logger.info("Trying with the gemini langchain agent...")
# lc_gemini_chart_result = await asyncio.to_thread(langchain_gemini_csv_handler, csv_url, query, True)
# if lc_gemini_chart_result is not None:
# clean_path = lc_gemini_chart_result.strip()
# unique_file_name = f'{str(uuid.uuid4())}.png'
# logger.info("Uploading the chart to supabase...")
# image_public_url = await upload_file_to_supabase(clean_path, unique_file_name)
# logger.info("Image uploaded to Supabase and Image URL is... ", image_public_url)
# return {"image_url": image_public_url}
# except Exception as gemini_error:
# logger.info(f"Gemini Langchain chart generation also failed: {str(gemini_error)}")
# # If both methods fail
# return {"error": "Could not generate the chart, please try again."}
# except Exception as e:
# logger.info(f"Critical chart error: {str(e)}")
# return {"error": "Internal system error"}
# async def csv_chat(csv_url: str, query: str):
# """
# Generate a response based on the provided CSV URL and query.
# Parameters:
# - csv_url (str): The URL of the CSV file.
# - query (str): The query for generating the response.
# Returns:
# - dict: A dictionary containing the generated response.
# Example:
# - csv_url: "https://example.com/data.csv"
# - query: "What is the total sales for the year 2022?"
# Returns:
# - dict: {"answer": "The total sales for 2022 is $100,000."}.
# """
# try:
# updated_query = f"{query} and Do not show any charts or graphs."
# # Process with Groq first
# try:
# groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
# logger.info("groq_answer:", groq_answer)
# if process_answer(groq_answer) == "Empty response received." or groq_answer == None:
# return {"answer": "Sorry, I couldn't find relevant data..."}
# if process_answer(groq_answer) or groq_answer == None:
# raise Exception("Groq response not usable, falling back to LangChain")
# return {"answer": jsonable_encoder(groq_answer)}
# except Exception as groq_error:
# logger.info(f"Groq error, falling back to LangChain: {str(groq_error)}")
# # Process with LangChain if Groq fails
# try:
# lang_answer = await asyncio.to_thread(
# langchain_csv_chat, csv_url, query, False
# )
# if not process_answer(lang_answer):
# return {"answer": jsonable_encoder(lang_answer)}
# return {"answer": "Sorry, I couldn't find relevant data..."}
# except Exception as langchain_error:
# logger.info(f"LangChain processing error: {str(langchain_error)}")
# # last resort: Try with the gemini langchain agent
# try:
# gemini_answer = await asyncio.to_thread(
# langchain_gemini_csv_handler, csv_url, query, False
# )
# if not process_answer(gemini_answer):
# return {"answer": jsonable_encoder(gemini_answer)}
# return {"answer": "Sorry, I couldn't find relevant data..."}
# except Exception as gemini_error:
# logger.info(f"Gemini Langchain processing error: {str(gemini_error)}")
# return {"answer": "error"}
# except Exception as e:
# logger.info(f"Error processing request: {str(e)}")
# return {"answer": "error"}
####################################### Start with lc_gemini #######################################
async def csv_chat(csv_url: str, query: str):
"""
Generate a response based on the provided CSV URL and query.
Prioritizes LangChain-Groq, then raw Groq, and finally LangChain-Gemini as fallback.
Parameters:
- csv_url (str): The URL of the CSV file.
- query (str): The query for generating the response.
Returns:
- dict: A dictionary containing the generated response.
Example:
- csv_url: "https://example.com/data.csv"
- query: "What is the total sales for the year 2022?"
Returns:
- dict: {"answer": "The total sales for 2022 is $100,000."}
"""
try:
updated_query = f"{query} and Do not show any charts or graphs."
# --- 1. First Attempt: LangChain Groq ---
try:
lang_groq_answer = await asyncio.to_thread(
langchain_csv_chat, csv_url, updated_query, False
)
logger.info("LangChain-Groq answer:", lang_groq_answer)
if lang_groq_answer is not None:
return {"answer": jsonable_encoder(lang_groq_answer)}
raise Exception("LangChain-Groq response not usable, falling back to raw Groq")
except Exception as lang_groq_error:
logger.info(f"LangChain-Groq error: {str(lang_groq_error)}")
# --- 2. Second Attempt: Raw Groq Chat ---
try:
raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
logger.info("Raw Groq answer:", raw_groq_answer)
if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None:
raise Exception("Raw Groq response not usable, falling back to LangChain-Gemini")
if process_answer(raw_groq_answer):
raise Exception("Raw Groq response not usable, falling back to LangChain-Gemini")
return {"answer": jsonable_encoder(raw_groq_answer)}
except Exception as raw_groq_error:
logger.info(f"Raw Groq error: {str(raw_groq_error)}")
# --- 3. Final Attempt: LangChain Gemini ---
try:
gemini_answer = await asyncio.to_thread(
langchain_gemini_csv_handler, csv_url, updated_query, False
)
logger.info("LangChain-Gemini answer:", gemini_answer)
if gemini_answer is not None:
return {"answer": jsonable_encoder(gemini_answer)}
raise Exception("All fallbacks exhausted")
except Exception as gemini_error:
logger.info(f"LangChain-Gemini error: {str(gemini_error)}")
return {"answer": "Sorry, I couldn't find relevant data..."}
except Exception as e:
logger.info(f"Unexpected error: {str(e)}")
return {"answer": "error"}
async def csv_chart(csv_url: str, query: str, chat_id: str):
"""
Generate a chart based on the provided CSV URL and query.
Prioritizes OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
Parameters:
- csv_url (str): The URL of the CSV file.
- query (str): The query for generating the chart.
Returns:
- dict: A dictionary containing either:
- {"image_url": "https://example.com/chart.png"} on success, or
- {"error": "error message"} on failure
Example:
- csv_url: "https://example.com/data.csv"
- query: "Show sales trends as a line chart"
Returns:
- dict: {"image_url": "https://storage.example.com/chart_uuid.png"}
"""
async def upload_and_return(image_path: str, chat_id: str) -> dict:
"""Helper function to handle image uploads"""
unique_name = f'{uuid.uuid4()}.png'
public_url = await upload_file_to_supabase(image_path, unique_name, chat_id)
logger.info(f"Uploaded chart: {public_url}")
os.remove(image_path) # Remove the local image file after upload
return {"image_url": public_url}
try:
try:
# --- 1. First Attempt: OpenAI ---
openai_result = await asyncio.to_thread(openai_chart, csv_url, query)
logger.info(f"OpenAI chart result:", openai_result)
if openai_result and openai_result != 'Chart not generated':
return await upload_and_return(openai_result, chat_id)
raise Exception("OpenAI failed to generate chart")
except Exception as openai_error:
logger.info(f"OpenAI failed ({str(openai_error)}), trying raw Groq...")
# --- 2. Second Attempt: Raw Groq ---
try:
groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
logger.info(f"Raw Groq chart result:", groq_result)
if groq_result and groq_result != 'Chart not generated':
return await upload_and_return(groq_result, chat_id)
raise Exception("Raw Groq failed to generate chart")
except Exception as groq_error:
logger.info(f"Raw Groq failed ({str(groq_error)}), trying LangChain Gemini...")
# --- 3. Third Attempt: LangChain Gemini ---
try:
gemini_result = await asyncio.to_thread(
langchain_gemini_csv_handler, csv_url, query, True
)
logger.info("LangChain Gemini chart result:", gemini_result)
# --- i) If Gemini result is a string, return it ---
if gemini_result and isinstance(gemini_result, str):
clean_path = gemini_result.strip()
return await upload_and_return(clean_path, chat_id)
# --- ii) If Gemini result is a list, return the first element ---
if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0:
return await upload_and_return(gemini_result[0], chat_id)
raise Exception("LangChain Gemini returned empty result")
except Exception as gemini_error:
logger.info(f"LangChain Gemini failed ({str(gemini_error)}), trying LangChain Groq...")
# --- 4. Final Attempt: LangChain Groq ---
try:
lc_groq_paths = await asyncio.to_thread(
langchain_csv_chart, csv_url, query, True
)
logger.info("LangChain Groq chart result:", lc_groq_paths)
if isinstance(lc_groq_paths, list) and lc_groq_paths:
return await upload_and_return(lc_groq_paths[0], chat_id)
return {"error": "All chart generation methods failed"}
except Exception as lc_groq_error:
logger.info(f"LangChain Groq failed: {str(lc_groq_error)}")
return {"error": "Could not generate chart"}
except Exception as e:
logger.info(f"Critical error: {str(e)}")
return {"error": "Internal system error"}