import torch from model import SmoothDiffusionUNet from noise_scheduler import FrequencyAwareNoise from config import Config from torchvision.utils import save_image, make_grid from dataloader import get_dataloaders import numpy as np def diagnose_and_fix(): """Final diagnosis and alternative sampling approach""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model checkpoint = torch.load('model_final.pth', map_location=device) config = Config() model = SmoothDiffusionUNet(config).to(device) noise_scheduler = FrequencyAwareNoise(config) model.load_state_dict(checkpoint) model.eval() print("=== FINAL DIAGNOSIS ===") # Load some real training data to compare train_loader, _ = get_dataloaders(config) real_batch, _ = next(iter(train_loader)) real_images = real_batch[:4].to(device) print(f"Real training data range: [{real_images.min():.3f}, {real_images.max():.3f}]") print(f"Real training data mean: {real_images.mean():.3f}, std: {real_images.std():.3f}") # Save real images for comparison real_display = torch.clamp((real_images + 1) / 2, 0, 1) real_grid = make_grid(real_display, nrow=2, normalize=False, pad_value=1.0) save_image(real_grid, "real_training_images.png") print("Real training images saved to real_training_images.png") with torch.no_grad(): # Test model on real data at different noise levels print("\n=== TESTING MODEL ON REAL DATA ===") for t_val in [50, 200, 400]: t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) # Add noise to real image x_noisy, noise_target = noise_scheduler.apply_noise(real_images, t_tensor) # Get model prediction noise_pred = model(x_noisy, t_tensor) # Try to reconstruct alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t) x_reconstructed = torch.clamp(x_reconstructed, -1, 1) print(f"\nTimestep {t_val}:") print(f" Reconstruction error: {torch.mean((x_reconstructed - real_images) ** 2).item():.6f}") # Save reconstruction recon_display = torch.clamp((x_reconstructed + 1) / 2, 0, 1) recon_grid = make_grid(recon_display, nrow=2, normalize=False) save_image(recon_grid, f"reconstruction_t{t_val}.png") print(f" Reconstruction saved to reconstruction_t{t_val}.png") print("\n=== TRYING INTERPOLATION SAMPLING ===") # Instead of starting from pure noise, interpolate between real images x1 = real_images[0:1] # First real image x2 = real_images[1:2] # Second real image # Create interpolations alphas = torch.linspace(0, 1, 4, device=device).view(-1, 1, 1, 1) x_interp = torch.cat([ alpha * x1 + (1 - alpha) * x2 for alpha in alphas ], dim=0) print(f"Starting from real image interpolation...") print(f"Interpolation range: [{x_interp.min():.3f}, {x_interp.max():.3f}]") # Apply light denoising starting from these interpolated real images timesteps = [100, 80, 60, 40, 25, 15, 8, 3, 1] x = x_interp.clone() for t_val in timesteps: t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) # Get model prediction predicted_noise = model(x, t_tensor) # Apply denoising step alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.3) / np.sqrt(alpha_bar_t) # Gentle denoising x = torch.clamp(x, -1, 1) print(f"Interpolation result range: [{x.min():.3f}, {x.max():.3f}]") # Save interpolation result interp_display = torch.clamp((x + 1) / 2, 0, 1) interp_grid = make_grid(interp_display, nrow=2, normalize=False) save_image(interp_grid, "interpolation_sampling.png") print("Interpolation sampling saved to interpolation_sampling.png") print("\n=== TRYING MINIMAL NOISE SAMPLING ===") # Start from very light noise around zero x_minimal = torch.randn(4, 3, 64, 64, device=device) * 0.1 # Very light noise # Apply just a few denoising steps light_timesteps = [50, 30, 15, 5, 1] for t_val in light_timesteps: t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) # Get model prediction predicted_noise = model(x_minimal, t_tensor) # Light denoising alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() x_minimal = (x_minimal - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t) x_minimal = torch.clamp(x_minimal, -1, 1) print(f"Minimal noise result range: [{x_minimal.min():.3f}, {x_minimal.max():.3f}]") print(f"Minimal noise result std: {x_minimal.std():.3f}") # Save minimal noise result minimal_display = torch.clamp((x_minimal + 1) / 2, 0, 1) minimal_grid = make_grid(minimal_display, nrow=2, normalize=False) save_image(minimal_grid, "minimal_noise_sampling.png") print("Minimal noise sampling saved to minimal_noise_sampling.png") print("\n=== SUMMARY ===") print("Generated files:") print("- real_training_images.png (what we want to achieve)") print("- reconstruction_t*.png (model's denoising ability)") print("- interpolation_sampling.png (interpolation approach)") print("- minimal_noise_sampling.png (light noise approach)") if __name__ == "__main__": diagnose_and_fix()