VecMapLocNet / models /feature_extractor_v3.py
wangerniu's picture
添加必要文件
c9b5796
import logging
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor
from .base import BaseModel
logger = logging.getLogger(__name__)
class DecoderBlock(nn.Module):
def __init__(
self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
):
super().__init__()
layers = []
for i in range(num_convs):
conv = nn.Conv2d(
previous if i == 0 else out,
out,
kernel_size=ksize,
padding=ksize // 2,
bias=norm is None,
padding_mode=padding,
)
layers.append(conv)
if norm is not None:
layers.append(norm(out))
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
def forward(self, previous, skip):
_, _, hp, wp = previous.shape
_, _, hs, ws = skip.shape
scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
upsampled = nn.functional.interpolate(
previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
)
# If the shape of the input map `skip` is not a multiple of 2,
# it will not match the shape of the upsampled map `upsampled`.
# If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
# If it uses ceil_mode=True (not supported here), we should pad it.
_, _, hu, wu = upsampled.shape
_, _, hs, ws = skip.shape
if (hu <= hs) and (wu <= ws):
skip = skip[:, :, :hu, :wu]
elif (hu >= hs) and (wu >= ws):
skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
else:
raise ValueError(
f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
)
return self.layers(skip) + upsampled
class FPN(nn.Module):
def __init__(self, in_channels_list, out_channels, **kw):
super().__init__()
self.first = nn.Conv2d(
in_channels_list[-1], out_channels, 1, padding=0, bias=True
)
self.blocks = nn.ModuleList(
[
DecoderBlock(c, out_channels, ksize=1, **kw)
for c in in_channels_list[::-1][1:]
]
)
self.out = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, layers):
feats = None
for idx, x in enumerate(reversed(layers.values())):
if feats is None:
feats = self.first(x)
else:
feats = self.blocks[idx - 1](feats, x)
out = self.out(feats)
return out
def remove_conv_stride(conv):
conv_new = nn.Conv2d(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
bias=conv.bias is not None,
stride=1,
padding=conv.padding,
)
conv_new.weight = conv.weight
conv_new.bias = conv.bias
return conv_new
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class FeatureExtractor(BaseModel):
default_conf = {
"pretrained": True,
"input_dim": 3,
"output_dim": 128, # # of channels in output feature maps
"encoder": "resnet50", # torchvision net as string
"remove_stride_from_first_conv": False,
"num_downsample": None, # how many downsample block
"decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
"do_average_pooling": False,
"checkpointed": False, # whether to use gradient checkpointing
}
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def build_encoder(self, conf):
assert isinstance(conf.encoder, str)
if conf.pretrained:
assert conf.input_dim == 3
# Encoder
self.conv1 = self.conv_block(conf.input_dim, 64)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = self.conv_block(64, 128)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3 = self.conv_block(128, 256)
self.pool3 = nn.MaxPool2d(2, 2)
self.conv4 = self.conv_block(256, 512)
self.pool4 = nn.MaxPool2d(2, 2)
self.conv5 = self.conv_block(512, 1024)
# Decoder
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = self.conv_block(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = self.conv_block(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = self.conv_block(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = self.conv_block(128, 64)
self.conv10 = nn.Conv2d(64, conf.output_dim, 1)
# return encoder, layers
def unet(self,x):
# Encoder
conv1 = self.conv1(x)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
pool4 = self.pool4(conv4)
conv5 = self.conv5(pool4)
# Decoder
up6 = self.up6(conv5)
concat6 = torch.cat([up6, conv4], dim=1)
conv6 = self.conv6(concat6)
up7 = self.up7(conv6)
concat7 = torch.cat([up7, conv3], dim=1)
conv7 = self.conv7(concat7)
up8 = self.up8(conv7)
concat8 = torch.cat([up8, conv2], dim=1)
conv8 = self.conv8(concat8)
up9 = self.up9(conv8)
concat9 = torch.cat([up9, conv1], dim=1)
conv9 = self.conv9(concat9)
output = self.conv10(conv9)
return output
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def _init(self, conf):
# Preprocessing
self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
self.register_buffer("std_", torch.tensor(self.std), persistent=False)
# Encoder
self.build_encoder(conf)
def _forward(self, data):
image = data["image"]
image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
output = self.unet(image)
# output = self.decoder(skip_features)
pred = {"feature_maps": [output]}
return pred
if __name__ == '__main__':
model=FeatureExtractor()