File size: 10,360 Bytes
825c5d4
cc2523d
ff5de12
8976fd8
ff5de12
 
6c5b26a
e11c5ec
c9181f9
4616c72
66669f3
ff5de12
 
 
 
c9181f9
 
ff5de12
66669f3
1c99721
fb49894
 
2127103
0aaba05
 
c9181f9
0aaba05
ff5de12
 
 
 
 
6c5b26a
4ccca0d
4e23572
 
 
4ccca0d
66d521b
4ccca0d
6c5b26a
4616c72
ad465b4
4616c72
 
6c5b26a
 
4616c72
6c5b26a
 
 
ff5de12
6c5b26a
 
 
1c99721
6c5b26a
 
 
7c78074
 
6c5b26a
 
 
6fe2de4
6c5b26a
 
 
 
 
6fe2de4
 
 
 
 
 
6c5b26a
 
 
6fe2de4
 
 
 
 
 
 
 
6c5b26a
 
 
 
 
6fe2de4
 
6c5b26a
 
 
6fe2de4
 
 
 
 
 
 
 
 
 
ff5de12
6fe2de4
6c5b26a
 
6fe2de4
6c5b26a
 
 
6fe2de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c5b26a
6fe2de4
6c5b26a
 
 
7c78074
0805e5b
4616c72
6c5b26a
8966609
6c5b26a
 
3595c95
6c5b26a
 
 
 
0805e5b
 
 
 
 
 
 
 
 
 
 
9f14647
 
 
 
0805e5b
 
 
8976fd8
7c78074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8976fd8
0805e5b
7c78074
8966609
7c78074
9f14647
 
6c5b26a
2e69406
0805e5b
 
9f14647
 
6c5b26a
7c78074
9f14647
 
4616c72
9f14647
c9181f9
9f14647
 
0805e5b
 
 
 
 
 
 
 
 
6c5b26a
 
434563f
ff5de12
9f14647
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image

# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
    'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
    'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}

# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")

model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)

# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet

# ---- Fonctions utilitaires ----
def get_primary_input_name(model):
    if isinstance(model.inputs, list) and len(model.inputs) > 0:
        return model.inputs[0].name.split(':')[0]
    return "input_1" 

def safe_forward(model, x):
    input_name = get_primary_input_name(model)
    return model({input_name: x}, training=False) 

# ---- Prédiction ----
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
    img_np = np.array(image_pil)
    img_299_arr = np.expand_dims(cv2.resize(img_np, (299, 299)), axis=0)
    img_224_arr = np.expand_dims(cv2.resize(img_np, (224, 224)), axis=0)
    
    pred_x_tensor = safe_forward(model_xcept, preprocess_xception(img_299_arr))
    pred_r_tensor = safe_forward(model_resnet50, preprocess_resnet(img_224_arr))
    pred_d_tensor = safe_forward(model_densenet, preprocess_densenet(img_224_arr))

    pred_x, pred_r, pred_d = pred_x_tensor.numpy(), pred_r_tensor.numpy(), pred_d_tensor.numpy()
    
    preds_ensemble = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
    mel_idx = label_to_index['mel']
    preds_ensemble[:, mel_idx] = (0.5 * preds_ensemble[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
    
    return {
        "ensemble": preds_ensemble[0], "xception": pred_x[0],
        "resnet50": pred_r[0], "densenet201": pred_d[0]
    }

# ---- Grad-CAM ----
# ---- Grad-CAM CORRIGÉ ----
def make_gradcam(image_pil, model, last_conv_layer_name, class_index):
    input_size = model.input_shape[1:3]
    img_np = np.array(image_pil)
    img_resized = cv2.resize(img_np, input_size)

    if 'xception' in model.name: 
        preprocessor = preprocess_xception
    elif 'resnet50' in model.name: 
        preprocessor = preprocess_resnet
    else: 
        preprocessor = preprocess_densenet
    
    img_array_preprocessed = preprocessor(np.expand_dims(img_resized, axis=0))
    
    # Vérification que la couche existe
    try:
        conv_layer = model.get_layer(last_conv_layer_name)
    except ValueError:
        print(f"Couche '{last_conv_layer_name}' non trouvée dans le modèle")
        return img_resized
    
    grad_model = Model(model.inputs, [conv_layer.output, model.output])
    input_name = get_primary_input_name(model)
    input_for_model = {input_name: img_array_preprocessed}

    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(input_for_model, training=False)
        if isinstance(preds, list): 
            preds = preds[0]
        class_channel = preds[:, class_index]

    grads = tape.gradient(class_channel, last_conv_layer_output)
    
    # Vérifications de sécurité
    if grads is None:
        print("Gradients sont None - retour de l'image originale")
        return img_resized
    
    # Vérifier les valeurs NaN ou inf
    if tf.reduce_any(tf.math.is_nan(grads)) or tf.reduce_any(tf.math.is_inf(grads)):
        print("Gradients contiennent des NaN/inf - retour de l'image originale")
        return img_resized

    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    last_conv_layer_output = last_conv_layer_output[0]
    
    # Calcul de la heatmap
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    
    # Normalisation sécurisée
    heatmap = tf.maximum(heatmap, 0)
    max_val = tf.math.reduce_max(heatmap)
    
    if max_val == 0:
        print("Heatmap max est 0 - création d'une heatmap neutre")
        heatmap = tf.ones_like(heatmap) * 0.5
    else:
        heatmap = heatmap / max_val
    
    heatmap_np = heatmap.numpy()
    
    # Vérifications finales avant resize
    if heatmap_np.size == 0:
        print("Heatmap vide - retour de l'image originale")
        return img_resized
    
    if np.any(np.isnan(heatmap_np)) or np.any(np.isinf(heatmap_np)):
        print("Heatmap contient des NaN/inf après conversion - retour de l'image originale")
        return img_resized
    
    # Redimensionnement sécurisé
    try:
        # S'assurer que heatmap_np est en float32 et dans [0,1]
        heatmap_np = np.clip(heatmap_np.astype(np.float32), 0, 1)
        heatmap_resized = cv2.resize(heatmap_np, (img_resized.shape[1], img_resized.shape[0]))
    except cv2.error as e:
        print(f"Erreur OpenCV resize: {e}")
        return img_resized
    
    # Conversion finale
    heatmap_uint8 = np.uint8(255 * heatmap_resized)
    heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
    
    # Superposition
    img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
    superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap_colored, 0.4, 0)
    
    return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)

# ---- Fonction Gradio (avec gestion d'erreur pour Grad-CAM) ----
# ---- Fonction Gradio améliorée (avec pourcentages) ----
def gradio_predict(image_pil):
    if image_pil is None: return "Veuillez uploader une image.", None, None
    try:
        all_preds = predict_single(image_pil)
        ensemble_probs = all_preds["ensemble"]
        
        top_class_idx = np.argmax(ensemble_probs)
        top_class_name = CLASS_NAMES[top_class_idx]
        global_diag = diagnosis_map[top_class_name]
        
        # Calcul du pourcentage pour le diagnostic principal
        top_class_prob = float(ensemble_probs[top_class_idx])
        diagnostic_with_percentage = f"{global_diag} - {top_class_prob*100:.1f}%"
        
        # Préparation des données pour le graphique avec pourcentages
        confidences = {}
        for i in range(len(CLASS_NAMES)):
            prob_value = float(ensemble_probs[i])
            percentage_str = f"{prob_value*100:.1f}%"
            confidences[CLASS_NAMES[i]] = prob_value
        
        df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité'])
        df = df.sort_values(by='Probabilité', ascending=False)
        df.index.name = "Classe"
        df = df.reset_index()
        
        # Ajout d'une colonne pour les labels avec pourcentages
        df['Pourcentage'] = df['Probabilité'].apply(lambda x: f"{x*100:.1f}%")

        # --- BLOC GRAD-CAM SÉCURISÉ ---
        gradcam_img = None # Initialisation à None
        try:
            model_confidences = {
                "xception": all_preds["xception"][top_class_idx],
                "resnet50": all_preds["resnet50"][top_class_idx],
                "densenet201": all_preds["densenet201"][top_class_idx]
            }
            explainer_model_name = max(model_confidences, key=model_confidences.get)

            model_map = {"xception": model_xcept, "resnet50": model_resnet50, "densenet201": model_densenet}
            layer_map = {"xception": "block14_sepconv2_act", "resnet50": "conv5_block3_out", "densenet201": "relu"}
            
            explainer_model = model_map[explainer_model_name]
            explainer_layer = layer_map[explainer_model_name]
            
            print(f"Génération du Grad-CAM avec le modèle '{explainer_model_name}' sur la couche '{explainer_layer}'.")
            gradcam_img = make_gradcam(image_pil, explainer_model, explainer_layer, class_index=top_class_idx)
        except Exception as e:
            print(f"--- ERREUR LORS DE LA GÉNÉRATION DE GRAD-CAM (le reste de l'app continue) ---")
            print(e)
            # gradcam_img reste à None, Gradio affichera une boîte vide
        # --- FIN DU BLOC SÉCURISÉ ---

        return diagnostic_with_percentage, df, gradcam_img
        
    except Exception as e:
        print(f"Erreur majeure dans gradio_predict : {e}")
        import traceback
        traceback.print_exc()
        return "Erreur lors du traitement de l'image.", None, None


# ---- Gradio UI avec pourcentages dans les barres ----
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)")
    gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle dynamique.")
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Uploader une image de lésion")
            submit_btn = gr.Button("Analyser", variant="primary")
            gr.Examples(examples=example_paths, inputs=input_image)
        with gr.Column(scale=1):
            output_label = gr.Label(label="Diagnostic global")
            # Configuration du graphique avec texte sur les barres
            output_plot = gr.BarPlot(
                label="Probabilités par classe", 
                x="Classe", 
                y="Probabilité", 
                y_lim=[0, 1],
                text="Pourcentage",  # Affiche la colonne "Pourcentage" sur les barres
                text_position="inside"  # Position du texte à l'intérieur des barres
            )
            output_gradcam = gr.Image(label="Visualisation Grad-CAM (Modèle 'le plus sûr')")
    submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot, output_gradcam])

if __name__ == "__main__":
    demo.launch()