skin_care / app.py
ericjedha's picture
Update app.py
ad465b4 verified
raw
history blame
6.44 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}
# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
# ---- Fonctions utilitaires robustes pour les modèles Keras ----
def get_primary_input_name(model):
if isinstance(model.inputs, list) and len(model.inputs) > 0:
return model.inputs[0].name.split(':')[0]
return "input_1"
def safe_forward(model, x):
input_name = get_primary_input_name(model)
return model({input_name: x}, training=False).numpy()
# ---- Prédiction ----
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
img_np = np.array(image_pil)
img_299 = cv2.resize(img_np, (299, 299))
img_224 = cv2.resize(img_np, (224, 224))
bx = preprocess_xception(np.expand_dims(img_299, axis=0))
br = preprocess_resnet(np.expand_dims(img_224, axis=0))
bd = preprocess_densenet(np.expand_dims(img_224, axis=0))
pred_x = safe_forward(model_xcept, bx)
pred_r = safe_forward(model_resnet50, br)
pred_d = safe_forward(model_densenet, bd)
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
mel_idx = label_to_index['mel']
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return preds[0]
# ---- Grad-CAM ----
def make_gradcam(image_pil, model, last_conv_layer_name="conv5_block32_concat", class_index=None):
img_np = np.array(image_pil)
img_resized = cv2.resize(img_np, (224, 224))
img_array_preprocessed = preprocess_densenet(np.expand_dims(img_resized, axis=0))
grad_model = Model(model.inputs, [model.get_layer(last_conv_layer_name).output, model.output])
input_name = get_primary_input_name(model)
input_for_model = {input_name: img_array_preprocessed}
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(input_for_model, training=False)
if isinstance(preds, list):
preds = preds[0]
if class_index is None:
class_index = tf.argmax(preds[0])
class_channel = preds[:, class_index]
grads = tape.gradient(class_channel, last_conv_layer_output)
# ===== LA CORRECTION FINALE EST ICI : on ne moyenne que sur les axes spatiaux =====
pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
# ===================================================================================
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
heatmap = heatmap.numpy()
heatmap = np.uint8(255 * heatmap)
# applyColorMap prend une image 1 canal et retourne une image 3 canaux (BGR)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
# Les deux images ont maintenant la même taille (224, 224) et 3 canaux.
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
# ---- Fonction Gradio ----
def gradio_predict(image_pil):
if image_pil is None:
return "Veuillez uploader une image.", None, None
try:
probs = predict_single(image_pil)
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
confidences = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité'])
df = df.sort_values(by='Probabilité', ascending=False)
df.index.name = "Classe"
df = df.reset_index()
top_class_idx = np.argmax(probs)
gradcam_img = make_gradcam(image_pil, model_densenet, class_index=top_class_idx)
return global_diag, df, gradcam_img
except Exception as e:
print(f"Erreur dans gradio_predict : {e}")
import traceback
traceback.print_exc()
return "Erreur lors du traitement de l'image.", None, None
# ---- Gradio UI ----
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)")
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
submit_btn = gr.Button("Analyser", variant="primary")
gr.Examples(examples=example_paths, inputs=input_image)
with gr.Column(scale=1):
output_label = gr.Label(label="Diagnostic global")
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
output_gradcam = gr.Image(label="Visualisation Grad-CAM")
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot, output_gradcam])
if __name__ == "__main__":
demo.launch()