Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import abstractmethod | |
from typing import Union | |
import torch | |
import torch.nn.functional as F | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmseg.registry import TASK_UTILS | |
class BaseMatchCost: | |
"""Base match cost class. | |
Args: | |
weight (Union[float, int]): Cost weight. Defaults to 1. | |
""" | |
def __init__(self, weight: Union[float, int] = 1.) -> None: | |
self.weight = weight | |
def __call__(self, pred_instances: InstanceData, | |
gt_instances: InstanceData, **kwargs) -> Tensor: | |
"""Compute match cost. | |
Args: | |
pred_instances (InstanceData): Instances of model predictions. | |
It often includes "labels" and "scores". | |
gt_instances (InstanceData): Ground truth of instance | |
annotations. It usually includes "labels". | |
Returns: | |
Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
""" | |
pass | |
class ClassificationCost(BaseMatchCost): | |
"""ClsSoftmaxCost. | |
Args: | |
weight (Union[float, int]): Cost weight. Defaults to 1. | |
Examples: | |
>>> from mmseg.models.assigners import ClassificationCost | |
>>> import torch | |
>>> self = ClassificationCost() | |
>>> cls_pred = torch.rand(4, 3) | |
>>> gt_labels = torch.tensor([0, 1, 2]) | |
>>> factor = torch.tensor([10, 8, 10, 8]) | |
>>> self(cls_pred, gt_labels) | |
tensor([[-0.3430, -0.3525, -0.3045], | |
[-0.3077, -0.2931, -0.3992], | |
[-0.3664, -0.3455, -0.2881], | |
[-0.3343, -0.2701, -0.3956]]) | |
""" | |
def __init__(self, weight: Union[float, int] = 1) -> None: | |
super().__init__(weight=weight) | |
def __call__(self, pred_instances: InstanceData, | |
gt_instances: InstanceData, **kwargs) -> Tensor: | |
"""Compute match cost. | |
Args: | |
pred_instances (InstanceData): "scores" inside is | |
predicted classification logits, of shape | |
(num_queries, num_class). | |
gt_instances (InstanceData): "labels" inside should have | |
shape (num_gt, ). | |
Returns: | |
Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
""" | |
assert hasattr(pred_instances, 'scores'), \ | |
"pred_instances must contain 'scores'" | |
assert hasattr(gt_instances, 'labels'), \ | |
"gt_instances must contain 'labels'" | |
pred_scores = pred_instances.scores | |
gt_labels = gt_instances.labels | |
pred_scores = pred_scores.softmax(-1) | |
cls_cost = -pred_scores[:, gt_labels] | |
return cls_cost * self.weight | |
class DiceCost(BaseMatchCost): | |
"""Cost of mask assignments based on dice losses. | |
Args: | |
pred_act (bool): Whether to apply sigmoid to mask_pred. | |
Defaults to False. | |
eps (float): Defaults to 1e-3. | |
naive_dice (bool): If True, use the naive dice loss | |
in which the power of the number in the denominator is | |
the first power. If False, use the second power that | |
is adopted by K-Net and SOLO. Defaults to True. | |
weight (Union[float, int]): Cost weight. Defaults to 1. | |
""" | |
def __init__(self, | |
pred_act: bool = False, | |
eps: float = 1e-3, | |
naive_dice: bool = True, | |
weight: Union[float, int] = 1.) -> None: | |
super().__init__(weight=weight) | |
self.pred_act = pred_act | |
self.eps = eps | |
self.naive_dice = naive_dice | |
def _binary_mask_dice_loss(self, mask_preds: Tensor, | |
gt_masks: Tensor) -> Tensor: | |
""" | |
Args: | |
mask_preds (Tensor): Mask prediction in shape (num_queries, *). | |
gt_masks (Tensor): Ground truth in shape (num_gt, *) | |
store 0 or 1, 0 for negative class and 1 for | |
positive class. | |
Returns: | |
Tensor: Dice cost matrix in shape (num_queries, num_gt). | |
""" | |
mask_preds = mask_preds.flatten(1) | |
gt_masks = gt_masks.flatten(1).float() | |
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) | |
if self.naive_dice: | |
denominator = mask_preds.sum(-1)[:, None] + \ | |
gt_masks.sum(-1)[None, :] | |
else: | |
denominator = mask_preds.pow(2).sum(1)[:, None] + \ | |
gt_masks.pow(2).sum(1)[None, :] | |
loss = 1 - (numerator + self.eps) / (denominator + self.eps) | |
return loss | |
def __call__(self, pred_instances: InstanceData, | |
gt_instances: InstanceData, **kwargs) -> Tensor: | |
"""Compute match cost. | |
Args: | |
pred_instances (InstanceData): Predicted instances which | |
must contain "masks". | |
gt_instances (InstanceData): Ground truth which must contain | |
"mask". | |
Returns: | |
Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
""" | |
assert hasattr(pred_instances, 'masks'), \ | |
"pred_instances must contain 'masks'" | |
assert hasattr(gt_instances, 'masks'), \ | |
"gt_instances must contain 'masks'" | |
pred_masks = pred_instances.masks | |
gt_masks = gt_instances.masks | |
if self.pred_act: | |
pred_masks = pred_masks.sigmoid() | |
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks) | |
return dice_cost * self.weight | |
class CrossEntropyLossCost(BaseMatchCost): | |
"""CrossEntropyLossCost. | |
Args: | |
use_sigmoid (bool): Whether the prediction uses sigmoid | |
of softmax. Defaults to True. | |
weight (Union[float, int]): Cost weight. Defaults to 1. | |
""" | |
def __init__(self, | |
use_sigmoid: bool = True, | |
weight: Union[float, int] = 1.) -> None: | |
super().__init__(weight=weight) | |
self.use_sigmoid = use_sigmoid | |
def _binary_cross_entropy(self, cls_pred: Tensor, | |
gt_labels: Tensor) -> Tensor: | |
""" | |
Args: | |
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or | |
(num_queries, *). | |
gt_labels (Tensor): The learning label of prediction with | |
shape (num_gt, *). | |
Returns: | |
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt). | |
""" | |
cls_pred = cls_pred.flatten(1).float() | |
gt_labels = gt_labels.flatten(1).float() | |
n = cls_pred.shape[1] | |
pos = F.binary_cross_entropy_with_logits( | |
cls_pred, torch.ones_like(cls_pred), reduction='none') | |
neg = F.binary_cross_entropy_with_logits( | |
cls_pred, torch.zeros_like(cls_pred), reduction='none') | |
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ | |
torch.einsum('nc,mc->nm', neg, 1 - gt_labels) | |
cls_cost = cls_cost / n | |
return cls_cost | |
def __call__(self, pred_instances: InstanceData, | |
gt_instances: InstanceData, **kwargs) -> Tensor: | |
"""Compute match cost. | |
Args: | |
pred_instances (:obj:`InstanceData`): Predicted instances which | |
must contain ``masks``. | |
gt_instances (:obj:`InstanceData`): Ground truth which must contain | |
``masks``. | |
Returns: | |
Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
""" | |
assert hasattr(pred_instances, 'masks'), \ | |
"pred_instances must contain 'masks'" | |
assert hasattr(gt_instances, 'masks'), \ | |
"gt_instances must contain 'masks'" | |
pred_masks = pred_instances.masks | |
gt_masks = gt_instances.masks | |
if self.use_sigmoid: | |
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks) | |
else: | |
raise NotImplementedError | |
return cls_cost * self.weight | |