Spaces:
Sleeping
Sleeping
import timm | |
import functools | |
import torch.utils.model_zoo as model_zoo | |
from .resnet import resnet_encoders | |
from .dpn import dpn_encoders | |
from .vgg import vgg_encoders | |
from .senet import senet_encoders | |
from .densenet import densenet_encoders | |
from .inceptionresnetv2 import inceptionresnetv2_encoders | |
from .inceptionv4 import inceptionv4_encoders | |
from .efficientnet import efficient_net_encoders | |
from .mobilenet import mobilenet_encoders | |
from .xception import xception_encoders | |
from .timm_efficientnet import timm_efficientnet_encoders | |
from .timm_resnest import timm_resnest_encoders | |
from .timm_res2net import timm_res2net_encoders | |
from .timm_regnet import timm_regnet_encoders | |
from .timm_sknet import timm_sknet_encoders | |
from .timm_mobilenetv3 import timm_mobilenetv3_encoders | |
from .timm_gernet import timm_gernet_encoders | |
from .mix_transformer import mix_transformer_encoders | |
from .mobileone import mobileone_encoders | |
from .timm_universal import TimmUniversalEncoder | |
from ._preprocessing import preprocess_input | |
encoders = {} | |
encoders.update(resnet_encoders) | |
encoders.update(dpn_encoders) | |
encoders.update(vgg_encoders) | |
encoders.update(senet_encoders) | |
encoders.update(densenet_encoders) | |
encoders.update(inceptionresnetv2_encoders) | |
encoders.update(inceptionv4_encoders) | |
encoders.update(efficient_net_encoders) | |
encoders.update(mobilenet_encoders) | |
encoders.update(xception_encoders) | |
encoders.update(timm_efficientnet_encoders) | |
encoders.update(timm_resnest_encoders) | |
encoders.update(timm_res2net_encoders) | |
encoders.update(timm_regnet_encoders) | |
encoders.update(timm_sknet_encoders) | |
encoders.update(timm_mobilenetv3_encoders) | |
encoders.update(timm_gernet_encoders) | |
encoders.update(mix_transformer_encoders) | |
encoders.update(mobileone_encoders) | |
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): | |
if name.startswith("tu-"): | |
name = name[3:] | |
encoder = TimmUniversalEncoder( | |
name=name, | |
in_channels=in_channels, | |
depth=depth, | |
output_stride=output_stride, | |
pretrained=weights is not None, | |
**kwargs, | |
) | |
return encoder | |
try: | |
Encoder = encoders[name]["encoder"] | |
except KeyError: | |
raise KeyError( | |
"Wrong encoder name `{}`, supported encoders: {}".format( | |
name, list(encoders.keys()) | |
) | |
) | |
params = encoders[name]["params"] | |
params.update(depth=depth) | |
encoder = Encoder(**params) | |
if weights is not None: | |
try: | |
settings = encoders[name]["pretrained_settings"][weights] | |
except KeyError: | |
raise KeyError( | |
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( | |
weights, name, list(encoders[name]["pretrained_settings"].keys()) | |
) | |
) | |
encoder.load_state_dict(model_zoo.load_url(settings["url"])) | |
encoder.set_in_channels(in_channels, pretrained=weights is not None) | |
if output_stride != 32: | |
encoder.make_dilated(output_stride) | |
return encoder | |
def get_encoder_names(): | |
return list(encoders.keys()) | |
def get_preprocessing_params(encoder_name, pretrained="imagenet"): | |
if encoder_name.startswith("tu-"): | |
encoder_name = encoder_name[3:] | |
if not timm.models.is_model_pretrained(encoder_name): | |
raise ValueError( | |
f"{encoder_name} does not have pretrained weights and preprocessing parameters" | |
) | |
settings = timm.models.get_pretrained_cfg(encoder_name).__dict__ | |
else: | |
all_settings = encoders[encoder_name]["pretrained_settings"] | |
if pretrained not in all_settings.keys(): | |
raise ValueError( | |
"Available pretrained options {}".format(all_settings.keys()) | |
) | |
settings = all_settings[pretrained] | |
formatted_settings = {} | |
formatted_settings["input_space"] = settings.get("input_space", "RGB") | |
formatted_settings["input_range"] = list(settings.get("input_range", [0, 1])) | |
formatted_settings["mean"] = list(settings["mean"]) | |
formatted_settings["std"] = list(settings["std"]) | |
return formatted_settings | |
def get_preprocessing_fn(encoder_name, pretrained="imagenet"): | |
params = get_preprocessing_params(encoder_name, pretrained=pretrained) | |
return functools.partial(preprocess_input, **params) | |