Spaces:
Runtime error
Runtime error
File size: 5,685 Bytes
412c852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
master/code/train_LA_HD.py (Apache-2.0 License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt as distance
from torch import Tensor
from mmseg.registry import MODELS
from .utils import get_class_weight, weighted_loss
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
"""
compute the distance transform map of foreground in mask
Args:
img_gt: Ground truth of the image, (b, h, w)
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
Returns:
output: the foreground Distance Map (SDM)
dtm(x) = 0; x in segmentation boundary
inf|x-y|; x in segmentation
"""
fg_dtm = torch.zeros_like(pred)
out_shape = pred.shape
for b in range(out_shape[0]): # batch size
for c in range(1, out_shape[1]): # default 0 channel is background
posmask = img_gt[b].byte()
if posmask.any():
posdis = distance(posmask)
fg_dtm[b][c] = torch.from_numpy(posdis)
return fg_dtm
@weighted_loss
def hd_loss(seg_soft: Tensor,
gt: Tensor,
seg_dtm: Tensor,
gt_dtm: Tensor,
class_weight=None,
ignore_index=255) -> Tensor:
"""
compute huasdorff distance loss for segmentation
Args:
seg_soft: softmax results, shape=(b,c,x,y)
gt: ground truth, shape=(b,x,y)
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
Returns:
output: hd_loss
"""
assert seg_soft.shape[0] == gt.shape[0]
total_loss = 0
num_class = seg_soft.shape[1]
if class_weight is not None:
assert class_weight.ndim == num_class
for i in range(1, num_class):
if i != ignore_index:
delta_s = (seg_soft[:, i, ...] - gt.float())**2
s_dtm = seg_dtm[:, i, ...]**2
g_dtm = gt_dtm[:, i, ...]**2
dtm = s_dtm + g_dtm
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
hd_loss = multiplied.mean()
if class_weight is not None:
hd_loss *= class_weight[i]
total_loss += hd_loss
return total_loss / num_class
@MODELS.register_module()
class HuasdorffDisstanceLoss(nn.Module):
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
Maps Boost Segmentation CNNs: An Empirical Study.
<http://proceedings.mlr.press/v121/ma20b.html>`_.
Args:
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float): Weight of the loss. Defaults to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255,
loss_name='loss_huasdorff_disstance',
**kwargs):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name
self.ignore_index = ignore_index
def forward(self,
pred: Tensor,
target: Tensor,
avg_factor=None,
reduction_override=None,
**kwargs) -> Tensor:
"""Forward function.
Args:
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
target (Tensor): Ground truth of the image. (B, H, W)
avg_factor (int, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used
to override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
Tensor: Loss tensor.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None
pred_soft = F.softmax(pred, dim=1)
valid_mask = (target != self.ignore_index).long()
target = target * valid_mask
with torch.no_grad():
gt_dtm = compute_dtm(target.cpu(), pred_soft)
gt_dtm = gt_dtm.float()
seg_dtm2 = compute_dtm(
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
seg_dtm2 = seg_dtm2.float()
loss_hd = self.loss_weight * hd_loss(
pred_soft,
target,
seg_dtm=seg_dtm2,
gt_dtm=gt_dtm,
reduction=reduction,
avg_factor=avg_factor,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss_hd
@property
def loss_name(self):
return self._loss_name
|