Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional, Union | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from mmseg.registry import MODELS | |
class OhemCrossEntropy(nn.Module): | |
"""OhemCrossEntropy loss. | |
This func is modified from | |
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L43>`_. # noqa | |
Licensed under the MIT License. | |
Args: | |
ignore_label (int): Labels to ignore when computing the loss. | |
Default: 255 | |
thresh (float, optional): The threshold for hard example selection. | |
Below which, are prediction with low confidence. If not | |
specified, the hard examples will be pixels of top ``min_kept`` | |
loss. Default: 0.7. | |
min_kept (int, optional): The minimum number of predictions to keep. | |
Default: 100000. | |
loss_weight (float): Weight of the loss. Defaults to 1.0. | |
class_weight (list[float] | str, optional): Weight of each class. If in | |
str format, read them from a file. Defaults to None. | |
loss_name (str): Name of the loss item. If you want this loss | |
item to be included into the backward graph, `loss_` must be the | |
prefix of the name. Defaults to 'loss_boundary'. | |
""" | |
def __init__(self, | |
ignore_label: int = 255, | |
thres: float = 0.7, | |
min_kept: int = 100000, | |
loss_weight: float = 1.0, | |
class_weight: Optional[Union[List[float], str]] = None, | |
loss_name: str = 'loss_ohem'): | |
super().__init__() | |
self.thresh = thres | |
self.min_kept = max(1, min_kept) | |
self.ignore_label = ignore_label | |
self.loss_weight = loss_weight | |
self.loss_name_ = loss_name | |
self.class_weight = class_weight | |
def forward(self, score: Tensor, target: Tensor) -> Tensor: | |
"""Forward function. | |
Args: | |
score (Tensor): Predictions of the segmentation head. | |
target (Tensor): Ground truth of the image. | |
Returns: | |
Tensor: Loss tensor. | |
""" | |
# score: (N, C, H, W) | |
pred = F.softmax(score, dim=1) | |
if self.class_weight is not None: | |
class_weight = score.new_tensor(self.class_weight) | |
else: | |
class_weight = None | |
pixel_losses = F.cross_entropy( | |
score, | |
target, | |
weight=class_weight, | |
ignore_index=self.ignore_label, | |
reduction='none').contiguous().view(-1) # (N*H*W) | |
mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W) | |
tmp_target = target.clone() # (N, H, W) | |
tmp_target[tmp_target == self.ignore_label] = 0 | |
# pred: (N, C, H, W) -> (N*H*W, C) | |
pred = pred.gather(1, tmp_target.unsqueeze(1)) | |
# pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W) | |
pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort() | |
if pred.numel() > 0: | |
min_value = pred[min(self.min_kept, pred.numel() - 1)] | |
else: | |
return score.new_tensor(0.0) | |
threshold = max(min_value, self.thresh) | |
pixel_losses = pixel_losses[mask][ind] | |
pixel_losses = pixel_losses[pred < threshold] | |
return self.loss_weight * pixel_losses.mean() | |
def loss_name(self): | |
return self.loss_name_ | |