import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np import os from huggingface_hub import hf_hub_download # --- Model Definition --- class FocalModulation(nn.Module): def __init__( self, dim, focal_window=9, focal_level=2, use_postln=False, normalize_modulator=False, ): super().__init__() self.dim, self.focal_window, self.focal_level = dim, focal_window, focal_level self.use_postln, self.normalize_modulator = use_postln, normalize_modulator self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, groups=1, bias=True) self.act = nn.GELU() self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(0.1) self.focal_layers = nn.ModuleList() for k in range(self.focal_level): kernel_size = self.focal_window + 2 * k * 2 self.focal_layers.append( nn.Sequential( nn.Conv2d( dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size // 2, bias=False, ), nn.GELU(), ) ) if self.use_postln: self.ln = nn.LayerNorm(dim) def forward(self, x): C = x.shape[-1] q, ctx, gates = torch.split(self.f(x), (C, C, self.focal_level + 1), -1) ctx = ctx.permute(0, 3, 1, 2).contiguous() ctx_all = 0 for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) ctx_all += ctx * gates[:, :, :, l : l + 1].permute(0, 3, 1, 2).contiguous() ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) ctx_all += ( ctx_global * gates[:, :, :, self.focal_level :].permute(0, 3, 1, 2).contiguous() ) if self.normalize_modulator: ctx_all /= self.focal_level + 1 x_out = q * self.h(ctx_all).permute(0, 2, 3, 1).contiguous() x_out = self.proj_drop(self.proj(x_out)) return self.ln(x_out) if self.use_postln else x_out class CrossAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.1, proj_drop=0.1): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.wq, self.wk, self.wv = ( nn.Linear(dim, dim, bias=qkv_bias) for _ in range(3) ) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x_near, x_far): B, N, C = x_near.shape q_near = ( self.wq(x_near) .reshape(B, N, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) k_far = ( self.wk(x_far) .reshape(B, N, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) v_far = ( self.wv(x_far) .reshape(B, N, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) attn = (q_near @ k_far.transpose(-2, -1)) * self.scale attn = self.attn_drop(attn.softmax(dim=-1)) x = (attn @ v_far).transpose(1, 2).reshape(B, N, C) return self.proj_drop(self.proj(x)) class CrossViTBlock(nn.Module): def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.1): super().__init__() self.norm1_near, self.norm1_far = nn.LayerNorm(dim), nn.LayerNorm(dim) self.cross_attn = CrossAttention(dim, num_heads, attn_drop=drop, proj_drop=drop) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Dropout(drop), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(drop), ) def forward(self, x_near, x_far): x_fused = x_near + self.cross_attn( self.norm1_near(x_near), self.norm1_far(x_far) ) return x_fused + self.mlp(self.norm2(x_fused)) class FocalTransformerBlock(nn.Module): def __init__(self, dim, focal_window=9, focal_level=2, mlp_ratio=4.0, drop=0.1): super().__init__() self.norm1 = nn.LayerNorm(dim) self.focal_mod = FocalModulation(dim, focal_window, focal_level) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Dropout(drop), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(drop), ) def forward(self, x): # The input x is expected to be of shape (B, H, W, C) # LayerNorm is applied to the last dimension (C) x_norm1 = self.norm1(x) # Focal modulation expects (B, H, W, C) and returns (B, H, W, C) x_focal = self.focal_mod(x_norm1) x = x + x_focal # Second LayerNorm x_norm2 = self.norm2(x) # MLP is applied to the last dimension (C) x_mlp = self.mlp(x_norm2) x = x + x_mlp return x class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.grid_size = img_size // patch_size self.num_patches = self.grid_size**2 self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): return self.norm(self.proj(x).permute(0, 2, 3, 1)) class FocalCrossViTHybrid(nn.Module): def __init__( self, img_size=224, patch_size=16, embed_dim=768, depth_cross=4, depth_focal=6, num_heads=12, focal_window=9, focal_level=3, ): super().__init__() self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim) self.grid_size = img_size // patch_size self.pos_embed = nn.Parameter( torch.zeros(1, self.grid_size, self.grid_size, embed_dim) ) self.cross_blocks = nn.ModuleList( [CrossViTBlock(embed_dim, num_heads) for _ in range(depth_cross)] ) self.focal_blocks = nn.ModuleList( [ FocalTransformerBlock(embed_dim, focal_window, focal_level) for _ in range(depth_focal) ] ) self.fusion_norm = nn.LayerNorm(embed_dim) self.decoder = nn.Sequential( nn.ConvTranspose2d(embed_dim, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU(True), nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True), nn.Conv2d(64, 3, 3, 1, 1), nn.Sigmoid(), ) def forward(self, near, far): x_near = self.patch_embed(near) + self.pos_embed x_far = self.patch_embed(far) + self.pos_embed x_near_flat, x_far_flat = x_near.flatten(1, 2), x_far.flatten(1, 2) for block in self.cross_blocks: x_fused = block(x_near_flat, x_far_flat) x_near_flat = 0.5 * x_near_flat + 0.5 * x_fused x_far_flat = 0.5 * x_far_flat + 0.5 * block(x_far_flat, x_near_flat) x_fused = (x_near_flat + x_far_flat) / 2 x_fused = x_fused.view_as(x_near) for block in self.focal_blocks: x_fused = block(x_fused) return self.decoder(self.fusion_norm(x_fused).permute(0, 3, 1, 2)) # --- Global Variables --- model = None device = None # --- Inference Logic --- def load_model(): """Loads the model from HuggingFace Hub and caches it in a global variable.""" global model, device if model is not None: return model, device try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download model from HuggingFace Hub model_path = hf_hub_download( repo_id="divitmittal/HybridTransformer-MFIF", filename="best_model.pth", cache_dir="./model_cache" ) model_instance = FocalCrossViTHybrid(img_size=224).to(device) checkpoint = torch.load(model_path, map_location=device) state_dict = checkpoint.get("model_state_dict", checkpoint) if any(key.startswith("module.") for key in state_dict.keys()): state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model_instance.load_state_dict(state_dict) model_instance.eval() model = model_instance return model, device except Exception as e: # Catch any exception during loading and show it in the UI raise gr.Error(f"Failed to load the model from HuggingFace Hub: {e}") # Image processing functions def get_transform(): return transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) def denormalize(tensor): mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device).view(1, 3, 1, 1) return torch.clamp(tensor * std + mean, 0, 1) def fuse_images(near_img, far_img): """Takes two PIL images and returns the fused PIL image.""" if near_img is None or far_img is None: raise gr.Error("Please upload both a near-focus and a far-focus image.") model, device = load_model() transform = get_transform() near_tensor = transform(near_img).unsqueeze(0).to(device) far_tensor = transform(far_img).unsqueeze(0).to(device) with torch.no_grad(): fused_tensor = model(near_tensor, far_tensor) fused_tensor_denorm = denormalize(fused_tensor) fused_np = fused_tensor_denorm.squeeze(0).permute(1, 2, 0).cpu().numpy() fused_pil = Image.fromarray((fused_np * 255).astype(np.uint8)) return fused_pil # --- Gradio Interface --- title = "Hybrid Transformer for Multi-Focus Image Fusion" description = """ This demo showcases a transformer-based deep learning model that combines a Focal Transformer & Cross-view Vision Transformer (CrossViT) to fuse a near-focus and a far-focus image into a single, all-in-focus image. Upload one of each to see it in action. """ article = "

GitHub Repository

" with gr.Blocks() as iface: gr.Markdown(f"

{title}

") gr.Markdown(description) with gr.Row(): with gr.Column(): near_img = gr.Image(type="pil", label="Near-Focus Image") far_img = gr.Image(type="pil", label="Far-Focus Image") submit_btn = gr.Button("Fuse Images") with gr.Column(): fused_img = gr.Image(type="pil", label="Fused Image") gr.Examples( examples=[ [ "assets/lytro-01-A.jpg", "assets/lytro-01-B.jpg", ] ], inputs=[near_img, far_img, fused_img], outputs=fused_img, fn=fuse_images, cache_examples=False, ) gr.Markdown(article) submit_btn.click(fn=fuse_images, inputs=[near_img, far_img], outputs=fused_img) if __name__ == "__main__": iface.launch()