File size: 2,747 Bytes
30e7daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae28beb
30e7daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging
import os
import threading
import uuid
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
from csv_service import clean_data, extract_chart_filenames
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from util_service import _prompt_generator
import seaborn as sns

load_dotenv()

# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()

# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

current_langchain_chart_key_index = 0
current_langchain_chart_lock = threading.Lock()

def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
    global current_langchain_chart_key_index, current_langchain_chart_lock
    
    data = clean_data(csv_url)

    for attempt in range(len(groq_api_keys)):
        try:
            with current_langchain_chart_lock:
                api_key = groq_api_keys[current_langchain_chart_key_index]
                current_key = current_langchain_chart_key_index
                current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)

            llm = ChatGroq(model=model_name, api_key=api_key)
            tool = PythonAstREPLTool(locals={
                "df": data,
                "pd": pd,
                "np": np,
                "plt": plt,
                "sns": sns,
                "matplotlib": matplotlib,
                "uuid": uuid
            })

            agent = create_pandas_dataframe_agent(
                llm,
                data,
                agent_type="tool-calling",
                verbose=True,
                allow_dangerous_code=True,
                extra_tools=[tool],
                return_intermediate_steps=True
            )

            result = agent.invoke({"input": _prompt_generator(question, True, csv_url)})
            output = result.get("output", "")

            # Verify chart file creation
            chart_files = extract_chart_filenames(output)
            if len(chart_files) > 0:
                return chart_files

            if attempt < len(groq_api_keys) - 1:
                logger.info(f"Langchain chart error (key {current_key}): {output}")

        except Exception as e:
            logger.info(f"Langchain chart error (key {current_key}): {str(e)}")
    
    logger.info("All API keys exhausted for chart generation")
    return None