|
""" |
|
This script defines the MIPHEI-ViT architecture for image-to-image translation |
|
Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/ |
|
""" |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import timm |
|
from timm.models import VisionTransformer, SwinTransformer |
|
from timm.models import load_state_dict_from_hf |
|
|
|
|
|
class Basic_Conv3x3(nn.Module): |
|
""" |
|
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. |
|
https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5 |
|
""" |
|
def __init__( |
|
self, |
|
in_chans, |
|
out_chans, |
|
stride=2, |
|
padding=1, |
|
): |
|
super().__init__() |
|
self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) |
|
self.bn = nn.BatchNorm2d(out_chans) |
|
self.relu = nn.ReLU(inplace=False) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
x = self.relu(x) |
|
|
|
return x |
|
|
|
|
|
class ConvStream(nn.Module): |
|
""" |
|
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. |
|
""" |
|
def __init__( |
|
self, |
|
in_chans = 4, |
|
out_chans = [48, 96, 192], |
|
): |
|
super().__init__() |
|
self.convs = nn.ModuleList() |
|
|
|
self.conv_chans = out_chans.copy() |
|
self.conv_chans.insert(0, in_chans) |
|
|
|
for i in range(len(self.conv_chans)-1): |
|
in_chan_ = self.conv_chans[i] |
|
out_chan_ = self.conv_chans[i+1] |
|
self.convs.append( |
|
Basic_Conv3x3(in_chan_, out_chan_) |
|
) |
|
|
|
def forward(self, x): |
|
out_dict = {'D0': x} |
|
for i in range(len(self.convs)): |
|
x = self.convs[i](x) |
|
name_ = 'D'+str(i+1) |
|
out_dict[name_] = x |
|
|
|
return out_dict |
|
|
|
|
|
class SegmentationHead(nn.Sequential): |
|
|
|
def __init__( |
|
self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False, |
|
): |
|
if use_attention: |
|
attention = AttentionBlock(in_channels) |
|
else: |
|
attention = nn.Identity() |
|
conv2d = nn.Conv2d( |
|
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 |
|
) |
|
activation = activation |
|
super().__init__(attention, conv2d, activation) |
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
""" |
|
Attention gate |
|
|
|
Parameters: |
|
----------- |
|
in_chns : int |
|
Number of input channels. |
|
|
|
Forward Input: |
|
-------------- |
|
x : torch.Tensor |
|
Input tensor of shape [B, C, H, W]. |
|
|
|
Returns: |
|
-------- |
|
torch.Tensor |
|
Reweighted tensor of the same shape as input. |
|
""" |
|
def __init__(self, in_chns): |
|
super(AttentionBlock, self).__init__() |
|
|
|
self.psi = nn.Sequential( |
|
nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True), |
|
nn.BatchNorm2d(in_chns // 2), |
|
nn.ReLU(), |
|
nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
g = self.psi(x) |
|
return x * g |
|
|
|
|
|
class Fusion_Block(nn.Module): |
|
""" |
|
Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. |
|
""" |
|
def __init__( |
|
self, |
|
in_chans, |
|
out_chans, |
|
): |
|
super().__init__() |
|
self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) |
|
|
|
def forward(self, x, D): |
|
F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) |
|
out = torch.cat([D, F_up], dim=1) |
|
out = self.conv(out) |
|
|
|
return out |
|
|
|
|
|
class MIPHEIViT(nn.Module): |
|
""" |
|
U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin) |
|
as encoder and a convolutional decoder. Designed for dense image prediction tasks, |
|
such as image-to-image translation. |
|
|
|
Parameters: |
|
----------- |
|
encoder : nn.Module |
|
A ViT- or Swin-based encoder that outputs spatial feature maps. |
|
decoder : nn.Module |
|
A decoder module that maps encoder features (and optionally the original image) |
|
to the output prediction. |
|
|
|
Example: |
|
-------- |
|
model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder()) |
|
output = model(input_tensor) |
|
""" |
|
def __init__(self, |
|
encoder, |
|
decoder, |
|
): |
|
super(MIPHEIViT, self).__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.initialize() |
|
|
|
def forward(self, x): |
|
|
|
features = self.encoder(x) |
|
outputs = self.decoder(features, x) |
|
return outputs |
|
|
|
def initialize(self): |
|
pass |
|
|
|
@classmethod |
|
def from_pretrained_hf(cls, repo_path=None, repo_id=None): |
|
from safetensors.torch import load_file |
|
import json |
|
if repo_path: |
|
weights_path = os.path.join(repo_path, "model.safetensors") |
|
config_path = os.path.join(repo_path, "config_hf.json") |
|
else: |
|
from huggingface_hub import hf_hub_download |
|
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors") |
|
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json") |
|
|
|
|
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
|
|
img_size = config["img_size"] |
|
nc_out = len(config["targ_channel_names"]) |
|
use_attention = config["use_attention"] |
|
hoptimus_hf_id = config["hoptimus_hf_id"] |
|
|
|
vit = get_hoptimus0_hf(hoptimus_hf_id) |
|
vit.set_input_size(img_size=(img_size, img_size)) |
|
encoder = Encoder(vit) |
|
decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh()) |
|
model = cls(encoder=encoder, decoder=decoder) |
|
state_dict = load_file(weights_path) |
|
state_dict = merge_lora_weights(model, state_dict) |
|
load_info = model.load_state_dict(state_dict, strict=False) |
|
validate_load_info(load_info) |
|
model.eval() |
|
return model |
|
|
|
def set_input_size(self, img_size): |
|
if any((s & (s - 1)) != 0 or s == 0 for s in img_size): |
|
raise ValueError("Both height and width in img_size must be powers of 2") |
|
if any(s < 128 for s in img_size): |
|
raise ValueError("Height and width must be greater or equal to 128") |
|
self.encoder.vit.set_input_size(img_size=img_size) |
|
self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size |
|
|
|
|
|
class Encoder(nn.Module): |
|
""" |
|
Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible |
|
with U-Net-like architectures. It reshapes and resizes transformer outputs |
|
into spatial feature maps. |
|
|
|
Parameters: |
|
----------- |
|
vit : VisionTransformer or SwinTransformer |
|
A pretrained transformer model from `timm` that outputs patch embeddings. |
|
""" |
|
def __init__(self, vit): |
|
super().__init__() |
|
if not isinstance(vit, (VisionTransformer, SwinTransformer)): |
|
raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}") |
|
self.vit = vit |
|
|
|
self.is_swint = isinstance(vit, SwinTransformer) |
|
self.grid_size = self.vit.patch_embed.grid_size |
|
if self.is_swint: |
|
self.num_prefix_tokens = 0 |
|
self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1) |
|
else: |
|
self.num_prefix_tokens = self.vit.num_prefix_tokens |
|
self.embed_dim = self.vit.embed_dim |
|
patch_size = self.vit.patch_embed.patch_size |
|
img_size = self.vit.patch_embed.img_size |
|
assert img_size[0] % 16 == 0 |
|
assert img_size[1] % 16 == 0 |
|
|
|
if self.is_swint: |
|
self.scale_factor = (2., 2.) |
|
else: |
|
if patch_size != (16, 16): |
|
target_grid_size = (img_size[0] / 16, img_size[1] / 16) |
|
self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1]) |
|
else: |
|
self.scale_factor = None |
|
|
|
def forward(self, x): |
|
features = self.vit(x) |
|
if self.is_swint: |
|
features = features.permute(0, 3, 1, 2) |
|
else: |
|
features = features[:, self.num_prefix_tokens:] |
|
features = features.permute(0, 2, 1) |
|
features = features.view((-1, self.embed_dim, *self.grid_size)) |
|
if self.scale_factor is not None: |
|
features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic") |
|
return features |
|
|
|
|
|
class Detail_Capture(nn.Module): |
|
""" |
|
Simple and Lightweight Detail Capture Module for ViT Matting. |
|
""" |
|
def __init__( |
|
self, |
|
emb_chans, |
|
in_chans=3, |
|
out_chans=1, |
|
convstream_out = [48, 96, 192], |
|
fusion_out = [256, 128, 64, 32], |
|
use_attention=True, |
|
activation=torch.nn.Identity() |
|
): |
|
super().__init__() |
|
assert len(fusion_out) == len(convstream_out) + 1 |
|
|
|
self.convstream = ConvStream(in_chans=in_chans) |
|
self.conv_chans = self.convstream.conv_chans |
|
self.num_heads = out_chans |
|
|
|
self.fusion_blks = nn.ModuleList() |
|
self.fus_channs = fusion_out.copy() |
|
self.fus_channs.insert(0, emb_chans) |
|
for i in range(len(self.fus_channs)-1): |
|
self.fusion_blks.append( |
|
Fusion_Block( |
|
in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], |
|
out_chans = self.fus_channs[i+1], |
|
) |
|
) |
|
|
|
for idx in range(self.num_heads): |
|
setattr(self, f'segmentation_head_{idx}', SegmentationHead( |
|
in_channels=fusion_out[-1], |
|
out_channels=1, |
|
activation=activation, |
|
kernel_size=3, |
|
use_attention=use_attention |
|
)) |
|
|
|
def forward(self, features, images): |
|
detail_features = self.convstream(images) |
|
for i in range(len(self.fusion_blks)): |
|
d_name_ = 'D'+str(len(self.fusion_blks)-i-1) |
|
features = self.fusion_blks[i](features, detail_features[d_name_]) |
|
|
|
outputs = [] |
|
for idx_head in range(self.num_heads): |
|
segmentation_head = getattr(self, f'segmentation_head_{idx_head}') |
|
output = segmentation_head(features) |
|
outputs.append(output) |
|
outputs = torch.cat(outputs, dim=1) |
|
|
|
return outputs |
|
|
|
|
|
def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"): |
|
""" |
|
Merges LoRA weights into the base attention Q and V projection weights for each transformer block. |
|
We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo. |
|
|
|
Parameters: |
|
----------- |
|
model : torch.nn.Module |
|
The model containing the transformer blocks to modify (e.g., ViT backbone). |
|
state_dict : dict |
|
The state_dict containing LoRA matrices with keys formatted as |
|
'{block_prefix}.{idx}.attn.qkv.lora_q.A', etc. |
|
This dict is modified in-place to remove LoRA weights after merging. |
|
alpha : float, optional |
|
Scaling factor for the LoRA update. Defaults to 1.0. |
|
block_prefix : str, optional |
|
Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks". |
|
|
|
Returns: |
|
-------- |
|
dict |
|
The modified state_dict with LoRA weights removed after merging. |
|
""" |
|
with torch.no_grad(): |
|
for idx in range(len(model.encoder.vit.blocks)): |
|
prefix = f"{block_prefix}.{idx}.attn.qkv" |
|
|
|
|
|
A_q = state_dict.pop(f"{prefix}.lora_q.A") |
|
B_q = state_dict.pop(f"{prefix}.lora_q.B") |
|
A_v = state_dict.pop(f"{prefix}.lora_v.A") |
|
B_v = state_dict.pop(f"{prefix}.lora_v.B") |
|
|
|
|
|
delta_q = (alpha * A_q @ B_q).T |
|
delta_v = (alpha * A_v @ B_v).T |
|
|
|
|
|
W = model.get_parameter(f"{prefix}.weight") |
|
dim = delta_q.shape[0] |
|
assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}" |
|
|
|
|
|
W[:dim, :] += delta_q |
|
W[2 * dim:, :] += delta_v |
|
|
|
return state_dict |
|
|
|
|
|
def get_hoptimus0_hf(repo_id): |
|
""" Hoptimus foundation model from hugginface repo id |
|
""" |
|
model = timm.create_model( |
|
"vit_giant_patch14_reg4_dinov2", img_size=224, |
|
drop_path_rate=0., num_classes=0, |
|
global_pool="", pretrained=False, init_values=1e-5, |
|
dynamic_img_size=False) |
|
state_dict = load_state_dict_from_hf(repo_id, weights_only=True) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
def validate_load_info(load_info): |
|
""" |
|
Validates the result of model.load_state_dict(..., strict=False). |
|
|
|
Raises: |
|
ValueError if unexpected keys are found, |
|
or if missing keys are not related to the allowed encoder modules. |
|
""" |
|
|
|
if load_info.unexpected_keys: |
|
raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}") |
|
|
|
|
|
for key in load_info.missing_keys: |
|
if ".lora" in key: |
|
raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}") |
|
elif not any(part in key for part in ["encoder.vit.", "encoder.model."]): |
|
raise ValueError(f"Missing key in state_dict: {key}") |
|
|