|
import os |
|
from typing import Optional |
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from fastapi.responses import JSONResponse |
|
import numpy as np |
|
from PIL import Image |
|
from io import BytesIO |
|
from huggingface_hub import login, hf_hub_download |
|
import keras |
|
|
|
|
|
app = FastAPI(title="PETRA API", description="API PETRA, Input Shape (200, 200, 3)") |
|
|
|
|
|
if "HF_TOKENS" in os.environ: |
|
login(token=os.environ.get("HF_TOKENS")) |
|
|
|
|
|
os.environ["KERAS_BACKEND"] = "tensorflow" |
|
|
|
|
|
INPUT_SHAPE = (200, 200, 3) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
model_path = hf_hub_download(repo_id="ncardian/sifrac-ml", filename="sifract.keras") |
|
model = keras.models.load_model(model_path) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model: {str(e)}") |
|
|
|
def preprocess_image(image: Image.Image) -> np.ndarray: |
|
""" |
|
Preprocess the uploaded image to match model input requirements. |
|
""" |
|
|
|
image = image.resize((INPUT_SHAPE[1], INPUT_SHAPE[0])) |
|
|
|
|
|
image_array = np.array(image) |
|
|
|
|
|
if len(image_array.shape) == 2: |
|
image_array = np.stack((image_array,) * 3, axis=-1) |
|
elif image_array.shape[2] == 4: |
|
image_array = image_array[:, :, :3] |
|
|
|
|
|
image_array = image_array.astype('float32') / 255.0 |
|
|
|
|
|
image_array = np.expand_dims(image_array, axis=0) |
|
|
|
return image_array |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "SIFRACT-ML API", "input_shape": INPUT_SHAPE} |
|
|
|
@app.post("/predict") |
|
async def predict(file: UploadFile = File(...)): |
|
""" |
|
Endpoint for making predictions with the Keras model. |
|
Accepts an image file and returns model predictions. |
|
""" |
|
|
|
if not file.content_type.startswith('image/'): |
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
try: |
|
|
|
contents = await file.read() |
|
image = Image.open(BytesIO(contents)) |
|
|
|
|
|
processed_image = preprocess_image(image) |
|
|
|
|
|
prediction = model.predict(processed_image) |
|
|
|
|
|
prediction = prediction.tolist() |
|
|
|
return JSONResponse(content={"prediction": prediction}) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |