File size: 5,392 Bytes
7c5f30e 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa 2ccbdb1 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa 7c5f30e 30e7daa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import re
import uuid
from langchain_google_genai import ChatGoogleGenerativeAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langchain_experimental.tools import PythonAstREPLTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from dotenv import load_dotenv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import datetime as dt
# Set the backend for matplotlib to 'Agg' to avoid GUI issues
matplotlib.use('Agg')
load_dotenv()
model_name = 'gemini-2.0-flash' # Specify the model name
google_api_keys = os.getenv("GEMINI_API_KEYS").split(",")
# Create pre-initialized LLM instances
llm_instances = [
ChatGoogleGenerativeAI(model=model_name, api_key=key)
for key in google_api_keys
]
current_instance_index = 0 # Track current instance being used
def create_agent(llm, data, tools):
"""Create agent with tool names"""
return create_pandas_dataframe_agent(
llm,
data,
agent_type="tool-calling",
verbose=True,
allow_dangerous_code=True,
extra_tools=tools,
return_intermediate_steps=True
)
def _prompt_generator(question: str, chart_required: bool):
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
3. **Communication:** Provide concise, professional, and well-structured responses.
4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
**Query:** {question}
"""
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
2. Visualization requirements:
- Adjust font sizes, rotate labels (45° if needed), truncate for readability
- Figure size: (12, 6)
- Descriptive titles (fontsize=14)
- Colorblind-friendly palettes
- Do not use any visualization library other than matplotlib or seaborn
3. File handling rules:
- Create MAXIMUM 2 charts if absolutely necessary
- For multiple charts:
* Arrange in grid format (2x1 vertical layout preferred)
* Use SAME unique_id with suffixes:
- f"{{unique_id}}_1.png"
- f"{{unique_id}}_2.png"
- Save EXCLUSIVELY to "generated_charts" folder
- File naming: f"chart_{{unique_id}}.png" (for single chart)
4. FINAL OUTPUT MUST BE:
- For single chart: f"generated_charts/chart_{{unique_id}}.png"
- For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
- **ONLY return this full path string, nothing else**
**Query:** {question}
IMPORTANT:
- Generate the unique_id FIRST before any operations
- Use THE SAME unique_id throughout entire process
- NEVER generate new UUIDs after initial creation
- Return EXACT filepath string of the final saved chart
"""
if chart_required:
return ChatPromptTemplate.from_template(chart_prompt)
else:
return ChatPromptTemplate.from_template(chat_prompt)
def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool):
global current_instance_index
data = pd.read_csv(csv_url)
# Try all available instances
while current_instance_index < len(llm_instances):
try:
llm = llm_instances[current_instance_index]
print(f"Using LLM instance index {current_instance_index}")
# Create tool with validated name
tool = PythonAstREPLTool(
locals={
"df": data,
"pd": pd,
"np": np,
"plt": plt,
"sns": sns,
"matplotlib": matplotlib,
"uuid": uuid,
"dt": dt
},
)
agent = create_agent(llm, data, [tool])
prompt = _prompt_generator(question, chart_required)
result = agent.invoke({"input": prompt})
output = result.get("output")
if output is None:
raise ValueError("Received None response from agent")
return output
except Exception as e:
print(f"Error using LLM instance index {current_instance_index}: {e}")
current_instance_index += 1
print("All LLM instances have been exhausted.")
return None
|