Index error when using batch size > 1
#11
by
ce-mwest
- opened
Finetuning the model and using the prebuilt loss functions ends up in an error when using a batch size > 1 and dynamic value for the number of bounding boxes, due to the following code snippet:
def _get_target_classes_one_hot(self, outputs, targets, indices):
"""
Create one_hot based on the matching indices
"""
logits = outputs["logits"]
# Add offsets to class_labels to select the correct label map
class_labels = torch.cat(
[
target["class_labels"][J] + len(outputs["label_maps"][i]) if i > 0 else target["class_labels"][J]
for i, (target, (_, J)) in enumerate(zip(targets, indices))
]
)
label_maps = torch.cat(outputs["label_maps"], dim=0)
idx = self._get_source_permutation_idx(indices)
target_classes_onehot = torch.zeros_like(logits, device=logits.device, dtype=torch.long)
target_classes_onehot[idx] = label_maps[class_labels].to(torch.long)
The offset calculation len(outputs["label_maps"][i])
is only done based on the current element.
For example:
targets[0]["class_labels"] = [0, 1, 2, 3]
targets[1]["class_labels"] = [0, 1, 2, 3, 4]
=> class_labels = [0, 1, 2, 3, 5, 6, 7, 8, 9]
Which leads to an index error at label_maps[class_labels].to(torch.long)
Is this expected or am i doing something wrong?