|
|
import torch |
|
|
from dataloader import get_dataloaders |
|
|
from config import Config |
|
|
from noise_scheduler import FrequencyAwareNoise |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def debug_data(): |
|
|
config = Config() |
|
|
train_loader, _ = get_dataloaders(config) |
|
|
x0, _ = next(iter(train_loader)) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
plt.subplot(1, 2, 1) |
|
|
plt.imshow(x0[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) |
|
|
plt.title("Original") |
|
|
|
|
|
|
|
|
noise_scheduler = FrequencyAwareNoise(config) |
|
|
xt = noise_scheduler.apply_noise(x0, torch.tensor([500] * len(x0))) |
|
|
plt.subplot(1, 2, 2) |
|
|
plt.imshow(xt[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) |
|
|
plt.title("Noisy (t=500)") |
|
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
debug_data() |