import json from torch.utils.tensorboard import SummaryWriter import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from six.moves import xrange from einops import rearrange from torchvision import models def Normalize(in_channels, num_groups=32, norm_type="groupnorm"): """Normalization layer""" if norm_type == "batchnorm": return torch.nn.BatchNorm2d(in_channels) return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) def nonlinearity(x, act_type="relu"): """Nonlinear activation function""" if act_type == "relu": return F.relu(x) else: # swish return x * torch.sigmoid(x) class VectorQuantizer(nn.Module): """Vector quantization layer""" def __init__(self, num_embeddings, embedding_dim, commitment_cost): super(VectorQuantizer, self).__init__() self._embedding_dim = embedding_dim self._num_embeddings = num_embeddings self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings) self._commitment_cost = commitment_cost def forward(self, inputs): # convert inputs from BCHW -> BHWC inputs = inputs.permute(0, 2, 3, 1).contiguous() input_shape = inputs.shape # Flatten input BCHW -> (BHW)C flat_input = inputs.view(-1, self._embedding_dim) # Calculate distances (input-embedding)^2 distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) + torch.sum(self._embedding.weight ** 2, dim=1) - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # Encoding (one-hot-encoding matrix) encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) encodings.scatter_(1, encoding_indices, 1) # Quantize and unflatten quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # Loss e_latent_loss = F.mse_loss(quantized.detach(), inputs) q_latent_loss = F.mse_loss(quantized, inputs.detach()) loss = q_latent_loss + self._commitment_cost * e_latent_loss quantized = inputs + (quantized - inputs).detach() avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # convert quantized from BHWC -> BCHW min_encodings, min_encoding_indices = None, None return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices) class VectorQuantizerEMA(nn.Module): """Vector quantization layer based on exponential moving average""" def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): super(VectorQuantizerEMA, self).__init__() self._embedding_dim = embedding_dim self._num_embeddings = num_embeddings self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) self._embedding.weight.data.normal_() self._commitment_cost = commitment_cost self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) self._ema_w.data.normal_() self._decay = decay self._epsilon = epsilon def forward(self, inputs): # convert inputs from BCHW -> BHWC inputs = inputs.permute(0, 2, 3, 1).contiguous() input_shape = inputs.shape # Flatten input flat_input = inputs.view(-1, self._embedding_dim) # Calculate distances distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) + torch.sum(self._embedding.weight ** 2, dim=1) - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # Encoding encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) encodings.scatter_(1, encoding_indices, 1) # Quantize and unflatten quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # Use EMA to update the embedding vectors if self.training: self._ema_cluster_size = self._ema_cluster_size * self._decay + \ (1 - self._decay) * torch.sum(encodings, 0) # Laplace smoothing of the cluster size n = torch.sum(self._ema_cluster_size.data) self._ema_cluster_size = ( (self._ema_cluster_size + self._epsilon) / (n + self._num_embeddings * self._epsilon) * n) dw = torch.matmul(encodings.t(), flat_input) self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) # Loss e_latent_loss = F.mse_loss(quantized.detach(), inputs) loss = self._commitment_cost * e_latent_loss # Straight Through Estimator quantized = inputs + (quantized - inputs).detach() avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # convert quantized from BHWC -> BCHW min_encodings, min_encoding_indices = None, None return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices) class DownSample(nn.Module): """DownSample layer""" def __init__(self, in_channels, out_channels): super(DownSample, self).__init__() self._conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1) def forward(self, x): return self._conv2d(x) class UpSample(nn.Module): """UpSample layer""" def __init__(self, in_channels, out_channels): super(UpSample, self).__init__() self._conv2d = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1) def forward(self, x): return self._conv2d(x) class ResnetBlock(nn.Module): """ResnetBlock is a combination of non-linearity, convolution, and normalization""" def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False, dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.act_type = act_type self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.double_conv = double_conv if self.double_conv: self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb=None): h = x h = self.norm1(h) h = nonlinearity(h, act_type=self.act_type) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None] if self.double_conv: h = self.norm2(h) h = nonlinearity(h, act_type=self.act_type) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class LinearAttention(nn.Module): """Efficient attention block based on """ def __init__(self, dim, heads=4, dim_head=32, with_skip=True): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) self.with_skip = with_skip if self.with_skip: self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) if self.with_skip: return self.to_out(out) + self.nin_shortcut(x) return self.to_out(out) class Encoder(nn.Module): """The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers.""" def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2, attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32): super(Encoder, self).__init__() if attn_pos is None: attn_pos = [] self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])]) current_channel = hidden_channels[0] for i in range(1, len(hidden_channels)): for _ in range(block_depth - 1): self._layers.append(ResnetBlock(in_channels=current_channel, out_channels=current_channel, double_conv=False, conv_shortcut=False, norm_type=norm_type, act_type=act_type, num_groups=num_groups)) if current_channel in attn_pos: self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) self._layers.append(nn.ReLU()) self._layers.append(DownSample(current_channel, hidden_channels[i])) current_channel = hidden_channels[i] for _ in range(block_depth - 1): self._layers.append(ResnetBlock(in_channels=current_channel, out_channels=current_channel, double_conv=False, conv_shortcut=False, norm_type=norm_type, act_type=act_type, num_groups=num_groups)) if current_channel in attn_pos: self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) # Conv1x1: hidden_channels[-1] -> embedding_dim self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) self._layers.append(nn.ReLU()) self._layers.append(nn.Conv2d(in_channels=current_channel, out_channels=embedding_dim, kernel_size=1, stride=1)) def forward(self, x): for layer in self._layers: x = layer(x) return x class Decoder(nn.Module): """The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers.""" def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2, attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32): super(Decoder, self).__init__() if attn_pos is None: attn_pos = [] reversed_hidden_channels = list(reversed(hidden_channels)) # Conv1x1: hidden_channels[-1] -> embedding_dim self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim, out_channels=reversed_hidden_channels[0], kernel_size=1, stride=1, bias=False)]) current_channel = reversed_hidden_channels[0] for _ in range(block_depth - 1): if current_channel in attn_pos: self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) self._layers.append(ResnetBlock(in_channels=current_channel, out_channels=current_channel, double_conv=False, conv_shortcut=False, norm_type=norm_type, act_type=act_type, num_groups=num_groups)) for i in range(1, len(reversed_hidden_channels)): self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) self._layers.append(nn.ReLU()) self._layers.append(UpSample(current_channel, reversed_hidden_channels[i])) current_channel = reversed_hidden_channels[i] for _ in range(block_depth - 1): if current_channel in attn_pos: self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) self._layers.append(ResnetBlock(in_channels=current_channel, out_channels=current_channel, double_conv=False, conv_shortcut=False, norm_type=norm_type, act_type=act_type, num_groups=num_groups)) self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) self._layers.append(nn.ReLU()) self._layers.append(UpSample(current_channel, current_channel)) # final layers self._layers.append(ResnetBlock(in_channels=current_channel, out_channels=out_channels, double_conv=False, conv_shortcut=False, norm_type=norm_type, act_type=act_type, num_groups=num_groups)) def forward(self, x): for layer in self._layers: x = layer(x) log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :]) cos_phase = torch.tanh(x[:, 1, :, :]) sin_phase = torch.tanh(x[:, 2, :, :]) x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1) return x class VQGAN_Discriminator(nn.Module): """The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier layer.""" def __init__(self, in_channels=1): super(VQGAN_Discriminator, self).__init__() resnet = models.resnet18(pretrained=True) # 修改第一层以接受单通道(黑白)图像 resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # 使用ResNet的特征提取部分 self.features = nn.Sequential(*list(resnet.children())[:-2]) # 添加判别器的额外层 self.classifier = nn.Sequential( nn.Linear(512, 1), nn.Sigmoid() ) def forward(self, x): x = self.features(x) x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) x = torch.flatten(x, 1) x = self.classifier(x) return x class VQGAN(nn.Module): """The VQ-GAN model. """ def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2, attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32): super(VQGAN, self).__init__() self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth, attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups) if decay > 0.0: self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay) else: self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost) self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth, attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type=act_type, num_groups=num_groups) def forward(self, x): z = self._encoder(x) quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z) x_recon = self._decoder(quantized) return vq_loss, x_recon, perplexity class ReconstructionLoss(nn.Module): def __init__(self, w1, w2, epsilon=1e-3): super(ReconstructionLoss, self).__init__() self.w1 = w1 self.w2 = w2 self.epsilon = epsilon def weighted_mae_loss(self, y_true, y_pred): # avoid divide by zero y_true_safe = torch.clamp(y_true, min=self.epsilon) # compute weighted MAE loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe) return loss def mae_loss(self, y_true, y_pred): loss = torch.mean(torch.abs(y_pred - y_true)) return loss def forward(self, y_pred, y_true): # loss for magnitude channel log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :]) # loss for phase channels phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :]) # sum up rec_loss = log_magnitude_loss + phase_loss return log_magnitude_loss, phase_loss, rec_loss def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig): model.to(trainingConfig["device"]) model.eval() train_res_error = [] for i in xrange(100): data = next(iter(iterator)) data = data.to(trainingConfig["device"]) # true/fake labels real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"]) vq_loss, data_recon, perplexity = model(data) fake_preds = discriminator(data_recon) adver_loss = adversarial_loss(fake_preds, real_labels) log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data) loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss train_res_error.append(loss.item()) initial_loss = np.mean(train_res_error) return initial_loss def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"): VQVAE = VQGAN(**model_Config) print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}") VQVAE.to(device) if load_pretrain: print(f"Loading weights from models/{model_name}_imageVQVAE.pth") checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device) VQVAE.load_state_dict(checkpoint['model_state_dict']) VQVAE.eval() return VQVAE def train_VQGAN(model_Config, trainingConfig, iterator): def save_model_hyperparameter(model_Config, trainingConfig, current_iter, log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss, current_loss): model_name = trainingConfig["model_name"] model_hyperparameter = model_Config model_hyperparameter.update(trainingConfig) model_hyperparameter["current_iter"] = current_iter model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss model_hyperparameter["phase_loss"] = phase_loss model_hyperparameter["erplexity"] = current_perplexity model_hyperparameter["vq_loss"] = current_vq_loss model_hyperparameter["total_loss"] = current_loss with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file: json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) # initialize VAE model = VQGAN(**model_Config) print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") model.to(trainingConfig["device"]) VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False) model_name = trainingConfig["model_name"] if trainingConfig["load_pretrain"]: print(f"Loading weights from models/{model_name}_imageVQVAE.pth") checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"]) model.load_state_dict(checkpoint['model_state_dict']) VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: print("VAE initialized.") if trainingConfig["max_iter"] == 0: print("Return VAE directly.") return model # initialize discriminator discriminator = VQGAN_Discriminator(model_Config["in_channels"]) print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}") discriminator.to(trainingConfig["device"]) discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False) if trainingConfig["load_pretrain"]: print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth") checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"]) discriminator.load_state_dict(checkpoint['model_state_dict']) discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: print("Discriminator initialized.") # Training train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], [] train_discriminator_loss, train_adverserial_loss = [], [] reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"]) adversarial_loss = nn.BCEWithLogitsLoss() writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4') previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig) print(f"initial_loss: {previous_lowest_loss}") model.train() for i in xrange(trainingConfig["max_iter"]): data = next(iter(iterator)) data = data.to(trainingConfig["device"]) # true/fake labels real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"]) fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"]) # update discriminator discriminator_optimizer.zero_grad() vq_loss, data_recon, perplexity = model(data) real_preds = discriminator(data) fake_preds = discriminator(data_recon.detach()) loss_real = adversarial_loss(real_preds, real_labels) loss_fake = adversarial_loss(fake_preds, fake_labels) loss_D = loss_real + loss_fake loss_D.backward() discriminator_optimizer.step() # update VQVAE VAE_optimizer.zero_grad() fake_preds = discriminator(data_recon) adver_loss = adversarial_loss(fake_preds, real_labels) log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data) loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss loss.backward() VAE_optimizer.step() train_discriminator_loss.append(loss_D.item()) train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item()) train_res_log_magnitude_loss.append(log_magnitude_loss.item()) train_res_phase_loss.append(phase_loss.item()) train_res_perplexity.append(perplexity.item()) train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item()) train_res_loss.append(loss.item()) step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy()) save_steps = trainingConfig["save_steps"] if (i + 1) % 100 == 0: print('%d step' % (step)) if (i + 1) % save_steps == 0: current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:]) current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:]) current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:]) current_phase_loss = np.mean(train_res_phase_loss[-save_steps:]) current_perplexity = np.mean(train_res_perplexity[-save_steps:]) current_vq_loss = np.mean(train_res_vq_loss[-save_steps:]) current_loss = np.mean(train_res_loss[-save_steps:]) print('discriminator_loss: %.3f' % current_discriminator_loss) print('adverserial_loss: %.3f' % current_adverserial_loss) print('log_magnitude_loss: %.3f' % current_log_magnitude_loss) print('phase_loss: %.3f' % current_phase_loss) print('perplexity: %.3f' % current_perplexity) print('vq_loss: %.3f' % current_vq_loss) print('total_loss: %.3f' % current_loss) writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step) writer.add_scalar(f"phase_loss", current_phase_loss, step) writer.add_scalar(f"perplexity", current_perplexity, step) writer.add_scalar(f"vq_loss", current_vq_loss, step) writer.add_scalar(f"total_loss", current_loss, step) if current_loss < previous_lowest_loss: previous_lowest_loss = current_loss torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': VAE_optimizer.state_dict(), }, f'models/{model_name}_imageVQVAE.pth') torch.save({ 'model_state_dict': discriminator.state_dict(), 'optimizer_state_dict': discriminator_optimizer.state_dict(), }, f'models/{model_name}_imageVQVAE_discriminator.pth') save_model_hyperparameter(model_Config, trainingConfig, step, current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss, current_loss) return model