import gradio as gr import shap import numpy as np import pandas as pd import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split import json import io from PIL import Image import warnings warnings.filterwarnings("ignore") # Configure TensorFlow to avoid GPU issues import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force TensorFlow to use CPU only import tensorflow as tf # Disable GPU for TensorFlow tf.config.set_visible_devices([], 'GPU') from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input # Set random seeds for reproducibility torch.manual_seed(42) np.random.seed(42) # ============================================================================ # MNIST Model Definition (for Pixel-level SHAP) # ============================================================================ class MNISTNet(nn.Module): def __init__(self): super(MNISTNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.softmax(x, dim=1) return output # ============================================================================ # Global Variables and Model Loading # ============================================================================ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load MNIST data transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # Initialize models (will be loaded on first use) mnist_model = None mnist_background = None resnet_model = None resnet_explainer = None tabular_model = None tabular_explainer = None tabular_data = None # ============================================================================ # Helper Functions # ============================================================================ def initialize_mnist_model(): """Initialize MNIST model and background data""" global mnist_model, mnist_background if mnist_model is None: # Load MNIST test data test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False) # Get background and test images images, targets = next(iter(test_loader)) mnist_background = images[:100] # Create and train a simple model mnist_model = MNISTNet().to(DEVICE) mnist_model.eval() return mnist_model, mnist_background def initialize_resnet_model(): """Initialize ResNet50 model and explainer""" global resnet_model, resnet_explainer if resnet_model is None: resnet_model = ResNet50(weights="imagenet") # Load ImageNet class names class_names = None json_path = "imagenet_class_index.json" # Try to load from file if os.path.exists(json_path): try: with open(json_path) as f: class_idx = json.load(f) class_names = [class_idx[str(i)][1] for i in range(1000)] print(f"✓ Loaded {len(class_names)} ImageNet class names") except Exception as e: print(f"⚠ Error loading class names: {e}") # If not found, try to download if class_names is None: print("Downloading ImageNet class names...") try: import urllib.request url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json" urllib.request.urlretrieve(url, json_path) with open(json_path) as f: class_idx = json.load(f) class_names = [class_idx[str(i)][1] for i in range(1000)] print(f"✓ Downloaded and loaded {len(class_names)} ImageNet class names") except Exception as e: print(f"⚠ Could not download class names: {e}") print("Using placeholder names...") class_names = [f"class_{i}" for i in range(1000)] def f(x): tmp = x.copy() preprocess_input(tmp) return resnet_model(tmp) masker = shap.maskers.Image("inpaint_telea", (224, 224, 3)) resnet_explainer = shap.Explainer(f, masker, output_names=class_names) return resnet_model, resnet_explainer def initialize_tabular_model(): """Initialize tabular model and explainer""" global tabular_model, tabular_explainer, tabular_data if tabular_model is None: # Load adult income dataset (returns DataFrame and Series) X, y = shap.datasets.adult() # Convert to pandas DataFrame if it's not already if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) if not isinstance(y, pd.Series): y = pd.Series(y) # Keep as DataFrame after split X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) # Train Random Forest tabular_model = RandomForestClassifier(n_estimators=100, random_state=42) tabular_model.fit(X_train, y_train) # Create explainer tabular_explainer = shap.TreeExplainer(tabular_model) tabular_data = (X_test, y_test) return tabular_model, tabular_explainer, tabular_data # ============================================================================ # SHAP Explanation Functions # ============================================================================ def explain_mnist_digit(digit_index): """Generate SHAP explanation for MNIST digit""" try: model, background = initialize_mnist_model() # Load test data test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False) images, targets = next(iter(test_loader)) test_images = images[100:110] test_targets = targets[100:110].numpy() # Select image idx = min(digit_index, len(test_images) - 1) test_image = test_images[[idx]] # Move to same device as model test_image = test_image.to(DEVICE) background_device = background.to(DEVICE) # Get prediction with torch.no_grad(): output = model(test_image) pred = output.max(1, keepdim=True)[1].cpu().numpy()[0][0] # Create explainer and get SHAP values explainer = shap.DeepExplainer(model, background_device) shap_values = explainer.shap_values(test_image) # Prepare for visualization shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values] test_numpy = np.swapaxes(np.swapaxes(test_image.cpu().numpy(), 1, -1), 1, 2) # Create plot fig = plt.figure(figsize=(15, 3)) shap.image_plot(shap_numpy, -test_numpy, show=False) # Add title plt.suptitle(f'Actual: {test_targets[idx]}, Predicted: {pred}', fontsize=14, y=1.02) # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close() return img, f"Prediction: {pred} (Actual: {test_targets[idx]})" except Exception as e: return None, f"Error: {str(e)}" def explain_imagenet_image(image): """Generate SHAP explanation for ImageNet image""" try: model, explainer = initialize_resnet_model() # Preprocess image if image is None: return None, "Please upload an image" # Resize and prepare image img = Image.fromarray(image).resize((224, 224)) img_array = np.array(img) if len(img_array.shape) == 2: # Grayscale img_array = np.stack([img_array] * 3, axis=-1) elif img_array.shape[2] == 4: # RGBA img_array = img_array[:, :, :3] img_array = np.clip(img_array, 0, 255).astype(np.uint8) img_array = np.expand_dims(img_array, axis=0) # Calculate SHAP values shap_values = explainer(img_array, max_evals=100, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]) # Create plot fig = plt.figure(figsize=(15, 5)) shap.image_plot(shap_values, show=False) # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) result_img = Image.open(buf) plt.close() return result_img, "SHAP explanation generated successfully" except Exception as e: return None, f"Error: {str(e)}" def explain_tabular_sample(sample_index): """Generate SHAP explanation for tabular data sample""" try: model, explainer, (X_test, y_test) = initialize_tabular_model() # Select sample idx = min(sample_index, len(X_test) - 1) # Get first 100 samples for SHAP calculation X_subset = X_test.iloc[:100] if hasattr(X_test, 'iloc') else X_test[:100] shap_values = explainer(X_subset) # Create waterfall plot fig = plt.figure(figsize=(10, 8)) shap.plots.waterfall(shap_values[idx, :, 1], show=False) # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close() # Get prediction - handle both DataFrame and numpy array if hasattr(X_test, 'iloc'): # DataFrame/Series X_sample = X_test.iloc[[idx]] actual = y_test.iloc[idx] else: # Numpy array X_sample = X_test[idx:idx+1] actual = y_test[idx] pred = model.predict(X_sample)[0] return img, f"Prediction: {pred} (Actual: {actual})" except Exception as e: import traceback error_details = traceback.format_exc() return None, f"Error: {str(e)}\n\nDetails:\n{error_details}" # ============================================================================ # Gradio Interface # ============================================================================ def create_demo(): """Create Gradio demo interface""" with gr.Blocks(title="SHAP Explanations Demo") as demo: gr.Markdown("# SHAP (SHapley Additive exPlanations) Demo") gr.Markdown("This demo showcases three different SHAP explanation methods for machine learning models.") with gr.Tabs(): # Tab 1: MNIST Pixel-level Explanations with gr.Tab("1. Pixel-level (MNIST Digits)"): gr.Markdown(""" ### Pixel-level SHAP Explanations This method uses **DeepExplainer** to show which pixels contribute to the model's prediction. - **Red pixels**: Increase the probability of the predicted class - **Blue pixels**: Decrease the probability of the predicted class """) with gr.Row(): with gr.Column(): mnist_slider = gr.Slider(minimum=0, maximum=9, step=1, value=0, label="Select Test Image Index") mnist_button = gr.Button("Generate Explanation", variant="primary") with gr.Column(): mnist_output = gr.Image(label="SHAP Explanation") mnist_text = gr.Textbox(label="Prediction Result") mnist_button.click( fn=explain_mnist_digit, inputs=[mnist_slider], outputs=[mnist_output, mnist_text] ) # Tab 2: ImageNet Image Explanations with gr.Tab("2. Image Segmentation (ImageNet)"): gr.Markdown(""" ### Image Segmentation SHAP Explanations This method uses **Partition Explainer** with image masking to explain ResNet50 predictions. Upload an image to see which regions contribute to the top predicted classes. """) with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Image") image_button = gr.Button("Generate Explanation", variant="primary") with gr.Column(): image_output = gr.Image(label="SHAP Explanation") image_text = gr.Textbox(label="Status") image_button.click( fn=explain_imagenet_image, inputs=[image_input], outputs=[image_output, image_text] ) # Tab 3: Tabular Data Explanations with gr.Tab("3. Tabular Data (Adult Income)"): gr.Markdown(""" ### Tabular Data SHAP Explanations This method uses **TreeExplainer** to explain Random Forest predictions on the Adult Income dataset. The waterfall plot shows how each feature contributes to the prediction. """) with gr.Row(): with gr.Column(): tabular_slider = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Select Sample Index") tabular_button = gr.Button("Generate Explanation", variant="primary") with gr.Column(): tabular_output = gr.Image(label="SHAP Waterfall Plot") tabular_text = gr.Textbox(label="Prediction Result") tabular_button.click( fn=explain_tabular_sample, inputs=[tabular_slider], outputs=[tabular_output, tabular_text] ) gr.Markdown(""" --- ### About SHAP SHAP (SHapley Additive exPlanations) is a unified approach to explain the output of machine learning models. It connects game theory with local explanations and provides consistent and locally accurate feature attributions. """) return demo # ============================================================================ # Main # ============================================================================ if __name__ == "__main__": demo = create_demo() demo.launch(share=False, server_name="0.0.0.0", server_port=7860)