|
# CIFAR-10 Diffusion Model |
|
|
|
A lightweight diffusion model trained from scratch on the CIFAR-10 dataset in just 14.5 minutes using PyTorch. |
|
|
|
## Model Description |
|
|
|
This is a **SimpleUNet-based diffusion model** trained to generate 32x32 RGB images similar to the CIFAR-10 dataset. The model demonstrates the fundamentals of diffusion-based image generation with a compact architecture suitable for educational purposes and quick experimentation. |
|
|
|
### Key Features |
|
- π **Fast Training**: Complete training in under 15 minutes on RTX 3060 |
|
- πΎ **Lightweight**: Only 16.8M parameters (~64MB model size) |
|
- π― **Educational**: Clean, well-documented code for learning diffusion models |
|
- β‘ **Efficient Inference**: Generate images in seconds on consumer GPUs |
|
|
|
## Model Details |
|
|
|
| Attribute | Value | |
|
|-----------|-------| |
|
| **Architecture** | SimpleUNet with ResNet blocks + Attention | |
|
| **Parameters** | 16,808,835 | |
|
| **Dataset** | CIFAR-10 (50,000 training images) | |
|
| **Image Size** | 32Γ32 RGB | |
|
| **Training Steps** | 7,820 (20 epochs Γ 391 batches) | |
|
| **Training Time** | 14.54 minutes | |
|
| **Hardware** | NVIDIA RTX 3060 (0.43GB VRAM used) | |
|
| **Framework** | PyTorch 2.0+ | |
|
|
|
## Quick Start |
|
|
|
### Installation |
|
```bash |
|
pip install torch torchvision matplotlib tqdm pillow numpy |
|
``` |
|
|
|
### Basic Usage |
|
```python |
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
# Load model |
|
checkpoint = torch.load('complete_diffusion_model.pth') |
|
model = SimpleUNet(**checkpoint['model_config']) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
|
|
# Initialize scheduler |
|
scheduler = DDPMScheduler(**checkpoint['diffusion_config']) |
|
|
|
# Generate images |
|
@torch.no_grad() |
|
def generate_images(model, scheduler, num_images=4): |
|
device = next(model.parameters()).device |
|
images = torch.randn(num_images, 3, 32, 32).to(device) |
|
|
|
for t in range(999, -1, -20): # 50 denoising steps |
|
timestep = torch.full((num_images,), t, device=device) |
|
noise_pred = model(images, timestep) |
|
|
|
# Simplified DDPM step |
|
alpha_t = scheduler.alpha_cumprod[t] |
|
alpha_prev = scheduler.alpha_cumprod[t-20] if t >= 20 else 1.0 |
|
|
|
pred_x0 = (images - torch.sqrt(1-alpha_t) * noise_pred) / torch.sqrt(alpha_t) |
|
images = torch.sqrt(alpha_prev) * pred_x0 + torch.sqrt(1-alpha_prev) * noise_pred |
|
|
|
return images |
|
|
|
# Generate and display |
|
generated = generate_images(model, scheduler) |
|
``` |
|
|
|
## Training Details |
|
|
|
- **Loss Function**: MSE between predicted and actual noise |
|
- **Optimizer**: AdamW (lr=1e-4, weight_decay=1e-6) |
|
- **Scheduler**: CosineAnnealingLR |
|
- **Batch Size**: 128 |
|
- **Final Loss**: 0.0363 (73% reduction from initial) |
|
- **Diffusion Steps**: 1000 (linear beta schedule) |
|
|
|
## Performance |
|
|
|
### Training Loss Curve |
|
The model shows excellent convergence: |
|
- **Epoch 1**: 0.1349 β **Epoch 20**: 0.0363 |
|
- **Best Loss**: 0.0358 (Epoch 19) |
|
- **Stable convergence** without overfitting |
|
|
|
### Generation Quality |
|
- β
Captures CIFAR-10 color distributions |
|
- β
Generates diverse, non-repetitive outputs |
|
- β οΈ Abstract patterns (needs longer training for object recognition) |
|
- π― Suitable for color/texture generation tasks |
|
|
|
## Files in this Repository |
|
|
|
| File | Description | Size | |
|
|------|-------------|------| |
|
| `complete_diffusion_model.pth` | Full model with config and weights | ~64MB | |
|
| `diffusion_model_final.pth` | Training checkpoint (epoch 20) | ~64MB | |
|
| `model_info.json` | Training metadata and hyperparameters | <1KB | |
|
| `inference_example.py` | Complete inference script with model classes | ~5KB | |
|
|
|
## Model Architecture |
|
|
|
``` |
|
SimpleUNet( |
|
time_embedding: TimeEmbedding(128) |
|
encoder: 3 ResNet blocks with downsampling |
|
middle: ResNet + Self-Attention + ResNet |
|
decoder: 3 ResNet blocks with upsampling |
|
output: GroupNorm β SiLU β Conv2d |
|
) |
|
``` |
|
|
|
## Use Cases |
|
|
|
- π **Educational**: Learn diffusion model fundamentals |
|
- π¬ **Research**: Baseline for diffusion experiments |
|
- π¨ **Art**: Generate abstract textures and patterns |
|
- β‘ **Prototyping**: Quick diffusion model testing |
|
|
|
## Limitations & Improvements |
|
|
|
### Current Limitations |
|
- Generates abstract patterns rather than recognizable objects |
|
- Trained on small 32Γ32 resolution |
|
- Limited to 20 training epochs |
|
|
|
### Suggested Improvements |
|
1. **Extended Training**: 50-100 epochs for better object generation |
|
2. **Larger Architecture**: Increase model capacity |
|
3. **Advanced Sampling**: Implement DDIM or DPM-Solver++ |
|
4. **Higher Resolution**: Train on 64Γ64 or 128Γ128 images |
|
5. **Better Datasets**: Use CelebA-HQ or custom datasets |
|
|
|
## Citation |
|
|
|
```bibtex |
|
@misc{cifar10-diffusion-2025, |
|
title={CIFAR-10 Diffusion Model: Fast Training Implementation}, |
|
author={Karthik}, |
|
year={2025}, |
|
publisher={Hugging Face}, |
|
howpublished={\url{https://huggingface.co/karthik-2905/DiffusionPretrained}} |
|
} |
|
``` |
|
|
|
## License |
|
|
|
MIT License - Free for research and commercial use. |
|
|
|
--- |
|
|
|
**π Want to train your own?** Check out the [full implementation](https://github.com/GruheshKurra/DiffusionModelPretrained) with Jupyter notebooks and step-by-step training code! |
|
|
|
**π Training Stats**: 16.8M params β’ 14.5min training β’ RTX 3060 β’ PyTorch 2.0 |