Grad-CDM / noise_scheduler.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
import numpy as np
from scipy.fftpack import dctn, idctn
class FrequencyAwareNoise:
def __init__(self, config):
self.config = config
self.betas = torch.linspace(config.beta_start, config.beta_end, config.T)
self.alphas = 1. - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
# Store as numpy arrays for DCT operations
self.betas_np = self.betas.numpy()
self.alphas_np = self.alphas.numpy()
self.alpha_bars_np = self.alpha_bars.numpy()
def apply_noise(self, x0, t, noise=None):
"""Add noise in frequency space (patch-wise DCT) - FIXED VERSION"""
B, C, H, W = x0.shape
device = x0.device
xt = torch.zeros_like(x0)
noise_spatial = torch.zeros_like(x0) # Store the spatial domain noise for training
patch_size = self.config.patch_size
# Convert t to CPU for numpy operations
t_cpu = t.cpu()
for i in range(0, H, patch_size):
for j in range(0, W, patch_size):
patch = x0[:, :, i:i+patch_size, j:j+patch_size]
patch_np = patch.cpu().numpy()
# DCT per patch
dct = dctn(patch_np, axes=(2, 3), norm='ortho')
# Generate noise in DCT domain
noise_dct = np.random.randn(*dct.shape)
# Apply frequency-dependent scaling
max_freq = dct.shape[2] + dct.shape[3] - 2
for u in range(dct.shape[2]):
for v in range(dct.shape[3]):
freq_weight = 0.1 + 0.9 * (u + v) / max_freq
noise_dct[:, :, u, v] *= freq_weight
# Get noise schedule parameters
alpha_bars = self.alpha_bars_np[t_cpu]
if alpha_bars.ndim == 0:
alpha_bars = np.array([alpha_bars])
alpha_bars = alpha_bars.reshape(-1, 1, 1, 1)
if alpha_bars.shape[0] != dct.shape[0]:
alpha_bars = np.broadcast_to(alpha_bars[0:1], (dct.shape[0], 1, 1, 1))
# Apply noise in DCT domain
noisy_dct = np.sqrt(alpha_bars) * dct + np.sqrt(1 - alpha_bars) * noise_dct
noisy_patch = idctn(noisy_dct, axes=(2, 3), norm='ortho')
# IMPORTANT: Convert the DCT noise back to spatial for model to predict
noise_patch_spatial = idctn(noise_dct, axes=(2, 3), norm='ortho')
xt[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noisy_patch).float().to(device)
noise_spatial[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noise_patch_spatial).float().to(device)
return xt, noise_spatial
def debug_noise_stats(self, x0, t):
"""Debug function to check noise statistics"""
xt, noise = self.apply_noise(x0, t)
print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]")
print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]")
print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]")
print(f"Noise std: {noise.std().item():.4f}")
return xt, noise