|
import torch |
|
import numpy as np |
|
|
|
class FrequencyAwareNoise: |
|
def __init__(self, config): |
|
self.config = config |
|
self.betas = torch.linspace(config.beta_start, config.beta_end, config.T) |
|
self.alphas = 1. - self.betas |
|
self.alpha_bars = torch.cumprod(self.alphas, dim=0) |
|
|
|
def apply_noise(self, x0, t, noise=None): |
|
"""Standard DDPM noise application - let's get basic diffusion working first""" |
|
if noise is None: |
|
noise = torch.randn_like(x0) |
|
|
|
device = x0.device |
|
|
|
|
|
alpha_bars = self.alpha_bars.to(device) |
|
|
|
|
|
alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1) |
|
|
|
|
|
xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise |
|
|
|
return xt, noise |
|
|
|
def debug_noise_stats(self, x0, t): |
|
"""Debug function to check noise statistics""" |
|
xt, noise = self.apply_noise(x0, t) |
|
print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]") |
|
print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]") |
|
print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]") |
|
print(f"Noise std: {noise.std().item():.4f}") |
|
return xt, noise |
|
|