# # For licensing see accompanying LICENSE file. # Copyright (C) 2022 Apple Inc. All Rights Reserved. # import copy from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from . import _utils as utils from ._base import EncoderMixin __all__ = ["MobileOne", "reparameterize_model"] class SEBlock(nn.Module): """Squeeze and Excite module. Pytorch implementation of `Squeeze-and-Excitation Networks` - https://arxiv.org/pdf/1709.01507.pdf """ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: """Construct a Squeeze and Excite Module. :param in_channels: Number of input channels. :param rd_ratio: Input channel reduction ratio. """ super(SEBlock, self).__init__() self.reduce = nn.Conv2d( in_channels=in_channels, out_channels=int(in_channels * rd_ratio), kernel_size=1, stride=1, bias=True, ) self.expand = nn.Conv2d( in_channels=int(in_channels * rd_ratio), out_channels=in_channels, kernel_size=1, stride=1, bias=True, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" b, c, h, w = inputs.size() x = F.avg_pool2d(inputs, kernel_size=[h, w]) x = self.reduce(x) x = F.relu(x) x = self.expand(x) x = torch.sigmoid(x) x = x.view(-1, c, 1, 1) return inputs * x class MobileOneBlock(nn.Module): """MobileOne building block. This block has a multi-branched architecture at train-time and plain-CNN style architecture at inference time For more details, please refer to our paper: `An Improved One millisecond Mobile Backbone` - https://arxiv.org/pdf/2206.04040.pdf """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, inference_mode: bool = False, use_se: bool = False, num_conv_branches: int = 1, ) -> None: """Construct a MobileOneBlock module. :param in_channels: Number of channels in the input. :param out_channels: Number of channels produced by the block. :param kernel_size: Size of the convolution kernel. :param stride: Stride size. :param padding: Zero-padding size. :param dilation: Kernel dilation factor. :param groups: Group number. :param inference_mode: If True, instantiates model in inference mode. :param use_se: Whether to use SE-ReLU activations. :param num_conv_branches: Number of linear conv branches. """ super(MobileOneBlock, self).__init__() self.inference_mode = inference_mode self.groups = groups self.stride = stride self.kernel_size = kernel_size self.in_channels = in_channels self.out_channels = out_channels self.num_conv_branches = num_conv_branches # Check if SE-ReLU is requested if use_se: self.se = SEBlock(out_channels) else: self.se = nn.Identity() self.activation = nn.ReLU() if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True, ) else: # Re-parameterizable skip connection self.rbr_skip = ( nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None ) # Re-parameterizable conv branches rbr_conv = list() for _ in range(self.num_conv_branches): rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding)) self.rbr_conv = nn.ModuleList(rbr_conv) # Re-parameterizable scale branch self.rbr_scale = None if kernel_size > 1: self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" # Inference mode forward pass. if self.inference_mode: return self.activation(self.se(self.reparam_conv(x))) # Multi-branched train-time forward pass. # Skip branch output identity_out = 0 if self.rbr_skip is not None: identity_out = self.rbr_skip(x) # Scale branch output scale_out = 0 if self.rbr_scale is not None: scale_out = self.rbr_scale(x) # Other branches out = scale_out + identity_out for ix in range(self.num_conv_branches): out += self.rbr_conv[ix](x) return self.activation(self.se(out)) def reparameterize(self): """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure for inference. """ if self.inference_mode: return kernel, bias = self._get_kernel_bias() self.reparam_conv = nn.Conv2d( in_channels=self.rbr_conv[0].conv.in_channels, out_channels=self.rbr_conv[0].conv.out_channels, kernel_size=self.rbr_conv[0].conv.kernel_size, stride=self.rbr_conv[0].conv.stride, padding=self.rbr_conv[0].conv.padding, dilation=self.rbr_conv[0].conv.dilation, groups=self.rbr_conv[0].conv.groups, bias=True, ) self.reparam_conv.weight.data = kernel self.reparam_conv.bias.data = bias # Delete un-used branches for para in self.parameters(): para.detach_() self.__delattr__("rbr_conv") self.__delattr__("rbr_scale") if hasattr(self, "rbr_skip"): self.__delattr__("rbr_skip") self.inference_mode = True def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: """Obtain the re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 :return: Tuple of (kernel, bias) after fusing branches. """ # get weights and bias of scale branch kernel_scale = 0 bias_scale = 0 if self.rbr_scale is not None: kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) # Pad scale branch kernel to match conv branch kernel size. pad = self.kernel_size // 2 kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) # get weights and bias of skip branch kernel_identity = 0 bias_identity = 0 if self.rbr_skip is not None: kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) # get weights and bias of conv branches kernel_conv = 0 bias_conv = 0 for ix in range(self.num_conv_branches): _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) kernel_conv += _kernel bias_conv += _bias kernel_final = kernel_conv + kernel_scale + kernel_identity bias_final = bias_conv + bias_scale + bias_identity return kernel_final, bias_final def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: """Fuse batchnorm layer with preceeding conv layer. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 :param branch: :return: Tuple of (kernel, bias) after fusing batchnorm. """ if isinstance(branch, nn.Sequential): kernel = branch.conv.weight running_mean = branch.bn.running_mean running_var = branch.bn.running_var gamma = branch.bn.weight beta = branch.bn.bias eps = branch.bn.eps else: assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, "id_tensor"): input_dim = self.in_channels // self.groups kernel_value = torch.zeros( (self.in_channels, input_dim, self.kernel_size, self.kernel_size), dtype=branch.weight.dtype, device=branch.weight.device, ) for i in range(self.in_channels): kernel_value[ i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 ] = 1 self.id_tensor = kernel_value kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var gamma = branch.weight beta = branch.bias eps = branch.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: """Construct conv-batchnorm layers. :param kernel_size: Size of the convolution kernel. :param padding: Zero-padding size. :return: Conv-BN module. """ mod_list = nn.Sequential() mod_list.add_module( "conv", nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=kernel_size, stride=self.stride, padding=padding, groups=self.groups, bias=False, ), ) mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) return mod_list class MobileOne(nn.Module, EncoderMixin): """MobileOne Model Pytorch implementation of `An Improved One millisecond Mobile Backbone` - https://arxiv.org/pdf/2206.04040.pdf """ def __init__( self, out_channels, num_blocks_per_stage: List[int] = [2, 8, 10, 1], width_multipliers: Optional[List[float]] = None, inference_mode: bool = False, use_se: bool = False, depth=5, in_channels=3, num_conv_branches: int = 1, ) -> None: """Construct MobileOne model. :param num_blocks_per_stage: List of number of blocks per stage. :param num_classes: Number of classes in the dataset. :param width_multipliers: List of width multiplier for blocks in a stage. :param inference_mode: If True, instantiates model in inference mode. :param use_se: Whether to use SE-ReLU activations. :param num_conv_branches: Number of linear conv branches. """ super().__init__() assert len(width_multipliers) == 4 self.inference_mode = inference_mode self._out_channels = out_channels self.in_planes = min(64, int(64 * width_multipliers[0])) self.use_se = use_se self.num_conv_branches = num_conv_branches self._depth = depth self._in_channels = in_channels self.set_in_channels(self._in_channels) # Build stages self.stage0 = MobileOneBlock( in_channels=self._in_channels, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, inference_mode=self.inference_mode, ) self.cur_layer_idx = 1 self.stage1 = self._make_stage( int(64 * width_multipliers[0]), num_blocks_per_stage[0], num_se_blocks=0 ) self.stage2 = self._make_stage( int(128 * width_multipliers[1]), num_blocks_per_stage[1], num_se_blocks=0 ) self.stage3 = self._make_stage( int(256 * width_multipliers[2]), num_blocks_per_stage[2], num_se_blocks=int(num_blocks_per_stage[2] // 2) if use_se else 0, ) self.stage4 = self._make_stage( int(512 * width_multipliers[3]), num_blocks_per_stage[3], num_se_blocks=num_blocks_per_stage[3] if use_se else 0, ) def get_stages(self): return [ nn.Identity(), self.stage0, self.stage1, self.stage2, self.stage3, self.stage4, ] def _make_stage( self, planes: int, num_blocks: int, num_se_blocks: int ) -> nn.Sequential: """Build a stage of MobileOne model. :param planes: Number of output channels. :param num_blocks: Number of blocks in this stage. :param num_se_blocks: Number of SE blocks in this stage. :return: A stage of MobileOne model. """ # Get strides for all layers strides = [2] + [1] * (num_blocks - 1) blocks = [] for ix, stride in enumerate(strides): use_se = False if num_se_blocks > num_blocks: raise ValueError( "Number of SE blocks cannot " "exceed number of layers." ) if ix >= (num_blocks - num_se_blocks): use_se = True # Depthwise conv blocks.append( MobileOneBlock( in_channels=self.in_planes, out_channels=self.in_planes, kernel_size=3, stride=stride, padding=1, groups=self.in_planes, inference_mode=self.inference_mode, use_se=use_se, num_conv_branches=self.num_conv_branches, ) ) # Pointwise conv blocks.append( MobileOneBlock( in_channels=self.in_planes, out_channels=planes, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=self.inference_mode, use_se=use_se, num_conv_branches=self.num_conv_branches, ) ) self.in_planes = planes self.cur_layer_idx += 1 return nn.Sequential(*blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" stages = self.get_stages() features = [] for i in range(self._depth + 1): x = stages[i](x) features.append(x) return features def load_state_dict(self, state_dict, **kwargs): state_dict.pop("linear.weight", None) state_dict.pop("linear.bias", None) super().load_state_dict(state_dict, **kwargs) def set_in_channels(self, in_channels, pretrained=True): """Change first convolution channels""" if in_channels == 3: return self._in_channels = in_channels self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) utils.patch_first_conv( model=self.stage0.rbr_conv, new_in_channels=in_channels, pretrained=pretrained, ) utils.patch_first_conv( model=self.stage0.rbr_scale, new_in_channels=in_channels, pretrained=pretrained, ) def reparameterize_model(model: torch.nn.Module) -> nn.Module: """Return a model where a multi-branched structure used in training is re-parameterized into a single branch for inference. :param model: MobileOne model in train mode. :return: MobileOne model in inference mode. """ # Avoid editing original graph model = copy.deepcopy(model) for module in model.modules(): if hasattr(module, "reparameterize"): module.reparameterize() return model mobileone_encoders = { "mobileone_s0": { "encoder": MobileOne, "pretrained_settings": { "imagenet": { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar", # noqa "input_space": "RGB", "input_range": [0, 1], } }, "params": { "out_channels": (3, 48, 48, 128, 256, 1024), "width_multipliers": (0.75, 1.0, 1.0, 2.0), "num_conv_branches": 4, "inference_mode": False, }, }, "mobileone_s1": { "encoder": MobileOne, "pretrained_settings": { "imagenet": { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar", # noqa "input_space": "RGB", "input_range": [0, 1], } }, "params": { "out_channels": (3, 64, 96, 192, 512, 1280), "width_multipliers": (1.5, 1.5, 2.0, 2.5), "inference_mode": False, }, }, "mobileone_s2": { "encoder": MobileOne, "pretrained_settings": { "imagenet": { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar", # noqa "input_space": "RGB", "input_range": [0, 1], } }, "params": { "out_channels": (3, 64, 96, 256, 640, 2048), "width_multipliers": (1.5, 2.0, 2.5, 4.0), "inference_mode": False, }, }, "mobileone_s3": { "encoder": MobileOne, "pretrained_settings": { "imagenet": { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar", # noqa "input_space": "RGB", "input_range": [0, 1], } }, "params": { "out_channels": (3, 64, 128, 320, 768, 2048), "width_multipliers": (2.0, 2.5, 3.0, 4.0), "inference_mode": False, }, }, "mobileone_s4": { "encoder": MobileOne, "pretrained_settings": { "imagenet": { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar", # noqa "input_space": "RGB", "input_range": [0, 1], } }, "params": { "out_channels": (3, 64, 192, 448, 896, 2048), "width_multipliers": (3.0, 3.5, 3.5, 4.0), "use_se": True, "inference_mode": False, }, }, }