FastApi / controller.py
Soumik555's picture
Changed supabase query
d784ff5
raw
history blame
35.7 kB
# Import necessary modules
from concurrent.futures import ProcessPoolExecutor
import logging
import os
import asyncio
import threading
import uuid
from fastapi import FastAPI, HTTPException, Header
from fastapi.encoders import jsonable_encoder
from typing import Dict, List
from fastapi.responses import FileResponse
import numpy as np
import pandas as pd
from pandasai import SmartDataframe
from langchain_groq.chat_models import ChatGroq
from dotenv import load_dotenv
from pydantic import BaseModel
from csv_service import clean_data, extract_chart_filenames, generate_csv_data, get_csv_basic_info
from urllib.parse import unquote
from langchain_groq import ChatGroq
import pandas as pd
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from gemini_report_generator import generate_csv_report
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
from orchestrator_agent import csv_orchestrator_chat
from supabase_service import upload_file_to_supabase
from util_service import _prompt_generator, process_answer
from fastapi.middleware.cors import CORSMiddleware
import matplotlib
matplotlib.use('Agg')
# Initialize FastAPI app
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize the ProcessPoolExecutor
max_cpus = os.cpu_count()
logger.info(f"Max CPUs: {max_cpus}")
# Ensure the cache directory exists
os.makedirs("/app/cache", exist_ok=True)
os.makedirs("/app", exist_ok=True)
open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
# Ensure the generated_charts directory exists
os.makedirs("/app/generated_charts", exist_ok=True)
load_dotenv()
image_file_path = os.getenv("IMAGE_FILE_PATH")
image_not_found = os.getenv("IMAGE_NOT_FOUND")
allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_hosts,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
class CsvUrlRequest(BaseModel):
csv_url: str
class ImageRequest(BaseModel):
image_path: str
chat_id: str
class FileProps(BaseModel):
fileName: str
filePath: str
fileType: str # 'csv' | 'image'
class Files(BaseModel):
csv_files: List[FileProps]
image_files: List[FileProps]
class FileBoxProps(BaseModel):
files: Files
dummy_response = FileBoxProps(
files=Files(
csv_files=[
FileProps(
fileName="sales_data.csv",
filePath="/downloads/sales_data.csv",
fileType="csv"
),
FileProps(
fileName="customer_data.csv",
filePath="/downloads/customer_data.csv",
fileType="csv"
)
],
image_files=[
FileProps(
fileName="chart.png",
filePath="/downloads/chart.png",
fileType="image"
),
FileProps(
fileName="graph.jpg",
filePath="/downloads/graph.jpg",
fileType="image"
)
]
)
)
# Thread-safe key management for groq_chat
current_groq_key_index = 0
current_groq_key_lock = threading.Lock()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# PING CHECK
@app.get("/ping")
async def root():
return {"message": "Pong !!"}
# BASIC KNOWLEDGE BASED ON CSV
# Remove trailing slash from the URL otherwise it will redirect to GET method
@app.post("/api/basic_csv_data")
async def basic_csv_data(request: CsvUrlRequest):
try:
decoded_url = unquote(request.csv_url)
logger.info(f"Fetching CSV data from URL: {decoded_url}")
# csv_data = await get_csv_basic_info(decoded_url)
# Run the synchronous function in a thread pool executor
loop = asyncio.get_running_loop()
csv_data = await loop.run_in_executor(
process_executor, get_csv_basic_info, decoded_url
)
logger.info(f"CSV data fetched successfully: {csv_data}")
return {"data": csv_data}
except Exception as e:
logger.error(f"Error while fetching CSV data: {e}")
raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}")
# GET THE CHART FROM A SPECIFIC FILE PATH
@app.post("/api/get-chart")
async def get_image(request: ImageRequest, authorization: str = Header(None)):
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header missing")
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header format")
token = authorization.split(" ")[1]
if not token:
raise HTTPException(status_code=401, detail="Token missing")
if token != os.getenv("AUTH_TOKEN"):
raise HTTPException(status_code=403, detail="Invalid token")
try:
logger.info("Groq Chat created a chat for the user query...")
image_file_path = request.image_path
unique_file_name =f'{str(uuid.uuid4())}.png'
logger.info("Uploading the chart to supabase...")
image_public_url = await upload_file_to_supabase(f"{image_file_path}", unique_file_name, chat_id=request.chat_id)
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
os.remove(image_file_path)
return {"image_url": image_public_url}
# return FileResponse(image_file_path, media_type="image/png")
except Exception as e:
logger.error(f"Error: {e}")
return {"answer": "error"}
# GET CSV DATA FOR GENERATING THE TABLE
@app.post("/api/csv_data")
async def get_csv_data(request: CsvUrlRequest):
try:
decoded_url = unquote(request.csv_url)
logger.info(f"Fetching CSV data from URL: {decoded_url}")
# csv_data = await generate_csv_data(decoded_url)
loop = asyncio.get_running_loop()
csv_data = await loop.run_in_executor(
process_executor, generate_csv_data, decoded_url
)
return csv_data
except Exception as e:
logger.error(f"Error while fetching CSV data: {e}")
raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}")
# CHAT CODING STARTS FROM HERE
# Modified groq_chat function with thread-safe key rotation
def groq_chat(csv_url: str, question: str):
global current_groq_key_index, current_groq_key_lock
while True:
with current_groq_key_lock:
if current_groq_key_index >= len(groq_api_keys):
return {"error": "All API keys exhausted."}
current_api_key = groq_api_keys[current_groq_key_index]
try:
data = clean_data(csv_url)
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename, # Unique filename
'enable_cache': False
}
)
answer = df.chat(question)
# Process different response types
if isinstance(answer, pd.DataFrame):
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
elif isinstance(answer, pd.Series):
processed = answer.apply(handle_out_of_range_float).to_dict()
elif isinstance(answer, list):
processed = [handle_out_of_range_float(item) for item in answer]
elif isinstance(answer, dict):
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
else:
processed = {"answer": str(handle_out_of_range_float(answer))}
return processed
except Exception as e:
error_message = str(e)
if error_message != "":
logger.warning("Rate limit exceeded. Switching to next API key.")
with current_groq_key_lock:
current_groq_key_index += 1
if current_groq_key_index >= len(groq_api_keys):
return {"error": "All API keys exhausted."}
else:
logger.error("Error in groq_chat: %s", e)
return {"error": error_message}
# Modified langchain_csv_chat with thread-safe key rotation
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
global current_langchain_key_index, current_langchain_key_lock
data = clean_data(csv_url)
attempts = 0
while attempts < len(groq_api_keys):
with current_langchain_key_lock:
if current_langchain_key_index >= len(groq_api_keys):
current_langchain_key_index = 0
api_key = groq_api_keys[current_langchain_key_index]
current_key = current_langchain_key_index
current_langchain_key_index += 1
attempts += 1
try:
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
return result.get("output")
except Exception as e:
error_message = str(e)
# if "429" in error_message:
if error_message != "":
with current_langchain_chart_lock:
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1)
logger.warning(f"Rate limit exceeded. Switching to next API key: {groq_api_keys[current_langchain_chart_key_index]}")
else:
logger.error(f"Error with API key {api_key}: {error_message}")
return {"error": error_message}
return {"error": "All API keys exhausted"}
# Async endpoint with non-blocking execution
@app.post("/api/csv-chat")
async def csv_chat(request: Dict, authorization: str = Header(None)):
# Authorization checks
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization")
token = authorization.split(" ")[1]
if token != os.getenv("AUTH_TOKEN"):
raise HTTPException(status_code=403, detail="Invalid token")
try:
query = request.get("query")
csv_url = request.get("csv_url")
decoded_url = unquote(csv_url)
detailed_answer = request.get("detailed_answer")
conversation_history = request.get("conversation_history", [])
generate_report = request.get("generate_report")
chat_id = request.get("chat_id")
if generate_report is True:
report_files = await generate_csv_report(csv_url, query, chat_id)
if report_files is not None:
return {"answer": jsonable_encoder(report_files)}
if if_initial_chat_question(query):
answer = await asyncio.to_thread(
langchain_csv_chat, decoded_url, query, False
)
logger.info("langchain_answer:", answer)
return {"answer": jsonable_encoder(answer)}
# Orchestrate the execution
if detailed_answer is True:
orchestrator_answer = await asyncio.to_thread(
csv_orchestrator_chat, decoded_url, query, conversation_history, chat_id
)
if orchestrator_answer is not None:
return {"answer": jsonable_encoder(orchestrator_answer)}
# Process with groq_chat first
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
logger.info("groq_answer:", groq_answer)
if process_answer(groq_answer) == "Empty response received.":
return {"answer": "Sorry, I couldn't find relevant data..."}
if process_answer(groq_answer):
lang_answer = await asyncio.to_thread(
langchain_csv_chat, decoded_url, query, False
)
if process_answer(lang_answer):
return {"answer": "error"}
return {"answer": jsonable_encoder(lang_answer)}
return {"answer": jsonable_encoder(groq_answer)}
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return {"answer": "error"}
def handle_out_of_range_float(value):
if isinstance(value, float):
if np.isnan(value):
return None
elif np.isinf(value):
return "Infinity"
return value
# CHART CODING STARTS FROM HERE
instructions = """
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
- Use colorblind-friendly palette
- Read above instructions and follow them.
"""
# Thread-safe configuration for chart endpoints
current_groq_chart_key_index = 0
current_groq_chart_lock = threading.Lock()
# current_langchain_chart_key_index = 0
# current_langchain_chart_lock = threading.Lock()
def model():
global current_groq_chart_key_index, current_groq_chart_lock
with current_groq_chart_lock:
if current_groq_chart_key_index >= len(groq_api_keys):
raise Exception("All API keys exhausted for chart generation")
api_key = groq_api_keys[current_groq_chart_key_index]
return ChatGroq(model=model_name, api_key=api_key)
def groq_chart(csv_url: str, question: str):
global current_groq_chart_key_index, current_groq_chart_lock
for attempt in range(len(groq_api_keys)):
try:
# Clean cache before processing
# cache_db_path = "/workspace/cache/cache_db_0.11.db"
# if os.path.exists(cache_db_path):
# try:
# os.remove(cache_db_path)
# except Exception as e:
# print(f"Cache cleanup error: {e}")
data = clean_data(csv_url)
with current_groq_chart_lock:
current_api_key = groq_api_keys[current_groq_chart_key_index]
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename, # Unique filename
'enable_cache': False
}
)
answer = df.chat(question + instructions)
if process_answer(answer):
return "Chart not generated"
return answer
except Exception as e:
error = str(e)
# if "429" in error:
if error != "":
with current_groq_chart_lock:
current_groq_chart_key_index = (current_groq_chart_key_index + 1)
else:
logger.error(f"Chart generation error: {error}")
return {"error": error}
return {"error": "All API keys exhausted for chart generation"}
# def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
# global current_langchain_chart_key_index, current_langchain_chart_lock
# data = clean_data(csv_url)
# for attempt in range(len(groq_api_keys)):
# try:
# with current_langchain_chart_lock:
# api_key = groq_api_keys[current_langchain_chart_key_index]
# current_key = current_langchain_chart_key_index
# current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
# llm = ChatGroq(model=model_name, api_key=api_key)
# tool = PythonAstREPLTool(locals={
# "df": data,
# "pd": pd,
# "np": np,
# "plt": plt,
# "sns": sns,
# "matplotlib": matplotlib,
# "uuid": uuid
# })
# agent = create_pandas_dataframe_agent(
# llm,
# data,
# agent_type="openai-tools",
# verbose=True,
# allow_dangerous_code=True,
# extra_tools=[tool],
# return_intermediate_steps=True
# )
# result = agent.invoke({"input": _prompt_generator(question, True)})
# output = result.get("output", "")
# # Verify chart file creation
# chart_files = extract_chart_filenames(output)
# if len(chart_files) > 0:
# return chart_files
# if attempt < len(groq_api_keys) - 1:
# print(f"Langchain chart error (key {current_key}): {output}")
# except Exception as e:
# print(f"Langchain chart error (key {current_key}): {str(e)}")
# return "Chart generation failed after all retries"
# @app.post("/api/csv-chart")
# async def csv_chart(request: dict, authorization: str = Header(None)):
# # Authorization verification
# if not authorization or not authorization.startswith("Bearer "):
# raise HTTPException(status_code=401, detail="Authorization required")
# token = authorization.split(" ")[1]
# if token != os.getenv("AUTH_TOKEN"):
# raise HTTPException(status_code=403, detail="Invalid credentials")
# try:
# query = request.get("query", "")
# csv_url = unquote(request.get("csv_url", ""))
# # Parallel processing with thread pool
# if if_initial_chart_question(query):
# chart_paths = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# print(chart_paths)
# if len(chart_paths) > 0:
# return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
# # Groq-based chart generation
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
# print(f"Generated Chart: {groq_result}")
# if groq_result != 'Chart not generated':
# return FileResponse(groq_result, media_type="image/png")
# # Fallback to Langchain
# langchain_paths = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# print (langchain_paths)
# if len(langchain_paths) > 0:
# return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
# else:
# return {"error": "All chart generation methods failed"}
# except Exception as e:
# print(f"Critical chart error: {str(e)}")
# return {"error": "Internal system error"}
# MERGED CALL
# class CSVData(BaseModel):
# csv_url: str
# query: str
# chart_required: bool
# @app.post("/api/v1/csv_chat")
# async def csv_chat(csv_data: CSVData, authorization: str = Header(None)):
# # Authorization verification
# if not authorization or not authorization.startswith("Bearer "):
# raise HTTPException(status_code=401, detail="Authorization required")
# token = authorization.split(" ")[1]
# if token != os.getenv("AUTH_TOKEN"):
# raise HTTPException(status_code=403, detail="Invalid credentials")
# csv_url = csv_data.csv_url
# query = csv_data.query
# chart_required = csv_data.chart_required
# if(chart_required == True):
# try:
# # Parallel processing with thread pool
# if if_initial_chart_question(query):
# chart_path = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# if "temp" in chart_path:
# print("langchain chart Generated")
# return FileResponse('temp.png', media_type="image/png")
# return {"error": "Chart generation failed"}
# # Groq-based chart generation
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
# if groq_result == "Chart Generated":
# return FileResponse("exports/charts/temp_chart.png")
# # Fallback to Langchain
# langchain_path = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# if "temp" in langchain_path:
# print("langchain chart Generated")
# return FileResponse('temp.png', media_type="image/png")
# return {"error": "All chart generation methods failed"}
# except Exception as e:
# print(f"Critical chart error: {str(e)}")
# raise HTTPException(status_code=500, detail="Internal server error")
# else:
# try:
# if if_initial_chat_question(query):
# answer = await asyncio.to_thread(
# langchain_csv_chat, csv_url, query, False
# )
# print("langchain_answer:", answer)
# return {"answer": jsonable_encoder(answer)}
# # Process with groq_chat first
# groq_answer = await asyncio.to_thread(groq_chat, csv_url, query)
# print("groq_answer:", groq_answer)
# if process_answer(groq_answer) == "Empty response received.":
# return {"answer": "Sorry, I couldn't find relevant data..."}
# if process_answer(groq_answer):
# lang_answer = await asyncio.to_thread(
# langchain_csv_chat, csv_url, query, False
# )
# if process_answer(lang_answer):
# return {"answer": "error"}
# return {"answer": jsonable_encoder(lang_answer)}
# return {"answer": jsonable_encoder(groq_answer)}
# except Exception as e:
# print(f"Error processing request: {str(e)}")
# raise HTTPException(status_code=500, detail="Internal server error")
# Global locks for key rotation (chart endpoints)
# current_groq_chart_key_index = 0
# current_groq_chart_lock = threading.Lock()
current_langchain_chart_key_index = 0
current_langchain_chart_lock = threading.Lock()
# Use a process pool to run CPU-bound charts generation
process_executor = ProcessPoolExecutor(max_workers=max_cpus-2)
# --- GROQ-BASED CHART GENERATION ---
# def groq_chart(csv_url: str, question: str):
# """
# Generate a chart using the groq-based method.
# Modifications:
# • No deletion of a shared cache file (avoid interference).
# • After chart generation, close all matplotlib figures.
# • Return the full path of the saved chart.
# """
# global current_groq_chart_key_index, current_groq_chart_lock
# for attempt in range(len(groq_api_keys)):
# try:
# # Instead of deleting a global cache file, you might later configure a per-request cache.
# cache_db_path = "/app/cache/cache_db_0.11.db"
# if os.path.exists(cache_db_path):
# try:
# os.remove(cache_db_path)
# print(f"Deleted cache DB file: {cache_db_path}")
# except Exception as e:
# print(f"Error deleting cache DB file: {e}")
# chart_dir = "generated_charts"
# if not os.path.exists(chart_dir):
# os.makedirs(chart_dir)
# data = clean_data(csv_url)
# with current_groq_chart_lock:
# current_api_key = groq_api_keys[current_groq_chart_key_index]
# llm = ChatGroq(model=model_name, api_key=current_api_key)
# # Generate a unique filename and full path for the chart
# chart_filename = f"chart_{uuid.uuid4().hex}.png"
# chart_path = os.path.join("generated_charts", chart_filename)
# # Configure your dataframe tool (e.g. using SmartDataframe) to save charts.
# # (Assuming your SmartDataframe uses these settings to save charts.)
# from pandasai import SmartDataframe # Import here if not already imported
# df = SmartDataframe(
# data,
# config={
# 'llm': llm,
# 'save_charts': True,
# 'open_charts': False,
# 'save_charts_path': os.path.dirname(chart_path),
# 'custom_chart_filename': chart_filename
# }
# )
# # Append any extra instructions if needed
# instructions = """
# - Ensure each value is clearly visible.
# - Adjust font sizes, rotate labels if necessary.
# - Use a colorblind-friendly palette.
# - Arrange multiple charts in a grid if needed.
# """
# answer = df.chat(question + instructions)
# # Make sure to close figures so they don't conflict between processes
# plt.close('all')
# # If process_answer indicates a problem, return a failure message.
# if process_answer(answer):
# return "Chart not generated"
# # Return the chart path that was used for saving
# return chart_path
# except Exception as e:
# error = str(e)
# if "429" in error:
# with current_groq_chart_lock:
# current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
# else:
# print(f"Groq chart generation error: {error}")
# return {"error": error}
# return {"error": "All API keys exhausted for chart generation"}
# --- LANGCHAIN-BASED CHART GENERATION ---
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
"""
Generate a chart using the langchain-based method.
Modifications:
• No shared deletion of cache.
• Close matplotlib figures after generation.
• Return a list of full chart file paths.
"""
global current_langchain_chart_key_index, current_langchain_chart_lock
data = clean_data(csv_url)
for attempt in range(len(groq_api_keys)):
try:
with current_langchain_chart_lock:
api_key = groq_api_keys[current_langchain_chart_key_index]
current_key = current_langchain_chart_key_index
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
result = agent.invoke({"input": _prompt_generator(question, True)})
output = result.get("output", "")
# Close figures to avoid interference
plt.close('all')
# Extract chart filenames (assuming extract_chart_filenames returns a list)
chart_files = extract_chart_filenames(output)
if len(chart_files) > 0:
# Return full paths (join with your image_file_path)
return [os.path.join(image_file_path, f) for f in chart_files]
if attempt < len(groq_api_keys) - 1:
logger.info(f"Langchain chart error (key {current_key}): {output}")
except Exception as e:
error_message = str(e)
if error_message != "":
with current_langchain_chart_lock:
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1)
logger.warning(f"Rate limit exceeded. Switching to next API key: {groq_api_keys[current_langchain_chart_key_index]}")
else:
logger.error(f"Error with API key {api_key}: {error_message}")
return {"error": error_message}
logger.error("All API keys exhausted for chart generation")
return "Chart generation failed after all retries"
# --- FASTAPI ENDPOINT FOR CHART GENERATION ---
@app.post("/api/csv-chart")
async def csv_chart(request: dict, authorization: str = Header(None)):
"""
Endpoint for generating a chart from CSV data.
This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation
functions in separate processes so that multiple requests can run in parallel.
"""
# --- Authorization Check ---
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Authorization required")
token = authorization.split(" ")[1]
if token != os.getenv("AUTH_TOKEN"):
raise HTTPException(status_code=403, detail="Invalid credentials")
try:
query = request.get("query", "")
csv_url = unquote(request.get("csv_url", ""))
detailed_answer = request.get("detailed_answer", False)
conversation_history = request.get("conversation_history", [])
generate_report = request.get("generate_report", False)
chat_id = request.get("chat_id", "")
if generate_report is True:
report_files = await generate_csv_report(csv_url, query, chat_id)
if report_files is not None:
return {"orchestrator_response": jsonable_encoder(report_files)}
loop = asyncio.get_running_loop()
# First, try the langchain-based method if the question qualifies
if if_initial_chart_question(query):
langchain_result = await loop.run_in_executor(
process_executor, langchain_csv_chart, csv_url, query, True
)
logger.info("Langchain chart result:", langchain_result)
if isinstance(langchain_result, list) and len(langchain_result) > 0:
unique_file_name =f'{str(uuid.uuid4())}.png'
logger.info("Uploading the chart to supabase...")
image_public_url = await upload_file_to_supabase(f"{langchain_result[0]}", unique_file_name, chat_id=chat_id)
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
os.remove(langchain_result[0])
return {"image_url": image_public_url}
# return FileResponse(langchain_result[0], media_type="image/png")
# Use orchestrator to handle the user's chart query first
if detailed_answer is True:
orchestrator_answer = await asyncio.to_thread(
csv_orchestrator_chat, csv_url, query, conversation_history, chat_id
)
if orchestrator_answer is not None:
return {"orchestrator_response": jsonable_encoder(orchestrator_answer)}
# Next, try the groq-based method
groq_result = await loop.run_in_executor(
process_executor, groq_chart, csv_url, query
)
logger.info(f"Groq chart result: {groq_result}")
if isinstance(groq_result, str) and groq_result != "Chart not generated":
unique_file_name =f'{str(uuid.uuid4())}.png'
logger.info("Uploading the chart to supabase...")
image_public_url = await upload_file_to_supabase(f"{groq_result}", unique_file_name, chat_id=chat_id)
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
os.remove(groq_result)
return {"image_url": image_public_url}
# return FileResponse(groq_result, media_type="image/png")
# Fallback: try langchain-based again
logger.error("Groq chart generation failed, trying langchain....")
langchain_paths = await loop.run_in_executor(
process_executor, langchain_csv_chart, csv_url, query, True
)
logger.info("Fallback langchain chart result:", langchain_paths)
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
unique_file_name =f'{str(uuid.uuid4())}.png'
logger.info("Uploading the chart to supabase...")
image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name, chat_id=chat_id)
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
os.remove(langchain_paths[0])
return {"image_url": image_public_url}
# return FileResponse(langchain_paths[0], media_type="image/png")
else:
logger.error("All chart generation methods failed")
return {"answer": "error"}
except Exception as e:
logger.error(f"Critical chart error: {str(e)}")
return {"answer": "error"}