Spaces:
Sleeping
Sleeping
Add HuggingFace Space files: app.py, README.md with metadata, requirements.txt, and .gitignore
89a46dd
| import gradio as gr | |
| import shap | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import datasets, transforms | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| import json | |
| import io | |
| from PIL import Image | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Configure TensorFlow to avoid GPU issues | |
| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force TensorFlow to use CPU only | |
| import tensorflow as tf | |
| # Disable GPU for TensorFlow | |
| tf.config.set_visible_devices([], 'GPU') | |
| from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input | |
| # Set random seeds for reproducibility | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| # ============================================================================ | |
| # MNIST Model Definition (for Pixel-level SHAP) | |
| # ============================================================================ | |
| class MNISTNet(nn.Module): | |
| def __init__(self): | |
| super(MNISTNet, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 32, 3, 1) | |
| self.conv2 = nn.Conv2d(32, 64, 3, 1) | |
| self.dropout1 = nn.Dropout2d(0.25) | |
| self.dropout2 = nn.Dropout2d(0.5) | |
| self.fc1 = nn.Linear(9216, 128) | |
| self.fc2 = nn.Linear(128, 10) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = F.relu(x) | |
| x = self.conv2(x) | |
| x = F.relu(x) | |
| x = F.max_pool2d(x, 2) | |
| x = self.dropout1(x) | |
| x = torch.flatten(x, 1) | |
| x = self.fc1(x) | |
| x = F.relu(x) | |
| x = self.dropout2(x) | |
| x = self.fc2(x) | |
| output = F.softmax(x, dim=1) | |
| return output | |
| # ============================================================================ | |
| # Global Variables and Model Loading | |
| # ============================================================================ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load MNIST data | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| # Initialize models (will be loaded on first use) | |
| mnist_model = None | |
| mnist_background = None | |
| resnet_model = None | |
| resnet_explainer = None | |
| tabular_model = None | |
| tabular_explainer = None | |
| tabular_data = None | |
| # ============================================================================ | |
| # Helper Functions | |
| # ============================================================================ | |
| def initialize_mnist_model(): | |
| """Initialize MNIST model and background data""" | |
| global mnist_model, mnist_background | |
| if mnist_model is None: | |
| # Load MNIST test data | |
| test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) | |
| test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False) | |
| # Get background and test images | |
| images, targets = next(iter(test_loader)) | |
| mnist_background = images[:100] | |
| # Create and train a simple model | |
| mnist_model = MNISTNet().to(DEVICE) | |
| mnist_model.eval() | |
| return mnist_model, mnist_background | |
| def initialize_resnet_model(): | |
| """Initialize ResNet50 model and explainer""" | |
| global resnet_model, resnet_explainer | |
| if resnet_model is None: | |
| resnet_model = ResNet50(weights="imagenet") | |
| # Load ImageNet class names | |
| class_names = None | |
| json_path = "imagenet_class_index.json" | |
| # Try to load from file | |
| if os.path.exists(json_path): | |
| try: | |
| with open(json_path) as f: | |
| class_idx = json.load(f) | |
| class_names = [class_idx[str(i)][1] for i in range(1000)] | |
| print(f"✓ Loaded {len(class_names)} ImageNet class names") | |
| except Exception as e: | |
| print(f"⚠ Error loading class names: {e}") | |
| # If not found, try to download | |
| if class_names is None: | |
| print("Downloading ImageNet class names...") | |
| try: | |
| import urllib.request | |
| url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json" | |
| urllib.request.urlretrieve(url, json_path) | |
| with open(json_path) as f: | |
| class_idx = json.load(f) | |
| class_names = [class_idx[str(i)][1] for i in range(1000)] | |
| print(f"✓ Downloaded and loaded {len(class_names)} ImageNet class names") | |
| except Exception as e: | |
| print(f"⚠ Could not download class names: {e}") | |
| print("Using placeholder names...") | |
| class_names = [f"class_{i}" for i in range(1000)] | |
| def f(x): | |
| tmp = x.copy() | |
| preprocess_input(tmp) | |
| return resnet_model(tmp) | |
| masker = shap.maskers.Image("inpaint_telea", (224, 224, 3)) | |
| resnet_explainer = shap.Explainer(f, masker, output_names=class_names) | |
| return resnet_model, resnet_explainer | |
| def initialize_tabular_model(): | |
| """Initialize tabular model and explainer""" | |
| global tabular_model, tabular_explainer, tabular_data | |
| if tabular_model is None: | |
| # Load adult income dataset (returns DataFrame and Series) | |
| X, y = shap.datasets.adult() | |
| # Convert to pandas DataFrame if it's not already | |
| if not isinstance(X, pd.DataFrame): | |
| X = pd.DataFrame(X) | |
| if not isinstance(y, pd.Series): | |
| y = pd.Series(y) | |
| # Keep as DataFrame after split | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42 | |
| ) | |
| # Train Random Forest | |
| tabular_model = RandomForestClassifier(n_estimators=100, random_state=42) | |
| tabular_model.fit(X_train, y_train) | |
| # Create explainer | |
| tabular_explainer = shap.TreeExplainer(tabular_model) | |
| tabular_data = (X_test, y_test) | |
| return tabular_model, tabular_explainer, tabular_data | |
| # ============================================================================ | |
| # SHAP Explanation Functions | |
| # ============================================================================ | |
| def explain_mnist_digit(digit_index): | |
| """Generate SHAP explanation for MNIST digit""" | |
| try: | |
| model, background = initialize_mnist_model() | |
| # Load test data | |
| test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) | |
| test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False) | |
| images, targets = next(iter(test_loader)) | |
| test_images = images[100:110] | |
| test_targets = targets[100:110].numpy() | |
| # Select image | |
| idx = min(digit_index, len(test_images) - 1) | |
| test_image = test_images[[idx]] | |
| # Move to same device as model | |
| test_image = test_image.to(DEVICE) | |
| background_device = background.to(DEVICE) | |
| # Get prediction | |
| with torch.no_grad(): | |
| output = model(test_image) | |
| pred = output.max(1, keepdim=True)[1].cpu().numpy()[0][0] | |
| # Create explainer and get SHAP values | |
| explainer = shap.DeepExplainer(model, background_device) | |
| shap_values = explainer.shap_values(test_image) | |
| # Prepare for visualization | |
| shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values] | |
| test_numpy = np.swapaxes(np.swapaxes(test_image.cpu().numpy(), 1, -1), 1, 2) | |
| # Create plot | |
| fig = plt.figure(figsize=(15, 3)) | |
| shap.image_plot(shap_numpy, -test_numpy, show=False) | |
| # Add title | |
| plt.suptitle(f'Actual: {test_targets[idx]}, Predicted: {pred}', fontsize=14, y=1.02) | |
| # Convert to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img, f"Prediction: {pred} (Actual: {test_targets[idx]})" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def explain_imagenet_image(image): | |
| """Generate SHAP explanation for ImageNet image""" | |
| try: | |
| model, explainer = initialize_resnet_model() | |
| # Preprocess image | |
| if image is None: | |
| return None, "Please upload an image" | |
| # Resize and prepare image | |
| img = Image.fromarray(image).resize((224, 224)) | |
| img_array = np.array(img) | |
| if len(img_array.shape) == 2: # Grayscale | |
| img_array = np.stack([img_array] * 3, axis=-1) | |
| elif img_array.shape[2] == 4: # RGBA | |
| img_array = img_array[:, :, :3] | |
| img_array = np.clip(img_array, 0, 255).astype(np.uint8) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Calculate SHAP values | |
| shap_values = explainer(img_array, max_evals=100, batch_size=50, | |
| outputs=shap.Explanation.argsort.flip[:4]) | |
| # Create plot | |
| fig = plt.figure(figsize=(15, 5)) | |
| shap.image_plot(shap_values, show=False) | |
| # Convert to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
| buf.seek(0) | |
| result_img = Image.open(buf) | |
| plt.close() | |
| return result_img, "SHAP explanation generated successfully" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def explain_tabular_sample(sample_index): | |
| """Generate SHAP explanation for tabular data sample""" | |
| try: | |
| model, explainer, (X_test, y_test) = initialize_tabular_model() | |
| # Select sample | |
| idx = min(sample_index, len(X_test) - 1) | |
| # Get first 100 samples for SHAP calculation | |
| X_subset = X_test.iloc[:100] if hasattr(X_test, 'iloc') else X_test[:100] | |
| shap_values = explainer(X_subset) | |
| # Create waterfall plot | |
| fig = plt.figure(figsize=(10, 8)) | |
| shap.plots.waterfall(shap_values[idx, :, 1], show=False) | |
| # Convert to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| # Get prediction - handle both DataFrame and numpy array | |
| if hasattr(X_test, 'iloc'): | |
| # DataFrame/Series | |
| X_sample = X_test.iloc[[idx]] | |
| actual = y_test.iloc[idx] | |
| else: | |
| # Numpy array | |
| X_sample = X_test[idx:idx+1] | |
| actual = y_test[idx] | |
| pred = model.predict(X_sample)[0] | |
| return img, f"Prediction: {pred} (Actual: {actual})" | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| return None, f"Error: {str(e)}\n\nDetails:\n{error_details}" | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| def create_demo(): | |
| """Create Gradio demo interface""" | |
| with gr.Blocks(title="SHAP Explanations Demo") as demo: | |
| gr.Markdown("# SHAP (SHapley Additive exPlanations) Demo") | |
| gr.Markdown("This demo showcases three different SHAP explanation methods for machine learning models.") | |
| with gr.Tabs(): | |
| # Tab 1: MNIST Pixel-level Explanations | |
| with gr.Tab("1. Pixel-level (MNIST Digits)"): | |
| gr.Markdown(""" | |
| ### Pixel-level SHAP Explanations | |
| This method uses **DeepExplainer** to show which pixels contribute to the model's prediction. | |
| - **Red pixels**: Increase the probability of the predicted class | |
| - **Blue pixels**: Decrease the probability of the predicted class | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| mnist_slider = gr.Slider(minimum=0, maximum=9, step=1, value=0, | |
| label="Select Test Image Index") | |
| mnist_button = gr.Button("Generate Explanation", variant="primary") | |
| with gr.Column(): | |
| mnist_output = gr.Image(label="SHAP Explanation") | |
| mnist_text = gr.Textbox(label="Prediction Result") | |
| mnist_button.click( | |
| fn=explain_mnist_digit, | |
| inputs=[mnist_slider], | |
| outputs=[mnist_output, mnist_text] | |
| ) | |
| # Tab 2: ImageNet Image Explanations | |
| with gr.Tab("2. Image Segmentation (ImageNet)"): | |
| gr.Markdown(""" | |
| ### Image Segmentation SHAP Explanations | |
| This method uses **Partition Explainer** with image masking to explain ResNet50 predictions. | |
| Upload an image to see which regions contribute to the top predicted classes. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload Image") | |
| image_button = gr.Button("Generate Explanation", variant="primary") | |
| with gr.Column(): | |
| image_output = gr.Image(label="SHAP Explanation") | |
| image_text = gr.Textbox(label="Status") | |
| image_button.click( | |
| fn=explain_imagenet_image, | |
| inputs=[image_input], | |
| outputs=[image_output, image_text] | |
| ) | |
| # Tab 3: Tabular Data Explanations | |
| with gr.Tab("3. Tabular Data (Adult Income)"): | |
| gr.Markdown(""" | |
| ### Tabular Data SHAP Explanations | |
| This method uses **TreeExplainer** to explain Random Forest predictions on the Adult Income dataset. | |
| The waterfall plot shows how each feature contributes to the prediction. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| tabular_slider = gr.Slider(minimum=0, maximum=99, step=1, value=0, | |
| label="Select Sample Index") | |
| tabular_button = gr.Button("Generate Explanation", variant="primary") | |
| with gr.Column(): | |
| tabular_output = gr.Image(label="SHAP Waterfall Plot") | |
| tabular_text = gr.Textbox(label="Prediction Result") | |
| tabular_button.click( | |
| fn=explain_tabular_sample, | |
| inputs=[tabular_slider], | |
| outputs=[tabular_output, tabular_text] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About SHAP | |
| SHAP (SHapley Additive exPlanations) is a unified approach to explain the output of machine learning models. | |
| It connects game theory with local explanations and provides consistent and locally accurate feature attributions. | |
| """) | |
| return demo | |
| # ============================================================================ | |
| # Main | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(share=False, server_name="0.0.0.0", server_port=7860) | |