|
from dataclasses import dataclass
|
|
from typing import List, Dict
|
|
|
|
import torch
|
|
from torchvision.transforms import Resize
|
|
from transformers import PreTrainedModel
|
|
from transformers.utils import ModelOutput, torch_int
|
|
from rfdetr import RFDETRBase, RFDETRLarge
|
|
from rfdetr.util.misc import NestedTensor
|
|
|
|
from .configuration_rf_detr import RFDetrConfig
|
|
|
|
|
|
|
|
@dataclass
|
|
class RFDetrObjectDetectionOutput(ModelOutput):
|
|
loss: torch.Tensor = None
|
|
loss_dict: Dict[str, torch.Tensor] = None
|
|
logits: torch.FloatTensor = None
|
|
pred_boxes: torch.FloatTensor = None
|
|
aux_outputs: List[Dict[str, torch.Tensor]] = None
|
|
enc_outputs: Dict[str, torch.Tensor] = None
|
|
|
|
|
|
class RFDetrModelForObjectDetection(PreTrainedModel):
|
|
config_class = RFDetrConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
models = {
|
|
'RFDETRBase': RFDETRBase,
|
|
'RFDETRLarge': RFDETRLarge,
|
|
}
|
|
rf_detr_model = models[config.model_name](
|
|
out_feature_indexes = config.out_feature_indexes,
|
|
dec_layers = config.dec_layers,
|
|
two_stage = config.two_stage,
|
|
bbox_reparam = config.bbox_reparam,
|
|
lite_refpoint_refine = config.lite_refpoint_refine,
|
|
layer_norm = config.layer_norm,
|
|
amp = config.amp,
|
|
num_classes = config.num_classes,
|
|
resolution = config.resolution,
|
|
group_detr = config.group_detr,
|
|
gradient_checkpointing = config.gradient_checkpointing,
|
|
num_queries = config.num_queries,
|
|
encoder = config.encoder,
|
|
hidden_dim = config.hidden_dim,
|
|
sa_nheads = config.sa_nheads,
|
|
ca_nheads = config.ca_nheads,
|
|
dec_n_points = config.dec_n_points,
|
|
projector_scale = config.projector_scale,
|
|
pretrain_weights = config.pretrain_weights,
|
|
)
|
|
self.model = rf_detr_model.model.model
|
|
self.criterion = rf_detr_model.model.criterion
|
|
|
|
def compute_loss(self, outputs, labels=None):
|
|
"""
|
|
Parameters
|
|
----------
|
|
labels: list[Dict[str, torch.Tensor]]
|
|
list of bounding boxes and labels for each image in the batch.
|
|
outputs:
|
|
outputs from rfdetr model
|
|
"""
|
|
loss = None
|
|
loss_dict = None
|
|
|
|
if labels is None:
|
|
|
|
pass
|
|
else:
|
|
losses = self.criterion(outputs, targets=labels)
|
|
loss_dict = {
|
|
'loss_fl': losses["loss_ce"],
|
|
|
|
'class_error': losses["class_error"],
|
|
'cardinality_error': losses["cardinality_error"],
|
|
'loss_bbox': losses["loss_bbox"],
|
|
'loss_giou': losses["loss_giou"],
|
|
}
|
|
loss = sum(loss_dict[k] for k in ['loss_fl', 'loss_bbox', 'loss_giou'])
|
|
|
|
return loss, loss_dict
|
|
|
|
def validate_labels(self, labels):
|
|
|
|
for label_idx, label in enumerate(labels):
|
|
boxes = label["boxes"]
|
|
degenerate_boxes = boxes[:, 2:] <= 0
|
|
if degenerate_boxes.any():
|
|
|
|
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
|
|
degen_bb: List[float] = boxes[bb_idx].tolist()
|
|
torch._assert(
|
|
False,
|
|
"All bounding boxes should have positive height and width."
|
|
f" Found invalid box {degen_bb} for target at index {label_idx}.",
|
|
)
|
|
|
|
if 'class_labels' in label.keys():
|
|
label['labels'] = label.pop('class_labels')
|
|
|
|
def resize_labels(self, labels, h, w):
|
|
"""
|
|
Resize boxes coordinates to model's resolution
|
|
"""
|
|
hr = self.config.resolution / float(h)
|
|
wr = self.config.resolution / float(w)
|
|
|
|
for label in labels:
|
|
boxes = label["boxes"]
|
|
|
|
boxes[:, [0, 2]] *= wr
|
|
boxes[:, [1, 3]] *= hr
|
|
|
|
boxes[:] /= self.config.resolution
|
|
label["boxes"] = boxes
|
|
|
|
|
|
def _onnx_interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
|
resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
|
|
with the original implementation.
|
|
|
|
Adapted from:
|
|
- https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
|
- https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
|
|
"""
|
|
position_embeddings = self.model.backbone[0].encoder.encoder.embeddings.position_embeddings
|
|
config = self.model.backbone[0].encoder.encoder.embeddings.config
|
|
|
|
num_patches = embeddings.shape[1] - 1
|
|
num_positions = position_embeddings.shape[1] - 1
|
|
|
|
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
|
return position_embeddings
|
|
|
|
|
|
class_pos_embed = position_embeddings[:, 0]
|
|
patch_pos_embed = position_embeddings[:, 1:]
|
|
dim = embeddings.shape[-1]
|
|
|
|
|
|
height = height // config.patch_size
|
|
width = width // config.patch_size
|
|
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
|
|
|
|
target_dtype = patch_pos_embed.dtype
|
|
|
|
|
|
|
|
patch_pos_embed = torch.nn.functional.interpolate(
|
|
patch_pos_embed.to(dtype=torch.float32),
|
|
size=(torch_int(height), torch_int(width)),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
antialias=False,
|
|
).to(dtype=target_dtype)
|
|
|
|
|
|
if not torch.jit.is_tracing():
|
|
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
|
|
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
|
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
|
|
|
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
|
|
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor=None, labels=None, **kwargs) -> ModelOutput:
|
|
"""
|
|
Parameters
|
|
----------
|
|
pixel_values : torch.Tensor
|
|
Input tensor representing image pixel values.
|
|
labels : Optional[List[Dict[str, torch.Tensor | List]]]
|
|
List of annotations associated with the image or batch of images. If annotation is for object
|
|
detection, the annotations should be a dictionary with the following keys:
|
|
- boxes (FloatTensor[N, 4]): the ground-truth boxes in format [center_x, center_y, width, height]
|
|
- class_labels (Int64Tensor[N]): the class label for each ground-truth box
|
|
|
|
Returns
|
|
-------
|
|
RFDetrObjectDetectionOutput
|
|
Object containing
|
|
- loss: sum of focal loss, bounding box loss, and generalized iou loss
|
|
- loss_dict: dictionary of losses
|
|
- logits
|
|
- pred_boxes
|
|
- aux_outputs
|
|
- enc_outputs
|
|
"""
|
|
if torch.jit.is_tracing():
|
|
|
|
resize = Resize((self.config.resolution, self.config.resolution), antialias=False)
|
|
self.model.backbone[0].encoder.encoder.embeddings.interpolate_pos_encoding = self._onnx_interpolate_pos_encoding
|
|
else:
|
|
resize = Resize((self.config.resolution, self.config.resolution))
|
|
|
|
if labels is not None:
|
|
self.validate_labels(labels)
|
|
_,_,h,w = pixel_values.shape
|
|
self.resize_labels(labels, h, w)
|
|
else:
|
|
self.model.training = False
|
|
self.model.transformer.training = False
|
|
for layer in self.model.transformer.decoder.layers:
|
|
layer.training = False
|
|
self.criterion.training = False
|
|
|
|
|
|
pixel_values = resize(pixel_values)
|
|
if pixel_mask is None:
|
|
pixel_mask = torch.zeros([pixel_values.shape[0], self.config.resolution, self.config.resolution], dtype=torch.bool)
|
|
else:
|
|
pixel_mask = resize(pixel_mask)
|
|
|
|
samples = NestedTensor(pixel_values, pixel_mask)
|
|
outputs = self.model(samples)
|
|
|
|
|
|
loss, loss_dict = self.compute_loss(outputs, labels)
|
|
|
|
return RFDetrObjectDetectionOutput(
|
|
loss=loss,
|
|
loss_dict=loss_dict,
|
|
logits=outputs["pred_logits"],
|
|
pred_boxes=outputs["pred_boxes"],
|
|
aux_outputs=outputs["aux_outputs"],
|
|
enc_outputs=outputs["enc_outputs"],
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"RFDetrModelForObjectDetection"
|
|
] |