import json from functools import partial import numpy as np import torch from torch import nn import torch.nn.functional as F from six.moves import xrange from torch.utils.tensorboard import SummaryWriter import random from metrics.IS import get_inception_score from tools import create_key from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \ PreNorm, \ Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \ LinearCrossAttention, LinearCrossAttentionAdd class ConditionedUnet(nn.Module): def __init__( self, in_dim, out_dim=None, down_dims=None, up_dims=None, mid_depth=3, with_time_emb=True, time_dim=None, resnet_block_groups=8, use_convnext=True, convnext_mult=2, attn_type="linear_cat", n_label_class=11, condition_type="instrument_family", label_emb_dim=128, ): super().__init__() self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type) if up_dims is None: up_dims = [128, 128, 64, 32] if down_dims is None: down_dims = [32, 32, 64, 128] out_dim = default(out_dim, in_dim) assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)" assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]" assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]" down_in_out = list(zip(down_dims[:-1], down_dims[1:])) up_in_out = list(zip(up_dims[:-1], up_dims[1:])) print(f"down_in_out: {down_in_out}") print(f"up_in_out: {up_in_out}") time_dim = default(time_dim, int(down_dims[0] * 4)) self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3) if use_convnext: block_klass = partial(ConvNextBlock, mult=convnext_mult) else: block_klass = partial(ResnetBlock, groups=resnet_block_groups) if attn_type == "linear_cat": attn_klass = partial(LinearCrossAttention) elif attn_type == "linear_add": attn_klass = partial(LinearCrossAttentionAdd) else: raise NotImplementedError() # time embeddings if with_time_emb: self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(down_dims[0]), nn.Linear(down_dims[0], time_dim), nn.GELU(), nn.Linear(time_dim, time_dim), ) else: time_dim = None self.time_mlp = None # left layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) skip_dims = [] for down_dim_in, down_dim_out in down_in_out: self.downs.append( nn.ModuleList( [ block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim), Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim), Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), Downsample(down_dim_out), ] ) ) skip_dims.append(down_dim_out) # bottleneck mid_dim = down_dims[-1] self.mid_left = nn.ModuleList([]) self.mid_right = nn.ModuleList([]) for _ in range(mid_depth - 1): self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)) self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim)) self.mid_mid = nn.ModuleList( [ block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))), block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), ] ) # right layers for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out): skip_dim = skip_dims.pop() # down_dim_out self.ups.append( nn.ModuleList( [ # pop&cat (h/2, w/2, down_dim_out) block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim), Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))), Upsample(up_dim_in), # pop&cat (h, w, down_dim_out) block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim), Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), # pop&cat (h, w, down_dim_out) block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim), Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), ] ) ) self.final_conv = nn.Sequential( block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1) ) def size(self): total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"Total parameters: {total_params}") print(f"Trainable parameters: {trainable_params}") def forward(self, x, time, condition=None): if condition is not None: condition_emb = self.label_embedding(condition) else: condition_emb = None h = [] x = self.init_conv(x) h.append(x) time_emb = self.time_mlp(time) if exists(self.time_mlp) else None # downsample for block1, attn1, block2, attn2, downsample in self.downs: x = block1(x, time_emb) x = attn1(x, condition_emb) h.append(x) x = block2(x, time_emb) x = attn2(x, condition_emb) h.append(x) x = downsample(x) h.append(x) # bottleneck for block in self.mid_left: x = block(x, time_emb) h.append(x) (block1, attn, block2) = self.mid_mid x = block1(x, time_emb) x = attn(x, condition_emb) x = block2(x, time_emb) for block in self.mid_right: # This is U-Net!!! x = pad_and_concat(h.pop(), x) x = block(x, time_emb) # upsample for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups: x = pad_and_concat(h.pop(), x) x = block1(x, time_emb) x = attn1(x, condition_emb) x = upsample(x) x = pad_and_concat(h.pop(), x) x = block2(x, time_emb) x = attn2(x, condition_emb) x = pad_and_concat(h.pop(), x) x = block3(x, time_emb) x = attn3(x, condition_emb) x = pad_and_concat(h.pop(), x) x = self.final_conv(x) return x def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None, loss_type="l1"): if noise is None: noise = torch.randn_like(x_start) x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) predicted_noise = denoise_model(x_noisy, t, condition) if loss_type == 'l1': loss = F.l1_loss(noise, predicted_noise) elif loss_type == 'l2': loss = F.mse_loss(noise, predicted_noise) elif loss_type == "huber": loss = F.smooth_l1_loss(noise, predicted_noise) else: raise NotImplementedError() return loss def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, uncondition_rate, unconditional_condition): model.to(device) model.eval() eva_loss = [] sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) for i in xrange(500): data, attributes = next(iter(iterator)) data = data.to(device) conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] selected_conditions = [ unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions] selected_conditions = torch.stack(selected_conditions).float().to(device) t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) eva_loss.append(loss.item()) initial_loss = np.mean(eva_loss) return initial_loss def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"): UNet = ConditionedUnet(**model_Config) print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}") UNet.to(device) if load_pretrain: print(f"Loading weights from models/{model_name}_UNet.pth") checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device) UNet.load_state_dict(checkpoint['model_state_dict']) UNet.eval() return UNet def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain, encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None, n_IS_batches=50): if save_model_name is None: save_model_name = init_model_name def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss): model_hyperparameter = unetConfig model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE model_hyperparameter["lr"] = lr model_hyperparameter["model_size"] = model_size model_hyperparameter["current_iter"] = current_iter model_hyperparameter["current_loss"] = current_loss with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file: json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) model = ConditionedUnet(**unetConfig) model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Trainable parameters: {model_size}") model.to(device) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False) if load_pretrain: print(f"Loading weights from models/{init_model_name}_UNet.pt") checkpoint = torch.load(f'models/{init_model_name}_UNet.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: print("Model initialized.") if max_iter == 0: print("Return model directly.") return model, optimizer train_loss = [] writer = SummaryWriter(f'runs/{save_model_name}_UNet') if init_loss is None: previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, uncondition_rate, unconditional_condition) else: previous_loss = init_loss print(f"initial_IS: {previous_loss}") sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) model.train() for i in xrange(max_iter): data, attributes = next(iter(iterator)) data = data.to(device) conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach() selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice( conditions_of_one_sample) for conditions_of_one_sample in conditions] selected_conditions = torch.stack(selected_conditions).float().to(device) optimizer.zero_grad() t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) loss.backward() optimizer.step() train_loss.append(loss.item()) step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy()) if step % 100 == 0: print('%d step' % (step)) if step % save_steps == 0: current_loss = np.mean(train_loss[-save_steps:]) print(f"current_loss = {current_loss}") torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, f'models/{save_model_name}_UNet.pth') save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) if step % 20000 == 0: current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches, positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT") print('current_IS: %.5f' % current_IS) current_loss = np.mean(train_loss[-save_steps:]) writer.add_scalar(f"current_IS", current_IS, step) torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, f'models/history/{save_model_name}_{step}_UNet.pth') save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) return model, optimizer