bitskip-v1-earlyexit
BitSkip v1 with 8-bit activation quantization and ternary weights (no Hadamard transform)
Model Description
This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.
Architecture Details
- Layers: 24
- Hidden dimension: 2048
- Attention heads: 32 (64-dimensional each)
- Key-Value heads: 8 (Grouped Query Attention with 4:1 ratio)
- FFN intermediate size: 4096
- Position embeddings: Rotary Position Embeddings (RoPE)
- Normalization: RMSNorm
- Activation: SwiGLU (for MLP)
- Parameters: ~1.06B
Quantization Scheme
- Weights: Ternary {-1, 0, 1}
- Activations: 8-bit quantization
- Hadamard: No
Training Details
Dataset
- Source: TinyStories (2.1M stories)
- Tokenizer: GPT-2 BPE (vocab size: 50,257)
- Sequence length: 512 tokens
Training Techniques
Quadratic Layer Dropout:
- Progressive dropout: p_l = 0.5 × (l/L)²
- Normalized so Σp_l = 1.0
- Never drops final layer
- Makes earlier layers more accurate
Early Exit Loss:
- All layers share the same LM head
- Loss = main_loss + 0.3 × early_exit_loss
- Layer-proportional weighting: w_i = (i+1)/L
- Enables flexible early exit at inference
Hyperparameters
- Optimizer: AdamW
- Learning rate: 6e-4
- Warmup steps: 1000
- Batch size: 16 (effective: 64)
- Training steps: 50000
- Gradient clipping: 1.0
Performance
Perplexity (TinyStories validation)
| Exit Layer | Perplexity | Speed (tok/s) |
|---|---|---|
| All layers | TBD | TBD |
| Layer 18 | TBD | TBD |
| Layer 12 | TBD | TBD |
| Layer 6 | TBD | TBD |
Training Stability
- Gradient norms: 2-5
- Final loss: TBD
Usage
Installation
pip install transformers torch
Basic Inference
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model
model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v1-earlyexit")
tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v1-earlyexit")
# Generate text
inputs = tokenizer("Once upon a time", return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
Early Exit Inference
# Exit at layer 12 for faster inference
model.set_exit_layer(12)
outputs = model.generate(**inputs, max_length=100)
# 1.5-2x faster with minimal quality loss
Benchmark Different Exit Layers
for exit_layer in [6, 12, 18, 24]:
model.set_exit_layer(exit_layer)
outputs = model.generate(**inputs, max_length=100)
print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")
Limitations
- Inference speed: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
- Training instability: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
- Dataset scope: Trained only on TinyStories; may not generalize to other domains without fine-tuning
Citation
If you use this model, please cite:
@article{bitnet,
title={BitNet: Scaling 1-bit Transformers for Large Language Models},
author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
journal={arXiv preprint arXiv:2310.11453},
year={2023}
}
@article{layerskip,
title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
journal={arXiv preprint arXiv:2404.16710},
year={2024}
}
License
MIT License
Contact
For questions or issues, please open an issue on the model repository.
- Downloads last month
- 7