# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule

from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead

try:
    from mmcv.ops import PSAMask
except ModuleNotFoundError:
    PSAMask = None


@MODELS.register_module()
class PSAHead(BaseDecodeHead):
    """Point-wise Spatial Attention Network for Scene Parsing.

    This head is the implementation of `PSANet
    <https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.

    Args:
        mask_size (tuple[int]): The PSA mask size. It usually equals input
            size.
        psa_type (str): The type of psa module. Options are 'collect',
            'distribute', 'bi-direction'. Default: 'bi-direction'
        compact (bool): Whether use compact map for 'collect' mode.
            Default: True.
        shrink_factor (int): The downsample factors of psa mask. Default: 2.
        normalization_factor (float): The normalize factor of attention.
        psa_softmax (bool): Whether use softmax for attention.
    """

    def __init__(self,
                 mask_size,
                 psa_type='bi-direction',
                 compact=False,
                 shrink_factor=2,
                 normalization_factor=1.0,
                 psa_softmax=True,
                 **kwargs):
        if PSAMask is None:
            raise RuntimeError('Please install mmcv-full for PSAMask ops')
        super().__init__(**kwargs)
        assert psa_type in ['collect', 'distribute', 'bi-direction']
        self.psa_type = psa_type
        self.compact = compact
        self.shrink_factor = shrink_factor
        self.mask_size = mask_size
        mask_h, mask_w = mask_size
        self.psa_softmax = psa_softmax
        if normalization_factor is None:
            normalization_factor = mask_h * mask_w
        self.normalization_factor = normalization_factor

        self.reduce = ConvModule(
            self.in_channels,
            self.channels,
            kernel_size=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        self.attention = nn.Sequential(
            ConvModule(
                self.channels,
                self.channels,
                kernel_size=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg),
            nn.Conv2d(
                self.channels, mask_h * mask_w, kernel_size=1, bias=False))
        if psa_type == 'bi-direction':
            self.reduce_p = ConvModule(
                self.in_channels,
                self.channels,
                kernel_size=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg)
            self.attention_p = nn.Sequential(
                ConvModule(
                    self.channels,
                    self.channels,
                    kernel_size=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg),
                nn.Conv2d(
                    self.channels, mask_h * mask_w, kernel_size=1, bias=False))
            self.psamask_collect = PSAMask('collect', mask_size)
            self.psamask_distribute = PSAMask('distribute', mask_size)
        else:
            self.psamask = PSAMask(psa_type, mask_size)
        self.proj = ConvModule(
            self.channels * (2 if psa_type == 'bi-direction' else 1),
            self.in_channels,
            kernel_size=1,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        self.bottleneck = ConvModule(
            self.in_channels * 2,
            self.channels,
            kernel_size=3,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

    def forward(self, inputs):
        """Forward function."""
        x = self._transform_inputs(inputs)
        identity = x
        align_corners = self.align_corners
        if self.psa_type in ['collect', 'distribute']:
            out = self.reduce(x)
            n, c, h, w = out.size()
            if self.shrink_factor != 1:
                if h % self.shrink_factor and w % self.shrink_factor:
                    h = (h - 1) // self.shrink_factor + 1
                    w = (w - 1) // self.shrink_factor + 1
                    align_corners = True
                else:
                    h = h // self.shrink_factor
                    w = w // self.shrink_factor
                    align_corners = False
                out = resize(
                    out,
                    size=(h, w),
                    mode='bilinear',
                    align_corners=align_corners)
            y = self.attention(out)
            if self.compact:
                if self.psa_type == 'collect':
                    y = y.view(n, h * w,
                               h * w).transpose(1, 2).view(n, h * w, h, w)
            else:
                y = self.psamask(y)
            if self.psa_softmax:
                y = F.softmax(y, dim=1)
            out = torch.bmm(
                out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
                    n, c, h, w) * (1.0 / self.normalization_factor)
        else:
            x_col = self.reduce(x)
            x_dis = self.reduce_p(x)
            n, c, h, w = x_col.size()
            if self.shrink_factor != 1:
                if h % self.shrink_factor and w % self.shrink_factor:
                    h = (h - 1) // self.shrink_factor + 1
                    w = (w - 1) // self.shrink_factor + 1
                    align_corners = True
                else:
                    h = h // self.shrink_factor
                    w = w // self.shrink_factor
                    align_corners = False
                x_col = resize(
                    x_col,
                    size=(h, w),
                    mode='bilinear',
                    align_corners=align_corners)
                x_dis = resize(
                    x_dis,
                    size=(h, w),
                    mode='bilinear',
                    align_corners=align_corners)
            y_col = self.attention(x_col)
            y_dis = self.attention_p(x_dis)
            if self.compact:
                y_dis = y_dis.view(n, h * w,
                                   h * w).transpose(1, 2).view(n, h * w, h, w)
            else:
                y_col = self.psamask_collect(y_col)
                y_dis = self.psamask_distribute(y_dis)
            if self.psa_softmax:
                y_col = F.softmax(y_col, dim=1)
                y_dis = F.softmax(y_dis, dim=1)
            x_col = torch.bmm(
                x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
                    n, c, h, w) * (1.0 / self.normalization_factor)
            x_dis = torch.bmm(
                x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
                    n, c, h, w) * (1.0 / self.normalization_factor)
            out = torch.cat([x_col, x_dis], 1)
        out = self.proj(out)
        out = resize(
            out,
            size=identity.shape[2:],
            mode='bilinear',
            align_corners=align_corners)
        out = self.bottleneck(torch.cat((identity, out), dim=1))
        out = self.cls_seg(out)
        return out