added gemini too
Browse files- controller.py +6 -2
- orchestrator_agent.py +4 -3
controller.py
CHANGED
@@ -301,6 +301,8 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
301 |
csv_url = request.get("csv_url")
|
302 |
decoded_url = unquote(csv_url)
|
303 |
detailed_answer = request.get("detailed_answer")
|
|
|
|
|
304 |
|
305 |
if if_initial_chat_question(query):
|
306 |
answer = await asyncio.to_thread(
|
@@ -312,7 +314,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
312 |
# Orchestrate the execution
|
313 |
if detailed_answer is True:
|
314 |
orchestrator_answer = await asyncio.to_thread(
|
315 |
-
csv_orchestrator_chat, decoded_url, query
|
316 |
)
|
317 |
if orchestrator_answer is not None:
|
318 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
@@ -798,6 +800,8 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
798 |
query = request.get("query", "")
|
799 |
csv_url = unquote(request.get("csv_url", ""))
|
800 |
detailed_answer = request.get("detailed_answer", False)
|
|
|
|
|
801 |
|
802 |
loop = asyncio.get_running_loop()
|
803 |
# First, try the langchain-based method if the question qualifies
|
@@ -817,7 +821,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
817 |
# Use orchestrator to handle the user's chart query first
|
818 |
if detailed_answer is True:
|
819 |
orchestrator_answer = await asyncio.to_thread(
|
820 |
-
csv_orchestrator_chat, csv_url, query
|
821 |
)
|
822 |
|
823 |
if orchestrator_answer is not None:
|
|
|
301 |
csv_url = request.get("csv_url")
|
302 |
decoded_url = unquote(csv_url)
|
303 |
detailed_answer = request.get("detailed_answer")
|
304 |
+
conversation_history = request.get("conversation_history", [])
|
305 |
+
return {"answer": jsonable_encoder(conversation_history)}
|
306 |
|
307 |
if if_initial_chat_question(query):
|
308 |
answer = await asyncio.to_thread(
|
|
|
314 |
# Orchestrate the execution
|
315 |
if detailed_answer is True:
|
316 |
orchestrator_answer = await asyncio.to_thread(
|
317 |
+
csv_orchestrator_chat, decoded_url, query, conversation_history
|
318 |
)
|
319 |
if orchestrator_answer is not None:
|
320 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
|
|
800 |
query = request.get("query", "")
|
801 |
csv_url = unquote(request.get("csv_url", ""))
|
802 |
detailed_answer = request.get("detailed_answer", False)
|
803 |
+
conversation_history = request.get("conversation_history", [])
|
804 |
+
return {"orchestrator_response": jsonable_encoder(conversation_history)}
|
805 |
|
806 |
loop = asyncio.get_running_loop()
|
807 |
# First, try the langchain-based method if the question qualifies
|
|
|
821 |
# Use orchestrator to handle the user's chart query first
|
822 |
if detailed_answer is True:
|
823 |
orchestrator_answer = await asyncio.to_thread(
|
824 |
+
csv_orchestrator_chat, csv_url, query, conversation_history
|
825 |
)
|
826 |
|
827 |
if orchestrator_answer is not None:
|
orchestrator_agent.py
CHANGED
@@ -90,7 +90,7 @@ async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
|
|
90 |
return charts
|
91 |
|
92 |
# Function to create an agent with a specific CSV URL
|
93 |
-
def create_agent(csv_url: str, api_key: str) -> Agent:
|
94 |
csv_metadata = get_csv_basic_info(csv_url)
|
95 |
|
96 |
system_prompt = f"""
|
@@ -130,6 +130,7 @@ def create_agent(csv_url: str, api_key: str) -> Agent:
|
|
130 |
## Current Context:
|
131 |
- Working with CSV_URL: {csv_url}
|
132 |
- Dataset overview: {csv_metadata}
|
|
|
133 |
- Output format: Markdown compatible
|
134 |
|
135 |
## Response Template:
|
@@ -192,7 +193,7 @@ def create_agent(csv_url: str, api_key: str) -> Agent:
|
|
192 |
system_prompt=system_prompt
|
193 |
)
|
194 |
|
195 |
-
def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
|
196 |
print("CSV URL:", csv_url)
|
197 |
print("User questions:", user_question)
|
198 |
|
@@ -200,7 +201,7 @@ def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
|
|
200 |
for api_key in GEMINI_API_KEYS:
|
201 |
try:
|
202 |
print(f"Attempting with API key: {api_key}")
|
203 |
-
agent = create_agent(csv_url, api_key)
|
204 |
result = agent.run_sync(user_question)
|
205 |
print("Orchestrator Result:", result.data)
|
206 |
return result.data
|
|
|
90 |
return charts
|
91 |
|
92 |
# Function to create an agent with a specific CSV URL
|
93 |
+
def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agent:
|
94 |
csv_metadata = get_csv_basic_info(csv_url)
|
95 |
|
96 |
system_prompt = f"""
|
|
|
130 |
## Current Context:
|
131 |
- Working with CSV_URL: {csv_url}
|
132 |
- Dataset overview: {csv_metadata}
|
133 |
+
- Your conversation history: {conversation_history}
|
134 |
- Output format: Markdown compatible
|
135 |
|
136 |
## Response Template:
|
|
|
193 |
system_prompt=system_prompt
|
194 |
)
|
195 |
|
196 |
+
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
|
197 |
print("CSV URL:", csv_url)
|
198 |
print("User questions:", user_question)
|
199 |
|
|
|
201 |
for api_key in GEMINI_API_KEYS:
|
202 |
try:
|
203 |
print(f"Attempting with API key: {api_key}")
|
204 |
+
agent = create_agent(csv_url, api_key, conversation_history)
|
205 |
result = agent.run_sync(user_question)
|
206 |
print("Orchestrator Result:", result.data)
|
207 |
return result.data
|