Grad-CDM / hybrid_generation.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
from dataloader import get_dataloaders
import numpy as np
def hybrid_generation():
"""Hybrid approach: Use model as super-denoiser rather than pure generator"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = torch.load('model_final.pth', map_location=device)
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
model.load_state_dict(checkpoint)
model.eval()
# Load real training data for smart initialization
train_loader, _ = get_dataloaders(config)
real_batch, _ = next(iter(train_loader))
real_images = real_batch[:8].to(device)
print("=== HYBRID GENERATION APPROACH ===")
with torch.no_grad():
# Method 1: Smart noise initialization
print("\n--- Method 1: Smart Noise Initialization ---")
# Initialize with noise that has similar statistics to training data
smart_noise = torch.randn(4, 3, 64, 64, device=device)
smart_noise = smart_noise * real_images.std().item() # Match training data std
smart_noise = smart_noise + real_images.mean().item() # Match training data mean
smart_noise = torch.clamp(smart_noise, -1, 1)
print(f"Smart noise stats: mean={smart_noise.mean():.3f}, std={smart_noise.std():.3f}")
# Apply progressive denoising
timesteps = [150, 120, 90, 70, 50, 35, 25, 15, 8, 3, 1]
x = smart_noise.clone()
for t_val in timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
predicted_noise = model(x, t_tensor)
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.7) / np.sqrt(alpha_bar_t)
x = torch.clamp(x, -1, 1)
# Save result
smart_display = torch.clamp((x + 1) / 2, 0, 1)
smart_grid = make_grid(smart_display, nrow=2, normalize=False)
save_image(smart_grid, "smart_noise_generation.png")
print(f"Smart noise result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
print("Saved to smart_noise_generation.png")
# Method 2: Blended real images + denoising
print("\n--- Method 2: Blended Real Images ---")
# Create new combinations by blending random real images
indices = torch.randint(0, len(real_images), (4, 3)) # Pick 3 random images for each output
weights = torch.rand(4, 3, device=device)
weights = weights / weights.sum(dim=1, keepdim=True) # Normalize weights
blended = torch.zeros(4, 3, 64, 64, device=device)
for i in range(4):
for j in range(3):
blended[i] += weights[i, j] * real_images[indices[i, j]]
# Add some noise to make it more interesting
noise = torch.randn_like(blended) * 0.15
blended = blended + noise
blended = torch.clamp(blended, -1, 1)
# Light denoising to clean up
light_timesteps = [80, 60, 40, 25, 12, 5, 1]
x = blended.clone()
for t_val in light_timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
predicted_noise = model(x, t_tensor)
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t)
x = torch.clamp(x, -1, 1)
# Save result
blended_display = torch.clamp((x + 1) / 2, 0, 1)
blended_grid = make_grid(blended_display, nrow=2, normalize=False)
save_image(blended_grid, "blended_generation.png")
print(f"Blended result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
print("Saved to blended_generation.png")
# Method 3: Frequency-domain initialization
print("\n--- Method 3: Frequency-Domain Initialization ---")
# Start with structured noise in frequency domain, then convert to spatial
from scipy.fftpack import dctn, idctn
freq_images = torch.zeros(4, 3, 64, 64, device=device)
for i in range(4):
for c in range(3):
# Create structured frequency pattern
freq_pattern = np.zeros((64, 64))
# Add some low-frequency components (overall shape/color)
for u in range(0, 8):
for v in range(0, 8):
freq_pattern[u, v] = np.random.randn() * (1.0 / (1 + u + v))
# Add some mid-frequency components (textures)
for u in range(8, 20):
for v in range(8, 20):
freq_pattern[u, v] = np.random.randn() * 0.1
# Convert to spatial domain
spatial = idctn(freq_pattern, norm='ortho')
freq_images[i, c] = torch.from_numpy(spatial).float()
# Normalize to training data range
freq_images = freq_images.to(device)
freq_images = freq_images - freq_images.mean()
freq_images = freq_images / freq_images.std() * real_images.std()
freq_images = torch.clamp(freq_images, -1, 1)
# Apply denoising
freq_timesteps = [100, 75, 55, 40, 28, 18, 10, 4, 1]
x = freq_images.clone()
for t_val in freq_timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
predicted_noise = model(x, t_tensor)
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.6) / np.sqrt(alpha_bar_t)
x = torch.clamp(x, -1, 1)
# Save result
freq_display = torch.clamp((x + 1) / 2, 0, 1)
freq_grid = make_grid(freq_display, nrow=2, normalize=False)
save_image(freq_grid, "frequency_generation.png")
print(f"Frequency result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
print("Saved to frequency_generation.png")
print("\n=== RESULTS ===")
print("Generated files:")
print("- smart_noise_generation.png (noise matching training stats)")
print("- blended_generation.png (combinations of real images)")
print("- frequency_generation.png (frequency-domain initialization)")
print("\nYour model works as a super-denoiser!")
print("It can clean up any reasonable starting point to look more image-like.")
if __name__ == "__main__":
hybrid_generation()