skin_care / app.py
ericjedha's picture
Update app.py
0805e5b verified
raw
history blame
10.4 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}
# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
# ---- Fonctions utilitaires ----
def get_primary_input_name(model):
if isinstance(model.inputs, list) and len(model.inputs) > 0:
return model.inputs[0].name.split(':')[0]
return "input_1"
def safe_forward(model, x):
input_name = get_primary_input_name(model)
return model({input_name: x}, training=False)
# ---- Prédiction ----
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
img_np = np.array(image_pil)
img_299_arr = np.expand_dims(cv2.resize(img_np, (299, 299)), axis=0)
img_224_arr = np.expand_dims(cv2.resize(img_np, (224, 224)), axis=0)
pred_x_tensor = safe_forward(model_xcept, preprocess_xception(img_299_arr))
pred_r_tensor = safe_forward(model_resnet50, preprocess_resnet(img_224_arr))
pred_d_tensor = safe_forward(model_densenet, preprocess_densenet(img_224_arr))
pred_x, pred_r, pred_d = pred_x_tensor.numpy(), pred_r_tensor.numpy(), pred_d_tensor.numpy()
preds_ensemble = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
mel_idx = label_to_index['mel']
preds_ensemble[:, mel_idx] = (0.5 * preds_ensemble[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return {
"ensemble": preds_ensemble[0], "xception": pred_x[0],
"resnet50": pred_r[0], "densenet201": pred_d[0]
}
# ---- Grad-CAM ----
# ---- Grad-CAM CORRIGÉ ----
def make_gradcam(image_pil, model, last_conv_layer_name, class_index):
input_size = model.input_shape[1:3]
img_np = np.array(image_pil)
img_resized = cv2.resize(img_np, input_size)
if 'xception' in model.name:
preprocessor = preprocess_xception
elif 'resnet50' in model.name:
preprocessor = preprocess_resnet
else:
preprocessor = preprocess_densenet
img_array_preprocessed = preprocessor(np.expand_dims(img_resized, axis=0))
# Vérification que la couche existe
try:
conv_layer = model.get_layer(last_conv_layer_name)
except ValueError:
print(f"Couche '{last_conv_layer_name}' non trouvée dans le modèle")
return img_resized
grad_model = Model(model.inputs, [conv_layer.output, model.output])
input_name = get_primary_input_name(model)
input_for_model = {input_name: img_array_preprocessed}
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(input_for_model, training=False)
if isinstance(preds, list):
preds = preds[0]
class_channel = preds[:, class_index]
grads = tape.gradient(class_channel, last_conv_layer_output)
# Vérifications de sécurité
if grads is None:
print("Gradients sont None - retour de l'image originale")
return img_resized
# Vérifier les valeurs NaN ou inf
if tf.reduce_any(tf.math.is_nan(grads)) or tf.reduce_any(tf.math.is_inf(grads)):
print("Gradients contiennent des NaN/inf - retour de l'image originale")
return img_resized
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0]
# Calcul de la heatmap
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
# Normalisation sécurisée
heatmap = tf.maximum(heatmap, 0)
max_val = tf.math.reduce_max(heatmap)
if max_val == 0:
print("Heatmap max est 0 - création d'une heatmap neutre")
heatmap = tf.ones_like(heatmap) * 0.5
else:
heatmap = heatmap / max_val
heatmap_np = heatmap.numpy()
# Vérifications finales avant resize
if heatmap_np.size == 0:
print("Heatmap vide - retour de l'image originale")
return img_resized
if np.any(np.isnan(heatmap_np)) or np.any(np.isinf(heatmap_np)):
print("Heatmap contient des NaN/inf après conversion - retour de l'image originale")
return img_resized
# Redimensionnement sécurisé
try:
# S'assurer que heatmap_np est en float32 et dans [0,1]
heatmap_np = np.clip(heatmap_np.astype(np.float32), 0, 1)
heatmap_resized = cv2.resize(heatmap_np, (img_resized.shape[1], img_resized.shape[0]))
except cv2.error as e:
print(f"Erreur OpenCV resize: {e}")
return img_resized
# Conversion finale
heatmap_uint8 = np.uint8(255 * heatmap_resized)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
# Superposition
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap_colored, 0.4, 0)
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
# ---- Fonction Gradio (avec gestion d'erreur pour Grad-CAM) ----
# ---- Fonction Gradio améliorée (avec pourcentages) ----
def gradio_predict(image_pil):
if image_pil is None: return "Veuillez uploader une image.", None, None
try:
all_preds = predict_single(image_pil)
ensemble_probs = all_preds["ensemble"]
top_class_idx = np.argmax(ensemble_probs)
top_class_name = CLASS_NAMES[top_class_idx]
global_diag = diagnosis_map[top_class_name]
# Calcul du pourcentage pour le diagnostic principal
top_class_prob = float(ensemble_probs[top_class_idx])
diagnostic_with_percentage = f"{global_diag} - {top_class_prob*100:.1f}%"
# Préparation des données pour le graphique avec pourcentages
confidences = {}
for i in range(len(CLASS_NAMES)):
prob_value = float(ensemble_probs[i])
percentage_str = f"{prob_value*100:.1f}%"
confidences[CLASS_NAMES[i]] = prob_value
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité'])
df = df.sort_values(by='Probabilité', ascending=False)
df.index.name = "Classe"
df = df.reset_index()
# Ajout d'une colonne pour les labels avec pourcentages
df['Pourcentage'] = df['Probabilité'].apply(lambda x: f"{x*100:.1f}%")
# --- BLOC GRAD-CAM SÉCURISÉ ---
gradcam_img = None # Initialisation à None
try:
model_confidences = {
"xception": all_preds["xception"][top_class_idx],
"resnet50": all_preds["resnet50"][top_class_idx],
"densenet201": all_preds["densenet201"][top_class_idx]
}
explainer_model_name = max(model_confidences, key=model_confidences.get)
model_map = {"xception": model_xcept, "resnet50": model_resnet50, "densenet201": model_densenet}
layer_map = {"xception": "block14_sepconv2_act", "resnet50": "conv5_block3_out", "densenet201": "relu"}
explainer_model = model_map[explainer_model_name]
explainer_layer = layer_map[explainer_model_name]
print(f"Génération du Grad-CAM avec le modèle '{explainer_model_name}' sur la couche '{explainer_layer}'.")
gradcam_img = make_gradcam(image_pil, explainer_model, explainer_layer, class_index=top_class_idx)
except Exception as e:
print(f"--- ERREUR LORS DE LA GÉNÉRATION DE GRAD-CAM (le reste de l'app continue) ---")
print(e)
# gradcam_img reste à None, Gradio affichera une boîte vide
# --- FIN DU BLOC SÉCURISÉ ---
return diagnostic_with_percentage, df, gradcam_img
except Exception as e:
print(f"Erreur majeure dans gradio_predict : {e}")
import traceback
traceback.print_exc()
return "Erreur lors du traitement de l'image.", None, None
# ---- Gradio UI avec pourcentages dans les barres ----
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)")
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle dynamique.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
submit_btn = gr.Button("Analyser", variant="primary")
gr.Examples(examples=example_paths, inputs=input_image)
with gr.Column(scale=1):
output_label = gr.Label(label="Diagnostic global")
# Configuration du graphique avec texte sur les barres
output_plot = gr.BarPlot(
label="Probabilités par classe",
x="Classe",
y="Probabilité",
y_lim=[0, 1],
text="Pourcentage", # Affiche la colonne "Pourcentage" sur les barres
text_position="inside" # Position du texte à l'intérieur des barres
)
output_gradcam = gr.Image(label="Visualisation Grad-CAM (Modèle 'le plus sûr')")
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot, output_gradcam])
if __name__ == "__main__":
demo.launch()