import torch import numpy as np class FrequencyAwareNoise: def __init__(self, config): self.config = config self.betas = torch.linspace(config.beta_start, config.beta_end, config.T) self.alphas = 1. - self.betas self.alpha_bars = torch.cumprod(self.alphas, dim=0) def apply_noise(self, x0, t, noise=None): """Standard DDPM noise application - let's get basic diffusion working first""" if noise is None: noise = torch.randn_like(x0) device = x0.device # Move scheduler tensors to the correct device alpha_bars = self.alpha_bars.to(device) # Get alpha_bar for the given timesteps alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1) # Standard DDPM: xt = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise return xt, noise def debug_noise_stats(self, x0, t): """Debug function to check noise statistics""" xt, noise = self.apply_noise(x0, t) print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]") print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]") print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]") print(f"Noise std: {noise.std().item():.4f}") return xt, noise