FastApi / groq_chart.py
Soumik555's picture
Cron job
2ccbdb1
raw
history blame
3.66 kB
from util_service import process_answer
import os
import threading
import uuid
from dotenv import load_dotenv
from langchain_groq import ChatGroq
import pandas as pd
from pandasai import SmartDataframe
import numpy as np
import logging
from csv_service import clean_data
from util_service import handle_out_of_range_float
load_dotenv()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
instructions = """
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
- Use colorblind-friendly palette
- Read above instructions and follow them.
- Please do not use any visualization library other than matplotlib or seaborn.
"""
# Thread-safe configuration for chart endpoints
current_groq_chart_key_index = 0
current_groq_chart_lock = threading.Lock()
def model():
global current_groq_chart_key_index, current_groq_chart_lock
with current_groq_chart_lock:
if current_groq_chart_key_index >= len(groq_api_keys):
raise Exception("All API keys exhausted for chart generation")
api_key = groq_api_keys[current_groq_chart_key_index]
return ChatGroq(model=model_name, api_key=api_key)
def groq_chart(csv_url: str, question: str):
global current_groq_chart_key_index, current_groq_chart_lock
for attempt in range(len(groq_api_keys)):
try:
# Clean cache before processing
cache_db_path = "/workspace/cache/cache_db_0.11.db"
if os.path.exists(cache_db_path):
try:
os.remove(cache_db_path)
except Exception as e:
logger.info(f"Cache cleanup error: {e}")
data = clean_data(csv_url)
with current_groq_chart_lock:
current_api_key = groq_api_keys[current_groq_chart_key_index]
llm = ChatGroq(model=model_name, api_key=current_api_key)
# Generate unique filename using UUID
chart_filename = f"chart_{uuid.uuid4()}.png"
chart_path = os.path.join("generated_charts", chart_filename)
# Configure SmartDataframe with chart settings
df = SmartDataframe(
data,
config={
'llm': llm,
'save_charts': True, # Enable chart saving
'open_charts': False,
'save_charts_path': os.path.dirname(chart_path), # Directory to save
'custom_chart_filename': chart_filename # Unique filename
}
)
answer = df.chat(question + instructions)
if process_answer(answer):
return "Chart not generated"
return answer
except Exception as e:
error = str(e)
if "429" in error:
with current_groq_chart_lock:
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
else:
logger.info(f"Chart generation error: {error}")
return {"error": error}
logger.info("All API keys exhausted for chart generation")
return None