File size: 1,913 Bytes
092931e c3716f0 1f9a106 092931e 7b62f29 1cf899e 68b0ad3 1cf899e 68b0ad3 1cf899e 092931e e13cd84 3b4b5c1 e13cd84 3b4b5c1 e13cd84 3b4b5c1 e13cd84 3b4b5c1 e13cd84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
---
title: Attention Visualization for Translation
emoji: π
colorFrom: pink
colorTo: yellow
sdk: gradio
sdk_version: 5.25.0
app_file: app.py
pinned: false
license: mit
short_description: Visualize attention in translation(EN-ZH)
tags:
- attention
- visualization
- translation
- cross-attention
- self-attention
- transformer
- encoder-decoder
---
This app aims to help users better understand the behavior behind the attention layers in transformer models by visualizing the cross-attention and self-attention weights in an encoder-decoder model to see the alignment between and within the source and target tokens.
The app leverages the `Helsinki-NLP/opus-mt-en-zh` model to perform translation tasks from English to Chinese and by setting `output_attentions=True`, the attention weights are stored as follows:
| Attention Type | Shape | Role |
|---------------------|---------------------------------------------|----------------------------------------------------|
| encoder_attentions | (layers, B, heads, src_len, src_len) | Encoder self-attention on source tokens |
| decoder_attentions | (layers, B, heads, tgt_len, tgt_len) | Decoder self-attention on generated tokens |
| cross_attentions | (layers, B, heads, tgt_len, src_len) | Decoder attention over source tokens (encoder outputs) |
By taking the weights from the last encoder and decoder layers and calculating the mean over all of the attention heads, the attention weights (avg over heads) are obtained to build attention visualization.
**Note :**
* `attention_weights = softmax(Q @ K.T / sqrt(d_k)) ` - A probability distribution over all keys (i.e., tokens being attended to) for each query (i.e., the current token).
* `(layers, B, heads, src_len, src_len)` - e.g. `(6, 1, 8, 24, 18)`
|