add encoder self-attention
Browse files- __pycache__/utils.cpython-313.pyc +0 -0
- app.py +105 -16
- utils.py +5 -9
__pycache__/utils.cpython-313.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-313.pyc and b/__pycache__/utils.cpython-313.pyc differ
|
|
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
import gradio as gr
|
4 |
-
from utils import save_data, get_attn_list, get_top_attns
|
5 |
|
6 |
|
7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
@@ -37,8 +37,11 @@ def translate_text(input_text):
|
|
37 |
avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index)
|
38 |
decoder_attn_scores = get_top_attns(avg_decoder_attn_list)
|
39 |
|
|
|
|
|
|
|
40 |
# save_data(outputs, src_tokens, tgt_tokens, attn_scores)
|
41 |
-
return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores
|
42 |
|
43 |
|
44 |
def render_cross_attn_html(src_tokens, tgt_tokens):
|
@@ -64,13 +67,17 @@ def render_cross_attn_html(src_tokens, tgt_tokens):
|
|
64 |
def render_encoder_decoder_attn_html(tokens, type):
|
65 |
# Build HTML for source and target tokens
|
66 |
tokens_html = ""
|
|
|
|
|
|
|
|
|
67 |
for i, token in enumerate(tokens):
|
68 |
-
tokens_html += f'<span class="token
|
69 |
|
70 |
html = f"""
|
71 |
<div class="tgt-token-wrapper-text">{type} Tokens</div>
|
72 |
<div class="tgt-token-wrapper">{tokens_html}</div>
|
73 |
-
<div class="scores"><span class="score-1
|
74 |
"""
|
75 |
return html
|
76 |
|
@@ -80,7 +87,7 @@ css = """
|
|
80 |
.output-html {padding-top: 1rem; padding-bottom: 1rem;}
|
81 |
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
|
82 |
.token {padding: .5rem; border-radius: 5px;}
|
83 |
-
.
|
84 |
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
85 |
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
86 |
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
|
@@ -94,18 +101,21 @@ css = """
|
|
94 |
"""
|
95 |
|
96 |
js = """
|
97 |
-
function showCrossAttFun(attn_scores, decoder_attn) {
|
98 |
|
99 |
const scrTokens = document.querySelectorAll('.src-token');
|
100 |
const srcLen = scrTokens.length - 1
|
101 |
const targetTokens = document.querySelectorAll('.tgt-token');
|
102 |
const scores = document.querySelectorAll('.score');
|
103 |
|
104 |
-
|
105 |
const decoderTokens = document.querySelectorAll('.decoder-token');
|
106 |
const decLen = decoderTokens.length - 1
|
107 |
const decoderScores = document.querySelectorAll('.decoder-score');
|
108 |
|
|
|
|
|
|
|
|
|
109 |
function onTgtHover(event, idx) {
|
110 |
event.style.backgroundColor = "#C6E6E6";
|
111 |
|
@@ -153,9 +163,7 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
153 |
scores[2].style.display = "none";
|
154 |
}
|
155 |
|
156 |
-
function
|
157 |
-
event.style.backgroundColor = "#C6E6E6";
|
158 |
-
|
159 |
idx0 = decoder_attn[idx]['top_index'][0]
|
160 |
if (idx0 < decLen) {
|
161 |
el0 = decoderTokens[idx0]
|
@@ -181,12 +189,12 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
181 |
}
|
182 |
|
183 |
for (i=idx+1; i < decoderTokens.length; i++) {
|
184 |
-
decoderTokens[i].style.color = "#
|
185 |
}
|
186 |
|
187 |
}
|
188 |
|
189 |
-
function
|
190 |
event.style.backgroundColor = "";
|
191 |
idx0 = decoder_attn[idx]['top_index'][0]
|
192 |
el0 = decoderTokens[idx0]
|
@@ -216,6 +224,62 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
216 |
}
|
217 |
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
targetTokens.forEach((el, idx) => {
|
220 |
el.addEventListener("mouseover", () => {
|
221 |
onTgtHover(el, idx)
|
@@ -230,13 +294,25 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
230 |
|
231 |
decoderTokens.forEach((el, idx) => {
|
232 |
el.addEventListener("mouseover", () => {
|
233 |
-
|
234 |
})
|
235 |
});
|
236 |
|
237 |
decoderTokens.forEach((el, idx) => {
|
238 |
el.addEventListener("mouseout", () => {
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
})
|
241 |
});
|
242 |
}
|
@@ -269,6 +345,7 @@ with gr.Blocks(css=css) as demo:
|
|
269 |
|
270 |
cross_attn = gr.JSON(value=[], visible=False)
|
271 |
decoder_attn = gr.JSON(value=[], visible=False)
|
|
|
272 |
|
273 |
gr.Markdown(
|
274 |
"""
|
@@ -281,6 +358,18 @@ with gr.Blocks(css=css) as demo:
|
|
281 |
with gr.Row(elem_classes="output-html-row"):
|
282 |
output_html = gr.HTML(label="Cross Attention", elem_classes="output-html")
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
gr.Markdown(
|
285 |
"""
|
286 |
## Check Self Attentions for Decoder
|
@@ -293,9 +382,9 @@ with gr.Blocks(css=css) as demo:
|
|
293 |
with gr.Row(elem_classes="output-html-row"):
|
294 |
decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
|
295 |
|
296 |
-
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn])
|
297 |
|
298 |
-
output_box.change(None, [cross_attn, decoder_attn], None, js=js)
|
299 |
|
300 |
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
|
301 |
elem_classes="note-text")
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
import gradio as gr
|
4 |
+
from utils import save_data, get_attn_list, get_top_attns, get_encoder_attn_list
|
5 |
|
6 |
|
7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
37 |
avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index)
|
38 |
decoder_attn_scores = get_top_attns(avg_decoder_attn_list)
|
39 |
|
40 |
+
avg_encoder_attn_list = get_encoder_attn_list(translated.encoder_attentions, layer_index)
|
41 |
+
encoder_attn_scores = get_top_attns(avg_encoder_attn_list)
|
42 |
+
|
43 |
# save_data(outputs, src_tokens, tgt_tokens, attn_scores)
|
44 |
+
return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores, render_encoder_decoder_attn_html(src_tokens, "Input"), encoder_attn_scores
|
45 |
|
46 |
|
47 |
def render_cross_attn_html(src_tokens, tgt_tokens):
|
|
|
67 |
def render_encoder_decoder_attn_html(tokens, type):
|
68 |
# Build HTML for source and target tokens
|
69 |
tokens_html = ""
|
70 |
+
className = "decoder"
|
71 |
+
if type == "Input":
|
72 |
+
className = "encoder"
|
73 |
+
|
74 |
for i, token in enumerate(tokens):
|
75 |
+
tokens_html += f'<span class="token {className}-token" data-index="{i}">{token}</span> '
|
76 |
|
77 |
html = f"""
|
78 |
<div class="tgt-token-wrapper-text">{type} Tokens</div>
|
79 |
<div class="tgt-token-wrapper">{tokens_html}</div>
|
80 |
+
<div class="scores"><span class="score-1 {className}-score"></span><span class="score-2 {className}-score"></span><span class="score-3 {className}-score"></span><div>
|
81 |
"""
|
82 |
return html
|
83 |
|
|
|
87 |
.output-html {padding-top: 1rem; padding-bottom: 1rem;}
|
88 |
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
|
89 |
.token {padding: .5rem; border-radius: 5px;}
|
90 |
+
.token {cursor: pointer;}
|
91 |
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
92 |
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
93 |
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
|
|
|
101 |
"""
|
102 |
|
103 |
js = """
|
104 |
+
function showCrossAttFun(attn_scores, decoder_attn, encoder_attn) {
|
105 |
|
106 |
const scrTokens = document.querySelectorAll('.src-token');
|
107 |
const srcLen = scrTokens.length - 1
|
108 |
const targetTokens = document.querySelectorAll('.tgt-token');
|
109 |
const scores = document.querySelectorAll('.score');
|
110 |
|
|
|
111 |
const decoderTokens = document.querySelectorAll('.decoder-token');
|
112 |
const decLen = decoderTokens.length - 1
|
113 |
const decoderScores = document.querySelectorAll('.decoder-score');
|
114 |
|
115 |
+
const encoderTokens = document.querySelectorAll('.encoder-token');
|
116 |
+
const encLen = encoderTokens.length - 1
|
117 |
+
const encoderScores = document.querySelectorAll('.encoder-score');
|
118 |
+
|
119 |
function onTgtHover(event, idx) {
|
120 |
event.style.backgroundColor = "#C6E6E6";
|
121 |
|
|
|
163 |
scores[2].style.display = "none";
|
164 |
}
|
165 |
|
166 |
+
function onDecodeHover(event, idx) {
|
|
|
|
|
167 |
idx0 = decoder_attn[idx]['top_index'][0]
|
168 |
if (idx0 < decLen) {
|
169 |
el0 = decoderTokens[idx0]
|
|
|
189 |
}
|
190 |
|
191 |
for (i=idx+1; i < decoderTokens.length; i++) {
|
192 |
+
decoderTokens[i].style.color = "#ccc9c9";
|
193 |
}
|
194 |
|
195 |
}
|
196 |
|
197 |
+
function outDecodeHover(event, idx) {
|
198 |
event.style.backgroundColor = "";
|
199 |
idx0 = decoder_attn[idx]['top_index'][0]
|
200 |
el0 = decoderTokens[idx0]
|
|
|
224 |
}
|
225 |
|
226 |
|
227 |
+
function onEncodeHover(event, idx) {
|
228 |
+
idx0 = encoder_attn[idx]['top_index'][0]
|
229 |
+
if (idx0 < encLen) {
|
230 |
+
el0 = encoderTokens[idx0]
|
231 |
+
el0.style.backgroundColor = "#89C6C6"
|
232 |
+
encoderScores[0].textContent = encoder_attn[idx]['top_values'][0]
|
233 |
+
encoderScores[0].style.display = "initial"
|
234 |
+
encoderScores[0].style.backgroundColor = "#89C6C6"
|
235 |
+
}
|
236 |
+
|
237 |
+
idx1 = encoder_attn[idx]['top_index'][1]
|
238 |
+
if (idx1 < encLen) {
|
239 |
+
el1 = encoderTokens[idx1]
|
240 |
+
el1.style.backgroundColor = "#C6E6E6"
|
241 |
+
encoderScores[1].textContent = encoder_attn[idx]['top_values'][1]
|
242 |
+
encoderScores[1].style.display = "initial"
|
243 |
+
encoderScores[1].style.backgroundColor = "#C6E6E6"
|
244 |
+
}
|
245 |
+
|
246 |
+
idx2 = encoder_attn[idx]['top_index'][2]
|
247 |
+
if (idx2 < encLen) {
|
248 |
+
el2 = encoderTokens[idx2]
|
249 |
+
el2.style.backgroundColor = "#E5F5F5"
|
250 |
+
encoderScores[2].textContent = encoder_attn[idx]['top_values'][2]
|
251 |
+
encoderScores[2].style.display = "initial"
|
252 |
+
encoderScores[2].style.backgroundColor = "#E5F5F5"
|
253 |
+
}
|
254 |
+
|
255 |
+
}
|
256 |
+
|
257 |
+
function outEncodeHover(event, idx) {
|
258 |
+
event.style.backgroundColor = "";
|
259 |
+
idx0 = encoder_attn[idx]['top_index'][0]
|
260 |
+
el0 = encoderTokens[idx0]
|
261 |
+
el0.style.backgroundColor = ""
|
262 |
+
encoderScores[0].textContent = ""
|
263 |
+
encoderScores[0].style.display = "none";
|
264 |
+
|
265 |
+
idx1 = encoder_attn[idx]['top_index'][1]
|
266 |
+
if (idx1 || idx1 == 0) {
|
267 |
+
el1 = encoderTokens[idx1]
|
268 |
+
el1.style.backgroundColor = ""
|
269 |
+
encoderScores[1].textContent = ""
|
270 |
+
encoderScores[1].style.display = "none";
|
271 |
+
}
|
272 |
+
|
273 |
+
idx2 = encoder_attn[idx]['top_index'][2]
|
274 |
+
if (idx2 || idx2 == 0) {
|
275 |
+
el2 = encoderTokens[idx2]
|
276 |
+
el2.style.backgroundColor = ""
|
277 |
+
encoderScores[2].textContent = ""
|
278 |
+
encoderScores[2].style.display = "none";
|
279 |
+
}
|
280 |
+
}
|
281 |
+
|
282 |
+
|
283 |
targetTokens.forEach((el, idx) => {
|
284 |
el.addEventListener("mouseover", () => {
|
285 |
onTgtHover(el, idx)
|
|
|
294 |
|
295 |
decoderTokens.forEach((el, idx) => {
|
296 |
el.addEventListener("mouseover", () => {
|
297 |
+
onDecodeHover(el, idx)
|
298 |
})
|
299 |
});
|
300 |
|
301 |
decoderTokens.forEach((el, idx) => {
|
302 |
el.addEventListener("mouseout", () => {
|
303 |
+
outDecodeHover(el, idx)
|
304 |
+
})
|
305 |
+
});
|
306 |
+
|
307 |
+
encoderTokens.forEach((el, idx) => {
|
308 |
+
el.addEventListener("mouseover", () => {
|
309 |
+
onEncodeHover(el, idx)
|
310 |
+
})
|
311 |
+
});
|
312 |
+
|
313 |
+
encoderTokens.forEach((el, idx) => {
|
314 |
+
el.addEventListener("mouseout", () => {
|
315 |
+
outEncodeHover(el, idx)
|
316 |
})
|
317 |
});
|
318 |
}
|
|
|
345 |
|
346 |
cross_attn = gr.JSON(value=[], visible=False)
|
347 |
decoder_attn = gr.JSON(value=[], visible=False)
|
348 |
+
encoder_attn = gr.JSON(value=[], visible=False)
|
349 |
|
350 |
gr.Markdown(
|
351 |
"""
|
|
|
358 |
with gr.Row(elem_classes="output-html-row"):
|
359 |
output_html = gr.HTML(label="Cross Attention", elem_classes="output-html")
|
360 |
|
361 |
+
gr.Markdown(
|
362 |
+
"""
|
363 |
+
## Check Self Attentions for Encoder
|
364 |
+
Hover your mouse over an input (English) word/token to see which word/token it is self-attending to.
|
365 |
+
""",
|
366 |
+
elem_classes="output-html-desc"
|
367 |
+
)
|
368 |
+
|
369 |
+
with gr.Row(elem_classes="output-html-row"):
|
370 |
+
encoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
|
371 |
+
|
372 |
+
|
373 |
gr.Markdown(
|
374 |
"""
|
375 |
## Check Self Attentions for Decoder
|
|
|
382 |
with gr.Row(elem_classes="output-html-row"):
|
383 |
decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
|
384 |
|
385 |
+
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn, encoder_output_html, encoder_attn])
|
386 |
|
387 |
+
output_box.change(None, [cross_attn, decoder_attn, encoder_attn], None, js=js)
|
388 |
|
389 |
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
|
390 |
elem_classes="note-text")
|
utils.py
CHANGED
@@ -40,12 +40,8 @@ def get_top_attns(avg_attn_list):
|
|
40 |
|
41 |
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
# attn_tensor = decoder_attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
|
49 |
-
# avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
|
50 |
-
#
|
51 |
-
# return avg_attn_list
|
|
|
40 |
|
41 |
|
42 |
|
43 |
+
def get_encoder_attn_list(encoder_attentions, layer_index):
|
44 |
+
attn_tensor = encoder_attentions[layer_index]
|
45 |
+
avg_attn_list = attn_tensor[0].mean(dim=0)
|
46 |
+
|
47 |
+
return avg_attn_list
|
|
|
|
|
|
|
|