Spaces:
Running
Running
import time | |
import numpy as np | |
import torch | |
from PIL import Image | |
import glob | |
import sys | |
import argparse | |
import datetime | |
import json | |
from pathlib import Path | |
class MaskHookLogger(object): | |
def __init__(self, model, device): | |
self.current_layer = 0 | |
self.device = device | |
self.attentions = [] | |
self.mlps = [] | |
self.post_ln_std = None | |
self.post_ln_mean = None | |
self.model = model | |
def compute_attentions(self, ret): | |
if self.current_layer == self.layer_index: | |
bias_term = self.model.visual.transformer.resblocks[self.current_layer].attn.out_proj.bias | |
return_value = ret[:, 0] | |
return_value = return_value + bias_term[np.newaxis, np.newaxis] / (return_value.shape[1])# [b, n, d] | |
self.attentions.append(return_value.detach()) | |
self.current_layer += 1 | |
return ret | |
def compute_mlps(self, ret): | |
if self.current_layer == self.layer_index + 1: | |
self.mlps.append(ret[:, 1:].detach()) # [b, n, d] | |
return ret | |
def log_post_ln_mean(self, ret): | |
self.post_ln_mean = ret.detach() # [b, 1] | |
return ret | |
def log_post_ln_std(self, ret): | |
self.post_ln_std = ret.detach() # [b, 1] | |
return ret | |
def _normalize_mlps(self): | |
len_intermediates = self.current_layer * 2 - 1 | |
# This is just the normalization layer: | |
mean_centered = (self.mlps - | |
self.post_ln_mean[:, :, np.newaxis, np.newaxis] / len_intermediates) | |
weighted_mean_centered = self.model.visual.ln_post.weight.detach() * mean_centered | |
weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[:, :, np.newaxis, np.newaxis] | |
bias_term = self.model.visual.ln_post.bias.detach() / len_intermediates | |
post_ln = weighted_mean_by_std + bias_term | |
return post_ln @ self.model.visual.proj.detach() | |
def _normalize_attentions(self): | |
len_intermediates = self.current_layer * 2 - 1 # 2*l + 1 | |
normalization_term = self.attentions.shape[2] * 1 # n * h, h=1 | |
# This is just the normalization layer: | |
mean_centered = (self.attentions - | |
self.post_ln_mean[:, :, np.newaxis, np.newaxis] / | |
(len_intermediates * normalization_term)) | |
weighted_mean_centered = self.model.visual.ln_post.weight.detach() * mean_centered | |
weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[:, :, np.newaxis, np.newaxis] | |
bias_term = self.model.visual.ln_post.bias.detach() / (len_intermediates * normalization_term) | |
post_ln = weighted_mean_by_std + bias_term | |
return post_ln @ self.model.visual.proj.detach() | |
def finalize(self, representation): | |
"""We calculate the post-ln scaling, project it and normalize by the last norm.""" | |
self.attentions = torch.stack(self.attentions, axis=1) # [b, 1, n, d] | |
self.mlps = torch.stack(self.mlps, axis=1) # [b, 1, n, d] | |
projected_attentions = self._normalize_attentions() | |
projected_mlps = self._normalize_mlps() | |
norm = representation.norm(dim=-1).detach() | |
return (projected_attentions / norm[:, np.newaxis, np.newaxis, np.newaxis], | |
projected_mlps / norm[:, np.newaxis, np.newaxis, np.newaxis]) | |
def reinit(self): | |
self.current_layer = 0 | |
self.attentions = [] | |
self.mlps = [] | |
self.post_ln_mean = None | |
self.post_ln_std = None | |
torch.cuda.empty_cache() | |
def hook_prs_logger(model, device, layer_index = 23): | |
"""Hooks a projected residual stream logger to the model.""" | |
prs = MaskHookLogger(model, device) | |
model.hook_manager.register('visual.transformer.resblocks.*.attn.out.post', | |
prs.compute_attentions) | |
model.hook_manager.register('visual.transformer.resblocks.*.post', | |
prs.compute_mlps) | |
model.hook_manager.register('visual.ln_pre_post', | |
prs.compute_mlps) | |
model.hook_manager.register('visual.ln_post.mean', | |
prs.log_post_ln_mean) | |
model.hook_manager.register('visual.ln_post.sqrt_var', | |
prs.log_post_ln_std) | |
prs.layer_index = layer_index | |
return prs | |