ericjedha commited on
Commit
825c5d4
·
verified ·
1 Parent(s): b83f95d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -39
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import numpy as np
2
  import gradio as gr
3
  import cv2
@@ -19,14 +22,16 @@ diagnosis_map = {
19
  'vasc': 'Bénin',
20
  'mel': 'Malin'
21
  }
 
22
 
23
- # ---- Chargement modèles ----
24
  resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
25
  densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
26
 
27
- model_xcept = keras.saving.load_model("Xception.keras", compile=False) # si local
28
  model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
29
  model_densenet = keras.saving.load_model(densenet_path, compile=False)
 
30
 
31
  # ---- Préprocesseurs ----
32
  from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
@@ -37,27 +42,32 @@ def load_image(path, target_size):
37
  img = image.load_img(path, target_size=target_size)
38
  return image.img_to_array(img)
39
 
 
 
 
 
 
 
 
 
 
 
40
  # ---- Prédiction single image ----
41
  def predict_single(img_path, weights=(0.45, 0.25, 0.30)):
42
  bx = preprocess_xception(np.expand_dims(load_image(img_path, (299, 299)), axis=0))
43
  br = preprocess_resnet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
44
  bd = preprocess_densenet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
45
 
46
- pred_x = model_xcept.predict(bx, verbose=0)
47
- pred_r = model_resnet50.predict(br, verbose=0)
48
- pred_d = model_densenet.predict(bd, verbose=0)
49
-
50
- # ⚡ Assurer shape (1,7)
51
- for p in [pred_x, pred_r, pred_d]:
52
- if p.ndim == 1:
53
- p = np.expand_dims(p, axis=0)
54
 
55
  preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
56
 
57
- # boost MEL avec DenseNet
58
- mel_idx = label_to_index['mel']
59
- if mel_idx < pred_d.shape[1]:
60
- preds[:, mel_idx] = 0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx]
61
 
62
  return preds[0]
63
 
@@ -70,10 +80,7 @@ def make_gradcam(img_path, model, last_conv_layer_name="conv5_block32_concat", c
70
 
71
  if class_index is None:
72
  preds = model.predict(input_array)
73
- if isinstance(preds, list):
74
- preds = preds[0]
75
- if preds.ndim == 1:
76
- preds = np.expand_dims(preds, axis=0)
77
  class_index = np.argmax(preds[0])
78
 
79
  grad_model = Model(inputs=model.inputs, outputs=[
@@ -83,24 +90,28 @@ def make_gradcam(img_path, model, last_conv_layer_name="conv5_block32_concat", c
83
 
84
  with tf.GradientTape() as tape:
85
  conv_outputs, predictions = grad_model(input_array)
 
 
 
 
86
  predictions = tf.convert_to_tensor(predictions)
87
- if predictions.ndim == 1:
88
- predictions = tf.expand_dims(predictions, axis=0)
89
- loss = predictions[:, class_index]
90
 
91
- grads = tape.gradient(loss, conv_outputs)[0] # (H, W, C)
92
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1)) # (C,)
93
 
94
- conv_outputs = conv_outputs[0] # (H, W, C)
95
- heatmap = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1) # (H, W)
96
 
 
 
 
97
  heatmap = np.maximum(heatmap, 0) / (np.max(heatmap) + 1e-6)
 
98
  heatmap = cv2.resize(heatmap.numpy(), (224, 224))
99
  heatmap = np.uint8(255 * heatmap)
100
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
101
 
102
  superimposed_img = cv2.addWeighted(
103
- cv2.cvtColor(img_array.astype("uint8"), cv2.COLOR_RGB2BGR),
104
  0.6, heatmap, 0.4, 0
105
  )
106
  return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
@@ -109,29 +120,23 @@ def make_gradcam(img_path, model, last_conv_layer_name="conv5_block32_concat", c
109
  def gradio_predict(image_file):
110
  probs = predict_single(image_file)
111
 
112
- # Classe globale
 
 
 
113
  benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
114
  malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
115
  global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
116
 
117
- # BarPlot
118
- sorted_idx = np.argsort(-probs)
119
- sorted_labels = [CLASS_NAMES[i].upper() for i in sorted_idx]
120
- sorted_probs = (probs[sorted_idx] * 100).tolist()
121
- bar_data = {"Classes": sorted_labels, "Probabilité (%)": sorted_probs}
122
 
123
- # Grad-CAM
124
  top_class = np.argmax(probs)
125
  gradcam_img = make_gradcam(image_file, model_densenet, class_index=top_class)
126
 
127
- return global_diag, bar_data, gradcam_img
128
 
129
  # ---- Gradio UI ----
130
- examples = [
131
- "exemple1.jpg",
132
- "exemple2.jpg",
133
- "exemple3.jpg"
134
- ]
135
 
136
  demo = gr.Interface(
137
  fn=gradio_predict,
 
1
+ import os
2
+ #os.environ["KERAS_BACKEND"] = "jax"
3
+
4
  import numpy as np
5
  import gradio as gr
6
  import cv2
 
22
  'vasc': 'Bénin',
23
  'mel': 'Malin'
24
  }
25
+ NUM_CLASSES = len(CLASS_NAMES)
26
 
27
+ # ---- Téléchargement des modèles depuis Hugging Face ----
28
  resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
29
  densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
30
 
31
+ # ---- Chargement modèles ----
32
  model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
33
  model_densenet = keras.saving.load_model(densenet_path, compile=False)
34
+ model_xcept = keras.saving.load_model("Xception.keras", compile=False)
35
 
36
  # ---- Préprocesseurs ----
37
  from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
 
42
  img = image.load_img(path, target_size=target_size)
43
  return image.img_to_array(img)
44
 
45
+ # ---- Sécurité sur les prédictions ----
46
+ def safe_preds(pred_array, target_len=NUM_CLASSES):
47
+ pred_array = np.array(pred_array)
48
+ if pred_array.shape[1] < target_len:
49
+ zeros = np.zeros((pred_array.shape[0], target_len - pred_array.shape[1]))
50
+ pred_array = np.concatenate([pred_array, zeros], axis=1)
51
+ elif pred_array.shape[1] > target_len:
52
+ pred_array = pred_array[:, :target_len]
53
+ return pred_array
54
+
55
  # ---- Prédiction single image ----
56
  def predict_single(img_path, weights=(0.45, 0.25, 0.30)):
57
  bx = preprocess_xception(np.expand_dims(load_image(img_path, (299, 299)), axis=0))
58
  br = preprocess_resnet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
59
  bd = preprocess_densenet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
60
 
61
+ pred_x = safe_preds(model_xcept.predict(bx, verbose=0))
62
+ pred_r = safe_preds(model_resnet50.predict(br, verbose=0))
63
+ pred_d = safe_preds(model_densenet.predict(bd, verbose=0))
 
 
 
 
 
64
 
65
  preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
66
 
67
+ # boost MEL avec DenseNet si possible
68
+ mel_idx = label_to_index.get('mel', None)
69
+ if mel_idx is not None and mel_idx < preds.shape[1] and mel_idx < pred_d.shape[1]:
70
+ preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
71
 
72
  return preds[0]
73
 
 
80
 
81
  if class_index is None:
82
  preds = model.predict(input_array)
83
+ preds = np.array(preds)
 
 
 
84
  class_index = np.argmax(preds[0])
85
 
86
  grad_model = Model(inputs=model.inputs, outputs=[
 
90
 
91
  with tf.GradientTape() as tape:
92
  conv_outputs, predictions = grad_model(input_array)
93
+
94
+ # sécuriser predictions
95
+ if isinstance(predictions, list):
96
+ predictions = predictions[0]
97
  predictions = tf.convert_to_tensor(predictions)
 
 
 
98
 
99
+ loss = predictions[:, class_index]
 
100
 
101
+ grads = tape.gradient(loss, conv_outputs)[0]
102
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
103
 
104
+ conv_outputs = conv_outputs[0]
105
+ heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
106
+ heatmap = tf.squeeze(heatmap)
107
  heatmap = np.maximum(heatmap, 0) / (np.max(heatmap) + 1e-6)
108
+
109
  heatmap = cv2.resize(heatmap.numpy(), (224, 224))
110
  heatmap = np.uint8(255 * heatmap)
111
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
112
 
113
  superimposed_img = cv2.addWeighted(
114
+ cv2.cvtColor(img_array.astype("uint8"), cv2.COLOR_RGB2BGR),
115
  0.6, heatmap, 0.4, 0
116
  )
117
  return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
 
120
  def gradio_predict(image_file):
121
  probs = predict_single(image_file)
122
 
123
+ sorted_idx = np.argsort(-probs)
124
+ sorted_labels = [CLASS_NAMES[i].upper() for i in sorted_idx]
125
+ sorted_probs = probs[sorted_idx] * 100
126
+
127
  benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
128
  malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
129
  global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
130
 
131
+ bar_data = {"Classes": sorted_labels, "Probabilité (%)": sorted_probs.tolist()}
 
 
 
 
132
 
 
133
  top_class = np.argmax(probs)
134
  gradcam_img = make_gradcam(image_file, model_densenet, class_index=top_class)
135
 
136
+ return global_diag, gr.BarPlot.update(value=bar_data, x="Classes", y="Probabilité (%)", title="Distribution des classes"), gradcam_img
137
 
138
  # ---- Gradio UI ----
139
+ examples = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
 
 
 
 
140
 
141
  demo = gr.Interface(
142
  fn=gradio_predict,