|
import torch |
|
from torch import nn |
|
import gradio as gr |
|
import heapq |
|
import pickle |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
model_name = "Helsinki-NLP/opus-mt-en-zh" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
layer_index = model.config.decoder_layers - 1 |
|
|
|
|
|
def save_data(outputs, src_tokens, tgt_tokens, attn_scores): |
|
data = {'outputs': outputs, 'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'attn_scores': attn_scores} |
|
|
|
|
|
with open("data.pkl", "wb") as f: |
|
pickle.dump(data, f) |
|
|
|
|
|
def get_attn_list(cross_attentions): |
|
avg_attn_list = [] |
|
|
|
for i in range(len(cross_attentions)): |
|
token_index = i |
|
attn_tensor = cross_attentions[token_index][layer_index] |
|
avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) |
|
|
|
return avg_attn_list |
|
|
|
def get_top_attns(avg_attn_list): |
|
avg_attn_top = [] |
|
|
|
for i in range(len(avg_attn_list)): |
|
|
|
top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1]) |
|
|
|
|
|
top_values = [val for idx, val in top_3] |
|
top_index = [idx for idx, val in top_3] |
|
|
|
avg_attn_top.append({ |
|
"top_values": top_values, |
|
"top_index": top_index |
|
}) |
|
|
|
return avg_attn_top |
|
|
|
|
|
|
|
def translate_text(input_text): |
|
inputs = tokenizer(input_text, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
translated = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True, |
|
num_beams=1) |
|
|
|
outputs = tokenizer.decode(translated.sequences[0][1:][:-1]) |
|
|
|
|
|
src_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
src_tokens = [token.lstrip('▁_') for token in src_tokens] |
|
|
|
tgt_tokens = tokenizer.convert_ids_to_tokens(translated.sequences[0])[1:] |
|
tgt_tokens = [token.lstrip('▁_') for token in tgt_tokens] |
|
|
|
avg_attn_list = get_attn_list(translated.cross_attentions) |
|
attn_scores = get_top_attns(avg_attn_list) |
|
|
|
|
|
return outputs, render_attention_html(src_tokens, tgt_tokens), attn_scores |
|
|
|
|
|
def render_attention_html(src_tokens, tgt_tokens): |
|
|
|
src_html = "" |
|
for i, token in enumerate(src_tokens): |
|
src_html += f'<span class="token src-token" data-index="{i}">{token}</span> ' |
|
|
|
tgt_html = "" |
|
for i, token in enumerate(tgt_tokens): |
|
tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> ' |
|
|
|
html = f'<div class="tgt-token-wrapper-text">Output Tokens</div><div class="tgt-token-wrapper">{tgt_html}</div><hr class="token-wrapper-seperator"><div class="src-token-wrapper-text">Input Tokens</div><div class="src-token-wrapper">{src_html}</div>' |
|
return html |
|
|
|
|
|
css = """ |
|
.output-html-desc {padding-top: 1rem} |
|
.output-html {padding-top: 1rem; padding-bottom: 1rem;} |
|
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);} |
|
.token {padding: .5rem; border-radius: 5px;} |
|
.tgt-token {cursor: pointer;} |
|
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;} |
|
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;} |
|
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;} |
|
.tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;} |
|
.token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem} |
|
.note-text {margin-bottom: 3.5rem;} |
|
""" |
|
|
|
js = """ |
|
function showCrossAttFun(attn_scores) { |
|
|
|
const scrTokens = document.querySelectorAll('.src-token'); |
|
const srcLen = scrTokens.length - 1 |
|
|
|
const targetTokens = document.querySelectorAll('.tgt-token'); |
|
|
|
function onTgtHover(event, idx) { |
|
event.style.backgroundColor = "#C6E6E6"; |
|
|
|
srcIdx0 = attn_scores[idx]['top_index'][0] |
|
if (srcIdx0 < srcLen) { |
|
srcEl0 = scrTokens[srcIdx0] |
|
srcEl0.style.backgroundColor = "#FF8865" |
|
} |
|
|
|
srcIdx1 = attn_scores[idx]['top_index'][1] |
|
if (srcIdx1 < srcLen) { |
|
srcEl1 = scrTokens[srcIdx1] |
|
srcEl1.style.backgroundColor = "#FFD2C4" |
|
} |
|
|
|
srcIdx2 = attn_scores[idx]['top_index'][2] |
|
if (srcIdx2 < srcLen) { |
|
srcEl2 = scrTokens[srcIdx2] |
|
srcEl2.style.backgroundColor = "#FFF3F0" |
|
} |
|
} |
|
|
|
function outHover(event, idx) { |
|
event.style.backgroundColor = ""; |
|
srcIdx0 = attn_scores[idx]['top_index'][0] |
|
srcIdx1 = attn_scores[idx]['top_index'][1] |
|
srcIdx2 = attn_scores[idx]['top_index'][2] |
|
srcEl0 = scrTokens[srcIdx0] |
|
srcEl0.style.backgroundColor = "" |
|
srcEl1 = scrTokens[srcIdx1] |
|
srcEl1.style.backgroundColor = "" |
|
srcEl2 = scrTokens[srcIdx2] |
|
srcEl2.style.backgroundColor = "" |
|
} |
|
|
|
|
|
targetTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseover", () => { |
|
onTgtHover(el, idx) |
|
}) |
|
}); |
|
|
|
targetTokens.forEach((el, idx) => { |
|
el.addEventListener("mouseout", () => { |
|
outHover(el, idx) |
|
}) |
|
}); |
|
} |
|
""" |
|
|
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(""" |
|
## 🕸️ Visualize Cross Attention between Translated Text (English to Chinese) |
|
Cross attention is a key component in transformers, where a sequence (English Text) can attend to another sequence’s information (Chinese Text). |
|
You can check the cross attention of the translated text in the lower section of the page. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_box = gr.Textbox(lines=4, label="Input Text (English)") |
|
with gr.Column(): |
|
output_box = gr.Textbox(lines=4, label="Translated Text (Chinese)") |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["They heard the click of the front door and knew that the Dursleys had left the house."], |
|
["Azkaban was a fortress where the most dangerous dark wizards were held, guarded by creatures called Dementors."] |
|
], |
|
inputs=[input_box] |
|
) |
|
|
|
translate_button = gr.Button("Translate", variant="primary") |
|
|
|
attn = gr.JSON(value=[], visible=False) |
|
|
|
gr.Markdown( |
|
""" |
|
## Check Cross Attentions |
|
Hover your mouse over an output (Chinese) word/token to see which input (English) word/token it is attending to. |
|
""", |
|
elem_classes="output-html-desc" |
|
) |
|
with gr.Row(elem_classes="output-html-row"): |
|
output_html = gr.HTML(label="Translated Text (HTML)", elem_classes="output-html") |
|
|
|
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, attn]) |
|
|
|
output_box.change(None, attn, None, js=js) |
|
|
|
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ", |
|
elem_classes="note-text") |
|
|
|
|
|
|
|
demo.launch() |
|
|