Grad-CDM / sample_simple.py
nazgut's picture
Upload 24 files
8abfb97 verified
raw
history blame
2.98 kB
import torch
import torchvision
from torchvision.utils import save_image
import os
from config import Config
def simple_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
"""Standard DDPM sampling - this should actually work"""
config = Config()
model.eval()
with torch.no_grad():
# Start with random noise
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device)
print(f"Starting reverse diffusion for {n_samples} samples...")
# Move scheduler tensors to device
alphas = noise_scheduler.alphas.to(device)
alpha_bars = noise_scheduler.alpha_bars.to(device)
betas = noise_scheduler.betas.to(device)
# Reverse diffusion process
for step, t in enumerate(reversed(range(config.T))):
if step % 100 == 0:
print(f"Step {step}/{config.T}, t={t}")
t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
# Predict noise
pred_noise = model(x, t_tensor)
# Get schedule parameters
alpha_t = alphas[t]
alpha_bar_t = alpha_bars[t]
beta_t = betas[t]
# Standard DDPM reverse step
if t > 0:
alpha_bar_prev = alpha_bars[t-1]
# Predict x0
pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
# Compute mean
mean = (torch.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)) * pred_x0 + \
(torch.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)) * x
# Add noise
noise = torch.randn_like(x)
variance = (1 - alpha_bar_prev) / (1 - alpha_bar_t) * beta_t
x = mean + torch.sqrt(variance) * noise
else:
# Final step
x = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
# Clamp to valid range
x = torch.clamp(x, -1, 1)
# Debug: print sample statistics
if epoch is not None and epoch % 10 == 0:
print(f"Sample stats at epoch {epoch}: range [{x.min().item():.3f}, {x.max().item():.3f}], mean {x.mean().item():.3f}")
grid = torchvision.utils.make_grid(x, nrow=2, normalize=True)
if writer:
writer.add_image('Samples', grid, epoch)
if epoch is not None:
os.makedirs("samples", exist_ok=True)
save_image(grid, f"samples/epoch_{epoch}.png")
return x, grid
# Use the simple sampler
def sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
return simple_sample(model, noise_scheduler, device, epoch, writer, n_samples)