skin_care / app.py
ericjedha's picture
Update app.py
1c99721 verified
raw
history blame
6.11 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from keras.preprocessing import image
from huggingface_hub import hf_hub_download
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin',
'bcc': 'Malin',
'bkl': 'Bénin',
'df': 'Bénin',
'nv': 'Bénin',
'vasc': 'Bénin',
'mel': 'Malin'
}
# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
def load_image(path, target_size):
img = image.load_img(path, target_size=target_size)
return image.img_to_array(img)
# ---- Wrapper robuste pour prédictions ----
def get_primary_input_name(model):
"""Retourne le vrai nom d'input du modèle Keras."""
if isinstance(model.inputs, list) and hasattr(model.inputs[0], "name"):
return model.inputs[0].name.split(":")[0]
return None
def safe_forward(model, x):
"""Forward pass qui évite les crashs liés aux noms d’inputs."""
if isinstance(x, np.ndarray):
x = x.astype(np.float32, copy=False)
input_name = get_primary_input_name(model)
try:
if input_name:
return model({input_name: x}, training=False).numpy()
else:
return model(x, training=False).numpy()
except Exception as e:
print(f"[safe_forward] Erreur avec {model.name}: {e}")
return np.zeros((1, len(CLASS_NAMES))) # fallback
# ---- Prédiction single image ----
def predict_single(img_path, weights=(0.45, 0.25, 0.30)):
bx = preprocess_xception(np.expand_dims(load_image(img_path, (299, 299)), axis=0))
br = preprocess_resnet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
bd = preprocess_densenet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
pred_x = safe_forward(model_xcept, bx)
pred_r = safe_forward(model_resnet50, br)
pred_d = safe_forward(model_densenet, bd)
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
# Boost MEL avec DenseNet
mel_idx = label_to_index['mel']
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return preds[0]
# ---- Grad-CAM ----
def make_gradcam(img_path, model, last_conv_layer_name="conv5_block32_concat", class_index=None):
img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img)
input_array = np.expand_dims(img_array, axis=0)
input_array = preprocess_densenet(input_array)
if class_index is None:
preds = safe_forward(model, input_array)
class_index = np.argmax(preds[0])
grad_model = Model(
inputs=model.inputs,
outputs=[model.get_layer(last_conv_layer_name).output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(input_array, training=False)
loss = predictions[:, class_index]
grads = tape.gradient(loss, conv_outputs)[0]
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = np.maximum(heatmap, 0) / (np.max(heatmap) + 1e-6)
heatmap = cv2.resize(heatmap.numpy(), (224, 224))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(
cv2.cvtColor(img_array.astype("uint8"), cv2.COLOR_RGB2BGR),
0.6, heatmap, 0.4, 0
)
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
# ---- Fonction Gradio ----
def gradio_predict(image_file):
try:
probs = predict_single(image_file)
sorted_idx = np.argsort(-probs)
sorted_labels = [CLASS_NAMES[i].upper() for i in sorted_idx]
sorted_probs = probs[sorted_idx] * 100
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
bar_data = {"Classes": sorted_labels, "Probabilité (%)": sorted_probs.tolist()}
# Grad-CAM sur la meilleure classe
top_class = np.argmax(probs)
gradcam_img = make_gradcam(image_file, model_densenet, class_index=top_class)
return global_diag, gr.BarPlot.update(
value=bar_data,
x="Classes", y="Probabilité (%)",
title="Distribution des classes"
), gradcam_img
except Exception as e:
print("Erreur dans gradio_predict :", e)
return "Erreur", None, None
# ---- Gradio UI ----
examples = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
demo = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(type="filepath", label="Uploader une image de lésion"),
outputs=[
gr.Label(label="Diagnostic global"),
gr.BarPlot(label="Probabilités par classe"),
gr.Image(label="Visualisation Grad-CAM")
],
examples=examples,
title="Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)",
description="Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle."
)
if __name__ == "__main__":
demo.launch()