stay hard
Browse files- controller.py +32 -19
- rethink_gemini_agents/rethink_chart.py +2 -2
- rethink_gemini_agents/rethink_chat.py +24 -19
- util_service.py +1 -1
controller.py
CHANGED
@@ -26,7 +26,7 @@ import matplotlib.pyplot as plt
|
|
26 |
import matplotlib
|
27 |
import seaborn as sns
|
28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
29 |
-
from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chat
|
30 |
from rethink_gemini_agents.rethink_chart import gemini_llm_chart
|
31 |
from rethink_gemini_agents.rethink_chat import gemini_llm_chat
|
32 |
from util_service import _prompt_generator, process_answer
|
@@ -302,18 +302,19 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
302 |
langchain_gemini_csv_chat, decoded_url, query, False
|
303 |
)
|
304 |
logger.info("gemini langchain_answer --> ", answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
return {"answer": jsonable_encoder(answer)}
|
|
|
|
|
306 |
|
307 |
-
gemini_answer = await asyncio.to_thread(gemini_llm_chat, decoded_url, query)
|
308 |
-
logger.info("gemini_answer --> ", gemini_answer)
|
309 |
-
return {"answer": gemini_answer}
|
310 |
|
311 |
-
if if_initial_chat_question(query):
|
312 |
-
answer = await asyncio.to_thread(
|
313 |
-
langchain_csv_chat, decoded_url, query, False
|
314 |
-
)
|
315 |
-
logger.info("langchain_answer --> ", answer)
|
316 |
-
return {"answer": jsonable_encoder(answer)}
|
317 |
|
318 |
# Process with groq_chat first
|
319 |
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
@@ -799,26 +800,32 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
799 |
loop = asyncio.get_running_loop()
|
800 |
# First, try the langchain-based method if the question qualifies
|
801 |
|
802 |
-
|
|
|
803 |
langchain_gemini_chart_answer = await asyncio.to_thread(
|
804 |
langchain_gemini_csv_chat, csv_url, query, False
|
805 |
)
|
806 |
logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
|
807 |
if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
|
808 |
return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
|
809 |
-
|
810 |
-
gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
|
811 |
-
logger.info("gemini_answer --> ", gemini_answer)
|
812 |
-
return FileResponse(gemini_answer, media_type="image/png")
|
813 |
|
814 |
-
|
815 |
langchain_result = await loop.run_in_executor(
|
816 |
process_executor, langchain_csv_chart, csv_url, query, True
|
817 |
)
|
818 |
-
logger.info("
|
819 |
if isinstance(langchain_result, list) and len(langchain_result) > 0:
|
820 |
return FileResponse(langchain_result[0], media_type="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
821 |
|
|
|
822 |
# Next, try the groq-based method
|
823 |
groq_result = await loop.run_in_executor(
|
824 |
process_executor, groq_chart, csv_url, query
|
@@ -828,13 +835,19 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
828 |
return FileResponse(groq_result, media_type="image/png")
|
829 |
|
830 |
# Fallback: try langchain-based again
|
831 |
-
logger.error("Groq chart generation failed, trying langchain....")
|
832 |
langchain_paths = await loop.run_in_executor(
|
833 |
process_executor, langchain_csv_chart, csv_url, query, True
|
834 |
)
|
835 |
-
logger.info("
|
836 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
837 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
838 |
else:
|
839 |
logger.error("All chart generation methods failed")
|
840 |
return {"error": "All chart generation methods failed"}
|
|
|
26 |
import matplotlib
|
27 |
import seaborn as sns
|
28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
29 |
+
from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chart, langchain_gemini_csv_chat
|
30 |
from rethink_gemini_agents.rethink_chart import gemini_llm_chart
|
31 |
from rethink_gemini_agents.rethink_chat import gemini_llm_chat
|
32 |
from util_service import _prompt_generator, process_answer
|
|
|
302 |
langchain_gemini_csv_chat, decoded_url, query, False
|
303 |
)
|
304 |
logger.info("gemini langchain_answer --> ", answer)
|
305 |
+
if(answer and process_answer(answer)):
|
306 |
+
|
307 |
+
logger.error("Gemini chat initial query failed, trying groq based langchain_csv_chat...")
|
308 |
+
|
309 |
+
answer = await asyncio.to_thread(
|
310 |
+
langchain_csv_chat, decoded_url, query, False
|
311 |
+
)
|
312 |
+
logger.info("groq langchain_answer --> ", answer)
|
313 |
return {"answer": jsonable_encoder(answer)}
|
314 |
+
|
315 |
+
|
316 |
|
|
|
|
|
|
|
317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
# Process with groq_chat first
|
320 |
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
|
|
800 |
loop = asyncio.get_running_loop()
|
801 |
# First, try the langchain-based method if the question qualifies
|
802 |
|
803 |
+
|
804 |
+
if if_initial_chart_question(query):
|
805 |
langchain_gemini_chart_answer = await asyncio.to_thread(
|
806 |
langchain_gemini_csv_chat, csv_url, query, False
|
807 |
)
|
808 |
logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
|
809 |
if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
|
810 |
return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
|
|
|
|
|
|
|
|
|
811 |
|
812 |
+
logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
|
813 |
langchain_result = await loop.run_in_executor(
|
814 |
process_executor, langchain_csv_chart, csv_url, query, True
|
815 |
)
|
816 |
+
logger.info("groq langchain_answer -->", langchain_result)
|
817 |
if isinstance(langchain_result, list) and len(langchain_result) > 0:
|
818 |
return FileResponse(langchain_result[0], media_type="image/png")
|
819 |
+
|
820 |
+
logger.error("Groq Langchain chart generation failed for initial question, trying gemini rethink agent....")
|
821 |
+
|
822 |
+
# If not initial question then try gemini langchain method
|
823 |
+
gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
|
824 |
+
logger.info("gemini_answer --> ", gemini_answer)
|
825 |
+
if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
|
826 |
+
return FileResponse(gemini_answer, media_type="image/png")
|
827 |
|
828 |
+
logger.error("Gemini chart generation failed, trying groq....")
|
829 |
# Next, try the groq-based method
|
830 |
groq_result = await loop.run_in_executor(
|
831 |
process_executor, groq_chart, csv_url, query
|
|
|
835 |
return FileResponse(groq_result, media_type="image/png")
|
836 |
|
837 |
# Fallback: try langchain-based again
|
838 |
+
logger.error("Groq chart generation failed, trying langchain (Groq)....")
|
839 |
langchain_paths = await loop.run_in_executor(
|
840 |
process_executor, langchain_csv_chart, csv_url, query, True
|
841 |
)
|
842 |
+
logger.info("Groq langchain image paths -->", langchain_paths)
|
843 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
844 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
845 |
+
|
846 |
+
# Fallback: try gemini-based again
|
847 |
+
logger.error("Groq langchain chart generation failed, trying Gemini Langchain agent....")
|
848 |
+
gemini_paths = await asyncio.to_thread(langchain_gemini_csv_chart, csv_url, query, True)
|
849 |
+
if isinstance(gemini_paths, list) and len(gemini_paths) > 0:
|
850 |
+
return FileResponse(gemini_paths[0], media_type="image/png")
|
851 |
else:
|
852 |
logger.error("All chart generation methods failed")
|
853 |
return {"error": "All chart generation methods failed"}
|
rethink_gemini_agents/rethink_chart.py
CHANGED
@@ -232,10 +232,10 @@ def gemini_llm_chart(csv_url: str, query: str) -> str:
|
|
232 |
return chart_path
|
233 |
else:
|
234 |
print("Chart path not found")
|
235 |
-
return
|
236 |
else:
|
237 |
print("Unexpected result format:", type(result))
|
238 |
-
return
|
239 |
|
240 |
|
241 |
|
|
|
232 |
return chart_path
|
233 |
else:
|
234 |
print("Chart path not found")
|
235 |
+
return "Chart path not found"
|
236 |
else:
|
237 |
print("Unexpected result format:", type(result))
|
238 |
+
return "Chart path not found"
|
239 |
|
240 |
|
241 |
|
rethink_gemini_agents/rethink_chat.py
CHANGED
@@ -212,31 +212,36 @@ class RethinkAgent(BaseModel):
|
|
212 |
|
213 |
|
214 |
def gemini_llm_chat(csv_url: str, query: str) -> str:
|
|
|
|
|
215 |
# Assuming clean_data and RethinkAgent are defined elsewhere
|
216 |
-
|
217 |
-
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
|
224 |
-
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
processed = result.apply(handle_out_of_range_float).to_dict()
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
240 |
|
241 |
# uvicorn controller:app --host localhost --port 8000 --reload
|
242 |
|
|
|
212 |
|
213 |
|
214 |
def gemini_llm_chat(csv_url: str, query: str) -> str:
|
215 |
+
|
216 |
+
try:
|
217 |
# Assuming clean_data and RethinkAgent are defined elsewhere
|
218 |
+
df = clean_data(csv_url)
|
219 |
+
agent = RethinkAgent(df=df)
|
220 |
|
221 |
+
# Assuming API_KEYS is defined elsewhere
|
222 |
+
if not agent.initialize_model(API_KEYS):
|
223 |
+
print("Failed to initialize model with provided keys")
|
224 |
+
exit(1)
|
225 |
|
226 |
+
result = agent.execute_query(query)
|
227 |
|
228 |
+
# Process different response types
|
229 |
+
if isinstance(result, pd.DataFrame):
|
230 |
+
processed = result.apply(handle_out_of_range_float).to_dict(orient="records")
|
231 |
+
elif isinstance(result, pd.Series):
|
232 |
processed = result.apply(handle_out_of_range_float).to_dict()
|
233 |
+
elif isinstance(result, list):
|
234 |
+
processed = [handle_out_of_range_float(item) for item in result]
|
235 |
+
elif isinstance(result, dict):
|
236 |
+
processed = {k: handle_out_of_range_float(v) for k, v in result.items()}
|
237 |
+
else:
|
238 |
+
processed = {"answer": str(handle_out_of_range_float(result))}
|
239 |
|
240 |
+
logger.info(f"gemini processed result: {processed}")
|
241 |
+
return processed
|
242 |
+
except Exception as e:
|
243 |
+
logger.error(f"Error in gemini_llm_chat: {str(e)}")
|
244 |
+
return None
|
245 |
|
246 |
# uvicorn controller:app --host localhost --port 8000 --reload
|
247 |
|
util_service.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
import numpy as np
|
3 |
|
4 |
-
keywords = ["unfortunately", "unsupported", "error", "sorry", "response", "unable", "because"]
|
5 |
|
6 |
def contains_keywords(text, keywords):
|
7 |
return any(keyword.lower() in text.lower() for keyword in keywords)
|
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
import numpy as np
|
3 |
|
4 |
+
keywords = ["unfortunately", "unsupported", "error", "sorry", "response", "unable", "because", "too many"]
|
5 |
|
6 |
def contains_keywords(text, keywords):
|
7 |
return any(keyword.lower() in text.lower() for keyword in keywords)
|