ah_ad / inference.py
marcjordan's picture
Upload inference.py
396ed0c verified
import tensorflow as tf
import numpy as np
from PIL import Image
import io
# Carga el modelo SavedModel
model = tf.saved_model.load("efficientnet_alzheimer")
# Define la funci贸n de preprocesamiento de im谩genes
def preprocess_image(image):
image = image.resize((200, 200))
image = np.array(image) / 255.0
image = np.expand_dims(image, axis=0)
return image
def predict(image: Image.Image):
"""
Realiza la predicci贸n de la imagen.
Args:
image: Imagen en formato PIL.Image
Returns:
dict: Un diccionario con la predicci贸n y la probabilidad.
"""
image = preprocess_image(image)
# Realiza la predicci贸n
predictions = model(image)
# Obt茅n la clase predicha y la probabilidad
predicted_class = np.argmax(predictions)
confidence = np.max(predictions)
# Mapea la clase num茅rica al nombre de la clase
class_names = {
0: "VeryMildDemented",
1: "NonDemented",
2: "ModerateDemented",
3: "MildDemented"
}
predicted_class_name = class_names.get(predicted_class, "Clase desconocida")
# Devuelve el resultado
return {"prediction": predicted_class_name, "confidence": confidence.item()}