# Import necessary modules from concurrent.futures import ProcessPoolExecutor import os import asyncio import threading import uuid from fastapi import FastAPI, HTTPException, Header from fastapi.encoders import jsonable_encoder from typing import Dict from fastapi.responses import FileResponse import numpy as np import pandas as pd from pandasai import SmartDataframe from langchain_groq.chat_models import ChatGroq from dotenv import load_dotenv from pydantic import BaseModel from csv_service import clean_data, extract_chart_filenames from urllib.parse import unquote import csv_service from langchain_groq import ChatGroq import pandas as pd from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent import numpy as np import matplotlib.pyplot as plt import matplotlib import seaborn as sns from intitial_q_handler import if_initial_chart_question, if_initial_chat_question from util_service import _prompt_generator, process_answer from fastapi.middleware.cors import CORSMiddleware import matplotlib matplotlib.use('Agg') # Initialize FastAPI app app = FastAPI() # Ensure the cache directory exists os.makedirs("/app/cache", exist_ok=True) os.makedirs("/app", exist_ok=True) open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist # Ensure the generated_charts directory exists os.makedirs("/app/generated_charts", exist_ok=True) load_dotenv() image_file_path = os.getenv("IMAGE_FILE_PATH") image_not_found = os.getenv("IMAGE_NOT_FOUND") allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") app.add_middleware( CORSMiddleware, allow_origins=allowed_hosts, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load environment variables groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") model_name = os.getenv("GROQ_LLM_MODEL") class CsvUrlRequest(BaseModel): csv_url: str class ImageRequest(BaseModel): image_path: str # Thread-safe key management for groq_chat current_groq_key_index = 0 current_groq_key_lock = threading.Lock() # Thread-safe key management for langchain_csv_chat current_langchain_key_index = 0 current_langchain_key_lock = threading.Lock() # PING CHECK @app.get("/ping") async def root(): return {"message": "Pong !!"} # BASIC KNOWLEDGE BASED ON CSV # Remove trailing slash from the URL otherwise it will redirect to GET method @app.post("/api/basic_csv_data") async def basic_csv_data(request: CsvUrlRequest): try: decoded_url = unquote(request.csv_url) print(f"Fetching CSV data from URL: {decoded_url}") csv_data = csv_service.get_csv_basic_info(decoded_url) print(f"CSV data fetched successfully: {csv_data}") return {"data": csv_data} except Exception as e: print(f"Error while fetching CSV data: {e}") raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") # GET THE CHART FROM A SPECIFIC FILE PATH @app.post("/api/get-chart") async def get_image(request: ImageRequest, authorization: str = Header(None)): if not authorization: raise HTTPException(status_code=401, detail="Authorization header missing") if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid authorization header format") token = authorization.split(" ")[1] if not token: raise HTTPException(status_code=401, detail="Token missing") if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid token") try: image_file_path = request.image_path return FileResponse(image_file_path, media_type="image/png") except Exception as e: print(f"Error: {e}") return {"answer": "error"} # GET CSV DATA FOR GENERATING THE TABLE @app.post("/api/csv_data") async def get_csv_data(request: CsvUrlRequest): try: decoded_url = unquote(request.csv_url) # print(f"Fetching CSV data from URL: {decoded_url}") csv_data = csv_service.generate_csv_data(decoded_url) return csv_data except Exception as e: # print(f"Error while fetching CSV data: {e}") raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") # CHAT CODING STARTS FROM HERE # Modified groq_chat function with thread-safe key rotation def groq_chat(csv_url: str, question: str): global current_groq_key_index, current_groq_key_lock while True: with current_groq_key_lock: if current_groq_key_index >= len(groq_api_keys): return {"error": "All API keys exhausted."} current_api_key = groq_api_keys[current_groq_key_index] try: # Delete cache file if exists cache_db_path = "/workspace/cache/cache_db_0.11.db" if os.path.exists(cache_db_path): try: os.remove(cache_db_path) except Exception as e: print(f"Error deleting cache DB file: {e}") data = clean_data(csv_url) llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate unique filename using UUID chart_filename = f"chart_{uuid.uuid4()}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure SmartDataframe with chart settings df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, # Enable chart saving 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), # Directory to save 'custom_chart_filename': chart_filename # Unique filename } ) answer = df.chat(question) # Process different response types if isinstance(answer, pd.DataFrame): processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") elif isinstance(answer, pd.Series): processed = answer.apply(handle_out_of_range_float).to_dict() elif isinstance(answer, list): processed = [handle_out_of_range_float(item) for item in answer] elif isinstance(answer, dict): processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} else: processed = {"answer": str(handle_out_of_range_float(answer))} return processed except Exception as e: error_message = str(e) if "429" in error_message: with current_groq_key_lock: current_groq_key_index += 1 if current_groq_key_index >= len(groq_api_keys): return {"error": "All API keys exhausted."} else: return {"error": error_message} # Modified langchain_csv_chat with thread-safe key rotation def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): global current_langchain_key_index, current_langchain_key_lock data = clean_data(csv_url) attempts = 0 while attempts < len(groq_api_keys): with current_langchain_key_lock: if current_langchain_key_index >= len(groq_api_keys): current_langchain_key_index = 0 api_key = groq_api_keys[current_langchain_key_index] current_key = current_langchain_key_index current_langchain_key_index += 1 attempts += 1 try: llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib }) agent = create_pandas_dataframe_agent( llm, data, agent_type="openai-tools", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) prompt = _prompt_generator(question, chart_required) result = agent.invoke({"input": prompt}) return result.get("output") except Exception as e: print(f"Error with key index {current_key}: {str(e)}") return {"error": "All API keys exhausted"} # Async endpoint with non-blocking execution @app.post("/api/csv-chat") async def csv_chat(request: Dict, authorization: str = Header(None)): # Authorization checks if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid authorization") token = authorization.split(" ")[1] if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid token") try: query = request.get("query") csv_url = request.get("csv_url") decoded_url = unquote(csv_url) if if_initial_chat_question(query): answer = await asyncio.to_thread( langchain_csv_chat, decoded_url, query, False ) print("langchain_answer:", answer) return {"answer": jsonable_encoder(answer)} # Process with groq_chat first groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query) print("groq_answer:", groq_answer) if process_answer(groq_answer) == "Empty response received.": return {"answer": "Sorry, I couldn't find relevant data..."} if process_answer(groq_answer): lang_answer = await asyncio.to_thread( langchain_csv_chat, decoded_url, query, False ) if process_answer(lang_answer): return {"answer": "error"} return {"answer": jsonable_encoder(lang_answer)} return {"answer": jsonable_encoder(groq_answer)} except Exception as e: print(f"Error processing request: {str(e)}") return {"answer": "error"} def handle_out_of_range_float(value): if isinstance(value, float): if np.isnan(value): return None elif np.isinf(value): return "Infinity" return value # CHART CODING STARTS FROM HERE # instructions = """ # - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). # - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) # - Use colorblind-friendly palette # - Read above instructions and follow them. # """ # # Thread-safe configuration for chart endpoints # current_groq_chart_key_index = 0 # current_groq_chart_lock = threading.Lock() # current_langchain_chart_key_index = 0 # current_langchain_chart_lock = threading.Lock() # def model(): # global current_groq_chart_key_index, current_groq_chart_lock # with current_groq_chart_lock: # if current_groq_chart_key_index >= len(groq_api_keys): # raise Exception("All API keys exhausted for chart generation") # api_key = groq_api_keys[current_groq_chart_key_index] # return ChatGroq(model=model_name, api_key=api_key) # def groq_chart(csv_url: str, question: str): # global current_groq_chart_key_index, current_groq_chart_lock # for attempt in range(len(groq_api_keys)): # try: # # Clean cache before processing # cache_db_path = "/workspace/cache/cache_db_0.11.db" # if os.path.exists(cache_db_path): # try: # os.remove(cache_db_path) # except Exception as e: # print(f"Cache cleanup error: {e}") # data = clean_data(csv_url) # with current_groq_chart_lock: # current_api_key = groq_api_keys[current_groq_chart_key_index] # llm = ChatGroq(model=model_name, api_key=current_api_key) # # Generate unique filename using UUID # chart_filename = f"chart_{uuid.uuid4()}.png" # chart_path = os.path.join("generated_charts", chart_filename) # # Configure SmartDataframe with chart settings # df = SmartDataframe( # data, # config={ # 'llm': llm, # 'save_charts': True, # Enable chart saving # 'open_charts': False, # 'save_charts_path': os.path.dirname(chart_path), # Directory to save # 'custom_chart_filename': chart_filename # Unique filename # } # ) # answer = df.chat(question + instructions) # if process_answer(answer): # return "Chart not generated" # return answer # except Exception as e: # error = str(e) # if "429" in error: # with current_groq_chart_lock: # current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) # else: # print(f"Chart generation error: {error}") # return {"error": error} # return {"error": "All API keys exhausted for chart generation"} # def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): # global current_langchain_chart_key_index, current_langchain_chart_lock # data = clean_data(csv_url) # for attempt in range(len(groq_api_keys)): # try: # with current_langchain_chart_lock: # api_key = groq_api_keys[current_langchain_chart_key_index] # current_key = current_langchain_chart_key_index # current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) # llm = ChatGroq(model=model_name, api_key=api_key) # tool = PythonAstREPLTool(locals={ # "df": data, # "pd": pd, # "np": np, # "plt": plt, # "sns": sns, # "matplotlib": matplotlib, # "uuid": uuid # }) # agent = create_pandas_dataframe_agent( # llm, # data, # agent_type="openai-tools", # verbose=True, # allow_dangerous_code=True, # extra_tools=[tool], # return_intermediate_steps=True # ) # result = agent.invoke({"input": _prompt_generator(question, True)}) # output = result.get("output", "") # # Verify chart file creation # chart_files = extract_chart_filenames(output) # if len(chart_files) > 0: # return chart_files # if attempt < len(groq_api_keys) - 1: # print(f"Langchain chart error (key {current_key}): {output}") # except Exception as e: # print(f"Langchain chart error (key {current_key}): {str(e)}") # return "Chart generation failed after all retries" # @app.post("/api/csv-chart") # async def csv_chart(request: dict, authorization: str = Header(None)): # # Authorization verification # if not authorization or not authorization.startswith("Bearer "): # raise HTTPException(status_code=401, detail="Authorization required") # token = authorization.split(" ")[1] # if token != os.getenv("AUTH_TOKEN"): # raise HTTPException(status_code=403, detail="Invalid credentials") # try: # query = request.get("query", "") # csv_url = unquote(request.get("csv_url", "")) # # Parallel processing with thread pool # if if_initial_chart_question(query): # chart_paths = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # print(chart_paths) # if len(chart_paths) > 0: # return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png") # # Groq-based chart generation # groq_result = await asyncio.to_thread(groq_chart, csv_url, query) # print(f"Generated Chart: {groq_result}") # if groq_result != 'Chart not generated': # return FileResponse(groq_result, media_type="image/png") # # Fallback to Langchain # langchain_paths = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # print (langchain_paths) # if len(langchain_paths) > 0: # return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png") # else: # return {"error": "All chart generation methods failed"} # except Exception as e: # print(f"Critical chart error: {str(e)}") # return {"error": "Internal system error"} # MERGED CALL # class CSVData(BaseModel): # csv_url: str # query: str # chart_required: bool # @app.post("/api/v1/csv_chat") # async def csv_chat(csv_data: CSVData, authorization: str = Header(None)): # # Authorization verification # if not authorization or not authorization.startswith("Bearer "): # raise HTTPException(status_code=401, detail="Authorization required") # token = authorization.split(" ")[1] # if token != os.getenv("AUTH_TOKEN"): # raise HTTPException(status_code=403, detail="Invalid credentials") # csv_url = csv_data.csv_url # query = csv_data.query # chart_required = csv_data.chart_required # if(chart_required == True): # try: # # Parallel processing with thread pool # if if_initial_chart_question(query): # chart_path = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # if "temp" in chart_path: # print("langchain chart Generated") # return FileResponse('temp.png', media_type="image/png") # return {"error": "Chart generation failed"} # # Groq-based chart generation # groq_result = await asyncio.to_thread(groq_chart, csv_url, query) # if groq_result == "Chart Generated": # return FileResponse("exports/charts/temp_chart.png") # # Fallback to Langchain # langchain_path = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # if "temp" in langchain_path: # print("langchain chart Generated") # return FileResponse('temp.png', media_type="image/png") # return {"error": "All chart generation methods failed"} # except Exception as e: # print(f"Critical chart error: {str(e)}") # raise HTTPException(status_code=500, detail="Internal server error") # else: # try: # if if_initial_chat_question(query): # answer = await asyncio.to_thread( # langchain_csv_chat, csv_url, query, False # ) # print("langchain_answer:", answer) # return {"answer": jsonable_encoder(answer)} # # Process with groq_chat first # groq_answer = await asyncio.to_thread(groq_chat, csv_url, query) # print("groq_answer:", groq_answer) # if process_answer(groq_answer) == "Empty response received.": # return {"answer": "Sorry, I couldn't find relevant data..."} # if process_answer(groq_answer): # lang_answer = await asyncio.to_thread( # langchain_csv_chat, csv_url, query, False # ) # if process_answer(lang_answer): # return {"answer": "error"} # return {"answer": jsonable_encoder(lang_answer)} # return {"answer": jsonable_encoder(groq_answer)} # except Exception as e: # print(f"Error processing request: {str(e)}") # raise HTTPException(status_code=500, detail="Internal server error") # Global locks for key rotation (chart endpoints) current_groq_chart_key_index = 0 current_groq_chart_lock = threading.Lock() current_langchain_chart_key_index = 0 current_langchain_chart_lock = threading.Lock() max_cpus = os.cpu_count() print("Available CPUs:", max_cpus) # Use a process pool to run CPU-bound chart generation process_executor = ProcessPoolExecutor(max_workers=4) # --- GROQ-BASED CHART GENERATION --- def groq_chart(csv_url: str, question: str): """ Generate a chart using the groq-based method. Modifications: • No deletion of a shared cache file (avoid interference). • After chart generation, close all matplotlib figures. • Return the full path of the saved chart. """ global current_groq_chart_key_index, current_groq_chart_lock for attempt in range(len(groq_api_keys)): try: # Instead of deleting a global cache file, you might later configure a per-request cache. data = clean_data(csv_url) with current_groq_chart_lock: current_api_key = groq_api_keys[current_groq_chart_key_index] llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate a unique filename and full path for the chart chart_filename = f"chart_{uuid.uuid4().hex}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure your dataframe tool (e.g. using SmartDataframe) to save charts. # (Assuming your SmartDataframe uses these settings to save charts.) from pandasai import SmartDataframe # Import here if not already imported df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), 'custom_chart_filename': chart_filename } ) # Append any extra instructions if needed instructions = """ - Ensure each value is clearly visible. - Adjust font sizes, rotate labels if necessary. - Use a colorblind-friendly palette. - Arrange multiple charts in a grid if needed. """ answer = df.chat(question + instructions) # Make sure to close figures so they don't conflict between processes plt.close('all') # If process_answer indicates a problem, return a failure message. if process_answer(answer): return "Chart not generated" # Return the chart path that was used for saving return chart_path except Exception as e: error = str(e) if "429" in error: with current_groq_chart_lock: current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) else: print(f"Groq chart generation error: {error}") return {"error": error} return {"error": "All API keys exhausted for chart generation"} # --- LANGCHAIN-BASED CHART GENERATION --- def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): """ Generate a chart using the langchain-based method. Modifications: • No shared deletion of cache. • Close matplotlib figures after generation. • Return a list of full chart file paths. """ global current_langchain_chart_key_index, current_langchain_chart_lock data = clean_data(csv_url) for attempt in range(len(groq_api_keys)): try: with current_langchain_chart_lock: api_key = groq_api_keys[current_langchain_chart_key_index] current_key = current_langchain_chart_key_index current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib, "uuid": uuid }) agent = create_pandas_dataframe_agent( llm, data, agent_type="openai-tools", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) result = agent.invoke({"input": _prompt_generator(question, True)}) output = result.get("output", "") # Close figures to avoid interference plt.close('all') # Extract chart filenames (assuming extract_chart_filenames returns a list) chart_files = extract_chart_filenames(output) if len(chart_files) > 0: # Return full paths (join with your image_file_path) return [os.path.join(image_file_path, f) for f in chart_files] if attempt < len(groq_api_keys) - 1: print(f"Langchain chart error (key {current_key}): {output}") except Exception as e: print(f"Langchain chart error (key {current_key}): {str(e)}") return "Chart generation failed after all retries" # --- FASTAPI ENDPOINT FOR CHART GENERATION --- @app.post("/api/csv-chart") async def csv_chart(request: dict, authorization: str = Header(None)): """ Endpoint for generating a chart from CSV data. This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation functions in separate processes so that multiple requests can run in parallel. """ # --- Authorization Check --- if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Authorization required") token = authorization.split(" ")[1] if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid credentials") try: query = request.get("query", "") csv_url = unquote(request.get("csv_url", "")) loop = asyncio.get_running_loop() # First, try the langchain-based method if the question qualifies if if_initial_chart_question(query): langchain_result = await loop.run_in_executor( process_executor, langchain_csv_chart, csv_url, query, True ) print("Langchain chart result:", langchain_result) if isinstance(langchain_result, list) and len(langchain_result) > 0: return FileResponse(langchain_result[0], media_type="image/png") # Next, try the groq-based method groq_result = await loop.run_in_executor( process_executor, groq_chart, csv_url, query ) print(f"Groq chart result: {groq_result}") if isinstance(groq_result, str) and groq_result != "Chart not generated": return FileResponse(groq_result, media_type="image/png") # Fallback: try langchain-based again langchain_paths = await loop.run_in_executor( process_executor, langchain_csv_chart, csv_url, query, True ) print("Fallback langchain chart result:", langchain_paths) if isinstance(langchain_paths, list) and len(langchain_paths) > 0: return FileResponse(langchain_paths[0], media_type="image/png") else: return {"error": "All chart generation methods failed"} except Exception as e: print(f"Critical chart error: {str(e)}") return {"error": "Internal system error"}