Spaces:
Sleeping
Sleeping
# | |
# For licensing see accompanying LICENSE file. | |
# Copyright (C) 2022 Apple Inc. All Rights Reserved. | |
# | |
import copy | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from . import _utils as utils | |
from ._base import EncoderMixin | |
__all__ = ["MobileOne", "reparameterize_model"] | |
class SEBlock(nn.Module): | |
"""Squeeze and Excite module. | |
Pytorch implementation of `Squeeze-and-Excitation Networks` - | |
https://arxiv.org/pdf/1709.01507.pdf | |
""" | |
def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: | |
"""Construct a Squeeze and Excite Module. | |
:param in_channels: Number of input channels. | |
:param rd_ratio: Input channel reduction ratio. | |
""" | |
super(SEBlock, self).__init__() | |
self.reduce = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=int(in_channels * rd_ratio), | |
kernel_size=1, | |
stride=1, | |
bias=True, | |
) | |
self.expand = nn.Conv2d( | |
in_channels=int(in_channels * rd_ratio), | |
out_channels=in_channels, | |
kernel_size=1, | |
stride=1, | |
bias=True, | |
) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
"""Apply forward pass.""" | |
b, c, h, w = inputs.size() | |
x = F.avg_pool2d(inputs, kernel_size=[h, w]) | |
x = self.reduce(x) | |
x = F.relu(x) | |
x = self.expand(x) | |
x = torch.sigmoid(x) | |
x = x.view(-1, c, 1, 1) | |
return inputs * x | |
class MobileOneBlock(nn.Module): | |
"""MobileOne building block. | |
This block has a multi-branched architecture at train-time | |
and plain-CNN style architecture at inference time | |
For more details, please refer to our paper: | |
`An Improved One millisecond Mobile Backbone` - | |
https://arxiv.org/pdf/2206.04040.pdf | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
dilation: int = 1, | |
groups: int = 1, | |
inference_mode: bool = False, | |
use_se: bool = False, | |
num_conv_branches: int = 1, | |
) -> None: | |
"""Construct a MobileOneBlock module. | |
:param in_channels: Number of channels in the input. | |
:param out_channels: Number of channels produced by the block. | |
:param kernel_size: Size of the convolution kernel. | |
:param stride: Stride size. | |
:param padding: Zero-padding size. | |
:param dilation: Kernel dilation factor. | |
:param groups: Group number. | |
:param inference_mode: If True, instantiates model in inference mode. | |
:param use_se: Whether to use SE-ReLU activations. | |
:param num_conv_branches: Number of linear conv branches. | |
""" | |
super(MobileOneBlock, self).__init__() | |
self.inference_mode = inference_mode | |
self.groups = groups | |
self.stride = stride | |
self.kernel_size = kernel_size | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.num_conv_branches = num_conv_branches | |
# Check if SE-ReLU is requested | |
if use_se: | |
self.se = SEBlock(out_channels) | |
else: | |
self.se = nn.Identity() | |
self.activation = nn.ReLU() | |
if inference_mode: | |
self.reparam_conv = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=True, | |
) | |
else: | |
# Re-parameterizable skip connection | |
self.rbr_skip = ( | |
nn.BatchNorm2d(num_features=in_channels) | |
if out_channels == in_channels and stride == 1 | |
else None | |
) | |
# Re-parameterizable conv branches | |
rbr_conv = list() | |
for _ in range(self.num_conv_branches): | |
rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding)) | |
self.rbr_conv = nn.ModuleList(rbr_conv) | |
# Re-parameterizable scale branch | |
self.rbr_scale = None | |
if kernel_size > 1: | |
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Apply forward pass.""" | |
# Inference mode forward pass. | |
if self.inference_mode: | |
return self.activation(self.se(self.reparam_conv(x))) | |
# Multi-branched train-time forward pass. | |
# Skip branch output | |
identity_out = 0 | |
if self.rbr_skip is not None: | |
identity_out = self.rbr_skip(x) | |
# Scale branch output | |
scale_out = 0 | |
if self.rbr_scale is not None: | |
scale_out = self.rbr_scale(x) | |
# Other branches | |
out = scale_out + identity_out | |
for ix in range(self.num_conv_branches): | |
out += self.rbr_conv[ix](x) | |
return self.activation(self.se(out)) | |
def reparameterize(self): | |
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` - | |
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched | |
architecture used at training time to obtain a plain CNN-like structure | |
for inference. | |
""" | |
if self.inference_mode: | |
return | |
kernel, bias = self._get_kernel_bias() | |
self.reparam_conv = nn.Conv2d( | |
in_channels=self.rbr_conv[0].conv.in_channels, | |
out_channels=self.rbr_conv[0].conv.out_channels, | |
kernel_size=self.rbr_conv[0].conv.kernel_size, | |
stride=self.rbr_conv[0].conv.stride, | |
padding=self.rbr_conv[0].conv.padding, | |
dilation=self.rbr_conv[0].conv.dilation, | |
groups=self.rbr_conv[0].conv.groups, | |
bias=True, | |
) | |
self.reparam_conv.weight.data = kernel | |
self.reparam_conv.bias.data = bias | |
# Delete un-used branches | |
for para in self.parameters(): | |
para.detach_() | |
self.__delattr__("rbr_conv") | |
self.__delattr__("rbr_scale") | |
if hasattr(self, "rbr_skip"): | |
self.__delattr__("rbr_skip") | |
self.inference_mode = True | |
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Obtain the re-parameterized kernel and bias. | |
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 | |
:return: Tuple of (kernel, bias) after fusing branches. | |
""" | |
# get weights and bias of scale branch | |
kernel_scale = 0 | |
bias_scale = 0 | |
if self.rbr_scale is not None: | |
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) | |
# Pad scale branch kernel to match conv branch kernel size. | |
pad = self.kernel_size // 2 | |
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) | |
# get weights and bias of skip branch | |
kernel_identity = 0 | |
bias_identity = 0 | |
if self.rbr_skip is not None: | |
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) | |
# get weights and bias of conv branches | |
kernel_conv = 0 | |
bias_conv = 0 | |
for ix in range(self.num_conv_branches): | |
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) | |
kernel_conv += _kernel | |
bias_conv += _bias | |
kernel_final = kernel_conv + kernel_scale + kernel_identity | |
bias_final = bias_conv + bias_scale + bias_identity | |
return kernel_final, bias_final | |
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Fuse batchnorm layer with preceeding conv layer. | |
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 | |
:param branch: | |
:return: Tuple of (kernel, bias) after fusing batchnorm. | |
""" | |
if isinstance(branch, nn.Sequential): | |
kernel = branch.conv.weight | |
running_mean = branch.bn.running_mean | |
running_var = branch.bn.running_var | |
gamma = branch.bn.weight | |
beta = branch.bn.bias | |
eps = branch.bn.eps | |
else: | |
assert isinstance(branch, nn.BatchNorm2d) | |
if not hasattr(self, "id_tensor"): | |
input_dim = self.in_channels // self.groups | |
kernel_value = torch.zeros( | |
(self.in_channels, input_dim, self.kernel_size, self.kernel_size), | |
dtype=branch.weight.dtype, | |
device=branch.weight.device, | |
) | |
for i in range(self.in_channels): | |
kernel_value[ | |
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 | |
] = 1 | |
self.id_tensor = kernel_value | |
kernel = self.id_tensor | |
running_mean = branch.running_mean | |
running_var = branch.running_var | |
gamma = branch.weight | |
beta = branch.bias | |
eps = branch.eps | |
std = (running_var + eps).sqrt() | |
t = (gamma / std).reshape(-1, 1, 1, 1) | |
return kernel * t, beta - running_mean * gamma / std | |
def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: | |
"""Construct conv-batchnorm layers. | |
:param kernel_size: Size of the convolution kernel. | |
:param padding: Zero-padding size. | |
:return: Conv-BN module. | |
""" | |
mod_list = nn.Sequential() | |
mod_list.add_module( | |
"conv", | |
nn.Conv2d( | |
in_channels=self.in_channels, | |
out_channels=self.out_channels, | |
kernel_size=kernel_size, | |
stride=self.stride, | |
padding=padding, | |
groups=self.groups, | |
bias=False, | |
), | |
) | |
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) | |
return mod_list | |
class MobileOne(nn.Module, EncoderMixin): | |
"""MobileOne Model | |
Pytorch implementation of `An Improved One millisecond Mobile Backbone` - | |
https://arxiv.org/pdf/2206.04040.pdf | |
""" | |
def __init__( | |
self, | |
out_channels, | |
num_blocks_per_stage: List[int] = [2, 8, 10, 1], | |
width_multipliers: Optional[List[float]] = None, | |
inference_mode: bool = False, | |
use_se: bool = False, | |
depth=5, | |
in_channels=3, | |
num_conv_branches: int = 1, | |
) -> None: | |
"""Construct MobileOne model. | |
:param num_blocks_per_stage: List of number of blocks per stage. | |
:param num_classes: Number of classes in the dataset. | |
:param width_multipliers: List of width multiplier for blocks in a stage. | |
:param inference_mode: If True, instantiates model in inference mode. | |
:param use_se: Whether to use SE-ReLU activations. | |
:param num_conv_branches: Number of linear conv branches. | |
""" | |
super().__init__() | |
assert len(width_multipliers) == 4 | |
self.inference_mode = inference_mode | |
self._out_channels = out_channels | |
self.in_planes = min(64, int(64 * width_multipliers[0])) | |
self.use_se = use_se | |
self.num_conv_branches = num_conv_branches | |
self._depth = depth | |
self._in_channels = in_channels | |
self.set_in_channels(self._in_channels) | |
# Build stages | |
self.stage0 = MobileOneBlock( | |
in_channels=self._in_channels, | |
out_channels=self.in_planes, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
inference_mode=self.inference_mode, | |
) | |
self.cur_layer_idx = 1 | |
self.stage1 = self._make_stage( | |
int(64 * width_multipliers[0]), num_blocks_per_stage[0], num_se_blocks=0 | |
) | |
self.stage2 = self._make_stage( | |
int(128 * width_multipliers[1]), num_blocks_per_stage[1], num_se_blocks=0 | |
) | |
self.stage3 = self._make_stage( | |
int(256 * width_multipliers[2]), | |
num_blocks_per_stage[2], | |
num_se_blocks=int(num_blocks_per_stage[2] // 2) if use_se else 0, | |
) | |
self.stage4 = self._make_stage( | |
int(512 * width_multipliers[3]), | |
num_blocks_per_stage[3], | |
num_se_blocks=num_blocks_per_stage[3] if use_se else 0, | |
) | |
def get_stages(self): | |
return [ | |
nn.Identity(), | |
self.stage0, | |
self.stage1, | |
self.stage2, | |
self.stage3, | |
self.stage4, | |
] | |
def _make_stage( | |
self, planes: int, num_blocks: int, num_se_blocks: int | |
) -> nn.Sequential: | |
"""Build a stage of MobileOne model. | |
:param planes: Number of output channels. | |
:param num_blocks: Number of blocks in this stage. | |
:param num_se_blocks: Number of SE blocks in this stage. | |
:return: A stage of MobileOne model. | |
""" | |
# Get strides for all layers | |
strides = [2] + [1] * (num_blocks - 1) | |
blocks = [] | |
for ix, stride in enumerate(strides): | |
use_se = False | |
if num_se_blocks > num_blocks: | |
raise ValueError( | |
"Number of SE blocks cannot " "exceed number of layers." | |
) | |
if ix >= (num_blocks - num_se_blocks): | |
use_se = True | |
# Depthwise conv | |
blocks.append( | |
MobileOneBlock( | |
in_channels=self.in_planes, | |
out_channels=self.in_planes, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
groups=self.in_planes, | |
inference_mode=self.inference_mode, | |
use_se=use_se, | |
num_conv_branches=self.num_conv_branches, | |
) | |
) | |
# Pointwise conv | |
blocks.append( | |
MobileOneBlock( | |
in_channels=self.in_planes, | |
out_channels=planes, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
groups=1, | |
inference_mode=self.inference_mode, | |
use_se=use_se, | |
num_conv_branches=self.num_conv_branches, | |
) | |
) | |
self.in_planes = planes | |
self.cur_layer_idx += 1 | |
return nn.Sequential(*blocks) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Apply forward pass.""" | |
stages = self.get_stages() | |
features = [] | |
for i in range(self._depth + 1): | |
x = stages[i](x) | |
features.append(x) | |
return features | |
def load_state_dict(self, state_dict, **kwargs): | |
state_dict.pop("linear.weight", None) | |
state_dict.pop("linear.bias", None) | |
super().load_state_dict(state_dict, **kwargs) | |
def set_in_channels(self, in_channels, pretrained=True): | |
"""Change first convolution channels""" | |
if in_channels == 3: | |
return | |
self._in_channels = in_channels | |
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) | |
utils.patch_first_conv( | |
model=self.stage0.rbr_conv, | |
new_in_channels=in_channels, | |
pretrained=pretrained, | |
) | |
utils.patch_first_conv( | |
model=self.stage0.rbr_scale, | |
new_in_channels=in_channels, | |
pretrained=pretrained, | |
) | |
def reparameterize_model(model: torch.nn.Module) -> nn.Module: | |
"""Return a model where a multi-branched structure | |
used in training is re-parameterized into a single branch | |
for inference. | |
:param model: MobileOne model in train mode. | |
:return: MobileOne model in inference mode. | |
""" | |
# Avoid editing original graph | |
model = copy.deepcopy(model) | |
for module in model.modules(): | |
if hasattr(module, "reparameterize"): | |
module.reparameterize() | |
return model | |
mobileone_encoders = { | |
"mobileone_s0": { | |
"encoder": MobileOne, | |
"pretrained_settings": { | |
"imagenet": { | |
"mean": [0.485, 0.456, 0.406], | |
"std": [0.229, 0.224, 0.225], | |
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar", # noqa | |
"input_space": "RGB", | |
"input_range": [0, 1], | |
} | |
}, | |
"params": { | |
"out_channels": (3, 48, 48, 128, 256, 1024), | |
"width_multipliers": (0.75, 1.0, 1.0, 2.0), | |
"num_conv_branches": 4, | |
"inference_mode": False, | |
}, | |
}, | |
"mobileone_s1": { | |
"encoder": MobileOne, | |
"pretrained_settings": { | |
"imagenet": { | |
"mean": [0.485, 0.456, 0.406], | |
"std": [0.229, 0.224, 0.225], | |
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar", # noqa | |
"input_space": "RGB", | |
"input_range": [0, 1], | |
} | |
}, | |
"params": { | |
"out_channels": (3, 64, 96, 192, 512, 1280), | |
"width_multipliers": (1.5, 1.5, 2.0, 2.5), | |
"inference_mode": False, | |
}, | |
}, | |
"mobileone_s2": { | |
"encoder": MobileOne, | |
"pretrained_settings": { | |
"imagenet": { | |
"mean": [0.485, 0.456, 0.406], | |
"std": [0.229, 0.224, 0.225], | |
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar", # noqa | |
"input_space": "RGB", | |
"input_range": [0, 1], | |
} | |
}, | |
"params": { | |
"out_channels": (3, 64, 96, 256, 640, 2048), | |
"width_multipliers": (1.5, 2.0, 2.5, 4.0), | |
"inference_mode": False, | |
}, | |
}, | |
"mobileone_s3": { | |
"encoder": MobileOne, | |
"pretrained_settings": { | |
"imagenet": { | |
"mean": [0.485, 0.456, 0.406], | |
"std": [0.229, 0.224, 0.225], | |
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar", # noqa | |
"input_space": "RGB", | |
"input_range": [0, 1], | |
} | |
}, | |
"params": { | |
"out_channels": (3, 64, 128, 320, 768, 2048), | |
"width_multipliers": (2.0, 2.5, 3.0, 4.0), | |
"inference_mode": False, | |
}, | |
}, | |
"mobileone_s4": { | |
"encoder": MobileOne, | |
"pretrained_settings": { | |
"imagenet": { | |
"mean": [0.485, 0.456, 0.406], | |
"std": [0.229, 0.224, 0.225], | |
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar", # noqa | |
"input_space": "RGB", | |
"input_range": [0, 1], | |
} | |
}, | |
"params": { | |
"out_channels": (3, 64, 192, 448, 896, 2048), | |
"width_multipliers": (3.0, 3.5, 3.5, 4.0), | |
"use_se": True, | |
"inference_mode": False, | |
}, | |
}, | |
} | |