import torch import torchvision from torchvision.utils import save_image, make_grid import os from config import Config from model import SmoothDiffusionUNet from noise_scheduler import FrequencyAwareNoise from sample import frequency_aware_sample def test_latest_checkpoint(): """Test the latest checkpoint with frequency-aware sampling""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Find latest log directory log_dirs = [] if os.path.exists('./logs'): for item in os.listdir('./logs'): if os.path.isdir(os.path.join('./logs', item)): log_dirs.append(item) if not log_dirs: print("No log directories found!") return latest_log = sorted(log_dirs)[-1] log_path = os.path.join('./logs', latest_log) print(f"Testing latest log directory: {log_path}") # Find checkpoint files checkpoint_files = [] for file in os.listdir(log_path): if file.startswith('model_epoch_') and file.endswith('.pth'): epoch = int(file.split('_')[2].split('.')[0]) checkpoint_files.append((epoch, file)) if not checkpoint_files: print("No checkpoint files found!") return # Sort and get latest checkpoint checkpoint_files.sort() latest_epoch, latest_file = checkpoint_files[-1] checkpoint_path = os.path.join(log_path, latest_file) print(f"Testing checkpoint: {latest_file} (epoch {latest_epoch})") # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=device) # Initialize model and noise scheduler if 'config' in checkpoint: config = checkpoint['config'] else: config = Config() model = SmoothDiffusionUNet(config).to(device) noise_scheduler = FrequencyAwareNoise(config) # Load model state if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) epoch = checkpoint.get('epoch', 'unknown') loss = checkpoint.get('loss', 'unknown') print(f"Loaded model from epoch {epoch}, loss: {loss}") else: model.load_state_dict(checkpoint) print("Loaded model state dict") # Generate samples using frequency-aware sampling print("\n=== Generating samples with frequency-aware approach ===") try: samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=8) # Save the samples save_path = f"test_samples_epoch_{latest_epoch}_fixed.png" save_image(grid, save_path, normalize=False) print(f"Samples saved to: {save_path}") # Print sample statistics print(f"Sample statistics:") print(f" Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]") print(f" Mean: {samples.mean().item():.3f}") print(f" Std: {samples.std().item():.3f}") # Check if samples look like noise (all values close to 0 or very uniform) if samples.std().item() < 0.1: print("WARNING: Samples have very low variance - might be noise!") elif abs(samples.mean().item()) < 0.01 and samples.std().item() > 0.8: print("WARNING: Samples look like random noise!") else: print("Samples look reasonable!") except Exception as e: print(f"Error during sampling: {e}") import traceback traceback.print_exc() if __name__ == "__main__": test_latest_checkpoint()