import torch import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) if in_channels != out_channels: self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels)) else: self.shortcut = nn.Identity() def forward(self, x): residual = self.shortcut(x) out = self.conv1(x); out = self.bn1(out); out = self.relu(out) out = self.conv2(out); out = self.bn2(out) out += residual out = self.relu(out) return out class ComplexUNet(nn.Module): def __init__(self, base_channels=96): # Default to the trained architecture super(ComplexUNet, self).__init__() c = base_channels self.pool = nn.MaxPool2d(2, 2) self.enc1 = ResidualBlock(3, c) self.enc2 = ResidualBlock(c, c*2) self.enc3 = ResidualBlock(c*2, c*4) self.enc4 = ResidualBlock(c*4, c*8) self.bottleneck = ResidualBlock(c*8, c*16) self.upconv1 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2) self.upconv2 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2) self.upconv3 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2) self.upconv4 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2) self.dec_conv1 = ResidualBlock(c*16, c*8) self.dec_conv2 = ResidualBlock(c*8, c*4) self.dec_conv3 = ResidualBlock(c*4, c*2) self.dec_conv4 = ResidualBlock(c*2, c) self.final_conv = nn.Conv2d(c, 3, kernel_size=1) def forward(self, x): e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2) e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4) b = self.bottleneck(p4) d1 = self.upconv1(b); d1 = torch.cat([d1, e4], dim=1); d1 = self.dec_conv1(d1) d2 = self.upconv2(d1); d2 = torch.cat([d2, e3], dim=1); d2 = self.dec_conv2(d2) d3 = self.upconv3(d2); d3 = torch.cat([d3, e2], dim=1); d3 = self.dec_conv3(d3) d4 = self.upconv4(d3); d4 = torch.cat([d4, e1], dim=1); d4 = self.dec_conv4(d4) out = self.final_conv(d4) return torch.sigmoid(out)