Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Tuple, Union | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
from torch import Tensor | |
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
from mmseg.models.losses import accuracy | |
from mmseg.models.utils import resize | |
from mmseg.registry import MODELS | |
from mmseg.utils import OptConfigType, SampleList | |
class DDRHead(BaseDecodeHead): | |
"""Decode head for DDRNet. | |
Args: | |
in_channels (int): Number of input channels. | |
channels (int): Number of output channels. | |
num_classes (int): Number of classes. | |
norm_cfg (dict, optional): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict, optional): Config dict for activation layer. | |
Default: dict(type='ReLU', inplace=True). | |
""" | |
def __init__(self, | |
in_channels: int, | |
channels: int, | |
num_classes: int, | |
norm_cfg: OptConfigType = dict(type='BN'), | |
act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
**kwargs): | |
super().__init__( | |
in_channels, | |
channels, | |
num_classes=num_classes, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
**kwargs) | |
self.head = self._make_base_head(self.in_channels, self.channels) | |
self.aux_head = self._make_base_head(self.in_channels // 2, | |
self.channels) | |
self.aux_cls_seg = nn.Conv2d( | |
self.channels, self.out_channels, kernel_size=1) | |
def init_weights(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_( | |
m.weight, mode='fan_out', nonlinearity='relu') | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def forward( | |
self, | |
inputs: Union[Tensor, | |
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: | |
if self.training: | |
c3_feat, c5_feat = inputs | |
x_c = self.head(c5_feat) | |
x_c = self.cls_seg(x_c) | |
x_s = self.aux_head(c3_feat) | |
x_s = self.aux_cls_seg(x_s) | |
return x_c, x_s | |
else: | |
x_c = self.head(inputs) | |
x_c = self.cls_seg(x_c) | |
return x_c | |
def _make_base_head(self, in_channels: int, | |
channels: int) -> nn.Sequential: | |
layers = [ | |
ConvModule( | |
in_channels, | |
channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
order=('norm', 'act', 'conv')), | |
build_norm_layer(self.norm_cfg, channels)[1], | |
build_activation_layer(self.act_cfg), | |
] | |
return nn.Sequential(*layers) | |
def loss_by_feat(self, seg_logits: Tuple[Tensor], | |
batch_data_samples: SampleList) -> dict: | |
loss = dict() | |
context_logit, spatial_logit = seg_logits | |
seg_label = self._stack_batch_gt(batch_data_samples) | |
context_logit = resize( | |
context_logit, | |
size=seg_label.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
spatial_logit = resize( | |
spatial_logit, | |
size=seg_label.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
seg_label = seg_label.squeeze(1) | |
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label) | |
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label) | |
loss['acc_seg'] = accuracy( | |
context_logit, seg_label, ignore_index=self.ignore_index) | |
return loss | |