|
|
|
""" |
|
Simple Metrics Evaluation for Frequency-Aware Super-Denoiser |
|
============================================================ |
|
Calculates PSNR, SSIM, and MSE metrics using existing sampling methods |
|
""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
from skimage.metrics import structural_similarity as ssim |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from model import SmoothDiffusionUNet |
|
from noise_scheduler import FrequencyAwareNoise |
|
from config import Config |
|
from dataloader import get_dataloaders |
|
from sample import frequency_aware_sample |
|
|
|
def calculate_psnr(img1, img2, max_val=2.0): |
|
"""Calculate PSNR between two images""" |
|
mse = F.mse_loss(img1, img2) |
|
if mse == 0: |
|
return float('inf') |
|
return 20 * torch.log10(torch.tensor(max_val) / torch.sqrt(mse)) |
|
|
|
def calculate_ssim(img1, img2): |
|
"""Calculate SSIM between two images""" |
|
|
|
img1_np = img1.detach().cpu().numpy().transpose(1, 2, 0) |
|
img2_np = img2.detach().cpu().numpy().transpose(1, 2, 0) |
|
|
|
|
|
img1_np = (img1_np + 1) / 2 |
|
img2_np = (img2_np + 1) / 2 |
|
img1_np = np.clip(img1_np, 0, 1) |
|
img2_np = np.clip(img2_np, 0, 1) |
|
|
|
return ssim(img1_np, img2_np, multichannel=True, channel_axis=2, data_range=1.0) |
|
|
|
def add_noise(image, noise_level=0.2): |
|
"""Add Gaussian noise to images""" |
|
noise = torch.randn_like(image) * noise_level |
|
return torch.clamp(image + noise, -1, 1) |
|
|
|
def evaluate_model(): |
|
"""Simplified model evaluation using existing sampling methods""" |
|
print("π FREQUENCY-AWARE SUPER-DENOISER METRICS EVALUATION") |
|
print("=" * 60) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
config = Config() |
|
|
|
|
|
model = SmoothDiffusionUNet(config).to(device) |
|
if os.path.exists('model_final.pth'): |
|
checkpoint = torch.load('model_final.pth', map_location=device, weights_only=False) |
|
model.load_state_dict(checkpoint) |
|
print("β
Model loaded successfully") |
|
else: |
|
print("β No trained model found! Please run training first.") |
|
return |
|
|
|
model.eval() |
|
scheduler = FrequencyAwareNoise(config) |
|
|
|
|
|
try: |
|
_, test_loader = get_dataloaders(config) |
|
print(f"β
Test data loaded: {len(test_loader)} batches") |
|
except: |
|
print("β Could not load test data") |
|
return |
|
|
|
|
|
metrics = { |
|
'reconstruction_mse': [], |
|
'reconstruction_psnr': [], |
|
'reconstruction_ssim': [], |
|
'enhancement_mse': [], |
|
'enhancement_psnr': [], |
|
'enhancement_ssim': [] |
|
} |
|
|
|
print("\nπ Evaluating reconstruction quality...") |
|
|
|
with torch.no_grad(): |
|
for i, (images, _) in enumerate(test_loader): |
|
if i >= 20: |
|
break |
|
|
|
images = images.to(device) |
|
batch_size = min(4, images.shape[0]) |
|
images = images[:batch_size] |
|
|
|
print(f" Processing batch {i+1}/20...") |
|
|
|
|
|
|
|
lightly_noisy = add_noise(images, noise_level=0.1) |
|
|
|
|
|
t_light = torch.full((batch_size,), 50, device=device, dtype=torch.long) |
|
noisy_imgs, noise_spatial = scheduler.apply_noise(images, t_light) |
|
|
|
|
|
predicted_noise = model(noisy_imgs, t_light) |
|
|
|
|
|
alpha_bar = scheduler.alpha_bars[50].item() |
|
reconstructed = (noisy_imgs - np.sqrt(1 - alpha_bar) * predicted_noise) / np.sqrt(alpha_bar) |
|
|
|
|
|
for j in range(batch_size): |
|
original = images[j] |
|
recon = reconstructed[j] |
|
|
|
|
|
mse_val = F.mse_loss(original, recon).item() |
|
metrics['reconstruction_mse'].append(mse_val) |
|
|
|
|
|
psnr_val = calculate_psnr(original, recon, max_val=2.0).item() |
|
metrics['reconstruction_psnr'].append(psnr_val) |
|
|
|
|
|
ssim_val = calculate_ssim(original, recon) |
|
metrics['reconstruction_ssim'].append(ssim_val) |
|
|
|
|
|
|
|
noisy_enhanced = add_noise(images, noise_level=0.3) |
|
|
|
|
|
t_heavy = torch.full((batch_size,), 150, device=device, dtype=torch.long) |
|
heavy_noisy, _ = scheduler.apply_noise(images, t_heavy) |
|
|
|
|
|
enhanced = heavy_noisy.clone() |
|
timesteps = [150, 100, 50, 25, 10, 5, 1] |
|
|
|
for t_val in timesteps: |
|
t_tensor = torch.full((batch_size,), max(t_val, 0), device=device, dtype=torch.long) |
|
pred_noise = model(enhanced, t_tensor) |
|
|
|
|
|
if t_val > 0: |
|
alpha_bar = scheduler.alpha_bars[t_val].item() |
|
enhanced = (enhanced - 0.1 * pred_noise) |
|
enhanced = torch.clamp(enhanced, -1, 1) |
|
|
|
|
|
for j in range(batch_size): |
|
original = images[j] |
|
enhanced_img = enhanced[j] |
|
|
|
mse_val = F.mse_loss(original, enhanced_img).item() |
|
metrics['enhancement_mse'].append(mse_val) |
|
|
|
psnr_val = calculate_psnr(original, enhanced_img, max_val=2.0).item() |
|
metrics['enhancement_psnr'].append(psnr_val) |
|
|
|
ssim_val = calculate_ssim(original, enhanced_img) |
|
metrics['enhancement_ssim'].append(ssim_val) |
|
|
|
|
|
print("\nπ FINAL METRICS RESULTS:") |
|
print("=" * 60) |
|
|
|
print("π― RECONSTRUCTION PERFORMANCE (Light Noise β Original):") |
|
recon_mse = np.mean(metrics['reconstruction_mse']) |
|
recon_psnr = np.mean(metrics['reconstruction_psnr']) |
|
recon_ssim = np.mean(metrics['reconstruction_ssim']) |
|
|
|
print(f" MSE: {recon_mse:.6f} Β± {np.std(metrics['reconstruction_mse']):.6f}") |
|
print(f" PSNR: {recon_psnr:.2f} Β± {np.std(metrics['reconstruction_psnr']):.2f} dB") |
|
print(f" SSIM: {recon_ssim:.4f} Β± {np.std(metrics['reconstruction_ssim']):.4f}") |
|
|
|
print("\nπ§Ή ENHANCEMENT PERFORMANCE (Heavy Noise β Original):") |
|
enh_mse = np.mean(metrics['enhancement_mse']) |
|
enh_psnr = np.mean(metrics['enhancement_psnr']) |
|
enh_ssim = np.mean(metrics['enhancement_ssim']) |
|
|
|
print(f" MSE: {enh_mse:.6f} Β± {np.std(metrics['enhancement_mse']):.6f}") |
|
print(f" PSNR: {enh_psnr:.2f} Β± {np.std(metrics['enhancement_psnr']):.2f} dB") |
|
print(f" SSIM: {enh_ssim:.4f} Β± {np.std(metrics['enhancement_ssim']):.4f}") |
|
|
|
|
|
def grade_metric(value, thresholds, metric_name): |
|
if metric_name == 'MSE': |
|
if value < thresholds[0]: return "Excellent β
" |
|
elif value < thresholds[1]: return "Very Good π’" |
|
elif value < thresholds[2]: return "Good π΅" |
|
else: return "Fair π‘" |
|
else: |
|
if value > thresholds[0]: return "Excellent β
" |
|
elif value > thresholds[1]: return "Very Good π’" |
|
elif value > thresholds[2]: return "Good π΅" |
|
else: return "Fair π‘" |
|
|
|
print("\nπ RECONSTRUCTION GRADES:") |
|
print(f" MSE: {grade_metric(recon_mse, [0.01, 0.05, 0.1], 'MSE')}") |
|
print(f" PSNR: {grade_metric(recon_psnr, [35, 30, 25], 'PSNR')}") |
|
print(f" SSIM: {grade_metric(recon_ssim, [0.9, 0.8, 0.7], 'SSIM')}") |
|
|
|
print("\nπ ENHANCEMENT GRADES:") |
|
print(f" MSE: {grade_metric(enh_mse, [0.05, 0.1, 0.2], 'MSE')}") |
|
print(f" PSNR: {grade_metric(enh_psnr, [30, 25, 20], 'PSNR')}") |
|
print(f" SSIM: {grade_metric(enh_ssim, [0.85, 0.75, 0.65], 'SSIM')}") |
|
|
|
|
|
print("\nπ SUMMARY FOR README:") |
|
print("=" * 60) |
|
print("Reconstruction Performance:") |
|
print(f"- MSE: {recon_mse:.6f}") |
|
print(f"- PSNR: {recon_psnr:.1f} dB") |
|
print(f"- SSIM: {recon_ssim:.4f}") |
|
print("\nEnhancement Performance:") |
|
print(f"- MSE: {enh_mse:.6f}") |
|
print(f"- PSNR: {enh_psnr:.1f} dB") |
|
print(f"- SSIM: {enh_ssim:.4f}") |
|
|
|
print("\nπ Metrics evaluation completed!") |
|
return { |
|
'recon_mse': recon_mse, |
|
'recon_psnr': recon_psnr, |
|
'recon_ssim': recon_ssim, |
|
'enh_mse': enh_mse, |
|
'enh_psnr': enh_psnr, |
|
'enh_ssim': enh_ssim |
|
} |
|
|
|
if __name__ == "__main__": |
|
evaluate_model() |
|
|