wangerniu's picture
添加必要文件
c9b5796
raw
history blame
19.6 kB
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
import copy
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import _utils as utils
from ._base import EncoderMixin
__all__ = ["MobileOne", "reparameterize_model"]
class SEBlock(nn.Module):
"""Squeeze and Excite module.
Pytorch implementation of `Squeeze-and-Excitation Networks` -
https://arxiv.org/pdf/1709.01507.pdf
"""
def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
"""Construct a Squeeze and Excite Module.
:param in_channels: Number of input channels.
:param rd_ratio: Input channel reduction ratio.
"""
super(SEBlock, self).__init__()
self.reduce = nn.Conv2d(
in_channels=in_channels,
out_channels=int(in_channels * rd_ratio),
kernel_size=1,
stride=1,
bias=True,
)
self.expand = nn.Conv2d(
in_channels=int(in_channels * rd_ratio),
out_channels=in_channels,
kernel_size=1,
stride=1,
bias=True,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
b, c, h, w = inputs.size()
x = F.avg_pool2d(inputs, kernel_size=[h, w])
x = self.reduce(x)
x = F.relu(x)
x = self.expand(x)
x = torch.sigmoid(x)
x = x.view(-1, c, 1, 1)
return inputs * x
class MobileOneBlock(nn.Module):
"""MobileOne building block.
This block has a multi-branched architecture at train-time
and plain-CNN style architecture at inference time
For more details, please refer to our paper:
`An Improved One millisecond Mobile Backbone` -
https://arxiv.org/pdf/2206.04040.pdf
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
inference_mode: bool = False,
use_se: bool = False,
num_conv_branches: int = 1,
) -> None:
"""Construct a MobileOneBlock module.
:param in_channels: Number of channels in the input.
:param out_channels: Number of channels produced by the block.
:param kernel_size: Size of the convolution kernel.
:param stride: Stride size.
:param padding: Zero-padding size.
:param dilation: Kernel dilation factor.
:param groups: Group number.
:param inference_mode: If True, instantiates model in inference mode.
:param use_se: Whether to use SE-ReLU activations.
:param num_conv_branches: Number of linear conv branches.
"""
super(MobileOneBlock, self).__init__()
self.inference_mode = inference_mode
self.groups = groups
self.stride = stride
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_conv_branches = num_conv_branches
# Check if SE-ReLU is requested
if use_se:
self.se = SEBlock(out_channels)
else:
self.se = nn.Identity()
self.activation = nn.ReLU()
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
)
else:
# Re-parameterizable skip connection
self.rbr_skip = (
nn.BatchNorm2d(num_features=in_channels)
if out_channels == in_channels and stride == 1
else None
)
# Re-parameterizable conv branches
rbr_conv = list()
for _ in range(self.num_conv_branches):
rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding))
self.rbr_conv = nn.ModuleList(rbr_conv)
# Re-parameterizable scale branch
self.rbr_scale = None
if kernel_size > 1:
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
# Inference mode forward pass.
if self.inference_mode:
return self.activation(self.se(self.reparam_conv(x)))
# Multi-branched train-time forward pass.
# Skip branch output
identity_out = 0
if self.rbr_skip is not None:
identity_out = self.rbr_skip(x)
# Scale branch output
scale_out = 0
if self.rbr_scale is not None:
scale_out = self.rbr_scale(x)
# Other branches
out = scale_out + identity_out
for ix in range(self.num_conv_branches):
out += self.rbr_conv[ix](x)
return self.activation(self.se(out))
def reparameterize(self):
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.inference_mode:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(
in_channels=self.rbr_conv[0].conv.in_channels,
out_channels=self.rbr_conv[0].conv.out_channels,
kernel_size=self.rbr_conv[0].conv.kernel_size,
stride=self.rbr_conv[0].conv.stride,
padding=self.rbr_conv[0].conv.padding,
dilation=self.rbr_conv[0].conv.dilation,
groups=self.rbr_conv[0].conv.groups,
bias=True,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
# Delete un-used branches
for para in self.parameters():
para.detach_()
self.__delattr__("rbr_conv")
self.__delattr__("rbr_scale")
if hasattr(self, "rbr_skip"):
self.__delattr__("rbr_skip")
self.inference_mode = True
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Obtain the re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
:return: Tuple of (kernel, bias) after fusing branches.
"""
# get weights and bias of scale branch
kernel_scale = 0
bias_scale = 0
if self.rbr_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
# get weights and bias of conv branches
kernel_conv = 0
bias_conv = 0
for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
kernel_conv += _kernel
bias_conv += _bias
kernel_final = kernel_conv + kernel_scale + kernel_identity
bias_final = bias_conv + bias_scale + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
"""Fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
:param branch:
:return: Tuple of (kernel, bias) after fusing batchnorm.
"""
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_value = torch.zeros(
(self.in_channels, input_dim, self.kernel_size, self.kernel_size),
dtype=branch.weight.dtype,
device=branch.weight.device,
)
for i in range(self.in_channels):
kernel_value[
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
"""Construct conv-batchnorm layers.
:param kernel_size: Size of the convolution kernel.
:param padding: Zero-padding size.
:return: Conv-BN module.
"""
mod_list = nn.Sequential()
mod_list.add_module(
"conv",
nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=padding,
groups=self.groups,
bias=False,
),
)
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
return mod_list
class MobileOne(nn.Module, EncoderMixin):
"""MobileOne Model
Pytorch implementation of `An Improved One millisecond Mobile Backbone` -
https://arxiv.org/pdf/2206.04040.pdf
"""
def __init__(
self,
out_channels,
num_blocks_per_stage: List[int] = [2, 8, 10, 1],
width_multipliers: Optional[List[float]] = None,
inference_mode: bool = False,
use_se: bool = False,
depth=5,
in_channels=3,
num_conv_branches: int = 1,
) -> None:
"""Construct MobileOne model.
:param num_blocks_per_stage: List of number of blocks per stage.
:param num_classes: Number of classes in the dataset.
:param width_multipliers: List of width multiplier for blocks in a stage.
:param inference_mode: If True, instantiates model in inference mode.
:param use_se: Whether to use SE-ReLU activations.
:param num_conv_branches: Number of linear conv branches.
"""
super().__init__()
assert len(width_multipliers) == 4
self.inference_mode = inference_mode
self._out_channels = out_channels
self.in_planes = min(64, int(64 * width_multipliers[0]))
self.use_se = use_se
self.num_conv_branches = num_conv_branches
self._depth = depth
self._in_channels = in_channels
self.set_in_channels(self._in_channels)
# Build stages
self.stage0 = MobileOneBlock(
in_channels=self._in_channels,
out_channels=self.in_planes,
kernel_size=3,
stride=2,
padding=1,
inference_mode=self.inference_mode,
)
self.cur_layer_idx = 1
self.stage1 = self._make_stage(
int(64 * width_multipliers[0]), num_blocks_per_stage[0], num_se_blocks=0
)
self.stage2 = self._make_stage(
int(128 * width_multipliers[1]), num_blocks_per_stage[1], num_se_blocks=0
)
self.stage3 = self._make_stage(
int(256 * width_multipliers[2]),
num_blocks_per_stage[2],
num_se_blocks=int(num_blocks_per_stage[2] // 2) if use_se else 0,
)
self.stage4 = self._make_stage(
int(512 * width_multipliers[3]),
num_blocks_per_stage[3],
num_se_blocks=num_blocks_per_stage[3] if use_se else 0,
)
def get_stages(self):
return [
nn.Identity(),
self.stage0,
self.stage1,
self.stage2,
self.stage3,
self.stage4,
]
def _make_stage(
self, planes: int, num_blocks: int, num_se_blocks: int
) -> nn.Sequential:
"""Build a stage of MobileOne model.
:param planes: Number of output channels.
:param num_blocks: Number of blocks in this stage.
:param num_se_blocks: Number of SE blocks in this stage.
:return: A stage of MobileOne model.
"""
# Get strides for all layers
strides = [2] + [1] * (num_blocks - 1)
blocks = []
for ix, stride in enumerate(strides):
use_se = False
if num_se_blocks > num_blocks:
raise ValueError(
"Number of SE blocks cannot " "exceed number of layers."
)
if ix >= (num_blocks - num_se_blocks):
use_se = True
# Depthwise conv
blocks.append(
MobileOneBlock(
in_channels=self.in_planes,
out_channels=self.in_planes,
kernel_size=3,
stride=stride,
padding=1,
groups=self.in_planes,
inference_mode=self.inference_mode,
use_se=use_se,
num_conv_branches=self.num_conv_branches,
)
)
# Pointwise conv
blocks.append(
MobileOneBlock(
in_channels=self.in_planes,
out_channels=planes,
kernel_size=1,
stride=1,
padding=0,
groups=1,
inference_mode=self.inference_mode,
use_se=use_se,
num_conv_branches=self.num_conv_branches,
)
)
self.in_planes = planes
self.cur_layer_idx += 1
return nn.Sequential(*blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
stages = self.get_stages()
features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)
return features
def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("linear.weight", None)
state_dict.pop("linear.bias", None)
super().load_state_dict(state_dict, **kwargs)
def set_in_channels(self, in_channels, pretrained=True):
"""Change first convolution channels"""
if in_channels == 3:
return
self._in_channels = in_channels
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
utils.patch_first_conv(
model=self.stage0.rbr_conv,
new_in_channels=in_channels,
pretrained=pretrained,
)
utils.patch_first_conv(
model=self.stage0.rbr_scale,
new_in_channels=in_channels,
pretrained=pretrained,
)
def reparameterize_model(model: torch.nn.Module) -> nn.Module:
"""Return a model where a multi-branched structure
used in training is re-parameterized into a single branch
for inference.
:param model: MobileOne model in train mode.
:return: MobileOne model in inference mode.
"""
# Avoid editing original graph
model = copy.deepcopy(model)
for module in model.modules():
if hasattr(module, "reparameterize"):
module.reparameterize()
return model
mobileone_encoders = {
"mobileone_s0": {
"encoder": MobileOne,
"pretrained_settings": {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar", # noqa
"input_space": "RGB",
"input_range": [0, 1],
}
},
"params": {
"out_channels": (3, 48, 48, 128, 256, 1024),
"width_multipliers": (0.75, 1.0, 1.0, 2.0),
"num_conv_branches": 4,
"inference_mode": False,
},
},
"mobileone_s1": {
"encoder": MobileOne,
"pretrained_settings": {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar", # noqa
"input_space": "RGB",
"input_range": [0, 1],
}
},
"params": {
"out_channels": (3, 64, 96, 192, 512, 1280),
"width_multipliers": (1.5, 1.5, 2.0, 2.5),
"inference_mode": False,
},
},
"mobileone_s2": {
"encoder": MobileOne,
"pretrained_settings": {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar", # noqa
"input_space": "RGB",
"input_range": [0, 1],
}
},
"params": {
"out_channels": (3, 64, 96, 256, 640, 2048),
"width_multipliers": (1.5, 2.0, 2.5, 4.0),
"inference_mode": False,
},
},
"mobileone_s3": {
"encoder": MobileOne,
"pretrained_settings": {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar", # noqa
"input_space": "RGB",
"input_range": [0, 1],
}
},
"params": {
"out_channels": (3, 64, 128, 320, 768, 2048),
"width_multipliers": (2.0, 2.5, 3.0, 4.0),
"inference_mode": False,
},
},
"mobileone_s4": {
"encoder": MobileOne,
"pretrained_settings": {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar", # noqa
"input_space": "RGB",
"input_range": [0, 1],
}
},
"params": {
"out_channels": (3, 64, 192, 448, 896, 2048),
"width_multipliers": (3.0, 3.5, 3.5, 4.0),
"use_se": True,
"inference_mode": False,
},
},
}