|
import torch |
|
from torch import nn |
|
import gradio as gr |
|
from utils import save_data, get_attn_list, get_top_attns, get_encoder_attn_list |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
model_name = "Helsinki-NLP/opus-mt-en-zh" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
layer_index = model.config.decoder_layers - 1 |
|
|
|
|
|
|
|
def translate_text(input_text): |
|
inputs = tokenizer(input_text, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
translated = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True, |
|
num_beams=1) |
|
|
|
outputs = tokenizer.decode(translated.sequences[0][1:][:-1]) |
|
|
|
|
|
src_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
src_tokens = [token.lstrip('▁_') for token in src_tokens] |
|
|
|
tgt_tokens = tokenizer.convert_ids_to_tokens(translated.sequences[0])[1:] |
|
tgt_tokens = [token.lstrip('▁_') for token in tgt_tokens] |
|
|
|
avg_cross_attn_list = get_attn_list(translated.cross_attentions, layer_index) |
|
cross_attn_scores = get_top_attns(avg_cross_attn_list) |
|
|
|
avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index) |
|
decoder_attn_scores = get_top_attns(avg_decoder_attn_list) |
|
|
|
avg_encoder_attn_list = get_encoder_attn_list(translated.encoder_attentions, layer_index) |
|
encoder_attn_scores = get_top_attns(avg_encoder_attn_list) |
|
|
|
|
|
return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores, render_encoder_decoder_attn_html(src_tokens, "Input"), encoder_attn_scores |
|
|
|
|
|
def render_cross_attn_html(src_tokens, tgt_tokens): |
|
|
|
src_html = "" |
|
for i, token in enumerate(src_tokens): |
|
src_html += f'<span class="token src-token" data-index="{i}">{token}</span> ' |
|
|
|
tgt_html = "" |
|
for i, token in enumerate(tgt_tokens): |
|
tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> ' |
|
|
|
html = f""" |
|
<div class="tgt-token-wrapper-text">Output Tokens</div> |
|
<div class="tgt-token-wrapper">{tgt_html}</div> |
|
<hr class="token-wrapper-seperator"> |
|
<div class="src-token-wrapper-text">Input Tokens</div> |
|
<div class="src-token-wrapper">{src_html}</div> |
|
<div class="scores"><span class="score-1 score"></span><span class="score-2 score"></span><span class="score-3 score"></span><div> |
|
""" |
|
return html |
|
|
|
def render_encoder_decoder_attn_html(tokens, type): |
|
|
|
tokens_html = "" |
|
className = "decoder" |
|
if type == "Input": |
|
className = "encoder" |
|
|
|
for i, token in enumerate(tokens): |
|
tokens_html += f'<span class="token {className}-token" data-index="{i}">{token}</span> ' |
|
|
|
html = f""" |
|
<div class="tgt-token-wrapper-text">{type} Tokens</div> |
|
<div class="tgt-token-wrapper">{tokens_html}</div> |
|
<div class="scores"><span class="score-1 {className}-score"></span><span class="score-2 {className}-score"></span><span class="score-3 {className}-score"></span><div> |
|
""" |
|
return html |
|
|
|
|
|
css = """ |
|
.output-html-desc {padding-top: 1rem} |
|
.output-html {padding-top: 1rem; padding-bottom: 1rem;} |
|
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);} |
|
.token {padding: .5rem; border-radius: 5px;} |
|
.token {cursor: pointer;} |
|
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;} |
|
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;} |
|
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;} |
|
.tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;} |
|
.token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem} |
|
.note-text {margin-bottom: 3.5rem;} |
|
.scores { position: absolute; bottom: 0.75rem; color: rgb(113, 113, 122); right: 1rem;} |
|
.score-1 { display: none; background-color: #FF8865; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;} |
|
.score-2 { display: none; background-color: #FFD2C4; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;} |
|
.score-3 { display: none; background-color: #FFF3F0; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;} |
|
""" |
|
|
|
js = """ |
|
function showCrossAttFun(attn_scores, decoder_attn, encoder_attn) { |
|
|
|
const scrTokens = document.querySelectorAll('.src-token'); |
|
const srcLen = scrTokens.length - 1 |
|
const targetTokens = document.querySelectorAll('.tgt-token'); |
|
const scores = document.querySelectorAll('.score'); |
|
|
|
const decoderTokens = document.querySelectorAll('.decoder-token'); |
|
const decLen = decoderTokens.length - 1 |
|
const decoderScores = document.querySelectorAll('.decoder-score'); |
|
|
|
const encoderTokens = document.querySelectorAll('.encoder-token'); |
|
const encLen = encoderTokens.length - 1 |
|
const encoderScores = document.querySelectorAll('.encoder-score'); |
|
|
|
function onTgtHover(event, idx) { |
|
event.style.backgroundColor = "#C6E6E6"; |
|
|
|
srcIdx0 = attn_scores[idx]['top_index'][0] |
|
if (srcIdx0 < srcLen) { |
|
srcEl0 = scrTokens[srcIdx0] |
|
srcEl0.style.backgroundColor = "#FF8865" |
|
scores[0].textContent = attn_scores[idx]['top_values'][0] |
|
scores[0].style.display = "initial"; |
|
} |
|
|
|
srcIdx1 = attn_scores[idx]['top_index'][1] |
|
if (srcIdx1 < srcLen) { |
|
srcEl1 = scrTokens[srcIdx1] |
|
srcEl1.style.backgroundColor = "#FFD2C4" |
|
scores[1].textContent = attn_scores[idx]['top_values'][1] |
|
scores[1].style.display = "initial"; |
|
} |
|
|
|
srcIdx2 = attn_scores[idx]['top_index'][2] |
|
if (srcIdx2 < srcLen) { |
|
srcEl2 = scrTokens[srcIdx2] |
|
srcEl2.style.backgroundColor = "#FFF3F0" |
|
scores[2].textContent = attn_scores[idx]['top_values'][2] |
|
scores[2].style.display = "initial"; |
|
} |
|
} |
|
|
|
function outHover(event, idx) { |
|
event.style.backgroundColor = ""; |
|
srcIdx0 = attn_scores[idx]['top_index'][0] |
|
srcIdx1 = attn_scores[idx]['top_index'][1] |
|
srcIdx2 = attn_scores[idx]['top_index'][2] |
|
srcEl0 = scrTokens[srcIdx0] |
|
srcEl0.style.backgroundColor = "" |
|
scores[0].textContent = "" |
|
scores[0].style.display = "none"; |
|
srcEl1 = scrTokens[srcIdx1] |
|
srcEl1.style.backgroundColor = "" |
|
scores[1].textContent = "" |
|
scores[1].style.display = "none"; |
|
srcEl2 = scrTokens[srcIdx2] |
|
srcEl2.style.backgroundColor = "" |
|
scores[2].textContent = "" |
|
scores[2].style.display = "none"; |
|
} |
|
|
|
function onDecodeHover(event, idx) { |
|
idx0 = decoder_attn[idx]['top_index'][0] |
|
if (idx0 < decLen) { |
|
el0 = decoderTokens[idx0] |
|
el0.style.backgroundColor = "#FF8865" |
|
decoderScores[0].textContent = decoder_attn[idx]['top_values'][0] |
|
decoderScores[0].style.display = "initial"; |
|
} |
|
|
|
idx1 = decoder_attn[idx]['top_index'][1] |
|
if (idx1 < decLen) { |
|
el1 = decoderTokens[idx1] |
|
el1.style.backgroundColor = "#FFD2C4" |
|
decoderScores[1].textContent = decoder_attn[idx]['top_values'][1] |
|
decoderScores[1].style.display = "initial"; |
|
} |
|
|
|
idx2 = decoder_attn[idx]['top_index'][2] |
|
if (idx2 < decLen) { |
|
el2 = decoderTokens[idx2] |
|
el2.style.backgroundColor = "#FFF3F0" |
|
decoderScores[2].textContent = decoder_attn[idx]['top_values'][2] |
|
decoderScores[2].style.display = "initial"; |
|
} |
|
|
|
for (i=idx+1; i < decoderTokens.length; i++) { |
|
decoderTokens[i].style.color = "#ccc9c9"; |
|
} |
|
|
|
} |
|
|
|
function outDecodeHover(event, idx) { |
|
event.style.backgroundColor = ""; |
|
idx0 = decoder_attn[idx]['top_index'][0] |
|
el0 = decoderTokens[idx0] |
|
el0.style.backgroundColor = "" |
|
decoderScores[0].textContent = "" |
|
decoderScores[0].style.display = "none"; |
|
|
|
idx1 = decoder_attn[idx]['top_index'][1] |
|
if (idx1 || idx1 == 0) { |
|
el1 = decoderTokens[idx1] |
|
el1.style.backgroundColor = "" |
|
decoderScores[1].textContent = "" |
|
decoderScores[1].style.display = "none"; |
|
} |
|
|
|
idx2 = decoder_attn[idx]['top_index'][2] |
|
if (idx2 || idx2 == 0) { |
|
el2 = decoderTokens[idx2] |
|
el2.style.backgroundColor = "" |
|
decoderScores[2].textContent = "" |
|
decoderScores[2].style.display = "none"; |
|
} |
|
|
|
for (i=idx+1; i < decoderTokens.length; i++) { |
|
decoderTokens[i].style.color = "black"; |
|
} |
|
} |
|
|
|
|
|
function onEncodeHover(event, idx) { |
|
idx0 = encoder_attn[idx]['top_index'][0] |
|
if (idx0 < encLen) { |
|
el0 = encoderTokens[idx0] |
|
el0.style.backgroundColor = "#89C6C6" |
|
encoderScores[0].textContent = encoder_attn[idx]['top_values'][0] |
|
encoderScores[0].style.display = "initial" |
|
encoderScores[0].style.backgroundColor = "#89C6C6" |
|
} |
|
|
|
idx1 = encoder_attn[idx]['top_index'][1] |
|
if (idx1 < encLen) { |
|
el1 = encoderTokens[idx1] |
|
el1.style.backgroundColor = "#C6E6E6" |
|
encoderScores[1].textContent = encoder_attn[idx]['top_values'][1] |
|
encoderScores[1].style.display = "initial" |
|
encoderScores[1].style.backgroundColor = "#C6E6E6" |
|
} |
|
|
|
idx2 = encoder_attn[idx]['top_index'][2] |
|
if (idx2 < encLen) { |
|
el2 = encoderTokens[idx2] |
|
el2.style.backgroundColor = "#E5F5F5" |
|
encoderScores[2].textContent = encoder_attn[idx]['top_values'][2] |
|
encoderScores[2].style.display = "initial" |
|
encoderScores[2].style.backgroundColor = "#E5F5F5" |
|
} |
|
|
|
} |
|
|
|
function outEncodeHover(event, idx) { |
|
event.style.backgroundColor = ""; |
|
idx0 = encoder_attn[idx]['top_index'][0] |
|
el0 = encoderTokens[idx0] |
|
el0.style.backgroundColor = "" |
|
encoderScores[0].textContent = "" |
|
encoderScores[0].style.display = "none"; |
|
|
|
idx1 = encoder_attn[idx]['top_index'][1] |
|
if (idx1 || idx1 == 0) { |
|
el1 = encoderTokens[idx1] |
|
el1.style.backgroundColor = "" |
|
encoderScores[1].textContent = "" |
|
encoderScores[1].style.display = "none"; |
|
} |
|
|
|
idx2 = encoder_attn[idx]['top_index'][2] |
|
if (idx2 || idx2 == 0) { |
|
el2 = encoderTokens[idx2] |
|
el2.style.backgroundColor = "" |
|
encoderScores[2].textContent = "" |
|
encoderScores[2].style.display = "none"; |
|
} |
|
} |
|
|
|
|
|
targetTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseover", () => { |
|
onTgtHover(el, idx) |
|
}) |
|
}); |
|
|
|
targetTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseout", () => { |
|
outHover(el, idx) |
|
}) |
|
}); |
|
|
|
decoderTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseover", () => { |
|
onDecodeHover(el, idx) |
|
}) |
|
}); |
|
|
|
decoderTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseout", () => { |
|
outDecodeHover(el, idx) |
|
}) |
|
}); |
|
|
|
encoderTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseover", () => { |
|
onEncodeHover(el, idx) |
|
}) |
|
}); |
|
|
|
encoderTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseout", () => { |
|
outEncodeHover(el, idx) |
|
}) |
|
}); |
|
} |
|
""" |
|
|
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(""" |
|
## 🕸️ Visualize Attentions in Translated Text (English to Chinese) |
|
This app aims to help users better understand the behavior behind the attention layers in transformer models by visualizing the cross-attention and self-attention weights in an encoder-decoder model to see the alignment between and within the source and target tokens. |
|
|
|
After translating your English input to Chinese, you can check the cross attentions and self-attentions of the translation in the lower section of the page. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_box = gr.Textbox(lines=4, label="Input Text (English)") |
|
with gr.Column(): |
|
output_box = gr.Textbox(lines=4, label="Translated Text (Chinese)") |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["A bird can fly and so can a fly"], |
|
["She sat by the river bank, letting the cool breeze and the sound of flowing water calm her thoughts."] |
|
], |
|
inputs=[input_box] |
|
) |
|
|
|
translate_button = gr.Button("Translate", variant="primary") |
|
|
|
cross_attn = gr.JSON(value=[], visible=False) |
|
decoder_attn = gr.JSON(value=[], visible=False) |
|
encoder_attn = gr.JSON(value=[], visible=False) |
|
|
|
gr.Markdown( |
|
""" |
|
## Check Cross Attentions |
|
Cross attention is a key component in transformers, where a sequence (English Text) can attend to another sequence’s information (Chinese Text). |
|
Hover your mouse over an output (Chinese) word/token to see which input (English) word/token it is attending to. |
|
""", |
|
elem_classes="output-html-desc" |
|
) |
|
with gr.Row(elem_classes="output-html-row"): |
|
output_html = gr.HTML(label="Cross Attention", elem_classes="output-html") |
|
|
|
gr.Markdown( |
|
""" |
|
## Check Self Attentions for Encoder |
|
Hover your mouse over an input (English) word/token to see which word/token it is self-attending to. |
|
""", |
|
elem_classes="output-html-desc" |
|
) |
|
|
|
with gr.Row(elem_classes="output-html-row"): |
|
encoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html") |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
## Check Self Attentions for Decoder |
|
Hover your mouse over an output (Chinese) word/token to see which word/token it is self-attending to. |
|
Notice that decoder tokens only attend to tokens on its left as during the generation of each token, it pays attention only to the past not to the future. |
|
""", |
|
elem_classes="output-html-desc" |
|
) |
|
|
|
with gr.Row(elem_classes="output-html-row"): |
|
decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html") |
|
|
|
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn, encoder_output_html, encoder_attn]) |
|
|
|
output_box.change(None, [cross_attn, decoder_attn, encoder_attn], None, js=js) |
|
|
|
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ", |
|
elem_classes="note-text") |
|
|
|
|
|
|
|
demo.launch() |
|
|