from transformers import PreTrainedModel from typing import Optional, Dict, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd from copy import deepcopy from safetensors.torch import save_file, load_file from sae.modeling.config import SAEConfig import os class BaseSAE(PreTrainedModel): """Base class for autoencoder models.""" config_class = SAEConfig base_model_prefix = "sae" def __init__(self, config: SAEConfig): super().__init__(config) print(config) self.config = config torch.manual_seed(42) self.b_dec = nn.Parameter(torch.zeros(self.config.act_size)) self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size)) self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty(self.config.act_size, self.config.dict_size) ) ) self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty(self.config.dict_size, self.config.act_size) ) ) self.W_dec.data[:] = self.W_enc.t().data self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) self.num_batches_not_active = torch.zeros((self.config.dict_size,)) self.to(self.config.get_torch_dtype(self.config.dtype)) def preprocess_input(self, x): x = x.to(self.config.get_torch_dtype(self.config.sae_dtype)) if self.config.input_unit_norm: x_mean = x.mean(dim=-1, keepdim=True) x = x - x_mean x_std = x.std(dim=-1, keepdim=True) x = x / (x_std + 1e-5) return x, x_mean, x_std else: return x, None, None def postprocess_output(self, x_reconstruct, x_mean, x_std): if self.config.input_unit_norm: x_reconstruct = x_reconstruct * x_std + x_mean return x_reconstruct @torch.no_grad() def make_decoder_weights_and_grad_unit_norm(self): W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum( -1, keepdim=True ) * W_dec_normed self.W_dec.grad -= W_dec_grad_proj self.W_dec.data = W_dec_normed def update_inactive_features(self, acts): self.num_batches_not_active += (acts.sum(0) == 0).float() self.num_batches_not_active[acts.sum(0) > 0] = 0 # @classmethod # def from_pretrained( # cls, # pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], # *model_args, # **kwargs # ) -> "BaseSAE": # config = kwargs.pop("config", None) # if config is None: # config = SAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # model = cls(config) # model.load_state_dict( # load_file(os.path.join(pretrained_model_name_or_path, "model.safetensors")) # ) # return model # def save_pretrained( # self, # save_directory: Union[str, os.PathLike], # **kwargs # ): # os.makedirs(save_directory, exist_ok=True) # # Save the config # self.config.save_pretrained(save_directory) # # Save the model weights # save_file( # self.state_dict(), # os.path.join(save_directory, "model.safetensors") # ) class BatchTopKSAE(BaseSAE): def forward(self, x): x, x_mean, x_std = self.preprocess_input(x) x_cent = x - self.b_dec acts = F.relu(x_cent @ self.W_enc) acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1) acts_topk = ( torch.zeros_like(acts.flatten()) .scatter(-1, acts_topk.indices, acts_topk.values) .reshape(acts.shape) ) x_reconstruct = acts_topk @ self.W_dec + self.b_dec self.update_inactive_features(acts_topk) output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std) return output def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std): l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() l1_norm = acts_topk.float().abs().sum(-1).mean() l1_loss = self.config.l1_coeff * l1_norm l0_norm = (acts_topk > 0).float().sum(-1).mean() aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts) loss = l2_loss + aux_loss num_dead_features = ( self.num_batches_not_active > self.config.n_batches_to_dead ).sum() sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() output = { "sae_out": sae_out, "feature_acts": acts_topk, "num_dead_features": num_dead_features, "loss": loss, "l1_loss": l1_loss, "l2_loss": l2_loss, "l0_norm": l0_norm, "l1_norm": l1_norm, "aux_loss": aux_loss, "explained_variance": explained_variance, "top_k": self.config.top_k } return output def get_auxiliary_loss(self, x, x_reconstruct, acts): dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead if dead_features.sum() > 0: residual = x.float() - x_reconstruct.float() acts_topk_aux = torch.topk( acts[:, dead_features], min(self.config.top_k_aux, dead_features.sum()), dim=-1, ) acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( -1, acts_topk_aux.indices, acts_topk_aux.values ) x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] l2_loss_aux = ( self.config.aux_penalty * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() ) return l2_loss_aux else: return torch.tensor(0, dtype=x.dtype, device=x.device) class TopKSAE(BaseSAE): def forward(self, x): x, x_mean, x_std = self.preprocess_input(x) x_cent = x - self.b_dec acts = F.relu(x_cent @ self.W_enc) acts_topk = torch.topk(acts, self.config.top_k, dim=-1) acts_topk = torch.zeros_like(acts).scatter( -1, acts_topk.indices, acts_topk.values ) x_reconstruct = acts_topk @ self.W_dec + self.b_dec self.update_inactive_features(acts_topk) output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std) return output def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std): l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() l1_norm = acts_topk.float().abs().sum(-1).mean() l1_loss = self.config.l1_coeff * l1_norm l0_norm = (acts_topk > 0).float().sum(-1).mean() aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts) loss = l2_loss + l1_loss + aux_loss num_dead_features = ( self.num_batches_not_active > self.config.n_batches_to_dead ).sum() sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() output = { "sae_out": sae_out, "feature_acts": acts_topk, "num_dead_features": num_dead_features, "loss": loss, "l1_loss": l1_loss, "l2_loss": l2_loss, "l0_norm": l0_norm, "l1_norm": l1_norm, "explained_variance": explained_variance, "aux_loss": aux_loss, } return output def get_auxiliary_loss(self, x, x_reconstruct, acts): dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead if dead_features.sum() > 0: residual = x.float() - x_reconstruct.float() acts_topk_aux = torch.topk( acts[:, dead_features], min(self.config.top_k_aux, dead_features.sum()), dim=-1, ) acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( -1, acts_topk_aux.indices, acts_topk_aux.values ) x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] l2_loss_aux = ( self.config.aux_penalty * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() ) return l2_loss_aux else: return torch.tensor(0, dtype=x.dtype, device=x.device) class VanillaSAE(BaseSAE): def forward(self, x): x, x_mean, x_std = self.preprocess_input(x) x_cent = x - self.b_dec acts = F.relu(x_cent @ self.W_enc + self.b_enc) x_reconstruct = acts @ self.W_dec + self.b_dec self.update_inactive_features(acts) output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std) return output def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std): l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() l1_norm = acts.float().abs().sum(-1).mean() l1_loss = self.config.l1_coeff * l1_norm l0_norm = (acts > 0).float().sum(-1).mean() loss = l2_loss + l1_loss num_dead_features = ( self.num_batches_not_active > self.config.n_batches_to_dead ).sum() sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() output = { "sae_out": sae_out, "feature_acts": acts, "num_dead_features": num_dead_features, "loss": loss, "l1_loss": l1_loss, "l2_loss": l2_loss, "l0_norm": l0_norm, "l1_norm": l1_norm, "explained_variance": explained_variance, } return output import torch import torch.nn as nn class RectangleFunction(autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return ((x > -0.5) & (x < 0.5)).float() @staticmethod def backward(ctx, grad_output): (x,) = ctx.saved_tensors grad_input = grad_output.clone() grad_input[(x <= -0.5) | (x >= 0.5)] = 0 return grad_input class JumpReLUFunction(autograd.Function): @staticmethod def forward(ctx, x, log_threshold, bandwidth): ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) threshold = torch.exp(log_threshold) return x * (x > threshold).float() @staticmethod def backward(ctx, grad_output): x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) x_grad = (x > threshold).float() * grad_output threshold_grad = ( -(threshold / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output ) return x_grad, threshold_grad, None # None for bandwidth class JumpReLU(nn.Module): def __init__(self, feature_size, bandwidth, device='cpu'): super(JumpReLU, self).__init__() self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device)) self.bandwidth = bandwidth def forward(self, x): return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) class StepFunction(autograd.Function): @staticmethod def forward(ctx, x, log_threshold, bandwidth): ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) threshold = torch.exp(log_threshold) return (x > threshold).float() @staticmethod def backward(ctx, grad_output): x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) x_grad = torch.zeros_like(x) threshold_grad = ( -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output ) return x_grad, threshold_grad, None # None for bandwidth class JumpReLUSAE(BaseSAE): def __init__(self, config: SAEConfig): super().__init__(config) self.jumprelu = JumpReLU( feature_size=config.dict_size, bandwidth=config.bandwidth, device=config.device if hasattr(config, 'device') else 'cpu' ) def forward(self, x, use_pre_enc_bias=False): x, x_mean, x_std = self.preprocess_input(x) if use_pre_enc_bias: x = x - self.b_dec pre_activations = torch.relu(x @ self.W_enc + self.b_enc) feature_magnitudes = self.jumprelu(pre_activations) x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std) def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std): l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean() l0_loss = self.config.l1_coeff * l0 l1_loss = l0_loss loss = l2_loss + l1_loss num_dead_features = ( self.num_batches_not_active > self.config.n_batches_to_dead ).sum() sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() output = { "sae_out": sae_out, "feature_acts": acts, "num_dead_features": num_dead_features, "loss": loss, "l1_loss": l1_loss, "l2_loss": l2_loss, "l0_norm": l0, "l1_norm": l0, "explained_variance": explained_variance, } return output