import math import torch from torch import nn import torch.nn.functional as F from .ops import (ConstantInput, ConvLayer, StyledConv, ToFlow, ToRGB, Direction) class FlowResBlock(nn.Module): def __init__(self, in_channel, out_channel, style_dim): super().__init__() self.norm = nn.GroupNorm(32, out_channel) self.conv1 = StyledConv(in_channel, out_channel, 3, style_dim, False) self.conv2 = StyledConv(out_channel, out_channel, 3, style_dim, False) self.gamma = nn.Parameter(1e-5 * torch.ones([1, out_channel, 1, 1])) def forward(self, x, style): h = x h = self.conv1(h, style) skip = h h = self.norm(h) h = self.conv2(h, style) h = self.gamma * h return h + skip class ResBlock(nn.Module): def __init__(self, in_channel, out_channel): super().__init__() self.conv1 = ConvLayer(in_channel, out_channel, 3, upsample=False) self.conv2 = ConvLayer(out_channel, out_channel, 3, upsample=False) if in_channel != out_channel: self.skip = ConvLayer(in_channel, out_channel, 1, upsample=False, activate=False, bias=False) else: self.skip = torch.nn.Identity() def forward(self, x): h = x h = self.conv1(h) h = self.conv2(h) skip = self.skip(x) return (h + skip) / math.sqrt(2) class Decoder(nn.Module): def __init__(self, style_dim, motion_dim, scale=1): super().__init__() channels = [512*scale, 256 * scale, 128 * scale, 64 * scale] self.direction = Direction(style_dim, motion_dim) self.input = ConstantInput(channels[0], size=4) # 4 # block1, 4 self.conv1 = StyledConv(channels[0], channels[0], 3, style_dim, False) # for 512 self.conv_512_1 = StyledConv(channels[0], channels[0], 3, style_dim, True) self.conv_512_2 = nn.ModuleList([ FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), ]) self.conv_512_2_rgb = nn.ModuleList([ ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ]) self.rgb_512 = ToRGB(channels[0]) self.flow_512 = ToFlow(channels[0], style_dim) # 16 # block2, 8 self.conv2_1 = StyledConv(channels[0], channels[0], 3, style_dim, True) self.conv2_2 = nn.ModuleList([ FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), ]) self.conv2_2_up = ConvLayer(channels[0], channels[0], 3, upsample=True) self.conv2_2_rgb = nn.ModuleList([ ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ]) self.rgb2 = ToRGB(channels[0]) self.flow2 = ToFlow(channels[0], style_dim) # 16 # block3, 16 self.conv3_1 = StyledConv(channels[0], channels[0], 3, style_dim, True) self.conv3_2 = nn.ModuleList([ FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), ]) self.conv3_2_up = ConvLayer(channels[0], channels[0], 3, upsample=True) self.conv3_2_rgb = nn.ModuleList([ ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ]) self.rgb3 = ToRGB(channels[0]) self.flow3 = ToFlow(channels[0], style_dim) # 32 # block4, 32 self.conv4_1 = StyledConv(channels[0], channels[0], 3, style_dim, True) self.conv4_2 = nn.ModuleList([ FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), FlowResBlock(channels[0], channels[0], style_dim), ]) self.conv4_2_up = ConvLayer(channels[0], channels[0], 3, upsample=True) self.conv4_2_rgb = nn.ModuleList([ ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ResBlock(channels[0], channels[0]), ]) self.rgb4 = ToRGB(channels[0]) self.flow4 = ToFlow(channels[0], style_dim) # 64 # block5, 64 self.conv5_1 = StyledConv(channels[0], channels[1], 3, style_dim, True) self.conv5_2 = nn.ModuleList([ FlowResBlock(channels[1], channels[1], style_dim), FlowResBlock(channels[1], channels[1], style_dim), FlowResBlock(channels[1], channels[1], style_dim), FlowResBlock(channels[1], channels[1], style_dim), ]) self.conv5_2_up = ConvLayer(channels[0], channels[1], 3, upsample=True) self.conv5_2_rgb = nn.ModuleList([ ResBlock(channels[1], channels[1]), ResBlock(channels[1], channels[1]), ResBlock(channels[1], channels[1]), ResBlock(channels[1], channels[1]), ]) self.rgb5 = ToRGB(channels[1]) self.flow5 = ToFlow(channels[1], style_dim) # 128 # block6, 128 self.conv6_1 = StyledConv(channels[1], channels[2], 3, style_dim, True) self.conv6_2 = nn.ModuleList([ FlowResBlock(channels[2], channels[2], style_dim), FlowResBlock(channels[2], channels[2], style_dim), FlowResBlock(channels[2], channels[2], style_dim), FlowResBlock(channels[2], channels[2], style_dim), ]) self.conv6_2_up = ConvLayer(channels[1], channels[2], 3, upsample=True) self.conv6_2_rgb = nn.ModuleList([ ResBlock(channels[2], channels[2]), ResBlock(channels[2], channels[2]), ResBlock(channels[2], channels[2]), ResBlock(channels[2], channels[2]), ]) self.rgb6 = ToRGB(channels[2]) self.flow6 = ToFlow(channels[2], style_dim) # 128 # block7, 256 self.conv7_1 = StyledConv(channels[2], channels[3], 3, style_dim, True) self.conv7_2 = nn.ModuleList([ FlowResBlock(channels[3], channels[3], style_dim), FlowResBlock(channels[3], channels[3], style_dim), FlowResBlock(channels[3], channels[3], style_dim), FlowResBlock(channels[3], channels[3], style_dim), ]) self.conv7_2_up = ConvLayer(channels[2], channels[3], 3, upsample=True) self.conv7_2_rgb = nn.ModuleList([ ResBlock(channels[3], channels[3]), ResBlock(channels[3], channels[3]), ResBlock(channels[3], channels[3]), ResBlock(channels[3], channels[3]), ]) self.rgb7 = ToRGB(channels[3]) self.flow7 = ToFlow(channels[3], style_dim) # 128 def navigation(self, z_s2r, alpha): if alpha is not None: # generating moving directions if len(alpha) > 1: z_r2t = self.direction(alpha[0]) # target z_r2s = self.direction(alpha[1]) # source z_start = self.direction(alpha[2]) # start z_s2t = z_s2r + (z_r2t - z_start) + z_r2s else: z_r2t = self.direction(alpha[0]) z_s2t = z_s2r + z_r2t # wa + directions else: z_s2t = z_s2r return z_s2t def apply_flow(self, h, mask, flow, feat): feat_warp = F.grid_sample(feat, flow) * mask h = feat_warp + (1 - mask) * h return feat_warp, h def forward(self, z_s2r, alpha, feats): # z_s2r: bs x style_dim # alpha: bs x style_dim z_s2t = self.navigation(z_s2r, alpha) h = self.input(z_s2t) h = self.conv1(h, z_s2t) #for 512 h = self.conv_512_1(h, z_s2t) for conv in self.conv_512_2: h = conv(h, z_s2t) h_warp_512, h, h_flow_512 = self.flow_512(h, z_s2t, feats[0]) for conv in self.conv_512_2_rgb: h_warp_512 = conv(h_warp_512) rgb_512 = self.rgb_512(h_warp_512) h = self.conv2_1(h, z_s2t) for conv in self.conv2_2: h = conv(h, z_s2t) h_warp2, h, h_flow2 = self.flow2(h, z_s2t, feats[1], h_flow_512) h_warp2 = h_warp2 + self.conv2_2_up(h_warp_512) for conv in self.conv2_2_rgb: h_warp2 = conv(h_warp2) rgb2 = self.rgb2(h_warp2, rgb_512) h = self.conv3_1(h, z_s2t) for conv in self.conv3_2: h = conv(h, z_s2t) h_warp3, h, h_flow3 = self.flow3(h, z_s2t, feats[2], h_flow2) h_warp3 = h_warp3 + self.conv3_2_up(h_warp2) for conv in self.conv3_2_rgb: h_warp3 = conv(h_warp3) rgb3 = self.rgb3(h_warp3, rgb2) h = self.conv4_1(h, z_s2t) for conv in self.conv4_2: h = conv(h, z_s2t) h_warp4, h, h_flow4 = self.flow4(h, z_s2t, feats[3], h_flow3) h_warp4 = h_warp4 + self.conv4_2_up(h_warp3) for conv in self.conv4_2_rgb: h_warp4 = conv(h_warp4) rgb4 = self.rgb4(h_warp4, rgb3) h = self.conv5_1(h, z_s2t) for conv in self.conv5_2: h = conv(h, z_s2t) h_warp5, h, h_flow5 = self.flow5(h, z_s2t, feats[4], h_flow4) h_warp5 = h_warp5 + self.conv5_2_up(h_warp4) for conv in self.conv5_2_rgb: h_warp5 = conv(h_warp5) rgb5 = self.rgb5(h_warp5, rgb4) h = self.conv6_1(h, z_s2t) for conv in self.conv6_2: h = conv(h, z_s2t) h_warp6, h, h_flow6 = self.flow6(h, z_s2t, feats[5], h_flow5) h_warp6 = h_warp6 + self.conv6_2_up(h_warp5) for conv in self.conv6_2_rgb: h_warp6 = conv(h_warp6) rgb6 = self.rgb6(h_warp6, rgb5) h = self.conv7_1(h, z_s2t) for conv in self.conv7_2: h = conv(h, z_s2t) h_warp7, h, h_flow7 = self.flow7(h, z_s2t, feats[6], h_flow6) h_warp7 = h_warp7 + self.conv7_2_up(h_warp6) for conv in self.conv7_2_rgb: h_warp7 = conv(h_warp7) out = self.rgb7(h_warp7, rgb6) return out