Grad-CDM / sample.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
import torchvision
from torchvision.utils import save_image
import os
import numpy as np
from scipy.fftpack import dctn, idctn
from config import Config
def frequency_aware_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
"""OPTIMIZED sampling for frequency-aware trained models"""
config = Config()
model.eval()
with torch.no_grad():
# Start with moderate noise instead of extreme noise
# Your model excels at moderate denoising, not extreme noise removal
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.4
print(f"Starting optimized frequency-aware sampling for {n_samples} samples...")
print(f"Initial moderate noise range: [{x.min().item():.3f}, {x.max().item():.3f}]")
# Use adaptive timestep schedule - fewer steps, bigger jumps
# This works better with frequency-aware training
total_steps = 100 # Much fewer than 500
timesteps = []
# Create exponential decay schedule
for i in range(total_steps):
# Start from 300 instead of 499 (skip extreme noise)
t = int(300 * (1 - i / total_steps) ** 2)
timesteps.append(max(t, 0))
timesteps = sorted(list(set(timesteps)), reverse=True) # Remove duplicates
print(f"Using {len(timesteps)} adaptive timesteps: {timesteps[:10]}...{timesteps[-5:]}")
for step, t in enumerate(timesteps):
if step % 20 == 0:
print(f" Step {step}/{len(timesteps)}, t={t}, range: [{x.min().item():.3f}, {x.max().item():.3f}]")
t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Get noise schedule parameters
alpha_t = noise_scheduler.alphas[t].item()
alpha_bar_t = noise_scheduler.alpha_bars[t].item()
beta_t = noise_scheduler.betas[t].item()
if step < len(timesteps) - 1:
# Not final step
next_t = timesteps[step + 1]
alpha_bar_prev = noise_scheduler.alpha_bars[next_t].item()
# Predict clean image with stability clamping
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1.2, 1.2) # Prevent extreme values
# Compute posterior mean with frequency-aware adjustments
coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)
posterior_mean = coeff1 * x + coeff2 * pred_x0
# Add controlled noise - much less than standard DDPM
if next_t > 0:
posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
noise = torch.randn_like(x)
# Reduce noise for stability - key for frequency-aware models
noise_scale = np.sqrt(posterior_variance) * 0.3 # 70% less noise
x = posterior_mean + noise_scale * noise
else:
x = posterior_mean
else:
# Final step - direct prediction
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
# Gentle clamping to prevent drift (key for long sampling chains)
x = torch.clamp(x, -1.3, 1.3)
# Final processing
x = torch.clamp(x, -1, 1)
print(f"Final samples statistics:")
print(f" Range: [{x.min().item():.3f}, {x.max().item():.3f}]")
print(f" Mean: {x.mean().item():.3f}, Std: {x.std().item():.3f}")
# Quality checks
unique_vals = len(torch.unique(torch.round(x * 100) / 100))
print(f" Unique values (x100): {unique_vals}")
if unique_vals < 20:
print(" ⚠️ Low diversity - might be collapsed")
elif x.std().item() < 0.05:
print(" ⚠️ Very low variance - uniform output")
elif x.std().item() > 0.9:
print(" ⚠️ High variance - might still be noisy")
else:
print(" ✅ Good sample diversity and range!")
# Convert to display format
x_display = torch.clamp((x + 1.0) / 2.0, 0, 1)
# Create grid with proper formatting
grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
# Save with epoch info
if writer and epoch is not None:
writer.add_image('Samples', grid, epoch)
if epoch is not None:
os.makedirs("samples", exist_ok=True)
save_image(grid, f"samples/epoch_{epoch}.png")
return x, grid
# Alternative sampling method specifically for frequency-aware models
def progressive_frequency_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
"""Progressive sampling - fewer steps, more stable for frequency-aware models"""
config = Config()
model.eval()
with torch.no_grad():
# Start from moderate noise instead of maximum
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.4
print(f"Starting progressive frequency sampling for {n_samples} samples...")
# Use fewer, larger steps - better for frequency-aware training
timesteps = [300, 250, 200, 150, 120, 90, 70, 50, 35, 25, 15, 8, 3, 1]
for i, t_val in enumerate(timesteps):
print(f"Step {i+1}/{len(timesteps)}, t={t_val}")
t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Get schedule parameters
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
# Predict clean image
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
# Move towards clean prediction
if i < len(timesteps) - 1:
next_t = timesteps[i + 1]
alpha_bar_next = noise_scheduler.alpha_bars[next_t].item()
# Blend current image with clean prediction
blend_factor = 0.3 # How much to trust the clean prediction
x = (1 - blend_factor) * x + blend_factor * pred_x0
# Add controlled noise for next step
noise_scale = np.sqrt(1 - alpha_bar_next) * 0.2 # Reduced noise
noise = torch.randn_like(x)
x = np.sqrt(alpha_bar_next) * x + noise_scale * noise
else:
# Final step
x = pred_x0
# Prevent drift
x = torch.clamp(x, -1.2, 1.2)
# Final cleanup
x = torch.clamp(x, -1, 1)
print(f"Progressive samples - Range: [{x.min():.3f}, {x.max():.3f}], Mean: {x.mean():.3f}, Std: {x.std():.3f}")
# Convert to display range and create grid
x_display = torch.clamp((x + 1) / 2, 0, 1)
grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
if writer and epoch is not None:
writer.add_image('Progressive_Samples', grid, epoch)
if epoch is not None:
os.makedirs("samples", exist_ok=True)
save_image(grid, f"samples/progressive_epoch_{epoch}.png")
return x, grid
def optimized_frequency_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
"""Optimized sampling with adaptive timesteps for frequency-aware models"""
config = Config()
model.eval()
with torch.no_grad():
# Start with moderate noise
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.5
print(f"Starting optimized frequency sampling for {n_samples} samples...")
# Adaptive timestep schedule - more steps where model is most effective
early_steps = list(range(400, 200, -25)) # Coarse denoising
middle_steps = list(range(200, 50, -15)) # Fine denoising
final_steps = list(range(50, 0, -5)) # Detail refinement
timesteps = early_steps + middle_steps + final_steps
for i, t_val in enumerate(timesteps):
if i % 10 == 0:
print(f"Step {i+1}/{len(timesteps)}, t={t_val}")
t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Standard DDPM step with stability improvements
alpha_t = noise_scheduler.alphas[t_val].item()
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
beta_t = noise_scheduler.betas[t_val].item()
if t_val > 0:
# Find next timestep
next_idx = min(i + 1, len(timesteps) - 1)
if next_idx < len(timesteps):
next_t = timesteps[next_idx] if next_idx < len(timesteps) else 0
alpha_bar_prev = noise_scheduler.alpha_bars[next_t].item() if next_t > 0 else 1.0
else:
alpha_bar_prev = 1.0
# Predict x0
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
# Compute posterior mean
coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)
mean = coeff1 * x + coeff2 * pred_x0
# Add noise with adaptive scaling
if t_val > 5:
posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
# Reduce noise in later steps for stability
noise_scale = 1.0 if t_val > 100 else 0.5
noise = torch.randn_like(x)
x = mean + np.sqrt(posterior_variance) * noise * noise_scale
else:
x = mean
else:
# Final step
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
# Adaptive clamping - tighter as we get closer to final image
clamp_range = 2.0 if t_val > 200 else 1.5 if t_val > 50 else 1.2
x = torch.clamp(x, -clamp_range, clamp_range)
# Final clamp to data range
x = torch.clamp(x, -1, 1)
print(f"Optimized samples - Range: [{x.min():.3f}, {x.max():.3f}], Mean: {x.mean():.3f}, Std: {x.std():.3f}")
# Quality check
unique_vals = len(torch.unique(torch.round(x * 100) / 100))
if unique_vals > 50:
print("✅ Good diversity in generated samples")
else:
print("⚠️ Low diversity - samples might be collapsed")
# Convert to display range and create grid
x_display = torch.clamp((x + 1) / 2, 0, 1)
grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
if writer and epoch is not None:
writer.add_image('Optimized_Samples', grid, epoch)
if epoch is not None:
os.makedirs("samples", exist_ok=True)
save_image(grid, f"samples/optimized_epoch_{epoch}.png")
return x, grid
# Aggressive sampling method leveraging the model's strong denoising ability
def aggressive_frequency_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
"""Aggressive sampling - leverages the model's strong denoising ability"""
config = Config()
model.eval()
with torch.no_grad():
# Start with stronger noise since your model handles it well
x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.8
print(f"Starting aggressive frequency sampling for {n_samples} samples...")
print(f"Initial noise range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}")
# Use your model's sweet spot - it excels at moderate denoising
# So do several medium-strength denoising steps
timesteps = [350, 280, 220, 170, 130, 100, 75, 55, 40, 28, 18, 10, 5, 2, 1]
for i, t_val in enumerate(timesteps):
t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Your model predicts noise very accurately, so trust it more
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
# Predict clean image
pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
if i < len(timesteps) - 2: # Not final steps
# Move aggressively toward clean prediction
alpha_bar_next = noise_scheduler.alpha_bars[timesteps[i + 1]].item() if i + 1 < len(timesteps) else 1.0
# Trust the model more (higher blend factor)
trust_factor = 0.6 if t_val > 100 else 0.8
x = (1 - trust_factor) * x + trust_factor * pred_x0
# Add fresh noise for next iteration
if t_val > 10:
noise_strength = np.sqrt(1 - alpha_bar_next) * 0.4
fresh_noise = torch.randn_like(x)
x = np.sqrt(alpha_bar_next) * x + noise_strength * fresh_noise
elif i == len(timesteps) - 2: # Second to last step
# Almost final - very gentle noise
x = 0.2 * x + 0.8 * pred_x0
tiny_noise = torch.randn_like(x) * 0.05
x = x + tiny_noise
else: # Final step
x = pred_x0
# Prevent explosion but allow more range
x = torch.clamp(x, -1.5, 1.5)
if i % 3 == 0:
print(f" Step {i+1}/{len(timesteps)}, t={t_val}, range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}")
# Final clamp to data range
x = torch.clamp(x, -1, 1)
print(f"Aggressive samples - Range: [{x.min():.3f}, {x.max():.3f}], Mean: {x.mean():.3f}, Std: {x.std():.3f}")
# Quality metrics
unique_vals = len(torch.unique(torch.round(x * 200) / 200)) # Higher resolution check
print(f"Unique values (x200): {unique_vals}")
if x.std().item() < 0.05:
print("❌ Very low variance - output collapsed")
elif x.std().item() < 0.15:
print("⚠️ Low variance - output may be too smooth")
elif x.std().item() > 0.6:
print("⚠️ High variance - output may be noisy")
else:
print("✅ Good variance - output looks promising")
if unique_vals < 20:
print("❌ Very low diversity")
elif unique_vals < 100:
print("⚠️ Moderate diversity")
else:
print("✅ Good diversity")
# Convert to display range and create grid
x_display = torch.clamp((x + 1) / 2, 0, 1)
grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0)
if writer and epoch is not None:
writer.add_image('Aggressive_Samples', grid, epoch)
if epoch is not None:
os.makedirs("samples", exist_ok=True)
save_image(grid, f"samples/aggressive_epoch_{epoch}.png")
return x, grid
# Keep the old function name for compatibility
def sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
return frequency_aware_sample(model, noise_scheduler, device, epoch, writer, n_samples)