Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader, Subset | |
| from torchvision import transforms, datasets | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| import torch.nn.functional as F | |
| from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, FullGrad | |
| from matplotlib import colormaps | |
| import numpy as np | |
| import gradio as gr | |
| class CNN(nn.Module): | |
| def __init__(self): | |
| super(CNN, self).__init__() | |
| # Convolutional layers | |
| self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) | |
| self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) | |
| self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) | |
| # Pooling layer | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |
| # Fully connected layers | |
| self.fc1 = nn.Linear(64 * (224 // 8) * (224 // 8), 64) # Adjusted based on pooling layers | |
| self.fc2 = nn.Linear(64, 2) # 2 classes for binary classification | |
| def forward(self, x): | |
| # Convolutional layers with relu activation and pooling | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = self.pool(F.relu(self.conv3(x))) | |
| # Flatten for fully connected layers | |
| x = torch.flatten(x, 1) | |
| # Fully connected layers with relu activation | |
| x = F.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Resize to 224x224 | |
| transforms.ToTensor(), # Convert to tensor | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize | |
| ]) | |
| model = CNN() | |
| model.load_state_dict(torch.load("trained-cnn-concrete-crack.model", map_location=torch.device("cpu"))) | |
| magmaify = colormaps['magma'] | |
| def compute_gradcam(img_tensor, layer_idx, typeCAM): | |
| allCAMs = {"GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAMPlusPlus": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "FullGrad": FullGrad} | |
| target_layers = [[model.conv1], [model.conv2], [model.conv3]] | |
| cam = allCAMs[typeCAM](model=model, target_layers=target_layers[layer_idx-1]) | |
| grayscale_cam = cam(input_tensor=img_tensor, targets=None) | |
| return magmaify(grayscale_cam.reshape(224, 224)) | |
| def predict_and_gradcam(model, img, layer_idx, typeCAM): | |
| # Preprocess the image | |
| img = Image.fromarray(img.astype('uint8'), 'RGB') if isinstance(img, np.ndarray) else img | |
| img_tensor = transform(img).unsqueeze(0) | |
| # Get predicted class index | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| _, predicted = torch.max(output.data, 1) | |
| predicted_label = str(predicted.item()) | |
| # Compute GradCAM | |
| gradcam = compute_gradcam(img_tensor, layer_idx, typeCAM) | |
| return predicted_label, gradcam | |
| idx_to_lbl = {"0": "Cracked", "1":"Uncracked"} | |
| # Define a function to be used in Gradio app | |
| def classify_image(image, layer_idx, typeCAM): | |
| # Predict label and get GradCAM | |
| label, gradcam_img = predict_and_gradcam(model, image, layer_idx, typeCAM) | |
| return idx_to_lbl[label], gradcam_img | |
| description = """\ | |
| <hr><center>Upload an image of concrete and get the predicted label along with the GradCAM heatmap. <br><br> | |
| <img src="https://www.huggingface.co/spaces/1rsh/concrete-crack-gradcam/resolve/main/header.jpeg" width=200px></img></center> | |
| \ | |
| """ | |
| typeCAMs = ["GradCAM", "HiResCAM", "ScoreCAM", "GradCAMPlusPlus", "AblationCAM", "XGradCAM", "FullGrad"] | |
| # Define Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=[gr.Image(), gr.Slider(minimum=1, maximum=3, step=1, value=1), gr.Dropdown(choices=typeCAMs, value="GradCAM")], | |
| outputs=[gr.Textbox(label="Predicted Label"), gr.Image(label="GradCAM Heatmap")], | |
| title="Concrete Crack Detection with GradCAM", | |
| description= description, | |
| allow_flagging=False, | |
| theme=gr.themes.Monochrome(font=gr.themes.GoogleFont("IBM Plex Mono")) | |
| ) | |
| # Launch the interface | |
| iface.launch() |