maomao88's picture
add encoder self-attention
93559de
raw
history blame
1.4 kB
import heapq
import pickle
def save_data(outputs, src_tokens, tgt_tokens, attn_scores):
data = {'outputs': outputs, 'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'attn_scores': attn_scores}
# Save to file
with open("data.pkl", "wb") as f:
pickle.dump(data, f)
def get_attn_list(attentions, layer_index):
avg_attn_list = []
for i in range(len(attentions)):
token_index = i # pick a token index from the output (1 to 18)
attn_tensor = attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
return avg_attn_list
def get_top_attns(avg_attn_list):
avg_attn_top = []
for i in range(len(avg_attn_list)):
# Get top 3 (index, value) pairs
top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
# get the indices and values of the source tokens
top_values = [round(val.item(), 2) for idx, val in top_3]
top_index = [idx for idx, val in top_3]
avg_attn_top.append({
"top_values": top_values,
"top_index": top_index
})
return avg_attn_top
def get_encoder_attn_list(encoder_attentions, layer_index):
attn_tensor = encoder_attentions[layer_index]
avg_attn_list = attn_tensor[0].mean(dim=0)
return avg_attn_list