import os import uuid from langchain_google_genai import ChatGoogleGenerativeAI import pandas as pd from langchain_core.prompts import ChatPromptTemplate from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent from dotenv import load_dotenv import numpy as np import matplotlib.pyplot as plt import matplotlib import seaborn as sns import datetime as dt # Set the backend for matplotlib to 'Agg' to avoid GUI issues matplotlib.use('Agg') load_dotenv() model_name = 'gemini-2.0-flash' # Specify the model name google_api_keys = os.getenv("GEMINI_API_KEYS").split(",") # Create pre-initialized LLM instances llm_instances = [ ChatGoogleGenerativeAI(model=model_name, api_key=key) for key in google_api_keys ] current_instance_index = 0 # Track current instance being used def create_agent(llm, data, tools): """Create agent with tool names""" return create_pandas_dataframe_agent( llm, data, agent_type="tool-calling", verbose=True, allow_dangerous_code=True, extra_tools=tools, return_intermediate_steps=True ) def _prompt_generator(question: str, chart_required: bool, csv_url: str): chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines: 1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis. 2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability. 3. **Communication:** Provide concise, professional, and well-structured responses. 4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.) 5. Always use pd.read_csv({csv_url}) to read the CSV file. **Query:** {question} """ chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY: 1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex 2. Visualization requirements: - Adjust font sizes, rotate labels (45° if needed), truncate for readability - Figure size: (12, 6) - Descriptive titles (fontsize=14) - Colorblind-friendly palettes - Do not use any visualization library other than matplotlib or seaborn 3. File handling rules: - Create MAXIMUM 2 charts if absolutely necessary - For multiple charts: * Arrange in grid format (2x1 vertical layout preferred) * Use SAME unique_id with suffixes: - f"{{unique_id}}_1.png" - f"{{unique_id}}_2.png" - Save EXCLUSIVELY to "generated_charts" folder - File naming: f"chart_{{unique_id}}.png" (for single chart) 4. FINAL OUTPUT MUST BE: - For single chart: f"generated_charts/chart_{{unique_id}}.png" - For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image) - **ONLY return this full path string, nothing else** **Query:** {question} IMPORTANT: - Generate the unique_id FIRST before any operations - Use THE SAME unique_id throughout entire process - NEVER generate new UUIDs after initial creation - Return EXACT filepath string of the final saved chart - Always use pd.read_csv({csv_url}) to read the CSV file """ if chart_required: return ChatPromptTemplate.from_template(chart_prompt) else: return ChatPromptTemplate.from_template(chat_prompt) def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool): global current_instance_index data = pd.read_csv(csv_url) # Try all available instances while current_instance_index < len(llm_instances): try: llm = llm_instances[current_instance_index] print(f"Using LLM instance index {current_instance_index}") # Create tool with validated name tool = PythonAstREPLTool( locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib, "uuid": uuid, "dt": dt }, ) agent = create_agent(llm, data, [tool]) prompt = _prompt_generator(question, chart_required, csv_url) result = agent.invoke({"input": prompt}) output = result.get("output") if output is None: raise ValueError("Received None response from agent") return output except Exception as e: print(f"Error using LLM instance index {current_instance_index}: {e}") current_instance_index += 1 print("All LLM instances have been exhausted.") return None