Grad-CDM / alternative_sampling.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
import numpy as np
def deterministic_sample(model, noise_scheduler, device, n_samples=4):
"""Deterministic sampling - just do a few big denoising steps"""
config = Config()
model.eval()
with torch.no_grad():
# Start with noise but not too extreme
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.5
print(f"Starting simplified sampling for {n_samples} samples...")
# Use fewer, bigger steps - more like denoising than full diffusion
timesteps = [400, 300, 200, 150, 100, 70, 50, 30, 20, 10, 5, 1]
for i, t_val in enumerate(timesteps):
print(f"Step {i+1}/{len(timesteps)}, t={t_val}")
t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Simple denoising step
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
# Predict clean image
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
# Move towards clean prediction
if i < len(timesteps) - 1:
# Not final step - blend
next_t = timesteps[i + 1]
alpha_bar_next = noise_scheduler.alpha_bars[next_t].item()
# Add some noise for next step
noise_scale = np.sqrt(1 - alpha_bar_next)
noise = torch.randn_like(x) * 0.1 # Much less noise
x = np.sqrt(alpha_bar_next) * pred_x0 + noise_scale * noise
else:
# Final step
x = pred_x0
x = torch.clamp(x, -1.5, 1.5) # Prevent drift
if i % 3 == 0:
print(f" Current range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}")
# Final clamp
x = torch.clamp(x, -1, 1)
print(f"Final samples:")
print(f" Range: [{x.min():.3f}, {x.max():.3f}]")
print(f" Mean: {x.mean():.3f}, Std: {x.std():.3f}")
# Convert to display range
x_display = torch.clamp((x + 1) / 2, 0, 1)
# Create and save grid
grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
save_image(grid, "simplified_samples.png")
print(f"Samples saved to simplified_samples.png")
return x, grid
def progressive_sample(model, noise_scheduler, device, n_samples=4):
"""Progressive denoising - start from less noise"""
config = Config()
model.eval()
with torch.no_grad():
# Start from moderately noisy image instead of pure noise
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.3
print(f"Starting progressive denoising for {n_samples} samples...")
# Start from a moderate timestep instead of maximum noise
start_t = 200
for step, t in enumerate(reversed(range(0, start_t))):
if step % 50 == 0:
print(f"Denoising step {step}/{start_t}, t={t}")
t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
# Get prediction
predicted_noise = model(x, t_tensor)
# Standard DDPM step but with more stability
alpha_t = noise_scheduler.alphas[t].item()
alpha_bar_t = noise_scheduler.alpha_bars[t].item()
beta_t = noise_scheduler.betas[t].item()
if t > 0:
alpha_bar_prev = noise_scheduler.alpha_bars[t-1].item()
# Predict x0
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
# Posterior mean
coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)
mean = coeff1 * x + coeff2 * pred_x0
# Reduced noise for stability
if t > 1:
posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
noise = torch.randn_like(x)
# Reduce noise by half for more stability
x = mean + np.sqrt(posterior_variance) * noise * 0.5
else:
x = mean
else:
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
# Gentle clamping
x = torch.clamp(x, -1.2, 1.2)
x = torch.clamp(x, -1, 1)
print(f"Progressive samples:")
print(f" Range: [{x.min():.3f}, {x.max():.3f}]")
print(f" Mean: {x.mean():.3f}, Std: {x.std():.3f}")
x_display = torch.clamp((x + 1) / 2, 0, 1)
grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
save_image(grid, "progressive_samples.png")
print(f"Samples saved to progressive_samples.png")
return x, grid
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = torch.load('model_final.pth', map_location=device)
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
model.load_state_dict(checkpoint)
print("=== TRYING DETERMINISTIC SAMPLING ===")
deterministic_sample(model, noise_scheduler, device, n_samples=4)
print("\n=== TRYING PROGRESSIVE SAMPLING ===")
progressive_sample(model, noise_scheduler, device, n_samples=4)
if __name__ == "__main__":
main()