File size: 1,433 Bytes
8abfb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
import numpy as np
class FrequencyAwareNoise:
def __init__(self, config):
self.config = config
self.betas = torch.linspace(config.beta_start, config.beta_end, config.T)
self.alphas = 1. - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def apply_noise(self, x0, t, noise=None):
"""Standard DDPM noise application - let's get basic diffusion working first"""
if noise is None:
noise = torch.randn_like(x0)
device = x0.device
# Move scheduler tensors to the correct device
alpha_bars = self.alpha_bars.to(device)
# Get alpha_bar for the given timesteps
alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
# Standard DDPM: xt = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise
xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise
return xt, noise
def debug_noise_stats(self, x0, t):
"""Debug function to check noise statistics"""
xt, noise = self.apply_noise(x0, t)
print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]")
print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]")
print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]")
print(f"Noise std: {noise.std().item():.4f}")
return xt, noise
|