|
import os |
|
import numpy as np |
|
import gradio as gr |
|
import cv2 |
|
import tensorflow as tf |
|
import keras |
|
from huggingface_hub import hf_hub_download |
|
import pandas as pd |
|
from PIL import Image |
|
|
|
|
|
CLASS_NAMES = ['akiec', 'bcc', 'bkl', 'df', 'nv', 'vasc', 'mel'] |
|
label_to_index = {name: i for i, name in enumerate(CLASS_NAMES)} |
|
diagnosis_map = { |
|
'akiec': 'Bénin', 'bcc': 'Malin', 'bkl': 'Bénin', 'df': 'Bénin', |
|
'nv': 'Bénin', 'vasc': 'Bénin', 'mel': 'Malin' |
|
} |
|
|
|
|
|
resnet_path = hf_hub_download(repo_id="ericjedha/resnet50", filename="Resnet50.keras") |
|
densenet_path = hf_hub_download(repo_id="ericjedha/densenet201", filename="Densenet201.keras") |
|
|
|
model_resnet50 = keras.saving.load_model(resnet_path, compile=False) |
|
model_densenet = keras.saving.load_model(densenet_path, compile=False) |
|
model_xcept = keras.saving.load_model("Xception.keras", compile=False) |
|
|
|
|
|
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception |
|
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet |
|
from tensorflow.keras.applications.densenet import preprocess_input as preprocess_densenet |
|
|
|
|
|
def get_primary_input_name(model): |
|
"""Retourne le nom de la couche d'input principale du modèle.""" |
|
if isinstance(model.inputs, list) and len(model.inputs) > 0: |
|
return model.inputs[0].name.split(':')[0] |
|
return "input_1" |
|
|
|
def safe_forward(model, x): |
|
"""Appelle un modèle en utilisant le nom d'input correct pour éviter les UserWarnings.""" |
|
input_name = get_primary_input_name(model) |
|
return model({input_name: x}, training=False).numpy() |
|
|
|
|
|
def predict_single(image_pil, weights=(0.45, 0.25, 0.30)): |
|
img_np = np.array(image_pil) |
|
img_299 = cv2.resize(img_np, (299, 299)) |
|
img_224 = cv2.resize(img_np, (224, 224)) |
|
|
|
bx = preprocess_xception(np.expand_dims(img_299, axis=0)) |
|
br = preprocess_resnet(np.expand_dims(img_224, axis=0)) |
|
bd = preprocess_densenet(np.expand_dims(img_224, axis=0)) |
|
|
|
pred_x = safe_forward(model_xcept, bx) |
|
pred_r = safe_forward(model_resnet50, br) |
|
pred_d = safe_forward(model_densenet, bd) |
|
|
|
preds = (weights[0] * pred_x + weights[1] * pred_r + weights[2] * pred_d) |
|
mel_idx = label_to_index['mel'] |
|
preds[:, mel_idx] = (0.5 * preds[:, mel_idx] + 0.5 * pred_d[:, mel_idx]) |
|
return preds[0] |
|
|
|
|
|
def gradio_predict(image_pil): |
|
if image_pil is None: |
|
return "Veuillez uploader une image.", None |
|
try: |
|
probs = predict_single(image_pil) |
|
|
|
benign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Bénin") |
|
malign_prob = sum(probs[i] for i, cls in enumerate(CLASS_NAMES) if diagnosis_map[cls] == "Malin") |
|
global_diag = "Bénin" if benign_prob >= malign_prob else "Malin" |
|
|
|
confidences = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} |
|
df = pd.DataFrame.from_dict(confidences, orient='index', columns=['Probabilité']) |
|
df = df.sort_values(by='Probabilité', ascending=False) |
|
df.index.name = "Classe" |
|
df = df.reset_index() |
|
|
|
|
|
return global_diag, df |
|
|
|
except Exception as e: |
|
print(f"Erreur dans gradio_predict : {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return "Erreur lors du traitement de l'image.", None |
|
|
|
|
|
example_paths = ["exemple1.jpg", "exemple2.jpg", "exemple3.jpg"] |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# Analyse de lésions cutanées (Ensemble de modèles)") |
|
gr.Markdown("Cet outil propose une prédiction de la nature de la lésion (Bénin/Malin).") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image(type="pil", label="Uploader une image de lésion") |
|
submit_btn = gr.Button("Analyser", variant="primary") |
|
gr.Examples(examples=example_paths, inputs=input_image) |
|
with gr.Column(scale=1): |
|
output_label = gr.Label(label="Diagnostic global") |
|
output_plot = gr.BarPlot(label="Probabilités par classe", x="Classe", y="Probabilité", y_lim=[0, 1]) |
|
|
|
|
|
|
|
submit_btn.click(fn=gradio_predict, inputs=input_image, outputs=[output_label, output_plot]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |