maomao88's picture
add decoder self-attentions
b0d1726
raw
history blame
1.66 kB
import heapq
import pickle
def save_data(outputs, src_tokens, tgt_tokens, attn_scores):
data = {'outputs': outputs, 'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'attn_scores': attn_scores}
# Save to file
with open("data.pkl", "wb") as f:
pickle.dump(data, f)
def get_attn_list(attentions, layer_index):
avg_attn_list = []
for i in range(len(attentions)):
token_index = i # pick a token index from the output (1 to 18)
attn_tensor = attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
return avg_attn_list
def get_top_attns(avg_attn_list):
avg_attn_top = []
for i in range(len(avg_attn_list)):
# Get top 3 (index, value) pairs
top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1])
# get the indices and values of the source tokens
top_values = [round(val.item(), 2) for idx, val in top_3]
top_index = [idx for idx, val in top_3]
avg_attn_top.append({
"top_values": top_values,
"top_index": top_index
})
return avg_attn_top
# def get_encoder_attn_list(decoder_attentions, layer_index):
# avg_attn_list = []
#
# for i in range(len(decoder_attentions)):
# token_index = i # pick a token index from the output (1 to 18)
# attn_tensor = decoder_attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
# avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
#
# return avg_attn_list