|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
half_dim = dim // 2 |
|
|
emb = torch.log(torch.tensor(10000)) / (half_dim - 1) |
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) |
|
|
self.register_buffer('emb', emb) |
|
|
|
|
|
def forward(self, t): |
|
|
emb = t.float()[:, None] * self.emb[None, :] |
|
|
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1) |
|
|
return emb |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, in_ch, out_ch, time_emb_dim, up=False): |
|
|
super().__init__() |
|
|
self.time_mlp = nn.Linear(time_emb_dim, out_ch) |
|
|
if up: |
|
|
self.conv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1) |
|
|
else: |
|
|
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) |
|
|
self.norm = nn.GroupNorm(8, out_ch) |
|
|
self.act = nn.SiLU() |
|
|
|
|
|
def forward(self, x, t): |
|
|
h = self.conv(x) |
|
|
time_emb = self.time_mlp(t) |
|
|
h = h + time_emb[:, :, None, None] |
|
|
h = self.norm(h) |
|
|
h = self.act(h) |
|
|
return h |
|
|
|
|
|
class SmoothDiffusionUNet(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.time_mlp = TimeEmbedding(config.time_emb_dim) |
|
|
|
|
|
|
|
|
self.down1 = Block(config.in_channels, config.base_channels, config.time_emb_dim) |
|
|
self.down2 = Block(config.base_channels, config.base_channels*2, config.time_emb_dim) |
|
|
self.down3 = Block(config.base_channels*2, config.base_channels*4, config.time_emb_dim) |
|
|
|
|
|
|
|
|
self.mid1 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim) |
|
|
self.mid2 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim) |
|
|
|
|
|
|
|
|
self.up1 = Block(config.base_channels*4, config.base_channels*2, config.time_emb_dim, up=True) |
|
|
self.up2 = Block(config.base_channels*6, config.base_channels, config.time_emb_dim, up=True) |
|
|
self.up3 = Block(config.base_channels*3, config.base_channels, config.time_emb_dim, up=True) |
|
|
|
|
|
|
|
|
self.out = nn.Conv2d(config.base_channels*2, config.in_channels, kernel_size=3, padding=1) |
|
|
|
|
|
def forward(self, x, t): |
|
|
|
|
|
t_emb = self.time_mlp(t) |
|
|
|
|
|
|
|
|
h1 = self.down1(x, t_emb) |
|
|
h2 = self.down2(F.max_pool2d(h1, 2), t_emb) |
|
|
h3 = self.down3(F.max_pool2d(h2, 2), t_emb) |
|
|
|
|
|
|
|
|
h = self.mid1(F.max_pool2d(h3, 2), t_emb) |
|
|
h = self.mid2(h, t_emb) |
|
|
|
|
|
|
|
|
h = self.up1(h, t_emb) |
|
|
h = torch.cat([h, h3], dim=1) |
|
|
h = self.up2(h, t_emb) |
|
|
h = torch.cat([h, h2], dim=1) |
|
|
h = self.up3(h, t_emb) |
|
|
h = torch.cat([h, h1], dim=1) |
|
|
|
|
|
return self.out(h) |