Fixed langchain gemini agent underscore issue
Browse files- gemini_langchain_agent.py +127 -0
- orchestrator_functions.py +64 -56
gemini_langchain_agent.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import uuid
|
4 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
5 |
+
import pandas as pd
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langchain_experimental.tools import PythonAstREPLTool
|
8 |
+
from langchain_experimental.agents import create_pandas_dataframe_agent
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import matplotlib
|
13 |
+
import seaborn as sns
|
14 |
+
import datetime as dt
|
15 |
+
|
16 |
+
# Set the backend for matplotlib to 'Agg' to avoid GUI issues
|
17 |
+
matplotlib.use('Agg')
|
18 |
+
|
19 |
+
load_dotenv()
|
20 |
+
model_name = 'gemini-2.0-flash' # Model name for Google Generative AI
|
21 |
+
google_api_keys = os.getenv("GEMINI_API_KEYS").split(",")
|
22 |
+
current_key_index = 0 # Global index for API keys
|
23 |
+
|
24 |
+
def create_agent(llm, data, tools):
|
25 |
+
"""Create agent with tool names"""
|
26 |
+
|
27 |
+
return create_pandas_dataframe_agent(
|
28 |
+
llm,
|
29 |
+
data,
|
30 |
+
agent_type="tool-calling",
|
31 |
+
verbose=True,
|
32 |
+
allow_dangerous_code=True,
|
33 |
+
extra_tools=tools,
|
34 |
+
return_intermediate_steps=True
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def _prompt_generator(question: str, chart_required: bool):
|
39 |
+
|
40 |
+
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
|
41 |
+
|
42 |
+
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
|
43 |
+
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
|
44 |
+
3. **Communication:** Provide concise, professional, and well-structured responses.
|
45 |
+
4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
|
46 |
+
|
47 |
+
**Query:** {question}
|
48 |
+
|
49 |
+
"""
|
50 |
+
|
51 |
+
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
|
52 |
+
|
53 |
+
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
|
54 |
+
2. Visualization requirements:
|
55 |
+
- Adjust font sizes, rotate labels (45° if needed), truncate for readability
|
56 |
+
- Figure size: (12, 6)
|
57 |
+
- Descriptive titles (fontsize=14)
|
58 |
+
- Colorblind-friendly palettes
|
59 |
+
3. File handling rules:
|
60 |
+
- Create MAXIMUM 2 charts if absolutely necessary
|
61 |
+
- For multiple charts:
|
62 |
+
* Arrange in grid format (2x1 vertical layout preferred)
|
63 |
+
* Use SAME unique_id with suffixes:
|
64 |
+
- f"{{unique_id}}_1.png"
|
65 |
+
- f"{{unique_id}}_2.png"
|
66 |
+
- Save EXCLUSIVELY to "generated_charts" folder
|
67 |
+
- File naming: f"chart_{{unique_id}}.png" (for single chart)
|
68 |
+
4. FINAL OUTPUT MUST BE:
|
69 |
+
- For single chart: f"generated_charts/chart_{{unique_id}}.png"
|
70 |
+
- For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
|
71 |
+
- **ONLY return this full path string, nothing else**
|
72 |
+
|
73 |
+
**Query:** {question}
|
74 |
+
|
75 |
+
IMPORTANT:
|
76 |
+
- Generate the unique_id FIRST before any operations
|
77 |
+
- Use THE SAME unique_id throughout entire process
|
78 |
+
- NEVER generate new UUIDs after initial creation
|
79 |
+
- Return EXACT filepath string of the final saved chart
|
80 |
+
"""
|
81 |
+
|
82 |
+
|
83 |
+
if chart_required:
|
84 |
+
return ChatPromptTemplate.from_template(chart_prompt)
|
85 |
+
else:
|
86 |
+
return ChatPromptTemplate.from_template(chat_prompt)
|
87 |
+
|
88 |
+
def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool):
|
89 |
+
global current_key_index
|
90 |
+
data = pd.read_csv(csv_url)
|
91 |
+
|
92 |
+
attempts = 0
|
93 |
+
total_keys = len(google_api_keys)
|
94 |
+
while attempts < total_keys:
|
95 |
+
try:
|
96 |
+
api_key = google_api_keys[current_key_index]
|
97 |
+
print(f"Using API key index {current_key_index}")
|
98 |
+
|
99 |
+
llm = ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
|
100 |
+
|
101 |
+
# Create tool with validated name
|
102 |
+
tool = PythonAstREPLTool(
|
103 |
+
locals={
|
104 |
+
"df": data,
|
105 |
+
"pd": pd,
|
106 |
+
"np": np,
|
107 |
+
"plt": plt,
|
108 |
+
"sns": sns,
|
109 |
+
"matplotlib": matplotlib,
|
110 |
+
"uuid": uuid,
|
111 |
+
"dt": dt
|
112 |
+
},
|
113 |
+
)
|
114 |
+
|
115 |
+
agent = create_agent(llm, data, [tool])
|
116 |
+
|
117 |
+
prompt = _prompt_generator(question, chart_required)
|
118 |
+
result = agent.invoke({"input": prompt})
|
119 |
+
return result.get("output")
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
print(f"Error using API key index {current_key_index}: {e}")
|
123 |
+
current_key_index = (current_key_index + 1) % total_keys
|
124 |
+
attempts += 1
|
125 |
+
|
126 |
+
print("All API keys have been exhausted.")
|
127 |
+
return None
|
orchestrator_functions.py
CHANGED
@@ -19,6 +19,7 @@ import numpy as np
|
|
19 |
import matplotlib.pyplot as plt
|
20 |
import matplotlib
|
21 |
import seaborn as sns
|
|
|
22 |
from supabase_service import upload_image_to_supabase
|
23 |
from util_service import _prompt_generator, process_answer
|
24 |
import matplotlib
|
@@ -385,78 +386,85 @@ async def csv_chart(csv_url: str, query: str):
|
|
385 |
|
386 |
|
387 |
|
388 |
-
async def csv_chat(csv_url: str, query: str):
|
389 |
-
|
390 |
-
|
391 |
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
|
408 |
-
|
409 |
-
|
410 |
|
411 |
-
|
412 |
-
|
413 |
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
|
423 |
|
424 |
|
425 |
|
426 |
|
427 |
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
|
432 |
-
#
|
433 |
-
|
434 |
-
#
|
435 |
-
#
|
436 |
|
437 |
-
#
|
438 |
-
#
|
439 |
|
440 |
-
#
|
441 |
-
#
|
442 |
|
443 |
-
#
|
444 |
|
445 |
-
|
446 |
-
|
|
|
447 |
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
#
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
|
460 |
-
|
461 |
-
|
462 |
-
|
|
|
19 |
import matplotlib.pyplot as plt
|
20 |
import matplotlib
|
21 |
import seaborn as sns
|
22 |
+
from gemini_langchain_agent import langchain_gemini_csv_handler
|
23 |
from supabase_service import upload_image_to_supabase
|
24 |
from util_service import _prompt_generator, process_answer
|
25 |
import matplotlib
|
|
|
386 |
|
387 |
|
388 |
|
389 |
+
# async def csv_chat(csv_url: str, query: str):
|
390 |
+
# try:
|
391 |
+
# updated_query = f"{query} and Do not show any charts or graphs."
|
392 |
|
393 |
+
# # Process with langchain_chat first
|
394 |
+
# try:
|
395 |
+
# lang_answer = await asyncio.to_thread(
|
396 |
+
# langchain_csv_chat, csv_url, query, False
|
397 |
+
# )
|
398 |
+
# if lang_answer is not None:
|
399 |
+
# return {"answer": jsonable_encoder(lang_answer)}
|
400 |
+
# raise Exception("Langchain failed to process")
|
401 |
+
# except Exception as langchain_error:
|
402 |
+
# print(f"Langchain error, falling back to Groq: {str(langchain_error)}")
|
403 |
|
404 |
+
# # Process with groq_chat if langchain fails
|
405 |
+
# try:
|
406 |
+
# groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
|
407 |
+
# print("groq_answer:", groq_answer)
|
408 |
|
409 |
+
# if process_answer(groq_answer) == "Empty response received.":
|
410 |
+
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
411 |
|
412 |
+
# if process_answer(groq_answer) or groq_answer is None:
|
413 |
+
# return {"answer": "error"}
|
414 |
|
415 |
+
# return {"answer": jsonable_encoder(groq_answer)}
|
416 |
+
# except Exception as groq_error:
|
417 |
+
# print(f"Groq processing error: {str(groq_error)}")
|
418 |
+
# return {"answer": "error"}
|
419 |
|
420 |
+
# except Exception as e:
|
421 |
+
# print(f"Error processing request: {str(e)}")
|
422 |
+
# return {"answer": "error"}
|
423 |
|
424 |
|
425 |
|
426 |
|
427 |
|
428 |
|
429 |
+
async def csv_chat(csv_url: str, query: str):
|
430 |
+
try:
|
431 |
+
updated_query = f"{query}"
|
432 |
|
433 |
+
# Process with gemini langchain chat first
|
434 |
+
try:
|
435 |
+
# groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
|
436 |
+
# print("groq_answer:", groq_answer)
|
437 |
|
438 |
+
# if process_answer(groq_answer) == "Empty response received." or groq_answer == None:
|
439 |
+
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
440 |
|
441 |
+
# if process_answer(groq_answer) or groq_answer == None:
|
442 |
+
# raise Exception("Groq response not usable, falling back to LangChain")
|
443 |
|
444 |
+
# return {"answer": jsonable_encoder(groq_answer)}
|
445 |
|
446 |
+
lc_gemini = await asyncio.to_thread(
|
447 |
+
langchain_gemini_csv_handler, csv_url, updated_query, False
|
448 |
+
)
|
449 |
|
450 |
+
if lc_gemini is not None:
|
451 |
+
return {"answer": jsonable_encoder(lc_gemini)}
|
452 |
+
|
453 |
+
except Exception as gemini_error:
|
454 |
+
print(f"Gemini error, falling back to Groq-LangChain: {str(gemini_error)}")
|
455 |
+
|
456 |
+
# Process with LangChain if Groq fails
|
457 |
+
try:
|
458 |
+
lang_answer = await asyncio.to_thread(
|
459 |
+
langchain_csv_chat, csv_url, query, False
|
460 |
+
)
|
461 |
+
if not process_answer(lang_answer):
|
462 |
+
return {"answer": jsonable_encoder(lang_answer)}
|
463 |
+
return {"answer": "Sorry, I couldn't find relevant data..."}
|
464 |
+
except Exception as langchain_error:
|
465 |
+
print(f"LangChain processing error: {str(langchain_error)}")
|
466 |
+
return {"answer": "error"}
|
467 |
|
468 |
+
except Exception as e:
|
469 |
+
print(f"Error processing request: {str(e)}")
|
470 |
+
return {"answer": "error"}
|