Index error when using batch size > 1

#11
by ce-mwest - opened

Finetuning the model and using the prebuilt loss functions ends up in an error when using a batch size > 1 and dynamic value for the number of bounding boxes, due to the following code snippet:

    def _get_target_classes_one_hot(self, outputs, targets, indices):
        """
        Create one_hot based on the matching indices
        """
        logits = outputs["logits"]
        # Add offsets to class_labels to select the correct label map
        class_labels = torch.cat(
            [
                target["class_labels"][J] + len(outputs["label_maps"][i]) if i > 0 else target["class_labels"][J]
                for i, (target, (_, J)) in enumerate(zip(targets, indices))
            ]
        )
        label_maps = torch.cat(outputs["label_maps"], dim=0)

        idx = self._get_source_permutation_idx(indices)
        target_classes_onehot = torch.zeros_like(logits, device=logits.device, dtype=torch.long)
        target_classes_onehot[idx] = label_maps[class_labels].to(torch.long)

The offset calculation len(outputs["label_maps"][i]) is only done based on the current element.
For example:

targets[0]["class_labels"] = [0, 1, 2, 3]
targets[1]["class_labels"] = [0, 1, 2, 3, 4]
=> class_labels = [0, 1, 2, 3, 5, 6, 7, 8, 9]

Which leads to an index error at label_maps[class_labels].to(torch.long)
Is this expected or am i doing something wrong?

Sign up or log in to comment