skin_care / app.py
ericjedha's picture
Update app.py
825c5d4 verified
raw
history blame
5.94 kB
import os
#os.environ["KERAS_BACKEND"] = "jax"
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from keras.preprocessing import image
from huggingface_hub import hf_hub_download
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin',
'bcc': 'Malin',
'bkl': 'Bénin',
'df': 'Bénin',
'nv': 'Bénin',
'vasc': 'Bénin',
'mel': 'Malin'
}
NUM_CLASSES = len(CLASS_NAMES)
# ---- Téléchargement des modèles depuis Hugging Face ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
# ---- Chargement modèles ----
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
def load_image(path, target_size):
img = image.load_img(path, target_size=target_size)
return image.img_to_array(img)
# ---- Sécurité sur les prédictions ----
def safe_preds(pred_array, target_len=NUM_CLASSES):
pred_array = np.array(pred_array)
if pred_array.shape[1] < target_len:
zeros = np.zeros((pred_array.shape[0], target_len - pred_array.shape[1]))
pred_array = np.concatenate([pred_array, zeros], axis=1)
elif pred_array.shape[1] > target_len:
pred_array = pred_array[:, :target_len]
return pred_array
# ---- Prédiction single image ----
def predict_single(img_path, weights=(0.45, 0.25, 0.30)):
bx = preprocess_xception(np.expand_dims(load_image(img_path, (299, 299)), axis=0))
br = preprocess_resnet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
bd = preprocess_densenet(np.expand_dims(load_image(img_path, (224, 224)), axis=0))
pred_x = safe_preds(model_xcept.predict(bx, verbose=0))
pred_r = safe_preds(model_resnet50.predict(br, verbose=0))
pred_d = safe_preds(model_densenet.predict(bd, verbose=0))
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
# boost MEL avec DenseNet si possible
mel_idx = label_to_index.get('mel', None)
if mel_idx is not None and mel_idx < preds.shape[1] and mel_idx < pred_d.shape[1]:
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return preds[0]
# ---- Grad-CAM ----
def make_gradcam(img_path, model, last_conv_layer_name="conv5_block32_concat", class_index=None):
img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img)
input_array = np.expand_dims(img_array, axis=0)
input_array = preprocess_densenet(input_array)
if class_index is None:
preds = model.predict(input_array)
preds = np.array(preds)
class_index = np.argmax(preds[0])
grad_model = Model(inputs=model.inputs, outputs=[
model.get_layer(last_conv_layer_name).output,
model.output
])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(input_array)
# sécuriser predictions
if isinstance(predictions, list):
predictions = predictions[0]
predictions = tf.convert_to_tensor(predictions)
loss = predictions[:, class_index]
grads = tape.gradient(loss, conv_outputs)[0]
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = np.maximum(heatmap, 0) / (np.max(heatmap) + 1e-6)
heatmap = cv2.resize(heatmap.numpy(), (224, 224))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(
cv2.cvtColor(img_array.astype("uint8"), cv2.COLOR_RGB2BGR),
0.6, heatmap, 0.4, 0
)
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
# ---- Fonction Gradio ----
def gradio_predict(image_file):
probs = predict_single(image_file)
sorted_idx = np.argsort(-probs)
sorted_labels = [CLASS_NAMES[i].upper() for i in sorted_idx]
sorted_probs = probs[sorted_idx] * 100
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
bar_data = {"Classes": sorted_labels, "Probabilité (%)": sorted_probs.tolist()}
top_class = np.argmax(probs)
gradcam_img = make_gradcam(image_file, model_densenet, class_index=top_class)
return global_diag, gr.BarPlot.update(value=bar_data, x="Classes", y="Probabilité (%)", title="Distribution des classes"), gradcam_img
# ---- Gradio UI ----
examples = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
demo = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(type="filepath", label="Uploader une image de lésion"),
outputs=[
gr.Label(label="Diagnostic global"),
gr.BarPlot(label="Probabilités par classe"),
gr.Image(label="Visualisation Grad-CAM")
],
examples=examples,
title="Analyse de lésions cutanées (Ensemble de modèles + Grad-CAM)",
description="Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin) avec explication visuelle."
)
if __name__ == "__main__":
demo.launch()