File size: 1,843 Bytes
962e656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from io import BytesIO
from typing import Dict

import uvicorn
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from pydantic import BaseModel
import tensorflow as tf

# Carga el modelo SavedModel (ajusta la ruta si es necesario)
model = tf.saved_model.load("./efficientnet_alzheimer")

app = FastAPI(title="API de clasificaci贸n de Alzheimer")

# Define la funci贸n de preprocesamiento de im谩genes (la misma que en inference.py)
def preprocess_image(image):
    image = image.resize((200, 200))
    image = np.array(image) / 255.0
    image = np.expand_dims(image, axis=0)
    return image

# Define el formato de la respuesta de la API
class Prediction(BaseModel):
    prediction: str
    confidence: float

# Endpoint para la predicci贸n
@app.post("/predict", response_model=Prediction)
async def predict_image(image: UploadFile = File(...)):
    # Lee la imagen
    contents = await image.read()
    image = Image.open(BytesIO(contents))

    # Preprocesa la imagen
    processed_image = preprocess_image(image)

    # Realiza la predicci贸n
    predictions = model(processed_image)

    # Obt茅n la clase predicha y la probabilidad
    predicted_class = np.argmax(predictions)
    confidence = np.max(predictions)

    # Mapea la clase num茅rica al nombre de la clase (igual que en inference.py)
    class_names = {
        0: "VeryMildDemented",
        1: "NonDemented",
        2: "ModerateDemented",
        3: "MildDemented"
    }
    predicted_class_name = class_names.get(predicted_class, "Clase desconocida")

    # Devuelve la predicci贸n y la confianza
    return {"prediction": predicted_class_name, "confidence": confidence.item()}

# Punto de entrada para ejecutar la aplicaci贸n con Uvicorn
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))