|
|
import torch |
|
|
import torchvision |
|
|
from torchvision.utils import save_image, make_grid |
|
|
import os |
|
|
from config import Config |
|
|
from model import SmoothDiffusionUNet |
|
|
from noise_scheduler import FrequencyAwareNoise |
|
|
from sample import frequency_aware_sample |
|
|
|
|
|
def test_latest_checkpoint(): |
|
|
"""Test the latest checkpoint with frequency-aware sampling""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
log_dirs = [] |
|
|
if os.path.exists('./logs'): |
|
|
for item in os.listdir('./logs'): |
|
|
if os.path.isdir(os.path.join('./logs', item)): |
|
|
log_dirs.append(item) |
|
|
|
|
|
if not log_dirs: |
|
|
print("No log directories found!") |
|
|
return |
|
|
|
|
|
latest_log = sorted(log_dirs)[-1] |
|
|
log_path = os.path.join('./logs', latest_log) |
|
|
print(f"Testing latest log directory: {log_path}") |
|
|
|
|
|
|
|
|
checkpoint_files = [] |
|
|
for file in os.listdir(log_path): |
|
|
if file.startswith('model_epoch_') and file.endswith('.pth'): |
|
|
epoch = int(file.split('_')[2].split('.')[0]) |
|
|
checkpoint_files.append((epoch, file)) |
|
|
|
|
|
if not checkpoint_files: |
|
|
print("No checkpoint files found!") |
|
|
return |
|
|
|
|
|
|
|
|
checkpoint_files.sort() |
|
|
latest_epoch, latest_file = checkpoint_files[-1] |
|
|
checkpoint_path = os.path.join(log_path, latest_file) |
|
|
|
|
|
print(f"Testing checkpoint: {latest_file} (epoch {latest_epoch})") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
|
|
|
if 'config' in checkpoint: |
|
|
config = checkpoint['config'] |
|
|
else: |
|
|
config = Config() |
|
|
|
|
|
model = SmoothDiffusionUNet(config).to(device) |
|
|
noise_scheduler = FrequencyAwareNoise(config) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
epoch = checkpoint.get('epoch', 'unknown') |
|
|
loss = checkpoint.get('loss', 'unknown') |
|
|
print(f"Loaded model from epoch {epoch}, loss: {loss}") |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
print("Loaded model state dict") |
|
|
|
|
|
|
|
|
print("\n=== Generating samples with frequency-aware approach ===") |
|
|
try: |
|
|
samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=8) |
|
|
|
|
|
|
|
|
save_path = f"test_samples_epoch_{latest_epoch}_fixed.png" |
|
|
save_image(grid, save_path, normalize=False) |
|
|
print(f"Samples saved to: {save_path}") |
|
|
|
|
|
|
|
|
print(f"Sample statistics:") |
|
|
print(f" Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]") |
|
|
print(f" Mean: {samples.mean().item():.3f}") |
|
|
print(f" Std: {samples.std().item():.3f}") |
|
|
|
|
|
|
|
|
if samples.std().item() < 0.1: |
|
|
print("WARNING: Samples have very low variance - might be noise!") |
|
|
elif abs(samples.mean().item()) < 0.01 and samples.std().item() > 0.8: |
|
|
print("WARNING: Samples look like random noise!") |
|
|
else: |
|
|
print("Samples look reasonable!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during sampling: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_latest_checkpoint() |
|
|
|