Autoregressive Transformer trained on TinyStories

This is an autoregressive decoder-only transformer model trained on the TinyStories dataset using JAX and Flax NNX.

Model Details

  • Model Type: Autoregressive Decoder-only Transformer
  • Framework: JAX + Flax NNX
  • Dataset: TinyStories
  • Parameters: ~85.0M
  • Precision: Mixed (FP32 parameters, BF16 computation)

Architecture

- Hidden Size: 512
- Number of Layers: 8
- Attention Heads: 8
- Intermediate Size: 2048
- Max Position Embeddings: 256
- Vocab Size: 50257
- Rotary Position Embeddings: True

Training Details

  • Training Steps: 3,120
  • Batch Size: 32
  • Gradient Accumulation: 4
  • Learning Rate: 0.0003
  • Training Duration: 0.43 hours
  • Final Eval Loss: 1.7965960502624512
  • Final Eval Perplexity: 6.170201301574707

Usage

# This model was trained with JAX/Flax and requires the custom transformer implementation
# to load and use. See the repository for implementation details.

from transformers import AutoTokenizer
import jax.numpy as jnp

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Example text generation (requires custom model loading)
prompt = "Once upon a time, there was a little"
# ... (model loading and generation code)

Training Configuration

model:
  hidden_size: 512
  num_layers: 8  
  num_attention_heads: 8
  intermediate_size: 2048
  max_position_embeddings: 256

training:
  learning_rate: 0.0003
  batch_size: 32
  epochs: 10
  warmup_ratio: 0.1

Files

  • config.json: Model configuration
  • train_history.json: Training metrics and duration
  • tokenizer/: GPT-2 tokenizer files
  • model_checkpoint/: Best model checkpoint
  • tensorboard_logs/: Training logs for TensorBoard

License

MIT License - see LICENSE file for details.

Downloads last month
6
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train thiomajid/StoriesLM