|
import torch |
|
import numpy as np |
|
from scipy.fftpack import dctn, idctn |
|
|
|
class FrequencyAwareNoise: |
|
def __init__(self, config): |
|
self.config = config |
|
self.betas = torch.linspace(config.beta_start, config.beta_end, config.T) |
|
self.alphas = 1. - self.betas |
|
self.alpha_bars = torch.cumprod(self.alphas, dim=0) |
|
|
|
|
|
self.betas_np = self.betas.numpy() |
|
self.alphas_np = self.alphas.numpy() |
|
self.alpha_bars_np = self.alpha_bars.numpy() |
|
|
|
def apply_noise(self, x0, t, noise=None): |
|
"""Add noise in frequency space (patch-wise DCT) - FIXED VERSION""" |
|
B, C, H, W = x0.shape |
|
device = x0.device |
|
xt = torch.zeros_like(x0) |
|
noise_spatial = torch.zeros_like(x0) |
|
patch_size = self.config.patch_size |
|
|
|
|
|
t_cpu = t.cpu() |
|
|
|
for i in range(0, H, patch_size): |
|
for j in range(0, W, patch_size): |
|
patch = x0[:, :, i:i+patch_size, j:j+patch_size] |
|
patch_np = patch.cpu().numpy() |
|
|
|
|
|
dct = dctn(patch_np, axes=(2, 3), norm='ortho') |
|
|
|
|
|
noise_dct = np.random.randn(*dct.shape) |
|
|
|
|
|
max_freq = dct.shape[2] + dct.shape[3] - 2 |
|
for u in range(dct.shape[2]): |
|
for v in range(dct.shape[3]): |
|
freq_weight = 0.1 + 0.9 * (u + v) / max_freq |
|
noise_dct[:, :, u, v] *= freq_weight |
|
|
|
|
|
alpha_bars = self.alpha_bars_np[t_cpu] |
|
if alpha_bars.ndim == 0: |
|
alpha_bars = np.array([alpha_bars]) |
|
alpha_bars = alpha_bars.reshape(-1, 1, 1, 1) |
|
if alpha_bars.shape[0] != dct.shape[0]: |
|
alpha_bars = np.broadcast_to(alpha_bars[0:1], (dct.shape[0], 1, 1, 1)) |
|
|
|
|
|
noisy_dct = np.sqrt(alpha_bars) * dct + np.sqrt(1 - alpha_bars) * noise_dct |
|
noisy_patch = idctn(noisy_dct, axes=(2, 3), norm='ortho') |
|
|
|
|
|
noise_patch_spatial = idctn(noise_dct, axes=(2, 3), norm='ortho') |
|
|
|
xt[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noisy_patch).float().to(device) |
|
noise_spatial[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noise_patch_spatial).float().to(device) |
|
|
|
return xt, noise_spatial |
|
|
|
def debug_noise_stats(self, x0, t): |
|
"""Debug function to check noise statistics""" |
|
xt, noise = self.apply_noise(x0, t) |
|
print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]") |
|
print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]") |
|
print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]") |
|
print(f"Noise std: {noise.std().item():.4f}") |
|
return xt, noise |