|
import torch |
|
import torchvision |
|
from torchvision.utils import save_image, make_grid |
|
import os |
|
import argparse |
|
from datetime import datetime |
|
from config import Config |
|
from model import SmoothDiffusionUNet |
|
from noise_scheduler_simple import FrequencyAwareNoise |
|
from sample_simple import simple_sample |
|
|
|
def load_model(checkpoint_path, device): |
|
"""Load model from checkpoint""" |
|
print(f"Loading model from: {checkpoint_path}") |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
if 'config' in checkpoint: |
|
config = checkpoint['config'] |
|
else: |
|
config = Config() |
|
|
|
model = SmoothDiffusionUNet(config).to(device) |
|
noise_scheduler = FrequencyAwareNoise(config) |
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
epoch = checkpoint.get('epoch', 'unknown') |
|
loss = checkpoint.get('loss', 'unknown') |
|
print(f"Loaded model from epoch {epoch}, loss: {loss}") |
|
else: |
|
|
|
model.load_state_dict(checkpoint) |
|
print("Loaded model state dict") |
|
|
|
return model, noise_scheduler, config |
|
|
|
def test_checkpoint(checkpoint_path, device, n_samples=16): |
|
"""Test a single checkpoint with working sampler""" |
|
model, noise_scheduler, config = load_model(checkpoint_path, device) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
save_path = f"test_samples_simple_{timestamp}.png" |
|
|
|
print(f"Testing checkpoint with {n_samples} samples...") |
|
samples, grid = simple_sample(model, noise_scheduler, device, n_samples=n_samples) |
|
|
|
|
|
save_image(grid, save_path, normalize=False) |
|
print(f"Samples saved to: {save_path}") |
|
|
|
return samples, grid |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Test trained diffusion model (simple version)') |
|
parser.add_argument('--checkpoint', type=str, required=True, help='Path to checkpoint file') |
|
parser.add_argument('--n_samples', type=int, default=16, help='Number of samples to generate') |
|
parser.add_argument('--device', type=str, default='auto', help='Device to use (cuda/cpu/auto)') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.device == 'auto': |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
else: |
|
device = torch.device(args.device) |
|
|
|
print(f"Using device: {device}") |
|
|
|
|
|
print("=== Testing Checkpoint with Simple DDPM ===") |
|
test_checkpoint(args.checkpoint, device, args.n_samples) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|