""" Runs statistic Cosine Similarity of Weights on all tensors of two given models, if they match in size. Part of the "Unconstrained Setting" experiments (see StripedHyena experiments). Relevant for hybrid models where only some parameters are shared. To run: Use the HuggingFace Ids for the two models in Line 65-66. Prints p-values between tensors that align in dimension. """ import torch import scipy from scipy.optimize import linear_sum_assignment as LAP from transformers import AutoModelForCausalLM from tracing.utils.utils import cossim, fisher import warnings warnings.filterwarnings("ignore") def csw_sp_pair(base_model, ft_model, layer_name_base, layer_name_ft): """ Calculate Cosine Similarity of Weights between two specific layers. Uses linear assignment to find optimal matching between neurons and calculates Pearson correlation to quantify similarity. Args: base_model: First model to compare ft_model: Second model to compare layer_name_base: Name of the layer in the first model's state dict layer_name_ft: Name of the layer in the second model's state dict Returns: float: p-value indicating the statistical similarity of weight matrices """ base_mat = base_model.state_dict()[layer_name_base] ft_mat = ft_model.state_dict()[layer_name_ft] matched = LAP(cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True) matched = matched[1] orig = torch.arange(len(matched)) cor, pvalue = scipy.stats.pearsonr(matched.tolist(), orig.tolist()) return pvalue def csw_models(base_model, ft_model): """ Perform comprehensive pairwise comparisons between all compatible layers of two models. Tests all possible layer pairings between models that have compatible shapes, useful for exploring model structure similarities without assuming corresponding positions. Args: base_model: First model to compare ft_model: Second model to compare Returns: float: Aggregate p-value from Fisher's method combining all layer comparisons, or 999 if no compatible layers were found """ base_model.to("cpu") ft_model.to("cpu") weights_base = base_model.state_dict() weights_ft = ft_model.state_dict() shapes_base = {} shapes_ft = {} for name1 in list(weights_base.keys()): shapes_base[name1] = weights_base[name1].shape for name2 in list(weights_ft.keys()): shapes_ft[name2] = weights_ft[name2].shape pvalues = [] for name1 in list(weights_base.keys()): for name2 in list(weights_ft.keys()): # print(name1,name2) if shapes_base[name1] == shapes_ft[name2] and len(shapes_base[name1]) != 1: pval = csw_sp_pair(base_model, ft_model, name1, name2) print(name1, name2, pval) pvalues.append(pval) res = 0 if len(pvalues) == 0: res = 999 else: res = fisher(pvalues) return res def main(): model_1_id = "openai-community/gpt2" model_2_id = "trl-internal-testing/dummy-GPT2-correct-vocab" print(model_1_id, model_2_id) model_1 = AutoModelForCausalLM.from_pretrained(model_1_id, torch_dtype=torch.bfloat16) model_2 = AutoModelForCausalLM.from_pretrained(model_2_id, torch_dtype=torch.bfloat16) print(csw_models(model_1, model_2)) if __name__ == "__main__": main()