# Inference script for the trained diffusion model import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt from tqdm import tqdm import math # [Copy all the model architecture classes here - TimeEmbedding, ResidualBlock, etc.] def load_model(checkpoint_path, device='cuda'): """Load the trained diffusion model""" checkpoint = torch.load(checkpoint_path, map_location=device) # Initialize model with saved config model = SimpleUNet(**checkpoint['model_config']) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() # Initialize scheduler scheduler = DDPMScheduler(**checkpoint['diffusion_config'], device=device) return model, scheduler, checkpoint['model_info'] # Usage example: # model, scheduler, info = load_model('complete_diffusion_model.pth') # generated_images = generate_images(model, scheduler, num_images=4)