HichTala commited on
Commit
89b2487
·
verified ·
1 Parent(s): 02988cf

Upload 5 files

Browse files
configuration_diffusiondet.py CHANGED
@@ -91,8 +91,8 @@ class DiffusionDetConfig(PretrainedConfig):
91
 
92
  # Auto mapping
93
  self.auto_map = {
94
- "AutoConfig": "diffusiondet.configuration_diffusiondet.DiffusionDetConfig",
95
- "AutoModelForObjectDetection": "diffusiondet.modeling_diffusiondet.DiffusionDet"
96
  }
97
 
98
  # Backbone.
 
91
 
92
  # Auto mapping
93
  self.auto_map = {
94
+ "AutoConfig": "configuration_diffusiondet.DiffusionDetConfig",
95
+ "AutoModelForObjectDetection": "modeling_diffusiondet.DiffusionDet"
96
  }
97
 
98
  # Backbone.
head.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from dataclasses import astuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn.modules.transformer import _get_activation_fn
8
+ from torchvision.ops import RoIAlign
9
+
10
+ _DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
11
+
12
+ def convert_boxes_to_pooler_format(bboxes):
13
+ bs, num_proposals = bboxes.shape[:2]
14
+ sizes = torch.full((bs,), num_proposals).to(bboxes.device)
15
+ aggregated_bboxes = bboxes.view(bs * num_proposals, -1)
16
+ indices = torch.repeat_interleave(
17
+ torch.arange(len(sizes), dtype=aggregated_bboxes.dtype, device=aggregated_bboxes.device), sizes
18
+ )
19
+ return torch.cat([indices[:, None], aggregated_bboxes], dim=1)
20
+
21
+
22
+ def assign_boxes_to_levels(
23
+ bboxes,
24
+ min_level,
25
+ max_level,
26
+ canonical_box_size,
27
+ canonical_level,
28
+ ):
29
+ aggregated_bboxes = bboxes.view(bboxes.shape[0] * bboxes.shape[1], -1)
30
+ area = (aggregated_bboxes[:, 2] - aggregated_bboxes[:, 0]) * (aggregated_bboxes[:, 3] - aggregated_bboxes[:, 1])
31
+ box_sizes = torch.sqrt(area)
32
+ # Eqn.(1) in FPN paper
33
+ level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
34
+ # clamp level to (min, max), in case the box size is too large or too small
35
+ # for the available feature maps
36
+ level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
37
+ return level_assignments.to(torch.int64) - min_level
38
+
39
+
40
+ class SinusoidalPositionEmbeddings(nn.Module):
41
+ def __init__(self, dim):
42
+ super().__init__()
43
+ self.dim = dim
44
+
45
+ def forward(self, time):
46
+ device = time.device
47
+ half_dim = self.dim // 2
48
+ embeddings = math.log(10000) / (half_dim - 1)
49
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
50
+ embeddings = time[:, None] * embeddings[None, :]
51
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
52
+ return embeddings
53
+
54
+
55
+ class HeadDynamicK(nn.Module):
56
+ def __init__(self, config, roi_input_shape):
57
+ super().__init__()
58
+ num_classes = config.num_labels
59
+
60
+ ddet_head = DiffusionDetHead(config, roi_input_shape, num_classes)
61
+ self.num_head = config.num_heads
62
+ self.head_series = nn.ModuleList([copy.deepcopy(ddet_head) for _ in range(self.num_head)])
63
+ self.return_intermediate = config.deep_supervision
64
+
65
+ # Gaussian random feature embedding layer for time
66
+ self.hidden_dim = config.hidden_dim
67
+ time_dim = self.hidden_dim * 4
68
+ self.time_mlp = nn.Sequential(
69
+ SinusoidalPositionEmbeddings(self.hidden_dim),
70
+ nn.Linear(self.hidden_dim, time_dim),
71
+ nn.GELU(),
72
+ nn.Linear(time_dim, time_dim),
73
+ )
74
+
75
+ # Init parameters.
76
+ self.use_focal = config.use_focal
77
+ self.use_fed_loss = config.use_fed_loss
78
+ self.num_classes = num_classes
79
+ if self.use_focal or self.use_fed_loss:
80
+ prior_prob = config.prior_prob
81
+ self.bias_value = -math.log((1 - prior_prob) / prior_prob)
82
+ self._reset_parameters()
83
+
84
+ def _reset_parameters(self):
85
+ # init all parameters.
86
+ for p in self.parameters():
87
+ if p.dim() > 1:
88
+ nn.init.xavier_uniform_(p)
89
+
90
+ # initialize the bias for focal loss and fed loss.
91
+ if self.use_focal or self.use_fed_loss:
92
+ if p.shape[-1] == self.num_classes or p.shape[-1] == self.num_classes + 1:
93
+ nn.init.constant_(p, self.bias_value)
94
+
95
+
96
+ def forward(self, features, bboxes, t):
97
+ # assert t shape (batch_size)
98
+ time = self.time_mlp(t)
99
+
100
+ inter_class_logits = []
101
+ inter_pred_bboxes = []
102
+
103
+ bs = len(features[0])
104
+
105
+ class_logits, pred_bboxes = None, None
106
+ for head_idx, ddet_head in enumerate(self.head_series):
107
+ class_logits, pred_bboxes, proposal_features = ddet_head(features, bboxes, time)
108
+ if self.return_intermediate:
109
+ inter_class_logits.append(class_logits)
110
+ inter_pred_bboxes.append(pred_bboxes)
111
+ bboxes = pred_bboxes.detach()
112
+
113
+ if self.return_intermediate:
114
+ return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes)
115
+
116
+ return class_logits[None], pred_bboxes[None]
117
+
118
+
119
+ class DynamicConv(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+
123
+ self.hidden_dim = config.hidden_dim
124
+ self.dim_dynamic = config.dim_dynamic
125
+ self.num_dynamic = config.num_dynamic
126
+ self.num_params = self.hidden_dim * self.dim_dynamic
127
+ self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params)
128
+
129
+ self.norm1 = nn.LayerNorm(self.dim_dynamic)
130
+ self.norm2 = nn.LayerNorm(self.hidden_dim)
131
+
132
+ self.activation = nn.ReLU(inplace=True)
133
+
134
+ pooler_resolution = config.pooler_resolution
135
+ num_output = self.hidden_dim * pooler_resolution ** 2
136
+ self.out_layer = nn.Linear(num_output, self.hidden_dim)
137
+ self.norm3 = nn.LayerNorm(self.hidden_dim)
138
+
139
+
140
+ def forward(self, pro_features, roi_features):
141
+ features = roi_features.permute(1, 0, 2)
142
+ parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
143
+
144
+ param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic)
145
+ param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim)
146
+
147
+ features = torch.bmm(features, param1)
148
+ features = self.norm1(features)
149
+ features = self.activation(features)
150
+
151
+ features = torch.bmm(features, param2)
152
+ features = self.norm2(features)
153
+ features = self.activation(features)
154
+
155
+ features = features.flatten(1)
156
+ features = self.out_layer(features)
157
+ features = self.norm3(features)
158
+ features = self.activation(features)
159
+
160
+ return features
161
+
162
+
163
+ class DiffusionDetHead(nn.Module):
164
+ def __init__(self, config, roi_input_shape, num_classes):
165
+ super().__init__()
166
+
167
+ dim_feedforward = config.dim_feedforward
168
+ nhead = config.num_attn_heads
169
+ dropout = config.dropout
170
+ activation = config.activation
171
+ in_features = config.roi_head_in_features
172
+ pooler_resolution = config.pooler_resolution
173
+ pooler_scales = tuple(1.0 / roi_input_shape[k]['stride'] for k in in_features)
174
+ sampling_ratio = config.sampling_ratio
175
+
176
+ self.hidden_dim = config.hidden_dim
177
+
178
+ self.pooler = ROIPooler(
179
+ output_size=pooler_resolution,
180
+ scales=pooler_scales,
181
+ sampling_ratio=sampling_ratio,
182
+ )
183
+
184
+ # dynamic.
185
+ self.self_attn = nn.MultiheadAttention(self.hidden_dim, nhead, dropout=dropout)
186
+ self.inst_interact = DynamicConv(config)
187
+
188
+ self.linear1 = nn.Linear(self.hidden_dim, dim_feedforward)
189
+ self.dropout = nn.Dropout(dropout)
190
+ self.linear2 = nn.Linear(dim_feedforward, self.hidden_dim)
191
+
192
+ self.norm1 = nn.LayerNorm(self.hidden_dim)
193
+ self.norm2 = nn.LayerNorm(self.hidden_dim)
194
+ self.norm3 = nn.LayerNorm(self.hidden_dim)
195
+ self.dropout1 = nn.Dropout(dropout)
196
+ self.dropout2 = nn.Dropout(dropout)
197
+ self.dropout3 = nn.Dropout(dropout)
198
+
199
+ self.activation = _get_activation_fn(activation)
200
+
201
+ # block time mlp
202
+ self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(self.hidden_dim * 4, self.hidden_dim * 2))
203
+
204
+ # cls.
205
+ num_cls = config.num_cls
206
+ cls_module = list()
207
+ for _ in range(num_cls):
208
+ cls_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
209
+ cls_module.append(nn.LayerNorm(self.hidden_dim))
210
+ cls_module.append(nn.ReLU(inplace=True))
211
+ self.cls_module = nn.ModuleList(cls_module)
212
+
213
+ # reg.
214
+ num_reg = config.num_reg
215
+ reg_module = list()
216
+ for _ in range(num_reg):
217
+ reg_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
218
+ reg_module.append(nn.LayerNorm(self.hidden_dim))
219
+ reg_module.append(nn.ReLU(inplace=True))
220
+ self.reg_module = nn.ModuleList(reg_module)
221
+
222
+ # pred.
223
+ self.use_focal = config.use_focal
224
+ self.use_fed_loss = config.use_fed_loss
225
+ if self.use_focal or self.use_fed_loss:
226
+ self.class_logits = nn.Linear(self.hidden_dim, num_classes)
227
+ else:
228
+ self.class_logits = nn.Linear(self.hidden_dim, num_classes + 1)
229
+ self.bboxes_delta = nn.Linear(self.hidden_dim, 4)
230
+ self.scale_clamp = _DEFAULT_SCALE_CLAMP
231
+ self.bbox_weights = (2.0, 2.0, 1.0, 1.0)
232
+
233
+ def forward(self, features, bboxes, time_emb):
234
+ bs, num_proposals = bboxes.shape[:2]
235
+
236
+ # roi_feature.
237
+ roi_features = self.pooler(features, bboxes)
238
+
239
+ pro_features = roi_features.view(bs, num_proposals, self.hidden_dim, -1).mean(-1)
240
+
241
+ roi_features = roi_features.view(bs * num_proposals, self.hidden_dim, -1).permute(2, 0, 1)
242
+
243
+ # self_att.
244
+ pro_features = pro_features.view(bs, num_proposals, self.hidden_dim).permute(1, 0, 2)
245
+ pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0]
246
+ pro_features = pro_features + self.dropout1(pro_features2)
247
+ pro_features = self.norm1(pro_features)
248
+
249
+ # inst_interact.
250
+ pro_features = pro_features.view(num_proposals, bs, self.hidden_dim).permute(1, 0, 2).reshape(1, bs * num_proposals,
251
+ self.hidden_dim)
252
+ pro_features2 = self.inst_interact(pro_features, roi_features)
253
+ pro_features = pro_features + self.dropout2(pro_features2)
254
+ obj_features = self.norm2(pro_features)
255
+
256
+ # obj_feature.
257
+ obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features))))
258
+ obj_features = obj_features + self.dropout3(obj_features2)
259
+ obj_features = self.norm3(obj_features)
260
+
261
+ fc_feature = obj_features.transpose(0, 1).reshape(bs * num_proposals, -1)
262
+
263
+ scale_shift = self.block_time_mlp(time_emb)
264
+ scale_shift = torch.repeat_interleave(scale_shift, num_proposals, dim=0)
265
+ scale, shift = scale_shift.chunk(2, dim=1)
266
+ fc_feature = fc_feature * (scale + 1) + shift
267
+
268
+ cls_feature = fc_feature.clone()
269
+ reg_feature = fc_feature.clone()
270
+ for cls_layer in self.cls_module:
271
+ cls_feature = cls_layer(cls_feature)
272
+ for reg_layer in self.reg_module:
273
+ reg_feature = reg_layer(reg_feature)
274
+ class_logits = self.class_logits(cls_feature)
275
+ bboxes_deltas = self.bboxes_delta(reg_feature)
276
+ pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
277
+
278
+ return class_logits.view(bs, num_proposals, -1), pred_bboxes.view(bs, num_proposals, -1), obj_features
279
+
280
+ def apply_deltas(self, deltas, boxes):
281
+ """
282
+ Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
283
+
284
+ Args:
285
+ deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
286
+ deltas[i] represents k potentially different class-specific
287
+ box transformations for the single box boxes[i].
288
+ boxes (Tensor): boxes to transform, of shape (N, 4)
289
+ """
290
+ boxes = boxes.to(deltas.dtype)
291
+
292
+ widths = boxes[:, 2] - boxes[:, 0]
293
+ heights = boxes[:, 3] - boxes[:, 1]
294
+ ctr_x = boxes[:, 0] + 0.5 * widths
295
+ ctr_y = boxes[:, 1] + 0.5 * heights
296
+
297
+ wx, wy, ww, wh = self.bbox_weights
298
+ dx = deltas[:, 0::4] / wx
299
+ dy = deltas[:, 1::4] / wy
300
+ dw = deltas[:, 2::4] / ww
301
+ dh = deltas[:, 3::4] / wh
302
+
303
+ # Prevent sending too large values into torch.exp()
304
+ dw = torch.clamp(dw, max=self.scale_clamp)
305
+ dh = torch.clamp(dh, max=self.scale_clamp)
306
+
307
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
308
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
309
+ pred_w = torch.exp(dw) * widths[:, None]
310
+ pred_h = torch.exp(dh) * heights[:, None]
311
+
312
+ pred_boxes = torch.zeros_like(deltas)
313
+ pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
314
+ pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
315
+ pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
316
+ pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
317
+
318
+ return pred_boxes
319
+
320
+
321
+ class ROIPooler(nn.Module):
322
+ """
323
+ Region of interest feature map pooler that supports pooling from one or more
324
+ feature maps.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ output_size,
330
+ scales,
331
+ sampling_ratio,
332
+ canonical_box_size=224,
333
+ canonical_level=4,
334
+ ):
335
+ super().__init__()
336
+
337
+ min_level = -(math.log2(scales[0]))
338
+ max_level = -(math.log2(scales[-1]))
339
+
340
+ if isinstance(output_size, int):
341
+ output_size = (output_size, output_size)
342
+ assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int)
343
+ assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level))
344
+ assert (len(scales) == max_level - min_level + 1)
345
+ assert 0 <= min_level <= max_level
346
+ assert canonical_box_size > 0
347
+
348
+ self.output_size = output_size
349
+ self.min_level = int(min_level)
350
+ self.max_level = int(max_level)
351
+ self.canonical_level = canonical_level
352
+ self.canonical_box_size = canonical_box_size
353
+ self.level_poolers = nn.ModuleList(
354
+ RoIAlign(
355
+ output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True
356
+ )
357
+ for scale in scales
358
+ )
359
+
360
+ def forward(self, x, bboxes):
361
+ num_level_assignments = len(self.level_poolers)
362
+ assert len(x) == num_level_assignments and len(bboxes) == x[0].size(0)
363
+
364
+ pooler_fmt_boxes = convert_boxes_to_pooler_format(bboxes)
365
+
366
+ if num_level_assignments == 1:
367
+ return self.level_poolers[0](x[0], pooler_fmt_boxes)
368
+
369
+ level_assignments = assign_boxes_to_levels(
370
+ bboxes, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
371
+ )
372
+
373
+ batches = pooler_fmt_boxes.shape[0]
374
+ channels = x[0].shape[1]
375
+ output_size = self.output_size[0]
376
+ sizes = (batches, channels, output_size, output_size)
377
+
378
+ output = torch.zeros(sizes, dtype=x[0].dtype, device=x[0].device)
379
+
380
+ for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
381
+ inds = (level_assignments == level).nonzero(as_tuple=True)[0]
382
+ pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
383
+ # Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
384
+ output.index_put_((inds,), pooler(x_level, pooler_fmt_boxes_level))
385
+
386
+ return output
image_processing_diffusiondet.py ADDED
@@ -0,0 +1,1632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Deformable DETR."""
16
+
17
+ import io
18
+ import pathlib
19
+ from collections import defaultdict
20
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
21
+
22
+ import numpy as np
23
+ from transformers.feature_extraction_utils import BatchFeature
24
+ from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
25
+ from transformers.image_transforms import (
26
+ PaddingMode,
27
+ center_to_corners_format,
28
+ corners_to_center_format,
29
+ id_to_rgb,
30
+ pad,
31
+ rescale,
32
+ resize,
33
+ rgb_to_id,
34
+ to_channel_dimension_format,
35
+ )
36
+
37
+ from transformers.image_utils import (
38
+ IMAGENET_DEFAULT_MEAN,
39
+ IMAGENET_DEFAULT_STD,
40
+ AnnotationFormat,
41
+ AnnotationType,
42
+ ChannelDimension,
43
+ ImageInput,
44
+ PILImageResampling,
45
+ get_image_size,
46
+ infer_channel_dimension_format,
47
+ is_scaled_image,
48
+ make_list_of_images,
49
+ to_numpy_array,
50
+ valid_images,
51
+ validate_annotations,
52
+ validate_kwargs,
53
+ validate_preprocess_arguments
54
+ )
55
+
56
+ from transformers.utils import (
57
+ TensorType,
58
+ is_flax_available,
59
+ is_jax_tensor,
60
+ is_tf_available,
61
+ is_tf_tensor,
62
+ is_torch_tensor,
63
+ is_vision_available
64
+ )
65
+ from transformers.utils import (
66
+ is_torch_available,
67
+ is_scipy_available,
68
+ logging
69
+ )
70
+
71
+
72
+ if is_torch_available():
73
+ import torch
74
+ from torch import nn
75
+
76
+ if is_vision_available():
77
+ import PIL
78
+
79
+ if is_scipy_available():
80
+ import scipy.special
81
+ import scipy.stats
82
+
83
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
84
+
85
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
86
+
87
+
88
+ # Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
89
+ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]:
90
+ """
91
+ Computes the output image size given the input image size and the desired output size.
92
+
93
+ Args:
94
+ image_size (`Tuple[int, int]`):
95
+ The input image size.
96
+ size (`int`):
97
+ The desired output size.
98
+ max_size (`int`, *optional*):
99
+ The maximum allowed output size.
100
+ """
101
+ height, width = image_size
102
+ raw_size = None
103
+ if max_size is not None:
104
+ min_original_size = float(min((height, width)))
105
+ max_original_size = float(max((height, width)))
106
+ if max_original_size / min_original_size * size > max_size:
107
+ raw_size = max_size * min_original_size / max_original_size
108
+ size = int(round(raw_size))
109
+
110
+ if (height <= width and height == size) or (width <= height and width == size):
111
+ oh, ow = height, width
112
+ elif width < height:
113
+ ow = size
114
+ if max_size is not None and raw_size is not None:
115
+ oh = int(raw_size * height / width)
116
+ else:
117
+ oh = int(size * height / width)
118
+ else:
119
+ oh = size
120
+ if max_size is not None and raw_size is not None:
121
+ ow = int(raw_size * width / height)
122
+ else:
123
+ ow = int(size * width / height)
124
+
125
+ return (oh, ow)
126
+
127
+
128
+ # Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
129
+ def get_resize_output_image_size(
130
+ input_image: np.ndarray,
131
+ size: Union[int, Tuple[int, int], List[int]],
132
+ max_size: Optional[int] = None,
133
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
134
+ ) -> Tuple[int, int]:
135
+ """
136
+ Computes the output image size given the input image size and the desired output size. If the desired output size
137
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
138
+ image size is computed by keeping the aspect ratio of the input image size.
139
+
140
+ Args:
141
+ input_image (`np.ndarray`):
142
+ The image to resize.
143
+ size (`int` or `Tuple[int, int]` or `List[int]`):
144
+ The desired output size.
145
+ max_size (`int`, *optional*):
146
+ The maximum allowed output size.
147
+ input_data_format (`ChannelDimension` or `str`, *optional*):
148
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
149
+ """
150
+ image_size = get_image_size(input_image, input_data_format)
151
+ if isinstance(size, (list, tuple)):
152
+ return size
153
+
154
+ return get_size_with_aspect_ratio(image_size, size, max_size)
155
+
156
+
157
+ # Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
158
+ def get_image_size_for_max_height_width(
159
+ input_image: np.ndarray,
160
+ max_height: int,
161
+ max_width: int,
162
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
163
+ ) -> Tuple[int, int]:
164
+ """
165
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
166
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
167
+ to at least one of the edges be equal to max_height or max_width.
168
+
169
+ For example:
170
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
171
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
172
+
173
+ Args:
174
+ input_image (`np.ndarray`):
175
+ The image to resize.
176
+ max_height (`int`):
177
+ The maximum allowed height.
178
+ max_width (`int`):
179
+ The maximum allowed width.
180
+ input_data_format (`ChannelDimension` or `str`, *optional*):
181
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
182
+ """
183
+ image_size = get_image_size(input_image, input_data_format)
184
+ height, width = image_size
185
+ height_scale = max_height / height
186
+ width_scale = max_width / width
187
+ min_scale = min(height_scale, width_scale)
188
+ new_height = int(height * min_scale)
189
+ new_width = int(width * min_scale)
190
+ return new_height, new_width
191
+
192
+
193
+ # Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
194
+ def get_numpy_to_framework_fn(arr) -> Callable:
195
+ """
196
+ Returns a function that converts a numpy array to the framework of the input array.
197
+
198
+ Args:
199
+ arr (`np.ndarray`): The array to convert.
200
+ """
201
+ if isinstance(arr, np.ndarray):
202
+ return np.array
203
+ if is_tf_available() and is_tf_tensor(arr):
204
+ import tensorflow as tf
205
+
206
+ return tf.convert_to_tensor
207
+ if is_torch_available() and is_torch_tensor(arr):
208
+ import torch
209
+
210
+ return torch.tensor
211
+ if is_flax_available() and is_jax_tensor(arr):
212
+ import jax.numpy as jnp
213
+
214
+ return jnp.array
215
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
216
+
217
+
218
+ # Copied from transformers.models.detr.image_processing_detr.safe_squeeze
219
+ def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
220
+ """
221
+ Squeezes an array, but only if the axis specified has dim 1.
222
+ """
223
+ if axis is None:
224
+ return arr.squeeze()
225
+
226
+ try:
227
+ return arr.squeeze(axis=axis)
228
+ except ValueError:
229
+ return arr
230
+
231
+
232
+ # Copied from transformers.models.detr.image_processing_detr.normalize_annotation
233
+ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
234
+ image_height, image_width = image_size
235
+ norm_annotation = {}
236
+ for key, value in annotation.items():
237
+ if key == "boxes":
238
+ boxes = value
239
+ boxes = corners_to_center_format(boxes)
240
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
241
+ norm_annotation[key] = boxes
242
+ else:
243
+ norm_annotation[key] = value
244
+ return norm_annotation
245
+
246
+
247
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
248
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
249
+ """
250
+ Return the maximum value across all indices of an iterable of values.
251
+ """
252
+ return [max(values_i) for values_i in zip(*values)]
253
+
254
+
255
+ # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
256
+ def get_max_height_width(
257
+ images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
258
+ ) -> List[int]:
259
+ """
260
+ Get the maximum height and width across all images in a batch.
261
+ """
262
+ if input_data_format is None:
263
+ input_data_format = infer_channel_dimension_format(images[0])
264
+
265
+ if input_data_format == ChannelDimension.FIRST:
266
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
267
+ elif input_data_format == ChannelDimension.LAST:
268
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
269
+ else:
270
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
271
+ return (max_height, max_width)
272
+
273
+
274
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
275
+ def make_pixel_mask(
276
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
277
+ ) -> np.ndarray:
278
+ """
279
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
280
+
281
+ Args:
282
+ image (`np.ndarray`):
283
+ Image to make the pixel mask for.
284
+ output_size (`Tuple[int, int]`):
285
+ Output size of the mask.
286
+ """
287
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
288
+ mask = np.zeros(output_size, dtype=np.int64)
289
+ mask[:input_height, :input_width] = 1
290
+ return mask
291
+
292
+
293
+ # Copied from transformers.models.detr.image_processing_detr.convert_coco_poly_to_mask
294
+ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray:
295
+ """
296
+ Convert a COCO polygon annotation to a mask.
297
+
298
+ Args:
299
+ segmentations (`List[List[float]]`):
300
+ List of polygons, each polygon represented by a list of x-y coordinates.
301
+ height (`int`):
302
+ Height of the mask.
303
+ width (`int`):
304
+ Width of the mask.
305
+ """
306
+ try:
307
+ from pycocotools import mask as coco_mask
308
+ except ImportError:
309
+ raise ImportError("Pycocotools is not installed in your environment.")
310
+
311
+ masks = []
312
+ for polygons in segmentations:
313
+ rles = coco_mask.frPyObjects(polygons, height, width)
314
+ mask = coco_mask.decode(rles)
315
+ if len(mask.shape) < 3:
316
+ mask = mask[..., None]
317
+ mask = np.asarray(mask, dtype=np.uint8)
318
+ mask = np.any(mask, axis=2)
319
+ masks.append(mask)
320
+ if masks:
321
+ masks = np.stack(masks, axis=0)
322
+ else:
323
+ masks = np.zeros((0, height, width), dtype=np.uint8)
324
+
325
+ return masks
326
+
327
+
328
+ # Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->DeformableDetr
329
+ def prepare_coco_detection_annotation(
330
+ image,
331
+ target,
332
+ return_segmentation_masks: bool = False,
333
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
334
+ ):
335
+ """
336
+ Convert the target in COCO format into the format expected by DeformableDetr.
337
+ """
338
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
339
+
340
+ image_id = target["image_id"]
341
+ image_id = np.asarray([image_id], dtype=np.int64)
342
+
343
+ # Get all COCO annotations for the given image.
344
+ annotations = target["annotations"]
345
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
346
+
347
+ classes = [obj["category_id"] for obj in annotations]
348
+ classes = np.asarray(classes, dtype=np.int64)
349
+
350
+ # for conversion to coco api
351
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
352
+ iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64)
353
+
354
+ boxes = [obj["bbox"] for obj in annotations]
355
+ # guard against no boxes via resizing
356
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
357
+ boxes[:, 2:] += boxes[:, :2]
358
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
359
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
360
+
361
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
362
+
363
+ new_target = {}
364
+ new_target["image_id"] = image_id
365
+ new_target["class_labels"] = classes[keep]
366
+ new_target["boxes"] = boxes[keep]
367
+ new_target["area"] = area[keep]
368
+ new_target["iscrowd"] = iscrowd[keep]
369
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
370
+
371
+ if annotations and "keypoints" in annotations[0]:
372
+ keypoints = [obj["keypoints"] for obj in annotations]
373
+ # Converting the filtered keypoints list to a numpy array
374
+ keypoints = np.asarray(keypoints, dtype=np.float32)
375
+ # Apply the keep mask here to filter the relevant annotations
376
+ keypoints = keypoints[keep]
377
+ num_keypoints = keypoints.shape[0]
378
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
379
+ new_target["keypoints"] = keypoints
380
+
381
+ if return_segmentation_masks:
382
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
383
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width)
384
+ new_target["masks"] = masks[keep]
385
+
386
+ return new_target
387
+
388
+
389
+ # Copied from transformers.models.detr.image_processing_detr.masks_to_boxes
390
+ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
391
+ """
392
+ Compute the bounding boxes around the provided panoptic segmentation masks.
393
+
394
+ Args:
395
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
396
+
397
+ Returns:
398
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
399
+ """
400
+ if masks.size == 0:
401
+ return np.zeros((0, 4))
402
+
403
+ h, w = masks.shape[-2:]
404
+ y = np.arange(0, h, dtype=np.float32)
405
+ x = np.arange(0, w, dtype=np.float32)
406
+ # see https://github.com/pytorch/pytorch/issues/50276
407
+ y, x = np.meshgrid(y, x, indexing="ij")
408
+
409
+ x_mask = masks * np.expand_dims(x, axis=0)
410
+ x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
411
+ x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
412
+ x_min = x.filled(fill_value=1e8)
413
+ x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
414
+
415
+ y_mask = masks * np.expand_dims(y, axis=0)
416
+ y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
417
+ y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
418
+ y_min = y.filled(fill_value=1e8)
419
+ y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
420
+
421
+ return np.stack([x_min, y_min, x_max, y_max], 1)
422
+
423
+
424
+ # Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->DeformableDetr
425
+ def prepare_coco_panoptic_annotation(
426
+ image: np.ndarray,
427
+ target: Dict,
428
+ masks_path: Union[str, pathlib.Path],
429
+ return_masks: bool = True,
430
+ input_data_format: Union[ChannelDimension, str] = None,
431
+ ) -> Dict:
432
+ """
433
+ Prepare a coco panoptic annotation for DeformableDetr.
434
+ """
435
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
436
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
437
+
438
+ new_target = {}
439
+ new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64)
440
+ new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64)
441
+ new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64)
442
+
443
+ if "segments_info" in target:
444
+ masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32)
445
+ masks = rgb_to_id(masks)
446
+
447
+ ids = np.array([segment_info["id"] for segment_info in target["segments_info"]])
448
+ masks = masks == ids[:, None, None]
449
+ masks = masks.astype(np.uint8)
450
+ if return_masks:
451
+ new_target["masks"] = masks
452
+ new_target["boxes"] = masks_to_boxes(masks)
453
+ new_target["class_labels"] = np.array(
454
+ [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64
455
+ )
456
+ new_target["iscrowd"] = np.asarray(
457
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64
458
+ )
459
+ new_target["area"] = np.asarray(
460
+ [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32
461
+ )
462
+
463
+ return new_target
464
+
465
+
466
+ # Copied from transformers.models.detr.image_processing_detr.get_segmentation_image
467
+ def get_segmentation_image(
468
+ masks: np.ndarray, input_size: Tuple, target_size: Tuple, stuff_equiv_classes, deduplicate=False
469
+ ):
470
+ h, w = input_size
471
+ final_h, final_w = target_size
472
+
473
+ m_id = scipy.special.softmax(masks.transpose(0, 1), -1)
474
+
475
+ if m_id.shape[-1] == 0:
476
+ # We didn't detect any mask :(
477
+ m_id = np.zeros((h, w), dtype=np.int64)
478
+ else:
479
+ m_id = m_id.argmax(-1).reshape(h, w)
480
+
481
+ if deduplicate:
482
+ # Merge the masks corresponding to the same stuff class
483
+ for equiv in stuff_equiv_classes.values():
484
+ for eq_id in equiv:
485
+ m_id[m_id == eq_id] = equiv[0]
486
+
487
+ seg_img = id_to_rgb(m_id)
488
+ seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST)
489
+ return seg_img
490
+
491
+
492
+ # Copied from transformers.models.detr.image_processing_detr.get_mask_area
493
+ def get_mask_area(seg_img: np.ndarray, target_size: Tuple[int, int], n_classes: int) -> np.ndarray:
494
+ final_h, final_w = target_size
495
+ np_seg_img = seg_img.astype(np.uint8)
496
+ np_seg_img = np_seg_img.reshape(final_h, final_w, 3)
497
+ m_id = rgb_to_id(np_seg_img)
498
+ area = [(m_id == i).sum() for i in range(n_classes)]
499
+ return area
500
+
501
+
502
+ # Copied from transformers.models.detr.image_processing_detr.score_labels_from_class_probabilities
503
+ def score_labels_from_class_probabilities(logits: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
504
+ probs = scipy.special.softmax(logits, axis=-1)
505
+ labels = probs.argmax(-1, keepdims=True)
506
+ scores = np.take_along_axis(probs, labels, axis=-1)
507
+ scores, labels = scores.squeeze(-1), labels.squeeze(-1)
508
+ return scores, labels
509
+
510
+
511
+ # Copied from transformers.models.detr.image_processing_detr.post_process_panoptic_sample
512
+ def post_process_panoptic_sample(
513
+ out_logits: np.ndarray,
514
+ masks: np.ndarray,
515
+ boxes: np.ndarray,
516
+ processed_size: Tuple[int, int],
517
+ target_size: Tuple[int, int],
518
+ is_thing_map: Dict,
519
+ threshold=0.85,
520
+ ) -> Dict:
521
+ """
522
+ Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample.
523
+
524
+ Args:
525
+ out_logits (`torch.Tensor`):
526
+ The logits for this sample.
527
+ masks (`torch.Tensor`):
528
+ The predicted segmentation masks for this sample.
529
+ boxes (`torch.Tensor`):
530
+ The prediced bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y,
531
+ width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding).
532
+ processed_size (`Tuple[int, int]`):
533
+ The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size
534
+ after data augmentation but before batching.
535
+ target_size (`Tuple[int, int]`):
536
+ The target size of the image, `(height, width)` corresponding to the requested final size of the
537
+ prediction.
538
+ is_thing_map (`Dict`):
539
+ A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not.
540
+ threshold (`float`, *optional*, defaults to 0.85):
541
+ The threshold used to binarize the segmentation masks.
542
+ """
543
+ # we filter empty queries and detection below threshold
544
+ scores, labels = score_labels_from_class_probabilities(out_logits)
545
+ keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold)
546
+
547
+ cur_scores = scores[keep]
548
+ cur_classes = labels[keep]
549
+ cur_boxes = center_to_corners_format(boxes[keep])
550
+
551
+ if len(cur_boxes) != len(cur_classes):
552
+ raise ValueError("Not as many boxes as there are classes")
553
+
554
+ cur_masks = masks[keep]
555
+ cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR)
556
+ cur_masks = safe_squeeze(cur_masks, 1)
557
+ b, h, w = cur_masks.shape
558
+
559
+ # It may be that we have several predicted masks for the same stuff class.
560
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
561
+ cur_masks = cur_masks.reshape(b, -1)
562
+ stuff_equiv_classes = defaultdict(list)
563
+ for k, label in enumerate(cur_classes):
564
+ if not is_thing_map[label]:
565
+ stuff_equiv_classes[label].append(k)
566
+
567
+ seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True)
568
+ area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores))
569
+
570
+ # We filter out any mask that is too small
571
+ if cur_classes.size() > 0:
572
+ # We know filter empty masks as long as we find some
573
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
574
+ while filtered_small.any():
575
+ cur_masks = cur_masks[~filtered_small]
576
+ cur_scores = cur_scores[~filtered_small]
577
+ cur_classes = cur_classes[~filtered_small]
578
+ seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True)
579
+ area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores))
580
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
581
+ else:
582
+ cur_classes = np.ones((1, 1), dtype=np.int64)
583
+
584
+ segments_info = [
585
+ {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a}
586
+ for i, (cat, a) in enumerate(zip(cur_classes, area))
587
+ ]
588
+ del cur_classes
589
+
590
+ with io.BytesIO() as out:
591
+ PIL.Image.fromarray(seg_img).save(out, format="PNG")
592
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
593
+
594
+ return predictions
595
+
596
+
597
+ # Copied from transformers.models.detr.image_processing_detr.resize_annotation
598
+ def resize_annotation(
599
+ annotation: Dict[str, Any],
600
+ orig_size: Tuple[int, int],
601
+ target_size: Tuple[int, int],
602
+ threshold: float = 0.5,
603
+ resample: PILImageResampling = PILImageResampling.NEAREST,
604
+ ):
605
+ """
606
+ Resizes an annotation to a target size.
607
+
608
+ Args:
609
+ annotation (`Dict[str, Any]`):
610
+ The annotation dictionary.
611
+ orig_size (`Tuple[int, int]`):
612
+ The original size of the input image.
613
+ target_size (`Tuple[int, int]`):
614
+ The target size of the image, as returned by the preprocessing `resize` step.
615
+ threshold (`float`, *optional*, defaults to 0.5):
616
+ The threshold used to binarize the segmentation masks.
617
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
618
+ The resampling filter to use when resizing the masks.
619
+ """
620
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
621
+ ratio_height, ratio_width = ratios
622
+
623
+ new_annotation = {}
624
+ new_annotation["size"] = target_size
625
+
626
+ for key, value in annotation.items():
627
+ if key == "boxes":
628
+ boxes = value
629
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
630
+ new_annotation["boxes"] = scaled_boxes
631
+ elif key == "area":
632
+ area = value
633
+ scaled_area = area * (ratio_width * ratio_height)
634
+ new_annotation["area"] = scaled_area
635
+ elif key == "masks":
636
+ masks = value[:, None]
637
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
638
+ masks = masks.astype(np.float32)
639
+ masks = masks[:, 0] > threshold
640
+ new_annotation["masks"] = masks
641
+ elif key == "size":
642
+ new_annotation["size"] = target_size
643
+ else:
644
+ new_annotation[key] = value
645
+
646
+ return new_annotation
647
+
648
+
649
+ # Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
650
+ def binary_mask_to_rle(mask):
651
+ """
652
+ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
653
+
654
+ Args:
655
+ mask (`torch.Tensor` or `numpy.array`):
656
+ A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
657
+ segment_id or class_id.
658
+ Returns:
659
+ `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
660
+ format.
661
+ """
662
+ if is_torch_tensor(mask):
663
+ mask = mask.numpy()
664
+
665
+ pixels = mask.flatten()
666
+ pixels = np.concatenate([[0], pixels, [0]])
667
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
668
+ runs[1::2] -= runs[::2]
669
+ return list(runs)
670
+
671
+
672
+ # Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
673
+ def convert_segmentation_to_rle(segmentation):
674
+ """
675
+ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
676
+
677
+ Args:
678
+ segmentation (`torch.Tensor` or `numpy.array`):
679
+ A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
680
+ Returns:
681
+ `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
682
+ """
683
+ segment_ids = torch.unique(segmentation)
684
+
685
+ run_length_encodings = []
686
+ for idx in segment_ids:
687
+ mask = torch.where(segmentation == idx, 1, 0)
688
+ rle = binary_mask_to_rle(mask)
689
+ run_length_encodings.append(rle)
690
+
691
+ return run_length_encodings
692
+
693
+
694
+ # Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
695
+ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
696
+ """
697
+ Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
698
+ `labels`.
699
+
700
+ Args:
701
+ masks (`torch.Tensor`):
702
+ A tensor of shape `(num_queries, height, width)`.
703
+ scores (`torch.Tensor`):
704
+ A tensor of shape `(num_queries)`.
705
+ labels (`torch.Tensor`):
706
+ A tensor of shape `(num_queries)`.
707
+ object_mask_threshold (`float`):
708
+ A number between 0 and 1 used to binarize the masks.
709
+ Raises:
710
+ `ValueError`: Raised when the first dimension doesn't match in all input tensors.
711
+ Returns:
712
+ `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
713
+ < `object_mask_threshold`.
714
+ """
715
+ if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
716
+ raise ValueError("mask, scores and labels must have the same shape!")
717
+
718
+ to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
719
+
720
+ return masks[to_keep], scores[to_keep], labels[to_keep]
721
+
722
+
723
+ # Copied from transformers.models.detr.image_processing_detr.check_segment_validity
724
+ def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
725
+ # Get the mask associated with the k class
726
+ mask_k = mask_labels == k
727
+ mask_k_area = mask_k.sum()
728
+
729
+ # Compute the area of all the stuff in query k
730
+ original_area = (mask_probs[k] >= mask_threshold).sum()
731
+ mask_exists = mask_k_area > 0 and original_area > 0
732
+
733
+ # Eliminate disconnected tiny segments
734
+ if mask_exists:
735
+ area_ratio = mask_k_area / original_area
736
+ if not area_ratio.item() > overlap_mask_area_threshold:
737
+ mask_exists = False
738
+
739
+ return mask_exists, mask_k
740
+
741
+
742
+ # Copied from transformers.models.detr.image_processing_detr.compute_segments
743
+ def compute_segments(
744
+ mask_probs,
745
+ pred_scores,
746
+ pred_labels,
747
+ mask_threshold: float = 0.5,
748
+ overlap_mask_area_threshold: float = 0.8,
749
+ label_ids_to_fuse: Optional[Set[int]] = None,
750
+ target_size: Tuple[int, int] = None,
751
+ ):
752
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
753
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
754
+
755
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
756
+ segments: List[Dict] = []
757
+
758
+ if target_size is not None:
759
+ mask_probs = nn.functional.interpolate(
760
+ mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
761
+ )[0]
762
+
763
+ current_segment_id = 0
764
+
765
+ # Weigh each mask by its prediction score
766
+ mask_probs *= pred_scores.view(-1, 1, 1)
767
+ mask_labels = mask_probs.argmax(0) # [height, width]
768
+
769
+ # Keep track of instances of each class
770
+ stuff_memory_list: Dict[str, int] = {}
771
+ for k in range(pred_labels.shape[0]):
772
+ pred_class = pred_labels[k].item()
773
+ should_fuse = pred_class in label_ids_to_fuse
774
+
775
+ # Check if mask exists and large enough to be a segment
776
+ mask_exists, mask_k = check_segment_validity(
777
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
778
+ )
779
+
780
+ if mask_exists:
781
+ if pred_class in stuff_memory_list:
782
+ current_segment_id = stuff_memory_list[pred_class]
783
+ else:
784
+ current_segment_id += 1
785
+
786
+ # Add current object segment to final segmentation map
787
+ segmentation[mask_k] = current_segment_id
788
+ segment_score = round(pred_scores[k].item(), 6)
789
+ segments.append(
790
+ {
791
+ "id": current_segment_id,
792
+ "label_id": pred_class,
793
+ "was_fused": should_fuse,
794
+ "score": segment_score,
795
+ }
796
+ )
797
+ if should_fuse:
798
+ stuff_memory_list[pred_class] = current_segment_id
799
+
800
+ return segmentation, segments
801
+
802
+
803
+ class DiffusionDetImageProcessor(BaseImageProcessor):
804
+ r"""
805
+ Constructs a DiffusionDet image processor.
806
+
807
+ Args:
808
+ format (`str`, *optional*, defaults to `"coco_detection"`):
809
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
810
+ do_resize (`bool`, *optional*, defaults to `True`):
811
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
812
+ overridden by the `do_resize` parameter in the `preprocess` method.
813
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
814
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
815
+ in the `preprocess` method. Available options are:
816
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
817
+ Do NOT keep the aspect ratio.
818
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
819
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
820
+ less or equal to `longest_edge`.
821
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
822
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
823
+ `max_width`.
824
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
825
+ Resampling filter to use if resizing the image.
826
+ do_rescale (`bool`, *optional*, defaults to `True`):
827
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
828
+ `do_rescale` parameter in the `preprocess` method.
829
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
830
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
831
+ `preprocess` method.
832
+ do_normalize:
833
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
834
+ `preprocess` method.
835
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
836
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
837
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
838
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
839
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
840
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
841
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
842
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
843
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
844
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
845
+ do_pad (`bool`, *optional*, defaults to `True`):
846
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
847
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
848
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
849
+ Otherwise, the image will be padded to the maximum height and width of the batch.
850
+ pad_size (`Dict[str, int]`, *optional*):
851
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
852
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
853
+ height and width in the batch.
854
+ """
855
+
856
+ model_input_names = ["pixel_values", "pixel_mask"]
857
+
858
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.__init__
859
+ def __init__(
860
+ self,
861
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
862
+ do_resize: bool = True,
863
+ size: Dict[str, int] = None,
864
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
865
+ do_rescale: bool = True,
866
+ rescale_factor: Union[int, float] = 1 / 255,
867
+ do_normalize: bool = True,
868
+ image_mean: Union[float, List[float]] = None,
869
+ image_std: Union[float, List[float]] = None,
870
+ do_convert_annotations: Optional[bool] = None,
871
+ do_pad: bool = True,
872
+ pad_size: Optional[Dict[str, int]] = None,
873
+ **kwargs,
874
+ ) -> None:
875
+ if "pad_and_return_pixel_mask" in kwargs:
876
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
877
+
878
+ if "max_size" in kwargs:
879
+ logger.warning_once(
880
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
881
+ "Please specify in `size['longest_edge'] instead`.",
882
+ )
883
+ max_size = kwargs.pop("max_size")
884
+ else:
885
+ max_size = None if size is None else 1333
886
+
887
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
888
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
889
+
890
+ # Backwards compatibility
891
+ if do_convert_annotations is None:
892
+ do_convert_annotations = do_normalize
893
+
894
+ super().__init__(**kwargs)
895
+ self.format = format
896
+ self.do_resize = do_resize
897
+ self.size = size
898
+ self.resample = resample
899
+ self.do_rescale = do_rescale
900
+ self.rescale_factor = rescale_factor
901
+ self.do_normalize = do_normalize
902
+ self.do_convert_annotations = do_convert_annotations
903
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
904
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
905
+ self.do_pad = do_pad
906
+ self.pad_size = pad_size
907
+ self._valid_processor_keys = [
908
+ "images",
909
+ "annotations",
910
+ "return_segmentation_masks",
911
+ "masks_path",
912
+ "do_resize",
913
+ "size",
914
+ "resample",
915
+ "do_rescale",
916
+ "rescale_factor",
917
+ "do_normalize",
918
+ "do_convert_annotations",
919
+ "image_mean",
920
+ "image_std",
921
+ "do_pad",
922
+ "pad_size",
923
+ "format",
924
+ "return_tensors",
925
+ "data_format",
926
+ "input_data_format",
927
+ ]
928
+
929
+ @classmethod
930
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->DeformableDetr
931
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
932
+ """
933
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
934
+ created using from_dict and kwargs e.g. `DeformableDetrImageProcessor.from_pretrained(checkpoint, size=600,
935
+ max_size=800)`
936
+ """
937
+ image_processor_dict = image_processor_dict.copy()
938
+ if "max_size" in kwargs:
939
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
940
+ if "pad_and_return_pixel_mask" in kwargs:
941
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
942
+ return super().from_dict(image_processor_dict, **kwargs)
943
+
944
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr
945
+ def prepare_annotation(
946
+ self,
947
+ image: np.ndarray,
948
+ target: Dict,
949
+ format: Optional[AnnotationFormat] = None,
950
+ return_segmentation_masks: bool = None,
951
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
952
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
953
+ ) -> Dict:
954
+ """
955
+ Prepare an annotation for feeding into DeformableDetr model.
956
+ """
957
+ format = format if format is not None else self.format
958
+
959
+ if format == AnnotationFormat.COCO_DETECTION:
960
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
961
+ target = prepare_coco_detection_annotation(
962
+ image, target, return_segmentation_masks, input_data_format=input_data_format
963
+ )
964
+ elif format == AnnotationFormat.COCO_PANOPTIC:
965
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
966
+ target = prepare_coco_panoptic_annotation(
967
+ image,
968
+ target,
969
+ masks_path=masks_path,
970
+ return_masks=return_segmentation_masks,
971
+ input_data_format=input_data_format,
972
+ )
973
+ else:
974
+ raise ValueError(f"Format {format} is not supported.")
975
+ return target
976
+
977
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
978
+ def resize(
979
+ self,
980
+ image: np.ndarray,
981
+ size: Dict[str, int],
982
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
983
+ data_format: Optional[ChannelDimension] = None,
984
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
985
+ **kwargs,
986
+ ) -> np.ndarray:
987
+ """
988
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
989
+ int, smaller edge of the image will be matched to this number.
990
+
991
+ Args:
992
+ image (`np.ndarray`):
993
+ Image to resize.
994
+ size (`Dict[str, int]`):
995
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
996
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
997
+ Do NOT keep the aspect ratio.
998
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
999
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
1000
+ less or equal to `longest_edge`.
1001
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
1002
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
1003
+ `max_width`.
1004
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
1005
+ Resampling filter to use if resizing the image.
1006
+ data_format (`str` or `ChannelDimension`, *optional*):
1007
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
1008
+ image is used.
1009
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1010
+ The channel dimension format of the input image. If not provided, it will be inferred.
1011
+ """
1012
+ if "max_size" in kwargs:
1013
+ logger.warning_once(
1014
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
1015
+ "Please specify in `size['longest_edge'] instead`.",
1016
+ )
1017
+ max_size = kwargs.pop("max_size")
1018
+ else:
1019
+ max_size = None
1020
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
1021
+ if "shortest_edge" in size and "longest_edge" in size:
1022
+ new_size = get_resize_output_image_size(
1023
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
1024
+ )
1025
+ elif "max_height" in size and "max_width" in size:
1026
+ new_size = get_image_size_for_max_height_width(
1027
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
1028
+ )
1029
+ elif "height" in size and "width" in size:
1030
+ new_size = (size["height"], size["width"])
1031
+ else:
1032
+ raise ValueError(
1033
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
1034
+ f" {size.keys()}."
1035
+ )
1036
+ image = resize(
1037
+ image,
1038
+ size=new_size,
1039
+ resample=resample,
1040
+ data_format=data_format,
1041
+ input_data_format=input_data_format,
1042
+ **kwargs,
1043
+ )
1044
+ return image
1045
+
1046
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
1047
+ def resize_annotation(
1048
+ self,
1049
+ annotation,
1050
+ orig_size,
1051
+ size,
1052
+ resample: PILImageResampling = PILImageResampling.NEAREST,
1053
+ ) -> Dict:
1054
+ """
1055
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
1056
+ to this number.
1057
+ """
1058
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
1059
+
1060
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
1061
+ def rescale(
1062
+ self,
1063
+ image: np.ndarray,
1064
+ rescale_factor: float,
1065
+ data_format: Optional[Union[str, ChannelDimension]] = None,
1066
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1067
+ ) -> np.ndarray:
1068
+ """
1069
+ Rescale the image by the given factor. image = image * rescale_factor.
1070
+
1071
+ Args:
1072
+ image (`np.ndarray`):
1073
+ Image to rescale.
1074
+ rescale_factor (`float`):
1075
+ The value to use for rescaling.
1076
+ data_format (`str` or `ChannelDimension`, *optional*):
1077
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
1078
+ image is used. Can be one of:
1079
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1080
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1081
+ input_data_format (`str` or `ChannelDimension`, *optional*):
1082
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
1083
+ one of:
1084
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1085
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1086
+ """
1087
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
1088
+
1089
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
1090
+ def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
1091
+ """
1092
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
1093
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
1094
+ """
1095
+ return normalize_annotation(annotation, image_size=image_size)
1096
+
1097
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
1098
+ def _update_annotation_for_padded_image(
1099
+ self,
1100
+ annotation: Dict,
1101
+ input_image_size: Tuple[int, int],
1102
+ output_image_size: Tuple[int, int],
1103
+ padding,
1104
+ update_bboxes,
1105
+ ) -> Dict:
1106
+ """
1107
+ Update the annotation for a padded image.
1108
+ """
1109
+ new_annotation = {}
1110
+ new_annotation["size"] = output_image_size
1111
+
1112
+ for key, value in annotation.items():
1113
+ if key == "masks":
1114
+ masks = value
1115
+ masks = pad(
1116
+ masks,
1117
+ padding,
1118
+ mode=PaddingMode.CONSTANT,
1119
+ constant_values=0,
1120
+ input_data_format=ChannelDimension.FIRST,
1121
+ )
1122
+ masks = safe_squeeze(masks, 1)
1123
+ new_annotation["masks"] = masks
1124
+ elif key == "boxes" and update_bboxes:
1125
+ boxes = value
1126
+ boxes *= np.asarray(
1127
+ [
1128
+ input_image_size[1] / output_image_size[1],
1129
+ input_image_size[0] / output_image_size[0],
1130
+ input_image_size[1] / output_image_size[1],
1131
+ input_image_size[0] / output_image_size[0],
1132
+ ]
1133
+ )
1134
+ new_annotation["boxes"] = boxes
1135
+ elif key == "size":
1136
+ new_annotation["size"] = output_image_size
1137
+ else:
1138
+ new_annotation[key] = value
1139
+ return new_annotation
1140
+
1141
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
1142
+ def _pad_image(
1143
+ self,
1144
+ image: np.ndarray,
1145
+ output_size: Tuple[int, int],
1146
+ annotation: Optional[Dict[str, Any]] = None,
1147
+ constant_values: Union[float, Iterable[float]] = 0,
1148
+ data_format: Optional[ChannelDimension] = None,
1149
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1150
+ update_bboxes: bool = True,
1151
+ ) -> np.ndarray:
1152
+ """
1153
+ Pad an image with zeros to the given size.
1154
+ """
1155
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
1156
+ output_height, output_width = output_size
1157
+
1158
+ pad_bottom = output_height - input_height
1159
+ pad_right = output_width - input_width
1160
+ padding = ((0, pad_bottom), (0, pad_right))
1161
+ padded_image = pad(
1162
+ image,
1163
+ padding,
1164
+ mode=PaddingMode.CONSTANT,
1165
+ constant_values=constant_values,
1166
+ data_format=data_format,
1167
+ input_data_format=input_data_format,
1168
+ )
1169
+ if annotation is not None:
1170
+ annotation = self._update_annotation_for_padded_image(
1171
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
1172
+ )
1173
+ return padded_image, annotation
1174
+
1175
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
1176
+ def pad(
1177
+ self,
1178
+ images: List[np.ndarray],
1179
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
1180
+ constant_values: Union[float, Iterable[float]] = 0,
1181
+ return_pixel_mask: bool = True,
1182
+ return_tensors: Optional[Union[str, TensorType]] = None,
1183
+ data_format: Optional[ChannelDimension] = None,
1184
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1185
+ update_bboxes: bool = True,
1186
+ pad_size: Optional[Dict[str, int]] = None,
1187
+ ) -> BatchFeature:
1188
+ """
1189
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
1190
+ in the batch and optionally returns their corresponding pixel mask.
1191
+
1192
+ Args:
1193
+ images (List[`np.ndarray`]):
1194
+ Images to pad.
1195
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
1196
+ Annotations to transform according to the padding that is applied to the images.
1197
+ constant_values (`float` or `Iterable[float]`, *optional*):
1198
+ The value to use for the padding if `mode` is `"constant"`.
1199
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
1200
+ Whether to return a pixel mask.
1201
+ return_tensors (`str` or `TensorType`, *optional*):
1202
+ The type of tensors to return. Can be one of:
1203
+ - Unset: Return a list of `np.ndarray`.
1204
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
1205
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
1206
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
1207
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
1208
+ data_format (`str` or `ChannelDimension`, *optional*):
1209
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
1210
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1211
+ The channel dimension format of the input image. If not provided, it will be inferred.
1212
+ update_bboxes (`bool`, *optional*, defaults to `True`):
1213
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
1214
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
1215
+ format, the bounding boxes will not be updated.
1216
+ pad_size (`Dict[str, int]`, *optional*):
1217
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
1218
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
1219
+ height and width in the batch.
1220
+ """
1221
+ pad_size = pad_size if pad_size is not None else self.pad_size
1222
+ if pad_size is not None:
1223
+ padded_size = (pad_size["height"], pad_size["width"])
1224
+ else:
1225
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
1226
+
1227
+ annotation_list = annotations if annotations is not None else [None] * len(images)
1228
+ padded_images = []
1229
+ padded_annotations = []
1230
+ for image, annotation in zip(images, annotation_list):
1231
+ padded_image, padded_annotation = self._pad_image(
1232
+ image,
1233
+ padded_size,
1234
+ annotation,
1235
+ constant_values=constant_values,
1236
+ data_format=data_format,
1237
+ input_data_format=input_data_format,
1238
+ update_bboxes=update_bboxes,
1239
+ )
1240
+ padded_images.append(padded_image)
1241
+ padded_annotations.append(padded_annotation)
1242
+
1243
+ data = {"pixel_values": padded_images}
1244
+
1245
+ if return_pixel_mask:
1246
+ masks = [
1247
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
1248
+ for image in images
1249
+ ]
1250
+ data["pixel_mask"] = masks
1251
+
1252
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
1253
+
1254
+ if annotations is not None:
1255
+ encoded_inputs["labels"] = [
1256
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
1257
+ ]
1258
+
1259
+ return encoded_inputs
1260
+
1261
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.preprocess
1262
+ def preprocess(
1263
+ self,
1264
+ images: ImageInput,
1265
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
1266
+ return_segmentation_masks: bool = None,
1267
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
1268
+ do_resize: Optional[bool] = None,
1269
+ size: Optional[Dict[str, int]] = None,
1270
+ resample=None, # PILImageResampling
1271
+ do_rescale: Optional[bool] = None,
1272
+ rescale_factor: Optional[Union[int, float]] = None,
1273
+ do_normalize: Optional[bool] = None,
1274
+ do_convert_annotations: Optional[bool] = None,
1275
+ image_mean: Optional[Union[float, List[float]]] = None,
1276
+ image_std: Optional[Union[float, List[float]]] = None,
1277
+ do_pad: Optional[bool] = None,
1278
+ format: Optional[Union[str, AnnotationFormat]] = None,
1279
+ return_tensors: Optional[Union[TensorType, str]] = None,
1280
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
1281
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1282
+ pad_size: Optional[Dict[str, int]] = None,
1283
+ **kwargs,
1284
+ ) -> BatchFeature:
1285
+ """
1286
+ Preprocess an image or a batch of images so that it can be used by the model.
1287
+
1288
+ Args:
1289
+ images (`ImageInput`):
1290
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
1291
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
1292
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
1293
+ List of annotations associated with the image or batch of images. If annotation is for object
1294
+ detection, the annotations should be a dictionary with the following keys:
1295
+ - "image_id" (`int`): The image id.
1296
+ - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
1297
+ dictionary. An image can have no annotations, in which case the list should be empty.
1298
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
1299
+ - "image_id" (`int`): The image id.
1300
+ - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
1301
+ An image can have no segments, in which case the list should be empty.
1302
+ - "file_name" (`str`): The file name of the image.
1303
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
1304
+ Whether to return segmentation masks.
1305
+ masks_path (`str` or `pathlib.Path`, *optional*):
1306
+ Path to the directory containing the segmentation masks.
1307
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
1308
+ Whether to resize the image.
1309
+ size (`Dict[str, int]`, *optional*, defaults to self.size):
1310
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
1311
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
1312
+ Do NOT keep the aspect ratio.
1313
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
1314
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
1315
+ less or equal to `longest_edge`.
1316
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
1317
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
1318
+ `max_width`.
1319
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
1320
+ Resampling filter to use when resizing the image.
1321
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
1322
+ Whether to rescale the image.
1323
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
1324
+ Rescale factor to use when rescaling the image.
1325
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
1326
+ Whether to normalize the image.
1327
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
1328
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
1329
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
1330
+ and in relative coordinates.
1331
+ image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
1332
+ Mean to use when normalizing the image.
1333
+ image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
1334
+ Standard deviation to use when normalizing the image.
1335
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
1336
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
1337
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
1338
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
1339
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
1340
+ Format of the annotations.
1341
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
1342
+ Type of tensors to return. If `None`, will return the list of images.
1343
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
1344
+ The channel dimension format for the output image. Can be one of:
1345
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1346
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1347
+ - Unset: Use the channel dimension format of the input image.
1348
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1349
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
1350
+ from the input image. Can be one of:
1351
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1352
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1353
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
1354
+ pad_size (`Dict[str, int]`, *optional*):
1355
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
1356
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
1357
+ height and width in the batch.
1358
+ """
1359
+ if "pad_and_return_pixel_mask" in kwargs:
1360
+ logger.warning_once(
1361
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
1362
+ "use `do_pad` instead."
1363
+ )
1364
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
1365
+
1366
+ if "max_size" in kwargs:
1367
+ logger.warning_once(
1368
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
1369
+ " `size['longest_edge']` instead."
1370
+ )
1371
+ size = kwargs.pop("max_size")
1372
+
1373
+ do_resize = self.do_resize if do_resize is None else do_resize
1374
+ size = self.size if size is None else size
1375
+ size = get_size_dict(size=size, default_to_square=False)
1376
+ resample = self.resample if resample is None else resample
1377
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
1378
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
1379
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
1380
+ image_mean = self.image_mean if image_mean is None else image_mean
1381
+ image_std = self.image_std if image_std is None else image_std
1382
+ do_convert_annotations = (
1383
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
1384
+ )
1385
+ do_pad = self.do_pad if do_pad is None else do_pad
1386
+ pad_size = self.pad_size if pad_size is None else pad_size
1387
+ format = self.format if format is None else format
1388
+
1389
+ images = make_list_of_images(images)
1390
+
1391
+ if not valid_images(images):
1392
+ raise ValueError(
1393
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
1394
+ "torch.Tensor, tf.Tensor or jax.ndarray."
1395
+ )
1396
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
1397
+
1398
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
1399
+ validate_preprocess_arguments(
1400
+ do_rescale=do_rescale,
1401
+ rescale_factor=rescale_factor,
1402
+ do_normalize=do_normalize,
1403
+ image_mean=image_mean,
1404
+ image_std=image_std,
1405
+ do_resize=do_resize,
1406
+ size=size,
1407
+ resample=resample,
1408
+ )
1409
+
1410
+ if annotations is not None and isinstance(annotations, dict):
1411
+ annotations = [annotations]
1412
+
1413
+ if annotations is not None and len(images) != len(annotations):
1414
+ raise ValueError(
1415
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
1416
+ )
1417
+
1418
+ format = AnnotationFormat(format)
1419
+ if annotations is not None:
1420
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
1421
+
1422
+ if (
1423
+ masks_path is not None
1424
+ and format == AnnotationFormat.COCO_PANOPTIC
1425
+ and not isinstance(masks_path, (pathlib.Path, str))
1426
+ ):
1427
+ raise ValueError(
1428
+ "The path to the directory containing the mask PNG files should be provided as a"
1429
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
1430
+ )
1431
+
1432
+ # All transformations expect numpy arrays
1433
+ images = [to_numpy_array(image) for image in images]
1434
+
1435
+ if is_scaled_image(images[0]) and do_rescale:
1436
+ logger.warning_once(
1437
+ "It looks like you are trying to rescale already rescaled images. If the input"
1438
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
1439
+ )
1440
+
1441
+ if input_data_format is None:
1442
+ # We assume that all images have the same channel dimension format.
1443
+ input_data_format = infer_channel_dimension_format(images[0])
1444
+
1445
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
1446
+ if annotations is not None:
1447
+ prepared_images = []
1448
+ prepared_annotations = []
1449
+ for image, target in zip(images, annotations):
1450
+ target = self.prepare_annotation(
1451
+ image,
1452
+ target,
1453
+ format,
1454
+ return_segmentation_masks=return_segmentation_masks,
1455
+ masks_path=masks_path,
1456
+ input_data_format=input_data_format,
1457
+ )
1458
+ prepared_images.append(image)
1459
+ prepared_annotations.append(target)
1460
+ images = prepared_images
1461
+ annotations = prepared_annotations
1462
+ del prepared_images, prepared_annotations
1463
+
1464
+ # transformations
1465
+ if do_resize:
1466
+ if annotations is not None:
1467
+ resized_images, resized_annotations = [], []
1468
+ for image, target in zip(images, annotations):
1469
+ orig_size = get_image_size(image, input_data_format)
1470
+ resized_image = self.resize(
1471
+ image, size=size, resample=resample, input_data_format=input_data_format
1472
+ )
1473
+ resized_annotation = self.resize_annotation(
1474
+ target, orig_size, get_image_size(resized_image, input_data_format)
1475
+ )
1476
+ resized_images.append(resized_image)
1477
+ resized_annotations.append(resized_annotation)
1478
+ images = resized_images
1479
+ annotations = resized_annotations
1480
+ del resized_images, resized_annotations
1481
+ else:
1482
+ images = [
1483
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
1484
+ for image in images
1485
+ ]
1486
+
1487
+ if do_rescale:
1488
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
1489
+
1490
+ if do_normalize:
1491
+ images = [
1492
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
1493
+ ]
1494
+
1495
+ if do_convert_annotations and annotations is not None:
1496
+ annotations = [
1497
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
1498
+ for annotation, image in zip(annotations, images)
1499
+ ]
1500
+
1501
+ if do_pad:
1502
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
1503
+ encoded_inputs = self.pad(
1504
+ images,
1505
+ annotations=annotations,
1506
+ return_pixel_mask=True,
1507
+ data_format=data_format,
1508
+ input_data_format=input_data_format,
1509
+ update_bboxes=do_convert_annotations,
1510
+ return_tensors=return_tensors,
1511
+ pad_size=pad_size,
1512
+ )
1513
+ else:
1514
+ images = [
1515
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
1516
+ for image in images
1517
+ ]
1518
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
1519
+ if annotations is not None:
1520
+ encoded_inputs["labels"] = [
1521
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
1522
+ ]
1523
+
1524
+ return encoded_inputs
1525
+
1526
+ # POSTPROCESSING METHODS - TODO: add support for other frameworks
1527
+ def post_process(self, outputs, target_sizes):
1528
+ """
1529
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
1530
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1531
+
1532
+ Args:
1533
+ outputs ([`DeformableDetrObjectDetectionOutput`]):
1534
+ Raw outputs of the model.
1535
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
1536
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
1537
+ original image size (before any data augmentation). For visualization, this should be the image size
1538
+ after data augment, but before padding.
1539
+ Returns:
1540
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1541
+ in the batch as predicted by the model.
1542
+ """
1543
+ logger.warning_once(
1544
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
1545
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
1546
+ )
1547
+
1548
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1549
+
1550
+ if len(out_logits) != len(target_sizes):
1551
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
1552
+ if target_sizes.shape[1] != 2:
1553
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
1554
+
1555
+ prob = out_logits.sigmoid()
1556
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
1557
+ scores = topk_values
1558
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
1559
+ labels = topk_indexes % out_logits.shape[2]
1560
+ boxes = center_to_corners_format(out_bbox)
1561
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
1562
+
1563
+ # and from relative [0, 1] to absolute [0, height] coordinates
1564
+ img_h, img_w = target_sizes.unbind(1)
1565
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
1566
+ boxes = boxes * scale_fct[:, None, :]
1567
+
1568
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
1569
+
1570
+ return results
1571
+
1572
+ def post_process_object_detection(
1573
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
1574
+ ):
1575
+ """
1576
+ Converts the raw output of [`DiffusionDet`] into final bounding boxes in (top_left_x,
1577
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1578
+
1579
+ Args:
1580
+ outputs ([`DetrObjectDetectionOutput`]):
1581
+ Raw outputs of the model.
1582
+ threshold (`float`, *optional*):
1583
+ Score threshold to keep object detection predictions.
1584
+ target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
1585
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
1586
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
1587
+ top_k (`int`, *optional*, defaults to 100):
1588
+ Keep only top k bounding boxes before filtering by thresholding.
1589
+
1590
+ Returns:
1591
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1592
+ in the batch as predicted by the model.
1593
+ """
1594
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1595
+
1596
+ if target_sizes is not None:
1597
+ if len(out_logits) != len(target_sizes):
1598
+ raise ValueError(
1599
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1600
+ )
1601
+
1602
+ prob = out_logits.sigmoid()
1603
+ prob = prob.view(out_logits.shape[0], -1)
1604
+ k_value = min(top_k, prob.size(1))
1605
+ topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
1606
+ scores = topk_values
1607
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
1608
+ labels = topk_indexes % out_logits.shape[2]
1609
+ boxes = center_to_corners_format(out_bbox)
1610
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
1611
+
1612
+ # and from relative [0, 1] to absolute [0, height] coordinates
1613
+ if target_sizes is not None:
1614
+ if isinstance(target_sizes, List):
1615
+ img_h = torch.Tensor([i[0] for i in target_sizes])
1616
+ img_w = torch.Tensor([i[1] for i in target_sizes])
1617
+ else:
1618
+ img_h, img_w = target_sizes.unbind(1)
1619
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1620
+ boxes = boxes * scale_fct[:, None, :]
1621
+
1622
+ results = []
1623
+ for s, l, b in zip(scores, labels, boxes):
1624
+ score = s[s > threshold]
1625
+ label = l[s > threshold]
1626
+ box = b[s > threshold]
1627
+ results.append({"scores": score, "labels": label, "boxes": box})
1628
+
1629
+ return results
1630
+
1631
+
1632
+ __all__ = ["DiffusionDetImageProcessor"]
loss.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from fvcore.nn import sigmoid_focal_loss_jit
4
+ from torch import nn
5
+
6
+ import torch.distributed as dist
7
+ from torch.distributed import get_world_size
8
+ from torchvision import ops
9
+
10
+
11
+ def is_dist_avail_and_initialized():
12
+ if not dist.is_available():
13
+ return False
14
+ if not dist.is_initialized():
15
+ return False
16
+ return True
17
+
18
+
19
+ def get_fed_loss_classes(gt_classes, num_fed_loss_classes, num_classes, weight):
20
+ """
21
+ Args:
22
+ gt_classes: a long tensor of shape R that contains the gt class label of each proposal.
23
+ num_fed_loss_classes: minimum number of classes to keep when calculating federated loss.
24
+ Will sample negative classes if number of unique gt_classes is smaller than this value.
25
+ num_classes: number of foreground classes
26
+ weight: probabilities used to sample negative classes
27
+ Returns:
28
+ Tensor:
29
+ classes to keep when calculating the federated loss, including both unique gt
30
+ classes and sampled negative classes.
31
+ """
32
+ unique_gt_classes = torch.unique(gt_classes)
33
+ prob = unique_gt_classes.new_ones(num_classes + 1).float()
34
+ prob[-1] = 0
35
+ if len(unique_gt_classes) < num_fed_loss_classes:
36
+ prob[:num_classes] = weight.float().clone()
37
+ prob[unique_gt_classes] = 0
38
+ sampled_negative_classes = torch.multinomial(
39
+ prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False
40
+ )
41
+ fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes])
42
+ else:
43
+ fed_loss_classes = unique_gt_classes
44
+ return fed_loss_classes
45
+
46
+
47
+ class CriterionDynamicK(nn.Module):
48
+ """ This class computes the loss for DiffusionDet.
49
+ The process happens in two steps:
50
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
51
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
52
+ """
53
+
54
+ def __init__(self, config, num_classes, weight_dict):
55
+ """ Create the criterion.
56
+ Parameters:
57
+ num_classes: number of object categories, omitting the special no-object category
58
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
59
+ """
60
+ super().__init__()
61
+ self.config = config
62
+ self.num_classes = num_classes
63
+ self.matcher = HungarianMatcherDynamicK(config)
64
+ self.weight_dict = weight_dict
65
+ self.eos_coef = config.no_object_weight
66
+ self.use_focal = config.use_focal
67
+ self.use_fed_loss = config.use_fed_loss
68
+
69
+ if self.use_focal:
70
+ self.focal_loss_alpha = config.alpha
71
+ self.focal_loss_gamma = config.gamma
72
+
73
+ # copy-paste from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/roi_heads/fast_rcnn.py#L356
74
+ def loss_labels(self, outputs, targets, indices):
75
+ """Classification loss (NLL)
76
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
77
+ """
78
+ assert 'pred_logits' in outputs
79
+ src_logits = outputs['pred_logits']
80
+ batch_size = len(targets)
81
+
82
+ # idx = self._get_src_permutation_idx(indices)
83
+ # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
84
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
85
+ dtype=torch.int64, device=src_logits.device)
86
+ src_logits_list = []
87
+ target_classes_o_list = []
88
+ # target_classes[idx] = target_classes_o
89
+ for batch_idx in range(batch_size):
90
+ valid_query = indices[batch_idx][0]
91
+ gt_multi_idx = indices[batch_idx][1]
92
+ if len(gt_multi_idx) == 0:
93
+ continue
94
+ bz_src_logits = src_logits[batch_idx]
95
+ target_classes_o = targets[batch_idx]["labels"]
96
+ target_classes[batch_idx, valid_query] = target_classes_o[gt_multi_idx]
97
+
98
+ src_logits_list.append(bz_src_logits[valid_query])
99
+ target_classes_o_list.append(target_classes_o[gt_multi_idx])
100
+
101
+ if self.use_focal or self.use_fed_loss:
102
+ num_boxes = torch.cat(target_classes_o_list).shape[0] if len(target_classes_o_list) != 0 else 1
103
+
104
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], self.num_classes + 1],
105
+ dtype=src_logits.dtype, layout=src_logits.layout,
106
+ device=src_logits.device)
107
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
108
+
109
+ gt_classes = torch.argmax(target_classes_onehot, dim=-1)
110
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
111
+
112
+ src_logits = src_logits.flatten(0, 1)
113
+ target_classes_onehot = target_classes_onehot.flatten(0, 1)
114
+ if self.use_focal:
115
+ cls_loss = sigmoid_focal_loss_jit(src_logits, target_classes_onehot, alpha=self.focal_loss_alpha,
116
+ gamma=self.focal_loss_gamma, reduction="none")
117
+ else:
118
+ cls_loss = F.binary_cross_entropy_with_logits(src_logits, target_classes_onehot, reduction="none")
119
+ if self.use_fed_loss:
120
+ K = self.num_classes
121
+ N = src_logits.shape[0]
122
+ fed_loss_classes = get_fed_loss_classes(
123
+ gt_classes,
124
+ num_fed_loss_classes=self.fed_loss_num_classes,
125
+ num_classes=K,
126
+ weight=self.fed_loss_cls_weights,
127
+ )
128
+ fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1)
129
+ fed_loss_classes_mask[fed_loss_classes] = 1
130
+ fed_loss_classes_mask = fed_loss_classes_mask[:K]
131
+ weight = fed_loss_classes_mask.view(1, K).expand(N, K).float()
132
+
133
+ loss_ce = torch.sum(cls_loss * weight) / num_boxes
134
+ else:
135
+ loss_ce = torch.sum(cls_loss) / num_boxes
136
+
137
+ losses = {'loss_ce': loss_ce}
138
+ else:
139
+ raise NotImplementedError
140
+
141
+ return losses
142
+
143
+ def loss_boxes(self, outputs, targets, indices):
144
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
145
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
146
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
147
+ """
148
+ assert 'pred_boxes' in outputs
149
+ # idx = self._get_src_permutation_idx(indices)
150
+ src_boxes = outputs['pred_boxes']
151
+
152
+ batch_size = len(targets)
153
+ pred_box_list = []
154
+ pred_norm_box_list = []
155
+ tgt_box_list = []
156
+ tgt_box_xyxy_list = []
157
+ for batch_idx in range(batch_size):
158
+ valid_query = indices[batch_idx][0]
159
+ gt_multi_idx = indices[batch_idx][1]
160
+ if len(gt_multi_idx) == 0:
161
+ continue
162
+ bz_image_whwh = targets[batch_idx]['image_size_xyxy']
163
+ bz_src_boxes = src_boxes[batch_idx]
164
+ bz_target_boxes = targets[batch_idx]["boxes"] # normalized (cx, cy, w, h)
165
+ bz_target_boxes_xyxy = targets[batch_idx]["boxes_xyxy"] # absolute (x1, y1, x2, y2)
166
+ pred_box_list.append(bz_src_boxes[valid_query])
167
+ pred_norm_box_list.append(bz_src_boxes[valid_query] / bz_image_whwh) # normalize (x1, y1, x2, y2)
168
+ tgt_box_list.append(bz_target_boxes[gt_multi_idx])
169
+ tgt_box_xyxy_list.append(bz_target_boxes_xyxy[gt_multi_idx])
170
+
171
+ if len(pred_box_list) != 0:
172
+ src_boxes = torch.cat(pred_box_list)
173
+ src_boxes_norm = torch.cat(pred_norm_box_list) # normalized (x1, y1, x2, y2)
174
+ target_boxes = torch.cat(tgt_box_list)
175
+ target_boxes_abs_xyxy = torch.cat(tgt_box_xyxy_list)
176
+ num_boxes = src_boxes.shape[0]
177
+
178
+ losses = {}
179
+ # require normalized (x1, y1, x2, y2)
180
+ loss_bbox = F.l1_loss(src_boxes_norm, ops.box_convert(target_boxes, 'cxcywh', 'xyxy'), reduction='none')
181
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
182
+
183
+ # loss_giou = giou_loss(box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))
184
+ loss_giou = 1 - torch.diag(ops.generalized_box_iou(src_boxes, target_boxes_abs_xyxy))
185
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
186
+ else:
187
+ losses = {'loss_bbox': outputs['pred_boxes'].sum() * 0,
188
+ 'loss_giou': outputs['pred_boxes'].sum() * 0}
189
+
190
+ return losses
191
+
192
+ def get_loss(self, loss, outputs, targets, indices):
193
+ loss_map = {
194
+ 'labels': self.loss_labels,
195
+ 'boxes': self.loss_boxes,
196
+ }
197
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
198
+ return loss_map[loss](outputs, targets, indices)
199
+
200
+ def forward(self, outputs, targets):
201
+ """ This performs the loss computation.
202
+ Parameters:
203
+ outputs: dict of tensors, see the output specification of the model for the format
204
+ targets: list of dicts, such that len(targets) == batch_size.
205
+ The expected keys in each dict depends on the losses applied, see each loss' doc
206
+ """
207
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
208
+
209
+ # Retrieve the matching between the outputs of the last layer and the targets
210
+ indices, _ = self.matcher(outputs_without_aux, targets)
211
+
212
+ # Compute all the requested losses
213
+ losses = {}
214
+ for loss in ["labels", "boxes"]:
215
+ losses.update(self.get_loss(loss, outputs, targets, indices))
216
+
217
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
218
+ if 'aux_outputs' in outputs:
219
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
220
+ indices, _ = self.matcher(aux_outputs, targets)
221
+ for loss in ["labels", "boxes"]:
222
+ if loss == 'masks':
223
+ # Intermediate masks losses are too costly to compute, we ignore them.
224
+ continue
225
+
226
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices)
227
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
228
+ losses.update(l_dict)
229
+
230
+ return losses
231
+
232
+
233
+ def get_in_boxes_info(boxes, target_gts):
234
+ xy_target_gts = ops.box_convert(target_gts, 'cxcywh', 'xyxy') # (x1, y1, x2, y2)
235
+
236
+ anchor_center_x = boxes[:, 0].unsqueeze(1)
237
+ anchor_center_y = boxes[:, 1].unsqueeze(1)
238
+
239
+ # whether the center of each anchor is inside a gt box
240
+ b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0)
241
+ b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0)
242
+ b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0)
243
+ b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0)
244
+ # (b_l.long()+b_r.long()+b_t.long()+b_b.long())==4 [300,num_gt] ,
245
+ is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
246
+ is_in_boxes_all = is_in_boxes.sum(1) > 0 # [num_query]
247
+ # in fixed center
248
+ center_radius = 2.5
249
+ # Modified to self-adapted sampling --- the center size depends on the size of the gt boxes
250
+ # https://github.com/dulucas/UVO_Challenge/blob/main/Track1/detection/mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py#L212
251
+ b_l = anchor_center_x > (
252
+ target_gts[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
253
+ b_r = anchor_center_x < (
254
+ target_gts[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
255
+ b_t = anchor_center_y > (
256
+ target_gts[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
257
+ b_b = anchor_center_y < (
258
+ target_gts[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
259
+
260
+ is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
261
+ is_in_centers_all = is_in_centers.sum(1) > 0
262
+
263
+ is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
264
+ is_in_boxes_and_center = (is_in_boxes & is_in_centers)
265
+
266
+ return is_in_boxes_anchor, is_in_boxes_and_center
267
+
268
+
269
+ class HungarianMatcherDynamicK(nn.Module):
270
+ """This class computes an assignment between the targets and the predictions of the network
271
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
272
+ there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions,
273
+ while the others are un-matched (and thus treated as non-objects).
274
+ """
275
+
276
+ def __init__(self, config):
277
+ super().__init__()
278
+ self.use_focal = config.use_focal
279
+ self.use_fed_loss = config.use_fed_loss
280
+ self.cost_class = config.class_weight
281
+ self.cost_giou = config.giou_weight
282
+ self.cost_bbox = config.l1_weight
283
+ self.ota_k = config.ota_k
284
+
285
+ if self.use_focal:
286
+ self.focal_loss_alpha = config.alpha
287
+ self.focal_loss_gamma = config.gamma
288
+
289
+ assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0"
290
+
291
+ def forward(self, outputs, targets):
292
+ """ simOTA for detr"""
293
+ with torch.no_grad():
294
+ bs, num_queries = outputs["pred_logits"].shape[:2]
295
+ # We flatten to compute the cost matrices in a batch
296
+ if self.use_focal or self.use_fed_loss:
297
+ out_prob = outputs["pred_logits"].sigmoid() # [batch_size, num_queries, num_classes]
298
+ out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
299
+ else:
300
+ out_prob = outputs["pred_logits"].softmax(-1) # [batch_size, num_queries, num_classes]
301
+ out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
302
+
303
+ indices = []
304
+ matched_ids = []
305
+ assert bs == len(targets)
306
+ for batch_idx in range(bs):
307
+ bz_boxes = out_bbox[batch_idx] # [num_proposals, 4]
308
+ bz_out_prob = out_prob[batch_idx]
309
+ bz_tgt_ids = targets[batch_idx]["labels"]
310
+ num_insts = len(bz_tgt_ids)
311
+ if num_insts == 0: # empty object in key frame
312
+ non_valid = torch.zeros(bz_out_prob.shape[0]).to(bz_out_prob) > 0
313
+ indices_batchi = (non_valid, torch.arange(0, 0).to(bz_out_prob))
314
+ matched_qidx = torch.arange(0, 0).to(bz_out_prob)
315
+ indices.append(indices_batchi)
316
+ matched_ids.append(matched_qidx)
317
+ continue
318
+
319
+ bz_gtboxs = targets[batch_idx]['boxes'] # [num_gt, 4] normalized (cx, xy, w, h)
320
+ bz_gtboxs_abs_xyxy = targets[batch_idx]['boxes_xyxy']
321
+ fg_mask, is_in_boxes_and_center = get_in_boxes_info(
322
+ ops.box_convert(bz_boxes, 'xyxy', 'cxcywh'), # absolute (cx, cy, w, h)
323
+ ops.box_convert(bz_gtboxs_abs_xyxy, 'xyxy', 'cxcywh') # absolute (cx, cy, w, h)
324
+ )
325
+
326
+ pair_wise_ious = ops.box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
327
+
328
+ # Compute the classification cost.
329
+ if self.use_focal:
330
+ alpha = self.focal_loss_alpha
331
+ gamma = self.focal_loss_gamma
332
+ neg_cost_class = (1 - alpha) * (bz_out_prob ** gamma) * (-(1 - bz_out_prob + 1e-8).log())
333
+ pos_cost_class = alpha * ((1 - bz_out_prob) ** gamma) * (-(bz_out_prob + 1e-8).log())
334
+ cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
335
+ elif self.use_fed_loss:
336
+ # focal loss degenerates to naive one
337
+ neg_cost_class = (-(1 - bz_out_prob + 1e-8).log())
338
+ pos_cost_class = (-(bz_out_prob + 1e-8).log())
339
+ cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
340
+ else:
341
+ cost_class = -bz_out_prob[:, bz_tgt_ids]
342
+
343
+ # Compute the L1 cost between boxes
344
+ # image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets])
345
+ # image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1)
346
+ # image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets])
347
+
348
+ bz_image_size_out = targets[batch_idx]['image_size_xyxy']
349
+ bz_image_size_tgt = targets[batch_idx]['image_size_xyxy_tgt']
350
+
351
+ bz_out_bbox_ = bz_boxes / bz_image_size_out # normalize (x1, y1, x2, y2)
352
+ bz_tgt_bbox_ = bz_gtboxs_abs_xyxy / bz_image_size_tgt # normalize (x1, y1, x2, y2)
353
+ cost_bbox = torch.cdist(bz_out_bbox_, bz_tgt_bbox_, p=1)
354
+
355
+ cost_giou = -ops.generalized_box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
356
+
357
+ # Final cost matrix
358
+ cost = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + 100.0 * (
359
+ ~is_in_boxes_and_center)
360
+ # cost = (cost_class + 3.0 * cost_giou + 100.0 * (~is_in_boxes_and_center)) # [num_query,num_gt]
361
+ cost[~fg_mask] = cost[~fg_mask] + 10000.0
362
+
363
+ # if bz_gtboxs.shape[0]>0:
364
+ indices_batchi, matched_qidx = self.dynamic_k_matching(cost, pair_wise_ious, bz_gtboxs.shape[0])
365
+
366
+ indices.append(indices_batchi)
367
+ matched_ids.append(matched_qidx)
368
+
369
+ return indices, matched_ids
370
+
371
+ def dynamic_k_matching(self, cost, pair_wise_ious, num_gt):
372
+ matching_matrix = torch.zeros_like(cost) # [300,num_gt]
373
+ ious_in_boxes_matrix = pair_wise_ious
374
+ n_candidate_k = self.ota_k
375
+
376
+ # Take the sum of the predicted value and the top 10 iou of gt with the largest iou as dynamic_k
377
+ topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0)
378
+ dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
379
+
380
+ for gt_idx in range(num_gt):
381
+ _, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
382
+ matching_matrix[:, gt_idx][pos_idx] = 1.0
383
+
384
+ del topk_ious, dynamic_ks, pos_idx
385
+
386
+ anchor_matching_gt = matching_matrix.sum(1)
387
+
388
+ if (anchor_matching_gt > 1).sum() > 0:
389
+ _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1)
390
+ matching_matrix[anchor_matching_gt > 1] *= 0
391
+ matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1
392
+
393
+ while (matching_matrix.sum(0) == 0).any():
394
+ num_zero_gt = (matching_matrix.sum(0) == 0).sum()
395
+ matched_query_id = matching_matrix.sum(1) > 0
396
+ cost[matched_query_id] += 100000.0
397
+ unmatch_id = torch.nonzero(matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
398
+ for gt_idx in unmatch_id:
399
+ pos_idx = torch.argmin(cost[:, gt_idx])
400
+ matching_matrix[:, gt_idx][pos_idx] = 1.0
401
+ if (matching_matrix.sum(1) > 1).sum() > 0: # If a query matches more than one gt
402
+ _, cost_argmin = torch.min(cost[anchor_matching_gt > 1],
403
+ dim=1) # find gt for these queries with minimal cost
404
+ matching_matrix[anchor_matching_gt > 1] *= 0 # reset mapping relationship
405
+ matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 # keep gt with minimal cost
406
+
407
+ assert not (matching_matrix.sum(0) == 0).any()
408
+ selected_query = matching_matrix.sum(1) > 0
409
+ gt_indices = matching_matrix[selected_query].max(1)[1]
410
+ assert selected_query.sum() == len(gt_indices)
411
+
412
+ cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf')
413
+ matched_query_id = torch.min(cost, dim=0)[1]
414
+
415
+ return (selected_query, gt_indices), matched_query_id
modeling_diffusiondet.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from collections import namedtuple, OrderedDict
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from torchvision import ops
11
+ from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
12
+ from transformers import PreTrainedModel
13
+ import wandb
14
+
15
+ from transformers.utils.backbone_utils import load_backbone
16
+ from .configuration_diffusiondet import DiffusionDetConfig
17
+
18
+ from .head import HeadDynamicK
19
+ from .loss import CriterionDynamicK
20
+
21
+ from transformers.utils import ModelOutput
22
+
23
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
24
+
25
+
26
+ def default(val, d):
27
+ if val is not None:
28
+ return val
29
+ return d() if callable(d) else d
30
+
31
+
32
+ def extract(a, t, x_shape):
33
+ """extract the appropriate t index for a batch of indices"""
34
+ batch_size = t.shape[0]
35
+ out = a.gather(-1, t)
36
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
37
+
38
+
39
+ def cosine_beta_schedule(timesteps, s=0.008):
40
+ """
41
+ cosine schedule
42
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
43
+ """
44
+ steps = timesteps + 1
45
+ x = torch.linspace(0, timesteps, steps)
46
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
47
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
48
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
49
+ return torch.clip(betas, 0, 0.999)
50
+
51
+ @dataclass
52
+ class DiffusionDetOutput(ModelOutput):
53
+ """
54
+ Output type of DiffusionDet.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ loss_dict: Optional[Dict] = None
59
+ logits: torch.FloatTensor = None
60
+ labels: torch.IntTensor = None
61
+ pred_boxes: torch.FloatTensor = None
62
+
63
+ class DiffusionDet(PreTrainedModel):
64
+ """
65
+ Implement DiffusionDet
66
+ """
67
+ config_class = DiffusionDetConfig
68
+ main_input_name = "pixel_values"
69
+
70
+ def __init__(self, config):
71
+ super(DiffusionDet, self).__init__(config)
72
+
73
+ self.in_features = config.roi_head_in_features
74
+ self.num_classes = config.num_labels
75
+ self.num_proposals = config.num_proposals
76
+ self.num_heads = config.num_heads
77
+
78
+ self.backbone = load_backbone(config)
79
+ self.fpn = FeaturePyramidNetwork(
80
+ in_channels_list=self.backbone.channels,
81
+ out_channels=config.fpn_out_channels,
82
+ # extra_blocks=LastLevelMaxPool(),
83
+ )
84
+
85
+ # build diffusion
86
+ betas = cosine_beta_schedule(1000)
87
+ alphas_cumprod = torch.cumprod(1 - betas, dim=0)
88
+
89
+ timesteps, = betas.shape
90
+ sampling_timesteps = config.sample_step
91
+
92
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
93
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
94
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
95
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
96
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
97
+
98
+ self.num_timesteps = int(timesteps)
99
+ self.sampling_timesteps = default(sampling_timesteps, timesteps)
100
+ self.ddim_sampling_eta = 1.
101
+ self.scale = config.snr_scale
102
+ assert self.sampling_timesteps <= timesteps
103
+
104
+ roi_input_shape = {
105
+ 'p2': {'stride': 4},
106
+ 'p3': {'stride': 8},
107
+ 'p4': {'stride': 16},
108
+ 'p5': {'stride': 32},
109
+ 'p6': {'stride': 64}
110
+ }
111
+ self.head = HeadDynamicK(config, roi_input_shape=roi_input_shape)
112
+
113
+ self.deep_supervision = config.deep_supervision
114
+ self.use_focal = config.use_focal
115
+ self.use_fed_loss = config.use_fed_loss
116
+ self.use_nms = config.use_nms
117
+
118
+ weight_dict = {
119
+ "loss_ce": config.class_weight, "loss_bbox": config.l1_weight, "loss_giou": config.giou_weight
120
+ }
121
+ if self.deep_supervision:
122
+ aux_weight_dict = {}
123
+ for i in range(self.num_heads - 1):
124
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
125
+ weight_dict.update(aux_weight_dict)
126
+
127
+ self.criterion = CriterionDynamicK(config, num_classes=self.num_classes, weight_dict=weight_dict)
128
+
129
+ def predict_noise_from_start(self, x_t, t, x0):
130
+ return (
131
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) /
132
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
133
+ )
134
+
135
+ def model_predictions(self, backbone_feats, images_whwh, x, t):
136
+ x_boxes = torch.clamp(x, min=-1 * self.scale, max=self.scale)
137
+ x_boxes = ((x_boxes / self.scale) + 1) / 2
138
+ x_boxes = ops.box_convert(x_boxes, 'cxcywh', 'xyxy')
139
+ x_boxes = x_boxes * images_whwh[:, None, :]
140
+ outputs_class, outputs_coord = self.head(backbone_feats, x_boxes, t)
141
+
142
+ x_start = outputs_coord[-1] # (batch, num_proposals, 4) predict boxes: absolute coordinates (x1, y1, x2, y2)
143
+ x_start = x_start / images_whwh[:, None, :]
144
+ x_start = ops.box_convert(x_start, 'xyxy', 'cxcywh')
145
+ x_start = (x_start * 2 - 1.) * self.scale
146
+ x_start = torch.clamp(x_start, min=-1 * self.scale, max=self.scale)
147
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
148
+
149
+ return ModelPrediction(pred_noise, x_start), outputs_class, outputs_coord
150
+
151
+ @torch.no_grad()
152
+ def ddim_sample(self, batched_inputs, backbone_feats, images_whwh):
153
+ bs = len(batched_inputs)
154
+ image_sizes = batched_inputs.shape
155
+ shape = (bs, self.num_proposals, 4)
156
+
157
+ # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
158
+ times = torch.linspace(-1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1)
159
+ times = list(reversed(times.int().tolist()))
160
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
161
+
162
+ img = torch.randn(shape, device=self.device)
163
+
164
+ ensemble_score, ensemble_label, ensemble_coord = [], [], []
165
+ outputs_class, outputs_coord = None, None
166
+ for time, time_next in time_pairs:
167
+ time_cond = torch.full((bs,), time, device=self.device, dtype=torch.long)
168
+
169
+ preds, outputs_class, outputs_coord = self.model_predictions(backbone_feats, images_whwh, img, time_cond)
170
+ pred_noise, x_start = preds.pred_noise, preds.pred_x_start
171
+
172
+ score_per_image, box_per_image = outputs_class[-1][0], outputs_coord[-1][0]
173
+ threshold = 0.5
174
+ score_per_image = torch.sigmoid(score_per_image)
175
+ value, _ = torch.max(score_per_image, -1, keepdim=False)
176
+ keep_idx = value > threshold
177
+ num_remain = torch.sum(keep_idx)
178
+
179
+ pred_noise = pred_noise[:, keep_idx, :]
180
+ x_start = x_start[:, keep_idx, :]
181
+ img = img[:, keep_idx, :]
182
+
183
+ if time_next < 0:
184
+ img = x_start
185
+ continue
186
+
187
+ alpha = self.alphas_cumprod[time]
188
+ alpha_next = self.alphas_cumprod[time_next]
189
+
190
+ sigma = self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
191
+ c = (1 - alpha_next - sigma ** 2).sqrt()
192
+
193
+ noise = torch.randn_like(img)
194
+
195
+ img = x_start * alpha_next.sqrt() + \
196
+ c * pred_noise + \
197
+ sigma * noise
198
+
199
+ img = torch.cat((img, torch.randn(1, self.num_proposals - num_remain, 4, device=img.device)), dim=1)
200
+
201
+ if self.sampling_timesteps > 1:
202
+ box_pred_per_image, scores_per_image, labels_per_image = self.inference(outputs_class[-1],
203
+ outputs_coord[-1])
204
+ ensemble_score.append(scores_per_image)
205
+ ensemble_label.append(labels_per_image)
206
+ ensemble_coord.append(box_pred_per_image)
207
+
208
+ if self.sampling_timesteps > 1:
209
+ box_pred_per_image = torch.cat(ensemble_coord, dim=0)
210
+ scores_per_image = torch.cat(ensemble_score, dim=0)
211
+ labels_per_image = torch.cat(ensemble_label, dim=0)
212
+
213
+ if self.use_nms:
214
+ keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
215
+ box_pred_per_image = box_pred_per_image[keep]
216
+ scores_per_image = scores_per_image[keep]
217
+ labels_per_image = labels_per_image[keep]
218
+
219
+ return box_pred_per_image, scores_per_image, labels_per_image
220
+ else:
221
+ return self.inference(outputs_class[-1], outputs_coord[-1])
222
+
223
+ def q_sample(self, x_start, t, noise=None):
224
+ if noise is None:
225
+ noise = torch.randn_like(x_start)
226
+
227
+ sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
228
+ sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
229
+
230
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
231
+
232
+ def forward(self, pixel_values, labels):
233
+ """
234
+ Args:
235
+ """
236
+ images = pixel_values.to(self.device)
237
+ images_whwh = list()
238
+ for image in images:
239
+ h, w = image.shape[-2:]
240
+ images_whwh.append(torch.tensor([w, h, w, h], device=self.device))
241
+ images_whwh = torch.stack(images_whwh)
242
+
243
+ features = self.backbone(images)
244
+ features = OrderedDict(
245
+ [(key, feature) for key, feature in zip(self.backbone.out_features, features.feature_maps)]
246
+ )
247
+ features = self.fpn(features) # [144, 72, 36, 18]
248
+ features = [features[f] for f in features.keys()]
249
+
250
+ # if self.training:
251
+ labels = list(map(lambda tensor: tensor.to(self.device), labels))
252
+ targets, x_boxes, noises, ts = self.prepare_targets(labels)
253
+
254
+ ts = ts.squeeze(-1)
255
+ x_boxes = x_boxes * images_whwh[:, None, :]
256
+
257
+ outputs_class, outputs_coord = self.head(features, x_boxes, ts)
258
+ output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
259
+
260
+ if self.deep_supervision:
261
+ output['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
262
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
263
+
264
+ loss_dict = self.criterion(output, targets)
265
+ weight_dict = self.criterion.weight_dict
266
+ for k in loss_dict.keys():
267
+ if k in weight_dict:
268
+ loss_dict[k] *= weight_dict[k]
269
+ loss_dict['loss'] = sum([loss_dict[k] for k in weight_dict.keys()])
270
+
271
+ wandb_logs_values = ["loss_ce", "loss_bbox", "loss_giou"]
272
+
273
+ if self.training:
274
+ wandb.log({f'train/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
275
+ else:
276
+ wandb.log({f'eval/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
277
+
278
+ if not self.training:
279
+ pred_logits, pred_labels, pred_boxes = self.ddim_sample(pixel_values, features, images_whwh)
280
+ return DiffusionDetOutput(
281
+ loss=loss_dict['loss'],
282
+ loss_dict=loss_dict,
283
+ logits=pred_logits,
284
+ labels=pred_labels,
285
+ pred_boxes=pred_boxes,
286
+ )
287
+
288
+ return DiffusionDetOutput(
289
+ loss=loss_dict['loss'],
290
+ loss_dict=loss_dict,
291
+ logits=output['pred_logits'],
292
+ pred_boxes=output['pred_boxes']
293
+ )
294
+
295
+ def prepare_diffusion_concat(self, gt_boxes):
296
+ """
297
+ :param gt_boxes: (cx, cy, w, h), normalized
298
+ :param num_proposals:
299
+ """
300
+ t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()
301
+ noise = torch.randn(self.num_proposals, 4, device=self.device)
302
+
303
+ num_gt = gt_boxes.shape[0]
304
+ if not num_gt: # generate fake gt boxes if empty gt boxes
305
+ gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
306
+ num_gt = 1
307
+
308
+ if num_gt < self.num_proposals:
309
+ box_placeholder = torch.randn(self.num_proposals - num_gt, 4,
310
+ device=self.device) / 6. + 0.5 # 3sigma = 1/2 --> sigma: 1/6
311
+ box_placeholder[:, 2:] = torch.clip(box_placeholder[:, 2:], min=1e-4)
312
+ x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
313
+ elif num_gt > self.num_proposals:
314
+ select_mask = [True] * self.num_proposals + [False] * (num_gt - self.num_proposals)
315
+ random.shuffle(select_mask)
316
+ x_start = gt_boxes[select_mask]
317
+ else:
318
+ x_start = gt_boxes
319
+
320
+ x_start = (x_start * 2. - 1.) * self.scale
321
+
322
+ # noise sample
323
+ x = self.q_sample(x_start=x_start, t=t, noise=noise)
324
+
325
+ x = torch.clamp(x, min=-1 * self.scale, max=self.scale)
326
+ x = ((x / self.scale) + 1) / 2.
327
+
328
+ diff_boxes = ops.box_convert(x, 'cxcywh', 'xyxy')
329
+
330
+ return diff_boxes, noise, t
331
+
332
+ def prepare_targets(self, targets):
333
+ new_targets = []
334
+ diffused_boxes = []
335
+ noises = []
336
+ ts = []
337
+ for target in targets:
338
+ h, w = target.size
339
+ image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
340
+ gt_classes = target.class_labels.to(self.device)
341
+ gt_boxes = target.boxes.to(self.device)
342
+ d_boxes, d_noise, d_t = self.prepare_diffusion_concat(gt_boxes)
343
+ image_size_xyxy_tgt = image_size_xyxy.unsqueeze(0).repeat(len(gt_boxes), 1)
344
+ gt_boxes = gt_boxes * image_size_xyxy
345
+ gt_boxes = ops.box_convert(gt_boxes, 'cxcywh', 'xyxy')
346
+
347
+ diffused_boxes.append(d_boxes)
348
+ noises.append(d_noise)
349
+ ts.append(d_t)
350
+ new_targets.append({
351
+ "labels": gt_classes,
352
+ "boxes": target.boxes.to(self.device),
353
+ "boxes_xyxy": gt_boxes,
354
+ "image_size_xyxy": image_size_xyxy.to(self.device),
355
+ "image_size_xyxy_tgt": image_size_xyxy_tgt.to(self.device),
356
+ "area": ops.box_area(target.boxes.to(self.device)),
357
+ })
358
+
359
+ return new_targets, torch.stack(diffused_boxes), torch.stack(noises), torch.stack(ts)
360
+
361
+ def inference(self, box_cls, box_pred):
362
+ """
363
+ Arguments:
364
+ box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
365
+ The tensor predicts the classification probability for each proposal.
366
+ box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
367
+ The tensor predicts 4-vector (x,y,w,h) box
368
+ regression values for every proposal
369
+ image_sizes (List[torch.Size]): the input image sizes
370
+
371
+ Returns:
372
+ results (List[Instances]): a list of #images elements.
373
+ """
374
+ results = []
375
+ boxes_output = []
376
+ logits_output = []
377
+ labels_output = []
378
+
379
+ if self.use_focal or self.use_fed_loss:
380
+ scores = torch.sigmoid(box_cls)
381
+ labels = torch.arange(self.num_classes, device=self.device). \
382
+ unsqueeze(0).repeat(self.num_proposals, 1).flatten(0, 1)
383
+
384
+ for i, (scores_per_image, box_pred_per_image) in enumerate(zip(
385
+ scores, box_pred
386
+ )):
387
+ scores_per_image, topk_indices = scores_per_image.flatten(0, 1).topk(self.num_proposals, sorted=False)
388
+ labels_per_image = labels[topk_indices]
389
+ box_pred_per_image = box_pred_per_image.view(-1, 1, 4).repeat(1, self.num_classes, 1).view(-1, 4)
390
+ box_pred_per_image = box_pred_per_image[topk_indices]
391
+
392
+ if self.sampling_timesteps > 1:
393
+ return box_pred_per_image, scores_per_image, labels_per_image
394
+
395
+ if self.use_nms:
396
+ keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
397
+ box_pred_per_image = box_pred_per_image[keep]
398
+ scores_per_image = scores_per_image[keep]
399
+ labels_per_image = labels_per_image[keep]
400
+
401
+ boxes_output.append(box_pred_per_image)
402
+ logits_output.append(scores_per_image)
403
+ labels_output.append(labels_per_image)
404
+ else:
405
+ # For each box we assign the best class or the second best if the best on is `no_object`.
406
+ scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
407
+
408
+ for i, (scores_per_image, labels_per_image, box_pred_per_image) in enumerate(zip(
409
+ scores, labels, box_pred
410
+ )):
411
+ if self.sampling_timesteps > 1:
412
+ return box_pred_per_image, scores_per_image, labels_per_image
413
+
414
+ if self.use_nms:
415
+ keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
416
+ box_pred_per_image = box_pred_per_image[keep]
417
+ scores_per_image = scores_per_image[keep]
418
+ labels_per_image = labels_per_image[keep]
419
+
420
+ boxes_output.append(box_pred_per_image)
421
+ logits_output.append(scores_per_image)
422
+ labels_output.append(labels_per_image)
423
+
424
+ return boxes_output, logits_output, labels_output