File size: 1,913 Bytes
092931e
c3716f0
1f9a106
092931e
 
 
 
 
 
 
7b62f29
1cf899e
 
 
68b0ad3
1cf899e
 
68b0ad3
1cf899e
092931e
 
e13cd84
 
3b4b5c1
e13cd84
3b4b5c1
 
 
 
 
e13cd84
3b4b5c1
e13cd84
 
3b4b5c1
e13cd84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
---
title: Attention Visualization for Translation 
emoji: πŸ‘€
colorFrom: pink
colorTo: yellow
sdk: gradio
sdk_version: 5.25.0
app_file: app.py
pinned: false
license: mit
short_description: Visualize attention in translation(EN-ZH)
tags:
  - attention
  - visualization
  - translation
  - cross-attention
  - self-attention
  - transformer
  - encoder-decoder
---

This app aims to help users better understand the behavior behind the attention layers in transformer models by visualizing the cross-attention and self-attention weights in an encoder-decoder model to see the alignment between and within the source and target tokens.

The app leverages the `Helsinki-NLP/opus-mt-en-zh` model to perform translation tasks from English to Chinese and by setting `output_attentions=True`, the attention weights are stored as follows:

| Attention Type      | Shape                                      | Role                                               |
|---------------------|---------------------------------------------|----------------------------------------------------|
| encoder_attentions  | (layers, B, heads, src_len, src_len)        | Encoder self-attention on source tokens            |
| decoder_attentions  | (layers, B, heads, tgt_len, tgt_len)        | Decoder self-attention on generated tokens         |
| cross_attentions    | (layers, B, heads, tgt_len, src_len)        | Decoder attention over source tokens (encoder outputs) |

By taking the weights from the last encoder and decoder layers and calculating the mean over all of the attention heads, the attention weights (avg over heads) are obtained to build attention visualization.

**Note :** 
* `attention_weights = softmax(Q @ K.T / sqrt(d_k)) ` - A probability distribution over all keys (i.e., tokens being attended to) for each query (i.e., the current token).
* `(layers, B, heads, src_len, src_len)` - e.g. `(6, 1, 8, 24, 18)`