FastApi / lc_groq_chart.py
Soumik555's picture
openai key rotate
30e7daa
raw
history blame
2.8 kB
import logging
import os
import threading
import uuid
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
from csv_service import clean_data, extract_chart_filenames
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from util_service import _prompt_generator
import seaborn as sns
load_dotenv()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
current_langchain_chart_key_index = 0
current_langchain_chart_lock = threading.Lock()
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
global current_langchain_chart_key_index, current_langchain_chart_lock
data = clean_data(csv_url)
for attempt in range(len(groq_api_keys)):
try:
with current_langchain_chart_lock:
api_key = groq_api_keys[current_langchain_chart_key_index]
current_key = current_langchain_chart_key_index
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
result = agent.invoke({"input": _prompt_generator(f"{question} and use this csv_url: {csv_url} to read the csv file", True)})
output = result.get("output", "")
# Verify chart file creation
chart_files = extract_chart_filenames(output)
if len(chart_files) > 0:
return chart_files
if attempt < len(groq_api_keys) - 1:
logger.info(f"Langchain chart error (key {current_key}): {output}")
except Exception as e:
logger.info(f"Langchain chart error (key {current_key}): {str(e)}")
logger.info("All API keys exhausted for chart generation")
return None