Soumik555 commited on
Commit
a8cab65
·
1 Parent(s): d40dc9d
controller.py CHANGED
@@ -26,7 +26,7 @@ import matplotlib.pyplot as plt
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
29
- from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chat
30
  from rethink_gemini_agents.rethink_chart import gemini_llm_chart
31
  from rethink_gemini_agents.rethink_chat import gemini_llm_chat
32
  from util_service import _prompt_generator, process_answer
@@ -302,18 +302,19 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
302
  langchain_gemini_csv_chat, decoded_url, query, False
303
  )
304
  logger.info("gemini langchain_answer --> ", answer)
 
 
 
 
 
 
 
 
305
  return {"answer": jsonable_encoder(answer)}
 
 
306
 
307
- gemini_answer = await asyncio.to_thread(gemini_llm_chat, decoded_url, query)
308
- logger.info("gemini_answer --> ", gemini_answer)
309
- return {"answer": gemini_answer}
310
 
311
- if if_initial_chat_question(query):
312
- answer = await asyncio.to_thread(
313
- langchain_csv_chat, decoded_url, query, False
314
- )
315
- logger.info("langchain_answer --> ", answer)
316
- return {"answer": jsonable_encoder(answer)}
317
 
318
  # Process with groq_chat first
319
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
@@ -799,26 +800,32 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
799
  loop = asyncio.get_running_loop()
800
  # First, try the langchain-based method if the question qualifies
801
 
802
- if if_initial_chat_question(query):
 
803
  langchain_gemini_chart_answer = await asyncio.to_thread(
804
  langchain_gemini_csv_chat, csv_url, query, False
805
  )
806
  logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
807
  if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
808
  return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
809
-
810
- gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
811
- logger.info("gemini_answer --> ", gemini_answer)
812
- return FileResponse(gemini_answer, media_type="image/png")
813
 
814
- if if_initial_chart_question(query):
815
  langchain_result = await loop.run_in_executor(
816
  process_executor, langchain_csv_chart, csv_url, query, True
817
  )
818
- logger.info("Langchain chart result:", langchain_result)
819
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
820
  return FileResponse(langchain_result[0], media_type="image/png")
 
 
 
 
 
 
 
 
821
 
 
822
  # Next, try the groq-based method
823
  groq_result = await loop.run_in_executor(
824
  process_executor, groq_chart, csv_url, query
@@ -828,13 +835,19 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
828
  return FileResponse(groq_result, media_type="image/png")
829
 
830
  # Fallback: try langchain-based again
831
- logger.error("Groq chart generation failed, trying langchain....")
832
  langchain_paths = await loop.run_in_executor(
833
  process_executor, langchain_csv_chart, csv_url, query, True
834
  )
835
- logger.info("Fallback langchain chart result:", langchain_paths)
836
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
837
  return FileResponse(langchain_paths[0], media_type="image/png")
 
 
 
 
 
 
838
  else:
839
  logger.error("All chart generation methods failed")
840
  return {"error": "All chart generation methods failed"}
 
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
29
+ from rethink_gemini_agents.gemini_langchain_service import langchain_gemini_csv_chart, langchain_gemini_csv_chat
30
  from rethink_gemini_agents.rethink_chart import gemini_llm_chart
31
  from rethink_gemini_agents.rethink_chat import gemini_llm_chat
32
  from util_service import _prompt_generator, process_answer
 
302
  langchain_gemini_csv_chat, decoded_url, query, False
303
  )
304
  logger.info("gemini langchain_answer --> ", answer)
305
+ if(answer and process_answer(answer)):
306
+
307
+ logger.error("Gemini chat initial query failed, trying groq based langchain_csv_chat...")
308
+
309
+ answer = await asyncio.to_thread(
310
+ langchain_csv_chat, decoded_url, query, False
311
+ )
312
+ logger.info("groq langchain_answer --> ", answer)
313
  return {"answer": jsonable_encoder(answer)}
314
+
315
+
316
 
 
 
 
317
 
 
 
 
 
 
 
318
 
319
  # Process with groq_chat first
320
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
 
800
  loop = asyncio.get_running_loop()
801
  # First, try the langchain-based method if the question qualifies
802
 
803
+
804
+ if if_initial_chart_question(query):
805
  langchain_gemini_chart_answer = await asyncio.to_thread(
806
  langchain_gemini_csv_chat, csv_url, query, False
807
  )
808
  logger.info("gemini langchain_answer --> ", langchain_gemini_chart_answer)
809
  if isinstance(langchain_gemini_chart_answer, list) and len(langchain_gemini_chart_answer) > 0:
810
  return FileResponse(langchain_gemini_chart_answer[0], media_type="image/png")
 
 
 
 
811
 
812
+ logger.error("Gemini Langchain chart generation failed for initial question, trying groq langchain....")
813
  langchain_result = await loop.run_in_executor(
814
  process_executor, langchain_csv_chart, csv_url, query, True
815
  )
816
+ logger.info("groq langchain_answer -->", langchain_result)
817
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
818
  return FileResponse(langchain_result[0], media_type="image/png")
819
+
820
+ logger.error("Groq Langchain chart generation failed for initial question, trying gemini rethink agent....")
821
+
822
+ # If not initial question then try gemini langchain method
823
+ gemini_answer = await asyncio.to_thread(gemini_llm_chart, csv_url, query)
824
+ logger.info("gemini_answer --> ", gemini_answer)
825
+ if(isinstance(gemini_answer, str) and gemini_answer != "Chart path not found" and ".png" in gemini_answer):
826
+ return FileResponse(gemini_answer, media_type="image/png")
827
 
828
+ logger.error("Gemini chart generation failed, trying groq....")
829
  # Next, try the groq-based method
830
  groq_result = await loop.run_in_executor(
831
  process_executor, groq_chart, csv_url, query
 
835
  return FileResponse(groq_result, media_type="image/png")
836
 
837
  # Fallback: try langchain-based again
838
+ logger.error("Groq chart generation failed, trying langchain (Groq)....")
839
  langchain_paths = await loop.run_in_executor(
840
  process_executor, langchain_csv_chart, csv_url, query, True
841
  )
842
+ logger.info("Groq langchain image paths -->", langchain_paths)
843
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
844
  return FileResponse(langchain_paths[0], media_type="image/png")
845
+
846
+ # Fallback: try gemini-based again
847
+ logger.error("Groq langchain chart generation failed, trying Gemini Langchain agent....")
848
+ gemini_paths = await asyncio.to_thread(langchain_gemini_csv_chart, csv_url, query, True)
849
+ if isinstance(gemini_paths, list) and len(gemini_paths) > 0:
850
+ return FileResponse(gemini_paths[0], media_type="image/png")
851
  else:
852
  logger.error("All chart generation methods failed")
853
  return {"error": "All chart generation methods failed"}
rethink_gemini_agents/rethink_chart.py CHANGED
@@ -232,10 +232,10 @@ def gemini_llm_chart(csv_url: str, query: str) -> str:
232
  return chart_path
233
  else:
234
  print("Chart path not found")
235
- return None
236
  else:
237
  print("Unexpected result format:", type(result))
238
- return None
239
 
240
 
241
 
 
232
  return chart_path
233
  else:
234
  print("Chart path not found")
235
+ return "Chart path not found"
236
  else:
237
  print("Unexpected result format:", type(result))
238
+ return "Chart path not found"
239
 
240
 
241
 
rethink_gemini_agents/rethink_chat.py CHANGED
@@ -212,31 +212,36 @@ class RethinkAgent(BaseModel):
212
 
213
 
214
  def gemini_llm_chat(csv_url: str, query: str) -> str:
 
 
215
  # Assuming clean_data and RethinkAgent are defined elsewhere
216
- df = clean_data(csv_url)
217
- agent = RethinkAgent(df=df)
218
 
219
- # Assuming API_KEYS is defined elsewhere
220
- if not agent.initialize_model(API_KEYS):
221
- print("Failed to initialize model with provided keys")
222
- exit(1)
223
 
224
- result = agent.execute_query(query)
225
 
226
- # Process different response types
227
- if isinstance(result, pd.DataFrame):
228
- processed = result.apply(handle_out_of_range_float).to_dict(orient="records")
229
- elif isinstance(result, pd.Series):
230
  processed = result.apply(handle_out_of_range_float).to_dict()
231
- elif isinstance(result, list):
232
- processed = [handle_out_of_range_float(item) for item in result]
233
- elif isinstance(result, dict):
234
- processed = {k: handle_out_of_range_float(v) for k, v in result.items()}
235
- else:
236
- processed = {"answer": str(handle_out_of_range_float(result))}
237
 
238
- logger.info(f"gemini processed result: {processed}")
239
- return processed
 
 
 
240
 
241
  # uvicorn controller:app --host localhost --port 8000 --reload
242
 
 
212
 
213
 
214
  def gemini_llm_chat(csv_url: str, query: str) -> str:
215
+
216
+ try:
217
  # Assuming clean_data and RethinkAgent are defined elsewhere
218
+ df = clean_data(csv_url)
219
+ agent = RethinkAgent(df=df)
220
 
221
+ # Assuming API_KEYS is defined elsewhere
222
+ if not agent.initialize_model(API_KEYS):
223
+ print("Failed to initialize model with provided keys")
224
+ exit(1)
225
 
226
+ result = agent.execute_query(query)
227
 
228
+ # Process different response types
229
+ if isinstance(result, pd.DataFrame):
230
+ processed = result.apply(handle_out_of_range_float).to_dict(orient="records")
231
+ elif isinstance(result, pd.Series):
232
  processed = result.apply(handle_out_of_range_float).to_dict()
233
+ elif isinstance(result, list):
234
+ processed = [handle_out_of_range_float(item) for item in result]
235
+ elif isinstance(result, dict):
236
+ processed = {k: handle_out_of_range_float(v) for k, v in result.items()}
237
+ else:
238
+ processed = {"answer": str(handle_out_of_range_float(result))}
239
 
240
+ logger.info(f"gemini processed result: {processed}")
241
+ return processed
242
+ except Exception as e:
243
+ logger.error(f"Error in gemini_llm_chat: {str(e)}")
244
+ return None
245
 
246
  # uvicorn controller:app --host localhost --port 8000 --reload
247
 
util_service.py CHANGED
@@ -1,7 +1,7 @@
1
  from langchain_core.prompts import ChatPromptTemplate
2
  import numpy as np
3
 
4
- keywords = ["unfortunately", "unsupported", "error", "sorry", "response", "unable", "because"]
5
 
6
  def contains_keywords(text, keywords):
7
  return any(keyword.lower() in text.lower() for keyword in keywords)
 
1
  from langchain_core.prompts import ChatPromptTemplate
2
  import numpy as np
3
 
4
+ keywords = ["unfortunately", "unsupported", "error", "sorry", "response", "unable", "because", "too many"]
5
 
6
  def contains_keywords(text, keywords):
7
  return any(keyword.lower() in text.lower() for keyword in keywords)