Spaces:
Running
Running
import time | |
import numpy as np | |
import torch | |
from PIL import Image | |
import glob | |
import sys | |
import argparse | |
import datetime | |
import json | |
from pathlib import Path | |
from llava.hook import HookManager | |
def init_hookmanager(module): | |
module.hook_manager = HookManager() | |
class MaskHookLogger(object): | |
def __init__(self, model, device): | |
self.current_layer = 0 | |
self.device = device | |
self.attns = [] | |
self.projected_attns = [] | |
self.image_embed_range = [] | |
self.index = [] | |
self.model = model | |
def compute_attentions(self, ret): | |
assert len(self.image_embed_range) > 0 | |
st, ed = self.image_embed_range[-1] | |
image_attention = ret[:,:,-1,st:ed].detach() | |
image_attention = image_attention.mean(dim = 1) | |
self.attns.append(image_attention) # [b, k] | |
return ret | |
def compute_projected_attentions(self, ret): | |
assert len(self.image_embed_range) > 0 | |
st, ed = self.image_embed_range[-1] | |
image_attention = ret[:,-1,st:ed].detach() # [b, k, d] | |
self.projected_attns.append(image_attention) # [b, k, d] | |
return ret | |
def compute_attentions_withsoftmax(self, ret): | |
assert len(self.image_embed_range) > 0 | |
st, ed = self.image_embed_range[-1] | |
image_attention = ret[:,:,-1,st:ed].detach() | |
image_attention = image_attention.softmax(dim = -1) | |
image_attention = image_attention.mean(dim = 1) | |
self.attns.append(image_attention) # [b, k] | |
return ret | |
def compute_logits_index(self, ret): | |
next_token_logits = ret[:, -1, :] | |
index = next_token_logits.argmax(dim=-1) | |
self.index.append(index.item()) | |
return ret | |
def finalize(self): | |
attns = torch.cat(self.attns, dim = 0).to(self.device) | |
return attns | |
def finalize_projected_attn(self, norm_weight, proj): | |
assert len(self.index) == len(self.projected_attns) | |
mask = [] | |
for i in range(-4,-2): | |
index = self.index[i] | |
attns = self.projected_attns[i].to(self.device) # 1,k,d | |
input_dtype = attns.dtype | |
attns_var = attns.to(torch.float32).sum(dim = 1).pow(2).mean(-1, keepdim=True)# 1,d | |
attns_var = attns_var.unsqueeze(1)# 1,1,d | |
normalized_attns = attns * torch.rsqrt(attns_var + 1e-6) # 1,k,d | |
normalized_attns = norm_weight.to(normalized_attns.device) * normalized_attns.to(input_dtype) # 1,k,d | |
logits = proj(normalized_attns) | |
max_logits = logits[0,:,index] # k | |
mask.append(max_logits) | |
mask = torch.stack(mask, dim = 0) | |
return mask.mean(dim = 0) | |
def reinit(self): | |
self.attns = [] | |
self.projected_attns = [] | |
self.image_embed_range = [] | |
self.index = [] | |
torch.cuda.empty_cache() | |
def log_image_embeds_range(self, ret): | |
self.image_embed_range.append(ret[0][0]) | |
return ret | |
def hook_logger(model, device, layer_index = 20): | |
"""Hooks a projected residual stream logger to the model.""" | |
init_hookmanager(model.model.layers[layer_index].self_attn) | |
prs = MaskHookLogger(model, device) | |
model.model.layers[layer_index].self_attn.hook_manager.register('after_attn_mask', | |
prs.compute_attentions_withsoftmax) | |
model.hooklogger = prs | |
return prs | |