|
import logging |
|
import os |
|
import threading |
|
import uuid |
|
from dotenv import load_dotenv |
|
from langchain_groq import ChatGroq |
|
from matplotlib import pyplot as plt |
|
import matplotlib |
|
import numpy as np |
|
import pandas as pd |
|
from csv_service import clean_data, extract_chart_filenames |
|
from langchain_experimental.tools import PythonAstREPLTool |
|
from langchain_experimental.agents import create_pandas_dataframe_agent |
|
from util_service import _prompt_generator |
|
import seaborn as sns |
|
|
|
load_dotenv() |
|
|
|
|
|
current_langchain_key_index = 0 |
|
current_langchain_key_lock = threading.Lock() |
|
|
|
|
|
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") |
|
model_name = os.getenv("GROQ_LLM_MODEL") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
current_langchain_chart_key_index = 0 |
|
current_langchain_chart_lock = threading.Lock() |
|
|
|
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): |
|
global current_langchain_chart_key_index, current_langchain_chart_lock |
|
|
|
data = clean_data(csv_url) |
|
|
|
for attempt in range(len(groq_api_keys)): |
|
try: |
|
with current_langchain_chart_lock: |
|
api_key = groq_api_keys[current_langchain_chart_key_index] |
|
current_key = current_langchain_chart_key_index |
|
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) |
|
|
|
llm = ChatGroq(model=model_name, api_key=api_key) |
|
tool = PythonAstREPLTool(locals={ |
|
"df": data, |
|
"pd": pd, |
|
"np": np, |
|
"plt": plt, |
|
"sns": sns, |
|
"matplotlib": matplotlib, |
|
"uuid": uuid |
|
}) |
|
|
|
agent = create_pandas_dataframe_agent( |
|
llm, |
|
data, |
|
agent_type="tool-calling", |
|
verbose=True, |
|
allow_dangerous_code=True, |
|
extra_tools=[tool], |
|
return_intermediate_steps=True |
|
) |
|
|
|
result = agent.invoke({"input": _prompt_generator(f"{question} and use this csv_url: {csv_url} to read the csv file", True)}) |
|
output = result.get("output", "") |
|
|
|
|
|
chart_files = extract_chart_filenames(output) |
|
if len(chart_files) > 0: |
|
return chart_files |
|
|
|
if attempt < len(groq_api_keys) - 1: |
|
logger.info(f"Langchain chart error (key {current_key}): {output}") |
|
|
|
except Exception as e: |
|
logger.info(f"Langchain chart error (key {current_key}): {str(e)}") |
|
|
|
logger.info("All API keys exhausted for chart generation") |
|
return None |