Spaces:
Sleeping
Sleeping
import torch | |
from . import initialization as init | |
from .hub_mixin import SMPHubMixin | |
import torch.nn as nn | |
class SegmentationModel(torch.nn.Module, SMPHubMixin): | |
def initialize(self): | |
# self.out = nn.Sequential( | |
# nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), | |
# nn.BatchNorm2d(8), | |
# nn.ReLU(inplace=True), | |
# ) | |
init.initialize_decoder(self.decoder) | |
init.initialize_head(self.segmentation_head) | |
if self.classification_head is not None: | |
init.initialize_head(self.classification_head) | |
def check_input_shape(self, x): | |
h, w = x.shape[-2:] | |
output_stride = self.encoder.output_stride | |
if h % output_stride != 0 or w % output_stride != 0: | |
new_h = ( | |
(h // output_stride + 1) * output_stride | |
if h % output_stride != 0 | |
else h | |
) | |
new_w = ( | |
(w // output_stride + 1) * output_stride | |
if w % output_stride != 0 | |
else w | |
) | |
raise RuntimeError( | |
f"Wrong input shape height={h}, width={w}. Expected image height and width " | |
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})." | |
) | |
def forward(self, x): | |
"""Sequentially pass `x` trough model`s encoder, decoder and heads""" | |
self.check_input_shape(x) | |
features = self.encoder(x) | |
decoder_output = self.decoder(*features) | |
decoder_output = self.segmentation_head(decoder_output) | |
# | |
# if self.classification_head is not None: | |
# labels = self.classification_head(features[-1]) | |
# return masks, labels | |
return decoder_output | |
def predict(self, x): | |
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` | |
Args: | |
x: 4D torch tensor with shape (batch_size, channels, height, width) | |
Return: | |
prediction: 4D torch tensor with shape (batch_size, classes, height, width) | |
""" | |
if self.training: | |
self.eval() | |
x = self.forward(x) | |
return x | |