Spaces:
Sleeping
Sleeping
from . import encoders | |
from . import decoders | |
from .decoders.unet import Unet | |
from .decoders.unetplusplus import UnetPlusPlus | |
from .decoders.manet import MAnet | |
from .decoders.linknet import Linknet | |
from .decoders.fpn import FPN | |
from .decoders.lightfpn import LightFPN | |
from .decoders.pspnet import PSPNet | |
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus | |
from .decoders.pan import PAN | |
from .base.hub_mixin import from_pretrained | |
from .__version__ import __version__ | |
# some private imports for create_model function | |
from typing import Optional as _Optional | |
import torch as _torch | |
def create_model( | |
arch: str, | |
encoder_name: str = "resnet34", | |
encoder_weights: _Optional[str] = "imagenet", | |
in_channels: int = 3, | |
classes: int = 1, | |
**kwargs, | |
) -> _torch.nn.Module: | |
"""Models entrypoint, allows to create any model architecture just with | |
parameters, without using its class | |
""" | |
archs = [ | |
Unet, | |
UnetPlusPlus, | |
MAnet, | |
Linknet, | |
FPN, | |
LightFPN, | |
PSPNet, | |
DeepLabV3, | |
DeepLabV3Plus, | |
PAN, | |
] | |
archs_dict = {a.__name__.lower(): a for a in archs} | |
try: | |
model_class = archs_dict[arch.lower()] | |
except KeyError: | |
raise KeyError( | |
"Wrong architecture type `{}`. Available options are: {}".format( | |
arch, list(archs_dict.keys()) | |
) | |
) | |
return model_class( | |
encoder_name=encoder_name, | |
encoder_weights=encoder_weights, | |
in_channels=in_channels, | |
classes=classes, | |
**kwargs, | |
) | |
__all__ = [ | |
"encoders", | |
"decoders", | |
"Unet", | |
"UnetPlusPlus", | |
"MAnet", | |
"Linknet", | |
"FPN", | |
"LightFPN", | |
"PSPNet", | |
"DeepLabV3", | |
"DeepLabV3Plus", | |
"PAN", | |
"from_pretrained", | |
"create_model", | |
"__version__", | |
] | |