|
import os |
|
import numpy as np |
|
import gradio as gr |
|
import cv2 |
|
import tensorflow as tf |
|
import keras |
|
from keras.models import Model |
|
from huggingface_hub import hf_hub_download |
|
import pandas as pd |
|
from PIL import Image |
|
|
|
|
|
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel'] |
|
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)} |
|
diagnosis_map = { |
|
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin', |
|
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin' |
|
} |
|
|
|
|
|
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras") |
|
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras") |
|
|
|
model_resnet50 = keras.saving.load_model(resnet_path, compile=False) |
|
model_densenet = keras.saving.load_model(densenet_path, compile=False) |
|
model_xcept = keras.saving.load_model("Xception.keras", compile=False) |
|
|
|
|
|
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception |
|
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet |
|
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet |
|
|
|
|
|
def get_primary_input_name(model): |
|
if isinstance(model.inputs, list) and len(model.inputs) > 0: |
|
return model.inputs[0].name.split(':')[0] |
|
return "input_1" |
|
|
|
def safe_forward(model, x): |
|
input_name = get_primary_input_name(model) |
|
return model({input_name: x}, training=False) |
|
|
|
|
|
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)): |
|
img_np = np.array(image_pil) |
|
img_299_arr = np.expand_dims(cv2.resize(img_np, (299, 299)), axis=0) |
|
img_224_arr = np.expand_dims(cv2.resize(img_np, (224, 224)), axis=0) |
|
|
|
pred_x_tensor = safe_forward(model_xcept, preprocess_xception(img_299_arr)) |
|
pred_r_tensor = safe_forward(model_resnet50, preprocess_resnet(img_224_arr)) |
|
pred_d_tensor = safe_forward(model_densenet, preprocess_densenet(img_224_arr)) |
|
|
|
pred_x, pred_r, pred_d = pred_x_tensor.numpy(), pred_r_tensor.numpy(), pred_d_tensor.numpy() |
|
|
|
preds_ensemble = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d) |
|
mel_idx = label_to_index['mel'] |
|
preds_ensemble[:, mel_idx] = (0.5 * preds_ensemble[:, mel_idx] + 0.5 * pred_d[:, mel_idx]) |
|
|
|
return { |
|
"ensemble": preds_ensemble[0], "xception": pred_x[0], |
|
"resnet50": pred_r[0], "densenet201": pred_d[0] |
|
} |
|
|
|
|
|
|
|
def make_gradcam(image_pil, model, last_conv_layer_name, class_index): |
|
input_size = model.input_shape[1:3] |
|
img_np = np.array(image_pil) |
|
img_resized = cv2.resize(img_np, input_size) |
|
|
|
if 'xception' in model.name: |
|
preprocessor = preprocess_xception |
|
elif 'resnet50' in model.name: |
|
preprocessor = preprocess_resnet |
|
else: |
|
preprocessor = preprocess_densenet |
|
|
|
img_array_preprocessed = preprocessor(np.expand_dims(img_resized, axis=0)) |
|
|
|
|
|
try: |
|
conv_layer = model.get_layer(last_conv_layer_name) |
|
except ValueError: |
|
print(f"Couche '{last_conv_layer_name}' non trouvée dans le modèle") |
|
return img_resized |
|
|
|
grad_model = Model(model.inputs, [conv_layer.output, model.output]) |
|
input_name = get_primary_input_name(model) |
|
input_for_model = {input_name: img_array_preprocessed} |
|
|
|
with tf.GradientTape() as tape: |
|
last_conv_layer_output, preds = grad_model(input_for_model, training=False) |
|
if isinstance(preds, list): |
|
preds = preds[0] |
|
class_channel = preds[:, class_index] |
|
|
|
grads = tape.gradient(class_channel, last_conv_layer_output) |
|
|
|
|
|
if grads is None: |
|
print("Gradients sont None - retour de l'image originale") |
|
return img_resized |
|
|
|
|
|
if tf.reduce_any(tf.math.is_nan(grads)) or tf.reduce_any(tf.math.is_inf(grads)): |
|
print("Gradients contiennent des NaN/inf - retour de l'image originale") |
|
return img_resized |
|
|
|
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) |
|
last_conv_layer_output = last_conv_layer_output[0] |
|
|
|
|
|
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis] |
|
heatmap = tf.squeeze(heatmap) |
|
|
|
|
|
heatmap = tf.maximum(heatmap, 0) |
|
max_val = tf.math.reduce_max(heatmap) |
|
|
|
if max_val == 0: |
|
print("Heatmap max est 0 - création d'une heatmap neutre") |
|
heatmap = tf.ones_like(heatmap) * 0.5 |
|
else: |
|
heatmap = heatmap / max_val |
|
|
|
heatmap_np = heatmap.numpy() |
|
|
|
|
|
if heatmap_np.size == 0: |
|
print("Heatmap vide - retour de l'image originale") |
|
return img_resized |
|
|
|
if np.any(np.isnan(heatmap_np)) or np.any(np.isinf(heatmap_np)): |
|
print("Heatmap contient des NaN/inf après conversion - retour de l'image originale") |
|
return img_resized |
|
|
|
|
|
try: |
|
|
|
heatmap_np = np.clip(heatmap_np.astype(np.float32), 0, 1) |
|
heatmap_resized = cv2.resize(heatmap_np, (img_resized.shape[1], img_resized.shape[0])) |
|
except cv2.error as e: |
|
print(f"Erreur OpenCV resize: {e}") |
|
return img_resized |
|
|
|
|
|
heatmap_uint8 = np.uint8(255 * heatmap_resized) |
|
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) |
|
|
|
|
|
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR) |
|
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap_colored, 0.4, 0) |
|
|
|
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
def gradio_predict(image_pil): |
|
if image_pil is None: return "Veuillez uploader une image.", None, None |
|
try: |
|
all_preds = predict_single(image_pil) |
|
ensemble_probs = all_preds["ensemble"] |
|
|
|
top_class_idx = np.argmax(ensemble_probs) |
|
top_class_name = CLASS_NAMES[top_class_idx] |
|
global_diag = diagnosis_map[top_class_name] |
|
|
|
|
|
top_class_prob = float(ensemble_probs[top_class_idx]) |
|
diagnostic_with_percentage = f"{global_diag} - {top_class_prob*100:.1f}%" |
|
|
|
|
|
confidences = {} |
|
for i in range(len(CLASS_NAMES)): |
|
prob_value = float(ensemble_probs[i]) |
|
percentage_str = f"{prob_value*100:.1f}%" |
|
confidences[CLASS_NAMES[i]] = prob_value |
|
|
|
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité']) |
|
df = df.sort_values(by='Probabilité', ascending=False) |
|
df.index.name = "Classe" |
|
df = df.reset_index() |
|
|
|
|
|
df['Pourcentage'] = df['Probabilité'].apply(lambda x: f"{x*100:.1f}%") |
|
|
|
|
|
gradcam_img = None |
|
try: |
|
model_confidences = { |
|
"xception": all_preds["xception"][top_class_idx], |
|
"resnet50": all_preds["resnet50"][top_class_idx], |
|
"densenet201": all_preds["densenet201"][top_class_idx] |
|
} |
|
explainer_model_name = max(model_confidences, key=model_confidences.get) |
|
|
|
model_map = {"xception": model_xcept, "resnet50": model_resnet50, "densenet201": model_densenet} |
|
layer_map = {"xception": "block14_sepconv2_act", "resnet50": "conv5_block3_out", "densenet201": "relu"} |
|
|
|
explainer_model = model_map[explainer_model_name] |
|
explainer_layer = layer_map[explainer_model_name] |
|
|
|
print(f"Génération du Grad-CAM avec le modèle '{explainer_model_name}' sur la couche '{explainer_layer}'.") |
|
gradcam_img = make_gradcam(image_pil, explainer_model, explainer_layer, class_index=top_class_idx) |
|
except Exception as e: |
|
print(f"--- ERREUR LORS DE LA GÉNÉRATION DE GRAD-CAM (le reste de l'app continue) ---") |
|
print(e) |
|
|
|
|
|
|
|
return diagnostic_with_percentage, df, gradcam_img |
|
|
|
except Exception as e: |
|
print(f"Erreur majeure dans gradio_predict : {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return "Erreur lors du traitement de l'image.", None, None |
|
|
|
|
|
|
|
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"] |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)") |
|
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle dynamique.") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image(type="pil", label="Uploader une image de lésion") |
|
submit_btn = gr.Button("Analyser", variant="primary") |
|
gr.Examples(examples=example_paths, inputs=input_image) |
|
with gr.Column(scale=1): |
|
output_label = gr.Label(label="Diagnostic global") |
|
|
|
output_plot = gr.BarPlot( |
|
label="Probabilités par classe", |
|
x="Classe", |
|
y="Probabilité", |
|
y_lim=[0, 1], |
|
text="Pourcentage", |
|
text_position="inside" |
|
) |
|
output_gradcam = gr.Image(label="Visualisation Grad-CAM (Modèle 'le plus sûr')") |
|
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot, output_gradcam]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |