# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
master/code/train_LA_HD.py (Apache-2.0 License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt as distance
from torch import Tensor

from mmseg.registry import MODELS
from .utils import get_class_weight, weighted_loss


def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
    """
    compute the distance transform map of foreground in mask
    Args:
        img_gt: Ground truth of the image, (b, h, w)
        pred: Predictions of the segmentation head after softmax, (b, c, h, w)

    Returns:
        output: the foreground Distance Map (SDM)
        dtm(x) = 0; x in segmentation boundary
                inf|x-y|; x in segmentation
    """

    fg_dtm = torch.zeros_like(pred)
    out_shape = pred.shape
    for b in range(out_shape[0]):  # batch size
        for c in range(1, out_shape[1]):  # default 0 channel is background
            posmask = img_gt[b].byte()
            if posmask.any():
                posdis = distance(posmask)
                fg_dtm[b][c] = torch.from_numpy(posdis)

    return fg_dtm


@weighted_loss
def hd_loss(seg_soft: Tensor,
            gt: Tensor,
            seg_dtm: Tensor,
            gt_dtm: Tensor,
            class_weight=None,
            ignore_index=255) -> Tensor:
    """
    compute huasdorff distance loss for segmentation
    Args:
        seg_soft: softmax results, shape=(b,c,x,y)
        gt: ground truth, shape=(b,x,y)
        seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
        gt_dtm: ground truth distance transform map, shape=(b,c,x,y)

    Returns:
        output: hd_loss
    """
    assert seg_soft.shape[0] == gt.shape[0]
    total_loss = 0
    num_class = seg_soft.shape[1]
    if class_weight is not None:
        assert class_weight.ndim == num_class
    for i in range(1, num_class):
        if i != ignore_index:
            delta_s = (seg_soft[:, i, ...] - gt.float())**2
            s_dtm = seg_dtm[:, i, ...]**2
            g_dtm = gt_dtm[:, i, ...]**2
            dtm = s_dtm + g_dtm
            multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
            hd_loss = multiplied.mean()
        if class_weight is not None:
            hd_loss *= class_weight[i]
        total_loss += hd_loss

    return total_loss / num_class


@MODELS.register_module()
class HuasdorffDisstanceLoss(nn.Module):
    """HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
    Maps Boost Segmentation CNNs: An Empirical Study.

    <http://proceedings.mlr.press/v121/ma20b.html>`_.
    Args:
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'.
        class_weight (list[float] | str, optional): Weight of each class. If in
            str format, read them from a file. Defaults to None.
        loss_weight (float): Weight of the loss. Defaults to 1.0.
        ignore_index (int | None): The label index to be ignored. Default: 255.
        loss_name (str): Name of the loss item. If you want this loss
            item to be included into the backward graph, `loss_` must be the
            prefix of the name. Defaults to 'loss_boundary'.
    """

    def __init__(self,
                 reduction='mean',
                 class_weight=None,
                 loss_weight=1.0,
                 ignore_index=255,
                 loss_name='loss_huasdorff_disstance',
                 **kwargs):
        super().__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.class_weight = get_class_weight(class_weight)
        self._loss_name = loss_name
        self.ignore_index = ignore_index

    def forward(self,
                pred: Tensor,
                target: Tensor,
                avg_factor=None,
                reduction_override=None,
                **kwargs) -> Tensor:
        """Forward function.

        Args:
            pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
            target (Tensor): Ground truth of the image. (B, H, W)
            avg_factor (int, optional): Average factor that is used to
                average the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used
                to override the original reduction method of the loss.
                Options are "none", "mean" and "sum".
        Returns:
            Tensor: Loss tensor.
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if self.class_weight is not None:
            class_weight = pred.new_tensor(self.class_weight)
        else:
            class_weight = None

        pred_soft = F.softmax(pred, dim=1)
        valid_mask = (target != self.ignore_index).long()
        target = target * valid_mask

        with torch.no_grad():
            gt_dtm = compute_dtm(target.cpu(), pred_soft)
            gt_dtm = gt_dtm.float()
            seg_dtm2 = compute_dtm(
                pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
            seg_dtm2 = seg_dtm2.float()

        loss_hd = self.loss_weight * hd_loss(
            pred_soft,
            target,
            seg_dtm=seg_dtm2,
            gt_dtm=gt_dtm,
            reduction=reduction,
            avg_factor=avg_factor,
            class_weight=class_weight,
            ignore_index=self.ignore_index)
        return loss_hd

    @property
    def loss_name(self):
        return self._loss_name