File size: 8,489 Bytes
412c852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
from mmcv.ops import point_sample
from mmengine.structures import InstanceData
from torch import Tensor

from mmseg.registry import TASK_UTILS
from mmseg.utils import ConfigType, SampleList


def seg_data_to_instance_data(ignore_index: int,
                              batch_data_samples: SampleList):
    """Convert the paradigm of ground truth from semantic segmentation to
    instance segmentation.

    Args:
        ignore_index (int): The label index to be ignored.
        batch_data_samples (List[SegDataSample]): The Data
            Samples. It usually includes information such as
            `gt_sem_seg`.

    Returns:
        tuple[Tensor]: A tuple contains two lists.
            - batch_gt_instances (List[InstanceData]): Batch of
                gt_instance. It usually includes ``labels``, each is
                unique ground truth label id of images, with
                shape (num_gt, ) and ``masks``, each is ground truth
                masks of each instances of a image, shape (num_gt, h, w).
            - batch_img_metas (List[Dict]): List of image meta information.
    """
    batch_gt_instances = []

    for data_sample in batch_data_samples:
        gt_sem_seg = data_sample.gt_sem_seg.data
        classes = torch.unique(
            gt_sem_seg,
            sorted=False,
            return_inverse=False,
            return_counts=False)

        # remove ignored region
        gt_labels = classes[classes != ignore_index]

        masks = []
        for class_id in gt_labels:
            masks.append(gt_sem_seg == class_id)

        if len(masks) == 0:
            gt_masks = torch.zeros(
                (0, gt_sem_seg.shape[-2],
                 gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
        else:
            gt_masks = torch.stack(masks).squeeze(1).long()

        instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
        batch_gt_instances.append(instance_data)
    return batch_gt_instances


class MatchMasks:
    """Match the predictions to category labels.

    Args:
        num_points (int): the number of sampled points to compute cost.
        num_queries (int): the number of prediction masks.
        num_classes (int): the number of classes.
        assigner (BaseAssigner): the assigner to compute matching.
    """

    def __init__(self,
                 num_points: int,
                 num_queries: int,
                 num_classes: int,
                 assigner: ConfigType = None):
        assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
                                     'cannot be None'
        assert num_points > 0, 'num_points should be a positive integer.'
        self.num_points = num_points
        self.num_queries = num_queries
        self.num_classes = num_classes
        self.assigner = TASK_UTILS.build(assigner)

    def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
                    batch_gt_instances: List[InstanceData]) -> Tuple:
        """Compute best mask matches for all images for a decoder layer.

        Args:
            cls_scores (List[Tensor]): Mask score logits from a single
                decoder layer for all images. Each with shape (num_queries,
                cls_out_channels).
            mask_preds (List[Tensor]): Mask logits from a single decoder
                layer for all images. Each with shape (num_queries, h, w).
            batch_gt_instances (List[InstanceData]): each contains
                ``labels`` and ``masks``.

        Returns:
            tuple: a tuple containing the following targets.

                - labels (List[Tensor]): Labels of all images.\
                    Each with shape (num_queries, ).
                - mask_targets (List[Tensor]): Mask targets of\
                    all images. Each with shape (num_queries, h, w).
                - mask_weights (List[Tensor]): Mask weights of\
                    all images. Each with shape (num_queries, ).
                - avg_factor (int): Average factor that is used to
                    average the loss. `avg_factor` is usually equal
                    to the number of positive priors.
        """
        batch_size = cls_scores.shape[0]
        results = dict({
            'labels': [],
            'mask_targets': [],
            'mask_weights': [],
        })
        for i in range(batch_size):
            labels, mask_targets, mask_weights\
                = self._get_targets_single(cls_scores[i],
                                           mask_preds[i],
                                           batch_gt_instances[i])
            results['labels'].append(labels)
            results['mask_targets'].append(mask_targets)
            results['mask_weights'].append(mask_weights)

        # shape (batch_size, num_queries)
        labels = torch.stack(results['labels'], dim=0)
        # shape (batch_size, num_gts, h, w)
        mask_targets = torch.cat(results['mask_targets'], dim=0)
        # shape (batch_size, num_queries)
        mask_weights = torch.stack(results['mask_weights'], dim=0)

        avg_factor = sum(
            [len(gt_instances.labels) for gt_instances in batch_gt_instances])

        res = (labels, mask_targets, mask_weights, avg_factor)

        return res

    def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
                            gt_instances: InstanceData) \
            -> Tuple[Tensor, Tensor, Tensor]:
        """Compute a set of best mask matches for one image.

        Args:
            cls_score (Tensor): Mask score logits from a single decoder layer
                for one image. Shape (num_queries, cls_out_channels).
            mask_pred (Tensor): Mask logits for a single decoder layer for one
                image. Shape (num_queries, h, w).
            gt_instances (:obj:`InstanceData`): It contains ``labels`` and
                ``masks``.

        Returns:
            tuple[Tensor]: A tuple containing the following for one image.

                - labels (Tensor): Labels of each image. \
                    shape (num_queries, ).
                - mask_targets (Tensor): Mask targets of each image. \
                    shape (num_queries, h, w).
                - mask_weights (Tensor): Mask weights of each image. \
                    shape (num_queries, ).
        """
        gt_labels = gt_instances.labels
        gt_masks = gt_instances.masks
        # when "gt_labels" is empty, classify all queries to background
        if len(gt_labels) == 0:
            labels = gt_labels.new_full((self.num_queries, ),
                                        self.num_classes,
                                        dtype=torch.long)
            mask_targets = gt_labels
            mask_weights = gt_labels.new_zeros((self.num_queries, ))
            return labels, mask_targets, mask_weights
        # sample points
        num_queries = cls_score.shape[0]
        num_gts = gt_labels.shape[0]

        point_coords = torch.rand((1, self.num_points, 2),
                                  device=cls_score.device)
        # shape (num_queries, num_points)
        mask_points_pred = point_sample(
            mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
                                                        1)).squeeze(1)
        # shape (num_gts, num_points)
        gt_points_masks = point_sample(
            gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
                                                               1)).squeeze(1)

        sampled_gt_instances = InstanceData(
            labels=gt_labels, masks=gt_points_masks)
        sampled_pred_instances = InstanceData(
            scores=cls_score, masks=mask_points_pred)
        # assign and sample
        matched_quiery_inds, matched_label_inds = self.assigner.assign(
            pred_instances=sampled_pred_instances,
            gt_instances=sampled_gt_instances)
        labels = gt_labels.new_full((self.num_queries, ),
                                    self.num_classes,
                                    dtype=torch.long)
        labels[matched_quiery_inds] = gt_labels[matched_label_inds]

        mask_weights = gt_labels.new_zeros((self.num_queries, ))
        mask_weights[matched_quiery_inds] = 1
        mask_targets = gt_masks[matched_label_inds]

        return labels, mask_targets, mask_weights