skin_care / app.py
ericjedha's picture
Update app.py
8976fd8 verified
raw
history blame
2.62 kB
import gradio as gr
import numpy as np
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras.applications.xception import preprocess_input, decode_predictions
# Charger le modèle
model = tf.keras.models.load_model("Xception-Baseline.keras")
# Nom de la dernière couche convolutive
last_conv_layer_name = "block14_sepconv2_act"
# Fonction Grad-CAM
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
grad_model = tf.keras.models.Model(
[model.inputs],
[model.get_layer(last_conv_layer_name).output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
if pred_index is None:
pred_index = tf.argmax(predictions[0])
class_channel = predictions[:, pred_index]
grads = tape.gradient(class_channel, conv_outputs)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
return heatmap.numpy(), predictions.numpy()
# Superposition de la heatmap
def overlay_heatmap(original_img, heatmap, alpha=0.4):
heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(heatmap_color, alpha, original_img, 1 - alpha, 0)
return superimposed_img
# Fonction principale pour Gradio
def gradcam_interface(img):
# Convertir l'image
img_resized = cv2.resize(img, (299, 299))
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
input_array = preprocess_input(np.expand_dims(img_rgb.astype(np.float32), axis=0))
# Générer heatmap
heatmap, preds = make_gradcam_heatmap(input_array, model, last_conv_layer_name)
# Décodage prédiction (si modèle pré-entraîné sur ImageNet, sinon adapter)
class_idx = np.argmax(preds[0])
confidence = preds[0][class_idx]
decoded = f"Classe prédite : {class_idx} | Confiance : {confidence:.2f}"
# Appliquer la heatmap
heatmap_overlay = overlay_heatmap(img_resized, heatmap)
return heatmap_overlay, decoded
# Interface Gradio
demo = gr.Interface(
fn=gradcam_interface,
inputs=gr.Image(type="numpy", label="Image"),
outputs=[
gr.Image(type="numpy", label="Grad-CAM"),
gr.Text(label="Prédiction")
],
title="Grad-CAM Visualizer - Xception"
)
demo.launch()