import torch import torch.nn as nn import torch.nn.functional as F class TimeEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim half_dim = dim // 2 emb = torch.log(torch.tensor(10000)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) self.register_buffer('emb', emb) def forward(self, t): emb = t.float()[:, None] * self.emb[None, :] emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1) return emb class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim, up=False): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) if up: self.conv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1) else: self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) self.norm = nn.GroupNorm(8, out_ch) self.act = nn.SiLU() def forward(self, x, t): h = self.conv(x) time_emb = self.time_mlp(t) h = h + time_emb[:, :, None, None] h = self.norm(h) h = self.act(h) return h class SmoothDiffusionUNet(nn.Module): def __init__(self, config): super().__init__() self.config = config # Time embedding self.time_mlp = TimeEmbedding(config.time_emb_dim) # Downsample blocks self.down1 = Block(config.in_channels, config.base_channels, config.time_emb_dim) self.down2 = Block(config.base_channels, config.base_channels*2, config.time_emb_dim) self.down3 = Block(config.base_channels*2, config.base_channels*4, config.time_emb_dim) # Middle blocks self.mid1 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim) self.mid2 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim) # Upsample blocks self.up1 = Block(config.base_channels*4, config.base_channels*2, config.time_emb_dim, up=True) self.up2 = Block(config.base_channels*6, config.base_channels, config.time_emb_dim, up=True) # 128 + 256 = 384 = 6*64 self.up3 = Block(config.base_channels*3, config.base_channels, config.time_emb_dim, up=True) # 64 + 128 = 192 = 3*64 # Final output self.out = nn.Conv2d(config.base_channels*2, config.in_channels, kernel_size=3, padding=1) # 128 = 2*64 def forward(self, x, t): # Time embedding t_emb = self.time_mlp(t) # Downsample path h1 = self.down1(x, t_emb) # [B, 64, H, W] h2 = self.down2(F.max_pool2d(h1, 2), t_emb) # [B, 128, H/2, W/2] h3 = self.down3(F.max_pool2d(h2, 2), t_emb) # [B, 256, H/4, W/4] # Bottleneck h = self.mid1(F.max_pool2d(h3, 2), t_emb) # [B, 256, H/8, W/8] h = self.mid2(h, t_emb) # [B, 256, H/8, W/8] # Upsample path h = self.up1(h, t_emb) # [B, 128, H/4, W/4] h = torch.cat([h, h3], dim=1) # [B, 384, H/4, W/4] h = self.up2(h, t_emb) # [B, 64, H/2, W/2] h = torch.cat([h, h2], dim=1) # [B, 192, H/2, W/2] h = self.up3(h, t_emb) # [B, 64, H, W] h = torch.cat([h, h1], dim=1) # [B, 128, H, W] return self.out(h)