|
import os |
|
import re |
|
import uuid |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
import pandas as pd |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_experimental.tools import PythonAstREPLTool |
|
from langchain_experimental.agents import create_pandas_dataframe_agent |
|
from dotenv import load_dotenv |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
import seaborn as sns |
|
import datetime as dt |
|
|
|
|
|
matplotlib.use('Agg') |
|
|
|
load_dotenv() |
|
model_name = 'gemini-2.0-flash' |
|
google_api_keys = os.getenv("GEMINI_API_KEYS").split(",") |
|
|
|
|
|
llm_instances = [ |
|
ChatGoogleGenerativeAI(model=model_name, api_key=key) |
|
for key in google_api_keys |
|
] |
|
current_instance_index = 0 |
|
|
|
def create_agent(llm, data, tools): |
|
"""Create agent with tool names""" |
|
return create_pandas_dataframe_agent( |
|
llm, |
|
data, |
|
agent_type="tool-calling", |
|
verbose=True, |
|
allow_dangerous_code=True, |
|
extra_tools=tools, |
|
return_intermediate_steps=True |
|
) |
|
|
|
def _prompt_generator(question: str, chart_required: bool): |
|
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines: |
|
|
|
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis. |
|
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability. |
|
3. **Communication:** Provide concise, professional, and well-structured responses. |
|
4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.) |
|
|
|
**Query:** {question} |
|
""" |
|
|
|
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY: |
|
|
|
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex |
|
2. Visualization requirements: |
|
- Adjust font sizes, rotate labels (45° if needed), truncate for readability |
|
- Figure size: (12, 6) |
|
- Descriptive titles (fontsize=14) |
|
- Colorblind-friendly palettes |
|
- Do not use any visualization library other than matplotlib or seaborn |
|
3. File handling rules: |
|
- Create MAXIMUM 2 charts if absolutely necessary |
|
- For multiple charts: |
|
* Arrange in grid format (2x1 vertical layout preferred) |
|
* Use SAME unique_id with suffixes: |
|
- f"{{unique_id}}_1.png" |
|
- f"{{unique_id}}_2.png" |
|
- Save EXCLUSIVELY to "generated_charts" folder |
|
- File naming: f"chart_{{unique_id}}.png" (for single chart) |
|
4. FINAL OUTPUT MUST BE: |
|
- For single chart: f"generated_charts/chart_{{unique_id}}.png" |
|
- For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image) |
|
- **ONLY return this full path string, nothing else** |
|
|
|
**Query:** {question} |
|
|
|
IMPORTANT: |
|
- Generate the unique_id FIRST before any operations |
|
- Use THE SAME unique_id throughout entire process |
|
- NEVER generate new UUIDs after initial creation |
|
- Return EXACT filepath string of the final saved chart |
|
""" |
|
|
|
if chart_required: |
|
return ChatPromptTemplate.from_template(chart_prompt) |
|
else: |
|
return ChatPromptTemplate.from_template(chat_prompt) |
|
|
|
def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool): |
|
global current_instance_index |
|
data = pd.read_csv(csv_url) |
|
|
|
|
|
while current_instance_index < len(llm_instances): |
|
try: |
|
llm = llm_instances[current_instance_index] |
|
print(f"Using LLM instance index {current_instance_index}") |
|
|
|
|
|
tool = PythonAstREPLTool( |
|
locals={ |
|
"df": data, |
|
"pd": pd, |
|
"np": np, |
|
"plt": plt, |
|
"sns": sns, |
|
"matplotlib": matplotlib, |
|
"uuid": uuid, |
|
"dt": dt |
|
}, |
|
) |
|
|
|
agent = create_agent(llm, data, [tool]) |
|
prompt = _prompt_generator(question, chart_required) |
|
result = agent.invoke({"input": prompt}) |
|
output = result.get("output") |
|
|
|
if output is None: |
|
raise ValueError("Received None response from agent") |
|
|
|
return output |
|
|
|
except Exception as e: |
|
print(f"Error using LLM instance index {current_instance_index}: {e}") |
|
current_instance_index += 1 |
|
|
|
print("All LLM instances have been exhausted.") |
|
return None |
|
|
|
|