Grad-CDM / simple_metrics.py
nazgut's picture
Upload 2 files
455ba60 verified
#!/usr/bin/env python3
"""
Simple Metrics Evaluation for Frequency-Aware Super-Denoiser
============================================================
Calculates PSNR, SSIM, and MSE metrics using existing sampling methods
"""
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
# Import model components
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from dataloader import get_dataloaders
from sample import frequency_aware_sample
def calculate_psnr(img1, img2, max_val=2.0):
"""Calculate PSNR between two images"""
mse = F.mse_loss(img1, img2)
if mse == 0:
return float('inf')
return 20 * torch.log10(torch.tensor(max_val) / torch.sqrt(mse))
def calculate_ssim(img1, img2):
"""Calculate SSIM between two images"""
# Convert to numpy and ensure proper format
img1_np = img1.detach().cpu().numpy().transpose(1, 2, 0)
img2_np = img2.detach().cpu().numpy().transpose(1, 2, 0)
# Normalize to [0,1]
img1_np = (img1_np + 1) / 2
img2_np = (img2_np + 1) / 2
img1_np = np.clip(img1_np, 0, 1)
img2_np = np.clip(img2_np, 0, 1)
return ssim(img1_np, img2_np, multichannel=True, channel_axis=2, data_range=1.0)
def add_noise(image, noise_level=0.2):
"""Add Gaussian noise to images"""
noise = torch.randn_like(image) * noise_level
return torch.clamp(image + noise, -1, 1)
def evaluate_model():
"""Simplified model evaluation using existing sampling methods"""
print("πŸ” FREQUENCY-AWARE SUPER-DENOISER METRICS EVALUATION")
print("=" * 60)
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config()
# Load model
model = SmoothDiffusionUNet(config).to(device)
if os.path.exists('model_final.pth'):
checkpoint = torch.load('model_final.pth', map_location=device, weights_only=False)
model.load_state_dict(checkpoint)
print("βœ… Model loaded successfully")
else:
print("❌ No trained model found! Please run training first.")
return
model.eval()
scheduler = FrequencyAwareNoise(config)
# Get test data
try:
_, test_loader = get_dataloaders(config)
print(f"βœ… Test data loaded: {len(test_loader)} batches")
except:
print("❌ Could not load test data")
return
# Evaluation metrics storage
metrics = {
'reconstruction_mse': [],
'reconstruction_psnr': [],
'reconstruction_ssim': [],
'enhancement_mse': [],
'enhancement_psnr': [],
'enhancement_ssim': []
}
print("\nπŸ“Š Evaluating reconstruction quality...")
with torch.no_grad():
for i, (images, _) in enumerate(test_loader):
if i >= 20: # Evaluate on 20 batches for speed
break
images = images.to(device)
batch_size = min(4, images.shape[0]) # Process 4 images at a time
images = images[:batch_size]
print(f" Processing batch {i+1}/20...")
# Test 1: Reconstruction from low noise
# Add light noise and see how well we can reconstruct
lightly_noisy = add_noise(images, noise_level=0.1)
# Apply noise using the scheduler
t_light = torch.full((batch_size,), 50, device=device, dtype=torch.long) # Light noise
noisy_imgs, noise_spatial = scheduler.apply_noise(images, t_light)
# Reconstruct by predicting the noise
predicted_noise = model(noisy_imgs, t_light)
# Simple reconstruction
alpha_bar = scheduler.alpha_bars[50].item()
reconstructed = (noisy_imgs - np.sqrt(1 - alpha_bar) * predicted_noise) / np.sqrt(alpha_bar)
# Calculate reconstruction metrics
for j in range(batch_size):
original = images[j]
recon = reconstructed[j]
# MSE
mse_val = F.mse_loss(original, recon).item()
metrics['reconstruction_mse'].append(mse_val)
# PSNR
psnr_val = calculate_psnr(original, recon, max_val=2.0).item()
metrics['reconstruction_psnr'].append(psnr_val)
# SSIM
ssim_val = calculate_ssim(original, recon)
metrics['reconstruction_ssim'].append(ssim_val)
# Test 2: Enhancement from noisy images
# Add more significant noise and test enhancement
noisy_enhanced = add_noise(images, noise_level=0.3)
# Apply heavier noise with scheduler
t_heavy = torch.full((batch_size,), 150, device=device, dtype=torch.long)
heavy_noisy, _ = scheduler.apply_noise(images, t_heavy)
# Multi-step denoising simulation
enhanced = heavy_noisy.clone()
timesteps = [150, 100, 50, 25, 10, 5, 1]
for t_val in timesteps:
t_tensor = torch.full((batch_size,), max(t_val, 0), device=device, dtype=torch.long)
pred_noise = model(enhanced, t_tensor)
# Simple denoising step
if t_val > 0:
alpha_bar = scheduler.alpha_bars[t_val].item()
enhanced = (enhanced - 0.1 * pred_noise)
enhanced = torch.clamp(enhanced, -1, 1)
# Calculate enhancement metrics
for j in range(batch_size):
original = images[j]
enhanced_img = enhanced[j]
mse_val = F.mse_loss(original, enhanced_img).item()
metrics['enhancement_mse'].append(mse_val)
psnr_val = calculate_psnr(original, enhanced_img, max_val=2.0).item()
metrics['enhancement_psnr'].append(psnr_val)
ssim_val = calculate_ssim(original, enhanced_img)
metrics['enhancement_ssim'].append(ssim_val)
# Calculate final statistics
print("\nπŸ“ˆ FINAL METRICS RESULTS:")
print("=" * 60)
print("🎯 RECONSTRUCTION PERFORMANCE (Light Noise β†’ Original):")
recon_mse = np.mean(metrics['reconstruction_mse'])
recon_psnr = np.mean(metrics['reconstruction_psnr'])
recon_ssim = np.mean(metrics['reconstruction_ssim'])
print(f" MSE: {recon_mse:.6f} Β± {np.std(metrics['reconstruction_mse']):.6f}")
print(f" PSNR: {recon_psnr:.2f} Β± {np.std(metrics['reconstruction_psnr']):.2f} dB")
print(f" SSIM: {recon_ssim:.4f} Β± {np.std(metrics['reconstruction_ssim']):.4f}")
print("\n🧹 ENHANCEMENT PERFORMANCE (Heavy Noise β†’ Original):")
enh_mse = np.mean(metrics['enhancement_mse'])
enh_psnr = np.mean(metrics['enhancement_psnr'])
enh_ssim = np.mean(metrics['enhancement_ssim'])
print(f" MSE: {enh_mse:.6f} Β± {np.std(metrics['enhancement_mse']):.6f}")
print(f" PSNR: {enh_psnr:.2f} Β± {np.std(metrics['enhancement_psnr']):.2f} dB")
print(f" SSIM: {enh_ssim:.4f} Β± {np.std(metrics['enhancement_ssim']):.4f}")
# Generate performance grades
def grade_metric(value, thresholds, metric_name):
if metric_name == 'MSE':
if value < thresholds[0]: return "Excellent βœ…"
elif value < thresholds[1]: return "Very Good 🟒"
elif value < thresholds[2]: return "Good πŸ”΅"
else: return "Fair 🟑"
else: # PSNR, SSIM
if value > thresholds[0]: return "Excellent βœ…"
elif value > thresholds[1]: return "Very Good 🟒"
elif value > thresholds[2]: return "Good πŸ”΅"
else: return "Fair 🟑"
print("\nπŸ† RECONSTRUCTION GRADES:")
print(f" MSE: {grade_metric(recon_mse, [0.01, 0.05, 0.1], 'MSE')}")
print(f" PSNR: {grade_metric(recon_psnr, [35, 30, 25], 'PSNR')}")
print(f" SSIM: {grade_metric(recon_ssim, [0.9, 0.8, 0.7], 'SSIM')}")
print("\nπŸ† ENHANCEMENT GRADES:")
print(f" MSE: {grade_metric(enh_mse, [0.05, 0.1, 0.2], 'MSE')}")
print(f" PSNR: {grade_metric(enh_psnr, [30, 25, 20], 'PSNR')}")
print(f" SSIM: {grade_metric(enh_ssim, [0.85, 0.75, 0.65], 'SSIM')}")
# Create summary for README
print("\nπŸ“‹ SUMMARY FOR README:")
print("=" * 60)
print("Reconstruction Performance:")
print(f"- MSE: {recon_mse:.6f}")
print(f"- PSNR: {recon_psnr:.1f} dB")
print(f"- SSIM: {recon_ssim:.4f}")
print("\nEnhancement Performance:")
print(f"- MSE: {enh_mse:.6f}")
print(f"- PSNR: {enh_psnr:.1f} dB")
print(f"- SSIM: {enh_ssim:.4f}")
print("\nπŸŽ‰ Metrics evaluation completed!")
return {
'recon_mse': recon_mse,
'recon_psnr': recon_psnr,
'recon_ssim': recon_ssim,
'enh_mse': enh_mse,
'enh_psnr': enh_psnr,
'enh_ssim': enh_ssim
}
if __name__ == "__main__":
evaluate_model()