first commit
Browse files- .gitignore +1 -0
- app.py +212 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.idea/
|
app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import gradio as gr
|
4 |
+
import heapq
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
|
8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
9 |
+
|
10 |
+
model_name = "Helsinki-NLP/opus-mt-en-zh"
|
11 |
+
# model_name = "Helsinki-NLP/opus-mt-zh-en"
|
12 |
+
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
15 |
+
|
16 |
+
layer_index = model.config.decoder_layers - 1 # last decoder layer index
|
17 |
+
|
18 |
+
|
19 |
+
def save_data(outputs, src_tokens, tgt_tokens, attn_scores):
|
20 |
+
data = {'outputs': outputs, 'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'attn_scores': attn_scores}
|
21 |
+
|
22 |
+
# Save to file
|
23 |
+
with open("data.pkl", "wb") as f:
|
24 |
+
pickle.dump(data, f)
|
25 |
+
|
26 |
+
|
27 |
+
def get_attn_list(cross_attentions):
|
28 |
+
avg_attn_list = []
|
29 |
+
|
30 |
+
for i in range(len(cross_attentions)):
|
31 |
+
token_index = i # pick a token index from the output (1 to 18)
|
32 |
+
attn_tensor = cross_attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
|
33 |
+
avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
|
34 |
+
|
35 |
+
return avg_attn_list
|
36 |
+
|
37 |
+
def get_top_attns(avg_attn_list):
|
38 |
+
avg_attn_top = []
|
39 |
+
|
40 |
+
for i in range(len(avg_attn_list)):
|
41 |
+
# Get top 3 (index, value) pairs
|
42 |
+
top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
|
43 |
+
|
44 |
+
# get the indices and values of the source tokens
|
45 |
+
top_values = [val for idx, val in top_3]
|
46 |
+
top_index = [idx for idx, val in top_3]
|
47 |
+
|
48 |
+
avg_attn_top.append({
|
49 |
+
"top_values": top_values,
|
50 |
+
"top_index": top_index
|
51 |
+
})
|
52 |
+
|
53 |
+
return avg_attn_top
|
54 |
+
|
55 |
+
|
56 |
+
# Define translation function
|
57 |
+
def translate_text(input_text):
|
58 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True)
|
59 |
+
with torch.no_grad():
|
60 |
+
translated = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True,
|
61 |
+
num_beams=1)
|
62 |
+
|
63 |
+
outputs = tokenizer.decode(translated.sequences[0][1:][:-1])
|
64 |
+
|
65 |
+
# Decode tokens
|
66 |
+
src_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
67 |
+
src_tokens = [token.lstrip('▁_') for token in src_tokens]
|
68 |
+
|
69 |
+
tgt_tokens = tokenizer.convert_ids_to_tokens(translated.sequences[0])[1:]
|
70 |
+
tgt_tokens = [token.lstrip('▁_') for token in tgt_tokens]
|
71 |
+
|
72 |
+
avg_attn_list = get_attn_list(translated.cross_attentions)
|
73 |
+
attn_scores = get_top_attns(avg_attn_list)
|
74 |
+
|
75 |
+
# save_data(outputs, src_tokens, tgt_tokens, attn_scores)
|
76 |
+
return outputs, render_attention_html(src_tokens, tgt_tokens), attn_scores
|
77 |
+
|
78 |
+
|
79 |
+
def render_attention_html(src_tokens, tgt_tokens):
|
80 |
+
# Build HTML for source and target tokens
|
81 |
+
src_html = ""
|
82 |
+
for i, token in enumerate(src_tokens):
|
83 |
+
src_html += f'<span class="token src-token" data-index="{i}">{token}</span> '
|
84 |
+
|
85 |
+
tgt_html = ""
|
86 |
+
for i, token in enumerate(tgt_tokens):
|
87 |
+
tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> '
|
88 |
+
|
89 |
+
html = f'<div class="tgt-token-wrapper-text">Output Tokens</div><div class="tgt-token-wrapper">{tgt_html}</div><hr class="token-wrapper-seperator"><div class="src-token-wrapper-text">Input Tokens</div><div class="src-token-wrapper">{src_html}</div>'
|
90 |
+
return html
|
91 |
+
|
92 |
+
|
93 |
+
css = """
|
94 |
+
.output-html-desc {padding-top: 1rem}
|
95 |
+
.output-html {padding-top: 1rem; padding-bottom: 1rem;}
|
96 |
+
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
|
97 |
+
.token {padding: .5rem; border-radius: 5px;}
|
98 |
+
.tgt-token {cursor: pointer;}
|
99 |
+
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
100 |
+
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
101 |
+
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
|
102 |
+
.tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;}
|
103 |
+
.token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem}
|
104 |
+
.note-text {margin-bottom: 3.5rem;}
|
105 |
+
"""
|
106 |
+
|
107 |
+
js = """
|
108 |
+
function showCrossAttFun(attn_scores) {
|
109 |
+
|
110 |
+
const scrTokens = document.querySelectorAll('.src-token');
|
111 |
+
const srcLen = scrTokens.length - 1
|
112 |
+
|
113 |
+
const targetTokens = document.querySelectorAll('.tgt-token');
|
114 |
+
|
115 |
+
function onTgtHover(event, idx) {
|
116 |
+
event.style.backgroundColor = "#C6E6E6";
|
117 |
+
|
118 |
+
srcIdx0 = attn_scores[idx]['top_index'][0]
|
119 |
+
if (srcIdx0 < srcLen) {
|
120 |
+
srcEl0 = scrTokens[srcIdx0]
|
121 |
+
srcEl0.style.backgroundColor = "#FF8865"
|
122 |
+
}
|
123 |
+
|
124 |
+
srcIdx1 = attn_scores[idx]['top_index'][1]
|
125 |
+
if (srcIdx1 < srcLen) {
|
126 |
+
srcEl1 = scrTokens[srcIdx1]
|
127 |
+
srcEl1.style.backgroundColor = "#FFD2C4"
|
128 |
+
}
|
129 |
+
|
130 |
+
srcIdx2 = attn_scores[idx]['top_index'][2]
|
131 |
+
if (srcIdx2 < srcLen) {
|
132 |
+
srcEl2 = scrTokens[srcIdx2]
|
133 |
+
srcEl2.style.backgroundColor = "#FFF3F0"
|
134 |
+
}
|
135 |
+
}
|
136 |
+
|
137 |
+
function outHover(event, idx) {
|
138 |
+
event.style.backgroundColor = "";
|
139 |
+
srcIdx0 = attn_scores[idx]['top_index'][0]
|
140 |
+
srcIdx1 = attn_scores[idx]['top_index'][1]
|
141 |
+
srcIdx2 = attn_scores[idx]['top_index'][2]
|
142 |
+
srcEl0 = scrTokens[srcIdx0]
|
143 |
+
srcEl0.style.backgroundColor = ""
|
144 |
+
srcEl1 = scrTokens[srcIdx1]
|
145 |
+
srcEl1.style.backgroundColor = ""
|
146 |
+
srcEl2 = scrTokens[srcIdx2]
|
147 |
+
srcEl2.style.backgroundColor = ""
|
148 |
+
}
|
149 |
+
|
150 |
+
|
151 |
+
targetTokens.forEach((el, idx) => {
|
152 |
+
el.addEventListener("mouseover", () => {
|
153 |
+
onTgtHover(el, idx)
|
154 |
+
})
|
155 |
+
});
|
156 |
+
|
157 |
+
targetTokens.forEach((el, idx) => {
|
158 |
+
el.addEventListener("mouseout", () => {
|
159 |
+
outHover(el, idx)
|
160 |
+
})
|
161 |
+
});
|
162 |
+
}
|
163 |
+
"""
|
164 |
+
|
165 |
+
|
166 |
+
# Gradio Interface
|
167 |
+
with gr.Blocks(css=css) as demo:
|
168 |
+
gr.Markdown("""
|
169 |
+
## 🕸️ Visualize Cross Attention between Translated Text (English to Chinese)
|
170 |
+
Cross attention is a key component in transformers, where a sequence (English Text) can attend to another sequence’s information (Chinese Text).
|
171 |
+
You can check the cross attention of the translated text in the lower section of the page.
|
172 |
+
""")
|
173 |
+
|
174 |
+
with gr.Row():
|
175 |
+
with gr.Column():
|
176 |
+
input_box = gr.Textbox(lines=4, label="Input Text (English)")
|
177 |
+
with gr.Column():
|
178 |
+
output_box = gr.Textbox(lines=4, label="Translated Text (Chinese)")
|
179 |
+
|
180 |
+
# Examples Section
|
181 |
+
gr.Examples(
|
182 |
+
examples=[
|
183 |
+
["They heard the click of the front door and knew that the Dursleys had left the house."],
|
184 |
+
["Azkaban was a fortress where the most dangerous dark wizards were held, guarded by creatures called Dementors."]
|
185 |
+
],
|
186 |
+
inputs=[input_box]
|
187 |
+
)
|
188 |
+
|
189 |
+
translate_button = gr.Button("Translate", variant="primary")
|
190 |
+
|
191 |
+
attn = gr.JSON(value=[], visible=False)
|
192 |
+
|
193 |
+
gr.Markdown(
|
194 |
+
"""
|
195 |
+
## Check Cross Attentions
|
196 |
+
Hover your mouse over an output (Chinese) word/token to see which input (English) word/token it is attending to.
|
197 |
+
""",
|
198 |
+
elem_classes="output-html-desc"
|
199 |
+
)
|
200 |
+
with gr.Row(elem_classes="output-html-row"):
|
201 |
+
output_html = gr.HTML(label="Translated Text (HTML)", elem_classes="output-html")
|
202 |
+
|
203 |
+
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, attn])
|
204 |
+
|
205 |
+
output_box.change(None, attn, None, js=js)
|
206 |
+
|
207 |
+
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
|
208 |
+
elem_classes="note-text")
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
transformers
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
sacremoses
|
6 |
+
sentencepiece
|