SHAP_DEMO / app.py
xxnithicxx's picture
Add HuggingFace Space files: app.py, README.md with metadata, requirements.txt, and .gitignore
89a46dd
raw
history blame
15.7 kB
import gradio as gr
import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import json
import io
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
# Configure TensorFlow to avoid GPU issues
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force TensorFlow to use CPU only
import tensorflow as tf
# Disable GPU for TensorFlow
tf.config.set_visible_devices([], 'GPU')
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# ============================================================================
# MNIST Model Definition (for Pixel-level SHAP)
# ============================================================================
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.softmax(x, dim=1)
return output
# ============================================================================
# Global Variables and Model Loading
# ============================================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load MNIST data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Initialize models (will be loaded on first use)
mnist_model = None
mnist_background = None
resnet_model = None
resnet_explainer = None
tabular_model = None
tabular_explainer = None
tabular_data = None
# ============================================================================
# Helper Functions
# ============================================================================
def initialize_mnist_model():
"""Initialize MNIST model and background data"""
global mnist_model, mnist_background
if mnist_model is None:
# Load MNIST test data
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False)
# Get background and test images
images, targets = next(iter(test_loader))
mnist_background = images[:100]
# Create and train a simple model
mnist_model = MNISTNet().to(DEVICE)
mnist_model.eval()
return mnist_model, mnist_background
def initialize_resnet_model():
"""Initialize ResNet50 model and explainer"""
global resnet_model, resnet_explainer
if resnet_model is None:
resnet_model = ResNet50(weights="imagenet")
# Load ImageNet class names
class_names = None
json_path = "imagenet_class_index.json"
# Try to load from file
if os.path.exists(json_path):
try:
with open(json_path) as f:
class_idx = json.load(f)
class_names = [class_idx[str(i)][1] for i in range(1000)]
print(f"✓ Loaded {len(class_names)} ImageNet class names")
except Exception as e:
print(f"⚠ Error loading class names: {e}")
# If not found, try to download
if class_names is None:
print("Downloading ImageNet class names...")
try:
import urllib.request
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
urllib.request.urlretrieve(url, json_path)
with open(json_path) as f:
class_idx = json.load(f)
class_names = [class_idx[str(i)][1] for i in range(1000)]
print(f"✓ Downloaded and loaded {len(class_names)} ImageNet class names")
except Exception as e:
print(f"⚠ Could not download class names: {e}")
print("Using placeholder names...")
class_names = [f"class_{i}" for i in range(1000)]
def f(x):
tmp = x.copy()
preprocess_input(tmp)
return resnet_model(tmp)
masker = shap.maskers.Image("inpaint_telea", (224, 224, 3))
resnet_explainer = shap.Explainer(f, masker, output_names=class_names)
return resnet_model, resnet_explainer
def initialize_tabular_model():
"""Initialize tabular model and explainer"""
global tabular_model, tabular_explainer, tabular_data
if tabular_model is None:
# Load adult income dataset (returns DataFrame and Series)
X, y = shap.datasets.adult()
# Convert to pandas DataFrame if it's not already
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
if not isinstance(y, pd.Series):
y = pd.Series(y)
# Keep as DataFrame after split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train Random Forest
tabular_model = RandomForestClassifier(n_estimators=100, random_state=42)
tabular_model.fit(X_train, y_train)
# Create explainer
tabular_explainer = shap.TreeExplainer(tabular_model)
tabular_data = (X_test, y_test)
return tabular_model, tabular_explainer, tabular_data
# ============================================================================
# SHAP Explanation Functions
# ============================================================================
def explain_mnist_digit(digit_index):
"""Generate SHAP explanation for MNIST digit"""
try:
model, background = initialize_mnist_model()
# Load test data
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False)
images, targets = next(iter(test_loader))
test_images = images[100:110]
test_targets = targets[100:110].numpy()
# Select image
idx = min(digit_index, len(test_images) - 1)
test_image = test_images[[idx]]
# Move to same device as model
test_image = test_image.to(DEVICE)
background_device = background.to(DEVICE)
# Get prediction
with torch.no_grad():
output = model(test_image)
pred = output.max(1, keepdim=True)[1].cpu().numpy()[0][0]
# Create explainer and get SHAP values
explainer = shap.DeepExplainer(model, background_device)
shap_values = explainer.shap_values(test_image)
# Prepare for visualization
shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
test_numpy = np.swapaxes(np.swapaxes(test_image.cpu().numpy(), 1, -1), 1, 2)
# Create plot
fig = plt.figure(figsize=(15, 3))
shap.image_plot(shap_numpy, -test_numpy, show=False)
# Add title
plt.suptitle(f'Actual: {test_targets[idx]}, Predicted: {pred}', fontsize=14, y=1.02)
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
buf.seek(0)
img = Image.open(buf)
plt.close()
return img, f"Prediction: {pred} (Actual: {test_targets[idx]})"
except Exception as e:
return None, f"Error: {str(e)}"
def explain_imagenet_image(image):
"""Generate SHAP explanation for ImageNet image"""
try:
model, explainer = initialize_resnet_model()
# Preprocess image
if image is None:
return None, "Please upload an image"
# Resize and prepare image
img = Image.fromarray(image).resize((224, 224))
img_array = np.array(img)
if len(img_array.shape) == 2: # Grayscale
img_array = np.stack([img_array] * 3, axis=-1)
elif img_array.shape[2] == 4: # RGBA
img_array = img_array[:, :, :3]
img_array = np.clip(img_array, 0, 255).astype(np.uint8)
img_array = np.expand_dims(img_array, axis=0)
# Calculate SHAP values
shap_values = explainer(img_array, max_evals=100, batch_size=50,
outputs=shap.Explanation.argsort.flip[:4])
# Create plot
fig = plt.figure(figsize=(15, 5))
shap.image_plot(shap_values, show=False)
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
buf.seek(0)
result_img = Image.open(buf)
plt.close()
return result_img, "SHAP explanation generated successfully"
except Exception as e:
return None, f"Error: {str(e)}"
def explain_tabular_sample(sample_index):
"""Generate SHAP explanation for tabular data sample"""
try:
model, explainer, (X_test, y_test) = initialize_tabular_model()
# Select sample
idx = min(sample_index, len(X_test) - 1)
# Get first 100 samples for SHAP calculation
X_subset = X_test.iloc[:100] if hasattr(X_test, 'iloc') else X_test[:100]
shap_values = explainer(X_subset)
# Create waterfall plot
fig = plt.figure(figsize=(10, 8))
shap.plots.waterfall(shap_values[idx, :, 1], show=False)
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
buf.seek(0)
img = Image.open(buf)
plt.close()
# Get prediction - handle both DataFrame and numpy array
if hasattr(X_test, 'iloc'):
# DataFrame/Series
X_sample = X_test.iloc[[idx]]
actual = y_test.iloc[idx]
else:
# Numpy array
X_sample = X_test[idx:idx+1]
actual = y_test[idx]
pred = model.predict(X_sample)[0]
return img, f"Prediction: {pred} (Actual: {actual})"
except Exception as e:
import traceback
error_details = traceback.format_exc()
return None, f"Error: {str(e)}\n\nDetails:\n{error_details}"
# ============================================================================
# Gradio Interface
# ============================================================================
def create_demo():
"""Create Gradio demo interface"""
with gr.Blocks(title="SHAP Explanations Demo") as demo:
gr.Markdown("# SHAP (SHapley Additive exPlanations) Demo")
gr.Markdown("This demo showcases three different SHAP explanation methods for machine learning models.")
with gr.Tabs():
# Tab 1: MNIST Pixel-level Explanations
with gr.Tab("1. Pixel-level (MNIST Digits)"):
gr.Markdown("""
### Pixel-level SHAP Explanations
This method uses **DeepExplainer** to show which pixels contribute to the model's prediction.
- **Red pixels**: Increase the probability of the predicted class
- **Blue pixels**: Decrease the probability of the predicted class
""")
with gr.Row():
with gr.Column():
mnist_slider = gr.Slider(minimum=0, maximum=9, step=1, value=0,
label="Select Test Image Index")
mnist_button = gr.Button("Generate Explanation", variant="primary")
with gr.Column():
mnist_output = gr.Image(label="SHAP Explanation")
mnist_text = gr.Textbox(label="Prediction Result")
mnist_button.click(
fn=explain_mnist_digit,
inputs=[mnist_slider],
outputs=[mnist_output, mnist_text]
)
# Tab 2: ImageNet Image Explanations
with gr.Tab("2. Image Segmentation (ImageNet)"):
gr.Markdown("""
### Image Segmentation SHAP Explanations
This method uses **Partition Explainer** with image masking to explain ResNet50 predictions.
Upload an image to see which regions contribute to the top predicted classes.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image")
image_button = gr.Button("Generate Explanation", variant="primary")
with gr.Column():
image_output = gr.Image(label="SHAP Explanation")
image_text = gr.Textbox(label="Status")
image_button.click(
fn=explain_imagenet_image,
inputs=[image_input],
outputs=[image_output, image_text]
)
# Tab 3: Tabular Data Explanations
with gr.Tab("3. Tabular Data (Adult Income)"):
gr.Markdown("""
### Tabular Data SHAP Explanations
This method uses **TreeExplainer** to explain Random Forest predictions on the Adult Income dataset.
The waterfall plot shows how each feature contributes to the prediction.
""")
with gr.Row():
with gr.Column():
tabular_slider = gr.Slider(minimum=0, maximum=99, step=1, value=0,
label="Select Sample Index")
tabular_button = gr.Button("Generate Explanation", variant="primary")
with gr.Column():
tabular_output = gr.Image(label="SHAP Waterfall Plot")
tabular_text = gr.Textbox(label="Prediction Result")
tabular_button.click(
fn=explain_tabular_sample,
inputs=[tabular_slider],
outputs=[tabular_output, tabular_text]
)
gr.Markdown("""
---
### About SHAP
SHAP (SHapley Additive exPlanations) is a unified approach to explain the output of machine learning models.
It connects game theory with local explanations and provides consistent and locally accurate feature attributions.
""")
return demo
# ============================================================================
# Main
# ============================================================================
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)