File size: 2,388 Bytes
30e7daa ae28beb 30e7daa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import logging
import os
import threading
from langchain_groq import ChatGroq
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from csv_service import clean_data
import seaborn as sns
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from util_service import _prompt_generator
load_dotenv()
# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()
# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
global current_langchain_key_index, current_langchain_key_lock
data = clean_data(csv_url)
attempts = 0
while attempts < len(groq_api_keys):
with current_langchain_key_lock:
if current_langchain_key_index >= len(groq_api_keys):
current_langchain_key_index = 0
api_key = groq_api_keys[current_langchain_key_index]
current_key = current_langchain_key_index
current_langchain_key_index += 1
attempts += 1
try:
llm = ChatGroq(model=model_name, api_key=api_key)
tool = PythonAstREPLTool(locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib
})
agent = create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=[tool],
return_intermediate_steps=True
)
prompt = _prompt_generator(question, chart_required, csv_url)
result = agent.invoke({"input": prompt})
return result.get("output")
except Exception as e:
logger.info(f"Error with key index {current_key}: {str(e)}")
# If all keys are exhausted, return None
logger.info("All API keys have been exhausted.")
return None |