Autoregressive Transformer trained on TinyStories
This is an autoregressive decoder-only transformer model trained on the TinyStories dataset using JAX and Flax NNX.
Model Details
- Model Type: Autoregressive Decoder-only Transformer
- Framework: JAX + Flax NNX
- Dataset: TinyStories
- Parameters: ~85.0M
- Precision: Mixed (FP32 parameters, BF16 computation)
Architecture
- Hidden Size: 512
- Number of Layers: 8
- Attention Heads: 8
- Intermediate Size: 2048
- Max Position Embeddings: 256
- Vocab Size: 50257
- Rotary Position Embeddings: True
Training Details
- Training Steps: 3,120
- Batch Size: 32
- Gradient Accumulation: 4
- Learning Rate: 0.0003
- Training Duration: 0.43 hours
- Final Eval Loss: 1.7965960502624512
- Final Eval Perplexity: 6.170201301574707
Usage
# This model was trained with JAX/Flax and requires the custom transformer implementation
# to load and use. See the repository for implementation details.
from transformers import AutoTokenizer
import jax.numpy as jnp
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Example text generation (requires custom model loading)
prompt = "Once upon a time, there was a little"
# ... (model loading and generation code)
Training Configuration
model:
hidden_size: 512
num_layers: 8
num_attention_heads: 8
intermediate_size: 2048
max_position_embeddings: 256
training:
learning_rate: 0.0003
batch_size: 32
epochs: 10
warmup_ratio: 0.1
Files
config.json
: Model configurationtrain_history.json
: Training metrics and durationtokenizer/
: GPT-2 tokenizer filesmodel_checkpoint/
: Best model checkpointtensorboard_logs/
: Training logs for TensorBoard
License
MIT License - see LICENSE file for details.
- Downloads last month
- 6