divitmittal's picture
feat: implement multi-focus image fusion with Gradio
95663ef
raw
history blame
11.9 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from urllib.request import urlretrieve
# --- Model Definition ---
class FocalModulation(nn.Module):
def __init__(
self,
dim,
focal_window=9,
focal_level=2,
use_postln=False,
normalize_modulator=False,
):
super().__init__()
self.dim, self.focal_window, self.focal_level = dim, focal_window, focal_level
self.use_postln, self.normalize_modulator = use_postln, normalize_modulator
self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True)
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, groups=1, bias=True)
self.act = nn.GELU()
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(0.1)
self.focal_layers = nn.ModuleList()
for k in range(self.focal_level):
kernel_size = self.focal_window + 2 * k * 2
self.focal_layers.append(
nn.Sequential(
nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
stride=1,
groups=dim,
padding=kernel_size // 2,
bias=False,
),
nn.GELU(),
)
)
if self.use_postln:
self.ln = nn.LayerNorm(dim)
def forward(self, x):
C = x.shape[-1]
q, ctx, gates = torch.split(self.f(x), (C, C, self.focal_level + 1), -1)
ctx = ctx.permute(0, 3, 1, 2).contiguous()
ctx_all = 0
for l in range(self.focal_level):
ctx = self.focal_layers[l](ctx)
ctx_all += ctx * gates[:, :, :, l : l + 1].permute(0, 3, 1, 2).contiguous()
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
ctx_all += (
ctx_global
* gates[:, :, :, self.focal_level :].permute(0, 3, 1, 2).contiguous()
)
if self.normalize_modulator:
ctx_all /= self.focal_level + 1
x_out = q * self.h(ctx_all).permute(0, 2, 3, 1).contiguous()
x_out = self.proj_drop(self.proj(x_out))
return self.ln(x_out) if self.use_postln else x_out
class CrossAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.1, proj_drop=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.wq, self.wk, self.wv = (
nn.Linear(dim, dim, bias=qkv_bias) for _ in range(3)
)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x_near, x_far):
B, N, C = x_near.shape
q_near = (
self.wq(x_near)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k_far = (
self.wk(x_far)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v_far = (
self.wv(x_far)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
attn = (q_near @ k_far.transpose(-2, -1)) * self.scale
attn = self.attn_drop(attn.softmax(dim=-1))
x = (attn @ v_far).transpose(1, 2).reshape(B, N, C)
return self.proj_drop(self.proj(x))
class CrossViTBlock(nn.Module):
def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.1):
super().__init__()
self.norm1_near, self.norm1_far = nn.LayerNorm(dim), nn.LayerNorm(dim)
self.cross_attn = CrossAttention(dim, num_heads, attn_drop=drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop),
)
def forward(self, x_near, x_far):
x_fused = x_near + self.cross_attn(
self.norm1_near(x_near), self.norm1_far(x_far)
)
return x_fused + self.mlp(self.norm2(x_fused))
class FocalTransformerBlock(nn.Module):
def __init__(self, dim, focal_window=9, focal_level=2, mlp_ratio=4.0, drop=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.focal_mod = FocalModulation(dim, focal_window, focal_level)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop),
)
def forward(self, x):
# The input x is expected to be of shape (B, H, W, C)
# LayerNorm is applied to the last dimension (C)
x_norm1 = self.norm1(x)
# Focal modulation expects (B, H, W, C) and returns (B, H, W, C)
x_focal = self.focal_mod(x_norm1)
x = x + x_focal
# Second LayerNorm
x_norm2 = self.norm2(x)
# MLP is applied to the last dimension (C)
x_mlp = self.mlp(x_norm2)
x = x + x_mlp
return x
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.grid_size = img_size // patch_size
self.num_patches = self.grid_size**2
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
return self.norm(self.proj(x).permute(0, 2, 3, 1))
class FocalCrossViTHybrid(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
embed_dim=768,
depth_cross=4,
depth_focal=6,
num_heads=12,
focal_window=9,
focal_level=3,
):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim)
self.grid_size = img_size // patch_size
self.pos_embed = nn.Parameter(
torch.zeros(1, self.grid_size, self.grid_size, embed_dim)
)
self.cross_blocks = nn.ModuleList(
[CrossViTBlock(embed_dim, num_heads) for _ in range(depth_cross)]
)
self.focal_blocks = nn.ModuleList(
[
FocalTransformerBlock(embed_dim, focal_window, focal_level)
for _ in range(depth_focal)
]
)
self.fusion_norm = nn.LayerNorm(embed_dim)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embed_dim, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 3, 3, 1, 1),
nn.Sigmoid(),
)
def forward(self, near, far):
x_near = self.patch_embed(near) + self.pos_embed
x_far = self.patch_embed(far) + self.pos_embed
x_near_flat, x_far_flat = x_near.flatten(1, 2), x_far.flatten(1, 2)
for block in self.cross_blocks:
x_fused = block(x_near_flat, x_far_flat)
x_near_flat = 0.5 * x_near_flat + 0.5 * x_fused
x_far_flat = 0.5 * x_far_flat + 0.5 * block(x_far_flat, x_near_flat)
x_fused = (x_near_flat + x_far_flat) / 2
x_fused = x_fused.view_as(x_near)
for block in self.focal_blocks:
x_fused = block(x_fused)
return self.decoder(self.fusion_norm(x_fused).permute(0, 3, 1, 2))
# --- Global Variables ---
model = None
device = None
# --- Inference Logic ---
def load_model():
"""Loads the model and caches it in a global variable."""
global model, device
if model is not None:
return model, device
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "best_model.pth"
model_instance = FocalCrossViTHybrid(img_size=224).to(device)
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(
"Model checkpoint 'best_model.pth' not found. Please ensure it is available in the Space."
)
checkpoint = torch.load(MODEL_PATH, map_location=device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
if any(key.startswith("module.")):
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model_instance.load_state_dict(state_dict)
model_instance.eval()
model = model_instance
return model, device
except Exception as e:
# Catch any exception during loading and show it in the UI
raise gr.Error(f"Failed to load the model: {e}")
# Image processing functions
def get_transform():
return transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def denormalize(tensor):
mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device).view(1, 3, 1, 1)
return torch.clamp(tensor * std + mean, 0, 1)
def fuse_images(near_img, far_img):
"""Takes two PIL images and returns the fused PIL image."""
if near_img is None or far_img is None:
raise gr.Error("Please upload both a near-focus and a far-focus image.")
model, device = load_model()
transform = get_transform()
near_tensor = transform(near_img).unsqueeze(0).to(device)
far_tensor = transform(far_img).unsqueeze(0).to(device)
with torch.no_grad():
fused_tensor = model(near_tensor, far_tensor)
fused_tensor_denorm = denormalize(fused_tensor)
fused_np = fused_tensor_denorm.squeeze(0).permute(1, 2, 0).cpu().numpy()
fused_pil = Image.fromarray((fused_np * 255).astype(np.uint8))
return fused_pil
# --- Gradio Interface ---
title = "Hybrid Transformer for Multi-Focus Image Fusion"
description = """
This demo showcases a transformer-based deep learning model that combines a Focal Transformer & Cross-view Vision Transformer (CrossViT)
to fuse a near-focus and a far-focus image into a single, all-in-focus image.
Upload one of each to see it in action.
"""
article = "<p style='text-align: center'><a href='https://github.com/DivitMittal/HybridTransformer-MFIF' target='_blank'>GitHub Repository</a></p>"
with gr.Blocks() as iface:
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
gr.Markdown(description)
with gr.Row():
with gr.Column():
near_img = gr.Image(type="pil", label="Near-Focus Image")
far_img = gr.Image(type="pil", label="Far-Focus Image")
submit_btn = gr.Button("Fuse Images")
with gr.Column():
fused_img = gr.Image(type="pil", label="Fused Image")
gr.Examples(
examples=[
[
"assets/lytro-01-A.jpg",
"assets/lytro-01-B.jpg",
]
],
inputs=[near_img, far_img, fused_img],
outputs=fused_img,
fn=fuse_images,
cache_examples=False,
)
gr.Markdown(article)
submit_btn.click(fn=fuse_images, inputs=[near_img, far_img], outputs=fused_img)
if __name__ == "__main__":
iface.launch()