skin_care / app.py
ericjedha's picture
Update app.py
f99f68e verified
raw
history blame
15.5 kB
import os
import numpy as np
import gradio as gr
import cv2
import tensorflow as tf
import keras
from keras.models import Model
from keras.preprocessing import image
from huggingface_hub import hf_hub_download
import pandas as pd
from PIL import Image
import plotly.express as px
import time
# Désactiver GPU et logs TensorFlow
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.config.set_visible_devices([], 'GPU')
# ---- Configuration ----
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel']
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)}
diagnosis_map = {
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin',
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin'
}
# ---- Chargement des modèles ----
def load_models_safely():
models = {}
try:
print("📥 Téléchargement ResNet50...")
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras")
models['resnet50'] = keras.saving.load_model(resnet_path, compile=False)
print("✅ ResNet50 chargé")
except Exception as e:
models['resnet50'] = None
try:
print("📥 Téléchargement DenseNet201...")
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras")
models['densenet201'] = keras.saving.load_model(densenet_path, compile=False)
print("✅ DenseNet201 chargé")
except Exception as e:
models['densenet201'] = None
try:
print("📥 Chargement Xception local...")
if os.path.exists("Xception.keras"):
models['xception'] = keras.saving.load_model("Xception.keras", compile=False)
print("✅ Xception chargé")
else:
models['xception'] = None
except Exception as e:
models['xception'] = None
loaded = {k: v for k, v in models.items() if v is not None}
if not loaded:
raise Exception("❌ Aucun modèle n'a pu être chargé!")
print(f"🎯 Modèles chargés: {list(loaded.keys())}")
return models
try:
models_dict = load_models_safely()
model_resnet50 = models_dict.get('resnet50')
model_densenet = models_dict.get('densenet201')
model_xcept = models_dict.get('xception')
except Exception as e:
print(f"🚨 ERREUR CRITIQUE: {e}")
model_resnet50 = model_densenet = model_xcept = None
# ---- Préprocesseurs ----
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet
# ---- Utils ----
def _renorm_safe(p: np.ndarray) -> np.ndarray:
p = np.clip(p, 0.0, None) # Évite les valeurs négatives
s = np.sum(p)
if s <= 0:
return np.ones_like(p, dtype=np.float32) / len(p)
normalized = p / s
return normalized / np.sum(normalized) if np.sum(normalized) > 1.0001 else normalized
def get_primary_input_name(model):
if isinstance(model.inputs, list) and len(model.inputs) > 0:
return model.inputs[0].name.split(':')[0]
return "input_1"
# Helper progress robuste
import time
import numpy as np
import time
import numpy as np
def _update_progress(progress, value, desc=None, animate=False, sleep=0.00):
"""
Met à jour la barre de progression Gradio.
- progress : objet gr.Progress
- value : valeur cible (0–100 ou 0–1)
- desc : texte affiché
- animate : si True, interpolation fluide entre l'ancienne valeur et la nouvelle
- sleep : temps d'attente (secondes) après update, pour forcer l'UI à se rafraîchir
"""
if progress is None:
return
# normalisation
try:
val = float(value)
if val > 1.0:
val = val / 100.0
except Exception:
val = 0.0
# récupérer la dernière valeur connue
last_val = getattr(progress, "_last_val", 0.0)
try:
if animate and val > last_val:
# interpolation fluide
steps = 8
for step in np.linspace(last_val, val, steps):
if desc:
progress(float(step), desc=desc)
else:
progress(float(step))
time.sleep(0.02) # vitesse de lissage
else:
# mise à jour directe
if desc:
progress(val, desc=desc)
else:
progress(val)
except Exception:
pass
# sauvegarder la valeur pour la prochaine fois
progress._last_val = val
# petit délai optionnel pour forcer le rafraîchissement
if sleep > 0:
time.sleep(sleep)
# ---- PREDICT SINGLE ----
def predict_single(img_input, weights=(0.45, 0.25, 0.30), normalize=True):
if isinstance(img_input, str):
pil_img = Image.open(img_input).convert("RGB")
elif isinstance(img_input, Image.Image):
pil_img = img_input.convert("RGB")
else:
raise ValueError("img_input doit être un chemin (str) ou une image PIL")
preds = {}
if model_xcept is not None:
img_x = np.expand_dims(preprocess_xception(np.array(pil_img.resize((299, 299), resample=Image.BILINEAR))), axis=0)
preds['xception'] = model_xcept.predict(img_x, verbose=0)[0]
if model_resnet50 is not None:
img_r = np.expand_dims(preprocess_resnet(np.array(pil_img.resize((224, 224), resample=Image.BILINEAR))), axis=0)
preds['resnet50'] = model_resnet50.predict(img_r, verbose=0)[0]
if model_densenet is not None:
img_d = np.expand_dims(preprocess_densenet(np.array(pil_img.resize((224, 224), resample=Image.BILINEAR))), axis=0)
preds['densenet201'] = model_densenet.predict(img_d, verbose=0)[0]
ensemble = np.zeros(len(CLASS_NAMES), dtype=np.float32)
if 'xception' in preds: ensemble += weights[0] * preds['xception']
if 'resnet50' in preds: ensemble += weights[1] * preds['resnet50']
if 'densenet201' in preds: ensemble += weights[2] * preds['densenet201']
if 'densenet201' in preds:
mel_idx = label_to_index['mel']
ensemble[mel_idx] = 0.5 * ensemble[mel_idx] + 0.5 * preds['densenet201'][mel_idx]
if normalize:
ensemble = _renorm_safe(ensemble)
preds['ensemble'] = ensemble
return preds
# ---- Helpers Grad-CAM ----
LAST_CONV_LAYERS = {
"xception": "block14_sepconv2_act",
"resnet50": "conv5_block3_out",
"densenet201": "conv5_block32_concat"
}
def _guess_backbone_name(model):
name = (getattr(model, "name", "") or "").lower()
if "xception" in name: return "xception"
if "resnet" in name: return "resnet50"
if "densenet" in name: return "densenet201"
return None
def find_last_dense_layer(model):
for layer in reversed(model.layers):
if isinstance(layer, keras.layers.Dense):
return layer
raise ValueError("Aucune couche Dense trouvée dans le modèle.")
# ---- GRAD-CAM ----
def make_gradcam(image_pil, model, last_conv_layer_name, class_index, progress=None):
if model is None:
return np.array(image_pil)
try:
_update_progress(progress, 0, desc="Préparation de l'image...")
input_size = model.input_shape[1:3]
if 'xception' in model.name.lower():
preprocessor = preprocess_xception
elif 'resnet50' in model.name.lower():
preprocessor = preprocess_resnet
elif 'densenet' in model.name.lower():
preprocessor = preprocess_densenet
else:
preprocessor = preprocess_densenet
img_np = np.array(image_pil.convert("RGB"))
img_resized = cv2.resize(img_np, input_size)
img_array_preprocessed = preprocessor(np.expand_dims(img_resized, axis=0))
_update_progress(progress, 20, desc="Calcul des gradients...")
try:
conv_layer = model.get_layer(last_conv_layer_name)
except ValueError:
return img_resized
dense_layer = find_last_dense_layer(model)
grad_model = Model(model.inputs, [conv_layer.output, model.output])
input_name = get_primary_input_name(model)
input_for_model = {input_name: img_array_preprocessed}
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(input_for_model, training=False)
if isinstance(preds, list):
preds = preds[0]
class_channel = preds[:, int(class_index)]
grads = tape.gradient(class_channel, last_conv_layer_output)
if grads is None:
return img_resized
_update_progress(progress, 40, desc="Pooling des gradients...")
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0]
_update_progress(progress, 55, desc="Construction de la heatmap...")
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0)
max_val = tf.math.reduce_max(heatmap)
if max_val == 0:
heatmap = tf.ones_like(heatmap) * 0.5
else:
heatmap = heatmap / max_val
_update_progress(progress, 70, desc="Conversion NumPy...")
heatmap_np = heatmap.numpy()
heatmap_np = np.clip(heatmap_np.astype(np.float32), 0, 1)
_update_progress(progress, 80, desc="Application du colormap...")
heatmap_resized = cv2.resize(heatmap_np, (img_resized.shape[1], img_resized.shape[0]))
heatmap_uint8 = np.uint8(255 * heatmap_resized)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)
superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap_colored, 0.4, 0)
_update_progress(progress, 100, desc="✅ Grad-CAM terminé !")
return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
except Exception as e:
import traceback; traceback.print_exc()
return np.array(image_pil)
# ---- GESTION ASYNCHRONE / ÉTAT ----
current_image = None
current_predictions = None
# ---- Fonctions pour l'UI Gradio ----
def quick_predict_ui(image_pil):
global current_image, current_predictions
if image_pil is None:
return "Veuillez uploader une image.", None, "❌ Erreur: Aucune image fournie."
try:
current_image = image_pil
all_preds = predict_single(image_pil)
current_predictions = all_preds
ensemble_probs = all_preds["ensemble"]
top_class_idx = int(np.argmax(ensemble_probs))
top_class_name = CLASS_NAMES[top_class_idx]
global_diag = diagnosis_map[top_class_name]
confidences = {CLASS_NAMES[i]: float(ensemble_probs[i] * 100) for i in range(len(CLASS_NAMES))}
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité']).reset_index().rename(columns={'index': 'Classe'})
df = df.sort_values(by='Probabilité', ascending=False)
df['Pourcentage'] = df['Probabilité'].apply(lambda x: f"{x:.1f}%")
fig = px.bar(df,
x="Classe",
y="Probabilité",
color="Probabilité",
color_continuous_scale=px.colors.sequential.Viridis,
title="Probabilités par classe",
text="Pourcentage")
text_positions = []
for val in df['Probabilité']:
if val <= 10:
text_positions.append("outside")
else:
text_positions.append("inside")
fig.update_traces(textposition=text_positions)
fig.update_layout(xaxis_title="", yaxis_title="Probabilité (%)", height=400)
return f"{global_diag} ({top_class_name.upper()})", fig, "✅ Analyse terminée. Prêt pour Grad-CAM."
except Exception as e:
return f"Erreur: {e}", None, "❌ Erreur lors de l'analyse."
def generate_gradcam_ui(progress=gr.Progress()):
global current_image, current_predictions
if current_image is None or current_predictions is None:
return None, "❌ Aucun résultat précédent — lance d'abord l'analyse rapide."
try:
_update_progress(progress, 0, desc="Début de la génération Grad-CAM...")
ensemble_probs = current_predictions["ensemble"]
top_class_idx = int(np.argmax(ensemble_probs))
candidates = []
if model_xcept is not None: candidates.append(("xception", model_xcept, current_predictions["xception"][top_class_idx]))
if model_resnet50 is not None: candidates.append(("resnet50", model_resnet50, current_predictions["resnet50"][top_class_idx]))
if model_densenet is not None: candidates.append(("densenet201", model_densenet, current_predictions["densenet201"][top_class_idx]))
if not candidates:
return None, "❌ Aucun modèle disponible pour Grad-CAM."
explainer_model_name, explainer_model, conf = max(candidates, key=lambda t: t[2])
explainer_layer = LAST_CONV_LAYERS.get(explainer_model_name)
_update_progress(progress, 5, desc=f"Génération Grad-CAM avec {explainer_model_name}...")
gradcam_img = make_gradcam(current_image, explainer_model, explainer_layer, class_index=top_class_idx, progress=progress)
_update_progress(progress, 100, desc="✅ Grad-CAM généré !")
return gradcam_img, f"✅ Grad-CAM généré avec {explainer_model_name} (confiance: {conf:.1%})"
except Exception as e:
import traceback; traceback.print_exc()
return None, f"❌ Erreur: {e}"
# ---- INTERFACE GRADIO ----
example_paths = ["ISIC_0024627.jpg", "ISIC_0025539.jpg", "ISIC_0031410.jpg"]
with gr.Blocks(theme=gr.themes.Soft(), title="Analyse de lésions") as demo:
gr.Markdown("# 🔬 Analyse de lésions cutanées")
models_status = []
if model_resnet50: models_status.append("✅ ResNet50")
if model_densenet: models_status.append("✅ DenseNet201")
if model_xcept: models_status.append("✅ Xception")
gr.Markdown(f"**Modèles chargés:** {', '.join(models_status) if models_status else 'AUCUN'}")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="📸 Uploader une image")
with gr.Row():
quick_btn = gr.Button("⚡ Analyse Rapide", variant="primary")
gradcam_btn = gr.Button("🎯 Carte de chaleur", variant="secondary")
gr.Examples(examples=example_paths, inputs=input_image)
with gr.Column(scale=2):
output_label = gr.Label(label="📊 Diagnostic global")
output_plot = gr.Plot(label="📈 Probabilités")
output_gradcam = gr.Image(label="🔍 Visualisation Grad-CAM")
output_status = gr.Textbox(label="Statut", interactive=False)
quick_btn.click(fn=quick_predict_ui, inputs=input_image, outputs=[output_label, output_plot, output_status])
gradcam_btn.click(fn=generate_gradcam_ui, inputs=[], outputs=[output_gradcam, output_status])
if __name__ == "__main__":
if all(m is None for m in [model_resnet50, model_densenet, model_xcept]):
print("\n\n🚨 ATTENTION: Aucun modèle n'a été chargé. L'application ne fonctionnera pas.\n\n")
demo.launch()