ericjedha commited on
Commit
3595c95
·
verified ·
1 Parent(s): e177ec4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -58
app.py CHANGED
@@ -4,7 +4,6 @@ import gradio as gr
4
  import cv2
5
  import tensorflow as tf
6
  import keras
7
- from keras.models import Model
8
  from huggingface_hub import hf_hub_download
9
  import pandas as pd
10
  from PIL import Image
@@ -30,13 +29,15 @@ from tensorflow.keras.applications.xception import preprocess_input as preproces
30
  from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
31
  from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
32
 
33
- # ---- Fonctions utilitaires ----
34
  def get_primary_input_name(model):
 
35
  if isinstance(model.inputs, list) and len(model.inputs) > 0:
36
  return model.inputs[0].name.split(':')[0]
37
  return "input_1"
38
 
39
  def safe_forward(model, x):
 
40
  input_name = get_primary_input_name(model)
41
  return model({input_name: x}, training=False).numpy()
42
 
@@ -59,57 +60,13 @@ def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
59
  preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
60
  return preds[0]
61
 
62
-
63
- # ---- Grad-CAM (Version Finale avec nom de couche fixe) ----
64
- def make_gradcam(image_pil, model, class_index=None):
65
- # ===== CORRECTION FINALE : On utilise le nom connu et fiable de la dernière couche de features de DenseNet201 =====
66
- last_conv_layer_name = "relu"
67
- # =================================================================================================================
68
-
69
- img_np = np.array(image_pil)
70
- img_resized = cv2.resize(img_np, (224, 224))
71
- img_array_preprocessed = preprocess_densenet(np.expand_dims(img_resized, axis=0))
72
-
73
- grad_model = Model(model.inputs, [model.get_layer(last_conv_layer_name).output, model.output])
74
- input_name = get_primary_input_name(model)
75
- input_for_model = {input_name: img_array_preprocessed}
76
-
77
- with tf.GradientTape() as tape:
78
- last_conv_layer_output, preds = grad_model(input_for_model, training=False)
79
- if isinstance(preds, list):
80
- preds = preds[0]
81
- if class_index is None:
82
- class_index = tf.argmax(preds[0])
83
- class_channel = preds[:, class_index]
84
-
85
- grads = tape.gradient(class_channel, last_conv_layer_output)
86
- if grads is None:
87
- print("Erreur: Le gradient est None.")
88
- return img_resized
89
-
90
- # Correction de la régression : on moyenne UNIQUEMENT sur les axes spatiaux
91
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
92
-
93
- last_conv_layer_output = last_conv_layer_output[0]
94
- heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
95
- heatmap = tf.squeeze(heatmap)
96
- heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
97
- heatmap = heatmap.numpy()
98
-
99
- heatmap = np.uint8(255 * heatmap)
100
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
101
-
102
- img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
103
- superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)
104
-
105
- return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
106
-
107
- # ---- Fonction Gradio ----
108
  def gradio_predict(image_pil):
109
  if image_pil is None:
110
- return "Veuillez uploader une image.", None, None
111
  try:
112
  probs = predict_single(image_pil)
 
113
  benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
114
  malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
115
  global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
@@ -120,21 +77,21 @@ def gradio_predict(image_pil):
120
  df.index.name = "Classe"
121
  df = df.reset_index()
122
 
123
- top_class_idx = np.argmax(probs)
124
- gradcam_img = make_gradcam(image_pil, model_densenet, class_index=top_class_idx)
125
 
126
- return global_diag, df, gradcam_img
127
  except Exception as e:
128
  print(f"Erreur dans gradio_predict : {e}")
129
  import traceback
130
  traceback.print_exc()
131
- return "Erreur lors du traitement de l'image.", None, None
132
 
133
- # ---- Gradio UI ----
134
  example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
 
135
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
- gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)")
137
- gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle.")
138
  with gr.Row():
139
  with gr.Column(scale=1):
140
  input_image = gr.Image(type="pil", label="Uploader une image de lésion")
@@ -143,8 +100,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
143
  with gr.Column(scale=1):
144
  output_label = gr.Label(label="Diagnostic global")
145
  output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
146
- output_gradcam = gr.Image(label="Visualisation Grad-CAM")
147
- submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot, output_gradcam])
 
 
148
 
149
  if __name__ == "__main__":
150
  demo.launch()
 
4
  import cv2
5
  import tensorflow as tf
6
  import keras
 
7
  from huggingface_hub import hf_hub_download
8
  import pandas as pd
9
  from PIL import Image
 
29
  from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
30
  from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
31
 
32
+ # ---- Fonctions utilitaires robustes pour les modèles Keras ----
33
  def get_primary_input_name(model):
34
+ """Retourne le nom de la couche d'input principale du modèle."""
35
  if isinstance(model.inputs, list) and len(model.inputs) > 0:
36
  return model.inputs[0].name.split(':')[0]
37
  return "input_1"
38
 
39
  def safe_forward(model, x):
40
+ """Appelle un modèle en utilisant le nom d'input correct pour éviter les UserWarnings."""
41
  input_name = get_primary_input_name(model)
42
  return model({input_name: x}, training=False).numpy()
43
 
 
60
  preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
61
  return preds[0]
62
 
63
+ # ---- Fonction Gradio (Simplifiée) ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def gradio_predict(image_pil):
65
  if image_pil is None:
66
+ return "Veuillez uploader une image.", None
67
  try:
68
  probs = predict_single(image_pil)
69
+
70
  benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
71
  malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
72
  global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
 
77
  df.index.name = "Classe"
78
  df = df.reset_index()
79
 
80
+ # On ne retourne que les informations pour le label et le bar plot
81
+ return global_diag, df
82
 
 
83
  except Exception as e:
84
  print(f"Erreur dans gradio_predict : {e}")
85
  import traceback
86
  traceback.print_exc()
87
+ return "Erreur lors du traitement de l'image.", None
88
 
89
+ # ---- Gradio UI (Simplifiée) ----
90
  example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
91
+
92
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
93
+ gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles)")
94
+ gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin).")
95
  with gr.Row():
96
  with gr.Column(scale=1):
97
  input_image = gr.Image(type="pil", label="Uploader une image de lésion")
 
100
  with gr.Column(scale=1):
101
  output_label = gr.Label(label="Diagnostic global")
102
  output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
103
+ # La sortie Grad-CAM a été enlevée
104
+
105
+ # L'appel au clic a été mis à jour pour ne gérer que 2 sorties
106
+ submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot])
107
 
108
  if __name__ == "__main__":
109
  demo.launch()