#!/usr/bin/env python3 """ Simple Metrics Evaluation for Frequency-Aware Super-Denoiser ============================================================ Calculates PSNR, SSIM, and MSE metrics using existing sampling methods """ import torch import torch.nn.functional as F import numpy as np from PIL import Image import os from skimage.metrics import structural_similarity as ssim import matplotlib.pyplot as plt # Import model components from model import SmoothDiffusionUNet from noise_scheduler import FrequencyAwareNoise from config import Config from dataloader import get_dataloaders from sample import frequency_aware_sample def calculate_psnr(img1, img2, max_val=2.0): """Calculate PSNR between two images""" mse = F.mse_loss(img1, img2) if mse == 0: return float('inf') return 20 * torch.log10(torch.tensor(max_val) / torch.sqrt(mse)) def calculate_ssim(img1, img2): """Calculate SSIM between two images""" # Convert to numpy and ensure proper format img1_np = img1.detach().cpu().numpy().transpose(1, 2, 0) img2_np = img2.detach().cpu().numpy().transpose(1, 2, 0) # Normalize to [0,1] img1_np = (img1_np + 1) / 2 img2_np = (img2_np + 1) / 2 img1_np = np.clip(img1_np, 0, 1) img2_np = np.clip(img2_np, 0, 1) return ssim(img1_np, img2_np, multichannel=True, channel_axis=2, data_range=1.0) def add_noise(image, noise_level=0.2): """Add Gaussian noise to images""" noise = torch.randn_like(image) * noise_level return torch.clamp(image + noise, -1, 1) def evaluate_model(): """Simplified model evaluation using existing sampling methods""" print("๐Ÿ” FREQUENCY-AWARE SUPER-DENOISER METRICS EVALUATION") print("=" * 60) # Setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config = Config() # Load model model = SmoothDiffusionUNet(config).to(device) if os.path.exists('model_final.pth'): checkpoint = torch.load('model_final.pth', map_location=device, weights_only=False) model.load_state_dict(checkpoint) print("โœ… Model loaded successfully") else: print("โŒ No trained model found! Please run training first.") return model.eval() scheduler = FrequencyAwareNoise(config) # Get test data try: _, test_loader = get_dataloaders(config) print(f"โœ… Test data loaded: {len(test_loader)} batches") except: print("โŒ Could not load test data") return # Evaluation metrics storage metrics = { 'reconstruction_mse': [], 'reconstruction_psnr': [], 'reconstruction_ssim': [], 'enhancement_mse': [], 'enhancement_psnr': [], 'enhancement_ssim': [] } print("\n๐Ÿ“Š Evaluating reconstruction quality...") with torch.no_grad(): for i, (images, _) in enumerate(test_loader): if i >= 20: # Evaluate on 20 batches for speed break images = images.to(device) batch_size = min(4, images.shape[0]) # Process 4 images at a time images = images[:batch_size] print(f" Processing batch {i+1}/20...") # Test 1: Reconstruction from low noise # Add light noise and see how well we can reconstruct lightly_noisy = add_noise(images, noise_level=0.1) # Apply noise using the scheduler t_light = torch.full((batch_size,), 50, device=device, dtype=torch.long) # Light noise noisy_imgs, noise_spatial = scheduler.apply_noise(images, t_light) # Reconstruct by predicting the noise predicted_noise = model(noisy_imgs, t_light) # Simple reconstruction alpha_bar = scheduler.alpha_bars[50].item() reconstructed = (noisy_imgs - np.sqrt(1 - alpha_bar) * predicted_noise) / np.sqrt(alpha_bar) # Calculate reconstruction metrics for j in range(batch_size): original = images[j] recon = reconstructed[j] # MSE mse_val = F.mse_loss(original, recon).item() metrics['reconstruction_mse'].append(mse_val) # PSNR psnr_val = calculate_psnr(original, recon, max_val=2.0).item() metrics['reconstruction_psnr'].append(psnr_val) # SSIM ssim_val = calculate_ssim(original, recon) metrics['reconstruction_ssim'].append(ssim_val) # Test 2: Enhancement from noisy images # Add more significant noise and test enhancement noisy_enhanced = add_noise(images, noise_level=0.3) # Apply heavier noise with scheduler t_heavy = torch.full((batch_size,), 150, device=device, dtype=torch.long) heavy_noisy, _ = scheduler.apply_noise(images, t_heavy) # Multi-step denoising simulation enhanced = heavy_noisy.clone() timesteps = [150, 100, 50, 25, 10, 5, 1] for t_val in timesteps: t_tensor = torch.full((batch_size,), max(t_val, 0), device=device, dtype=torch.long) pred_noise = model(enhanced, t_tensor) # Simple denoising step if t_val > 0: alpha_bar = scheduler.alpha_bars[t_val].item() enhanced = (enhanced - 0.1 * pred_noise) enhanced = torch.clamp(enhanced, -1, 1) # Calculate enhancement metrics for j in range(batch_size): original = images[j] enhanced_img = enhanced[j] mse_val = F.mse_loss(original, enhanced_img).item() metrics['enhancement_mse'].append(mse_val) psnr_val = calculate_psnr(original, enhanced_img, max_val=2.0).item() metrics['enhancement_psnr'].append(psnr_val) ssim_val = calculate_ssim(original, enhanced_img) metrics['enhancement_ssim'].append(ssim_val) # Calculate final statistics print("\n๐Ÿ“ˆ FINAL METRICS RESULTS:") print("=" * 60) print("๐ŸŽฏ RECONSTRUCTION PERFORMANCE (Light Noise โ†’ Original):") recon_mse = np.mean(metrics['reconstruction_mse']) recon_psnr = np.mean(metrics['reconstruction_psnr']) recon_ssim = np.mean(metrics['reconstruction_ssim']) print(f" MSE: {recon_mse:.6f} ยฑ {np.std(metrics['reconstruction_mse']):.6f}") print(f" PSNR: {recon_psnr:.2f} ยฑ {np.std(metrics['reconstruction_psnr']):.2f} dB") print(f" SSIM: {recon_ssim:.4f} ยฑ {np.std(metrics['reconstruction_ssim']):.4f}") print("\n๐Ÿงน ENHANCEMENT PERFORMANCE (Heavy Noise โ†’ Original):") enh_mse = np.mean(metrics['enhancement_mse']) enh_psnr = np.mean(metrics['enhancement_psnr']) enh_ssim = np.mean(metrics['enhancement_ssim']) print(f" MSE: {enh_mse:.6f} ยฑ {np.std(metrics['enhancement_mse']):.6f}") print(f" PSNR: {enh_psnr:.2f} ยฑ {np.std(metrics['enhancement_psnr']):.2f} dB") print(f" SSIM: {enh_ssim:.4f} ยฑ {np.std(metrics['enhancement_ssim']):.4f}") # Generate performance grades def grade_metric(value, thresholds, metric_name): if metric_name == 'MSE': if value < thresholds[0]: return "Excellent โœ…" elif value < thresholds[1]: return "Very Good ๐ŸŸข" elif value < thresholds[2]: return "Good ๐Ÿ”ต" else: return "Fair ๐ŸŸก" else: # PSNR, SSIM if value > thresholds[0]: return "Excellent โœ…" elif value > thresholds[1]: return "Very Good ๐ŸŸข" elif value > thresholds[2]: return "Good ๐Ÿ”ต" else: return "Fair ๐ŸŸก" print("\n๐Ÿ† RECONSTRUCTION GRADES:") print(f" MSE: {grade_metric(recon_mse, [0.01, 0.05, 0.1], 'MSE')}") print(f" PSNR: {grade_metric(recon_psnr, [35, 30, 25], 'PSNR')}") print(f" SSIM: {grade_metric(recon_ssim, [0.9, 0.8, 0.7], 'SSIM')}") print("\n๐Ÿ† ENHANCEMENT GRADES:") print(f" MSE: {grade_metric(enh_mse, [0.05, 0.1, 0.2], 'MSE')}") print(f" PSNR: {grade_metric(enh_psnr, [30, 25, 20], 'PSNR')}") print(f" SSIM: {grade_metric(enh_ssim, [0.85, 0.75, 0.65], 'SSIM')}") # Create summary for README print("\n๐Ÿ“‹ SUMMARY FOR README:") print("=" * 60) print("Reconstruction Performance:") print(f"- MSE: {recon_mse:.6f}") print(f"- PSNR: {recon_psnr:.1f} dB") print(f"- SSIM: {recon_ssim:.4f}") print("\nEnhancement Performance:") print(f"- MSE: {enh_mse:.6f}") print(f"- PSNR: {enh_psnr:.1f} dB") print(f"- SSIM: {enh_ssim:.4f}") print("\n๐ŸŽ‰ Metrics evaluation completed!") return { 'recon_mse': recon_mse, 'recon_psnr': recon_psnr, 'recon_ssim': recon_ssim, 'enh_mse': enh_mse, 'enh_psnr': enh_psnr, 'enh_ssim': enh_ssim } if __name__ == "__main__": evaluate_model()