|
import torch |
|
from torch import nn |
|
import gradio as gr |
|
from utils import save_data, get_attn_list, get_top_attns |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
model_name = "Helsinki-NLP/opus-mt-en-zh" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
layer_index = model.config.decoder_layers - 1 |
|
|
|
|
|
|
|
def translate_text(input_text): |
|
inputs = tokenizer(input_text, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
translated = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True, |
|
num_beams=1) |
|
|
|
outputs = tokenizer.decode(translated.sequences[0][1:][:-1]) |
|
|
|
|
|
src_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
src_tokens = [token.lstrip('▁_') for token in src_tokens] |
|
|
|
tgt_tokens = tokenizer.convert_ids_to_tokens(translated.sequences[0])[1:] |
|
tgt_tokens = [token.lstrip('▁_') for token in tgt_tokens] |
|
|
|
avg_cross_attn_list = get_attn_list(translated.cross_attentions, layer_index) |
|
cross_attn_scores = get_top_attns(avg_cross_attn_list) |
|
|
|
avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index) |
|
decoder_attn_scores = get_top_attns(avg_decoder_attn_list) |
|
|
|
|
|
return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores |
|
|
|
|
|
def render_cross_attn_html(src_tokens, tgt_tokens): |
|
|
|
src_html = "" |
|
for i, token in enumerate(src_tokens): |
|
src_html += f'<span class="token src-token" data-index="{i}">{token}</span> ' |
|
|
|
tgt_html = "" |
|
for i, token in enumerate(tgt_tokens): |
|
tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> ' |
|
|
|
html = f""" |
|
<div class="tgt-token-wrapper-text">Output Tokens</div> |
|
<div class="tgt-token-wrapper">{tgt_html}</div> |
|
<hr class="token-wrapper-seperator"> |
|
<div class="src-token-wrapper-text">Input Tokens</div> |
|
<div class="src-token-wrapper">{src_html}</div> |
|
<div class="scores"><span class="score-1 score"></span><span class="score-2 score"></span><span class="score-3 score"></span><div> |
|
""" |
|
return html |
|
|
|
def render_encoder_decoder_attn_html(tokens, type): |
|
|
|
tokens_html = "" |
|
for i, token in enumerate(tokens): |
|
tokens_html += f'<span class="token decoder-token" data-index="{i}">{token}</span> ' |
|
|
|
html = f""" |
|
<div class="tgt-token-wrapper-text">{type} Tokens</div> |
|
<div class="tgt-token-wrapper">{tokens_html}</div> |
|
<div class="scores"><span class="score-1 decoder-score"></span><span class="score-2 decoder-score"></span><span class="score-3 decoder-score"></span><div> |
|
""" |
|
return html |
|
|
|
|
|
css = """ |
|
.output-html-desc {padding-top: 1rem} |
|
.output-html {padding-top: 1rem; padding-bottom: 1rem;} |
|
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);} |
|
.token {padding: .5rem; border-radius: 5px;} |
|
.tgt-token {cursor: pointer;} |
|
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;} |
|
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;} |
|
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;} |
|
.tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;} |
|
.token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem} |
|
.note-text {margin-bottom: 3.5rem;} |
|
.scores { position: absolute; bottom: 0.75rem; color: rgb(113, 113, 122); right: 1rem;} |
|
.score-1 { display: none; background-color: #FF8865; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;} |
|
.score-2 { display: none; background-color: #FFD2C4; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;} |
|
.score-3 { display: none; background-color: #FFF3F0; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;} |
|
""" |
|
|
|
js = """ |
|
function showCrossAttFun(attn_scores, decoder_attn) { |
|
|
|
const scrTokens = document.querySelectorAll('.src-token'); |
|
const srcLen = scrTokens.length - 1 |
|
const targetTokens = document.querySelectorAll('.tgt-token'); |
|
const scores = document.querySelectorAll('.score'); |
|
|
|
|
|
const decoderTokens = document.querySelectorAll('.decoder-token'); |
|
const decLen = decoderTokens.length - 1 |
|
const decoderScores = document.querySelectorAll('.decoder-score'); |
|
|
|
function onTgtHover(event, idx) { |
|
event.style.backgroundColor = "#C6E6E6"; |
|
|
|
srcIdx0 = attn_scores[idx]['top_index'][0] |
|
if (srcIdx0 < srcLen) { |
|
srcEl0 = scrTokens[srcIdx0] |
|
srcEl0.style.backgroundColor = "#FF8865" |
|
scores[0].textContent = attn_scores[idx]['top_values'][0] |
|
scores[0].style.display = "initial"; |
|
} |
|
|
|
srcIdx1 = attn_scores[idx]['top_index'][1] |
|
if (srcIdx1 < srcLen) { |
|
srcEl1 = scrTokens[srcIdx1] |
|
srcEl1.style.backgroundColor = "#FFD2C4" |
|
scores[1].textContent = attn_scores[idx]['top_values'][1] |
|
scores[1].style.display = "initial"; |
|
} |
|
|
|
srcIdx2 = attn_scores[idx]['top_index'][2] |
|
if (srcIdx2 < srcLen) { |
|
srcEl2 = scrTokens[srcIdx2] |
|
srcEl2.style.backgroundColor = "#FFF3F0" |
|
scores[2].textContent = attn_scores[idx]['top_values'][2] |
|
scores[2].style.display = "initial"; |
|
} |
|
} |
|
|
|
function outHover(event, idx) { |
|
event.style.backgroundColor = ""; |
|
srcIdx0 = attn_scores[idx]['top_index'][0] |
|
srcIdx1 = attn_scores[idx]['top_index'][1] |
|
srcIdx2 = attn_scores[idx]['top_index'][2] |
|
srcEl0 = scrTokens[srcIdx0] |
|
srcEl0.style.backgroundColor = "" |
|
scores[0].textContent = "" |
|
scores[0].style.display = "none"; |
|
srcEl1 = scrTokens[srcIdx1] |
|
srcEl1.style.backgroundColor = "" |
|
scores[1].textContent = "" |
|
scores[1].style.display = "none"; |
|
srcEl2 = scrTokens[srcIdx2] |
|
srcEl2.style.backgroundColor = "" |
|
scores[2].textContent = "" |
|
scores[2].style.display = "none"; |
|
} |
|
|
|
function onSelfHover(event, idx) { |
|
event.style.backgroundColor = "#C6E6E6"; |
|
|
|
idx0 = decoder_attn[idx]['top_index'][0] |
|
if (idx0 < decLen) { |
|
el0 = decoderTokens[idx0] |
|
el0.style.backgroundColor = "#FF8865" |
|
decoderScores[0].textContent = decoder_attn[idx]['top_values'][0] |
|
decoderScores[0].style.display = "initial"; |
|
} |
|
|
|
idx1 = decoder_attn[idx]['top_index'][1] |
|
if (idx1 < decLen) { |
|
el1 = decoderTokens[idx1] |
|
el1.style.backgroundColor = "#FFD2C4" |
|
decoderScores[1].textContent = decoder_attn[idx]['top_values'][1] |
|
decoderScores[1].style.display = "initial"; |
|
} |
|
|
|
idx2 = decoder_attn[idx]['top_index'][2] |
|
if (idx2 < decLen) { |
|
el2 = decoderTokens[idx2] |
|
el2.style.backgroundColor = "#FFF3F0" |
|
decoderScores[2].textContent = decoder_attn[idx]['top_values'][2] |
|
decoderScores[2].style.display = "initial"; |
|
} |
|
|
|
for (i=idx+1; i < decoderTokens.length; i++) { |
|
decoderTokens[i].style.color = "#aaa8a8"; |
|
} |
|
|
|
} |
|
|
|
function outSelfHover(event, idx) { |
|
event.style.backgroundColor = ""; |
|
idx0 = decoder_attn[idx]['top_index'][0] |
|
el0 = decoderTokens[idx0] |
|
el0.style.backgroundColor = "" |
|
decoderScores[0].textContent = "" |
|
decoderScores[0].style.display = "none"; |
|
|
|
idx1 = decoder_attn[idx]['top_index'][1] |
|
if (idx1 || idx1 == 0) { |
|
el1 = decoderTokens[idx1] |
|
el1.style.backgroundColor = "" |
|
decoderScores[1].textContent = "" |
|
decoderScores[1].style.display = "none"; |
|
} |
|
|
|
idx2 = decoder_attn[idx]['top_index'][2] |
|
if (idx2 || idx2 == 0) { |
|
el2 = decoderTokens[idx2] |
|
el2.style.backgroundColor = "" |
|
decoderScores[2].textContent = "" |
|
decoderScores[2].style.display = "none"; |
|
} |
|
|
|
for (i=idx+1; i < decoderTokens.length; i++) { |
|
decoderTokens[i].style.color = "black"; |
|
} |
|
} |
|
|
|
|
|
targetTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseover", () => { |
|
onTgtHover(el, idx) |
|
}) |
|
}); |
|
|
|
targetTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseout", () => { |
|
outHover(el, idx) |
|
}) |
|
}); |
|
|
|
decoderTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseover", () => { |
|
onSelfHover(el, idx) |
|
}) |
|
}); |
|
|
|
decoderTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseout", () => { |
|
outSelfHover(el, idx) |
|
}) |
|
}); |
|
} |
|
""" |
|
|
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(""" |
|
## 🕸️ Visualize Attentions in Translated Text (English to Chinese) |
|
After translating your English input to Chinese, you can check the cross attentions and self-attentions of the translation in the lower section of the page. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_box = gr.Textbox(lines=4, label="Input Text (English)") |
|
with gr.Column(): |
|
output_box = gr.Textbox(lines=4, label="Translated Text (Chinese)") |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["They heard the click of the front door and knew that the Dursleys had left the house."], |
|
["Azkaban was a fortress where the most dangerous dark wizards were held, guarded by creatures called Dementors."] |
|
], |
|
inputs=[input_box] |
|
) |
|
|
|
translate_button = gr.Button("Translate", variant="primary") |
|
|
|
cross_attn = gr.JSON(value=[], visible=False) |
|
decoder_attn = gr.JSON(value=[], visible=False) |
|
|
|
gr.Markdown( |
|
""" |
|
## Check Cross Attentions |
|
Cross attention is a key component in transformers, where a sequence (English Text) can attend to another sequence’s information (Chinese Text). |
|
Hover your mouse over an output (Chinese) word/token to see which input (English) word/token it is attending to. |
|
""", |
|
elem_classes="output-html-desc" |
|
) |
|
with gr.Row(elem_classes="output-html-row"): |
|
output_html = gr.HTML(label="Cross Attention", elem_classes="output-html") |
|
|
|
gr.Markdown( |
|
""" |
|
## Check Self Attentions for Decoder |
|
Hover your mouse over an output (Chinese) word/token to see which word/token it is self-attending to. |
|
Notice that decoder tokens only attend to tokens on its left as during the generation of each token, it pays attention only to the past not to the future. |
|
""", |
|
elem_classes="output-html-desc" |
|
) |
|
|
|
with gr.Row(elem_classes="output-html-row"): |
|
decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html") |
|
|
|
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn]) |
|
|
|
output_box.change(None, [cross_attn, decoder_attn], None, js=js) |
|
|
|
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ", |
|
elem_classes="note-text") |
|
|
|
|
|
|
|
demo.launch() |
|
|