Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ import gradio as gr
|
|
4 |
import cv2
|
5 |
import tensorflow as tf
|
6 |
import keras
|
7 |
-
from keras.models import Model
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
import pandas as pd
|
10 |
from PIL import Image
|
@@ -30,13 +29,15 @@ from tensorflow.keras.applications.xception import preprocess_input as preproces
|
|
30 |
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
|
31 |
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
|
32 |
|
33 |
-
# ---- Fonctions utilitaires ----
|
34 |
def get_primary_input_name(model):
|
|
|
35 |
if isinstance(model.inputs, list) and len(model.inputs) > 0:
|
36 |
return model.inputs[0].name.split(':')[0]
|
37 |
return "input_1"
|
38 |
|
39 |
def safe_forward(model, x):
|
|
|
40 |
input_name = get_primary_input_name(model)
|
41 |
return model({input_name: x}, training=False).numpy()
|
42 |
|
@@ -59,57 +60,13 @@ def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
|
|
59 |
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
|
60 |
return preds[0]
|
61 |
|
62 |
-
|
63 |
-
# ---- Grad-CAM (Version Finale avec nom de couche fixe) ----
|
64 |
-
def make_gradcam(image_pil, model, class_index=None):
|
65 |
-
# ===== CORRECTION FINALE : On utilise le nom connu et fiable de la dernière couche de features de DenseNet201 =====
|
66 |
-
last_conv_layer_name = "relu"
|
67 |
-
# =================================================================================================================
|
68 |
-
|
69 |
-
img_np = np.array(image_pil)
|
70 |
-
img_resized = cv2.resize(img_np, (224, 224))
|
71 |
-
img_array_preprocessed = preprocess_densenet(np.expand_dims(img_resized, axis=0))
|
72 |
-
|
73 |
-
grad_model = Model(model.inputs, [model.get_layer(last_conv_layer_name).output, model.output])
|
74 |
-
input_name = get_primary_input_name(model)
|
75 |
-
input_for_model = {input_name: img_array_preprocessed}
|
76 |
-
|
77 |
-
with tf.GradientTape() as tape:
|
78 |
-
last_conv_layer_output, preds = grad_model(input_for_model, training=False)
|
79 |
-
if isinstance(preds, list):
|
80 |
-
preds = preds[0]
|
81 |
-
if class_index is None:
|
82 |
-
class_index = tf.argmax(preds[0])
|
83 |
-
class_channel = preds[:, class_index]
|
84 |
-
|
85 |
-
grads = tape.gradient(class_channel, last_conv_layer_output)
|
86 |
-
if grads is None:
|
87 |
-
print("Erreur: Le gradient est None.")
|
88 |
-
return img_resized
|
89 |
-
|
90 |
-
# Correction de la régression : on moyenne UNIQUEMENT sur les axes spatiaux
|
91 |
-
pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
|
92 |
-
|
93 |
-
last_conv_layer_output = last_conv_layer_output[0]
|
94 |
-
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
|
95 |
-
heatmap = tf.squeeze(heatmap)
|
96 |
-
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
|
97 |
-
heatmap = heatmap.numpy()
|
98 |
-
|
99 |
-
heatmap = np.uint8(255 * heatmap)
|
100 |
-
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
101 |
-
|
102 |
-
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
|
103 |
-
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)
|
104 |
-
|
105 |
-
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
|
106 |
-
|
107 |
-
# ---- Fonction Gradio ----
|
108 |
def gradio_predict(image_pil):
|
109 |
if image_pil is None:
|
110 |
-
return "Veuillez uploader une image.", None
|
111 |
try:
|
112 |
probs = predict_single(image_pil)
|
|
|
113 |
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
|
114 |
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
|
115 |
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
|
@@ -120,21 +77,21 @@ def gradio_predict(image_pil):
|
|
120 |
df.index.name = "Classe"
|
121 |
df = df.reset_index()
|
122 |
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
return global_diag, df, gradcam_img
|
127 |
except Exception as e:
|
128 |
print(f"Erreur dans gradio_predict : {e}")
|
129 |
import traceback
|
130 |
traceback.print_exc()
|
131 |
-
return "Erreur lors du traitement de l'image.", None
|
132 |
|
133 |
-
# ---- Gradio UI ----
|
134 |
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
|
|
|
135 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
136 |
-
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles
|
137 |
-
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin)
|
138 |
with gr.Row():
|
139 |
with gr.Column(scale=1):
|
140 |
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
|
@@ -143,8 +100,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
143 |
with gr.Column(scale=1):
|
144 |
output_label = gr.Label(label="Diagnostic global")
|
145 |
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
|
149 |
if __name__ == "__main__":
|
150 |
demo.launch()
|
|
|
4 |
import cv2
|
5 |
import tensorflow as tf
|
6 |
import keras
|
|
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
import pandas as pd
|
9 |
from PIL import Image
|
|
|
29 |
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
|
30 |
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
|
31 |
|
32 |
+
# ---- Fonctions utilitaires robustes pour les modèles Keras ----
|
33 |
def get_primary_input_name(model):
|
34 |
+
"""Retourne le nom de la couche d'input principale du modèle."""
|
35 |
if isinstance(model.inputs, list) and len(model.inputs) > 0:
|
36 |
return model.inputs[0].name.split(':')[0]
|
37 |
return "input_1"
|
38 |
|
39 |
def safe_forward(model, x):
|
40 |
+
"""Appelle un modèle en utilisant le nom d'input correct pour éviter les UserWarnings."""
|
41 |
input_name = get_primary_input_name(model)
|
42 |
return model({input_name: x}, training=False).numpy()
|
43 |
|
|
|
60 |
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
|
61 |
return preds[0]
|
62 |
|
63 |
+
# ---- Fonction Gradio (Simplifiée) ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def gradio_predict(image_pil):
|
65 |
if image_pil is None:
|
66 |
+
return "Veuillez uploader une image.", None
|
67 |
try:
|
68 |
probs = predict_single(image_pil)
|
69 |
+
|
70 |
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
|
71 |
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
|
72 |
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
|
|
|
77 |
df.index.name = "Classe"
|
78 |
df = df.reset_index()
|
79 |
|
80 |
+
# On ne retourne que les informations pour le label et le bar plot
|
81 |
+
return global_diag, df
|
82 |
|
|
|
83 |
except Exception as e:
|
84 |
print(f"Erreur dans gradio_predict : {e}")
|
85 |
import traceback
|
86 |
traceback.print_exc()
|
87 |
+
return "Erreur lors du traitement de l'image.", None
|
88 |
|
89 |
+
# ---- Gradio UI (Simplifiée) ----
|
90 |
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
|
91 |
+
|
92 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
93 |
+
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles)")
|
94 |
+
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin).")
|
95 |
with gr.Row():
|
96 |
with gr.Column(scale=1):
|
97 |
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
|
|
|
100 |
with gr.Column(scale=1):
|
101 |
output_label = gr.Label(label="Diagnostic global")
|
102 |
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
|
103 |
+
# La sortie Grad-CAM a été enlevée
|
104 |
+
|
105 |
+
# L'appel au clic a été mis à jour pour ne gérer que 2 sorties
|
106 |
+
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot])
|
107 |
|
108 |
if __name__ == "__main__":
|
109 |
demo.launch()
|