maomao88 commited on
Commit
30396a8
·
1 Parent(s): 092931e

first commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +212 -0
  3. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import gradio as gr
4
+ import heapq
5
+ import pickle
6
+
7
+
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+
10
+ model_name = "Helsinki-NLP/opus-mt-en-zh"
11
+ # model_name = "Helsinki-NLP/opus-mt-zh-en"
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
+
16
+ layer_index = model.config.decoder_layers - 1 # last decoder layer index
17
+
18
+
19
+ def save_data(outputs, src_tokens, tgt_tokens, attn_scores):
20
+ data = {'outputs': outputs, 'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'attn_scores': attn_scores}
21
+
22
+ # Save to file
23
+ with open("data.pkl", "wb") as f:
24
+ pickle.dump(data, f)
25
+
26
+
27
+ def get_attn_list(cross_attentions):
28
+ avg_attn_list = []
29
+
30
+ for i in range(len(cross_attentions)):
31
+ token_index = i # pick a token index from the output (1 to 18)
32
+ attn_tensor = cross_attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
33
+ avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
34
+
35
+ return avg_attn_list
36
+
37
+ def get_top_attns(avg_attn_list):
38
+ avg_attn_top = []
39
+
40
+ for i in range(len(avg_attn_list)):
41
+ # Get top 3 (index, value) pairs
42
+ top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
43
+
44
+ # get the indices and values of the source tokens
45
+ top_values = [val for idx, val in top_3]
46
+ top_index = [idx for idx, val in top_3]
47
+
48
+ avg_attn_top.append({
49
+ "top_values": top_values,
50
+ "top_index": top_index
51
+ })
52
+
53
+ return avg_attn_top
54
+
55
+
56
+ # Define translation function
57
+ def translate_text(input_text):
58
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True)
59
+ with torch.no_grad():
60
+ translated = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True,
61
+ num_beams=1)
62
+
63
+ outputs = tokenizer.decode(translated.sequences[0][1:][:-1])
64
+
65
+ # Decode tokens
66
+ src_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
67
+ src_tokens = [token.lstrip('▁_') for token in src_tokens]
68
+
69
+ tgt_tokens = tokenizer.convert_ids_to_tokens(translated.sequences[0])[1:]
70
+ tgt_tokens = [token.lstrip('▁_') for token in tgt_tokens]
71
+
72
+ avg_attn_list = get_attn_list(translated.cross_attentions)
73
+ attn_scores = get_top_attns(avg_attn_list)
74
+
75
+ # save_data(outputs, src_tokens, tgt_tokens, attn_scores)
76
+ return outputs, render_attention_html(src_tokens, tgt_tokens), attn_scores
77
+
78
+
79
+ def render_attention_html(src_tokens, tgt_tokens):
80
+ # Build HTML for source and target tokens
81
+ src_html = ""
82
+ for i, token in enumerate(src_tokens):
83
+ src_html += f'<span class="token src-token" data-index="{i}">{token}</span> '
84
+
85
+ tgt_html = ""
86
+ for i, token in enumerate(tgt_tokens):
87
+ tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> '
88
+
89
+ html = f'<div class="tgt-token-wrapper-text">Output Tokens</div><div class="tgt-token-wrapper">{tgt_html}</div><hr class="token-wrapper-seperator"><div class="src-token-wrapper-text">Input Tokens</div><div class="src-token-wrapper">{src_html}</div>'
90
+ return html
91
+
92
+
93
+ css = """
94
+ .output-html-desc {padding-top: 1rem}
95
+ .output-html {padding-top: 1rem; padding-bottom: 1rem;}
96
+ .output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
97
+ .token {padding: .5rem; border-radius: 5px;}
98
+ .tgt-token {cursor: pointer;}
99
+ .tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
100
+ .src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
101
+ .src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
102
+ .tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;}
103
+ .token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem}
104
+ .note-text {margin-bottom: 3.5rem;}
105
+ """
106
+
107
+ js = """
108
+ function showCrossAttFun(attn_scores) {
109
+
110
+ const scrTokens = document.querySelectorAll('.src-token');
111
+ const srcLen = scrTokens.length - 1
112
+
113
+ const targetTokens = document.querySelectorAll('.tgt-token');
114
+
115
+ function onTgtHover(event, idx) {
116
+ event.style.backgroundColor = "#C6E6E6";
117
+
118
+ srcIdx0 = attn_scores[idx]['top_index'][0]
119
+ if (srcIdx0 < srcLen) {
120
+ srcEl0 = scrTokens[srcIdx0]
121
+ srcEl0.style.backgroundColor = "#FF8865"
122
+ }
123
+
124
+ srcIdx1 = attn_scores[idx]['top_index'][1]
125
+ if (srcIdx1 < srcLen) {
126
+ srcEl1 = scrTokens[srcIdx1]
127
+ srcEl1.style.backgroundColor = "#FFD2C4"
128
+ }
129
+
130
+ srcIdx2 = attn_scores[idx]['top_index'][2]
131
+ if (srcIdx2 < srcLen) {
132
+ srcEl2 = scrTokens[srcIdx2]
133
+ srcEl2.style.backgroundColor = "#FFF3F0"
134
+ }
135
+ }
136
+
137
+ function outHover(event, idx) {
138
+ event.style.backgroundColor = "";
139
+ srcIdx0 = attn_scores[idx]['top_index'][0]
140
+ srcIdx1 = attn_scores[idx]['top_index'][1]
141
+ srcIdx2 = attn_scores[idx]['top_index'][2]
142
+ srcEl0 = scrTokens[srcIdx0]
143
+ srcEl0.style.backgroundColor = ""
144
+ srcEl1 = scrTokens[srcIdx1]
145
+ srcEl1.style.backgroundColor = ""
146
+ srcEl2 = scrTokens[srcIdx2]
147
+ srcEl2.style.backgroundColor = ""
148
+ }
149
+
150
+
151
+ targetTokens.forEach((el, idx) => {
152
+ el.addEventListener("mouseover", () => {
153
+ onTgtHover(el, idx)
154
+ })
155
+ });
156
+
157
+ targetTokens.forEach((el, idx) => {
158
+ el.addEventListener("mouseout", () => {
159
+ outHover(el, idx)
160
+ })
161
+ });
162
+ }
163
+ """
164
+
165
+
166
+ # Gradio Interface
167
+ with gr.Blocks(css=css) as demo:
168
+ gr.Markdown("""
169
+ ## 🕸️ Visualize Cross Attention between Translated Text (English to Chinese)
170
+ Cross attention is a key component in transformers, where a sequence (English Text) can attend to another sequence’s information (Chinese Text).
171
+ You can check the cross attention of the translated text in the lower section of the page.
172
+ """)
173
+
174
+ with gr.Row():
175
+ with gr.Column():
176
+ input_box = gr.Textbox(lines=4, label="Input Text (English)")
177
+ with gr.Column():
178
+ output_box = gr.Textbox(lines=4, label="Translated Text (Chinese)")
179
+
180
+ # Examples Section
181
+ gr.Examples(
182
+ examples=[
183
+ ["They heard the click of the front door and knew that the Dursleys had left the house."],
184
+ ["Azkaban was a fortress where the most dangerous dark wizards were held, guarded by creatures called Dementors."]
185
+ ],
186
+ inputs=[input_box]
187
+ )
188
+
189
+ translate_button = gr.Button("Translate", variant="primary")
190
+
191
+ attn = gr.JSON(value=[], visible=False)
192
+
193
+ gr.Markdown(
194
+ """
195
+ ## Check Cross Attentions
196
+ Hover your mouse over an output (Chinese) word/token to see which input (English) word/token it is attending to.
197
+ """,
198
+ elem_classes="output-html-desc"
199
+ )
200
+ with gr.Row(elem_classes="output-html-row"):
201
+ output_html = gr.HTML(label="Translated Text (HTML)", elem_classes="output-html")
202
+
203
+ translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, attn])
204
+
205
+ output_box.change(None, attn, None, js=js)
206
+
207
+ gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
208
+ elem_classes="note-text")
209
+
210
+
211
+
212
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ torchvision
5
+ sacremoses
6
+ sentencepiece