|
from util_service import process_answer |
|
import os |
|
import threading |
|
import uuid |
|
from dotenv import load_dotenv |
|
from langchain_groq import ChatGroq |
|
import pandas as pd |
|
from pandasai import SmartDataframe |
|
import numpy as np |
|
import logging |
|
from csv_service import clean_data |
|
from util_service import handle_out_of_range_float |
|
|
|
load_dotenv() |
|
|
|
|
|
current_langchain_key_index = 0 |
|
current_langchain_key_lock = threading.Lock() |
|
|
|
|
|
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") |
|
model_name = os.getenv("GROQ_LLM_MODEL") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
instructions = """ |
|
|
|
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). |
|
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) |
|
- Use colorblind-friendly palette |
|
- Read above instructions and follow them. |
|
- Please do not use any visualization library other than matplotlib or seaborn. |
|
|
|
""" |
|
|
|
|
|
current_groq_chart_key_index = 0 |
|
current_groq_chart_lock = threading.Lock() |
|
|
|
def model(): |
|
global current_groq_chart_key_index, current_groq_chart_lock |
|
with current_groq_chart_lock: |
|
if current_groq_chart_key_index >= len(groq_api_keys): |
|
raise Exception("All API keys exhausted for chart generation") |
|
api_key = groq_api_keys[current_groq_chart_key_index] |
|
return ChatGroq(model=model_name, api_key=api_key) |
|
|
|
def groq_chart(csv_url: str, question: str): |
|
global current_groq_chart_key_index, current_groq_chart_lock |
|
|
|
for attempt in range(len(groq_api_keys)): |
|
try: |
|
|
|
cache_db_path = "/workspace/cache/cache_db_0.11.db" |
|
if os.path.exists(cache_db_path): |
|
try: |
|
os.remove(cache_db_path) |
|
except Exception as e: |
|
logger.info(f"Cache cleanup error: {e}") |
|
|
|
data = clean_data(csv_url) |
|
with current_groq_chart_lock: |
|
current_api_key = groq_api_keys[current_groq_chart_key_index] |
|
|
|
llm = ChatGroq(model=model_name, api_key=current_api_key) |
|
|
|
|
|
chart_filename = f"chart_{uuid.uuid4()}.png" |
|
chart_path = os.path.join("generated_charts", chart_filename) |
|
|
|
|
|
df = SmartDataframe( |
|
data, |
|
config={ |
|
'llm': llm, |
|
'save_charts': True, |
|
'open_charts': False, |
|
'save_charts_path': os.path.dirname(chart_path), |
|
'custom_chart_filename': chart_filename |
|
} |
|
) |
|
|
|
answer = df.chat(question + instructions) |
|
|
|
if process_answer(answer): |
|
return "Chart not generated" |
|
return answer |
|
|
|
except Exception as e: |
|
error = str(e) |
|
if "429" in error: |
|
with current_groq_chart_lock: |
|
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) |
|
else: |
|
logger.info(f"Chart generation error: {error}") |
|
return {"error": error} |
|
|
|
logger.info("All API keys exhausted for chart generation") |
|
return None |