File size: 2,388 Bytes
30e7daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae28beb
30e7daa
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import logging
import os
import threading
from langchain_groq import ChatGroq
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from csv_service import clean_data
import seaborn as sns
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from util_service import _prompt_generator

load_dotenv()

# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()

# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
    global current_langchain_key_index, current_langchain_key_lock

    data = clean_data(csv_url)
    attempts = 0

    while attempts < len(groq_api_keys):
        with current_langchain_key_lock:
            if current_langchain_key_index >= len(groq_api_keys):
                current_langchain_key_index = 0
            api_key = groq_api_keys[current_langchain_key_index]
            current_key = current_langchain_key_index
            current_langchain_key_index += 1
            attempts += 1

        try:
            llm = ChatGroq(model=model_name, api_key=api_key)
            tool = PythonAstREPLTool(locals={
                "df": data,
                "pd": pd,
                "np": np,
                "plt": plt,
                "sns": sns,
                "matplotlib": matplotlib
            })

            agent = create_pandas_dataframe_agent(
                llm,
                data,
                agent_type="tool-calling",
                verbose=True,
                allow_dangerous_code=True,
                extra_tools=[tool],
                return_intermediate_steps=True
            )

            prompt = _prompt_generator(question, chart_required, csv_url)
            result = agent.invoke({"input": prompt})
            return result.get("output")

        except Exception as e:
            logger.info(f"Error with key index {current_key}: {str(e)}")

    # If all keys are exhausted, return None
    logger.info("All API keys have been exhausted.")
    return None