Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| class UNet(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.mean = [0.485, 0.456, 0.406] | |
| self.std = [0.229, 0.224, 0.225] | |
| # Downsampler | |
| self.enc_conv0 = nn.Sequential( | |
| nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(64), | |
| nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(64), | |
| nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(64) | |
| ) | |
| self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.enc_conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(128), | |
| nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(128), | |
| nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(128) | |
| ) | |
| self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.enc_conv2 = nn.Sequential( | |
| nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(256), | |
| nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(256), | |
| nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(256) | |
| ) | |
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.enc_conv3 = nn.Sequential( | |
| nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(512), | |
| nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(512), | |
| nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(512) | |
| ) | |
| self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # bottleneck | |
| self.bottleneck_conv = nn.Sequential( | |
| nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(1024), | |
| nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(1024), | |
| nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(1024) | |
| ) | |
| # Upsampler | |
| self.upsample0 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1), | |
| ) | |
| self.dec_conv0 = nn.Sequential( | |
| nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(512), | |
| nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(512), | |
| nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(512) | |
| ) | |
| self.upsample1 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1), | |
| ) | |
| self.dec_conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(256), | |
| nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(256), | |
| nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(256) | |
| ) | |
| self.upsample2 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1), | |
| ) | |
| self.dec_conv2 = nn.Sequential( | |
| nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(128), | |
| nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(128), | |
| nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(128) | |
| ) | |
| self.upsample3 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1), | |
| ) | |
| self.dec_conv3 = nn.Sequential( | |
| nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(64), | |
| nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.BatchNorm2d(64), | |
| nn.Conv2d(in_channels=64, out_channels=18, kernel_size=1, stride=1, padding=0) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # encoder | |
| e0 = self.enc_conv0(x) | |
| e1 = self.pool0(e0) | |
| e1 = self.enc_conv1(e1) | |
| e2 = self.pool1(e1) | |
| e2 = self.enc_conv2(e2) | |
| e3 = self.pool2(e2) | |
| e3 = self.enc_conv3(e3) | |
| # bottleneck | |
| b = self.pool3(e3) | |
| b = self.bottleneck_conv(b) | |
| # decoder | |
| d0 = self.upsample0(b) | |
| d0 = torch.cat([d0, e3], dim=1) | |
| d0 = self.dec_conv0(d0) | |
| d1 = self.upsample1(d0) | |
| d1 = torch.cat([d1, e2], dim=1) | |
| d1 = self.dec_conv1(d1) | |
| d2 = self.upsample2(d1) | |
| d2 = torch.cat([d2, e1], dim=1) | |
| d2 = self.dec_conv2(d2) | |
| d3 = self.upsample3(d2) | |
| d3 = torch.cat([d3, e0], dim=1) | |
| d3 = self.dec_conv3(d3) | |
| return d3 |