|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
def plot_losses(log_dir): |
|
|
"""Plot training losses from TensorBoard logs""" |
|
|
|
|
|
pass |
|
|
|
|
|
def save_checkpoint(model, optimizer, epoch, path): |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
}, path) |
|
|
|
|
|
def load_checkpoint(model, optimizer, path): |
|
|
checkpoint = torch.load(path) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
return checkpoint['epoch'] |
|
|
|
|
|
def show_samples(samples): |
|
|
"""Display generated samples""" |
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(np.transpose(samples.numpy(), (1, 2, 0))) |
|
|
plt.axis('off') |
|
|
plt.show() |