Spaces:
Running
Running
import json | |
from functools import partial | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from six.moves import xrange | |
from torch.utils.tensorboard import SummaryWriter | |
import random | |
from metrics.IS import get_inception_score | |
from tools import create_key | |
from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \ | |
PreNorm, \ | |
Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \ | |
LinearCrossAttention, LinearCrossAttentionAdd | |
class ConditionedUnet(nn.Module): | |
def __init__( | |
self, | |
in_dim, | |
out_dim=None, | |
down_dims=None, | |
up_dims=None, | |
mid_depth=3, | |
with_time_emb=True, | |
time_dim=None, | |
resnet_block_groups=8, | |
use_convnext=True, | |
convnext_mult=2, | |
attn_type="linear_cat", | |
n_label_class=11, | |
condition_type="instrument_family", | |
label_emb_dim=128, | |
): | |
super().__init__() | |
self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type) | |
if up_dims is None: | |
up_dims = [128, 128, 64, 32] | |
if down_dims is None: | |
down_dims = [32, 32, 64, 128] | |
out_dim = default(out_dim, in_dim) | |
assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)" | |
assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]" | |
assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]" | |
down_in_out = list(zip(down_dims[:-1], down_dims[1:])) | |
up_in_out = list(zip(up_dims[:-1], up_dims[1:])) | |
print(f"down_in_out: {down_in_out}") | |
print(f"up_in_out: {up_in_out}") | |
time_dim = default(time_dim, int(down_dims[0] * 4)) | |
self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3) | |
if use_convnext: | |
block_klass = partial(ConvNextBlock, mult=convnext_mult) | |
else: | |
block_klass = partial(ResnetBlock, groups=resnet_block_groups) | |
if attn_type == "linear_cat": | |
attn_klass = partial(LinearCrossAttention) | |
elif attn_type == "linear_add": | |
attn_klass = partial(LinearCrossAttentionAdd) | |
else: | |
raise NotImplementedError() | |
# time embeddings | |
if with_time_emb: | |
self.time_mlp = nn.Sequential( | |
SinusoidalPositionEmbeddings(down_dims[0]), | |
nn.Linear(down_dims[0], time_dim), | |
nn.GELU(), | |
nn.Linear(time_dim, time_dim), | |
) | |
else: | |
time_dim = None | |
self.time_mlp = None | |
# left layers | |
self.downs = nn.ModuleList([]) | |
self.ups = nn.ModuleList([]) | |
skip_dims = [] | |
for down_dim_in, down_dim_out in down_in_out: | |
self.downs.append( | |
nn.ModuleList( | |
[ | |
block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim), | |
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), | |
block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim), | |
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), | |
Downsample(down_dim_out), | |
] | |
) | |
) | |
skip_dims.append(down_dim_out) | |
# bottleneck | |
mid_dim = down_dims[-1] | |
self.mid_left = nn.ModuleList([]) | |
self.mid_right = nn.ModuleList([]) | |
for _ in range(mid_depth - 1): | |
self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)) | |
self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim)) | |
self.mid_mid = nn.ModuleList( | |
[ | |
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), | |
Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))), | |
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), | |
] | |
) | |
# right layers | |
for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out): | |
skip_dim = skip_dims.pop() # down_dim_out | |
self.ups.append( | |
nn.ModuleList( | |
[ | |
# pop&cat (h/2, w/2, down_dim_out) | |
block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim), | |
Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))), | |
Upsample(up_dim_in), | |
# pop&cat (h, w, down_dim_out) | |
block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim), | |
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), | |
# pop&cat (h, w, down_dim_out) | |
block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim), | |
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), | |
] | |
) | |
) | |
self.final_conv = nn.Sequential( | |
block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1) | |
) | |
def size(self): | |
total_params = sum(p.numel() for p in self.parameters()) | |
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
print(f"Total parameters: {total_params}") | |
print(f"Trainable parameters: {trainable_params}") | |
def forward(self, x, time, condition=None): | |
if condition is not None: | |
condition_emb = self.label_embedding(condition) | |
else: | |
condition_emb = None | |
h = [] | |
x = self.init_conv(x) | |
h.append(x) | |
time_emb = self.time_mlp(time) if exists(self.time_mlp) else None | |
# downsample | |
for block1, attn1, block2, attn2, downsample in self.downs: | |
x = block1(x, time_emb) | |
x = attn1(x, condition_emb) | |
h.append(x) | |
x = block2(x, time_emb) | |
x = attn2(x, condition_emb) | |
h.append(x) | |
x = downsample(x) | |
h.append(x) | |
# bottleneck | |
for block in self.mid_left: | |
x = block(x, time_emb) | |
h.append(x) | |
(block1, attn, block2) = self.mid_mid | |
x = block1(x, time_emb) | |
x = attn(x, condition_emb) | |
x = block2(x, time_emb) | |
for block in self.mid_right: | |
# This is U-Net!!! | |
x = pad_and_concat(h.pop(), x) | |
x = block(x, time_emb) | |
# upsample | |
for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups: | |
x = pad_and_concat(h.pop(), x) | |
x = block1(x, time_emb) | |
x = attn1(x, condition_emb) | |
x = upsample(x) | |
x = pad_and_concat(h.pop(), x) | |
x = block2(x, time_emb) | |
x = attn2(x, condition_emb) | |
x = pad_and_concat(h.pop(), x) | |
x = block3(x, time_emb) | |
x = attn3(x, condition_emb) | |
x = pad_and_concat(h.pop(), x) | |
x = self.final_conv(x) | |
return x | |
def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, | |
noise=None, loss_type="l1"): | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, | |
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) | |
predicted_noise = denoise_model(x_noisy, t, condition) | |
if loss_type == 'l1': | |
loss = F.l1_loss(noise, predicted_noise) | |
elif loss_type == 'l2': | |
loss = F.mse_loss(noise, predicted_noise) | |
elif loss_type == "huber": | |
loss = F.smooth_l1_loss(noise, predicted_noise) | |
else: | |
raise NotImplementedError() | |
return loss | |
def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, | |
uncondition_rate, unconditional_condition): | |
model.to(device) | |
model.eval() | |
eva_loss = [] | |
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) | |
for i in xrange(500): | |
data, attributes = next(iter(iterator)) | |
data = data.to(device) | |
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] | |
selected_conditions = [ | |
unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample) | |
for conditions_of_one_sample in conditions] | |
selected_conditions = torch.stack(selected_conditions).float().to(device) | |
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() | |
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", | |
sqrt_alphas_cumprod=sqrt_alphas_cumprod, | |
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) | |
eva_loss.append(loss.item()) | |
initial_loss = np.mean(eva_loss) | |
return initial_loss | |
def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
UNet = ConditionedUnet(**model_Config) | |
print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}") | |
UNet.to(device) | |
if load_pretrain: | |
print(f"Loading weights from models/{model_name}_UNet.pth") | |
checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device) | |
UNet.load_state_dict(checkpoint['model_state_dict']) | |
UNet.eval() | |
return UNet | |
def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain, | |
encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None, | |
n_IS_batches=50): | |
if save_model_name is None: | |
save_model_name = init_model_name | |
def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss): | |
model_hyperparameter = unetConfig | |
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE | |
model_hyperparameter["lr"] = lr | |
model_hyperparameter["model_size"] = model_size | |
model_hyperparameter["current_iter"] = current_iter | |
model_hyperparameter["current_loss"] = current_loss | |
with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file: | |
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) | |
model = ConditionedUnet(**unetConfig) | |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Trainable parameters: {model_size}") | |
model.to(device) | |
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False) | |
if load_pretrain: | |
print(f"Loading weights from models/{init_model_name}_UNet.pt") | |
checkpoint = torch.load(f'models/{init_model_name}_UNet.pth') | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
else: | |
print("Model initialized.") | |
if max_iter == 0: | |
print("Return model directly.") | |
return model, optimizer | |
train_loss = [] | |
writer = SummaryWriter(f'runs/{save_model_name}_UNet') | |
if init_loss is None: | |
previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, | |
uncondition_rate, unconditional_condition) | |
else: | |
previous_loss = init_loss | |
print(f"initial_IS: {previous_loss}") | |
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) | |
model.train() | |
for i in xrange(max_iter): | |
data, attributes = next(iter(iterator)) | |
data = data.to(device) | |
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] | |
unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach() | |
selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice( | |
conditions_of_one_sample) for conditions_of_one_sample in conditions] | |
selected_conditions = torch.stack(selected_conditions).float().to(device) | |
optimizer.zero_grad() | |
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() | |
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", | |
sqrt_alphas_cumprod=sqrt_alphas_cumprod, | |
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) | |
loss.backward() | |
optimizer.step() | |
train_loss.append(loss.item()) | |
step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy()) | |
if step % 100 == 0: | |
print('%d step' % (step)) | |
if step % save_steps == 0: | |
current_loss = np.mean(train_loss[-save_steps:]) | |
print(f"current_loss = {current_loss}") | |
torch.save({ | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
}, f'models/{save_model_name}_UNet.pth') | |
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) | |
if step % 20000 == 0: | |
current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches, | |
positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT") | |
print('current_IS: %.5f' % current_IS) | |
current_loss = np.mean(train_loss[-save_steps:]) | |
writer.add_scalar(f"current_IS", current_IS, step) | |
torch.save({ | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
}, f'models/history/{save_model_name}_{step}_UNet.pth') | |
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) | |
return model, optimizer | |