File size: 3,597 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torchvision
from torchvision.utils import save_image, make_grid
import os
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from sample import frequency_aware_sample

def test_latest_checkpoint():
    """Test the latest checkpoint with frequency-aware sampling"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Find latest log directory
    log_dirs = []
    if os.path.exists('./logs'):
        for item in os.listdir('./logs'):
            if os.path.isdir(os.path.join('./logs', item)):
                log_dirs.append(item)
    
    if not log_dirs:
        print("No log directories found!")
        return
    
    latest_log = sorted(log_dirs)[-1]
    log_path = os.path.join('./logs', latest_log)
    print(f"Testing latest log directory: {log_path}")
    
    # Find checkpoint files
    checkpoint_files = []
    for file in os.listdir(log_path):
        if file.startswith('model_epoch_') and file.endswith('.pth'):
            epoch = int(file.split('_')[2].split('.')[0])
            checkpoint_files.append((epoch, file))
    
    if not checkpoint_files:
        print("No checkpoint files found!")
        return
    
    # Sort and get latest checkpoint
    checkpoint_files.sort()
    latest_epoch, latest_file = checkpoint_files[-1]
    checkpoint_path = os.path.join(log_path, latest_file)
    
    print(f"Testing checkpoint: {latest_file} (epoch {latest_epoch})")
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Initialize model and noise scheduler
    if 'config' in checkpoint:
        config = checkpoint['config']
    else:
        config = Config()
    
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    
    # Load model state
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        epoch = checkpoint.get('epoch', 'unknown')
        loss = checkpoint.get('loss', 'unknown')
        print(f"Loaded model from epoch {epoch}, loss: {loss}")
    else:
        model.load_state_dict(checkpoint)
        print("Loaded model state dict")
    
    # Generate samples using frequency-aware sampling
    print("\n=== Generating samples with frequency-aware approach ===")
    try:
        samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=8)
        
        # Save the samples
        save_path = f"test_samples_epoch_{latest_epoch}_fixed.png"
        save_image(grid, save_path, normalize=False)
        print(f"Samples saved to: {save_path}")
        
        # Print sample statistics
        print(f"Sample statistics:")
        print(f"  Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]")
        print(f"  Mean: {samples.mean().item():.3f}")
        print(f"  Std: {samples.std().item():.3f}")
        
        # Check if samples look like noise (all values close to 0 or very uniform)
        if samples.std().item() < 0.1:
            print("WARNING: Samples have very low variance - might be noise!")
        elif abs(samples.mean().item()) < 0.01 and samples.std().item() > 0.8:
            print("WARNING: Samples look like random noise!")
        else:
            print("Samples look reasonable!")
            
    except Exception as e:
        print(f"Error during sampling: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    test_latest_checkpoint()