# Import necessary modules import os import asyncio import threading import uuid from fastapi import FastAPI, HTTPException, Header from fastapi.encoders import jsonable_encoder from typing import Dict from fastapi.responses import FileResponse import numpy as np import pandas as pd from pandasai import SmartDataframe from langchain_groq.chat_models import ChatGroq from dotenv import load_dotenv from pydantic import BaseModel from csv_service import clean_data, extract_chart_filenames from urllib.parse import unquote import csv_service from langchain_groq import ChatGroq import pandas as pd from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent import numpy as np import matplotlib.pyplot as plt import matplotlib import seaborn as sns from intitial_q_handler import if_initial_chart_question, if_initial_chat_question from util_service import _prompt_generator, process_answer from fastapi.middleware.cors import CORSMiddleware import matplotlib matplotlib.use('Agg') # Initialize FastAPI app app = FastAPI() # Ensure the cache directory exists os.makedirs("/app/cache", exist_ok=True) os.makedirs("/app", exist_ok=True) open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist # Ensure the generated_charts directory exists os.makedirs("/app/generated_charts", exist_ok=True) load_dotenv() image_file_path = os.getenv("IMAGE_FILE_PATH") image_not_found = os.getenv("IMAGE_NOT_FOUND") allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") app.add_middleware( CORSMiddleware, allow_origins=allowed_hosts, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load environment variables groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") model_name = os.getenv("GROQ_LLM_MODEL") class CsvUrlRequest(BaseModel): csv_url: str class ImageRequest(BaseModel): image_path: str # Thread-safe key management for groq_chat current_groq_key_index = 0 current_groq_key_lock = threading.Lock() # Thread-safe key management for langchain_csv_chat current_langchain_key_index = 0 current_langchain_key_lock = threading.Lock() # PING CHECK @app.get("/ping") async def root(): return {"message": "Pong !!"} # BASIC KNOWLEDGE BASED ON CSV # Remove trailing slash from the URL otherwise it will redirect to GET method @app.post("/api/basic_csv_data") async def basic_csv_data(request: CsvUrlRequest): try: decoded_url = unquote(request.csv_url) print(f"Fetching CSV data from URL: {decoded_url}") csv_data = csv_service.get_csv_basic_info(decoded_url) print(f"CSV data fetched successfully: {csv_data}") return {"data": csv_data} except Exception as e: print(f"Error while fetching CSV data: {e}") raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") # GET THE CHART FROM A SPECIFIC FILE PATH @app.post("/api/get-chart") async def get_image(request: ImageRequest, authorization: str = Header(None)): if not authorization: raise HTTPException(status_code=401, detail="Authorization header missing") if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid authorization header format") token = authorization.split(" ")[1] if not token: raise HTTPException(status_code=401, detail="Token missing") if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid token") try: image_file_path = request.image_path return FileResponse(image_file_path, media_type="image/png") except Exception as e: print(f"Error: {e}") return {"answer": "error"} # GET CSV DATA FOR GENERATING THE TABLE @app.post("/api/csv_data") async def get_csv_data(request: CsvUrlRequest): try: decoded_url = unquote(request.csv_url) # print(f"Fetching CSV data from URL: {decoded_url}") csv_data = csv_service.generate_csv_data(decoded_url) return csv_data except Exception as e: # print(f"Error while fetching CSV data: {e}") raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") # CHAT CODING STARTS FROM HERE # Modified groq_chat function with thread-safe key rotation def groq_chat(csv_url: str, question: str): global current_groq_key_index, current_groq_key_lock while True: with current_groq_key_lock: if current_groq_key_index >= len(groq_api_keys): return {"error": "All API keys exhausted."} current_api_key = groq_api_keys[current_groq_key_index] try: # Delete cache file if exists cache_db_path = "/workspace/cache/cache_db_0.11.db" if os.path.exists(cache_db_path): try: os.remove(cache_db_path) except Exception as e: print(f"Error deleting cache DB file: {e}") data = clean_data(csv_url) llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate unique filename using UUID chart_filename = f"chart_{uuid.uuid4()}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure SmartDataframe with chart settings df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, # Enable chart saving 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), # Directory to save 'custom_chart_filename': chart_filename # Unique filename } ) answer = df.chat(question) # Process different response types if isinstance(answer, pd.DataFrame): processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") elif isinstance(answer, pd.Series): processed = answer.apply(handle_out_of_range_float).to_dict() elif isinstance(answer, list): processed = [handle_out_of_range_float(item) for item in answer] elif isinstance(answer, dict): processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} else: processed = {"answer": str(handle_out_of_range_float(answer))} return processed except Exception as e: error_message = str(e) if "429" in error_message: with current_groq_key_lock: current_groq_key_index += 1 if current_groq_key_index >= len(groq_api_keys): return {"error": "All API keys exhausted."} else: return {"error": error_message} # Modified langchain_csv_chat with thread-safe key rotation def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): global current_langchain_key_index, current_langchain_key_lock data = clean_data(csv_url) attempts = 0 while attempts < len(groq_api_keys): with current_langchain_key_lock: if current_langchain_key_index >= len(groq_api_keys): current_langchain_key_index = 0 api_key = groq_api_keys[current_langchain_key_index] current_key = current_langchain_key_index current_langchain_key_index += 1 attempts += 1 try: llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib }) agent = create_pandas_dataframe_agent( llm, data, agent_type="openai-tools", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) prompt = _prompt_generator(question, chart_required) result = agent.invoke({"input": prompt}) return result.get("output") except Exception as e: print(f"Error with key index {current_key}: {str(e)}") return {"error": "All API keys exhausted"} # Async endpoint with non-blocking execution @app.post("/api/csv-chat") async def csv_chat(request: Dict, authorization: str = Header(None)): # Authorization checks if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid authorization") token = authorization.split(" ")[1] if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid token") try: query = request.get("query") csv_url = request.get("csv_url") decoded_url = unquote(csv_url) if if_initial_chat_question(query): answer = await asyncio.to_thread( langchain_csv_chat, decoded_url, query, False ) print("langchain_answer:", answer) return {"answer": jsonable_encoder(answer)} # Process with groq_chat first groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query) print("groq_answer:", groq_answer) if process_answer(groq_answer) == "Empty response received.": return {"answer": "Sorry, I couldn't find relevant data..."} if process_answer(groq_answer): lang_answer = await asyncio.to_thread( langchain_csv_chat, decoded_url, query, False ) if process_answer(lang_answer): return {"answer": "error"} return {"answer": jsonable_encoder(lang_answer)} return {"answer": jsonable_encoder(groq_answer)} except Exception as e: print(f"Error processing request: {str(e)}") return {"answer": "error"} def handle_out_of_range_float(value): if isinstance(value, float): if np.isnan(value): return None elif np.isinf(value): return "Infinity" return value # CHART CODING STARTS FROM HERE # instructions = """ # - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). # - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) # - Use colorblind-friendly palette # - Read above instructions and follow them. # """ # # Thread-safe configuration for chart endpoints # current_groq_chart_key_index = 0 # current_groq_chart_lock = threading.Lock() # current_langchain_chart_key_index = 0 # current_langchain_chart_lock = threading.Lock() # def model(): # global current_groq_chart_key_index, current_groq_chart_lock # with current_groq_chart_lock: # if current_groq_chart_key_index >= len(groq_api_keys): # raise Exception("All API keys exhausted for chart generation") # api_key = groq_api_keys[current_groq_chart_key_index] # return ChatGroq(model=model_name, api_key=api_key) # def groq_chart(csv_url: str, question: str): # global current_groq_chart_key_index, current_groq_chart_lock # for attempt in range(len(groq_api_keys)): # try: # # Clean cache before processing # cache_db_path = "/workspace/cache/cache_db_0.11.db" # if os.path.exists(cache_db_path): # try: # os.remove(cache_db_path) # except Exception as e: # print(f"Cache cleanup error: {e}") # data = clean_data(csv_url) # with current_groq_chart_lock: # current_api_key = groq_api_keys[current_groq_chart_key_index] # llm = ChatGroq(model=model_name, api_key=current_api_key) # # Generate unique filename using UUID # chart_filename = f"chart_{uuid.uuid4()}.png" # chart_path = os.path.join("generated_charts", chart_filename) # # Configure SmartDataframe with chart settings # df = SmartDataframe( # data, # config={ # 'llm': llm, # 'save_charts': True, # Enable chart saving # 'open_charts': False, # 'save_charts_path': os.path.dirname(chart_path), # Directory to save # 'custom_chart_filename': chart_filename # Unique filename # } # ) # answer = df.chat(question + instructions) # if process_answer(answer): # return "Chart not generated" # return answer # except Exception as e: # error = str(e) # if "429" in error: # with current_groq_chart_lock: # current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) # else: # print(f"Chart generation error: {error}") # return {"error": error} # return {"error": "All API keys exhausted for chart generation"} # def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): # global current_langchain_chart_key_index, current_langchain_chart_lock # data = clean_data(csv_url) # for attempt in range(len(groq_api_keys)): # try: # with current_langchain_chart_lock: # api_key = groq_api_keys[current_langchain_chart_key_index] # current_key = current_langchain_chart_key_index # current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) # llm = ChatGroq(model=model_name, api_key=api_key) # tool = PythonAstREPLTool(locals={ # "df": data, # "pd": pd, # "np": np, # "plt": plt, # "sns": sns, # "matplotlib": matplotlib, # "uuid": uuid # }) # agent = create_pandas_dataframe_agent( # llm, # data, # agent_type="openai-tools", # verbose=True, # allow_dangerous_code=True, # extra_tools=[tool], # return_intermediate_steps=True # ) # result = agent.invoke({"input": _prompt_generator(question, True)}) # output = result.get("output", "") # # Verify chart file creation # chart_files = extract_chart_filenames(output) # if len(chart_files) > 0: # return chart_files # if attempt < len(groq_api_keys) - 1: # print(f"Langchain chart error (key {current_key}): {output}") # except Exception as e: # print(f"Langchain chart error (key {current_key}): {str(e)}") # return "Chart generation failed after all retries" # @app.post("/api/csv-chart") # async def csv_chart(request: dict, authorization: str = Header(None)): # # Authorization verification # if not authorization or not authorization.startswith("Bearer "): # raise HTTPException(status_code=401, detail="Authorization required") # token = authorization.split(" ")[1] # if token != os.getenv("AUTH_TOKEN"): # raise HTTPException(status_code=403, detail="Invalid credentials") # try: # query = request.get("query", "") # csv_url = unquote(request.get("csv_url", "")) # # Parallel processing with thread pool # if if_initial_chart_question(query): # chart_paths = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # print(chart_paths) # if len(chart_paths) > 0: # return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png") # # Groq-based chart generation # groq_result = await asyncio.to_thread(groq_chart, csv_url, query) # print(f"Generated Chart: {groq_result}") # if groq_result != 'Chart not generated': # return FileResponse(groq_result, media_type="image/png") # # Fallback to Langchain # langchain_paths = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # print (langchain_paths) # if len(langchain_paths) > 0: # return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png") # else: # return {"error": "All chart generation methods failed"} # except Exception as e: # print(f"Critical chart error: {str(e)}") # return {"error": "Internal system error"} # MERGED CALL # class CSVData(BaseModel): # csv_url: str # query: str # chart_required: bool # @app.post("/api/v1/csv_chat") # async def csv_chat(csv_data: CSVData, authorization: str = Header(None)): # # Authorization verification # if not authorization or not authorization.startswith("Bearer "): # raise HTTPException(status_code=401, detail="Authorization required") # token = authorization.split(" ")[1] # if token != os.getenv("AUTH_TOKEN"): # raise HTTPException(status_code=403, detail="Invalid credentials") # csv_url = csv_data.csv_url # query = csv_data.query # chart_required = csv_data.chart_required # if(chart_required == True): # try: # # Parallel processing with thread pool # if if_initial_chart_question(query): # chart_path = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # if "temp" in chart_path: # print("langchain chart Generated") # return FileResponse('temp.png', media_type="image/png") # return {"error": "Chart generation failed"} # # Groq-based chart generation # groq_result = await asyncio.to_thread(groq_chart, csv_url, query) # if groq_result == "Chart Generated": # return FileResponse("exports/charts/temp_chart.png") # # Fallback to Langchain # langchain_path = await asyncio.to_thread( # langchain_csv_chart, csv_url, query, True # ) # if "temp" in langchain_path: # print("langchain chart Generated") # return FileResponse('temp.png', media_type="image/png") # return {"error": "All chart generation methods failed"} # except Exception as e: # print(f"Critical chart error: {str(e)}") # raise HTTPException(status_code=500, detail="Internal server error") # else: # try: # if if_initial_chat_question(query): # answer = await asyncio.to_thread( # langchain_csv_chat, csv_url, query, False # ) # print("langchain_answer:", answer) # return {"answer": jsonable_encoder(answer)} # # Process with groq_chat first # groq_answer = await asyncio.to_thread(groq_chat, csv_url, query) # print("groq_answer:", groq_answer) # if process_answer(groq_answer) == "Empty response received.": # return {"answer": "Sorry, I couldn't find relevant data..."} # if process_answer(groq_answer): # lang_answer = await asyncio.to_thread( # langchain_csv_chat, csv_url, query, False # ) # if process_answer(lang_answer): # return {"answer": "error"} # return {"answer": jsonable_encoder(lang_answer)} # return {"answer": jsonable_encoder(groq_answer)} # except Exception as e: # print(f"Error processing request: {str(e)}") # raise HTTPException(status_code=500, detail="Internal server error") import os import asyncio import threading import uuid from fastapi import FastAPI, HTTPException, Header from fastapi.responses import FileResponse from urllib.parse import unquote from pydantic import BaseModel from concurrent.futures import ProcessPoolExecutor import matplotlib.pyplot as plt import matplotlib import pandas as pd import numpy as np import seaborn as sns # Import your custom modules (assumed available) from csv_service import clean_data, extract_chart_filenames from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent from langchain_groq import ChatGroq from util_service import _prompt_generator, process_answer from intitial_q_handler import if_initial_chart_question # Use non-interactive backend matplotlib.use('Agg') # FastAPI app initialization app = FastAPI() # Environment variables and configuration import os groq_api_keys = os.getenv("GROQ_API_KEYS", "").split(",") model_name = os.getenv("GROQ_LLM_MODEL") image_file_path = os.getenv("IMAGE_FILE_PATH") # e.g. "/app/generated_charts" # Global locks for key rotation (chart endpoints) current_groq_chart_key_index = 0 current_groq_chart_lock = threading.Lock() current_langchain_chart_key_index = 0 current_langchain_chart_lock = threading.Lock() # Use a process pool to run CPU-bound chart generation process_executor = ProcessPoolExecutor(max_workers=2) # --- GROQ-BASED CHART GENERATION --- def groq_chart(csv_url: str, question: str): """ Generate a chart using the groq-based method. Modifications: • No deletion of a shared cache file (avoid interference). • After chart generation, close all matplotlib figures. • Return the full path of the saved chart. """ global current_groq_chart_key_index, current_groq_chart_lock for attempt in range(len(groq_api_keys)): try: # Instead of deleting a global cache file, you might later configure a per-request cache. data = clean_data(csv_url) with current_groq_chart_lock: current_api_key = groq_api_keys[current_groq_chart_key_index] llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate a unique filename and full path for the chart chart_filename = f"chart_{uuid.uuid4().hex}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure your dataframe tool (e.g. using SmartDataframe) to save charts. # (Assuming your SmartDataframe uses these settings to save charts.) from pandasai import SmartDataframe # Import here if not already imported df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), 'custom_chart_filename': chart_filename } ) # Append any extra instructions if needed instructions = """ - Ensure each value is clearly visible. - Adjust font sizes, rotate labels if necessary. - Use a colorblind-friendly palette. - Arrange multiple charts in a grid if needed. """ answer = df.chat(question + instructions) # Make sure to close figures so they don't conflict between processes plt.close('all') # If process_answer indicates a problem, return a failure message. if process_answer(answer): return "Chart not generated" # Return the chart path that was used for saving return chart_path except Exception as e: error = str(e) if "429" in error: with current_groq_chart_lock: current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) else: print(f"Groq chart generation error: {error}") return {"error": error} return {"error": "All API keys exhausted for chart generation"} # --- LANGCHAIN-BASED CHART GENERATION --- def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): """ Generate a chart using the langchain-based method. Modifications: • No shared deletion of cache. • Close matplotlib figures after generation. • Return a list of full chart file paths. """ global current_langchain_chart_key_index, current_langchain_chart_lock data = clean_data(csv_url) for attempt in range(len(groq_api_keys)): try: with current_langchain_chart_lock: api_key = groq_api_keys[current_langchain_chart_key_index] current_key = current_langchain_chart_key_index current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib, "uuid": uuid }) agent = create_pandas_dataframe_agent( llm, data, agent_type="openai-tools", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) result = agent.invoke({"input": _prompt_generator(question, True)}) output = result.get("output", "") # Close figures to avoid interference plt.close('all') # Extract chart filenames (assuming extract_chart_filenames returns a list) chart_files = extract_chart_filenames(output) if len(chart_files) > 0: # Return full paths (join with your image_file_path) return [os.path.join(image_file_path, f) for f in chart_files] if attempt < len(groq_api_keys) - 1: print(f"Langchain chart error (key {current_key}): {output}") except Exception as e: print(f"Langchain chart error (key {current_key}): {str(e)}") return "Chart generation failed after all retries" # --- FASTAPI ENDPOINT FOR CHART GENERATION --- @app.post("/api/csv-chart") async def csv_chart(request: dict, authorization: str = Header(None)): """ Endpoint for generating a chart from CSV data. This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation functions in separate processes so that multiple requests can run in parallel. """ # --- Authorization Check --- if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Authorization required") token = authorization.split(" ")[1] if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid credentials") try: query = request.get("query", "") csv_url = unquote(request.get("csv_url", "")) loop = asyncio.get_running_loop() # First, try the langchain-based method if the question qualifies if if_initial_chart_question(query): langchain_result = await loop.run_in_executor( process_executor, langchain_csv_chart, csv_url, query, True ) print("Langchain chart result:", langchain_result) if isinstance(langchain_result, list) and len(langchain_result) > 0: return FileResponse(langchain_result[0], media_type="image/png") # Next, try the groq-based method groq_result = await loop.run_in_executor( process_executor, groq_chart, csv_url, query ) print(f"Groq chart result: {groq_result}") if isinstance(groq_result, str) and groq_result != "Chart not generated": return FileResponse(groq_result, media_type="image/png") # Fallback: try langchain-based again langchain_paths = await loop.run_in_executor( process_executor, langchain_csv_chart, csv_url, query, True ) print("Fallback langchain chart result:", langchain_paths) if isinstance(langchain_paths, list) and len(langchain_paths) > 0: return FileResponse(langchain_paths[0], media_type="image/png") else: return {"error": "All chart generation methods failed"} except Exception as e: print(f"Critical chart error: {str(e)}") return {"error": "Internal system error"}