skin_care / app.py
ericjedha's picture
Update app.py
5a52e73 verified
raw
history blame
7.05 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}
# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
# ---- Fonctions utilitaires ----
def get_primary_input_name(model):
if isinstance(model.inputs, list) and len(model.inputs) > 0:
return model.inputs[0].name.split(':')[0]
return "input_1"
def safe_forward(model, x):
input_name = get_primary_input_name(model)
return model({input_name: x}, training=False).numpy()
# ---- Prédiction ----
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
img_np = np.array(image_pil)
img_299 = cv2.resize(img_np, (299, 299))
img_224 = cv2.resize(img_np, (224, 224))
bx = preprocess_xception(np.expand_dims(img_299, axis=0))
br = preprocess_resnet(np.expand_dims(img_224, axis=0))
bd = preprocess_densenet(np.expand_dims(img_224, axis=0))
pred_x = safe_forward(model_xcept, bx)
pred_r = safe_forward(model_resnet50, br)
pred_d = safe_forward(model_densenet, bd)
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
mel_idx = label_to_index['mel']
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return preds[0]
# ---- Grad-CAM (Version Finale et Infaillible) ----
def find_last_feature_map_layer(model):
"""Trouve le nom de la dernière couche avec une sortie 4D (carte de caractéristiques)."""
for layer in reversed(model.layers):
# Vérifie de manière sûre si l'attribut existe ET si la forme est 4D
if hasattr(layer, 'output_shape') and len(layer.output_shape) == 4:
return layer.name
raise ValueError("Impossible de trouver une couche de features (sortie 4D) dans le modèle.")
def make_gradcam(image_pil, model, class_index=None):
# 1. Trouver le nom de la couche dynamiquement et de manière sûre
last_conv_layer_name = find_last_feature_map_layer(model)
img_np = np.array(image_pil)
img_resized = cv2.resize(img_np, (224, 224))
img_array_preprocessed = preprocess_densenet(np.expand_dims(img_resized, axis=0))
grad_model = Model(model.inputs, [model.get_layer(last_conv_layer_name).output, model.output])
input_name = get_primary_input_name(model)
input_for_model = {input_name: img_array_preprocessed}
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(input_for_model, training=False)
if isinstance(preds, list):
preds = preds[0]
if class_index is None:
class_index = tf.argmax(preds[0])
class_channel = preds[:, class_index]
grads = tape.gradient(class_channel, last_conv_layer_output)
if grads is None:
print("Erreur: Le gradient est None. Impossible de générer la heatmap.")
return img_resized # Retourne l'image originale non modifiée
# 2. Correction de la régression : on moyenne UNIQUEMENT sur les axes spatiaux
pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
heatmap = heatmap.numpy()
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
# ---- Fonction Gradio ----
def gradio_predict(image_pil):
if image_pil is None:
return "Veuillez uploader une image.", None, None
try:
probs = predict_single(image_pil)
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
confidences = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité'])
df = df.sort_values(by='Probabilité', ascending=False)
df.index.name = "Classe"
df = df.reset_index()
top_class_idx = np.argmax(probs)
gradcam_img = make_gradcam(image_pil, model_densenet, class_index=top_class_idx)
return global_diag, df, gradcam_img
except Exception as e:
print(f"Erreur dans gradio_predict : {e}")
import traceback
traceback.print_exc()
return "Erreur lors du traitement de l'image.", None, None
# ---- Gradio UI ----
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)")
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
submit_btn = gr.Button("Analyser", variant="primary")
gr.Examples(examples=example_paths, inputs=input_image)
with gr.Column(scale=1):
output_label = gr.Label(label="Diagnostic global")
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
output_gradcam = gr.Image(label="Visualisation Grad-CAM")
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot,
#output_gradcam
])
if __name__ == "__main__":
demo.launch()