|
import heapq |
|
import pickle |
|
|
|
|
|
def save_data(outputs, src_tokens, tgt_tokens, attn_scores): |
|
data = {'outputs': outputs, 'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'attn_scores': attn_scores} |
|
|
|
|
|
with open("data.pkl", "wb") as f: |
|
pickle.dump(data, f) |
|
|
|
|
|
def get_attn_list(attentions, layer_index): |
|
avg_attn_list = [] |
|
|
|
for i in range(len(attentions)): |
|
token_index = i |
|
attn_tensor = attentions[token_index][layer_index] |
|
avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) |
|
|
|
return avg_attn_list |
|
|
|
def get_top_attns(avg_attn_list): |
|
avg_attn_top = [] |
|
|
|
for i in range(len(avg_attn_list)): |
|
|
|
top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1]) |
|
|
|
|
|
top_values = [round(val.item(), 2) for idx, val in top_3] |
|
top_index = [idx for idx, val in top_3] |
|
|
|
avg_attn_top.append({ |
|
"top_values": top_values, |
|
"top_index": top_index |
|
}) |
|
|
|
return avg_attn_top |
|
|
|
|
|
|
|
def get_encoder_attn_list(encoder_attentions, layer_index): |
|
attn_tensor = encoder_attentions[layer_index] |
|
avg_attn_list = attn_tensor[0].mean(dim=0) |
|
|
|
return avg_attn_list |