FastApi / lc_groq_chat.py
Soumik555's picture
openai key rotate
30e7daa
raw
history blame
2.38 kB
import logging
import os
import threading
from langchain_groq import ChatGroq
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from csv_service import clean_data
import seaborn as sns
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from util_service import _prompt_generator
load_dotenv()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
global current_langchain_key_index, current_langchain_key_lock
data = clean_data(csv_url)
attempts = 0
while attempts < len(groq_api_keys):
with current_langchain_key_lock:
if current_langchain_key_index >= len(groq_api_keys):
current_langchain_key_index = 0
api_key = groq_api_keys[current_langchain_key_index]
current_key = current_langchain_key_index
current_langchain_key_index += 1
attempts += 1
try:
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
return result.get("output")
except Exception as e:
logger.info(f"Error with key index {current_key}: {str(e)}")
# If all keys are exhausted, return None
logger.info("All API keys have been exhausted.")
return None