import torch from model import SmoothDiffusionUNet from noise_scheduler import FrequencyAwareNoise from config import Config from torchvision.utils import save_image import numpy as np def test_model_quality(): """Test if the model can actually denoise""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model checkpoint = torch.load('model_final.pth', map_location=device) config = Config() model = SmoothDiffusionUNet(config).to(device) noise_scheduler = FrequencyAwareNoise(config) model.load_state_dict(checkpoint) model.eval() print("=== TESTING MODEL DENOISING ABILITY ===") with torch.no_grad(): # Create a simple test pattern x_clean = torch.zeros(1, 3, 64, 64, device=device) # Create clear patterns that should be easy to denoise x_clean[0, 0, 20:44, 20:44] = 1.0 # Red square x_clean[0, 1, 10:30, 40:60] = -1.0 # Green rectangle x_clean[0, 2, 35:50, 10:25] = 0.5 # Blue rectangle print(f"Created test pattern with range [{x_clean.min():.3f}, {x_clean.max():.3f}]") # Test at different noise levels test_timesteps = [50, 100, 200, 400] for t_val in test_timesteps: print(f"\n--- Testing at timestep {t_val} ---") t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long) # Add noise like in training x_noisy, noise_target = noise_scheduler.apply_noise(x_clean, t_tensor) # Get model prediction noise_pred = model(x_noisy, t_tensor) # Calculate accuracy mse = torch.mean((noise_pred - noise_target) ** 2) mae = torch.mean(torch.abs(noise_pred - noise_target)) print(f" Noisy image range: [{x_noisy.min():.3f}, {x_noisy.max():.3f}]") print(f" Target noise range: [{noise_target.min():.3f}, {noise_target.max():.3f}]") print(f" Predicted noise range: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]") print(f" MSE: {mse.item():.6f}") print(f" MAE: {mae.item():.6f}") # Try to reconstruct clean image alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t) x_reconstructed = torch.clamp(x_reconstructed, -1, 1) reconstruction_error = torch.mean((x_reconstructed - x_clean) ** 2) print(f" Reconstruction MSE: {reconstruction_error.item():.6f}") if mse.item() > 1.0: print(f" ❌ High prediction error - model didn't learn well") elif reconstruction_error.item() > 0.5: print(f" ⚠️ Poor reconstruction - model learned noise but not images") else: print(f" ✅ Good denoising performance") # Save test images print(f"\n=== SAVING TEST IMAGES ===") # Save original test pattern x_clean_display = (x_clean + 1) / 2 save_image(x_clean_display, "test_pattern_clean.png") print(f"Clean test pattern saved to test_pattern_clean.png") # Save heavily noised version t_heavy = torch.full((1,), 400, device=device, dtype=torch.long) x_heavy_noisy, _ = noise_scheduler.apply_noise(x_clean, t_heavy) x_heavy_display = torch.clamp((x_heavy_noisy + 1) / 2, 0, 1) save_image(x_heavy_display, "test_pattern_noisy.png") print(f"Noisy test pattern saved to test_pattern_noisy.png") # Try to denoise it noise_pred = model(x_heavy_noisy, t_heavy) alpha_bar_t = noise_scheduler.alpha_bars[400].item() x_denoised = (x_heavy_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t) x_denoised = torch.clamp(x_denoised, -1, 1) x_denoised_display = (x_denoised + 1) / 2 save_image(x_denoised_display, "test_pattern_denoised.png") print(f"Denoised test pattern saved to test_pattern_denoised.png") final_error = torch.mean((x_denoised - x_clean) ** 2) print(f"Final reconstruction error: {final_error.item():.6f}") if final_error.item() < 0.1: print("✅ Model can denoise simple patterns!") else: print("❌ Model cannot denoise - training was unsuccessful") if __name__ == "__main__": test_model_quality()