|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
from urllib.request import urlretrieve |
|
|
|
|
|
|
|
|
|
class FocalModulation(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
focal_window=9, |
|
focal_level=2, |
|
use_postln=False, |
|
normalize_modulator=False, |
|
): |
|
super().__init__() |
|
self.dim, self.focal_window, self.focal_level = dim, focal_window, focal_level |
|
self.use_postln, self.normalize_modulator = use_postln, normalize_modulator |
|
self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True) |
|
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, groups=1, bias=True) |
|
self.act = nn.GELU() |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(0.1) |
|
self.focal_layers = nn.ModuleList() |
|
for k in range(self.focal_level): |
|
kernel_size = self.focal_window + 2 * k * 2 |
|
self.focal_layers.append( |
|
nn.Sequential( |
|
nn.Conv2d( |
|
dim, |
|
dim, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
groups=dim, |
|
padding=kernel_size // 2, |
|
bias=False, |
|
), |
|
nn.GELU(), |
|
) |
|
) |
|
if self.use_postln: |
|
self.ln = nn.LayerNorm(dim) |
|
|
|
def forward(self, x): |
|
C = x.shape[-1] |
|
q, ctx, gates = torch.split(self.f(x), (C, C, self.focal_level + 1), -1) |
|
ctx = ctx.permute(0, 3, 1, 2).contiguous() |
|
ctx_all = 0 |
|
for l in range(self.focal_level): |
|
ctx = self.focal_layers[l](ctx) |
|
ctx_all += ctx * gates[:, :, :, l : l + 1].permute(0, 3, 1, 2).contiguous() |
|
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) |
|
ctx_all += ( |
|
ctx_global |
|
* gates[:, :, :, self.focal_level :].permute(0, 3, 1, 2).contiguous() |
|
) |
|
if self.normalize_modulator: |
|
ctx_all /= self.focal_level + 1 |
|
x_out = q * self.h(ctx_all).permute(0, 2, 3, 1).contiguous() |
|
x_out = self.proj_drop(self.proj(x_out)) |
|
return self.ln(x_out) if self.use_postln else x_out |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.1, proj_drop=0.1): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
self.scale = head_dim**-0.5 |
|
self.wq, self.wk, self.wv = ( |
|
nn.Linear(dim, dim, bias=qkv_bias) for _ in range(3) |
|
) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x_near, x_far): |
|
B, N, C = x_near.shape |
|
q_near = ( |
|
self.wq(x_near) |
|
.reshape(B, N, self.num_heads, C // self.num_heads) |
|
.permute(0, 2, 1, 3) |
|
) |
|
k_far = ( |
|
self.wk(x_far) |
|
.reshape(B, N, self.num_heads, C // self.num_heads) |
|
.permute(0, 2, 1, 3) |
|
) |
|
v_far = ( |
|
self.wv(x_far) |
|
.reshape(B, N, self.num_heads, C // self.num_heads) |
|
.permute(0, 2, 1, 3) |
|
) |
|
attn = (q_near @ k_far.transpose(-2, -1)) * self.scale |
|
attn = self.attn_drop(attn.softmax(dim=-1)) |
|
x = (attn @ v_far).transpose(1, 2).reshape(B, N, C) |
|
return self.proj_drop(self.proj(x)) |
|
|
|
|
|
class CrossViTBlock(nn.Module): |
|
def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.1): |
|
super().__init__() |
|
self.norm1_near, self.norm1_far = nn.LayerNorm(dim), nn.LayerNorm(dim) |
|
self.cross_attn = CrossAttention(dim, num_heads, attn_drop=drop, proj_drop=drop) |
|
self.norm2 = nn.LayerNorm(dim) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim, int(dim * mlp_ratio)), |
|
nn.GELU(), |
|
nn.Dropout(drop), |
|
nn.Linear(int(dim * mlp_ratio), dim), |
|
nn.Dropout(drop), |
|
) |
|
|
|
def forward(self, x_near, x_far): |
|
x_fused = x_near + self.cross_attn( |
|
self.norm1_near(x_near), self.norm1_far(x_far) |
|
) |
|
return x_fused + self.mlp(self.norm2(x_fused)) |
|
|
|
|
|
class FocalTransformerBlock(nn.Module): |
|
def __init__(self, dim, focal_window=9, focal_level=2, mlp_ratio=4.0, drop=0.1): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(dim) |
|
self.focal_mod = FocalModulation(dim, focal_window, focal_level) |
|
self.norm2 = nn.LayerNorm(dim) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim, int(dim * mlp_ratio)), |
|
nn.GELU(), |
|
nn.Dropout(drop), |
|
nn.Linear(int(dim * mlp_ratio), dim), |
|
nn.Dropout(drop), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
x_norm1 = self.norm1(x) |
|
|
|
|
|
x_focal = self.focal_mod(x_norm1) |
|
x = x + x_focal |
|
|
|
|
|
x_norm2 = self.norm2(x) |
|
|
|
|
|
x_mlp = self.mlp(x_norm2) |
|
x = x + x_mlp |
|
|
|
return x |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): |
|
super().__init__() |
|
self.grid_size = img_size // patch_size |
|
self.num_patches = self.grid_size**2 |
|
self.proj = nn.Conv2d( |
|
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size |
|
) |
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
|
def forward(self, x): |
|
return self.norm(self.proj(x).permute(0, 2, 3, 1)) |
|
|
|
|
|
class FocalCrossViTHybrid(nn.Module): |
|
def __init__( |
|
self, |
|
img_size=224, |
|
patch_size=16, |
|
embed_dim=768, |
|
depth_cross=4, |
|
depth_focal=6, |
|
num_heads=12, |
|
focal_window=9, |
|
focal_level=3, |
|
): |
|
super().__init__() |
|
self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim) |
|
self.grid_size = img_size // patch_size |
|
self.pos_embed = nn.Parameter( |
|
torch.zeros(1, self.grid_size, self.grid_size, embed_dim) |
|
) |
|
self.cross_blocks = nn.ModuleList( |
|
[CrossViTBlock(embed_dim, num_heads) for _ in range(depth_cross)] |
|
) |
|
self.focal_blocks = nn.ModuleList( |
|
[ |
|
FocalTransformerBlock(embed_dim, focal_window, focal_level) |
|
for _ in range(depth_focal) |
|
] |
|
) |
|
self.fusion_norm = nn.LayerNorm(embed_dim) |
|
self.decoder = nn.Sequential( |
|
nn.ConvTranspose2d(embed_dim, 512, 4, 2, 1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(512, 256, 4, 2, 1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(256, 128, 4, 2, 1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(128, 64, 4, 2, 1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(True), |
|
nn.Conv2d(64, 3, 3, 1, 1), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, near, far): |
|
x_near = self.patch_embed(near) + self.pos_embed |
|
x_far = self.patch_embed(far) + self.pos_embed |
|
x_near_flat, x_far_flat = x_near.flatten(1, 2), x_far.flatten(1, 2) |
|
for block in self.cross_blocks: |
|
x_fused = block(x_near_flat, x_far_flat) |
|
x_near_flat = 0.5 * x_near_flat + 0.5 * x_fused |
|
x_far_flat = 0.5 * x_far_flat + 0.5 * block(x_far_flat, x_near_flat) |
|
x_fused = (x_near_flat + x_far_flat) / 2 |
|
x_fused = x_fused.view_as(x_near) |
|
for block in self.focal_blocks: |
|
x_fused = block(x_fused) |
|
return self.decoder(self.fusion_norm(x_fused).permute(0, 3, 1, 2)) |
|
|
|
|
|
|
|
model = None |
|
device = None |
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
"""Loads the model and caches it in a global variable.""" |
|
global model, device |
|
if model is not None: |
|
return model, device |
|
|
|
try: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
MODEL_PATH = "best_model.pth" |
|
|
|
model_instance = FocalCrossViTHybrid(img_size=224).to(device) |
|
|
|
if not os.path.exists(MODEL_PATH): |
|
raise FileNotFoundError( |
|
"Model checkpoint 'best_model.pth' not found. Please ensure it is available in the Space." |
|
) |
|
|
|
checkpoint = torch.load(MODEL_PATH, map_location=device) |
|
|
|
state_dict = checkpoint.get("model_state_dict", checkpoint) |
|
if any(key.startswith("module.")): |
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
|
model_instance.load_state_dict(state_dict) |
|
model_instance.eval() |
|
|
|
model = model_instance |
|
return model, device |
|
except Exception as e: |
|
|
|
raise gr.Error(f"Failed to load the model: {e}") |
|
|
|
|
|
|
|
def get_transform(): |
|
return transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
def denormalize(tensor): |
|
mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device).view(1, 3, 1, 1) |
|
std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device).view(1, 3, 1, 1) |
|
return torch.clamp(tensor * std + mean, 0, 1) |
|
|
|
|
|
def fuse_images(near_img, far_img): |
|
"""Takes two PIL images and returns the fused PIL image.""" |
|
if near_img is None or far_img is None: |
|
raise gr.Error("Please upload both a near-focus and a far-focus image.") |
|
|
|
model, device = load_model() |
|
|
|
transform = get_transform() |
|
near_tensor = transform(near_img).unsqueeze(0).to(device) |
|
far_tensor = transform(far_img).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
fused_tensor = model(near_tensor, far_tensor) |
|
|
|
fused_tensor_denorm = denormalize(fused_tensor) |
|
fused_np = fused_tensor_denorm.squeeze(0).permute(1, 2, 0).cpu().numpy() |
|
fused_pil = Image.fromarray((fused_np * 255).astype(np.uint8)) |
|
|
|
return fused_pil |
|
|
|
|
|
|
|
|
|
title = "Hybrid Transformer for Multi-Focus Image Fusion" |
|
description = """ |
|
This demo showcases a transformer-based deep learning model that combines a Focal Transformer & Cross-view Vision Transformer (CrossViT) |
|
to fuse a near-focus and a far-focus image into a single, all-in-focus image. |
|
Upload one of each to see it in action. |
|
""" |
|
article = "<p style='text-align: center'><a href='https://github.com/DivitMittal/HybridTransformer-MFIF' target='_blank'>GitHub Repository</a></p>" |
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
near_img = gr.Image(type="pil", label="Near-Focus Image") |
|
far_img = gr.Image(type="pil", label="Far-Focus Image") |
|
submit_btn = gr.Button("Fuse Images") |
|
with gr.Column(): |
|
fused_img = gr.Image(type="pil", label="Fused Image") |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"assets/lytro-01-A.jpg", |
|
"assets/lytro-01-B.jpg", |
|
] |
|
], |
|
inputs=[near_img, far_img, fused_img], |
|
outputs=fused_img, |
|
fn=fuse_images, |
|
cache_examples=False, |
|
) |
|
|
|
gr.Markdown(article) |
|
|
|
submit_btn.click(fn=fuse_images, inputs=[near_img, far_img], outputs=fused_img) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|