import torch import torchvision from torchvision.utils import save_image, make_grid import os import argparse from datetime import datetime from config import Config from model import SmoothDiffusionUNet from noise_scheduler_simple import FrequencyAwareNoise from sample_simple import simple_sample def load_model(checkpoint_path, device): """Load model from checkpoint""" print(f"Loading model from: {checkpoint_path}") # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=device) # Initialize model and noise scheduler if 'config' in checkpoint: config = checkpoint['config'] else: config = Config() # Fallback to default config model = SmoothDiffusionUNet(config).to(device) noise_scheduler = FrequencyAwareNoise(config) # Load model state if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) epoch = checkpoint.get('epoch', 'unknown') loss = checkpoint.get('loss', 'unknown') print(f"Loaded model from epoch {epoch}, loss: {loss}") else: # Handle simple state dict (final model) model.load_state_dict(checkpoint) print("Loaded model state dict") return model, noise_scheduler, config def test_checkpoint(checkpoint_path, device, n_samples=16): """Test a single checkpoint with working sampler""" model, noise_scheduler, config = load_model(checkpoint_path, device) # Generate samples timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = f"test_samples_simple_{timestamp}.png" print(f"Testing checkpoint with {n_samples} samples...") samples, grid = simple_sample(model, noise_scheduler, device, n_samples=n_samples) # Save the results save_image(grid, save_path, normalize=False) print(f"Samples saved to: {save_path}") return samples, grid def main(): parser = argparse.ArgumentParser(description='Test trained diffusion model (simple version)') parser.add_argument('--checkpoint', type=str, required=True, help='Path to checkpoint file') parser.add_argument('--n_samples', type=int, default=16, help='Number of samples to generate') parser.add_argument('--device', type=str, default='auto', help='Device to use (cuda/cpu/auto)') args = parser.parse_args() # Setup device if args.device == 'auto': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) print(f"Using device: {device}") # Test the checkpoint print("=== Testing Checkpoint with Simple DDPM ===") test_checkpoint(args.checkpoint, device, args.n_samples) if __name__ == "__main__": main()