stay hard
Browse files- controller.py +106 -36
controller.py
CHANGED
@@ -50,7 +50,6 @@ os.makedirs("/app/cache", exist_ok=True)
|
|
50 |
|
51 |
os.makedirs("/app", exist_ok=True)
|
52 |
open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
|
53 |
-
open("/app/api_key_rotation.log", "a").close() # Create the file if it doesn't exist
|
54 |
|
55 |
# Ensure the generated_charts directory exists
|
56 |
os.makedirs("/app/generated_charts", exist_ok=True)
|
@@ -296,29 +295,60 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
296 |
query = request.get("query")
|
297 |
csv_url = request.get("csv_url")
|
298 |
decoded_url = unquote(csv_url)
|
299 |
-
|
300 |
if if_initial_chat_question(query):
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
303 |
)
|
304 |
-
logger.info("
|
305 |
-
if
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
|
|
|
|
|
|
310 |
langchain_csv_chat, decoded_url, query, False
|
311 |
-
|
312 |
-
logger.info("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
return {"answer": jsonable_encoder(answer)}
|
314 |
-
|
315 |
-
|
316 |
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
# Process with groq_chat first
|
320 |
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
321 |
-
logger.info("groq_answer
|
322 |
|
323 |
if process_answer(groq_answer) == "Empty response received.":
|
324 |
return {"answer": "Sorry, I couldn't find relevant data..."}
|
@@ -328,7 +358,20 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
328 |
langchain_csv_chat, decoded_url, query, False
|
329 |
)
|
330 |
if process_answer(lang_answer):
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
return {"answer": jsonable_encoder(lang_answer)}
|
333 |
|
334 |
return {"answer": jsonable_encoder(groq_answer)}
|
@@ -625,7 +668,7 @@ current_langchain_chart_lock = threading.Lock()
|
|
625 |
|
626 |
|
627 |
# Use a process pool to run CPU-bound chart generation
|
628 |
-
process_executor = ProcessPoolExecutor(max_workers=
|
629 |
|
630 |
# --- GROQ-BASED CHART GENERATION ---
|
631 |
# def groq_chart(csv_url: str, question: str):
|
@@ -799,33 +842,53 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
799 |
|
800 |
loop = asyncio.get_running_loop()
|
801 |
# First, try the langchain-based method if the question qualifies
|
802 |
-
|
803 |
-
|
804 |
if if_initial_chart_question(query):
|
|
|
|
|
|
|
|
|
|
|
805 |
langchain_gemini_chart_answer = await asyncio.to_thread(
|
806 |
-
|
807 |
)
|
808 |
logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
|
809 |
if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
|
810 |
return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
|
811 |
|
812 |
logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
langchain_result = await loop.run_in_executor(
|
814 |
process_executor, langchain_csv_chart, csv_url, query, True
|
815 |
)
|
816 |
-
logger.info("
|
817 |
if isinstance(langchain_result, list) and len(langchain_result) > 0:
|
818 |
return FileResponse(langchain_result[0], media_type="image/png")
|
819 |
|
820 |
-
|
821 |
-
|
822 |
-
|
|
|
|
|
|
|
|
|
823 |
gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
|
824 |
logger.info("gemini_answer --> ", gemini_answer)
|
825 |
if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
|
826 |
return FileResponse(gemini_answer, media_type="image/png")
|
827 |
|
828 |
logger.error("Gemini chart generation failed, trying groq....")
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
# Next, try the groq-based method
|
830 |
groq_result = await loop.run_in_executor(
|
831 |
process_executor, groq_chart, csv_url, query
|
@@ -835,22 +898,29 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
835 |
return FileResponse(groq_result, media_type="image/png")
|
836 |
|
837 |
# Fallback: try langchain-based again
|
838 |
-
logger.error("Groq chart generation failed, trying langchain
|
839 |
langchain_paths = await loop.run_in_executor(
|
840 |
process_executor, langchain_csv_chart, csv_url, query, True
|
841 |
)
|
842 |
-
logger.info("
|
843 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
844 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
845 |
-
|
846 |
-
# Fallback: try gemini-based again
|
847 |
-
logger.error("Groq langchain chart generation failed, trying Gemini Langchain agent....")
|
848 |
-
gemini_paths = await asyncio.to_thread(langchain_gemini_csv_chart, csv_url, query, True)
|
849 |
-
if isinstance(gemini_paths, list) and len(gemini_paths) > 0:
|
850 |
-
return FileResponse(gemini_paths[0], media_type="image/png")
|
851 |
else:
|
852 |
-
|
853 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
854 |
|
855 |
except Exception as e:
|
856 |
logger.error(f"Critical chart error: {str(e)}")
|
|
|
50 |
|
51 |
os.makedirs("/app", exist_ok=True)
|
52 |
open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
|
|
|
53 |
|
54 |
# Ensure the generated_charts directory exists
|
55 |
os.makedirs("/app/generated_charts", exist_ok=True)
|
|
|
295 |
query = request.get("query")
|
296 |
csv_url = request.get("csv_url")
|
297 |
decoded_url = unquote(csv_url)
|
298 |
+
|
299 |
if if_initial_chat_question(query):
|
300 |
+
|
301 |
+
|
302 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
303 |
+
|
304 |
+
langchain_gemini_answer = await asyncio.to_thread(
|
305 |
+
langchain_gemini_csv_chat, decoded_url, query, True
|
306 |
)
|
307 |
+
logger.info("Gemini langchain_answer:", langchain_gemini_answer)
|
308 |
+
if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
|
309 |
+
return {"answer": jsonable_encoder(langchain_gemini_answer)}
|
310 |
+
|
311 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
answer = await asyncio.to_thread(
|
316 |
langchain_csv_chat, decoded_url, query, False
|
317 |
+
)
|
318 |
+
logger.info("Groq langchain_answer:", answer)
|
319 |
+
if answer != None and process_answer(answer) != True:
|
320 |
+
return {"answer": jsonable_encoder(answer)}
|
321 |
+
|
322 |
+
|
323 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
324 |
+
|
325 |
+
|
326 |
+
answer = await asyncio.to_thread(
|
327 |
+
gemini_llm_chat, decoded_url, query
|
328 |
+
)
|
329 |
+
logger.info("Rethink gemini 1st:", answer)
|
330 |
return {"answer": jsonable_encoder(answer)}
|
|
|
|
|
331 |
|
332 |
+
|
333 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
334 |
+
|
335 |
+
|
336 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
337 |
+
|
338 |
+
|
339 |
+
answer = await asyncio.to_thread(
|
340 |
+
gemini_llm_chat, decoded_url, query
|
341 |
+
)
|
342 |
+
logger.info("Rethink gemini 1st:", answer)
|
343 |
+
if(answer != None and process_answer(answer) != True):
|
344 |
+
return {"answer": jsonable_encoder(answer)}
|
345 |
+
|
346 |
+
|
347 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
348 |
+
|
349 |
# Process with groq_chat first
|
350 |
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
351 |
+
logger.info("groq_answer:", groq_answer)
|
352 |
|
353 |
if process_answer(groq_answer) == "Empty response received.":
|
354 |
return {"answer": "Sorry, I couldn't find relevant data..."}
|
|
|
358 |
langchain_csv_chat, decoded_url, query, False
|
359 |
)
|
360 |
if process_answer(lang_answer):
|
361 |
+
logger.error("Error in Groq_Langchain, Trying Gemini...")
|
362 |
+
|
363 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation starts---
|
364 |
+
|
365 |
+
langchain_gemini_answer = await asyncio.to_thread(
|
366 |
+
langchain_gemini_csv_chat, decoded_url, query, True
|
367 |
+
)
|
368 |
+
logger.info("Gemini langchain_answer:", langchain_gemini_answer)
|
369 |
+
if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
|
370 |
+
return {"answer": jsonable_encoder(langchain_gemini_answer)}
|
371 |
+
|
372 |
+
# --- Gemini-Langchain-based and Rethink-based chat generation ends---
|
373 |
+
|
374 |
+
|
375 |
return {"answer": jsonable_encoder(lang_answer)}
|
376 |
|
377 |
return {"answer": jsonable_encoder(groq_answer)}
|
|
|
668 |
|
669 |
|
670 |
# Use a process pool to run CPU-bound chart generation
|
671 |
+
process_executor = ProcessPoolExecutor(max_workers=max_cpus-2)
|
672 |
|
673 |
# --- GROQ-BASED CHART GENERATION ---
|
674 |
# def groq_chart(csv_url: str, question: str):
|
|
|
842 |
|
843 |
loop = asyncio.get_running_loop()
|
844 |
# First, try the langchain-based method if the question qualifies
|
|
|
|
|
845 |
if if_initial_chart_question(query):
|
846 |
+
|
847 |
+
|
848 |
+
# --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
|
849 |
+
|
850 |
+
|
851 |
langchain_gemini_chart_answer = await asyncio.to_thread(
|
852 |
+
langchain_gemini_csv_chart, csv_url, query, True
|
853 |
)
|
854 |
logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
|
855 |
if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
|
856 |
return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
|
857 |
|
858 |
logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
|
859 |
+
|
860 |
+
|
861 |
+
|
862 |
+
|
863 |
+
# --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
|
864 |
+
|
865 |
+
|
866 |
langchain_result = await loop.run_in_executor(
|
867 |
process_executor, langchain_csv_chart, csv_url, query, True
|
868 |
)
|
869 |
+
logger.info("Langchain chart result:", langchain_result)
|
870 |
if isinstance(langchain_result, list) and len(langchain_result) > 0:
|
871 |
return FileResponse(langchain_result[0], media_type="image/png")
|
872 |
|
873 |
+
|
874 |
+
|
875 |
+
|
876 |
+
# --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
|
877 |
+
|
878 |
+
|
879 |
+
logger.error("Groq Langchain chart generation failed, trying rethink gemini....")
|
880 |
gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
|
881 |
logger.info("gemini_answer --> ", gemini_answer)
|
882 |
if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
|
883 |
return FileResponse(gemini_answer, media_type="image/png")
|
884 |
|
885 |
logger.error("Gemini chart generation failed, trying groq....")
|
886 |
+
|
887 |
+
|
888 |
+
# --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
|
889 |
+
|
890 |
+
|
891 |
+
|
892 |
# Next, try the groq-based method
|
893 |
groq_result = await loop.run_in_executor(
|
894 |
process_executor, groq_chart, csv_url, query
|
|
|
898 |
return FileResponse(groq_result, media_type="image/png")
|
899 |
|
900 |
# Fallback: try langchain-based again
|
901 |
+
logger.error("Groq chart generation failed, trying langchain....")
|
902 |
langchain_paths = await loop.run_in_executor(
|
903 |
process_executor, langchain_csv_chart, csv_url, query, True
|
904 |
)
|
905 |
+
logger.info("Fallback langchain chart result:", langchain_paths)
|
906 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
907 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
908 |
else:
|
909 |
+
|
910 |
+
# --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
|
911 |
+
|
912 |
+
|
913 |
+
langchain_gemini_chart_answer = await asyncio.to_thread(
|
914 |
+
langchain_gemini_csv_chart, csv_url, query, True
|
915 |
+
)
|
916 |
+
logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
|
917 |
+
if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
|
918 |
+
return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
|
919 |
+
|
920 |
+
# --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
|
921 |
+
|
922 |
+
logger.error("All chart generation methods failed")
|
923 |
+
return {"error": "All chart generation methods failed"}
|
924 |
|
925 |
except Exception as e:
|
926 |
logger.error(f"Critical chart error: {str(e)}")
|