test-sae / sae.py
elephantmipt's picture
Upload BatchTopKSAE
ca2139a verified
from transformers import PreTrainedModel
from typing import Optional, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from copy import deepcopy
from safetensors.torch import save_file, load_file
from sae.modeling.config import SAEConfig
import os
class BaseSAE(PreTrainedModel):
"""Base class for autoencoder models."""
config_class = SAEConfig
base_model_prefix = "sae"
def __init__(self, config: SAEConfig):
super().__init__(config)
print(config)
self.config = config
torch.manual_seed(42)
self.b_dec = nn.Parameter(torch.zeros(self.config.act_size))
self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size))
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(self.config.act_size, self.config.dict_size)
)
)
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(self.config.dict_size, self.config.act_size)
)
)
self.W_dec.data[:] = self.W_enc.t().data
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
self.num_batches_not_active = torch.zeros((self.config.dict_size,))
self.to(self.config.get_torch_dtype(self.config.dtype))
def preprocess_input(self, x):
x = x.to(self.config.get_torch_dtype(self.config.sae_dtype))
if self.config.input_unit_norm:
x_mean = x.mean(dim=-1, keepdim=True)
x = x - x_mean
x_std = x.std(dim=-1, keepdim=True)
x = x / (x_std + 1e-5)
return x, x_mean, x_std
else:
return x, None, None
def postprocess_output(self, x_reconstruct, x_mean, x_std):
if self.config.input_unit_norm:
x_reconstruct = x_reconstruct * x_std + x_mean
return x_reconstruct
@torch.no_grad()
def make_decoder_weights_and_grad_unit_norm(self):
W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
-1, keepdim=True
) * W_dec_normed
self.W_dec.grad -= W_dec_grad_proj
self.W_dec.data = W_dec_normed
def update_inactive_features(self, acts):
self.num_batches_not_active += (acts.sum(0) == 0).float()
self.num_batches_not_active[acts.sum(0) > 0] = 0
# @classmethod
# def from_pretrained(
# cls,
# pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
# *model_args,
# **kwargs
# ) -> "BaseSAE":
# config = kwargs.pop("config", None)
# if config is None:
# config = SAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
# model = cls(config)
# model.load_state_dict(
# load_file(os.path.join(pretrained_model_name_or_path, "model.safetensors"))
# )
# return model
# def save_pretrained(
# self,
# save_directory: Union[str, os.PathLike],
# **kwargs
# ):
# os.makedirs(save_directory, exist_ok=True)
# # Save the config
# self.config.save_pretrained(save_directory)
# # Save the model weights
# save_file(
# self.state_dict(),
# os.path.join(save_directory, "model.safetensors")
# )
class BatchTopKSAE(BaseSAE):
def forward(self, x):
x, x_mean, x_std = self.preprocess_input(x)
x_cent = x - self.b_dec
acts = F.relu(x_cent @ self.W_enc)
acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
acts_topk = (
torch.zeros_like(acts.flatten())
.scatter(-1, acts_topk.indices, acts_topk.values)
.reshape(acts.shape)
)
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
self.update_inactive_features(acts_topk)
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
return output
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
l1_norm = acts_topk.float().abs().sum(-1).mean()
l1_loss = self.config.l1_coeff * l1_norm
l0_norm = (acts_topk > 0).float().sum(-1).mean()
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
loss = l2_loss + aux_loss
num_dead_features = (
self.num_batches_not_active > self.config.n_batches_to_dead
).sum()
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
output = {
"sae_out": sae_out,
"feature_acts": acts_topk,
"num_dead_features": num_dead_features,
"loss": loss,
"l1_loss": l1_loss,
"l2_loss": l2_loss,
"l0_norm": l0_norm,
"l1_norm": l1_norm,
"aux_loss": aux_loss,
"explained_variance": explained_variance,
"top_k": self.config.top_k
}
return output
def get_auxiliary_loss(self, x, x_reconstruct, acts):
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
if dead_features.sum() > 0:
residual = x.float() - x_reconstruct.float()
acts_topk_aux = torch.topk(
acts[:, dead_features],
min(self.config.top_k_aux, dead_features.sum()),
dim=-1,
)
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
-1, acts_topk_aux.indices, acts_topk_aux.values
)
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
l2_loss_aux = (
self.config.aux_penalty
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
)
return l2_loss_aux
else:
return torch.tensor(0, dtype=x.dtype, device=x.device)
class TopKSAE(BaseSAE):
def forward(self, x):
x, x_mean, x_std = self.preprocess_input(x)
x_cent = x - self.b_dec
acts = F.relu(x_cent @ self.W_enc)
acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
acts_topk = torch.zeros_like(acts).scatter(
-1, acts_topk.indices, acts_topk.values
)
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
self.update_inactive_features(acts_topk)
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
return output
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
l1_norm = acts_topk.float().abs().sum(-1).mean()
l1_loss = self.config.l1_coeff * l1_norm
l0_norm = (acts_topk > 0).float().sum(-1).mean()
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
loss = l2_loss + l1_loss + aux_loss
num_dead_features = (
self.num_batches_not_active > self.config.n_batches_to_dead
).sum()
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
output = {
"sae_out": sae_out,
"feature_acts": acts_topk,
"num_dead_features": num_dead_features,
"loss": loss,
"l1_loss": l1_loss,
"l2_loss": l2_loss,
"l0_norm": l0_norm,
"l1_norm": l1_norm,
"explained_variance": explained_variance,
"aux_loss": aux_loss,
}
return output
def get_auxiliary_loss(self, x, x_reconstruct, acts):
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
if dead_features.sum() > 0:
residual = x.float() - x_reconstruct.float()
acts_topk_aux = torch.topk(
acts[:, dead_features],
min(self.config.top_k_aux, dead_features.sum()),
dim=-1,
)
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
-1, acts_topk_aux.indices, acts_topk_aux.values
)
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
l2_loss_aux = (
self.config.aux_penalty
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
)
return l2_loss_aux
else:
return torch.tensor(0, dtype=x.dtype, device=x.device)
class VanillaSAE(BaseSAE):
def forward(self, x):
x, x_mean, x_std = self.preprocess_input(x)
x_cent = x - self.b_dec
acts = F.relu(x_cent @ self.W_enc + self.b_enc)
x_reconstruct = acts @ self.W_dec + self.b_dec
self.update_inactive_features(acts)
output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
return output
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
l1_norm = acts.float().abs().sum(-1).mean()
l1_loss = self.config.l1_coeff * l1_norm
l0_norm = (acts > 0).float().sum(-1).mean()
loss = l2_loss + l1_loss
num_dead_features = (
self.num_batches_not_active > self.config.n_batches_to_dead
).sum()
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
output = {
"sae_out": sae_out,
"feature_acts": acts,
"num_dead_features": num_dead_features,
"loss": loss,
"l1_loss": l1_loss,
"l2_loss": l2_loss,
"l0_norm": l0_norm,
"l1_norm": l1_norm,
"explained_variance": explained_variance,
}
return output
import torch
import torch.nn as nn
class RectangleFunction(autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return ((x > -0.5) & (x < 0.5)).float()
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[(x <= -0.5) | (x >= 0.5)] = 0
return grad_input
class JumpReLUFunction(autograd.Function):
@staticmethod
def forward(ctx, x, log_threshold, bandwidth):
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
threshold = torch.exp(log_threshold)
return x * (x > threshold).float()
@staticmethod
def backward(ctx, grad_output):
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
bandwidth = bandwidth_tensor.item()
threshold = torch.exp(log_threshold)
x_grad = (x > threshold).float() * grad_output
threshold_grad = (
-(threshold / bandwidth)
* RectangleFunction.apply((x - threshold) / bandwidth)
* grad_output
)
return x_grad, threshold_grad, None # None for bandwidth
class JumpReLU(nn.Module):
def __init__(self, feature_size, bandwidth, device='cpu'):
super(JumpReLU, self).__init__()
self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
self.bandwidth = bandwidth
def forward(self, x):
return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
class StepFunction(autograd.Function):
@staticmethod
def forward(ctx, x, log_threshold, bandwidth):
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
threshold = torch.exp(log_threshold)
return (x > threshold).float()
@staticmethod
def backward(ctx, grad_output):
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
bandwidth = bandwidth_tensor.item()
threshold = torch.exp(log_threshold)
x_grad = torch.zeros_like(x)
threshold_grad = (
-(1.0 / bandwidth)
* RectangleFunction.apply((x - threshold) / bandwidth)
* grad_output
)
return x_grad, threshold_grad, None # None for bandwidth
class JumpReLUSAE(BaseSAE):
def __init__(self, config: SAEConfig):
super().__init__(config)
self.jumprelu = JumpReLU(
feature_size=config.dict_size,
bandwidth=config.bandwidth,
device=config.device if hasattr(config, 'device') else 'cpu'
)
def forward(self, x, use_pre_enc_bias=False):
x, x_mean, x_std = self.preprocess_input(x)
if use_pre_enc_bias:
x = x - self.b_dec
pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
feature_magnitudes = self.jumprelu(pre_activations)
x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean()
l0_loss = self.config.l1_coeff * l0
l1_loss = l0_loss
loss = l2_loss + l1_loss
num_dead_features = (
self.num_batches_not_active > self.config.n_batches_to_dead
).sum()
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
output = {
"sae_out": sae_out,
"feature_acts": acts,
"num_dead_features": num_dead_features,
"loss": loss,
"l1_loss": l1_loss,
"l2_loss": l2_loss,
"l0_norm": l0,
"l1_norm": l0,
"explained_variance": explained_variance,
}
return output