ericjedha commited on
Commit
ff5de12
·
verified ·
1 Parent(s): b58aa9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -49
app.py CHANGED
@@ -1,74 +1,136 @@
1
- import gradio as gr
 
 
2
  import numpy as np
3
- import tensorflow as tf
4
  import cv2
5
- import matplotlib.pyplot as plt
6
- from tensorflow.keras.applications.xception import preprocess_input, decode_predictions
 
 
7
 
8
- # Charger le modèle
9
- model = tf.keras.models.load_model("Xception-Baseline.keras")
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Nom de la dernière couche convolutive
12
- last_conv_layer_name = "block14_sepconv2_act"
 
 
13
 
14
- # Fonction Grad-CAM
15
- def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
16
- grad_model = tf.keras.models.Model(
17
- [model.inputs],
18
- [model.get_layer(last_conv_layer_name).output, model.output]
19
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  with tf.GradientTape() as tape:
22
- conv_outputs, predictions = grad_model(img_array)
23
- if pred_index is None:
24
- pred_index = tf.argmax(predictions[0])
25
- class_channel = predictions[:, pred_index]
26
 
27
- grads = tape.gradient(class_channel, conv_outputs)
28
  pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
 
29
  conv_outputs = conv_outputs[0]
30
  heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
31
  heatmap = tf.squeeze(heatmap)
32
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
33
- return heatmap.numpy(), predictions.numpy()
34
 
35
- # Superposition de la heatmap
36
- def overlay_heatmap(original_img, heatmap, alpha=0.4):
37
- heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
38
  heatmap = np.uint8(255 * heatmap)
39
- heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
40
- superimposed_img = cv2.addWeighted(heatmap_color, alpha, original_img, 1 - alpha, 0)
41
- return superimposed_img
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Fonction principale pour Gradio
44
- def gradcam_interface(img):
45
- # Convertir l'image
46
- img_resized = cv2.resize(img, (299, 299))
47
- img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
48
- input_array = preprocess_input(np.expand_dims(img_rgb.astype(np.float32), axis=0))
49
 
50
- # Générer heatmap
51
- heatmap, preds = make_gradcam_heatmap(input_array, model, last_conv_layer_name)
52
 
53
- # Décodage prédiction (si modèle pré-entraîné sur ImageNet, sinon adapter)
54
- class_idx = np.argmax(preds[0])
55
- confidence = preds[0][class_idx]
56
- decoded = f"Classe prédite : {class_idx} | Confiance : {confidence:.2f}"
57
 
58
- # Appliquer la heatmap
59
- heatmap_overlay = overlay_heatmap(img_resized, heatmap)
60
 
61
- return heatmap_overlay, decoded
 
 
 
 
 
62
 
63
- # Interface Gradio
64
  demo = gr.Interface(
65
- fn=gradcam_interface,
66
- inputs=gr.Image(type="numpy", label="Image"),
67
  outputs=[
68
- gr.Image(type="numpy", label="Grad-CAM"),
69
- gr.Text(label="Prédiction")
 
70
  ],
71
- title="Grad-CAM Visualizer - Xception"
 
 
72
  )
73
 
74
- demo.launch()
 
 
1
+ import os
2
+ os.environ["KERAS_BACKEND"] = "jax"
3
+
4
  import numpy as np
5
+ import gradio as gr
6
  import cv2
7
+ import tensorflow as tf
8
+ import keras
9
+ from keras.models import Model
10
+ from keras.preprocessing import image
11
 
12
+ # ---- Configuration ----
13
+ CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
14
+ label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
15
+ diagnosis_map = {
16
+ 'akiec': 'Bénin',
17
+ 'bcc': 'Malin',
18
+ 'bkl': 'Bénin',
19
+ 'df': 'Bénin',
20
+ 'nv': 'Bénin',
21
+ 'vasc': 'Bénin',
22
+ 'mel': 'Malin'
23
+ }
24
 
25
+ # ---- Chargement modèles ----
26
+ model_xcept = keras.saving.load_model("Xception.keras", compile=False)
27
+ model_resnet50 = keras.saving.load_model("hf://ericjedha/resnet50")
28
+ model_densenet = keras.saving.load_model("hf://ericjedha/densenet201")
29
 
30
+ # ---- Préprocesseurs ----
31
+ from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
32
+ from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
33
+ from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
34
+
35
+ def load_image(path, target_size):
36
+ img = image.load_img(path, target_size=target_size)
37
+ return image.img_to_array(img)
38
+
39
+ # ---- Prédiction single image ----
40
+ def predict_single(img_path, weights=(0.45, 0.25, 0.30)):
41
+ bx = preprocess_xception(np.expand_dims(load_image(img_path, (299, 299)), axis=0))
42
+ br = preprocess_resnet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
43
+ bd = preprocess_densenet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
44
+
45
+ pred_x = model_xcept.predict(bx, verbose=0)
46
+ pred_r = model_resnet50.predict(br, verbose=0)
47
+ pred_d = model_densenet.predict(bd, verbose=0)
48
+
49
+ preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
50
+
51
+ # boost MEL avec DenseNet
52
+ mel_idx = label_to_index['mel']
53
+ preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
54
+
55
+ return preds[0]
56
+
57
+ # ---- Grad-CAM ----
58
+ def make_gradcam(img_path, model, last_conv_layer_name="conv5_block32_concat", class_index=None):
59
+ img = image.load_img(img_path, target_size=(224, 224))
60
+ img_array = image.img_to_array(img)
61
+ input_array = np.expand_dims(img_array, axis=0)
62
+ input_array = preprocess_densenet(input_array)
63
+
64
+ if class_index is None:
65
+ preds = model.predict(input_array)
66
+ class_index = np.argmax(preds[0])
67
+
68
+ grad_model = Model(inputs=model.inputs, outputs=[
69
+ model.get_layer(last_conv_layer_name).output,
70
+ model.output
71
+ ])
72
 
73
  with tf.GradientTape() as tape:
74
+ conv_outputs, predictions = grad_model(input_array)
75
+ loss = predictions[:, class_index]
 
 
76
 
77
+ grads = tape.gradient(loss, conv_outputs)[0]
78
  pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
79
+
80
  conv_outputs = conv_outputs[0]
81
  heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
82
  heatmap = tf.squeeze(heatmap)
83
+ heatmap = np.maximum(heatmap, 0) / (np.max(heatmap) + 1e-6)
 
84
 
85
+ heatmap = cv2.resize(heatmap.numpy(), (224, 224))
 
 
86
  heatmap = np.uint8(255 * heatmap)
87
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
88
+
89
+ superimposed_img = cv2.addWeighted(
90
+ cv2.cvtColor(img_array.astype("uint8"), cv2.COLOR_RGB2BGR),
91
+ 0.6, heatmap, 0.4, 0
92
+ )
93
+ return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
94
+
95
+ # ---- Fonction Gradio ----
96
+ def gradio_predict(image_file):
97
+ probs = predict_single(image_file)
98
+
99
+ sorted_idx = np.argsort(-probs)
100
+ sorted_labels = [CLASS_NAMES[i].upper() for i in sorted_idx]
101
+ sorted_probs = probs[sorted_idx] * 100
102
 
103
+ benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
104
+ malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
105
+ global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
 
 
 
106
 
107
+ bar_data = {"Classes": sorted_labels, "Probabilité (%)": sorted_probs.tolist()}
 
108
 
109
+ # Grad-CAM sur la meilleure classe
110
+ top_class = np.argmax(probs)
111
+ gradcam_img = make_gradcam(image_file, model_densenet, class_index=top_class)
 
112
 
113
+ return global_diag, gr.BarPlot.update(value=bar_data, x="Classes", y="Probabilité (%)", title="Distribution des classes"), gradcam_img
 
114
 
115
+ # ---- Gradio UI ----
116
+ examples = [
117
+ "exemple1.jpg",
118
+ "exemple2.jpg",
119
+ "exemple3.jpg"
120
+ ]
121
 
 
122
  demo = gr.Interface(
123
+ fn=gradio_predict,
124
+ inputs=gr.Image(type="filepath", label="Uploader une image de lésion"),
125
  outputs=[
126
+ gr.Label(label="Diagnostic global"),
127
+ gr.BarPlot(label="Probabilités par classe"),
128
+ gr.Image(label="Visualisation Grad-CAM")
129
  ],
130
+ examples=examples,
131
+ title="Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)",
132
+ description="Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle."
133
  )
134
 
135
+ if __name__ == "__main__":
136
+ demo.launch()