bitskip-v1-earlyexit

BitSkip v1 with 8-bit activation quantization and ternary weights (no Hadamard transform)

Model Description

This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.

Architecture Details

  • Layers: 24
  • Hidden dimension: 2048
  • Attention heads: 32 (64-dimensional each)
  • Key-Value heads: 8 (Grouped Query Attention with 4:1 ratio)
  • FFN intermediate size: 4096
  • Position embeddings: Rotary Position Embeddings (RoPE)
  • Normalization: RMSNorm
  • Activation: SwiGLU (for MLP)
  • Parameters: ~1.06B

Quantization Scheme

  • Weights: Ternary {-1, 0, 1}
  • Activations: 8-bit quantization
  • Hadamard: No

Training Details

Dataset

  • Source: TinyStories (2.1M stories)
  • Tokenizer: GPT-2 BPE (vocab size: 50,257)
  • Sequence length: 512 tokens

Training Techniques

Quadratic Layer Dropout:

  • Progressive dropout: p_l = 0.5 × (l/L)²
  • Normalized so Σp_l = 1.0
  • Never drops final layer
  • Makes earlier layers more accurate

Early Exit Loss:

  • All layers share the same LM head
  • Loss = main_loss + 0.3 × early_exit_loss
  • Layer-proportional weighting: w_i = (i+1)/L
  • Enables flexible early exit at inference

Hyperparameters

  • Optimizer: AdamW
  • Learning rate: 6e-4
  • Warmup steps: 1000
  • Batch size: 16 (effective: 64)
  • Training steps: 50000
  • Gradient clipping: 1.0

Performance

Perplexity (TinyStories validation)

Exit Layer Perplexity Speed (tok/s)
All layers TBD TBD
Layer 18 TBD TBD
Layer 12 TBD TBD
Layer 6 TBD TBD

Training Stability

  • Gradient norms: 2-5
  • Final loss: TBD

Usage

Installation

pip install transformers torch

Basic Inference

from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model
model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v1-earlyexit")
tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v1-earlyexit")

# Generate text
inputs = tokenizer("Once upon a time", return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))

Early Exit Inference

# Exit at layer 12 for faster inference
model.set_exit_layer(12)
outputs = model.generate(**inputs, max_length=100)
# 1.5-2x faster with minimal quality loss

Benchmark Different Exit Layers

for exit_layer in [6, 12, 18, 24]:
    model.set_exit_layer(exit_layer)
    outputs = model.generate(**inputs, max_length=100)
    print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")

Limitations

  • Inference speed: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
  • Training instability: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
  • Dataset scope: Trained only on TinyStories; may not generalize to other domains without fine-tuning

Citation

If you use this model, please cite:

@article{bitnet,
  title={BitNet: Scaling 1-bit Transformers for Large Language Models},
  author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
  journal={arXiv preprint arXiv:2310.11453},
  year={2023}
}

@article{layerskip,
  title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
  author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
  journal={arXiv preprint arXiv:2404.16710},
  year={2024}
}

License

MIT License

Contact

For questions or issues, please open an issue on the model repository.

Downloads last month
7
Safetensors
Model size
1.0B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train Ram07/bitskip-v1-earlyexit