maomao88 commited on
Commit
23ce706
·
1 Parent(s): 30396a8

update attn score display

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -42,7 +42,7 @@ def get_top_attns(avg_attn_list):
42
  top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
43
 
44
  # get the indices and values of the source tokens
45
- top_values = [val for idx, val in top_3]
46
  top_index = [idx for idx, val in top_3]
47
 
48
  avg_attn_top.append({
@@ -86,7 +86,14 @@ def render_attention_html(src_tokens, tgt_tokens):
86
  for i, token in enumerate(tgt_tokens):
87
  tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> '
88
 
89
- html = f'<div class="tgt-token-wrapper-text">Output Tokens</div><div class="tgt-token-wrapper">{tgt_html}</div><hr class="token-wrapper-seperator"><div class="src-token-wrapper-text">Input Tokens</div><div class="src-token-wrapper">{src_html}</div>'
 
 
 
 
 
 
 
90
  return html
91
 
92
 
@@ -102,6 +109,10 @@ css = """
102
  .tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;}
103
  .token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem}
104
  .note-text {margin-bottom: 3.5rem;}
 
 
 
 
105
  """
106
 
107
  js = """
@@ -112,6 +123,8 @@ function showCrossAttFun(attn_scores) {
112
 
113
  const targetTokens = document.querySelectorAll('.tgt-token');
114
 
 
 
115
  function onTgtHover(event, idx) {
116
  event.style.backgroundColor = "#C6E6E6";
117
 
@@ -119,18 +132,24 @@ function showCrossAttFun(attn_scores) {
119
  if (srcIdx0 < srcLen) {
120
  srcEl0 = scrTokens[srcIdx0]
121
  srcEl0.style.backgroundColor = "#FF8865"
 
 
122
  }
123
 
124
  srcIdx1 = attn_scores[idx]['top_index'][1]
125
  if (srcIdx1 < srcLen) {
126
  srcEl1 = scrTokens[srcIdx1]
127
  srcEl1.style.backgroundColor = "#FFD2C4"
 
 
128
  }
129
 
130
  srcIdx2 = attn_scores[idx]['top_index'][2]
131
  if (srcIdx2 < srcLen) {
132
  srcEl2 = scrTokens[srcIdx2]
133
  srcEl2.style.backgroundColor = "#FFF3F0"
 
 
134
  }
135
  }
136
 
@@ -141,10 +160,16 @@ function showCrossAttFun(attn_scores) {
141
  srcIdx2 = attn_scores[idx]['top_index'][2]
142
  srcEl0 = scrTokens[srcIdx0]
143
  srcEl0.style.backgroundColor = ""
 
 
144
  srcEl1 = scrTokens[srcIdx1]
145
  srcEl1.style.backgroundColor = ""
 
 
146
  srcEl2 = scrTokens[srcIdx2]
147
  srcEl2.style.backgroundColor = ""
 
 
148
  }
149
 
150
 
 
42
  top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
43
 
44
  # get the indices and values of the source tokens
45
+ top_values = [round(val.item(), 2) for idx, val in top_3]
46
  top_index = [idx for idx, val in top_3]
47
 
48
  avg_attn_top.append({
 
86
  for i, token in enumerate(tgt_tokens):
87
  tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> '
88
 
89
+ html = f"""
90
+ <div class="tgt-token-wrapper-text">Output Tokens</div>
91
+ <div class="tgt-token-wrapper">{tgt_html}</div>
92
+ <hr class="token-wrapper-seperator">
93
+ <div class="src-token-wrapper-text">Input Tokens</div>
94
+ <div class="src-token-wrapper">{src_html}</div>
95
+ <div class="scores"><span class="score-1 score"></span><span class="score-2 score"></span><span class="score-3 score"></span><div>
96
+ """
97
  return html
98
 
99
 
 
109
  .tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;}
110
  .token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem}
111
  .note-text {margin-bottom: 3.5rem;}
112
+ .scores { position: absolute; bottom: 0.75rem; color: rgb(113, 113, 122); right: 1rem;}
113
+ .score-1 { display: none; background-color: #FF8865; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;}
114
+ .score-2 { display: none; background-color: #FFD2C4; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;}
115
+ .score-3 { display: none; background-color: #FFF3F0; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;}
116
  """
117
 
118
  js = """
 
123
 
124
  const targetTokens = document.querySelectorAll('.tgt-token');
125
 
126
+ const scores = document.querySelectorAll('.score');
127
+
128
  function onTgtHover(event, idx) {
129
  event.style.backgroundColor = "#C6E6E6";
130
 
 
132
  if (srcIdx0 < srcLen) {
133
  srcEl0 = scrTokens[srcIdx0]
134
  srcEl0.style.backgroundColor = "#FF8865"
135
+ scores[0].textContent = attn_scores[idx]['top_values'][0]
136
+ scores[0].style.display = "initial";
137
  }
138
 
139
  srcIdx1 = attn_scores[idx]['top_index'][1]
140
  if (srcIdx1 < srcLen) {
141
  srcEl1 = scrTokens[srcIdx1]
142
  srcEl1.style.backgroundColor = "#FFD2C4"
143
+ scores[1].textContent = attn_scores[idx]['top_values'][1]
144
+ scores[1].style.display = "initial";
145
  }
146
 
147
  srcIdx2 = attn_scores[idx]['top_index'][2]
148
  if (srcIdx2 < srcLen) {
149
  srcEl2 = scrTokens[srcIdx2]
150
  srcEl2.style.backgroundColor = "#FFF3F0"
151
+ scores[2].textContent = attn_scores[idx]['top_values'][2]
152
+ scores[2].style.display = "initial";
153
  }
154
  }
155
 
 
160
  srcIdx2 = attn_scores[idx]['top_index'][2]
161
  srcEl0 = scrTokens[srcIdx0]
162
  srcEl0.style.backgroundColor = ""
163
+ scores[0].textContent = ""
164
+ scores[0].style.display = "none";
165
  srcEl1 = scrTokens[srcIdx1]
166
  srcEl1.style.backgroundColor = ""
167
+ scores[1].textContent = ""
168
+ scores[1].style.display = "none";
169
  srcEl2 = scrTokens[srcIdx2]
170
  srcEl2.style.backgroundColor = ""
171
+ scores[2].textContent = ""
172
+ scores[2].style.display = "none";
173
  }
174
 
175