import os from typing import Optional from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse import numpy as np from PIL import Image from io import BytesIO from huggingface_hub import login, hf_hub_download import keras # Initialize FastAPI app app = FastAPI(title="PETRA API", description="API PETRA, Input Shape (200, 200, 3)") # Login to Hugging Face (if needed) if "HF_TOKENS" in os.environ: login(token=os.environ.get("HF_TOKENS")) # Set backend (optional) os.environ["KERAS_BACKEND"] = "tensorflow" # Define expected input shape INPUT_SHAPE = (200, 200, 3) # Load model from Hugging Face Hub try: # Option 1: If saved with keras.saving.save_model() #model = keras.saving.load_model("hf://ncardian/petra") # OR Option 2: If you have specific files model_path = hf_hub_download(repo_id="ncardian/sifrac-ml", filename="sifract.keras") model = keras.models.load_model(model_path) except Exception as e: raise RuntimeError(f"Error loading model: {str(e)}") def preprocess_image(image: Image.Image) -> np.ndarray: """ Preprocess the uploaded image to match model input requirements. """ # Resize image to match model input shape image = image.resize((INPUT_SHAPE[1], INPUT_SHAPE[0])) # Convert to numpy array image_array = np.array(image) # Check if image has 3 channels (RGB) if len(image_array.shape) == 2: # Grayscale image_array = np.stack((image_array,) * 3, axis=-1) elif image_array.shape[2] == 4: # RGBA image_array = image_array[:, :, :3] # Normalize pixel values to [0, 1] image_array = image_array.astype('float32') / 255.0 # Add batch dimension image_array = np.expand_dims(image_array, axis=0) return image_array @app.get("/") async def root(): return {"message": "SIFRACT-ML API", "input_shape": INPUT_SHAPE} @app.post("/predict") async def predict(file: UploadFile = File(...)): """ Endpoint for making predictions with the Keras model. Accepts an image file and returns model predictions. """ # Check if the file is an image if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") try: # Read the image file contents = await file.read() image = Image.open(BytesIO(contents)) # Preprocess the image processed_image = preprocess_image(image) # Make prediction prediction = model.predict(processed_image) # Convert numpy array to list for JSON serialization prediction = prediction.tolist() return JSONResponse(content={"prediction": prediction}) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")