rf-detr-base / modeling_rf_detr.py
Thastp's picture
Upload model
ea77e61 verified
from dataclasses import dataclass
from typing import List, Dict
import torch
from torchvision.transforms import Resize
from transformers import PreTrainedModel
from transformers.utils import ModelOutput, torch_int
from rfdetr import RFDETRBase, RFDETRLarge
from rfdetr.util.misc import NestedTensor
from .configuration_rf_detr import RFDetrConfig
### ONLY WORKS WITH Transformers version 4.50.3 and python 3.11
@dataclass
class RFDetrObjectDetectionOutput(ModelOutput):
loss: torch.Tensor = None
loss_dict: Dict[str, torch.Tensor] = None
logits: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
aux_outputs: List[Dict[str, torch.Tensor]] = None
enc_outputs: Dict[str, torch.Tensor] = None
class RFDetrModelForObjectDetection(PreTrainedModel):
config_class = RFDetrConfig
def __init__(self, config):
super().__init__(config)
self.config = config
models = {
'RFDETRBase': RFDETRBase,
'RFDETRLarge': RFDETRLarge,
}
rf_detr_model = models[config.model_name](
out_feature_indexes = config.out_feature_indexes,
dec_layers = config.dec_layers,
two_stage = config.two_stage,
bbox_reparam = config.bbox_reparam,
lite_refpoint_refine = config.lite_refpoint_refine,
layer_norm = config.layer_norm,
amp = config.amp,
num_classes = config.num_classes,
resolution = config.resolution,
group_detr = config.group_detr,
gradient_checkpointing = config.gradient_checkpointing,
num_queries = config.num_queries,
encoder = config.encoder,
hidden_dim = config.hidden_dim,
sa_nheads = config.sa_nheads,
ca_nheads = config.ca_nheads,
dec_n_points = config.dec_n_points,
projector_scale = config.projector_scale,
pretrain_weights = config.pretrain_weights,
)
self.model = rf_detr_model.model.model
self.criterion = rf_detr_model.model.criterion
def compute_loss(self, outputs, labels=None):
"""
Parameters
----------
labels: list[Dict[str, torch.Tensor]]
list of bounding boxes and labels for each image in the batch.
outputs:
outputs from rfdetr model
"""
loss = None
loss_dict = None
#if self.model.training:
if labels is None:
#torch._assert(False, "targets should not be none when in training mode")
pass
else:
losses = self.criterion(outputs, targets=labels)
loss_dict = {
'loss_fl': losses["loss_ce"],
### class error and cardinality error is for logging purposes only, no back propagation
'class_error': losses["class_error"],
'cardinality_error': losses["cardinality_error"],
'loss_bbox': losses["loss_bbox"],
'loss_giou': losses["loss_giou"],
}
loss = sum(loss_dict[k] for k in ['loss_fl', 'loss_bbox', 'loss_giou'])
return loss, loss_dict
def validate_labels(self, labels):
# Check for degenerate boxes
for label_idx, label in enumerate(labels):
boxes = label["boxes"]
degenerate_boxes = boxes[:, 2:] <= 0
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
torch._assert(
False,
"All bounding boxes should have positive height and width."
f" Found invalid box {degen_bb} for target at index {label_idx}.",
)
# rename key class_labels to labels for compute_loss
if 'class_labels' in label.keys():
label['labels'] = label.pop('class_labels')
def resize_labels(self, labels, h, w):
"""
Resize boxes coordinates to model's resolution
"""
hr = self.config.resolution / float(h)
wr = self.config.resolution / float(w)
for label in labels:
boxes = label["boxes"]
# resize boxes to model's resolution
boxes[:, [0, 2]] *= wr
boxes[:, [1, 3]] *= hr
# normalize to [0, 1] by model's resolution
boxes[:] /= self.config.resolution
label["boxes"] = boxes
### modified from https://github.com/roboflow/rf-detr/blob/develop/rfdetr/models/backbone/dinov2_with_windowed_attn.py
def _onnx_interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
with the original implementation.
Adapted from:
- https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
- https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
"""
position_embeddings = self.model.backbone[0].encoder.encoder.embeddings.position_embeddings
config = self.model.backbone[0].encoder.encoder.embeddings.config
num_patches = embeddings.shape[1] - 1
num_positions = position_embeddings.shape[1] - 1
# Skip interpolation for matching dimensions (unless tracing)
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return position_embeddings
# Handle class token and patch embeddings separately
class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]
dim = embeddings.shape[-1]
# Calculate new dimensions
height = height // config.patch_size
width = width // config.patch_size
# Reshape for interpolation
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
# Store original dtype for restoration after interpolation
target_dtype = patch_pos_embed.dtype
# Interpolate at float32 precision
### disable antialiasing for ONNX export
patch_pos_embed = torch.nn.functional.interpolate(
patch_pos_embed.to(dtype=torch.float32),
size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
mode="bicubic",
align_corners=False,
antialias=False,
).to(dtype=target_dtype)
# Validate output dimensions if not tracing
if not torch.jit.is_tracing():
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
# Reshape back to original format
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
# Combine class and patch embeddings
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor=None, labels=None, **kwargs) -> ModelOutput:
"""
Parameters
----------
pixel_values : torch.Tensor
Input tensor representing image pixel values.
labels : Optional[List[Dict[str, torch.Tensor | List]]]
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- boxes (FloatTensor[N, 4]): the ground-truth boxes in format [center_x, center_y, width, height]
- class_labels (Int64Tensor[N]): the class label for each ground-truth box
Returns
-------
RFDetrObjectDetectionOutput
Object containing
- loss: sum of focal loss, bounding box loss, and generalized iou loss
- loss_dict: dictionary of losses
- logits
- pred_boxes
- aux_outputs
- enc_outputs
"""
if torch.jit.is_tracing():
### disable antialiasing for ONNX export
resize = Resize((self.config.resolution, self.config.resolution), antialias=False)
self.model.backbone[0].encoder.encoder.embeddings.interpolate_pos_encoding = self._onnx_interpolate_pos_encoding
else:
resize = Resize((self.config.resolution, self.config.resolution))
if labels is not None:
self.validate_labels(labels)
_,_,h,w = pixel_values.shape
self.resize_labels(labels, h, w) # reshape labels with model's resolution
else:
self.model.training = False
self.model.transformer.training = False
for layer in self.model.transformer.decoder.layers:
layer.training = False
self.criterion.training = False
# resize pixel values and mask to model's resolution
pixel_values = resize(pixel_values)
if pixel_mask is None:
pixel_mask = torch.zeros([pixel_values.shape[0], self.config.resolution, self.config.resolution], dtype=torch.bool)
else:
pixel_mask = resize(pixel_mask)
samples = NestedTensor(pixel_values, pixel_mask)
outputs = self.model(samples)
# compute loss, return none and empty dict if not training
loss, loss_dict = self.compute_loss(outputs, labels)
return RFDetrObjectDetectionOutput(
loss=loss,
loss_dict=loss_dict,
logits=outputs["pred_logits"],
pred_boxes=outputs["pred_boxes"],
aux_outputs=outputs["aux_outputs"],
enc_outputs=outputs["enc_outputs"],
)
__all__ = [
"RFDetrModelForObjectDetection"
]