Grad-CDM / debug.py
nazgut's picture
Upload 24 files
8abfb97 verified
raw
history blame
777 Bytes
import torch
from dataloader import get_dataloaders
from config import Config
from noise_scheduler import FrequencyAwareNoise
import matplotlib.pyplot as plt
def debug_data():
config = Config()
train_loader, _ = get_dataloaders(config)
x0, _ = next(iter(train_loader))
# Visualize original
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(x0[0].permute(1, 2, 0).numpy() * 0.5 + 0.5)
plt.title("Original")
# Visualize noisy
noise_scheduler = FrequencyAwareNoise(config)
xt = noise_scheduler.apply_noise(x0, torch.tensor([500] * len(x0)))
plt.subplot(1, 2, 2)
plt.imshow(xt[0].permute(1, 2, 0).numpy() * 0.5 + 0.5)
plt.title("Noisy (t=500)")
plt.show()
if __name__ == "__main__":
debug_data()