import torch import torch.nn as nn import timm from transformers import PreTrainedModel, PretrainedConfig class HarpoonConfig(PretrainedConfig): model_type = "harpoon" def __init__(self, num_classes=1, resolution=544, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.resolution = resolution class HarpoonModel(PreTrainedModel): config_class = HarpoonConfig def __init__(self, config): super().__init__(config) self.num_classes = config.num_classes self.resolution = config.resolution # ConvNeXt Small backbone self.backbone = timm.create_model( 'convnext_small', pretrained=False, num_classes=0, global_pool='', in_chans=3 ) # Harpoon Core detection head self.harpoon_core = nn.Sequential( nn.Conv2d(768, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, (5 + self.num_classes) * 3, 1) ) self._initialize_harpoon_core() def _initialize_harpoon_core(self): """Initialize Harpoon Core detection head""" for m in self.harpoon_core.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): # ConvNeXt backbone feature extraction features = self.backbone(x) # Harpoon Core detection detections = self.harpoon_core(features) return detections