FastApi / gemini_langchain_agent.py
Soumik555's picture
single tooltip component
a7202b3
import os
import re
import uuid
from langchain_google_genai import ChatGoogleGenerativeAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from dotenv import load_dotenv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import datetime as dt
# Set the backend for matplotlib to 'Agg' to avoid GUI issues
matplotlib.use('Agg')
load_dotenv()
model_name = 'gemini-2.0-flash' # Specify the model name
google_api_keys = os.getenv("GEMINI_API_KEYS").split(",")
# Create pre-initialized LLM instances
llm_instances = [
ChatGoogleGenerativeAI(model=model_name, api_key=key)
for key in google_api_keys
]
current_instance_index = 0 # Track current instance being used
def create_agent(llm, data, tools):
"""Create agent with tool names"""
return create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=tools,
return_intermediate_steps=True
)
def _prompt_generator(question: str, chart_required: bool):
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
3. **Communication:** Provide concise, professional, and well-structured responses.
4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
**Query:** {question}
"""
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
2. Visualization requirements:
- Adjust font sizes, rotate labels (45° if needed), truncate for readability
- Figure size: (12, 6)
- Descriptive titles (fontsize=14)
- Colorblind-friendly palettes
- Do not use any visualization library other than matplotlib or seaborn
3. File handling rules:
- Create MAXIMUM 2 charts if absolutely necessary
- For multiple charts:
* Arrange in grid format (2x1 vertical layout preferred)
* Use SAME unique_id with suffixes:
- f"{{unique_id}}_1.png"
- f"{{unique_id}}_2.png"
- Save EXCLUSIVELY to "generated_charts" folder
- File naming: f"chart_{{unique_id}}.png" (for single chart)
4. FINAL OUTPUT MUST BE:
- For single chart: f"generated_charts/chart_{{unique_id}}.png"
- For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
- **ONLY return this full path string, nothing else**
**Query:** {question}
IMPORTANT:
- Generate the unique_id FIRST before any operations
- Use THE SAME unique_id throughout entire process
- NEVER generate new UUIDs after initial creation
- Return EXACT filepath string of the final saved chart
"""
if chart_required:
return ChatPromptTemplate.from_template(chart_prompt)
else:
return ChatPromptTemplate.from_template(chat_prompt)
def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool):
global current_instance_index
data = pd.read_csv(csv_url)
# Try all available instances
while current_instance_index < len(llm_instances):
try:
llm = llm_instances[current_instance_index]
print(f"Using LLM instance index {current_instance_index}")
# Create tool with validated name
tool = PythonAstREPLTool(
locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid,
"dt": dt
},
)
agent = create_agent(llm, data, [tool])
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
output = result.get("output")
if output is None:
raise ValueError("Received None response from agent")
return output
except Exception as e:
print(f"Error using LLM instance index {current_instance_index}: {e}")
current_instance_index += 1
print("All LLM instances have been exhausted.")
return None