Spaces:
Runtime error
Runtime error
File size: 8,489 Bytes
412c852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from mmcv.ops import point_sample
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import TASK_UTILS
from mmseg.utils import ConfigType, SampleList
def seg_data_to_instance_data(ignore_index: int,
batch_data_samples: SampleList):
"""Convert the paradigm of ground truth from semantic segmentation to
instance segmentation.
Args:
ignore_index (int): The label index to be ignored.
batch_data_samples (List[SegDataSample]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two lists.
- batch_gt_instances (List[InstanceData]): Batch of
gt_instance. It usually includes ``labels``, each is
unique ground truth label id of images, with
shape (num_gt, ) and ``masks``, each is ground truth
masks of each instances of a image, shape (num_gt, h, w).
- batch_img_metas (List[Dict]): List of image meta information.
"""
batch_gt_instances = []
for data_sample in batch_data_samples:
gt_sem_seg = data_sample.gt_sem_seg.data
classes = torch.unique(
gt_sem_seg,
sorted=False,
return_inverse=False,
return_counts=False)
# remove ignored region
gt_labels = classes[classes != ignore_index]
masks = []
for class_id in gt_labels:
masks.append(gt_sem_seg == class_id)
if len(masks) == 0:
gt_masks = torch.zeros(
(0, gt_sem_seg.shape[-2],
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
else:
gt_masks = torch.stack(masks).squeeze(1).long()
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
batch_gt_instances.append(instance_data)
return batch_gt_instances
class MatchMasks:
"""Match the predictions to category labels.
Args:
num_points (int): the number of sampled points to compute cost.
num_queries (int): the number of prediction masks.
num_classes (int): the number of classes.
assigner (BaseAssigner): the assigner to compute matching.
"""
def __init__(self,
num_points: int,
num_queries: int,
num_classes: int,
assigner: ConfigType = None):
assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
'cannot be None'
assert num_points > 0, 'num_points should be a positive integer.'
self.num_points = num_points
self.num_queries = num_queries
self.num_classes = num_classes
self.assigner = TASK_UTILS.build(assigner)
def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
batch_gt_instances: List[InstanceData]) -> Tuple:
"""Compute best mask matches for all images for a decoder layer.
Args:
cls_scores (List[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape (num_queries,
cls_out_channels).
mask_preds (List[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape (num_queries, h, w).
batch_gt_instances (List[InstanceData]): each contains
``labels`` and ``masks``.
Returns:
tuple: a tuple containing the following targets.
- labels (List[Tensor]): Labels of all images.\
Each with shape (num_queries, ).
- mask_targets (List[Tensor]): Mask targets of\
all images. Each with shape (num_queries, h, w).
- mask_weights (List[Tensor]): Mask weights of\
all images. Each with shape (num_queries, ).
- avg_factor (int): Average factor that is used to
average the loss. `avg_factor` is usually equal
to the number of positive priors.
"""
batch_size = cls_scores.shape[0]
results = dict({
'labels': [],
'mask_targets': [],
'mask_weights': [],
})
for i in range(batch_size):
labels, mask_targets, mask_weights\
= self._get_targets_single(cls_scores[i],
mask_preds[i],
batch_gt_instances[i])
results['labels'].append(labels)
results['mask_targets'].append(mask_targets)
results['mask_weights'].append(mask_weights)
# shape (batch_size, num_queries)
labels = torch.stack(results['labels'], dim=0)
# shape (batch_size, num_gts, h, w)
mask_targets = torch.cat(results['mask_targets'], dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(results['mask_weights'], dim=0)
avg_factor = sum(
[len(gt_instances.labels) for gt_instances in batch_gt_instances])
res = (labels, mask_targets, mask_weights, avg_factor)
return res
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
gt_instances: InstanceData) \
-> Tuple[Tensor, Tensor, Tensor]:
"""Compute a set of best mask matches for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
``masks``.
Returns:
tuple[Tensor]: A tuple containing the following for one image.
- labels (Tensor): Labels of each image. \
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image. \
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image. \
shape (num_queries, ).
"""
gt_labels = gt_instances.labels
gt_masks = gt_instances.masks
# when "gt_labels" is empty, classify all queries to background
if len(gt_labels) == 0:
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
mask_targets = gt_labels
mask_weights = gt_labels.new_zeros((self.num_queries, ))
return labels, mask_targets, mask_weights
# sample points
num_queries = cls_score.shape[0]
num_gts = gt_labels.shape[0]
point_coords = torch.rand((1, self.num_points, 2),
device=cls_score.device)
# shape (num_queries, num_points)
mask_points_pred = point_sample(
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
1)).squeeze(1)
# shape (num_gts, num_points)
gt_points_masks = point_sample(
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
1)).squeeze(1)
sampled_gt_instances = InstanceData(
labels=gt_labels, masks=gt_points_masks)
sampled_pred_instances = InstanceData(
scores=cls_score, masks=mask_points_pred)
# assign and sample
matched_quiery_inds, matched_label_inds = self.assigner.assign(
pred_instances=sampled_pred_instances,
gt_instances=sampled_gt_instances)
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
labels[matched_quiery_inds] = gt_labels[matched_label_inds]
mask_weights = gt_labels.new_zeros((self.num_queries, ))
mask_weights[matched_quiery_inds] = 1
mask_targets = gt_masks[matched_label_inds]
return labels, mask_targets, mask_weights
|