import heapq import pickle def save_data(outputs, src_tokens, tgt_tokens, attn_scores): data = {'outputs': outputs, 'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'attn_scores': attn_scores} # Save to file with open("data.pkl", "wb") as f: pickle.dump(data, f) def get_attn_list(attentions, layer_index): avg_attn_list = [] for i in range(len(attentions)): token_index = i # pick a token index from the output (1 to 18) attn_tensor = attentions[token_index][layer_index] # shape: [1, 8, 1, 24] avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads return avg_attn_list def get_top_attns(avg_attn_list): avg_attn_top = [] for i in range(len(avg_attn_list)): # Get top 3 (index, value) pairs top_3 = heapq.nlargest(3, enumerate(avg_attn_list[i]), key=lambda x: x[1]) # get the indices and values of the source tokens top_values = [round(val.item(), 2) for idx, val in top_3] top_index = [idx for idx, val in top_3] avg_attn_top.append({ "top_values": top_values, "top_index": top_index }) return avg_attn_top def get_encoder_attn_list(encoder_attentions, layer_index): attn_tensor = encoder_attentions[layer_index] avg_attn_list = attn_tensor[0].mean(dim=0) return avg_attn_list