WeixuanYuan's picture
Upload 66 files
ae1bdf7 verified
raw
history blame
30.4 kB
import json
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from six.moves import xrange
from einops import rearrange
from torchvision import models
def Normalize(in_channels, num_groups=32, norm_type="groupnorm"):
"""Normalization layer"""
if norm_type == "batchnorm":
return torch.nn.BatchNorm2d(in_channels)
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
def nonlinearity(x, act_type="relu"):
"""Nonlinear activation function"""
if act_type == "relu":
return F.relu(x)
else:
# swish
return x * torch.sigmoid(x)
class VectorQuantizer(nn.Module):
"""Vector quantization layer"""
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
self._commitment_cost = commitment_cost
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input BCHW -> (BHW)C
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances (input-embedding)^2
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight ** 2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding (one-hot-encoding matrix)
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self._commitment_cost * e_latent_loss
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BHWC -> BCHW
min_encodings, min_encoding_indices = None, None
return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
class VectorQuantizerEMA(nn.Module):
"""Vector quantization layer based on exponential moving average"""
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
super(VectorQuantizerEMA, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.normal_()
self._commitment_cost = commitment_cost
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
self._ema_w.data.normal_()
self._decay = decay
self._epsilon = epsilon
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight ** 2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Use EMA to update the embedding vectors
if self.training:
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
(1 - self._decay) * torch.sum(encodings, 0)
# Laplace smoothing of the cluster size
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size = (
(self._ema_cluster_size + self._epsilon)
/ (n + self._num_embeddings * self._epsilon) * n)
dw = torch.matmul(encodings.t(), flat_input)
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
loss = self._commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BHWC -> BCHW
min_encodings, min_encoding_indices = None, None
return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
class DownSample(nn.Module):
"""DownSample layer"""
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self._conv2d = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=4,
stride=2, padding=1)
def forward(self, x):
return self._conv2d(x)
class UpSample(nn.Module):
"""UpSample layer"""
def __init__(self, in_channels, out_channels):
super(UpSample, self).__init__()
self._conv2d = nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=4,
stride=2, padding=1)
def forward(self, x):
return self._conv2d(x)
class ResnetBlock(nn.Module):
"""ResnetBlock is a combination of non-linearity, convolution, and normalization"""
def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False,
dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.act_type = act_type
self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.double_conv = double_conv
if self.double_conv:
self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = nonlinearity(h, act_type=self.act_type)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None]
if self.double_conv:
h = self.norm2(h)
h = nonlinearity(h, act_type=self.act_type)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class LinearAttention(nn.Module):
"""Efficient attention block based on <https://proceedings.mlr.press/v119/katharopoulos20a.html>"""
def __init__(self, dim, heads=4, dim_head=32, with_skip=True):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
self.with_skip = with_skip
if self.with_skip:
self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
if self.with_skip:
return self.to_out(out) + self.nin_shortcut(x)
return self.to_out(out)
class Encoder(nn.Module):
"""The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers."""
def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2,
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32):
super(Encoder, self).__init__()
if attn_pos is None:
attn_pos = []
self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])])
current_channel = hidden_channels[0]
for i in range(1, len(hidden_channels)):
for _ in range(block_depth - 1):
self._layers.append(ResnetBlock(in_channels=current_channel,
out_channels=current_channel,
double_conv=False,
conv_shortcut=False,
norm_type=norm_type,
act_type=act_type,
num_groups=num_groups))
if current_channel in attn_pos:
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
self._layers.append(nn.ReLU())
self._layers.append(DownSample(current_channel, hidden_channels[i]))
current_channel = hidden_channels[i]
for _ in range(block_depth - 1):
self._layers.append(ResnetBlock(in_channels=current_channel,
out_channels=current_channel,
double_conv=False,
conv_shortcut=False,
norm_type=norm_type,
act_type=act_type,
num_groups=num_groups))
if current_channel in attn_pos:
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
# Conv1x1: hidden_channels[-1] -> embedding_dim
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
self._layers.append(nn.ReLU())
self._layers.append(nn.Conv2d(in_channels=current_channel,
out_channels=embedding_dim,
kernel_size=1,
stride=1))
def forward(self, x):
for layer in self._layers:
x = layer(x)
return x
class Decoder(nn.Module):
"""The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers."""
def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2,
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
num_groups=32):
super(Decoder, self).__init__()
if attn_pos is None:
attn_pos = []
reversed_hidden_channels = list(reversed(hidden_channels))
# Conv1x1: hidden_channels[-1] -> embedding_dim
self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim,
out_channels=reversed_hidden_channels[0],
kernel_size=1, stride=1, bias=False)])
current_channel = reversed_hidden_channels[0]
for _ in range(block_depth - 1):
if current_channel in attn_pos:
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
self._layers.append(ResnetBlock(in_channels=current_channel,
out_channels=current_channel,
double_conv=False,
conv_shortcut=False,
norm_type=norm_type,
act_type=act_type,
num_groups=num_groups))
for i in range(1, len(reversed_hidden_channels)):
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
self._layers.append(nn.ReLU())
self._layers.append(UpSample(current_channel, reversed_hidden_channels[i]))
current_channel = reversed_hidden_channels[i]
for _ in range(block_depth - 1):
if current_channel in attn_pos:
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
self._layers.append(ResnetBlock(in_channels=current_channel,
out_channels=current_channel,
double_conv=False,
conv_shortcut=False,
norm_type=norm_type,
act_type=act_type,
num_groups=num_groups))
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
self._layers.append(nn.ReLU())
self._layers.append(UpSample(current_channel, current_channel))
# final layers
self._layers.append(ResnetBlock(in_channels=current_channel,
out_channels=out_channels,
double_conv=False,
conv_shortcut=False,
norm_type=norm_type,
act_type=act_type,
num_groups=num_groups))
def forward(self, x):
for layer in self._layers:
x = layer(x)
log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :])
cos_phase = torch.tanh(x[:, 1, :, :])
sin_phase = torch.tanh(x[:, 2, :, :])
x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1)
return x
class VQGAN_Discriminator(nn.Module):
"""The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional
layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier
layer."""
def __init__(self, in_channels=1):
super(VQGAN_Discriminator, self).__init__()
resnet = models.resnet18(pretrained=True)
# 修改第一层以接受单通道(黑白)图像
resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# 使用ResNet的特征提取部分
self.features = nn.Sequential(*list(resnet.children())[:-2])
# 添加判别器的额外层
self.classifier = nn.Sequential(
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.features(x)
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class VQGAN(nn.Module):
"""The VQ-GAN model. <https://openaccess.thecvf.com/content/CVPR2021/html/Esser_Taming_Transformers_for_High-Resolution_Image_Synthesis_CVPR_2021_paper.html?ref=>"""
def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2,
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32):
super(VQGAN, self).__init__()
self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth,
attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups)
if decay > 0.0:
self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
commitment_cost, decay)
else:
self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
commitment_cost)
self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth,
attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type,
act_type=act_type, num_groups=num_groups)
def forward(self, x):
z = self._encoder(x)
quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z)
x_recon = self._decoder(quantized)
return vq_loss, x_recon, perplexity
class ReconstructionLoss(nn.Module):
def __init__(self, w1, w2, epsilon=1e-3):
super(ReconstructionLoss, self).__init__()
self.w1 = w1
self.w2 = w2
self.epsilon = epsilon
def weighted_mae_loss(self, y_true, y_pred):
# avoid divide by zero
y_true_safe = torch.clamp(y_true, min=self.epsilon)
# compute weighted MAE
loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe)
return loss
def mae_loss(self, y_true, y_pred):
loss = torch.mean(torch.abs(y_pred - y_true))
return loss
def forward(self, y_pred, y_true):
# loss for magnitude channel
log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :])
# loss for phase channels
phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :])
# sum up
rec_loss = log_magnitude_loss + phase_loss
return log_magnitude_loss, phase_loss, rec_loss
def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig):
model.to(trainingConfig["device"])
model.eval()
train_res_error = []
for i in xrange(100):
data = next(iter(iterator))
data = data.to(trainingConfig["device"])
# true/fake labels
real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
vq_loss, data_recon, perplexity = model(data)
fake_preds = discriminator(data_recon)
adver_loss = adversarial_loss(fake_preds, real_labels)
log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
train_res_error.append(loss.item())
initial_loss = np.mean(train_res_error)
return initial_loss
def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"):
VQVAE = VQGAN(**model_Config)
print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}")
VQVAE.to(device)
if load_pretrain:
print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device)
VQVAE.load_state_dict(checkpoint['model_state_dict'])
VQVAE.eval()
return VQVAE
def train_VQGAN(model_Config, trainingConfig, iterator):
def save_model_hyperparameter(model_Config, trainingConfig, current_iter,
log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss,
current_loss):
model_name = trainingConfig["model_name"]
model_hyperparameter = model_Config
model_hyperparameter.update(trainingConfig)
model_hyperparameter["current_iter"] = current_iter
model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss
model_hyperparameter["phase_loss"] = phase_loss
model_hyperparameter["erplexity"] = current_perplexity
model_hyperparameter["vq_loss"] = current_vq_loss
model_hyperparameter["total_loss"] = current_loss
with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file:
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
# initialize VAE
model = VQGAN(**model_Config)
print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
model.to(trainingConfig["device"])
VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False)
model_name = trainingConfig["model_name"]
if trainingConfig["load_pretrain"]:
print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"])
model.load_state_dict(checkpoint['model_state_dict'])
VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
print("VAE initialized.")
if trainingConfig["max_iter"] == 0:
print("Return VAE directly.")
return model
# initialize discriminator
discriminator = VQGAN_Discriminator(model_Config["in_channels"])
print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
discriminator.to(trainingConfig["device"])
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False)
if trainingConfig["load_pretrain"]:
print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth")
checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"])
discriminator.load_state_dict(checkpoint['model_state_dict'])
discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
print("Discriminator initialized.")
# Training
train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], []
train_discriminator_loss, train_adverserial_loss = [], []
reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"])
adversarial_loss = nn.BCEWithLogitsLoss()
writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4')
previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator,
reconstructionLoss, adversarial_loss, trainingConfig)
print(f"initial_loss: {previous_lowest_loss}")
model.train()
for i in xrange(trainingConfig["max_iter"]):
data = next(iter(iterator))
data = data.to(trainingConfig["device"])
# true/fake labels
real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"])
# update discriminator
discriminator_optimizer.zero_grad()
vq_loss, data_recon, perplexity = model(data)
real_preds = discriminator(data)
fake_preds = discriminator(data_recon.detach())
loss_real = adversarial_loss(real_preds, real_labels)
loss_fake = adversarial_loss(fake_preds, fake_labels)
loss_D = loss_real + loss_fake
loss_D.backward()
discriminator_optimizer.step()
# update VQVAE
VAE_optimizer.zero_grad()
fake_preds = discriminator(data_recon)
adver_loss = adversarial_loss(fake_preds, real_labels)
log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
loss.backward()
VAE_optimizer.step()
train_discriminator_loss.append(loss_D.item())
train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item())
train_res_log_magnitude_loss.append(log_magnitude_loss.item())
train_res_phase_loss.append(phase_loss.item())
train_res_perplexity.append(perplexity.item())
train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item())
train_res_loss.append(loss.item())
step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
save_steps = trainingConfig["save_steps"]
if (i + 1) % 100 == 0:
print('%d step' % (step))
if (i + 1) % save_steps == 0:
current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:])
current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:])
current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:])
current_phase_loss = np.mean(train_res_phase_loss[-save_steps:])
current_perplexity = np.mean(train_res_perplexity[-save_steps:])
current_vq_loss = np.mean(train_res_vq_loss[-save_steps:])
current_loss = np.mean(train_res_loss[-save_steps:])
print('discriminator_loss: %.3f' % current_discriminator_loss)
print('adverserial_loss: %.3f' % current_adverserial_loss)
print('log_magnitude_loss: %.3f' % current_log_magnitude_loss)
print('phase_loss: %.3f' % current_phase_loss)
print('perplexity: %.3f' % current_perplexity)
print('vq_loss: %.3f' % current_vq_loss)
print('total_loss: %.3f' % current_loss)
writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step)
writer.add_scalar(f"phase_loss", current_phase_loss, step)
writer.add_scalar(f"perplexity", current_perplexity, step)
writer.add_scalar(f"vq_loss", current_vq_loss, step)
writer.add_scalar(f"total_loss", current_loss, step)
if current_loss < previous_lowest_loss:
previous_lowest_loss = current_loss
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': VAE_optimizer.state_dict(),
}, f'models/{model_name}_imageVQVAE.pth')
torch.save({
'model_state_dict': discriminator.state_dict(),
'optimizer_state_dict': discriminator_optimizer.state_dict(),
}, f'models/{model_name}_imageVQVAE_discriminator.pth')
save_model_hyperparameter(model_Config, trainingConfig, step,
current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss,
current_loss)
return model