Soumik555 commited on
Commit
66fea52
·
1 Parent(s): a8cab65
Files changed (1) hide show
  1. controller.py +106 -36
controller.py CHANGED
@@ -50,7 +50,6 @@ os.makedirs("/app/cache", exist_ok=True)
50
 
51
  os.makedirs("/app", exist_ok=True)
52
  open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
53
- open("/app/api_key_rotation.log", "a").close() # Create the file if it doesn't exist
54
 
55
  # Ensure the generated_charts directory exists
56
  os.makedirs("/app/generated_charts", exist_ok=True)
@@ -296,29 +295,60 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
296
  query = request.get("query")
297
  csv_url = request.get("csv_url")
298
  decoded_url = unquote(csv_url)
299
-
300
  if if_initial_chat_question(query):
301
- answer = await asyncio.to_thread(
302
- langchain_gemini_csv_chat, decoded_url, query, False
 
 
 
 
303
  )
304
- logger.info("gemini langchain_answer --> ", answer)
305
- if(answer and process_answer(answer)):
306
-
307
- logger.error("Gemini chat initial query failed, trying groq based langchain_csv_chat...")
308
-
309
- answer = await asyncio.to_thread(
 
 
 
310
  langchain_csv_chat, decoded_url, query, False
311
- )
312
- logger.info("groq langchain_answer --> ", answer)
 
 
 
 
 
 
 
 
 
 
 
313
  return {"answer": jsonable_encoder(answer)}
314
-
315
-
316
 
317
-
318
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  # Process with groq_chat first
320
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
321
- logger.info("groq_answer --> ", groq_answer)
322
 
323
  if process_answer(groq_answer) == "Empty response received.":
324
  return {"answer": "Sorry, I couldn't find relevant data..."}
@@ -328,7 +358,20 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
328
  langchain_csv_chat, decoded_url, query, False
329
  )
330
  if process_answer(lang_answer):
331
- return {"answer": "error"}
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  return {"answer": jsonable_encoder(lang_answer)}
333
 
334
  return {"answer": jsonable_encoder(groq_answer)}
@@ -625,7 +668,7 @@ current_langchain_chart_lock = threading.Lock()
625
 
626
 
627
  # Use a process pool to run CPU-bound chart generation
628
- process_executor = ProcessPoolExecutor(max_workers=(os.cpu_count()-2))
629
 
630
  # --- GROQ-BASED CHART GENERATION ---
631
  # def groq_chart(csv_url: str, question: str):
@@ -799,33 +842,53 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
799
 
800
  loop = asyncio.get_running_loop()
801
  # First, try the langchain-based method if the question qualifies
802
-
803
-
804
  if if_initial_chart_question(query):
 
 
 
 
 
805
  langchain_gemini_chart_answer = await asyncio.to_thread(
806
- langchain_gemini_csv_chat, csv_url, query, False
807
  )
808
  logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
809
  if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
810
  return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
811
 
812
  logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
 
 
 
 
 
 
 
813
  langchain_result = await loop.run_in_executor(
814
  process_executor, langchain_csv_chart, csv_url, query, True
815
  )
816
- logger.info("groq langchain_answer -->", langchain_result)
817
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
818
  return FileResponse(langchain_result[0], media_type="image/png")
819
 
820
- logger.error("Groq Langchain chart generation failed for initial question, trying gemini rethink agent....")
821
-
822
- # If not initial question then try gemini langchain method
 
 
 
 
823
  gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
824
  logger.info("gemini_answer --> ", gemini_answer)
825
  if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
826
  return FileResponse(gemini_answer, media_type="image/png")
827
 
828
  logger.error("Gemini chart generation failed, trying groq....")
 
 
 
 
 
 
829
  # Next, try the groq-based method
830
  groq_result = await loop.run_in_executor(
831
  process_executor, groq_chart, csv_url, query
@@ -835,22 +898,29 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
835
  return FileResponse(groq_result, media_type="image/png")
836
 
837
  # Fallback: try langchain-based again
838
- logger.error("Groq chart generation failed, trying langchain (Groq)....")
839
  langchain_paths = await loop.run_in_executor(
840
  process_executor, langchain_csv_chart, csv_url, query, True
841
  )
842
- logger.info("Groq langchain image paths -->", langchain_paths)
843
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
844
  return FileResponse(langchain_paths[0], media_type="image/png")
845
-
846
- # Fallback: try gemini-based again
847
- logger.error("Groq langchain chart generation failed, trying Gemini Langchain agent....")
848
- gemini_paths = await asyncio.to_thread(langchain_gemini_csv_chart, csv_url, query, True)
849
- if isinstance(gemini_paths, list) and len(gemini_paths) > 0:
850
- return FileResponse(gemini_paths[0], media_type="image/png")
851
  else:
852
- logger.error("All chart generation methods failed")
853
- return {"error": "All chart generation methods failed"}
 
 
 
 
 
 
 
 
 
 
 
 
 
854
 
855
  except Exception as e:
856
  logger.error(f"Critical chart error: {str(e)}")
 
50
 
51
  os.makedirs("/app", exist_ok=True)
52
  open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
 
53
 
54
  # Ensure the generated_charts directory exists
55
  os.makedirs("/app/generated_charts", exist_ok=True)
 
295
  query = request.get("query")
296
  csv_url = request.get("csv_url")
297
  decoded_url = unquote(csv_url)
298
+
299
  if if_initial_chat_question(query):
300
+
301
+
302
+ # --- Gemini-Langchain-based and Rethink-based chat generation starts---
303
+
304
+ langchain_gemini_answer = await asyncio.to_thread(
305
+ langchain_gemini_csv_chat, decoded_url, query, True
306
  )
307
+ logger.info("Gemini langchain_answer:", langchain_gemini_answer)
308
+ if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
309
+ return {"answer": jsonable_encoder(langchain_gemini_answer)}
310
+
311
+ # --- Gemini-Langchain-based and Rethink-based chat generation ends---
312
+
313
+
314
+
315
+ answer = await asyncio.to_thread(
316
  langchain_csv_chat, decoded_url, query, False
317
+ )
318
+ logger.info("Groq langchain_answer:", answer)
319
+ if answer != None and process_answer(answer) != True:
320
+ return {"answer": jsonable_encoder(answer)}
321
+
322
+
323
+ # --- Gemini-Langchain-based and Rethink-based chat generation starts---
324
+
325
+
326
+ answer = await asyncio.to_thread(
327
+ gemini_llm_chat, decoded_url, query
328
+ )
329
+ logger.info("Rethink gemini 1st:", answer)
330
  return {"answer": jsonable_encoder(answer)}
 
 
331
 
332
+
333
+ # --- Gemini-Langchain-based and Rethink-based chat generation ends---
334
+
335
+
336
+ # --- Gemini-Langchain-based and Rethink-based chat generation starts---
337
+
338
+
339
+ answer = await asyncio.to_thread(
340
+ gemini_llm_chat, decoded_url, query
341
+ )
342
+ logger.info("Rethink gemini 1st:", answer)
343
+ if(answer != None and process_answer(answer) != True):
344
+ return {"answer": jsonable_encoder(answer)}
345
+
346
+
347
+ # --- Gemini-Langchain-based and Rethink-based chat generation ends---
348
+
349
  # Process with groq_chat first
350
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
351
+ logger.info("groq_answer:", groq_answer)
352
 
353
  if process_answer(groq_answer) == "Empty response received.":
354
  return {"answer": "Sorry, I couldn't find relevant data..."}
 
358
  langchain_csv_chat, decoded_url, query, False
359
  )
360
  if process_answer(lang_answer):
361
+ logger.error("Error in Groq_Langchain, Trying Gemini...")
362
+
363
+ # --- Gemini-Langchain-based and Rethink-based chat generation starts---
364
+
365
+ langchain_gemini_answer = await asyncio.to_thread(
366
+ langchain_gemini_csv_chat, decoded_url, query, True
367
+ )
368
+ logger.info("Gemini langchain_answer:", langchain_gemini_answer)
369
+ if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
370
+ return {"answer": jsonable_encoder(langchain_gemini_answer)}
371
+
372
+ # --- Gemini-Langchain-based and Rethink-based chat generation ends---
373
+
374
+
375
  return {"answer": jsonable_encoder(lang_answer)}
376
 
377
  return {"answer": jsonable_encoder(groq_answer)}
 
668
 
669
 
670
  # Use a process pool to run CPU-bound chart generation
671
+ process_executor = ProcessPoolExecutor(max_workers=max_cpus-2)
672
 
673
  # --- GROQ-BASED CHART GENERATION ---
674
  # def groq_chart(csv_url: str, question: str):
 
842
 
843
  loop = asyncio.get_running_loop()
844
  # First, try the langchain-based method if the question qualifies
 
 
845
  if if_initial_chart_question(query):
846
+
847
+
848
+ # --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
849
+
850
+
851
  langchain_gemini_chart_answer = await asyncio.to_thread(
852
+ langchain_gemini_csv_chart, csv_url, query, True
853
  )
854
  logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
855
  if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
856
  return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
857
 
858
  logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
859
+
860
+
861
+
862
+
863
+ # --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
864
+
865
+
866
  langchain_result = await loop.run_in_executor(
867
  process_executor, langchain_csv_chart, csv_url, query, True
868
  )
869
+ logger.info("Langchain chart result:", langchain_result)
870
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
871
  return FileResponse(langchain_result[0], media_type="image/png")
872
 
873
+
874
+
875
+
876
+ # --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
877
+
878
+
879
+ logger.error("Groq Langchain chart generation failed, trying rethink gemini....")
880
  gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
881
  logger.info("gemini_answer --> ", gemini_answer)
882
  if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
883
  return FileResponse(gemini_answer, media_type="image/png")
884
 
885
  logger.error("Gemini chart generation failed, trying groq....")
886
+
887
+
888
+ # --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
889
+
890
+
891
+
892
  # Next, try the groq-based method
893
  groq_result = await loop.run_in_executor(
894
  process_executor, groq_chart, csv_url, query
 
898
  return FileResponse(groq_result, media_type="image/png")
899
 
900
  # Fallback: try langchain-based again
901
+ logger.error("Groq chart generation failed, trying langchain....")
902
  langchain_paths = await loop.run_in_executor(
903
  process_executor, langchain_csv_chart, csv_url, query, True
904
  )
905
+ logger.info("Fallback langchain chart result:", langchain_paths)
906
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
907
  return FileResponse(langchain_paths[0], media_type="image/png")
 
 
 
 
 
 
908
  else:
909
+
910
+ # --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
911
+
912
+
913
+ langchain_gemini_chart_answer = await asyncio.to_thread(
914
+ langchain_gemini_csv_chart, csv_url, query, True
915
+ )
916
+ logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
917
+ if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
918
+ return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
919
+
920
+ # --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
921
+
922
+ logger.error("All chart generation methods failed")
923
+ return {"error": "All chart generation methods failed"}
924
 
925
  except Exception as e:
926
  logger.error(f"Critical chart error: {str(e)}")