Soumik555 commited on
Commit
d784ff5
·
1 Parent(s): 2a7fd4f

Changed supabase query

Browse files
controller.py CHANGED
@@ -76,6 +76,7 @@ class CsvUrlRequest(BaseModel):
76
 
77
  class ImageRequest(BaseModel):
78
  image_path: str
 
79
 
80
  class FileProps(BaseModel):
81
  fileName: str
@@ -176,7 +177,7 @@ async def get_image(request: ImageRequest, authorization: str = Header(None)):
176
  image_file_path = request.image_path
177
  unique_file_name =f'{str(uuid.uuid4())}.png'
178
  logger.info("Uploading the chart to supabase...")
179
- image_public_url = await upload_file_to_supabase(f"{image_file_path}", unique_file_name)
180
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
181
  os.remove(image_file_path)
182
  return {"image_url": image_public_url}
@@ -219,15 +220,7 @@ def groq_chat(csv_url: str, question: str):
219
  current_api_key = groq_api_keys[current_groq_key_index]
220
 
221
  try:
222
- # Delete cache file if exists
223
- # cache_db_path = "/app/cache/cache_db_0.11.db"
224
- # if os.path.exists(cache_db_path):
225
- # try:
226
- # os.remove(cache_db_path)
227
- # print(f"Deleted cache DB file: {cache_db_path}")
228
- # except Exception as e:
229
- # print(f"Error deleting cache DB file: {e}")
230
-
231
  data = clean_data(csv_url)
232
  llm = ChatGroq(model=model_name, api_key=current_api_key)
233
  # Generate unique filename using UUID
@@ -347,9 +340,10 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
347
  detailed_answer = request.get("detailed_answer")
348
  conversation_history = request.get("conversation_history", [])
349
  generate_report = request.get("generate_report")
 
350
 
351
  if generate_report is True:
352
- report_files = await generate_csv_report(csv_url, query)
353
  if report_files is not None:
354
  return {"answer": jsonable_encoder(report_files)}
355
 
@@ -363,7 +357,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
363
  # Orchestrate the execution
364
  if detailed_answer is True:
365
  orchestrator_answer = await asyncio.to_thread(
366
- csv_orchestrator_chat, decoded_url, query, conversation_history
367
  )
368
  if orchestrator_answer is not None:
369
  return {"answer": jsonable_encoder(orchestrator_answer)}
@@ -851,9 +845,10 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
851
  detailed_answer = request.get("detailed_answer", False)
852
  conversation_history = request.get("conversation_history", [])
853
  generate_report = request.get("generate_report", False)
 
854
 
855
  if generate_report is True:
856
- report_files = await generate_csv_report(csv_url, query)
857
  if report_files is not None:
858
  return {"orchestrator_response": jsonable_encoder(report_files)}
859
 
@@ -867,7 +862,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
867
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
868
  unique_file_name =f'{str(uuid.uuid4())}.png'
869
  logger.info("Uploading the chart to supabase...")
870
- image_public_url = await upload_file_to_supabase(f"{langchain_result[0]}", unique_file_name)
871
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
872
  os.remove(langchain_result[0])
873
  return {"image_url": image_public_url}
@@ -876,7 +871,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
876
  # Use orchestrator to handle the user's chart query first
877
  if detailed_answer is True:
878
  orchestrator_answer = await asyncio.to_thread(
879
- csv_orchestrator_chat, csv_url, query, conversation_history
880
  )
881
 
882
  if orchestrator_answer is not None:
@@ -890,7 +885,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
890
  if isinstance(groq_result, str) and groq_result != "Chart not generated":
891
  unique_file_name =f'{str(uuid.uuid4())}.png'
892
  logger.info("Uploading the chart to supabase...")
893
- image_public_url = await upload_file_to_supabase(f"{groq_result}", unique_file_name)
894
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
895
  os.remove(groq_result)
896
  return {"image_url": image_public_url}
@@ -905,7 +900,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
905
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
906
  unique_file_name =f'{str(uuid.uuid4())}.png'
907
  logger.info("Uploading the chart to supabase...")
908
- image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name)
909
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
910
  os.remove(langchain_paths[0])
911
  return {"image_url": image_public_url}
 
76
 
77
  class ImageRequest(BaseModel):
78
  image_path: str
79
+ chat_id: str
80
 
81
  class FileProps(BaseModel):
82
  fileName: str
 
177
  image_file_path = request.image_path
178
  unique_file_name =f'{str(uuid.uuid4())}.png'
179
  logger.info("Uploading the chart to supabase...")
180
+ image_public_url = await upload_file_to_supabase(f"{image_file_path}", unique_file_name, chat_id=request.chat_id)
181
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
182
  os.remove(image_file_path)
183
  return {"image_url": image_public_url}
 
220
  current_api_key = groq_api_keys[current_groq_key_index]
221
 
222
  try:
223
+
 
 
 
 
 
 
 
 
224
  data = clean_data(csv_url)
225
  llm = ChatGroq(model=model_name, api_key=current_api_key)
226
  # Generate unique filename using UUID
 
340
  detailed_answer = request.get("detailed_answer")
341
  conversation_history = request.get("conversation_history", [])
342
  generate_report = request.get("generate_report")
343
+ chat_id = request.get("chat_id")
344
 
345
  if generate_report is True:
346
+ report_files = await generate_csv_report(csv_url, query, chat_id)
347
  if report_files is not None:
348
  return {"answer": jsonable_encoder(report_files)}
349
 
 
357
  # Orchestrate the execution
358
  if detailed_answer is True:
359
  orchestrator_answer = await asyncio.to_thread(
360
+ csv_orchestrator_chat, decoded_url, query, conversation_history, chat_id
361
  )
362
  if orchestrator_answer is not None:
363
  return {"answer": jsonable_encoder(orchestrator_answer)}
 
845
  detailed_answer = request.get("detailed_answer", False)
846
  conversation_history = request.get("conversation_history", [])
847
  generate_report = request.get("generate_report", False)
848
+ chat_id = request.get("chat_id", "")
849
 
850
  if generate_report is True:
851
+ report_files = await generate_csv_report(csv_url, query, chat_id)
852
  if report_files is not None:
853
  return {"orchestrator_response": jsonable_encoder(report_files)}
854
 
 
862
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
863
  unique_file_name =f'{str(uuid.uuid4())}.png'
864
  logger.info("Uploading the chart to supabase...")
865
+ image_public_url = await upload_file_to_supabase(f"{langchain_result[0]}", unique_file_name, chat_id=chat_id)
866
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
867
  os.remove(langchain_result[0])
868
  return {"image_url": image_public_url}
 
871
  # Use orchestrator to handle the user's chart query first
872
  if detailed_answer is True:
873
  orchestrator_answer = await asyncio.to_thread(
874
+ csv_orchestrator_chat, csv_url, query, conversation_history, chat_id
875
  )
876
 
877
  if orchestrator_answer is not None:
 
885
  if isinstance(groq_result, str) and groq_result != "Chart not generated":
886
  unique_file_name =f'{str(uuid.uuid4())}.png'
887
  logger.info("Uploading the chart to supabase...")
888
+ image_public_url = await upload_file_to_supabase(f"{groq_result}", unique_file_name, chat_id=chat_id)
889
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
890
  os.remove(groq_result)
891
  return {"image_url": image_public_url}
 
900
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
901
  unique_file_name =f'{str(uuid.uuid4())}.png'
902
  logger.info("Uploading the chart to supabase...")
903
+ image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name, chat_id=chat_id)
904
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
905
  os.remove(langchain_paths[0])
906
  return {"image_url": image_public_url}
gemini_report_generator.py CHANGED
@@ -288,7 +288,7 @@ def gemini_llm_chat(csv_url: str, query: str) -> Dict[str, Any]:
288
  }
289
 
290
 
291
- async def generate_csv_report(csv_url: str, query: str) -> FileBoxProps:
292
  try:
293
  result = gemini_llm_chat(csv_url, query)
294
  logger.info(f"Raw result from gemini_llm_chat: {result}")
@@ -306,7 +306,8 @@ async def generate_csv_report(csv_url: str, query: str) -> FileBoxProps:
306
  unique_file_name = f"{uuid.uuid4()}_{file_name}"
307
  public_url = await upload_file_to_supabase(
308
  file_path=csv_path,
309
- file_name=unique_file_name
 
310
  )
311
  csv_files.append(FileProps(
312
  fileName=file_name,
@@ -326,7 +327,8 @@ async def generate_csv_report(csv_url: str, query: str) -> FileBoxProps:
326
  unique_file_name = f"{uuid.uuid4()}_{file_name}"
327
  public_url = await upload_file_to_supabase(
328
  file_path=img_path,
329
- file_name=unique_file_name
 
330
  )
331
  image_files.append(FileProps(
332
  fileName=file_name,
 
288
  }
289
 
290
 
291
+ async def generate_csv_report(csv_url: str, query: str, chat_id: str) -> FileBoxProps:
292
  try:
293
  result = gemini_llm_chat(csv_url, query)
294
  logger.info(f"Raw result from gemini_llm_chat: {result}")
 
306
  unique_file_name = f"{uuid.uuid4()}_{file_name}"
307
  public_url = await upload_file_to_supabase(
308
  file_path=csv_path,
309
+ file_name=unique_file_name,
310
+ chat_id=chat_id
311
  )
312
  csv_files.append(FileProps(
313
  fileName=file_name,
 
327
  unique_file_name = f"{uuid.uuid4()}_{file_name}"
328
  public_url = await upload_file_to_supabase(
329
  file_path=img_path,
330
+ file_name=unique_file_name,
331
+ chat_id=chat_id
332
  )
333
  image_files.append(FileProps(
334
  fileName=file_name,
orchestrator_agent.py CHANGED
@@ -56,7 +56,7 @@ async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
56
  answers.append(dict(question=question, answer=answer))
57
  return answers
58
 
59
- async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
60
 
61
  """
62
  This function generates charts for the given user questions using the CSV URL.
@@ -84,13 +84,13 @@ async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
84
  charts = []
85
  # Loop through the user questions and generate charts for each
86
  for question in user_questions:
87
- chart = await csv_chart(csv_url, question)
88
  charts.append(dict(question=question, image_url=chart))
89
 
90
  return charts
91
 
92
  # Function to create an agent with a specific CSV URL
93
- def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agent:
94
  csv_metadata = get_csv_basic_info(csv_url)
95
 
96
  system_prompt = f"""
@@ -115,6 +115,7 @@ def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agen
115
  - **Dataset:** {csv_url}
116
  - **Metadata:** {csv_metadata}
117
  - **History:** {conversation_history}
 
118
 
119
  ## Required Output:
120
  For every question return:
@@ -133,7 +134,7 @@ For every question return:
133
  system_prompt=system_prompt
134
  )
135
 
136
- def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
137
  print("CSV URL:", csv_url)
138
  print("User questions:", user_question)
139
 
@@ -141,7 +142,7 @@ def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history
141
  for api_key in GEMINI_API_KEYS:
142
  try:
143
  print(f"Attempting with API key: {api_key}")
144
- agent = create_agent(csv_url, api_key, conversation_history)
145
  result = agent.run_sync(user_question)
146
  print("Orchestrator Result:", result.data)
147
  return result.data
 
56
  answers.append(dict(question=question, answer=answer))
57
  return answers
58
 
59
+ async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
60
 
61
  """
62
  This function generates charts for the given user questions using the CSV URL.
 
84
  charts = []
85
  # Loop through the user questions and generate charts for each
86
  for question in user_questions:
87
+ chart = await csv_chart(csv_url, question, chat_id)
88
  charts.append(dict(question=question, image_url=chart))
89
 
90
  return charts
91
 
92
  # Function to create an agent with a specific CSV URL
93
+ def create_agent(csv_url: str, api_key: str, conversation_history: List, chat_id: str) -> Agent:
94
  csv_metadata = get_csv_basic_info(csv_url)
95
 
96
  system_prompt = f"""
 
115
  - **Dataset:** {csv_url}
116
  - **Metadata:** {csv_metadata}
117
  - **History:** {conversation_history}
118
+ - **Chat ID:** {chat_id}
119
 
120
  ## Required Output:
121
  For every question return:
 
134
  system_prompt=system_prompt
135
  )
136
 
137
+ def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
138
  print("CSV URL:", csv_url)
139
  print("User questions:", user_question)
140
 
 
142
  for api_key in GEMINI_API_KEYS:
143
  try:
144
  print(f"Attempting with API key: {api_key}")
145
+ agent = create_agent(csv_url, api_key, conversation_history, chat_id)
146
  result = agent.run_sync(user_question)
147
  print("Orchestrator Result:", result.data)
148
  return result.data
orchestrator_functions.py CHANGED
@@ -569,7 +569,7 @@ async def csv_chat(csv_url: str, query: str):
569
 
570
 
571
 
572
- async def csv_chart(csv_url: str, query: str):
573
  """
574
  Generate a chart based on the provided CSV URL and query.
575
  Prioritizes OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
@@ -590,10 +590,10 @@ async def csv_chart(csv_url: str, query: str):
590
  - dict: {"image_url": "https://storage.example.com/chart_uuid.png"}
591
  """
592
 
593
- async def upload_and_return(image_path: str) -> dict:
594
  """Helper function to handle image uploads"""
595
  unique_name = f'{uuid.uuid4()}.png'
596
- public_url = await upload_file_to_supabase(image_path, unique_name)
597
  logger.info(f"Uploaded chart: {public_url}")
598
  os.remove(image_path) # Remove the local image file after upload
599
  return {"image_url": public_url}
@@ -605,7 +605,7 @@ async def csv_chart(csv_url: str, query: str):
605
  logger.info(f"OpenAI chart result:", openai_result)
606
 
607
  if openai_result and openai_result != 'Chart not generated':
608
- return await upload_and_return(openai_result)
609
 
610
  raise Exception("OpenAI failed to generate chart")
611
 
@@ -617,7 +617,7 @@ async def csv_chart(csv_url: str, query: str):
617
  logger.info(f"Raw Groq chart result:", groq_result)
618
 
619
  if groq_result and groq_result != 'Chart not generated':
620
- return await upload_and_return(groq_result)
621
 
622
  raise Exception("Raw Groq failed to generate chart")
623
 
@@ -634,11 +634,11 @@ async def csv_chart(csv_url: str, query: str):
634
  # --- i) If Gemini result is a string, return it ---
635
  if gemini_result and isinstance(gemini_result, str):
636
  clean_path = gemini_result.strip()
637
- return await upload_and_return(clean_path)
638
 
639
  # --- ii) If Gemini result is a list, return the first element ---
640
  if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0:
641
- return await upload_and_return(gemini_result[0])
642
 
643
  raise Exception("LangChain Gemini returned empty result")
644
 
@@ -653,7 +653,7 @@ async def csv_chart(csv_url: str, query: str):
653
  logger.info("LangChain Groq chart result:", lc_groq_paths)
654
 
655
  if isinstance(lc_groq_paths, list) and lc_groq_paths:
656
- return await upload_and_return(lc_groq_paths[0])
657
 
658
  return {"error": "All chart generation methods failed"}
659
 
 
569
 
570
 
571
 
572
+ async def csv_chart(csv_url: str, query: str, chat_id: str):
573
  """
574
  Generate a chart based on the provided CSV URL and query.
575
  Prioritizes OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
 
590
  - dict: {"image_url": "https://storage.example.com/chart_uuid.png"}
591
  """
592
 
593
+ async def upload_and_return(image_path: str, chat_id: str) -> dict:
594
  """Helper function to handle image uploads"""
595
  unique_name = f'{uuid.uuid4()}.png'
596
+ public_url = await upload_file_to_supabase(image_path, unique_name, chat_id)
597
  logger.info(f"Uploaded chart: {public_url}")
598
  os.remove(image_path) # Remove the local image file after upload
599
  return {"image_url": public_url}
 
605
  logger.info(f"OpenAI chart result:", openai_result)
606
 
607
  if openai_result and openai_result != 'Chart not generated':
608
+ return await upload_and_return(openai_result, chat_id)
609
 
610
  raise Exception("OpenAI failed to generate chart")
611
 
 
617
  logger.info(f"Raw Groq chart result:", groq_result)
618
 
619
  if groq_result and groq_result != 'Chart not generated':
620
+ return await upload_and_return(groq_result, chat_id)
621
 
622
  raise Exception("Raw Groq failed to generate chart")
623
 
 
634
  # --- i) If Gemini result is a string, return it ---
635
  if gemini_result and isinstance(gemini_result, str):
636
  clean_path = gemini_result.strip()
637
+ return await upload_and_return(clean_path, chat_id)
638
 
639
  # --- ii) If Gemini result is a list, return the first element ---
640
  if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0:
641
+ return await upload_and_return(gemini_result[0], chat_id)
642
 
643
  raise Exception("LangChain Gemini returned empty result")
644
 
 
653
  logger.info("LangChain Groq chart result:", lc_groq_paths)
654
 
655
  if isinstance(lc_groq_paths, list) and lc_groq_paths:
656
+ return await upload_and_return(lc_groq_paths[0], chat_id)
657
 
658
  return {"error": "All chart generation methods failed"}
659
 
supabase_service.py CHANGED
@@ -15,10 +15,10 @@ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
15
  # Define the bucket name (you can create one in the Supabase Storage section)
16
  BUCKET_NAME = "csvcharts"
17
 
18
- async def upload_file_to_supabase(file_path: str, file_name: str) -> str:
19
  """
20
  Uploads an image to Supabase Storage and returns the public URL.
21
- Also saves the mapping between public_url and chart_name in the database.
22
  """
23
  if not os.path.exists(file_path):
24
  raise FileNotFoundError(f"The file {file_path} does not exist.")
@@ -35,13 +35,16 @@ async def upload_file_to_supabase(file_path: str, file_name: str) -> str:
35
  public_url = supabase.storage.from_(BUCKET_NAME).get_public_url(file_name)
36
  print("Public URL:", public_url)
37
 
38
- # Save the mapping to the database
39
  try:
40
  supabase.table("chart_mappings").insert({
41
  "public_url": public_url,
42
- "chart_name": file_name
 
43
  }).execute()
44
  except Exception as e:
 
 
45
  raise Exception(f"Failed to save mapping to database: {e}")
46
 
47
  return public_url
 
15
  # Define the bucket name (you can create one in the Supabase Storage section)
16
  BUCKET_NAME = "csvcharts"
17
 
18
+ async def upload_file_to_supabase(file_path: str, file_name: str, chat_id: str) -> str:
19
  """
20
  Uploads an image to Supabase Storage and returns the public URL.
21
+ Saves the mapping between public_url, chart_name, and chat_id in the database.
22
  """
23
  if not os.path.exists(file_path):
24
  raise FileNotFoundError(f"The file {file_path} does not exist.")
 
35
  public_url = supabase.storage.from_(BUCKET_NAME).get_public_url(file_name)
36
  print("Public URL:", public_url)
37
 
38
+ # Save the mapping to the database including chat_id
39
  try:
40
  supabase.table("chart_mappings").insert({
41
  "public_url": public_url,
42
+ "chart_name": file_name,
43
+ "chat_id": chat_id
44
  }).execute()
45
  except Exception as e:
46
+ # Try to delete the uploaded file if DB insertion fails
47
+ supabase.storage.from_(BUCKET_NAME).remove([file_name])
48
  raise Exception(f"Failed to save mapping to database: {e}")
49
 
50
  return public_url