Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| from .resnet_backbone import ResNetBackbone | |
| class ResNet50(nn.Module): | |
| def __init__( | |
| self, | |
| weight_type: str = "supervised", | |
| use_dilated_resnet: bool = True | |
| ): | |
| super(ResNet50, self).__init__() | |
| self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None) | |
| self.n_embs = self.network.num_features | |
| self.use_dilated_resnet = use_dilated_resnet | |
| self._load_pretrained(weight_type) | |
| def _load_pretrained(self, training_method: str) -> None: | |
| curr_state_dict = self.network.state_dict() | |
| if training_method == "mocov2": | |
| state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"] | |
| for k in list(state_dict.keys()): | |
| if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]): | |
| state_dict.pop(k) | |
| elif training_method == "swav": | |
| state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar") | |
| for k in list(state_dict.keys()): | |
| if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]): | |
| state_dict.pop(k) | |
| elif training_method == "supervised": | |
| # Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why. | |
| # for k in list(curr_state_dict.keys()): | |
| # if k.find("num_batches_tracked") != -1: | |
| # curr_state_dict.pop(k) | |
| # state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth") | |
| from torchvision.models.resnet import resnet50 | |
| resnet50_supervised = resnet50(True, True) | |
| state_dict = resnet50_supervised.state_dict() | |
| for k in list(state_dict.keys()): | |
| if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]): | |
| state_dict.pop(k) | |
| assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}" | |
| for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()): | |
| curr_state_dict[k_curr].copy_(state_dict[k]) | |
| print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.") | |
| return | |
| def forward(self, x): | |
| return self.network(x) | |
| if __name__ == '__main__': | |
| resnet = ResNet50("mocov2") | |