Description

Implementation of the KV cache introduced in the Attention Sinks paper. It allows the model to generate beyond the length of its context window, without losing fluency in the conversation. This is done by always keeping the first few tokens ("sink tokens") in the KV cache, as models often pay a large amount of attention to them. As it discards past non-sink tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. It's also a solution to contain the memory footprint of the KV cache.

This implementation matches the SinkCache class present in transformers<4.53.0.

Sink Cache diagram from the original paper

Base model

Model compatibility

  • Decoder-only transformers models

Additional Arguments

  • window_length (int, optional, defaults to 256): The length of the context window.
  • num_sink_tokens (int, optional, defaults to 4): The number of sink tokens. See the original paper for more information.

Output Type changes

  • When return_dict_in_generate=True, output.past_key_values will be a SinkCache instance. SinkCache is defined in generate.py, in this repository.

Example usage

We can use the custom generation method in this repository like the the base generate from transformers:

# requires `transformers>=4.52.0`
from transformers import AutoModelForCausalLM, AutoTokenizer

# Preparing model, tokenizer, and model inputs
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", device_map="auto")
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# Using sink cache
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    # sink cache arguments (default `window_length=256`)
    custom_generate="transformers-community/sink_cache",
    trust_remote_code=True,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "sinkcache" in str(type(gen_out.past_key_values)).lower()
# ['user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
# eyes that sparkled with wonder. She had a soft, warm coat that shimmered like the morning sun, and her tail was
# always wagging in playful motions.\n\nOne day, while exploring the village, Luna noticed a curious sight: a young
# boy playing with a ball on the lake. She followed him closely, her heart racing']

Continuing the example above, we can confirm some properties of the SinkCache

# `max_new_tokens` < `window_length` in the example above -> matches output with the default cache
gen_out = model.generate(
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "dynamiccache" in str(type(gen_out.past_key_values)).lower()
# ['user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
# eyes that sparkled with wonder. She had a soft, warm coat that shimmered like the morning sun, and her tail was
# always wagging in playful motions.\n\nOne day, while exploring the village, Luna noticed a curious sight: a young
# boy playing with a ball on the lake. She followed him closely, her heart racing']

# if we set a smaller `window_length`, the story is less coherent after that point, but the used cache is also
# significantly smaller
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    # sink cache arguments
    custom_generate="transformers-community/sink_cache",
    trust_remote_code=True,
    window_length=50,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
# ["user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
# heart. She loved exploring the village and playing with her friends.\n\nOne day, Luna noticed something unusual.
# She looked around and saw a shadow moving in the dark. She ran quickly, but she couldn't see the shadow. She
# thought maybe it was a ghost or something else.\n\nAs she was running, she heard a voice."]
Downloads last month
8
Safetensors
Model size
752M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including transformers-community/sink_cache