Soumik555 commited on
Commit
d67a459
·
1 Parent(s): 66fea52
Files changed (1) hide show
  1. controller.py +5 -116
controller.py CHANGED
@@ -26,9 +26,6 @@ import matplotlib.pyplot as plt
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
29
- from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chart, langchain_gemini_csv_chat
30
- from rethink_gemini_agents.rethink_chart import gemini_llm_chart
31
- from rethink_gemini_agents.rethink_chat import gemini_llm_chat
32
  from util_service import _prompt_generator, process_answer
33
  from fastapi.middleware.cors import CORSMiddleware
34
  import matplotlib
@@ -297,55 +294,12 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
297
  decoded_url = unquote(csv_url)
298
 
299
  if if_initial_chat_question(query):
300
-
301
-
302
- # --- Gemini-Langchain-based and Rethink-based chat generation starts---
303
-
304
- langchain_gemini_answer = await asyncio.to_thread(
305
- langchain_gemini_csv_chat, decoded_url, query, True
306
- )
307
- logger.info("Gemini langchain_answer:", langchain_gemini_answer)
308
- if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
309
- return {"answer": jsonable_encoder(langchain_gemini_answer)}
310
-
311
- # --- Gemini-Langchain-based and Rethink-based chat generation ends---
312
-
313
-
314
-
315
  answer = await asyncio.to_thread(
316
  langchain_csv_chat, decoded_url, query, False
317
  )
318
- logger.info("Groq langchain_answer:", answer)
319
- if answer != None and process_answer(answer) != True:
320
- return {"answer": jsonable_encoder(answer)}
321
-
322
-
323
- # --- Gemini-Langchain-based and Rethink-based chat generation starts---
324
-
325
-
326
- answer = await asyncio.to_thread(
327
- gemini_llm_chat, decoded_url, query
328
- )
329
- logger.info("Rethink gemini 1st:", answer)
330
  return {"answer": jsonable_encoder(answer)}
331
-
332
-
333
- # --- Gemini-Langchain-based and Rethink-based chat generation ends---
334
-
335
-
336
- # --- Gemini-Langchain-based and Rethink-based chat generation starts---
337
-
338
-
339
- answer = await asyncio.to_thread(
340
- gemini_llm_chat, decoded_url, query
341
- )
342
- logger.info("Rethink gemini 1st:", answer)
343
- if(answer != None and process_answer(answer) != True):
344
- return {"answer": jsonable_encoder(answer)}
345
-
346
-
347
- # --- Gemini-Langchain-based and Rethink-based chat generation ends---
348
-
349
  # Process with groq_chat first
350
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
351
  logger.info("groq_answer:", groq_answer)
@@ -358,20 +312,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
358
  langchain_csv_chat, decoded_url, query, False
359
  )
360
  if process_answer(lang_answer):
361
- logger.error("Error in Groq_Langchain, Trying Gemini...")
362
-
363
- # --- Gemini-Langchain-based and Rethink-based chat generation starts---
364
-
365
- langchain_gemini_answer = await asyncio.to_thread(
366
- langchain_gemini_csv_chat, decoded_url, query, True
367
- )
368
- logger.info("Gemini langchain_answer:", langchain_gemini_answer)
369
- if langchain_gemini_answer != None and process_answer(langchain_gemini_answer) != True:
370
- return {"answer": jsonable_encoder(langchain_gemini_answer)}
371
-
372
- # --- Gemini-Langchain-based and Rethink-based chat generation ends---
373
-
374
-
375
  return {"answer": jsonable_encoder(lang_answer)}
376
 
377
  return {"answer": jsonable_encoder(groq_answer)}
@@ -843,51 +784,12 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
843
  loop = asyncio.get_running_loop()
844
  # First, try the langchain-based method if the question qualifies
845
  if if_initial_chart_question(query):
846
-
847
-
848
- # --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
849
-
850
-
851
- langchain_gemini_chart_answer = await asyncio.to_thread(
852
- langchain_gemini_csv_chart, csv_url, query, True
853
- )
854
- logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
855
- if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
856
- return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
857
-
858
- logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
859
-
860
-
861
-
862
-
863
- # --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
864
-
865
-
866
  langchain_result = await loop.run_in_executor(
867
  process_executor, langchain_csv_chart, csv_url, query, True
868
  )
869
  logger.info("Langchain chart result:", langchain_result)
870
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
871
  return FileResponse(langchain_result[0], media_type="image/png")
872
-
873
-
874
-
875
-
876
- # --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
877
-
878
-
879
- logger.error("Groq Langchain chart generation failed, trying rethink gemini....")
880
- gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
881
- logger.info("gemini_answer --> ", gemini_answer)
882
- if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
883
- return FileResponse(gemini_answer, media_type="image/png")
884
-
885
- logger.error("Gemini chart generation failed, trying groq....")
886
-
887
-
888
- # --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
889
-
890
-
891
 
892
  # Next, try the groq-based method
893
  groq_result = await loop.run_in_executor(
@@ -906,21 +808,8 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
906
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
907
  return FileResponse(langchain_paths[0], media_type="image/png")
908
  else:
909
-
910
- # --- Gemini-Langchain-based and Rethink-based chart generation starts here ---
911
-
912
-
913
- langchain_gemini_chart_answer = await asyncio.to_thread(
914
- langchain_gemini_csv_chart, csv_url, query, True
915
- )
916
- logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
917
- if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
918
- return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
919
-
920
- # --- Gemini-Langchain-based and Rethink-based chart generation ends here ---
921
-
922
- logger.error("All chart generation methods failed")
923
- return {"error": "All chart generation methods failed"}
924
 
925
  except Exception as e:
926
  logger.error(f"Critical chart error: {str(e)}")
 
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
 
 
 
29
  from util_service import _prompt_generator, process_answer
30
  from fastapi.middleware.cors import CORSMiddleware
31
  import matplotlib
 
294
  decoded_url = unquote(csv_url)
295
 
296
  if if_initial_chat_question(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  answer = await asyncio.to_thread(
298
  langchain_csv_chat, decoded_url, query, False
299
  )
300
+ logger.info("langchain_answer:", answer)
 
 
 
 
 
 
 
 
 
 
 
301
  return {"answer": jsonable_encoder(answer)}
302
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  # Process with groq_chat first
304
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
305
  logger.info("groq_answer:", groq_answer)
 
312
  langchain_csv_chat, decoded_url, query, False
313
  )
314
  if process_answer(lang_answer):
315
+ return {"answer": "error"}
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  return {"answer": jsonable_encoder(lang_answer)}
317
 
318
  return {"answer": jsonable_encoder(groq_answer)}
 
784
  loop = asyncio.get_running_loop()
785
  # First, try the langchain-based method if the question qualifies
786
  if if_initial_chart_question(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
  langchain_result = await loop.run_in_executor(
788
  process_executor, langchain_csv_chart, csv_url, query, True
789
  )
790
  logger.info("Langchain chart result:", langchain_result)
791
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
792
  return FileResponse(langchain_result[0], media_type="image/png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793
 
794
  # Next, try the groq-based method
795
  groq_result = await loop.run_in_executor(
 
808
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
809
  return FileResponse(langchain_paths[0], media_type="image/png")
810
  else:
811
+ logger.error("All chart generation methods failed")
812
+ return {"error": "All chart generation methods failed"}
 
 
 
 
 
 
 
 
 
 
 
 
 
813
 
814
  except Exception as e:
815
  logger.error(f"Critical chart error: {str(e)}")