Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich | |
# https://github.com/cvg/pixloc | |
# Released under the Apache License 2.0 | |
""" | |
Flexible UNet model which takes any Torchvision backbone as encoder. | |
Predicts multi-level feature and makes sure that they are well aligned. | |
""" | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from .base import BaseModel | |
from .utils import checkpointed | |
class DecoderBlock(nn.Module): | |
def __init__( | |
self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros" | |
): | |
super().__init__() | |
self.upsample = nn.Upsample( | |
scale_factor=2, mode="bilinear", align_corners=False | |
) | |
layers = [] | |
for i in range(num_convs): | |
conv = nn.Conv2d( | |
previous + skip if i == 0 else out, | |
out, | |
kernel_size=3, | |
padding=1, | |
bias=norm is None, | |
padding_mode=padding, | |
) | |
layers.append(conv) | |
if norm is not None: | |
layers.append(norm(out)) | |
layers.append(nn.ReLU(inplace=True)) | |
self.layers = nn.Sequential(*layers) | |
def forward(self, previous, skip): | |
upsampled = self.upsample(previous) | |
# If the shape of the input map `skip` is not a multiple of 2, | |
# it will not match the shape of the upsampled map `upsampled`. | |
# If the downsampling uses ceil_mode=False, we nedd to crop `skip`. | |
# If it uses ceil_mode=True (not supported here), we should pad it. | |
_, _, hu, wu = upsampled.shape | |
_, _, hs, ws = skip.shape | |
assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?" | |
# assert (hu == hs) and (wu == ws), 'Careful about padding' | |
skip = skip[:, :, :hu, :wu] | |
return self.layers(torch.cat([upsampled, skip], dim=1)) | |
class AdaptationBlock(nn.Sequential): | |
def __init__(self, inp, out): | |
conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True) | |
super().__init__(conv) | |
class FeatureExtractor(BaseModel): | |
default_conf = { | |
"pretrained": True, | |
"input_dim": 3, | |
"output_scales": [0, 2, 4], # what scales to adapt and output | |
"output_dim": 128, # # of channels in output feature maps | |
"encoder": "vgg16", # string (torchvision net) or list of channels | |
"num_downsample": 4, # how many downsample block (if VGG-style net) | |
"decoder": [64, 64, 64, 64], # list of channels of decoder | |
"decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks | |
"do_average_pooling": False, | |
"checkpointed": False, # whether to use gradient checkpointing | |
"padding": "zeros" | |
} | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
def build_encoder(self, conf): | |
assert isinstance(conf.encoder, str) | |
if conf.pretrained: | |
assert conf.input_dim == 3 | |
Encoder = getattr(torchvision.models, conf.encoder) | |
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None) | |
Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed) | |
assert max(conf.output_scales) <= conf.num_downsample | |
if conf.encoder.startswith("vgg"): | |
# Parse the layers and pack them into downsampling blocks | |
# It's easy for VGG-style nets because of their linear structure. | |
# This does not handle strided convs and residual connections | |
skip_dims = [] | |
previous_dim = None | |
blocks = [[]] | |
for i, layer in enumerate(encoder.features): | |
if isinstance(layer, torch.nn.Conv2d): | |
# Change the first conv layer if the input dim mismatches | |
if i == 0 and conf.input_dim != layer.in_channels: | |
args = {k: getattr(layer, k) for k in layer.__constants__} | |
args.pop("output_padding") | |
layer = torch.nn.Conv2d( | |
**{**args, "in_channels": conf.input_dim} | |
) | |
previous_dim = layer.out_channels | |
elif isinstance(layer, torch.nn.MaxPool2d): | |
assert previous_dim is not None | |
skip_dims.append(previous_dim) | |
if (conf.num_downsample + 1) == len(blocks): | |
break | |
blocks.append([]) # start a new block | |
if conf.do_average_pooling: | |
assert layer.dilation == 1 | |
layer = torch.nn.AvgPool2d( | |
kernel_size=layer.kernel_size, | |
stride=layer.stride, | |
padding=layer.padding, | |
ceil_mode=layer.ceil_mode, | |
count_include_pad=False, | |
) | |
blocks[-1].append(layer) | |
encoder = [Block(*b) for b in blocks] | |
elif conf.encoder.startswith("resnet"): | |
# Manually define the ResNet blocks such that the downsampling comes first | |
assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"] | |
assert conf.input_dim == 3, "Unsupported for now." | |
block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) | |
block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1) | |
block3 = encoder.layer2 | |
block4 = encoder.layer3 | |
block5 = encoder.layer4 | |
blocks = [block1, block2, block3, block4, block5] | |
# Extract the output dimension of each block | |
skip_dims = [encoder.conv1.out_channels] | |
for i in range(1, 5): | |
modules = getattr(encoder, f"layer{i}")[-1]._modules | |
conv = sorted(k for k in modules if k.startswith("conv"))[-1] | |
skip_dims.append(modules[conv].out_channels) | |
# Add a dummy block such that the first one does not downsample | |
encoder = [torch.nn.Identity()] + [Block(b) for b in blocks] | |
skip_dims = [3] + skip_dims | |
# Trim based on the requested encoder size | |
encoder = encoder[: conf.num_downsample + 1] | |
skip_dims = skip_dims[: conf.num_downsample + 1] | |
else: | |
raise NotImplementedError(conf.encoder) | |
assert (conf.num_downsample + 1) == len(encoder) | |
encoder = nn.ModuleList(encoder) | |
return encoder, skip_dims | |
def _init(self, conf): | |
# Encoder | |
self.encoder, skip_dims = self.build_encoder(conf) | |
self.skip_dims = skip_dims | |
def update_padding(module): | |
if isinstance(module, nn.Conv2d): | |
module.padding_mode = conf.padding | |
if conf.padding != "zeros": | |
self.encoder.apply(update_padding) | |
# Decoder | |
if conf.decoder is not None: | |
assert len(conf.decoder) == (len(skip_dims) - 1) | |
Block = checkpointed(DecoderBlock, do=conf.checkpointed) | |
norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa | |
previous = skip_dims[-1] | |
decoder = [] | |
for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]): | |
decoder.append( | |
Block(previous, skip, out, norm=norm, padding=conf.padding) | |
) | |
previous = out | |
self.decoder = nn.ModuleList(decoder) | |
# Adaptation layers | |
adaptation = [] | |
for idx, i in enumerate(conf.output_scales): | |
if conf.decoder is None or i == (len(self.encoder) - 1): | |
input_ = skip_dims[i] | |
else: | |
input_ = conf.decoder[-1 - i] | |
# out_dim can be an int (same for all scales) or a list (per scale) | |
dim = conf.output_dim | |
if not isinstance(dim, int): | |
dim = dim[idx] | |
block = AdaptationBlock(input_, dim) | |
adaptation.append(block) | |
self.adaptation = nn.ModuleList(adaptation) | |
self.scales = [2**s for s in conf.output_scales] | |
def _forward(self, data): | |
image = data["image"] | |
if self.conf.pretrained: | |
mean, std = image.new_tensor(self.mean), image.new_tensor(self.std) | |
image = (image - mean[:, None, None]) / std[:, None, None] | |
skip_features = [] | |
features = image | |
for block in self.encoder: | |
# print("block features",block) | |
features = block(features) | |
skip_features.append(features) | |
if self.conf.decoder: | |
pre_features = [skip_features[-1]] | |
for block, skip in zip(self.decoder, skip_features[:-1][::-1]): | |
pre_features.append(block(pre_features[-1], skip)) | |
pre_features = pre_features[::-1] # fine to coarse | |
else: | |
pre_features = skip_features | |
out_features = [] | |
for adapt, i in zip(self.adaptation, self.conf.output_scales): | |
out_features.append(adapt(pre_features[i])) | |
pred = {"feature_maps": out_features, "skip_features": skip_features} | |
return pred | |
def loss(self, pred, data): | |
raise NotImplementedError | |
def metrics(self, pred, data): | |
raise NotImplementedError | |