update attn score display
Browse files
app.py
CHANGED
@@ -42,7 +42,7 @@ def get_top_attns(avg_attn_list):
|
|
42 |
top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
|
43 |
|
44 |
# get the indices and values of the source tokens
|
45 |
-
top_values = [val for idx, val in top_3]
|
46 |
top_index = [idx for idx, val in top_3]
|
47 |
|
48 |
avg_attn_top.append({
|
@@ -86,7 +86,14 @@ def render_attention_html(src_tokens, tgt_tokens):
|
|
86 |
for i, token in enumerate(tgt_tokens):
|
87 |
tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> '
|
88 |
|
89 |
-
html = f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
return html
|
91 |
|
92 |
|
@@ -102,6 +109,10 @@ css = """
|
|
102 |
.tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;}
|
103 |
.token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem}
|
104 |
.note-text {margin-bottom: 3.5rem;}
|
|
|
|
|
|
|
|
|
105 |
"""
|
106 |
|
107 |
js = """
|
@@ -112,6 +123,8 @@ function showCrossAttFun(attn_scores) {
|
|
112 |
|
113 |
const targetTokens = document.querySelectorAll('.tgt-token');
|
114 |
|
|
|
|
|
115 |
function onTgtHover(event, idx) {
|
116 |
event.style.backgroundColor = "#C6E6E6";
|
117 |
|
@@ -119,18 +132,24 @@ function showCrossAttFun(attn_scores) {
|
|
119 |
if (srcIdx0 < srcLen) {
|
120 |
srcEl0 = scrTokens[srcIdx0]
|
121 |
srcEl0.style.backgroundColor = "#FF8865"
|
|
|
|
|
122 |
}
|
123 |
|
124 |
srcIdx1 = attn_scores[idx]['top_index'][1]
|
125 |
if (srcIdx1 < srcLen) {
|
126 |
srcEl1 = scrTokens[srcIdx1]
|
127 |
srcEl1.style.backgroundColor = "#FFD2C4"
|
|
|
|
|
128 |
}
|
129 |
|
130 |
srcIdx2 = attn_scores[idx]['top_index'][2]
|
131 |
if (srcIdx2 < srcLen) {
|
132 |
srcEl2 = scrTokens[srcIdx2]
|
133 |
srcEl2.style.backgroundColor = "#FFF3F0"
|
|
|
|
|
134 |
}
|
135 |
}
|
136 |
|
@@ -141,10 +160,16 @@ function showCrossAttFun(attn_scores) {
|
|
141 |
srcIdx2 = attn_scores[idx]['top_index'][2]
|
142 |
srcEl0 = scrTokens[srcIdx0]
|
143 |
srcEl0.style.backgroundColor = ""
|
|
|
|
|
144 |
srcEl1 = scrTokens[srcIdx1]
|
145 |
srcEl1.style.backgroundColor = ""
|
|
|
|
|
146 |
srcEl2 = scrTokens[srcIdx2]
|
147 |
srcEl2.style.backgroundColor = ""
|
|
|
|
|
148 |
}
|
149 |
|
150 |
|
|
|
42 |
top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
|
43 |
|
44 |
# get the indices and values of the source tokens
|
45 |
+
top_values = [round(val.item(), 2) for idx, val in top_3]
|
46 |
top_index = [idx for idx, val in top_3]
|
47 |
|
48 |
avg_attn_top.append({
|
|
|
86 |
for i, token in enumerate(tgt_tokens):
|
87 |
tgt_html += f'<span class="token tgt-token" data-index="{i}">{token}</span> '
|
88 |
|
89 |
+
html = f"""
|
90 |
+
<div class="tgt-token-wrapper-text">Output Tokens</div>
|
91 |
+
<div class="tgt-token-wrapper">{tgt_html}</div>
|
92 |
+
<hr class="token-wrapper-seperator">
|
93 |
+
<div class="src-token-wrapper-text">Input Tokens</div>
|
94 |
+
<div class="src-token-wrapper">{src_html}</div>
|
95 |
+
<div class="scores"><span class="score-1 score"></span><span class="score-2 score"></span><span class="score-3 score"></span><div>
|
96 |
+
"""
|
97 |
return html
|
98 |
|
99 |
|
|
|
109 |
.tgt-token-wrapper-text {position: absolute; top: .75rem; color: #71717a;}
|
110 |
.token-wrapper-seperator {margin-top: 1rem; margin-bottom: 1rem}
|
111 |
.note-text {margin-bottom: 3.5rem;}
|
112 |
+
.scores { position: absolute; bottom: 0.75rem; color: rgb(113, 113, 122); right: 1rem;}
|
113 |
+
.score-1 { display: none; background-color: #FF8865; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;}
|
114 |
+
.score-2 { display: none; background-color: #FFD2C4; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;}
|
115 |
+
.score-3 { display: none; background-color: #FFF3F0; padding: .5rem; border-radius: var(--block-radius); margin-right: .75rem;}
|
116 |
"""
|
117 |
|
118 |
js = """
|
|
|
123 |
|
124 |
const targetTokens = document.querySelectorAll('.tgt-token');
|
125 |
|
126 |
+
const scores = document.querySelectorAll('.score');
|
127 |
+
|
128 |
function onTgtHover(event, idx) {
|
129 |
event.style.backgroundColor = "#C6E6E6";
|
130 |
|
|
|
132 |
if (srcIdx0 < srcLen) {
|
133 |
srcEl0 = scrTokens[srcIdx0]
|
134 |
srcEl0.style.backgroundColor = "#FF8865"
|
135 |
+
scores[0].textContent = attn_scores[idx]['top_values'][0]
|
136 |
+
scores[0].style.display = "initial";
|
137 |
}
|
138 |
|
139 |
srcIdx1 = attn_scores[idx]['top_index'][1]
|
140 |
if (srcIdx1 < srcLen) {
|
141 |
srcEl1 = scrTokens[srcIdx1]
|
142 |
srcEl1.style.backgroundColor = "#FFD2C4"
|
143 |
+
scores[1].textContent = attn_scores[idx]['top_values'][1]
|
144 |
+
scores[1].style.display = "initial";
|
145 |
}
|
146 |
|
147 |
srcIdx2 = attn_scores[idx]['top_index'][2]
|
148 |
if (srcIdx2 < srcLen) {
|
149 |
srcEl2 = scrTokens[srcIdx2]
|
150 |
srcEl2.style.backgroundColor = "#FFF3F0"
|
151 |
+
scores[2].textContent = attn_scores[idx]['top_values'][2]
|
152 |
+
scores[2].style.display = "initial";
|
153 |
}
|
154 |
}
|
155 |
|
|
|
160 |
srcIdx2 = attn_scores[idx]['top_index'][2]
|
161 |
srcEl0 = scrTokens[srcIdx0]
|
162 |
srcEl0.style.backgroundColor = ""
|
163 |
+
scores[0].textContent = ""
|
164 |
+
scores[0].style.display = "none";
|
165 |
srcEl1 = scrTokens[srcIdx1]
|
166 |
srcEl1.style.backgroundColor = ""
|
167 |
+
scores[1].textContent = ""
|
168 |
+
scores[1].style.display = "none";
|
169 |
srcEl2 = scrTokens[srcIdx2]
|
170 |
srcEl2.style.backgroundColor = ""
|
171 |
+
scores[2].textContent = ""
|
172 |
+
scores[2].style.display = "none";
|
173 |
}
|
174 |
|
175 |
|