FastApi / controller.py
Soumik555's picture
blank image issue on multiple req
b8d0141
raw
history blame
29.6 kB
# Import necessary modules
import os
import asyncio
import threading
import uuid
from fastapi import FastAPI, HTTPException, Header
from fastapi.encoders import jsonable_encoder
from typing import Dict
from fastapi.responses import FileResponse
import numpy as np
import pandas as pd
from pandasai import SmartDataframe
from langchain_groq.chat_models import ChatGroq
from dotenv import load_dotenv
from pydantic import BaseModel
from csv_service import clean_data, extract_chart_filenames
from urllib.parse import unquote
import csv_service
from langchain_groq import ChatGroq
import pandas as pd
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
from util_service import _prompt_generator, process_answer
from fastapi.middleware.cors import CORSMiddleware
import matplotlib
matplotlib.use('Agg')
# Initialize FastAPI app
app = FastAPI()
# Ensure the cache directory exists
os.makedirs("/app/cache", exist_ok=True)
os.makedirs("/app", exist_ok=True)
open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
# Ensure the generated_charts directory exists
os.makedirs("/app/generated_charts", exist_ok=True)
load_dotenv()
image_file_path = os.getenv("IMAGE_FILE_PATH")
image_not_found = os.getenv("IMAGE_NOT_FOUND")
allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_hosts,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
class CsvUrlRequest(BaseModel):
csv_url: str
class ImageRequest(BaseModel):
image_path: str
# Thread-safe key management for groq_chat
current_groq_key_index = 0
current_groq_key_lock = threading.Lock()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# PING CHECK
@app.get("/ping")
async def root():
return {"message": "Pong !!"}
# BASIC KNOWLEDGE BASED ON CSV
# Remove trailing slash from the URL otherwise it will redirect to GET method
@app.post("/api/basic_csv_data")
async def basic_csv_data(request: CsvUrlRequest):
try:
decoded_url = unquote(request.csv_url)
print(f"Fetching CSV data from URL: {decoded_url}")
csv_data = csv_service.get_csv_basic_info(decoded_url)
print(f"CSV data fetched successfully: {csv_data}")
return {"data": csv_data}
except Exception as e:
print(f"Error while fetching CSV data: {e}")
raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}")
# GET THE CHART FROM A SPECIFIC FILE PATH
@app.post("/api/get-chart")
async def get_image(request: ImageRequest, authorization: str = Header(None)):
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header missing")
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header format")
token = authorization.split(" ")[1]
if not token:
raise HTTPException(status_code=401, detail="Token missing")
if token != os.getenv("AUTH_TOKEN"):
raise HTTPException(status_code=403, detail="Invalid token")
try:
image_file_path = request.image_path
return FileResponse(image_file_path, media_type="image/png")
except Exception as e:
print(f"Error: {e}")
return {"answer": "error"}
# GET CSV DATA FOR GENERATING THE TABLE
@app.post("/api/csv_data")
async def get_csv_data(request: CsvUrlRequest):
try:
decoded_url = unquote(request.csv_url)
# print(f"Fetching CSV data from URL: {decoded_url}")
csv_data = csv_service.generate_csv_data(decoded_url)
return csv_data
except Exception as e:
# print(f"Error while fetching CSV data: {e}")
raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}")
# CHAT CODING STARTS FROM HERE
# Modified groq_chat function with thread-safe key rotation
def groq_chat(csv_url: str, question: str):
global current_groq_key_index, current_groq_key_lock
while True:
with current_groq_key_lock:
if current_groq_key_index >= len(groq_api_keys):
return {"error": "All API keys exhausted."}
current_api_key = groq_api_keys[current_groq_key_index]
try:
# Delete cache file if exists
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
print(f"Error deleting cache DB file: {e}")
data = clean_data(csv_url)
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question)
# Process different response types
if isinstance(answer, pd.DataFrame):
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
elif isinstance(answer, pd.Series):
processed = answer.apply(handle_out_of_range_float).to_dict()
elif isinstance(answer, list):
processed = [handle_out_of_range_float(item) for item in answer]
elif isinstance(answer, dict):
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
else:
processed = {"answer": str(handle_out_of_range_float(answer))}
return processed
except Exception as e:
error_message = str(e)
if "429" in error_message:
with current_groq_key_lock:
current_groq_key_index += 1
if current_groq_key_index >= len(groq_api_keys):
return {"error": "All API keys exhausted."}
else:
return {"error": error_message}
# Modified langchain_csv_chat with thread-safe key rotation
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
global current_langchain_key_index, current_langchain_key_lock
data = clean_data(csv_url)
attempts = 0
while attempts < len(groq_api_keys):
with current_langchain_key_lock:
if current_langchain_key_index >= len(groq_api_keys):
current_langchain_key_index = 0
api_key = groq_api_keys[current_langchain_key_index]
current_key = current_langchain_key_index
current_langchain_key_index += 1
attempts += 1
try:
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="openai-tools",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
return result.get("output")
except Exception as e:
print(f"Error with key index {current_key}: {str(e)}")
return {"error": "All API keys exhausted"}
# Async endpoint with non-blocking execution
@app.post("/api/csv-chat")
async def csv_chat(request: Dict, authorization: str = Header(None)):
# Authorization checks
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization")
token = authorization.split(" ")[1]
if token != os.getenv("AUTH_TOKEN"):
raise HTTPException(status_code=403, detail="Invalid token")
try:
query = request.get("query")
csv_url = request.get("csv_url")
decoded_url = unquote(csv_url)
if if_initial_chat_question(query):
answer = await asyncio.to_thread(
langchain_csv_chat, decoded_url, query, False
)
print("langchain_answer:", answer)
return {"answer": jsonable_encoder(answer)}
# Process with groq_chat first
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
print("groq_answer:", groq_answer)
if process_answer(groq_answer) == "Empty response received.":
return {"answer": "Sorry, I couldn't find relevant data..."}
if process_answer(groq_answer):
lang_answer = await asyncio.to_thread(
langchain_csv_chat, decoded_url, query, False
)
if process_answer(lang_answer):
return {"answer": "error"}
return {"answer": jsonable_encoder(lang_answer)}
return {"answer": jsonable_encoder(groq_answer)}
except Exception as e:
print(f"Error processing request: {str(e)}")
return {"answer": "error"}
def handle_out_of_range_float(value):
if isinstance(value, float):
if np.isnan(value):
return None
elif np.isinf(value):
return "Infinity"
return value
# CHART CODING STARTS FROM HERE
# instructions = """
# - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
# - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
# - Use colorblind-friendly palette
# - Read above instructions and follow them.
# """
# # Thread-safe configuration for chart endpoints
# current_groq_chart_key_index = 0
# current_groq_chart_lock = threading.Lock()
# current_langchain_chart_key_index = 0
# current_langchain_chart_lock = threading.Lock()
# def model():
# global current_groq_chart_key_index, current_groq_chart_lock
# with current_groq_chart_lock:
# if current_groq_chart_key_index >= len(groq_api_keys):
# raise Exception("All API keys exhausted for chart generation")
# api_key = groq_api_keys[current_groq_chart_key_index]
# return ChatGroq(model=model_name, api_key=api_key)
# def groq_chart(csv_url: str, question: str):
# global current_groq_chart_key_index, current_groq_chart_lock
# for attempt in range(len(groq_api_keys)):
# try:
# # Clean cache before processing
# cache_db_path = "/workspace/cache/cache_db_0.11.db"
# if os.path.exists(cache_db_path):
# try:
# os.remove(cache_db_path)
# except Exception as e:
# print(f"Cache cleanup error: {e}")
# data = clean_data(csv_url)
# with current_groq_chart_lock:
# current_api_key = groq_api_keys[current_groq_chart_key_index]
# llm = ChatGroq(model=model_name, api_key=current_api_key)
# # Generate unique filename using UUID
# chart_filename = f"chart_{uuid.uuid4()}.png"
# chart_path = os.path.join("generated_charts", chart_filename)
# # Configure SmartDataframe with chart settings
# df = SmartDataframe(
# data,
# config={
# 'llm': llm,
# 'save_charts': True, # Enable chart saving
# 'open_charts': False,
# 'save_charts_path': os.path.dirname(chart_path), # Directory to save
# 'custom_chart_filename': chart_filename # Unique filename
# }
# )
# answer = df.chat(question + instructions)
# if process_answer(answer):
# return "Chart not generated"
# return answer
# except Exception as e:
# error = str(e)
# if "429" in error:
# with current_groq_chart_lock:
# current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
# else:
# print(f"Chart generation error: {error}")
# return {"error": error}
# return {"error": "All API keys exhausted for chart generation"}
# def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
# global current_langchain_chart_key_index, current_langchain_chart_lock
# data = clean_data(csv_url)
# for attempt in range(len(groq_api_keys)):
# try:
# with current_langchain_chart_lock:
# api_key = groq_api_keys[current_langchain_chart_key_index]
# current_key = current_langchain_chart_key_index
# current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
# llm = ChatGroq(model=model_name, api_key=api_key)
# tool = PythonAstREPLTool(locals={
# "df": data,
# "pd": pd,
# "np": np,
# "plt": plt,
# "sns": sns,
# "matplotlib": matplotlib,
# "uuid": uuid
# })
# agent = create_pandas_dataframe_agent(
# llm,
# data,
# agent_type="openai-tools",
# verbose=True,
# allow_dangerous_code=True,
# extra_tools=[tool],
# return_intermediate_steps=True
# )
# result = agent.invoke({"input": _prompt_generator(question, True)})
# output = result.get("output", "")
# # Verify chart file creation
# chart_files = extract_chart_filenames(output)
# if len(chart_files) > 0:
# return chart_files
# if attempt < len(groq_api_keys) - 1:
# print(f"Langchain chart error (key {current_key}): {output}")
# except Exception as e:
# print(f"Langchain chart error (key {current_key}): {str(e)}")
# return "Chart generation failed after all retries"
# @app.post("/api/csv-chart")
# async def csv_chart(request: dict, authorization: str = Header(None)):
# # Authorization verification
# if not authorization or not authorization.startswith("Bearer "):
# raise HTTPException(status_code=401, detail="Authorization required")
# token = authorization.split(" ")[1]
# if token != os.getenv("AUTH_TOKEN"):
# raise HTTPException(status_code=403, detail="Invalid credentials")
# try:
# query = request.get("query", "")
# csv_url = unquote(request.get("csv_url", ""))
# # Parallel processing with thread pool
# if if_initial_chart_question(query):
# chart_paths = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# print(chart_paths)
# if len(chart_paths) > 0:
# return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
# # Groq-based chart generation
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
# print(f"Generated Chart: {groq_result}")
# if groq_result != 'Chart not generated':
# return FileResponse(groq_result, media_type="image/png")
# # Fallback to Langchain
# langchain_paths = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# print (langchain_paths)
# if len(langchain_paths) > 0:
# return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
# else:
# return {"error": "All chart generation methods failed"}
# except Exception as e:
# print(f"Critical chart error: {str(e)}")
# return {"error": "Internal system error"}
# MERGED CALL
# class CSVData(BaseModel):
# csv_url: str
# query: str
# chart_required: bool
# @app.post("/api/v1/csv_chat")
# async def csv_chat(csv_data: CSVData, authorization: str = Header(None)):
# # Authorization verification
# if not authorization or not authorization.startswith("Bearer "):
# raise HTTPException(status_code=401, detail="Authorization required")
# token = authorization.split(" ")[1]
# if token != os.getenv("AUTH_TOKEN"):
# raise HTTPException(status_code=403, detail="Invalid credentials")
# csv_url = csv_data.csv_url
# query = csv_data.query
# chart_required = csv_data.chart_required
# if(chart_required == True):
# try:
# # Parallel processing with thread pool
# if if_initial_chart_question(query):
# chart_path = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# if "temp" in chart_path:
# print("langchain chart Generated")
# return FileResponse('temp.png', media_type="image/png")
# return {"error": "Chart generation failed"}
# # Groq-based chart generation
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
# if groq_result == "Chart Generated":
# return FileResponse("exports/charts/temp_chart.png")
# # Fallback to Langchain
# langchain_path = await asyncio.to_thread(
# langchain_csv_chart, csv_url, query, True
# )
# if "temp" in langchain_path:
# print("langchain chart Generated")
# return FileResponse('temp.png', media_type="image/png")
# return {"error": "All chart generation methods failed"}
# except Exception as e:
# print(f"Critical chart error: {str(e)}")
# raise HTTPException(status_code=500, detail="Internal server error")
# else:
# try:
# if if_initial_chat_question(query):
# answer = await asyncio.to_thread(
# langchain_csv_chat, csv_url, query, False
# )
# print("langchain_answer:", answer)
# return {"answer": jsonable_encoder(answer)}
# # Process with groq_chat first
# groq_answer = await asyncio.to_thread(groq_chat, csv_url, query)
# print("groq_answer:", groq_answer)
# if process_answer(groq_answer) == "Empty response received.":
# return {"answer": "Sorry, I couldn't find relevant data..."}
# if process_answer(groq_answer):
# lang_answer = await asyncio.to_thread(
# langchain_csv_chat, csv_url, query, False
# )
# if process_answer(lang_answer):
# return {"answer": "error"}
# return {"answer": jsonable_encoder(lang_answer)}
# return {"answer": jsonable_encoder(groq_answer)}
# except Exception as e:
# print(f"Error processing request: {str(e)}")
# raise HTTPException(status_code=500, detail="Internal server error")
import os
import asyncio
import threading
import uuid
from fastapi import FastAPI, HTTPException, Header
from fastapi.responses import FileResponse
from urllib.parse import unquote
from pydantic import BaseModel
from concurrent.futures import ProcessPoolExecutor
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
import seaborn as sns
# Import your custom modules (assumed available)
from csv_service import clean_data, extract_chart_filenames
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from langchain_groq import ChatGroq
from util_service import _prompt_generator, process_answer
from intitial_q_handler import if_initial_chart_question
# Use non-interactive backend
matplotlib.use('Agg')
# FastAPI app initialization
app = FastAPI()
# Environment variables and configuration
import os
groq_api_keys = os.getenv("GROQ_API_KEYS", "").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
image_file_path = os.getenv("IMAGE_FILE_PATH") # e.g. "/app/generated_charts"
# Global locks for key rotation (chart endpoints)
current_groq_chart_key_index = 0
current_groq_chart_lock = threading.Lock()
current_langchain_chart_key_index = 0
current_langchain_chart_lock = threading.Lock()
# Use a process pool to run CPU-bound chart generation
process_executor = ProcessPoolExecutor(max_workers=2)
# --- GROQ-BASED CHART GENERATION ---
def groq_chart(csv_url: str, question: str):
"""
Generate a chart using the groq-based method.
Modifications:
• No deletion of a shared cache file (avoid interference).
• After chart generation, close all matplotlib figures.
• Return the full path of the saved chart.
"""
global current_groq_chart_key_index, current_groq_chart_lock
for attempt in range(len(groq_api_keys)):
try:
# Instead of deleting a global cache file, you might later configure a per-request cache.
data = clean_data(csv_url)
with current_groq_chart_lock:
current_api_key = groq_api_keys[current_groq_chart_key_index]
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate a unique filename and full path for the chart
chart_filename = f"chart_{uuid.uuid4().hex}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure your dataframe tool (e.g. using SmartDataframe) to save charts.
# (Assuming your SmartDataframe uses these settings to save charts.)
from pandasai import SmartDataframe # Import here if not already imported
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True,
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path),
'custom_chart_filename': chart_filename
}
)
# Append any extra instructions if needed
instructions = """
- Ensure each value is clearly visible.
- Adjust font sizes, rotate labels if necessary.
- Use a colorblind-friendly palette.
- Arrange multiple charts in a grid if needed.
"""
answer = df.chat(question + instructions)
# Make sure to close figures so they don't conflict between processes
plt.close('all')
# If process_answer indicates a problem, return a failure message.
if process_answer(answer):
return "Chart not generated"
# Return the chart path that was used for saving
return chart_path
except Exception as e:
error = str(e)
if "429" in error:
with current_groq_chart_lock:
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
else:
print(f"Groq chart generation error: {error}")
return {"error": error}
return {"error": "All API keys exhausted for chart generation"}
# --- LANGCHAIN-BASED CHART GENERATION ---
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
"""
Generate a chart using the langchain-based method.
Modifications:
• No shared deletion of cache.
• Close matplotlib figures after generation.
• Return a list of full chart file paths.
"""
global current_langchain_chart_key_index, current_langchain_chart_lock
data = clean_data(csv_url)
for attempt in range(len(groq_api_keys)):
try:
with current_langchain_chart_lock:
api_key = groq_api_keys[current_langchain_chart_key_index]
current_key = current_langchain_chart_key_index
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="openai-tools",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
result = agent.invoke({"input": _prompt_generator(question, True)})
output = result.get("output", "")
# Close figures to avoid interference
plt.close('all')
# Extract chart filenames (assuming extract_chart_filenames returns a list)
chart_files = extract_chart_filenames(output)
if len(chart_files) > 0:
# Return full paths (join with your image_file_path)
return [os.path.join(image_file_path, f) for f in chart_files]
if attempt < len(groq_api_keys) - 1:
print(f"Langchain chart error (key {current_key}): {output}")
except Exception as e:
print(f"Langchain chart error (key {current_key}): {str(e)}")
return "Chart generation failed after all retries"
# --- FASTAPI ENDPOINT FOR CHART GENERATION ---
@app.post("/api/csv-chart")
async def csv_chart(request: dict, authorization: str = Header(None)):
"""
Endpoint for generating a chart from CSV data.
This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation
functions in separate processes so that multiple requests can run in parallel.
"""
# --- Authorization Check ---
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Authorization required")
token = authorization.split(" ")[1]
if token != os.getenv("AUTH_TOKEN"):
raise HTTPException(status_code=403, detail="Invalid credentials")
try:
query = request.get("query", "")
csv_url = unquote(request.get("csv_url", ""))
loop = asyncio.get_running_loop()
# First, try the langchain-based method if the question qualifies
if if_initial_chart_question(query):
langchain_result = await loop.run_in_executor(
process_executor, langchain_csv_chart, csv_url, query, True
)
print("Langchain chart result:", langchain_result)
if isinstance(langchain_result, list) and len(langchain_result) > 0:
return FileResponse(langchain_result[0], media_type="image/png")
# Next, try the groq-based method
groq_result = await loop.run_in_executor(
process_executor, groq_chart, csv_url, query
)
print(f"Groq chart result: {groq_result}")
if isinstance(groq_result, str) and groq_result != "Chart not generated":
return FileResponse(groq_result, media_type="image/png")
# Fallback: try langchain-based again
langchain_paths = await loop.run_in_executor(
process_executor, langchain_csv_chart, csv_url, query, True
)
print("Fallback langchain chart result:", langchain_paths)
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
return FileResponse(langchain_paths[0], media_type="image/png")
else:
return {"error": "All chart generation methods failed"}
except Exception as e:
print(f"Critical chart error: {str(e)}")
return {"error": "Internal system error"}