|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from mmdet.core import multi_apply |
|
from torch import nn |
|
|
|
from mmocr.models.builder import LOSSES |
|
|
|
|
|
@LOSSES.register_module() |
|
class FCELoss(nn.Module): |
|
"""The class for implementing FCENet loss. |
|
|
|
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text |
|
Detection <https://arxiv.org/abs/2104.10442>`_ |
|
|
|
Args: |
|
fourier_degree (int) : The maximum Fourier transform degree k. |
|
num_sample (int) : The sampling points number of regression |
|
loss. If it is too small, fcenet tends to be overfitting. |
|
ohem_ratio (float): the negative/positive ratio in OHEM. |
|
""" |
|
|
|
def __init__(self, fourier_degree, num_sample, ohem_ratio=3.): |
|
super().__init__() |
|
self.fourier_degree = fourier_degree |
|
self.num_sample = num_sample |
|
self.ohem_ratio = ohem_ratio |
|
|
|
def forward(self, preds, _, p3_maps, p4_maps, p5_maps): |
|
"""Compute FCENet loss. |
|
|
|
Args: |
|
preds (list[list[Tensor]]): The outer list indicates images |
|
in a batch, and the inner list indicates the classification |
|
prediction map (with shape :math:`(N, C, H, W)`) and |
|
regression map (with shape :math:`(N, C, H, W)`). |
|
p3_maps (list[ndarray]): List of leval 3 ground truth target map |
|
with shape :math:`(C, H, W)`. |
|
p4_maps (list[ndarray]): List of leval 4 ground truth target map |
|
with shape :math:`(C, H, W)`. |
|
p5_maps (list[ndarray]): List of leval 5 ground truth target map |
|
with shape :math:`(C, H, W)`. |
|
|
|
Returns: |
|
dict: A loss dict with ``loss_text``, ``loss_center``, |
|
``loss_reg_x`` and ``loss_reg_y``. |
|
""" |
|
assert isinstance(preds, list) |
|
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\ |
|
'fourier degree not equal in FCEhead and FCEtarget' |
|
|
|
device = preds[0][0].device |
|
|
|
gts = [p3_maps, p4_maps, p5_maps] |
|
for idx, maps in enumerate(gts): |
|
gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device) |
|
|
|
losses = multi_apply(self.forward_single, preds, gts) |
|
|
|
loss_tr = torch.tensor(0., device=device).float() |
|
loss_tcl = torch.tensor(0., device=device).float() |
|
loss_reg_x = torch.tensor(0., device=device).float() |
|
loss_reg_y = torch.tensor(0., device=device).float() |
|
|
|
for idx, loss in enumerate(losses): |
|
if idx == 0: |
|
loss_tr += sum(loss) |
|
elif idx == 1: |
|
loss_tcl += sum(loss) |
|
elif idx == 2: |
|
loss_reg_x += sum(loss) |
|
else: |
|
loss_reg_y += sum(loss) |
|
|
|
results = dict( |
|
loss_text=loss_tr, |
|
loss_center=loss_tcl, |
|
loss_reg_x=loss_reg_x, |
|
loss_reg_y=loss_reg_y, |
|
) |
|
|
|
return results |
|
|
|
def forward_single(self, pred, gt): |
|
cls_pred = pred[0].permute(0, 2, 3, 1).contiguous() |
|
reg_pred = pred[1].permute(0, 2, 3, 1).contiguous() |
|
gt = gt.permute(0, 2, 3, 1).contiguous() |
|
|
|
k = 2 * self.fourier_degree + 1 |
|
tr_pred = cls_pred[:, :, :, :2].view(-1, 2) |
|
tcl_pred = cls_pred[:, :, :, 2:].view(-1, 2) |
|
x_pred = reg_pred[:, :, :, 0:k].view(-1, k) |
|
y_pred = reg_pred[:, :, :, k:2 * k].view(-1, k) |
|
|
|
tr_mask = gt[:, :, :, :1].view(-1) |
|
tcl_mask = gt[:, :, :, 1:2].view(-1) |
|
train_mask = gt[:, :, :, 2:3].view(-1) |
|
x_map = gt[:, :, :, 3:3 + k].view(-1, k) |
|
y_map = gt[:, :, :, 3 + k:].view(-1, k) |
|
|
|
tr_train_mask = train_mask * tr_mask |
|
device = x_map.device |
|
|
|
loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long()) |
|
|
|
|
|
loss_tcl = torch.tensor(0.).float().to(device) |
|
tr_neg_mask = 1 - tr_train_mask |
|
if tr_train_mask.sum().item() > 0: |
|
loss_tcl_pos = F.cross_entropy( |
|
tcl_pred[tr_train_mask.bool()], |
|
tcl_mask[tr_train_mask.bool()].long()) |
|
loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()], |
|
tcl_mask[tr_neg_mask.bool()].long()) |
|
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg |
|
|
|
|
|
loss_reg_x = torch.tensor(0.).float().to(device) |
|
loss_reg_y = torch.tensor(0.).float().to(device) |
|
if tr_train_mask.sum().item() > 0: |
|
weight = (tr_mask[tr_train_mask.bool()].float() + |
|
tcl_mask[tr_train_mask.bool()].float()) / 2 |
|
weight = weight.contiguous().view(-1, 1) |
|
|
|
ft_x, ft_y = self.fourier2poly(x_map, y_map) |
|
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred) |
|
|
|
loss_reg_x = torch.mean(weight * F.smooth_l1_loss( |
|
ft_x_pre[tr_train_mask.bool()], |
|
ft_x[tr_train_mask.bool()], |
|
reduction='none')) |
|
loss_reg_y = torch.mean(weight * F.smooth_l1_loss( |
|
ft_y_pre[tr_train_mask.bool()], |
|
ft_y[tr_train_mask.bool()], |
|
reduction='none')) |
|
|
|
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y |
|
|
|
def ohem(self, predict, target, train_mask): |
|
device = train_mask.device |
|
pos = (target * train_mask).bool() |
|
neg = ((1 - target) * train_mask).bool() |
|
|
|
n_pos = pos.float().sum() |
|
|
|
if n_pos.item() > 0: |
|
loss_pos = F.cross_entropy( |
|
predict[pos], target[pos], reduction='sum') |
|
loss_neg = F.cross_entropy( |
|
predict[neg], target[neg], reduction='none') |
|
n_neg = min( |
|
int(neg.float().sum().item()), |
|
int(self.ohem_ratio * n_pos.float())) |
|
else: |
|
loss_pos = torch.tensor(0.).to(device) |
|
loss_neg = F.cross_entropy( |
|
predict[neg], target[neg], reduction='none') |
|
n_neg = 100 |
|
if len(loss_neg) > n_neg: |
|
loss_neg, _ = torch.topk(loss_neg, n_neg) |
|
|
|
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float() |
|
|
|
def fourier2poly(self, real_maps, imag_maps): |
|
"""Transform Fourier coefficient maps to polygon maps. |
|
|
|
Args: |
|
real_maps (tensor): A map composed of the real parts of the |
|
Fourier coefficients, whose shape is (-1, 2k+1) |
|
imag_maps (tensor):A map composed of the imag parts of the |
|
Fourier coefficients, whose shape is (-1, 2k+1) |
|
|
|
Returns |
|
x_maps (tensor): A map composed of the x value of the polygon |
|
represented by n sample points (xn, yn), whose shape is (-1, n) |
|
y_maps (tensor): A map composed of the y value of the polygon |
|
represented by n sample points (xn, yn), whose shape is (-1, n) |
|
""" |
|
|
|
device = real_maps.device |
|
|
|
k_vect = torch.arange( |
|
-self.fourier_degree, |
|
self.fourier_degree + 1, |
|
dtype=torch.float, |
|
device=device).view(-1, 1) |
|
i_vect = torch.arange( |
|
0, self.num_sample, dtype=torch.float, device=device).view(1, -1) |
|
|
|
transform_matrix = 2 * np.pi / self.num_sample * torch.mm( |
|
k_vect, i_vect) |
|
|
|
x1 = torch.einsum('ak, kn-> an', real_maps, |
|
torch.cos(transform_matrix)) |
|
x2 = torch.einsum('ak, kn-> an', imag_maps, |
|
torch.sin(transform_matrix)) |
|
y1 = torch.einsum('ak, kn-> an', real_maps, |
|
torch.sin(transform_matrix)) |
|
y2 = torch.einsum('ak, kn-> an', imag_maps, |
|
torch.cos(transform_matrix)) |
|
|
|
x_maps = x1 - x2 |
|
y_maps = y1 + y2 |
|
|
|
return x_maps, y_maps |
|
|