|
print("Importing standard...") |
|
from abc import ABC, abstractmethod |
|
|
|
print("Importing external...") |
|
import torch |
|
from torch.nn.functional import binary_cross_entropy |
|
|
|
|
|
|
|
print("Importing internal...") |
|
from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou |
|
|
|
|
|
|
|
def my_lovasz_hinge(logits, gt, downsample=False): |
|
if downsample: |
|
offset = int(torch.randint(downsample - 1, (1,))) |
|
logits, gt = logits[:, offset::downsample], gt[:, offset::downsample] |
|
|
|
gt = 1.0 * gt |
|
areas = gt.sum(dim=1, keepdims=True) |
|
|
|
signs = 2 * gt - 1 |
|
errors = 1 - logits * signs |
|
errors_sorted, perm = torch.sort(errors, dim=1, descending=True) |
|
gt_sorted = torch.gather(gt, 1, perm) |
|
|
|
intersection = areas - gt_sorted.cumsum(dim=1) |
|
union = areas + (1 - gt_sorted).cumsum(dim=1) |
|
jaccard = 1 - intersection / union |
|
jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1] |
|
loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) |
|
return torch.nanmean(loss) |
|
|
|
|
|
def focal_loss(scores, targets, alpha=0.25, gamma=2): |
|
p = scores |
|
ce_loss = binary_cross_entropy(p, targets, reduction="none") |
|
p_t = p * targets + (1 - p) * (1 - targets) |
|
loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
|
if alpha >= 0: |
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
|
loss = alpha_t * loss |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_distances(features, refs, sigma, norm_p, square_distances, H, W): |
|
|
|
|
|
|
|
B, M = refs.shape[0], refs.shape[1] |
|
distances = torch.norm( |
|
features - refs, dim=2, p=norm_p, keepdim=True |
|
) |
|
distances = distances**2 if square_distances else distances |
|
distances = (distances / (2 * sigma**2)).reshape(B, M, H * W) |
|
return distances |
|
|
|
|
|
def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction): |
|
|
|
|
|
assert activation in ["sigmoid", "symlog"] |
|
if masks is None: |
|
B, M = 1, 1 |
|
F, N = sorted(features.shape) |
|
H, W = [int(N ** (0.5))] * 2 |
|
features = features.reshape(1, 1, -1, H * W) |
|
else: |
|
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) |
|
|
|
|
|
if use_sigma: |
|
sigma = torch.nn.functional.softplus(features)[:, :, -1:] |
|
features = features[:, :, :-1] |
|
F = features.shape[2] |
|
else: |
|
sigma = 1 |
|
features = symlog(features) if activation == "symlog" else torch.sigmoid(features) |
|
if offset_pos: |
|
assert F >= 2 |
|
row, col = get_row_col(H, W, features.device) |
|
row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) |
|
col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) |
|
positional_features = torch.cat([row, col], dim=2) |
|
features[:, :, :2] = features[:, :, :2] + positional_features |
|
prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None |
|
if masks is None: |
|
features = features.reshape(-1, H * W) |
|
sigma = sigma.reshape(-1, H * W) if use_sigma else 1 |
|
return features, sigma, H, W |
|
return features, masks, sigma, prediction, B, M, F, H, W |
|
|
|
|
|
class AbstractLoss(ABC): |
|
@staticmethod |
|
@abstractmethod |
|
def loss(features, masks, ret_prediction=False, **kwargs): |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def get_mask_from_query(features, sindex, **kwargs): |
|
pass |
|
|
|
|
|
class IISLoss(AbstractLoss): |
|
@staticmethod |
|
def loss(features, masks, ret_prediction=False, K=3, logger=None): |
|
features, masks, sigma, prediction, B, M, F, H, W = activate( |
|
features, masks, "symlog", False, False, ret_prediction |
|
) |
|
rindices = torch.randperm(H * W, device=masks.device) |
|
|
|
sindices = torch.stack( |
|
[ |
|
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) |
|
for b in range(B) |
|
] |
|
) |
|
feats_at_sindices = torch.gather( |
|
features.permute(0, 3, 1, 2).expand(B, H * W, K, F), |
|
dim=1, |
|
index=sindices.reshape(B, M, K, 1).expand(B, M, K, F), |
|
) |
|
feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) |
|
dists = get_distances( |
|
features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W |
|
) |
|
score = torch.exp(-dists) |
|
targets = ( |
|
masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float() |
|
) |
|
floss = focal_loss(score, targets).mean() |
|
lloss = my_lovasz_hinge( |
|
score.view(B * M * K, H * W) * 2 - 1, |
|
targets.view(B * M * K, H * W), |
|
) |
|
loss = floss + lloss |
|
return loss, prediction |
|
|
|
@staticmethod |
|
def get_mask_from_query(features, sindex): |
|
features, _, H, W = activate(features, None, "symlog", False, False, False) |
|
F = features.shape[0] |
|
query_feat = features[:, sindex] |
|
dists = get_distances( |
|
features.reshape(1, 1, F, H * W), |
|
query_feat.reshape(1, 1, F, 1), |
|
1, |
|
2, |
|
True, |
|
H, |
|
W, |
|
) |
|
score = torch.exp(-dists) |
|
pred = score > 0.5 |
|
return pred |
|
|
|
|
|
def iis_iou(features, masks, get_mask_from_query, K=20): |
|
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) |
|
|
|
|
|
rindices = torch.randperm(H * W).to(masks.device) |
|
sindices = torch.stack( |
|
[ |
|
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) |
|
for b in range(B) |
|
] |
|
) |
|
cum_iou, n_samples = 0, 0 |
|
for b in range(B): |
|
for m in range(M): |
|
for k in range(K): |
|
sindex = sindices[b, m, k] |
|
pred = get_mask_from_query(features[b, 0], sindex) |
|
iou = calculate_iou(pred, masks[b, m, 0, :]) |
|
cum_iou += iou |
|
n_samples += 1 |
|
|
|
return cum_iou / n_samples |
|
|
|
|
|
losses_names = [ |
|
"iis", |
|
] |
|
|
|
|
|
|
|
def get_loss_class(loss_name): |
|
if loss_name == "iis": |
|
return IISLoss |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def get_get_mask_from_query(loss_name): |
|
loss_class = get_loss_class(loss_name) |
|
return loss_class.get_mask_from_query |
|
|
|
|
|
def get_loss(loss_name): |
|
loss_class = get_loss_class(loss_name) |
|
return loss_class.loss |
|
|