Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): | |
"""Change first convolution layer input channels. | |
In case: | |
in_channels == 1 or in_channels == 2 -> reuse original weights | |
in_channels > 3 -> make random kaiming normal initialization | |
""" | |
# get first conv | |
for module in model.modules(): | |
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: | |
break | |
weight = module.weight.detach() | |
module.in_channels = new_in_channels | |
if not pretrained: | |
module.weight = nn.parameter.Parameter( | |
torch.Tensor( | |
module.out_channels, | |
new_in_channels // module.groups, | |
*module.kernel_size, | |
) | |
) | |
module.reset_parameters() | |
elif new_in_channels == 1: | |
new_weight = weight.sum(1, keepdim=True) | |
module.weight = nn.parameter.Parameter(new_weight) | |
else: | |
new_weight = torch.Tensor( | |
module.out_channels, new_in_channels // module.groups, *module.kernel_size | |
) | |
for i in range(new_in_channels): | |
new_weight[:, i] = weight[:, i % default_in_channels] | |
new_weight = new_weight * (default_in_channels / new_in_channels) | |
module.weight = nn.parameter.Parameter(new_weight) | |
def replace_strides_with_dilation(module, dilation_rate): | |
"""Patch Conv2d modules replacing strides with dilation""" | |
for mod in module.modules(): | |
if isinstance(mod, nn.Conv2d): | |
mod.stride = (1, 1) | |
mod.dilation = (dilation_rate, dilation_rate) | |
kh, kw = mod.kernel_size | |
mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) | |
# Kostyl for EfficientNet | |
if hasattr(mod, "static_padding"): | |
mod.static_padding = nn.Identity() | |