Grad-CDM / model.py
nazgut's picture
Upload 24 files
8abfb97 verified
raw
history blame
3.36 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = dim // 2
emb = torch.log(torch.tensor(10000)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
self.register_buffer('emb', emb)
def forward(self, t):
emb = t.float()[:, None] * self.emb[None, :]
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
return emb
class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
else:
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.norm = nn.GroupNorm(8, out_ch)
self.act = nn.SiLU()
def forward(self, x, t):
h = self.conv(x)
time_emb = self.time_mlp(t)
h = h + time_emb[:, :, None, None]
h = self.norm(h)
h = self.act(h)
return h
class SmoothDiffusionUNet(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Time embedding
self.time_mlp = TimeEmbedding(config.time_emb_dim)
# Downsample blocks
self.down1 = Block(config.in_channels, config.base_channels, config.time_emb_dim)
self.down2 = Block(config.base_channels, config.base_channels*2, config.time_emb_dim)
self.down3 = Block(config.base_channels*2, config.base_channels*4, config.time_emb_dim)
# Middle blocks
self.mid1 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim)
self.mid2 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim)
# Upsample blocks
self.up1 = Block(config.base_channels*4, config.base_channels*2, config.time_emb_dim, up=True)
self.up2 = Block(config.base_channels*6, config.base_channels, config.time_emb_dim, up=True) # 128 + 256 = 384 = 6*64
self.up3 = Block(config.base_channels*3, config.base_channels, config.time_emb_dim, up=True) # 64 + 128 = 192 = 3*64
# Final output
self.out = nn.Conv2d(config.base_channels*2, config.in_channels, kernel_size=3, padding=1) # 128 = 2*64
def forward(self, x, t):
# Time embedding
t_emb = self.time_mlp(t)
# Downsample path
h1 = self.down1(x, t_emb) # [B, 64, H, W]
h2 = self.down2(F.max_pool2d(h1, 2), t_emb) # [B, 128, H/2, W/2]
h3 = self.down3(F.max_pool2d(h2, 2), t_emb) # [B, 256, H/4, W/4]
# Bottleneck
h = self.mid1(F.max_pool2d(h3, 2), t_emb) # [B, 256, H/8, W/8]
h = self.mid2(h, t_emb) # [B, 256, H/8, W/8]
# Upsample path
h = self.up1(h, t_emb) # [B, 128, H/4, W/4]
h = torch.cat([h, h3], dim=1) # [B, 384, H/4, W/4]
h = self.up2(h, t_emb) # [B, 64, H/2, W/2]
h = torch.cat([h, h2], dim=1) # [B, 192, H/2, W/2]
h = self.up3(h, t_emb) # [B, 64, H, W]
h = torch.cat([h, h1], dim=1) # [B, 128, H, W]
return self.out(h)