Grad-CDM / test_quality.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image
import numpy as np
def test_model_quality():
"""Test if the model can actually denoise"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = torch.load('model_final.pth', map_location=device)
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
model.load_state_dict(checkpoint)
model.eval()
print("=== TESTING MODEL DENOISING ABILITY ===")
with torch.no_grad():
# Create a simple test pattern
x_clean = torch.zeros(1, 3, 64, 64, device=device)
# Create clear patterns that should be easy to denoise
x_clean[0, 0, 20:44, 20:44] = 1.0 # Red square
x_clean[0, 1, 10:30, 40:60] = -1.0 # Green rectangle
x_clean[0, 2, 35:50, 10:25] = 0.5 # Blue rectangle
print(f"Created test pattern with range [{x_clean.min():.3f}, {x_clean.max():.3f}]")
# Test at different noise levels
test_timesteps = [50, 100, 200, 400]
for t_val in test_timesteps:
print(f"\n--- Testing at timestep {t_val} ---")
t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long)
# Add noise like in training
x_noisy, noise_target = noise_scheduler.apply_noise(x_clean, t_tensor)
# Get model prediction
noise_pred = model(x_noisy, t_tensor)
# Calculate accuracy
mse = torch.mean((noise_pred - noise_target) ** 2)
mae = torch.mean(torch.abs(noise_pred - noise_target))
print(f" Noisy image range: [{x_noisy.min():.3f}, {x_noisy.max():.3f}]")
print(f" Target noise range: [{noise_target.min():.3f}, {noise_target.max():.3f}]")
print(f" Predicted noise range: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]")
print(f" MSE: {mse.item():.6f}")
print(f" MAE: {mae.item():.6f}")
# Try to reconstruct clean image
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t)
x_reconstructed = torch.clamp(x_reconstructed, -1, 1)
reconstruction_error = torch.mean((x_reconstructed - x_clean) ** 2)
print(f" Reconstruction MSE: {reconstruction_error.item():.6f}")
if mse.item() > 1.0:
print(f" ❌ High prediction error - model didn't learn well")
elif reconstruction_error.item() > 0.5:
print(f" ⚠️ Poor reconstruction - model learned noise but not images")
else:
print(f" βœ… Good denoising performance")
# Save test images
print(f"\n=== SAVING TEST IMAGES ===")
# Save original test pattern
x_clean_display = (x_clean + 1) / 2
save_image(x_clean_display, "test_pattern_clean.png")
print(f"Clean test pattern saved to test_pattern_clean.png")
# Save heavily noised version
t_heavy = torch.full((1,), 400, device=device, dtype=torch.long)
x_heavy_noisy, _ = noise_scheduler.apply_noise(x_clean, t_heavy)
x_heavy_display = torch.clamp((x_heavy_noisy + 1) / 2, 0, 1)
save_image(x_heavy_display, "test_pattern_noisy.png")
print(f"Noisy test pattern saved to test_pattern_noisy.png")
# Try to denoise it
noise_pred = model(x_heavy_noisy, t_heavy)
alpha_bar_t = noise_scheduler.alpha_bars[400].item()
x_denoised = (x_heavy_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t)
x_denoised = torch.clamp(x_denoised, -1, 1)
x_denoised_display = (x_denoised + 1) / 2
save_image(x_denoised_display, "test_pattern_denoised.png")
print(f"Denoised test pattern saved to test_pattern_denoised.png")
final_error = torch.mean((x_denoised - x_clean) ** 2)
print(f"Final reconstruction error: {final_error.item():.6f}")
if final_error.item() < 0.1:
print("βœ… Model can denoise simple patterns!")
else:
print("❌ Model cannot denoise - training was unsuccessful")
if __name__ == "__main__":
test_model_quality()