Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmseg.registry import MODELS | |
from mmseg.utils import OptConfigType | |
class BasicBlock(BaseModule): | |
"""Basic block from `ResNet <https://arxiv.org/abs/1512.03385>`_. | |
Args: | |
in_channels (int): Input channels. | |
channels (int): Output channels. | |
stride (int): Stride of the first block. Default: 1. | |
downsample (nn.Module, optional): Downsample operation on identity. | |
Default: None. | |
norm_cfg (dict, optional): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict, optional): Config dict for activation layer in | |
ConvModule. Default: dict(type='ReLU', inplace=True). | |
act_cfg_out (dict, optional): Config dict for activation layer at the | |
last of the block. Default: None. | |
init_cfg (dict, optional): Initialization config dict. Default: None. | |
""" | |
expansion = 1 | |
def __init__(self, | |
in_channels: int, | |
channels: int, | |
stride: int = 1, | |
downsample: nn.Module = None, | |
norm_cfg: OptConfigType = dict(type='BN'), | |
act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True), | |
init_cfg: OptConfigType = None): | |
super().__init__(init_cfg) | |
self.conv1 = ConvModule( | |
in_channels, | |
channels, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.conv2 = ConvModule( | |
channels, | |
channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
self.downsample = downsample | |
if act_cfg_out: | |
self.act = MODELS.build(act_cfg_out) | |
def forward(self, x: Tensor) -> Tensor: | |
residual = x | |
out = self.conv1(x) | |
out = self.conv2(out) | |
if self.downsample: | |
residual = self.downsample(x) | |
out += residual | |
if hasattr(self, 'act'): | |
out = self.act(out) | |
return out | |
class Bottleneck(BaseModule): | |
"""Bottleneck block from `ResNet <https://arxiv.org/abs/1512.03385>`_. | |
Args: | |
in_channels (int): Input channels. | |
channels (int): Output channels. | |
stride (int): Stride of the first block. Default: 1. | |
downsample (nn.Module, optional): Downsample operation on identity. | |
Default: None. | |
norm_cfg (dict, optional): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict, optional): Config dict for activation layer in | |
ConvModule. Default: dict(type='ReLU', inplace=True). | |
act_cfg_out (dict, optional): Config dict for activation layer at | |
the last of the block. Default: None. | |
init_cfg (dict, optional): Initialization config dict. Default: None. | |
""" | |
expansion = 2 | |
def __init__(self, | |
in_channels: int, | |
channels: int, | |
stride: int = 1, | |
downsample: Optional[nn.Module] = None, | |
norm_cfg: OptConfigType = dict(type='BN'), | |
act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
act_cfg_out: OptConfigType = None, | |
init_cfg: OptConfigType = None): | |
super().__init__(init_cfg) | |
self.conv1 = ConvModule( | |
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) | |
self.conv2 = ConvModule( | |
channels, | |
channels, | |
3, | |
stride, | |
1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.conv3 = ConvModule( | |
channels, | |
channels * self.expansion, | |
1, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
if act_cfg_out: | |
self.act = MODELS.build(act_cfg_out) | |
self.downsample = downsample | |
def forward(self, x: Tensor) -> Tensor: | |
residual = x | |
out = self.conv1(x) | |
out = self.conv2(out) | |
out = self.conv3(out) | |
if self.downsample: | |
residual = self.downsample(x) | |
out += residual | |
if hasattr(self, 'act'): | |
out = self.act(out) | |
return out | |