Grad-CDM / test_simple.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
import torchvision
from torchvision.utils import save_image, make_grid
import os
import argparse
from datetime import datetime
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler_simple import FrequencyAwareNoise
from sample_simple import simple_sample
def load_model(checkpoint_path, device):
"""Load model from checkpoint"""
print(f"Loading model from: {checkpoint_path}")
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
# Initialize model and noise scheduler
if 'config' in checkpoint:
config = checkpoint['config']
else:
config = Config() # Fallback to default config
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
# Load model state
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint.get('epoch', 'unknown')
loss = checkpoint.get('loss', 'unknown')
print(f"Loaded model from epoch {epoch}, loss: {loss}")
else:
# Handle simple state dict (final model)
model.load_state_dict(checkpoint)
print("Loaded model state dict")
return model, noise_scheduler, config
def test_checkpoint(checkpoint_path, device, n_samples=16):
"""Test a single checkpoint with working sampler"""
model, noise_scheduler, config = load_model(checkpoint_path, device)
# Generate samples
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"test_samples_simple_{timestamp}.png"
print(f"Testing checkpoint with {n_samples} samples...")
samples, grid = simple_sample(model, noise_scheduler, device, n_samples=n_samples)
# Save the results
save_image(grid, save_path, normalize=False)
print(f"Samples saved to: {save_path}")
return samples, grid
def main():
parser = argparse.ArgumentParser(description='Test trained diffusion model (simple version)')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to checkpoint file')
parser.add_argument('--n_samples', type=int, default=16, help='Number of samples to generate')
parser.add_argument('--device', type=str, default='auto', help='Device to use (cuda/cpu/auto)')
args = parser.parse_args()
# Setup device
if args.device == 'auto':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# Test the checkpoint
print("=== Testing Checkpoint with Simple DDPM ===")
test_checkpoint(args.checkpoint, device, args.n_samples)
if __name__ == "__main__":
main()