import torch import torch.nn.functional as F from fvcore.nn import sigmoid_focal_loss_jit from torch import nn import torch.distributed as dist from torch.distributed import get_world_size from torchvision import ops def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_fed_loss_classes(gt_classes, num_fed_loss_classes, num_classes, weight): """ Args: gt_classes: a long tensor of shape R that contains the gt class label of each proposal. num_fed_loss_classes: minimum number of classes to keep when calculating federated loss. Will sample negative classes if number of unique gt_classes is smaller than this value. num_classes: number of foreground classes weight: probabilities used to sample negative classes Returns: Tensor: classes to keep when calculating the federated loss, including both unique gt classes and sampled negative classes. """ unique_gt_classes = torch.unique(gt_classes) prob = unique_gt_classes.new_ones(num_classes + 1).float() prob[-1] = 0 if len(unique_gt_classes) < num_fed_loss_classes: prob[:num_classes] = weight.float().clone() prob[unique_gt_classes] = 0 sampled_negative_classes = torch.multinomial( prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False ) fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes]) else: fed_loss_classes = unique_gt_classes return fed_loss_classes class CriterionDynamicK(nn.Module): """ This class computes the loss for DiffusionDet. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, config, num_classes, weight_dict): """ Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category weight_dict: dict containing as key the names of the losses and as values their relative weight. """ super().__init__() self.config = config self.num_classes = num_classes self.matcher = HungarianMatcherDynamicK(config) self.weight_dict = weight_dict self.eos_coef = config.no_object_weight self.use_focal = config.use_focal self.use_fed_loss = config.use_fed_loss if self.use_focal: self.focal_loss_alpha = config.alpha self.focal_loss_gamma = config.gamma # copy-paste from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/roi_heads/fast_rcnn.py#L356 def loss_labels(self, outputs, targets, indices): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] batch_size = len(targets) # idx = self._get_src_permutation_idx(indices) # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) src_logits_list = [] target_classes_o_list = [] # target_classes[idx] = target_classes_o for batch_idx in range(batch_size): valid_query = indices[batch_idx][0] gt_multi_idx = indices[batch_idx][1] if len(gt_multi_idx) == 0: continue bz_src_logits = src_logits[batch_idx] target_classes_o = targets[batch_idx]["labels"] target_classes[batch_idx, valid_query] = target_classes_o[gt_multi_idx] src_logits_list.append(bz_src_logits[valid_query]) target_classes_o_list.append(target_classes_o[gt_multi_idx]) if self.use_focal or self.use_fed_loss: num_boxes = torch.cat(target_classes_o_list).shape[0] if len(target_classes_o_list) != 0 else 1 target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], self.num_classes + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) gt_classes = torch.argmax(target_classes_onehot, dim=-1) target_classes_onehot = target_classes_onehot[:, :, :-1] src_logits = src_logits.flatten(0, 1) target_classes_onehot = target_classes_onehot.flatten(0, 1) if self.use_focal: cls_loss = sigmoid_focal_loss_jit(src_logits, target_classes_onehot, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="none") else: cls_loss = F.binary_cross_entropy_with_logits(src_logits, target_classes_onehot, reduction="none") if self.use_fed_loss: K = self.num_classes N = src_logits.shape[0] fed_loss_classes = get_fed_loss_classes( gt_classes, num_fed_loss_classes=self.fed_loss_num_classes, num_classes=K, weight=self.fed_loss_cls_weights, ) fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1) fed_loss_classes_mask[fed_loss_classes] = 1 fed_loss_classes_mask = fed_loss_classes_mask[:K] weight = fed_loss_classes_mask.view(1, K).expand(N, K).float() loss_ce = torch.sum(cls_loss * weight) / num_boxes else: loss_ce = torch.sum(cls_loss) / num_boxes losses = {'loss_ce': loss_ce} else: raise NotImplementedError return losses def loss_boxes(self, outputs, targets, indices): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ assert 'pred_boxes' in outputs # idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'] batch_size = len(targets) pred_box_list = [] pred_norm_box_list = [] tgt_box_list = [] tgt_box_xyxy_list = [] for batch_idx in range(batch_size): valid_query = indices[batch_idx][0] gt_multi_idx = indices[batch_idx][1] if len(gt_multi_idx) == 0: continue bz_image_whwh = targets[batch_idx]['image_size_xyxy'] bz_src_boxes = src_boxes[batch_idx] bz_target_boxes = targets[batch_idx]["boxes"] # normalized (cx, cy, w, h) bz_target_boxes_xyxy = targets[batch_idx]["boxes_xyxy"] # absolute (x1, y1, x2, y2) pred_box_list.append(bz_src_boxes[valid_query]) pred_norm_box_list.append(bz_src_boxes[valid_query] / bz_image_whwh) # normalize (x1, y1, x2, y2) tgt_box_list.append(bz_target_boxes[gt_multi_idx]) tgt_box_xyxy_list.append(bz_target_boxes_xyxy[gt_multi_idx]) if len(pred_box_list) != 0: src_boxes = torch.cat(pred_box_list) src_boxes_norm = torch.cat(pred_norm_box_list) # normalized (x1, y1, x2, y2) target_boxes = torch.cat(tgt_box_list) target_boxes_abs_xyxy = torch.cat(tgt_box_xyxy_list) num_boxes = src_boxes.shape[0] losses = {} # require normalized (x1, y1, x2, y2) loss_bbox = F.l1_loss(src_boxes_norm, ops.box_convert(target_boxes, 'cxcywh', 'xyxy'), reduction='none') losses['loss_bbox'] = loss_bbox.sum() / num_boxes # loss_giou = giou_loss(box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes)) loss_giou = 1 - torch.diag(ops.generalized_box_iou(src_boxes, target_boxes_abs_xyxy)) losses['loss_giou'] = loss_giou.sum() / num_boxes else: losses = {'loss_bbox': outputs['pred_boxes'].sum() * 0, 'loss_giou': outputs['pred_boxes'].sum() * 0} return losses def get_loss(self, loss, outputs, targets, indices): loss_map = { 'labels': self.loss_labels, 'boxes': self.loss_boxes, } assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices) def forward(self, outputs, targets): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} # Retrieve the matching between the outputs of the last layer and the targets indices, _ = self.matcher(outputs_without_aux, targets) # Compute all the requested losses losses = {} for loss in ["labels", "boxes"]: losses.update(self.get_loss(loss, outputs, targets, indices)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices, _ = self.matcher(aux_outputs, targets) for loss in ["labels", "boxes"]: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue l_dict = self.get_loss(loss, aux_outputs, targets, indices) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses def get_in_boxes_info(boxes, target_gts): xy_target_gts = ops.box_convert(target_gts, 'cxcywh', 'xyxy') # (x1, y1, x2, y2) anchor_center_x = boxes[:, 0].unsqueeze(1) anchor_center_y = boxes[:, 1].unsqueeze(1) # whether the center of each anchor is inside a gt box b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0) b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0) b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0) b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0) # (b_l.long()+b_r.long()+b_t.long()+b_b.long())==4 [300,num_gt] , is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) is_in_boxes_all = is_in_boxes.sum(1) > 0 # [num_query] # in fixed center center_radius = 2.5 # Modified to self-adapted sampling --- the center size depends on the size of the gt boxes # https://github.com/dulucas/UVO_Challenge/blob/main/Track1/detection/mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py#L212 b_l = anchor_center_x > ( target_gts[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) b_r = anchor_center_x < ( target_gts[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) b_t = anchor_center_y > ( target_gts[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) b_b = anchor_center_y < ( target_gts[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) is_in_centers_all = is_in_centers.sum(1) > 0 is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all is_in_boxes_and_center = (is_in_boxes & is_in_centers) return is_in_boxes_anchor, is_in_boxes_and_center class HungarianMatcherDynamicK(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__(self, config): super().__init__() self.use_focal = config.use_focal self.use_fed_loss = config.use_fed_loss self.cost_class = config.class_weight self.cost_giou = config.giou_weight self.cost_bbox = config.l1_weight self.ota_k = config.ota_k if self.use_focal: self.focal_loss_alpha = config.alpha self.focal_loss_gamma = config.gamma assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0" def forward(self, outputs, targets): """ simOTA for detr""" with torch.no_grad(): bs, num_queries = outputs["pred_logits"].shape[:2] # We flatten to compute the cost matrices in a batch if self.use_focal or self.use_fed_loss: out_prob = outputs["pred_logits"].sigmoid() # [batch_size, num_queries, num_classes] out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4] else: out_prob = outputs["pred_logits"].softmax(-1) # [batch_size, num_queries, num_classes] out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4] indices = [] matched_ids = [] assert bs == len(targets) for batch_idx in range(bs): bz_boxes = out_bbox[batch_idx] # [num_proposals, 4] bz_out_prob = out_prob[batch_idx] bz_tgt_ids = targets[batch_idx]["labels"] num_insts = len(bz_tgt_ids) if num_insts == 0: # empty object in key frame non_valid = torch.zeros(bz_out_prob.shape[0]).to(bz_out_prob) > 0 indices_batchi = (non_valid, torch.arange(0, 0).to(bz_out_prob)) matched_qidx = torch.arange(0, 0).to(bz_out_prob) indices.append(indices_batchi) matched_ids.append(matched_qidx) continue bz_gtboxs = targets[batch_idx]['boxes'] # [num_gt, 4] normalized (cx, xy, w, h) bz_gtboxs_abs_xyxy = targets[batch_idx]['boxes_xyxy'] fg_mask, is_in_boxes_and_center = get_in_boxes_info( ops.box_convert(bz_boxes, 'xyxy', 'cxcywh'), # absolute (cx, cy, w, h) ops.box_convert(bz_gtboxs_abs_xyxy, 'xyxy', 'cxcywh') # absolute (cx, cy, w, h) ) pair_wise_ious = ops.box_iou(bz_boxes, bz_gtboxs_abs_xyxy) # Compute the classification cost. if self.use_focal: alpha = self.focal_loss_alpha gamma = self.focal_loss_gamma neg_cost_class = (1 - alpha) * (bz_out_prob ** gamma) * (-(1 - bz_out_prob + 1e-8).log()) pos_cost_class = alpha * ((1 - bz_out_prob) ** gamma) * (-(bz_out_prob + 1e-8).log()) cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids] elif self.use_fed_loss: # focal loss degenerates to naive one neg_cost_class = (-(1 - bz_out_prob + 1e-8).log()) pos_cost_class = (-(bz_out_prob + 1e-8).log()) cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids] else: cost_class = -bz_out_prob[:, bz_tgt_ids] # Compute the L1 cost between boxes # image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets]) # image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1) # image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) bz_image_size_out = targets[batch_idx]['image_size_xyxy'] bz_image_size_tgt = targets[batch_idx]['image_size_xyxy_tgt'] bz_out_bbox_ = bz_boxes / bz_image_size_out # normalize (x1, y1, x2, y2) bz_tgt_bbox_ = bz_gtboxs_abs_xyxy / bz_image_size_tgt # normalize (x1, y1, x2, y2) cost_bbox = torch.cdist(bz_out_bbox_, bz_tgt_bbox_, p=1) cost_giou = -ops.generalized_box_iou(bz_boxes, bz_gtboxs_abs_xyxy) # Final cost matrix cost = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + 100.0 * ( ~is_in_boxes_and_center) # cost = (cost_class + 3.0 * cost_giou + 100.0 * (~is_in_boxes_and_center)) # [num_query,num_gt] cost[~fg_mask] = cost[~fg_mask] + 10000.0 # if bz_gtboxs.shape[0]>0: indices_batchi, matched_qidx = self.dynamic_k_matching(cost, pair_wise_ious, bz_gtboxs.shape[0]) indices.append(indices_batchi) matched_ids.append(matched_qidx) return indices, matched_ids def dynamic_k_matching(self, cost, pair_wise_ious, num_gt): matching_matrix = torch.zeros_like(cost) # [300,num_gt] ious_in_boxes_matrix = pair_wise_ious n_candidate_k = self.ota_k # Take the sum of the predicted value and the top 10 iou of gt with the largest iou as dynamic_k topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0) dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) for gt_idx in range(num_gt): _, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) matching_matrix[:, gt_idx][pos_idx] = 1.0 del topk_ious, dynamic_ks, pos_idx anchor_matching_gt = matching_matrix.sum(1) if (anchor_matching_gt > 1).sum() > 0: _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1) matching_matrix[anchor_matching_gt > 1] *= 0 matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 while (matching_matrix.sum(0) == 0).any(): num_zero_gt = (matching_matrix.sum(0) == 0).sum() matched_query_id = matching_matrix.sum(1) > 0 cost[matched_query_id] += 100000.0 unmatch_id = torch.nonzero(matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1) for gt_idx in unmatch_id: pos_idx = torch.argmin(cost[:, gt_idx]) matching_matrix[:, gt_idx][pos_idx] = 1.0 if (matching_matrix.sum(1) > 1).sum() > 0: # If a query matches more than one gt _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1) # find gt for these queries with minimal cost matching_matrix[anchor_matching_gt > 1] *= 0 # reset mapping relationship matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 # keep gt with minimal cost assert not (matching_matrix.sum(0) == 0).any() selected_query = matching_matrix.sum(1) > 0 gt_indices = matching_matrix[selected_query].max(1)[1] assert selected_query.sum() == len(gt_indices) cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf') matched_query_id = torch.min(cost, dim=0)[1] return (selected_query, gt_indices), matched_query_id