skin_care / app.py
ericjedha's picture
Update app.py
e177ec4 verified
raw
history blame
6.51 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}
# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
# ---- Fonctions utilitaires ----
def get_primary_input_name(model):
if isinstance(model.inputs, list) and len(model.inputs) > 0:
return model.inputs[0].name.split(':')[0]
return "input_1"
def safe_forward(model, x):
input_name = get_primary_input_name(model)
return model({input_name: x}, training=False).numpy()
# ---- Prédiction ----
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
img_np = np.array(image_pil)
img_299 = cv2.resize(img_np, (299, 299))
img_224 = cv2.resize(img_np, (224, 224))
bx = preprocess_xception(np.expand_dims(img_299, axis=0))
br = preprocess_resnet(np.expand_dims(img_224, axis=0))
bd = preprocess_densenet(np.expand_dims(img_224, axis=0))
pred_x = safe_forward(model_xcept, bx)
pred_r = safe_forward(model_resnet50, br)
pred_d = safe_forward(model_densenet, bd)
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
mel_idx = label_to_index['mel']
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return preds[0]
# ---- Grad-CAM (Version Finale avec nom de couche fixe) ----
def make_gradcam(image_pil, model, class_index=None):
# ===== CORRECTION FINALE : On utilise le nom connu et fiable de la dernière couche de features de DenseNet201 =====
last_conv_layer_name = "relu"
# =================================================================================================================
img_np = np.array(image_pil)
img_resized = cv2.resize(img_np, (224, 224))
img_array_preprocessed = preprocess_densenet(np.expand_dims(img_resized, axis=0))
grad_model = Model(model.inputs, [model.get_layer(last_conv_layer_name).output, model.output])
input_name = get_primary_input_name(model)
input_for_model = {input_name: img_array_preprocessed}
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(input_for_model, training=False)
if isinstance(preds, list):
preds = preds[0]
if class_index is None:
class_index = tf.argmax(preds[0])
class_channel = preds[:, class_index]
grads = tape.gradient(class_channel, last_conv_layer_output)
if grads is None:
print("Erreur: Le gradient est None.")
return img_resized
# Correction de la régression : on moyenne UNIQUEMENT sur les axes spatiaux
pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
heatmap = heatmap.numpy()
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
# ---- Fonction Gradio ----
def gradio_predict(image_pil):
if image_pil is None:
return "Veuillez uploader une image.", None, None
try:
probs = predict_single(image_pil)
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
confidences = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité'])
df = df.sort_values(by='Probabilité', ascending=False)
df.index.name = "Classe"
df = df.reset_index()
top_class_idx = np.argmax(probs)
gradcam_img = make_gradcam(image_pil, model_densenet, class_index=top_class_idx)
return global_diag, df, gradcam_img
except Exception as e:
print(f"Erreur dans gradio_predict : {e}")
import traceback
traceback.print_exc()
return "Erreur lors du traitement de l'image.", None, None
# ---- Gradio UI ----
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)")
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
submit_btn = gr.Button("Analyser", variant="primary")
gr.Examples(examples=example_paths, inputs=input_image)
with gr.Column(scale=1):
output_label = gr.Label(label="Diagnostic global")
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
output_gradcam = gr.Image(label="Visualisation Grad-CAM")
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot, output_gradcam])
if __name__ == "__main__":
demo.launch()