maomao88's picture
update readme
3b4b5c1
|
raw
history blame
1.91 kB
metadata
title: Attention Visualization for Translation
emoji: πŸ‘€
colorFrom: pink
colorTo: yellow
sdk: gradio
sdk_version: 5.25.0
app_file: app.py
pinned: false
license: mit
short_description: Visualize attention in translation(EN-ZH)
tags:
  - attention
  - visualization
  - translation
  - cross-attention
  - self-attention
  - transformer
  - encoder-decoder

This app aims to help users better understand the behavior behind the attention layers in transformer models by visualizing the cross-attention and self-attention weights in an encoder-decoder model to see the alignment between and within the source and target tokens.

The app leverages the Helsinki-NLP/opus-mt-en-zh model to perform translation tasks from English to Chinese and by setting output_attentions=True, the attention weights are stored as follows:

Attention Type Shape Role
encoder_attentions (layers, B, heads, src_len, src_len) Encoder self-attention on source tokens
decoder_attentions (layers, B, heads, tgt_len, tgt_len) Decoder self-attention on generated tokens
cross_attentions (layers, B, heads, tgt_len, src_len) Decoder attention over source tokens (encoder outputs)

By taking the weights from the last encoder and decoder layers and calculating the mean over all of the attention heads, the attention weights (avg over heads) are obtained to build attention visualization.

Note :

  • attention_weights = softmax(Q @ K.T / sqrt(d_k)) - A probability distribution over all keys (i.e., tokens being attended to) for each query (i.e., the current token).
  • (layers, B, heads, src_len, src_len) - e.g. (6, 1, 8, 24, 18)