File size: 1,433 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np

class FrequencyAwareNoise:
    def __init__(self, config):
        self.config = config
        self.betas = torch.linspace(config.beta_start, config.beta_end, config.T)
        self.alphas = 1. - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def apply_noise(self, x0, t, noise=None):
        """Standard DDPM noise application - let's get basic diffusion working first"""
        if noise is None:
            noise = torch.randn_like(x0)
        
        device = x0.device
        
        # Move scheduler tensors to the correct device
        alpha_bars = self.alpha_bars.to(device)
        
        # Get alpha_bar for the given timesteps
        alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
        
        # Standard DDPM: xt = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise
        xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise
        
        return xt, noise
    
    def debug_noise_stats(self, x0, t):
        """Debug function to check noise statistics"""
        xt, noise = self.apply_noise(x0, t)
        print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]")
        print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]")
        print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]")
        print(f"Noise std: {noise.std().item():.4f}")
        return xt, noise