--- title: Attention Visualization for Translation emoji: 👀 colorFrom: pink colorTo: yellow sdk: gradio sdk_version: 5.25.0 app_file: app.py pinned: false license: mit short_description: Visualize attention in translation(EN-ZH) tags: - attention - visualization - translation - cross-attention - self-attention - transformer - encoder-decoder --- This app aims to help users better understand the behavior behind the attention layers in transformer models by visualizing the cross-attention and self-attention weights in an encoder-decoder model to see the alignment between and within the source and target tokens. The app leverages the `Helsinki-NLP/opus-mt-en-zh` model to perform translation tasks from English to Chinese and by setting `output_attentions=True`, the attention weights are stored as follows: | Attention Type | Shape | Role | |---------------------|---------------------------------------------|----------------------------------------------------| | encoder_attentions | (layers, B, heads, src_len, src_len) | Encoder self-attention on source tokens | | decoder_attentions | (layers, B, heads, tgt_len, tgt_len) | Decoder self-attention on generated tokens | | cross_attentions | (layers, B, heads, tgt_len, src_len) | Decoder attention over source tokens (encoder outputs) | By taking the weights from the last encoder and decoder layers and calculating the mean over all of the attention heads, the attention weights (avg over heads) are obtained to build attention visualization. **Note :** * `attention_weights = softmax(Q @ K.T / sqrt(d_k)) ` - A probability distribution over all keys (i.e., tokens being attended to) for each query (i.e., the current token). * `(layers, B, heads, src_len, src_len)` - e.g. `(6, 1, 8, 24, 18)`