# Import necessary modules import asyncio import logging import os import threading import uuid from fastapi.encoders import jsonable_encoder import numpy as np import pandas as pd from pandasai import SmartDataframe from langchain_groq.chat_models import ChatGroq from dotenv import load_dotenv from pydantic import BaseModel from csv_service import clean_data, extract_chart_filenames from langchain_groq import ChatGroq import pandas as pd from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent import numpy as np import matplotlib.pyplot as plt import matplotlib import seaborn as sns from gemini_langchain_agent import langchain_gemini_csv_handler from openai_pandasai_service import openai_chart from supabase_service import upload_file_to_supabase from util_service import _prompt_generator, process_answer import matplotlib matplotlib.use('Agg') load_dotenv() image_file_path = os.getenv("IMAGE_FILE_PATH") image_not_found = os.getenv("IMAGE_NOT_FOUND") allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") # Load environment variables groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") model_name = os.getenv("GROQ_LLM_MODEL") class CsvUrlRequest(BaseModel): csv_url: str class ImageRequest(BaseModel): image_path: str class CsvCommonHeadersRequest(BaseModel): file_urls: list[str] class CsvsMergeRequest(BaseModel): file_urls: list[str] merge_type: str common_columns_name: list[str] # Thread-safe key management for groq_chat current_groq_key_index = 0 current_groq_key_lock = threading.Lock() # Thread-safe key management for langchain_csv_chat current_langchain_key_index = 0 current_langchain_key_lock = threading.Lock() # CHAT CODING STARTS FROM HERE def handle_out_of_range_float(value): if isinstance(value, float): if np.isnan(value): return None elif np.isinf(value): return "Infinity" return value # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Modified groq_chat function with thread-safe key rotation def groq_chat(csv_url: str, question: str): global current_groq_key_index, current_groq_key_lock while True: with current_groq_key_lock: if current_groq_key_index >= len(groq_api_keys): return {"error": "All API keys exhausted."} current_api_key = groq_api_keys[current_groq_key_index] try: # Delete cache file if exists cache_db_path = "/workspace/cache/cache_db_0.11.db" if os.path.exists(cache_db_path): try: os.remove(cache_db_path) except Exception as e: logger.info(f"Error deleting cache DB file: {e}") data = clean_data(csv_url) llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate unique filename using UUID chart_filename = f"chart_{uuid.uuid4()}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure SmartDataframe with chart settings df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, # Enable chart saving 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), # Directory to save 'custom_chart_filename': chart_filename # Unique filename } ) answer = df.chat(question) # Process different response types if isinstance(answer, pd.DataFrame): processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") elif isinstance(answer, pd.Series): processed = answer.apply(handle_out_of_range_float).to_dict() elif isinstance(answer, list): processed = [handle_out_of_range_float(item) for item in answer] elif isinstance(answer, dict): processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} else: processed = {"answer": str(handle_out_of_range_float(answer))} return processed except Exception as e: error_message = str(e) if "429" in error_message: with current_groq_key_lock: current_groq_key_index += 1 if current_groq_key_index >= len(groq_api_keys): logger.info("All API keys exhausted.") return None else: logger.info(f"Error with API key index {current_groq_key_index}: {error_message}") return None # Modified langchain_csv_chat with thread-safe key rotation def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): global current_langchain_key_index, current_langchain_key_lock data = clean_data(csv_url) attempts = 0 while attempts < len(groq_api_keys): with current_langchain_key_lock: if current_langchain_key_index >= len(groq_api_keys): current_langchain_key_index = 0 api_key = groq_api_keys[current_langchain_key_index] current_key = current_langchain_key_index current_langchain_key_index += 1 attempts += 1 try: llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib }) agent = create_pandas_dataframe_agent( llm, data, agent_type="tool-calling", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) prompt = _prompt_generator(question, chart_required, csv_url) result = agent.invoke({"input": prompt}) return result.get("output") except Exception as e: logger.info(f"Error with key index {current_key}: {str(e)}") # If all keys are exhausted, return None logger.info("All API keys have been exhausted.") return None def handle_out_of_range_float(value): if isinstance(value, float): if np.isnan(value): return None elif np.isinf(value): return "Infinity" return value # CHART CODING STARTS FROM HERE instructions = """ - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) - Use colorblind-friendly palette - Read above instructions and follow them. """ # Thread-safe configuration for chart endpoints current_groq_chart_key_index = 0 current_groq_chart_lock = threading.Lock() current_langchain_chart_key_index = 0 current_langchain_chart_lock = threading.Lock() def model(): global current_groq_chart_key_index, current_groq_chart_lock with current_groq_chart_lock: if current_groq_chart_key_index >= len(groq_api_keys): raise Exception("All API keys exhausted for chart generation") api_key = groq_api_keys[current_groq_chart_key_index] return ChatGroq(model=model_name, api_key=api_key) def groq_chart(csv_url: str, question: str): global current_groq_chart_key_index, current_groq_chart_lock for attempt in range(len(groq_api_keys)): try: # Clean cache before processing cache_db_path = "/workspace/cache/cache_db_0.11.db" if os.path.exists(cache_db_path): try: os.remove(cache_db_path) except Exception as e: logger.info(f"Cache cleanup error: {e}") data = clean_data(csv_url) with current_groq_chart_lock: current_api_key = groq_api_keys[current_groq_chart_key_index] llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate unique filename using UUID chart_filename = f"chart_{uuid.uuid4()}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure SmartDataframe with chart settings df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, # Enable chart saving 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), # Directory to save 'custom_chart_filename': chart_filename # Unique filename } ) answer = df.chat(question + instructions) if process_answer(answer): return "Chart not generated" return answer except Exception as e: error = str(e) if "429" in error: with current_groq_chart_lock: current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) else: logger.info(f"Chart generation error: {error}") return {"error": error} logger.info("All API keys exhausted for chart generation") return None def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): global current_langchain_chart_key_index, current_langchain_chart_lock data = clean_data(csv_url) for attempt in range(len(groq_api_keys)): try: with current_langchain_chart_lock: api_key = groq_api_keys[current_langchain_chart_key_index] current_key = current_langchain_chart_key_index current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib, "uuid": uuid }) agent = create_pandas_dataframe_agent( llm, data, agent_type="tool-calling", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) result = agent.invoke({"input": _prompt_generator(question, True, csv_url)}) output = result.get("output", "") # Verify chart file creation chart_files = extract_chart_filenames(output) if len(chart_files) > 0: return chart_files if attempt < len(groq_api_keys) - 1: logger.info(f"Langchain chart error (key {current_key}): {output}") except Exception as e: logger.info(f"Langchain chart error (key {current_key}): {str(e)}") logger.info("All API keys exhausted for chart generation") return None ########################################################################################################################### # async def csv_chart(csv_url: str, query: str): # """ # Generate a chart based on the provided CSV URL and query. # Parameters: # - csv_url (str): The URL of the CSV file. # - query (str): The query for generating the chart. # Returns: # - dict: A dictionary containing the generated chart image URL. # Example: # - csv_url: "https://example.com/data.csv" # - query: "Generate a bar chart showing sales by region." # Returns: # - dict: {"image_url": "https://example.com/chart.png"}. # """ # try: # # First try Groq-based chart generation # try: # groq_result = await asyncio.to_thread(groq_chart, csv_url, query) # logger.info(f"Generated Chart (Groq): {groq_result}") # if groq_result != 'Chart not generated': # unique_file_name = f'{str(uuid.uuid4())}.png' # image_public_url = await upload_file_to_supabase(groq_result, unique_file_name) # logger.info(f"Image uploaded to Supabase: {image_public_url}") # return {"image_url": image_public_url} # except Exception as groq_error: # logger.info(f"Groq chart generation failed, falling back to Langchain: {str(groq_error)}") # # Fallback to Langchain if Groq fails # try: # langchain_paths = await asyncio.to_thread(langchain_csv_chart, csv_url, query, True) # logger.info("Fallback langchain chart result:", langchain_paths) # if isinstance(langchain_paths, list) and len(langchain_paths) > 0: # unique_file_name = f'{str(uuid.uuid4())}.png' # logger.info("Uploading the chart to supabase...") # image_public_url = await upload_file_to_supabase(langchain_paths[0], unique_file_name) # logger.info("Image uploaded to Supabase and Image URL is... ", image_public_url) # return {"image_url": image_public_url} # except Exception as langchain_error: # logger.info(f"Langchain chart generation also failed: {str(langchain_error)}") # try: # # Last resort: Try with the gemini langchain agent # logger.info("Trying with the gemini langchain agent...") # lc_gemini_chart_result = await asyncio.to_thread(langchain_gemini_csv_handler, csv_url, query, True) # if lc_gemini_chart_result is not None: # clean_path = lc_gemini_chart_result.strip() # unique_file_name = f'{str(uuid.uuid4())}.png' # logger.info("Uploading the chart to supabase...") # image_public_url = await upload_file_to_supabase(clean_path, unique_file_name) # logger.info("Image uploaded to Supabase and Image URL is... ", image_public_url) # return {"image_url": image_public_url} # except Exception as gemini_error: # logger.info(f"Gemini Langchain chart generation also failed: {str(gemini_error)}") # # If both methods fail # return {"error": "Could not generate the chart, please try again."} # except Exception as e: # logger.info(f"Critical chart error: {str(e)}") # return {"error": "Internal system error"} # async def csv_chat(csv_url: str, query: str): # """ # Generate a response based on the provided CSV URL and query. # Parameters: # - csv_url (str): The URL of the CSV file. # - query (str): The query for generating the response. # Returns: # - dict: A dictionary containing the generated response. # Example: # - csv_url: "https://example.com/data.csv" # - query: "What is the total sales for the year 2022?" # Returns: # - dict: {"answer": "The total sales for 2022 is $100,000."}. # """ # try: # updated_query = f"{query} and Do not show any charts or graphs." # # Process with Groq first # try: # groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query) # logger.info("groq_answer:", groq_answer) # if process_answer(groq_answer) == "Empty response received." or groq_answer == None: # return {"answer": "Sorry, I couldn't find relevant data..."} # if process_answer(groq_answer) or groq_answer == None: # raise Exception("Groq response not usable, falling back to LangChain") # return {"answer": jsonable_encoder(groq_answer)} # except Exception as groq_error: # logger.info(f"Groq error, falling back to LangChain: {str(groq_error)}") # # Process with LangChain if Groq fails # try: # lang_answer = await asyncio.to_thread( # langchain_csv_chat, csv_url, query, False # ) # if not process_answer(lang_answer): # return {"answer": jsonable_encoder(lang_answer)} # return {"answer": "Sorry, I couldn't find relevant data..."} # except Exception as langchain_error: # logger.info(f"LangChain processing error: {str(langchain_error)}") # # last resort: Try with the gemini langchain agent # try: # gemini_answer = await asyncio.to_thread( # langchain_gemini_csv_handler, csv_url, query, False # ) # if not process_answer(gemini_answer): # return {"answer": jsonable_encoder(gemini_answer)} # return {"answer": "Sorry, I couldn't find relevant data..."} # except Exception as gemini_error: # logger.info(f"Gemini Langchain processing error: {str(gemini_error)}") # return {"answer": "error"} # except Exception as e: # logger.info(f"Error processing request: {str(e)}") # return {"answer": "error"} ####################################### Start with lc_gemini ####################################### async def csv_chat(csv_url: str, query: str): """ Generate a response based on the provided CSV URL and query. Prioritizes LangChain-Groq, then raw Groq, and finally LangChain-Gemini as fallback. Parameters: - csv_url (str): The URL of the CSV file. - query (str): The query for generating the response. Returns: - dict: A dictionary containing the generated response. Example: - csv_url: "https://example.com/data.csv" - query: "What is the total sales for the year 2022?" Returns: - dict: {"answer": "The total sales for 2022 is $100,000."} """ try: updated_query = f"{query} and Do not show any charts or graphs." # --- 1. First Attempt: LangChain Groq --- try: lang_groq_answer = await asyncio.to_thread( langchain_csv_chat, csv_url, updated_query, False ) logger.info("LangChain-Groq answer:", lang_groq_answer) if lang_groq_answer is not None: return {"answer": jsonable_encoder(lang_groq_answer)} raise Exception("LangChain-Groq response not usable, falling back to raw Groq") except Exception as lang_groq_error: logger.info(f"LangChain-Groq error: {str(lang_groq_error)}") # --- 2. Second Attempt: Raw Groq Chat --- try: raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query) logger.info("Raw Groq answer:", raw_groq_answer) if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None: raise Exception("Raw Groq response not usable, falling back to LangChain-Gemini") if process_answer(raw_groq_answer): raise Exception("Raw Groq response not usable, falling back to LangChain-Gemini") return {"answer": jsonable_encoder(raw_groq_answer)} except Exception as raw_groq_error: logger.info(f"Raw Groq error: {str(raw_groq_error)}") # --- 3. Final Attempt: LangChain Gemini --- try: gemini_answer = await asyncio.to_thread( langchain_gemini_csv_handler, csv_url, updated_query, False ) logger.info("LangChain-Gemini answer:", gemini_answer) if gemini_answer is not None: return {"answer": jsonable_encoder(gemini_answer)} raise Exception("All fallbacks exhausted") except Exception as gemini_error: logger.info(f"LangChain-Gemini error: {str(gemini_error)}") return {"answer": "Sorry, I couldn't find relevant data..."} except Exception as e: logger.info(f"Unexpected error: {str(e)}") return {"answer": "error"} async def csv_chart(csv_url: str, query: str, chat_id: str): """ Generate a chart based on the provided CSV URL and query. Prioritizes PandasAI Groq, then LangChain Gemini, and finally LangChain Groq as fallback. Parameters: - csv_url (str): The URL of the CSV file. - query (str): The query for generating the chart. Returns: - dict: A dictionary containing either: - {"image_url": "https://example.com/chart.png"} on success, or - {"error": "error message"} on failure Example: - csv_url: "https://example.com/data.csv" - query: "Show sales trends as a line chart" Returns: - dict: {"image_url": "https://storage.example.com/chart_uuid.png"} """ async def upload_and_return(image_path: str, chat_id: str) -> dict: """Helper function to handle image uploads""" unique_name = f'{uuid.uuid4()}.png' public_url = await upload_file_to_supabase(image_path, unique_name, chat_id) logger.info(f"Uploaded chart: {public_url}") os.remove(image_path) # Remove the local image file after upload return {"image_url": public_url} try: # Commented out for now because aiml api is not working # try: # # --- 1. First Attempt: OpenAI --- # openai_result = await asyncio.to_thread(openai_chart, csv_url, query) # logger.info(f"OpenAI chart result:", openai_result) # if openai_result and openai_result != 'Chart not generated': # return await upload_and_return(openai_result, chat_id) # raise Exception("OpenAI failed to generate chart") # except Exception as openai_error: # logger.info(f"OpenAI failed ({str(openai_error)}), trying raw Groq...") # --- 2. Second Attempt: Raw Groq --- try: groq_result = await asyncio.to_thread(groq_chart, csv_url, query) logger.info(f"Raw Groq chart result:", groq_result) if groq_result and groq_result != 'Chart not generated': return await upload_and_return(groq_result, chat_id) raise Exception("Raw Groq failed to generate chart") except Exception as groq_error: logger.info(f"Raw Groq failed ({str(groq_error)}), trying LangChain Gemini...") # --- 3. Third Attempt: LangChain Gemini --- try: gemini_result = await asyncio.to_thread( langchain_gemini_csv_handler, csv_url, query, True ) logger.info("LangChain Gemini chart result:", gemini_result) # --- i) If Gemini result is a string, return it --- if gemini_result and isinstance(gemini_result, str): clean_path = gemini_result.strip() return await upload_and_return(clean_path, chat_id) # --- ii) If Gemini result is a list, return the first element --- if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0: return await upload_and_return(gemini_result[0], chat_id) raise Exception("LangChain Gemini returned empty result") except Exception as gemini_error: logger.info(f"LangChain Gemini failed ({str(gemini_error)}), trying LangChain Groq...") # --- 4. Final Attempt: LangChain Groq --- try: lc_groq_paths = await asyncio.to_thread( langchain_csv_chart, csv_url, query, True ) logger.info("LangChain Groq chart result:", lc_groq_paths) if isinstance(lc_groq_paths, list) and lc_groq_paths: return await upload_and_return(lc_groq_paths[0], chat_id) return {"error": "All chart generation methods failed"} except Exception as lc_groq_error: logger.info(f"LangChain Groq failed: {str(lc_groq_error)}") return {"error": "Could not generate chart"} except Exception as e: logger.info(f"Critical error: {str(e)}") return {"error": "Internal system error"}