# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmseg.registry import MODELS from ..utils import resize from .decode_head import BaseDecodeHead try: from mmcv.ops import PSAMask except ModuleNotFoundError: PSAMask = None @MODELS.register_module() class PSAHead(BaseDecodeHead): """Point-wise Spatial Attention Network for Scene Parsing. This head is the implementation of `PSANet <https://hszhao.github.io/papers/eccv18_psanet.pdf>`_. Args: mask_size (tuple[int]): The PSA mask size. It usually equals input size. psa_type (str): The type of psa module. Options are 'collect', 'distribute', 'bi-direction'. Default: 'bi-direction' compact (bool): Whether use compact map for 'collect' mode. Default: True. shrink_factor (int): The downsample factors of psa mask. Default: 2. normalization_factor (float): The normalize factor of attention. psa_softmax (bool): Whether use softmax for attention. """ def __init__(self, mask_size, psa_type='bi-direction', compact=False, shrink_factor=2, normalization_factor=1.0, psa_softmax=True, **kwargs): if PSAMask is None: raise RuntimeError('Please install mmcv-full for PSAMask ops') super().__init__(**kwargs) assert psa_type in ['collect', 'distribute', 'bi-direction'] self.psa_type = psa_type self.compact = compact self.shrink_factor = shrink_factor self.mask_size = mask_size mask_h, mask_w = mask_size self.psa_softmax = psa_softmax if normalization_factor is None: normalization_factor = mask_h * mask_w self.normalization_factor = normalization_factor self.reduce = ConvModule( self.in_channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.attention = nn.Sequential( ConvModule( self.channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), nn.Conv2d( self.channels, mask_h * mask_w, kernel_size=1, bias=False)) if psa_type == 'bi-direction': self.reduce_p = ConvModule( self.in_channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.attention_p = nn.Sequential( ConvModule( self.channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), nn.Conv2d( self.channels, mask_h * mask_w, kernel_size=1, bias=False)) self.psamask_collect = PSAMask('collect', mask_size) self.psamask_distribute = PSAMask('distribute', mask_size) else: self.psamask = PSAMask(psa_type, mask_size) self.proj = ConvModule( self.channels * (2 if psa_type == 'bi-direction' else 1), self.in_channels, kernel_size=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.bottleneck = ConvModule( self.in_channels * 2, self.channels, kernel_size=3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) identity = x align_corners = self.align_corners if self.psa_type in ['collect', 'distribute']: out = self.reduce(x) n, c, h, w = out.size() if self.shrink_factor != 1: if h % self.shrink_factor and w % self.shrink_factor: h = (h - 1) // self.shrink_factor + 1 w = (w - 1) // self.shrink_factor + 1 align_corners = True else: h = h // self.shrink_factor w = w // self.shrink_factor align_corners = False out = resize( out, size=(h, w), mode='bilinear', align_corners=align_corners) y = self.attention(out) if self.compact: if self.psa_type == 'collect': y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y = self.psamask(y) if self.psa_softmax: y = F.softmax(y, dim=1) out = torch.bmm( out.view(n, c, h * w), y.view(n, h * w, h * w)).view( n, c, h, w) * (1.0 / self.normalization_factor) else: x_col = self.reduce(x) x_dis = self.reduce_p(x) n, c, h, w = x_col.size() if self.shrink_factor != 1: if h % self.shrink_factor and w % self.shrink_factor: h = (h - 1) // self.shrink_factor + 1 w = (w - 1) // self.shrink_factor + 1 align_corners = True else: h = h // self.shrink_factor w = w // self.shrink_factor align_corners = False x_col = resize( x_col, size=(h, w), mode='bilinear', align_corners=align_corners) x_dis = resize( x_dis, size=(h, w), mode='bilinear', align_corners=align_corners) y_col = self.attention(x_col) y_dis = self.attention_p(x_dis) if self.compact: y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y_col = self.psamask_collect(y_col) y_dis = self.psamask_distribute(y_dis) if self.psa_softmax: y_col = F.softmax(y_col, dim=1) y_dis = F.softmax(y_dis, dim=1) x_col = torch.bmm( x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( n, c, h, w) * (1.0 / self.normalization_factor) x_dis = torch.bmm( x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( n, c, h, w) * (1.0 / self.normalization_factor) out = torch.cat([x_col, x_dis], 1) out = self.proj(out) out = resize( out, size=identity.shape[2:], mode='bilinear', align_corners=align_corners) out = self.bottleneck(torch.cat((identity, out), dim=1)) out = self.cls_seg(out) return out