|
|
|
import itertools |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from mmdet.core import BitmapMasks |
|
from torch import nn |
|
|
|
from mmocr.models.builder import LOSSES |
|
from mmocr.utils import check_argument |
|
|
|
|
|
@LOSSES.register_module() |
|
class PANLoss(nn.Module): |
|
"""The class for implementing PANet loss. This was partially adapted from |
|
https://github.com/WenmuZhou/PAN.pytorch. |
|
|
|
PANet: `Efficient and Accurate Arbitrary- |
|
Shaped Text Detection with Pixel Aggregation Network |
|
<https://arxiv.org/abs/1908.05900>`_. |
|
|
|
Args: |
|
alpha (float): The kernel loss coef. |
|
beta (float): The aggregation and discriminative loss coef. |
|
delta_aggregation (float): The constant for aggregation loss. |
|
delta_discrimination (float): The constant for discriminative loss. |
|
ohem_ratio (float): The negative/positive ratio in ohem. |
|
reduction (str): The way to reduce the loss. |
|
speedup_bbox_thr (int): Speed up if speedup_bbox_thr > 0 |
|
and < bbox num. |
|
""" |
|
|
|
def __init__(self, |
|
alpha=0.5, |
|
beta=0.25, |
|
delta_aggregation=0.5, |
|
delta_discrimination=3, |
|
ohem_ratio=3, |
|
reduction='mean', |
|
speedup_bbox_thr=-1): |
|
super().__init__() |
|
assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']" |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.delta_aggregation = delta_aggregation |
|
self.delta_discrimination = delta_discrimination |
|
self.ohem_ratio = ohem_ratio |
|
self.reduction = reduction |
|
self.speedup_bbox_thr = speedup_bbox_thr |
|
|
|
def bitmasks2tensor(self, bitmasks, target_sz): |
|
"""Convert Bitmasks to tensor. |
|
|
|
Args: |
|
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is |
|
for one img. |
|
target_sz (tuple(int, int)): The target tensor of size |
|
:math:`(H, W)`. |
|
|
|
Returns: |
|
list[Tensor]: The list of kernel tensors. Each element stands for |
|
one kernel level. |
|
""" |
|
assert check_argument.is_type_list(bitmasks, BitmapMasks) |
|
assert isinstance(target_sz, tuple) |
|
|
|
batch_size = len(bitmasks) |
|
num_masks = len(bitmasks[0]) |
|
|
|
results = [] |
|
|
|
for level_inx in range(num_masks): |
|
kernel = [] |
|
for batch_inx in range(batch_size): |
|
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) |
|
|
|
mask_sz = mask.shape |
|
|
|
pad = [ |
|
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] |
|
] |
|
mask = F.pad(mask, pad, mode='constant', value=0) |
|
kernel.append(mask) |
|
kernel = torch.stack(kernel) |
|
results.append(kernel) |
|
|
|
return results |
|
|
|
def forward(self, preds, downsample_ratio, gt_kernels, gt_mask): |
|
"""Compute PANet loss. |
|
|
|
Args: |
|
preds (Tensor): The output tensor of size :math:`(N, 6, H, W)`. |
|
downsample_ratio (float): The downsample ratio between preds |
|
and the input img. |
|
gt_kernels (list[BitmapMasks]): The kernel list with each element |
|
being the text kernel mask for one img. |
|
gt_mask (list[BitmapMasks]): The effective mask list |
|
with each element being the effective mask for one img. |
|
|
|
Returns: |
|
dict: A loss dict with ``loss_text``, ``loss_kernel``, |
|
``loss_aggregation`` and ``loss_discrimination``. |
|
""" |
|
|
|
assert check_argument.is_type_list(gt_kernels, BitmapMasks) |
|
assert check_argument.is_type_list(gt_mask, BitmapMasks) |
|
assert isinstance(downsample_ratio, float) |
|
|
|
pred_texts = preds[:, 0, :, :] |
|
pred_kernels = preds[:, 1, :, :] |
|
inst_embed = preds[:, 2:, :, :] |
|
feature_sz = preds.size() |
|
|
|
mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask} |
|
gt = {} |
|
for key, value in mapping.items(): |
|
gt[key] = value |
|
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] |
|
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) |
|
gt[key] = [item.to(preds.device) for item in gt[key]] |
|
loss_aggrs, loss_discrs = self.aggregation_discrimination_loss( |
|
gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed) |
|
|
|
sampled_mask = self.ohem_batch(pred_texts.detach(), |
|
gt['gt_kernels'][0], gt['gt_mask'][0]) |
|
loss_texts = self.dice_loss_with_logits(pred_texts, |
|
gt['gt_kernels'][0], |
|
sampled_mask) |
|
|
|
|
|
|
|
sampled_masks_kernel = (gt['gt_kernels'][0] > 0.5).float() * ( |
|
gt['gt_mask'][0].float()) |
|
loss_kernels = self.dice_loss_with_logits(pred_kernels, |
|
gt['gt_kernels'][1], |
|
sampled_masks_kernel) |
|
losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs] |
|
if self.reduction == 'mean': |
|
losses = [item.mean() for item in losses] |
|
elif self.reduction == 'sum': |
|
losses = [item.sum() for item in losses] |
|
else: |
|
raise NotImplementedError |
|
|
|
coefs = [1, self.alpha, self.beta, self.beta] |
|
losses = [item * scale for item, scale in zip(losses, coefs)] |
|
|
|
results = dict() |
|
results.update( |
|
loss_text=losses[0], |
|
loss_kernel=losses[1], |
|
loss_aggregation=losses[2], |
|
loss_discrimination=losses[3]) |
|
return results |
|
|
|
def aggregation_discrimination_loss(self, gt_texts, gt_kernels, |
|
inst_embeds): |
|
"""Compute the aggregation and discrimnative losses. |
|
|
|
Args: |
|
gt_texts (Tensor): The ground truth text mask of size |
|
:math:`(N, 1, H, W)`. |
|
gt_kernels (Tensor): The ground truth text kernel mask of |
|
size :math:`(N, 1, H, W)`. |
|
inst_embeds(Tensor): The text instance embedding tensor |
|
of size :math:`(N, 1, H, W)`. |
|
|
|
Returns: |
|
(Tensor, Tensor): A tuple of aggregation loss and discriminative |
|
loss before reduction. |
|
""" |
|
|
|
batch_size = gt_texts.size()[0] |
|
gt_texts = gt_texts.contiguous().reshape(batch_size, -1) |
|
gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1) |
|
|
|
assert inst_embeds.shape[1] == 4 |
|
inst_embeds = inst_embeds.contiguous().reshape(batch_size, 4, -1) |
|
|
|
loss_aggrs = [] |
|
loss_discrs = [] |
|
|
|
for text, kernel, embed in zip(gt_texts, gt_kernels, inst_embeds): |
|
|
|
|
|
text_num = int(text.max().item()) |
|
loss_aggr_img = [] |
|
kernel_avgs = [] |
|
select_num = self.speedup_bbox_thr |
|
if 0 < select_num < text_num: |
|
inds = np.random.choice( |
|
text_num, select_num, replace=False) + 1 |
|
else: |
|
inds = range(1, text_num + 1) |
|
|
|
for i in inds: |
|
|
|
kernel_i = (kernel == i) |
|
if kernel_i.sum() == 0 or (text == i).sum() == 0: |
|
continue |
|
|
|
|
|
avg = embed[:, kernel_i].mean(1) |
|
kernel_avgs.append(avg) |
|
|
|
embed_i = embed[:, text == i] |
|
|
|
distance = (embed_i - avg.reshape(4, 1)).norm( |
|
2, dim=0) - self.delta_aggregation |
|
|
|
hinge = torch.max( |
|
distance, |
|
torch.tensor(0, device=distance.device, |
|
dtype=torch.float)).pow(2) |
|
|
|
aggr = torch.log(hinge + 1).mean() |
|
loss_aggr_img.append(aggr) |
|
|
|
num_inst = len(loss_aggr_img) |
|
if num_inst > 0: |
|
loss_aggr_img = torch.stack(loss_aggr_img).mean() |
|
else: |
|
loss_aggr_img = torch.tensor( |
|
0, device=gt_texts.device, dtype=torch.float) |
|
loss_aggrs.append(loss_aggr_img) |
|
|
|
loss_discr_img = 0 |
|
for avg_i, avg_j in itertools.combinations(kernel_avgs, 2): |
|
|
|
distance_ij = self.delta_discrimination - (avg_i - |
|
avg_j).norm(2) |
|
|
|
D_ij = torch.max( |
|
distance_ij, |
|
torch.tensor( |
|
0, device=distance_ij.device, |
|
dtype=torch.float)).pow(2) |
|
loss_discr_img += torch.log(D_ij + 1) |
|
|
|
if num_inst > 1: |
|
loss_discr_img /= (num_inst * (num_inst - 1)) |
|
else: |
|
loss_discr_img = torch.tensor( |
|
0, device=gt_texts.device, dtype=torch.float) |
|
if num_inst == 0: |
|
warnings.warn('num of instance is 0') |
|
loss_discrs.append(loss_discr_img) |
|
return torch.stack(loss_aggrs), torch.stack(loss_discrs) |
|
|
|
def dice_loss_with_logits(self, pred, target, mask): |
|
|
|
smooth = 0.001 |
|
|
|
pred = torch.sigmoid(pred) |
|
target[target <= 0.5] = 0 |
|
target[target > 0.5] = 1 |
|
pred = pred.contiguous().view(pred.size()[0], -1) |
|
target = target.contiguous().view(target.size()[0], -1) |
|
mask = mask.contiguous().view(mask.size()[0], -1) |
|
|
|
pred = pred * mask |
|
target = target * mask |
|
|
|
a = torch.sum(pred * target, 1) + smooth |
|
b = torch.sum(pred * pred, 1) + smooth |
|
c = torch.sum(target * target, 1) + smooth |
|
d = (2 * a) / (b + c) |
|
return 1 - d |
|
|
|
def ohem_img(self, text_score, gt_text, gt_mask): |
|
"""Sample the top-k maximal negative samples and all positive samples. |
|
|
|
Args: |
|
text_score (Tensor): The text score of size :math:`(H, W)`. |
|
gt_text (Tensor): The ground truth text mask of size |
|
:math:`(H, W)`. |
|
gt_mask (Tensor): The effective region mask of size :math:`(H, W)`. |
|
|
|
Returns: |
|
Tensor: The sampled pixel mask of size :math:`(H, W)`. |
|
""" |
|
assert isinstance(text_score, torch.Tensor) |
|
assert isinstance(gt_text, torch.Tensor) |
|
assert isinstance(gt_mask, torch.Tensor) |
|
assert len(text_score.shape) == 2 |
|
assert text_score.shape == gt_text.shape |
|
assert gt_text.shape == gt_mask.shape |
|
|
|
pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)( |
|
torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item()) |
|
neg_num = (int)(torch.sum(gt_text <= 0.5).item()) |
|
neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) |
|
|
|
if pos_num == 0 or neg_num == 0: |
|
warnings.warn('pos_num = 0 or neg_num = 0') |
|
return gt_mask.bool() |
|
|
|
neg_score = text_score[gt_text <= 0.5] |
|
neg_score_sorted, _ = torch.sort(neg_score, descending=True) |
|
threshold = neg_score_sorted[neg_num - 1] |
|
sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * ( |
|
gt_mask > 0.5) |
|
return sampled_mask |
|
|
|
def ohem_batch(self, text_scores, gt_texts, gt_mask): |
|
"""OHEM sampling for a batch of imgs. |
|
|
|
Args: |
|
text_scores (Tensor): The text scores of size :math:`(H, W)`. |
|
gt_texts (Tensor): The gt text masks of size :math:`(H, W)`. |
|
gt_mask (Tensor): The gt effective mask of size :math:`(H, W)`. |
|
|
|
Returns: |
|
Tensor: The sampled mask of size :math:`(H, W)`. |
|
""" |
|
assert isinstance(text_scores, torch.Tensor) |
|
assert isinstance(gt_texts, torch.Tensor) |
|
assert isinstance(gt_mask, torch.Tensor) |
|
assert len(text_scores.shape) == 3 |
|
assert text_scores.shape == gt_texts.shape |
|
assert gt_texts.shape == gt_mask.shape |
|
|
|
sampled_masks = [] |
|
for i in range(text_scores.shape[0]): |
|
sampled_masks.append( |
|
self.ohem_img(text_scores[i], gt_texts[i], gt_mask[i])) |
|
|
|
sampled_masks = torch.stack(sampled_masks) |
|
|
|
return sampled_masks |
|
|