|
import torch |
|
from model import SmoothDiffusionUNet |
|
from noise_scheduler import FrequencyAwareNoise |
|
from config import Config |
|
from torchvision.utils import save_image, make_grid |
|
from dataloader import get_dataloaders |
|
import numpy as np |
|
|
|
def hybrid_generation(): |
|
"""Hybrid approach: Use model as super-denoiser rather than pure generator""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
checkpoint = torch.load('model_final.pth', map_location=device) |
|
config = Config() |
|
|
|
model = SmoothDiffusionUNet(config).to(device) |
|
noise_scheduler = FrequencyAwareNoise(config) |
|
model.load_state_dict(checkpoint) |
|
model.eval() |
|
|
|
|
|
train_loader, _ = get_dataloaders(config) |
|
real_batch, _ = next(iter(train_loader)) |
|
real_images = real_batch[:8].to(device) |
|
|
|
print("=== HYBRID GENERATION APPROACH ===") |
|
|
|
with torch.no_grad(): |
|
|
|
print("\n--- Method 1: Smart Noise Initialization ---") |
|
|
|
|
|
smart_noise = torch.randn(4, 3, 64, 64, device=device) |
|
smart_noise = smart_noise * real_images.std().item() |
|
smart_noise = smart_noise + real_images.mean().item() |
|
smart_noise = torch.clamp(smart_noise, -1, 1) |
|
|
|
print(f"Smart noise stats: mean={smart_noise.mean():.3f}, std={smart_noise.std():.3f}") |
|
|
|
|
|
timesteps = [150, 120, 90, 70, 50, 35, 25, 15, 8, 3, 1] |
|
x = smart_noise.clone() |
|
|
|
for t_val in timesteps: |
|
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
|
predicted_noise = model(x, t_tensor) |
|
|
|
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
|
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.7) / np.sqrt(alpha_bar_t) |
|
x = torch.clamp(x, -1, 1) |
|
|
|
|
|
smart_display = torch.clamp((x + 1) / 2, 0, 1) |
|
smart_grid = make_grid(smart_display, nrow=2, normalize=False) |
|
save_image(smart_grid, "smart_noise_generation.png") |
|
print(f"Smart noise result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}") |
|
print("Saved to smart_noise_generation.png") |
|
|
|
|
|
print("\n--- Method 2: Blended Real Images ---") |
|
|
|
|
|
indices = torch.randint(0, len(real_images), (4, 3)) |
|
weights = torch.rand(4, 3, device=device) |
|
weights = weights / weights.sum(dim=1, keepdim=True) |
|
|
|
blended = torch.zeros(4, 3, 64, 64, device=device) |
|
for i in range(4): |
|
for j in range(3): |
|
blended[i] += weights[i, j] * real_images[indices[i, j]] |
|
|
|
|
|
noise = torch.randn_like(blended) * 0.15 |
|
blended = blended + noise |
|
blended = torch.clamp(blended, -1, 1) |
|
|
|
|
|
light_timesteps = [80, 60, 40, 25, 12, 5, 1] |
|
x = blended.clone() |
|
|
|
for t_val in light_timesteps: |
|
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
|
predicted_noise = model(x, t_tensor) |
|
|
|
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
|
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t) |
|
x = torch.clamp(x, -1, 1) |
|
|
|
|
|
blended_display = torch.clamp((x + 1) / 2, 0, 1) |
|
blended_grid = make_grid(blended_display, nrow=2, normalize=False) |
|
save_image(blended_grid, "blended_generation.png") |
|
print(f"Blended result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}") |
|
print("Saved to blended_generation.png") |
|
|
|
|
|
print("\n--- Method 3: Frequency-Domain Initialization ---") |
|
|
|
|
|
from scipy.fftpack import dctn, idctn |
|
|
|
freq_images = torch.zeros(4, 3, 64, 64, device=device) |
|
|
|
for i in range(4): |
|
for c in range(3): |
|
|
|
freq_pattern = np.zeros((64, 64)) |
|
|
|
|
|
for u in range(0, 8): |
|
for v in range(0, 8): |
|
freq_pattern[u, v] = np.random.randn() * (1.0 / (1 + u + v)) |
|
|
|
|
|
for u in range(8, 20): |
|
for v in range(8, 20): |
|
freq_pattern[u, v] = np.random.randn() * 0.1 |
|
|
|
|
|
spatial = idctn(freq_pattern, norm='ortho') |
|
freq_images[i, c] = torch.from_numpy(spatial).float() |
|
|
|
|
|
freq_images = freq_images.to(device) |
|
freq_images = freq_images - freq_images.mean() |
|
freq_images = freq_images / freq_images.std() * real_images.std() |
|
freq_images = torch.clamp(freq_images, -1, 1) |
|
|
|
|
|
freq_timesteps = [100, 75, 55, 40, 28, 18, 10, 4, 1] |
|
x = freq_images.clone() |
|
|
|
for t_val in freq_timesteps: |
|
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
|
predicted_noise = model(x, t_tensor) |
|
|
|
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
|
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.6) / np.sqrt(alpha_bar_t) |
|
x = torch.clamp(x, -1, 1) |
|
|
|
|
|
freq_display = torch.clamp((x + 1) / 2, 0, 1) |
|
freq_grid = make_grid(freq_display, nrow=2, normalize=False) |
|
save_image(freq_grid, "frequency_generation.png") |
|
print(f"Frequency result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}") |
|
print("Saved to frequency_generation.png") |
|
|
|
print("\n=== RESULTS ===") |
|
print("Generated files:") |
|
print("- smart_noise_generation.png (noise matching training stats)") |
|
print("- blended_generation.png (combinations of real images)") |
|
print("- frequency_generation.png (frequency-domain initialization)") |
|
print("\nYour model works as a super-denoiser!") |
|
print("It can clean up any reasonable starting point to look more image-like.") |
|
|
|
if __name__ == "__main__": |
|
hybrid_generation() |
|
|