4D masks support in Transformers
With recently merged PR*, transformers can now accept custom
4D
attention masks as arguments to .forward()
method. Why is it 4D and what opportunities does it open?
Understanding Attention Masks in Transformers
In the realm of Natural Language Processing, attention masks in transformers have often been a subtle yet crucial component. Most often we pay little attention to attention masks (pun?) and pass None
, the default argument, which suffices for many applications. Generated by tokenizers, standard attention masks are 2D tensors, shaped [batch_size, total_sequence_length]
, filled with ones. This format represents a simple, linear attention mechanism across the sequence.
However, delve deeper into transformer architectures, and you'll discover that these masks undergo a transformation (another pun?). Inside the model, a 2D tensor mask becomes a 4D tensor, shaped [batch_size, heads, input_ids_length, total_sequence_length]
. This format allows for more nuanced attention strategies, such as causal decoding, which uses a lower triangular matrix of ones, sometimes supplemented by a rectangle of ones if a Key-Value (KV) cache is present. This structure ensures that each token attends only to itself and its preceding neighbors.
tensor([[[[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1]]]])
The Emergence of Custom 4D Masks
Traditionally, these 4D masks were internal representations, not directly accessible or modifiable by users. The recently merged pull request allows passing custom 4D masks as arguments to the .forward() method in transformers marks a significant advancement. This feature allows for more complex attention patterns, essential for tasks that require non-sequential token processing or multiple sequences in a single input tensor.
Important mention: if supplying custom mask to the function, one most likely will need to provide a position_ids
tensor shaped [batch_size, total_sequence_length] - it will allow the positional encoding adjust themselves to non-sequential order of tokens. Not every model accept custom position_ids
argument, but many modern ones, like Llama
do. Others may require patching.
Below are a few use cases showing improvements in memory and/or efficiency with 4D masks:
1. Memory-efficient beam search
Consider beam search, a common strategy in language models for generating text. Traditionally, a batch of sequences with common prefixes and varying endings are processed separately, consuming significant memory. With 4D masks, these sequences can be compactly represented in a single sequence, leading to substantial memory savings.
Suppose in step 1 of beam search we got 2 beams:
- cat sat on the
- cat sat on my
At step 2 we evaluate 4 candidates for the next position:
- cat sat on the mat
- cat sat on the floor
- cat sat on my chair
- cat sat on my desk
Typically,those would be evaluated as batch of size 4 and length 5. Here is how to pack them into 1 sequence of length 9:
tokens | cat | sat | on | the | my | mat | floor | chair | desk |
---|---|---|---|---|---|---|---|---|---|
position_ids | 0 | 1 | 2 | 3 | 3 | 4 | 4 | 4 | 4 |
mask | 1 | 1 | 1 | 1 | . | . | . | . | . |
1 | 1 | 1 | . | 1 | . | . | . | . | |
1 | 1 | 1 | 1 | . | 1 | . | . | . | |
1 | 1 | 1 | 1 | . | . | 1 | . | . | |
1 | 1 | 1 | . | 1 | . | . | 1 | . | |
1 | 1 | 1 | . | 1 | . | . | . | 1 |
Assuming that the common prefix is already processed and added to the KV cache, we need to pass to the model.forward()
:
input_ids
as tokenized sequencemat floor chair desk
position_ids
as tensor shaped(1, 9)
as aboveattention_mask
shaped(1, 1, 5, 9)
kv_cache
with length equal 5 tokens (those withposition_id <= 3
).
As a result we have about 2x saving in memory, which gets greater as more beams of greater length are used.
This method comes from paper SpecInfer: Accelerating Generative Large Language Model Serving with Speculative Inference and Token Tree Verification by Xupeng Miao et al. where the authors introduce topology-aware casual mask to fuse tree attention computation of all tokens in a single kernel when building a tree of token candidates in speculative decoding.
2. Sequences packing in SFT (supervised finetuning) training.
Based on discussion here: https://github.com/huggingface/trl/issues/805.
When fine-tuning language models with varied-length sequences, the standard approach is to separate sequences with end-of-sequence (EOS) tokens as introduced in the T5 paper. However, this doesn't prevent cross-sequence attention. By employing 4D masks, attention can be confined exclusively to individual sequences, even when they are packed together. This method significantly reduces the risk of cross-contamination between sequences.
See the picture below for the method illustration:
3. Look ahead decoding.
A novel method, Lookahead Decoding, as proposed by Yichao Fu et al. in "Break the Sequential Dependency of LLM Inference Using Lookahead Decoding", blends token generation and validation in a single pass. This technique, previously requiring custom forward functions for manipulating token attention, can now be streamlined using 4D attention masks. Before 4D mask implementation, it required custom forward function with fine manipulation of tokens' attention. Now this can be done by just providing the right mask.
The Future of Attention Masks
Looking ahead, the evolution of attention masks might not stop at four dimensions. The introduction of a fifth dimension, representing the layer index, could unlock new experimental possibilities, such as layer-specific attention patterns. Another area to explore is head
dimension allowing for sparse attention in Big Bird style.
Conclusion
The integration of custom 4D attention masks in transformers offers greater flexibility and efficiency in handling complex language tasks. By enabling more intricate attention mechanisms, these masks open the door to novel applications and improved performance in existing ones.
- Thanks to Arthur Zucker, UniverseFly, KexinFeng, PhilJd, shentianxiao for support, suggestions and testing of the 4D masks PR.