import logging import os import threading from langchain_groq import ChatGroq from matplotlib import pyplot as plt import matplotlib import numpy as np import pandas as pd from dotenv import load_dotenv from csv_service import clean_data import seaborn as sns from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent from util_service import _prompt_generator load_dotenv() # Thread-safe key management for langchain_csv_chat current_langchain_key_index = 0 current_langchain_key_lock = threading.Lock() # Load environment variables groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") model_name = os.getenv("GROQ_LLM_MODEL") # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): global current_langchain_key_index, current_langchain_key_lock data = clean_data(csv_url) attempts = 0 while attempts < len(groq_api_keys): with current_langchain_key_lock: if current_langchain_key_index >= len(groq_api_keys): current_langchain_key_index = 0 api_key = groq_api_keys[current_langchain_key_index] current_key = current_langchain_key_index current_langchain_key_index += 1 attempts += 1 try: llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib }) agent = create_pandas_dataframe_agent( llm, data, agent_type="tool-calling", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) prompt = _prompt_generator(question, chart_required) result = agent.invoke({"input": prompt}) return result.get("output") except Exception as e: logger.info(f"Error with key index {current_key}: {str(e)}") # If all keys are exhausted, return None logger.info("All API keys have been exhausted.") return None