Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding:utf-8 -*- | |
| # Author: Donny You([email protected]) | |
| import torch.nn as nn | |
| from networks.resnet_models import * | |
| class NormalResnetBackbone(nn.Module): | |
| def __init__(self, orig_resnet): | |
| super(NormalResnetBackbone, self).__init__() | |
| self.num_features = 2048 | |
| # take pretrained resnet, except AvgPool and FC | |
| self.prefix = orig_resnet.prefix | |
| self.maxpool = orig_resnet.maxpool | |
| self.layer1 = orig_resnet.layer1 | |
| self.layer2 = orig_resnet.layer2 | |
| self.layer3 = orig_resnet.layer3 | |
| self.layer4 = orig_resnet.layer4 | |
| def get_num_features(self): | |
| return self.num_features | |
| def forward(self, x): | |
| tuple_features = list() | |
| x = self.prefix(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| tuple_features.append(x) | |
| x = self.layer2(x) | |
| tuple_features.append(x) | |
| x = self.layer3(x) | |
| tuple_features.append(x) | |
| x = self.layer4(x) | |
| tuple_features.append(x) | |
| return tuple_features | |
| class DilatedResnetBackbone(nn.Module): | |
| def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)): | |
| super(DilatedResnetBackbone, self).__init__() | |
| self.num_features = 2048 | |
| from functools import partial | |
| if dilate_scale == 8: | |
| orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) | |
| if multi_grid is None: | |
| orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) | |
| else: | |
| for i, r in enumerate(multi_grid): | |
| orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r))) | |
| elif dilate_scale == 16: | |
| if multi_grid is None: | |
| orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) | |
| else: | |
| for i, r in enumerate(multi_grid): | |
| orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r))) | |
| # Take pretrained resnet, except AvgPool and FC | |
| self.prefix = orig_resnet.prefix | |
| self.maxpool = orig_resnet.maxpool | |
| self.layer1 = orig_resnet.layer1 | |
| self.layer2 = orig_resnet.layer2 | |
| self.layer3 = orig_resnet.layer3 | |
| self.layer4 = orig_resnet.layer4 | |
| def _nostride_dilate(self, m, dilate): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| # the convolution with stride | |
| if m.stride == (2, 2): | |
| m.stride = (1, 1) | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate // 2, dilate // 2) | |
| m.padding = (dilate // 2, dilate // 2) | |
| # other convoluions | |
| else: | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate, dilate) | |
| m.padding = (dilate, dilate) | |
| def get_num_features(self): | |
| return self.num_features | |
| def forward(self, x): | |
| tuple_features = list() | |
| x = self.prefix(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| tuple_features.append(x) | |
| x = self.layer2(x) | |
| tuple_features.append(x) | |
| x = self.layer3(x) | |
| tuple_features.append(x) | |
| x = self.layer4(x) | |
| tuple_features.append(x) | |
| return tuple_features | |
| def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'): | |
| arch = backbone | |
| if arch == 'resnet18': | |
| orig_resnet = resnet18(pretrained=pretrained) | |
| arch_net = NormalResnetBackbone(orig_resnet) | |
| arch_net.num_features = 512 | |
| elif arch == 'resnet18_dilated8': | |
| orig_resnet = resnet18(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
| arch_net.num_features = 512 | |
| elif arch == 'resnet34': | |
| orig_resnet = resnet34(pretrained=pretrained) | |
| arch_net = NormalResnetBackbone(orig_resnet) | |
| arch_net.num_features = 512 | |
| elif arch == 'resnet34_dilated8': | |
| orig_resnet = resnet34(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
| arch_net.num_features = 512 | |
| elif arch == 'resnet34_dilated16': | |
| orig_resnet = resnet34(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
| arch_net.num_features = 512 | |
| elif arch == 'resnet50': | |
| orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) | |
| arch_net = NormalResnetBackbone(orig_resnet) | |
| elif arch == 'resnet50_dilated8': | |
| orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
| elif arch == 'resnet50_dilated16': | |
| orig_resnet = resnet50(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
| elif arch == 'deepbase_resnet50': | |
| if pretrained: | |
| pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' | |
| orig_resnet = deepbase_resnet50(pretrained=pretrained) | |
| arch_net = NormalResnetBackbone(orig_resnet) | |
| elif arch == 'deepbase_resnet50_dilated8': | |
| if pretrained: | |
| pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' | |
| # pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth" | |
| orig_resnet = deepbase_resnet50(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
| elif arch == 'deepbase_resnet50_dilated16': | |
| orig_resnet = deepbase_resnet50(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
| elif arch == 'resnet101': | |
| orig_resnet = resnet101(pretrained=pretrained) | |
| arch_net = NormalResnetBackbone(orig_resnet) | |
| elif arch == 'resnet101_dilated8': | |
| orig_resnet = resnet101(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
| elif arch == 'resnet101_dilated16': | |
| orig_resnet = resnet101(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
| elif arch == 'deepbase_resnet101': | |
| orig_resnet = deepbase_resnet101(pretrained=pretrained) | |
| arch_net = NormalResnetBackbone(orig_resnet) | |
| elif arch == 'deepbase_resnet101_dilated8': | |
| if pretrained: | |
| pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth' | |
| orig_resnet = deepbase_resnet101(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
| elif arch == 'deepbase_resnet101_dilated16': | |
| orig_resnet = deepbase_resnet101(pretrained=pretrained) | |
| arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
| else: | |
| raise Exception('Architecture undefined!') | |
| return arch_net | |