import torch from model import SmoothDiffusionUNet from noise_scheduler import FrequencyAwareNoise from config import Config from torchvision.utils import save_image, make_grid from dataloader import get_dataloaders import numpy as np def hybrid_generation(): """Hybrid approach: Use model as super-denoiser rather than pure generator""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model checkpoint = torch.load('model_final.pth', map_location=device) config = Config() model = SmoothDiffusionUNet(config).to(device) noise_scheduler = FrequencyAwareNoise(config) model.load_state_dict(checkpoint) model.eval() # Load real training data for smart initialization train_loader, _ = get_dataloaders(config) real_batch, _ = next(iter(train_loader)) real_images = real_batch[:8].to(device) print("=== HYBRID GENERATION APPROACH ===") with torch.no_grad(): # Method 1: Smart noise initialization print("\n--- Method 1: Smart Noise Initialization ---") # Initialize with noise that has similar statistics to training data smart_noise = torch.randn(4, 3, 64, 64, device=device) smart_noise = smart_noise * real_images.std().item() # Match training data std smart_noise = smart_noise + real_images.mean().item() # Match training data mean smart_noise = torch.clamp(smart_noise, -1, 1) print(f"Smart noise stats: mean={smart_noise.mean():.3f}, std={smart_noise.std():.3f}") # Apply progressive denoising timesteps = [150, 120, 90, 70, 50, 35, 25, 15, 8, 3, 1] x = smart_noise.clone() for t_val in timesteps: t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) predicted_noise = model(x, t_tensor) alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.7) / np.sqrt(alpha_bar_t) x = torch.clamp(x, -1, 1) # Save result smart_display = torch.clamp((x + 1) / 2, 0, 1) smart_grid = make_grid(smart_display, nrow=2, normalize=False) save_image(smart_grid, "smart_noise_generation.png") print(f"Smart noise result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}") print("Saved to smart_noise_generation.png") # Method 2: Blended real images + denoising print("\n--- Method 2: Blended Real Images ---") # Create new combinations by blending random real images indices = torch.randint(0, len(real_images), (4, 3)) # Pick 3 random images for each output weights = torch.rand(4, 3, device=device) weights = weights / weights.sum(dim=1, keepdim=True) # Normalize weights blended = torch.zeros(4, 3, 64, 64, device=device) for i in range(4): for j in range(3): blended[i] += weights[i, j] * real_images[indices[i, j]] # Add some noise to make it more interesting noise = torch.randn_like(blended) * 0.15 blended = blended + noise blended = torch.clamp(blended, -1, 1) # Light denoising to clean up light_timesteps = [80, 60, 40, 25, 12, 5, 1] x = blended.clone() for t_val in light_timesteps: t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) predicted_noise = model(x, t_tensor) alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t) x = torch.clamp(x, -1, 1) # Save result blended_display = torch.clamp((x + 1) / 2, 0, 1) blended_grid = make_grid(blended_display, nrow=2, normalize=False) save_image(blended_grid, "blended_generation.png") print(f"Blended result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}") print("Saved to blended_generation.png") # Method 3: Frequency-domain initialization print("\n--- Method 3: Frequency-Domain Initialization ---") # Start with structured noise in frequency domain, then convert to spatial from scipy.fftpack import dctn, idctn freq_images = torch.zeros(4, 3, 64, 64, device=device) for i in range(4): for c in range(3): # Create structured frequency pattern freq_pattern = np.zeros((64, 64)) # Add some low-frequency components (overall shape/color) for u in range(0, 8): for v in range(0, 8): freq_pattern[u, v] = np.random.randn() * (1.0 / (1 + u + v)) # Add some mid-frequency components (textures) for u in range(8, 20): for v in range(8, 20): freq_pattern[u, v] = np.random.randn() * 0.1 # Convert to spatial domain spatial = idctn(freq_pattern, norm='ortho') freq_images[i, c] = torch.from_numpy(spatial).float() # Normalize to training data range freq_images = freq_images.to(device) freq_images = freq_images - freq_images.mean() freq_images = freq_images / freq_images.std() * real_images.std() freq_images = torch.clamp(freq_images, -1, 1) # Apply denoising freq_timesteps = [100, 75, 55, 40, 28, 18, 10, 4, 1] x = freq_images.clone() for t_val in freq_timesteps: t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) predicted_noise = model(x, t_tensor) alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.6) / np.sqrt(alpha_bar_t) x = torch.clamp(x, -1, 1) # Save result freq_display = torch.clamp((x + 1) / 2, 0, 1) freq_grid = make_grid(freq_display, nrow=2, normalize=False) save_image(freq_grid, "frequency_generation.png") print(f"Frequency result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}") print("Saved to frequency_generation.png") print("\n=== RESULTS ===") print("Generated files:") print("- smart_noise_generation.png (noise matching training stats)") print("- blended_generation.png (combinations of real images)") print("- frequency_generation.png (frequency-domain initialization)") print("\nYour model works as a super-denoiser!") print("It can clean up any reasonable starting point to look more image-like.") if __name__ == "__main__": hybrid_generation()