import matplotlib.pyplot as plt import numpy as np import torch def plot_losses(log_dir): """Plot training losses from TensorBoard logs""" # Note: In practice, you'd use TensorBoard directly pass def save_checkpoint(model, optimizer, epoch, path): torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, path) def load_checkpoint(model, optimizer, path): checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return checkpoint['epoch'] def show_samples(samples): """Display generated samples""" plt.figure(figsize=(10, 10)) plt.imshow(np.transpose(samples.numpy(), (1, 2, 0))) plt.axis('off') plt.show()