import logging import numpy as np import torch import torch.nn as nn import torchvision from torchvision.models.feature_extraction import create_feature_extractor from .base import BaseModel logger = logging.getLogger(__name__) class DecoderBlock(nn.Module): def __init__( self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros" ): super().__init__() layers = [] for i in range(num_convs): conv = nn.Conv2d( previous if i == 0 else out, out, kernel_size=ksize, padding=ksize // 2, bias=norm is None, padding_mode=padding, ) layers.append(conv) if norm is not None: layers.append(norm(out)) layers.append(nn.ReLU(inplace=True)) self.layers = nn.Sequential(*layers) def forward(self, previous, skip): _, _, hp, wp = previous.shape _, _, hs, ws = skip.shape scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp]))) upsampled = nn.functional.interpolate( previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False ) # If the shape of the input map `skip` is not a multiple of 2, # it will not match the shape of the upsampled map `upsampled`. # If the downsampling uses ceil_mode=False, we nedd to crop `skip`. # If it uses ceil_mode=True (not supported here), we should pad it. _, _, hu, wu = upsampled.shape _, _, hs, ws = skip.shape if (hu <= hs) and (wu <= ws): skip = skip[:, :, :hu, :wu] elif (hu >= hs) and (wu >= ws): skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs]) else: raise ValueError( f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}" ) return self.layers(skip) + upsampled class FPN(nn.Module): def __init__(self, in_channels_list, out_channels, **kw): super().__init__() self.first = nn.Conv2d( in_channels_list[-1], out_channels, 1, padding=0, bias=True ) self.blocks = nn.ModuleList( [ DecoderBlock(c, out_channels, ksize=1, **kw) for c in in_channels_list[::-1][1:] ] ) self.out = nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, layers): feats = None for idx, x in enumerate(reversed(layers.values())): if feats is None: feats = self.first(x) else: feats = self.blocks[idx - 1](feats, x) out = self.out(feats) return out def remove_conv_stride(conv): conv_new = nn.Conv2d( conv.in_channels, conv.out_channels, conv.kernel_size, bias=conv.bias is not None, stride=1, padding=conv.padding, ) conv_new.weight = conv.weight conv_new.bias = conv.bias return conv_new class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class FeatureExtractor(BaseModel): default_conf = { "pretrained": True, "input_dim": 3, "output_dim": 128, # # of channels in output feature maps "encoder": "resnet50", # torchvision net as string "remove_stride_from_first_conv": False, "num_downsample": None, # how many downsample block "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks "do_average_pooling": False, "checkpointed": False, # whether to use gradient checkpointing } mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] def build_encoder(self, conf): assert isinstance(conf.encoder, str) if conf.pretrained: assert conf.input_dim == 3 # Encoder self.conv1 = self.conv_block(conf.input_dim, 64) self.pool1 = nn.MaxPool2d(2, 2) self.conv2 = self.conv_block(64, 128) self.pool2 = nn.MaxPool2d(2, 2) self.conv3 = self.conv_block(128, 256) self.pool3 = nn.MaxPool2d(2, 2) self.conv4 = self.conv_block(256, 512) self.pool4 = nn.MaxPool2d(2, 2) self.conv5 = self.conv_block(512, 1024) # Decoder self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv6 = self.conv_block(1024, 512) self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.conv7 = self.conv_block(512, 256) self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.conv8 = self.conv_block(256, 128) self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv9 = self.conv_block(128, 64) self.conv10 = nn.Conv2d(64, conf.output_dim, 1) # return encoder, layers def unet(self,x): # Encoder conv1 = self.conv1(x) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) pool4 = self.pool4(conv4) conv5 = self.conv5(pool4) # Decoder up6 = self.up6(conv5) concat6 = torch.cat([up6, conv4], dim=1) conv6 = self.conv6(concat6) up7 = self.up7(conv6) concat7 = torch.cat([up7, conv3], dim=1) conv7 = self.conv7(concat7) up8 = self.up8(conv7) concat8 = torch.cat([up8, conv2], dim=1) conv8 = self.conv8(concat8) up9 = self.up9(conv8) concat9 = torch.cat([up9, conv1], dim=1) conv9 = self.conv9(concat9) output = self.conv10(conv9) return output def conv_block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def _init(self, conf): # Preprocessing self.register_buffer("mean_", torch.tensor(self.mean), persistent=False) self.register_buffer("std_", torch.tensor(self.std), persistent=False) # Encoder self.build_encoder(conf) def _forward(self, data): image = data["image"] image = (image - self.mean_[:, None, None]) / self.std_[:, None, None] output = self.unet(image) # output = self.decoder(skip_features) pred = {"feature_maps": [output]} return pred if __name__ == '__main__': model=FeatureExtractor()