Grad-CDM / simple_test.py
nazgut's picture
Upload 24 files
8abfb97 verified
raw
history blame
3.6 kB
import torch
import torchvision
from torchvision.utils import save_image, make_grid
import os
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from sample import frequency_aware_sample
def test_latest_checkpoint():
"""Test the latest checkpoint with frequency-aware sampling"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Find latest log directory
log_dirs = []
if os.path.exists('./logs'):
for item in os.listdir('./logs'):
if os.path.isdir(os.path.join('./logs', item)):
log_dirs.append(item)
if not log_dirs:
print("No log directories found!")
return
latest_log = sorted(log_dirs)[-1]
log_path = os.path.join('./logs', latest_log)
print(f"Testing latest log directory: {log_path}")
# Find checkpoint files
checkpoint_files = []
for file in os.listdir(log_path):
if file.startswith('model_epoch_') and file.endswith('.pth'):
epoch = int(file.split('_')[2].split('.')[0])
checkpoint_files.append((epoch, file))
if not checkpoint_files:
print("No checkpoint files found!")
return
# Sort and get latest checkpoint
checkpoint_files.sort()
latest_epoch, latest_file = checkpoint_files[-1]
checkpoint_path = os.path.join(log_path, latest_file)
print(f"Testing checkpoint: {latest_file} (epoch {latest_epoch})")
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
# Initialize model and noise scheduler
if 'config' in checkpoint:
config = checkpoint['config']
else:
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
# Load model state
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint.get('epoch', 'unknown')
loss = checkpoint.get('loss', 'unknown')
print(f"Loaded model from epoch {epoch}, loss: {loss}")
else:
model.load_state_dict(checkpoint)
print("Loaded model state dict")
# Generate samples using frequency-aware sampling
print("\n=== Generating samples with frequency-aware approach ===")
try:
samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=8)
# Save the samples
save_path = f"test_samples_epoch_{latest_epoch}_fixed.png"
save_image(grid, save_path, normalize=False)
print(f"Samples saved to: {save_path}")
# Print sample statistics
print(f"Sample statistics:")
print(f" Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]")
print(f" Mean: {samples.mean().item():.3f}")
print(f" Std: {samples.std().item():.3f}")
# Check if samples look like noise (all values close to 0 or very uniform)
if samples.std().item() < 0.1:
print("WARNING: Samples have very low variance - might be noise!")
elif abs(samples.mean().item()) < 0.01 and samples.std().item() > 0.8:
print("WARNING: Samples look like random noise!")
else:
print("Samples look reasonable!")
except Exception as e:
print(f"Error during sampling: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_latest_checkpoint()