Spaces:
Running
Running
Upload 35 files
Browse files- auto_encoder/components/distributions.py +43 -0
- auto_encoder/components/nonlinearity.py +5 -0
- auto_encoder/components/normalize.py +4 -0
- auto_encoder/components/resnet_block.py +46 -0
- auto_encoder/components/sampling.py +31 -0
- auto_encoder/models/auto_encoder.py +31 -0
- auto_encoder/models/decoder.py +78 -0
- auto_encoder/models/encoder.py +71 -0
- auto_encoder/models/variational_auto_encoder.py +45 -0
- clip/encoders/image_encoder.py +44 -0
- clip/encoders/text_encoder.py +29 -0
- clip/models/clip.py +70 -0
- clip/models/ko_clip.py +26 -0
- configs/composite_clip_config.yaml +15 -0
- configs/composite_config.yaml +38 -0
- diffusion_model/models/clip_latent_diffusion_model.py +40 -0
- diffusion_model/models/diffusion_model.py +42 -0
- diffusion_model/models/latent_diffusion_model.py +37 -0
- diffusion_model/network/attention.py +187 -0
- diffusion_model/network/blocks.py +89 -0
- diffusion_model/network/timestep_embedding.py +42 -0
- diffusion_model/network/unet.py +217 -0
- diffusion_model/network/unet_wrapper.py +32 -0
- diffusion_model/sampler/base_sampler.py +69 -0
- diffusion_model/sampler/ddim.py +29 -0
- diffusion_model/sampler/ddpm.py +20 -0
- helper/beta_generator.py +47 -0
- helper/cond_encoder.py +69 -0
- helper/data_generator.py +129 -0
- helper/ema.py +375 -0
- helper/loader.py +48 -0
- helper/painter.py +51 -0
- helper/tokenizer.py +9 -0
- helper/trainer.py +99 -0
- helper/util.py +4 -0
auto_encoder/components/distributions.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#source: https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class DiagonalGaussianDistribution(object):
|
6 |
+
def __init__(self, parameters, deterministic=False):
|
7 |
+
self.parameters = parameters
|
8 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
9 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
10 |
+
self.deterministic = deterministic
|
11 |
+
self.std = torch.exp(0.5 * self.logvar)
|
12 |
+
self.var = torch.exp(self.logvar)
|
13 |
+
if self.deterministic:
|
14 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
15 |
+
|
16 |
+
def sample(self):
|
17 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
18 |
+
return x
|
19 |
+
|
20 |
+
def kl(self, other=None):
|
21 |
+
if self.deterministic:
|
22 |
+
return torch.Tensor([0.])
|
23 |
+
else:
|
24 |
+
if other is None:
|
25 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
26 |
+
+ self.var - 1.0 - self.logvar,
|
27 |
+
dim=[1, 2, 3])
|
28 |
+
else:
|
29 |
+
return 0.5 * torch.sum(
|
30 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
31 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
32 |
+
dim=[1, 2, 3])
|
33 |
+
|
34 |
+
def nll(self, sample, dims=[1,2,3]):
|
35 |
+
if self.deterministic:
|
36 |
+
return torch.Tensor([0.])
|
37 |
+
logtwopi = np.log(2.0 * np.pi)
|
38 |
+
return 0.5 * torch.sum(
|
39 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
40 |
+
dim=dims)
|
41 |
+
|
42 |
+
def mode(self):
|
43 |
+
return self.mean
|
auto_encoder/components/nonlinearity.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def nonlinearity(x):
|
4 |
+
# swish
|
5 |
+
return x*torch.sigmoid(x)
|
auto_encoder/components/normalize.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def Normalize(in_channels : int, num_groups : int = 32):
|
4 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
auto_encoder/components/resnet_block.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from auto_encoder.components.normalize import Normalize
|
6 |
+
from auto_encoder.components.nonlinearity import nonlinearity
|
7 |
+
|
8 |
+
class ResnetBlock(nn.Module):
|
9 |
+
def __init__(self, *, in_channels : int, out_channels : int = None, conv_shortcut=False, dropout):
|
10 |
+
super().__init__()
|
11 |
+
self.in_channels = in_channels
|
12 |
+
out_channels = in_channels if out_channels is None else out_channels
|
13 |
+
self.out_channels = out_channels
|
14 |
+
self.use_conv_shortcut = conv_shortcut
|
15 |
+
|
16 |
+
self.norm1 = Normalize(in_channels)
|
17 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
|
18 |
+
self.norm2 = Normalize(out_channels)
|
19 |
+
self.dropout = torch.nn.Dropout(dropout)
|
20 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
|
21 |
+
|
22 |
+
if self.in_channels != self.out_channels:
|
23 |
+
if self.use_conv_shortcut:
|
24 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
25 |
+
else:
|
26 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
h = x
|
30 |
+
h = self.norm1(h)
|
31 |
+
h = nonlinearity(h)
|
32 |
+
h = self.conv1(h)
|
33 |
+
h = self.norm2(h)
|
34 |
+
h = nonlinearity(h)
|
35 |
+
h = self.dropout(h)
|
36 |
+
h = self.conv2(h)
|
37 |
+
|
38 |
+
if self.in_channels != self.out_channels:
|
39 |
+
if self.use_conv_shortcut:
|
40 |
+
x = self.conv_shortcut(x)
|
41 |
+
else:
|
42 |
+
x = self.nin_shortcut(x)
|
43 |
+
|
44 |
+
return x+h
|
45 |
+
|
46 |
+
|
auto_encoder/components/sampling.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Upsample(nn.Module):
|
5 |
+
def __init__(self, in_channels : int, with_conv : bool):
|
6 |
+
super().__init__()
|
7 |
+
self.with_conv = with_conv
|
8 |
+
if self.with_conv:
|
9 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1)
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
x = torch.nn.functional.interpolate(x, scale_factor = 2.0, mode = "nearest")
|
13 |
+
if self.with_conv:
|
14 |
+
x = self.conv(x)
|
15 |
+
return x
|
16 |
+
|
17 |
+
class Downsample(nn.Module):
|
18 |
+
def __init__(self, in_channels : int, with_conv : bool):
|
19 |
+
super().__init__()
|
20 |
+
self.with_conv = with_conv
|
21 |
+
if self.with_conv:
|
22 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 2, padding = 0)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.with_conv:
|
26 |
+
pad = (0, 1, 0, 1)
|
27 |
+
x = torch.nn.functional.pad(x, pad, mode = "constant", value = 0)
|
28 |
+
x = self.conv(x)
|
29 |
+
else:
|
30 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size = 2, stride = 2)
|
31 |
+
return x
|
auto_encoder/models/auto_encoder.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from auto_encoder.models.decoder import Decoder
|
4 |
+
from auto_encoder.models.encoder import Encoder
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
class AutoEncoder(nn.Module):
|
8 |
+
def __init__(self, config_path : str):
|
9 |
+
super().__init__()
|
10 |
+
with open(config_path, "r") as file:
|
11 |
+
config = yaml.safe_load(file)
|
12 |
+
self.add_module('encoder', Encoder(**config["encoder"]))
|
13 |
+
self.add_module('decoder', Decoder(**config["decoder"]))
|
14 |
+
|
15 |
+
def encode(self, x):
|
16 |
+
h = self.encoder(x)
|
17 |
+
return h
|
18 |
+
|
19 |
+
def decode(self, z):
|
20 |
+
z = self.decoder(z)
|
21 |
+
return z
|
22 |
+
|
23 |
+
def reconstruct(self, x):
|
24 |
+
return self.decode(self.encode(x))
|
25 |
+
|
26 |
+
def loss(self, x):
|
27 |
+
x_hat = self(x)
|
28 |
+
return F.mse_loss(x, x_hat)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
return self.reconstruct(x)
|
auto_encoder/models/decoder.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#source : https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from auto_encoder.components.normalize import Normalize
|
6 |
+
from auto_encoder.components.resnet_block import ResnetBlock
|
7 |
+
from auto_encoder.components.sampling import Upsample
|
8 |
+
from auto_encoder.components.nonlinearity import nonlinearity
|
9 |
+
|
10 |
+
class Decoder(nn.Module):
|
11 |
+
def __init__(self, *, in_channels, out_channels, resolution, channels, channel_multipliers = (1, 2, 4, 8), z_channels, num_res_blocks,
|
12 |
+
dropout = 0.0, resample_with_conv : bool = True):
|
13 |
+
super().__init__()
|
14 |
+
self.ch = channels
|
15 |
+
self.num_resolutions = len(channel_multipliers)
|
16 |
+
self.num_res_blocks = num_res_blocks
|
17 |
+
self.in_channels = in_channels
|
18 |
+
self.z_channels = z_channels
|
19 |
+
|
20 |
+
in_ch_mult = (1 , ) + tuple(channel_multipliers)
|
21 |
+
block_in = self.ch * in_ch_mult[self.num_resolutions - 1]
|
22 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
23 |
+
self.z_shape = (1 , z_channels, curr_res, curr_res)
|
24 |
+
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
25 |
+
|
26 |
+
# z to block_in
|
27 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size = 3, stride = 1, padding = 1)
|
28 |
+
|
29 |
+
# middle
|
30 |
+
self.mid = nn.Module()
|
31 |
+
self.mid.block_1 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
|
32 |
+
self.mid.block_2 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
|
33 |
+
|
34 |
+
# upsampling
|
35 |
+
|
36 |
+
self.up = nn.ModuleList()
|
37 |
+
for i_level in reversed(range(self.num_resolutions)):
|
38 |
+
block = nn.ModuleList()
|
39 |
+
block_out = self.ch * channel_multipliers[i_level]
|
40 |
+
for i_block in range(self.num_res_blocks + 1):
|
41 |
+
block.append(ResnetBlock(in_channels = block_in, out_channels = block_out,
|
42 |
+
dropout = dropout))
|
43 |
+
block_in = block_out
|
44 |
+
up = nn.Module()
|
45 |
+
up.block = block
|
46 |
+
if i_level != 0:
|
47 |
+
up.upsample = Upsample(block_in, resample_with_conv)
|
48 |
+
curr_res = curr_res * 2
|
49 |
+
self.up.insert(0, up)
|
50 |
+
|
51 |
+
# end
|
52 |
+
self.norm_out = Normalize(block_in)
|
53 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_channels,
|
54 |
+
kernel_size = 3, stride = 1, padding = 1)
|
55 |
+
|
56 |
+
def forward(self, z):
|
57 |
+
assert z.shape[1:] == self.z_shape[1:]
|
58 |
+
self.last_z_shape = z.shape
|
59 |
+
|
60 |
+
# z to block_in
|
61 |
+
h = self.conv_in(z)
|
62 |
+
|
63 |
+
# middle
|
64 |
+
h = self.mid.block_1(h)
|
65 |
+
h = self.mid.block_2(h)
|
66 |
+
|
67 |
+
# upsampling
|
68 |
+
for i_level in reversed(range(self.num_resolutions)):
|
69 |
+
for i_block in range(self.num_res_blocks + 1):
|
70 |
+
h = self.up[i_level].block[i_block](h)
|
71 |
+
if i_level != 0:
|
72 |
+
h = self.up[i_level].upsample(h)
|
73 |
+
|
74 |
+
# end
|
75 |
+
h = self.norm_out(h)
|
76 |
+
h = nonlinearity(h)
|
77 |
+
h = self.conv_out(h)
|
78 |
+
return h
|
auto_encoder/models/encoder.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#source : https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from auto_encoder.components.normalize import Normalize
|
6 |
+
from auto_encoder.components.resnet_block import ResnetBlock
|
7 |
+
from auto_encoder.components.sampling import Downsample
|
8 |
+
from auto_encoder.components.nonlinearity import nonlinearity
|
9 |
+
|
10 |
+
class Encoder(nn.Module):
|
11 |
+
def __init__(self, *, in_channels, resolution, channels, channel_multipliers = (1, 2, 4, 8), z_channels, num_res_blocks,
|
12 |
+
dropout = 0.0, resample_with_conv : bool = True, double_z : bool = True):
|
13 |
+
super().__init__()
|
14 |
+
self.ch = channels
|
15 |
+
self.num_resolutions = len(channel_multipliers)
|
16 |
+
self.num_res_blocks = num_res_blocks
|
17 |
+
self.in_channels = in_channels
|
18 |
+
self.z_channels = z_channels
|
19 |
+
|
20 |
+
# downsampling
|
21 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size = 3, stride = 1, padding = 1)
|
22 |
+
curr_res = resolution
|
23 |
+
in_ch_mult = (1, ) + tuple(channel_multipliers)
|
24 |
+
self.in_ch_mult = in_ch_mult
|
25 |
+
self.down = nn.ModuleList()
|
26 |
+
for i_level in range(self.num_resolutions):
|
27 |
+
block = nn.ModuleList()
|
28 |
+
block_in = self.ch * in_ch_mult[i_level]
|
29 |
+
block_out = self.ch * channel_multipliers[i_level]
|
30 |
+
for i_block in range(self.num_res_blocks):
|
31 |
+
block.append(ResnetBlock(in_channels = block_in, out_channels = block_out, dropout = dropout))
|
32 |
+
block_in = block_out
|
33 |
+
down = nn.Module()
|
34 |
+
down.block = block
|
35 |
+
if i_level != self.num_resolutions - 1:
|
36 |
+
down.downsample = Downsample(block_in, resample_with_conv)
|
37 |
+
curr_res = curr_res // 2
|
38 |
+
self.down.append(down)
|
39 |
+
|
40 |
+
# middle
|
41 |
+
self.mid = nn.Module()
|
42 |
+
self.mid.block_1 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
|
43 |
+
self.mid.block_2 = ResnetBlock(in_channels = block_in, out_channels = block_in, dropout = dropout)
|
44 |
+
|
45 |
+
# end
|
46 |
+
self.norm_out = Normalize(block_in)
|
47 |
+
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels,
|
48 |
+
kernel_size = 3, stride = 1, padding = 1)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
# downsampling
|
52 |
+
hs = [self.conv_in(x)]
|
53 |
+
for i_level in range(self.num_resolutions):
|
54 |
+
for i_block in range(self.num_res_blocks):
|
55 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
56 |
+
|
57 |
+
hs.append(h)
|
58 |
+
if i_level != self.num_resolutions - 1:
|
59 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
60 |
+
|
61 |
+
# middle
|
62 |
+
h = hs[-1]
|
63 |
+
h = self.mid.block_1(h)
|
64 |
+
h = self.mid.block_2(h)
|
65 |
+
|
66 |
+
# end
|
67 |
+
h = self.norm_out(h)
|
68 |
+
h = nonlinearity(h)
|
69 |
+
h = self.conv_out(h)
|
70 |
+
return h
|
71 |
+
|
auto_encoder/models/variational_auto_encoder.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from auto_encoder.models.encoder import Encoder
|
6 |
+
from auto_encoder.models.decoder import Decoder
|
7 |
+
import yaml
|
8 |
+
from auto_encoder.components.distributions import DiagonalGaussianDistribution
|
9 |
+
|
10 |
+
class VariationalAutoEncoder(nn.Module):
|
11 |
+
def __init__(self, config_path):
|
12 |
+
super().__init__()
|
13 |
+
with open(config_path, "r") as file:
|
14 |
+
config = yaml.safe_load(file)
|
15 |
+
self.add_module('encoder', Encoder(**config["encoder"]))
|
16 |
+
self.add_module('decoder', Decoder(**config["decoder"]))
|
17 |
+
self.embed_dim = config['vae']['embed_dim']
|
18 |
+
self.kld_weight = float(config['vae']['kld_weight'])
|
19 |
+
|
20 |
+
self.quant_conv = torch.nn.Conv2d(self.decoder.z_channels, 2*self.embed_dim, 1)
|
21 |
+
self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.z_channels, 1)
|
22 |
+
|
23 |
+
def encode(self, x):
|
24 |
+
h = self.encoder(x)
|
25 |
+
moments = self.quant_conv(h)
|
26 |
+
posterior = DiagonalGaussianDistribution(moments)
|
27 |
+
return posterior
|
28 |
+
|
29 |
+
def decode(self, z):
|
30 |
+
z = self.post_quant_conv(z)
|
31 |
+
dec = self.decoder(z)
|
32 |
+
return dec
|
33 |
+
|
34 |
+
def loss(self, x):
|
35 |
+
x_hat, posterior = self(x)
|
36 |
+
return F.mse_loss(x, x_hat) + self.kld_weight * posterior.kl().mean()
|
37 |
+
|
38 |
+
def forward(self, input, sample_posterior=True):
|
39 |
+
posterior = self.encode(input)
|
40 |
+
if sample_posterior:
|
41 |
+
z = posterior.sample()
|
42 |
+
else:
|
43 |
+
z = posterior.mode()
|
44 |
+
dec = self.decode(z)
|
45 |
+
return dec, posterior
|
clip/encoders/image_encoder.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class ImageEncoder(nn.Module):
|
5 |
+
def __init__(self, in_channels: int, resolution: int, patch_size: int,
|
6 |
+
number_of_features: int, number_of_heads:int, number_of_transformer_layers: int,
|
7 |
+
embed_dim: int):
|
8 |
+
super().__init__()
|
9 |
+
self.resolution = resolution
|
10 |
+
self.embed_dim = embed_dim
|
11 |
+
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=number_of_features,
|
12 |
+
kernel_size=patch_size, stride=patch_size, bias=False)
|
13 |
+
self.number_of_patches = (resolution // patch_size) ** 2
|
14 |
+
self.positional_embedding = nn.Parameter(torch.zeros(1, self.number_of_patches + 1, number_of_features))
|
15 |
+
self.class_embedding = nn.Parameter(torch.zeros(1, 1, number_of_features))
|
16 |
+
|
17 |
+
self.ln_pre = nn.LayerNorm(number_of_features)
|
18 |
+
self.transformer = nn.TransformerEncoder(
|
19 |
+
nn.TransformerEncoderLayer(d_model=number_of_features, nhead=number_of_heads, batch_first=True),
|
20 |
+
num_layers=number_of_transformer_layers
|
21 |
+
)
|
22 |
+
|
23 |
+
self.ln_post = nn.LayerNorm(number_of_features)
|
24 |
+
self.fc = nn.Linear(number_of_features, embed_dim)
|
25 |
+
|
26 |
+
# initialize
|
27 |
+
nn.init.kaiming_normal_(self.positional_embedding, nonlinearity='relu')
|
28 |
+
nn.init.kaiming_normal_(self.class_embedding, nonlinearity='relu')
|
29 |
+
nn.init.kaiming_normal_(self.fc.weight, nonlinearity='relu')
|
30 |
+
|
31 |
+
def forward(self, x: torch.Tensor):
|
32 |
+
x = self.conv(x) # [batch_size, number_of_features, grid, grid]
|
33 |
+
x = x.flatten(2) # [batch_size, number_of_features, grid ** 2 = number_of_patches]
|
34 |
+
x = x.transpose(1, 2) # [batch_size, number_of_patches, number_of_features]
|
35 |
+
|
36 |
+
class_embeddings = self.class_embedding.expand(x.shape[0], -1, -1)
|
37 |
+
x = torch.cat([class_embeddings, x], dim=1)
|
38 |
+
x = x + self.positional_embedding
|
39 |
+
x = self.ln_pre(x)
|
40 |
+
x = self.transformer(x) # [batch_size, length_of_sequence, number_of_features]
|
41 |
+
x = x.permute(1, 0, 2) # [length_of_sequence, batch_size, number_of_features]
|
42 |
+
x = self.ln_post(x[0])
|
43 |
+
x = self.fc(x) # [batch_size, embed_dim]
|
44 |
+
return x
|
clip/encoders/text_encoder.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class TextEncoder(nn.Module):
|
5 |
+
def __init__(self, number_of_features: int, number_of_heads: int, number_of_transformer_layers: int,
|
6 |
+
context_length, embed_dim):
|
7 |
+
super().__init__()
|
8 |
+
self.vocab_size = 32000 # AutoTokenizer: "koclip/koclip-base-pt"
|
9 |
+
self.token_embedding = nn.Embedding(self.vocab_size, number_of_features)
|
10 |
+
self.positional_embedding = nn.Parameter(torch.zeros(context_length, number_of_features))
|
11 |
+
self.transformer = nn.TransformerEncoder(
|
12 |
+
nn.TransformerEncoderLayer(d_model=number_of_features, nhead=number_of_heads, batch_first=True),
|
13 |
+
num_layers=number_of_transformer_layers
|
14 |
+
)
|
15 |
+
self.text_projection = nn.Linear(number_of_features, embed_dim)
|
16 |
+
|
17 |
+
# initialize
|
18 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
19 |
+
nn.init.xavier_uniform_(self.positional_embedding)
|
20 |
+
nn.init.kaiming_normal_(self.text_projection.weight, nonlinearity='relu')
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
eot_token_idx = (x == 2).nonzero(as_tuple=True)[1] # Assume EOT token ID is 2
|
24 |
+
x = self.token_embedding(x)
|
25 |
+
x = x + self.positional_embedding[:x.size(1), :]
|
26 |
+
x = self.transformer(x)
|
27 |
+
x = x[torch.arange(x.shape[0]), eot_token_idx]
|
28 |
+
x = self.text_projection(x)
|
29 |
+
return x
|
clip/models/clip.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import yaml
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from clip.encoders.image_encoder import ImageEncoder
|
8 |
+
from clip.encoders.text_encoder import TextEncoder
|
9 |
+
from helper.tokenizer import Tokenizer
|
10 |
+
|
11 |
+
class CLIP(nn.Module):
|
12 |
+
def __init__(self, config_path):
|
13 |
+
super().__init__()
|
14 |
+
with open(config_path, "r") as file:
|
15 |
+
config = yaml.safe_load(file)
|
16 |
+
|
17 |
+
self.image_encoder = ImageEncoder(**config["image_encoder"])
|
18 |
+
self.text_encoder = TextEncoder(**config["text_encoder"])
|
19 |
+
self.tokenizer = Tokenizer()
|
20 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
21 |
+
|
22 |
+
# initialize
|
23 |
+
for module in self.modules():
|
24 |
+
if isinstance(module, nn.Linear):
|
25 |
+
nn.init.xavier_normal_(module.weight)
|
26 |
+
if module.bias is not None:
|
27 |
+
nn.init.constant_(module.bias, 0)
|
28 |
+
elif isinstance(module, nn.Conv2d):
|
29 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
30 |
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
31 |
+
nn.init.constant_(module.weight, 1)
|
32 |
+
nn.init.constant_(module.bias, 0)
|
33 |
+
|
34 |
+
def loss(self, image, text):
|
35 |
+
image_features, text_features = self(image, text, tokenize=False)
|
36 |
+
|
37 |
+
# Normalize features
|
38 |
+
image_features = F.normalize(image_features, dim=1)
|
39 |
+
text_features = F.normalize(text_features, dim=1)
|
40 |
+
|
41 |
+
# Cosine similarity as logits with learned temperature
|
42 |
+
logits = torch.matmul(image_features, text_features.t()) * self.logit_scale.exp()
|
43 |
+
labels = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device)
|
44 |
+
|
45 |
+
# Cross-entropy loss
|
46 |
+
loss_i2t = F.cross_entropy(logits, labels)
|
47 |
+
loss_t2i = F.cross_entropy(logits.t(), labels)
|
48 |
+
|
49 |
+
return (loss_i2t + loss_t2i) / 2
|
50 |
+
|
51 |
+
def text_encode(self, text, tokenize=True):
|
52 |
+
if tokenize:
|
53 |
+
tokens = self.tokenizer.tokenize(text)
|
54 |
+
else:
|
55 |
+
tokens = text
|
56 |
+
text_features = self.text_encoder(tokens)
|
57 |
+
if text_features.dim() < 2:
|
58 |
+
text_features = text_features.unsqueeze(0)
|
59 |
+
return text_features
|
60 |
+
|
61 |
+
def forward(self, image, text, tokenize=True):
|
62 |
+
image_features = self.image_encoder(image)
|
63 |
+
text_features = self.text_encoder(text, tokenize)
|
64 |
+
|
65 |
+
if image_features.dim() < 2:
|
66 |
+
image_features = image_features.unsqueeze(0)
|
67 |
+
if text_features.dim() < 2:
|
68 |
+
text_features = text_features.unsqueeze(0)
|
69 |
+
|
70 |
+
return image_features, text_features
|
clip/models/ko_clip.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from transformers import AutoModel, AutoTokenizer
|
4 |
+
|
5 |
+
class KoCLIPWrapper(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.model_name = "Bingsu/clip-vit-base-patch32-ko"
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
10 |
+
self.model = AutoModel.from_pretrained(self.model_name)
|
11 |
+
|
12 |
+
def loss(self, inputs):
|
13 |
+
outputs = self(inputs)
|
14 |
+
return outputs.loss
|
15 |
+
|
16 |
+
def text_encode(self, text, tokenize):
|
17 |
+
if tokenize:
|
18 |
+
tokens = self.tokenizer(text, padding='max_length', max_length=77, truncation=True, return_tensors="pt")
|
19 |
+
else:
|
20 |
+
tokens = text
|
21 |
+
tokens = tokens.to(self.model.device)
|
22 |
+
return self.model.get_text_features(**tokens)
|
23 |
+
|
24 |
+
def forward(self, inputs):
|
25 |
+
outputs = self.model(**inputs, return_loss=True)
|
26 |
+
return outputs # [1, 512] , [1, 512]
|
configs/composite_clip_config.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text_encoder:
|
2 |
+
number_of_features: 512
|
3 |
+
number_of_heads: 8
|
4 |
+
number_of_transformer_layers: 6
|
5 |
+
context_length: 77
|
6 |
+
embed_dim: 128
|
7 |
+
|
8 |
+
image_encoder:
|
9 |
+
in_channels: 3
|
10 |
+
resolution: 256
|
11 |
+
patch_size: 16
|
12 |
+
number_of_features: 768
|
13 |
+
number_of_heads: 12
|
14 |
+
number_of_transformer_layers: 4
|
15 |
+
embed_dim: 128
|
configs/composite_config.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
encoder:
|
2 |
+
in_channels: 3
|
3 |
+
resolution: 256
|
4 |
+
channels: 128
|
5 |
+
channel_multipliers: [1, 2, 4]
|
6 |
+
z_channels: 3
|
7 |
+
num_res_blocks: 2
|
8 |
+
dropout: 0.0
|
9 |
+
|
10 |
+
decoder:
|
11 |
+
in_channels: 3
|
12 |
+
out_channels: 3
|
13 |
+
resolution: 256
|
14 |
+
channels: 128
|
15 |
+
channel_multipliers: [1, 2, 4]
|
16 |
+
z_channels: 6
|
17 |
+
num_res_blocks: 2
|
18 |
+
dropout: 0.0
|
19 |
+
|
20 |
+
vae:
|
21 |
+
embed_dim: 3
|
22 |
+
kld_weight: 1e-6
|
23 |
+
|
24 |
+
sampler:
|
25 |
+
beta: 'sigmoid'
|
26 |
+
T: 1000
|
27 |
+
sampling_T: 50
|
28 |
+
eta: 1
|
29 |
+
|
30 |
+
cond_encoder:
|
31 |
+
embed_dim: 512
|
32 |
+
cond_dim: 768
|
33 |
+
cond_drop_prob: 0.2
|
34 |
+
|
35 |
+
unet:
|
36 |
+
dim: 192
|
37 |
+
dim_mults: [1, 2, 4, 8]
|
38 |
+
cond_dim: 768
|
diffusion_model/models/clip_latent_diffusion_model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
|
5 |
+
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
|
6 |
+
from clip.models.clip import CLIP
|
7 |
+
|
8 |
+
class CLIPLatentDiffusionModel(LatentDiffusionModel) :
|
9 |
+
def __init__(self, network : nn.Module, sampler : nn.Module,
|
10 |
+
auto_encoder : VariationalAutoEncoder, clip : CLIP, image_shape):
|
11 |
+
super().__init__(network, sampler, auto_encoder, image_shape)
|
12 |
+
self.clip = clip
|
13 |
+
self.clip.eval()
|
14 |
+
for param in self.clip.parameters():
|
15 |
+
param.requires_grad = False
|
16 |
+
|
17 |
+
def loss(self, x0, text):
|
18 |
+
text = self.clip.text_encode(text, tokenize=False)
|
19 |
+
x0 = self.auto_encoder.encode(x0).sample()
|
20 |
+
eps = torch.randn_like(x0)
|
21 |
+
t = torch.randint(0, self.T, (x0.size(0),), device = x0.device)
|
22 |
+
x_t = self.sampler.q_sample(x0, t, eps)
|
23 |
+
eps_hat = self.network(x=x_t, t=t, y=text)
|
24 |
+
return self.weighted_loss(t, eps, eps_hat)
|
25 |
+
|
26 |
+
@torch.no_grad()
|
27 |
+
def forward(self, text, n_samples : int = 4):
|
28 |
+
text = self.clip.text_encode(text)
|
29 |
+
text = text.repeat(n_samples, 1)
|
30 |
+
x_T = torch.randn(n_samples, *self.latent_shape, device = next(self.buffers(), None).device )
|
31 |
+
sample = self.sampler(x_T = x_T, y=text)
|
32 |
+
return self.auto_encoder.decode(sample)
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def generate_sequence(self, text, n_samples : int = 4):
|
36 |
+
text = self.clip.text_encode(text)
|
37 |
+
text = text.repeat(n_samples, 1)
|
38 |
+
x_T = torch.randn(n_samples, *self.latent_shape, device = next(self.buffers(), None).device )
|
39 |
+
sample_sequence = self.sampler.reverse_process(x_T, y = text, only_last=False)
|
40 |
+
return sample_sequence
|
diffusion_model/models/diffusion_model.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import reduce
|
4 |
+
|
5 |
+
from helper.util import extract
|
6 |
+
|
7 |
+
class DiffusionModel(nn.Module) :
|
8 |
+
def __init__(self, network : nn.Module, sampler : nn.Module, image_shape):
|
9 |
+
super().__init__()
|
10 |
+
self.add_module('sampler', sampler)
|
11 |
+
self.add_module('network', network)
|
12 |
+
self.sampler.set_network(network)
|
13 |
+
self.T = sampler.T
|
14 |
+
self.image_shape = image_shape
|
15 |
+
|
16 |
+
# loss weight
|
17 |
+
alpha_bar = self.sampler.alpha_bar
|
18 |
+
snr = alpha_bar / (1 - alpha_bar)
|
19 |
+
clipped_snr = snr.clone()
|
20 |
+
clipped_snr.clamp_(max = 5)
|
21 |
+
self.register_buffer('loss_weight', clipped_snr / snr)
|
22 |
+
|
23 |
+
def weighted_loss(self, t, eps, eps_hat):
|
24 |
+
loss = nn.functional.mse_loss(eps, eps_hat, reduction='none')
|
25 |
+
loss = reduce(loss, 'b ... -> b', 'mean')
|
26 |
+
loss = loss * extract(self.loss_weight, t, loss.shape)
|
27 |
+
return loss.mean()
|
28 |
+
|
29 |
+
def loss(self, x0, **kwargs):
|
30 |
+
eps = torch.randn_like(x0)
|
31 |
+
t = torch.randint(0, self.T, (x0.size(0),), device = x0.device)
|
32 |
+
x_t = self.sampler.q_sample(x0, t, eps)
|
33 |
+
eps_hat = self.network(x = x_t, t = t, **kwargs)
|
34 |
+
return self.weighted_loss(t, eps, eps_hat)
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def forward(self, n_samples: int = 4, only_last: bool = True, gamma = None, **kwargs):
|
38 |
+
"""
|
39 |
+
If only_last is False, the outputs will be the sequnece of the generated points
|
40 |
+
"""
|
41 |
+
x_T = torch.randn(n_samples, *self.image_shape, device = next(self.buffers(), None).device)
|
42 |
+
return self.sampler(x_T = x_T, only_last=only_last, gamma = gamma, **kwargs)
|
diffusion_model/models/latent_diffusion_model.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
|
5 |
+
from diffusion_model.models.diffusion_model import DiffusionModel
|
6 |
+
|
7 |
+
class LatentDiffusionModel(DiffusionModel) :
|
8 |
+
def __init__(self, network : nn.Module, sampler : nn.Module, auto_encoder : VariationalAutoEncoder):
|
9 |
+
super().__init__(network, sampler, None)
|
10 |
+
self.auto_encoder = auto_encoder
|
11 |
+
self.auto_encoder.eval()
|
12 |
+
for param in self.auto_encoder.parameters():
|
13 |
+
param.requires_grad = False
|
14 |
+
# The image shape is the latent shape
|
15 |
+
self.image_shape = [*self.auto_encoder.decoder.z_shape[1:]]
|
16 |
+
self.image_shape[0] = self.auto_encoder.embed_dim
|
17 |
+
|
18 |
+
def loss(self, x0, **kwargs):
|
19 |
+
x0 = self.auto_encoder.encode(x0).sample()
|
20 |
+
eps = torch.randn_like(x0)
|
21 |
+
t = torch.randint(0, self.T, (x0.size(0),), device = x0.device)
|
22 |
+
x_t = self.sampler.q_sample(x0, t, eps)
|
23 |
+
eps_hat = self.network(x = x_t, t = t, **kwargs)
|
24 |
+
return self.weighted_loss(t, eps, eps_hat)
|
25 |
+
|
26 |
+
# The forward function outputs the generated latents
|
27 |
+
# Therefore, sample() should be used for sampling data, not latents
|
28 |
+
@torch.no_grad()
|
29 |
+
def sample(self, n_samples: int = 4, gamma = None, **kwargs):
|
30 |
+
sample = self(n_samples, gamma=gamma, **kwargs)
|
31 |
+
return self.auto_encoder.decode(sample)
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def generate_sequence(self, n_samples: int = 4, gamma = None, **kwargs):
|
35 |
+
sequence = self(n_samples, only_last=False, gamma = gamma, **kwargs)
|
36 |
+
sample = self.auto_encoder.decode(sequence[-1])
|
37 |
+
return sequence, sample
|
diffusion_model/network/attention.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/attend.py
|
2 |
+
from functools import wraps
|
3 |
+
from packaging import version
|
4 |
+
from collections import namedtuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn, einsum
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
# constants
|
14 |
+
|
15 |
+
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
16 |
+
|
17 |
+
# helpers
|
18 |
+
|
19 |
+
def exists(val):
|
20 |
+
return val is not None
|
21 |
+
|
22 |
+
def default(val, d):
|
23 |
+
return val if exists(val) else d
|
24 |
+
|
25 |
+
def once(fn):
|
26 |
+
called = False
|
27 |
+
@wraps(fn)
|
28 |
+
def inner(x):
|
29 |
+
nonlocal called
|
30 |
+
if called:
|
31 |
+
return
|
32 |
+
called = True
|
33 |
+
return fn(x)
|
34 |
+
return inner
|
35 |
+
|
36 |
+
print_once = once(print)
|
37 |
+
|
38 |
+
class RMSNorm(nn.Module):
|
39 |
+
def __init__(self, dim):
|
40 |
+
super().__init__()
|
41 |
+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
|
45 |
+
|
46 |
+
# main class
|
47 |
+
|
48 |
+
class Attend(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
dropout = 0.,
|
52 |
+
flash = False,
|
53 |
+
scale = None
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
self.dropout = dropout
|
57 |
+
self.scale = scale
|
58 |
+
self.attn_dropout = nn.Dropout(dropout)
|
59 |
+
|
60 |
+
self.flash = flash
|
61 |
+
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
62 |
+
|
63 |
+
# determine efficient attention configs for cuda and cpu
|
64 |
+
|
65 |
+
self.cpu_config = AttentionConfig(True, True, True)
|
66 |
+
self.cuda_config = None
|
67 |
+
|
68 |
+
if not torch.cuda.is_available() or not flash:
|
69 |
+
return
|
70 |
+
|
71 |
+
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
72 |
+
|
73 |
+
device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
|
74 |
+
|
75 |
+
if device_version > version.parse('8.0'):
|
76 |
+
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
|
77 |
+
self.cuda_config = AttentionConfig(True, False, False)
|
78 |
+
else:
|
79 |
+
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
|
80 |
+
self.cuda_config = AttentionConfig(False, True, True)
|
81 |
+
|
82 |
+
def flash_attn(self, q, k, v):
|
83 |
+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
84 |
+
|
85 |
+
if exists(self.scale):
|
86 |
+
default_scale = q.shape[-1]
|
87 |
+
q = q * (self.scale / default_scale)
|
88 |
+
|
89 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
90 |
+
|
91 |
+
# Check if there is a compatible device for flash attention
|
92 |
+
|
93 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
94 |
+
|
95 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
96 |
+
|
97 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
98 |
+
out = F.scaled_dot_product_attention(
|
99 |
+
q, k, v,
|
100 |
+
dropout_p = self.dropout if self.training else 0.
|
101 |
+
)
|
102 |
+
|
103 |
+
return out
|
104 |
+
|
105 |
+
def forward(self, q, k, v):
|
106 |
+
"""
|
107 |
+
einstein notation
|
108 |
+
b - batch
|
109 |
+
h - heads
|
110 |
+
n, i, j - sequence length (base sequence length, source, target)
|
111 |
+
d - feature dimension
|
112 |
+
"""
|
113 |
+
|
114 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
115 |
+
|
116 |
+
if self.flash:
|
117 |
+
return self.flash_attn(q, k, v)
|
118 |
+
|
119 |
+
scale = default(self.scale, q.shape[-1] ** -0.5)
|
120 |
+
|
121 |
+
# similarity
|
122 |
+
|
123 |
+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
124 |
+
|
125 |
+
# attention
|
126 |
+
|
127 |
+
attn = sim.softmax(dim = -1)
|
128 |
+
attn = self.attn_dropout(attn)
|
129 |
+
|
130 |
+
# aggregate values
|
131 |
+
|
132 |
+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
133 |
+
|
134 |
+
return out
|
135 |
+
|
136 |
+
class LinearAttention(nn.Module):
|
137 |
+
def __init__(self, dim, heads = 4, dim_head = 32):
|
138 |
+
super().__init__()
|
139 |
+
self.scale = dim_head ** -0.5
|
140 |
+
self.heads = heads
|
141 |
+
hidden_dim = dim_head * heads
|
142 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
143 |
+
|
144 |
+
self.to_out = nn.Sequential(
|
145 |
+
nn.Conv2d(hidden_dim, dim, 1),
|
146 |
+
RMSNorm(dim)
|
147 |
+
)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
b, c, h, w = x.shape
|
151 |
+
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
152 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
|
153 |
+
|
154 |
+
q = q.softmax(dim = -2)
|
155 |
+
k = k.softmax(dim = -1)
|
156 |
+
|
157 |
+
q = q * self.scale
|
158 |
+
|
159 |
+
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
|
160 |
+
|
161 |
+
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
|
162 |
+
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
|
163 |
+
return self.to_out(out)
|
164 |
+
|
165 |
+
class Attention(nn.Module):
|
166 |
+
def __init__(self, dim, heads = 4, dim_head = 32):
|
167 |
+
super().__init__()
|
168 |
+
self.scale = dim_head ** -0.5
|
169 |
+
self.heads = heads
|
170 |
+
hidden_dim = dim_head * heads
|
171 |
+
|
172 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
173 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
b, c, h, w = x.shape
|
177 |
+
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
178 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
|
179 |
+
|
180 |
+
q = q * self.scale
|
181 |
+
|
182 |
+
sim = einsum('b h d i, b h d j -> b h i j', q, k)
|
183 |
+
attn = sim.softmax(dim = -1)
|
184 |
+
out = einsum('b h i j, b h d j -> b h i d', attn, v)
|
185 |
+
|
186 |
+
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
187 |
+
return self.to_out(out)
|
diffusion_model/network/blocks.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class AdaptiveGroupNorm(nn.Module):
|
5 |
+
def __init__(self, num_groups, num_channels, emb_dim, eps=1e-5):
|
6 |
+
super().__init__()
|
7 |
+
self.num_groups = num_groups
|
8 |
+
self.num_channels = num_channels
|
9 |
+
self.eps = eps
|
10 |
+
# Use a standard GroupNorm, but without learnable affine parameters
|
11 |
+
self.norm = nn.GroupNorm(num_groups, num_channels, eps=eps, affine=False)
|
12 |
+
|
13 |
+
# Linear layers to project the embedding to gamma and beta
|
14 |
+
self.gamma_proj = nn.Linear(emb_dim, num_channels)
|
15 |
+
self.beta_proj = nn.Linear(emb_dim, num_channels)
|
16 |
+
|
17 |
+
def forward(self, x, emb):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
x: Input tensor of shape [B, C, H, W].
|
21 |
+
emb: Embedding tensor of shape [B, emb_dim].
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Normalized tensor with adaptive scaling and shifting.
|
25 |
+
"""
|
26 |
+
# Normalize as usual with GroupNorm
|
27 |
+
normalized = self.norm(x)
|
28 |
+
|
29 |
+
# Get gamma and beta from the embedding
|
30 |
+
gamma = self.gamma_proj(emb)
|
31 |
+
beta = self.beta_proj(emb)
|
32 |
+
|
33 |
+
# Reshape for broadcasting: [B, C] -> [B, C, 1, 1]
|
34 |
+
gamma = gamma.view(-1, self.num_channels, 1, 1)
|
35 |
+
beta = beta.view(-1, self.num_channels, 1, 1)
|
36 |
+
|
37 |
+
# Apply adaptive scaling and shifting
|
38 |
+
return gamma * normalized + beta
|
39 |
+
|
40 |
+
class DepthwiseSeparableConv2d(nn.Module):
|
41 |
+
def __init__(self, dim_in, dim_out, kernel_size, padding):
|
42 |
+
super().__init__()
|
43 |
+
self.depthwise = nn.Conv2d(dim_in, dim_in, kernel_size, padding=padding, groups=dim_in)
|
44 |
+
self.pointwise = nn.Conv2d(dim_in, dim_out, 1) # 1x1 convolution
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.depthwise(x)
|
48 |
+
x = self.pointwise(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class Block(nn.Module):
|
52 |
+
def __init__(self, dim, dim_out, groups, emb_dim, dropout=0.0, use_depthwise=False):
|
53 |
+
super().__init__()
|
54 |
+
self.norm = AdaptiveGroupNorm(groups, dim, emb_dim)
|
55 |
+
if use_depthwise:
|
56 |
+
self.proj = DepthwiseSeparableConv2d(dim, dim_out, kernel_size=3, padding=1)
|
57 |
+
else:
|
58 |
+
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
59 |
+
self.act = nn.SiLU()
|
60 |
+
self.dropout = nn.Dropout(dropout)
|
61 |
+
|
62 |
+
def forward(self, x, emb):
|
63 |
+
x = self.norm(x, emb) # Pre-normalization
|
64 |
+
x = self.proj(x)
|
65 |
+
x = self.act(x)
|
66 |
+
return self.dropout(x)
|
67 |
+
|
68 |
+
class ResnetBlock(nn.Module):
|
69 |
+
def __init__(self, dim: int, dim_out: int, t_emb_dim: int, *,
|
70 |
+
y_emb_dim: int = None, groups: int = 32, dropout: float = 0.0, residual_scale=1.0):
|
71 |
+
super().__init__()
|
72 |
+
if y_emb_dim is None:
|
73 |
+
y_emb_dim = 0
|
74 |
+
emb_dim = t_emb_dim + y_emb_dim
|
75 |
+
|
76 |
+
self.block1 = Block(dim, dim_out, groups, emb_dim, dropout) # Pass emb_dim
|
77 |
+
self.block2 = Block(dim_out, dim_out, groups, emb_dim, dropout) # Pass emb_dim
|
78 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
79 |
+
self.residual_scale = nn.Parameter(torch.tensor(residual_scale))
|
80 |
+
|
81 |
+
def forward(self, x, t_emb, y_emb=None):
|
82 |
+
cond_emb = t_emb
|
83 |
+
if y_emb is not None:
|
84 |
+
cond_emb = torch.cat([cond_emb, y_emb], dim=-1)
|
85 |
+
|
86 |
+
h = self.block1(x, cond_emb) # Pass combined embedding to Block
|
87 |
+
h = self.block2(h, cond_emb) # Pass combined embedding to Block
|
88 |
+
|
89 |
+
return self.residual_scale * h + self.res_conv(x) # Scale the residual
|
diffusion_model/network/timestep_embedding.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
class SinusoidalEmbedding(nn.Module):
|
6 |
+
def __init__(self, embed_dim : int, theta : int = 10000):
|
7 |
+
"""
|
8 |
+
Creates sinusoidal embeddings for timesteps.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
embed_dim: The dimensionality of the embedding.
|
12 |
+
theta: The base for the log-spaced frequencies.
|
13 |
+
"""
|
14 |
+
super().__init__()
|
15 |
+
self.embed_dim = embed_dim
|
16 |
+
self.theta = theta
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
"""
|
20 |
+
Computes sinusoidal embeddings for the input timesteps.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
x: A 1D torch.Tensor of timesteps (shape: [batch_size]).
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
A torch.Tensor of sinusoidal embeddings (shape: [batch_size, embed_dim]).
|
27 |
+
"""
|
28 |
+
assert isinstance(x, torch.Tensor) # Input must be a torch.Tensor
|
29 |
+
assert x.ndim == 1 # Input must be a 1D tensor
|
30 |
+
assert isinstance(self.embed_dim, int) and self.embed_dim > 0 # embed_dim must be a positive integer
|
31 |
+
|
32 |
+
half_dim = self.embed_dim // 2
|
33 |
+
# Create a sequence of log-spaced frequencies
|
34 |
+
embeddings = math.log(self.theta) / (half_dim - 1)
|
35 |
+
embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings)
|
36 |
+
# Outer product: timesteps x frequencies
|
37 |
+
embeddings = x[:, None] * embeddings[None, :]
|
38 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
39 |
+
# Handle odd embedding dimensions
|
40 |
+
if self.embed_dim % 2 == 1:
|
41 |
+
embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1)
|
42 |
+
return embeddings
|
diffusion_model/network/unet.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Module, ModuleList
|
7 |
+
from diffusion_model.network.attention import LinearAttention, Attention
|
8 |
+
from diffusion_model.network.timestep_embedding import SinusoidalEmbedding
|
9 |
+
from diffusion_model.network.blocks import ResnetBlock
|
10 |
+
|
11 |
+
def exists(x):
|
12 |
+
return x is not None
|
13 |
+
|
14 |
+
def default(val, d):
|
15 |
+
if exists(val):
|
16 |
+
return val
|
17 |
+
return d() if callable(d) else d
|
18 |
+
|
19 |
+
def cast_tuple(t, length = 1):
|
20 |
+
if isinstance(t, tuple):
|
21 |
+
return t
|
22 |
+
return ((t,) * length)
|
23 |
+
|
24 |
+
def divisible_by(numer, denom):
|
25 |
+
return (numer % denom) == 0
|
26 |
+
|
27 |
+
# small helper modules
|
28 |
+
|
29 |
+
class DownSample(nn.Module):
|
30 |
+
def __init__(self, dim: int, dim_out: int):
|
31 |
+
"""
|
32 |
+
Downsamples the spatial dimensions by a factor of 2 using a strided convolution.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
dim: Input channel dimension.
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
self.downsample = nn.Conv2d(dim, dim_out, kernel_size=4, stride=2, padding=1)
|
39 |
+
|
40 |
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
41 |
+
"""
|
42 |
+
Forward pass.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
x: Input tensor of shape [B, C, H, W].
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Downsampled tensor of shape [B, C, H/2, W/2].
|
49 |
+
"""
|
50 |
+
return self.downsample(x)
|
51 |
+
|
52 |
+
class UpSample(nn.Module):
|
53 |
+
def __init__(self, dim: int, dim_out: int):
|
54 |
+
"""
|
55 |
+
Upsamples the spatial dimensions by a factor of 2 using a transposed convolution.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
dim: Input channel dimension.
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
self.upsample = nn.ConvTranspose2d(dim, dim_out, kernel_size=4, stride=2, padding=1)
|
62 |
+
|
63 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
64 |
+
"""
|
65 |
+
Forward pass.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
x: Input tensor of shape [B, C, H, W].
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Upsampled tensor of shape [B, C, 2*H, 2*W].
|
72 |
+
"""
|
73 |
+
return self.upsample(x)
|
74 |
+
|
75 |
+
|
76 |
+
# model
|
77 |
+
|
78 |
+
class Unet(Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
dim,
|
82 |
+
init_dim = None,
|
83 |
+
out_dim = None,
|
84 |
+
cond_dim = None,
|
85 |
+
dim_mults = (1, 2, 4, 8),
|
86 |
+
channels = 3,
|
87 |
+
dropout = 0.,
|
88 |
+
attn_dim_head = 32,
|
89 |
+
attn_heads = 4,
|
90 |
+
full_attn = None, # defaults to full attention only for inner most layer
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
# determine dimensions
|
95 |
+
|
96 |
+
self.channels = channels
|
97 |
+
input_channels = channels
|
98 |
+
|
99 |
+
init_dim = default(init_dim, dim)
|
100 |
+
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
|
101 |
+
|
102 |
+
dims = [*map(lambda m: dim * m, dim_mults)]
|
103 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
104 |
+
|
105 |
+
# time embeddings
|
106 |
+
time_dim = dim * 4
|
107 |
+
|
108 |
+
sinu_pos_emb = SinusoidalEmbedding(dim)
|
109 |
+
|
110 |
+
self.time_mlp = nn.Sequential(
|
111 |
+
sinu_pos_emb,
|
112 |
+
nn.Linear(dim, time_dim),
|
113 |
+
nn.GELU(),
|
114 |
+
nn.Linear(time_dim, time_dim)
|
115 |
+
)
|
116 |
+
|
117 |
+
# attention
|
118 |
+
|
119 |
+
if not full_attn:
|
120 |
+
full_attn = (*((False,) * (len(dim_mults) - 1)), True)
|
121 |
+
|
122 |
+
num_stages = len(dim_mults)
|
123 |
+
full_attn = cast_tuple(full_attn, num_stages)
|
124 |
+
attn_heads = cast_tuple(attn_heads, num_stages)
|
125 |
+
attn_dim_head = cast_tuple(attn_dim_head, num_stages)
|
126 |
+
|
127 |
+
assert len(full_attn) == len(dim_mults)
|
128 |
+
|
129 |
+
# prepare blocks
|
130 |
+
|
131 |
+
FullAttention = Attention
|
132 |
+
resnet_block = partial(ResnetBlock,
|
133 |
+
t_emb_dim = time_dim, y_emb_dim = cond_dim, dropout = dropout)
|
134 |
+
|
135 |
+
# layers
|
136 |
+
|
137 |
+
self.downs = ModuleList([])
|
138 |
+
self.ups = ModuleList([])
|
139 |
+
num_resolutions = len(in_out)
|
140 |
+
|
141 |
+
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
|
142 |
+
is_last = ind >= (num_resolutions - 1)
|
143 |
+
|
144 |
+
attn_klass = FullAttention if layer_full_attn else LinearAttention
|
145 |
+
|
146 |
+
self.downs.append(ModuleList([
|
147 |
+
resnet_block(dim_in, dim_in),
|
148 |
+
resnet_block(dim_in, dim_in),
|
149 |
+
attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
|
150 |
+
DownSample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
|
151 |
+
]))
|
152 |
+
|
153 |
+
mid_dim = dims[-1]
|
154 |
+
self.mid_block1 = resnet_block(mid_dim, mid_dim)
|
155 |
+
self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
|
156 |
+
self.mid_block2 = resnet_block(mid_dim, mid_dim)
|
157 |
+
|
158 |
+
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
|
159 |
+
is_last = ind == (len(in_out) - 1)
|
160 |
+
|
161 |
+
attn_klass = FullAttention if layer_full_attn else LinearAttention
|
162 |
+
|
163 |
+
self.ups.append(ModuleList([
|
164 |
+
resnet_block(dim_out + dim_in, dim_out),
|
165 |
+
resnet_block(dim_out + dim_in, dim_out),
|
166 |
+
attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
|
167 |
+
UpSample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
|
168 |
+
]))
|
169 |
+
|
170 |
+
default_out_dim = channels
|
171 |
+
self.out_dim = default(out_dim, default_out_dim)
|
172 |
+
|
173 |
+
self.final_res_block = resnet_block(init_dim * 2, init_dim)
|
174 |
+
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)
|
175 |
+
|
176 |
+
@property
|
177 |
+
def downsample_factor(self):
|
178 |
+
return 2 ** (len(self.downs) - 1)
|
179 |
+
|
180 |
+
def forward(self, x, t, y = None):
|
181 |
+
assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'
|
182 |
+
|
183 |
+
x = self.init_conv(x)
|
184 |
+
r = x.clone()
|
185 |
+
|
186 |
+
t = self.time_mlp(t)
|
187 |
+
|
188 |
+
h = []
|
189 |
+
|
190 |
+
for block1, block2, attn, downsample in self.downs:
|
191 |
+
x = block1(x, t, y)
|
192 |
+
h.append(x)
|
193 |
+
|
194 |
+
x = block2(x, t, y)
|
195 |
+
x = attn(x) + x
|
196 |
+
h.append(x)
|
197 |
+
|
198 |
+
x = downsample(x)
|
199 |
+
|
200 |
+
x = self.mid_block1(x, t, y)
|
201 |
+
x = self.mid_attn(x) + x
|
202 |
+
x = self.mid_block2(x, t, y)
|
203 |
+
|
204 |
+
for block1, block2, attn, upsample in self.ups:
|
205 |
+
x = torch.cat((x, h.pop()), dim = 1)
|
206 |
+
x = block1(x, t, y)
|
207 |
+
|
208 |
+
x = torch.cat((x, h.pop()), dim = 1)
|
209 |
+
x = block2(x, t, y)
|
210 |
+
x = attn(x) + x
|
211 |
+
|
212 |
+
x = upsample(x)
|
213 |
+
|
214 |
+
x = torch.cat((x, r), dim = 1)
|
215 |
+
|
216 |
+
x = self.final_res_block(x, t, y)
|
217 |
+
return self.final_conv(x)
|
diffusion_model/network/unet_wrapper.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import yaml
|
4 |
+
import transformers
|
5 |
+
|
6 |
+
class UnetWrapper(nn.Module):
|
7 |
+
def __init__(self, Unet: nn.Module, config_path: str,
|
8 |
+
cond_encoder = None):
|
9 |
+
super().__init__()
|
10 |
+
with open(config_path, "r") as file:
|
11 |
+
config = yaml.safe_load(file)['unet']
|
12 |
+
self.add_module('network', Unet(**config))
|
13 |
+
|
14 |
+
# ConditionalEncoder
|
15 |
+
self.add_module('cond_encoder', cond_encoder)
|
16 |
+
|
17 |
+
def forward(self, x, t, y=None, cond_drop_all:bool = False):
|
18 |
+
if t.dim() == 0:
|
19 |
+
t = x.new_full((x.size(0), ), t, dtype = torch.int, device = x.device)
|
20 |
+
if y is not None:
|
21 |
+
assert self.cond_encoder is not None, 'You need to set ConditionalEncoder for conditional sampling.'
|
22 |
+
if isinstance(y, str) or isinstance(y, transformers.tokenization_utils_base.BatchEncoding):
|
23 |
+
y = self.cond_encoder(y, cond_drop_all=cond_drop_all).to(x.device)
|
24 |
+
else:
|
25 |
+
if torch.is_tensor(y) == False:
|
26 |
+
y = torch.tensor([y], device=x.device)
|
27 |
+
y = self.cond_encoder(y, cond_drop_all=cond_drop_all).squeeze()
|
28 |
+
if y.size(0) != x.size(0):
|
29 |
+
y = y.repeat(x.size(0), 1)
|
30 |
+
return self.network(x, t, y)
|
31 |
+
else:
|
32 |
+
return self.network(x, t)
|
diffusion_model/sampler/base_sampler.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from tqdm import tqdm
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
from helper.util import extract
|
7 |
+
from helper.beta_generator import BetaGenerator
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
|
10 |
+
class BaseSampler(nn.Module, ABC):
|
11 |
+
def __init__(self, config_path : str):
|
12 |
+
super().__init__()
|
13 |
+
with open(config_path, "r") as file:
|
14 |
+
self.config = yaml.safe_load(file)['sampler']
|
15 |
+
self.T = self.config['T']
|
16 |
+
beta_generator = BetaGenerator(T=self.T)
|
17 |
+
self.timesteps = None
|
18 |
+
|
19 |
+
self.register_buffer('beta', getattr(beta_generator,
|
20 |
+
f"{self.config['beta']}_beta_schedule",
|
21 |
+
beta_generator.linear_beta_schedule)())
|
22 |
+
|
23 |
+
self.register_buffer('alpha', 1 - self.beta)
|
24 |
+
self.register_buffer('alpha_sqrt', self.alpha.sqrt())
|
25 |
+
self.register_buffer('alpha_bar', torch.cumprod(self.alpha, dim = 0))
|
26 |
+
|
27 |
+
@abstractmethod
|
28 |
+
@torch.no_grad()
|
29 |
+
def get_x_prev(self, x, t, idx, eps_hat):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def set_network(self, network : nn.Module):
|
33 |
+
self.network = network
|
34 |
+
|
35 |
+
def q_sample(self, x0, t, eps = None):
|
36 |
+
alpha_t_bar = extract(self.alpha_bar, t, x0.shape)
|
37 |
+
if eps is None:
|
38 |
+
eps = torch.randn_like(x0)
|
39 |
+
q_xt_x0 = alpha_t_bar.sqrt() * x0 + (1 - alpha_t_bar).sqrt() * eps
|
40 |
+
return q_xt_x0
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def reverse_process(self, x_T, only_last=True, **kwargs):
|
44 |
+
x = x_T
|
45 |
+
if only_last:
|
46 |
+
for i, t in tqdm(enumerate(reversed(self.timesteps))):
|
47 |
+
idx = len(self.timesteps) - i - 1
|
48 |
+
x = self.p_sample(x, t, idx, **kwargs)
|
49 |
+
return x
|
50 |
+
else:
|
51 |
+
x_seq = []
|
52 |
+
x_seq.append(x)
|
53 |
+
for i, t in tqdm(enumerate(reversed(self.timesteps))):
|
54 |
+
idx = len(self.timesteps) - i - 1
|
55 |
+
x_seq.append(self.p_sample(x_seq[-1], t, idx, **kwargs))
|
56 |
+
return x_seq
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def p_sample(self, x, t, idx, gamma = None, **kwargs):
|
60 |
+
eps_hat = self.network(x = x, t = t, **kwargs)
|
61 |
+
if gamma is not None:
|
62 |
+
eps_null = self.network(x = x, t = t, cond_drop_all=True, **kwargs)
|
63 |
+
eps_hat = gamma * eps_hat + (1 - gamma) * eps_null
|
64 |
+
x = self.get_x_prev(x, idx, eps_hat)
|
65 |
+
return x
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def forward(self, x_T, **kwargs):
|
69 |
+
return self.reverse_process(x_T, **kwargs)
|
diffusion_model/sampler/ddim.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from diffusion_model.sampler.base_sampler import BaseSampler
|
4 |
+
|
5 |
+
class DDIM(BaseSampler):
|
6 |
+
def __init__(self, config_path):
|
7 |
+
super().__init__(config_path)
|
8 |
+
self.sampling_T = self.config['sampling_T']
|
9 |
+
step = self.T // self.sampling_T
|
10 |
+
self.timesteps = torch.arange(0, self.T, step, dtype=torch.int)
|
11 |
+
|
12 |
+
self.ddim_alpha = self.alpha_bar[self.timesteps]
|
13 |
+
self.sqrt_one_minus_alpha_bar = (1. - self.ddim_alpha).sqrt()
|
14 |
+
self.alpha_bar_prev = torch.cat([self.ddim_alpha[0:1], self.ddim_alpha[:-1]])
|
15 |
+
self.sigma = (self.config['eta'] *
|
16 |
+
torch.sqrt((1-self.alpha_bar_prev) / (1-self.ddim_alpha) *
|
17 |
+
(1 - self.ddim_alpha / self.alpha_bar_prev)))
|
18 |
+
|
19 |
+
def get_x_prev(self, x, tau, eps_hat) :
|
20 |
+
alpha_prev = self.alpha_bar_prev[tau]
|
21 |
+
sigma = self.sigma[tau]
|
22 |
+
|
23 |
+
x0_hat = (x - self.sqrt_one_minus_alpha_bar[tau] * eps_hat) \
|
24 |
+
/ (self.ddim_alpha[tau] ** 0.5)
|
25 |
+
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * eps_hat
|
26 |
+
if sigma == 0. : noise = 0.
|
27 |
+
else : noise = torch.randn_like(x, device = x.device)
|
28 |
+
x = alpha_prev.sqrt() * x0_hat + dir_xt + sigma * noise
|
29 |
+
return x
|
diffusion_model/sampler/ddpm.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from diffusion_model.sampler.base_sampler import BaseSampler
|
4 |
+
|
5 |
+
class DDPM(BaseSampler):
|
6 |
+
def __init__(self, config_path):
|
7 |
+
super().__init__(config_path)
|
8 |
+
self.timesteps = torch.arange(0, self.T, dtype=torch.int)
|
9 |
+
self.sqrt_one_minus_alpha_bar = (1. - self.alpha_bar).sqrt()
|
10 |
+
self.alpha_bar_prev = torch.cat([self.alpha_bar[0:1], self.alpha_bar[:-1]])
|
11 |
+
self.sigma = (((1 - self.alpha_bar_prev) / (1 - self.alpha_bar)) * self.beta).sqrt()
|
12 |
+
|
13 |
+
@torch.no_grad()
|
14 |
+
def get_x_prev(self, x, t, eps_hat):
|
15 |
+
x = (1 / self.alpha_sqrt[t]) \
|
16 |
+
* (x - (self.beta[t] / self.sqrt_one_minus_alpha_bar[t] * eps_hat))
|
17 |
+
z = torch.randn_like(x) if t > 0 else 0.
|
18 |
+
x = x + self.sigma[t] * z
|
19 |
+
return x
|
20 |
+
|
helper/beta_generator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L442
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
|
5 |
+
class BetaGenerator():
|
6 |
+
def __init__(self, T) :
|
7 |
+
self.T = T
|
8 |
+
|
9 |
+
def fixed_beta_schedule(self, beta) :
|
10 |
+
betas = torch.Tensor.repeat(torch.Tensor([beta]) , self.T)
|
11 |
+
return betas
|
12 |
+
|
13 |
+
def linear_beta_schedule(self):
|
14 |
+
"""
|
15 |
+
linear schedule, proposed in original ddpm paper
|
16 |
+
"""
|
17 |
+
scale = 1000 / self.T
|
18 |
+
beta_start = scale * 0.0001
|
19 |
+
beta_end = scale * 0.02
|
20 |
+
return torch.linspace(beta_start, beta_end, self.T)
|
21 |
+
|
22 |
+
def cosine_beta_schedule(self, s = 0.008):
|
23 |
+
"""
|
24 |
+
cosine schedule
|
25 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
26 |
+
"""
|
27 |
+
steps = self.T + 1
|
28 |
+
t = torch.linspace(0, self.T, steps, dtype = torch.float32) / self.T
|
29 |
+
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
|
30 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
31 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
32 |
+
return torch.clip(betas, 0, 0.999)
|
33 |
+
|
34 |
+
def sigmoid_beta_schedule(self, start = -3, end = 3, tau = 1):
|
35 |
+
"""
|
36 |
+
sigmoid schedule
|
37 |
+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
|
38 |
+
better for images > 64x64, when used during training
|
39 |
+
"""
|
40 |
+
steps = self.T + 1
|
41 |
+
t = torch.linspace(0, self.T, steps, dtype = torch.float32) / self.T
|
42 |
+
v_start = torch.tensor(start / tau).sigmoid()
|
43 |
+
v_end = torch.tensor(end / tau).sigmoid()
|
44 |
+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
|
45 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
46 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
47 |
+
return torch.clip(betas, 0, 0.999)
|
helper/cond_encoder.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
class BaseCondEncoder(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
config_path
|
9 |
+
):
|
10 |
+
super().__init__()
|
11 |
+
with open(config_path, "r") as file:
|
12 |
+
self.config = yaml.safe_load(file)['cond_encoder']
|
13 |
+
self.embed_dim = self.config['embed_dim']
|
14 |
+
self.cond_dim = self.config['cond_dim']
|
15 |
+
if 'cond_drop_prob' in self.config:
|
16 |
+
self.cond_drop_prob = self.config['cond_drop_prob']
|
17 |
+
self.null_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
18 |
+
else:
|
19 |
+
self.cond_drop_prob = 0.0
|
20 |
+
|
21 |
+
self.cond_mlp = nn.Sequential(
|
22 |
+
nn.Linear(self.embed_dim, self.cond_dim),
|
23 |
+
nn.GELU(),
|
24 |
+
nn.Linear(self.cond_dim, self.cond_dim)
|
25 |
+
)
|
26 |
+
|
27 |
+
def cond_drop(self, y: torch.tensor):
|
28 |
+
if self.training and self.cond_drop_prob > 0.0:
|
29 |
+
flags = torch.zeros((y.size(0), ), device=y.device).float().uniform_(0, 1) < self.cond_drop_prob
|
30 |
+
y[flags] = self.null_embedding.to(y.dtype)
|
31 |
+
return y
|
32 |
+
|
33 |
+
class CLIPEncoder(BaseCondEncoder):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
clip,
|
37 |
+
config_path
|
38 |
+
):
|
39 |
+
super().__init__(config_path)
|
40 |
+
self.clip = clip
|
41 |
+
self.clip.eval()
|
42 |
+
for param in self.clip.parameters():
|
43 |
+
param.requires_grad = False
|
44 |
+
|
45 |
+
def forward(self, y, cond_drop_all:bool = False):
|
46 |
+
if isinstance(y, str):
|
47 |
+
y = self.clip.text_encode(y, tokenize=True)
|
48 |
+
else:
|
49 |
+
y = self.clip.text_encode(y, tokenize=False)
|
50 |
+
y = self.cond_drop(y) # Only training
|
51 |
+
if cond_drop_all:
|
52 |
+
y[:] = self.null_embedding
|
53 |
+
return self.cond_mlp(y)
|
54 |
+
|
55 |
+
class ClassEncoder(BaseCondEncoder):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
config_path
|
59 |
+
):
|
60 |
+
super().__init__(config_path)
|
61 |
+
self.num_cond = self.config['num_cond']
|
62 |
+
self.embed = nn.Embedding(self.num_cond, self.embed_dim)
|
63 |
+
|
64 |
+
def forward(self, y, cond_drop_all:bool = False):
|
65 |
+
y = self.embed(y)
|
66 |
+
y = self.cond_drop(y) # Only training
|
67 |
+
if cond_drop_all:
|
68 |
+
y[:] = self.null_embedding
|
69 |
+
return self.cond_mlp(y)
|
helper/data_generator.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.datasets import CIFAR10, CelebA
|
2 |
+
from torch.utils.data import DataLoader, Dataset
|
3 |
+
from torchvision.transforms import Compose, ToTensor, Lambda, CenterCrop, Resize, RandomHorizontalFlip
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import json
|
7 |
+
from PIL import Image as im
|
8 |
+
from helper.tokenizer import Tokenizer
|
9 |
+
from transformers import AutoProcessor
|
10 |
+
|
11 |
+
def center_crop_and_resize(img, crop_size, resize_size):
|
12 |
+
width, height = img.size
|
13 |
+
|
14 |
+
# 1. Center Crop
|
15 |
+
left = (width - crop_size) / 2
|
16 |
+
top = (height - crop_size) / 2
|
17 |
+
right = (width + crop_size) / 2
|
18 |
+
bottom = (height + crop_size) / 2
|
19 |
+
|
20 |
+
img_cropped = img.crop((left, top, right, bottom))
|
21 |
+
|
22 |
+
# 2. Resize
|
23 |
+
img_resized = img_cropped.resize((resize_size, resize_size), im.Resampling.BICUBIC)
|
24 |
+
|
25 |
+
return img_resized
|
26 |
+
|
27 |
+
class UnlabelDataset(Dataset):
|
28 |
+
def __init__(self, path, transform):
|
29 |
+
self.path = path
|
30 |
+
self.file_list = os.listdir(path)
|
31 |
+
self.transform = transform
|
32 |
+
|
33 |
+
def __len__(self) :
|
34 |
+
return len(self.file_list)
|
35 |
+
|
36 |
+
def __getitem__(self, index):
|
37 |
+
img_path = self.path + self.file_list[index]
|
38 |
+
image = im.open(img_path)
|
39 |
+
image = self.transform(image)
|
40 |
+
return image
|
41 |
+
|
42 |
+
class CompositeDataset(Dataset):
|
43 |
+
def __init__(self, path, text_path, processor: AutoProcessor = None):
|
44 |
+
self.path = path
|
45 |
+
self.text_path = text_path
|
46 |
+
self.tokenizer = Tokenizer()
|
47 |
+
self.processor = processor
|
48 |
+
|
49 |
+
self.file_numbers = os.listdir(path)
|
50 |
+
self.file_numbers = [ os.path.splitext(filename)[0] for filename in self.file_numbers ]
|
51 |
+
|
52 |
+
self.transform = Compose([
|
53 |
+
ToTensor(),
|
54 |
+
CenterCrop(400),
|
55 |
+
Resize(256, antialias=True),
|
56 |
+
RandomHorizontalFlip(),
|
57 |
+
Lambda(lambda x: (x - 0.5) * 2)
|
58 |
+
])
|
59 |
+
|
60 |
+
def __len__(self) :
|
61 |
+
return len(self.file_numbers)
|
62 |
+
|
63 |
+
def get_text(self, text_path):
|
64 |
+
with open(text_path, encoding = 'CP949') as f:
|
65 |
+
text = json.load(f)['description']['impression']['description']
|
66 |
+
return text
|
67 |
+
|
68 |
+
def __getitem__(self, idx) :
|
69 |
+
img_path = self.path + self.file_numbers[idx] + '.png'
|
70 |
+
text_path = self.text_path + self.file_numbers[idx] + '.json'
|
71 |
+
image = im.open(img_path)
|
72 |
+
text = self.get_text(text_path)
|
73 |
+
if self.processor is not None:
|
74 |
+
image = center_crop_and_resize(image, 400, 256)
|
75 |
+
inputs = self.processor(
|
76 |
+
text=text,
|
77 |
+
images=image,
|
78 |
+
return_tensors="pt",
|
79 |
+
padding='max_length',
|
80 |
+
max_length=77,
|
81 |
+
truncation=True,
|
82 |
+
)
|
83 |
+
for j in inputs:
|
84 |
+
inputs[j] = inputs[j].squeeze(0)
|
85 |
+
return inputs
|
86 |
+
else:
|
87 |
+
image = self.transform(image)
|
88 |
+
text = self.tokenizer.tokenize(text)
|
89 |
+
for j in text:
|
90 |
+
text[j] = text[j].squeeze(0)
|
91 |
+
return image, text
|
92 |
+
|
93 |
+
class DataGenerator():
|
94 |
+
def __init__(self, num_workers: int = 4, pin_memory: bool = True):
|
95 |
+
self.transform = Compose([
|
96 |
+
ToTensor(),
|
97 |
+
Lambda(lambda x: (x - 0.5) * 2)
|
98 |
+
])
|
99 |
+
self.num_workers = num_workers
|
100 |
+
self.pin_memory = pin_memory
|
101 |
+
|
102 |
+
def cifar10(self, path = './datasets', batch_size : int = 64, train : bool = True):
|
103 |
+
train_data = CIFAR10(path, download = True, train = train, transform = self.transform)
|
104 |
+
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
|
105 |
+
return dl
|
106 |
+
|
107 |
+
def celeba(self, path = './datasets', batch_size : int = 16):
|
108 |
+
train_data = CelebA(path, transform = Compose([
|
109 |
+
ToTensor(),
|
110 |
+
CenterCrop(178),
|
111 |
+
Resize(128),
|
112 |
+
Lambda(lambda x: (x - 0.5) * 2)
|
113 |
+
]))
|
114 |
+
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
|
115 |
+
return dl
|
116 |
+
|
117 |
+
def composite(self, path, text_path, batch_size : int = 16, is_process: bool = False):
|
118 |
+
processor = None
|
119 |
+
if is_process:
|
120 |
+
model_name = "Bingsu/clip-vit-base-patch32-ko"
|
121 |
+
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
|
122 |
+
dataset = CompositeDataset(path, text_path, processor)
|
123 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
124 |
+
num_workers=self.num_workers, pin_memory=self.pin_memory)
|
125 |
+
|
126 |
+
def random_data(self, size, batch_size : int = 4):
|
127 |
+
train_data = torch.randn(size)
|
128 |
+
return DataLoader(train_data, batch_size)
|
129 |
+
|
helper/ema.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/lucidrains/ema-pytorch/tree/main
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
from copy import deepcopy
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn, Tensor
|
11 |
+
from torch.nn import Module
|
12 |
+
|
13 |
+
def exists(val):
|
14 |
+
return val is not None
|
15 |
+
|
16 |
+
def divisible_by(num, den):
|
17 |
+
return (num % den) == 0
|
18 |
+
|
19 |
+
def get_module_device(m: Module):
|
20 |
+
return next(m.parameters()).device
|
21 |
+
|
22 |
+
def maybe_coerce_dtype(t, dtype):
|
23 |
+
if t.dtype == dtype:
|
24 |
+
return t
|
25 |
+
|
26 |
+
return t.to(dtype)
|
27 |
+
|
28 |
+
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False):
|
29 |
+
if auto_move_device:
|
30 |
+
src = src.to(tgt.device)
|
31 |
+
|
32 |
+
if coerce_dtype:
|
33 |
+
src = maybe_coerce_dtype(src, tgt.dtype)
|
34 |
+
|
35 |
+
tgt.copy_(src)
|
36 |
+
|
37 |
+
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False):
|
38 |
+
if auto_move_device:
|
39 |
+
src = src.to(tgt.device)
|
40 |
+
|
41 |
+
if coerce_dtype:
|
42 |
+
src = maybe_coerce_dtype(src, tgt.dtype)
|
43 |
+
|
44 |
+
tgt.lerp_(src, weight)
|
45 |
+
|
46 |
+
class EMA(Module):
|
47 |
+
"""
|
48 |
+
Implements exponential moving average shadowing for your model.
|
49 |
+
|
50 |
+
Utilizes an inverse decay schedule to manage longer term training runs.
|
51 |
+
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
52 |
+
|
53 |
+
@crowsonkb's notes on EMA Warmup:
|
54 |
+
|
55 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
56 |
+
good values for models you plan to train for a million or more steps (reaches decay
|
57 |
+
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
58 |
+
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
59 |
+
215.4k steps).
|
60 |
+
|
61 |
+
Args:
|
62 |
+
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
63 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
64 |
+
min_value (float): The minimum EMA decay rate. Default: 0.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
model: Module,
|
70 |
+
ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
|
71 |
+
beta = 0.9999,
|
72 |
+
update_after_step = 100,
|
73 |
+
update_every = 10,
|
74 |
+
inv_gamma = 1.0,
|
75 |
+
power = 2 / 3,
|
76 |
+
min_value = 0.0,
|
77 |
+
param_or_buffer_names_no_ema: set[str] = set(),
|
78 |
+
ignore_names: set[str] = set(),
|
79 |
+
ignore_startswith_names: set[str] = set(),
|
80 |
+
include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
|
81 |
+
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
|
82 |
+
use_foreach = False,
|
83 |
+
update_model_with_ema_every = None, # update the model with EMA model weights every number of steps, for better continual learning https://arxiv.org/abs/2406.02596
|
84 |
+
update_model_with_ema_beta = 0., # amount of model weight to keep when updating to EMA (hare to tortoise)
|
85 |
+
forward_method_names: tuple[str, ...] = (),
|
86 |
+
move_ema_to_online_device = False,
|
87 |
+
coerce_dtype = False,
|
88 |
+
lazy_init_ema = False,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
self.beta = beta
|
92 |
+
|
93 |
+
self.is_frozen = beta == 1.
|
94 |
+
|
95 |
+
# whether to include the online model within the module tree, so that state_dict also saves it
|
96 |
+
|
97 |
+
self.include_online_model = include_online_model
|
98 |
+
|
99 |
+
if include_online_model:
|
100 |
+
self.online_model = model
|
101 |
+
else:
|
102 |
+
self.online_model = [model] # hack
|
103 |
+
|
104 |
+
# handle callable returning ema module
|
105 |
+
|
106 |
+
if not isinstance(ema_model, Module) and callable(ema_model):
|
107 |
+
ema_model = ema_model()
|
108 |
+
|
109 |
+
# ema model
|
110 |
+
|
111 |
+
self.ema_model = None
|
112 |
+
self.forward_method_names = forward_method_names
|
113 |
+
|
114 |
+
if not lazy_init_ema:
|
115 |
+
self.init_ema(ema_model)
|
116 |
+
else:
|
117 |
+
assert not exists(ema_model)
|
118 |
+
|
119 |
+
# tensor update functions
|
120 |
+
|
121 |
+
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
|
122 |
+
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
|
123 |
+
|
124 |
+
# updating hyperparameters
|
125 |
+
|
126 |
+
self.update_every = update_every
|
127 |
+
self.update_after_step = update_after_step
|
128 |
+
|
129 |
+
self.inv_gamma = inv_gamma
|
130 |
+
self.power = power
|
131 |
+
self.min_value = min_value
|
132 |
+
|
133 |
+
assert isinstance(param_or_buffer_names_no_ema, (set, list))
|
134 |
+
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
|
135 |
+
|
136 |
+
self.ignore_names = ignore_names
|
137 |
+
self.ignore_startswith_names = ignore_startswith_names
|
138 |
+
|
139 |
+
# continual learning related
|
140 |
+
|
141 |
+
self.update_model_with_ema_every = update_model_with_ema_every
|
142 |
+
self.update_model_with_ema_beta = update_model_with_ema_beta
|
143 |
+
|
144 |
+
# whether to manage if EMA model is kept on a different device
|
145 |
+
|
146 |
+
self.allow_different_devices = allow_different_devices
|
147 |
+
|
148 |
+
# whether to coerce dtype when copy or lerp from online to EMA model
|
149 |
+
|
150 |
+
self.coerce_dtype = coerce_dtype
|
151 |
+
|
152 |
+
# whether to move EMA model to online model device automatically
|
153 |
+
|
154 |
+
self.move_ema_to_online_device = move_ema_to_online_device
|
155 |
+
|
156 |
+
# whether to use foreach
|
157 |
+
|
158 |
+
if use_foreach:
|
159 |
+
assert hasattr(torch, '_foreach_lerp_') and hasattr(torch, '_foreach_copy_'), 'your version of torch does not have the prerequisite foreach functions'
|
160 |
+
|
161 |
+
self.use_foreach = use_foreach
|
162 |
+
|
163 |
+
# init and step states
|
164 |
+
|
165 |
+
self.register_buffer('initted', torch.tensor(False))
|
166 |
+
self.register_buffer('step', torch.tensor(0))
|
167 |
+
|
168 |
+
def init_ema(
|
169 |
+
self,
|
170 |
+
ema_model: Module | None = None
|
171 |
+
):
|
172 |
+
self.ema_model = ema_model
|
173 |
+
|
174 |
+
if not exists(self.ema_model):
|
175 |
+
try:
|
176 |
+
self.ema_model = deepcopy(self.model)
|
177 |
+
except Exception as e:
|
178 |
+
print(f'Error: While trying to deepcopy model: {e}')
|
179 |
+
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
|
180 |
+
exit()
|
181 |
+
|
182 |
+
for p in self.ema_model.parameters():
|
183 |
+
p.detach_()
|
184 |
+
|
185 |
+
# forwarding methods
|
186 |
+
|
187 |
+
for forward_method_name in self.forward_method_names:
|
188 |
+
fn = getattr(self.ema_model, forward_method_name)
|
189 |
+
setattr(self, forward_method_name, fn)
|
190 |
+
|
191 |
+
# parameter and buffer names
|
192 |
+
|
193 |
+
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
|
194 |
+
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
|
195 |
+
|
196 |
+
def add_to_optimizer_post_step_hook(self, optimizer):
|
197 |
+
assert hasattr(optimizer, 'register_step_post_hook')
|
198 |
+
|
199 |
+
def hook(*_):
|
200 |
+
self.update()
|
201 |
+
|
202 |
+
return optimizer.register_step_post_hook(hook)
|
203 |
+
|
204 |
+
@property
|
205 |
+
def model(self):
|
206 |
+
return self.online_model if self.include_online_model else self.online_model[0]
|
207 |
+
|
208 |
+
def eval(self):
|
209 |
+
return self.ema_model.eval()
|
210 |
+
|
211 |
+
@torch.no_grad()
|
212 |
+
def forward_eval(self, *args, **kwargs):
|
213 |
+
# handy function for invoking ema model with no grad + eval
|
214 |
+
training = self.ema_model.training
|
215 |
+
out = self.ema_model(*args, **kwargs)
|
216 |
+
self.ema_model.train(training)
|
217 |
+
return out
|
218 |
+
|
219 |
+
def restore_ema_model_device(self):
|
220 |
+
device = self.initted.device
|
221 |
+
self.ema_model.to(device)
|
222 |
+
|
223 |
+
def get_params_iter(self, model):
|
224 |
+
for name, param in model.named_parameters():
|
225 |
+
if name not in self.parameter_names:
|
226 |
+
continue
|
227 |
+
yield name, param
|
228 |
+
|
229 |
+
def get_buffers_iter(self, model):
|
230 |
+
for name, buffer in model.named_buffers():
|
231 |
+
if name not in self.buffer_names:
|
232 |
+
continue
|
233 |
+
yield name, buffer
|
234 |
+
|
235 |
+
def copy_params_from_model_to_ema(self):
|
236 |
+
copy = self.inplace_copy
|
237 |
+
|
238 |
+
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
|
239 |
+
copy(ma_params.data, current_params.data)
|
240 |
+
|
241 |
+
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
|
242 |
+
copy(ma_buffers.data, current_buffers.data)
|
243 |
+
|
244 |
+
def copy_params_from_ema_to_model(self):
|
245 |
+
copy = self.inplace_copy
|
246 |
+
|
247 |
+
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
|
248 |
+
copy(current_params.data, ma_params.data)
|
249 |
+
|
250 |
+
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
|
251 |
+
copy(current_buffers.data, ma_buffers.data)
|
252 |
+
|
253 |
+
def update_model_with_ema(self, decay = None):
|
254 |
+
if not exists(decay):
|
255 |
+
decay = self.update_model_with_ema_beta
|
256 |
+
|
257 |
+
if decay == 0.:
|
258 |
+
return self.copy_params_from_ema_to_model()
|
259 |
+
|
260 |
+
self.update_moving_average(self.model, self.ema_model, decay)
|
261 |
+
|
262 |
+
def get_current_decay(self):
|
263 |
+
epoch = (self.step - self.update_after_step - 1).clamp(min = 0.)
|
264 |
+
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
265 |
+
|
266 |
+
if epoch.item() <= 0:
|
267 |
+
return 0.
|
268 |
+
|
269 |
+
return value.clamp(min = self.min_value, max = self.beta).item()
|
270 |
+
|
271 |
+
def update(self):
|
272 |
+
step = self.step.item()
|
273 |
+
self.step += 1
|
274 |
+
|
275 |
+
if not self.initted.item():
|
276 |
+
if not exists(self.ema_model):
|
277 |
+
self.init_ema()
|
278 |
+
|
279 |
+
self.copy_params_from_model_to_ema()
|
280 |
+
self.initted.data.copy_(torch.tensor(True))
|
281 |
+
return
|
282 |
+
|
283 |
+
should_update = divisible_by(step, self.update_every)
|
284 |
+
|
285 |
+
if should_update and step <= self.update_after_step:
|
286 |
+
self.copy_params_from_model_to_ema()
|
287 |
+
return
|
288 |
+
|
289 |
+
if should_update:
|
290 |
+
self.update_moving_average(self.ema_model, self.model)
|
291 |
+
|
292 |
+
if exists(self.update_model_with_ema_every) and divisible_by(step, self.update_model_with_ema_every):
|
293 |
+
self.update_model_with_ema()
|
294 |
+
|
295 |
+
@torch.no_grad()
|
296 |
+
def update_moving_average(self, ma_model, current_model, current_decay = None):
|
297 |
+
if self.is_frozen:
|
298 |
+
return
|
299 |
+
|
300 |
+
# move ema model to online model device if not same and needed
|
301 |
+
|
302 |
+
if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
|
303 |
+
ma_model.to(get_module_device(current_model))
|
304 |
+
|
305 |
+
# get current decay
|
306 |
+
|
307 |
+
if not exists(current_decay):
|
308 |
+
current_decay = self.get_current_decay()
|
309 |
+
|
310 |
+
# store all source and target tensors to copy or lerp
|
311 |
+
|
312 |
+
tensors_to_copy = []
|
313 |
+
tensors_to_lerp = []
|
314 |
+
|
315 |
+
# loop through parameters
|
316 |
+
|
317 |
+
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
|
318 |
+
if name in self.ignore_names:
|
319 |
+
continue
|
320 |
+
|
321 |
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
322 |
+
continue
|
323 |
+
|
324 |
+
if name in self.param_or_buffer_names_no_ema:
|
325 |
+
tensors_to_copy.append((ma_params.data, current_params.data))
|
326 |
+
continue
|
327 |
+
|
328 |
+
tensors_to_lerp.append((ma_params.data, current_params.data))
|
329 |
+
|
330 |
+
# loop through buffers
|
331 |
+
|
332 |
+
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
|
333 |
+
if name in self.ignore_names:
|
334 |
+
continue
|
335 |
+
|
336 |
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
337 |
+
continue
|
338 |
+
|
339 |
+
if name in self.param_or_buffer_names_no_ema:
|
340 |
+
tensors_to_copy.append((ma_buffer.data, current_buffer.data))
|
341 |
+
continue
|
342 |
+
|
343 |
+
tensors_to_lerp.append((ma_buffer.data, current_buffer.data))
|
344 |
+
|
345 |
+
# execute inplace copy or lerp
|
346 |
+
|
347 |
+
if not self.use_foreach:
|
348 |
+
|
349 |
+
for tgt, src in tensors_to_copy:
|
350 |
+
self.inplace_copy(tgt, src)
|
351 |
+
|
352 |
+
for tgt, src in tensors_to_lerp:
|
353 |
+
self.inplace_lerp(tgt, src, 1. - current_decay)
|
354 |
+
|
355 |
+
else:
|
356 |
+
# use foreach if available and specified
|
357 |
+
|
358 |
+
if self.allow_different_devices:
|
359 |
+
tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy]
|
360 |
+
tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp]
|
361 |
+
|
362 |
+
if self.coerce_dtype:
|
363 |
+
tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy]
|
364 |
+
tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp]
|
365 |
+
|
366 |
+
if len(tensors_to_copy) > 0:
|
367 |
+
tgt_copy, src_copy = zip(*tensors_to_copy)
|
368 |
+
torch._foreach_copy_(tgt_copy, src_copy)
|
369 |
+
|
370 |
+
if len(tensors_to_lerp) > 0:
|
371 |
+
tgt_lerp, src_lerp = zip(*tensors_to_lerp)
|
372 |
+
torch._foreach_lerp_(tgt_lerp, src_lerp, 1. - current_decay)
|
373 |
+
|
374 |
+
def __call__(self, *args, **kwargs):
|
375 |
+
return self.ema_model(*args, **kwargs)
|
helper/loader.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from helper.ema import EMA
|
4 |
+
from transformers import get_cosine_schedule_with_warmup
|
5 |
+
|
6 |
+
class Loader():
|
7 |
+
def __init__(self, device = None):
|
8 |
+
self.device = device
|
9 |
+
|
10 |
+
def print_model(self, check_point):
|
11 |
+
print("Epoch: " + str(check_point["epoch"]))
|
12 |
+
print("Training step: " + str(check_point["training_steps"]))
|
13 |
+
print("Best loss: " + str(check_point["best_loss"]))
|
14 |
+
print("Batch size: " + str(check_point["batch_size"]))
|
15 |
+
print("Number of batches: " + str(check_point["number_of_batches"]))
|
16 |
+
|
17 |
+
def model_load(self, file_name : str, model : nn.Module,
|
18 |
+
print_dict : bool = True, is_ema: bool = True):
|
19 |
+
check_point = torch.load(file_name + ".pth", map_location=self.device,
|
20 |
+
weights_only=True)
|
21 |
+
if print_dict: self.print_model(check_point)
|
22 |
+
if is_ema:
|
23 |
+
model = EMA(model)
|
24 |
+
model.load_state_dict(check_point['ema_state_dict'])
|
25 |
+
model = model.ema_model
|
26 |
+
else:
|
27 |
+
model.load_state_dict(check_point['model_state_dict'])
|
28 |
+
model.eval()
|
29 |
+
print("===Model loaded!===")
|
30 |
+
return model
|
31 |
+
|
32 |
+
def load_for_training(self, file_name: str, model: nn.Module, print_dict: bool = True):
|
33 |
+
check_point = torch.load(file_name + ".pth", map_location=self.device,
|
34 |
+
weights_only=True)
|
35 |
+
if print_dict: self.print_model(check_point)
|
36 |
+
model.load_state_dict(check_point['model_state_dict'])
|
37 |
+
model.train()
|
38 |
+
ema = EMA(model)
|
39 |
+
ema.load_state_dict(check_point['ema_state_dict'])
|
40 |
+
ema.train()
|
41 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4)
|
42 |
+
optimizer.load_state_dict(check_point["optimizer_state_dict"])
|
43 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
|
44 |
+
scheduler.load_state_dict(check_point["scheduler_state_dict"])
|
45 |
+
epoch = check_point["epoch"]
|
46 |
+
loss = check_point["best_loss"]
|
47 |
+
print("===Model/EMA/Optimizer/Scheduler/Epoch/Loss loaded!===")
|
48 |
+
return model, ema, optimizer, scheduler, epoch, loss
|
helper/painter.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from PIL import Image as im
|
3 |
+
from tqdm import tqdm
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class Painter(object):
|
9 |
+
def __init__(self) :
|
10 |
+
pass
|
11 |
+
|
12 |
+
def show_images(self, images, title : str = '', index : bool = False, cmap = None, show = True):
|
13 |
+
images = images.permute(0, 2, 3, 1)
|
14 |
+
if type(images) is torch.Tensor:
|
15 |
+
images = images.detach().cpu().numpy()
|
16 |
+
images = np.clip(images / 2 + 0.5, 0, 1)
|
17 |
+
|
18 |
+
fig = plt.figure(figsize=(8, 8))
|
19 |
+
rows = int(len(images) ** (1 / 2))
|
20 |
+
cols = round(len(images) / rows)
|
21 |
+
|
22 |
+
idx = 0
|
23 |
+
for _ in range(rows):
|
24 |
+
for _ in range(cols):
|
25 |
+
fig.add_subplot(rows, cols, idx + 1)
|
26 |
+
|
27 |
+
if idx < len(images):
|
28 |
+
plt.imshow(images[idx], cmap = cmap)
|
29 |
+
if index :
|
30 |
+
plt.title(idx + 1)
|
31 |
+
plt.axis('off')
|
32 |
+
idx += 1
|
33 |
+
fig.suptitle(title, fontsize=30)
|
34 |
+
if show:
|
35 |
+
plt.show()
|
36 |
+
|
37 |
+
def show_first_batch(self, loader):
|
38 |
+
for batch in loader:
|
39 |
+
self.show_images(images = batch, title = "First Batch")
|
40 |
+
break
|
41 |
+
|
42 |
+
def make_gif(self, images, file_name):
|
43 |
+
imgs = []
|
44 |
+
for i in tqdm(range(len(images))):
|
45 |
+
img_buf = io.BytesIO()
|
46 |
+
self.show_images(images[i], title = 't = ' + str(i), show=False)
|
47 |
+
plt.savefig(img_buf, format='png')
|
48 |
+
imgs.append(im.open(img_buf))
|
49 |
+
imgs[0].save(file_name + '.gif', format='GIF', append_images=imgs, save_all=True, duration=1, loop=0)
|
50 |
+
plt.close('all')
|
51 |
+
|
helper/tokenizer.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer
|
2 |
+
|
3 |
+
class Tokenizer:
|
4 |
+
def __init__(self, model_name="Bingsu/clip-vit-base-patch32-ko"):
|
5 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
6 |
+
self.vocab_size = self.tokenizer.vocab_size
|
7 |
+
|
8 |
+
def tokenize(self, text):
|
9 |
+
return self.tokenizer(text, padding='max_length', max_length=77, truncation=True, return_tensors='pt')
|
helper/trainer.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from accelerate import Accelerator
|
5 |
+
from tqdm import tqdm
|
6 |
+
from typing import Callable
|
7 |
+
from helper.ema import EMA
|
8 |
+
|
9 |
+
class Trainer():
|
10 |
+
def __init__(self,
|
11 |
+
model: nn.Module,
|
12 |
+
loss_fn: Callable,
|
13 |
+
ema: EMA = None,
|
14 |
+
optimizer: torch.optim.Optimizer = None,
|
15 |
+
scheduler: torch.optim.lr_scheduler = None,
|
16 |
+
start_epoch = 0,
|
17 |
+
best_loss = float("inf"),
|
18 |
+
accumulation_steps: int = 1,
|
19 |
+
max_grad_norm: float = 1.0):
|
20 |
+
self.accelerator = Accelerator(mixed_precision = 'fp16', gradient_accumulation_steps=accumulation_steps)
|
21 |
+
self.model = model.to(self.accelerator.device)
|
22 |
+
if ema is None:
|
23 |
+
self.ema = EMA(self.model).to(self.accelerator.device)
|
24 |
+
else:
|
25 |
+
self.ema = ema.to(self.accelerator.device)
|
26 |
+
self.loss_fn = loss_fn
|
27 |
+
self.optimizer = optimizer
|
28 |
+
if self.optimizer is None:
|
29 |
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = 1e-4)
|
30 |
+
self.scheduler = scheduler
|
31 |
+
if self.scheduler is None:
|
32 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100)
|
33 |
+
self.start_epoch = start_epoch
|
34 |
+
self.best_loss = best_loss
|
35 |
+
self.accumulation_steps = accumulation_steps
|
36 |
+
self.max_grad_norm = max_grad_norm
|
37 |
+
|
38 |
+
def train(self, dl : DataLoader, epochs: int, file_name : str, no_label : bool = False):
|
39 |
+
self.model.train()
|
40 |
+
self.model, self.optimizer, data_loader, self.scheduler = self.accelerator.prepare(
|
41 |
+
self.model, self.optimizer, dl, self.scheduler
|
42 |
+
)
|
43 |
+
|
44 |
+
for epoch in range(self.start_epoch + 1, epochs + 1):
|
45 |
+
epoch_loss = 0.0
|
46 |
+
progress_bar = tqdm(data_loader, leave=False, desc=f"Epoch {epoch}/{epochs}", colour="#005500", disable = not self.accelerator.is_local_main_process)
|
47 |
+
for step, batch in enumerate(progress_bar):
|
48 |
+
with self.accelerator.accumulate(self.model): # Context manager for accumulation
|
49 |
+
if no_label:
|
50 |
+
if isinstance(batch, list):
|
51 |
+
x = batch[0].to(self.accelerator.device)
|
52 |
+
else:
|
53 |
+
x = batch.to(self.accelerator.device)
|
54 |
+
else:
|
55 |
+
x, y = batch[0].to(self.accelerator.device), batch[1].to(self.accelerator.device)
|
56 |
+
|
57 |
+
with self.accelerator.autocast():
|
58 |
+
if no_label:
|
59 |
+
loss = self.loss_fn(x)
|
60 |
+
else:
|
61 |
+
loss = self.loss_fn(x, y=y)
|
62 |
+
|
63 |
+
# Normalize the loss
|
64 |
+
self.accelerator.backward(loss)
|
65 |
+
|
66 |
+
# Gradient Clipping:
|
67 |
+
if self.max_grad_norm is not None and self.accelerator.sync_gradients:
|
68 |
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
69 |
+
|
70 |
+
# Only step optimizer and scheduler when we have accumulated enough
|
71 |
+
self.optimizer.step()
|
72 |
+
self.ema.update()
|
73 |
+
self.optimizer.zero_grad()
|
74 |
+
|
75 |
+
epoch_loss += loss.item()
|
76 |
+
progress_bar.set_postfix(loss=epoch_loss / (min(step + 1, len(data_loader)))) # Correct progress bar update
|
77 |
+
|
78 |
+
self.accelerator.wait_for_everyone()
|
79 |
+
if self.accelerator.is_main_process:
|
80 |
+
epoch_loss = epoch_loss / len(progress_bar)
|
81 |
+
self.scheduler.step()
|
82 |
+
log_string = f"Loss at epoch {epoch}: {epoch_loss :.4f}"
|
83 |
+
|
84 |
+
# Save the best model
|
85 |
+
if self.best_loss > epoch_loss:
|
86 |
+
self.best_loss = epoch_loss
|
87 |
+
torch.save({
|
88 |
+
"model_state_dict": self.accelerator.get_state_dict(self.model),
|
89 |
+
"ema_state_dict": self.ema.state_dict(),
|
90 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
91 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
92 |
+
"epoch": epoch,
|
93 |
+
"training_steps": epoch * len(dl),
|
94 |
+
"best_loss": self.best_loss,
|
95 |
+
"batch_size": dl.batch_size,
|
96 |
+
"number_of_batches": len(dl)
|
97 |
+
}, file_name + '.pth')
|
98 |
+
log_string += " --> Best model ever (stored)"
|
99 |
+
print(log_string)
|
helper/util.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def extract(a, t, x_shape):
|
2 |
+
b, *_ = t.shape
|
3 |
+
out = a.gather(-1, t)
|
4 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|