maomao88 commited on
Commit
93559de
·
1 Parent(s): 77a9363

add encoder self-attention

Browse files
Files changed (3) hide show
  1. __pycache__/utils.cpython-313.pyc +0 -0
  2. app.py +105 -16
  3. utils.py +5 -9
__pycache__/utils.cpython-313.pyc CHANGED
Binary files a/__pycache__/utils.cpython-313.pyc and b/__pycache__/utils.cpython-313.pyc differ
 
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from torch import nn
3
  import gradio as gr
4
- from utils import save_data, get_attn_list, get_top_attns
5
 
6
 
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
@@ -37,8 +37,11 @@ def translate_text(input_text):
37
  avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index)
38
  decoder_attn_scores = get_top_attns(avg_decoder_attn_list)
39
 
 
 
 
40
  # save_data(outputs, src_tokens, tgt_tokens, attn_scores)
41
- return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores
42
 
43
 
44
  def render_cross_attn_html(src_tokens, tgt_tokens):
@@ -64,13 +67,17 @@ def render_cross_attn_html(src_tokens, tgt_tokens):
64
  def render_encoder_decoder_attn_html(tokens, type):
65
  # Build HTML for source and target tokens
66
  tokens_html = ""
 
 
 
 
67
  for i, token in enumerate(tokens):
68
- tokens_html += f'<span class="token decoder-token" data-index="{i}">{token}</span> '
69
 
70
  html = f"""
71
  <div class="tgt-token-wrapper-text">{type} Tokens</div>
72
  <div class="tgt-token-wrapper">{tokens_html}</div>
73
- <div class="scores"><span class="score-1 decoder-score"></span><span class="score-2 decoder-score"></span><span class="score-3 decoder-score"></span><div>
74
  """
75
  return html
76
 
@@ -80,7 +87,7 @@ css = """
80
  .output-html {padding-top: 1rem; padding-bottom: 1rem;}
81
  .output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
82
  .token {padding: .5rem; border-radius: 5px;}
83
- .tgt-token {cursor: pointer;}
84
  .tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
85
  .src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
86
  .src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
@@ -94,18 +101,21 @@ css = """
94
  """
95
 
96
  js = """
97
- function showCrossAttFun(attn_scores, decoder_attn) {
98
 
99
  const scrTokens = document.querySelectorAll('.src-token');
100
  const srcLen = scrTokens.length - 1
101
  const targetTokens = document.querySelectorAll('.tgt-token');
102
  const scores = document.querySelectorAll('.score');
103
 
104
-
105
  const decoderTokens = document.querySelectorAll('.decoder-token');
106
  const decLen = decoderTokens.length - 1
107
  const decoderScores = document.querySelectorAll('.decoder-score');
108
 
 
 
 
 
109
  function onTgtHover(event, idx) {
110
  event.style.backgroundColor = "#C6E6E6";
111
 
@@ -153,9 +163,7 @@ function showCrossAttFun(attn_scores, decoder_attn) {
153
  scores[2].style.display = "none";
154
  }
155
 
156
- function onSelfHover(event, idx) {
157
- event.style.backgroundColor = "#C6E6E6";
158
-
159
  idx0 = decoder_attn[idx]['top_index'][0]
160
  if (idx0 < decLen) {
161
  el0 = decoderTokens[idx0]
@@ -181,12 +189,12 @@ function showCrossAttFun(attn_scores, decoder_attn) {
181
  }
182
 
183
  for (i=idx+1; i < decoderTokens.length; i++) {
184
- decoderTokens[i].style.color = "#aaa8a8";
185
  }
186
 
187
  }
188
 
189
- function outSelfHover(event, idx) {
190
  event.style.backgroundColor = "";
191
  idx0 = decoder_attn[idx]['top_index'][0]
192
  el0 = decoderTokens[idx0]
@@ -216,6 +224,62 @@ function showCrossAttFun(attn_scores, decoder_attn) {
216
  }
217
 
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  targetTokens.forEach((el, idx) => {
220
  el.addEventListener("mouseover", () => {
221
  onTgtHover(el, idx)
@@ -230,13 +294,25 @@ function showCrossAttFun(attn_scores, decoder_attn) {
230
 
231
  decoderTokens.forEach((el, idx) => {
232
  el.addEventListener("mouseover", () => {
233
- onSelfHover(el, idx)
234
  })
235
  });
236
 
237
  decoderTokens.forEach((el, idx) => {
238
  el.addEventListener("mouseout", () => {
239
- outSelfHover(el, idx)
 
 
 
 
 
 
 
 
 
 
 
 
240
  })
241
  });
242
  }
@@ -269,6 +345,7 @@ with gr.Blocks(css=css) as demo:
269
 
270
  cross_attn = gr.JSON(value=[], visible=False)
271
  decoder_attn = gr.JSON(value=[], visible=False)
 
272
 
273
  gr.Markdown(
274
  """
@@ -281,6 +358,18 @@ with gr.Blocks(css=css) as demo:
281
  with gr.Row(elem_classes="output-html-row"):
282
  output_html = gr.HTML(label="Cross Attention", elem_classes="output-html")
283
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  gr.Markdown(
285
  """
286
  ## Check Self Attentions for Decoder
@@ -293,9 +382,9 @@ with gr.Blocks(css=css) as demo:
293
  with gr.Row(elem_classes="output-html-row"):
294
  decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
295
 
296
- translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn])
297
 
298
- output_box.change(None, [cross_attn, decoder_attn], None, js=js)
299
 
300
  gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
301
  elem_classes="note-text")
 
1
  import torch
2
  from torch import nn
3
  import gradio as gr
4
+ from utils import save_data, get_attn_list, get_top_attns, get_encoder_attn_list
5
 
6
 
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
37
  avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index)
38
  decoder_attn_scores = get_top_attns(avg_decoder_attn_list)
39
 
40
+ avg_encoder_attn_list = get_encoder_attn_list(translated.encoder_attentions, layer_index)
41
+ encoder_attn_scores = get_top_attns(avg_encoder_attn_list)
42
+
43
  # save_data(outputs, src_tokens, tgt_tokens, attn_scores)
44
+ return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores, render_encoder_decoder_attn_html(src_tokens, "Input"), encoder_attn_scores
45
 
46
 
47
  def render_cross_attn_html(src_tokens, tgt_tokens):
 
67
  def render_encoder_decoder_attn_html(tokens, type):
68
  # Build HTML for source and target tokens
69
  tokens_html = ""
70
+ className = "decoder"
71
+ if type == "Input":
72
+ className = "encoder"
73
+
74
  for i, token in enumerate(tokens):
75
+ tokens_html += f'<span class="token {className}-token" data-index="{i}">{token}</span> '
76
 
77
  html = f"""
78
  <div class="tgt-token-wrapper-text">{type} Tokens</div>
79
  <div class="tgt-token-wrapper">{tokens_html}</div>
80
+ <div class="scores"><span class="score-1 {className}-score"></span><span class="score-2 {className}-score"></span><span class="score-3 {className}-score"></span><div>
81
  """
82
  return html
83
 
 
87
  .output-html {padding-top: 1rem; padding-bottom: 1rem;}
88
  .output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
89
  .token {padding: .5rem; border-radius: 5px;}
90
+ .token {cursor: pointer;}
91
  .tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
92
  .src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
93
  .src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
 
101
  """
102
 
103
  js = """
104
+ function showCrossAttFun(attn_scores, decoder_attn, encoder_attn) {
105
 
106
  const scrTokens = document.querySelectorAll('.src-token');
107
  const srcLen = scrTokens.length - 1
108
  const targetTokens = document.querySelectorAll('.tgt-token');
109
  const scores = document.querySelectorAll('.score');
110
 
 
111
  const decoderTokens = document.querySelectorAll('.decoder-token');
112
  const decLen = decoderTokens.length - 1
113
  const decoderScores = document.querySelectorAll('.decoder-score');
114
 
115
+ const encoderTokens = document.querySelectorAll('.encoder-token');
116
+ const encLen = encoderTokens.length - 1
117
+ const encoderScores = document.querySelectorAll('.encoder-score');
118
+
119
  function onTgtHover(event, idx) {
120
  event.style.backgroundColor = "#C6E6E6";
121
 
 
163
  scores[2].style.display = "none";
164
  }
165
 
166
+ function onDecodeHover(event, idx) {
 
 
167
  idx0 = decoder_attn[idx]['top_index'][0]
168
  if (idx0 < decLen) {
169
  el0 = decoderTokens[idx0]
 
189
  }
190
 
191
  for (i=idx+1; i < decoderTokens.length; i++) {
192
+ decoderTokens[i].style.color = "#ccc9c9";
193
  }
194
 
195
  }
196
 
197
+ function outDecodeHover(event, idx) {
198
  event.style.backgroundColor = "";
199
  idx0 = decoder_attn[idx]['top_index'][0]
200
  el0 = decoderTokens[idx0]
 
224
  }
225
 
226
 
227
+ function onEncodeHover(event, idx) {
228
+ idx0 = encoder_attn[idx]['top_index'][0]
229
+ if (idx0 < encLen) {
230
+ el0 = encoderTokens[idx0]
231
+ el0.style.backgroundColor = "#89C6C6"
232
+ encoderScores[0].textContent = encoder_attn[idx]['top_values'][0]
233
+ encoderScores[0].style.display = "initial"
234
+ encoderScores[0].style.backgroundColor = "#89C6C6"
235
+ }
236
+
237
+ idx1 = encoder_attn[idx]['top_index'][1]
238
+ if (idx1 < encLen) {
239
+ el1 = encoderTokens[idx1]
240
+ el1.style.backgroundColor = "#C6E6E6"
241
+ encoderScores[1].textContent = encoder_attn[idx]['top_values'][1]
242
+ encoderScores[1].style.display = "initial"
243
+ encoderScores[1].style.backgroundColor = "#C6E6E6"
244
+ }
245
+
246
+ idx2 = encoder_attn[idx]['top_index'][2]
247
+ if (idx2 < encLen) {
248
+ el2 = encoderTokens[idx2]
249
+ el2.style.backgroundColor = "#E5F5F5"
250
+ encoderScores[2].textContent = encoder_attn[idx]['top_values'][2]
251
+ encoderScores[2].style.display = "initial"
252
+ encoderScores[2].style.backgroundColor = "#E5F5F5"
253
+ }
254
+
255
+ }
256
+
257
+ function outEncodeHover(event, idx) {
258
+ event.style.backgroundColor = "";
259
+ idx0 = encoder_attn[idx]['top_index'][0]
260
+ el0 = encoderTokens[idx0]
261
+ el0.style.backgroundColor = ""
262
+ encoderScores[0].textContent = ""
263
+ encoderScores[0].style.display = "none";
264
+
265
+ idx1 = encoder_attn[idx]['top_index'][1]
266
+ if (idx1 || idx1 == 0) {
267
+ el1 = encoderTokens[idx1]
268
+ el1.style.backgroundColor = ""
269
+ encoderScores[1].textContent = ""
270
+ encoderScores[1].style.display = "none";
271
+ }
272
+
273
+ idx2 = encoder_attn[idx]['top_index'][2]
274
+ if (idx2 || idx2 == 0) {
275
+ el2 = encoderTokens[idx2]
276
+ el2.style.backgroundColor = ""
277
+ encoderScores[2].textContent = ""
278
+ encoderScores[2].style.display = "none";
279
+ }
280
+ }
281
+
282
+
283
  targetTokens.forEach((el, idx) => {
284
  el.addEventListener("mouseover", () => {
285
  onTgtHover(el, idx)
 
294
 
295
  decoderTokens.forEach((el, idx) => {
296
  el.addEventListener("mouseover", () => {
297
+ onDecodeHover(el, idx)
298
  })
299
  });
300
 
301
  decoderTokens.forEach((el, idx) => {
302
  el.addEventListener("mouseout", () => {
303
+ outDecodeHover(el, idx)
304
+ })
305
+ });
306
+
307
+ encoderTokens.forEach((el, idx) => {
308
+ el.addEventListener("mouseover", () => {
309
+ onEncodeHover(el, idx)
310
+ })
311
+ });
312
+
313
+ encoderTokens.forEach((el, idx) => {
314
+ el.addEventListener("mouseout", () => {
315
+ outEncodeHover(el, idx)
316
  })
317
  });
318
  }
 
345
 
346
  cross_attn = gr.JSON(value=[], visible=False)
347
  decoder_attn = gr.JSON(value=[], visible=False)
348
+ encoder_attn = gr.JSON(value=[], visible=False)
349
 
350
  gr.Markdown(
351
  """
 
358
  with gr.Row(elem_classes="output-html-row"):
359
  output_html = gr.HTML(label="Cross Attention", elem_classes="output-html")
360
 
361
+ gr.Markdown(
362
+ """
363
+ ## Check Self Attentions for Encoder
364
+ Hover your mouse over an input (English) word/token to see which word/token it is self-attending to.
365
+ """,
366
+ elem_classes="output-html-desc"
367
+ )
368
+
369
+ with gr.Row(elem_classes="output-html-row"):
370
+ encoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
371
+
372
+
373
  gr.Markdown(
374
  """
375
  ## Check Self Attentions for Decoder
 
382
  with gr.Row(elem_classes="output-html-row"):
383
  decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
384
 
385
+ translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn, encoder_output_html, encoder_attn])
386
 
387
+ output_box.change(None, [cross_attn, decoder_attn, encoder_attn], None, js=js)
388
 
389
  gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
390
  elem_classes="note-text")
utils.py CHANGED
@@ -40,12 +40,8 @@ def get_top_attns(avg_attn_list):
40
 
41
 
42
 
43
- # def get_encoder_attn_list(decoder_attentions, layer_index):
44
- # avg_attn_list = []
45
- #
46
- # for i in range(len(decoder_attentions)):
47
- # token_index = i # pick a token index from the output (1 to 18)
48
- # attn_tensor = decoder_attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
49
- # avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
50
- #
51
- # return avg_attn_list
 
40
 
41
 
42
 
43
+ def get_encoder_attn_list(encoder_attentions, layer_index):
44
+ attn_tensor = encoder_attentions[layer_index]
45
+ avg_attn_list = attn_tensor[0].mean(dim=0)
46
+
47
+ return avg_attn_list