File size: 6,248 Bytes
8abfb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
import numpy as np
def deterministic_sample(model, noise_scheduler, device, n_samples=4):
"""Deterministic sampling - just do a few big denoising steps"""
config = Config()
model.eval()
with torch.no_grad():
# Start with noise but not too extreme
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.5
print(f"Starting simplified sampling for {n_samples} samples...")
# Use fewer, bigger steps - more like denoising than full diffusion
timesteps = [400, 300, 200, 150, 100, 70, 50, 30, 20, 10, 5, 1]
for i, t_val in enumerate(timesteps):
print(f"Step {i+1}/{len(timesteps)}, t={t_val}")
t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Simple denoising step
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
# Predict clean image
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
# Move towards clean prediction
if i < len(timesteps) - 1:
# Not final step - blend
next_t = timesteps[i + 1]
alpha_bar_next = noise_scheduler.alpha_bars[next_t].item()
# Add some noise for next step
noise_scale = np.sqrt(1 - alpha_bar_next)
noise = torch.randn_like(x) * 0.1 # Much less noise
x = np.sqrt(alpha_bar_next) * pred_x0 + noise_scale * noise
else:
# Final step
x = pred_x0
x = torch.clamp(x, -1.5, 1.5) # Prevent drift
if i % 3 == 0:
print(f" Current range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}")
# Final clamp
x = torch.clamp(x, -1, 1)
print(f"Final samples:")
print(f" Range: [{x.min():.3f}, {x.max():.3f}]")
print(f" Mean: {x.mean():.3f}, Std: {x.std():.3f}")
# Convert to display range
x_display = torch.clamp((x + 1) / 2, 0, 1)
# Create and save grid
grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
save_image(grid, "simplified_samples.png")
print(f"Samples saved to simplified_samples.png")
return x, grid
def progressive_sample(model, noise_scheduler, device, n_samples=4):
"""Progressive denoising - start from less noise"""
config = Config()
model.eval()
with torch.no_grad():
# Start from moderately noisy image instead of pure noise
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.3
print(f"Starting progressive denoising for {n_samples} samples...")
# Start from a moderate timestep instead of maximum noise
start_t = 200
for step, t in enumerate(reversed(range(0, start_t))):
if step % 50 == 0:
print(f"Denoising step {step}/{start_t}, t={t}")
t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
# Get prediction
predicted_noise = model(x, t_tensor)
# Standard DDPM step but with more stability
alpha_t = noise_scheduler.alphas[t].item()
alpha_bar_t = noise_scheduler.alpha_bars[t].item()
beta_t = noise_scheduler.betas[t].item()
if t > 0:
alpha_bar_prev = noise_scheduler.alpha_bars[t-1].item()
# Predict x0
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
# Posterior mean
coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)
mean = coeff1 * x + coeff2 * pred_x0
# Reduced noise for stability
if t > 1:
posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
noise = torch.randn_like(x)
# Reduce noise by half for more stability
x = mean + np.sqrt(posterior_variance) * noise * 0.5
else:
x = mean
else:
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
# Gentle clamping
x = torch.clamp(x, -1.2, 1.2)
x = torch.clamp(x, -1, 1)
print(f"Progressive samples:")
print(f" Range: [{x.min():.3f}, {x.max():.3f}]")
print(f" Mean: {x.mean():.3f}, Std: {x.std():.3f}")
x_display = torch.clamp((x + 1) / 2, 0, 1)
grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
save_image(grid, "progressive_samples.png")
print(f"Samples saved to progressive_samples.png")
return x, grid
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = torch.load('model_final.pth', map_location=device)
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
model.load_state_dict(checkpoint)
print("=== TRYING DETERMINISTIC SAMPLING ===")
deterministic_sample(model, noise_scheduler, device, n_samples=4)
print("\n=== TRYING PROGRESSIVE SAMPLING ===")
progressive_sample(model, noise_scheduler, device, n_samples=4)
if __name__ == "__main__":
main()
|