# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple import torch from mmcv.ops import point_sample from mmengine.structures import InstanceData from torch import Tensor from mmseg.registry import TASK_UTILS from mmseg.utils import ConfigType, SampleList def seg_data_to_instance_data(ignore_index: int, batch_data_samples: SampleList): """Convert the paradigm of ground truth from semantic segmentation to instance segmentation. Args: ignore_index (int): The label index to be ignored. batch_data_samples (List[SegDataSample]): The Data Samples. It usually includes information such as `gt_sem_seg`. Returns: tuple[Tensor]: A tuple contains two lists. - batch_gt_instances (List[InstanceData]): Batch of gt_instance. It usually includes ``labels``, each is unique ground truth label id of images, with shape (num_gt, ) and ``masks``, each is ground truth masks of each instances of a image, shape (num_gt, h, w). - batch_img_metas (List[Dict]): List of image meta information. """ batch_gt_instances = [] for data_sample in batch_data_samples: gt_sem_seg = data_sample.gt_sem_seg.data classes = torch.unique( gt_sem_seg, sorted=False, return_inverse=False, return_counts=False) # remove ignored region gt_labels = classes[classes != ignore_index] masks = [] for class_id in gt_labels: masks.append(gt_sem_seg == class_id) if len(masks) == 0: gt_masks = torch.zeros( (0, gt_sem_seg.shape[-2], gt_sem_seg.shape[-1])).to(gt_sem_seg).long() else: gt_masks = torch.stack(masks).squeeze(1).long() instance_data = InstanceData(labels=gt_labels, masks=gt_masks) batch_gt_instances.append(instance_data) return batch_gt_instances class MatchMasks: """Match the predictions to category labels. Args: num_points (int): the number of sampled points to compute cost. num_queries (int): the number of prediction masks. num_classes (int): the number of classes. assigner (BaseAssigner): the assigner to compute matching. """ def __init__(self, num_points: int, num_queries: int, num_classes: int, assigner: ConfigType = None): assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \ 'cannot be None' assert num_points > 0, 'num_points should be a positive integer.' self.num_points = num_points self.num_queries = num_queries self.num_classes = num_classes self.assigner = TASK_UTILS.build(assigner) def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor], batch_gt_instances: List[InstanceData]) -> Tuple: """Compute best mask matches for all images for a decoder layer. Args: cls_scores (List[Tensor]): Mask score logits from a single decoder layer for all images. Each with shape (num_queries, cls_out_channels). mask_preds (List[Tensor]): Mask logits from a single decoder layer for all images. Each with shape (num_queries, h, w). batch_gt_instances (List[InstanceData]): each contains ``labels`` and ``masks``. Returns: tuple: a tuple containing the following targets. - labels (List[Tensor]): Labels of all images.\ Each with shape (num_queries, ). - mask_targets (List[Tensor]): Mask targets of\ all images. Each with shape (num_queries, h, w). - mask_weights (List[Tensor]): Mask weights of\ all images. Each with shape (num_queries, ). - avg_factor (int): Average factor that is used to average the loss. `avg_factor` is usually equal to the number of positive priors. """ batch_size = cls_scores.shape[0] results = dict({ 'labels': [], 'mask_targets': [], 'mask_weights': [], }) for i in range(batch_size): labels, mask_targets, mask_weights\ = self._get_targets_single(cls_scores[i], mask_preds[i], batch_gt_instances[i]) results['labels'].append(labels) results['mask_targets'].append(mask_targets) results['mask_weights'].append(mask_weights) # shape (batch_size, num_queries) labels = torch.stack(results['labels'], dim=0) # shape (batch_size, num_gts, h, w) mask_targets = torch.cat(results['mask_targets'], dim=0) # shape (batch_size, num_queries) mask_weights = torch.stack(results['mask_weights'], dim=0) avg_factor = sum( [len(gt_instances.labels) for gt_instances in batch_gt_instances]) res = (labels, mask_targets, mask_weights, avg_factor) return res def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, gt_instances: InstanceData) \ -> Tuple[Tensor, Tensor, Tensor]: """Compute a set of best mask matches for one image. Args: cls_score (Tensor): Mask score logits from a single decoder layer for one image. Shape (num_queries, cls_out_channels). mask_pred (Tensor): Mask logits for a single decoder layer for one image. Shape (num_queries, h, w). gt_instances (:obj:`InstanceData`): It contains ``labels`` and ``masks``. Returns: tuple[Tensor]: A tuple containing the following for one image. - labels (Tensor): Labels of each image. \ shape (num_queries, ). - mask_targets (Tensor): Mask targets of each image. \ shape (num_queries, h, w). - mask_weights (Tensor): Mask weights of each image. \ shape (num_queries, ). """ gt_labels = gt_instances.labels gt_masks = gt_instances.masks # when "gt_labels" is empty, classify all queries to background if len(gt_labels) == 0: labels = gt_labels.new_full((self.num_queries, ), self.num_classes, dtype=torch.long) mask_targets = gt_labels mask_weights = gt_labels.new_zeros((self.num_queries, )) return labels, mask_targets, mask_weights # sample points num_queries = cls_score.shape[0] num_gts = gt_labels.shape[0] point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) # shape (num_queries, num_points) mask_points_pred = point_sample( mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) # shape (num_gts, num_points) gt_points_masks = point_sample( gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) sampled_gt_instances = InstanceData( labels=gt_labels, masks=gt_points_masks) sampled_pred_instances = InstanceData( scores=cls_score, masks=mask_points_pred) # assign and sample matched_quiery_inds, matched_label_inds = self.assigner.assign( pred_instances=sampled_pred_instances, gt_instances=sampled_gt_instances) labels = gt_labels.new_full((self.num_queries, ), self.num_classes, dtype=torch.long) labels[matched_quiery_inds] = gt_labels[matched_label_inds] mask_weights = gt_labels.new_zeros((self.num_queries, )) mask_weights[matched_quiery_inds] = 1 mask_targets = gt_masks[matched_label_inds] return labels, mask_targets, mask_weights