Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Union | |
import torch | |
import torch.nn as nn | |
from mmseg.registry import MODELS | |
from .utils import weight_reduce_loss | |
def _expand_onehot_labels_dice(pred: torch.Tensor, | |
target: torch.Tensor) -> torch.Tensor: | |
"""Expand onehot labels to match the size of prediction. | |
Args: | |
pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W). | |
target (torch.Tensor): The learning label of the prediction, | |
has a shape (N, H, W). | |
Returns: | |
torch.Tensor: The target after one-hot encoding, | |
has a shape (N, num_class, H, W). | |
""" | |
num_classes = pred.shape[1] | |
one_hot_target = torch.clamp(target, min=0, max=num_classes) | |
one_hot_target = torch.nn.functional.one_hot(one_hot_target, | |
num_classes + 1) | |
one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2) | |
return one_hot_target | |
def dice_loss(pred: torch.Tensor, | |
target: torch.Tensor, | |
weight: Union[torch.Tensor, None], | |
eps: float = 1e-3, | |
reduction: Union[str, None] = 'mean', | |
naive_dice: Union[bool, None] = False, | |
avg_factor: Union[int, None] = None, | |
ignore_index: Union[int, None] = 255) -> float: | |
"""Calculate dice loss, there are two forms of dice loss is supported: | |
- the one proposed in `V-Net: Fully Convolutional Neural | |
Networks for Volumetric Medical Image Segmentation | |
<https://arxiv.org/abs/1606.04797>`_. | |
- the dice loss in which the power of the number in the | |
denominator is the first power instead of the second | |
power. | |
Args: | |
pred (torch.Tensor): The prediction, has a shape (n, *) | |
target (torch.Tensor): The learning label of the prediction, | |
shape (n, *), same shape of pred. | |
weight (torch.Tensor, optional): The weight of loss for each | |
prediction, has a shape (n,). Defaults to None. | |
eps (float): Avoid dividing by zero. Default: 1e-3. | |
reduction (str, optional): The method used to reduce the loss into | |
a scalar. Defaults to 'mean'. | |
Options are "none", "mean" and "sum". | |
naive_dice (bool, optional): If false, use the dice | |
loss defined in the V-Net paper, otherwise, use the | |
naive dice loss in which the power of the number in the | |
denominator is the first power instead of the second | |
power.Defaults to False. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
ignore_index (int, optional): The label index to be ignored. | |
Defaults to 255. | |
""" | |
if ignore_index is not None: | |
num_classes = pred.shape[1] | |
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] | |
target = target[:, torch.arange(num_classes) != ignore_index, :, :] | |
assert pred.shape[1] != 0 # if the ignored index is the only class | |
input = pred.flatten(1) | |
target = target.flatten(1).float() | |
a = torch.sum(input * target, 1) | |
if naive_dice: | |
b = torch.sum(input, 1) | |
c = torch.sum(target, 1) | |
d = (2 * a + eps) / (b + c + eps) | |
else: | |
b = torch.sum(input * input, 1) + eps | |
c = torch.sum(target * target, 1) + eps | |
d = (2 * a) / (b + c) | |
loss = 1 - d | |
if weight is not None: | |
assert weight.ndim == loss.ndim | |
assert len(weight) == len(pred) | |
loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
return loss | |
class DiceLoss(nn.Module): | |
def __init__(self, | |
use_sigmoid=True, | |
activate=True, | |
reduction='mean', | |
naive_dice=False, | |
loss_weight=1.0, | |
ignore_index=255, | |
eps=1e-3, | |
loss_name='loss_dice'): | |
"""Compute dice loss. | |
Args: | |
use_sigmoid (bool, optional): Whether to the prediction is | |
used for sigmoid or softmax. Defaults to True. | |
activate (bool): Whether to activate the predictions inside, | |
this will disable the inside sigmoid operation. | |
Defaults to True. | |
reduction (str, optional): The method used | |
to reduce the loss. Options are "none", | |
"mean" and "sum". Defaults to 'mean'. | |
naive_dice (bool, optional): If false, use the dice | |
loss defined in the V-Net paper, otherwise, use the | |
naive dice loss in which the power of the number in the | |
denominator is the first power instead of the second | |
power. Defaults to False. | |
loss_weight (float, optional): Weight of loss. Defaults to 1.0. | |
ignore_index (int, optional): The label index to be ignored. | |
Default: 255. | |
eps (float): Avoid dividing by zero. Defaults to 1e-3. | |
loss_name (str, optional): Name of the loss item. If you want this | |
loss item to be included into the backward graph, `loss_` must | |
be the prefix of the name. Defaults to 'loss_dice'. | |
""" | |
super().__init__() | |
self.use_sigmoid = use_sigmoid | |
self.reduction = reduction | |
self.naive_dice = naive_dice | |
self.loss_weight = loss_weight | |
self.eps = eps | |
self.activate = activate | |
self.ignore_index = ignore_index | |
self._loss_name = loss_name | |
def forward(self, | |
pred, | |
target, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None, | |
ignore_index=255, | |
**kwargs): | |
"""Forward function. | |
Args: | |
pred (torch.Tensor): The prediction, has a shape (n, *). | |
target (torch.Tensor): The label of the prediction, | |
shape (n, *), same shape of pred. | |
weight (torch.Tensor, optional): The weight of loss for each | |
prediction, has a shape (n,). Defaults to None. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Options are "none", "mean" and "sum". | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
one_hot_target = target | |
if (pred.shape != target.shape): | |
one_hot_target = _expand_onehot_labels_dice(pred, target) | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = ( | |
reduction_override if reduction_override else self.reduction) | |
if self.activate: | |
if self.use_sigmoid: | |
pred = pred.sigmoid() | |
elif pred.shape[1] != 1: | |
# softmax does not work when there is only 1 class | |
pred = pred.softmax(dim=1) | |
loss = self.loss_weight * dice_loss( | |
pred, | |
one_hot_target, | |
weight, | |
eps=self.eps, | |
reduction=reduction, | |
naive_dice=self.naive_dice, | |
avg_factor=avg_factor, | |
ignore_index=self.ignore_index) | |
return loss | |
def loss_name(self): | |
"""Loss Name. | |
This function must be implemented and will return the name of this | |
loss function. This name will be used to combine different loss items | |
by simple sum operation. In addition, if you want this loss item to be | |
included into the backward graph, `loss_` must be the prefix of the | |
name. | |
Returns: | |
str: The name of this loss item. | |
""" | |
return self._loss_name | |