| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from model import SmoothDiffusionUNet | 
					
					
						
						| 
							 | 
						from noise_scheduler import FrequencyAwareNoise | 
					
					
						
						| 
							 | 
						from config import Config | 
					
					
						
						| 
							 | 
						from torchvision.utils import save_image, make_grid | 
					
					
						
						| 
							 | 
						from dataloader import get_dataloaders | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from PIL import Image, ImageFilter | 
					
					
						
						| 
							 | 
						import torchvision.transforms as transforms | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_test_applications(): | 
					
					
						
						| 
							 | 
						    """Comprehensive test of all super-denoiser applications""" | 
					
					
						
						| 
							 | 
						    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    checkpoint = torch.load('model_final.pth', map_location=device) | 
					
					
						
						| 
							 | 
						    config = Config() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    model = SmoothDiffusionUNet(config).to(device) | 
					
					
						
						| 
							 | 
						    noise_scheduler = FrequencyAwareNoise(config) | 
					
					
						
						| 
							 | 
						    model.load_state_dict(checkpoint) | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    train_loader, _ = get_dataloaders(config) | 
					
					
						
						| 
							 | 
						    real_batch, _ = next(iter(train_loader)) | 
					
					
						
						| 
							 | 
						    real_images = real_batch[:8].to(device) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    print("=== COMPREHENSIVE SUPER-DENOISER APPLICATIONS TEST ===") | 
					
					
						
						| 
							 | 
						    os.makedirs("applications_test", exist_ok=True) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    with torch.no_grad(): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n🔧 APPLICATION 1: NOISE REMOVAL") | 
					
					
						
						| 
							 | 
						        print("Use case: Cleaning noisy photos, low-light images, old scans") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        clean_img = real_images[0:1] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        gaussian_noisy = clean_img + torch.randn_like(clean_img) * 0.2 | 
					
					
						
						| 
							 | 
						        gaussian_noisy = torch.clamp(gaussian_noisy, -1, 1) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        salt_pepper = clean_img.clone() | 
					
					
						
						| 
							 | 
						        mask = torch.rand_like(clean_img) < 0.1 | 
					
					
						
						| 
							 | 
						        salt_pepper[mask] = torch.randint_like(salt_pepper[mask], -1, 2).float() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        denoised_gaussian = denoise_image(model, noise_scheduler, gaussian_noisy, strength=0.6) | 
					
					
						
						| 
							 | 
						        denoised_salt_pepper = denoise_image(model, noise_scheduler, salt_pepper, strength=0.8) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        noise_comparison = torch.cat([ | 
					
					
						
						| 
							 | 
						            clean_img, gaussian_noisy, denoised_gaussian, | 
					
					
						
						| 
							 | 
						            clean_img, salt_pepper, denoised_salt_pepper | 
					
					
						
						| 
							 | 
						        ], dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(noise_comparison, "applications_test/01_noise_removal.png",  | 
					
					
						
						| 
							 | 
						                       labels=["Original", "Gaussian Noise", "Denoised",  | 
					
					
						
						| 
							 | 
						                              "Original", "Salt&Pepper", "Denoised"]) | 
					
					
						
						| 
							 | 
						        print("✅ Noise removal test saved to applications_test/01_noise_removal.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n📸 APPLICATION 2: IMAGE SHARPENING & ENHANCEMENT") | 
					
					
						
						| 
							 | 
						        print("Use case: Enhancing blurry photos, improving image quality") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        blur_img = real_images[1:2] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        mild_blur = apply_blur(blur_img, sigma=0.8) | 
					
					
						
						| 
							 | 
						        heavy_blur = apply_blur(blur_img, sigma=2.0) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        enhanced_mild = enhance_image(model, noise_scheduler, mild_blur, enhancement=0.5) | 
					
					
						
						| 
							 | 
						        enhanced_heavy = enhance_image(model, noise_scheduler, heavy_blur, enhancement=0.8) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        enhancement_comparison = torch.cat([ | 
					
					
						
						| 
							 | 
						            blur_img, mild_blur, enhanced_mild, | 
					
					
						
						| 
							 | 
						            blur_img, heavy_blur, enhanced_heavy | 
					
					
						
						| 
							 | 
						        ], dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(enhancement_comparison, "applications_test/02_image_enhancement.png", | 
					
					
						
						| 
							 | 
						                       labels=["Original", "Mild Blur", "Enhanced", | 
					
					
						
						| 
							 | 
						                              "Original", "Heavy Blur", "Enhanced"]) | 
					
					
						
						| 
							 | 
						        print("✅ Enhancement test saved to applications_test/02_image_enhancement.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n🎨 APPLICATION 3: TEXTURE SYNTHESIS & ARTISTIC CREATION") | 
					
					
						
						| 
							 | 
						        print("Use case: Creating new textures, artistic effects, style transfer") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        patterns = [] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        organic = create_organic_pattern(device) | 
					
					
						
						| 
							 | 
						        refined_organic = refine_pattern(model, noise_scheduler, organic, steps=8) | 
					
					
						
						| 
							 | 
						        patterns.extend([organic, refined_organic]) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        geometric = create_geometric_pattern(device) | 
					
					
						
						| 
							 | 
						        refined_geometric = refine_pattern(model, noise_scheduler, geometric, steps=6) | 
					
					
						
						| 
							 | 
						        patterns.extend([geometric, refined_geometric]) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        abstract = create_abstract_pattern(device) | 
					
					
						
						| 
							 | 
						        refined_abstract = refine_pattern(model, noise_scheduler, abstract, steps=10) | 
					
					
						
						| 
							 | 
						        patterns.extend([abstract, refined_abstract]) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        pattern_grid = torch.cat(patterns, dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(pattern_grid, "applications_test/03_texture_synthesis.png", | 
					
					
						
						| 
							 | 
						                       labels=["Organic Raw", "Organic Refined", "Geometric Raw",  | 
					
					
						
						| 
							 | 
						                              "Geometric Refined", "Abstract Raw", "Abstract Refined"]) | 
					
					
						
						| 
							 | 
						        print("✅ Texture synthesis test saved to applications_test/03_texture_synthesis.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n🔄 APPLICATION 4: IMAGE INTERPOLATION & MORPHING") | 
					
					
						
						| 
							 | 
						        print("Use case: Creating smooth transitions, morphing between images") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        img1 = real_images[2:3] | 
					
					
						
						| 
							 | 
						        img2 = real_images[3:4] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        interpolations = [] | 
					
					
						
						| 
							 | 
						        alphas = [0.0, 0.25, 0.5, 0.75, 1.0] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        for alpha in alphas: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            interp = alpha * img1 + (1 - alpha) * img2 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            interp = interp + torch.randn_like(interp) * 0.05 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            refined = refine_interpolation(model, noise_scheduler, interp) | 
					
					
						
						| 
							 | 
						            interpolations.append(refined) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        interp_grid = torch.cat(interpolations, dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(interp_grid, "applications_test/04_image_interpolation.png", | 
					
					
						
						| 
							 | 
						                       labels=[f"α={a:.2f}" for a in alphas]) | 
					
					
						
						| 
							 | 
						        print("✅ Interpolation test saved to applications_test/04_image_interpolation.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n🖼️ APPLICATION 5: STYLE TRANSFER & ARTISTIC EFFECTS") | 
					
					
						
						| 
							 | 
						        print("Use case: Applying artistic styles, creating stylized versions") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        content_img = real_images[4:5] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        styles = [] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        high_contrast = create_high_contrast_version(content_img) | 
					
					
						
						| 
							 | 
						        refined_contrast = apply_style_refinement(model, noise_scheduler, high_contrast, "contrast") | 
					
					
						
						| 
							 | 
						        styles.extend([high_contrast, refined_contrast]) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        soft_style = create_soft_version(content_img) | 
					
					
						
						| 
							 | 
						        refined_soft = apply_style_refinement(model, noise_scheduler, soft_style, "soft") | 
					
					
						
						| 
							 | 
						        styles.extend([soft_style, refined_soft]) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        edge_style = create_edge_enhanced_version(content_img) | 
					
					
						
						| 
							 | 
						        refined_edge = apply_style_refinement(model, noise_scheduler, edge_style, "edge") | 
					
					
						
						| 
							 | 
						        styles.extend([edge_style, refined_edge]) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        styles_with_original = torch.cat([content_img] + styles, dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(styles_with_original, "applications_test/05_style_transfer.png", | 
					
					
						
						| 
							 | 
						                       labels=["Original", "High Contrast", "Refined", "Soft", "Refined", "Edge Enhanced", "Refined"]) | 
					
					
						
						| 
							 | 
						        print("✅ Style transfer test saved to applications_test/05_style_transfer.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n⚡ APPLICATION 6: PROGRESSIVE ENHANCEMENT") | 
					
					
						
						| 
							 | 
						        print("Use case: Showing different enhancement levels, user control") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        base_img = real_images[5:6] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        degraded = base_img + torch.randn_like(base_img) * 0.15 | 
					
					
						
						| 
							 | 
						        degraded = apply_blur(degraded, sigma=1.2) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        enhancement_levels = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] | 
					
					
						
						| 
							 | 
						        progressive = [degraded]   | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        for level in enhancement_levels[1:]: | 
					
					
						
						| 
							 | 
						            enhanced = progressive_enhance(model, noise_scheduler, degraded, level) | 
					
					
						
						| 
							 | 
						            progressive.append(enhanced) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        prog_grid = torch.cat(progressive, dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(prog_grid, "applications_test/06_progressive_enhancement.png", | 
					
					
						
						| 
							 | 
						                       labels=[f"Level {l:.1f}" for l in enhancement_levels]) | 
					
					
						
						| 
							 | 
						        print("✅ Progressive enhancement test saved to applications_test/06_progressive_enhancement.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n🔬 APPLICATION 7: MEDICAL/SCIENTIFIC SIMULATION") | 
					
					
						
						| 
							 | 
						        print("Use case: Enhancing low-quality scientific images") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        scientific_img = real_images[6:7] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        low_contrast = scientific_img * 0.3 + 0.1 | 
					
					
						
						| 
							 | 
						        enhanced_contrast = enhance_medical_image(model, noise_scheduler, low_contrast, "contrast") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        noisy_scan = scientific_img + torch.randn_like(scientific_img) * 0.25 | 
					
					
						
						| 
							 | 
						        enhanced_scan = enhance_medical_image(model, noise_scheduler, noisy_scan, "noise") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        blurry_micro = apply_blur(scientific_img, sigma=1.5) | 
					
					
						
						| 
							 | 
						        enhanced_micro = enhance_medical_image(model, noise_scheduler, blurry_micro, "sharpness") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        medical_comparison = torch.cat([ | 
					
					
						
						| 
							 | 
						            low_contrast, enhanced_contrast, | 
					
					
						
						| 
							 | 
						            noisy_scan, enhanced_scan, | 
					
					
						
						| 
							 | 
						            blurry_micro, enhanced_micro | 
					
					
						
						| 
							 | 
						        ], dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(medical_comparison, "applications_test/07_medical_enhancement.png", | 
					
					
						
						| 
							 | 
						                       labels=["Low Contrast", "Enhanced", "Noisy Scan", "Denoised", "Blurry Micro", "Sharpened"]) | 
					
					
						
						| 
							 | 
						        print("✅ Medical enhancement test saved to applications_test/07_medical_enhancement.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n⚡ APPLICATION 8: REAL-TIME ENHANCEMENT SIMULATION") | 
					
					
						
						| 
							 | 
						        print("Use case: Fast single-pass enhancement for real-time applications") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        realtime_img = real_images[7:8] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        video_call = realtime_img * 0.6 + torch.randn_like(realtime_img) * 0.1 | 
					
					
						
						| 
							 | 
						        enhanced_video = single_pass_enhance(model, noise_scheduler, video_call) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        mobile_photo = realtime_img + torch.randn_like(realtime_img) * 0.08 | 
					
					
						
						| 
							 | 
						        mobile_photo = apply_blur(mobile_photo, sigma=0.5) | 
					
					
						
						| 
							 | 
						        enhanced_mobile = single_pass_enhance(model, noise_scheduler, mobile_photo) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        security_cam = realtime_img * 0.4 + torch.randn_like(realtime_img) * 0.2 | 
					
					
						
						| 
							 | 
						        enhanced_security = single_pass_enhance(model, noise_scheduler, security_cam) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        realtime_comparison = torch.cat([ | 
					
					
						
						| 
							 | 
						            video_call, enhanced_video, | 
					
					
						
						| 
							 | 
						            mobile_photo, enhanced_mobile, | 
					
					
						
						| 
							 | 
						            security_cam, enhanced_security | 
					
					
						
						| 
							 | 
						        ], dim=0) | 
					
					
						
						| 
							 | 
						        save_comparison(realtime_comparison, "applications_test/08_realtime_enhancement.png", | 
					
					
						
						| 
							 | 
						                       labels=["Video Call", "Enhanced", "Mobile Photo", "Enhanced", "Security Cam", "Enhanced"]) | 
					
					
						
						| 
							 | 
						        print("✅ Real-time enhancement test saved to applications_test/08_realtime_enhancement.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        print("\n🎉 SUMMARY: ALL APPLICATIONS TESTED") | 
					
					
						
						| 
							 | 
						        print("=" * 50) | 
					
					
						
						| 
							 | 
						        print("Your frequency-aware super-denoiser model successfully handles:") | 
					
					
						
						| 
							 | 
						        print("1. ✅ Noise removal (Gaussian, salt & pepper)") | 
					
					
						
						| 
							 | 
						        print("2. ✅ Image sharpening and enhancement") | 
					
					
						
						| 
							 | 
						        print("3. ✅ Texture synthesis and artistic creation") | 
					
					
						
						| 
							 | 
						        print("4. ✅ Image interpolation and morphing") | 
					
					
						
						| 
							 | 
						        print("5. ✅ Style transfer and artistic effects") | 
					
					
						
						| 
							 | 
						        print("6. ✅ Progressive enhancement with user control") | 
					
					
						
						| 
							 | 
						        print("7. ✅ Medical/scientific image enhancement") | 
					
					
						
						| 
							 | 
						        print("8. ✅ Real-time enhancement applications") | 
					
					
						
						| 
							 | 
						        print("\nAll test results saved in 'applications_test/' directory") | 
					
					
						
						| 
							 | 
						        print("Your model is ready for production use! 🚀") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def denoise_image(model, noise_scheduler, noisy_img, strength=0.5): | 
					
					
						
						| 
							 | 
						    """Apply denoising with controlled strength""" | 
					
					
						
						| 
							 | 
						    timesteps = [int(strength * 100), int(strength * 60), int(strength * 30), int(strength * 10), 1] | 
					
					
						
						| 
							 | 
						    x = noisy_img.clone() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for t_val in timesteps: | 
					
					
						
						| 
							 | 
						        if t_val > 0: | 
					
					
						
						| 
							 | 
						            t_tensor = torch.full((x.shape[0],), t_val, device=x.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						            predicted_noise = model(x, t_tensor) | 
					
					
						
						| 
							 | 
						            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						            x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * strength) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						            x = torch.clamp(x, -1, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def enhance_image(model, noise_scheduler, blurry_img, enhancement=0.5): | 
					
					
						
						| 
							 | 
						    """Enhance blurry or low-quality images""" | 
					
					
						
						| 
							 | 
						    timesteps = [int(enhancement * 80), int(enhancement * 50), int(enhancement * 25), int(enhancement * 10)] | 
					
					
						
						| 
							 | 
						    x = blurry_img.clone() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for t_val in timesteps: | 
					
					
						
						| 
							 | 
						        if t_val > 0: | 
					
					
						
						| 
							 | 
						            t_tensor = torch.full((x.shape[0],), t_val, device=x.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						            predicted_noise = model(x, t_tensor) | 
					
					
						
						| 
							 | 
						            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						            x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * enhancement) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						            x = torch.clamp(x, -1, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def refine_pattern(model, noise_scheduler, pattern, steps=5): | 
					
					
						
						| 
							 | 
						    """Refine generated patterns""" | 
					
					
						
						| 
							 | 
						    timesteps = [60, 40, 25, 15, 5][:steps] | 
					
					
						
						| 
							 | 
						    x = pattern.clone() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for t_val in timesteps: | 
					
					
						
						| 
							 | 
						        t_tensor = torch.full((x.shape[0],), t_val, device=x.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						        predicted_noise = model(x, t_tensor) | 
					
					
						
						| 
							 | 
						        alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						        x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.4) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						        x = torch.clamp(x, -1, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def refine_interpolation(model, noise_scheduler, interp_img): | 
					
					
						
						| 
							 | 
						    """Refine interpolated images""" | 
					
					
						
						| 
							 | 
						    timesteps = [30, 20, 10, 5] | 
					
					
						
						| 
							 | 
						    x = interp_img.clone() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for t_val in timesteps: | 
					
					
						
						| 
							 | 
						        t_tensor = torch.full((x.shape[0],), t_val, device=x.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						        predicted_noise = model(x, t_tensor) | 
					
					
						
						| 
							 | 
						        alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						        x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.3) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						        x = torch.clamp(x, -1, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def apply_style_refinement(model, noise_scheduler, styled_img, style_type): | 
					
					
						
						| 
							 | 
						    """Apply style-specific refinement""" | 
					
					
						
						| 
							 | 
						    if style_type == "contrast": | 
					
					
						
						| 
							 | 
						        timesteps = [40, 25, 10] | 
					
					
						
						| 
							 | 
						        strength = 0.4 | 
					
					
						
						| 
							 | 
						    elif style_type == "soft": | 
					
					
						
						| 
							 | 
						        timesteps = [60, 35, 15, 5] | 
					
					
						
						| 
							 | 
						        strength = 0.3 | 
					
					
						
						| 
							 | 
						    else:   | 
					
					
						
						| 
							 | 
						        timesteps = [35, 20, 8] | 
					
					
						
						| 
							 | 
						        strength = 0.5 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x = styled_img.clone() | 
					
					
						
						| 
							 | 
						    for t_val in timesteps: | 
					
					
						
						| 
							 | 
						        t_tensor = torch.full((x.shape[0],), t_val, device=x.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						        predicted_noise = model(x, t_tensor) | 
					
					
						
						| 
							 | 
						        alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						        x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * strength) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						        x = torch.clamp(x, -1, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def progressive_enhance(model, noise_scheduler, degraded_img, level): | 
					
					
						
						| 
							 | 
						    """Apply progressive enhancement based on level""" | 
					
					
						
						| 
							 | 
						    if level == 0: | 
					
					
						
						| 
							 | 
						        return degraded_img | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    max_timestep = int(level * 100) | 
					
					
						
						| 
							 | 
						    timesteps = [max_timestep, int(max_timestep * 0.6), int(max_timestep * 0.3)] | 
					
					
						
						| 
							 | 
						    timesteps = [t for t in timesteps if t > 0] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x = degraded_img.clone() | 
					
					
						
						| 
							 | 
						    for t_val in timesteps: | 
					
					
						
						| 
							 | 
						        t_tensor = torch.full((x.shape[0],), t_val, device=x.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						        predicted_noise = model(x, t_tensor) | 
					
					
						
						| 
							 | 
						        alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						        x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * level) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						        x = torch.clamp(x, -1, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def enhance_medical_image(model, noise_scheduler, medical_img, enhancement_type): | 
					
					
						
						| 
							 | 
						    """Enhance medical/scientific images""" | 
					
					
						
						| 
							 | 
						    if enhancement_type == "contrast": | 
					
					
						
						| 
							 | 
						        timesteps = [50, 30, 15] | 
					
					
						
						| 
							 | 
						        strength = 0.6 | 
					
					
						
						| 
							 | 
						    elif enhancement_type == "noise": | 
					
					
						
						| 
							 | 
						        timesteps = [80, 50, 25, 10] | 
					
					
						
						| 
							 | 
						        strength = 0.7 | 
					
					
						
						| 
							 | 
						    else:   | 
					
					
						
						| 
							 | 
						        timesteps = [60, 35, 18, 8] | 
					
					
						
						| 
							 | 
						        strength = 0.5 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x = medical_img.clone() | 
					
					
						
						| 
							 | 
						    for t_val in timesteps: | 
					
					
						
						| 
							 | 
						        t_tensor = torch.full((x.shape[0],), t_val, device=x.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						        predicted_noise = model(x, t_tensor) | 
					
					
						
						| 
							 | 
						        alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						        x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * strength) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						        x = torch.clamp(x, -1, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def single_pass_enhance(model, noise_scheduler, input_img): | 
					
					
						
						| 
							 | 
						    """Fast single-pass enhancement for real-time use""" | 
					
					
						
						| 
							 | 
						    t_val = 25   | 
					
					
						
						| 
							 | 
						    t_tensor = torch.full((input_img.shape[0],), t_val, device=input_img.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						    predicted_noise = model(input_img, t_tensor) | 
					
					
						
						| 
							 | 
						    alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() | 
					
					
						
						| 
							 | 
						    enhanced = (input_img - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t) | 
					
					
						
						| 
							 | 
						    return torch.clamp(enhanced, -1, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def apply_blur(img, sigma=1.0): | 
					
					
						
						| 
							 | 
						    """Apply Gaussian blur""" | 
					
					
						
						| 
							 | 
						    kernel_size = int(sigma * 4) * 2 + 1 | 
					
					
						
						| 
							 | 
						    blur = torch.nn.functional.conv2d( | 
					
					
						
						| 
							 | 
						        img,  | 
					
					
						
						| 
							 | 
						        create_gaussian_kernel(kernel_size, sigma).repeat(3, 1, 1, 1).to(img.device), | 
					
					
						
						| 
							 | 
						        padding=kernel_size//2, | 
					
					
						
						| 
							 | 
						        groups=3 | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    return blur | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_gaussian_kernel(kernel_size, sigma): | 
					
					
						
						| 
							 | 
						    """Create Gaussian blur kernel""" | 
					
					
						
						| 
							 | 
						    x = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2 | 
					
					
						
						| 
							 | 
						    gauss = torch.exp(-x**2 / (2 * sigma**2)) | 
					
					
						
						| 
							 | 
						    kernel_1d = gauss / gauss.sum() | 
					
					
						
						| 
							 | 
						    kernel_2d = kernel_1d[:, None] * kernel_1d[None, :] | 
					
					
						
						| 
							 | 
						    return kernel_2d | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_organic_pattern(device): | 
					
					
						
						| 
							 | 
						    """Create organic texture pattern""" | 
					
					
						
						| 
							 | 
						    pattern = torch.randn(1, 3, 64, 64, device=device) * 0.3 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x, y = torch.meshgrid(torch.linspace(-1, 1, 64), torch.linspace(-1, 1, 64), indexing='ij') | 
					
					
						
						| 
							 | 
						    x, y = x.to(device), y.to(device) | 
					
					
						
						| 
							 | 
						    structure = torch.sin(x * 3) * torch.cos(y * 3) * 0.2 | 
					
					
						
						| 
							 | 
						    pattern[0] += structure.unsqueeze(0) | 
					
					
						
						| 
							 | 
						    return torch.clamp(pattern, -1, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_geometric_pattern(device): | 
					
					
						
						| 
							 | 
						    """Create geometric pattern""" | 
					
					
						
						| 
							 | 
						    pattern = torch.zeros(1, 3, 64, 64, device=device) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for i in range(0, 64, 8): | 
					
					
						
						| 
							 | 
						        for j in range(0, 64, 8): | 
					
					
						
						| 
							 | 
						            if (i//8 + j//8) % 2 == 0: | 
					
					
						
						| 
							 | 
						                pattern[0, :, i:i+8, j:j+8] = 0.5 | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                pattern[0, :, i:i+8, j:j+8] = -0.5 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    pattern += torch.randn_like(pattern) * 0.1 | 
					
					
						
						| 
							 | 
						    return torch.clamp(pattern, -1, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_abstract_pattern(device): | 
					
					
						
						| 
							 | 
						    """Create abstract pattern""" | 
					
					
						
						| 
							 | 
						    pattern = torch.randn(1, 3, 64, 64, device=device) * 0.4 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x, y = torch.meshgrid(torch.linspace(0, 2*np.pi, 64), torch.linspace(0, 2*np.pi, 64), indexing='ij') | 
					
					
						
						| 
							 | 
						    x, y = x.to(device), y.to(device) | 
					
					
						
						| 
							 | 
						    wave1 = torch.sin(x * 2) * torch.cos(y * 3) * 0.3 | 
					
					
						
						| 
							 | 
						    wave2 = torch.sin(x * 4 + y * 2) * 0.2 | 
					
					
						
						| 
							 | 
						    pattern[0, 0] += wave1 | 
					
					
						
						| 
							 | 
						    pattern[0, 1] += wave2 | 
					
					
						
						| 
							 | 
						    pattern[0, 2] += (wave1 + wave2) * 0.5 | 
					
					
						
						| 
							 | 
						    return torch.clamp(pattern, -1, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_high_contrast_version(img): | 
					
					
						
						| 
							 | 
						    """Create high contrast version""" | 
					
					
						
						| 
							 | 
						    contrast_img = img * 1.5 | 
					
					
						
						| 
							 | 
						    return torch.clamp(contrast_img, -1, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_soft_version(img): | 
					
					
						
						| 
							 | 
						    """Create soft/dreamy version""" | 
					
					
						
						| 
							 | 
						    soft_img = apply_blur(img, sigma=0.8) * 0.8 | 
					
					
						
						| 
							 | 
						    return soft_img | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_edge_enhanced_version(img): | 
					
					
						
						| 
							 | 
						    """Create edge-enhanced version""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    edge_kernel = torch.tensor([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]], dtype=torch.float32) | 
					
					
						
						| 
							 | 
						    edge_kernel = edge_kernel.view(1, 1, 3, 3).repeat(3, 1, 1, 1).to(img.device) | 
					
					
						
						| 
							 | 
						    edge_enhanced = torch.nn.functional.conv2d(img, edge_kernel, padding=1, groups=3) | 
					
					
						
						| 
							 | 
						    return torch.clamp(edge_enhanced, -1, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def save_comparison(images, filepath, labels=None): | 
					
					
						
						| 
							 | 
						    """Save comparison grid with labels""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    display_images = torch.clamp((images + 1) / 2, 0, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    nrow = len(images) if len(images) <= 4 else len(images) // 2 | 
					
					
						
						| 
							 | 
						    grid = make_grid(display_images, nrow=nrow, normalize=False, pad_value=1.0) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    save_image(grid, filepath) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    create_test_applications() | 
					
					
						
						| 
							 | 
						
 |