skin_care / app.py
ericjedha's picture
Update app.py
3595c95 verified
raw
history blame
4.69 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}
# ---- Téléchargement modèles ----
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
model_resnet50 = keras.saving.load_model(resnet_path, compile=False)
model_densenet = keras.saving.load_model(densenet_path, compile=False)
model_xcept = keras.saving.load_model("Xception.keras", compile=False)
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
# ---- Fonctions utilitaires robustes pour les modèles Keras ----
def get_primary_input_name(model):
"""Retourne le nom de la couche d'input principale du modèle."""
if isinstance(model.inputs, list) and len(model.inputs) > 0:
return model.inputs[0].name.split(':')[0]
return "input_1"
def safe_forward(model, x):
"""Appelle un modèle en utilisant le nom d'input correct pour éviter les UserWarnings."""
input_name = get_primary_input_name(model)
return model({input_name: x}, training=False).numpy()
# ---- Prédiction ----
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)):
img_np = np.array(image_pil)
img_299 = cv2.resize(img_np, (299, 299))
img_224 = cv2.resize(img_np, (224, 224))
bx = preprocess_xception(np.expand_dims(img_299, axis=0))
br = preprocess_resnet(np.expand_dims(img_224, axis=0))
bd = preprocess_densenet(np.expand_dims(img_224, axis=0))
pred_x = safe_forward(model_xcept, bx)
pred_r = safe_forward(model_resnet50, br)
pred_d = safe_forward(model_densenet, bd)
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d)
mel_idx = label_to_index['mel']
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx])
return preds[0]
# ---- Fonction Gradio (Simplifiée) ----
def gradio_predict(image_pil):
if image_pil is None:
return "Veuillez uploader une image.", None
try:
probs = predict_single(image_pil)
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin")
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin")
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin"
confidences = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité'])
df = df.sort_values(by='Probabilité', ascending=False)
df.index.name = "Classe"
df = df.reset_index()
# On ne retourne que les informations pour le label et le bar plot
return global_diag, df
except Exception as e:
print(f"Erreur dans gradio_predict : {e}")
import traceback
traceback.print_exc()
return "Erreur lors du traitement de l'image.", None
# ---- Gradio UI (Simplifiée) ----
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles)")
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin).")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Uploader une image de lésion")
submit_btn = gr.Button("Analyser", variant="primary")
gr.Examples(examples=example_paths, inputs=input_image)
with gr.Column(scale=1):
output_label = gr.Label(label="Diagnostic global")
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1])
# La sortie Grad-CAM a été enlevée
# L'appel au clic a été mis à jour pour ne gérer que 2 sorties
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot])
if __name__ == "__main__":
demo.launch()