added openai react
Browse files- controller.py +44 -70
controller.py
CHANGED
@@ -27,7 +27,6 @@ import matplotlib
|
|
27 |
import seaborn as sns
|
28 |
from gemini_report_generator import generate_csv_report
|
29 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
30 |
-
from openai_react_agent_service import openai_react_chat
|
31 |
from orchestrator_agent import csv_orchestrator_chat
|
32 |
from supabase_service import upload_file_to_supabase
|
33 |
from util_service import _prompt_generator, process_answer
|
@@ -345,10 +344,9 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
345 |
query = request.get("query")
|
346 |
csv_url = request.get("csv_url")
|
347 |
decoded_url = unquote(csv_url)
|
348 |
-
detailed_answer = request.get("detailed_answer"
|
349 |
conversation_history = request.get("conversation_history", [])
|
350 |
-
generate_report = request.get("generate_report"
|
351 |
-
is_pro = request.get("is_pro", False)
|
352 |
|
353 |
if generate_report is True:
|
354 |
report_files = await generate_csv_report(csv_url, query)
|
@@ -370,31 +368,22 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
370 |
if orchestrator_answer is not None:
|
371 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
372 |
|
373 |
-
# if the user is pro, then we use the openai_react_agent first
|
374 |
-
if is_pro is True:
|
375 |
-
openai_answer = await asyncio.to_thread(
|
376 |
-
openai_react_chat, decoded_url, query, False
|
377 |
-
)
|
378 |
-
logger.info("openai_answer:", openai_answer)
|
379 |
-
if openai_answer is not None:
|
380 |
-
return {"answer": jsonable_encoder(openai_answer)}
|
381 |
-
|
382 |
# Process with groq_chat first
|
383 |
-
|
384 |
-
|
385 |
|
386 |
-
|
387 |
-
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
|
397 |
-
|
398 |
|
399 |
except Exception as e:
|
400 |
logger.error(f"Error processing request: {str(e)}")
|
@@ -862,7 +851,6 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
862 |
detailed_answer = request.get("detailed_answer", False)
|
863 |
conversation_history = request.get("conversation_history", [])
|
864 |
generate_report = request.get("generate_report", False)
|
865 |
-
is_pro = request.get("is_pro", False)
|
866 |
|
867 |
if generate_report is True:
|
868 |
report_files = await generate_csv_report(csv_url, query)
|
@@ -893,52 +881,38 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
893 |
|
894 |
if orchestrator_answer is not None:
|
895 |
return {"orchestrator_response": jsonable_encoder(orchestrator_answer)}
|
896 |
-
|
897 |
-
# If user have a pro subscription start with openai-reAct agent
|
898 |
-
if is_pro is True:
|
899 |
-
openai_react_answer = await asyncio.to_thread(
|
900 |
-
process_executor, openai_react_chat, csv_url, query, True
|
901 |
-
)
|
902 |
-
if openai_react_answer is not None:
|
903 |
-
chart_path = openai_react_answer
|
904 |
-
logger.info("Uploading the chart to supabase...")
|
905 |
-
unique_file_name =f'{str(uuid.uuid4())}.png'
|
906 |
-
image_public_url = await upload_file_to_supabase(f"{chart_path}", unique_file_name)
|
907 |
-
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
908 |
-
os.remove(chart_path)
|
909 |
-
return {"image_url": image_public_url}
|
910 |
-
|
911 |
# Next, try the groq-based method
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
|
925 |
-
#
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
|
943 |
except Exception as e:
|
944 |
logger.error(f"Critical chart error: {str(e)}")
|
|
|
27 |
import seaborn as sns
|
28 |
from gemini_report_generator import generate_csv_report
|
29 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
|
|
30 |
from orchestrator_agent import csv_orchestrator_chat
|
31 |
from supabase_service import upload_file_to_supabase
|
32 |
from util_service import _prompt_generator, process_answer
|
|
|
344 |
query = request.get("query")
|
345 |
csv_url = request.get("csv_url")
|
346 |
decoded_url = unquote(csv_url)
|
347 |
+
detailed_answer = request.get("detailed_answer")
|
348 |
conversation_history = request.get("conversation_history", [])
|
349 |
+
generate_report = request.get("generate_report")
|
|
|
350 |
|
351 |
if generate_report is True:
|
352 |
report_files = await generate_csv_report(csv_url, query)
|
|
|
368 |
if orchestrator_answer is not None:
|
369 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
# Process with groq_chat first
|
372 |
+
groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
373 |
+
logger.info("groq_answer:", groq_answer)
|
374 |
|
375 |
+
if process_answer(groq_answer) == "Empty response received.":
|
376 |
+
return {"answer": "Sorry, I couldn't find relevant data..."}
|
377 |
|
378 |
+
if process_answer(groq_answer):
|
379 |
+
lang_answer = await asyncio.to_thread(
|
380 |
+
langchain_csv_chat, decoded_url, query, False
|
381 |
+
)
|
382 |
+
if process_answer(lang_answer):
|
383 |
+
return {"answer": "error"}
|
384 |
+
return {"answer": jsonable_encoder(lang_answer)}
|
385 |
|
386 |
+
return {"answer": jsonable_encoder(groq_answer)}
|
387 |
|
388 |
except Exception as e:
|
389 |
logger.error(f"Error processing request: {str(e)}")
|
|
|
851 |
detailed_answer = request.get("detailed_answer", False)
|
852 |
conversation_history = request.get("conversation_history", [])
|
853 |
generate_report = request.get("generate_report", False)
|
|
|
854 |
|
855 |
if generate_report is True:
|
856 |
report_files = await generate_csv_report(csv_url, query)
|
|
|
881 |
|
882 |
if orchestrator_answer is not None:
|
883 |
return {"orchestrator_response": jsonable_encoder(orchestrator_answer)}
|
884 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
885 |
# Next, try the groq-based method
|
886 |
+
groq_result = await loop.run_in_executor(
|
887 |
+
process_executor, groq_chart, csv_url, query
|
888 |
+
)
|
889 |
+
logger.info(f"Groq chart result: {groq_result}")
|
890 |
+
if isinstance(groq_result, str) and groq_result != "Chart not generated":
|
891 |
+
unique_file_name =f'{str(uuid.uuid4())}.png'
|
892 |
+
logger.info("Uploading the chart to supabase...")
|
893 |
+
image_public_url = await upload_file_to_supabase(f"{groq_result}", unique_file_name)
|
894 |
+
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
895 |
+
os.remove(groq_result)
|
896 |
+
return {"image_url": image_public_url}
|
897 |
+
# return FileResponse(groq_result, media_type="image/png")
|
898 |
|
899 |
+
# Fallback: try langchain-based again
|
900 |
+
logger.error("Groq chart generation failed, trying langchain....")
|
901 |
+
langchain_paths = await loop.run_in_executor(
|
902 |
+
process_executor, langchain_csv_chart, csv_url, query, True
|
903 |
+
)
|
904 |
+
logger.info("Fallback langchain chart result:", langchain_paths)
|
905 |
+
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
906 |
+
unique_file_name =f'{str(uuid.uuid4())}.png'
|
907 |
+
logger.info("Uploading the chart to supabase...")
|
908 |
+
image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name)
|
909 |
+
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
910 |
+
os.remove(langchain_paths[0])
|
911 |
+
return {"image_url": image_public_url}
|
912 |
+
# return FileResponse(langchain_paths[0], media_type="image/png")
|
913 |
+
else:
|
914 |
+
logger.error("All chart generation methods failed")
|
915 |
+
return {"answer": "error"}
|
916 |
|
917 |
except Exception as e:
|
918 |
logger.error(f"Critical chart error: {str(e)}")
|