FastApi / gemini_langchain_agent.py
Soumik555's picture
Cron job
2ccbdb1
raw
history blame
13.1 kB
import os
import re
import uuid
from langchain_google_genai import ChatGoogleGenerativeAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from dotenv import load_dotenv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import datetime as dt
# Set the backend for matplotlib to 'Agg' to avoid GUI issues
matplotlib.use('Agg')
load_dotenv()
model_name = 'gemini-2.0-flash' # Specify the model name
google_api_keys = os.getenv("GEMINI_API_KEYS").split(",")
# Create pre-initialized LLM instances
llm_instances = [
ChatGoogleGenerativeAI(model=model_name, api_key=key)
for key in google_api_keys
]
current_instance_index = 0 # Track current instance being used
def create_agent(llm, data, tools):
"""Create agent with tool names"""
return create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=tools,
return_intermediate_steps=True
)
def _prompt_generator(question: str, chart_required: bool):
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
3. **Communication:** Provide concise, professional, and well-structured responses.
4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
**Query:** {question}
"""
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
2. Visualization requirements:
- Adjust font sizes, rotate labels (45° if needed), truncate for readability
- Figure size: (12, 6)
- Descriptive titles (fontsize=14)
- Colorblind-friendly palettes
- Do not use any visualization library other than matplotlib or seaborn
3. File handling rules:
- Create MAXIMUM 2 charts if absolutely necessary
- For multiple charts:
* Arrange in grid format (2x1 vertical layout preferred)
* Use SAME unique_id with suffixes:
- f"{{unique_id}}_1.png"
- f"{{unique_id}}_2.png"
- Save EXCLUSIVELY to "generated_charts" folder
- File naming: f"chart_{{unique_id}}.png" (for single chart)
4. FINAL OUTPUT MUST BE:
- For single chart: f"generated_charts/chart_{{unique_id}}.png"
- For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
- **ONLY return this full path string, nothing else**
**Query:** {question}
IMPORTANT:
- Generate the unique_id FIRST before any operations
- Use THE SAME unique_id throughout entire process
- NEVER generate new UUIDs after initial creation
- Return EXACT filepath string of the final saved chart
"""
if chart_required:
return ChatPromptTemplate.from_template(chart_prompt)
else:
return ChatPromptTemplate.from_template(chat_prompt)
def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool):
global current_instance_index
data = pd.read_csv(csv_url)
# Try all available instances
while current_instance_index < len(llm_instances):
try:
llm = llm_instances[current_instance_index]
print(f"Using LLM instance index {current_instance_index}")
# Create tool with validated name
tool = PythonAstREPLTool(
locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid,
"dt": dt
},
)
agent = create_agent(llm, data, [tool])
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
output = result.get("output")
if output is None:
raise ValueError("Received None response from agent")
return output
except Exception as e:
print(f"Error using LLM instance index {current_instance_index}: {e}")
current_instance_index += 1
print("All LLM instances have been exhausted.")
return None
# import os
# import re
# import uuid
# from langchain_google_genai import ChatGoogleGenerativeAI
# import pandas as pd
# from langchain_core.prompts import ChatPromptTemplate
# from langchain_experimental.tools import PythonAstREPLTool
# from langchain_experimental.agents import create_pandas_dataframe_agent
# from dotenv import load_dotenv
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib
# import seaborn as sns
# import datetime as dt
# # Set the backend for matplotlib to 'Agg' to avoid GUI issues
# matplotlib.use('Agg')
# load_dotenv()
# model_name = 'gemini-2.0-flash' # Specify the model name
# google_api_keys = os.getenv("GEMINI_API_KEYS").split(",")
# # Create pre-initialized LLM instances
# llm_instances = [
# ChatGoogleGenerativeAI(model=model_name, api_key=key)
# for key in google_api_keys
# ]
# current_instance_index = 0 # Track current instance being used
# def is_retryable_error(error: Exception) -> bool:
# """Check if the error should trigger a retry with next instance"""
# error_str = str(error).lower()
# retry_conditions = [
# # Rate limiting and quota errors
# '429' in error_str,
# 'quota' in error_str,
# 'rate limit' in error_str,
# 'resource exhausted' in error_str,
# 'exceeded' in error_str,
# 'limit reached' in error_str,
# # Authentication and permission errors
# 'permission denied' in error_str,
# 'invalid api key' in error_str,
# 'authentication' in error_str,
# # Server errors
# '500' in error_str,
# '503' in error_str,
# 'service unavailable' in error_str,
# # Connection issues
# 'timeout' in error_str,
# 'connection' in error_str,
# # Content policy
# 'content policy' in error_str,
# 'safety' in error_str,
# 'blocked' in error_str
# ]
# return any(retry_conditions)
# def create_agent(llm, data, tools):
# """Create agent with tool names"""
# return create_pandas_dataframe_agent(
# llm,
# data,
# agent_type="tool-calling",
# verbose=True,
# allow_dangerous_code=True,
# extra_tools=tools,
# return_intermediate_steps=True
# )
# def _prompt_generator(question: str, chart_required: bool):
# chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
# 1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
# 2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
# 3. **Communication:** Provide concise, professional, and well-structured responses.
# 4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
# **Query:** {question}
# """
# chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
# 1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
# 2. Visualization requirements:
# - Adjust font sizes, rotate labels (45° if needed), truncate for readability
# - Figure size: (12, 6)
# - Descriptive titles (fontsize=14)
# - Colorblind-friendly palettes
# 3. File handling rules:
# - Create MAXIMUM 2 charts if absolutely necessary
# - For multiple charts:
# * Arrange in grid format (2x1 vertical layout preferred)
# * Use SAME unique_id with suffixes:
# - f"{{unique_id}}_1.png"
# - f"{{unique_id}}_2.png"
# - Save EXCLUSIVELY to "generated_charts" folder
# - File naming: f"chart_{{unique_id}}.png" (for single chart)
# 4. FINAL OUTPUT MUST BE:
# - For single chart: f"generated_charts/chart_{{unique_id}}.png"
# - For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
# - **ONLY return this full path string, nothing else**
# **Query:** {question}
# IMPORTANT:
# - Generate the unique_id FIRST before any operations
# - Use THE SAME unique_id throughout entire process
# - NEVER generate new UUIDs after initial creation
# - Return EXACT filepath string of the final saved chart
# """
# if chart_required:
# return ChatPromptTemplate.from_template(chart_prompt)
# else:
# return ChatPromptTemplate.from_template(chat_prompt)
# def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool):
# global current_instance_index
# data = pd.read_csv(csv_url)
# # Track first error in case all instances fail
# first_error = None
# while current_instance_index < len(llm_instances):
# try:
# llm = llm_instances[current_instance_index]
# print(f"Attempting with LLM instance {current_instance_index + 1}/{len(llm_instances)}")
# # Create tool with validated name
# tool = PythonAstREPLTool(
# locals={
# "df": data,
# "pd": pd,
# "np": np,
# "plt": plt,
# "sns": sns,
# "matplotlib": matplotlib,
# "uuid": uuid,
# "dt": dt
# },
# )
# agent = create_agent(llm, data, [tool])
# prompt = _prompt_generator(question, chart_required)
# result = agent.invoke({"input": prompt})
# output = result.get("output")
# if output is None:
# raise ValueError("Received None response from agent")
# if isinstance(output, str) and any(err in output.lower() for err in ['quota', 'limit', 'exhausted']):
# raise ValueError(f"API limitation detected in response: {output}")
# return output
# except Exception as e:
# error_msg = f"Error with instance {current_instance_index}: {str(e)}"
# print(error_msg)
# # Store first error if not set
# if first_error is None:
# first_error = error_msg
# # Check if we should try next instance
# if is_retryable_error(e):
# current_instance_index += 1
# continue
# else:
# # Non-retryable error - return immediately
# return {
# "error": "Non-retryable error occurred",
# "details": str(e),
# "instance": current_instance_index
# }
# # All instances exhausted
# error_response = {
# "error": "All API instances failed",
# "details": first_error or "Unknown error",
# "attempted_instances": current_instance_index
# }
# print(error_response)
# return error_response