DeltaNet Model Checkpoint for the Forgetting Transformer Paper

The final checkpoint for the 760M-parameter DeltaNet model in the main experiment of the ICLR 2025 paper Forgetting Transformer: Softmax Attention with a Forget Gate.

Model Details

Model Description

  • Developed by: Zhixuan Lin
  • Model type: DeltaNet
  • Language(s) (NLP): English
  • License: MIT

Model Sources

Uses

Direct Use

First, install the forgetting-transformer repository as a Python package and some needed dependencies (we pin the versions to make sure that this works, but you don't have to):

# We recommend you keep track of the commit hash you used. We may introduce breaking changes in the future.
# First, uninstall to prevent potential issues
pip uninstall forgetting_transformer && pip install -U git+https://github.com/zhixuan-lin/forgetting-transformer
pip install pytest einops numpy
pip install torch==2.4.0
pip install transformers==4.44.0
# No guarantee other commits would work; we may fix this later
pip install --no-deps --force-reinstall git+https://github.com/sustcsonglin/flash-linear-attention.git@1c5937eeeb8b0aa17bed5ee6dae345b353196bd4

Usage example:

import forgetting_transformer.model  # Needed to register the model classes
import forgetting_transformer.tokenizer  # Needed to register the tokenizer class
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("zhixuan-lin/delta_net-760m-longcrawl64-48b")
tokenizer = AutoTokenizer.from_pretrained("zhixuan-lin/delta_net-760m-longcrawl64-48b", add_bos_token=True, clean_up_tokenization_spaces=False)

# Generation using HF api
prompt = "The best thing to do in San Francisco is"
model = model.cuda()
encoded = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    output = model.generate(
        encoded,
        max_new_tokens=30,
    )[0]
pred = tokenizer.decode(output, skip_special_tokens=True)
print(pred)

# Of course you can also compute the logits or loss given proper inputs
batch_size, seq_len = encoded.shape
labels = encoded
input_ids = torch.roll(labels, shifts=1, dims=-1)
input_ids[:, 0] = tokenizer.bos_token_id  # 50256
out = model(input_ids=input_ids, labels=labels)
assert out.loss.size() == (batch_size, seq_len)
# Logits are not returned (to save memory) if labels are given
assert out.logits is None
# To get logits don't provide labels
out = model(input_ids=input_ids)
assert out.logits.size() == (batch_size, seq_len, tokenizer.vocab_size)

Limitations

This is a small model trained on a small number of tokens from LongCrawl64, provided for reproducibility and research purposes. Also, as a long-context dataset for research purposes, LongCrawl64 is not designed for optimal downstream task performance (it also has a strange tokenization process, see here). Therefore, this model is only suitable for research purposes (e.g., inspecting attention maps). Also, if you want to compare this model with other models trained in another setting with another dataset, you should definitely train it from scratch on your own dataset under your own setting for the comparison.

Training Details

Training Data

This model is trained on roughly 48B tokens on LongCrawl64, with a training context length of 16k tokens.

Training Procedure

Please see our paper for details. The training code is also provided in our official repository.

BibTeX:

@inproceedings{
lin2025forgetting,
title={Forgetting Transformer: Softmax Attention with a Forget Gate},
author={Zhixuan Lin and Evgenii Nikishin and Xu He and Aaron Courville},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=q2Lnyegkr8}
}
Downloads last month
19
Safetensors
Model size
835M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including zhixuan-lin/delta_net-760m-longcrawl64-48b