File size: 6,248 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
import numpy as np

def deterministic_sample(model, noise_scheduler, device, n_samples=4):
    """Deterministic sampling - just do a few big denoising steps"""
    config = Config()
    model.eval()
    
    with torch.no_grad():
        # Start with noise but not too extreme
        x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.5
        
        print(f"Starting simplified sampling for {n_samples} samples...")
        
        # Use fewer, bigger steps - more like denoising than full diffusion
        timesteps = [400, 300, 200, 150, 100, 70, 50, 30, 20, 10, 5, 1]
        
        for i, t_val in enumerate(timesteps):
            print(f"Step {i+1}/{len(timesteps)}, t={t_val}")
            
            t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long)
            
            # Get model prediction
            predicted_noise = model(x, t_tensor)
            
            # Simple denoising step
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            
            # Predict clean image
            pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
            pred_x0 = torch.clamp(pred_x0, -1, 1)
            
            # Move towards clean prediction
            if i < len(timesteps) - 1:
                # Not final step - blend
                next_t = timesteps[i + 1]
                alpha_bar_next = noise_scheduler.alpha_bars[next_t].item()
                
                # Add some noise for next step
                noise_scale = np.sqrt(1 - alpha_bar_next)
                noise = torch.randn_like(x) * 0.1  # Much less noise
                
                x = np.sqrt(alpha_bar_next) * pred_x0 + noise_scale * noise
            else:
                # Final step
                x = pred_x0
            
            x = torch.clamp(x, -1.5, 1.5)  # Prevent drift
            
            if i % 3 == 0:
                print(f"  Current range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}")
        
        # Final clamp
        x = torch.clamp(x, -1, 1)
        
        print(f"Final samples:")
        print(f"  Range: [{x.min():.3f}, {x.max():.3f}]")
        print(f"  Mean: {x.mean():.3f}, Std: {x.std():.3f}")
        
        # Convert to display range
        x_display = torch.clamp((x + 1) / 2, 0, 1)
        
        # Create and save grid
        grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
        save_image(grid, "simplified_samples.png")
        print(f"Samples saved to simplified_samples.png")
        
        return x, grid

def progressive_sample(model, noise_scheduler, device, n_samples=4):
    """Progressive denoising - start from less noise"""
    config = Config()
    model.eval()
    
    with torch.no_grad():
        # Start from moderately noisy image instead of pure noise
        x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.3
        
        print(f"Starting progressive denoising for {n_samples} samples...")
        
        # Start from a moderate timestep instead of maximum noise
        start_t = 200
        
        for step, t in enumerate(reversed(range(0, start_t))):
            if step % 50 == 0:
                print(f"Denoising step {step}/{start_t}, t={t}")
            
            t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
            
            # Get prediction
            predicted_noise = model(x, t_tensor)
            
            # Standard DDPM step but with more stability
            alpha_t = noise_scheduler.alphas[t].item()
            alpha_bar_t = noise_scheduler.alpha_bars[t].item()
            beta_t = noise_scheduler.betas[t].item()
            
            if t > 0:
                alpha_bar_prev = noise_scheduler.alpha_bars[t-1].item()
                
                # Predict x0
                pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
                pred_x0 = torch.clamp(pred_x0, -1, 1)
                
                # Posterior mean
                coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
                coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)
                mean = coeff1 * x + coeff2 * pred_x0
                
                # Reduced noise for stability
                if t > 1:
                    posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
                    noise = torch.randn_like(x)
                    # Reduce noise by half for more stability
                    x = mean + np.sqrt(posterior_variance) * noise * 0.5
                else:
                    x = mean
            else:
                x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
            
            # Gentle clamping
            x = torch.clamp(x, -1.2, 1.2)
        
        x = torch.clamp(x, -1, 1)
        
        print(f"Progressive samples:")
        print(f"  Range: [{x.min():.3f}, {x.max():.3f}]")
        print(f"  Mean: {x.mean():.3f}, Std: {x.std():.3f}")
        
        x_display = torch.clamp((x + 1) / 2, 0, 1)
        grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
        save_image(grid, "progressive_samples.png")
        print(f"Samples saved to progressive_samples.png")
        
        return x, grid

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model
    checkpoint = torch.load('model_final.pth', map_location=device)
    config = Config()
    
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    model.load_state_dict(checkpoint)
    
    print("=== TRYING DETERMINISTIC SAMPLING ===")
    deterministic_sample(model, noise_scheduler, device, n_samples=4)
    
    print("\n=== TRYING PROGRESSIVE SAMPLING ===")
    progressive_sample(model, noise_scheduler, device, n_samples=4)

if __name__ == "__main__":
    main()