DiffuSynthV0.2 / model /diffusion.py
WeixuanYuan's picture
Upload 66 files
ae1bdf7 verified
raw
history blame
15.3 kB
import json
from functools import partial
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from six.moves import xrange
from torch.utils.tensorboard import SummaryWriter
import random
from metrics.IS import get_inception_score
from tools import create_key
from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \
PreNorm, \
Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \
LinearCrossAttention, LinearCrossAttentionAdd
class ConditionedUnet(nn.Module):
def __init__(
self,
in_dim,
out_dim=None,
down_dims=None,
up_dims=None,
mid_depth=3,
with_time_emb=True,
time_dim=None,
resnet_block_groups=8,
use_convnext=True,
convnext_mult=2,
attn_type="linear_cat",
n_label_class=11,
condition_type="instrument_family",
label_emb_dim=128,
):
super().__init__()
self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type)
if up_dims is None:
up_dims = [128, 128, 64, 32]
if down_dims is None:
down_dims = [32, 32, 64, 128]
out_dim = default(out_dim, in_dim)
assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)"
assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]"
assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]"
down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
print(f"down_in_out: {down_in_out}")
print(f"up_in_out: {up_in_out}")
time_dim = default(time_dim, int(down_dims[0] * 4))
self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3)
if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult)
else:
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
if attn_type == "linear_cat":
attn_klass = partial(LinearCrossAttention)
elif attn_type == "linear_add":
attn_klass = partial(LinearCrossAttentionAdd)
else:
raise NotImplementedError()
# time embeddings
if with_time_emb:
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(down_dims[0]),
nn.Linear(down_dims[0], time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
# left layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
skip_dims = []
for down_dim_in, down_dim_out in down_in_out:
self.downs.append(
nn.ModuleList(
[
block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim),
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim),
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
Downsample(down_dim_out),
]
)
)
skip_dims.append(down_dim_out)
# bottleneck
mid_dim = down_dims[-1]
self.mid_left = nn.ModuleList([])
self.mid_right = nn.ModuleList([])
for _ in range(mid_depth - 1):
self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim))
self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim))
self.mid_mid = nn.ModuleList(
[
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))),
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
]
)
# right layers
for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out):
skip_dim = skip_dims.pop() # down_dim_out
self.ups.append(
nn.ModuleList(
[
# pop&cat (h/2, w/2, down_dim_out)
block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim),
Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))),
Upsample(up_dim_in),
# pop&cat (h, w, down_dim_out)
block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim),
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
# pop&cat (h, w, down_dim_out)
block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim),
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
]
)
)
self.final_conv = nn.Sequential(
block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1)
)
def size(self):
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
def forward(self, x, time, condition=None):
if condition is not None:
condition_emb = self.label_embedding(condition)
else:
condition_emb = None
h = []
x = self.init_conv(x)
h.append(x)
time_emb = self.time_mlp(time) if exists(self.time_mlp) else None
# downsample
for block1, attn1, block2, attn2, downsample in self.downs:
x = block1(x, time_emb)
x = attn1(x, condition_emb)
h.append(x)
x = block2(x, time_emb)
x = attn2(x, condition_emb)
h.append(x)
x = downsample(x)
h.append(x)
# bottleneck
for block in self.mid_left:
x = block(x, time_emb)
h.append(x)
(block1, attn, block2) = self.mid_mid
x = block1(x, time_emb)
x = attn(x, condition_emb)
x = block2(x, time_emb)
for block in self.mid_right:
# This is U-Net!!!
x = pad_and_concat(h.pop(), x)
x = block(x, time_emb)
# upsample
for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups:
x = pad_and_concat(h.pop(), x)
x = block1(x, time_emb)
x = attn1(x, condition_emb)
x = upsample(x)
x = pad_and_concat(h.pop(), x)
x = block2(x, time_emb)
x = attn2(x, condition_emb)
x = pad_and_concat(h.pop(), x)
x = block3(x, time_emb)
x = attn3(x, condition_emb)
x = pad_and_concat(h.pop(), x)
x = self.final_conv(x)
return x
def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod,
noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod,
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
predicted_noise = denoise_model(x_noisy, t, condition)
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
uncondition_rate, unconditional_condition):
model.to(device)
model.eval()
eva_loss = []
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
for i in xrange(500):
data, attributes = next(iter(iterator))
data = data.to(device)
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
selected_conditions = [
unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample)
for conditions_of_one_sample in conditions]
selected_conditions = torch.stack(selected_conditions).float().to(device)
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
sqrt_alphas_cumprod=sqrt_alphas_cumprod,
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
eva_loss.append(loss.item())
initial_loss = np.mean(eva_loss)
return initial_loss
def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"):
UNet = ConditionedUnet(**model_Config)
print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}")
UNet.to(device)
if load_pretrain:
print(f"Loading weights from models/{model_name}_UNet.pth")
checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device)
UNet.load_state_dict(checkpoint['model_state_dict'])
UNet.eval()
return UNet
def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain,
encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None,
n_IS_batches=50):
if save_model_name is None:
save_model_name = init_model_name
def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss):
model_hyperparameter = unetConfig
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
model_hyperparameter["lr"] = lr
model_hyperparameter["model_size"] = model_size
model_hyperparameter["current_iter"] = current_iter
model_hyperparameter["current_loss"] = current_loss
with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file:
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
model = ConditionedUnet(**unetConfig)
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {model_size}")
model.to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False)
if load_pretrain:
print(f"Loading weights from models/{init_model_name}_UNet.pt")
checkpoint = torch.load(f'models/{init_model_name}_UNet.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
print("Model initialized.")
if max_iter == 0:
print("Return model directly.")
return model, optimizer
train_loss = []
writer = SummaryWriter(f'runs/{save_model_name}_UNet')
if init_loss is None:
previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
uncondition_rate, unconditional_condition)
else:
previous_loss = init_loss
print(f"initial_IS: {previous_loss}")
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
model.train()
for i in xrange(max_iter):
data, attributes = next(iter(iterator))
data = data.to(device)
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
conditions_of_one_sample) for conditions_of_one_sample in conditions]
selected_conditions = torch.stack(selected_conditions).float().to(device)
optimizer.zero_grad()
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
sqrt_alphas_cumprod=sqrt_alphas_cumprod,
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
if step % 100 == 0:
print('%d step' % (step))
if step % save_steps == 0:
current_loss = np.mean(train_loss[-save_steps:])
print(f"current_loss = {current_loss}")
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'models/{save_model_name}_UNet.pth')
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
if step % 20000 == 0:
current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches,
positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT")
print('current_IS: %.5f' % current_IS)
current_loss = np.mean(train_loss[-save_steps:])
writer.add_scalar(f"current_IS", current_IS, step)
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'models/history/{save_model_name}_{step}_UNet.pth')
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
return model, optimizer