skin_care / app.py
ericjedha's picture
Update app.py
4e23572 verified
raw
history blame
6.74 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}
# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
# ---- Fonctions utilitaires robustes pour les modèles Keras ----
def get_primary_input_name(model):
"""Retourne le nom de la couche d'input principale du modèle."""
# ===== CORRECTION FINALE ICI =====
# L'attribut correct est `model.inputs`, qui est une liste de tenseurs.
# On prend le premier tenseur de la liste et on récupère son nom.
if isinstance(model.inputs, list) and len(model.inputs) > 0:
# Le nom est souvent de la forme "input_layer:0", on ne garde que "input_layer"
return model.inputs[0].name.split(':')[0]
# Fallback au cas où, mais ne devrait pas être nécessaire
return "input_1"
# ===== FIN DE LA CORRECTION =====
def safe_forward(model, x):
"""Appelle un modèle en utilisant le nom d'input correct pour éviter les UserWarnings."""
input_name = get_primary_input_name(model)
return model({input_name: x}, training=False).numpy()
# ---- Prédiction (utilise maintenant safe_forward) ----
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
img_np = np.array(image_pil)
img_299 = cv2.resize(img_np, (299, 299))
img_224 = cv2.resize(img_np, (224, 224))
bx = preprocess_xception(np.expand_dims(img_299, axis=0))
br = preprocess_resnet(np.expand_dims(img_224, axis=0))
bd = preprocess_densenet(np.expand_dims(img_224, axis=0))
pred_x = safe_forward(model_xcept, bx)
pred_r = safe_forward(model_resnet50, br)
pred_d = safe_forward(model_densenet, bd)
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
mel_idx = label_to_index['mel']
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return preds[0]
# ---- Grad-CAM (combine TOUTES les corrections) ----
def make_gradcam(image_pil, model, last_conv_layer_name="conv5_block32_concat", class_index=None):
img_np = np.array(image_pil)
img_resized = cv2.resize(img_np, (224, 224))
img_array_preprocessed = preprocess_densenet(np.expand_dims(img_resized, axis=0))
grad_model = Model(model.inputs, [model.get_layer(last_conv_layer_name).output, model.output])
input_name = get_primary_input_name(model)
input_for_model = {input_name: img_array_preprocessed}
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(input_for_model, training=False)
if isinstance(preds, list):
preds = preds[0]
if class_index is None:
class_index = tf.argmax(preds[0])
class_channel = preds[:, class_index]
grads = tape.gradient(class_channel, last_conv_layer_output)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
heatmap = heatmap.numpy()
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
# ---- Fonction Gradio ----
def gradio_predict(image_pil):
if image_pil is None:
return "Veuillez uploader une image.", None, None
try:
probs = predict_single(image_pil)
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
confidences = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité'])
df = df.sort_values(by='Probabilité', ascending=False)
df.index.name = "Classe"
df = df.reset_index()
top_class_idx = np.argmax(probs)
gradcam_img = make_gradcam(image_pil, model_densenet, class_index=top_class_idx)
return global_diag, df, gradcam_img
except Exception as e:
print(f"Erreur dans gradio_predict : {e}")
import traceback
traceback.print_exc()
return "Erreur lors du traitement de l'image.", None, None
# ---- Gradio UI ----
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)")
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
submit_btn = gr.Button("Analyser", variant="primary")
gr.Examples(examples=example_paths, inputs=input_image)
with gr.Column(scale=1):
output_label = gr.Label(label="Diagnostic global")
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
output_gradcam = gr.Image(label="Visualisation Grad-CAM")
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot, output_gradcam])
if __name__ == "__main__":
demo.launch()