""" This file contains the code to do generalize \phi_MATCH. Given two models, it first retrains the MLPs or other FFN of the models, then runs \phi_MATCH on the distilled MLPs. May need to be modified depending on model architecture (this code was used for GPT-architecture). """ MLP_SIZE = 3072 MLP_SIZE_2 = 3072 EMB_SIZE = 768 EMB_SIZE_2 = 768 N_BLOCKS = 12 import torch import torch.nn as nn from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoConfig, GPT2LMHeadModel, OpenAIGPTLMHeadModel, ) import scipy from collections import defaultdict import numpy as np from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate from tracing.utils.evaluate import ( prepare_hf_dataset, prepare_hf_dataloader, ) from tracing.utils.utils import manual_seed from tracing.utils.llama.matching import match_wmats manual_seed(0) # architecture of MLP trained from scratch can be different from original # eg, uncomment to get a 2-hidden layer MLP (original has just 1 hidden layer) class CustomLlamaMLP(nn.Module): """ Custom MLP module for Llama-style architecture with SwiGLU activation. This implementation allows for flexible architecture changes when training replacement MLPs for model distillation and analysis. Args: config: Model configuration containing embedding dimensions """ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.n_embd self.intermediate_size = 4 * config.n_embd self.gate_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, x): """ Forward pass implementing SwiGLU activation function. Args: x: Input tensor of shape [batch_size, seq_len, hidden_size] Returns: torch.Tensor: Output tensor after MLP transformation """ down_proj = self.down_proj1(self.act_fn(self.gate_proj1(x)) * self.up_proj1(x)) return down_proj def hook_out(m, inp, op, feats, name): """ Forward hook to capture output activations from model layers. Args: m: Module being hooked inp: Input to the module op: Output from the module feats: Dictionary to store activations name: Key to store the activations under """ feats[name].append(op.detach().cpu()) def hook_in(m, inp, op, feats, name): """ Forward hook to capture input activations to model layers. Args: m: Module being hooked inp: Input to the module (tuple) op: Output from the module feats: Dictionary to store activations name: Key to store the activations under """ feats[name].append(inp[0].detach().cpu()) def mlp_layers(base_model_gate, base_model_up, ft_model_gate, ft_model_up, dataloader, i, j): """ Compare gate and up projections between separate model components. Tests whether separately trained gate and up projection models have consistent permutation patterns, which would indicate functionally corresponding neurons. Args: base_model_gate: First model with gate projection weights base_model_up: First model with up projection weights ft_model_gate: Second model with gate projection weights ft_model_up: Second model with up projection weights dataloader: DataLoader providing input data for activation collection i: Layer index in the first model j: Layer index in the second model Returns: float: Pearson correlation p-value between gate and up projection permutations """ gate_match = mlp_matching(base_model_gate, ft_model_gate, dataloader, i, j) up_match = mlp_matching(base_model_up, ft_model_up, dataloader, i, j) print(gate_match, up_match, i, j) cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist()) return pvalue def mlp_matching(base_model, ft_model, dataloader, i, j): """ Match neurons between models by comparing activations. Collects activations from the feed-forward layer for both models and computes a permutation that would align corresponding neurons. Args: base_model: Base model to compare ft_model: Target model to compare against the base model dataloader: DataLoader providing input data for activation collection i: Layer index in the base model j: Layer index in the target model Returns: torch.Tensor: Permutation indices that match neurons between models """ feats = defaultdict(list) base_hook = lambda *args: hook_out(*args, feats, "base") base_handle = base_model.transformer.h[i].mlp.c_fc.register_forward_hook(base_hook) ft_hook = lambda *args: hook_out(*args, feats, "ft") ft_handle = ft_model.transformer.h[i].mlp.c_fc.register_forward_hook(ft_hook) evaluate(base_model, dataloader) evaluate(ft_model, dataloader) base_mat = torch.vstack(feats["base"]) ft_mat = torch.vstack(feats["ft"]) base_mat.to("cuda") ft_mat.to("cuda") base_mat = base_mat.view(-1, base_mat.shape[-1]).T ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T base_handle.remove() ft_handle.remove() perm = match_wmats(base_mat, ft_mat) return perm def run(i): """ Run the generalized MATCH algorithm to compare models with distilled components. This function: 1. Loads two different GPT-2 models 2. Trains custom MLPs to replicate the behavior of the original model MLPs 3. Optionally applies random rotations to test invariance to representation changes 4. Creates separate models for gate and up projections 5. Compares the models using the MATCH algorithm The process demonstrates that functionally equivalent components can be identified even after representation changes, by examining activation patterns. Args: i: Layer index to focus on for the analysis Returns: None: Prints p-value results from the neuron matching """ train_losses = [] model_id_2 = "manupande21/GPT2_PMC" model_id_1 = "openai-community/gpt2" tokenizer = AutoTokenizer.from_pretrained(model_id_1, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_id_1, torch_dtype=torch.bfloat16) base_tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) dataset_wikitext = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer) dataloader_wikitext = prepare_hf_dataloader(dataset_wikitext, 1) config = AutoConfig.from_pretrained(model_id_1) i = 0 # layer to retrain bsz = 5000 # batch size T = 10000 # gradient steps width_fac = 1.0 # easier to get loss down for wider MLPs when retraining # Train the first custom MLP mlp = CustomLlamaMLP(config).bfloat16() mlp.to("cuda") model.transformer.h[i].mlp.to("cuda") criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001) print(f"Training MLP {model_id_1}") # Random rotation matrix to test invariance to representation changes A = torch.randn(size=(EMB_SIZE, EMB_SIZE), device="cuda").bfloat16() / np.sqrt( EMB_SIZE ) # rotate outputs (just for kicks / sanity check) # Distillation training loop for first model for t in range(T): X_batch = torch.randn(size=(bsz, EMB_SIZE), dtype=torch.bfloat16, device="cuda") with torch.no_grad(): Y_batch = model.transformer.h[i].mlp(X_batch) Y_batch = Y_batch @ A.T # Apply rotation to outputs Y_h = mlp(X_batch) optimizer.zero_grad() loss = criterion(Y_h, Y_batch) loss.backward() optimizer.step() if t % 1000 == 0: print(f"train loss: {loss.item()}") train_losses.append(loss.item()) # Create separate models for gate and up projections config = AutoConfig.from_pretrained(model_id_1) config.intermediate_size = int(width_fac * MLP_SIZE) model_retrained_1_gate = OpenAIGPTLMHeadModel(config).bfloat16() model_retrained_1_up = OpenAIGPTLMHeadModel(config).bfloat16() model.to("cpu") mlp.to("cpu") # Loading retrained weights to model_retrained weights_1_gate = model.state_dict() weights_1_up = model.state_dict() weights_1_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T weights_1_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T model_retrained_1_gate.load_state_dict(weights_1_gate) model_retrained_1_up.load_state_dict(weights_1_up) # Retraining / distilling second model model = AutoModelForCausalLM.from_pretrained(model_id_2, torch_dtype=torch.bfloat16) config = AutoConfig.from_pretrained(model_id_2) mlp = CustomLlamaMLP(config).bfloat16() mlp.to("cuda") model.transformer.h[i].mlp.to("cuda") criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001) print(f"Training MLP {model_id_2}") # Different random rotation matrix for second model A = torch.randn(size=(EMB_SIZE_2, EMB_SIZE_2), device="cuda").bfloat16() / np.sqrt( EMB_SIZE_2 ) # rotate outputs (just for kicks / sanity check) # Distillation training loop for second model for t in range(T): X_batch = torch.randn(size=(bsz, EMB_SIZE_2), dtype=torch.bfloat16, device="cuda") with torch.no_grad(): Y_batch = model.transformer.h[i].mlp(X_batch) Y_batch = Y_batch @ A.T # Apply rotation to outputs Y_h = mlp(X_batch) optimizer.zero_grad() loss = criterion(Y_h, Y_batch) loss.backward() optimizer.step() if t % 1000 == 0: print(f"train loss: {loss.item()}") train_losses.append(loss.item()) # Create separate models for gate and up projections config = AutoConfig.from_pretrained(model_id_2) config.intermediate_size = int(width_fac * MLP_SIZE_2) model_retrained_2_gate = GPT2LMHeadModel(config).bfloat16() model_retrained_2_up = GPT2LMHeadModel(config).bfloat16() model.to("cpu") mlp.to("cpu") weights_2_gate = model.state_dict() weights_2_up = model.state_dict() weights_2_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T weights_2_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T model_retrained_2_gate.load_state_dict(weights_2_gate) model_retrained_2_up.load_state_dict(weights_2_up) # Run MATCH algorithm to compare the models print( mlp_layers( model_retrained_1_gate, model_retrained_1_up, model_retrained_2_gate, model_retrained_2_up, dataloader, 0, 0, ) ) if __name__ == "__main__": for i in range(0, 10): run(i)