File size: 3,362 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.nn.functional as F

class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = dim // 2
        emb = torch.log(torch.tensor(10000)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        self.register_buffer('emb', emb)

    def forward(self, t):
        emb = t.float()[:, None] * self.emb[None, :]
        emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
        return emb

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
        else:
            self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.norm = nn.GroupNorm(8, out_ch)
        self.act = nn.SiLU()

    def forward(self, x, t):
        h = self.conv(x)
        time_emb = self.time_mlp(t)
        h = h + time_emb[:, :, None, None]
        h = self.norm(h)
        h = self.act(h)
        return h

class SmoothDiffusionUNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Time embedding
        self.time_mlp = TimeEmbedding(config.time_emb_dim)
        
        # Downsample blocks
        self.down1 = Block(config.in_channels, config.base_channels, config.time_emb_dim)
        self.down2 = Block(config.base_channels, config.base_channels*2, config.time_emb_dim)
        self.down3 = Block(config.base_channels*2, config.base_channels*4, config.time_emb_dim)
        
        # Middle blocks
        self.mid1 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim)
        self.mid2 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim)
        
        # Upsample blocks
        self.up1 = Block(config.base_channels*4, config.base_channels*2, config.time_emb_dim, up=True)
        self.up2 = Block(config.base_channels*6, config.base_channels, config.time_emb_dim, up=True)  # 128 + 256 = 384 = 6*64
        self.up3 = Block(config.base_channels*3, config.base_channels, config.time_emb_dim, up=True)  # 64 + 128 = 192 = 3*64
        
        # Final output
        self.out = nn.Conv2d(config.base_channels*2, config.in_channels, kernel_size=3, padding=1)  # 128 = 2*64

    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_mlp(t)
        
        # Downsample path
        h1 = self.down1(x, t_emb)  # [B, 64, H, W]
        h2 = self.down2(F.max_pool2d(h1, 2), t_emb)  # [B, 128, H/2, W/2]
        h3 = self.down3(F.max_pool2d(h2, 2), t_emb)  # [B, 256, H/4, W/4]
        
        # Bottleneck
        h = self.mid1(F.max_pool2d(h3, 2), t_emb)  # [B, 256, H/8, W/8]
        h = self.mid2(h, t_emb)  # [B, 256, H/8, W/8]
        
        # Upsample path
        h = self.up1(h, t_emb)  # [B, 128, H/4, W/4]
        h = torch.cat([h, h3], dim=1)  # [B, 384, H/4, W/4]
        h = self.up2(h, t_emb)  # [B, 64, H/2, W/2]
        h = torch.cat([h, h2], dim=1)  # [B, 192, H/2, W/2]
        h = self.up3(h, t_emb)  # [B, 64, H, W]
        h = torch.cat([h, h1], dim=1)  # [B, 128, H, W]
        
        return self.out(h)