JuyeopDang commited on
Commit
5ab5cab
·
verified ·
1 Parent(s): cda6ed1

Upload 35 files

Browse files
auto_encoder/components/distributions.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #source: https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py
2
+ import torch
3
+ import numpy as np
4
+
5
+ class DiagonalGaussianDistribution(object):
6
+ def __init__(self, parameters, deterministic=False):
7
+ self.parameters = parameters
8
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
9
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
10
+ self.deterministic = deterministic
11
+ self.std = torch.exp(0.5 * self.logvar)
12
+ self.var = torch.exp(self.logvar)
13
+ if self.deterministic:
14
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
15
+
16
+ def sample(self):
17
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
18
+ return x
19
+
20
+ def kl(self, other=None):
21
+ if self.deterministic:
22
+ return torch.Tensor([0.])
23
+ else:
24
+ if other is None:
25
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
26
+ + self.var - 1.0 - self.logvar,
27
+ dim=[1, 2, 3])
28
+ else:
29
+ return 0.5 * torch.sum(
30
+ torch.pow(self.mean - other.mean, 2) / other.var
31
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
32
+ dim=[1, 2, 3])
33
+
34
+ def nll(self, sample, dims=[1,2,3]):
35
+ if self.deterministic:
36
+ return torch.Tensor([0.])
37
+ logtwopi = np.log(2.0 * np.pi)
38
+ return 0.5 * torch.sum(
39
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
40
+ dim=dims)
41
+
42
+ def mode(self):
43
+ return self.mean
auto_encoder/components/nonlinearity.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def nonlinearity(x):
4
+ # swish
5
+ return x*torch.sigmoid(x)
auto_encoder/components/normalize.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch
2
+
3
+ def Normalize(in_channels : int, num_groups : int = 32):
4
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
auto_encoder/components/resnet_block.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from auto_encoder.components.normalize import Normalize
6
+ from auto_encoder.components.nonlinearity import nonlinearity
7
+
8
+ class ResnetBlock(nn.Module):
9
+ def __init__(self, *, in_channels : int, out_channels : int = None, conv_shortcut=False, dropout):
10
+ super().__init__()
11
+ self.in_channels = in_channels
12
+ out_channels = in_channels if out_channels is None else out_channels
13
+ self.out_channels = out_channels
14
+ self.use_conv_shortcut = conv_shortcut
15
+
16
+ self.norm1 = Normalize(in_channels)
17
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
18
+ self.norm2 = Normalize(out_channels)
19
+ self.dropout = torch.nn.Dropout(dropout)
20
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
21
+
22
+ if self.in_channels != self.out_channels:
23
+ if self.use_conv_shortcut:
24
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
25
+ else:
26
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
27
+
28
+ def forward(self, x):
29
+ h = x
30
+ h = self.norm1(h)
31
+ h = nonlinearity(h)
32
+ h = self.conv1(h)
33
+ h = self.norm2(h)
34
+ h = nonlinearity(h)
35
+ h = self.dropout(h)
36
+ h = self.conv2(h)
37
+
38
+ if self.in_channels != self.out_channels:
39
+ if self.use_conv_shortcut:
40
+ x = self.conv_shortcut(x)
41
+ else:
42
+ x = self.nin_shortcut(x)
43
+
44
+ return x+h
45
+
46
+
auto_encoder/components/sampling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Upsample(nn.Module):
5
+ def __init__(self, in_channels : int, with_conv : bool):
6
+ super().__init__()
7
+ self.with_conv = with_conv
8
+ if self.with_conv:
9
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1)
10
+
11
+ def forward(self, x):
12
+ x = torch.nn.functional.interpolate(x, scale_factor = 2.0, mode = "nearest")
13
+ if self.with_conv:
14
+ x = self.conv(x)
15
+ return x
16
+
17
+ class Downsample(nn.Module):
18
+ def __init__(self, in_channels : int, with_conv : bool):
19
+ super().__init__()
20
+ self.with_conv = with_conv
21
+ if self.with_conv:
22
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 2, padding = 0)
23
+
24
+ def forward(self, x):
25
+ if self.with_conv:
26
+ pad = (0, 1, 0, 1)
27
+ x = torch.nn.functional.pad(x, pad, mode = "constant", value = 0)
28
+ x = self.conv(x)
29
+ else:
30
+ x = torch.nn.functional.avg_pool2d(x, kernel_size = 2, stride = 2)
31
+ return x
auto_encoder/models/auto_encoder.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from auto_encoder.models.decoder import Decoder
4
+ from auto_encoder.models.encoder import Encoder
5
+ import yaml
6
+
7
+ class AutoEncoder(nn.Module):
8
+ def __init__(self, config_path : str):
9
+ super().__init__()
10
+ with open(config_path, "r") as file:
11
+ config = yaml.safe_load(file)
12
+ self.add_module('encoder', Encoder(**config["encoder"]))
13
+ self.add_module('decoder', Decoder(**config["decoder"]))
14
+
15
+ def encode(self, x):
16
+ h = self.encoder(x)
17
+ return h
18
+
19
+ def decode(self, z):
20
+ z = self.decoder(z)
21
+ return z
22
+
23
+ def reconstruct(self, x):
24
+ return self.decode(self.encode(x))
25
+
26
+ def loss(self, x):
27
+ x_hat = self(x)
28
+ return F.mse_loss(x, x_hat)
29
+
30
+ def forward(self, x):
31
+ return self.reconstruct(x)
auto_encoder/models/decoder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #source : https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from auto_encoder.components.normalize import Normalize
6
+ from auto_encoder.components.resnet_block import ResnetBlock
7
+ from auto_encoder.components.sampling import Upsample
8
+ from auto_encoder.components.nonlinearity import nonlinearity
9
+
10
+ class Decoder(nn.Module):
11
+ def __init__(self, *, in_channels, out_channels, resolution, channels, channel_multipliers = (1, 2, 4, 8), z_channels, num_res_blocks,
12
+ dropout = 0.0, resample_with_conv : bool = True):
13
+ super().__init__()
14
+ self.ch = channels
15
+ self.num_resolutions = len(channel_multipliers)
16
+ self.num_res_blocks = num_res_blocks
17
+ self.in_channels = in_channels
18
+ self.z_channels = z_channels
19
+
20
+ in_ch_mult = (1 , ) + tuple(channel_multipliers)
21
+ block_in = self.ch * in_ch_mult[self.num_resolutions - 1]
22
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
23
+ self.z_shape = (1 , z_channels, curr_res, curr_res)
24
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
25
+
26
+ # z to block_in
27
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size = 3, stride = 1, padding = 1)
28
+
29
+ # middle
30
+ self.mid = nn.Module()
31
+ self.mid.block_1 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
32
+ self.mid.block_2 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
33
+
34
+ # upsampling
35
+
36
+ self.up = nn.ModuleList()
37
+ for i_level in reversed(range(self.num_resolutions)):
38
+ block = nn.ModuleList()
39
+ block_out = self.ch * channel_multipliers[i_level]
40
+ for i_block in range(self.num_res_blocks + 1):
41
+ block.append(ResnetBlock(in_channels = block_in, out_channels = block_out,
42
+ dropout = dropout))
43
+ block_in = block_out
44
+ up = nn.Module()
45
+ up.block = block
46
+ if i_level != 0:
47
+ up.upsample = Upsample(block_in, resample_with_conv)
48
+ curr_res = curr_res * 2
49
+ self.up.insert(0, up)
50
+
51
+ # end
52
+ self.norm_out = Normalize(block_in)
53
+ self.conv_out = torch.nn.Conv2d(block_in, out_channels,
54
+ kernel_size = 3, stride = 1, padding = 1)
55
+
56
+ def forward(self, z):
57
+ assert z.shape[1:] == self.z_shape[1:]
58
+ self.last_z_shape = z.shape
59
+
60
+ # z to block_in
61
+ h = self.conv_in(z)
62
+
63
+ # middle
64
+ h = self.mid.block_1(h)
65
+ h = self.mid.block_2(h)
66
+
67
+ # upsampling
68
+ for i_level in reversed(range(self.num_resolutions)):
69
+ for i_block in range(self.num_res_blocks + 1):
70
+ h = self.up[i_level].block[i_block](h)
71
+ if i_level != 0:
72
+ h = self.up[i_level].upsample(h)
73
+
74
+ # end
75
+ h = self.norm_out(h)
76
+ h = nonlinearity(h)
77
+ h = self.conv_out(h)
78
+ return h
auto_encoder/models/encoder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #source : https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from auto_encoder.components.normalize import Normalize
6
+ from auto_encoder.components.resnet_block import ResnetBlock
7
+ from auto_encoder.components.sampling import Downsample
8
+ from auto_encoder.components.nonlinearity import nonlinearity
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(self, *, in_channels, resolution, channels, channel_multipliers = (1, 2, 4, 8), z_channels, num_res_blocks,
12
+ dropout = 0.0, resample_with_conv : bool = True, double_z : bool = True):
13
+ super().__init__()
14
+ self.ch = channels
15
+ self.num_resolutions = len(channel_multipliers)
16
+ self.num_res_blocks = num_res_blocks
17
+ self.in_channels = in_channels
18
+ self.z_channels = z_channels
19
+
20
+ # downsampling
21
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size = 3, stride = 1, padding = 1)
22
+ curr_res = resolution
23
+ in_ch_mult = (1, ) + tuple(channel_multipliers)
24
+ self.in_ch_mult = in_ch_mult
25
+ self.down = nn.ModuleList()
26
+ for i_level in range(self.num_resolutions):
27
+ block = nn.ModuleList()
28
+ block_in = self.ch * in_ch_mult[i_level]
29
+ block_out = self.ch * channel_multipliers[i_level]
30
+ for i_block in range(self.num_res_blocks):
31
+ block.append(ResnetBlock(in_channels = block_in, out_channels = block_out, dropout = dropout))
32
+ block_in = block_out
33
+ down = nn.Module()
34
+ down.block = block
35
+ if i_level != self.num_resolutions - 1:
36
+ down.downsample = Downsample(block_in, resample_with_conv)
37
+ curr_res = curr_res // 2
38
+ self.down.append(down)
39
+
40
+ # middle
41
+ self.mid = nn.Module()
42
+ self.mid.block_1 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
43
+ self.mid.block_2 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
44
+
45
+ # end
46
+ self.norm_out = Normalize(block_in)
47
+ self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels,
48
+ kernel_size = 3, stride = 1, padding = 1)
49
+
50
+ def forward(self, x):
51
+ # downsampling
52
+ hs = [self.conv_in(x)]
53
+ for i_level in range(self.num_resolutions):
54
+ for i_block in range(self.num_res_blocks):
55
+ h = self.down[i_level].block[i_block](hs[-1])
56
+
57
+ hs.append(h)
58
+ if i_level != self.num_resolutions - 1:
59
+ hs.append(self.down[i_level].downsample(hs[-1]))
60
+
61
+ # middle
62
+ h = hs[-1]
63
+ h = self.mid.block_1(h)
64
+ h = self.mid.block_2(h)
65
+
66
+ # end
67
+ h = self.norm_out(h)
68
+ h = nonlinearity(h)
69
+ h = self.conv_out(h)
70
+ return h
71
+
auto_encoder/models/variational_auto_encoder.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from auto_encoder.models.encoder import Encoder
6
+ from auto_encoder.models.decoder import Decoder
7
+ import yaml
8
+ from auto_encoder.components.distributions import DiagonalGaussianDistribution
9
+
10
+ class VariationalAutoEncoder(nn.Module):
11
+ def __init__(self, config_path):
12
+ super().__init__()
13
+ with open(config_path, "r") as file:
14
+ config = yaml.safe_load(file)
15
+ self.add_module('encoder', Encoder(**config["encoder"]))
16
+ self.add_module('decoder', Decoder(**config["decoder"]))
17
+ self.embed_dim = config['vae']['embed_dim']
18
+ self.kld_weight = float(config['vae']['kld_weight'])
19
+
20
+ self.quant_conv = torch.nn.Conv2d(self.decoder.z_channels, 2*self.embed_dim, 1)
21
+ self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.z_channels, 1)
22
+
23
+ def encode(self, x):
24
+ h = self.encoder(x)
25
+ moments = self.quant_conv(h)
26
+ posterior = DiagonalGaussianDistribution(moments)
27
+ return posterior
28
+
29
+ def decode(self, z):
30
+ z = self.post_quant_conv(z)
31
+ dec = self.decoder(z)
32
+ return dec
33
+
34
+ def loss(self, x):
35
+ x_hat, posterior = self(x)
36
+ return F.mse_loss(x, x_hat) + self.kld_weight * posterior.kl().mean()
37
+
38
+ def forward(self, input, sample_posterior=True):
39
+ posterior = self.encode(input)
40
+ if sample_posterior:
41
+ z = posterior.sample()
42
+ else:
43
+ z = posterior.mode()
44
+ dec = self.decode(z)
45
+ return dec, posterior
clip/encoders/image_encoder.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ImageEncoder(nn.Module):
5
+ def __init__(self, in_channels: int, resolution: int, patch_size: int,
6
+ number_of_features: int, number_of_heads:int, number_of_transformer_layers: int,
7
+ embed_dim: int):
8
+ super().__init__()
9
+ self.resolution = resolution
10
+ self.embed_dim = embed_dim
11
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=number_of_features,
12
+ kernel_size=patch_size, stride=patch_size, bias=False)
13
+ self.number_of_patches = (resolution // patch_size) ** 2
14
+ self.positional_embedding = nn.Parameter(torch.zeros(1, self.number_of_patches + 1, number_of_features))
15
+ self.class_embedding = nn.Parameter(torch.zeros(1, 1, number_of_features))
16
+
17
+ self.ln_pre = nn.LayerNorm(number_of_features)
18
+ self.transformer = nn.TransformerEncoder(
19
+ nn.TransformerEncoderLayer(d_model=number_of_features, nhead=number_of_heads, batch_first=True),
20
+ num_layers=number_of_transformer_layers
21
+ )
22
+
23
+ self.ln_post = nn.LayerNorm(number_of_features)
24
+ self.fc = nn.Linear(number_of_features, embed_dim)
25
+
26
+ # initialize
27
+ nn.init.kaiming_normal_(self.positional_embedding, nonlinearity='relu')
28
+ nn.init.kaiming_normal_(self.class_embedding, nonlinearity='relu')
29
+ nn.init.kaiming_normal_(self.fc.weight, nonlinearity='relu')
30
+
31
+ def forward(self, x: torch.Tensor):
32
+ x = self.conv(x) # [batch_size, number_of_features, grid, grid]
33
+ x = x.flatten(2) # [batch_size, number_of_features, grid ** 2 = number_of_patches]
34
+ x = x.transpose(1, 2) # [batch_size, number_of_patches, number_of_features]
35
+
36
+ class_embeddings = self.class_embedding.expand(x.shape[0], -1, -1)
37
+ x = torch.cat([class_embeddings, x], dim=1)
38
+ x = x + self.positional_embedding
39
+ x = self.ln_pre(x)
40
+ x = self.transformer(x) # [batch_size, length_of_sequence, number_of_features]
41
+ x = x.permute(1, 0, 2) # [length_of_sequence, batch_size, number_of_features]
42
+ x = self.ln_post(x[0])
43
+ x = self.fc(x) # [batch_size, embed_dim]
44
+ return x
clip/encoders/text_encoder.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class TextEncoder(nn.Module):
5
+ def __init__(self, number_of_features: int, number_of_heads: int, number_of_transformer_layers: int,
6
+ context_length, embed_dim):
7
+ super().__init__()
8
+ self.vocab_size = 32000 # AutoTokenizer: "koclip/koclip-base-pt"
9
+ self.token_embedding = nn.Embedding(self.vocab_size, number_of_features)
10
+ self.positional_embedding = nn.Parameter(torch.zeros(context_length, number_of_features))
11
+ self.transformer = nn.TransformerEncoder(
12
+ nn.TransformerEncoderLayer(d_model=number_of_features, nhead=number_of_heads, batch_first=True),
13
+ num_layers=number_of_transformer_layers
14
+ )
15
+ self.text_projection = nn.Linear(number_of_features, embed_dim)
16
+
17
+ # initialize
18
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
19
+ nn.init.xavier_uniform_(self.positional_embedding)
20
+ nn.init.kaiming_normal_(self.text_projection.weight, nonlinearity='relu')
21
+
22
+ def forward(self, x):
23
+ eot_token_idx = (x == 2).nonzero(as_tuple=True)[1] # Assume EOT token ID is 2
24
+ x = self.token_embedding(x)
25
+ x = x + self.positional_embedding[:x.size(1), :]
26
+ x = self.transformer(x)
27
+ x = x[torch.arange(x.shape[0]), eot_token_idx]
28
+ x = self.text_projection(x)
29
+ return x
clip/models/clip.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import yaml
5
+ import numpy as np
6
+
7
+ from clip.encoders.image_encoder import ImageEncoder
8
+ from clip.encoders.text_encoder import TextEncoder
9
+ from helper.tokenizer import Tokenizer
10
+
11
+ class CLIP(nn.Module):
12
+ def __init__(self, config_path):
13
+ super().__init__()
14
+ with open(config_path, "r") as file:
15
+ config = yaml.safe_load(file)
16
+
17
+ self.image_encoder = ImageEncoder(**config["image_encoder"])
18
+ self.text_encoder = TextEncoder(**config["text_encoder"])
19
+ self.tokenizer = Tokenizer()
20
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
21
+
22
+ # initialize
23
+ for module in self.modules():
24
+ if isinstance(module, nn.Linear):
25
+ nn.init.xavier_normal_(module.weight)
26
+ if module.bias is not None:
27
+ nn.init.constant_(module.bias, 0)
28
+ elif isinstance(module, nn.Conv2d):
29
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
30
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
31
+ nn.init.constant_(module.weight, 1)
32
+ nn.init.constant_(module.bias, 0)
33
+
34
+ def loss(self, image, text):
35
+ image_features, text_features = self(image, text, tokenize=False)
36
+
37
+ # Normalize features
38
+ image_features = F.normalize(image_features, dim=1)
39
+ text_features = F.normalize(text_features, dim=1)
40
+
41
+ # Cosine similarity as logits with learned temperature
42
+ logits = torch.matmul(image_features, text_features.t()) * self.logit_scale.exp()
43
+ labels = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device)
44
+
45
+ # Cross-entropy loss
46
+ loss_i2t = F.cross_entropy(logits, labels)
47
+ loss_t2i = F.cross_entropy(logits.t(), labels)
48
+
49
+ return (loss_i2t + loss_t2i) / 2
50
+
51
+ def text_encode(self, text, tokenize=True):
52
+ if tokenize:
53
+ tokens = self.tokenizer.tokenize(text)
54
+ else:
55
+ tokens = text
56
+ text_features = self.text_encoder(tokens)
57
+ if text_features.dim() < 2:
58
+ text_features = text_features.unsqueeze(0)
59
+ return text_features
60
+
61
+ def forward(self, image, text, tokenize=True):
62
+ image_features = self.image_encoder(image)
63
+ text_features = self.text_encoder(text, tokenize)
64
+
65
+ if image_features.dim() < 2:
66
+ image_features = image_features.unsqueeze(0)
67
+ if text_features.dim() < 2:
68
+ text_features = text_features.unsqueeze(0)
69
+
70
+ return image_features, text_features
clip/models/ko_clip.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from transformers import AutoModel, AutoTokenizer
4
+
5
+ class KoCLIPWrapper(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.model_name = "Bingsu/clip-vit-base-patch32-ko"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
+ self.model = AutoModel.from_pretrained(self.model_name)
11
+
12
+ def loss(self, inputs):
13
+ outputs = self(inputs)
14
+ return outputs.loss
15
+
16
+ def text_encode(self, text, tokenize):
17
+ if tokenize:
18
+ tokens = self.tokenizer(text, padding='max_length', max_length=77, truncation=True, return_tensors="pt")
19
+ else:
20
+ tokens = text
21
+ tokens = tokens.to(self.model.device)
22
+ return self.model.get_text_features(**tokens)
23
+
24
+ def forward(self, inputs):
25
+ outputs = self.model(**inputs, return_loss=True)
26
+ return outputs # [1, 512] , [1, 512]
configs/composite_clip_config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text_encoder:
2
+ number_of_features: 512
3
+ number_of_heads: 8
4
+ number_of_transformer_layers: 6
5
+ context_length: 77
6
+ embed_dim: 128
7
+
8
+ image_encoder:
9
+ in_channels: 3
10
+ resolution: 256
11
+ patch_size: 16
12
+ number_of_features: 768
13
+ number_of_heads: 12
14
+ number_of_transformer_layers: 4
15
+ embed_dim: 128
configs/composite_config.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ encoder:
2
+ in_channels: 3
3
+ resolution: 256
4
+ channels: 128
5
+ channel_multipliers: [1, 2, 4]
6
+ z_channels: 3
7
+ num_res_blocks: 2
8
+ dropout: 0.0
9
+
10
+ decoder:
11
+ in_channels: 3
12
+ out_channels: 3
13
+ resolution: 256
14
+ channels: 128
15
+ channel_multipliers: [1, 2, 4]
16
+ z_channels: 6
17
+ num_res_blocks: 2
18
+ dropout: 0.0
19
+
20
+ vae:
21
+ embed_dim: 3
22
+ kld_weight: 1e-6
23
+
24
+ sampler:
25
+ beta: 'sigmoid'
26
+ T: 1000
27
+ sampling_T: 50
28
+ eta: 1
29
+
30
+ cond_encoder:
31
+ embed_dim: 512
32
+ cond_dim: 768
33
+ cond_drop_prob: 0.2
34
+
35
+ unet:
36
+ dim: 192
37
+ dim_mults: [1, 2, 4, 8]
38
+ cond_dim: 768
diffusion_model/models/clip_latent_diffusion_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
5
+ from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
6
+ from clip.models.clip import CLIP
7
+
8
+ class CLIPLatentDiffusionModel(LatentDiffusionModel) :
9
+ def __init__(self, network : nn.Module, sampler : nn.Module,
10
+ auto_encoder : VariationalAutoEncoder, clip : CLIP, image_shape):
11
+ super().__init__(network, sampler, auto_encoder, image_shape)
12
+ self.clip = clip
13
+ self.clip.eval()
14
+ for param in self.clip.parameters():
15
+ param.requires_grad = False
16
+
17
+ def loss(self, x0, text):
18
+ text = self.clip.text_encode(text, tokenize=False)
19
+ x0 = self.auto_encoder.encode(x0).sample()
20
+ eps = torch.randn_like(x0)
21
+ t = torch.randint(0, self.T, (x0.size(0),), device = x0.device)
22
+ x_t = self.sampler.q_sample(x0, t, eps)
23
+ eps_hat = self.network(x=x_t, t=t, y=text)
24
+ return self.weighted_loss(t, eps, eps_hat)
25
+
26
+ @torch.no_grad()
27
+ def forward(self, text, n_samples : int = 4):
28
+ text = self.clip.text_encode(text)
29
+ text = text.repeat(n_samples, 1)
30
+ x_T = torch.randn(n_samples, *self.latent_shape, device = next(self.buffers(), None).device )
31
+ sample = self.sampler(x_T = x_T, y=text)
32
+ return self.auto_encoder.decode(sample)
33
+
34
+ @torch.no_grad()
35
+ def generate_sequence(self, text, n_samples : int = 4):
36
+ text = self.clip.text_encode(text)
37
+ text = text.repeat(n_samples, 1)
38
+ x_T = torch.randn(n_samples, *self.latent_shape, device = next(self.buffers(), None).device )
39
+ sample_sequence = self.sampler.reverse_process(x_T, y = text, only_last=False)
40
+ return sample_sequence
diffusion_model/models/diffusion_model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import reduce
4
+
5
+ from helper.util import extract
6
+
7
+ class DiffusionModel(nn.Module) :
8
+ def __init__(self, network : nn.Module, sampler : nn.Module, image_shape):
9
+ super().__init__()
10
+ self.add_module('sampler', sampler)
11
+ self.add_module('network', network)
12
+ self.sampler.set_network(network)
13
+ self.T = sampler.T
14
+ self.image_shape = image_shape
15
+
16
+ # loss weight
17
+ alpha_bar = self.sampler.alpha_bar
18
+ snr = alpha_bar / (1 - alpha_bar)
19
+ clipped_snr = snr.clone()
20
+ clipped_snr.clamp_(max = 5)
21
+ self.register_buffer('loss_weight', clipped_snr / snr)
22
+
23
+ def weighted_loss(self, t, eps, eps_hat):
24
+ loss = nn.functional.mse_loss(eps, eps_hat, reduction='none')
25
+ loss = reduce(loss, 'b ... -> b', 'mean')
26
+ loss = loss * extract(self.loss_weight, t, loss.shape)
27
+ return loss.mean()
28
+
29
+ def loss(self, x0, **kwargs):
30
+ eps = torch.randn_like(x0)
31
+ t = torch.randint(0, self.T, (x0.size(0),), device = x0.device)
32
+ x_t = self.sampler.q_sample(x0, t, eps)
33
+ eps_hat = self.network(x = x_t, t = t, **kwargs)
34
+ return self.weighted_loss(t, eps, eps_hat)
35
+
36
+ @torch.no_grad()
37
+ def forward(self, n_samples: int = 4, only_last: bool = True, gamma = None, **kwargs):
38
+ """
39
+ If only_last is False, the outputs will be the sequnece of the generated points
40
+ """
41
+ x_T = torch.randn(n_samples, *self.image_shape, device = next(self.buffers(), None).device)
42
+ return self.sampler(x_T = x_T, only_last=only_last, gamma = gamma, **kwargs)
diffusion_model/models/latent_diffusion_model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
5
+ from diffusion_model.models.diffusion_model import DiffusionModel
6
+
7
+ class LatentDiffusionModel(DiffusionModel) :
8
+ def __init__(self, network : nn.Module, sampler : nn.Module, auto_encoder : VariationalAutoEncoder):
9
+ super().__init__(network, sampler, None)
10
+ self.auto_encoder = auto_encoder
11
+ self.auto_encoder.eval()
12
+ for param in self.auto_encoder.parameters():
13
+ param.requires_grad = False
14
+ # The image shape is the latent shape
15
+ self.image_shape = [*self.auto_encoder.decoder.z_shape[1:]]
16
+ self.image_shape[0] = self.auto_encoder.embed_dim
17
+
18
+ def loss(self, x0, **kwargs):
19
+ x0 = self.auto_encoder.encode(x0).sample()
20
+ eps = torch.randn_like(x0)
21
+ t = torch.randint(0, self.T, (x0.size(0),), device = x0.device)
22
+ x_t = self.sampler.q_sample(x0, t, eps)
23
+ eps_hat = self.network(x = x_t, t = t, **kwargs)
24
+ return self.weighted_loss(t, eps, eps_hat)
25
+
26
+ # The forward function outputs the generated latents
27
+ # Therefore, sample() should be used for sampling data, not latents
28
+ @torch.no_grad()
29
+ def sample(self, n_samples: int = 4, gamma = None, **kwargs):
30
+ sample = self(n_samples, gamma=gamma, **kwargs)
31
+ return self.auto_encoder.decode(sample)
32
+
33
+ @torch.no_grad()
34
+ def generate_sequence(self, n_samples: int = 4, gamma = None, **kwargs):
35
+ sequence = self(n_samples, only_last=False, gamma = gamma, **kwargs)
36
+ sample = self.auto_encoder.decode(sequence[-1])
37
+ return sequence, sample
diffusion_model/network/attention.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/attend.py
2
+ from functools import wraps
3
+ from packaging import version
4
+ from collections import namedtuple
5
+
6
+ import torch
7
+ from torch import nn, einsum
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, repeat
11
+ from functools import partial
12
+
13
+ # constants
14
+
15
+ AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
16
+
17
+ # helpers
18
+
19
+ def exists(val):
20
+ return val is not None
21
+
22
+ def default(val, d):
23
+ return val if exists(val) else d
24
+
25
+ def once(fn):
26
+ called = False
27
+ @wraps(fn)
28
+ def inner(x):
29
+ nonlocal called
30
+ if called:
31
+ return
32
+ called = True
33
+ return fn(x)
34
+ return inner
35
+
36
+ print_once = once(print)
37
+
38
+ class RMSNorm(nn.Module):
39
+ def __init__(self, dim):
40
+ super().__init__()
41
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
42
+
43
+ def forward(self, x):
44
+ return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
45
+
46
+ # main class
47
+
48
+ class Attend(nn.Module):
49
+ def __init__(
50
+ self,
51
+ dropout = 0.,
52
+ flash = False,
53
+ scale = None
54
+ ):
55
+ super().__init__()
56
+ self.dropout = dropout
57
+ self.scale = scale
58
+ self.attn_dropout = nn.Dropout(dropout)
59
+
60
+ self.flash = flash
61
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
62
+
63
+ # determine efficient attention configs for cuda and cpu
64
+
65
+ self.cpu_config = AttentionConfig(True, True, True)
66
+ self.cuda_config = None
67
+
68
+ if not torch.cuda.is_available() or not flash:
69
+ return
70
+
71
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
72
+
73
+ device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
74
+
75
+ if device_version > version.parse('8.0'):
76
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
77
+ self.cuda_config = AttentionConfig(True, False, False)
78
+ else:
79
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
80
+ self.cuda_config = AttentionConfig(False, True, True)
81
+
82
+ def flash_attn(self, q, k, v):
83
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
84
+
85
+ if exists(self.scale):
86
+ default_scale = q.shape[-1]
87
+ q = q * (self.scale / default_scale)
88
+
89
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
90
+
91
+ # Check if there is a compatible device for flash attention
92
+
93
+ config = self.cuda_config if is_cuda else self.cpu_config
94
+
95
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
96
+
97
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
98
+ out = F.scaled_dot_product_attention(
99
+ q, k, v,
100
+ dropout_p = self.dropout if self.training else 0.
101
+ )
102
+
103
+ return out
104
+
105
+ def forward(self, q, k, v):
106
+ """
107
+ einstein notation
108
+ b - batch
109
+ h - heads
110
+ n, i, j - sequence length (base sequence length, source, target)
111
+ d - feature dimension
112
+ """
113
+
114
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
115
+
116
+ if self.flash:
117
+ return self.flash_attn(q, k, v)
118
+
119
+ scale = default(self.scale, q.shape[-1] ** -0.5)
120
+
121
+ # similarity
122
+
123
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
124
+
125
+ # attention
126
+
127
+ attn = sim.softmax(dim = -1)
128
+ attn = self.attn_dropout(attn)
129
+
130
+ # aggregate values
131
+
132
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
133
+
134
+ return out
135
+
136
+ class LinearAttention(nn.Module):
137
+ def __init__(self, dim, heads = 4, dim_head = 32):
138
+ super().__init__()
139
+ self.scale = dim_head ** -0.5
140
+ self.heads = heads
141
+ hidden_dim = dim_head * heads
142
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
143
+
144
+ self.to_out = nn.Sequential(
145
+ nn.Conv2d(hidden_dim, dim, 1),
146
+ RMSNorm(dim)
147
+ )
148
+
149
+ def forward(self, x):
150
+ b, c, h, w = x.shape
151
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
152
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
153
+
154
+ q = q.softmax(dim = -2)
155
+ k = k.softmax(dim = -1)
156
+
157
+ q = q * self.scale
158
+
159
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
160
+
161
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
162
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
163
+ return self.to_out(out)
164
+
165
+ class Attention(nn.Module):
166
+ def __init__(self, dim, heads = 4, dim_head = 32):
167
+ super().__init__()
168
+ self.scale = dim_head ** -0.5
169
+ self.heads = heads
170
+ hidden_dim = dim_head * heads
171
+
172
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
173
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
174
+
175
+ def forward(self, x):
176
+ b, c, h, w = x.shape
177
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
178
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
179
+
180
+ q = q * self.scale
181
+
182
+ sim = einsum('b h d i, b h d j -> b h i j', q, k)
183
+ attn = sim.softmax(dim = -1)
184
+ out = einsum('b h i j, b h d j -> b h i d', attn, v)
185
+
186
+ out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
187
+ return self.to_out(out)
diffusion_model/network/blocks.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class AdaptiveGroupNorm(nn.Module):
5
+ def __init__(self, num_groups, num_channels, emb_dim, eps=1e-5):
6
+ super().__init__()
7
+ self.num_groups = num_groups
8
+ self.num_channels = num_channels
9
+ self.eps = eps
10
+ # Use a standard GroupNorm, but without learnable affine parameters
11
+ self.norm = nn.GroupNorm(num_groups, num_channels, eps=eps, affine=False)
12
+
13
+ # Linear layers to project the embedding to gamma and beta
14
+ self.gamma_proj = nn.Linear(emb_dim, num_channels)
15
+ self.beta_proj = nn.Linear(emb_dim, num_channels)
16
+
17
+ def forward(self, x, emb):
18
+ """
19
+ Args:
20
+ x: Input tensor of shape [B, C, H, W].
21
+ emb: Embedding tensor of shape [B, emb_dim].
22
+
23
+ Returns:
24
+ Normalized tensor with adaptive scaling and shifting.
25
+ """
26
+ # Normalize as usual with GroupNorm
27
+ normalized = self.norm(x)
28
+
29
+ # Get gamma and beta from the embedding
30
+ gamma = self.gamma_proj(emb)
31
+ beta = self.beta_proj(emb)
32
+
33
+ # Reshape for broadcasting: [B, C] -> [B, C, 1, 1]
34
+ gamma = gamma.view(-1, self.num_channels, 1, 1)
35
+ beta = beta.view(-1, self.num_channels, 1, 1)
36
+
37
+ # Apply adaptive scaling and shifting
38
+ return gamma * normalized + beta
39
+
40
+ class DepthwiseSeparableConv2d(nn.Module):
41
+ def __init__(self, dim_in, dim_out, kernel_size, padding):
42
+ super().__init__()
43
+ self.depthwise = nn.Conv2d(dim_in, dim_in, kernel_size, padding=padding, groups=dim_in)
44
+ self.pointwise = nn.Conv2d(dim_in, dim_out, 1) # 1x1 convolution
45
+
46
+ def forward(self, x):
47
+ x = self.depthwise(x)
48
+ x = self.pointwise(x)
49
+ return x
50
+
51
+ class Block(nn.Module):
52
+ def __init__(self, dim, dim_out, groups, emb_dim, dropout=0.0, use_depthwise=False):
53
+ super().__init__()
54
+ self.norm = AdaptiveGroupNorm(groups, dim, emb_dim)
55
+ if use_depthwise:
56
+ self.proj = DepthwiseSeparableConv2d(dim, dim_out, kernel_size=3, padding=1)
57
+ else:
58
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
59
+ self.act = nn.SiLU()
60
+ self.dropout = nn.Dropout(dropout)
61
+
62
+ def forward(self, x, emb):
63
+ x = self.norm(x, emb) # Pre-normalization
64
+ x = self.proj(x)
65
+ x = self.act(x)
66
+ return self.dropout(x)
67
+
68
+ class ResnetBlock(nn.Module):
69
+ def __init__(self, dim: int, dim_out: int, t_emb_dim: int, *,
70
+ y_emb_dim: int = None, groups: int = 32, dropout: float = 0.0, residual_scale=1.0):
71
+ super().__init__()
72
+ if y_emb_dim is None:
73
+ y_emb_dim = 0
74
+ emb_dim = t_emb_dim + y_emb_dim
75
+
76
+ self.block1 = Block(dim, dim_out, groups, emb_dim, dropout) # Pass emb_dim
77
+ self.block2 = Block(dim_out, dim_out, groups, emb_dim, dropout) # Pass emb_dim
78
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
79
+ self.residual_scale = nn.Parameter(torch.tensor(residual_scale))
80
+
81
+ def forward(self, x, t_emb, y_emb=None):
82
+ cond_emb = t_emb
83
+ if y_emb is not None:
84
+ cond_emb = torch.cat([cond_emb, y_emb], dim=-1)
85
+
86
+ h = self.block1(x, cond_emb) # Pass combined embedding to Block
87
+ h = self.block2(h, cond_emb) # Pass combined embedding to Block
88
+
89
+ return self.residual_scale * h + self.res_conv(x) # Scale the residual
diffusion_model/network/timestep_embedding.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class SinusoidalEmbedding(nn.Module):
6
+ def __init__(self, embed_dim : int, theta : int = 10000):
7
+ """
8
+ Creates sinusoidal embeddings for timesteps.
9
+
10
+ Args:
11
+ embed_dim: The dimensionality of the embedding.
12
+ theta: The base for the log-spaced frequencies.
13
+ """
14
+ super().__init__()
15
+ self.embed_dim = embed_dim
16
+ self.theta = theta
17
+
18
+ def forward(self, x):
19
+ """
20
+ Computes sinusoidal embeddings for the input timesteps.
21
+
22
+ Args:
23
+ x: A 1D torch.Tensor of timesteps (shape: [batch_size]).
24
+
25
+ Returns:
26
+ A torch.Tensor of sinusoidal embeddings (shape: [batch_size, embed_dim]).
27
+ """
28
+ assert isinstance(x, torch.Tensor) # Input must be a torch.Tensor
29
+ assert x.ndim == 1 # Input must be a 1D tensor
30
+ assert isinstance(self.embed_dim, int) and self.embed_dim > 0 # embed_dim must be a positive integer
31
+
32
+ half_dim = self.embed_dim // 2
33
+ # Create a sequence of log-spaced frequencies
34
+ embeddings = math.log(self.theta) / (half_dim - 1)
35
+ embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings)
36
+ # Outer product: timesteps x frequencies
37
+ embeddings = x[:, None] * embeddings[None, :]
38
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
39
+ # Handle odd embedding dimensions
40
+ if self.embed_dim % 2 == 1:
41
+ embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1)
42
+ return embeddings
diffusion_model/network/unet.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
2
+
3
+ from functools import partial
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Module, ModuleList
7
+ from diffusion_model.network.attention import LinearAttention, Attention
8
+ from diffusion_model.network.timestep_embedding import SinusoidalEmbedding
9
+ from diffusion_model.network.blocks import ResnetBlock
10
+
11
+ def exists(x):
12
+ return x is not None
13
+
14
+ def default(val, d):
15
+ if exists(val):
16
+ return val
17
+ return d() if callable(d) else d
18
+
19
+ def cast_tuple(t, length = 1):
20
+ if isinstance(t, tuple):
21
+ return t
22
+ return ((t,) * length)
23
+
24
+ def divisible_by(numer, denom):
25
+ return (numer % denom) == 0
26
+
27
+ # small helper modules
28
+
29
+ class DownSample(nn.Module):
30
+ def __init__(self, dim: int, dim_out: int):
31
+ """
32
+ Downsamples the spatial dimensions by a factor of 2 using a strided convolution.
33
+
34
+ Args:
35
+ dim: Input channel dimension.
36
+ """
37
+ super().__init__()
38
+ self.downsample = nn.Conv2d(dim, dim_out, kernel_size=4, stride=2, padding=1)
39
+
40
+ def forward(self, x: torch.tensor) -> torch.tensor:
41
+ """
42
+ Forward pass.
43
+
44
+ Args:
45
+ x: Input tensor of shape [B, C, H, W].
46
+
47
+ Returns:
48
+ Downsampled tensor of shape [B, C, H/2, W/2].
49
+ """
50
+ return self.downsample(x)
51
+
52
+ class UpSample(nn.Module):
53
+ def __init__(self, dim: int, dim_out: int):
54
+ """
55
+ Upsamples the spatial dimensions by a factor of 2 using a transposed convolution.
56
+
57
+ Args:
58
+ dim: Input channel dimension.
59
+ """
60
+ super().__init__()
61
+ self.upsample = nn.ConvTranspose2d(dim, dim_out, kernel_size=4, stride=2, padding=1)
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ """
65
+ Forward pass.
66
+
67
+ Args:
68
+ x: Input tensor of shape [B, C, H, W].
69
+
70
+ Returns:
71
+ Upsampled tensor of shape [B, C, 2*H, 2*W].
72
+ """
73
+ return self.upsample(x)
74
+
75
+
76
+ # model
77
+
78
+ class Unet(Module):
79
+ def __init__(
80
+ self,
81
+ dim,
82
+ init_dim = None,
83
+ out_dim = None,
84
+ cond_dim = None,
85
+ dim_mults = (1, 2, 4, 8),
86
+ channels = 3,
87
+ dropout = 0.,
88
+ attn_dim_head = 32,
89
+ attn_heads = 4,
90
+ full_attn = None, # defaults to full attention only for inner most layer
91
+ ):
92
+ super().__init__()
93
+
94
+ # determine dimensions
95
+
96
+ self.channels = channels
97
+ input_channels = channels
98
+
99
+ init_dim = default(init_dim, dim)
100
+ self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
101
+
102
+ dims = [*map(lambda m: dim * m, dim_mults)]
103
+ in_out = list(zip(dims[:-1], dims[1:]))
104
+
105
+ # time embeddings
106
+ time_dim = dim * 4
107
+
108
+ sinu_pos_emb = SinusoidalEmbedding(dim)
109
+
110
+ self.time_mlp = nn.Sequential(
111
+ sinu_pos_emb,
112
+ nn.Linear(dim, time_dim),
113
+ nn.GELU(),
114
+ nn.Linear(time_dim, time_dim)
115
+ )
116
+
117
+ # attention
118
+
119
+ if not full_attn:
120
+ full_attn = (*((False,) * (len(dim_mults) - 1)), True)
121
+
122
+ num_stages = len(dim_mults)
123
+ full_attn = cast_tuple(full_attn, num_stages)
124
+ attn_heads = cast_tuple(attn_heads, num_stages)
125
+ attn_dim_head = cast_tuple(attn_dim_head, num_stages)
126
+
127
+ assert len(full_attn) == len(dim_mults)
128
+
129
+ # prepare blocks
130
+
131
+ FullAttention = Attention
132
+ resnet_block = partial(ResnetBlock,
133
+ t_emb_dim = time_dim, y_emb_dim = cond_dim, dropout = dropout)
134
+
135
+ # layers
136
+
137
+ self.downs = ModuleList([])
138
+ self.ups = ModuleList([])
139
+ num_resolutions = len(in_out)
140
+
141
+ for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
142
+ is_last = ind >= (num_resolutions - 1)
143
+
144
+ attn_klass = FullAttention if layer_full_attn else LinearAttention
145
+
146
+ self.downs.append(ModuleList([
147
+ resnet_block(dim_in, dim_in),
148
+ resnet_block(dim_in, dim_in),
149
+ attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
150
+ DownSample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
151
+ ]))
152
+
153
+ mid_dim = dims[-1]
154
+ self.mid_block1 = resnet_block(mid_dim, mid_dim)
155
+ self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
156
+ self.mid_block2 = resnet_block(mid_dim, mid_dim)
157
+
158
+ for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
159
+ is_last = ind == (len(in_out) - 1)
160
+
161
+ attn_klass = FullAttention if layer_full_attn else LinearAttention
162
+
163
+ self.ups.append(ModuleList([
164
+ resnet_block(dim_out + dim_in, dim_out),
165
+ resnet_block(dim_out + dim_in, dim_out),
166
+ attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
167
+ UpSample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
168
+ ]))
169
+
170
+ default_out_dim = channels
171
+ self.out_dim = default(out_dim, default_out_dim)
172
+
173
+ self.final_res_block = resnet_block(init_dim * 2, init_dim)
174
+ self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)
175
+
176
+ @property
177
+ def downsample_factor(self):
178
+ return 2 ** (len(self.downs) - 1)
179
+
180
+ def forward(self, x, t, y = None):
181
+ assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'
182
+
183
+ x = self.init_conv(x)
184
+ r = x.clone()
185
+
186
+ t = self.time_mlp(t)
187
+
188
+ h = []
189
+
190
+ for block1, block2, attn, downsample in self.downs:
191
+ x = block1(x, t, y)
192
+ h.append(x)
193
+
194
+ x = block2(x, t, y)
195
+ x = attn(x) + x
196
+ h.append(x)
197
+
198
+ x = downsample(x)
199
+
200
+ x = self.mid_block1(x, t, y)
201
+ x = self.mid_attn(x) + x
202
+ x = self.mid_block2(x, t, y)
203
+
204
+ for block1, block2, attn, upsample in self.ups:
205
+ x = torch.cat((x, h.pop()), dim = 1)
206
+ x = block1(x, t, y)
207
+
208
+ x = torch.cat((x, h.pop()), dim = 1)
209
+ x = block2(x, t, y)
210
+ x = attn(x) + x
211
+
212
+ x = upsample(x)
213
+
214
+ x = torch.cat((x, r), dim = 1)
215
+
216
+ x = self.final_res_block(x, t, y)
217
+ return self.final_conv(x)
diffusion_model/network/unet_wrapper.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import yaml
4
+ import transformers
5
+
6
+ class UnetWrapper(nn.Module):
7
+ def __init__(self, Unet: nn.Module, config_path: str,
8
+ cond_encoder = None):
9
+ super().__init__()
10
+ with open(config_path, "r") as file:
11
+ config = yaml.safe_load(file)['unet']
12
+ self.add_module('network', Unet(**config))
13
+
14
+ # ConditionalEncoder
15
+ self.add_module('cond_encoder', cond_encoder)
16
+
17
+ def forward(self, x, t, y=None, cond_drop_all:bool = False):
18
+ if t.dim() == 0:
19
+ t = x.new_full((x.size(0), ), t, dtype = torch.int, device = x.device)
20
+ if y is not None:
21
+ assert self.cond_encoder is not None, 'You need to set ConditionalEncoder for conditional sampling.'
22
+ if isinstance(y, str) or isinstance(y, transformers.tokenization_utils_base.BatchEncoding):
23
+ y = self.cond_encoder(y, cond_drop_all=cond_drop_all).to(x.device)
24
+ else:
25
+ if torch.is_tensor(y) == False:
26
+ y = torch.tensor([y], device=x.device)
27
+ y = self.cond_encoder(y, cond_drop_all=cond_drop_all).squeeze()
28
+ if y.size(0) != x.size(0):
29
+ y = y.repeat(x.size(0), 1)
30
+ return self.network(x, t, y)
31
+ else:
32
+ return self.network(x, t)
diffusion_model/sampler/base_sampler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ import yaml
5
+
6
+ from helper.util import extract
7
+ from helper.beta_generator import BetaGenerator
8
+ from abc import ABC, abstractmethod
9
+
10
+ class BaseSampler(nn.Module, ABC):
11
+ def __init__(self, config_path : str):
12
+ super().__init__()
13
+ with open(config_path, "r") as file:
14
+ self.config = yaml.safe_load(file)['sampler']
15
+ self.T = self.config['T']
16
+ beta_generator = BetaGenerator(T=self.T)
17
+ self.timesteps = None
18
+
19
+ self.register_buffer('beta', getattr(beta_generator,
20
+ f"{self.config['beta']}_beta_schedule",
21
+ beta_generator.linear_beta_schedule)())
22
+
23
+ self.register_buffer('alpha', 1 - self.beta)
24
+ self.register_buffer('alpha_sqrt', self.alpha.sqrt())
25
+ self.register_buffer('alpha_bar', torch.cumprod(self.alpha, dim = 0))
26
+
27
+ @abstractmethod
28
+ @torch.no_grad()
29
+ def get_x_prev(self, x, t, idx, eps_hat):
30
+ pass
31
+
32
+ def set_network(self, network : nn.Module):
33
+ self.network = network
34
+
35
+ def q_sample(self, x0, t, eps = None):
36
+ alpha_t_bar = extract(self.alpha_bar, t, x0.shape)
37
+ if eps is None:
38
+ eps = torch.randn_like(x0)
39
+ q_xt_x0 = alpha_t_bar.sqrt() * x0 + (1 - alpha_t_bar).sqrt() * eps
40
+ return q_xt_x0
41
+
42
+ @torch.no_grad()
43
+ def reverse_process(self, x_T, only_last=True, **kwargs):
44
+ x = x_T
45
+ if only_last:
46
+ for i, t in tqdm(enumerate(reversed(self.timesteps))):
47
+ idx = len(self.timesteps) - i - 1
48
+ x = self.p_sample(x, t, idx, **kwargs)
49
+ return x
50
+ else:
51
+ x_seq = []
52
+ x_seq.append(x)
53
+ for i, t in tqdm(enumerate(reversed(self.timesteps))):
54
+ idx = len(self.timesteps) - i - 1
55
+ x_seq.append(self.p_sample(x_seq[-1], t, idx, **kwargs))
56
+ return x_seq
57
+
58
+ @torch.no_grad()
59
+ def p_sample(self, x, t, idx, gamma = None, **kwargs):
60
+ eps_hat = self.network(x = x, t = t, **kwargs)
61
+ if gamma is not None:
62
+ eps_null = self.network(x = x, t = t, cond_drop_all=True, **kwargs)
63
+ eps_hat = gamma * eps_hat + (1 - gamma) * eps_null
64
+ x = self.get_x_prev(x, idx, eps_hat)
65
+ return x
66
+
67
+ @torch.no_grad()
68
+ def forward(self, x_T, **kwargs):
69
+ return self.reverse_process(x_T, **kwargs)
diffusion_model/sampler/ddim.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusion_model.sampler.base_sampler import BaseSampler
4
+
5
+ class DDIM(BaseSampler):
6
+ def __init__(self, config_path):
7
+ super().__init__(config_path)
8
+ self.sampling_T = self.config['sampling_T']
9
+ step = self.T // self.sampling_T
10
+ self.timesteps = torch.arange(0, self.T, step, dtype=torch.int)
11
+
12
+ self.ddim_alpha = self.alpha_bar[self.timesteps]
13
+ self.sqrt_one_minus_alpha_bar = (1. - self.ddim_alpha).sqrt()
14
+ self.alpha_bar_prev = torch.cat([self.ddim_alpha[0:1], self.ddim_alpha[:-1]])
15
+ self.sigma = (self.config['eta'] *
16
+ torch.sqrt((1-self.alpha_bar_prev) / (1-self.ddim_alpha) *
17
+ (1 - self.ddim_alpha / self.alpha_bar_prev)))
18
+
19
+ def get_x_prev(self, x, tau, eps_hat) :
20
+ alpha_prev = self.alpha_bar_prev[tau]
21
+ sigma = self.sigma[tau]
22
+
23
+ x0_hat = (x - self.sqrt_one_minus_alpha_bar[tau] * eps_hat) \
24
+ / (self.ddim_alpha[tau] ** 0.5)
25
+ dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * eps_hat
26
+ if sigma == 0. : noise = 0.
27
+ else : noise = torch.randn_like(x, device = x.device)
28
+ x = alpha_prev.sqrt() * x0_hat + dir_xt + sigma * noise
29
+ return x
diffusion_model/sampler/ddpm.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusion_model.sampler.base_sampler import BaseSampler
4
+
5
+ class DDPM(BaseSampler):
6
+ def __init__(self, config_path):
7
+ super().__init__(config_path)
8
+ self.timesteps = torch.arange(0, self.T, dtype=torch.int)
9
+ self.sqrt_one_minus_alpha_bar = (1. - self.alpha_bar).sqrt()
10
+ self.alpha_bar_prev = torch.cat([self.alpha_bar[0:1], self.alpha_bar[:-1]])
11
+ self.sigma = (((1 - self.alpha_bar_prev) / (1 - self.alpha_bar)) * self.beta).sqrt()
12
+
13
+ @torch.no_grad()
14
+ def get_x_prev(self, x, t, eps_hat):
15
+ x = (1 / self.alpha_sqrt[t]) \
16
+ * (x - (self.beta[t] / self.sqrt_one_minus_alpha_bar[t] * eps_hat))
17
+ z = torch.randn_like(x) if t > 0 else 0.
18
+ x = x + self.sigma[t] * z
19
+ return x
20
+
helper/beta_generator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L442
2
+ import torch
3
+ import math
4
+
5
+ class BetaGenerator():
6
+ def __init__(self, T) :
7
+ self.T = T
8
+
9
+ def fixed_beta_schedule(self, beta) :
10
+ betas = torch.Tensor.repeat(torch.Tensor([beta]) , self.T)
11
+ return betas
12
+
13
+ def linear_beta_schedule(self):
14
+ """
15
+ linear schedule, proposed in original ddpm paper
16
+ """
17
+ scale = 1000 / self.T
18
+ beta_start = scale * 0.0001
19
+ beta_end = scale * 0.02
20
+ return torch.linspace(beta_start, beta_end, self.T)
21
+
22
+ def cosine_beta_schedule(self, s = 0.008):
23
+ """
24
+ cosine schedule
25
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
26
+ """
27
+ steps = self.T + 1
28
+ t = torch.linspace(0, self.T, steps, dtype = torch.float32) / self.T
29
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
30
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
31
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
32
+ return torch.clip(betas, 0, 0.999)
33
+
34
+ def sigmoid_beta_schedule(self, start = -3, end = 3, tau = 1):
35
+ """
36
+ sigmoid schedule
37
+ proposed in https://arxiv.org/abs/2212.11972 - Figure 8
38
+ better for images > 64x64, when used during training
39
+ """
40
+ steps = self.T + 1
41
+ t = torch.linspace(0, self.T, steps, dtype = torch.float32) / self.T
42
+ v_start = torch.tensor(start / tau).sigmoid()
43
+ v_end = torch.tensor(end / tau).sigmoid()
44
+ alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
45
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
46
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
47
+ return torch.clip(betas, 0, 0.999)
helper/cond_encoder.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import yaml
4
+
5
+ class BaseCondEncoder(nn.Module):
6
+ def __init__(
7
+ self,
8
+ config_path
9
+ ):
10
+ super().__init__()
11
+ with open(config_path, "r") as file:
12
+ self.config = yaml.safe_load(file)['cond_encoder']
13
+ self.embed_dim = self.config['embed_dim']
14
+ self.cond_dim = self.config['cond_dim']
15
+ if 'cond_drop_prob' in self.config:
16
+ self.cond_drop_prob = self.config['cond_drop_prob']
17
+ self.null_embedding = nn.Parameter(torch.randn(self.embed_dim))
18
+ else:
19
+ self.cond_drop_prob = 0.0
20
+
21
+ self.cond_mlp = nn.Sequential(
22
+ nn.Linear(self.embed_dim, self.cond_dim),
23
+ nn.GELU(),
24
+ nn.Linear(self.cond_dim, self.cond_dim)
25
+ )
26
+
27
+ def cond_drop(self, y: torch.tensor):
28
+ if self.training and self.cond_drop_prob > 0.0:
29
+ flags = torch.zeros((y.size(0), ), device=y.device).float().uniform_(0, 1) < self.cond_drop_prob
30
+ y[flags] = self.null_embedding.to(y.dtype)
31
+ return y
32
+
33
+ class CLIPEncoder(BaseCondEncoder):
34
+ def __init__(
35
+ self,
36
+ clip,
37
+ config_path
38
+ ):
39
+ super().__init__(config_path)
40
+ self.clip = clip
41
+ self.clip.eval()
42
+ for param in self.clip.parameters():
43
+ param.requires_grad = False
44
+
45
+ def forward(self, y, cond_drop_all:bool = False):
46
+ if isinstance(y, str):
47
+ y = self.clip.text_encode(y, tokenize=True)
48
+ else:
49
+ y = self.clip.text_encode(y, tokenize=False)
50
+ y = self.cond_drop(y) # Only training
51
+ if cond_drop_all:
52
+ y[:] = self.null_embedding
53
+ return self.cond_mlp(y)
54
+
55
+ class ClassEncoder(BaseCondEncoder):
56
+ def __init__(
57
+ self,
58
+ config_path
59
+ ):
60
+ super().__init__(config_path)
61
+ self.num_cond = self.config['num_cond']
62
+ self.embed = nn.Embedding(self.num_cond, self.embed_dim)
63
+
64
+ def forward(self, y, cond_drop_all:bool = False):
65
+ y = self.embed(y)
66
+ y = self.cond_drop(y) # Only training
67
+ if cond_drop_all:
68
+ y[:] = self.null_embedding
69
+ return self.cond_mlp(y)
helper/data_generator.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.datasets import CIFAR10, CelebA
2
+ from torch.utils.data import DataLoader, Dataset
3
+ from torchvision.transforms import Compose, ToTensor, Lambda, CenterCrop, Resize, RandomHorizontalFlip
4
+ import os
5
+ import torch
6
+ import json
7
+ from PIL import Image as im
8
+ from helper.tokenizer import Tokenizer
9
+ from transformers import AutoProcessor
10
+
11
+ def center_crop_and_resize(img, crop_size, resize_size):
12
+ width, height = img.size
13
+
14
+ # 1. Center Crop
15
+ left = (width - crop_size) / 2
16
+ top = (height - crop_size) / 2
17
+ right = (width + crop_size) / 2
18
+ bottom = (height + crop_size) / 2
19
+
20
+ img_cropped = img.crop((left, top, right, bottom))
21
+
22
+ # 2. Resize
23
+ img_resized = img_cropped.resize((resize_size, resize_size), im.Resampling.BICUBIC)
24
+
25
+ return img_resized
26
+
27
+ class UnlabelDataset(Dataset):
28
+ def __init__(self, path, transform):
29
+ self.path = path
30
+ self.file_list = os.listdir(path)
31
+ self.transform = transform
32
+
33
+ def __len__(self) :
34
+ return len(self.file_list)
35
+
36
+ def __getitem__(self, index):
37
+ img_path = self.path + self.file_list[index]
38
+ image = im.open(img_path)
39
+ image = self.transform(image)
40
+ return image
41
+
42
+ class CompositeDataset(Dataset):
43
+ def __init__(self, path, text_path, processor: AutoProcessor = None):
44
+ self.path = path
45
+ self.text_path = text_path
46
+ self.tokenizer = Tokenizer()
47
+ self.processor = processor
48
+
49
+ self.file_numbers = os.listdir(path)
50
+ self.file_numbers = [ os.path.splitext(filename)[0] for filename in self.file_numbers ]
51
+
52
+ self.transform = Compose([
53
+ ToTensor(),
54
+ CenterCrop(400),
55
+ Resize(256, antialias=True),
56
+ RandomHorizontalFlip(),
57
+ Lambda(lambda x: (x - 0.5) * 2)
58
+ ])
59
+
60
+ def __len__(self) :
61
+ return len(self.file_numbers)
62
+
63
+ def get_text(self, text_path):
64
+ with open(text_path, encoding = 'CP949') as f:
65
+ text = json.load(f)['description']['impression']['description']
66
+ return text
67
+
68
+ def __getitem__(self, idx) :
69
+ img_path = self.path + self.file_numbers[idx] + '.png'
70
+ text_path = self.text_path + self.file_numbers[idx] + '.json'
71
+ image = im.open(img_path)
72
+ text = self.get_text(text_path)
73
+ if self.processor is not None:
74
+ image = center_crop_and_resize(image, 400, 256)
75
+ inputs = self.processor(
76
+ text=text,
77
+ images=image,
78
+ return_tensors="pt",
79
+ padding='max_length',
80
+ max_length=77,
81
+ truncation=True,
82
+ )
83
+ for j in inputs:
84
+ inputs[j] = inputs[j].squeeze(0)
85
+ return inputs
86
+ else:
87
+ image = self.transform(image)
88
+ text = self.tokenizer.tokenize(text)
89
+ for j in text:
90
+ text[j] = text[j].squeeze(0)
91
+ return image, text
92
+
93
+ class DataGenerator():
94
+ def __init__(self, num_workers: int = 4, pin_memory: bool = True):
95
+ self.transform = Compose([
96
+ ToTensor(),
97
+ Lambda(lambda x: (x - 0.5) * 2)
98
+ ])
99
+ self.num_workers = num_workers
100
+ self.pin_memory = pin_memory
101
+
102
+ def cifar10(self, path = './datasets', batch_size : int = 64, train : bool = True):
103
+ train_data = CIFAR10(path, download = True, train = train, transform = self.transform)
104
+ dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
105
+ return dl
106
+
107
+ def celeba(self, path = './datasets', batch_size : int = 16):
108
+ train_data = CelebA(path, transform = Compose([
109
+ ToTensor(),
110
+ CenterCrop(178),
111
+ Resize(128),
112
+ Lambda(lambda x: (x - 0.5) * 2)
113
+ ]))
114
+ dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
115
+ return dl
116
+
117
+ def composite(self, path, text_path, batch_size : int = 16, is_process: bool = False):
118
+ processor = None
119
+ if is_process:
120
+ model_name = "Bingsu/clip-vit-base-patch32-ko"
121
+ processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
122
+ dataset = CompositeDataset(path, text_path, processor)
123
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True,
124
+ num_workers=self.num_workers, pin_memory=self.pin_memory)
125
+
126
+ def random_data(self, size, batch_size : int = 4):
127
+ train_data = torch.randn(size)
128
+ return DataLoader(train_data, batch_size)
129
+
helper/ema.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/lucidrains/ema-pytorch/tree/main
2
+
3
+ from __future__ import annotations
4
+ from typing import Callable
5
+
6
+ from copy import deepcopy
7
+ from functools import partial
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+ from torch.nn import Module
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+ def divisible_by(num, den):
17
+ return (num % den) == 0
18
+
19
+ def get_module_device(m: Module):
20
+ return next(m.parameters()).device
21
+
22
+ def maybe_coerce_dtype(t, dtype):
23
+ if t.dtype == dtype:
24
+ return t
25
+
26
+ return t.to(dtype)
27
+
28
+ def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False):
29
+ if auto_move_device:
30
+ src = src.to(tgt.device)
31
+
32
+ if coerce_dtype:
33
+ src = maybe_coerce_dtype(src, tgt.dtype)
34
+
35
+ tgt.copy_(src)
36
+
37
+ def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False):
38
+ if auto_move_device:
39
+ src = src.to(tgt.device)
40
+
41
+ if coerce_dtype:
42
+ src = maybe_coerce_dtype(src, tgt.dtype)
43
+
44
+ tgt.lerp_(src, weight)
45
+
46
+ class EMA(Module):
47
+ """
48
+ Implements exponential moving average shadowing for your model.
49
+
50
+ Utilizes an inverse decay schedule to manage longer term training runs.
51
+ By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
52
+
53
+ @crowsonkb's notes on EMA Warmup:
54
+
55
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
56
+ good values for models you plan to train for a million or more steps (reaches decay
57
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
58
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
59
+ 215.4k steps).
60
+
61
+ Args:
62
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
63
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
64
+ min_value (float): The minimum EMA decay rate. Default: 0.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ model: Module,
70
+ ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
71
+ beta = 0.9999,
72
+ update_after_step = 100,
73
+ update_every = 10,
74
+ inv_gamma = 1.0,
75
+ power = 2 / 3,
76
+ min_value = 0.0,
77
+ param_or_buffer_names_no_ema: set[str] = set(),
78
+ ignore_names: set[str] = set(),
79
+ ignore_startswith_names: set[str] = set(),
80
+ include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
81
+ allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
82
+ use_foreach = False,
83
+ update_model_with_ema_every = None, # update the model with EMA model weights every number of steps, for better continual learning https://arxiv.org/abs/2406.02596
84
+ update_model_with_ema_beta = 0., # amount of model weight to keep when updating to EMA (hare to tortoise)
85
+ forward_method_names: tuple[str, ...] = (),
86
+ move_ema_to_online_device = False,
87
+ coerce_dtype = False,
88
+ lazy_init_ema = False,
89
+ ):
90
+ super().__init__()
91
+ self.beta = beta
92
+
93
+ self.is_frozen = beta == 1.
94
+
95
+ # whether to include the online model within the module tree, so that state_dict also saves it
96
+
97
+ self.include_online_model = include_online_model
98
+
99
+ if include_online_model:
100
+ self.online_model = model
101
+ else:
102
+ self.online_model = [model] # hack
103
+
104
+ # handle callable returning ema module
105
+
106
+ if not isinstance(ema_model, Module) and callable(ema_model):
107
+ ema_model = ema_model()
108
+
109
+ # ema model
110
+
111
+ self.ema_model = None
112
+ self.forward_method_names = forward_method_names
113
+
114
+ if not lazy_init_ema:
115
+ self.init_ema(ema_model)
116
+ else:
117
+ assert not exists(ema_model)
118
+
119
+ # tensor update functions
120
+
121
+ self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
122
+ self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
123
+
124
+ # updating hyperparameters
125
+
126
+ self.update_every = update_every
127
+ self.update_after_step = update_after_step
128
+
129
+ self.inv_gamma = inv_gamma
130
+ self.power = power
131
+ self.min_value = min_value
132
+
133
+ assert isinstance(param_or_buffer_names_no_ema, (set, list))
134
+ self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
135
+
136
+ self.ignore_names = ignore_names
137
+ self.ignore_startswith_names = ignore_startswith_names
138
+
139
+ # continual learning related
140
+
141
+ self.update_model_with_ema_every = update_model_with_ema_every
142
+ self.update_model_with_ema_beta = update_model_with_ema_beta
143
+
144
+ # whether to manage if EMA model is kept on a different device
145
+
146
+ self.allow_different_devices = allow_different_devices
147
+
148
+ # whether to coerce dtype when copy or lerp from online to EMA model
149
+
150
+ self.coerce_dtype = coerce_dtype
151
+
152
+ # whether to move EMA model to online model device automatically
153
+
154
+ self.move_ema_to_online_device = move_ema_to_online_device
155
+
156
+ # whether to use foreach
157
+
158
+ if use_foreach:
159
+ assert hasattr(torch, '_foreach_lerp_') and hasattr(torch, '_foreach_copy_'), 'your version of torch does not have the prerequisite foreach functions'
160
+
161
+ self.use_foreach = use_foreach
162
+
163
+ # init and step states
164
+
165
+ self.register_buffer('initted', torch.tensor(False))
166
+ self.register_buffer('step', torch.tensor(0))
167
+
168
+ def init_ema(
169
+ self,
170
+ ema_model: Module | None = None
171
+ ):
172
+ self.ema_model = ema_model
173
+
174
+ if not exists(self.ema_model):
175
+ try:
176
+ self.ema_model = deepcopy(self.model)
177
+ except Exception as e:
178
+ print(f'Error: While trying to deepcopy model: {e}')
179
+ print('Your model was not copyable. Please make sure you are not using any LazyLinear')
180
+ exit()
181
+
182
+ for p in self.ema_model.parameters():
183
+ p.detach_()
184
+
185
+ # forwarding methods
186
+
187
+ for forward_method_name in self.forward_method_names:
188
+ fn = getattr(self.ema_model, forward_method_name)
189
+ setattr(self, forward_method_name, fn)
190
+
191
+ # parameter and buffer names
192
+
193
+ self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
194
+ self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
195
+
196
+ def add_to_optimizer_post_step_hook(self, optimizer):
197
+ assert hasattr(optimizer, 'register_step_post_hook')
198
+
199
+ def hook(*_):
200
+ self.update()
201
+
202
+ return optimizer.register_step_post_hook(hook)
203
+
204
+ @property
205
+ def model(self):
206
+ return self.online_model if self.include_online_model else self.online_model[0]
207
+
208
+ def eval(self):
209
+ return self.ema_model.eval()
210
+
211
+ @torch.no_grad()
212
+ def forward_eval(self, *args, **kwargs):
213
+ # handy function for invoking ema model with no grad + eval
214
+ training = self.ema_model.training
215
+ out = self.ema_model(*args, **kwargs)
216
+ self.ema_model.train(training)
217
+ return out
218
+
219
+ def restore_ema_model_device(self):
220
+ device = self.initted.device
221
+ self.ema_model.to(device)
222
+
223
+ def get_params_iter(self, model):
224
+ for name, param in model.named_parameters():
225
+ if name not in self.parameter_names:
226
+ continue
227
+ yield name, param
228
+
229
+ def get_buffers_iter(self, model):
230
+ for name, buffer in model.named_buffers():
231
+ if name not in self.buffer_names:
232
+ continue
233
+ yield name, buffer
234
+
235
+ def copy_params_from_model_to_ema(self):
236
+ copy = self.inplace_copy
237
+
238
+ for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
239
+ copy(ma_params.data, current_params.data)
240
+
241
+ for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
242
+ copy(ma_buffers.data, current_buffers.data)
243
+
244
+ def copy_params_from_ema_to_model(self):
245
+ copy = self.inplace_copy
246
+
247
+ for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
248
+ copy(current_params.data, ma_params.data)
249
+
250
+ for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
251
+ copy(current_buffers.data, ma_buffers.data)
252
+
253
+ def update_model_with_ema(self, decay = None):
254
+ if not exists(decay):
255
+ decay = self.update_model_with_ema_beta
256
+
257
+ if decay == 0.:
258
+ return self.copy_params_from_ema_to_model()
259
+
260
+ self.update_moving_average(self.model, self.ema_model, decay)
261
+
262
+ def get_current_decay(self):
263
+ epoch = (self.step - self.update_after_step - 1).clamp(min = 0.)
264
+ value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
265
+
266
+ if epoch.item() <= 0:
267
+ return 0.
268
+
269
+ return value.clamp(min = self.min_value, max = self.beta).item()
270
+
271
+ def update(self):
272
+ step = self.step.item()
273
+ self.step += 1
274
+
275
+ if not self.initted.item():
276
+ if not exists(self.ema_model):
277
+ self.init_ema()
278
+
279
+ self.copy_params_from_model_to_ema()
280
+ self.initted.data.copy_(torch.tensor(True))
281
+ return
282
+
283
+ should_update = divisible_by(step, self.update_every)
284
+
285
+ if should_update and step <= self.update_after_step:
286
+ self.copy_params_from_model_to_ema()
287
+ return
288
+
289
+ if should_update:
290
+ self.update_moving_average(self.ema_model, self.model)
291
+
292
+ if exists(self.update_model_with_ema_every) and divisible_by(step, self.update_model_with_ema_every):
293
+ self.update_model_with_ema()
294
+
295
+ @torch.no_grad()
296
+ def update_moving_average(self, ma_model, current_model, current_decay = None):
297
+ if self.is_frozen:
298
+ return
299
+
300
+ # move ema model to online model device if not same and needed
301
+
302
+ if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
303
+ ma_model.to(get_module_device(current_model))
304
+
305
+ # get current decay
306
+
307
+ if not exists(current_decay):
308
+ current_decay = self.get_current_decay()
309
+
310
+ # store all source and target tensors to copy or lerp
311
+
312
+ tensors_to_copy = []
313
+ tensors_to_lerp = []
314
+
315
+ # loop through parameters
316
+
317
+ for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
318
+ if name in self.ignore_names:
319
+ continue
320
+
321
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
322
+ continue
323
+
324
+ if name in self.param_or_buffer_names_no_ema:
325
+ tensors_to_copy.append((ma_params.data, current_params.data))
326
+ continue
327
+
328
+ tensors_to_lerp.append((ma_params.data, current_params.data))
329
+
330
+ # loop through buffers
331
+
332
+ for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
333
+ if name in self.ignore_names:
334
+ continue
335
+
336
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
337
+ continue
338
+
339
+ if name in self.param_or_buffer_names_no_ema:
340
+ tensors_to_copy.append((ma_buffer.data, current_buffer.data))
341
+ continue
342
+
343
+ tensors_to_lerp.append((ma_buffer.data, current_buffer.data))
344
+
345
+ # execute inplace copy or lerp
346
+
347
+ if not self.use_foreach:
348
+
349
+ for tgt, src in tensors_to_copy:
350
+ self.inplace_copy(tgt, src)
351
+
352
+ for tgt, src in tensors_to_lerp:
353
+ self.inplace_lerp(tgt, src, 1. - current_decay)
354
+
355
+ else:
356
+ # use foreach if available and specified
357
+
358
+ if self.allow_different_devices:
359
+ tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy]
360
+ tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp]
361
+
362
+ if self.coerce_dtype:
363
+ tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy]
364
+ tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp]
365
+
366
+ if len(tensors_to_copy) > 0:
367
+ tgt_copy, src_copy = zip(*tensors_to_copy)
368
+ torch._foreach_copy_(tgt_copy, src_copy)
369
+
370
+ if len(tensors_to_lerp) > 0:
371
+ tgt_lerp, src_lerp = zip(*tensors_to_lerp)
372
+ torch._foreach_lerp_(tgt_lerp, src_lerp, 1. - current_decay)
373
+
374
+ def __call__(self, *args, **kwargs):
375
+ return self.ema_model(*args, **kwargs)
helper/loader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from helper.ema import EMA
4
+ from transformers import get_cosine_schedule_with_warmup
5
+
6
+ class Loader():
7
+ def __init__(self, device = None):
8
+ self.device = device
9
+
10
+ def print_model(self, check_point):
11
+ print("Epoch: " + str(check_point["epoch"]))
12
+ print("Training step: " + str(check_point["training_steps"]))
13
+ print("Best loss: " + str(check_point["best_loss"]))
14
+ print("Batch size: " + str(check_point["batch_size"]))
15
+ print("Number of batches: " + str(check_point["number_of_batches"]))
16
+
17
+ def model_load(self, file_name : str, model : nn.Module,
18
+ print_dict : bool = True, is_ema: bool = True):
19
+ check_point = torch.load(file_name + ".pth", map_location=self.device,
20
+ weights_only=True)
21
+ if print_dict: self.print_model(check_point)
22
+ if is_ema:
23
+ model = EMA(model)
24
+ model.load_state_dict(check_point['ema_state_dict'])
25
+ model = model.ema_model
26
+ else:
27
+ model.load_state_dict(check_point['model_state_dict'])
28
+ model.eval()
29
+ print("===Model loaded!===")
30
+ return model
31
+
32
+ def load_for_training(self, file_name: str, model: nn.Module, print_dict: bool = True):
33
+ check_point = torch.load(file_name + ".pth", map_location=self.device,
34
+ weights_only=True)
35
+ if print_dict: self.print_model(check_point)
36
+ model.load_state_dict(check_point['model_state_dict'])
37
+ model.train()
38
+ ema = EMA(model)
39
+ ema.load_state_dict(check_point['ema_state_dict'])
40
+ ema.train()
41
+ optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4)
42
+ optimizer.load_state_dict(check_point["optimizer_state_dict"])
43
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
44
+ scheduler.load_state_dict(check_point["scheduler_state_dict"])
45
+ epoch = check_point["epoch"]
46
+ loss = check_point["best_loss"]
47
+ print("===Model/EMA/Optimizer/Scheduler/Epoch/Loss loaded!===")
48
+ return model, ema, optimizer, scheduler, epoch, loss
helper/painter.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from PIL import Image as im
3
+ from tqdm import tqdm
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import numpy as np
7
+
8
+ class Painter(object):
9
+ def __init__(self) :
10
+ pass
11
+
12
+ def show_images(self, images, title : str = '', index : bool = False, cmap = None, show = True):
13
+ images = images.permute(0, 2, 3, 1)
14
+ if type(images) is torch.Tensor:
15
+ images = images.detach().cpu().numpy()
16
+ images = np.clip(images / 2 + 0.5, 0, 1)
17
+
18
+ fig = plt.figure(figsize=(8, 8))
19
+ rows = int(len(images) ** (1 / 2))
20
+ cols = round(len(images) / rows)
21
+
22
+ idx = 0
23
+ for _ in range(rows):
24
+ for _ in range(cols):
25
+ fig.add_subplot(rows, cols, idx + 1)
26
+
27
+ if idx < len(images):
28
+ plt.imshow(images[idx], cmap = cmap)
29
+ if index :
30
+ plt.title(idx + 1)
31
+ plt.axis('off')
32
+ idx += 1
33
+ fig.suptitle(title, fontsize=30)
34
+ if show:
35
+ plt.show()
36
+
37
+ def show_first_batch(self, loader):
38
+ for batch in loader:
39
+ self.show_images(images = batch, title = "First Batch")
40
+ break
41
+
42
+ def make_gif(self, images, file_name):
43
+ imgs = []
44
+ for i in tqdm(range(len(images))):
45
+ img_buf = io.BytesIO()
46
+ self.show_images(images[i], title = 't = ' + str(i), show=False)
47
+ plt.savefig(img_buf, format='png')
48
+ imgs.append(im.open(img_buf))
49
+ imgs[0].save(file_name + '.gif', format='GIF', append_images=imgs, save_all=True, duration=1, loop=0)
50
+ plt.close('all')
51
+
helper/tokenizer.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+
3
+ class Tokenizer:
4
+ def __init__(self, model_name="Bingsu/clip-vit-base-patch32-ko"):
5
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ self.vocab_size = self.tokenizer.vocab_size
7
+
8
+ def tokenize(self, text):
9
+ return self.tokenizer(text, padding='max_length', max_length=77, truncation=True, return_tensors='pt')
helper/trainer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+ from accelerate import Accelerator
5
+ from tqdm import tqdm
6
+ from typing import Callable
7
+ from helper.ema import EMA
8
+
9
+ class Trainer():
10
+ def __init__(self,
11
+ model: nn.Module,
12
+ loss_fn: Callable,
13
+ ema: EMA = None,
14
+ optimizer: torch.optim.Optimizer = None,
15
+ scheduler: torch.optim.lr_scheduler = None,
16
+ start_epoch = 0,
17
+ best_loss = float("inf"),
18
+ accumulation_steps: int = 1,
19
+ max_grad_norm: float = 1.0):
20
+ self.accelerator = Accelerator(mixed_precision = 'fp16', gradient_accumulation_steps=accumulation_steps)
21
+ self.model = model.to(self.accelerator.device)
22
+ if ema is None:
23
+ self.ema = EMA(self.model).to(self.accelerator.device)
24
+ else:
25
+ self.ema = ema.to(self.accelerator.device)
26
+ self.loss_fn = loss_fn
27
+ self.optimizer = optimizer
28
+ if self.optimizer is None:
29
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = 1e-4)
30
+ self.scheduler = scheduler
31
+ if self.scheduler is None:
32
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100)
33
+ self.start_epoch = start_epoch
34
+ self.best_loss = best_loss
35
+ self.accumulation_steps = accumulation_steps
36
+ self.max_grad_norm = max_grad_norm
37
+
38
+ def train(self, dl : DataLoader, epochs: int, file_name : str, no_label : bool = False):
39
+ self.model.train()
40
+ self.model, self.optimizer, data_loader, self.scheduler = self.accelerator.prepare(
41
+ self.model, self.optimizer, dl, self.scheduler
42
+ )
43
+
44
+ for epoch in range(self.start_epoch + 1, epochs + 1):
45
+ epoch_loss = 0.0
46
+ progress_bar = tqdm(data_loader, leave=False, desc=f"Epoch {epoch}/{epochs}", colour="#005500", disable = not self.accelerator.is_local_main_process)
47
+ for step, batch in enumerate(progress_bar):
48
+ with self.accelerator.accumulate(self.model): # Context manager for accumulation
49
+ if no_label:
50
+ if isinstance(batch, list):
51
+ x = batch[0].to(self.accelerator.device)
52
+ else:
53
+ x = batch.to(self.accelerator.device)
54
+ else:
55
+ x, y = batch[0].to(self.accelerator.device), batch[1].to(self.accelerator.device)
56
+
57
+ with self.accelerator.autocast():
58
+ if no_label:
59
+ loss = self.loss_fn(x)
60
+ else:
61
+ loss = self.loss_fn(x, y=y)
62
+
63
+ # Normalize the loss
64
+ self.accelerator.backward(loss)
65
+
66
+ # Gradient Clipping:
67
+ if self.max_grad_norm is not None and self.accelerator.sync_gradients:
68
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
69
+
70
+ # Only step optimizer and scheduler when we have accumulated enough
71
+ self.optimizer.step()
72
+ self.ema.update()
73
+ self.optimizer.zero_grad()
74
+
75
+ epoch_loss += loss.item()
76
+ progress_bar.set_postfix(loss=epoch_loss / (min(step + 1, len(data_loader)))) # Correct progress bar update
77
+
78
+ self.accelerator.wait_for_everyone()
79
+ if self.accelerator.is_main_process:
80
+ epoch_loss = epoch_loss / len(progress_bar)
81
+ self.scheduler.step()
82
+ log_string = f"Loss at epoch {epoch}: {epoch_loss :.4f}"
83
+
84
+ # Save the best model
85
+ if self.best_loss > epoch_loss:
86
+ self.best_loss = epoch_loss
87
+ torch.save({
88
+ "model_state_dict": self.accelerator.get_state_dict(self.model),
89
+ "ema_state_dict": self.ema.state_dict(),
90
+ "optimizer_state_dict": self.optimizer.state_dict(),
91
+ "scheduler_state_dict": self.scheduler.state_dict(),
92
+ "epoch": epoch,
93
+ "training_steps": epoch * len(dl),
94
+ "best_loss": self.best_loss,
95
+ "batch_size": dl.batch_size,
96
+ "number_of_batches": len(dl)
97
+ }, file_name + '.pth')
98
+ log_string += " --> Best model ever (stored)"
99
+ print(log_string)
helper/util.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def extract(a, t, x_shape):
2
+ b, *_ = t.shape
3
+ out = a.gather(-1, t)
4
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))