import logging import os import threading import uuid from dotenv import load_dotenv from langchain_groq import ChatGroq from matplotlib import pyplot as plt import matplotlib import numpy as np import pandas as pd from csv_service import clean_data, extract_chart_filenames from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent from util_service import _prompt_generator import seaborn as sns load_dotenv() # Thread-safe key management for langchain_csv_chat current_langchain_key_index = 0 current_langchain_key_lock = threading.Lock() # Load environment variables groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") model_name = os.getenv("GROQ_LLM_MODEL") # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) current_langchain_chart_key_index = 0 current_langchain_chart_lock = threading.Lock() def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): global current_langchain_chart_key_index, current_langchain_chart_lock data = clean_data(csv_url) for attempt in range(len(groq_api_keys)): try: with current_langchain_chart_lock: api_key = groq_api_keys[current_langchain_chart_key_index] current_key = current_langchain_chart_key_index current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib, "uuid": uuid }) agent = create_pandas_dataframe_agent( llm, data, agent_type="tool-calling", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) result = agent.invoke({"input": _prompt_generator(question, True, csv_url)}) output = result.get("output", "") # Verify chart file creation chart_files = extract_chart_filenames(output) if len(chart_files) > 0: return chart_files if attempt < len(groq_api_keys) - 1: logger.info(f"Langchain chart error (key {current_key}): {output}") except Exception as e: logger.info(f"Langchain chart error (key {current_key}): {str(e)}") logger.info("All API keys exhausted for chart generation") return None