divitmittal's picture
feat: load model from HuggingFace Hub
0fac5e0
raw
history blame
12 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
# --- Model Definition ---
class FocalModulation(nn.Module):
def __init__(
self,
dim,
focal_window=9,
focal_level=2,
use_postln=False,
normalize_modulator=False,
):
super().__init__()
self.dim, self.focal_window, self.focal_level = dim, focal_window, focal_level
self.use_postln, self.normalize_modulator = use_postln, normalize_modulator
self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True)
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, groups=1, bias=True)
self.act = nn.GELU()
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(0.1)
self.focal_layers = nn.ModuleList()
for k in range(self.focal_level):
kernel_size = self.focal_window + 2 * k * 2
self.focal_layers.append(
nn.Sequential(
nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
stride=1,
groups=dim,
padding=kernel_size // 2,
bias=False,
),
nn.GELU(),
)
)
if self.use_postln:
self.ln = nn.LayerNorm(dim)
def forward(self, x):
C = x.shape[-1]
q, ctx, gates = torch.split(self.f(x), (C, C, self.focal_level + 1), -1)
ctx = ctx.permute(0, 3, 1, 2).contiguous()
ctx_all = 0
for l in range(self.focal_level):
ctx = self.focal_layers[l](ctx)
ctx_all += ctx * gates[:, :, :, l : l + 1].permute(0, 3, 1, 2).contiguous()
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
ctx_all += (
ctx_global
* gates[:, :, :, self.focal_level :].permute(0, 3, 1, 2).contiguous()
)
if self.normalize_modulator:
ctx_all /= self.focal_level + 1
x_out = q * self.h(ctx_all).permute(0, 2, 3, 1).contiguous()
x_out = self.proj_drop(self.proj(x_out))
return self.ln(x_out) if self.use_postln else x_out
class CrossAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.1, proj_drop=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.wq, self.wk, self.wv = (
nn.Linear(dim, dim, bias=qkv_bias) for _ in range(3)
)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x_near, x_far):
B, N, C = x_near.shape
q_near = (
self.wq(x_near)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k_far = (
self.wk(x_far)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v_far = (
self.wv(x_far)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
attn = (q_near @ k_far.transpose(-2, -1)) * self.scale
attn = self.attn_drop(attn.softmax(dim=-1))
x = (attn @ v_far).transpose(1, 2).reshape(B, N, C)
return self.proj_drop(self.proj(x))
class CrossViTBlock(nn.Module):
def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.1):
super().__init__()
self.norm1_near, self.norm1_far = nn.LayerNorm(dim), nn.LayerNorm(dim)
self.cross_attn = CrossAttention(dim, num_heads, attn_drop=drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop),
)
def forward(self, x_near, x_far):
x_fused = x_near + self.cross_attn(
self.norm1_near(x_near), self.norm1_far(x_far)
)
return x_fused + self.mlp(self.norm2(x_fused))
class FocalTransformerBlock(nn.Module):
def __init__(self, dim, focal_window=9, focal_level=2, mlp_ratio=4.0, drop=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.focal_mod = FocalModulation(dim, focal_window, focal_level)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop),
)
def forward(self, x):
# The input x is expected to be of shape (B, H, W, C)
# LayerNorm is applied to the last dimension (C)
x_norm1 = self.norm1(x)
# Focal modulation expects (B, H, W, C) and returns (B, H, W, C)
x_focal = self.focal_mod(x_norm1)
x = x + x_focal
# Second LayerNorm
x_norm2 = self.norm2(x)
# MLP is applied to the last dimension (C)
x_mlp = self.mlp(x_norm2)
x = x + x_mlp
return x
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.grid_size = img_size // patch_size
self.num_patches = self.grid_size**2
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
return self.norm(self.proj(x).permute(0, 2, 3, 1))
class FocalCrossViTHybrid(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
embed_dim=768,
depth_cross=4,
depth_focal=6,
num_heads=12,
focal_window=9,
focal_level=3,
):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim)
self.grid_size = img_size // patch_size
self.pos_embed = nn.Parameter(
torch.zeros(1, self.grid_size, self.grid_size, embed_dim)
)
self.cross_blocks = nn.ModuleList(
[CrossViTBlock(embed_dim, num_heads) for _ in range(depth_cross)]
)
self.focal_blocks = nn.ModuleList(
[
FocalTransformerBlock(embed_dim, focal_window, focal_level)
for _ in range(depth_focal)
]
)
self.fusion_norm = nn.LayerNorm(embed_dim)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embed_dim, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 3, 3, 1, 1),
nn.Sigmoid(),
)
def forward(self, near, far):
x_near = self.patch_embed(near) + self.pos_embed
x_far = self.patch_embed(far) + self.pos_embed
x_near_flat, x_far_flat = x_near.flatten(1, 2), x_far.flatten(1, 2)
for block in self.cross_blocks:
x_fused = block(x_near_flat, x_far_flat)
x_near_flat = 0.5 * x_near_flat + 0.5 * x_fused
x_far_flat = 0.5 * x_far_flat + 0.5 * block(x_far_flat, x_near_flat)
x_fused = (x_near_flat + x_far_flat) / 2
x_fused = x_fused.view_as(x_near)
for block in self.focal_blocks:
x_fused = block(x_fused)
return self.decoder(self.fusion_norm(x_fused).permute(0, 3, 1, 2))
# --- Global Variables ---
model = None
device = None
# --- Inference Logic ---
def load_model():
"""Loads the model from HuggingFace Hub and caches it in a global variable."""
global model, device
if model is not None:
return model, device
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download model from HuggingFace Hub
model_path = hf_hub_download(
repo_id="divitmittal/HybridTransformer-MFIF",
filename="best_model.pth",
cache_dir="./model_cache"
)
model_instance = FocalCrossViTHybrid(img_size=224).to(device)
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
if any(key.startswith("module.") for key in state_dict.keys()):
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model_instance.load_state_dict(state_dict)
model_instance.eval()
model = model_instance
return model, device
except Exception as e:
# Catch any exception during loading and show it in the UI
raise gr.Error(f"Failed to load the model from HuggingFace Hub: {e}")
# Image processing functions
def get_transform():
return transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def denormalize(tensor):
mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device).view(1, 3, 1, 1)
return torch.clamp(tensor * std + mean, 0, 1)
def fuse_images(near_img, far_img):
"""Takes two PIL images and returns the fused PIL image."""
if near_img is None or far_img is None:
raise gr.Error("Please upload both a near-focus and a far-focus image.")
model, device = load_model()
transform = get_transform()
near_tensor = transform(near_img).unsqueeze(0).to(device)
far_tensor = transform(far_img).unsqueeze(0).to(device)
with torch.no_grad():
fused_tensor = model(near_tensor, far_tensor)
fused_tensor_denorm = denormalize(fused_tensor)
fused_np = fused_tensor_denorm.squeeze(0).permute(1, 2, 0).cpu().numpy()
fused_pil = Image.fromarray((fused_np * 255).astype(np.uint8))
return fused_pil
# --- Gradio Interface ---
title = "Hybrid Transformer for Multi-Focus Image Fusion"
description = """
This demo showcases a transformer-based deep learning model that combines a Focal Transformer & Cross-view Vision Transformer (CrossViT)
to fuse a near-focus and a far-focus image into a single, all-in-focus image.
Upload one of each to see it in action.
"""
article = "<p style='text-align: center'><a href='https://github.com/DivitMittal/HybridTransformer-MFIF' target='_blank'>GitHub Repository</a></p>"
with gr.Blocks() as iface:
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
gr.Markdown(description)
with gr.Row():
with gr.Column():
near_img = gr.Image(type="pil", label="Near-Focus Image")
far_img = gr.Image(type="pil", label="Far-Focus Image")
submit_btn = gr.Button("Fuse Images")
with gr.Column():
fused_img = gr.Image(type="pil", label="Fused Image")
gr.Examples(
examples=[
[
"assets/lytro-01-A.jpg",
"assets/lytro-01-B.jpg",
]
],
inputs=[near_img, far_img, fused_img],
outputs=fused_img,
fn=fuse_images,
cache_examples=False,
)
gr.Markdown(article)
submit_btn.click(fn=fuse_images, inputs=[near_img, far_img], outputs=fused_img)
if __name__ == "__main__":
iface.launch()