# Inference script for the trained diffusion model | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import math | |
# [Copy all the model architecture classes here - TimeEmbedding, ResidualBlock, etc.] | |
def load_model(checkpoint_path, device='cuda'): | |
"""Load the trained diffusion model""" | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
# Initialize model with saved config | |
model = SimpleUNet(**checkpoint['model_config']) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() | |
# Initialize scheduler | |
scheduler = DDPMScheduler(**checkpoint['diffusion_config'], device=device) | |
return model, scheduler, checkpoint['model_info'] | |
# Usage example: | |
# model, scheduler, info = load_model('complete_diffusion_model.pth') | |
# generated_images = generate_images(model, scheduler, num_images=4) | |