import torch import numpy as np from scipy.fftpack import dctn, idctn class FrequencyAwareNoise: def __init__(self, config): self.config = config self.betas = torch.linspace(config.beta_start, config.beta_end, config.T) self.alphas = 1. - self.betas self.alpha_bars = torch.cumprod(self.alphas, dim=0) # Store as numpy arrays for DCT operations self.betas_np = self.betas.numpy() self.alphas_np = self.alphas.numpy() self.alpha_bars_np = self.alpha_bars.numpy() def apply_noise(self, x0, t, noise=None): """Add noise in frequency space (patch-wise DCT) - FIXED VERSION""" B, C, H, W = x0.shape device = x0.device xt = torch.zeros_like(x0) noise_spatial = torch.zeros_like(x0) # Store the spatial domain noise for training patch_size = self.config.patch_size # Convert t to CPU for numpy operations t_cpu = t.cpu() for i in range(0, H, patch_size): for j in range(0, W, patch_size): patch = x0[:, :, i:i+patch_size, j:j+patch_size] patch_np = patch.cpu().numpy() # DCT per patch dct = dctn(patch_np, axes=(2, 3), norm='ortho') # Generate noise in DCT domain noise_dct = np.random.randn(*dct.shape) # Apply frequency-dependent scaling max_freq = dct.shape[2] + dct.shape[3] - 2 for u in range(dct.shape[2]): for v in range(dct.shape[3]): freq_weight = 0.1 + 0.9 * (u + v) / max_freq noise_dct[:, :, u, v] *= freq_weight # Get noise schedule parameters alpha_bars = self.alpha_bars_np[t_cpu] if alpha_bars.ndim == 0: alpha_bars = np.array([alpha_bars]) alpha_bars = alpha_bars.reshape(-1, 1, 1, 1) if alpha_bars.shape[0] != dct.shape[0]: alpha_bars = np.broadcast_to(alpha_bars[0:1], (dct.shape[0], 1, 1, 1)) # Apply noise in DCT domain noisy_dct = np.sqrt(alpha_bars) * dct + np.sqrt(1 - alpha_bars) * noise_dct noisy_patch = idctn(noisy_dct, axes=(2, 3), norm='ortho') # IMPORTANT: Convert the DCT noise back to spatial for model to predict noise_patch_spatial = idctn(noise_dct, axes=(2, 3), norm='ortho') xt[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noisy_patch).float().to(device) noise_spatial[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noise_patch_spatial).float().to(device) return xt, noise_spatial def debug_noise_stats(self, x0, t): """Debug function to check noise statistics""" xt, noise = self.apply_noise(x0, t) print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]") print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]") print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]") print(f"Noise std: {noise.std().item():.4f}") return xt, noise