stay hard
Browse files- controller.py +5 -116
controller.py
CHANGED
@@ -26,9 +26,6 @@ import matplotlib.pyplot as plt
|
|
26 |
import matplotlib
|
27 |
import seaborn as sns
|
28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
29 |
-
from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chart, langchain_gemini_csv_chat
|
30 |
-
from rethink_gemini_agents.rethink_chart import gemini_llm_chart
|
31 |
-
from rethink_gemini_agents.rethink_chat import gemini_llm_chat
|
32 |
from util_service import _prompt_generator, process_answer
|
33 |
from fastapi.middleware.cors import CORSMiddleware
|
34 |
import matplotlib
|
@@ -297,55 +294,12 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
297 |
decoded_url = unquote(csv_url)
|
298 |
|
299 |
if if_initial_chat_question(query):
|
300 |
-
|
301 |
-
|
302 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
303 |
-
|
304 |
-
langchain_gemini_answer = await asyncio.to_thread(
|
305 |
-
langchain_gemini_csv_chat, decoded_url, query, True
|
306 |
-
)
|
307 |
-
logger.info("Gemini langchain_answer:", langchain_gemini_answer)
|
308 |
-
if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
|
309 |
-
return {"answer": jsonable_encoder(langchain_gemini_answer)}
|
310 |
-
|
311 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
answer = await asyncio.to_thread(
|
316 |
langchain_csv_chat, decoded_url, query, False
|
317 |
)
|
318 |
-
logger.info("
|
319 |
-
if answer != None and process_answer(answer) != True:
|
320 |
-
return {"answer": jsonable_encoder(answer)}
|
321 |
-
|
322 |
-
|
323 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
324 |
-
|
325 |
-
|
326 |
-
answer = await asyncio.to_thread(
|
327 |
-
gemini_llm_chat, decoded_url, query
|
328 |
-
)
|
329 |
-
logger.info("Rethink gemini 1st:", answer)
|
330 |
return {"answer": jsonable_encoder(answer)}
|
331 |
-
|
332 |
-
|
333 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
334 |
-
|
335 |
-
|
336 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
337 |
-
|
338 |
-
|
339 |
-
answer = await asyncio.to_thread(
|
340 |
-
gemini_llm_chat, decoded_url, query
|
341 |
-
)
|
342 |
-
logger.info("Rethink gemini 1st:", answer)
|
343 |
-
if(answer != None and process_answer(answer) != True):
|
344 |
-
return {"answer": jsonable_encoder(answer)}
|
345 |
-
|
346 |
-
|
347 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
348 |
-
|
349 |
# Process with groq_chat first
|
350 |
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
351 |
logger.info("groq_answer:", groq_answer)
|
@@ -358,20 +312,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
358 |
langchain_csv_chat, decoded_url, query, False
|
359 |
)
|
360 |
if process_answer(lang_answer):
|
361 |
-
|
362 |
-
|
363 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
364 |
-
|
365 |
-
langchain_gemini_answer = await asyncio.to_thread(
|
366 |
-
langchain_gemini_csv_chat, decoded_url, query, True
|
367 |
-
)
|
368 |
-
logger.info("Gemini langchain_answer:", langchain_gemini_answer)
|
369 |
-
if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
|
370 |
-
return {"answer": jsonable_encoder(langchain_gemini_answer)}
|
371 |
-
|
372 |
-
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
373 |
-
|
374 |
-
|
375 |
return {"answer": jsonable_encoder(lang_answer)}
|
376 |
|
377 |
return {"answer": jsonable_encoder(groq_answer)}
|
@@ -843,51 +784,12 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
843 |
loop = asyncio.get_running_loop()
|
844 |
# First, try the langchain-based method if the question qualifies
|
845 |
if if_initial_chart_question(query):
|
846 |
-
|
847 |
-
|
848 |
-
# --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
|
849 |
-
|
850 |
-
|
851 |
-
langchain_gemini_chart_answer = await asyncio.to_thread(
|
852 |
-
langchain_gemini_csv_chart, csv_url, query, True
|
853 |
-
)
|
854 |
-
logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
|
855 |
-
if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
|
856 |
-
return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
|
857 |
-
|
858 |
-
logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
# --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
|
864 |
-
|
865 |
-
|
866 |
langchain_result = await loop.run_in_executor(
|
867 |
process_executor, langchain_csv_chart, csv_url, query, True
|
868 |
)
|
869 |
logger.info("Langchain chart result:", langchain_result)
|
870 |
if isinstance(langchain_result, list) and len(langchain_result) > 0:
|
871 |
return FileResponse(langchain_result[0], media_type="image/png")
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
# --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
|
877 |
-
|
878 |
-
|
879 |
-
logger.error("Groq Langchain chart generation failed, trying rethink gemini....")
|
880 |
-
gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
|
881 |
-
logger.info("gemini_answer --> ", gemini_answer)
|
882 |
-
if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
|
883 |
-
return FileResponse(gemini_answer, media_type="image/png")
|
884 |
-
|
885 |
-
logger.error("Gemini chart generation failed, trying groq....")
|
886 |
-
|
887 |
-
|
888 |
-
# --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
|
889 |
-
|
890 |
-
|
891 |
|
892 |
# Next, try the groq-based method
|
893 |
groq_result = await loop.run_in_executor(
|
@@ -906,21 +808,8 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
906 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
907 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
908 |
else:
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
langchain_gemini_chart_answer = await asyncio.to_thread(
|
914 |
-
langchain_gemini_csv_chart, csv_url, query, True
|
915 |
-
)
|
916 |
-
logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
|
917 |
-
if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
|
918 |
-
return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
|
919 |
-
|
920 |
-
# --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
|
921 |
-
|
922 |
-
logger.error("All chart generation methods failed")
|
923 |
-
return {"error": "All chart generation methods failed"}
|
924 |
|
925 |
except Exception as e:
|
926 |
logger.error(f"Critical chart error: {str(e)}")
|
|
|
26 |
import matplotlib
|
27 |
import seaborn as sns
|
28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
|
|
|
|
|
|
29 |
from util_service import _prompt_generator, process_answer
|
30 |
from fastapi.middleware.cors import CORSMiddleware
|
31 |
import matplotlib
|
|
|
294 |
decoded_url = unquote(csv_url)
|
295 |
|
296 |
if if_initial_chat_question(query):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
answer = await asyncio.to_thread(
|
298 |
langchain_csv_chat, decoded_url, query, False
|
299 |
)
|
300 |
+
logger.info("langchain_answer:", answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
return {"answer": jsonable_encoder(answer)}
|
302 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
# Process with groq_chat first
|
304 |
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
305 |
logger.info("groq_answer:", groq_answer)
|
|
|
312 |
langchain_csv_chat, decoded_url, query, False
|
313 |
)
|
314 |
if process_answer(lang_answer):
|
315 |
+
return {"answer": "error"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
return {"answer": jsonable_encoder(lang_answer)}
|
317 |
|
318 |
return {"answer": jsonable_encoder(groq_answer)}
|
|
|
784 |
loop = asyncio.get_running_loop()
|
785 |
# First, try the langchain-based method if the question qualifies
|
786 |
if if_initial_chart_question(query):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
787 |
langchain_result = await loop.run_in_executor(
|
788 |
process_executor, langchain_csv_chart, csv_url, query, True
|
789 |
)
|
790 |
logger.info("Langchain chart result:", langchain_result)
|
791 |
if isinstance(langchain_result, list) and len(langchain_result) > 0:
|
792 |
return FileResponse(langchain_result[0], media_type="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
|
794 |
# Next, try the groq-based method
|
795 |
groq_result = await loop.run_in_executor(
|
|
|
808 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
809 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
810 |
else:
|
811 |
+
logger.error("All chart generation methods failed")
|
812 |
+
return {"error": "All chart generation methods failed"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
|
814 |
except Exception as e:
|
815 |
logger.error(f"Critical chart error: {str(e)}")
|