Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Tuple | |
import torch | |
from mmcv.ops import point_sample | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmseg.registry import TASK_UTILS | |
from mmseg.utils import ConfigType, SampleList | |
def seg_data_to_instance_data(ignore_index: int, | |
batch_data_samples: SampleList): | |
"""Convert the paradigm of ground truth from semantic segmentation to | |
instance segmentation. | |
Args: | |
ignore_index (int): The label index to be ignored. | |
batch_data_samples (List[SegDataSample]): The Data | |
Samples. It usually includes information such as | |
`gt_sem_seg`. | |
Returns: | |
tuple[Tensor]: A tuple contains two lists. | |
- batch_gt_instances (List[InstanceData]): Batch of | |
gt_instance. It usually includes ``labels``, each is | |
unique ground truth label id of images, with | |
shape (num_gt, ) and ``masks``, each is ground truth | |
masks of each instances of a image, shape (num_gt, h, w). | |
- batch_img_metas (List[Dict]): List of image meta information. | |
""" | |
batch_gt_instances = [] | |
for data_sample in batch_data_samples: | |
gt_sem_seg = data_sample.gt_sem_seg.data | |
classes = torch.unique( | |
gt_sem_seg, | |
sorted=False, | |
return_inverse=False, | |
return_counts=False) | |
# remove ignored region | |
gt_labels = classes[classes != ignore_index] | |
masks = [] | |
for class_id in gt_labels: | |
masks.append(gt_sem_seg == class_id) | |
if len(masks) == 0: | |
gt_masks = torch.zeros( | |
(0, gt_sem_seg.shape[-2], | |
gt_sem_seg.shape[-1])).to(gt_sem_seg).long() | |
else: | |
gt_masks = torch.stack(masks).squeeze(1).long() | |
instance_data = InstanceData(labels=gt_labels, masks=gt_masks) | |
batch_gt_instances.append(instance_data) | |
return batch_gt_instances | |
class MatchMasks: | |
"""Match the predictions to category labels. | |
Args: | |
num_points (int): the number of sampled points to compute cost. | |
num_queries (int): the number of prediction masks. | |
num_classes (int): the number of classes. | |
assigner (BaseAssigner): the assigner to compute matching. | |
""" | |
def __init__(self, | |
num_points: int, | |
num_queries: int, | |
num_classes: int, | |
assigner: ConfigType = None): | |
assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \ | |
'cannot be None' | |
assert num_points > 0, 'num_points should be a positive integer.' | |
self.num_points = num_points | |
self.num_queries = num_queries | |
self.num_classes = num_classes | |
self.assigner = TASK_UTILS.build(assigner) | |
def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor], | |
batch_gt_instances: List[InstanceData]) -> Tuple: | |
"""Compute best mask matches for all images for a decoder layer. | |
Args: | |
cls_scores (List[Tensor]): Mask score logits from a single | |
decoder layer for all images. Each with shape (num_queries, | |
cls_out_channels). | |
mask_preds (List[Tensor]): Mask logits from a single decoder | |
layer for all images. Each with shape (num_queries, h, w). | |
batch_gt_instances (List[InstanceData]): each contains | |
``labels`` and ``masks``. | |
Returns: | |
tuple: a tuple containing the following targets. | |
- labels (List[Tensor]): Labels of all images.\ | |
Each with shape (num_queries, ). | |
- mask_targets (List[Tensor]): Mask targets of\ | |
all images. Each with shape (num_queries, h, w). | |
- mask_weights (List[Tensor]): Mask weights of\ | |
all images. Each with shape (num_queries, ). | |
- avg_factor (int): Average factor that is used to | |
average the loss. `avg_factor` is usually equal | |
to the number of positive priors. | |
""" | |
batch_size = cls_scores.shape[0] | |
results = dict({ | |
'labels': [], | |
'mask_targets': [], | |
'mask_weights': [], | |
}) | |
for i in range(batch_size): | |
labels, mask_targets, mask_weights\ | |
= self._get_targets_single(cls_scores[i], | |
mask_preds[i], | |
batch_gt_instances[i]) | |
results['labels'].append(labels) | |
results['mask_targets'].append(mask_targets) | |
results['mask_weights'].append(mask_weights) | |
# shape (batch_size, num_queries) | |
labels = torch.stack(results['labels'], dim=0) | |
# shape (batch_size, num_gts, h, w) | |
mask_targets = torch.cat(results['mask_targets'], dim=0) | |
# shape (batch_size, num_queries) | |
mask_weights = torch.stack(results['mask_weights'], dim=0) | |
avg_factor = sum( | |
[len(gt_instances.labels) for gt_instances in batch_gt_instances]) | |
res = (labels, mask_targets, mask_weights, avg_factor) | |
return res | |
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, | |
gt_instances: InstanceData) \ | |
-> Tuple[Tensor, Tensor, Tensor]: | |
"""Compute a set of best mask matches for one image. | |
Args: | |
cls_score (Tensor): Mask score logits from a single decoder layer | |
for one image. Shape (num_queries, cls_out_channels). | |
mask_pred (Tensor): Mask logits for a single decoder layer for one | |
image. Shape (num_queries, h, w). | |
gt_instances (:obj:`InstanceData`): It contains ``labels`` and | |
``masks``. | |
Returns: | |
tuple[Tensor]: A tuple containing the following for one image. | |
- labels (Tensor): Labels of each image. \ | |
shape (num_queries, ). | |
- mask_targets (Tensor): Mask targets of each image. \ | |
shape (num_queries, h, w). | |
- mask_weights (Tensor): Mask weights of each image. \ | |
shape (num_queries, ). | |
""" | |
gt_labels = gt_instances.labels | |
gt_masks = gt_instances.masks | |
# when "gt_labels" is empty, classify all queries to background | |
if len(gt_labels) == 0: | |
labels = gt_labels.new_full((self.num_queries, ), | |
self.num_classes, | |
dtype=torch.long) | |
mask_targets = gt_labels | |
mask_weights = gt_labels.new_zeros((self.num_queries, )) | |
return labels, mask_targets, mask_weights | |
# sample points | |
num_queries = cls_score.shape[0] | |
num_gts = gt_labels.shape[0] | |
point_coords = torch.rand((1, self.num_points, 2), | |
device=cls_score.device) | |
# shape (num_queries, num_points) | |
mask_points_pred = point_sample( | |
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, | |
1)).squeeze(1) | |
# shape (num_gts, num_points) | |
gt_points_masks = point_sample( | |
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, | |
1)).squeeze(1) | |
sampled_gt_instances = InstanceData( | |
labels=gt_labels, masks=gt_points_masks) | |
sampled_pred_instances = InstanceData( | |
scores=cls_score, masks=mask_points_pred) | |
# assign and sample | |
matched_quiery_inds, matched_label_inds = self.assigner.assign( | |
pred_instances=sampled_pred_instances, | |
gt_instances=sampled_gt_instances) | |
labels = gt_labels.new_full((self.num_queries, ), | |
self.num_classes, | |
dtype=torch.long) | |
labels[matched_quiery_inds] = gt_labels[matched_label_inds] | |
mask_weights = gt_labels.new_zeros((self.num_queries, )) | |
mask_weights[matched_quiery_inds] = 1 | |
mask_targets = gt_masks[matched_label_inds] | |
return labels, mask_targets, mask_weights | |