stay hard
Browse files- controller.py +13 -0
controller.py
CHANGED
|
@@ -27,6 +27,7 @@ import matplotlib
|
|
| 27 |
import seaborn as sns
|
| 28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
| 29 |
from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chat
|
|
|
|
| 30 |
from rethink_gemini_agents.rethink_chat import gemini_llm_chat
|
| 31 |
from util_service import _prompt_generator, process_answer
|
| 32 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -797,6 +798,18 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
| 797 |
|
| 798 |
loop = asyncio.get_running_loop()
|
| 799 |
# First, try the langchain-based method if the question qualifies
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
if if_initial_chart_question(query):
|
| 801 |
langchain_result = await loop.run_in_executor(
|
| 802 |
process_executor, langchain_csv_chart, csv_url, query, True
|
|
|
|
| 27 |
import seaborn as sns
|
| 28 |
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
|
| 29 |
from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chat
|
| 30 |
+
from rethink_gemini_agents.rethink_chart import gemini_llm_chart
|
| 31 |
from rethink_gemini_agents.rethink_chat import gemini_llm_chat
|
| 32 |
from util_service import _prompt_generator, process_answer
|
| 33 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 798 |
|
| 799 |
loop = asyncio.get_running_loop()
|
| 800 |
# First, try the langchain-based method if the question qualifies
|
| 801 |
+
|
| 802 |
+
if if_initial_chat_question(query):
|
| 803 |
+
answer = await asyncio.to_thread(
|
| 804 |
+
langchain_gemini_csv_chat, csv_url, query, False
|
| 805 |
+
)
|
| 806 |
+
logger.info("gemini langchain_answer --> ", answer)
|
| 807 |
+
return {"answer": jsonable_encoder(answer)}
|
| 808 |
+
|
| 809 |
+
gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
|
| 810 |
+
logger.info("gemini_answer --> ", gemini_answer)
|
| 811 |
+
return {"answer": gemini_answer}
|
| 812 |
+
|
| 813 |
if if_initial_chart_question(query):
|
| 814 |
langchain_result = await loop.run_in_executor(
|
| 815 |
process_executor, langchain_csv_chart, csv_url, query, True
|