import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as F
from model import DRRRDBNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
gen2 = DRRRDBNet(3, 3, 64, 32, 2, 0.2).to(device)
def load_weights(checkpoint_file, model):
print("=> Loading weights from:", checkpoint_file)
checkpoint = torch.load(checkpoint_file, map_location=device)
model_state_dict = model.state_dict()
state_dict = {
k: v for k, v in checkpoint["state_dict"].items()
if k in model_state_dict and v.size() == model_state_dict[k].size()
}
model_state_dict.update(state_dict)
model.load_state_dict(model_state_dict)
print("Successfully loaded the pretrained model weights")
return model
gen2 = load_weights("gen174.pth.tar", gen2)
gen2.eval()
def enable_dropout(model):
"""Keeps dropout layers active during inference."""
for m in model.modules():
if isinstance(m, nn.Dropout):
m.train()
def mcd_superres_crop(image, mc_passes=5):
"""
1) Random crop input (200x200)
2) Upscale that cropped patch 4× with bicubic so users can compare visually
3) Run multiple forward passes (MCD) on the cropped patch
4) Return: (cropped_image_4x, mean_SR, std_heatmap)
"""
if image is None:
return None, None, None
# A) Random crop 200x200
transform_crop = transforms.Compose([
transforms.RandomCrop((200, 200)),
transforms.ToTensor(),
transforms.Normalize((0,0,0), (1,1,1))
])
lr_tensor = transform_crop(image) # shape: (3, 200, 200)
# Convert the cropped tensor to a PIL image
cropped_pil = F.to_pil_image(lr_tensor.clone().clamp_(0,1))
# B) Upscale the cropped patch 4× using bicubic
w, h = cropped_pil.size
cropped_pil_4x = cropped_pil.resize((w*4, h*4), Image.BICUBIC)
# Move the cropped tensor to device for SR
lr_tensor = lr_tensor.unsqueeze(0).to(device)
# C) Monte Carlo Dropout: multiple passes
sr_passes = []
for _ in range(mc_passes):
gen2.eval() # keep BN in eval mode
enable_dropout(gen2) # re-enable dropout layers
with torch.no_grad():
sr_out = gen2(lr_tensor)
sr_passes.append(sr_out)
# Stack across passes -> (mc_passes, 1, 3, H, W)
stacked = torch.stack(sr_passes, dim=0)
# Mean & std across 'mc_passes' dimension
mean_batch = torch.mean(stacked, dim=0)
std_batch = torch.std(stacked, dim=0)
# Convert mean SR to PIL
mean_batch = mean_batch.squeeze(0).clamp_(0,1) # shape (3, H, W)
mean_pil = F.to_pil_image(mean_batch.cpu())
# D) Build a STD heatmap (collapsing across channels)
std_map = torch.mean(std_batch, dim=1) # shape: (H, W)
s_min, s_max = std_map.min(), std_map.max()
if (s_max - s_min) < 1e-8:
std_norm = std_map.clone()
else:
std_norm = (std_map - s_min) / (s_max - s_min)
# Convert std map to a color image via matplotlib's 'jet' colormap
std_map_np = std_norm.squeeze().cpu().numpy()
colored_std = plt.cm.jet(std_map_np) # shape: (H, W, 4)
colored_std = (colored_std[..., :3] * 255).astype(np.uint8)
stdmap_pil = Image.fromarray(colored_std)
# Return the 4× upscaled crop, the mean SR output, and the STD heatmap
return cropped_pil_4x, mean_pil, stdmap_pil
demo = gr.Interface(
fn=mcd_superres_crop,
inputs=[gr.Image(type="pil", label="Upload an image"), gr.Slider(minimum=1, maximum=20, value=5 ,step=1, label="MC Dropout Passes")],
outputs=[
gr.Image(type="pil", label="1) Random Crop 4x Upscaled using bicubic interpolation"),
gr.Image(type="pil", label="2) Super-Resolved (Mean)"),
gr.Image(type="pil", label="3) STD Heatmap")
],
title="Uncertainity Estimation for Super Resolution using ESRGAN.",
description = """
This is the demo for our paper: Uncertainity Estimation for Super Resolution using ESRGAN.
Authors: Matias Valdenegro Toro, John Smith, & Maniraj Sai.
Presented at the 2025 VISAPP Conference.
Usage: Upload an image (or use one of the examples below) and click "Submit."
"""
,
article = """
This demo showcases an enhanced ESRGAN approach for image super-resolution. First, we take a 256×256 crop from the uploaded image to reduce computational load. Then, we apply a 4x upscale using our ESRGAN model, which has been modified to incorporate dropout layers. Through multiple forward passes (Monte Carlo Dropout), the demo not only produces a high-resolution output but also estimates pixelwise uncertainty. By visualizing these uncertainties in a color-coded heatmap, you can see which regions of the image the model is less confident about—an important insight for understanding model performance and reliability.
@inproceedings{your_paper_2024,
title={Uncertainity Estimation for Super Resolution using ESRGAN.},
author={Matias Valdenegro Toro, John Smith and Maniraj Sai Adapa},
booktitle={VISAPP Conference 2025},
year={2025}
}
"""
)
if __name__ == "__main__":
demo.launch()