Grad-CDM / utils.py
nazgut's picture
Upload 24 files
8abfb97 verified
raw
history blame
843 Bytes
import matplotlib.pyplot as plt
import numpy as np
import torch
def plot_losses(log_dir):
"""Plot training losses from TensorBoard logs"""
# Note: In practice, you'd use TensorBoard directly
pass
def save_checkpoint(model, optimizer, epoch, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
def show_samples(samples):
"""Display generated samples"""
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(samples.numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()