|
from transformers import PreTrainedModel |
|
from typing import Optional, Dict, Union |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.autograd as autograd |
|
from copy import deepcopy |
|
from safetensors.torch import save_file, load_file |
|
from sae.modeling.config import SAEConfig |
|
import os |
|
|
|
|
|
class BaseSAE(PreTrainedModel): |
|
"""Base class for autoencoder models.""" |
|
config_class = SAEConfig |
|
base_model_prefix = "sae" |
|
|
|
def __init__(self, config: SAEConfig): |
|
super().__init__(config) |
|
print(config) |
|
self.config = config |
|
torch.manual_seed(42) |
|
|
|
self.b_dec = nn.Parameter(torch.zeros(self.config.act_size)) |
|
self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size)) |
|
self.W_enc = nn.Parameter( |
|
torch.nn.init.kaiming_uniform_( |
|
torch.empty(self.config.act_size, self.config.dict_size) |
|
) |
|
) |
|
self.W_dec = nn.Parameter( |
|
torch.nn.init.kaiming_uniform_( |
|
torch.empty(self.config.dict_size, self.config.act_size) |
|
) |
|
) |
|
self.W_dec.data[:] = self.W_enc.t().data |
|
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) |
|
self.num_batches_not_active = torch.zeros((self.config.dict_size,)) |
|
|
|
self.to(self.config.get_torch_dtype(self.config.dtype)) |
|
|
|
def preprocess_input(self, x): |
|
x = x.to(self.config.get_torch_dtype(self.config.sae_dtype)) |
|
if self.config.input_unit_norm: |
|
x_mean = x.mean(dim=-1, keepdim=True) |
|
x = x - x_mean |
|
x_std = x.std(dim=-1, keepdim=True) |
|
x = x / (x_std + 1e-5) |
|
return x, x_mean, x_std |
|
else: |
|
return x, None, None |
|
|
|
def postprocess_output(self, x_reconstruct, x_mean, x_std): |
|
if self.config.input_unit_norm: |
|
x_reconstruct = x_reconstruct * x_std + x_mean |
|
return x_reconstruct |
|
|
|
@torch.no_grad() |
|
def make_decoder_weights_and_grad_unit_norm(self): |
|
W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) |
|
W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum( |
|
-1, keepdim=True |
|
) * W_dec_normed |
|
self.W_dec.grad -= W_dec_grad_proj |
|
self.W_dec.data = W_dec_normed |
|
|
|
def update_inactive_features(self, acts): |
|
self.num_batches_not_active += (acts.sum(0) == 0).float() |
|
self.num_batches_not_active[acts.sum(0) > 0] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchTopKSAE(BaseSAE): |
|
def forward(self, x): |
|
x, x_mean, x_std = self.preprocess_input(x) |
|
|
|
x_cent = x - self.b_dec |
|
acts = F.relu(x_cent @ self.W_enc) |
|
acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1) |
|
acts_topk = ( |
|
torch.zeros_like(acts.flatten()) |
|
.scatter(-1, acts_topk.indices, acts_topk.values) |
|
.reshape(acts.shape) |
|
) |
|
x_reconstruct = acts_topk @ self.W_dec + self.b_dec |
|
|
|
self.update_inactive_features(acts_topk) |
|
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std) |
|
return output |
|
|
|
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std): |
|
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() |
|
l1_norm = acts_topk.float().abs().sum(-1).mean() |
|
l1_loss = self.config.l1_coeff * l1_norm |
|
l0_norm = (acts_topk > 0).float().sum(-1).mean() |
|
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts) |
|
loss = l2_loss + aux_loss |
|
num_dead_features = ( |
|
self.num_batches_not_active > self.config.n_batches_to_dead |
|
).sum() |
|
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) |
|
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() |
|
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() |
|
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() |
|
output = { |
|
"sae_out": sae_out, |
|
"feature_acts": acts_topk, |
|
"num_dead_features": num_dead_features, |
|
"loss": loss, |
|
"l1_loss": l1_loss, |
|
"l2_loss": l2_loss, |
|
"l0_norm": l0_norm, |
|
"l1_norm": l1_norm, |
|
"aux_loss": aux_loss, |
|
"explained_variance": explained_variance, |
|
"top_k": self.config.top_k |
|
} |
|
return output |
|
|
|
def get_auxiliary_loss(self, x, x_reconstruct, acts): |
|
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead |
|
if dead_features.sum() > 0: |
|
residual = x.float() - x_reconstruct.float() |
|
acts_topk_aux = torch.topk( |
|
acts[:, dead_features], |
|
min(self.config.top_k_aux, dead_features.sum()), |
|
dim=-1, |
|
) |
|
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( |
|
-1, acts_topk_aux.indices, acts_topk_aux.values |
|
) |
|
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] |
|
l2_loss_aux = ( |
|
self.config.aux_penalty |
|
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean() |
|
) |
|
return l2_loss_aux |
|
else: |
|
return torch.tensor(0, dtype=x.dtype, device=x.device) |
|
|
|
|
|
class TopKSAE(BaseSAE): |
|
def forward(self, x): |
|
x, x_mean, x_std = self.preprocess_input(x) |
|
|
|
x_cent = x - self.b_dec |
|
acts = F.relu(x_cent @ self.W_enc) |
|
acts_topk = torch.topk(acts, self.config.top_k, dim=-1) |
|
acts_topk = torch.zeros_like(acts).scatter( |
|
-1, acts_topk.indices, acts_topk.values |
|
) |
|
x_reconstruct = acts_topk @ self.W_dec + self.b_dec |
|
|
|
self.update_inactive_features(acts_topk) |
|
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std) |
|
return output |
|
|
|
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std): |
|
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() |
|
l1_norm = acts_topk.float().abs().sum(-1).mean() |
|
l1_loss = self.config.l1_coeff * l1_norm |
|
l0_norm = (acts_topk > 0).float().sum(-1).mean() |
|
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts) |
|
loss = l2_loss + l1_loss + aux_loss |
|
num_dead_features = ( |
|
self.num_batches_not_active > self.config.n_batches_to_dead |
|
).sum() |
|
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) |
|
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() |
|
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() |
|
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() |
|
output = { |
|
"sae_out": sae_out, |
|
"feature_acts": acts_topk, |
|
"num_dead_features": num_dead_features, |
|
"loss": loss, |
|
"l1_loss": l1_loss, |
|
"l2_loss": l2_loss, |
|
"l0_norm": l0_norm, |
|
"l1_norm": l1_norm, |
|
"explained_variance": explained_variance, |
|
"aux_loss": aux_loss, |
|
} |
|
return output |
|
|
|
def get_auxiliary_loss(self, x, x_reconstruct, acts): |
|
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead |
|
if dead_features.sum() > 0: |
|
residual = x.float() - x_reconstruct.float() |
|
acts_topk_aux = torch.topk( |
|
acts[:, dead_features], |
|
min(self.config.top_k_aux, dead_features.sum()), |
|
dim=-1, |
|
) |
|
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( |
|
-1, acts_topk_aux.indices, acts_topk_aux.values |
|
) |
|
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] |
|
l2_loss_aux = ( |
|
self.config.aux_penalty |
|
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean() |
|
) |
|
return l2_loss_aux |
|
else: |
|
return torch.tensor(0, dtype=x.dtype, device=x.device) |
|
|
|
|
|
class VanillaSAE(BaseSAE): |
|
def forward(self, x): |
|
x, x_mean, x_std = self.preprocess_input(x) |
|
x_cent = x - self.b_dec |
|
acts = F.relu(x_cent @ self.W_enc + self.b_enc) |
|
x_reconstruct = acts @ self.W_dec + self.b_dec |
|
self.update_inactive_features(acts) |
|
output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std) |
|
return output |
|
|
|
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std): |
|
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() |
|
l1_norm = acts.float().abs().sum(-1).mean() |
|
l1_loss = self.config.l1_coeff * l1_norm |
|
l0_norm = (acts > 0).float().sum(-1).mean() |
|
loss = l2_loss + l1_loss |
|
num_dead_features = ( |
|
self.num_batches_not_active > self.config.n_batches_to_dead |
|
).sum() |
|
|
|
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) |
|
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() |
|
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() |
|
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() |
|
output = { |
|
"sae_out": sae_out, |
|
"feature_acts": acts, |
|
"num_dead_features": num_dead_features, |
|
"loss": loss, |
|
"l1_loss": l1_loss, |
|
"l2_loss": l2_loss, |
|
"l0_norm": l0_norm, |
|
"l1_norm": l1_norm, |
|
"explained_variance": explained_variance, |
|
} |
|
return output |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
class RectangleFunction(autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x): |
|
ctx.save_for_backward(x) |
|
return ((x > -0.5) & (x < 0.5)).float() |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
(x,) = ctx.saved_tensors |
|
grad_input = grad_output.clone() |
|
grad_input[(x <= -0.5) | (x >= 0.5)] = 0 |
|
return grad_input |
|
|
|
class JumpReLUFunction(autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x, log_threshold, bandwidth): |
|
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) |
|
threshold = torch.exp(log_threshold) |
|
return x * (x > threshold).float() |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
x, log_threshold, bandwidth_tensor = ctx.saved_tensors |
|
bandwidth = bandwidth_tensor.item() |
|
threshold = torch.exp(log_threshold) |
|
x_grad = (x > threshold).float() * grad_output |
|
threshold_grad = ( |
|
-(threshold / bandwidth) |
|
* RectangleFunction.apply((x - threshold) / bandwidth) |
|
* grad_output |
|
) |
|
return x_grad, threshold_grad, None |
|
|
|
class JumpReLU(nn.Module): |
|
def __init__(self, feature_size, bandwidth, device='cpu'): |
|
super(JumpReLU, self).__init__() |
|
self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device)) |
|
self.bandwidth = bandwidth |
|
|
|
def forward(self, x): |
|
return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) |
|
|
|
class StepFunction(autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x, log_threshold, bandwidth): |
|
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) |
|
threshold = torch.exp(log_threshold) |
|
return (x > threshold).float() |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
x, log_threshold, bandwidth_tensor = ctx.saved_tensors |
|
bandwidth = bandwidth_tensor.item() |
|
threshold = torch.exp(log_threshold) |
|
x_grad = torch.zeros_like(x) |
|
threshold_grad = ( |
|
-(1.0 / bandwidth) |
|
* RectangleFunction.apply((x - threshold) / bandwidth) |
|
* grad_output |
|
) |
|
return x_grad, threshold_grad, None |
|
|
|
class JumpReLUSAE(BaseSAE): |
|
def __init__(self, config: SAEConfig): |
|
super().__init__(config) |
|
self.jumprelu = JumpReLU( |
|
feature_size=config.dict_size, |
|
bandwidth=config.bandwidth, |
|
device=config.device if hasattr(config, 'device') else 'cpu' |
|
) |
|
|
|
def forward(self, x, use_pre_enc_bias=False): |
|
x, x_mean, x_std = self.preprocess_input(x) |
|
if use_pre_enc_bias: |
|
x = x - self.b_dec |
|
pre_activations = torch.relu(x @ self.W_enc + self.b_enc) |
|
feature_magnitudes = self.jumprelu(pre_activations) |
|
|
|
x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec |
|
|
|
return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std) |
|
|
|
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std): |
|
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() |
|
|
|
l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean() |
|
l0_loss = self.config.l1_coeff * l0 |
|
l1_loss = l0_loss |
|
|
|
loss = l2_loss + l1_loss |
|
num_dead_features = ( |
|
self.num_batches_not_active > self.config.n_batches_to_dead |
|
).sum() |
|
|
|
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) |
|
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() |
|
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() |
|
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() |
|
output = { |
|
"sae_out": sae_out, |
|
"feature_acts": acts, |
|
"num_dead_features": num_dead_features, |
|
"loss": loss, |
|
"l1_loss": l1_loss, |
|
"l2_loss": l2_loss, |
|
"l0_norm": l0, |
|
"l1_norm": l0, |
|
"explained_variance": explained_variance, |
|
} |
|
return output |
|
|