stay hard
Browse files- .env +1 -0
- controller.py +11 -3
- rethink_gemini_agents/gemini_langchain_service.py +212 -0
.env
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
GOOGLE_GENERATIVE_AI_API_KEYS=AIzaSyC6CuXP7oMHbQymv5desJ7HJypSPisYN7s,AIzaSyAzV1YGajXhC2N8n8b3bgU1PHUXNWdZiUk,AIzaSyAYvv5urC0lhzNYYO1f4a4EYqTsZrmubrM,AIzaSyD7VsIKjtNBlQUWXQ_bIFbl240f2AUT7nc,AIzaSyCFnFsqplkNeQFjRh2EhkK90t48wkmyJQU
|
2 |
GOOGLE_GENERATIVE_AI_MODEL=gemini-2.0-flash-lite-preview-02-05
|
|
|
3 |
|
4 |
# Pandas API keys
|
5 |
PANDASAI_API_KEYS=$2a$10$VVwPEnzFxnEnJhk2u5ef1ewTuT3rNK59QpYQWAhUY29FHH4b7fwNC,$2a$10$5ikmN9RtNWHvP8aLnHfm.epO/XhVF1Pvk1Chy2Fqa.4x232a374xK,$2a$10$aAvr1DH3Pt3KLPDYa.JED..d83Pl6M4xnQd6uY8fadNkqSEv9KaYK,$2a$10$tJkqyS9Us36ernP1N4/8dOE088rCm7MC3gIj2RQMlaalY34EkkDHy,$2a$10$V0tThT/XnmlHbJucM00yN.hxz9r3ZwVqe0sQRQwDZAGHmhMq81D7O,$2a$10$d9vj8iPtD/L/i2B5AhiKTexOSpZ52XZTRUDkZa4p0vnI6RCj7f0K2,$2a$10$PZdCvVJB8301iDIZrZ2z8uB9d68kaBeOjaOIbbXGgqlZ2frbTm0eG,$2a$10$SHK.YTrTQcol/yM/RD8tZOcIF2fUTXtaETpDo8G0At90NxQ1HGk.C,$2a$10$QYz2Fp2fFZNq80HjAC/Okuy/PZFMgGgpPuQAyFDVtvB0G9bCn8Cee,$2a$10$SGY3HoCX0jbBXHSbpwGH1OEC/yPwT5792MjSZeWYVLew52pE4gR0y,$2a$10$QHPpvXwCXhHtKyx4jWMTh.8Mz1azTEQbDdDMpmikOzdgKtFfOq3FG,$2a$10$KoTsqdLPNIBiLRHWUg/6guqxNrB4ByljnMDTN0HJXmGl.PagdxpGm,$2a$10$ERsxnbIwk0LOMqmFX1SfjuMSXzh5gsBqm1BnYXFNEBAS3J1AfK24m,$2a$10$zwX4F0/pxXgmuAfDteFlHeXswX8cvVAvkv8mBAJ4WLvAEaUM3v266,$2a$10$LPA4FUIjg6CbZYEhi3NLRuY2Yar5SbT9gYoQ/oZuPaFUxNUyaJ/ii,$2a$10$kLDISr9ivaqcYiAZ1TmBOeclXK0C5a/LPPB3Rsxme19NwVPhznQya,$2a$10$qpoxy4k4sQya0tY7/lSEkuEuwVQGEl757A.jVPGNEh6p5tN6Yofyq,$2a$10$TDndpw.NWwx2k5X.9eI30uAaga8pbYO/erUEblVGcj6ydzSgzdVde,$2a$10$TtZtCWXgVSUhaNMMsuOjLuC6tCY1GTzUR/PvIUdowXYQdmefgpvbW,$2a$10$Orj1ZiURJkREK30gdwEYLeV7mY657jJhif8SckIPdvctjkWHXHrq6,$2a$10$CxEXDLjFtK1.nE9GuIt1duxLbvYtz2EA7x1LqddNF44kKVcc8aGZC
|
|
|
1 |
GOOGLE_GENERATIVE_AI_API_KEYS=AIzaSyC6CuXP7oMHbQymv5desJ7HJypSPisYN7s,AIzaSyAzV1YGajXhC2N8n8b3bgU1PHUXNWdZiUk,AIzaSyAYvv5urC0lhzNYYO1f4a4EYqTsZrmubrM,AIzaSyD7VsIKjtNBlQUWXQ_bIFbl240f2AUT7nc,AIzaSyCFnFsqplkNeQFjRh2EhkK90t48wkmyJQU
|
2 |
GOOGLE_GENERATIVE_AI_MODEL=gemini-2.0-flash-lite-preview-02-05
|
3 |
+
GOOGLE_GENERATIVE_AI_MODEL_LANGCHAIN_AGENT=gemini-2.0-flash
|
4 |
|
5 |
# Pandas API keys
|
6 |
PANDASAI_API_KEYS=$2a$10$VVwPEnzFxnEnJhk2u5ef1ewTuT3rNK59QpYQWAhUY29FHH4b7fwNC,$2a$10$5ikmN9RtNWHvP8aLnHfm.epO/XhVF1Pvk1Chy2Fqa.4x232a374xK,$2a$10$aAvr1DH3Pt3KLPDYa.JED..d83Pl6M4xnQd6uY8fadNkqSEv9KaYK,$2a$10$tJkqyS9Us36ernP1N4/8dOE088rCm7MC3gIj2RQMlaalY34EkkDHy,$2a$10$V0tThT/XnmlHbJucM00yN.hxz9r3ZwVqe0sQRQwDZAGHmhMq81D7O,$2a$10$d9vj8iPtD/L/i2B5AhiKTexOSpZ52XZTRUDkZa4p0vnI6RCj7f0K2,$2a$10$PZdCvVJB8301iDIZrZ2z8uB9d68kaBeOjaOIbbXGgqlZ2frbTm0eG,$2a$10$SHK.YTrTQcol/yM/RD8tZOcIF2fUTXtaETpDo8G0At90NxQ1HGk.C,$2a$10$QYz2Fp2fFZNq80HjAC/Okuy/PZFMgGgpPuQAyFDVtvB0G9bCn8Cee,$2a$10$SGY3HoCX0jbBXHSbpwGH1OEC/yPwT5792MjSZeWYVLew52pE4gR0y,$2a$10$QHPpvXwCXhHtKyx4jWMTh.8Mz1azTEQbDdDMpmikOzdgKtFfOq3FG,$2a$10$KoTsqdLPNIBiLRHWUg/6guqxNrB4ByljnMDTN0HJXmGl.PagdxpGm,$2a$10$ERsxnbIwk0LOMqmFX1SfjuMSXzh5gsBqm1BnYXFNEBAS3J1AfK24m,$2a$10$zwX4F0/pxXgmuAfDteFlHeXswX8cvVAvkv8mBAJ4WLvAEaUM3v266,$2a$10$LPA4FUIjg6CbZYEhi3NLRuY2Yar5SbT9gYoQ/oZuPaFUxNUyaJ/ii,$2a$10$kLDISr9ivaqcYiAZ1TmBOeclXK0C5a/LPPB3Rsxme19NwVPhznQya,$2a$10$qpoxy4k4sQya0tY7/lSEkuEuwVQGEl757A.jVPGNEh6p5tN6Yofyq,$2a$10$TDndpw.NWwx2k5X.9eI30uAaga8pbYO/erUEblVGcj6ydzSgzdVde,$2a$10$TtZtCWXgVSUhaNMMsuOjLuC6tCY1GTzUR/PvIUdowXYQdmefgpvbW,$2a$10$Orj1ZiURJkREK30gdwEYLeV7mY657jJhif8SckIPdvctjkWHXHrq6,$2a$10$CxEXDLjFtK1.nE9GuIt1duxLbvYtz2EA7x1LqddNF44kKVcc8aGZC
|
controller.py
CHANGED
@@ -26,6 +26,7 @@ import matplotlib.pyplot as plt
|
|
26 |
import matplotlib
|
27 |
import seaborn as sns
|
28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
|
|
29 |
from rethink_gemini_agents.rethink_chat import gemini_llm_chat
|
30 |
from util_service import _prompt_generator, process_answer
|
31 |
from fastapi.middleware.cors import CORSMiddleware
|
@@ -295,9 +296,16 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
295 |
csv_url = request.get("csv_url")
|
296 |
decoded_url = unquote(csv_url)
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
if if_initial_chat_question(query):
|
303 |
answer = await asyncio.to_thread(
|
|
|
26 |
import matplotlib
|
27 |
import seaborn as sns
|
28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
29 |
+
from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chat
|
30 |
from rethink_gemini_agents.rethink_chat import gemini_llm_chat
|
31 |
from util_service import _prompt_generator, process_answer
|
32 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
296 |
csv_url = request.get("csv_url")
|
297 |
decoded_url = unquote(csv_url)
|
298 |
|
299 |
+
if if_initial_chat_question(query):
|
300 |
+
answer = await asyncio.to_thread(
|
301 |
+
langchain_gemini_csv_chat, decoded_url, query, False
|
302 |
+
)
|
303 |
+
logger.info("gemini langchain_answer --> ", answer)
|
304 |
+
return {"answer": jsonable_encoder(answer)}
|
305 |
+
|
306 |
+
gemini_answer = await asyncio.to_thread(gemini_llm_chat, decoded_url, query)
|
307 |
+
logger.info("gemini_answer --> ", gemini_answer)
|
308 |
+
return {"answer": gemini_answer}
|
309 |
|
310 |
if if_initial_chat_question(query):
|
311 |
answer = await asyncio.to_thread(
|
rethink_gemini_agents/gemini_langchain_service.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import uuid
|
4 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
5 |
+
import pandas as pd
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langchain_experimental.tools import PythonAstREPLTool
|
8 |
+
from langchain_experimental.agents import create_pandas_dataframe_agent
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import matplotlib
|
13 |
+
import seaborn as sns
|
14 |
+
|
15 |
+
|
16 |
+
# Set the backend for matplotlib to 'Agg' to avoid GUI issues
|
17 |
+
matplotlib.use('Agg')
|
18 |
+
|
19 |
+
load_dotenv()
|
20 |
+
model_name = os.getenv("GOOGLE_GENERATIVE_AI_MODEL_LANGCHAIN_AGENT")
|
21 |
+
google_api_keys = os.getenv("GOOGLE_GENERATIVE_AI_API_KEYS").split(",")
|
22 |
+
current_key_index = 0 # Global index for API keys
|
23 |
+
|
24 |
+
|
25 |
+
def _prompt_generator(question: str, chart_required: bool):
|
26 |
+
|
27 |
+
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
|
28 |
+
|
29 |
+
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
|
30 |
+
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
|
31 |
+
3. **Communication:** Provide concise, professional, and well-structured responses.
|
32 |
+
4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
|
33 |
+
|
34 |
+
**Query:** {question}
|
35 |
+
|
36 |
+
"""
|
37 |
+
|
38 |
+
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
|
39 |
+
|
40 |
+
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
|
41 |
+
2. Visualization requirements:
|
42 |
+
- Adjust font sizes, rotate labels (45° if needed), truncate for readability
|
43 |
+
- Figure size: (12, 6)
|
44 |
+
- Descriptive titles (fontsize=14)
|
45 |
+
- Colorblind-friendly palettes
|
46 |
+
3. File handling rules:
|
47 |
+
- Create MAXIMUM 2 charts if absolutely necessary
|
48 |
+
- For multiple charts:
|
49 |
+
* Arrange in grid format (2x1 vertical layout preferred)
|
50 |
+
* Use SAME unique_id with suffixes:
|
51 |
+
- f"{{unique_id}}_1.png"
|
52 |
+
- f"{{unique_id}}_2.png"
|
53 |
+
- Save EXCLUSIVELY to "generated_charts" folder
|
54 |
+
- File naming: f"chart_{{unique_id}}.png" (for single chart)
|
55 |
+
4. FINAL OUTPUT MUST BE:
|
56 |
+
- For single chart: f"generated_charts/chart_{{unique_id}}.png"
|
57 |
+
- For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
|
58 |
+
- ONLY return this full path string, nothing else
|
59 |
+
|
60 |
+
**Query:** {question}
|
61 |
+
|
62 |
+
IMPORTANT:
|
63 |
+
- Generate the unique_id FIRST before any operations
|
64 |
+
- Use THE SAME unique_id throughout entire process
|
65 |
+
- NEVER generate new UUIDs after initial creation
|
66 |
+
- Return EXACT filepath string of the final saved chart
|
67 |
+
"""
|
68 |
+
|
69 |
+
|
70 |
+
if chart_required:
|
71 |
+
return ChatPromptTemplate.from_template(chart_prompt)
|
72 |
+
else:
|
73 |
+
return ChatPromptTemplate.from_template(chat_prompt)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
def langchain_gemini_csv_chat(csv_url: str, question: str, chart_required: bool):
|
78 |
+
global current_key_index
|
79 |
+
|
80 |
+
data = pd.read_csv(csv_url)
|
81 |
+
# Try each API key until a successful response is generated or keys run out
|
82 |
+
attempts = 0
|
83 |
+
total_keys = len(google_api_keys)
|
84 |
+
while attempts < total_keys:
|
85 |
+
try:
|
86 |
+
# Select the current API key
|
87 |
+
api_key = google_api_keys[current_key_index]
|
88 |
+
print(f"Using API key index {current_key_index}")
|
89 |
+
|
90 |
+
# Initialize the LLM with the current API key
|
91 |
+
llm = ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
|
92 |
+
|
93 |
+
# Prepare the Python REPL tool with the dataframe and necessary libraries
|
94 |
+
tool = PythonAstREPLTool(locals={
|
95 |
+
"df": data,
|
96 |
+
"pd": pd,
|
97 |
+
"np": np,
|
98 |
+
"plt": plt, # Ensure plt is available
|
99 |
+
"sns": sns,
|
100 |
+
"matplotlib": matplotlib,
|
101 |
+
"uuid": uuid,
|
102 |
+
})
|
103 |
+
|
104 |
+
# Create the pandas agent with the provided tools and settings
|
105 |
+
agent = create_pandas_dataframe_agent(
|
106 |
+
llm,
|
107 |
+
data,
|
108 |
+
agent_type="openai-tools",
|
109 |
+
verbose=True,
|
110 |
+
allow_dangerous_code=True,
|
111 |
+
extra_tools=[tool],
|
112 |
+
return_intermediate_steps=True
|
113 |
+
)
|
114 |
+
|
115 |
+
chat_prompt = _prompt_generator(question, chart_required)
|
116 |
+
# Attempt to invoke the agent with the question
|
117 |
+
result = agent.invoke({"input": chat_prompt})
|
118 |
+
# If successful, return the output
|
119 |
+
return result.get("output")
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
# Log the error along with the current API key index
|
123 |
+
print(f"Error using API key index {current_key_index}: {e}")
|
124 |
+
|
125 |
+
# Move to the next API key
|
126 |
+
current_key_index += 1
|
127 |
+
attempts += 1
|
128 |
+
|
129 |
+
# If all keys have been exhausted, exit the loop
|
130 |
+
if current_key_index >= total_keys:
|
131 |
+
print("All API keys have been exhausted.")
|
132 |
+
return None
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
def langchain_gemini_csv_chart(csv_url: str, question: str, chart_required: bool):
|
138 |
+
global current_key_index
|
139 |
+
data = pd.read_csv(csv_url)
|
140 |
+
|
141 |
+
# Try each API key until a successful response is generated or keys run out
|
142 |
+
attempts = 0
|
143 |
+
total_keys = len(google_api_keys)
|
144 |
+
while attempts < total_keys:
|
145 |
+
try:
|
146 |
+
# Select the current API key
|
147 |
+
api_key = google_api_keys[current_key_index]
|
148 |
+
print(f"Using API key index {current_key_index}")
|
149 |
+
|
150 |
+
# Initialize the LLM with the current API key
|
151 |
+
llm = ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
|
152 |
+
|
153 |
+
# Prepare the Python REPL tool with the dataframe and necessary libraries
|
154 |
+
tool = PythonAstREPLTool(locals={
|
155 |
+
"df": data,
|
156 |
+
"pd": pd,
|
157 |
+
"np": np,
|
158 |
+
"plt": plt, # Ensure plt is available
|
159 |
+
"sns": sns,
|
160 |
+
"matplotlib": matplotlib
|
161 |
+
})
|
162 |
+
|
163 |
+
# Create the pandas agent with the provided tools and settings
|
164 |
+
agent = create_pandas_dataframe_agent(
|
165 |
+
llm,
|
166 |
+
data,
|
167 |
+
agent_type="openai-tools",
|
168 |
+
verbose=True,
|
169 |
+
allow_dangerous_code=True,
|
170 |
+
extra_tools=[tool],
|
171 |
+
return_intermediate_steps=True
|
172 |
+
)
|
173 |
+
|
174 |
+
chart_prompt = _prompt_generator(question, chart_required)
|
175 |
+
# Attempt to invoke the agent with the question
|
176 |
+
result = agent.invoke({"input": chart_prompt})
|
177 |
+
# If successful, return the output
|
178 |
+
return result.get("output")
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
# Log the error along with the current API key index
|
182 |
+
print(f"Error using API key index {current_key_index}: {e}")
|
183 |
+
|
184 |
+
# Move to the next API key
|
185 |
+
current_key_index += 1
|
186 |
+
attempts += 1
|
187 |
+
|
188 |
+
# If all keys have been exhausted, exit the loop
|
189 |
+
if current_key_index >= total_keys:
|
190 |
+
print("All API keys have been exhausted.")
|
191 |
+
return None
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
# Example usage:
|
200 |
+
# if __name__ == "__main__":
|
201 |
+
# csv_url = "./documents/titanic.csv"
|
202 |
+
# question = "Create 2 beautiful visualizations of the data using different chart styles (line, bar etc..), return file names"
|
203 |
+
# output = langchain_gemini_csv_chat(csv_url, question, True)
|
204 |
+
# print("Agent output:", output)
|
205 |
+
|
206 |
+
# # Define a regex pattern that matches 'temp' followed by one or more digits.
|
207 |
+
# pattern = r"temp\d+"
|
208 |
+
|
209 |
+
# # Use re.findall to extract all occurrences that match the pattern.
|
210 |
+
# names = re.findall(pattern, output)
|
211 |
+
|
212 |
+
# print(names)
|