import logging import numpy as np import torch import torch.nn as nn import torchvision from torchvision.models.feature_extraction import create_feature_extractor import feature_extractor_models as smp import torch from .base import BaseModel logger = logging.getLogger(__name__) class FeatureExtractor(BaseModel): default_conf = { "pretrained": True, "input_dim": 3, "output_dim": 128, # # of channels in output feature maps "encoder": "resnet50", # torchvision net as string "remove_stride_from_first_conv": False, "num_downsample": None, # how many downsample block "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks "do_average_pooling": False, "checkpointed": False, # whether to use gradient checkpointing "architecture":"FPN" } mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # self.fmodel=None def build_encoder(self, conf): assert isinstance(conf.encoder, str) if conf.pretrained: assert conf.input_dim == 3 # return encoder, layers def _init(self, conf): # Preprocessing self.register_buffer("mean_", torch.tensor(self.mean), persistent=False) self.register_buffer("std_", torch.tensor(self.std), persistent=False) if conf.architecture=="FPN": # Encoder self.fmodel = smp.FPN( encoder_name=conf.encoder, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization in_channels=conf.input_dim, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=conf.output_dim, # model output channels (number of classes in your dataset) upsampling=2, # optional, final output upsampling, default is 8 activation=None ) elif conf.architecture == "LightFPN": self.fmodel = smp.L( encoder_name=conf.encoder, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization in_channels=conf.input_dim, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=conf.output_dim, # model output channels (number of classes in your dataset) upsampling=2, # optional, final output upsampling, default is 8 activation=None ) elif conf.architecture=="PSP": self.fmodel =smp.PSPNet( encoder_name=conf.encoder, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization in_channels=conf.input_dim, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=conf.output_dim, # model output channels (number of classes in your dataset) upsampling=4, # optional, final output upsampling, default is 8 activation=None ) else: raise ValueError("Only FPN") # elif conf.architecture=="Unet": # self.fmodel = smp.FPN( # encoder_name=conf.encoder, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 # encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization # in_channels=conf.input_dim, # model input channels (1 for gray-scale images, 3 for RGB, etc.) # classes=conf.output_dim, # model output channels (number of classes in your dataset) # # upsampling=int(conf.upsampling), # optional, final output upsampling, default is 8 # activation="relu" # ) def _forward(self, data): image = data["image"] image = (image - self.mean_[:, None, None]) / self.std_[:, None, None] output = self.fmodel(image) # output = self.decoder(skip_features) pred = {"feature_maps": [output]} return pred if __name__ == '__main__': model=FeatureExtractor()