|
|
|
from mmdet.core import BitmapMasks |
|
|
|
from mmocr.models.builder import LOSSES |
|
from mmocr.utils import check_argument |
|
from . import PANLoss |
|
|
|
|
|
@LOSSES.register_module() |
|
class PSELoss(PANLoss): |
|
r"""The class for implementing PSENet loss. This is partially adapted from |
|
https://github.com/whai362/PSENet. |
|
|
|
PSENet: `Shape Robust Text Detection with |
|
Progressive Scale Expansion Network <https://arxiv.org/abs/1806.02559>`_. |
|
|
|
Args: |
|
alpha (float): Text loss coefficient, and :math:`1-\alpha` is the |
|
kernel loss coefficient. |
|
ohem_ratio (float): The negative/positive ratio in ohem. |
|
reduction (str): The way to reduce the loss. Available options are |
|
"mean" and "sum". |
|
""" |
|
|
|
def __init__(self, |
|
alpha=0.7, |
|
ohem_ratio=3, |
|
reduction='mean', |
|
kernel_sample_type='adaptive'): |
|
super().__init__() |
|
assert reduction in ['mean', 'sum' |
|
], "reduction must be either of ['mean','sum']" |
|
self.alpha = alpha |
|
self.ohem_ratio = ohem_ratio |
|
self.reduction = reduction |
|
self.kernel_sample_type = kernel_sample_type |
|
|
|
def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask): |
|
"""Compute PSENet loss. |
|
|
|
Args: |
|
score_maps (tensor): The output tensor with size of Nx6xHxW. |
|
downsample_ratio (float): The downsample ratio between score_maps |
|
and the input img. |
|
gt_kernels (list[BitmapMasks]): The kernel list with each element |
|
being the text kernel mask for one img. |
|
gt_mask (list[BitmapMasks]): The effective mask list |
|
with each element being the effective mask for one img. |
|
|
|
Returns: |
|
dict: A loss dict with ``loss_text`` and ``loss_kernel``. |
|
""" |
|
|
|
assert check_argument.is_type_list(gt_kernels, BitmapMasks) |
|
assert check_argument.is_type_list(gt_mask, BitmapMasks) |
|
assert isinstance(downsample_ratio, float) |
|
losses = [] |
|
|
|
pred_texts = score_maps[:, 0, :, :] |
|
pred_kernels = score_maps[:, 1:, :, :] |
|
feature_sz = score_maps.size() |
|
|
|
gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels] |
|
gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:]) |
|
gt_kernels = [item.to(score_maps.device) for item in gt_kernels] |
|
|
|
gt_mask = [item.rescale(downsample_ratio) for item in gt_mask] |
|
gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:]) |
|
gt_mask = [item.to(score_maps.device) for item in gt_mask] |
|
|
|
|
|
sampled_masks_text = self.ohem_batch(pred_texts.detach(), |
|
gt_kernels[0], gt_mask[0]) |
|
loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0], |
|
sampled_masks_text) |
|
losses.append(self.alpha * loss_texts) |
|
|
|
|
|
if self.kernel_sample_type == 'hard': |
|
sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * ( |
|
gt_mask[0].float()) |
|
elif self.kernel_sample_type == 'adaptive': |
|
sampled_masks_kernel = (pred_texts > 0).float() * ( |
|
gt_mask[0].float()) |
|
else: |
|
raise NotImplementedError |
|
|
|
num_kernel = pred_kernels.shape[1] |
|
assert num_kernel == len(gt_kernels) - 1 |
|
loss_list = [] |
|
for idx in range(num_kernel): |
|
loss_kernels = self.dice_loss_with_logits( |
|
pred_kernels[:, idx, :, :], gt_kernels[1 + idx], |
|
sampled_masks_kernel) |
|
loss_list.append(loss_kernels) |
|
|
|
losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list)) |
|
|
|
if self.reduction == 'mean': |
|
losses = [item.mean() for item in losses] |
|
elif self.reduction == 'sum': |
|
losses = [item.sum() for item in losses] |
|
else: |
|
raise NotImplementedError |
|
results = dict(loss_text=losses[0], loss_kernel=losses[1]) |
|
return results |
|
|