|
import torch |
|
from model import SmoothDiffusionUNet |
|
from noise_scheduler import FrequencyAwareNoise |
|
from config import Config |
|
from torchvision.utils import save_image |
|
import numpy as np |
|
|
|
def test_model_quality(): |
|
"""Test if the model can actually denoise""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
checkpoint = torch.load('model_final.pth', map_location=device) |
|
config = Config() |
|
|
|
model = SmoothDiffusionUNet(config).to(device) |
|
noise_scheduler = FrequencyAwareNoise(config) |
|
model.load_state_dict(checkpoint) |
|
model.eval() |
|
|
|
print("=== TESTING MODEL DENOISING ABILITY ===") |
|
|
|
with torch.no_grad(): |
|
|
|
x_clean = torch.zeros(1, 3, 64, 64, device=device) |
|
|
|
|
|
x_clean[0, 0, 20:44, 20:44] = 1.0 |
|
x_clean[0, 1, 10:30, 40:60] = -1.0 |
|
x_clean[0, 2, 35:50, 10:25] = 0.5 |
|
|
|
print(f"Created test pattern with range [{x_clean.min():.3f}, {x_clean.max():.3f}]") |
|
|
|
|
|
test_timesteps = [50, 100, 200, 400] |
|
|
|
for t_val in test_timesteps: |
|
print(f"\n--- Testing at timestep {t_val} ---") |
|
|
|
t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long) |
|
|
|
|
|
x_noisy, noise_target = noise_scheduler.apply_noise(x_clean, t_tensor) |
|
|
|
|
|
noise_pred = model(x_noisy, t_tensor) |
|
|
|
|
|
mse = torch.mean((noise_pred - noise_target) ** 2) |
|
mae = torch.mean(torch.abs(noise_pred - noise_target)) |
|
|
|
print(f" Noisy image range: [{x_noisy.min():.3f}, {x_noisy.max():.3f}]") |
|
print(f" Target noise range: [{noise_target.min():.3f}, {noise_target.max():.3f}]") |
|
print(f" Predicted noise range: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]") |
|
print(f" MSE: {mse.item():.6f}") |
|
print(f" MAE: {mae.item():.6f}") |
|
|
|
|
|
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
|
x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t) |
|
x_reconstructed = torch.clamp(x_reconstructed, -1, 1) |
|
|
|
reconstruction_error = torch.mean((x_reconstructed - x_clean) ** 2) |
|
print(f" Reconstruction MSE: {reconstruction_error.item():.6f}") |
|
|
|
if mse.item() > 1.0: |
|
print(f" β High prediction error - model didn't learn well") |
|
elif reconstruction_error.item() > 0.5: |
|
print(f" β οΈ Poor reconstruction - model learned noise but not images") |
|
else: |
|
print(f" β
Good denoising performance") |
|
|
|
|
|
print(f"\n=== SAVING TEST IMAGES ===") |
|
|
|
|
|
x_clean_display = (x_clean + 1) / 2 |
|
save_image(x_clean_display, "test_pattern_clean.png") |
|
print(f"Clean test pattern saved to test_pattern_clean.png") |
|
|
|
|
|
t_heavy = torch.full((1,), 400, device=device, dtype=torch.long) |
|
x_heavy_noisy, _ = noise_scheduler.apply_noise(x_clean, t_heavy) |
|
x_heavy_display = torch.clamp((x_heavy_noisy + 1) / 2, 0, 1) |
|
save_image(x_heavy_display, "test_pattern_noisy.png") |
|
print(f"Noisy test pattern saved to test_pattern_noisy.png") |
|
|
|
|
|
noise_pred = model(x_heavy_noisy, t_heavy) |
|
alpha_bar_t = noise_scheduler.alpha_bars[400].item() |
|
x_denoised = (x_heavy_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t) |
|
x_denoised = torch.clamp(x_denoised, -1, 1) |
|
x_denoised_display = (x_denoised + 1) / 2 |
|
save_image(x_denoised_display, "test_pattern_denoised.png") |
|
print(f"Denoised test pattern saved to test_pattern_denoised.png") |
|
|
|
final_error = torch.mean((x_denoised - x_clean) ** 2) |
|
print(f"Final reconstruction error: {final_error.item():.6f}") |
|
|
|
if final_error.item() < 0.1: |
|
print("β
Model can denoise simple patterns!") |
|
else: |
|
print("β Model cannot denoise - training was unsuccessful") |
|
|
|
if __name__ == "__main__": |
|
test_model_quality() |
|
|