ericjedha commited on
Commit
68b6287
·
verified ·
1 Parent(s): 94b316b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -173
app.py CHANGED
@@ -1,184 +1,191 @@
1
- import os, cv2, time, numpy as np
2
- import tensorflow as tf
3
- from tensorflow.keras.applications import (
4
- Xception, ResNet50, DenseNet201,
5
- xception, resnet50, densenet
6
- )
7
- from tensorflow.keras.preprocessing import image as keras_image
8
  import gradio as gr
 
9
  import matplotlib.pyplot as plt
10
-
11
- # --- GPU setup ---
12
- gpus = tf.config.list_physical_devices("GPU")
13
- if gpus:
14
- try:
15
- for gpu in gpus:
16
- tf.config.experimental.set_memory_growth(gpu, True)
17
- print("✅ GPU activé :", gpus)
18
- except RuntimeError as e:
19
- print("⚠️ Erreur GPU:", e)
20
- else:
21
- print(" Aucun GPU détecté.")
22
-
23
- # --- Utils ---
24
- def _update_progress(progress, start, end, desc, steps=20, delay=0.01):
25
- """Anime la barre de progression de start→end en plusieurs ticks"""
26
- if progress is None:
 
 
 
 
 
 
27
  return
28
- for i in range(steps+1):
29
- val = start + (end-start)*i/steps
30
- try:
31
- progress(val/100, desc=desc)
32
- except Exception:
33
- try:
34
- progress(val/100, desc)
35
- except:
36
- pass
37
- time.sleep(delay)
38
-
39
- # --- Models ---
40
- model_xcept, model_resnet50, model_densenet = None, None, None
41
- try: model_xcept = Xception(weights="imagenet"); print("Xception OK")
42
- except Exception as e: print("Xception KO", e)
43
- try: model_resnet50 = ResNet50(weights="imagenet"); print("ResNet50 OK")
44
- except Exception as e: print("ResNet50 KO", e)
45
- try: model_densenet = DenseNet201(weights="imagenet"); print("DenseNet201 OK")
46
- except Exception as e: print("DenseNet KO", e)
47
-
48
- LAST_CONV_LAYERS = {"xception":"block14_sepconv2_act",
49
- "resnet50":"conv5_block3_out",
50
- "densenet201":"conv5_block32_concat"}
51
-
52
- # --- Preprocess ---
53
- def preprocess_input(img_path, model_name):
54
- img_size = {"xception":299, "resnet50":224, "densenet201":224}[model_name]
55
- img = keras_image.load_img(img_path, target_size=(img_size,img_size))
56
- x = keras_image.img_to_array(img)
57
- x = np.expand_dims(x, axis=0)
58
- if model_name=="xception": return xception.preprocess_input(x)
59
- if model_name=="resnet50": return resnet50.preprocess_input(x)
60
- return densenet.preprocess_input(x)
61
-
62
- def decode_preds(preds, model_name, top=3):
63
- if model_name=="xception": return xception.decode_predictions(preds, top=top)[0]
64
- if model_name=="resnet50": return resnet50.decode_predictions(preds, top=top)[0]
65
- return densenet.decode_predictions(preds, top=top)[0]
66
-
67
- # --- GradCAM ---
68
- def make_gradcam(img_pil, model, last_conv, class_idx, progress=None):
69
  try:
70
- _update_progress(progress,0,25,"Prétraitement image")
71
- input_size = model.input_shape[1:3]
72
- img_resized = img_pil.resize(input_size)
73
- x = np.expand_dims(np.array(img_resized).astype("float32"),0)
74
- preprocess_fn = (xception.preprocess_input if x.shape[1]==299
75
- else resnet50.preprocess_input)
76
- x = preprocess_fn(x)
77
-
78
- _update_progress(progress,25,50,"Calcul gradients")
79
- grad_model = tf.keras.models.Model([model.inputs],
80
- [model.get_layer(last_conv).output,
81
- model.output])
82
- with tf.GradientTape() as tape:
83
- conv_outputs, preds = grad_model(x)
84
- loss = preds[:, class_idx]
85
- grads = tape.gradient(loss, conv_outputs)
86
- pooled = tf.reduce_mean(grads, axis=(0,1,2)).numpy()
87
- conv_out = conv_outputs[0].numpy()
88
- heatmap = np.dot(conv_out, pooled)
89
- heatmap = np.maximum(heatmap,0)/np.max(heatmap)
90
-
91
- _update_progress(progress,50,75,"Génération heatmap")
92
- heatmap = cv2.resize(heatmap, img_pil.size)
93
- heatmap = np.uint8(255*heatmap)
94
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
95
-
96
- _update_progress(progress,75,100,"Fusion image/heatmap")
97
- img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
98
- superimposed = cv2.addWeighted(img_bgr,0.6,heatmap,0.4,0)
99
- return cv2.cvtColor(superimposed, cv2.COLOR_BGR2RGB)
100
- except Exception as e:
101
- print("GradCAM Error:",e)
102
- return np.array(img_pil)
103
-
104
- # --- Global state ---
105
- current_image=None
106
- current_preds=None
107
-
108
- # --- Prediction ---
109
- def quick_predict_ui(img_pil, progress=gr.Progress()):
110
- global current_image, current_preds
111
- if img_pil is None: return {}, plt.figure(), "❌ Aucune image"
112
-
113
- current_image = img_pil
114
- current_preds={}
115
- results, fig = {}, plt.figure()
116
- ax=fig.add_subplot(111)
117
-
118
- for model_name, model in [("xception",model_xcept),
119
- ("resnet50",model_resnet50),
120
- ("densenet201",model_densenet)]:
121
- if model is None: continue
122
- _update_progress(progress,0,30,f"Prédiction {model_name}")
123
- tmp="tmp.jpg"; img_pil.save(tmp)
124
- x = preprocess_input(tmp, model_name)
125
- preds = model.predict(x,verbose=0)
126
- current_preds[model_name]=preds[0]
127
- top = decode_preds(preds, model_name, top=3)
128
- for cls,desc,prob in top: results[desc]=float(prob)
129
- ax.barh([desc for _,desc,prob in top],
130
- [prob for _,desc,prob in top], label=model_name)
131
- _update_progress(progress,30,100,f"{model_name} terminé")
132
-
133
- # ensemble
134
- if len(current_preds)>1:
135
- avg = np.mean(list(current_preds.values()),axis=0)
136
- current_preds["ensemble"]=avg
137
- idx=np.argmax(avg)
138
- label=f"Ensemble : {idx} ({avg[idx]:.2%})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  else:
140
- label="Prédiction modèle unique"
141
 
142
- ax.legend(); plt.tight_layout()
143
- return {"label":label}, fig, "✅ Prédictions terminées"
144
 
145
- # --- GradCAM UI ---
146
- def generate_gradcam_ui(progress=gr.Progress()):
147
- global current_image,current_preds
148
- if current_image is None or current_preds is None:
149
- return None,"❌ Pas d'image/prédiction"
150
- try:
151
- ensemble=current_preds.get("ensemble")
152
- if ensemble is None: return None,"❌ Pas d'Ensemble"
153
- class_idx=int(np.argmax(ensemble))
154
- candidates=[]
155
- if model_xcept: candidates.append(("xception",model_xcept,current_preds["xception"][class_idx]))
156
- if model_resnet50: candidates.append(("resnet50",model_resnet50,current_preds["resnet50"][class_idx]))
157
- if model_densenet: candidates.append(("densenet201",model_densenet,current_preds["densenet201"][class_idx]))
158
- if not candidates: return None,"❌ Aucun modèle dispo"
159
- name,model,conf=max(candidates,key=lambda t:t[2])
160
- layer=LAST_CONV_LAYERS[name]
161
- gradcam=make_gradcam(current_image,model,layer,class_idx,progress=progress)
162
- return gradcam,f"✅ GradCAM {name} ({conf:.1%})"
163
- except Exception as e:
164
- return None,f"❌ Erreur: {e}"
165
-
166
- # --- UI ---
167
- with gr.Blocks(title="Diag rapide") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  gr.Markdown("## 🧠 Diagnostic IA avec Grad-CAM")
 
169
  with gr.Row():
170
- with gr.Column(scale=1):
171
- input_img=gr.Image(type="pil",label="🩻 Image à analyser")
172
- btn_predict=gr.Button("⚡ Analyse rapide")
173
- btn_gradcam=gr.Button("🎨 Générer Grad-CAM")
174
- with gr.Column(scale=2):
175
- out_label=gr.Label()
176
- out_plot=gr.Plot()
177
- out_grad=gr.Image()
178
- out_status=gr.Textbox(label="Statut",interactive=False)
179
- btn_predict.click(quick_predict_ui,inputs=input_img,
180
- outputs=[out_label,out_plot,out_status])
181
- btn_gradcam.click(generate_gradcam_ui,inputs=[],
182
- outputs=[out_grad,out_status])
 
 
 
 
 
 
 
183
 
184
  demo.launch()
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
  import matplotlib.pyplot as plt
4
+ import cv2
5
+ import tensorflow as tf
6
+ from tensorflow.keras.models import load_model
7
+ import os
8
+
9
+ # -----------------------
10
+ # Chargement des modèles
11
+ # -----------------------
12
+ MODELS_DIR = "models"
13
+ AVAILABLE_MODELS = {
14
+ "ResNet50": os.path.join(MODELS_DIR, "resnet50.h5"),
15
+ "EfficientNetB0": os.path.join(MODELS_DIR, "efficientnetb0.h5"),
16
+ "MobileNetV2": os.path.join(MODELS_DIR, "mobilenetv2.h5"),
17
+ }
18
+ loaded_models = {}
19
+ current_preds = {}
20
+
21
+ # -----------------------
22
+ # Progress helper
23
+ # -----------------------
24
+ def _update_progress(progress, step, total=100, desc=None):
25
+ """Met à jour la barre de progression uniformisée (0-100)."""
26
+ if progress is None:
27
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  try:
29
+ val = float(step) / float(total)
30
+ val = min(max(val, 0.0), 1.0)
31
+ except Exception:
32
+ val = 0.0
33
+ try:
34
+ if desc:
35
+ progress(val, desc=desc)
36
+ else:
37
+ progress(val)
38
+ except Exception:
39
+ pass
40
+
41
+ # -----------------------
42
+ # Prétraitement
43
+ # -----------------------
44
+ def preprocess_image(image, target_size=(224, 224)):
45
+ img = cv2.resize(image, target_size)
46
+ img = img / 255.0
47
+ return np.expand_dims(img, axis=0)
48
+
49
+ # -----------------------
50
+ # Grad-CAM
51
+ # -----------------------
52
+ def make_gradcam(model, img_array, layer_name, progress=None):
53
+ _update_progress(progress, 10, desc="Préparation...")
54
+ grad_model = tf.keras.models.Model(
55
+ [model.inputs], [model.get_layer(layer_name).output, model.output]
56
+ )
57
+
58
+ with tf.GradientTape() as tape:
59
+ conv_outputs, predictions = grad_model(img_array)
60
+ pred_index = tf.argmax(predictions[0])
61
+ loss = predictions[:, pred_index]
62
+
63
+ _update_progress(progress, 40, desc="Calcul des gradients...")
64
+
65
+ grads = tape.gradient(loss, conv_outputs)
66
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
67
+ conv_outputs = conv_outputs[0]
68
+
69
+ _update_progress(progress, 70, desc="Génération de la heatmap...")
70
+
71
+ heatmap = tf.reduce_mean(tf.multiply(pooled_grads, conv_outputs), axis=-1)
72
+ heatmap = np.maximum(heatmap, 0)
73
+ heatmap /= np.max(heatmap) if np.max(heatmap) != 0 else 1
74
+
75
+ heatmap = cv2.resize(heatmap, (img_array.shape[2], img_array.shape[1]))
76
+ heatmap = np.uint8(255 * heatmap)
77
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
78
+
79
+ _update_progress(progress, 90, desc="Application overlay...")
80
+
81
+ superimposed_img = heatmap * 0.4 + img_array[0] * 255
82
+ _update_progress(progress, 100, desc="✅ Grad-CAM terminé !")
83
+
84
+ return np.uint8(superimposed_img)
85
+
86
+ # -----------------------
87
+ # UI : prédiction rapide
88
+ # -----------------------
89
+ def quick_predict_ui(image, model_names, progress=gr.Progress()):
90
+ global current_preds
91
+ current_preds = {}
92
+ if image is None:
93
+ return "⚠️ Pas d'image", None, None
94
+
95
+ _update_progress(progress, 5, desc="Prétraitement...")
96
+ img_array = preprocess_image(image)
97
+
98
+ fig, ax = plt.subplots()
99
+ ax.set_title("Prédictions modèles")
100
+ bar_labels, bar_values = [], []
101
+
102
+ step = 30
103
+ for idx, name in enumerate(model_names):
104
+ _update_progress(progress, 5 + step * (idx+1) / len(model_names),
105
+ desc=f"Prédiction avec {name}...")
106
+
107
+ if name not in loaded_models:
108
+ loaded_models[name] = load_model(AVAILABLE_MODELS[name])
109
+ model = loaded_models[name]
110
+
111
+ preds = model.predict(img_array, verbose=0)[0]
112
+ current_preds[name] = preds
113
+ bar_labels.append(name)
114
+ bar_values.append(float(np.max(preds)))
115
+
116
+ _update_progress(progress, 70, desc="Agrégation...")
117
+
118
+ if len(current_preds) > 1:
119
+ avg = np.mean(list(current_preds.values()), axis=0)
120
+ current_preds["ensemble"] = avg
121
+ idx = np.argmax(avg)
122
+ label = f"Ensemble : {idx} ({avg[idx]:.2%})"
123
  else:
124
+ label = "Prédiction modèle unique"
125
 
126
+ ax.bar(bar_labels, bar_values)
127
+ ax.set_ylabel("Confiance max")
128
 
129
+ _update_progress(progress, 100, desc="✅ Prédiction terminée !")
130
+
131
+ return label, fig, "✅ Prédictions terminées"
132
+
133
+ # -----------------------
134
+ # UI : Grad-CAM
135
+ # -----------------------
136
+ def generate_gradcam_ui(image, explainer_model_name, progress=gr.Progress()):
137
+ if image is None:
138
+ return None, "⚠️ Pas d'image"
139
+
140
+ _update_progress(progress, 0, desc="Démarrage Grad-CAM...")
141
+
142
+ if explainer_model_name not in loaded_models:
143
+ loaded_models[explainer_model_name] = load_model(AVAILABLE_MODELS[explainer_model_name])
144
+ model = loaded_models[explainer_model_name]
145
+
146
+ img_array = preprocess_image(image)
147
+ _update_progress(progress, 20, desc=f"Génération Grad-CAM ({explainer_model_name})...")
148
+
149
+ layer_name = None
150
+ for lname in reversed([l.name for l in model.layers]):
151
+ if "conv" in lname or "block" in lname:
152
+ layer_name = lname
153
+ break
154
+ if not layer_name:
155
+ return None, "⚠️ Pas de couche conv trouvée"
156
+
157
+ cam_img = make_gradcam(model, img_array, layer_name, progress=progress)
158
+
159
+ _update_progress(progress, 100, desc="✅ Grad-CAM généré !")
160
+
161
+ return cam_img, "✅ Grad-CAM généré"
162
+
163
+ # -----------------------
164
+ # Interface Gradio
165
+ # -----------------------
166
+ with gr.Blocks() as demo:
167
  gr.Markdown("## 🧠 Diagnostic IA avec Grad-CAM")
168
+
169
  with gr.Row():
170
+ with gr.Column():
171
+ img_input = gr.Image(label="Image à analyser", type="numpy")
172
+ model_selector = gr.CheckboxGroup(choices=list(AVAILABLE_MODELS.keys()),
173
+ value=["ResNet50"], label="Modèles à utiliser")
174
+ btn_predict = gr.Button("⚡ Analyse rapide")
175
+ btn_gradcam = gr.Button("🎨 Générer Grad-CAM")
176
+
177
+ with gr.Column():
178
+ label_out = gr.Label(label="Étiquette")
179
+ plot_out = gr.Plot(label="Histogramme")
180
+ status_out = gr.Textbox(label="Statut")
181
+
182
+ gradcam_out = gr.Image(label="Grad-CAM")
183
+
184
+ btn_predict.click(fn=quick_predict_ui,
185
+ inputs=[img_input, model_selector],
186
+ outputs=[label_out, plot_out, status_out])
187
+ btn_gradcam.click(fn=generate_gradcam_ui,
188
+ inputs=[img_input, model_selector],
189
+ outputs=[gradcam_out, status_out])
190
 
191
  demo.launch()