import os import numpy as np import gradio as gr import cv2 import tensorflow as tf import keras from keras.models import Model from keras.preprocessing import image from huggingface_hub import hf_hub_download import pandas as pd from PIL import Image import plotly.express as px import time # Désactiver GPU et logs TensorFlow os.environ['CUDA_VISIBLE_DEVICES'] = '-1' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' tf.config.set_visible_devices([], 'GPU') # ---- Configuration ---- CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel'] label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)} diagnosis_map = { 'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin', 'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin' } # ---- Chargement des modèles ---- def load_models_safely(): models = {} try: print("📥 Téléchargement ResNet50...") resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras") models['resnet50'] = keras.saving.load_model(resnet_path, compile=False) print("✅ ResNet50 chargé") except Exception as e: models['resnet50'] = None try: print("📥 Téléchargement DenseNet201...") densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras") models['densenet201'] = keras.saving.load_model(densenet_path, compile=False) print("✅ DenseNet201 chargé") except Exception as e: models['densenet201'] = None try: print("📥 Chargement Xception local...") if os.path.exists("Xception.keras"): models['xception'] = keras.saving.load_model("Xception.keras", compile=False) print("✅ Xception chargé") else: models['xception'] = None except Exception as e: models['xception'] = None loaded = {k: v for k, v in models.items() if v is not None} if not loaded: raise Exception("❌ Aucun modèle n'a pu être chargé!") print(f"🎯 Modèles chargés: {list(loaded.keys())}") return models try: models_dict = load_models_safely() model_resnet50 = models_dict.get('resnet50') model_densenet = models_dict.get('densenet201') model_xcept = models_dict.get('xception') except Exception as e: print(f"🚨 ERREUR CRITIQUE: {e}") model_resnet50 = model_densenet = model_xcept = None # ---- Préprocesseurs ---- from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet # ---- Utils ---- def _renorm_safe(p: np.ndarray) -> np.ndarray: p = np.clip(p, 0.0, None) # Évite les valeurs négatives s = np.sum(p) if s <= 0: return np.ones_like(p, dtype=np.float32) / len(p) normalized = p / s return normalized / np.sum(normalized) if np.sum(normalized) > 1.0001 else normalized def get_primary_input_name(model): if isinstance(model.inputs, list) and len(model.inputs) > 0: return model.inputs[0].name.split(':')[0] return "input_1" # Helper progress robuste import time import numpy as np import time import numpy as np def _update_progress(progress, value, desc=None, animate=False, sleep=0.00): """ Met à jour la barre de progression Gradio. - progress : objet gr.Progress - value : valeur cible (0–100 ou 0–1) - desc : texte affiché - animate : si True, interpolation fluide entre l'ancienne valeur et la nouvelle - sleep : temps d'attente (secondes) après update, pour forcer l'UI à se rafraîchir """ if progress is None: return # normalisation try: val = float(value) if val > 1.0: val = val / 100.0 except Exception: val = 0.0 # récupérer la dernière valeur connue last_val = getattr(progress, "_last_val", 0.0) try: if animate and val > last_val: # interpolation fluide steps = 8 for step in np.linspace(last_val, val, steps): if desc: progress(float(step), desc=desc) else: progress(float(step)) time.sleep(0.02) # vitesse de lissage else: # mise à jour directe if desc: progress(val, desc=desc) else: progress(val) except Exception: pass # sauvegarder la valeur pour la prochaine fois progress._last_val = val # petit délai optionnel pour forcer le rafraîchissement if sleep > 0: time.sleep(sleep) # ---- PREDICT SINGLE ---- def predict_single(img_input, weights=(0.45, 0.25, 0.30), normalize=True): if isinstance(img_input, str): pil_img = Image.open(img_input).convert("RGB") elif isinstance(img_input, Image.Image): pil_img = img_input.convert("RGB") else: raise ValueError("img_input doit être un chemin (str) ou une image PIL") preds = {} if model_xcept is not None: img_x = np.expand_dims(preprocess_xception(np.array(pil_img.resize((299, 299), resample=Image.BILINEAR))), axis=0) preds['xception'] = model_xcept.predict(img_x, verbose=0)[0] if model_resnet50 is not None: img_r = np.expand_dims(preprocess_resnet(np.array(pil_img.resize((224, 224), resample=Image.BILINEAR))), axis=0) preds['resnet50'] = model_resnet50.predict(img_r, verbose=0)[0] if model_densenet is not None: img_d = np.expand_dims(preprocess_densenet(np.array(pil_img.resize((224, 224), resample=Image.BILINEAR))), axis=0) preds['densenet201'] = model_densenet.predict(img_d, verbose=0)[0] ensemble = np.zeros(len(CLASS_NAMES), dtype=np.float32) if 'xception' in preds: ensemble += weights[0] * preds['xception'] if 'resnet50' in preds: ensemble += weights[1] * preds['resnet50'] if 'densenet201' in preds: ensemble += weights[2] * preds['densenet201'] if 'densenet201' in preds: mel_idx = label_to_index['mel'] ensemble[mel_idx] = 0.5 * ensemble[mel_idx] + 0.5 * preds['densenet201'][mel_idx] if normalize: ensemble = _renorm_safe(ensemble) preds['ensemble'] = ensemble return preds # ---- Helpers Grad-CAM ---- LAST_CONV_LAYERS = { "xception": "block14_sepconv2_act", "resnet50": "conv5_block3_out", "densenet201": "conv5_block32_concat" } def _guess_backbone_name(model): name = (getattr(model, "name", "") or "").lower() if "xception" in name: return "xception" if "resnet" in name: return "resnet50" if "densenet" in name: return "densenet201" return None def find_last_dense_layer(model): for layer in reversed(model.layers): if isinstance(layer, keras.layers.Dense): return layer raise ValueError("Aucune couche Dense trouvée dans le modèle.") # ---- GRAD-CAM ---- def make_gradcam(image_pil, model, last_conv_layer_name, class_index, progress=None): if model is None: return np.array(image_pil) try: _update_progress(progress, 0, desc="Préparation de l'image...") input_size = model.input_shape[1:3] if 'xception' in model.name.lower(): preprocessor = preprocess_xception elif 'resnet50' in model.name.lower(): preprocessor = preprocess_resnet elif 'densenet' in model.name.lower(): preprocessor = preprocess_densenet else: preprocessor = preprocess_densenet img_np = np.array(image_pil.convert("RGB")) img_resized = cv2.resize(img_np, input_size) img_array_preprocessed = preprocessor(np.expand_dims(img_resized, axis=0)) _update_progress(progress, 20, desc="Calcul des gradients...") try: conv_layer = model.get_layer(last_conv_layer_name) except ValueError: return img_resized dense_layer = find_last_dense_layer(model) grad_model = Model(model.inputs, [conv_layer.output, model.output]) input_name = get_primary_input_name(model) input_for_model = {input_name: img_array_preprocessed} with tf.GradientTape() as tape: last_conv_layer_output, preds = grad_model(input_for_model, training=False) if isinstance(preds, list): preds = preds[0] class_channel = preds[:, int(class_index)] grads = tape.gradient(class_channel, last_conv_layer_output) if grads is None: return img_resized _update_progress(progress, 40, desc="Pooling des gradients...") pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) last_conv_layer_output = last_conv_layer_output[0] _update_progress(progress, 55, desc="Construction de la heatmap...") heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis] heatmap = tf.squeeze(heatmap) heatmap = tf.maximum(heatmap, 0) max_val = tf.math.reduce_max(heatmap) if max_val == 0: heatmap = tf.ones_like(heatmap) * 0.5 else: heatmap = heatmap / max_val _update_progress(progress, 70, desc="Conversion NumPy...") heatmap_np = heatmap.numpy() heatmap_np = np.clip(heatmap_np.astype(np.float32), 0, 1) _update_progress(progress, 80, desc="Application du colormap...") heatmap_resized = cv2.resize(heatmap_np, (img_resized.shape[1], img_resized.shape[0])) heatmap_uint8 = np.uint8(255 * heatmap_resized) heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) img_bgr = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR) superimposed_img = cv2.addWeighted(img_bgr, 0.6, heatmap_colored, 0.4, 0) _update_progress(progress, 100, desc="✅ Grad-CAM terminé !") return cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB) except Exception as e: import traceback; traceback.print_exc() return np.array(image_pil) # ---- GESTION ASYNCHRONE / ÉTAT ---- current_image = None current_predictions = None # ---- Fonctions pour l'UI Gradio ---- def quick_predict_ui(image_pil): global current_image, current_predictions if image_pil is None: return "Veuillez uploader une image.", None, "❌ Erreur: Aucune image fournie." try: current_image = image_pil all_preds = predict_single(image_pil) current_predictions = all_preds ensemble_probs = all_preds["ensemble"] top_class_idx = int(np.argmax(ensemble_probs)) top_class_name = CLASS_NAMES[top_class_idx] global_diag = diagnosis_map[top_class_name] confidences = {CLASS_NAMES[i]: float(ensemble_probs[i] * 100) for i in range(len(CLASS_NAMES))} df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité']).reset_index().rename(columns={'index': 'Classe'}) df = df.sort_values(by='Probabilité', ascending=False) df['Pourcentage'] = df['Probabilité'].apply(lambda x: f"{x:.1f}%") fig = px.bar(df, x="Classe", y="Probabilité", color="Probabilité", color_continuous_scale=px.colors.sequential.Viridis, title="Probabilités par classe", text="Pourcentage") text_positions = [] for val in df['Probabilité']: if val <= 10: text_positions.append("outside") else: text_positions.append("inside") fig.update_traces(textposition=text_positions) fig.update_layout(xaxis_title="", yaxis_title="Probabilité (%)", height=400) return f"{global_diag} ({top_class_name.upper()})", fig, "✅ Analyse terminée. Prêt pour Grad-CAM." except Exception as e: return f"Erreur: {e}", None, "❌ Erreur lors de l'analyse." def generate_gradcam_ui(progress=gr.Progress()): global current_image, current_predictions if current_image is None or current_predictions is None: return None, "❌ Aucun résultat précédent — lance d'abord l'analyse rapide." try: _update_progress(progress, 0, desc="Début de la génération Grad-CAM...") ensemble_probs = current_predictions["ensemble"] top_class_idx = int(np.argmax(ensemble_probs)) candidates = [] if model_xcept is not None: candidates.append(("xception", model_xcept, current_predictions["xception"][top_class_idx])) if model_resnet50 is not None: candidates.append(("resnet50", model_resnet50, current_predictions["resnet50"][top_class_idx])) if model_densenet is not None: candidates.append(("densenet201", model_densenet, current_predictions["densenet201"][top_class_idx])) if not candidates: return None, "❌ Aucun modèle disponible pour Grad-CAM." explainer_model_name, explainer_model, conf = max(candidates, key=lambda t: t[2]) explainer_layer = LAST_CONV_LAYERS.get(explainer_model_name) _update_progress(progress, 5, desc=f"Génération Grad-CAM avec {explainer_model_name}...") gradcam_img = make_gradcam(current_image, explainer_model, explainer_layer, class_index=top_class_idx, progress=progress) _update_progress(progress, 100, desc="✅ Grad-CAM généré !") return gradcam_img, f"✅ Grad-CAM généré avec {explainer_model_name} (confiance: {conf:.1%})" except Exception as e: import traceback; traceback.print_exc() return None, f"❌ Erreur: {e}" # ---- INTERFACE GRADIO ---- example_paths = ["ISIC_0024627.jpg", "ISIC_0025539.jpg", "ISIC_0031410.jpg"] with gr.Blocks(theme=gr.themes.Soft(), title="Analyse de lésions") as demo: gr.Markdown("# 🔬 Analyse de lésions cutanées") models_status = [] if model_resnet50: models_status.append("✅ ResNet50") if model_densenet: models_status.append("✅ DenseNet201") if model_xcept: models_status.append("✅ Xception") gr.Markdown(f"**Modèles chargés:** {', '.join(models_status) if models_status else 'AUCUN'}") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="📸 Uploader une image") with gr.Row(): quick_btn = gr.Button("⚡ Analyse Rapide", variant="primary") gradcam_btn = gr.Button("🎯 Carte de chaleur", variant="secondary") gr.Examples(examples=example_paths, inputs=input_image) with gr.Column(scale=2): output_label = gr.Label(label="📊 Diagnostic global") output_plot = gr.Plot(label="📈 Probabilités") output_gradcam = gr.Image(label="🔍 Visualisation Grad-CAM") output_status = gr.Textbox(label="Statut", interactive=False) quick_btn.click(fn=quick_predict_ui, inputs=input_image, outputs=[output_label, output_plot, output_status]) gradcam_btn.click(fn=generate_gradcam_ui, inputs=[], outputs=[output_gradcam, output_status]) if __name__ == "__main__": if all(m is None for m in [model_resnet50, model_densenet, model_xcept]): print("\n\n🚨 ATTENTION: Aucun modèle n'a été chargé. L'application ne fonctionnera pas.\n\n") demo.launch()