Soumik555 commited on
Commit
48e6960
·
1 Parent(s): 489d74e

added gemini too

Browse files
Files changed (2) hide show
  1. controller.py +6 -2
  2. orchestrator_agent.py +4 -3
controller.py CHANGED
@@ -301,6 +301,8 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
301
  csv_url = request.get("csv_url")
302
  decoded_url = unquote(csv_url)
303
  detailed_answer = request.get("detailed_answer")
 
 
304
 
305
  if if_initial_chat_question(query):
306
  answer = await asyncio.to_thread(
@@ -312,7 +314,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
312
  # Orchestrate the execution
313
  if detailed_answer is True:
314
  orchestrator_answer = await asyncio.to_thread(
315
- csv_orchestrator_chat, decoded_url, query
316
  )
317
  if orchestrator_answer is not None:
318
  return {"answer": jsonable_encoder(orchestrator_answer)}
@@ -798,6 +800,8 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
798
  query = request.get("query", "")
799
  csv_url = unquote(request.get("csv_url", ""))
800
  detailed_answer = request.get("detailed_answer", False)
 
 
801
 
802
  loop = asyncio.get_running_loop()
803
  # First, try the langchain-based method if the question qualifies
@@ -817,7 +821,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
817
  # Use orchestrator to handle the user's chart query first
818
  if detailed_answer is True:
819
  orchestrator_answer = await asyncio.to_thread(
820
- csv_orchestrator_chat, csv_url, query
821
  )
822
 
823
  if orchestrator_answer is not None:
 
301
  csv_url = request.get("csv_url")
302
  decoded_url = unquote(csv_url)
303
  detailed_answer = request.get("detailed_answer")
304
+ conversation_history = request.get("conversation_history", [])
305
+ return {"answer": jsonable_encoder(conversation_history)}
306
 
307
  if if_initial_chat_question(query):
308
  answer = await asyncio.to_thread(
 
314
  # Orchestrate the execution
315
  if detailed_answer is True:
316
  orchestrator_answer = await asyncio.to_thread(
317
+ csv_orchestrator_chat, decoded_url, query, conversation_history
318
  )
319
  if orchestrator_answer is not None:
320
  return {"answer": jsonable_encoder(orchestrator_answer)}
 
800
  query = request.get("query", "")
801
  csv_url = unquote(request.get("csv_url", ""))
802
  detailed_answer = request.get("detailed_answer", False)
803
+ conversation_history = request.get("conversation_history", [])
804
+ return {"orchestrator_response": jsonable_encoder(conversation_history)}
805
 
806
  loop = asyncio.get_running_loop()
807
  # First, try the langchain-based method if the question qualifies
 
821
  # Use orchestrator to handle the user's chart query first
822
  if detailed_answer is True:
823
  orchestrator_answer = await asyncio.to_thread(
824
+ csv_orchestrator_chat, csv_url, query, conversation_history
825
  )
826
 
827
  if orchestrator_answer is not None:
orchestrator_agent.py CHANGED
@@ -90,7 +90,7 @@ async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
90
  return charts
91
 
92
  # Function to create an agent with a specific CSV URL
93
- def create_agent(csv_url: str, api_key: str) -> Agent:
94
  csv_metadata = get_csv_basic_info(csv_url)
95
 
96
  system_prompt = f"""
@@ -130,6 +130,7 @@ def create_agent(csv_url: str, api_key: str) -> Agent:
130
  ## Current Context:
131
  - Working with CSV_URL: {csv_url}
132
  - Dataset overview: {csv_metadata}
 
133
  - Output format: Markdown compatible
134
 
135
  ## Response Template:
@@ -192,7 +193,7 @@ def create_agent(csv_url: str, api_key: str) -> Agent:
192
  system_prompt=system_prompt
193
  )
194
 
195
- def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
196
  print("CSV URL:", csv_url)
197
  print("User questions:", user_question)
198
 
@@ -200,7 +201,7 @@ def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
200
  for api_key in GEMINI_API_KEYS:
201
  try:
202
  print(f"Attempting with API key: {api_key}")
203
- agent = create_agent(csv_url, api_key)
204
  result = agent.run_sync(user_question)
205
  print("Orchestrator Result:", result.data)
206
  return result.data
 
90
  return charts
91
 
92
  # Function to create an agent with a specific CSV URL
93
+ def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agent:
94
  csv_metadata = get_csv_basic_info(csv_url)
95
 
96
  system_prompt = f"""
 
130
  ## Current Context:
131
  - Working with CSV_URL: {csv_url}
132
  - Dataset overview: {csv_metadata}
133
+ - Your conversation history: {conversation_history}
134
  - Output format: Markdown compatible
135
 
136
  ## Response Template:
 
193
  system_prompt=system_prompt
194
  )
195
 
196
+ def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
197
  print("CSV URL:", csv_url)
198
  print("User questions:", user_question)
199
 
 
201
  for api_key in GEMINI_API_KEYS:
202
  try:
203
  print(f"Attempting with API key: {api_key}")
204
+ agent = create_agent(csv_url, api_key, conversation_history)
205
  result = agent.run_sync(user_question)
206
  print("Orchestrator Result:", result.data)
207
  return result.data