marcjordan commited on
Commit
962e656
verified
1 Parent(s): 812e01e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from typing import Dict
4
+
5
+ import uvicorn
6
+ from fastapi import FastAPI, File, UploadFile
7
+ from PIL import Image
8
+ from pydantic import BaseModel
9
+ import tensorflow as tf
10
+
11
+ # Carga el modelo SavedModel (ajusta la ruta si es necesario)
12
+ model = tf.saved_model.load("./efficientnet_alzheimer")
13
+
14
+ app = FastAPI(title="API de clasificaci贸n de Alzheimer")
15
+
16
+ # Define la funci贸n de preprocesamiento de im谩genes (la misma que en inference.py)
17
+ def preprocess_image(image):
18
+ image = image.resize((200, 200))
19
+ image = np.array(image) / 255.0
20
+ image = np.expand_dims(image, axis=0)
21
+ return image
22
+
23
+ # Define el formato de la respuesta de la API
24
+ class Prediction(BaseModel):
25
+ prediction: str
26
+ confidence: float
27
+
28
+ # Endpoint para la predicci贸n
29
+ @app.post("/predict", response_model=Prediction)
30
+ async def predict_image(image: UploadFile = File(...)):
31
+ # Lee la imagen
32
+ contents = await image.read()
33
+ image = Image.open(BytesIO(contents))
34
+
35
+ # Preprocesa la imagen
36
+ processed_image = preprocess_image(image)
37
+
38
+ # Realiza la predicci贸n
39
+ predictions = model(processed_image)
40
+
41
+ # Obt茅n la clase predicha y la probabilidad
42
+ predicted_class = np.argmax(predictions)
43
+ confidence = np.max(predictions)
44
+
45
+ # Mapea la clase num茅rica al nombre de la clase (igual que en inference.py)
46
+ class_names = {
47
+ 0: "VeryMildDemented",
48
+ 1: "NonDemented",
49
+ 2: "ModerateDemented",
50
+ 3: "MildDemented"
51
+ }
52
+ predicted_class_name = class_names.get(predicted_class, "Clase desconocida")
53
+
54
+ # Devuelve la predicci贸n y la confianza
55
+ return {"prediction": predicted_class_name, "confidence": confidence.item()}
56
+
57
+ # Punto de entrada para ejecutar la aplicaci贸n con Uvicorn
58
+ if __name__ == "__main__":
59
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))