sifrac-ml / app.py
ncardian's picture
Create app.py
d4d5df3 verified
import os
from typing import Optional
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import numpy as np
from PIL import Image
from io import BytesIO
from huggingface_hub import login, hf_hub_download
import keras
# Initialize FastAPI app
app = FastAPI(title="PETRA API", description="API PETRA, Input Shape (200, 200, 3)")
# Login to Hugging Face (if needed)
if "HF_TOKENS" in os.environ:
login(token=os.environ.get("HF_TOKENS"))
# Set backend (optional)
os.environ["KERAS_BACKEND"] = "tensorflow"
# Define expected input shape
INPUT_SHAPE = (200, 200, 3)
# Load model from Hugging Face Hub
try:
# Option 1: If saved with keras.saving.save_model()
#model = keras.saving.load_model("hf://ncardian/petra")
# OR Option 2: If you have specific files
model_path = hf_hub_download(repo_id="ncardian/sifrac-ml", filename="sifract.keras")
model = keras.models.load_model(model_path)
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
def preprocess_image(image: Image.Image) -> np.ndarray:
"""
Preprocess the uploaded image to match model input requirements.
"""
# Resize image to match model input shape
image = image.resize((INPUT_SHAPE[1], INPUT_SHAPE[0]))
# Convert to numpy array
image_array = np.array(image)
# Check if image has 3 channels (RGB)
if len(image_array.shape) == 2: # Grayscale
image_array = np.stack((image_array,) * 3, axis=-1)
elif image_array.shape[2] == 4: # RGBA
image_array = image_array[:, :, :3]
# Normalize pixel values to [0, 1]
image_array = image_array.astype('float32') / 255.0
# Add batch dimension
image_array = np.expand_dims(image_array, axis=0)
return image_array
@app.get("/")
async def root():
return {"message": "SIFRACT-ML API", "input_shape": INPUT_SHAPE}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Endpoint for making predictions with the Keras model.
Accepts an image file and returns model predictions.
"""
# Check if the file is an image
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
try:
# Read the image file
contents = await file.read()
image = Image.open(BytesIO(contents))
# Preprocess the image
processed_image = preprocess_image(image)
# Make prediction
prediction = model.predict(processed_image)
# Convert numpy array to list for JSON serialization
prediction = prediction.tolist()
return JSONResponse(content={"prediction": prediction})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")