marcjordan commited on
Commit
396ed0c
verified
1 Parent(s): 7298f61

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +46 -0
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from PIL import Image
4
+ import io
5
+
6
+ # Carga el modelo SavedModel
7
+ model = tf.saved_model.load("efficientnet_alzheimer")
8
+
9
+ # Define la funci贸n de preprocesamiento de im谩genes
10
+ def preprocess_image(image):
11
+ image = image.resize((200, 200))
12
+ image = np.array(image) / 255.0
13
+ image = np.expand_dims(image, axis=0)
14
+ return image
15
+
16
+ def predict(image: Image.Image):
17
+ """
18
+ Realiza la predicci贸n de la imagen.
19
+
20
+ Args:
21
+ image: Imagen en formato PIL.Image
22
+
23
+ Returns:
24
+ dict: Un diccionario con la predicci贸n y la probabilidad.
25
+ """
26
+
27
+ image = preprocess_image(image)
28
+
29
+ # Realiza la predicci贸n
30
+ predictions = model(image)
31
+
32
+ # Obt茅n la clase predicha y la probabilidad
33
+ predicted_class = np.argmax(predictions)
34
+ confidence = np.max(predictions)
35
+
36
+ # Mapea la clase num茅rica al nombre de la clase
37
+ class_names = {
38
+ 0: "VeryMildDemented",
39
+ 1: "NonDemented",
40
+ 2: "ModerateDemented",
41
+ 3: "MildDemented"
42
+ }
43
+ predicted_class_name = class_names.get(predicted_class, "Clase desconocida")
44
+
45
+ # Devuelve el resultado
46
+ return {"prediction": predicted_class_name, "confidence": confidence.item()}