import os | |
import re | |
import uuid | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import pandas as pd | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_experimental.tools import PythonAstREPLTool | |
from langchain_experimental.agents import create_pandas_dataframe_agent | |
from dotenv import load_dotenv | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import seaborn as sns | |
import datetime as dt | |
# Set the backend for matplotlib to 'Agg' to avoid GUI issues | |
matplotlib.use('Agg') | |
load_dotenv() | |
model_name = 'gemini-2.0-flash' # Specify the model name | |
google_api_keys = os.getenv("GEMINI_API_KEYS").split(",") | |
# Create pre-initialized LLM instances | |
llm_instances = [ | |
ChatGoogleGenerativeAI(model=model_name, api_key=key) | |
for key in google_api_keys | |
] | |
current_instance_index = 0 # Track current instance being used | |
def create_agent(llm, data, tools): | |
"""Create agent with tool names""" | |
return create_pandas_dataframe_agent( | |
llm, | |
data, | |
agent_type="tool-calling", | |
verbose=True, | |
allow_dangerous_code=True, | |
extra_tools=tools, | |
return_intermediate_steps=True | |
) | |
def _prompt_generator(question: str, chart_required: bool): | |
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines: | |
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis. | |
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability. | |
3. **Communication:** Provide concise, professional, and well-structured responses. | |
4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.) | |
**Query:** {question} | |
""" | |
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY: | |
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex | |
2. Visualization requirements: | |
- Adjust font sizes, rotate labels (45° if needed), truncate for readability | |
- Figure size: (12, 6) | |
- Descriptive titles (fontsize=14) | |
- Colorblind-friendly palettes | |
- Do not use any visualization library other than matplotlib or seaborn | |
3. File handling rules: | |
- Create MAXIMUM 2 charts if absolutely necessary | |
- For multiple charts: | |
* Arrange in grid format (2x1 vertical layout preferred) | |
* Use SAME unique_id with suffixes: | |
- f"{{unique_id}}_1.png" | |
- f"{{unique_id}}_2.png" | |
- Save EXCLUSIVELY to "generated_charts" folder | |
- File naming: f"chart_{{unique_id}}.png" (for single chart) | |
4. FINAL OUTPUT MUST BE: | |
- For single chart: f"generated_charts/chart_{{unique_id}}.png" | |
- For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image) | |
- **ONLY return this full path string, nothing else** | |
**Query:** {question} | |
IMPORTANT: | |
- Generate the unique_id FIRST before any operations | |
- Use THE SAME unique_id throughout entire process | |
- NEVER generate new UUIDs after initial creation | |
- Return EXACT filepath string of the final saved chart | |
""" | |
if chart_required: | |
return ChatPromptTemplate.from_template(chart_prompt) | |
else: | |
return ChatPromptTemplate.from_template(chat_prompt) | |
def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool): | |
global current_instance_index | |
data = pd.read_csv(csv_url) | |
# Try all available instances | |
while current_instance_index < len(llm_instances): | |
try: | |
llm = llm_instances[current_instance_index] | |
print(f"Using LLM instance index {current_instance_index}") | |
# Create tool with validated name | |
tool = PythonAstREPLTool( | |
locals={ | |
"df": data, | |
"pd": pd, | |
"np": np, | |
"plt": plt, | |
"sns": sns, | |
"matplotlib": matplotlib, | |
"uuid": uuid, | |
"dt": dt | |
}, | |
) | |
agent = create_agent(llm, data, [tool]) | |
prompt = _prompt_generator(question, chart_required) | |
result = agent.invoke({"input": prompt}) | |
output = result.get("output") | |
if output is None: | |
raise ValueError("Received None response from agent") | |
return output | |
except Exception as e: | |
print(f"Error using LLM instance index {current_instance_index}: {e}") | |
current_instance_index += 1 | |
print("All LLM instances have been exhausted.") | |
return None | |
# import os | |
# import re | |
# import uuid | |
# from langchain_google_genai import ChatGoogleGenerativeAI | |
# import pandas as pd | |
# from langchain_core.prompts import ChatPromptTemplate | |
# from langchain_experimental.tools import PythonAstREPLTool | |
# from langchain_experimental.agents import create_pandas_dataframe_agent | |
# from dotenv import load_dotenv | |
# import numpy as np | |
# import matplotlib.pyplot as plt | |
# import matplotlib | |
# import seaborn as sns | |
# import datetime as dt | |
# # Set the backend for matplotlib to 'Agg' to avoid GUI issues | |
# matplotlib.use('Agg') | |
# load_dotenv() | |
# model_name = 'gemini-2.0-flash' # Specify the model name | |
# google_api_keys = os.getenv("GEMINI_API_KEYS").split(",") | |
# # Create pre-initialized LLM instances | |
# llm_instances = [ | |
# ChatGoogleGenerativeAI(model=model_name, api_key=key) | |
# for key in google_api_keys | |
# ] | |
# current_instance_index = 0 # Track current instance being used | |
# def is_retryable_error(error: Exception) -> bool: | |
# """Check if the error should trigger a retry with next instance""" | |
# error_str = str(error).lower() | |
# retry_conditions = [ | |
# # Rate limiting and quota errors | |
# '429' in error_str, | |
# 'quota' in error_str, | |
# 'rate limit' in error_str, | |
# 'resource exhausted' in error_str, | |
# 'exceeded' in error_str, | |
# 'limit reached' in error_str, | |
# # Authentication and permission errors | |
# 'permission denied' in error_str, | |
# 'invalid api key' in error_str, | |
# 'authentication' in error_str, | |
# # Server errors | |
# '500' in error_str, | |
# '503' in error_str, | |
# 'service unavailable' in error_str, | |
# # Connection issues | |
# 'timeout' in error_str, | |
# 'connection' in error_str, | |
# # Content policy | |
# 'content policy' in error_str, | |
# 'safety' in error_str, | |
# 'blocked' in error_str | |
# ] | |
# return any(retry_conditions) | |
# def create_agent(llm, data, tools): | |
# """Create agent with tool names""" | |
# return create_pandas_dataframe_agent( | |
# llm, | |
# data, | |
# agent_type="tool-calling", | |
# verbose=True, | |
# allow_dangerous_code=True, | |
# extra_tools=tools, | |
# return_intermediate_steps=True | |
# ) | |
# def _prompt_generator(question: str, chart_required: bool): | |
# chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines: | |
# 1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis. | |
# 2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability. | |
# 3. **Communication:** Provide concise, professional, and well-structured responses. | |
# 4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.) | |
# **Query:** {question} | |
# """ | |
# chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY: | |
# 1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex | |
# 2. Visualization requirements: | |
# - Adjust font sizes, rotate labels (45° if needed), truncate for readability | |
# - Figure size: (12, 6) | |
# - Descriptive titles (fontsize=14) | |
# - Colorblind-friendly palettes | |
# 3. File handling rules: | |
# - Create MAXIMUM 2 charts if absolutely necessary | |
# - For multiple charts: | |
# * Arrange in grid format (2x1 vertical layout preferred) | |
# * Use SAME unique_id with suffixes: | |
# - f"{{unique_id}}_1.png" | |
# - f"{{unique_id}}_2.png" | |
# - Save EXCLUSIVELY to "generated_charts" folder | |
# - File naming: f"chart_{{unique_id}}.png" (for single chart) | |
# 4. FINAL OUTPUT MUST BE: | |
# - For single chart: f"generated_charts/chart_{{unique_id}}.png" | |
# - For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image) | |
# - **ONLY return this full path string, nothing else** | |
# **Query:** {question} | |
# IMPORTANT: | |
# - Generate the unique_id FIRST before any operations | |
# - Use THE SAME unique_id throughout entire process | |
# - NEVER generate new UUIDs after initial creation | |
# - Return EXACT filepath string of the final saved chart | |
# """ | |
# if chart_required: | |
# return ChatPromptTemplate.from_template(chart_prompt) | |
# else: | |
# return ChatPromptTemplate.from_template(chat_prompt) | |
# def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool): | |
# global current_instance_index | |
# data = pd.read_csv(csv_url) | |
# # Track first error in case all instances fail | |
# first_error = None | |
# while current_instance_index < len(llm_instances): | |
# try: | |
# llm = llm_instances[current_instance_index] | |
# print(f"Attempting with LLM instance {current_instance_index + 1}/{len(llm_instances)}") | |
# # Create tool with validated name | |
# tool = PythonAstREPLTool( | |
# locals={ | |
# "df": data, | |
# "pd": pd, | |
# "np": np, | |
# "plt": plt, | |
# "sns": sns, | |
# "matplotlib": matplotlib, | |
# "uuid": uuid, | |
# "dt": dt | |
# }, | |
# ) | |
# agent = create_agent(llm, data, [tool]) | |
# prompt = _prompt_generator(question, chart_required) | |
# result = agent.invoke({"input": prompt}) | |
# output = result.get("output") | |
# if output is None: | |
# raise ValueError("Received None response from agent") | |
# if isinstance(output, str) and any(err in output.lower() for err in ['quota', 'limit', 'exhausted']): | |
# raise ValueError(f"API limitation detected in response: {output}") | |
# return output | |
# except Exception as e: | |
# error_msg = f"Error with instance {current_instance_index}: {str(e)}" | |
# print(error_msg) | |
# # Store first error if not set | |
# if first_error is None: | |
# first_error = error_msg | |
# # Check if we should try next instance | |
# if is_retryable_error(e): | |
# current_instance_index += 1 | |
# continue | |
# else: | |
# # Non-retryable error - return immediately | |
# return { | |
# "error": "Non-retryable error occurred", | |
# "details": str(e), | |
# "instance": current_instance_index | |
# } | |
# # All instances exhausted | |
# error_response = { | |
# "error": "All API instances failed", | |
# "details": first_error or "Unknown error", | |
# "attempted_instances": current_instance_index | |
# } | |
# print(error_response) | |
# return error_response |