Ahmed Ahmed commited on
Commit
de071e9
·
1 Parent(s): 1dd4b6a

Add model-tracing code for p-value computation (without binary files)

Browse files
Files changed (48) hide show
  1. model-tracing +0 -1
  2. model-tracing/.pre-commit-config.yaml +29 -0
  3. model-tracing/README.md +138 -0
  4. model-tracing/config/dolma_pl.yaml +9 -0
  5. model-tracing/config/llama3.yaml +5 -0
  6. model-tracing/config/llama70b.yaml +5 -0
  7. model-tracing/config/llama7b.yaml +22 -0
  8. model-tracing/config/llama7b_split.yaml +24 -0
  9. model-tracing/config/llama7b_tree.yaml +31 -0
  10. model-tracing/config/m2d2.yaml +34 -0
  11. model-tracing/experiments/csw_full.py +111 -0
  12. model-tracing/experiments/faiss/csw_faiss.py +280 -0
  13. model-tracing/experiments/generalized_match.py +343 -0
  14. model-tracing/experiments/huref.py +140 -0
  15. model-tracing/experiments/localized_testing.py +144 -0
  16. model-tracing/launch.py +87 -0
  17. model-tracing/main.py +246 -0
  18. model-tracing/requirements-dev.txt +5 -0
  19. model-tracing/requirements.txt +17 -0
  20. model-tracing/results/jsd/model_pairs_jsd.csv +37 -0
  21. model-tracing/results/l2/model_pairs_l2.csv +37 -0
  22. model-tracing/results/perm/permutation_l2_updated_midpoint_wikitext_single.csv +21 -0
  23. model-tracing/results/perm/permutation_loss_midpoint_wikitext_single.csv +21 -0
  24. model-tracing/results/perm/permutation_norm_loss_midpoint_wikitext_single.csv +21 -0
  25. model-tracing/scripts/docs/doc_trace.py +204 -0
  26. model-tracing/scripts/docs/launch.py +58 -0
  27. model-tracing/scripts/docs/m2d_trace.py +279 -0
  28. model-tracing/scripts/mode/main.py +231 -0
  29. model-tracing/scripts/mode/mode_connectivity_metrics.py +173 -0
  30. model-tracing/scripts/perm/main.py +47 -0
  31. model-tracing/scripts/robust/pythia.py +83 -0
  32. model-tracing/tracing/__init__.py +0 -0
  33. model-tracing/tracing/perm/permute.py +119 -0
  34. model-tracing/tracing/statistics/__init__.py +0 -0
  35. model-tracing/tracing/statistics/csh.py +172 -0
  36. model-tracing/tracing/statistics/csu.py +168 -0
  37. model-tracing/tracing/statistics/jsd.py +195 -0
  38. model-tracing/tracing/statistics/l2.py +71 -0
  39. model-tracing/tracing/statistics/match.py +234 -0
  40. model-tracing/tracing/statistics/mc.py +12 -0
  41. model-tracing/tracing/statistics/perm_mc_l2.py +48 -0
  42. model-tracing/tracing/utils/__init__.py +0 -0
  43. model-tracing/tracing/utils/evaluate.py +209 -0
  44. model-tracing/tracing/utils/llama/matching.py +45 -0
  45. model-tracing/tracing/utils/llama/model.py +263 -0
  46. model-tracing/tracing/utils/olmo/model.py +182 -0
  47. model-tracing/tracing/utils/plot_metrics.py +545 -0
  48. model-tracing/tracing/utils/utils.py +109 -0
model-tracing DELETED
@@ -1 +0,0 @@
1
- Subproject commit 9eb3b67655be2a3576348a6d482e69c62f72fc3e
 
 
model-tracing/.pre-commit-config.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: check-yaml
8
+ - id: check-added-large-files
9
+ - id: check-merge-conflict
10
+
11
+ - repo: https://github.com/psf/black
12
+ rev: 24.1.1
13
+ hooks:
14
+ - id: black
15
+ language_version: python3
16
+ args: [--line-length=100]
17
+
18
+ #- repo: https://github.com/charliermarsh/ruff-pre-commit
19
+ # rev: 'v0.1.8'
20
+ # hooks:
21
+ # - id: ruff
22
+ # args: [--fix, --exit-non-zero-on-fix, --line-length=100, --ignore=E402,E731,F841,F811,F821]
23
+
24
+ - repo: https://github.com/nbQA-dev/nbQA
25
+ rev: 1.7.1
26
+ hooks:
27
+ - id: nbqa-black
28
+ additional_dependencies: [black==24.1.1]
29
+ args: [--line-length=100]
model-tracing/README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # LLM Model Tracing
3
+ This repository investigates model tracing in large language models (LLMs).
4
+
5
+ Specifically, given a base LLM and a fine-tuned LLM, this code provides functionality to:
6
+
7
+ - Permute the weights of one model (either MLP or embedding weights).
8
+ - Align the weights of the fine-tuned model to the base model using the Hungarian algorithm.
9
+ - Evaluate the effect of weight permutation and alignment on different statistics:
10
+ - Mode connectivity
11
+ - Cosine similarity
12
+ - Embedding similarity
13
+ - Evaluate the perplexity of the base and fine-tuned models on a given dataset.
14
+
15
+ ## Requirements
16
+
17
+ Install the necessary packages using:
18
+
19
+ ```bash
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ For development, install the development dependencies:
24
+
25
+ ```bash
26
+ pip install -r requirements-dev.txt
27
+ ```
28
+
29
+ ### Code Formatting with pre-commit
30
+
31
+ This repository uses pre-commit hooks to ensure code quality and consistency.
32
+
33
+ 1. Install pre-commit:
34
+
35
+ ```bash
36
+ pip install pre-commit
37
+ ```
38
+
39
+ 2. Set up the pre-commit hooks:
40
+
41
+ ```bash
42
+ pre-commit install
43
+ ```
44
+
45
+ 3. (Optional) Run pre-commit on all files:
46
+
47
+ ```bash
48
+ pre-commit run --all-files
49
+ ```
50
+
51
+ Pre-commit will automatically run on staged files when you commit changes, applying:
52
+ - Black for code formatting
53
+ - Ruff for linting and fixing common issues
54
+ - nbQA for notebook formatting
55
+ - Various file checks (trailing whitespace, YAML validity, etc.)
56
+
57
+ ## Usage
58
+
59
+ The repository provides three main scripts:
60
+
61
+ - `main.py`: Executes the main experiment pipeline for model tracing.
62
+ - `launch.py`: Launches multiple experiments in parallel using slurm.
63
+
64
+ ### `main.py`
65
+
66
+ This script performs the following steps:
67
+
68
+ 1. Loads the base and fine-tuned LLMs.
69
+ 2. Optionally permutes the weights of the fine-tuned model.
70
+ 3. Calculates the selected statistic for the non-aligned models.
71
+ 4. Optionally aligns the weights of the fine-tuned model to the base model.
72
+ 5. Calculates the selected statistic for the aligned models.
73
+ 6. Optionally evaluates the perplexity of the base and fine-tuned models.
74
+ 7. Saves the results to a pickle file.
75
+
76
+ The script accepts various command-line arguments:
77
+
78
+ - `--base_model_id`: HuggingFace model ID for the base model.
79
+ - `--ft_model_id`: HuggingFace model ID for the fine-tuned model.
80
+ - `--permute`: Whether to permute the weights of the fine-tuned model.
81
+ - `--align`: Whether to align the weights of the fine-tuned model to the base model.
82
+ - `--dataset_id`: HuggingFace dataset ID for perplexity evaluation.
83
+ - `--stat`: Statistic to calculate (options: "mode", "cos", "emb").
84
+ - csu: cosine similarity of weights statistic (on MLP up projection matrices) w/ Spearman correlation
85
+ - csu_all: csu on all pairs of parameters with equal shape
86
+ - csh: cosine similarity of MLP activations statistic w/ Spearman correlation
87
+ - match: unconstrained statistic (match) with permutation matching of MLP activations
88
+ - match_all: unconstrained statistic (match) on all pairs of MLP block activations
89
+ - `--attn`: Whether to consider attention weights in the "mode" statistic.
90
+ - `--emb`: Whether to consider embedding weights in the "mode" statistic.
91
+ - `--eval`: Whether to evaluate perplexity.
92
+ - `--save`: Path to save the results pickle file.
93
+
94
+ Example usage:
95
+
96
+ ```bash
97
+ python main.py --base_model_id meta-llama/Llama-2-7b-hf --ft_model_id lmsys/vicuna-7b-v1.5 --stat csu --save results.p
98
+ ```
99
+
100
+ ```bash
101
+ python main.py --base_model_id meta-llama/Llama-2-7b-hf --ft_model_id lmsys/vicuna-7b-v1.5 --permute --align --dataset wikitext --stat match --attn --save results.p
102
+ ```
103
+
104
+ ### `launch.py`
105
+
106
+ This script launches multiple experiments in parallel using slurm. It reads model IDs from a YAML file and runs `main.py` for each pair of base and fine-tuned models. Use the flag --flat all (defaulted) to run on all pairs of models from a YAML (see config/llama7b.yaml); or, --flat split to run on all pairs of a 'base' model with a 'finetuned' model (see config/llama7b_split.yaml); or --flat specified to run on a specified list of pairs of models.
107
+
108
+ ## Configuration
109
+
110
+ The `model-tracing/config/model_list.yaml` file defines the base and fine-tuned models for the experiments.
111
+ ## Data
112
+
113
+ The code downloads and uses the Wikitext 103 dataset for perplexity evaluation.
114
+
115
+ ## Results
116
+
117
+ The results of the experiments are saved as pickle files. The files contain dictionaries with the following keys:
118
+
119
+ - `args`: Command-line arguments used for the experiment.
120
+ - `commit`: Git commit hash of the code used for the experiment.
121
+ - `non-aligned test stat`: Value of the selected statistic for the non-aligned models.
122
+ - `aligned test stat`: Value of the selected statistic for the aligned models (if `--align` is True).
123
+ - `base loss`: Perplexity of the base model on the evaluation dataset (if `--eval` is True).
124
+ - `ft loss`: Perplexity of the fine-tuned model on the evaluation dataset (if `--eval` is True).
125
+ - `time`: Total execution time of the experiment.
126
+
127
+ ## Sample commands
128
+
129
+ ### 70B runs
130
+ ```
131
+ python main.py --base_model_id meta-llama/Llama-2-70b-hf --ft_model_id meta-llama/Meta-Llama-3-70B --stat csu
132
+ ```
133
+
134
+ # Experiments
135
+
136
+ Relevant scripts for running additional experiments described in our paper are in this folder. For example, there are experiments on retraining MLP blocks and evaluating our statistics.
137
+
138
+ These include `experiments/localized_testing.py` (Section 3.2.1) for fine-grained forensics and layer-matching between two models; `experiments/csu_full.py` (Section 3.2.1) for full parameter-matching between any two model architectures for hybrid models; `experiments/generalized_match.py` (Section 2.3.2, 3.2.3, 3.2.4) for the generalized robust test that involes retraining or distilling GLU MLPs; and `experiments/huref.py` (Appendix F) where we reproduce and break the invariants from a related work (Zeng et al. 2024).
model-tracing/config/dolma_pl.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Dolma Programming Languages
2
+ save_dir: "/nlp/scr/ahmedah/trace_results"
3
+ dolma_json_dir: "/juice4/scr4/nlp/model-tracing/dolma_program_languages/json_files"
4
+ datasets:
5
+ - "cpp"
6
+ - "python"
7
+ - "js"
8
+ model_architectures:
9
+ - "llama"
model-tracing/config/llama3.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ---
2
+ - meta-llama/Meta-Llama-3-8B
3
+ - meta-llama/Meta-Llama-3-8B-Instruct
4
+ - meta-llama/Meta-Llama-3.1-8B
5
+ - meta-llama/Meta-Llama-3.1-8B-Instruct
model-tracing/config/llama70b.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ---
2
+ - meta-llama/Llama-2-70b-hf
3
+ - meta-llama/Llama-3.1-70B
4
+ - 152334H/miqu-1-70b-sf
5
+ - Writer/Palmyra-Fin-70B-32K
model-tracing/config/llama7b.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ - lmsys/vicuna-7b-v1.5
3
+ - codellama/CodeLlama-7b-hf
4
+ - codellama/CodeLlama-7b-Python-hf
5
+ - codellama/CodeLlama-7b-Instruct-hf
6
+ - EleutherAI/llemma_7b
7
+ - microsoft/Orca-2-7b
8
+ - oh-yeontaek/llama-2-7B-LoRA-assemble
9
+ - lvkaokao/llama2-7b-hf-instruction-lora
10
+ - NousResearch/Nous-Hermes-llama-2-7b
11
+ - lmsys/vicuna-7b-v1.1
12
+ - yahma/llama-7b-hf
13
+ - Salesforce/xgen-7b-4k-base
14
+ - EleutherAI/llemma_7b_muinstruct_camelmath
15
+ - AlfredPros/CodeLlama-7b-Instruct-Solidity
16
+ - meta-llama/Llama-2-7b-hf
17
+ - LLM360/Amber
18
+ - LLM360/AmberChat
19
+ - openlm-research/open_llama_7b
20
+ - openlm-research/open_llama_7b_v2
21
+ - ibm-granite/granite-7b-base
22
+ - ibm-granite/granite-7b-instruct
model-tracing/config/llama7b_split.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ ft_models:
3
+ - lmsys/vicuna-7b-v1.5
4
+ - codellama/CodeLlama-7b-hf
5
+ - codellama/CodeLlama-7b-Python-hf
6
+ - codellama/CodeLlama-7b-Instruct-hf
7
+ - EleutherAI/llemma_7b
8
+ - microsoft/Orca-2-7b
9
+ - oh-yeontaek/llama-2-7B-LoRA-assemble
10
+ - lvkaokao/llama2-7b-hf-instruction-lora
11
+ - NousResearch/Nous-Hermes-llama-2-7b
12
+ - lmsys/vicuna-7b-v1.1
13
+ - EleutherAI/llemma_7b_muinstruct_camelmath
14
+ - AlfredPros/CodeLlama-7b-Instruct-Solidity
15
+ - LLM360/AmberChat
16
+ - ibm-granite/granite-7b-instruct
17
+ base_models:
18
+ - meta-llama/Llama-2-7b-hf
19
+ - yahma/llama-7b-hf
20
+ - Salesforce/xgen-7b-4k-base
21
+ - LLM360/Amber
22
+ - openlm-research/open_llama_7b
23
+ - openlm-research/open_llama_7b_v2
24
+ - ibm-granite/granite-7b-base
model-tracing/config/llama7b_tree.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # base llama1
3
+ - yahma/llama-7b-hf:
4
+ - lmsys/vicuna-7b-v1.1
5
+ # base llama2
6
+ - meta-llama/Llama-2-7b-hf:
7
+ - lmsys/vicuna-7b-v1.5
8
+ - codellama/CodeLlama-7b-hf:
9
+ # code llama family
10
+ - codellama/CodeLlama-7b-Python-hf
11
+ - codellama/CodeLlama-7b-Instruct-hf:
12
+ # code llama instruct fine tune
13
+ - AlfredPros/CodeLlama-7b-Instruct-Solidity
14
+ - EleutherAI/llemma_7b:
15
+ # llemma fine tune
16
+ - EleutherAI/llemma_7b_muinstruct_camelmath
17
+ - microsoft/Orca-2-7b
18
+ - oh-yeontaek/llama-2-7B-LoRA-assemble
19
+ - lvkaokao/llama2-7b-hf-instruction-lora
20
+ - NousResearch/Nous-Hermes-llama-2-7b
21
+
22
+ # independent llama models
23
+ - Salesforce/xgen-7b-4k-base
24
+ - LLM360/Amber:
25
+ - LLM360/AmberChat
26
+ # two open llama models trained independently
27
+ - openlm-research/open_llama_7b
28
+ - openlm-research/open_llama_7b_v2
29
+ - ibm-granite/granite-7b-base:
30
+ # granite instruct variant
31
+ - ibm-granite/granite-7b-instruct
model-tracing/config/m2d2.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for M2D2 datasets
2
+ save_dir: "/nlp/scr/ahmedah/trace_results"
3
+ m2d2_json_dir: "/juice4/scr4/nlp/model-tracing/m2d2_s2orc"
4
+ datasets:
5
+ # These categories are derived from the ArXiv ontology
6
+ - "AI" # Artificial Intelligence
7
+ - "CV" # Computer Vision
8
+ - "ET" # Emerging Technologies
9
+ - "IM" # Information Management
10
+ - "mtrl-sci" # Materials Science
11
+ - "stat-mech" # Statistical Mechanics
12
+ - "AR" # Architecture
13
+ - "CY" # Cryptography and Security
14
+ - "IR" # Information Retrieval
15
+ - "NA" # Numerical Analysis
16
+ - "str-el" # Strongly Correlated Electrons
17
+ # Additional ArXiv categories can be added here
18
+
19
+ # These categories are derived from the Wikipedia ontology
20
+ # - "HEAL" # Health and Fitness
21
+ # - "HIST" # History and Events
22
+ # - "SOCI" # Society and Social Sciences
23
+ # - "TECH" # Technology and Applied Sciences
24
+ # - "CULT" # Culture and the Arts
25
+ # - "NATU" # Natural and Physical Sciences
26
+ # - "HUMA" # Human Activities
27
+ # - "MATH" # Mathematics and Logic
28
+ # - "GENE" # General Reference
29
+ # - "RELI" # Religion and Belief Systems
30
+ # - "PHIL" # Philosophy and Thinking
31
+
32
+ model_architectures:
33
+ - "llama"
34
+ - "olmo"
model-tracing/experiments/csw_full.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runs statistic Cosine Similarity of Weights on all tensors of two given models, if they match in size.
3
+ Part of the "Unconstrained Setting" experiments (see StripedHyena experiments).
4
+ Relevant for hybrid models where only some parameters are shared.
5
+
6
+ To run: Use the HuggingFace Ids for the two models in Line 65-66.
7
+ Prints p-values between tensors that align in dimension.
8
+ """
9
+
10
+ import torch
11
+ import scipy
12
+ from scipy.optimize import linear_sum_assignment as LAP
13
+ from transformers import AutoModelForCausalLM
14
+
15
+ from tracing.utils.utils import cossim, fisher
16
+
17
+ import warnings
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+
22
+ def csw_sp_pair(base_model, ft_model, layer_name_base, layer_name_ft):
23
+ """
24
+ Calculate Cosine Similarity of Weights between two specific layers.
25
+
26
+ Uses linear assignment to find optimal matching between neurons and
27
+ calculates Pearson correlation to quantify similarity.
28
+
29
+ Args:
30
+ base_model: First model to compare
31
+ ft_model: Second model to compare
32
+ layer_name_base: Name of the layer in the first model's state dict
33
+ layer_name_ft: Name of the layer in the second model's state dict
34
+
35
+ Returns:
36
+ float: p-value indicating the statistical similarity of weight matrices
37
+ """
38
+ base_mat = base_model.state_dict()[layer_name_base]
39
+ ft_mat = ft_model.state_dict()[layer_name_ft]
40
+
41
+ matched = LAP(cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True)
42
+ matched = matched[1]
43
+
44
+ orig = torch.arange(len(matched))
45
+ cor, pvalue = scipy.stats.pearsonr(matched.tolist(), orig.tolist())
46
+ return pvalue
47
+
48
+
49
+ def csw_models(base_model, ft_model):
50
+ """
51
+ Perform comprehensive pairwise comparisons between all compatible layers of two models.
52
+
53
+ Tests all possible layer pairings between models that have compatible shapes,
54
+ useful for exploring model structure similarities without assuming corresponding positions.
55
+
56
+ Args:
57
+ base_model: First model to compare
58
+ ft_model: Second model to compare
59
+
60
+ Returns:
61
+ float: Aggregate p-value from Fisher's method combining all layer comparisons,
62
+ or 999 if no compatible layers were found
63
+ """
64
+ base_model.to("cpu")
65
+ ft_model.to("cpu")
66
+
67
+ weights_base = base_model.state_dict()
68
+ weights_ft = ft_model.state_dict()
69
+
70
+ shapes_base = {}
71
+ shapes_ft = {}
72
+
73
+ for name1 in list(weights_base.keys()):
74
+ shapes_base[name1] = weights_base[name1].shape
75
+ for name2 in list(weights_ft.keys()):
76
+ shapes_ft[name2] = weights_ft[name2].shape
77
+
78
+ pvalues = []
79
+
80
+ for name1 in list(weights_base.keys()):
81
+ for name2 in list(weights_ft.keys()):
82
+ # print(name1,name2)
83
+ if shapes_base[name1] == shapes_ft[name2] and len(shapes_base[name1]) != 1:
84
+ pval = csw_sp_pair(base_model, ft_model, name1, name2)
85
+ print(name1, name2, pval)
86
+ pvalues.append(pval)
87
+
88
+ res = 0
89
+
90
+ if len(pvalues) == 0:
91
+ res = 999
92
+ else:
93
+ res = fisher(pvalues)
94
+
95
+ return res
96
+
97
+
98
+ def main():
99
+ model_1_id = "openai-community/gpt2"
100
+ model_2_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
101
+
102
+ print(model_1_id, model_2_id)
103
+
104
+ model_1 = AutoModelForCausalLM.from_pretrained(model_1_id, torch_dtype=torch.bfloat16)
105
+ model_2 = AutoModelForCausalLM.from_pretrained(model_2_id, torch_dtype=torch.bfloat16)
106
+
107
+ print(csw_models(model_1, model_2))
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
model-tracing/experiments/faiss/csw_faiss.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ import torch
4
+ from typing import Dict, Tuple, List, NamedTuple
5
+ import os
6
+ import pickle
7
+ import yaml
8
+ from transformers import AutoModelForCausalLM
9
+
10
+
11
+ class WeightInfo(NamedTuple):
12
+ """
13
+ A named tuple containing metadata about a weight matrix.
14
+
15
+ Attributes:
16
+ model_name: Name or identifier of the model
17
+ param_name: Name of the parameter in the model's state dict
18
+ dimensions: Tuple containing the shape of the weight matrix (d1, d2)
19
+ """
20
+
21
+ model_name: str
22
+ param_name: str
23
+ dimensions: Tuple[int, int]
24
+
25
+
26
+ class CSWSearch:
27
+ """
28
+ CSWSearch (Cosine Similarity of Weights Search) using FAISS for efficient similarity search.
29
+
30
+ This class enables fast indexing and retrieval of similar weight matrices across models,
31
+ organizing weight matrices by their dimensions to ensure comparable searches.
32
+ """
33
+
34
+ def __init__(self):
35
+ # Keep track of what each index position corresponds to
36
+ self.metadata: Dict[Tuple[int, int], List[WeightInfo]] = {}
37
+ # Track dimensions and index file locations
38
+ self.index_files: Dict[Tuple[int, int], str] = {}
39
+ # Directory where indices are stored
40
+ self.index_dir: str = "indexes"
41
+ # Currently loaded index
42
+ self.current_index: Tuple[Tuple[int, int], faiss.Index] = None
43
+
44
+ def add_weight_matrix(
45
+ self, model_name: str, param_name: str, weight_matrix: np.ndarray
46
+ ) -> None:
47
+ """
48
+ Add a weight matrix to the appropriate index based on its dimensions.
49
+
50
+ Args:
51
+ model_name: Name or identifier of the model
52
+ param_name: Name of the parameter in the model's state dict
53
+ weight_matrix: The weight matrix tensor to index
54
+
55
+ Returns:
56
+ None
57
+ """
58
+ print(f"Adding {model_name} {param_name}")
59
+ d1, d2 = weight_matrix.shape
60
+ dim_key = (d1, d2)
61
+
62
+ # First time seeing this dimension combination
63
+ if dim_key not in self.index_files:
64
+ self.metadata[dim_key] = []
65
+ self.index_files[dim_key] = f"index_{d1}x{d2}.index"
66
+
67
+ # Load the appropriate index
68
+ index = self._load_index(dim_key)
69
+
70
+ # Flatten matrix in row-major order and normalize
71
+ flat_weights = np.array(weight_matrix.to(dtype=torch.float32).reshape(1, -1).numpy())
72
+ faiss.normalize_L2(flat_weights) # for cosine similarity
73
+
74
+ # Add to appropriate index
75
+ index.add(flat_weights)
76
+
77
+ # Store metadata
78
+ self.metadata[dim_key].append(WeightInfo(model_name, param_name, (d1, d2)))
79
+
80
+ # Save the updated index
81
+ self._save_index(dim_key, index)
82
+
83
+ def find_similar_weights(
84
+ self, model_name: str, weight_matrix: np.ndarray, k: int = 5
85
+ ) -> List[Tuple[WeightInfo, float]]:
86
+ """
87
+ Find similar weight matrices with matching dimensions.
88
+
89
+ Searches for weight matrices most similar to the provided one,
90
+ but only among those with the same dimensions.
91
+
92
+ Args:
93
+ model_name: Name or identifier of the model (used to exclude self-matches)
94
+ weight_matrix: The weight matrix tensor to search for
95
+ k: Number of similar matrices to return (default: 5)
96
+
97
+ Returns:
98
+ List of tuples containing (WeightInfo, similarity_score)
99
+
100
+ Raises:
101
+ ValueError: If no weight matrices with matching dimensions are found
102
+ """
103
+ d1, d2 = weight_matrix.shape
104
+ dim_key = (d1, d2)
105
+
106
+ if dim_key not in self.index_files:
107
+ raise ValueError(f"No weight matrices found with dimensions {dim_key}")
108
+
109
+ # Load the appropriate index
110
+ index = self._load_index(dim_key)
111
+
112
+ # Prepare query in same way as stored matrices
113
+ query = np.array(weight_matrix.to(dtype=torch.float32).reshape(1, -1).numpy())
114
+ faiss.normalize_L2(query)
115
+
116
+ # Search
117
+ distances, indices = index.search(query, k + 1) # +1 for self-match
118
+
119
+ # Format results (excluding self-match)
120
+ results = []
121
+ for idx, sim in zip(indices[0], distances[0]):
122
+ info = self.metadata[dim_key][idx]
123
+ if info.model_name != model_name: # Skip self-match
124
+ results.append((info, float(sim)))
125
+
126
+ return results[:k]
127
+
128
+ def _load_index(self, dim_key: Tuple[int, int]) -> faiss.Index:
129
+ """
130
+ Load or create the FAISS index for a specific dimension.
131
+
132
+ Args:
133
+ dim_key: Tuple of dimensions (d1, d2)
134
+
135
+ Returns:
136
+ faiss.Index: The loaded or newly created index
137
+ """
138
+ if self.current_index and self.current_index[0] == dim_key:
139
+ return self.current_index[1]
140
+
141
+ d1, d2 = dim_key
142
+ index_path = os.path.join(self.index_dir, self.index_files[dim_key])
143
+
144
+ if os.path.exists(index_path):
145
+ try:
146
+ index = faiss.read_index(index_path)
147
+ except RuntimeError:
148
+ print(f"Error reading index file {index_path}. Creating a new index.")
149
+ index = faiss.IndexFlatIP(d1 * d2)
150
+ else:
151
+ print(f"Index file {index_path} not found. Creating a new index.")
152
+ index = faiss.IndexFlatIP(d1 * d2)
153
+
154
+ self.current_index = (dim_key, index)
155
+ return index
156
+
157
+ def _save_index(self, dim_key: Tuple[int, int], index: faiss.Index):
158
+ """
159
+ Save the index for the given dimensions to disk.
160
+
161
+ Args:
162
+ dim_key: Tuple of dimensions (d1, d2)
163
+ index: The FAISS index to save
164
+
165
+ Returns:
166
+ None
167
+ """
168
+ index_path = os.path.join(self.index_dir, self.index_files[dim_key])
169
+ faiss.write_index(index, index_path)
170
+
171
+ def save(self, directory: str):
172
+ """
173
+ Save the entire search system (metadata and indexes) to disk.
174
+
175
+ Args:
176
+ directory: Directory where indices and metadata will be stored
177
+
178
+ Returns:
179
+ None
180
+ """
181
+ self.index_dir = directory
182
+ os.makedirs(directory, exist_ok=True)
183
+
184
+ if self.current_index:
185
+ self._save_index(self.current_index[0], self.current_index[1])
186
+
187
+ metadata_path = os.path.join(directory, "metadata.pkl")
188
+ with open(metadata_path, "wb") as f:
189
+ pickle.dump(self.metadata, f)
190
+
191
+ index_files_path = os.path.join(directory, "index_files.pkl")
192
+ with open(index_files_path, "wb") as f:
193
+ pickle.dump(self.index_files, f)
194
+
195
+ @classmethod
196
+ def load(cls, directory: str):
197
+ """
198
+ Load a previously saved search system from disk.
199
+
200
+ Args:
201
+ directory: Directory where indices and metadata are stored
202
+
203
+ Returns:
204
+ CSWSearch: The loaded search system
205
+ """
206
+ csw_search = cls()
207
+ csw_search.index_dir = directory
208
+
209
+ metadata_path = os.path.join(directory, "metadata.pkl")
210
+ with open(metadata_path, "rb") as f:
211
+ csw_search.metadata = pickle.load(f)
212
+
213
+ index_files_path = os.path.join(directory, "index_files.pkl")
214
+ with open(index_files_path, "rb") as f:
215
+ csw_search.index_files = pickle.load(f)
216
+
217
+ return csw_search
218
+
219
+
220
+ csw = CSWSearch()
221
+
222
+
223
+ def add_params(model_list):
224
+ """
225
+ Index weight matrices from a list of HuggingFace model IDs.
226
+
227
+ Loads each model, extracts its parameters, and adds all 2D weight matrices
228
+ to the CSWSearch index for later similarity search.
229
+
230
+ Args:
231
+ model_list: List of HuggingFace model IDs to index
232
+
233
+ Returns:
234
+ None: Updates the global csw search index
235
+ """
236
+ for model_id in model_list:
237
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
238
+ weights = model.state_dict()
239
+ params = list(weights.keys())
240
+ for param in params:
241
+ # Skip 1D tensors (like bias terms or layer norms)
242
+ if len(weights[param].shape) == 1:
243
+ continue
244
+ csw.add_weight_matrix(model_id, param_name=param, weight_matrix=weights[param])
245
+
246
+
247
+ def get_similar_param(param, k=5):
248
+ """
249
+ Find similar parameters to the given weight matrix across indexed models.
250
+
251
+ Args:
252
+ param: Weight matrix tensor to search for
253
+ k: Number of similar matrices to return (default: 5)
254
+
255
+ Returns:
256
+ List of tuples containing (WeightInfo, similarity_score)
257
+ """
258
+ return csw.find_similar_weights("--", param, k=k)
259
+
260
+
261
+ def main():
262
+ # Model list to add from yaml
263
+ model_list = yaml.safe_load(open("config/llama7b.yaml", "r"))
264
+ add_params(model_list)
265
+ csw.save("indexes")
266
+
267
+ # Weight matrix to search for
268
+ model = AutoModelForCausalLM.from_pretrained(
269
+ "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16
270
+ )
271
+ weights = model.state_dict()
272
+ attn_name = "model.layers.0.self_attn.o_proj.weight"
273
+
274
+ print(get_similar_param(weights[attn_name]))
275
+
276
+ return
277
+
278
+
279
+ if __name__ == "__main__":
280
+ main()
model-tracing/experiments/generalized_match.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the code to do generalize \phi_MATCH.
3
+ Given two models, it first retrains the MLPs or other FFN of the models, then runs \phi_MATCH on the distilled MLPs.
4
+
5
+ May need to be modified depending on model architecture (this code was used for GPT-architecture).
6
+ """
7
+
8
+ MLP_SIZE = 3072
9
+ MLP_SIZE_2 = 3072
10
+ EMB_SIZE = 768
11
+ EMB_SIZE_2 = 768
12
+ N_BLOCKS = 12
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ AutoConfig,
20
+ GPT2LMHeadModel,
21
+ OpenAIGPTLMHeadModel,
22
+ )
23
+ import scipy
24
+ from collections import defaultdict
25
+
26
+ import numpy as np
27
+
28
+ from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate
29
+ from tracing.utils.evaluate import (
30
+ prepare_hf_dataset,
31
+ prepare_hf_dataloader,
32
+ )
33
+
34
+ from tracing.utils.utils import manual_seed
35
+ from tracing.utils.llama.matching import match_wmats
36
+
37
+ manual_seed(0)
38
+
39
+
40
+ # architecture of MLP trained from scratch can be different from original
41
+ # eg, uncomment to get a 2-hidden layer MLP (original has just 1 hidden layer)
42
+ class CustomLlamaMLP(nn.Module):
43
+ """
44
+ Custom MLP module for Llama-style architecture with SwiGLU activation.
45
+
46
+ This implementation allows for flexible architecture changes when training
47
+ replacement MLPs for model distillation and analysis.
48
+
49
+ Args:
50
+ config: Model configuration containing embedding dimensions
51
+ """
52
+
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ self.config = config
56
+ self.hidden_size = config.n_embd
57
+ self.intermediate_size = 4 * config.n_embd
58
+
59
+ self.gate_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+ self.act_fn = nn.SiLU()
64
+
65
+ def forward(self, x):
66
+ """
67
+ Forward pass implementing SwiGLU activation function.
68
+
69
+ Args:
70
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
71
+
72
+ Returns:
73
+ torch.Tensor: Output tensor after MLP transformation
74
+ """
75
+ down_proj = self.down_proj1(self.act_fn(self.gate_proj1(x)) * self.up_proj1(x))
76
+ return down_proj
77
+
78
+
79
+ def hook_out(m, inp, op, feats, name):
80
+ """
81
+ Forward hook to capture output activations from model layers.
82
+
83
+ Args:
84
+ m: Module being hooked
85
+ inp: Input to the module
86
+ op: Output from the module
87
+ feats: Dictionary to store activations
88
+ name: Key to store the activations under
89
+ """
90
+ feats[name].append(op.detach().cpu())
91
+
92
+
93
+ def hook_in(m, inp, op, feats, name):
94
+ """
95
+ Forward hook to capture input activations to model layers.
96
+
97
+ Args:
98
+ m: Module being hooked
99
+ inp: Input to the module (tuple)
100
+ op: Output from the module
101
+ feats: Dictionary to store activations
102
+ name: Key to store the activations under
103
+ """
104
+ feats[name].append(inp[0].detach().cpu())
105
+
106
+
107
+ def mlp_layers(base_model_gate, base_model_up, ft_model_gate, ft_model_up, dataloader, i, j):
108
+ """
109
+ Compare gate and up projections between separate model components.
110
+
111
+ Tests whether separately trained gate and up projection models have
112
+ consistent permutation patterns, which would indicate functionally
113
+ corresponding neurons.
114
+
115
+ Args:
116
+ base_model_gate: First model with gate projection weights
117
+ base_model_up: First model with up projection weights
118
+ ft_model_gate: Second model with gate projection weights
119
+ ft_model_up: Second model with up projection weights
120
+ dataloader: DataLoader providing input data for activation collection
121
+ i: Layer index in the first model
122
+ j: Layer index in the second model
123
+
124
+ Returns:
125
+ float: Pearson correlation p-value between gate and up projection permutations
126
+ """
127
+ gate_match = mlp_matching(base_model_gate, ft_model_gate, dataloader, i, j)
128
+ up_match = mlp_matching(base_model_up, ft_model_up, dataloader, i, j)
129
+
130
+ print(gate_match, up_match, i, j)
131
+
132
+ cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist())
133
+ return pvalue
134
+
135
+
136
+ def mlp_matching(base_model, ft_model, dataloader, i, j):
137
+ """
138
+ Match neurons between models by comparing activations.
139
+
140
+ Collects activations from the feed-forward layer for both models
141
+ and computes a permutation that would align corresponding neurons.
142
+
143
+ Args:
144
+ base_model: Base model to compare
145
+ ft_model: Target model to compare against the base model
146
+ dataloader: DataLoader providing input data for activation collection
147
+ i: Layer index in the base model
148
+ j: Layer index in the target model
149
+
150
+ Returns:
151
+ torch.Tensor: Permutation indices that match neurons between models
152
+ """
153
+ feats = defaultdict(list)
154
+
155
+ base_hook = lambda *args: hook_out(*args, feats, "base")
156
+ base_handle = base_model.transformer.h[i].mlp.c_fc.register_forward_hook(base_hook)
157
+
158
+ ft_hook = lambda *args: hook_out(*args, feats, "ft")
159
+ ft_handle = ft_model.transformer.h[i].mlp.c_fc.register_forward_hook(ft_hook)
160
+
161
+ evaluate(base_model, dataloader)
162
+ evaluate(ft_model, dataloader)
163
+
164
+ base_mat = torch.vstack(feats["base"])
165
+ ft_mat = torch.vstack(feats["ft"])
166
+
167
+ base_mat.to("cuda")
168
+ ft_mat.to("cuda")
169
+
170
+ base_mat = base_mat.view(-1, base_mat.shape[-1]).T
171
+ ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
172
+
173
+ base_handle.remove()
174
+ ft_handle.remove()
175
+
176
+ perm = match_wmats(base_mat, ft_mat)
177
+ return perm
178
+
179
+
180
+ def run(i):
181
+ """
182
+ Run the generalized MATCH algorithm to compare models with distilled components.
183
+
184
+ This function:
185
+ 1. Loads two different GPT-2 models
186
+ 2. Trains custom MLPs to replicate the behavior of the original model MLPs
187
+ 3. Optionally applies random rotations to test invariance to representation changes
188
+ 4. Creates separate models for gate and up projections
189
+ 5. Compares the models using the MATCH algorithm
190
+
191
+ The process demonstrates that functionally equivalent components can be identified
192
+ even after representation changes, by examining activation patterns.
193
+
194
+ Args:
195
+ i: Layer index to focus on for the analysis
196
+
197
+ Returns:
198
+ None: Prints p-value results from the neuron matching
199
+ """
200
+ train_losses = []
201
+
202
+ model_id_2 = "manupande21/GPT2_PMC"
203
+ model_id_1 = "openai-community/gpt2"
204
+
205
+ tokenizer = AutoTokenizer.from_pretrained(model_id_1, use_fast=False)
206
+ model = AutoModelForCausalLM.from_pretrained(model_id_1, torch_dtype=torch.bfloat16)
207
+
208
+ base_tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
209
+ dataset_wikitext = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer)
210
+ dataloader_wikitext = prepare_hf_dataloader(dataset_wikitext, 1)
211
+
212
+ config = AutoConfig.from_pretrained(model_id_1)
213
+
214
+ i = 0 # layer to retrain
215
+ bsz = 5000 # batch size
216
+ T = 10000 # gradient steps
217
+ width_fac = 1.0 # easier to get loss down for wider MLPs when retraining
218
+
219
+ # Train the first custom MLP
220
+ mlp = CustomLlamaMLP(config).bfloat16()
221
+
222
+ mlp.to("cuda")
223
+ model.transformer.h[i].mlp.to("cuda")
224
+
225
+ criterion = torch.nn.MSELoss()
226
+ optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
227
+ print(f"Training MLP {model_id_1}")
228
+
229
+ # Random rotation matrix to test invariance to representation changes
230
+ A = torch.randn(size=(EMB_SIZE, EMB_SIZE), device="cuda").bfloat16() / np.sqrt(
231
+ EMB_SIZE
232
+ ) # rotate outputs (just for kicks / sanity check)
233
+
234
+ # Distillation training loop for first model
235
+ for t in range(T):
236
+ X_batch = torch.randn(size=(bsz, EMB_SIZE), dtype=torch.bfloat16, device="cuda")
237
+ with torch.no_grad():
238
+ Y_batch = model.transformer.h[i].mlp(X_batch)
239
+ Y_batch = Y_batch @ A.T # Apply rotation to outputs
240
+
241
+ Y_h = mlp(X_batch)
242
+
243
+ optimizer.zero_grad()
244
+ loss = criterion(Y_h, Y_batch)
245
+
246
+ loss.backward()
247
+ optimizer.step()
248
+
249
+ if t % 1000 == 0:
250
+ print(f"train loss: {loss.item()}")
251
+ train_losses.append(loss.item())
252
+
253
+ # Create separate models for gate and up projections
254
+ config = AutoConfig.from_pretrained(model_id_1)
255
+ config.intermediate_size = int(width_fac * MLP_SIZE)
256
+
257
+ model_retrained_1_gate = OpenAIGPTLMHeadModel(config).bfloat16()
258
+ model_retrained_1_up = OpenAIGPTLMHeadModel(config).bfloat16()
259
+ model.to("cpu")
260
+ mlp.to("cpu")
261
+
262
+ # Loading retrained weights to model_retrained
263
+ weights_1_gate = model.state_dict()
264
+ weights_1_up = model.state_dict()
265
+
266
+ weights_1_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T
267
+ weights_1_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T
268
+ model_retrained_1_gate.load_state_dict(weights_1_gate)
269
+ model_retrained_1_up.load_state_dict(weights_1_up)
270
+
271
+ # Retraining / distilling second model
272
+ model = AutoModelForCausalLM.from_pretrained(model_id_2, torch_dtype=torch.bfloat16)
273
+
274
+ config = AutoConfig.from_pretrained(model_id_2)
275
+ mlp = CustomLlamaMLP(config).bfloat16()
276
+
277
+ mlp.to("cuda")
278
+ model.transformer.h[i].mlp.to("cuda")
279
+
280
+ criterion = torch.nn.MSELoss()
281
+ optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
282
+
283
+ print(f"Training MLP {model_id_2}")
284
+
285
+ # Different random rotation matrix for second model
286
+ A = torch.randn(size=(EMB_SIZE_2, EMB_SIZE_2), device="cuda").bfloat16() / np.sqrt(
287
+ EMB_SIZE_2
288
+ ) # rotate outputs (just for kicks / sanity check)
289
+
290
+ # Distillation training loop for second model
291
+ for t in range(T):
292
+ X_batch = torch.randn(size=(bsz, EMB_SIZE_2), dtype=torch.bfloat16, device="cuda")
293
+ with torch.no_grad():
294
+ Y_batch = model.transformer.h[i].mlp(X_batch)
295
+ Y_batch = Y_batch @ A.T # Apply rotation to outputs
296
+
297
+ Y_h = mlp(X_batch)
298
+
299
+ optimizer.zero_grad()
300
+ loss = criterion(Y_h, Y_batch)
301
+
302
+ loss.backward()
303
+ optimizer.step()
304
+
305
+ if t % 1000 == 0:
306
+ print(f"train loss: {loss.item()}")
307
+ train_losses.append(loss.item())
308
+
309
+ # Create separate models for gate and up projections
310
+ config = AutoConfig.from_pretrained(model_id_2)
311
+ config.intermediate_size = int(width_fac * MLP_SIZE_2)
312
+
313
+ model_retrained_2_gate = GPT2LMHeadModel(config).bfloat16()
314
+ model_retrained_2_up = GPT2LMHeadModel(config).bfloat16()
315
+ model.to("cpu")
316
+ mlp.to("cpu")
317
+
318
+ weights_2_gate = model.state_dict()
319
+ weights_2_up = model.state_dict()
320
+
321
+ weights_2_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T
322
+ weights_2_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T
323
+
324
+ model_retrained_2_gate.load_state_dict(weights_2_gate)
325
+ model_retrained_2_up.load_state_dict(weights_2_up)
326
+
327
+ # Run MATCH algorithm to compare the models
328
+ print(
329
+ mlp_layers(
330
+ model_retrained_1_gate,
331
+ model_retrained_1_up,
332
+ model_retrained_2_gate,
333
+ model_retrained_2_up,
334
+ dataloader,
335
+ 0,
336
+ 0,
337
+ )
338
+ )
339
+
340
+
341
+ if __name__ == "__main__":
342
+ for i in range(0, 10):
343
+ run(i)
model-tracing/experiments/huref.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM
3
+
4
+ from tracing.utils.llama.model import rotate_model_t5
5
+
6
+ torch.set_default_dtype(torch.bfloat16)
7
+
8
+ model_name = "meta-llama/Llama-2-7b-hf"
9
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
10
+
11
+ model_rot_name = "yahma/llama-7b-hf"
12
+ model_rotated = AutoModelForCausalLM.from_pretrained(model_rot_name, torch_dtype=torch.bfloat16).to(
13
+ "cuda"
14
+ )
15
+ rotate_model_t5(model_rotated)
16
+ # rotate_model(model_rotated)
17
+
18
+ # Fixing the layer norms to 1's (HUReF works)
19
+ # """
20
+ # fix_layer_norm(model)
21
+ # fix_layer_norm(model_rotated)
22
+ # """
23
+
24
+ # base_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
25
+ # dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized",512,base_tokenizer)
26
+ # dataloader = prepare_hf_dataloader(dataset,1)
27
+
28
+ # evaluate_model = evaluate(model, dataloader)
29
+ # evaluate_rotated = evaluate(model_rotated, dataloader)
30
+
31
+ # print("outputs are aligned: ")
32
+ # print([abs(evaluate_model[i] - evaluate_rotated[i]) <= 0.01 for i in range(len(evaluate_model))])
33
+
34
+ weights = model.state_dict()
35
+ weights_rotated = model_rotated.state_dict()
36
+
37
+ # model.to('cuda')
38
+ # print("invariant 1")
39
+ # print(weights['model.embed_tokens.weight']@weights['model.layers.0.self_attn.q_proj.weight'].T@weights['model.layers.0.self_attn.k_proj.weight']@weights['model.embed_tokens.weight'].T)
40
+ # print("invariant 2")
41
+ # print(weights['model.embed_tokens.weight']@weights['model.layers.0.self_attn.v_proj.weight'].T@weights['model.layers.0.self_attn.o_proj.weight'].T@weights['model.embed_tokens.weight'].T)
42
+ # print("invariant 3")
43
+ # print(weights['model.embed_tokens.weight']@weights[f'model.layers.0.mlp.up_proj.weight'].T@weights[f'model.layers.0.mlp.down_proj.weight'].T@weights['model.embed_tokens.weight'].T)
44
+ # print()
45
+ # model.to('cpu')
46
+
47
+ # model_rotated.to('cuda')
48
+ # print("rotated")
49
+ # print("invariant 1")
50
+ # print(weights_rotated['model.embed_tokens.weight']@weights_rotated['model.layers.0.self_attn.q_proj.weight'].T@weights_rotated['model.layers.0.self_attn.k_proj.weight']@weights_rotated['model.embed_tokens.weight'].T)
51
+
52
+
53
+ # print("invariant 2")
54
+ # print(weights_rotated['model.embed_tokens.weight']@weights_rotated['model.layers.0.self_attn.v_proj.weight'].T@weights_rotated['model.layers.0.self_attn.o_proj.weight'].T@weights_rotated['model.embed_tokens.weight'].T)
55
+ # print("invariant 3")
56
+ # print(weights_rotated['model.embed_tokens.weight']@weights_rotated[f'model.layers.0.mlp.up_proj.weight'].T@weights_rotated[f'model.layers.0.mlp.down_proj.weight'].T@weights_rotated['model.embed_tokens.weight'].T)
57
+ # print()
58
+ # model_rotated.to('cpu')
59
+
60
+ # Cosine similarity
61
+
62
+ print("cosine similarity")
63
+
64
+ model.to("cuda")
65
+ print("invariant 1")
66
+ invariant = (
67
+ weights["model.embed_tokens.weight"]
68
+ @ weights["model.layers.0.self_attn.q_proj.weight"].T
69
+ @ weights["model.layers.0.self_attn.k_proj.weight"]
70
+ @ (weights["model.embed_tokens.weight"]).T
71
+ )
72
+ model.to("cpu")
73
+ model_rotated.to("cuda")
74
+ invariant_rotated = (
75
+ weights_rotated["model.embed_tokens.weight"]
76
+ @ weights_rotated["model.layers.0.self_attn.q_proj.weight"].T
77
+ @ weights_rotated["model.layers.0.self_attn.k_proj.weight"]
78
+ @ (weights_rotated["model.embed_tokens.weight"]).T
79
+ )
80
+ model_rotated.to("cpu")
81
+ invariant.to("cuda")
82
+ invariant_rotated.to("cuda")
83
+ invariant = torch.flatten(invariant)
84
+ invariant_rotated = torch.flatten(invariant_rotated)
85
+ print(
86
+ torch.dot(invariant, invariant_rotated)
87
+ / (torch.norm(invariant) * torch.norm(invariant_rotated))
88
+ )
89
+
90
+ model.to("cuda")
91
+ print("invariant 2")
92
+ invariant = (
93
+ weights["model.embed_tokens.weight"]
94
+ @ weights["model.layers.0.self_attn.v_proj.weight"].T
95
+ @ weights["model.layers.0.self_attn.o_proj.weight"].T
96
+ @ weights["model.embed_tokens.weight"].T
97
+ )
98
+ model.to("cpu")
99
+ model_rotated.to("cuda")
100
+ invariant_rotated = (
101
+ weights_rotated["model.embed_tokens.weight"]
102
+ @ weights_rotated["model.layers.0.self_attn.v_proj.weight"].T
103
+ @ weights_rotated["model.layers.0.self_attn.o_proj.weight"].T
104
+ @ weights_rotated["model.embed_tokens.weight"].T
105
+ )
106
+ model_rotated.to("cpu")
107
+ invariant.to("cuda")
108
+ invariant_rotated.to("cuda")
109
+ invariant = torch.flatten(invariant)
110
+ invariant_rotated = torch.flatten(invariant_rotated)
111
+ print(
112
+ torch.dot(invariant, invariant_rotated)
113
+ / (torch.norm(invariant) * torch.norm(invariant_rotated))
114
+ )
115
+
116
+ model.to("cuda")
117
+ print("invariant 3")
118
+ invariant = (
119
+ weights["model.embed_tokens.weight"]
120
+ @ weights["model.layers.0.mlp.up_proj.weight"].T
121
+ @ weights["model.layers.0.mlp.down_proj.weight"].T
122
+ @ weights["model.embed_tokens.weight"].T
123
+ )
124
+ model.to("cpu")
125
+ model_rotated.to("cuda")
126
+ invariant_rotated = (
127
+ weights_rotated["model.embed_tokens.weight"]
128
+ @ weights_rotated["model.layers.0.mlp.up_proj.weight"].T
129
+ @ weights_rotated["model.layers.0.mlp.down_proj.weight"].T
130
+ @ weights_rotated["model.embed_tokens.weight"].T
131
+ )
132
+ model_rotated.to("cpu")
133
+ invariant.to("cuda")
134
+ invariant_rotated.to("cuda")
135
+ invariant = torch.flatten(invariant)
136
+ invariant_rotated = torch.flatten(invariant_rotated)
137
+ print(
138
+ torch.dot(invariant, invariant_rotated)
139
+ / (torch.norm(invariant) * torch.norm(invariant_rotated))
140
+ )
model-tracing/experiments/localized_testing.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Localized testing experiments for two models.
3
+ Runs \phi_MATCH on all pairs of GLU MLPs between two models and identifies a match.
4
+ Also can uncomment code to print the aligned activations.
5
+
6
+ To run: Use HuggingFace model Ids in Lines 104-05.
7
+ """
8
+
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+
12
+ from tracing.utils.evaluate import evaluate
13
+ from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader
14
+ from tracing.statistics.mlp_sp import hook_out
15
+ from tracing.utils.evaluate import (
16
+ prepare_hf_dataset,
17
+ prepare_hf_dataloader,
18
+ evaluate,
19
+ )
20
+ from tracing.utils.llama.matching import match_wmats
21
+ from collections import defaultdict
22
+
23
+ import scipy
24
+ import warnings
25
+ import numpy as np
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+
30
+ def mlp_matching_gate(base_model, ft_model, dataloader, i, j):
31
+ feats = defaultdict(list)
32
+
33
+ base_hook = lambda *args: hook_out(*args, feats, "base")
34
+ base_handle = base_model.model.layers[i].mlp.gate_proj.register_forward_hook(base_hook)
35
+
36
+ ft_hook = lambda *args: hook_out(*args, feats, "ft")
37
+ ft_handle = ft_model.model.layers[j].mlp.gate_proj.register_forward_hook(ft_hook)
38
+
39
+ evaluate(base_model, dataloader)
40
+ evaluate(ft_model, dataloader)
41
+
42
+ base_mat = torch.vstack(feats["base"])
43
+ ft_mat = torch.vstack(feats["ft"])
44
+
45
+ base_mat.to("cuda")
46
+ ft_mat.to("cuda")
47
+
48
+ base_mat = base_mat.view(-1, base_mat.shape[-1]).T
49
+ ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
50
+
51
+ # If want to print the activations matching for localized testing (See Llama3.2-3B and Llama3.1-8B activation matching experiment)
52
+
53
+ """
54
+ ft_mat = torch.norm(ft_mat,dim=1)
55
+ sorted = torch.sort(torch.argsort(ft_mat)[:8192])[0]
56
+ for i in sorted:
57
+ print(i.item(),end=" ")
58
+ """
59
+
60
+ base_handle.remove()
61
+ ft_handle.remove()
62
+
63
+ perm = match_wmats(base_mat, ft_mat)
64
+
65
+ return perm
66
+
67
+
68
+ def mlp_matching_up(base_model, ft_model, dataloader, i, j):
69
+ feats = defaultdict(list)
70
+
71
+ base_hook = lambda *args: hook_out(*args, feats, "base")
72
+ base_handle = base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hook)
73
+
74
+ ft_hook = lambda *args: hook_out(*args, feats, "ft")
75
+ ft_handle = ft_model.model.layers[j].mlp.up_proj.register_forward_hook(ft_hook)
76
+
77
+ evaluate(base_model, dataloader)
78
+ evaluate(ft_model, dataloader)
79
+
80
+ base_mat = torch.vstack(feats["base"])
81
+ ft_mat = torch.vstack(feats["ft"])
82
+
83
+ base_mat.to("cuda")
84
+ ft_mat.to("cuda")
85
+
86
+ base_mat = base_mat.view(-1, base_mat.shape[-1]).T
87
+ ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
88
+
89
+ base_handle.remove()
90
+ ft_handle.remove()
91
+
92
+ perm = match_wmats(base_mat, ft_mat)
93
+
94
+ return perm
95
+
96
+
97
+ def mlp_layers(base_model, ft_model, dataloader, i, j):
98
+
99
+ gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i, j)
100
+ up_match = mlp_matching_up(base_model, ft_model, dataloader, i, j)
101
+
102
+ for g in gate_match:
103
+ print(g.item(), end=" ")
104
+
105
+ cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist())
106
+
107
+ return pvalue
108
+
109
+
110
+ def main():
111
+ model_1_id = "meta-llama/Llama-2-7b-hf"
112
+ model_2_id = "princeton-nlp/Sheared-LLaMA-2.7B"
113
+
114
+ print(model_1_id, model_2_id)
115
+
116
+ model_1 = AutoModelForCausalLM.from_pretrained(model_1_id, torch_dtype=torch.bfloat16)
117
+ model_2 = AutoModelForCausalLM.from_pretrained(model_2_id, torch_dtype=torch.bfloat16)
118
+
119
+ tokenizer = AutoTokenizer.from_pretrained(model_1_id)
120
+
121
+ dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, tokenizer)
122
+ dataloader = prepare_hf_dataloader(dataset, 1)
123
+
124
+ print(model_1.config.num_hidden_layers, model_2.config.num_hidden_layers)
125
+
126
+ model_1_matched = np.zeros(model_1.config.num_hidden_layers)
127
+ model_2_matched = np.zeros(model_2.config.num_hidden_layers)
128
+
129
+ for i in range(model_1.config.num_hidden_layers):
130
+ for j in range(model_2.config.num_hidden_layers):
131
+ if model_1_matched[i] == 1 or model_2_matched[j] == 1:
132
+ continue
133
+ stat = mlp_layers(model_1, model_2, dataloader, i, j)
134
+ print(i, j, stat)
135
+ if stat < 0.000001:
136
+ model_1_matched[i] = 1
137
+ model_2_matched[j] = 1
138
+ break
139
+ break
140
+ break
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
model-tracing/launch.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from yaml import Loader
3
+
4
+ import subprocess
5
+ import argparse
6
+
7
+
8
+ parser = argparse.ArgumentParser(description="Experiment Settings")
9
+
10
+ parser.add_argument("--slurm", default="nlprun -g 1 -d a6000 -r 80G -a model-tracing", type=str)
11
+ parser.add_argument("--python", default="python experiment.py", type=str)
12
+ parser.add_argument("--models", default="config/llama_flat.yaml", type=str)
13
+ parser.add_argument("--save", default="/juice4/scr4/nlp/model-tracing", type=str)
14
+ parser.add_argument("--flat", default="all", type=str)
15
+
16
+ args = parser.parse_args()
17
+
18
+ model_paths = yaml.load(open(args.models, "r"), Loader=Loader)
19
+
20
+ subprocess.run(f"mkdir -p {args.save}/logs", shell=True)
21
+ subprocess.run(f"mkdir -p {args.save}/results", shell=True)
22
+
23
+ if args.flat == "all":
24
+ for i in range(len(model_paths)):
25
+ for j in range(i + 1, len(model_paths)):
26
+ model_a = model_paths[i]
27
+ model_b = model_paths[j]
28
+
29
+ job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-")
30
+ if "miqu" not in job_id:
31
+ continue
32
+ if job_id[0] == "-":
33
+ job_id = job_id[1:]
34
+
35
+ log_path = args.save + "/logs/" + job_id + ".out"
36
+ results_path = args.save + "/results/" + job_id + ".p"
37
+
38
+ job = (
39
+ args.slurm + f" -o {log_path} -n {job_id}"
40
+ f" '{args.python}"
41
+ + f" --base_model_id {model_a} --ft_model_id {model_b} --save {results_path}'"
42
+ )
43
+ subprocess.run(job, shell=True)
44
+
45
+ elif args.flat == "split":
46
+ base_models = model_paths["base_models"]
47
+ ft_models = model_paths["ft_models"]
48
+
49
+ for base_model in base_models:
50
+ for ft_model in ft_models:
51
+ job_id = base_model.replace("/", "-") + "_AND_" + ft_model.replace("/", "-")
52
+
53
+ log_path = args.save + "/logs/" + job_id + ".out"
54
+ results_path = args.save + "/results/" + job_id + ".p"
55
+
56
+ job = (
57
+ args.slurm + f" -o {log_path} -n {job_id}"
58
+ f" '{args.python}"
59
+ + f" --base_model_id {base_model} --ft_model_id {ft_model} --save {results_path}'"
60
+ )
61
+ subprocess.run(job, shell=True)
62
+
63
+ elif args.flat == "specified":
64
+ models_1 = model_paths["models_1"]
65
+ models_2 = model_paths["models_2"]
66
+
67
+ names_1 = model_paths["names_1"]
68
+ names_2 = model_paths["names_2"]
69
+
70
+ for i in range(len(models_1)):
71
+ model_a = models_1[i]
72
+ model_b = models_2[i]
73
+
74
+ name_a = names_1[i]
75
+ name_b = names_2[i]
76
+
77
+ job_id = name_a.replace("/", "-") + "_AND_" + name_b.replace("/", "-")
78
+
79
+ log_path = args.save + "/logs/" + job_id + ".out"
80
+ results_path = args.save + "/results/" + job_id + ".p"
81
+
82
+ job = (
83
+ args.slurm + f" -o {log_path} -n {job_id}"
84
+ f" '{args.python}"
85
+ + f" --base_model_id {model_a} --ft_model_id {model_b} --save {results_path}'"
86
+ )
87
+ subprocess.run(job, shell=True)
model-tracing/main.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MLP_SIZE = 11008
2
+ EMB_SIZE = 4096
3
+
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ GPTNeoXTokenizerFast,
9
+ )
10
+
11
+ import argparse
12
+ import pickle
13
+ import timeit
14
+ import subprocess
15
+ import os
16
+
17
+ from tracing.utils.llama.model import permute_model, rotate_model
18
+ from tracing.utils.olmo.model import permute_model as permute_model_olmo
19
+ from tracing.utils.llama.matching import align_model
20
+ from tracing.utils.evaluate import (
21
+ prepare_hf_dataset,
22
+ prepare_aya_dataset,
23
+ prepare_hf_dataloader,
24
+ evaluate,
25
+ load_dolma_programming_datasets,
26
+ load_m2d2_datasets,
27
+ load_generated_datasets,
28
+ prepare_random_sample_dataset,
29
+ )
30
+ from tracing.utils.utils import manual_seed
31
+
32
+ from tracing.statistics.mc import statistic as mode_stat
33
+ from tracing.statistics.l2 import statistic as l2_stat
34
+ from tracing.statistics.jsd import statistic as jsd_stat
35
+ from tracing.statistics.csu import statistic as csu_stat
36
+ from tracing.statistics.csu import statistic_all as csu_all_stat
37
+ from tracing.statistics.csh import statistic as csh_stat
38
+ from tracing.statistics.match import statistic as match_stat
39
+ from tracing.statistics.match import statistic_all as match_all_stat
40
+ from tracing.statistics.perm_mc_l2 import statistic as perm_mc_l2_stat
41
+
42
+ parser = argparse.ArgumentParser(description="Experiment Settings")
43
+
44
+ parser.add_argument("--base_model_id", default="meta-llama/Llama-2-7b-hf", type=str)
45
+ parser.add_argument("--ft_model_id", default="lmsys/vicuna-7b-v1.1", type=str)
46
+
47
+ parser.add_argument("--permute", action="store_true")
48
+ parser.add_argument("--rotate", action="store_true")
49
+ parser.add_argument("--align", action="store_true")
50
+
51
+ parser.add_argument("--dataset", default="wikitext", type=str)
52
+ parser.add_argument("--block_size", default=512, type=int)
53
+ parser.add_argument("--batch_size", default=1, type=int)
54
+
55
+ parser.add_argument("--save", default="results.p", type=str)
56
+ parser.add_argument("--seed", default=0, type=int)
57
+ parser.add_argument("--alpha", default=0.5, type=float)
58
+ parser.add_argument("--token", default="", type=str)
59
+
60
+ parser.add_argument("--stat", default="mode", type=str)
61
+ parser.add_argument("--attn", action="store_true")
62
+ parser.add_argument("--emb", action="store_true")
63
+ parser.add_argument("--num_perm", default=99, type=int)
64
+
65
+
66
+ parser.add_argument("--eval", action="store_true")
67
+
68
+ parser.add_argument(
69
+ "--aya_subset", default="aya_human_annotated", type=str, help="Subset of Aya dataset"
70
+ )
71
+ parser.add_argument("--aya_language", default="eng", type=str, help="Language code for Aya dataset")
72
+
73
+ args = parser.parse_args()
74
+
75
+
76
+ from huggingface_hub import login
77
+
78
+ if args.token == "":
79
+ hf_token = os.environ["HF_TOKEN"]
80
+ else:
81
+ hf_token = args.token
82
+ login(token=hf_token)
83
+
84
+ start = timeit.default_timer()
85
+
86
+ results = {}
87
+ results["args"] = args
88
+ results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
89
+
90
+ # fix seed on torch, np and random
91
+ manual_seed(args.seed)
92
+
93
+ dtype = torch.bfloat16
94
+ low_cpu_mem_usage = (
95
+ "70b" in args.base_model_id.lower()
96
+ ) # Enable low memory loading if "70b" is in model name
97
+
98
+ print(f"Low CPU Mem Usage Flag set to {low_cpu_mem_usage}")
99
+ base_model = AutoModelForCausalLM.from_pretrained(
100
+ args.base_model_id, torch_dtype=dtype, low_cpu_mem_usage=low_cpu_mem_usage
101
+ )
102
+ if "olmo" in args.base_model_id.lower():
103
+ tokenizer_name = (
104
+ "allenai/OLMo-1.7-7B-hf" if "olmo" in args.base_model_id.lower() else args.base_model_id
105
+ )
106
+ base_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False)
107
+ elif "Alfred" in args.base_model_id:
108
+ base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id)
109
+ elif "Salesforce" in args.base_model_id:
110
+ base_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, trust_remote_code=True)
111
+ else:
112
+ base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, use_fast=False)
113
+
114
+ ft_model = AutoModelForCausalLM.from_pretrained(args.ft_model_id, torch_dtype=dtype)
115
+ if "olmo" in args.ft_model_id.lower():
116
+ tokenizer_name = (
117
+ "allenai/OLMo-1.7-7B-hf" if "olmo" in args.ft_model_id.lower() else args.ft_model_id
118
+ )
119
+ ft_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False)
120
+ elif "Alfred" in args.ft_model_id:
121
+ ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id)
122
+ elif "Salesforce" in args.ft_model_id:
123
+ ft_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True)
124
+ else:
125
+ ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, use_fast=False)
126
+
127
+ print("base and ft models loaded")
128
+
129
+ if args.permute is True:
130
+ mlp_permutation = torch.randperm(MLP_SIZE)
131
+ emb_permutation = torch.randperm(EMB_SIZE)
132
+ if "olmo" in args.base_model_id.lower():
133
+ permute_model_olmo(base_model, ft_model, mlp_permutation, emb_permutation)
134
+ else:
135
+ permute_model(base_model, ft_model, mlp_permutation, emb_permutation)
136
+ print("ft model permuted")
137
+
138
+ if args.rotate is True:
139
+ rotate_model(ft_model)
140
+ print("ft model rotated")
141
+
142
+ if "70b" in args.base_model_id.lower() and "70b" in args.ft_model_id.lower():
143
+ # skip tmp_model
144
+ tmp_model = None
145
+ elif args.stat == "mode":
146
+ tmp_model = AutoModelForCausalLM.from_pretrained(args.base_model_id, torch_dtype=dtype)
147
+ # tmp_tokenizer is unused
148
+
149
+ if args.dataset == "wikitext":
150
+ dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", args.block_size, base_tokenizer)
151
+ dataloader = prepare_hf_dataloader(dataset, args.batch_size)
152
+ elif args.dataset == "aya":
153
+ dataset = prepare_aya_dataset(
154
+ args.aya_subset, args.aya_language, args.block_size, base_tokenizer
155
+ )
156
+ dataloader = prepare_hf_dataloader(dataset, args.batch_size)
157
+ elif args.dataset.startswith("dolma_"):
158
+ language = args.dataset.split("_")[1]
159
+ if not language and language is not None:
160
+ raise ValueError("Language is an empty string")
161
+ columns_ignored = [
162
+ "text",
163
+ "added",
164
+ "id",
165
+ "lang",
166
+ "metadata",
167
+ "source",
168
+ "timestamp",
169
+ "subdomain",
170
+ ]
171
+ dataset = load_dolma_programming_datasets(
172
+ language, args.block_size, base_tokenizer, columns_ignored
173
+ )
174
+ dataloader = prepare_hf_dataloader(dataset, args.batch_size)
175
+ elif args.dataset.startswith("m2d2_"):
176
+ test_case = args.dataset.split("_")[1]
177
+ if not test_case:
178
+ raise ValueError("Invalid m2d2 dataset format. Use 'm2d2_testcase' (e.g., 'm2d2_AI')")
179
+ columns_ignored = ["text", "added", "id", "source", "subdomain"]
180
+ dataset = load_m2d2_datasets(test_case, args.block_size, base_tokenizer, columns_ignored)
181
+ dataloader = prepare_hf_dataloader(dataset, args.batch_size)
182
+ elif args.dataset == "generated":
183
+ columns_ignored = ["text"]
184
+ dataset = load_generated_datasets(
185
+ args.base_model_id, args.ft_model_id, args.block_size, base_tokenizer, columns_ignored
186
+ )
187
+ dataloader = prepare_hf_dataloader(dataset, args.batch_size)
188
+ elif args.dataset == "random":
189
+ dataset = prepare_random_sample_dataset(20, args.block_size)
190
+ dataloader = prepare_hf_dataloader(dataset, args.batch_size)
191
+
192
+ else:
193
+ raise ValueError(f"Unknown dataset: {args.dataset}")
194
+
195
+ print("dataset loaded")
196
+
197
+ if args.stat == "mode":
198
+ test_stat = lambda base_model, ft_model: mode_stat(
199
+ base_model, ft_model, tmp_model, dataloader, args.attn, args.emb, args.alpha
200
+ )
201
+ results["alpha"] = args.alpha
202
+ if args.stat == "l2":
203
+ test_stat = lambda base_model, ft_model: l2_stat(base_model, ft_model)
204
+ if args.stat == "jsd":
205
+ test_stat = lambda base_model, ft_model: jsd_stat(base_model, ft_model, dataloader)
206
+
207
+ if args.stat == "csu":
208
+ test_stat = lambda base_model, ft_model: csu_stat(base_model, ft_model)
209
+ if args.stat == "csu_all":
210
+ test_stat = lambda base_model, ft_model: csu_all_stat(base_model, ft_model)
211
+ if args.stat == "csh_sp":
212
+ test_stat = lambda base_model, ft_model: csh_stat(base_model, ft_model, dataloader)
213
+
214
+ if args.stat == "match":
215
+ test_stat = lambda base_model, ft_model: match_stat(base_model, ft_model, dataloader)
216
+ if args.stat == "match_all":
217
+ test_stat = lambda base_model, ft_model: match_all_stat(base_model, ft_model, dataloader)
218
+
219
+ if args.stat == "perm_mc_l2":
220
+ mc = lambda base_model, ft_model: mode_stat(
221
+ base_model, ft_model, tmp_model, dataloader, args.attn, args.emb
222
+ )
223
+ l2 = lambda base_model, ft_model: l2_stat(base_model, ft_model)
224
+ test_stat = lambda base_model, ft_model: perm_mc_l2_stat(
225
+ base_model, ft_model, mc, l2, args.num_perm
226
+ )
227
+
228
+ if args.eval is True:
229
+ results["base loss"] = sum(evaluate(base_model, dataloader))
230
+ results["ft loss"] = sum(evaluate(ft_model, dataloader))
231
+ print("losses evaluated")
232
+
233
+ results["non-aligned test stat"] = test_stat(base_model, ft_model)
234
+
235
+ print("non-aligned stat computed")
236
+
237
+ if args.align is True:
238
+ align_model(base_model, ft_model, ft_model)
239
+ results["aligned test stat"] = test_stat(base_model, ft_model)
240
+ print("aligned stat computed")
241
+
242
+ end = timeit.default_timer()
243
+ results["time"] = end - start
244
+
245
+ print(results)
246
+ pickle.dump(results, open(args.save, "wb"))
model-tracing/requirements-dev.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ black==24.1.1
2
+ ruff==0.1.8
3
+ pre-commit==3.5.0
4
+ nbqa==1.7.1
5
+ ipykernel==6.29.0
model-tracing/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.31.0
2
+ data_processing==0.1.4
3
+ datasets==2.18.0
4
+ huggingface-hub==0.27.1
5
+ matplotlib==3.8.4
6
+ numpy==1.26.4
7
+ pandas==2.2.2
8
+ PyYAML==6.0.1
9
+ PyYAML==6.0.1
10
+ scipy==1.13.1
11
+ torch==2.2.2
12
+ tqdm==4.66.2
13
+ transformers==4.40.0
14
+ sentencepiece==0.2.0
15
+ protobuf==5.27.1
16
+ zstandard==0.22.0
17
+ ipdb==0.13.13
model-tracing/results/jsd/model_pairs_jsd.csv ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model Pair,jsd_wikitext_batch, jsd_test_input
2
+ meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,1.9561830413294956e-06,3.245754214731278e-06
3
+ meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,1.729454925225582e-05,2.0375082385726273e-05
4
+ meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,2.1560920231422642e-06,1.0297326298314147e-06
5
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,1.5276991689461283e-06,4.193487256998196e-06
6
+ meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,2.434112047922099e-06,4.357912075647619e-06
7
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,3.127033551209024e-06,3.4884374144894537e-06
8
+ meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,2.6865156996791484e-06,4.245463969709817e-06
9
+ meta-llama/Llama-2-7b-hf vs LLM360/Amber,1.8514610928832553e-06,3.480401119304588e-06
10
+ codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,1.7027690773829818e-05,1.983477886824403e-05
11
+ codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,2.795685531964409e-06,3.3317055567749776e-06
12
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,3.2359234864998143e-06,4.885899215878453e-06
13
+ codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,1.8831561874321778e-06,4.028150215162896e-06
14
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,3.955687134293839e-06,4.28245175498887e-06
15
+ codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,3.917118647223106e-06,5.343406428437447e-06
16
+ codellama/CodeLlama-7b-hf vs LLM360/Amber,2.043052063527284e-06,3.20175354318053e-06
17
+ openlm-research/open_llama_7b vs huggyllama/llama-7b,1.7465732526034117e-05,2.0448682334972546e-05
18
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,1.7609712813282385e-05,2.049213435384445e-05
19
+ openlm-research/open_llama_7b vs EleutherAI/llemma_7b,1.706841794657521e-05,2.0151202988927253e-05
20
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,1.7788930563256145e-05,2.0416322513483465e-05
21
+ openlm-research/open_llama_7b vs microsoft/Orca-2-7b,1.767633148119785e-05,2.018710074480623e-05
22
+ openlm-research/open_llama_7b vs LLM360/Amber,1.7064834537450224e-05,2.0156170648988336e-05
23
+ huggyllama/llama-7b vs lmsys/vicuna-7b-v1.5,3.4697009141382296e-06,4.354528755357023e-06
24
+ huggyllama/llama-7b vs EleutherAI/llemma_7b,3.1778936318005435e-06,4.261534741090145e-06
25
+ huggyllama/llama-7b vs lmsys/vicuna-7b-v1.1,1.5405695421577548e-06,2.4467876755807083e-06
26
+ huggyllama/llama-7b vs microsoft/Orca-2-7b,4.071262083016336e-06,4.621022526407614e-06
27
+ huggyllama/llama-7b vs LLM360/Amber,2.3568713913846295e-06,3.4077993404935114e-06
28
+ lmsys/vicuna-7b-v1.5 vs EleutherAI/llemma_7b,3.499165813991567e-06,5.325471647665836e-06
29
+ lmsys/vicuna-7b-v1.5 vs lmsys/vicuna-7b-v1.1,3.2389764328399906e-06,3.7683282698708354e-06
30
+ lmsys/vicuna-7b-v1.5 vs microsoft/Orca-2-7b,2.189671249652747e-06,3.5288214803586015e-06
31
+ lmsys/vicuna-7b-v1.5 vs LLM360/Amber,2.9944339985377155e-06,4.550915946310852e-06
32
+ EleutherAI/llemma_7b vs lmsys/vicuna-7b-v1.1,4.045330570079386e-06,5.489064733410487e-06
33
+ EleutherAI/llemma_7b vs microsoft/Orca-2-7b,3.952619408664759e-06,6.314553502306808e-06
34
+ EleutherAI/llemma_7b vs LLM360/Amber,2.333970769541338e-06,3.391487553017214e-06
35
+ lmsys/vicuna-7b-v1.1 vs microsoft/Orca-2-7b,3.713232445079484e-06,4.426544364832807e-06
36
+ lmsys/vicuna-7b-v1.1 vs LLM360/Amber,3.438890416873619e-06,4.203274784231326e-06
37
+ microsoft/Orca-2-7b vs LLM360/Amber,3.5631983337225392e-06,5.076844900031574e-06
model-tracing/results/l2/model_pairs_l2.csv ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model Pair,l2
2
+ meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,93.18718912197232
3
+ meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,122.27028296821305
4
+ meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,122.27950493986255
5
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,4.1913064124248285
6
+ meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,88.74339722102076
7
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,122.32382946735395
8
+ meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,5.711168968966263
9
+ meta-llama/Llama-2-7b-hf vs LLM360/Amber,148.0270618556701
10
+ codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,140.50532547577853
11
+ codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,140.80544442041523
12
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,93.25281141868513
13
+ codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,49.639849790592784
14
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,140.8399383650519
15
+ codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,93.11375432525952
16
+ codellama/CodeLlama-7b-hf vs LLM360/Amber,163.1440311418685
17
+ openlm-research/open_llama_7b vs huggyllama/llama-7b,131.92685513316152
18
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,122.38108086340206
19
+ openlm-research/open_llama_7b vs EleutherAI/llemma_7b,133.5973724048443
20
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,131.96828017611685
21
+ openlm-research/open_llama_7b vs microsoft/Orca-2-7b,120.33676200259515
22
+ openlm-research/open_llama_7b vs LLM360/Amber,156.87263745704468
23
+ huggyllama/llama-7b vs lmsys/vicuna-7b-v1.5,122.38058419243987
24
+ huggyllama/llama-7b vs EleutherAI/llemma_7b,134.03754865916954
25
+ huggyllama/llama-7b vs lmsys/vicuna-7b-v1.1,3.2305828500859106
26
+ huggyllama/llama-7b vs microsoft/Orca-2-7b,120.63673226643598
27
+ huggyllama/llama-7b vs LLM360/Amber,156.8477233676976
28
+ lmsys/vicuna-7b-v1.5 vs EleutherAI/llemma_7b,88.80338316392734
29
+ lmsys/vicuna-7b-v1.5 vs lmsys/vicuna-7b-v1.1,122.41473367697594
30
+ lmsys/vicuna-7b-v1.5 vs microsoft/Orca-2-7b,6.786975900194637
31
+ lmsys/vicuna-7b-v1.5 vs LLM360/Amber,148.06894329896906
32
+ EleutherAI/llemma_7b vs lmsys/vicuna-7b-v1.1,134.06379757785467
33
+ EleutherAI/llemma_7b vs microsoft/Orca-2-7b,88.6362254000865
34
+ EleutherAI/llemma_7b vs LLM360/Amber,156.21647923875432
35
+ lmsys/vicuna-7b-v1.1 vs microsoft/Orca-2-7b,120.66557634083046
36
+ lmsys/vicuna-7b-v1.1 vs LLM360/Amber,156.87929553264604
37
+ microsoft/Orca-2-7b vs LLM360/Amber,145.12435121107268
model-tracing/results/perm/permutation_l2_updated_midpoint_wikitext_single.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model Pair,l2 p-value,unpermuted l2,permuted l2s
2
+ meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,0.01,93.18718912197232,132.05157871972318,132.05082179930795,132.0419009515571,132.04476643598616,132.03995458477507,132.04484753460207,132.04233347750866,132.04203611591694,132.02357266435988,132.0313310986159,132.04133326124568,132.0299794550173,132.041035899654,132.03403438581316,132.04498269896195,132.05074070069205,132.03608888408306,132.02641111591694,132.0261678200692,132.03270977508652,132.0413062283737,132.04349589100346,132.04060337370242,132.0303849480969,132.04200908304497,132.0373053633218,132.04376621972318,132.04279303633217,132.03746756055364,132.048821366782,132.05063256920414,132.04917279411765,132.03687283737025,132.03990051903114,132.03938689446366,132.03616998269896,132.0414143598616,132.04460423875432,132.0300605536332,132.04928092560553,132.0401438148789,132.047848183391,132.03941392733563,132.0504974048443,132.0397653546713,132.0411169982699,132.05222750865053,132.05025410899654,132.04311743079586,132.04152249134947,132.0435499567474,132.03844074394465,132.03768382352942,132.033196366782,132.04184688581316,132.0558499134948,132.0356563581315,132.0525519031142,132.03611591695503,132.03911656574394,132.04692906574394,132.0583910034602,132.04273897058823,132.04636137543253,132.04955125432525,132.0530384948097,132.0344669117647,132.03471020761245,132.033196366782,132.04108996539793,132.04849697231833,132.02995242214533,132.0433607266436,132.03146626297578,132.03838667820068,132.024410683391,132.0287089100346,132.0296820934256,132.03854887543253,132.04598291522493,132.0377919550173,132.04895653114187,132.0298983564014,132.0441176470588,132.02065311418684,132.03654844290656,132.0482266435986,132.02643814878894,132.04836180795849,132.04006271626298,132.03579152249134,132.03660250865053,132.05247080449826,132.05047037197232,132.03965722318338,132.03441284602076,132.05060553633217,132.03714316608998,132.03862997404843,132.04200908304497
3
+ meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,0.7,122.27028296821305,122.26932989690722,122.27398786512028,122.27390732388317,122.27303479381443,122.27692762027492,122.26056432560138,122.26719555412372,122.26647068298969,122.27483354810997,122.27200118127148,122.26848421391753,122.28572003865979,122.28478039089347,122.27780015034364,122.2675445661512,122.2703903565292,122.26941043814433,122.26728951890034,122.27903511597938,122.2711689218213,122.2724441580756,122.26934332044674,122.27316902920963,122.27233676975945,122.27694104381443,122.2737193943299,122.26997422680412,122.27283344072166,122.26644383591065,122.26551761168385,122.27107495704468,122.27491408934708,122.2752899484536,122.27907538659794,122.28435083762886,122.25655068728523,122.270591709622,122.28267289518901,122.27515571305842,122.2756389604811,122.2705245919244,122.26542364690722,122.27439057130584,122.27637725515464,122.28272658934708,122.26685996563575,122.28712951030928,122.28406894329896,122.2765651847079,122.28468642611683,122.27808204467354,122.26298056271477,122.26336984536083,122.27016215635739,122.2815318943299,122.28082044673539,122.27277974656357,122.27076621563575,122.27816258591065,122.2646316580756,122.27404155927834,122.27562553694158,122.2571010524055,122.26389336340206,122.268470790378,122.26716870704468,122.2803908934708,122.27643094931271,122.28996187714776,122.27609536082474,122.28057882302406,122.2712360395189,122.2636383161512,122.27245758161511,122.2852904853952,122.27440399484536,122.27890088058419,122.26483301116839,122.27869952749141,122.2716655927835,122.28151847079037,122.2673297895189,122.27923646907216,122.27596112542955,122.27347777061856,122.28584085051547,122.27966602233677,122.27594770189003,122.26401417525773,122.27174613402062,122.28414948453609,122.28114261168385,122.25987972508591,122.27652491408935,122.27314218213058,122.27409525343643,122.27867268041237,122.26774591924399,122.26245704467354,122.27077963917526
4
+ meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,0.54,122.27950493986255,122.27486039518901,122.28463273195877,122.2864314862543,122.27872637457045,122.27867268041237,122.27773303264605,122.27649806701031,122.28712951030928,122.27754510309278,122.27550472508591,122.28747852233677,122.28261920103093,122.270591709622,122.27614905498282,122.28127684707904,122.270591709622,122.27958548109966,122.27929016323024,122.27051116838489,122.27708870274914,122.2786189862543,122.27534364261169,122.27955863402062,122.27829682130584,122.28205541237114,122.27977341065292,122.26401417525773,122.2897068298969,122.28565292096219,122.26490012886597,122.27426975945018,122.29236469072166,122.28436426116839,122.2882033934708,122.28114261168385,122.28251181271477,122.2842300257732,122.2813842353952,122.28860609965636,122.2909149484536,122.2908344072165,122.2858676975945,122.28393470790378,122.27354488831615,122.27963917525773,122.2774108676976,122.29558634020619,122.27974656357388,122.28812285223368,122.29461984536083,122.28704896907216,122.2715850515464,122.26608140034364,122.26197379725086,122.27692762027492,122.28788122852234,122.27958548109966,122.2872100515464,122.27775987972508,122.27550472508591,122.28205541237114,122.28436426116839,122.27990764604812,122.27604166666667,122.27714239690722,122.27389390034364,122.28125,122.27843105670104,122.27241731099656,122.28664626288659,122.26771907216495,122.2760685137457,122.28774699312714,122.28259235395188,122.28484750859107,122.28433741408935,122.27013530927834,122.28479381443299,122.28551868556701,122.27765249140893,122.26758483676976,122.28814969931271,122.27896799828179,122.28149162371135,122.28576030927834,122.29311640893471,122.28299506013745,122.2726589347079,122.27891430412372,122.27445768900344,122.27843105670104,122.27268578178695,122.2838810137457,122.27996134020619,122.2746456185567,122.27453823024055,122.29010953608247,122.2850891323024,122.27300794673539,122.28710266323024
5
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,4.1913064124248285,111.40988777920963,111.4162639604811,111.4112972508591,111.4123577104811,111.40202158505154,111.4141296176976,111.38573883161511,111.3972964991409,111.42339185996563,111.4049747637457,111.39443728522336,111.4005718427835,111.40771316580756,111.4324527491409,111.40414250429554,111.43137886597938,111.40233032646049,111.41706937285224,111.44485609965636,111.42356636597938,111.41751234965636,111.38518846649484,111.38039626288659,111.41241140463917,111.41616999570446,111.41276041666667,111.40418277491409,111.42352609536083,111.39218213058419,111.41365979381443,111.40744469501718,111.40552512886597,111.3884235395189,111.4141296176976,111.4035518685567,111.4170425257732,111.4012161726804,111.41414304123711,111.40931056701031,111.43415753865979,111.41148518041237,111.41082742697594,111.41332420532646,111.40336393900344,111.4093776847079,111.41325708762886,111.39212843642612,111.36345575601375,111.41567332474227,111.42532484965636,111.42834514604812,111.41506926546391,111.41302888745705,111.39298754295532,111.41467998281787,111.41370006443299,111.41050526202748,111.39920264175258,111.41560620704468,111.41368664089347,111.39293384879726,111.42366033075601,111.38220844072166,111.42277437714776,111.41124355670104,111.39641054553265,111.42132463487972,111.41400880584193,111.41450547680412,111.39639712199313,111.40961930841924,111.40094770189003,111.40665270618557,111.41680090206185,111.40949849656357,111.426841709622,111.40143094931271,111.42231797680412,111.42061318728523,111.41054553264605,111.41502899484536,111.4170425257732,111.4147068298969,111.42254617697594,111.41698883161511,111.43566097508591,111.41457259450172,111.42980831185567,111.41288122852234,111.41047841494846,111.4120086984536,111.4169754080756,111.40952534364261,111.42842568728523,111.40661243556701,111.39262510738831,111.39547089776632,111.40763262457045,111.42008966924399,111.42800955756013
6
+ meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,0.01,88.74339722102076,124.33972210207612,124.33282871972318,124.33428849480968,124.32993620242215,124.34894031141869,124.3295847750865,124.3283142301038,124.33209883217994,124.32628676470588,124.31741998269896,124.34288494809688,124.34204692906575,124.32566500865052,124.33596453287197,124.34515570934256,124.35004865916954,124.33196366782006,124.32552984429066,124.34983239619378,124.32985510380622,124.33674848615917,124.33604563148789,124.33090938581314,124.3409115484429,124.34493944636678,124.33155817474048,124.33477508650519,124.34188473183391,124.34058715397924,124.3288008217993,124.33655925605537,124.33631596020761,124.32899005190312,124.35085964532873,124.33961397058823,124.32536764705883,124.32928741349481,124.33980320069205,124.3333964100346,124.31693339100346,124.33461288927336,124.33564013840831,124.3396410034602,124.33450475778547,124.34566933391004,124.33461288927336,124.33599156574394,124.327151816609,124.33820826124567,124.32677335640139,124.34280384948097,124.34139814013841,124.33431552768167,124.32274545847751,124.33431552768167,124.34210099480968,124.32774653979239,124.34764273356402,124.32088019031141,124.33585640138408,124.33215289792388,124.33331531141869,124.329098183391,124.33423442906575,124.33353157439447,124.31585207612457,124.35286007785467,124.32485402249135,124.33682958477509,124.32361051038062,124.33353157439447,124.33666738754326,124.32580017301038,124.33420739619378,124.3295847750865,124.34607482698962,124.33664035467127,124.34556120242215,124.3423713235294,124.34212802768167,124.32174524221453,124.33069312283737,124.34088451557093,124.32277249134948,124.33623486159169,124.33377487024221,124.32963884083046,124.34593966262976,124.34034385813149,124.32777357266436,124.33358564013841,124.32085315743944,124.34069528546713,124.3389651816609,124.32885488754326,124.33047685986159,124.329098183391,124.3323150951557,124.3367214532872,124.32131271626298
7
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,0.31,122.32382946735395,122.33115871993127,122.31719823883161,122.31287585910653,122.3221649484536,122.3154800257732,122.32098367697594,122.3173324742268,122.31561426116839,122.32866194158076,122.30825816151203,122.32503758591065,122.31749355670104,122.31824527491409,122.31663445017182,122.31043277491409,122.31840635738831,122.32970897766323,122.31829896907216,122.32447379725086,122.32071520618557,122.3299774484536,122.3195339347079,122.30586877147766,122.31803049828179,122.31980240549828,122.33448775773196,122.32780283505154,122.3230777491409,122.31652706185567,122.32624570446735,122.3242858676976,122.30023088487972,122.32345360824742,122.31486254295532,122.3313198024055,122.32718535223368,122.33652813573883,122.31835266323024,122.32868878865979,122.32178908934708,122.32586984536083,122.30815077319588,122.31905068728523,122.31558741408935,122.31650021477664,122.3141376718213,122.31190936426117,122.32530605670104,122.32718535223368,122.31322487113403,122.32334621993127,122.31284901202748,122.31556056701031,122.30307667525773,122.32050042955326,122.32272873711341,122.3085266323024,122.32079574742268,122.31180197594502,122.30774806701031,122.32423217353951,122.31652706185567,122.3191043814433,122.31539948453609,122.32063466494846,122.30337199312714,122.30900987972508,122.32503758591065,122.31416451890034,122.32079574742268,122.31792310996563,122.33027276632302,122.32468857388317,122.32560137457045,122.30543921821305,122.3055466065292,122.31765463917526,122.32299720790378,122.31762779209622,122.31207044673539,122.32858140034364,122.32624570446735,122.32608462199313,122.32267504295532,122.31671499140893,122.3194533934708,122.32541344501718,122.3225139604811,122.3261383161512,122.3116408934708,122.3131443298969,122.33674291237114,122.32710481099656,122.3195339347079,122.32288981958763,122.3251449742268,122.31478200171821,122.3196681701031,122.30903672680412,122.31709085051547
8
+ meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,0.01,5.711168968966263,109.73838938148789,109.76823367214533,109.74444474480968,109.75990754757785,109.75594723183391,109.74075475778547,109.76926092128028,109.75613646193771,109.7564473399654,109.75486591695501,109.73783520761246,109.75616349480968,109.73333423442907,109.7333072015571,109.7589884299308,109.74649924307958,109.7417009083045,109.74762110726644,109.7486348399654,109.74178200692042,109.75535250865052,109.74284980536332,109.75802876297578,109.74459342560553,109.7354427984429,109.7502027465398,109.74397166955018,109.75358185553634,109.74764814013841,109.74810769896193,109.73225291955018,109.74786440311419,109.74144409602076,109.72913062283737,109.73580774221453,109.7557580017301,109.745120566609,109.73607807093425,109.74572880622837,109.75629865916954,109.75543360726644,109.74751297577855,109.73726751730104,109.76432742214533,109.74180903979239,109.74395815311419,109.73598345588235,109.75665008650519,109.75124351211073,109.7499053849481,109.76511137543253,109.74045739619378,109.74457990916954,109.74847264273356,109.74220101643598,109.7159115484429,109.75601481401384,109.73921388408304,109.73537521626298,109.75866403546713,109.7580963451557,109.74631001297578,109.74312013408304,109.7440392517301,109.76505730968859,109.74194420415225,109.76772004757785,109.73433445069205,109.73845696366782,109.75387921712803,109.74189013840831,109.75978589965398,109.73451016435986,109.73171226211073,109.71661440311419,109.73563202854672,109.74274167387543,109.74071420847751,109.73142841695501,109.74034926470588,109.73146896626298,109.72857644896193,109.74256596020761,109.73361807958477,109.76313797577855,109.73498323961938,109.74025464965398,109.75832612456747,109.76293522923875,109.76716587370242,109.72335910467127,109.75854238754326,109.73529411764706,109.74763462370242,109.7467695717993,109.76869323096886,109.74939176038062,109.73225291955018,109.74653979238754,109.75794766435986
9
+ meta-llama/Llama-2-7b-hf vs LLM360/Amber,0.2,148.0270618556701,148.01374570446737,148.0257731958763,148.03135738831614,148.01331615120276,148.01460481099656,148.00644329896906,148.03823024054984,148.02276632302406,148.03221649484536,148.01546391752578,148.04252577319588,148.02405498281786,148.0176116838488,148.02190721649484,148.01202749140893,148.0266323024055,148.0158934707904,148.0189003436426,148.01847079037802,148.0356529209622,148.02061855670104,148.01245704467354,148.01159793814432,148.01675257731958,148.0180412371134,148.03264604810997,148.0287800687285,148.02018900343643,148.0257731958763,148.01460481099656,148.01202749140893,148.02018900343643,148.01374570446737,148.02061855670104,148.02448453608247,148.02405498281786,148.01460481099656,148.0287800687285,148.01503436426117,148.02362542955328,148.01847079037802,148.03049828178695,148.00644329896906,148.0098797250859,148.01417525773195,148.0171821305842,148.00644329896906,148.03049828178695,148.0081615120275,148.0266323024055,148.02061855670104,148.01675257731958,148.02491408934708,148.0270618556701,148.01675257731958,148.03908934707903,148.0274914089347,148.01675257731958,148.0158934707904,148.03307560137458,148.0090206185567,148.01331615120276,148.0171821305842,148.01503436426117,148.01374570446737,148.01503436426117,148.01546391752578,148.02018900343643,148.03393470790377,148.03608247422682,148.0356529209622,148.03092783505156,148.0266323024055,148.02362542955328,148.02061855670104,148.02319587628867,148.0158934707904,148.01159793814432,148.02233676975945,148.01331615120276,148.0270618556701,148.0171821305842,148.02362542955328,148.03178694158075,148.01116838487974,148.0262027491409,148.03049828178695,148.02362542955328,148.0270618556701,148.0257731958763,148.0257731958763,147.99785223367698,148.02190721649484,148.01503436426117,148.0253436426117,148.02233676975945,148.02405498281786,148.0017182130584,148.0004295532646,148.03393470790377
10
+ codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,0.38,140.50532547577853,140.5032033953287,140.5064608564014,140.51106996107268,140.49125486591694,140.50667711937717,140.50873161764707,140.50923172577853,140.49525573096886,140.5012840614187,140.50564987024222,140.49528276384083,140.4923902465398,140.50493349913495,140.4986483564014,140.50863700259515,140.50363592128028,140.5041089965398,140.49612078287197,140.49152519463667,140.5005001081315,140.517584883218,140.50701503027682,140.49028168252596,140.49904033304497,140.50935337370242,140.50778546712803,140.49187662197232,140.49671550605535,140.5021355968858,140.50373053633217,140.49620188148788,140.49144409602076,140.50054065743944,140.5080017301038,140.50940743944636,140.51417874134947,140.49390408737025,140.50521734429066,140.51169171712803,140.4971750648789,140.5050686634948,140.48786224048442,140.50223021193773,140.5075151384083,140.5000810986159,140.51462478373702,140.50389273356402,140.50124351211073,140.50512272923876,140.5103671064014,140.50636624134947,140.50823150951558,140.50950205449826,140.50141922577853,140.5096642517301,140.50913711072664,140.49066014273356,140.51182688148788,140.5005812067474,140.4827124783737,140.51116457612457,140.49401221885813,140.50786656574394,140.5064473399654,140.49655330882354,140.50082450259515,140.5028384515571,140.49266057525952,140.49933769463667,140.506825800173,140.4949178200692,140.50093263408306,140.5046902032872,140.49213343425606,140.50385218425606,140.49698583477507,140.50163548875432,140.50366295415225,140.5062581098616,140.5017706531142,140.49616133217992,140.51032655709344,140.51282709775086,140.5035007569204,140.50894788062283,140.5027573529412,140.49937824394465,140.50121647923876,140.5116376513841,140.49831044550174,140.50066230536333,140.49010596885813,140.49332288062283,140.50817744377161,140.50729887543253,140.5018517517301,140.48883542387543,140.5067447015571,140.509772383218,140.49279573961937
11
+ codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,0.69,140.80544442041523,140.79292820069205,140.80676903114187,140.8191500865052,140.80906682525952,140.81752811418684,140.80882352941177,140.80971561418684,140.8163116349481,140.82358347750866,140.81920415224914,140.80328179065745,140.80536332179932,140.81898788927336,140.8016598183391,140.82015030276816,140.7961721453287,140.8296929065744,140.82950367647058,140.8107428633218,140.80736375432525,140.80855320069205,140.81850129757785,140.806633866782,140.81187824394465,140.8059580449827,140.8185553633218,140.81771734429066,140.80606617647058,140.8207179930796,140.82880082179932,140.81477076124568,140.80079476643598,140.78873810553634,140.81606833910035,140.80755298442907,140.81393274221455,140.8020653114187,140.80160575259515,140.81898788927336,140.81304065743944,140.80717452422147,140.80017301038063,140.81909602076124,140.79106293252596,140.8329098183391,140.80252487024222,140.80322772491348,140.81925821799308,140.80417387543253,140.80387651384083,140.80941825259515,140.79379325259515,140.79952422145328,140.81842019896195,140.81255406574394,140.81339208477507,140.80582288062283,140.81477076124568,140.81655493079586,140.7884948096886,140.806660899654,140.82723291522493,140.8205017301038,140.8002000432526,140.81022923875432,140.8186634948097,140.80133542387543,140.80433607266437,140.80890462802768,140.818366133218,140.797848183391,140.81477076124568,140.7946312716263,140.80149762110727,140.80025410899654,140.81441933391002,140.80441717128028,140.81696042387543,140.81031033737025,140.81509515570934,140.80628243944636,140.79625324394465,140.80841803633217,140.7991187283737,140.79036007785467,140.82501621972318,140.8176903114187,140.81168901384083,140.81604130622839,140.81182417820068,140.8177714100346,140.80714749134947,140.81068879757785,140.806660899654,140.81433823529412,140.80912089100346,140.79500973183391,140.81360834775086,140.79879433391002,140.81222967128028
12
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,93.25281141868513,132.11570069204151,132.1188365051903,132.12359429065745,132.12824394463667,132.1175929930796,132.10329260380624,132.12218858131487,132.11305147058823,132.12135056228374,132.1322448096886,132.12097210207614,132.11245674740485,132.10715830449826,132.10064338235293,132.12210748269896,132.11791738754326,132.133785683391,132.12700043252596,132.1134839965398,132.1257028546713,132.13619160899654,132.1047794117647,132.11937716262975,132.13305579584775,132.1193230968858,132.10007569204151,132.1053741349481,132.11886353806227,132.1098615916955,132.10640138408306,132.13029844290656,132.10250865051904,132.1098615916955,132.11059147923876,132.1089695069204,132.1121053200692,132.11851211072664,132.1159169550173,132.1171064013841,132.11802551903114,132.1390570934256,132.12026924740485,132.11645761245674,132.1136732266436,132.11813365051904,132.11721453287197,132.13062283737025,132.12235077854672,132.12448637543253,132.1038062283737,132.11167279411765,132.10710423875432,132.12262110726644,132.11440311418684,132.119160899654,132.1130785034602,132.12121539792386,132.12527032871972,132.12278330449826,132.1101589532872,132.12410791522493,132.11275410899654,132.1047794117647,132.10396842560553,132.10315743944636,132.11059147923876,132.111321366782,132.11634948096886,132.12654087370242,132.11042928200692,132.11470047577853,132.1305687716263,132.11169982698962,132.10504974048442,132.1301903114187,132.1169982698962,132.1234320934256,132.12721669550174,132.11607915224914,132.11794442041523,132.11359212802768,132.11872837370242,132.12024221453288,132.12140462802768,132.11440311418684,132.11862024221455,132.10921280276816,132.12153979238755,132.1078070934256,132.10310337370242,132.11770112456747,132.12278330449826,132.10396842560553,132.12472967128028,132.11218641868513,132.12935229238755,132.12318879757785,132.12299956747404,132.12332396193773,132.11051038062283
13
+ codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,0.01,49.639849790592784,145.56556056701032,145.56373496563575,145.55355992268042,145.55103629725085,145.5629295532646,145.5602448453608,145.55651310137458,145.53490120274915,145.53055197594503,145.56274162371133,145.55941258591065,145.57844716494844,145.5524323453608,145.5569695017182,145.5328071305842,145.53490120274915,145.5290485395189,145.53739798109964,145.53420317869416,145.55670103092783,145.5446198453608,145.5487811426117,145.55345253436425,145.54733140034364,145.53326353092783,145.56346649484536,145.56523840206185,145.56966817010309,145.54853951890036,145.54717031786942,145.56961447594503,145.52821628006873,145.56534579037802,145.54408290378007,145.53423002577318,145.563922895189,145.57444695017182,145.55600300687286,145.54821735395188,145.55133161512026,145.53570661512026,145.5597347508591,145.54706292955328,145.55917096219932,145.52679338487974,145.548995919244,145.52665914948454,145.54008268900344,145.5602985395189,145.55500966494844,145.55087521477662,145.53527706185568,145.53976052405497,145.53527706185568,145.53608247422682,145.55750644329896,145.535518685567,145.55976159793815,145.53127684707903,145.51430949312714,145.54580111683848,145.54937177835052,145.5565399484536,145.5551707474227,145.53307560137458,145.57366838487974,145.55793599656357,145.5544727233677,145.54351911512026,145.54330433848799,145.5429016323024,145.5487274484536,145.54754617697594,145.560379080756,145.55310352233678,145.53726374570448,145.55898303264604,145.55004295532646,145.53570661512026,145.5592246563574,145.53978737113403,145.5659901202749,145.5630637886598,145.55355992268042,145.53261920103094,145.57619201030928,145.55672787800688,145.52061855670104,145.52126288659792,145.55756013745705,145.52998818728523,145.5549022766323,145.53967998281786,145.5607817869416,145.55613724226805,145.5617214347079,145.56502362542955,145.53033719931273,145.55361361683848,145.54206937285224
14
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,0.95,140.8399383650519,140.86207828719722,140.84693987889273,140.85461721453288,140.85002162629758,140.86124026816609,140.8376946366782,140.865241133218,140.84069528546712,140.86096993944636,140.84996756055364,140.87175605536333,140.85621215397924,140.8480482266436,140.86261894463667,140.84896734429066,140.84964316608998,140.85499567474048,140.84645328719722,140.86207828719722,140.84885921280278,140.84391219723184,140.8503730536332,140.86086180795849,140.8394517733564,140.87008001730104,140.86613321799308,140.84953503460207,140.8507785467128,140.8595642301038,140.85450908304497,140.848642949827,140.84558823529412,140.85104887543253,140.85534710207614,140.84372296712803,140.84915657439447,140.84596669550174,140.8400464965398,140.83196366782008,140.85656358131487,140.85294117647058,140.85288711072664,140.85337370242215,140.86215938581316,140.8411278114187,140.86178092560553,140.8532115051903,140.84264165224914,140.85207612456747,140.86283520761245,140.84591262975778,140.85461721453288,140.8628892733564,140.86315960207614,140.85018382352942,140.84826448961937,140.85131920415225,140.8591857698962,140.8618079584775,140.85632028546712,140.84634515570934,140.83704584775086,140.86091587370242,140.85480644463667,140.8448313148789,140.8634839965398,140.85759083044982,140.84283088235293,140.85456314878894,140.86691717128028,140.866214316609,140.85172469723184,140.84410142733563,140.84693987889273,140.85467128027682,140.85883434256056,140.8506974480969,140.8499945934256,140.86607915224914,140.8382893598616,140.8586721453287,140.8449394463668,140.84945393598616,140.84575043252596,140.8527519463668,140.8514814013841,140.86291630622839,140.84748053633217,140.85494160899654,140.86507893598616,140.86359212802768,140.8492106401384,140.84166846885813,140.86224048442907,140.86342993079586,140.84710207612457,140.8526438148789,140.8489403114187,140.8474535034602,140.84556120242215
15
+ codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,0.01,93.11375432525952,131.67065852076124,131.69193339100346,131.69139273356402,131.68917603806227,131.67574070069205,131.68130947231833,131.6948259083045,131.68598615916954,131.68725670415225,131.7043955449827,131.68766219723184,131.6863375865052,131.69371756055364,131.69139273356402,131.6988267733564,131.69517733564012,131.7019625865052,131.70339532871972,131.700205449827,131.69136570069205,131.68625648788927,131.6990160034602,131.70631487889273,131.71239727508652,131.6953935986159,131.69055471453288,131.6905276816609,131.68923010380624,131.67936310553634,131.70290873702422,131.67730860726644,131.68549956747404,131.68901384083046,131.69387975778545,131.67738970588235,131.70188148788927,131.70558499134947,131.68817582179932,131.6808499134948,131.68460748269896,131.68412089100346,131.7046929065744,131.7000973183391,131.6867971453287,131.67525410899654,131.68512110726644,131.6942311851211,131.69809688581316,131.69828611591694,131.70444961072664,131.6944474480969,131.6895544982699,131.68376946366783,131.68122837370242,131.6807688148789,131.6996107266436,131.69442041522493,131.67679498269896,131.68549956747404,131.67936310553634,131.6877973615917,131.70393598615917,131.6775519031142,131.6896355968858,131.69742106401384,131.68279628027682,131.68371539792386,131.68171496539793,131.67941717128028,131.70258434256056,131.7061526816609,131.6964749134948,131.68547253460207,131.69166306228374,131.68774329584775,131.69171712802768,131.68220155709344,131.69133866782008,131.69382569204151,131.68093101211073,131.69166306228374,131.6867971453287,131.67341587370242,131.69182525951558,131.6847426470588,131.68952746539793,131.68403979238755,131.67484861591694,131.6941230536332,131.69590722318338,131.70999134948096,131.69585315743944,131.69598832179932,131.67325367647058,131.70104346885813,131.69874567474048,131.69347426470588,131.66998269896195,131.72056120242215,131.68279628027682
16
+ codellama/CodeLlama-7b-hf vs LLM360/Amber,0.91,163.1440311418685,163.17171280276816,163.1613321799308,163.14749134948096,163.1613321799308,163.15787197231833,163.14749134948096,163.15787197231833,163.18209342560553,163.17517301038063,163.14057093425606,163.16479238754326,163.15787197231833,163.16479238754326,163.17171280276816,163.17171280276816,163.1682525951557,163.15787197231833,163.1691176470588,163.15095155709344,163.15441176470588,163.17517301038063,163.17171280276816,163.17171280276816,163.17517301038063,163.14057093425606,163.15787197231833,163.15095155709344,163.15787197231833,163.1613321799308,163.15441176470588,163.17171280276816,163.17517301038063,163.15787197231833,163.15787197231833,163.1725778546713,163.17171280276816,163.17517301038063,163.15787197231833,163.1725778546713,163.16479238754326,163.16955017301038,163.1613321799308,163.16479238754326,163.17517301038063,163.15441176470588,163.15787197231833,163.15787197231833,163.14749134948096,163.17171280276816,163.15441176470588,163.14749134948096,163.18209342560553,163.1613321799308,163.1301903114187,163.1652249134948,163.15441176470588,163.1691176470588,163.15441176470588,163.15787197231833,163.1682525951557,163.15787197231833,163.15095155709344,163.17171280276816,163.1440311418685,163.1440311418685,163.15095155709344,163.15441176470588,163.15787197231833,163.15787197231833,163.1682525951557,163.1725778546713,163.1440311418685,163.16479238754326,163.14705882352942,163.1691176470588,163.15787197231833,163.1691176470588,163.16608996539793,163.16479238754326,163.17517301038063,163.17171280276816,163.17517301038063,163.17863321799308,163.17517301038063,163.1682525951557,163.13365051903114,163.15787197231833,163.1371107266436,163.1613321799308,163.15787197231833,163.15095155709344,163.16565743944636,163.17171280276816,163.1371107266436,163.15441176470588,163.15441176470588,163.15787197231833,163.15787197231833,163.16479238754326,163.15181660899654
17
+ openlm-research/open_llama_7b vs huggyllama/llama-7b,0.19,131.92685513316152,131.92089508161513,131.92509664948454,131.91161941580756,131.92971434707903,131.92175418814432,131.9051089991409,131.91141806271477,131.91376718213058,131.92355294243987,131.90969984965636,131.9154182774914,131.91289465206185,131.93862757731958,131.91175365120276,131.9225864475945,131.90466602233678,131.92443889604812,131.91710964347078,131.91717676116838,131.91156572164948,131.91041129725085,131.92552620274915,131.9163981958763,131.90876020189003,131.93009020618555,131.92227770618555,131.91305573453607,131.9106394974227,131.9084380369416,131.91524377147766,131.93839937714776,131.91627738402062,131.9255664733677,131.90842461340208,131.91662639604812,131.91408934707903,131.93258698453607,131.9095924613402,131.92101589347078,131.91830433848799,131.9099414733677,131.93456024484536,131.93268094931273,131.9296740764605,131.90584729381445,131.9298217353952,131.9085320017182,131.90761920103094,131.92474763745705,131.92486844931273,131.90955219072166,131.92473421391753,131.92931164089347,131.9085320017182,131.91313627577318,131.92668062714776,131.90929714347078,131.92293545962198,131.9352314218213,131.91990173969072,131.90663928264604,131.90173969072166,131.90194104381445,131.93144598367698,131.9231502362543,131.91830433848799,131.92064003436425,131.91923056271477,131.9074581185567,131.92055949312714,131.9156196305842,131.93677512886597,131.92196896477662,131.92904317010309,131.91843857388315,131.92469394329896,131.9081561426117,131.92767396907217,131.9004913015464,131.91869362113403,131.9233113187285,131.92473421391753,131.93711071735396,131.91902920962198,131.92630476804123,131.8961554982818,131.91599548969072,131.9234455541237,131.91461286512026,131.90450493986253,131.92191527061857,131.92100246993127,131.92035814003435,131.9273518041237,131.91313627577318,131.91649216065292,131.92000912800688,131.92155283505156,131.93363402061857,131.8985717353952
18
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,0.3,122.38108086340206,122.37894652061856,122.37337575171821,122.39488026202748,122.38780605670104,122.37921499140893,122.38881282216495,122.37757731958763,122.37446305841924,122.37242268041237,122.38428908934708,122.37128167955326,122.37140249140893,122.37792633161511,122.38223528780068,122.37481207044674,122.36418062714776,122.36261007302406,122.38744362113403,122.39500107388317,122.3808929338488,122.37532216494846,122.36401954467354,122.38010094501718,122.37800687285224,122.37418116408935,122.37768470790378,122.36528135738831,122.36638208762886,122.38074527491409,122.37836930841924,122.36595253436425,122.37355025773196,122.38543009020619,122.37176492697594,122.38426224226804,122.37867804982818,122.37700010738831,122.38626234965636,122.36165700171821,122.37559063573883,122.37510738831615,122.38540324312714,122.3731475515464,122.38097347508591,122.38815506872852,122.39007463487972,122.37310728092784,122.3875778565292,122.37246295103093,122.38780605670104,122.38430251288659,122.3854435137457,122.36652974656357,122.37838273195877,122.37607388316151,122.37062392611683,122.37498657646049,122.36797948883161,122.38755100945018,122.3907592353952,122.37144276202748,122.38014121563575,122.36886544243987,122.36836877147766,122.3718723152921,122.36828823024055,122.36466387457045,122.37164411512028,122.39407484965636,122.37930895618557,122.38269168814433,122.37654370704468,122.38055734536083,122.37729542525773,122.38000698024055,122.37559063573883,122.36236844931271,122.38403404209622,122.38416827749141,122.37020779639175,122.37145618556701,122.39795425257732,122.36867751288659,122.38254402920963,122.38759128006873,122.38109428694158,122.38639658505154,122.37050311426117,122.37686587199313,122.37998013316151,122.3759933419244,122.36941580756013,122.37030176116839,122.37702695446735,122.37258376288659,122.36850300687286,122.37001986683849,122.3825977233677,122.38007409793815,122.3877120919244
19
+ openlm-research/open_llama_7b vs EleutherAI/llemma_7b,0.91,133.5973724048443,133.58549145761245,133.58641057525952,133.58520761245674,133.58861375432525,133.58793793252596,133.59591262975778,133.59347967128028,133.58755947231833,133.58696474913495,133.591817149654,133.59662900086505,133.58185553633217,133.5969669117647,133.59514219290656,133.59375,133.58782980103805,133.5972642733564,133.58754595588235,133.5884785899654,133.58831639273356,133.57458369377161,133.58235564446366,133.5881812283737,133.60386029411765,133.5882758434256,133.58638354238755,133.59469615051904,133.5890733131488,133.59923767301038,133.59812932525952,133.59323637543253,133.59118187716263,133.59565581747404,133.5847480536332,133.59446637110727,133.5928849480969,133.58389651816609,133.58851913927336,133.57748972750866,133.58812716262975,133.5881812283737,133.58270707179932,133.58442365916954,133.58492376730104,133.60133272058823,133.59592614619376,133.60564446366783,133.59327692474048,133.59200637975778,133.58781628460207,133.58978968425606,133.59554768598616,133.58843804065745,133.58768112024222,133.5958991133218,133.58226102941177,133.58912737889273,133.58693771626298,133.58226102941177,133.59546658737025,133.58423442906573,133.5899248486159,133.58745134083046,133.5714749134948,133.58942474048442,133.594723183391,133.58489673442907,133.59077638408306,133.59118187716263,133.58484266868513,133.59153330449826,133.58070663927336,133.58363970588235,133.5919928633218,133.59164143598616,133.59230374134947,133.5856401384083,133.59031682525952,133.59293901384083,133.60775302768167,133.58596453287197,133.58158520761245,133.59299307958477,133.58139597750866,133.58386948529412,133.59235780709344,133.60891544117646,133.59389868079586,133.5977373486159,133.5859375,133.587910899654,133.590816933391,133.58539684256056,133.5919928633218,133.59049253892732,133.58482915224914,133.58393706747404,133.59064121972318,133.59954855103805,133.5914116565744
20
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,0.87,131.96828017611685,131.96799828178695,131.95217192869416,131.9627094072165,131.97382409793815,131.95213165807561,131.96915270618555,131.94882946735396,131.94952749140893,131.96253490120276,131.96785062285224,131.96084353522338,131.94980938573883,131.95834675687286,131.95994415807561,131.9536216709622,131.9409766967354,131.962722830756,131.95621241408935,131.95221219931273,131.962722830756,131.9592327104811,131.9690318943299,131.95403780068727,131.9434063573883,131.96559546821305,131.95061479810997,131.9673136812715,131.9615684063574,131.94076192010309,131.96167579467354,131.96227985395188,131.96062875859107,131.9555412371134,131.94344662800688,131.95720575601374,131.9682667525773,131.95826621563575,131.94849387886597,131.95658827319588,131.95464185996565,131.945674935567,131.9667498926117,131.9552593427835,131.9645618556701,131.97374355670104,131.9653404209622,131.97028028350516,131.9652598797251,131.96083011168386,131.9667498926117,131.97054875429552,131.9605213702749,131.94987650343643,131.9624275128866,131.96632033934708,131.96393094931273,131.95391698883162,131.95994415807561,131.969112435567,131.95430627147766,131.95754134450172,131.97320661512026,131.96383698453607,131.96599817439864,131.94372852233678,131.95054768041237,131.96021262886597,131.9595011812715,131.9708038015464,131.9568970146048,131.97076353092783,131.96960910652922,131.96513906786942,131.9573802620275,131.96154155927834,131.95418545962198,131.95394383591065,131.9495677620275,131.94018470790377,131.94351374570448,131.94017128436425,131.95194372852234,131.95993073453607,131.96125966494844,131.95344716494844,131.9577426975945,131.95673593213058,131.9472991838488,131.94257409793815,131.96105831185568,131.96955541237114,131.95891054553263,131.95352770618555,131.96414572594503,131.97213273195877,131.95521907216494,131.94825225515464,131.96821305841925,131.96038713487974,131.9682533290378
21
+ openlm-research/open_llama_7b vs microsoft/Orca-2-7b,0.07,120.33676200259515,120.35221128892734,120.34593966262976,120.33831639273356,120.33472102076125,120.3409115484429,120.34412846020761,120.34708855968859,120.34992701124567,120.35217073961938,120.34460153546713,120.34233077422145,120.34884569636678,120.34533142301038,120.35580666089966,120.33586991782006,120.34685878027682,120.34514219290658,120.34084396626298,120.34454746972318,120.34618295847751,120.34007352941177,120.35161656574394,120.35206260813149,120.34639922145328,120.35118403979239,120.34798064446366,120.33600508217994,120.34848075259515,120.3397491349481,120.34822394031141,120.34788602941177,120.34557471885813,120.34649383650519,120.34214154411765,120.35111645761246,120.3458044982699,120.346683066609,120.34006001297578,120.33628892733564,120.34562878460207,120.34156033737024,120.35415765570934,120.35277897923875,120.34381758217994,120.34957558391004,120.33919496107266,120.34773734861592,120.34558823529412,120.35090019463668,120.36459234429066,120.34373648356402,120.34089803200692,120.34887272923875,120.34576394896193,120.3403303416955,120.3382758434256,120.33942474048443,120.36328125,120.34573691608996,120.34011407871972,120.34289846453287,120.35185986159169,120.33703233131487,120.35460369809688,120.35095426038062,120.34773734861592,120.3406277032872,120.33530222750865,120.34142517301038,120.34164143598616,120.33789738321799,120.33873540224913,120.34546658737024,120.34662900086505,120.3412089100346,120.35587424307958,120.34106022923875,120.35481996107266,120.34895382785467,120.34106022923875,120.3433580233564,120.34008704584775,120.34222264273356,120.3576178633218,120.34314176038062,120.33776221885813,120.34089803200692,120.3569285250865,120.3411142949827,120.34752108564014,120.34787251297578,120.34964316608996,120.33081477076125,120.3451286764706,120.34139814013841,120.35584721020761,120.35129217128028,120.33981671712803,120.35227887110727,120.33951935553634
model-tracing/results/perm/permutation_loss_midpoint_wikitext_single.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model Pair,loss p-value,unpermuted loss,permuted losses
2
+ meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,0.01,2.3246312141418457,12.414887428283691,10.436723709106445,11.463930130004883,10.289920806884766,10.242432594299316,10.102094650268555,10.40760612487793,11.008111000061035,10.213240623474121,10.67473316192627,10.590666770935059,11.035821914672852,11.31821060180664,10.628593444824219,10.762565612792969,10.666658401489258,10.025702476501465,11.167806625366211,9.551054000854492,10.156463623046875,11.520145416259766,10.491896629333496,11.043285369873047,9.333309173583984,11.062390327453613,12.100906372070312,11.202133178710938,9.852943420410156,11.296211242675781,10.389328956604004,11.376590728759766,10.070525169372559,10.026528358459473,9.897621154785156,11.314379692077637,10.841796875,10.254968643188477,10.329296112060547,9.514544486999512,10.902949333190918,11.021820068359375,10.437745094299316,10.649450302124023,10.367554664611816,10.411920547485352,10.035419464111328,10.890546798706055,10.692540168762207,10.53729248046875,11.218603134155273,9.837448120117188,11.327438354492188,9.788040161132812,10.384784698486328,11.534051895141602,11.14419174194336,10.694536209106445,11.900223731994629,10.65000057220459,12.121938705444336,10.567901611328125,11.527382850646973,9.785840034484863,11.457867622375488,10.247855186462402,10.888710021972656,10.499954223632812,10.134185791015625,10.207050323486328,10.164628028869629,10.550536155700684,11.72325325012207,9.784914016723633,9.676228523254395,9.838054656982422,10.10443115234375,10.668523788452148,11.105070114135742,10.959217071533203,10.044930458068848,10.379700660705566,10.720422744750977,10.373221397399902,11.601597785949707,9.767145156860352,11.08609390258789,11.894039154052734,10.29943561553955,9.70090103149414,10.862321853637695,10.94301700592041,10.51127815246582,10.600202560424805,10.773964881896973,10.515543937683105,10.527175903320312,9.276633262634277,10.818120002746582,10.66445541381836,9.917793273925781
3
+ meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,0.64,11.575419425964355,11.40134334564209,11.22384262084961,11.46405029296875,11.452263832092285,10.455619812011719,10.975679397583008,11.58855152130127,12.451857566833496,11.924081802368164,12.242836952209473,11.277565002441406,11.563451766967773,11.68265151977539,11.848937034606934,10.44631290435791,12.024474143981934,11.507471084594727,10.45692253112793,11.344988822937012,10.991713523864746,11.589402198791504,11.106851577758789,11.331841468811035,11.409880638122559,11.256386756896973,11.77776050567627,11.199036598205566,11.516610145568848,11.03821849822998,11.872330665588379,11.290637969970703,11.223627090454102,11.23520278930664,11.499801635742188,11.516365051269531,11.601478576660156,10.872085571289062,11.779101371765137,12.180392265319824,11.032282829284668,11.411971092224121,11.81676197052002,11.883353233337402,11.750810623168945,11.176957130432129,10.957067489624023,14.970661163330078,11.218901634216309,12.29248046875,11.31027603149414,11.960100173950195,11.291650772094727,11.711487770080566,12.310693740844727,11.598724365234375,11.308028221130371,11.46600341796875,11.952092170715332,11.185503959655762,11.173068046569824,11.149730682373047,11.459930419921875,11.234939575195312,11.293815612792969,11.118258476257324,12.10870361328125,12.251635551452637,11.076282501220703,11.664109230041504,12.398612976074219,11.496159553527832,11.82258415222168,11.470304489135742,11.340746879577637,12.212625503540039,10.980117797851562,11.57983684539795,11.1240873336792,11.654163360595703,11.260462760925293,11.391934394836426,10.89322566986084,11.501547813415527,11.8900785446167,10.93996524810791,10.885272979736328,11.528347969055176,11.503525733947754,11.81087589263916,11.610458374023438,11.825965881347656,11.274349212646484,10.811087608337402,11.338240623474121,11.545049667358398,11.304219245910645,11.655720710754395,11.276208877563477,10.926770210266113,10.414559364318848
4
+ meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,0.99,12.2577486038208,11.713994979858398,11.829354286193848,11.17656421661377,10.935490608215332,10.614337921142578,11.733335494995117,10.55212116241455,11.617198944091797,11.02830982208252,11.103273391723633,10.98201847076416,11.652020454406738,11.325891494750977,11.472954750061035,11.740574836730957,11.530477523803711,11.882132530212402,11.853250503540039,11.228290557861328,11.346905708312988,11.387336730957031,11.696052551269531,11.597851753234863,10.276288032531738,10.484336853027344,11.387497901916504,11.903156280517578,10.768474578857422,11.013677597045898,11.241617202758789,11.207913398742676,11.054834365844727,11.4959716796875,10.328938484191895,11.011316299438477,11.860649108886719,10.733698844909668,11.433752059936523,11.57979965209961,10.878405570983887,10.85395336151123,11.535321235656738,11.479925155639648,11.269013404846191,10.643021583557129,10.89708423614502,10.569060325622559,10.77185344696045,11.806313514709473,11.702690124511719,11.08808708190918,10.627337455749512,11.021454811096191,11.430144309997559,10.929702758789062,11.454261779785156,11.20509147644043,10.65364933013916,11.217240333557129,11.659832000732422,10.883869171142578,11.713050842285156,10.736254692077637,11.047320365905762,11.248353004455566,11.70473861694336,10.898530006408691,10.070318222045898,10.314364433288574,10.933834075927734,10.347864151000977,11.01766586303711,10.836737632751465,10.664950370788574,10.768060684204102,10.58050537109375,10.815657615661621,10.808479309082031,11.686751365661621,10.925615310668945,11.23190975189209,11.179738998413086,11.049586296081543,10.853154182434082,11.299239158630371,12.247625350952148,10.916298866271973,11.287036895751953,11.025016784667969,11.59552001953125,11.489020347595215,10.638093948364258,11.174018859863281,10.918281555175781,11.737152099609375,10.530423164367676,10.45882797241211,10.129399299621582,11.413297653198242,11.797016143798828
5
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,1.927486538887024,11.821135520935059,11.198407173156738,9.731367111206055,10.24201774597168,10.990260124206543,11.304125785827637,11.099743843078613,12.419852256774902,12.37198543548584,11.15950870513916,11.141377449035645,11.544782638549805,12.145750999450684,12.45425033569336,11.949847221374512,11.726274490356445,10.694775581359863,11.242759704589844,10.818204879760742,12.982400894165039,10.319830894470215,12.521677017211914,12.44443130493164,13.007572174072266,10.896312713623047,12.76135540008545,12.305526733398438,10.95059871673584,11.4083890914917,12.70875358581543,12.232357025146484,12.533140182495117,10.787182807922363,10.945103645324707,11.467927932739258,10.401111602783203,10.44174861907959,12.008749008178711,11.72877311706543,11.764703750610352,9.260933876037598,9.660629272460938,10.801837921142578,10.224106788635254,10.381224632263184,9.673469543457031,11.380040168762207,11.231133460998535,19.4843807220459,11.832645416259766,11.092031478881836,10.962821960449219,9.960295677185059,10.633321762084961,10.486164093017578,10.409814834594727,11.982443809509277,9.964591026306152,12.886614799499512,10.435898780822754,12.111292839050293,10.96882152557373,11.224199295043945,11.8229341506958,12.327183723449707,10.941879272460938,10.337071418762207,10.541845321655273,11.256275177001953,12.46101188659668,11.085774421691895,11.33261489868164,11.308060646057129,11.747783660888672,10.688689231872559,11.49790096282959,10.940205574035645,8.749332427978516,12.804007530212402,11.88497543334961,11.572806358337402,9.896434783935547,11.37403392791748,10.527442932128906,10.58166790008545,10.856118202209473,11.154521942138672,10.329510688781738,11.657892227172852,11.611366271972656,11.373307228088379,11.529753684997559,11.080778121948242,9.911948204040527,11.649473190307617,11.339656829833984,11.394058227539062,10.780144691467285,10.756746292114258,12.14152717590332
6
+ meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,0.01,2.2617459297180176,12.073643684387207,9.601665496826172,11.438493728637695,11.859432220458984,9.136279106140137,10.389586448669434,10.537793159484863,9.534730911254883,11.314827919006348,11.605558395385742,10.475652694702148,11.510196685791016,11.53796100616455,12.012444496154785,9.623738288879395,9.299081802368164,12.455397605895996,9.597661972045898,12.382262229919434,11.051375389099121,9.97256088256836,9.617574691772461,9.902551651000977,10.14933967590332,10.723257064819336,10.306427001953125,10.507966041564941,10.786173820495605,12.057255744934082,8.83625316619873,10.997522354125977,9.18514633178711,9.580232620239258,9.388571739196777,9.756895065307617,9.360857009887695,10.979372024536133,10.720354080200195,10.257851600646973,9.46591854095459,9.886110305786133,10.586426734924316,10.423218727111816,11.755844116210938,10.80079460144043,10.800479888916016,9.471284866333008,9.391319274902344,12.033135414123535,10.548360824584961,11.953831672668457,10.858678817749023,11.68787670135498,10.111590385437012,11.390238761901855,10.074946403503418,8.820842742919922,10.89544677734375,9.980045318603516,10.665379524230957,9.60218620300293,10.32979965209961,12.539587020874023,8.944647789001465,10.680426597595215,10.794387817382812,12.351435661315918,10.521456718444824,10.917708396911621,10.16922378540039,9.72276496887207,9.371663093566895,10.629088401794434,10.897933959960938,12.550519943237305,8.850178718566895,10.177017211914062,9.906770706176758,12.711698532104492,11.127090454101562,10.410646438598633,10.431379318237305,9.942227363586426,11.459847450256348,15.255708694458008,12.270158767700195,10.828147888183594,10.433170318603516,10.952369689941406,11.425644874572754,9.552434921264648,9.059112548828125,9.621358871459961,12.671890258789062,10.086618423461914,9.930680274963379,12.442190170288086,9.548700332641602,11.214447975158691,10.612310409545898
7
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,0.99,12.176732063293457,12.091742515563965,9.97877025604248,11.391814231872559,10.661109924316406,10.817087173461914,11.534388542175293,11.4345064163208,11.032309532165527,11.198960304260254,10.708231925964355,10.624578475952148,10.788887977600098,10.516521453857422,11.011672019958496,11.516682624816895,11.103180885314941,11.559279441833496,11.907734870910645,11.123976707458496,9.81065559387207,10.62118911743164,11.27763843536377,11.523125648498535,10.444096565246582,11.76767349243164,11.019610404968262,11.484188079833984,11.91457748413086,10.497208595275879,11.79175853729248,11.298368453979492,10.316926956176758,10.76264762878418,11.064068794250488,10.277680397033691,10.894306182861328,11.136940002441406,10.948762893676758,10.753281593322754,11.747068405151367,11.244501113891602,11.157686233520508,11.421767234802246,11.376803398132324,10.557845115661621,11.091096878051758,11.671372413635254,11.795584678649902,10.934099197387695,11.12752628326416,11.514236450195312,10.979425430297852,10.027077674865723,11.399852752685547,11.993349075317383,11.904010772705078,10.320328712463379,11.54886531829834,11.317293167114258,10.79334831237793,11.174019813537598,11.312468528747559,12.456009864807129,11.849678039550781,11.423567771911621,11.26967716217041,11.130739212036133,11.871220588684082,11.11894416809082,10.497808456420898,11.341180801391602,11.120772361755371,11.832477569580078,11.945394515991211,10.435187339782715,10.123815536499023,11.923513412475586,11.165726661682129,10.511265754699707,10.816417694091797,11.620299339294434,11.749600410461426,11.621185302734375,10.726395606994629,10.873162269592285,11.366233825683594,11.751174926757812,11.712872505187988,11.940488815307617,11.489510536193848,11.349868774414062,11.479663848876953,11.586658477783203,10.732288360595703,10.799344062805176,10.621288299560547,11.430079460144043,11.587478637695312,11.59907054901123,10.702455520629883
8
+ meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,0.01,1.9139776229858398,11.634427070617676,11.600348472595215,11.60654067993164,12.126233100891113,12.646196365356445,12.883075714111328,12.256454467773438,12.270318984985352,12.459698677062988,10.608137130737305,12.05358600616455,11.529598236083984,10.475113868713379,11.170097351074219,11.863300323486328,10.970882415771484,10.59147834777832,12.343642234802246,12.012216567993164,11.054078102111816,10.566030502319336,11.191433906555176,11.72900390625,12.421318054199219,11.642358779907227,12.26218032836914,12.066629409790039,11.98330020904541,12.333147048950195,10.645333290100098,11.23974323272705,12.008330345153809,11.807231903076172,10.58747673034668,11.714077949523926,11.820113182067871,11.989992141723633,12.196735382080078,12.603166580200195,10.58354663848877,12.315600395202637,12.89755916595459,12.080608367919922,11.813095092773438,11.074599266052246,10.643448829650879,10.9359769821167,11.14858627319336,12.007444381713867,12.142953872680664,12.374232292175293,10.88144302368164,10.891057014465332,11.451530456542969,12.084410667419434,12.564109802246094,12.367560386657715,10.01160717010498,11.172256469726562,12.028122901916504,11.117118835449219,10.586153030395508,12.69567584991455,11.325706481933594,11.118499755859375,12.011213302612305,12.913552284240723,11.33607006072998,12.125946044921875,10.641644477844238,11.218090057373047,11.023408889770508,11.858912467956543,9.926396369934082,11.108742713928223,10.33670425415039,11.369423866271973,10.87497329711914,10.876391410827637,11.830551147460938,18.367767333984375,11.91503620147705,10.409882545471191,12.978384971618652,11.026498794555664,12.638835906982422,11.090004920959473,13.22242259979248,10.301204681396484,12.243881225585938,11.466031074523926,11.41460132598877,11.086315155029297,11.420076370239258,10.147997856140137,11.232333183288574,11.026019096374512,11.566559791564941,13.21921157836914,11.737604141235352
9
+ meta-llama/Llama-2-7b-hf vs LLM360/Amber,0.11,9.658585548400879,9.505361557006836,11.505193710327148,9.863297462463379,10.549816131591797,10.601881980895996,9.804191589355469,9.992196083068848,10.386860847473145,10.168055534362793,10.143446922302246,11.450798988342285,10.739060401916504,10.874429702758789,11.022444725036621,11.62452507019043,11.448281288146973,10.981934547424316,10.17294692993164,9.73409652709961,10.862936973571777,10.203863143920898,9.555386543273926,11.31235122680664,9.649775505065918,9.874133110046387,10.310741424560547,10.610264778137207,11.861855506896973,11.109381675720215,10.608086585998535,10.293082237243652,10.586475372314453,9.161870956420898,10.675275802612305,9.941357612609863,10.5984525680542,9.234950065612793,9.805939674377441,10.848678588867188,10.375737190246582,11.113080024719238,10.478139877319336,10.219255447387695,9.84744930267334,9.266841888427734,11.306220054626465,11.001436233520508,9.924046516418457,10.07392406463623,10.432853698730469,11.701499938964844,9.725397109985352,9.526941299438477,11.007448196411133,10.04012680053711,11.11932373046875,11.325552940368652,10.447799682617188,10.715643882751465,9.796828269958496,10.892861366271973,10.44837474822998,10.71353816986084,10.633293151855469,10.90569019317627,10.542643547058105,11.243886947631836,11.541812896728516,11.15402889251709,10.407221794128418,11.446313858032227,10.677804946899414,10.178149223327637,10.385024070739746,11.192070960998535,9.855805397033691,10.154035568237305,11.104714393615723,10.019867897033691,10.577777862548828,9.540035247802734,11.262989044189453,9.886129379272461,9.649831771850586,10.185017585754395,9.09082317352295,10.29433822631836,9.811910629272461,10.254436492919922,11.52804946899414,10.116106986999512,10.657572746276855,9.819936752319336,10.43431282043457,10.864540100097656,10.484803199768066,12.004101753234863,11.187883377075195,10.061746597290039,11.245307922363281
10
+ codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,0.65,11.237561225891113,11.680026054382324,11.398880004882812,11.86768913269043,11.39999771118164,11.541326522827148,10.19598388671875,11.428864479064941,11.862239837646484,10.481870651245117,10.985757827758789,10.969829559326172,11.367602348327637,11.111282348632812,11.022823333740234,10.561090469360352,10.207883834838867,10.694847106933594,10.99021053314209,10.593587875366211,11.172009468078613,11.665566444396973,11.263614654541016,11.552456855773926,11.507055282592773,11.336711883544922,11.074647903442383,10.79491138458252,10.656916618347168,10.548526763916016,11.022597312927246,10.966035842895508,11.283071517944336,10.308378219604492,10.846853256225586,10.849564552307129,10.968177795410156,11.308629035949707,11.001319885253906,10.848873138427734,10.73318099975586,10.859678268432617,11.509345054626465,10.539803504943848,10.935742378234863,11.4722900390625,11.278949737548828,11.9168701171875,11.39941692352295,10.587061882019043,10.908806800842285,11.912999153137207,13.913973808288574,11.417015075683594,10.803816795349121,11.063304901123047,11.088665962219238,11.479676246643066,10.848278045654297,11.193321228027344,10.839192390441895,10.731671333312988,11.18687629699707,10.452286720275879,9.904109001159668,11.235068321228027,11.758543014526367,10.346553802490234,10.477143287658691,10.51911449432373,11.152212142944336,10.974753379821777,10.7435302734375,11.238974571228027,12.068460464477539,10.960865020751953,11.370020866394043,11.462693214416504,11.207311630249023,10.798894882202148,11.14282512664795,11.745903968811035,11.410775184631348,10.994516372680664,11.088298797607422,11.085737228393555,11.031476020812988,11.000276565551758,9.836112022399902,11.110745429992676,10.450307846069336,10.72428035736084,11.212058067321777,11.340717315673828,11.261013984680176,11.80449104309082,11.24316120147705,10.83240795135498,10.822101593017578,10.707273483276367,10.825920104980469
11
+ codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,0.32,10.771794319152832,10.501058578491211,10.56005859375,11.231363296508789,10.989171028137207,11.534804344177246,12.136800765991211,11.693639755249023,10.567752838134766,11.91177749633789,11.43637466430664,11.564026832580566,11.265721321105957,11.19032096862793,10.40317153930664,10.623876571655273,10.287189483642578,10.42343521118164,10.866487503051758,12.013888359069824,10.774230003356934,11.138988494873047,11.338595390319824,11.346685409545898,10.583052635192871,10.429323196411133,10.492464065551758,11.78402328491211,10.30309009552002,11.499597549438477,12.184760093688965,11.097319602966309,11.184183120727539,11.21827507019043,10.121459007263184,11.127937316894531,10.251218795776367,11.494721412658691,10.646721839904785,10.4236478805542,11.158930778503418,9.909791946411133,10.339156150817871,11.350455284118652,10.859237670898438,10.905741691589355,12.035804748535156,11.207295417785645,10.787891387939453,11.013189315795898,11.2416410446167,10.195667266845703,11.020384788513184,11.056377410888672,11.036890029907227,10.251171112060547,10.708823204040527,11.038761138916016,10.733219146728516,10.966094017028809,10.910711288452148,10.949577331542969,11.406147003173828,11.047470092773438,9.963555335998535,10.648523330688477,11.16801643371582,11.87838363647461,11.598986625671387,10.867388725280762,12.05274486541748,11.79859733581543,11.681252479553223,10.888261795043945,11.057225227355957,10.08951187133789,11.208805084228516,10.740156173706055,11.282068252563477,11.669149398803711,11.980530738830566,10.39877700805664,11.507721900939941,11.044110298156738,11.520370483398438,10.161965370178223,11.56179428100586,10.723697662353516,10.831975936889648,11.531084060668945,10.410531997680664,11.085836410522461,12.549727439880371,11.945355415344238,10.689722061157227,10.73282241821289,11.869728088378906,12.093964576721191,12.069504737854004,11.068297386169434,11.401714324951172
12
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,2.302018404006958,11.172709465026855,12.183158874511719,10.882681846618652,9.767948150634766,24.42896842956543,10.943366050720215,11.75478744506836,15.235170364379883,9.498678207397461,10.326377868652344,10.187347412109375,10.607508659362793,10.732197761535645,9.70353889465332,10.948262214660645,9.788002967834473,10.9060640335083,10.285189628601074,10.656235694885254,11.046913146972656,10.868155479431152,10.477668762207031,9.764632225036621,9.605854034423828,9.753314018249512,10.699603080749512,9.823118209838867,10.523367881774902,9.758964538574219,9.701371192932129,10.879630088806152,11.618856430053711,13.337488174438477,9.767634391784668,11.983025550842285,10.607478141784668,10.280527114868164,10.127568244934082,12.975506782531738,11.579995155334473,11.203207015991211,10.376364707946777,11.423201560974121,11.209802627563477,9.359248161315918,12.826610565185547,10.373741149902344,9.53094482421875,10.389079093933105,9.205557823181152,10.689602851867676,10.414274215698242,11.192754745483398,10.581353187561035,10.557870864868164,12.383353233337402,10.521531105041504,9.70833969116211,11.450300216674805,10.21054458618164,11.530344009399414,10.224842071533203,10.37917709350586,11.495973587036133,10.469186782836914,9.77238941192627,9.884854316711426,10.01948356628418,9.838289260864258,11.130271911621094,10.648575782775879,10.440532684326172,10.789712905883789,11.31021785736084,10.711143493652344,10.09947681427002,11.175568580627441,9.295597076416016,9.529111862182617,11.034727096557617,9.88634204864502,10.545726776123047,10.429464340209961,10.972481727600098,9.342894554138184,11.735199928283691,9.926610946655273,10.638096809387207,11.533238410949707,11.10690975189209,11.147697448730469,10.123966217041016,9.807247161865234,9.888313293457031,10.178050994873047,10.791872024536133,10.62413501739502,9.422971725463867,9.87336540222168,10.158583641052246
13
+ codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,0.01,1.8881051540374756,11.02577018737793,11.569954872131348,9.650038719177246,12.788570404052734,11.758131980895996,9.877432823181152,10.481515884399414,11.251947402954102,11.191924095153809,11.158823013305664,9.50117015838623,11.658199310302734,10.220966339111328,11.156086921691895,11.29964828491211,11.514382362365723,11.962372779846191,10.005611419677734,10.904683113098145,10.441763877868652,10.319205284118652,11.140127182006836,9.660285949707031,10.85449504852295,10.57632827758789,10.08746337890625,11.422477722167969,11.425626754760742,10.719733238220215,11.198336601257324,9.705710411071777,11.163413047790527,10.196398735046387,12.473783493041992,11.222222328186035,10.624934196472168,11.31132984161377,12.355788230895996,12.014126777648926,10.066976547241211,9.984125137329102,11.545228958129883,10.415767669677734,11.360154151916504,10.375480651855469,11.507826805114746,11.202043533325195,10.294898986816406,9.63949203491211,10.868791580200195,11.535318374633789,10.94536018371582,10.92975902557373,11.955968856811523,10.462796211242676,10.789320945739746,9.42170524597168,11.375829696655273,9.957808494567871,10.590437889099121,11.810711860656738,10.596860885620117,10.35020637512207,10.983830451965332,11.902685165405273,12.088788986206055,10.547697067260742,11.154755592346191,10.342395782470703,9.408880233764648,10.353204727172852,10.178812980651855,11.717917442321777,10.343524932861328,9.107279777526855,11.441140174865723,12.437472343444824,9.229593276977539,9.995097160339355,13.198864936828613,10.668238639831543,9.889089584350586,10.291954040527344,9.115019798278809,11.644171714782715,9.289207458496094,10.126248359680176,11.402484893798828,10.980873107910156,9.706541061401367,10.721711158752441,10.652327537536621,9.227180480957031,11.187313079833984,11.266683578491211,10.421162605285645,10.27970027923584,11.191493034362793,11.59189224243164,10.865525245666504
14
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,0.29,10.684311866760254,9.88802433013916,11.366555213928223,10.514373779296875,11.024654388427734,10.283346176147461,11.294032096862793,11.239091873168945,11.904309272766113,11.442981719970703,9.724740028381348,10.486740112304688,10.531262397766113,10.704021453857422,11.958599090576172,11.913610458374023,10.451040267944336,11.539121627807617,11.079904556274414,10.380290985107422,10.464911460876465,10.824151039123535,10.613722801208496,10.818387031555176,10.909002304077148,11.383553504943848,11.057758331298828,10.797572135925293,11.538613319396973,10.077716827392578,11.08383846282959,10.961997985839844,10.51244068145752,10.801470756530762,11.105854034423828,10.700427055358887,10.818270683288574,10.551637649536133,10.831416130065918,11.392118453979492,11.99124813079834,11.624017715454102,11.47986125946045,11.073309898376465,10.576276779174805,11.65388298034668,10.442636489868164,10.892948150634766,10.495064735412598,11.371464729309082,12.704389572143555,10.441764831542969,11.860435485839844,10.724827766418457,10.910114288330078,9.940642356872559,11.6974515914917,10.893054008483887,10.342144012451172,10.850858688354492,11.492060661315918,10.491785049438477,12.155282974243164,11.26374340057373,10.091155052185059,10.971919059753418,10.889839172363281,9.859522819519043,11.616873741149902,11.444681167602539,10.922995567321777,11.157459259033203,10.100933074951172,10.804597854614258,10.976712226867676,11.442182540893555,11.178875923156738,11.096498489379883,11.058975219726562,11.280542373657227,11.896692276000977,11.171963691711426,10.79887580871582,11.429341316223145,11.694151878356934,12.15854549407959,11.003829002380371,11.078950881958008,11.429558753967285,11.451701164245605,11.222574234008789,11.16943073272705,10.59616756439209,9.994879722595215,11.877211570739746,12.13260555267334,10.476194381713867,9.788568496704102,9.951860427856445,12.063776016235352,10.534578323364258
15
+ codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,0.01,2.320063829421997,9.977145195007324,11.017745018005371,11.732516288757324,10.55106258392334,10.886062622070312,11.940051078796387,10.490015029907227,11.409379959106445,9.593038558959961,9.326103210449219,11.52270221710205,10.925615310668945,11.000499725341797,11.391313552856445,10.632598876953125,10.968502044677734,9.419816970825195,11.293421745300293,10.746554374694824,10.45264720916748,9.733240127563477,9.230606079101562,11.2017240524292,11.673495292663574,9.883476257324219,14.689058303833008,10.65975284576416,11.826485633850098,9.935537338256836,10.098958969116211,11.369180679321289,10.066923141479492,10.007286071777344,10.073885917663574,10.076384544372559,10.496053695678711,9.437036514282227,10.592652320861816,9.364110946655273,10.822487831115723,10.531739234924316,10.315073013305664,9.872851371765137,10.577094078063965,10.572772026062012,10.358898162841797,10.806140899658203,16.713293075561523,10.911810874938965,9.414806365966797,11.51042366027832,10.734992027282715,12.246551513671875,11.260639190673828,10.792762756347656,10.72775936126709,10.830528259277344,11.213608741760254,12.049139976501465,9.823256492614746,9.768250465393066,10.623906135559082,11.0609712600708,10.343968391418457,8.824793815612793,10.778313636779785,10.15536880493164,10.200857162475586,11.061370849609375,10.520877838134766,8.951927185058594,10.345824241638184,11.1963472366333,10.408754348754883,9.833317756652832,9.915206909179688,11.101364135742188,11.677738189697266,11.327123641967773,10.630494117736816,9.27521800994873,10.425193786621094,10.057621955871582,10.960726737976074,11.673463821411133,10.601761817932129,10.41421127319336,10.354918479919434,11.0147705078125,11.991089820861816,9.368815422058105,8.995253562927246,11.107773780822754,10.934799194335938,10.410724639892578,10.493441581726074,10.318268775939941,9.758223533630371,9.74752426147461,10.921327590942383
16
+ codellama/CodeLlama-7b-hf vs LLM360/Amber,0.01,9.603846549987793,10.922776222229004,10.763627052307129,10.674799919128418,10.944330215454102,10.779414176940918,10.923018455505371,9.871673583984375,10.166775703430176,10.810501098632812,10.085100173950195,10.842877388000488,10.339820861816406,10.418185234069824,9.797157287597656,11.204648971557617,11.264816284179688,11.797182083129883,10.71696662902832,10.645628929138184,11.203916549682617,10.131481170654297,10.013680458068848,10.839400291442871,11.125710487365723,10.923445701599121,11.895150184631348,10.997817993164062,11.030921936035156,10.626928329467773,10.994245529174805,10.47247314453125,10.476496696472168,10.737053871154785,10.940540313720703,10.545219421386719,10.65115737915039,10.538052558898926,10.73913860321045,10.608595848083496,11.778488159179688,10.629700660705566,11.484350204467773,10.728248596191406,11.018856048583984,10.530044555664062,10.73550796508789,11.228963851928711,10.074200630187988,11.778253555297852,10.75644588470459,10.916414260864258,10.685816764831543,11.022653579711914,11.845925331115723,11.053071975708008,10.963454246520996,11.932684898376465,11.024358749389648,10.258270263671875,11.10335636138916,11.186234474182129,11.418861389160156,10.6289644241333,11.077651023864746,10.978907585144043,11.048869132995605,11.65108585357666,10.4456787109375,10.570917129516602,10.814337730407715,10.753689765930176,11.621301651000977,10.818181037902832,11.07013988494873,11.709806442260742,10.042532920837402,11.51170825958252,10.287849426269531,11.025697708129883,10.106474876403809,10.335064888000488,10.63949966430664,11.414520263671875,9.983678817749023,10.194723129272461,10.093085289001465,10.348984718322754,10.452190399169922,10.118818283081055,11.143959999084473,10.194239616394043,10.520627975463867,9.854783058166504,11.553059577941895,10.919979095458984,10.652848243713379,11.446795463562012,11.503812789916992,11.831265449523926,10.060680389404297
17
+ openlm-research/open_llama_7b vs huggyllama/llama-7b,0.12,10.616796493530273,10.797248840332031,10.766523361206055,10.866601943969727,11.188263893127441,10.954423904418945,11.125080108642578,11.481340408325195,10.507387161254883,10.983829498291016,11.449170112609863,11.820627212524414,10.870183944702148,11.375383377075195,11.897957801818848,10.795915603637695,11.425924301147461,10.779943466186523,10.608522415161133,11.291844367980957,10.9775972366333,11.396665573120117,11.077995300292969,10.915323257446289,10.149968147277832,10.838491439819336,11.376053810119629,12.027775764465332,11.226425170898438,11.149484634399414,11.914155006408691,11.06672191619873,12.116933822631836,11.715675354003906,11.286850929260254,10.582368850708008,10.894996643066406,10.559850692749023,11.138714790344238,11.34919548034668,11.406577110290527,11.709674835205078,10.768847465515137,11.264957427978516,11.060179710388184,10.48735523223877,11.610673904418945,10.382858276367188,11.576861381530762,11.695962905883789,11.838735580444336,11.8810396194458,11.433664321899414,10.812098503112793,11.260259628295898,10.915751457214355,10.764631271362305,11.118894577026367,11.433198928833008,11.013741493225098,11.012408256530762,11.988961219787598,10.893074989318848,11.14090633392334,11.564040184020996,11.033992767333984,11.447514533996582,10.347433090209961,11.568410873413086,11.325068473815918,11.319805145263672,11.36147689819336,10.636981010437012,10.374312400817871,11.18582820892334,11.290904998779297,11.412924766540527,11.148937225341797,11.629731178283691,10.574426651000977,11.250195503234863,11.57218074798584,10.607427597045898,11.080497741699219,11.073441505432129,11.367008209228516,11.016765594482422,11.250604629516602,10.794018745422363,11.461509704589844,12.037816047668457,10.851967811584473,11.706341743469238,10.905144691467285,12.039907455444336,10.87762451171875,11.777667999267578,11.1010160446167,11.086259841918945,10.743986129760742,11.00861930847168
18
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,0.68,11.10280704498291,11.494667053222656,10.498247146606445,11.133041381835938,11.101852416992188,11.380467414855957,10.272819519042969,10.971842765808105,10.812454223632812,10.571216583251953,11.184111595153809,10.600080490112305,10.708064079284668,11.08830738067627,11.812434196472168,11.79877758026123,10.77804946899414,10.4550142288208,10.40676498413086,10.906953811645508,10.607868194580078,11.703605651855469,11.5511474609375,9.885051727294922,12.578688621520996,10.805289268493652,10.065228462219238,10.415675163269043,9.463606834411621,10.35334587097168,10.453519821166992,11.14851188659668,11.603013038635254,11.39593505859375,10.219343185424805,10.553191184997559,11.516910552978516,11.12913990020752,10.367042541503906,11.04174518585205,10.621479988098145,10.78134536743164,10.831048011779785,10.864336967468262,10.25527572631836,10.48643970489502,10.815017700195312,10.441289901733398,10.695341110229492,12.055279731750488,11.352632522583008,11.06913948059082,11.00218677520752,10.366230010986328,10.634378433227539,11.0018949508667,11.126782417297363,9.731849670410156,10.59158706665039,11.318184852600098,10.395101547241211,11.940820693969727,10.694389343261719,11.356606483459473,10.950284957885742,11.315900802612305,10.90494155883789,11.158031463623047,11.424098014831543,10.433701515197754,11.031879425048828,10.566242218017578,10.624926567077637,11.704424858093262,9.559603691101074,10.54096508026123,10.466423034667969,10.139713287353516,10.190888404846191,10.859702110290527,11.18995475769043,11.700125694274902,10.920588493347168,10.214404106140137,11.712258338928223,11.302619934082031,11.045059204101562,21.055919647216797,10.661722183227539,10.79277515411377,11.078306198120117,10.720754623413086,11.443395614624023,10.91203498840332,11.552059173583984,9.76099967956543,10.945575714111328,10.874262809753418,11.279154777526855,9.812013626098633,11.022870063781738
19
+ openlm-research/open_llama_7b vs EleutherAI/llemma_7b,0.74,11.676680564880371,10.700831413269043,10.920976638793945,11.028438568115234,12.167314529418945,11.756142616271973,11.685977935791016,10.584118843078613,11.082442283630371,11.074419975280762,11.608972549438477,12.379051208496094,10.71125316619873,10.557001113891602,11.452447891235352,11.275219917297363,11.134586334228516,11.230876922607422,12.451685905456543,10.93018913269043,11.579869270324707,11.084054946899414,12.037456512451172,11.611255645751953,11.685752868652344,11.016987800598145,11.213319778442383,11.783517837524414,11.0402250289917,10.860733032226562,11.379624366760254,11.247736930847168,10.858418464660645,11.460200309753418,11.808459281921387,11.439626693725586,12.029986381530762,10.962839126586914,11.192753791809082,11.062772750854492,11.25781536102295,11.559438705444336,11.687849044799805,10.75302791595459,12.29854965209961,11.086894989013672,12.160594940185547,11.388923645019531,10.720033645629883,11.752241134643555,11.075825691223145,11.455404281616211,11.496362686157227,10.872966766357422,11.221406936645508,12.037018775939941,11.722471237182617,11.927973747253418,10.33534049987793,10.593981742858887,10.86403751373291,11.836677551269531,11.976253509521484,11.691933631896973,10.841435432434082,10.745210647583008,11.37106990814209,12.00289535522461,11.43269157409668,12.220749855041504,11.051970481872559,11.547199249267578,10.92629337310791,11.815045356750488,11.536211967468262,11.504477500915527,11.529980659484863,11.00515365600586,11.090625762939453,11.51578426361084,11.141458511352539,11.392504692077637,11.342756271362305,11.645566940307617,11.101119995117188,11.657102584838867,11.088163375854492,11.384390830993652,11.717936515808105,11.144939422607422,11.418688774108887,11.375986099243164,10.539258003234863,11.37796401977539,12.069337844848633,11.964790344238281,11.09656047821045,10.646003723144531,11.4827880859375,11.113225936889648,11.435702323913574
20
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,0.12,10.621867179870605,10.42145824432373,11.195003509521484,11.484685897827148,10.794024467468262,11.019747734069824,11.490104675292969,11.03162670135498,10.776253700256348,11.319022178649902,11.029229164123535,10.46848201751709,11.557026863098145,11.630350112915039,10.21575927734375,10.254093170166016,11.132651329040527,10.749253273010254,11.896352767944336,11.851204872131348,11.208588600158691,11.714153289794922,10.779157638549805,11.057229995727539,11.114214897155762,10.616106033325195,11.634489059448242,11.256133079528809,10.997373580932617,12.223559379577637,11.170015335083008,11.057348251342773,11.103898048400879,11.418852806091309,11.480304718017578,11.226197242736816,11.045890808105469,11.04325008392334,10.816136360168457,10.4801025390625,11.184416770935059,11.797572135925293,11.668251037597656,10.831574440002441,11.632143020629883,11.781684875488281,10.77318000793457,10.571830749511719,11.864724159240723,11.191770553588867,11.520462989807129,10.72339153289795,11.068861961364746,11.573631286621094,10.899574279785156,11.674015998840332,10.131872177124023,11.504653930664062,11.00288200378418,10.962963104248047,11.146491050720215,11.226775169372559,10.69074821472168,11.214972496032715,10.712320327758789,10.820131301879883,11.37641429901123,11.639378547668457,11.570959091186523,11.473939895629883,10.993010520935059,11.848175048828125,11.129470825195312,11.361347198486328,11.544088363647461,11.482494354248047,11.397720336914062,11.10208797454834,10.954719543457031,11.292170524597168,10.730844497680664,11.10391616821289,10.9358491897583,11.644660949707031,11.556461334228516,10.230692863464355,10.93234634399414,11.048127174377441,11.356008529663086,11.290736198425293,10.155138969421387,11.609221458435059,11.059803009033203,10.940677642822266,10.850116729736328,11.106929779052734,10.82397747039795,11.759023666381836,10.195236206054688,11.401366233825684,11.2745943069458
21
+ openlm-research/open_llama_7b vs microsoft/Orca-2-7b,0.85,11.351534843444824,10.584617614746094,10.774704933166504,10.837328910827637,11.801148414611816,11.330574035644531,11.290903091430664,10.862215042114258,11.384349822998047,10.745942115783691,11.390098571777344,11.086434364318848,11.015809059143066,9.748941421508789,10.519133567810059,11.45274829864502,10.552149772644043,11.926155090332031,10.675073623657227,11.253833770751953,10.66576862335205,9.94119644165039,19.171932220458984,10.226914405822754,12.355240821838379,11.018232345581055,10.890312194824219,10.387739181518555,11.143223762512207,12.016646385192871,10.465130805969238,10.726633071899414,11.077333450317383,10.403158187866211,11.529679298400879,10.922691345214844,10.887194633483887,11.203007698059082,10.525662422180176,10.012253761291504,9.828682899475098,10.859309196472168,10.792906761169434,10.633329391479492,10.120866775512695,10.508658409118652,12.094136238098145,10.804570198059082,10.728404998779297,10.672419548034668,12.785676002502441,9.50652027130127,11.129883766174316,11.35869312286377,11.06894302368164,9.344609260559082,10.127263069152832,10.435647010803223,10.6591215133667,11.319711685180664,10.950030326843262,10.885222434997559,10.239350318908691,10.612761497497559,11.15457534790039,10.812723159790039,10.070918083190918,10.321710586547852,11.3038969039917,11.07686996459961,10.517918586730957,9.845905303955078,10.563337326049805,10.056536674499512,11.044584274291992,11.334074974060059,10.442049026489258,10.59984302520752,10.307365417480469,11.638136863708496,10.443572044372559,10.811623573303223,10.837541580200195,10.584928512573242,10.582640647888184,10.057901382446289,10.074625968933105,11.10626220703125,11.195136070251465,10.514017105102539,10.031493186950684,9.643144607543945,11.793606758117676,10.653993606567383,10.19979190826416,11.364167213439941,10.87149715423584,11.284507751464844,9.864716529846191,10.578239440917969,10.145064353942871
model-tracing/results/perm/permutation_norm_loss_midpoint_wikitext_single.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model Pair,norm loss p-value,unpermuted norm loss,permuted norm losses
2
+ meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,0.01,0.3834942579269409,10.473750472068787,8.49558675289154,9.522793173789978,8.34878385066986,8.301295638084412,8.16095769405365,8.466469168663025,9.06697404384613,8.272103667259216,8.733596205711365,8.649529814720154,9.094684958457947,9.377073645591736,8.687456488609314,8.821428656578064,8.725521445274353,8.08456552028656,9.226669669151306,7.609917044639587,8.21532666683197,9.57900846004486,8.550759673118591,9.102148413658142,7.39217221736908,9.121253371238708,10.159769415855408,9.260996222496033,7.9118064641952515,9.355074286460876,8.4481920003891,9.43545377254486,8.129388213157654,8.085391402244568,7.9564841985702515,9.373242735862732,8.900659918785095,8.313831686973572,8.388159155845642,7.573407530784607,8.961812376976013,9.08068311214447,8.496608138084412,8.708313345909119,8.426417708396912,8.470783591270447,8.094282507896423,8.94940984249115,8.751403212547302,8.596155524253845,9.277466177940369,7.896311163902283,9.386301398277283,7.846903204917908,8.443647742271423,9.592914938926697,9.203054785728455,8.75339925289154,9.959086775779724,8.708863615989685,10.180801749229431,8.62676465511322,9.586245894432068,7.8447030782699585,9.516730666160583,8.306718230247498,8.947573065757751,8.558817267417908,8.19304883480072,8.265913367271423,8.223491072654724,8.609399199485779,9.782116293907166,7.843777060508728,7.73509156703949,7.896917700767517,8.163294196128845,8.727386832237244,9.163933157920837,9.018080115318298,8.103793501853943,8.438563704490662,8.779285788536072,8.432084441184998,9.660460829734802,7.826008200645447,9.144956946372986,9.95290219783783,8.358298659324646,7.759764075279236,8.92118489742279,9.001880049705505,8.570141196250916,8.6590656042099,8.832827925682068,8.5744069814682,8.586038947105408,7.335496306419373,8.876983046531677,8.723318457603455,7.9766563177108765
3
+ meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,0.64,6.713594198226929,6.539518117904663,6.362017393112183,6.602225065231323,6.590438604354858,5.593794584274292,6.113854169845581,6.726726293563843,7.590032339096069,7.062256574630737,7.381011724472046,6.4157397747039795,6.701626539230347,6.820826292037964,6.987111806869507,5.584487676620483,7.162648916244507,6.6456458568573,5.595097303390503,6.483163595199585,6.129888296127319,6.727576971054077,6.245026350021362,6.470016241073608,6.548055410385132,6.394561529159546,6.915935277938843,6.33721137046814,6.654784917831421,6.176393270492554,7.010505437850952,6.428812742233276,6.361801862716675,6.373377561569214,6.637976408004761,6.6545398235321045,6.7396533489227295,6.010260343551636,6.91727614402771,7.3185670375823975,6.170457601547241,6.550145864486694,6.954936742782593,7.021528005599976,6.8889853954315186,6.315131902694702,6.095242261886597,10.108835935592651,6.357076406478882,7.430655241012573,6.448450803756714,7.0982749462127686,6.4298255443573,6.84966254234314,7.4488685131073,6.736899137496948,6.446202993392944,6.604178190231323,7.090266942977905,6.323678731918335,6.3112428188323975,6.28790545463562,6.598105192184448,6.373114347457886,6.431990385055542,6.2564332485198975,7.246878385543823,7.38981032371521,6.214457273483276,6.802284002304077,7.536787748336792,6.634334325790405,6.960758924484253,6.608479261398315,6.47892165184021,7.350800275802612,6.118292570114136,6.7180116176605225,6.2622621059417725,6.792338132858276,6.398637533187866,6.530109167098999,6.031400442123413,6.639722585678101,7.0282533168792725,6.078140020370483,6.023447751998901,6.666522741317749,6.641700506210327,6.949050664901733,6.748633146286011,6.9641406536102295,6.412523984909058,5.949262380599976,6.476415395736694,6.683224439620972,6.442394018173218,6.793895483016968,6.41438364982605,6.0649449825286865,5.552734136581421
4
+ meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,0.99,10.224812746047974,9.681059122085571,9.79641842842102,9.143628358840942,8.902554750442505,8.581402063369751,9.70039963722229,8.519185304641724,9.58426308631897,8.995373964309692,9.070337533950806,8.949082612991333,9.619084596633911,9.29295563697815,9.440018892288208,9.70763897895813,9.497541666030884,9.849196672439575,9.820314645767212,9.195354700088501,9.313969850540161,9.354400873184204,9.663116693496704,9.564915895462036,8.243352174758911,8.451400995254517,9.354562044143677,9.870220422744751,8.735538721084595,8.980741739273071,9.208681344985962,9.174977540969849,9.0218985080719,9.463035821914673,8.296002626419067,8.97838044166565,9.827713251113892,8.70076298713684,9.400816202163696,9.546863794326782,8.84546971321106,8.821017503738403,9.502385377883911,9.446989297866821,9.236077547073364,8.610085725784302,8.864148378372192,8.536124467849731,8.738917589187622,9.773377656936646,9.669754266738892,9.055151224136353,8.594401597976685,8.988518953323364,9.397208452224731,8.896766901016235,9.421325922012329,9.172155618667603,8.620713472366333,9.184304475784302,9.626896142959595,8.850933313369751,9.680114984512329,8.70331883430481,9.014384508132935,9.21541714668274,9.671802759170532,8.865594148635864,8.037382364273071,8.281428575515747,8.900898218154907,8.31492829322815,8.984730005264282,8.803801774978638,8.632014513015747,8.735124826431274,8.547569513320923,8.782721757888794,8.775543451309204,9.653815507888794,8.892679452896118,9.198973894119263,9.146803140640259,9.016650438308716,8.820218324661255,9.266303300857544,10.214689493179321,8.883363008499146,9.254101037979126,8.992080926895142,9.562584161758423,9.456084489822388,8.60515809059143,9.141083002090454,8.885345697402954,9.704216241836548,8.497487306594849,8.425892114639282,8.096463441848755,9.380361795425415,9.764080286026001
5
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,-0.048818111419677734,9.844830870628357,9.222102522850037,7.755062460899353,8.265713095664978,9.013955473899841,9.327821135520935,9.123439192771912,10.4435476064682,10.395680785179138,9.183204054832458,9.165072798728943,9.568477988243103,10.169446349143982,10.477945685386658,9.97354257106781,9.749969840049744,8.718470931053162,9.266455054283142,8.84190022945404,11.006096243858337,8.343526244163513,10.545372366905212,10.468126654624939,11.031267523765564,8.920008063316345,10.785050749778748,10.329222083091736,8.974294066429138,9.432084441184998,10.732448935508728,10.256052374839783,10.556835532188416,8.810878157615662,8.968798995018005,9.491623282432556,8.424806952476501,8.465443968772888,10.03244435787201,9.752468466758728,9.78839910030365,7.284629225730896,7.684324622154236,8.825533270835876,8.247802138328552,8.404919981956482,7.69716489315033,9.403735518455505,9.254828810691833,17.508076071739197,9.856340765953064,9.115726828575134,8.986517310142517,7.983991026878357,8.65701711177826,8.509859442710876,8.433510184288025,10.006139159202576,7.988286375999451,10.91031014919281,8.459594130516052,10.134988188743591,8.992516875267029,9.247894644737244,9.8466295003891,10.350879073143005,8.965574622154236,8.360766768455505,8.565540671348572,9.279970526695251,10.484707236289978,9.109469771385193,9.356310248374939,9.331755995750427,9.77147901058197,8.712384581565857,9.521596312522888,8.963900923728943,6.773027777671814,10.8277028799057,9.908670783042908,9.5965017080307,7.920130133628845,9.397729277610779,8.551138281822205,8.605363249778748,8.879813551902771,9.17821729183197,8.353206038475037,9.68158757686615,9.635061621665955,9.397002577781677,9.553449034690857,9.10447347164154,7.935643553733826,9.673168540000916,9.363352179527283,9.41775357723236,8.803840041160583,8.780441641807556,10.165222525596619
6
+ meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,0.01,0.44821369647979736,10.260111451148987,7.788133263587952,9.624961495399475,10.045899987220764,7.3227468729019165,8.576054215431213,8.724260926246643,7.721198678016663,9.501295685768127,9.792026162147522,8.662120461463928,9.696664452552795,9.72442877292633,10.198912262916565,7.810206055641174,7.485549569129944,10.641865372657776,7.784129738807678,10.568729996681213,9.2378431558609,8.15902864933014,7.804042458534241,8.089019417762756,8.3358074426651,8.909724831581116,8.492894768714905,8.694433808326721,8.972641587257385,10.243723511695862,7.02272093296051,9.183990120887756,7.371614098548889,7.766700387001038,7.575039505958557,7.943362832069397,7.547324776649475,9.165839791297913,8.906821846961975,8.444319367408752,7.65238630771637,8.072578072547913,8.772894501686096,8.609686493873596,9.942311882972717,8.98726236820221,8.986947655677795,7.657752633094788,7.5777870416641235,10.219603180885315,8.73482859134674,10.140299439430237,9.045146584510803,9.87434446811676,8.298058152198792,9.576706528663635,8.261414170265198,7.007310509681702,9.08191454410553,8.166513085365295,8.851847290992737,7.7886539697647095,8.51626741886139,10.726054787635803,7.131115555763245,8.866894364356995,8.980855584144592,10.537903428077698,8.707924485206604,9.1041761636734,8.35569155216217,7.90923273563385,7.558130860328674,8.815556168556213,9.084401726722717,10.736987709999084,7.036646485328674,8.363484978675842,8.093238472938538,10.898166298866272,9.313558220863342,8.597114205360413,8.617847084999084,8.128695130348206,9.646315217018127,13.442176461219788,10.456626534461975,9.014615654945374,8.619638085365295,9.138837456703186,9.612112641334534,7.738902688026428,7.245580315589905,7.807826638221741,10.858358025550842,8.273086190223694,8.117148041725159,10.628657937049866,7.735168099403381,9.400915741920471,8.798778176307678
7
+ meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,0.99,10.070954203605652,9.98596465587616,7.872992396354675,9.286036372184753,8.555332064628601,8.711309313774109,9.428610682487488,9.328728556632996,8.926531672477722,9.093182444572449,8.60245406627655,8.518800616264343,8.683110117912292,8.410743594169617,8.905894160270691,9.41090476512909,8.997403025627136,9.453501582145691,9.80195701122284,9.018198847770691,7.704877734184265,8.515411257743835,9.171860575675964,9.41734778881073,8.338318705558777,9.661895632743835,8.913832545280457,9.37841022014618,9.808799624443054,8.391430735588074,9.685980677604675,9.192590594291687,8.211149096488953,8.656869769096375,8.958290934562683,8.171902537345886,8.788528323173523,9.031162142753601,8.842985033988953,8.647503733634949,9.641290545463562,9.138723254203796,9.051908373832703,9.315989375114441,9.271025538444519,8.452067255973816,8.985319018363953,9.565594553947449,9.689806818962097,8.82832133769989,9.021748423576355,9.408458590507507,8.873647570610046,7.9212998151779175,9.294074892997742,9.887571215629578,9.798232913017273,8.214550852775574,9.443087458610535,9.211515307426453,8.687570452690125,9.068241953849792,9.206690669059753,10.350232005119324,9.743900179862976,9.317789912223816,9.163899302482605,9.024961352348328,9.765442728996277,9.013166308403015,8.392030596733093,9.235402941703796,9.014994502067566,9.726699709892273,9.839616656303406,8.32940948009491,8.018037676811218,9.81773555278778,9.059948801994324,8.405487895011902,8.710639834403992,9.514521479606628,9.64382255077362,9.51540744304657,8.620617747306824,8.76738440990448,9.260455965995789,9.645397067070007,9.607094645500183,9.834710955619812,9.383732676506042,9.244090914726257,9.373885989189148,9.480880618095398,8.626510500907898,8.69356620311737,8.515510439872742,9.324301600456238,9.481700778007507,9.493292689323425,8.596677660942078
8
+ meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,0.01,-0.10828566551208496,9.612163782119751,9.57808518409729,9.584277391433716,10.103969812393188,10.62393307685852,10.860812425613403,10.234191179275513,10.248055696487427,10.437435388565063,8.58587384223938,10.031322717666626,9.50733494758606,8.452850580215454,9.147834062576294,9.841037034988403,8.94861912727356,8.569215059280396,10.321378946304321,9.98995327949524,9.031814813613892,8.543767213821411,9.169170618057251,9.706740617752075,10.399054765701294,9.620095491409302,10.239917039871216,10.044366121292114,9.961036920547485,10.31088376045227,8.623070001602173,9.217479944229126,9.986067056655884,9.784968614578247,8.565213441848755,9.691814661026001,9.797849893569946,9.967728853225708,10.174472093582153,10.58090329170227,8.561283349990845,10.293337106704712,10.875295877456665,10.058345079421997,9.790831804275513,9.052335977554321,8.621185541152954,8.913713693618774,9.126322984695435,9.985181093215942,10.12069058418274,10.351969003677368,8.859179735183716,8.868793725967407,9.429267168045044,10.062147378921509,10.541846513748169,10.34529709815979,7.989343881607056,9.149993181228638,10.005859613418579,9.094855546951294,8.563889741897583,10.673412561416626,9.303443193435669,9.09623646736145,9.98895001411438,10.891288995742798,9.313806772232056,10.10368275642395,8.619381189346313,9.195826768875122,9.001145601272583,9.836649179458618,7.904133081436157,9.086479425430298,8.314440965652466,9.347160577774048,8.852710008621216,8.854128122329712,9.808287858963013,16.34550404548645,9.892772912979126,8.387619256973267,10.956121683120728,9.00423550605774,10.616572618484497,9.067741632461548,11.200159311294556,8.27894139289856,10.221617937088013,9.443767786026001,9.392338037490845,9.064051866531372,9.397813081741333,8.125734567642212,9.21006989479065,9.003755807876587,9.544296503067017,11.196948289871216,9.715340852737427
9
+ meta-llama/Llama-2-7b-hf vs LLM360/Amber,0.11,7.894630134105682,7.741406142711639,9.741238296031952,8.099342048168182,8.7858607172966,8.8379265666008,8.040236175060272,8.228240668773651,8.622905433177948,8.404100120067596,8.37949150800705,9.686843574047089,8.975104987621307,9.110474288463593,9.258489310741425,9.860569655895233,9.684325873851776,9.21797913312912,8.408991515636444,7.970141112804413,9.09898155927658,8.439907729625702,7.791431128978729,9.548395812511444,7.885820090770721,8.11017769575119,8.54678601026535,8.84630936384201,10.097900092601776,9.345426261425018,8.844131171703339,8.529126822948456,8.822519958019257,7.397915542125702,8.911320388317108,8.177402198314667,8.834497153759003,7.470994651317596,8.041984260082245,9.084723174571991,8.611781775951385,9.349124610424042,8.71418446302414,8.455300033092499,8.083493888378143,7.502886474132538,9.542264640331268,9.237480819225311,8.16009110212326,8.309968650341034,8.668898284435272,9.937544524669647,7.961441695690155,7.76298588514328,9.243492782115936,8.276171386241913,9.355368316173553,9.561597526073456,8.683844268321991,8.951688468456268,8.0328728556633,9.128905951976776,8.684419333934784,8.949582755565643,8.869337737560272,9.141734778881073,8.778688132762909,9.47993153333664,9.777857482433319,9.390073478221893,8.643266379833221,9.68235844373703,8.913849532604218,8.41419380903244,8.62106865644455,9.428115546703339,8.091849982738495,8.390080153942108,9.340758979320526,8.255912482738495,8.813822448253632,7.776079833507538,9.499033629894257,8.122173964977264,7.885876357555389,8.421062171459198,7.326867759227753,8.530382812023163,8.047955214977264,8.490481078624725,9.764094054698944,8.352151572704315,8.893617331981659,8.05598133802414,8.670357406139374,9.10058468580246,8.72084778547287,10.240146338939667,9.423927962779999,8.297791182994843,9.481352508068085
10
+ codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,0.65,6.28434944152832,6.726814270019531,6.4456682205200195,6.914477348327637,6.446785926818848,6.5881147384643555,5.242772102355957,6.475652694702148,6.909028053283691,5.528658866882324,6.032546043395996,6.016617774963379,6.414390563964844,6.1580705642700195,6.069611549377441,5.607878684997559,5.254672050476074,5.741635322570801,6.036998748779297,5.640376091003418,6.21879768371582,6.71235466003418,6.310402870178223,6.599245071411133,6.5538434982299805,6.383500099182129,6.12143611907959,5.841699600219727,5.703704833984375,5.595314979553223,6.069385528564453,6.012824058532715,6.329859733581543,5.355166435241699,5.893641471862793,5.896352767944336,6.014966011047363,6.355417251586914,6.048108100891113,5.895661354064941,5.779969215393066,5.906466484069824,6.556133270263672,5.586591720581055,5.98253059387207,6.519078254699707,6.325737953186035,6.963658332824707,6.446205139160156,5.63385009765625,5.955595016479492,6.959787368774414,8.960762023925781,6.463803291320801,5.850605010986328,6.110093116760254,6.135454177856445,6.526464462280273,5.895066261291504,6.240109443664551,5.885980606079102,5.778459548950195,6.233664512634277,5.499074935913086,4.950897216796875,6.281856536865234,6.805331230163574,5.393342018127441,5.523931503295898,5.5659027099609375,6.199000358581543,6.021541595458984,5.790318489074707,6.285762786865234,7.115248680114746,6.00765323638916,6.41680908203125,6.509481430053711,6.2540998458862305,5.8456830978393555,6.189613342285156,6.792692184448242,6.457563400268555,6.041304588317871,6.135087013244629,6.132525444030762,6.078264236450195,6.047064781188965,4.882900238037109,6.157533645629883,5.497096061706543,5.771068572998047,6.258846282958984,6.387505531311035,6.307802200317383,6.851279258728027,6.289949417114258,5.8791961669921875,5.868889808654785,5.754061698913574,5.872708320617676
11
+ codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,0.32,8.647472620010376,8.376736879348755,8.435736894607544,9.107041597366333,8.864849328994751,9.41048264503479,10.012479066848755,9.569318056106567,8.44343113899231,9.787455797195435,9.312052965164185,9.43970513343811,9.141399621963501,9.065999269485474,8.278849840164185,8.499554872512817,8.162867784500122,8.299113512039185,8.742165803909302,9.889566659927368,8.649908304214478,9.01466679573059,9.214273691177368,9.222363710403442,8.458730936050415,8.305001497268677,8.368142366409302,9.659701585769653,8.178768396377563,9.37527585029602,10.060438394546509,8.972997903823853,9.059861421585083,9.093953371047974,7.9971373081207275,9.003615617752075,8.126897096633911,9.370399713516235,8.522400140762329,8.299326181411743,9.034609079360962,7.785470247268677,8.214834451675415,9.226133584976196,8.734915971755981,8.7814199924469,9.9114830493927,9.082973718643188,8.663569688796997,8.888867616653442,9.117319345474243,8.071345567703247,8.896063089370728,8.932055711746216,8.91256833076477,8.12684941291809,8.584501504898071,8.91443943977356,8.60889744758606,8.841772317886353,8.786389589309692,8.825255632400513,9.281825304031372,8.923148393630981,7.839233636856079,8.52420163154602,9.043694734573364,9.754061937332153,9.47466492652893,8.743067026138306,9.928423166275024,9.674275636672974,9.556930780410767,8.76394009590149,8.932903528213501,7.965190172195435,9.08448338508606,8.615834474563599,9.15774655342102,9.544827699661255,9.85620903968811,8.274455308914185,9.383400201797485,8.919788599014282,9.396048784255981,8.037643671035767,9.437472581863403,8.59937596321106,8.707654237747192,9.40676236152649,8.286210298538208,8.961514711380005,10.425405740737915,9.821033716201782,8.56540036201477,8.608500719070435,9.74540638923645,9.969642877578735,9.945183038711548,8.943975687026978,9.277392625808716
12
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,0.23432707786560059,9.105018138885498,10.115467548370361,8.814990520477295,7.700256824493408,22.361277103424072,8.875674724578857,9.687096118927002,13.167479038238525,7.4309868812561035,8.258686542510986,8.119656085968018,8.539817333221436,8.664506435394287,7.635847568511963,8.880570888519287,7.720311641693115,8.838372707366943,8.217498302459717,8.588544368743896,8.979221820831299,8.800464153289795,8.409977436065674,7.696940898895264,7.538162708282471,7.685622692108154,8.631911754608154,7.75542688369751,8.455676555633545,7.691273212432861,7.6336798667907715,8.811938762664795,9.551165103912354,11.26979684829712,7.6999430656433105,9.915334224700928,8.53978681564331,8.212835788726807,8.059876918792725,10.90781545639038,9.512303829193115,9.135515689849854,8.30867338180542,9.355510234832764,9.14211130142212,7.2915568351745605,10.75891923904419,8.306049823760986,7.463253498077393,8.321387767791748,7.137866497039795,8.621911525726318,8.346582889556885,9.125063419342041,8.513661861419678,8.490179538726807,10.315661907196045,8.453839778900146,7.640648365020752,9.382608890533447,8.142853260040283,9.462652683258057,8.157150745391846,8.311485767364502,9.428282260894775,8.401495456695557,7.704698085784912,7.817162990570068,7.951792240142822,7.7705979347229,9.062580585479736,8.580884456634521,8.372841358184814,8.722021579742432,9.242526531219482,8.643452167510986,8.031785488128662,9.107877254486084,7.227905750274658,7.46142053604126,8.96703577041626,7.818650722503662,8.47803544998169,8.361773014068604,8.90479040145874,7.275203227996826,9.667508602142334,7.858919620513916,8.57040548324585,9.46554708480835,9.039218425750732,9.080006122589111,8.056274890899658,7.739555835723877,7.820621967315674,8.11035966873169,8.724180698394775,8.556443691253662,7.35528039932251,7.805674076080322,8.090892314910889
13
+ codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,0.01,-0.01681309938430786,9.120851933956146,9.665036618709564,7.745120465755463,10.883652150630951,9.853213727474213,7.972514569759369,8.57659763097763,9.347029149532318,9.287005841732025,9.25390475988388,7.596251904964447,9.753281056880951,8.316048085689545,9.251168668270111,9.394730031490326,9.60946410894394,10.057454526424408,8.100693166255951,8.999764859676361,8.536845624446869,8.414287030696869,9.235208928585052,7.755367696285248,8.949576795101166,8.671410024166107,8.182545125484467,9.517559468746185,9.520708501338959,8.814814984798431,9.29341834783554,7.800792157649994,9.258494794368744,8.291480481624603,10.568865239620209,9.317304074764252,8.720015943050385,9.406411588191986,10.450869977474213,10.109208524227142,8.162058293819427,8.079206883907318,9.6403107047081,8.510849416255951,9.45523589849472,8.470562398433685,9.602908551692963,9.297125279903412,8.389980733394623,7.734573781490326,8.963873326778412,9.630400121212006,9.040441930294037,9.024840772151947,10.05105060338974,8.557877957820892,8.884402692317963,7.516786992549896,9.47091144323349,8.052890241146088,8.685519635677338,9.905793607234955,8.691942632198334,8.445288121700287,9.078912198543549,9.99776691198349,10.183870732784271,8.642778813838959,9.249837338924408,8.43747752904892,7.503961980342865,8.448286473751068,8.273894727230072,9.812999188899994,8.438606679439545,7.202361524105072,9.53622192144394,10.53255409002304,7.324675023555756,8.090178906917572,11.29394668340683,8.76332038640976,7.9841713309288025,8.38703578710556,7.210101544857025,9.739253461360931,7.38428920507431,8.221330106258392,9.497566640377045,9.075954854488373,7.801622807979584,8.816792905330658,8.747409284114838,7.322262227535248,9.282394826412201,9.361765325069427,8.516244351863861,8.374782025814056,9.28657478094101,9.686973989009857,8.96060699224472
14
+ codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,0.29,8.48714804649353,7.6908605098724365,9.169391393661499,8.317209959030151,8.82749056816101,8.086182355880737,9.09686827659607,9.041928052902222,9.70714545249939,9.24581789970398,7.527576208114624,8.289576292037964,8.33409857749939,8.506857633590698,9.761435270309448,9.7164466381073,8.253876447677612,9.341957807540894,8.88274073600769,8.183127164840698,8.267747640609741,8.626987218856812,8.416558980941772,8.621223211288452,8.711838483810425,9.186389684677124,8.860594511032104,8.60040831565857,9.341449499130249,7.8805530071258545,8.886674642562866,8.76483416557312,8.315276861190796,8.604306936264038,8.908690214157104,8.503263235092163,8.62110686302185,8.35447382926941,8.634252309799194,9.194954633712769,9.794084310531616,9.426853895187378,9.282697439193726,8.876146078109741,8.379112958908081,9.456719160079956,8.24547266960144,8.695784330368042,8.297900915145874,9.174300909042358,10.507225751876831,8.244601011276245,9.66327166557312,8.527663946151733,8.712950468063354,7.743478536605835,9.500287771224976,8.695890188217163,8.144980192184448,8.653694868087769,9.294896841049194,8.294621229171753,9.95811915397644,9.066579580307007,7.893991231918335,8.774755239486694,8.692675352096558,7.662358999252319,9.419709920883179,9.247517347335815,8.725831747055054,8.96029543876648,7.903769254684448,8.607434034347534,8.779548406600952,9.245018720626831,8.981712102890015,8.89933466911316,8.861811399459839,9.083378553390503,9.699528455734253,8.974799871444702,8.601711988449097,9.232177495956421,9.49698805809021,9.961381673812866,8.806665182113647,8.881787061691284,9.232394933700562,9.254537343978882,9.025410413742065,8.972266912460327,8.399003744125366,7.797715902328491,9.680047750473022,9.935441732406616,8.279030561447144,7.591404676437378,7.754696607589722,9.866612195968628,8.337414503097534
15
+ codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,0.01,0.20641469955444336,7.8634960651397705,8.904095888137817,9.61886715888977,8.437413454055786,8.772413492202759,9.826401948928833,8.376365900039673,9.295730829238892,7.479389429092407,7.212454080581665,9.409053087234497,8.811966180801392,8.886850595474243,9.277664422988892,8.518949747085571,8.85485291481018,7.306167840957642,9.17977261543274,8.63290524482727,8.338998079299927,7.619590997695923,7.116956949234009,9.088074922561646,9.55984616279602,7.769827127456665,12.575409173965454,8.546103715896606,9.712836503982544,7.821888208389282,7.985309839248657,9.255531549453735,7.9532740116119385,7.89363694190979,7.9602367877960205,7.962735414505005,8.382404565811157,7.323387384414673,8.479003190994263,7.25046181678772,8.708838701248169,8.418090105056763,8.20142388343811,7.759202241897583,8.463444948196411,8.459122896194458,8.245249032974243,8.69249176979065,14.59964394569397,8.798161745071411,7.301157236099243,9.396774530410767,8.621342897415161,10.132902383804321,9.146990060806274,8.679113626480103,8.614110231399536,8.71687912940979,9.0999596118927,9.935490846633911,7.709607362747192,7.654601335525513,8.510257005691528,8.947322130203247,8.230319261550903,6.711144685745239,8.664664506912231,8.041719675064087,8.087208032608032,8.947721719741821,8.407228708267212,6.83827805519104,8.23217511177063,9.082698106765747,8.295105218887329,7.719668626785278,7.801557779312134,8.987715005874634,9.564089059829712,9.21347451210022,8.516844987869263,7.161568880081177,8.31154465675354,7.943972826004028,8.84707760810852,9.559814691543579,8.488112688064575,8.300562143325806,8.24126935005188,8.901121377944946,9.877440690994263,7.255166292190552,6.881604433059692,8.9941246509552,8.821150064468384,8.297075510025024,8.37979245185852,8.204619646072388,7.644574403762817,7.633875131607056,8.807678461074829
16
+ codellama/CodeLlama-7b-hf vs LLM360/Amber,0.01,7.748504936695099,9.06743460893631,8.908285439014435,8.819458305835724,9.088988602161407,8.924072563648224,9.067676842212677,8.016331970691681,8.311434090137482,8.955159485340118,8.229758560657501,8.987535774707794,8.484479248523712,8.56284362077713,7.941815674304962,9.349307358264923,9.409474670886993,9.941840469837189,8.861625015735626,8.79028731584549,9.348574936389923,8.276139557361603,8.158338844776154,8.984058678150177,9.270368874073029,9.068104088306427,10.039808571338654,9.142476379871368,9.175580322742462,8.77158671617508,9.13890391588211,8.617131531238556,8.621155083179474,8.881712257862091,9.085198700428009,8.689877808094025,8.795815765857697,8.682710945606232,8.883796989917755,8.753254234790802,9.923146545886993,8.774359047412872,9.62900859117508,8.872906982898712,9.16351443529129,8.674702942371368,8.880166351795197,9.373622238636017,8.218859016895294,9.922911942005157,8.901104271411896,9.061072647571564,8.830475151538849,9.16731196641922,9.990583717823029,9.197730362415314,9.108112633228302,10.07734328508377,9.169017136096954,8.402928650379181,9.248014748096466,9.330892860889435,9.563519775867462,8.773622810840607,9.222309410572052,9.123565971851349,9.193527519702911,9.795744240283966,8.590337097644806,8.715575516223907,8.95899611711502,8.898348152637482,9.765960037708282,8.962839424610138,9.214798271656036,9.854464828968048,8.187191307544708,9.656366646289825,8.432507812976837,9.170356094837189,8.251133263111115,8.479723274707794,8.784158051013947,9.559178650379181,8.12833720445633,8.339381515979767,8.23774367570877,8.49364310503006,8.596848785877228,8.26347666978836,9.288618385791779,8.338898003101349,8.665286362171173,7.99944144487381,9.6977179646492,9.06463748216629,8.797506630420685,9.591453850269318,9.648471176624298,9.975923836231232,8.205338776111603
17
+ openlm-research/open_llama_7b vs huggyllama/llama-7b,0.12,5.4373085498809814,5.617760896682739,5.587035417556763,5.687114000320435,6.008775949478149,5.774935960769653,5.945592164993286,6.301852464675903,5.327899217605591,5.804341554641724,6.269682168960571,6.641139268875122,5.6906960010528564,6.195895433425903,6.718469858169556,5.616427659988403,6.246436357498169,5.6004555225372314,5.429034471511841,6.112356424331665,5.798109292984009,6.217177629470825,5.898507356643677,5.735835313796997,4.97048020362854,5.659003496170044,6.196565866470337,6.84828782081604,6.0469372272491455,5.969996690750122,6.734667062759399,5.8872339725494385,6.937445878982544,6.536187410354614,6.107362985610962,5.402880907058716,5.715508699417114,5.3803627490997314,5.959226846694946,6.169707536697388,6.227089166641235,6.530186891555786,5.589359521865845,6.085469484329224,5.880691766738892,5.3078672885894775,6.431185960769653,5.2033703327178955,6.39737343788147,6.516474962234497,6.659247636795044,6.701551675796509,6.254176378250122,5.632610559463501,6.0807716846466064,5.7362635135650635,5.585143327713013,5.939406633377075,6.253710985183716,5.834253549575806,5.83292031288147,6.809473276138306,5.713587045669556,5.961418390274048,6.384552240371704,5.854504823684692,6.26802659034729,5.167945146560669,6.388922929763794,6.145580530166626,6.14031720161438,6.181988954544067,5.45749306678772,5.194824457168579,6.006340265274048,6.111417055130005,6.233436822891235,5.969449281692505,6.450243234634399,5.394938707351685,6.070707559585571,6.392692804336548,5.4279396533966064,5.901009798049927,5.893953561782837,6.187520265579224,5.83727765083313,6.07111668586731,5.614530801773071,6.282021760940552,6.858328104019165,5.672479867935181,6.526853799819946,5.725656747817993,6.860419511795044,5.698136568069458,6.598180055618286,5.921528100967407,5.906771898269653,5.56449818611145,5.829131364822388
18
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,0.68,6.308373689651489,6.700233697891235,5.703813791275024,6.338608026504517,6.307419061660767,6.586034059524536,5.478386163711548,6.177409410476685,6.018020868301392,5.776783227920532,6.389678239822388,5.805647134780884,5.913630723953247,6.293874025344849,7.018000841140747,7.00434422492981,5.98361611366272,5.66058087348938,5.6123316287994385,6.112520456314087,5.813434839248657,6.909172296524048,6.756714105606079,5.090618371963501,7.784255266189575,6.0108559131622314,5.270795106887817,5.621241807937622,4.6691734790802,5.558912515640259,5.659086465835571,6.354078531265259,6.808579683303833,6.601501703262329,5.424909830093384,5.758757829666138,6.722477197647095,6.334706544876099,5.572609186172485,6.24731183052063,5.827046632766724,5.98691201210022,6.036614656448364,6.069903612136841,5.4608423709869385,5.692006349563599,6.020584344863892,5.6468565464019775,5.900907754898071,7.260846376419067,6.558199167251587,6.274706125259399,6.207753419876099,5.571796655654907,5.839945077896118,6.207461595535278,6.332349061965942,4.937416315078735,5.79715371131897,6.523751497268677,5.60066819190979,7.146387338638306,5.899955987930298,6.562173128128052,6.155851602554321,6.521467447280884,6.11050820350647,6.363598108291626,6.629664659500122,5.639268159866333,6.237446069717407,5.771808862686157,5.830493211746216,6.909991502761841,4.765170335769653,5.74653172492981,5.671989679336548,5.345279932022095,5.3964550495147705,6.0652687549591064,6.395521402359009,6.9056923389434814,6.126155138015747,5.419970750808716,6.917824983596802,6.50818657875061,6.250625848770142,16.261486291885376,5.867288827896118,5.998341798782349,6.283872842788696,5.926321268081665,6.6489622592926025,6.117601633071899,6.7576258182525635,4.966566324234009,6.151142358779907,6.079829454421997,6.484721422195435,5.017580270767212,6.228436708450317
19
+ openlm-research/open_llama_7b vs EleutherAI/llemma_7b,0.74,6.918841361999512,5.942992210388184,6.163137435913086,6.270599365234375,7.409475326538086,6.998303413391113,6.928138732910156,5.826279640197754,6.324603080749512,6.316580772399902,6.851133346557617,7.621212005615234,5.953413963317871,5.799161911010742,6.694608688354492,6.517380714416504,6.376747131347656,6.4730377197265625,7.693846702575684,6.17234992980957,6.822030067443848,6.326215744018555,7.2796173095703125,6.853416442871094,6.927913665771484,6.259148597717285,6.455480575561523,7.025678634643555,6.28238582611084,6.102893829345703,6.6217851638793945,6.489897727966309,6.100579261779785,6.702361106872559,7.050620079040527,6.681787490844727,7.272147178649902,6.204999923706055,6.434914588928223,6.304933547973633,6.49997615814209,6.801599502563477,6.930009841918945,5.9951887130737305,7.54071044921875,6.3290557861328125,7.4027557373046875,6.631084442138672,5.962194442749023,6.994401931762695,6.317986488342285,6.697565078735352,6.738523483276367,6.1151275634765625,6.463567733764648,7.279179573059082,6.964632034301758,7.170134544372559,5.57750129699707,5.836142539978027,6.106198310852051,7.078838348388672,7.218414306640625,6.934094429016113,6.083596229553223,5.987371444702148,6.6132307052612305,7.24505615234375,6.67485237121582,7.4629106521606445,6.294131278991699,6.789360046386719,6.168454170227051,7.057206153869629,6.778372764587402,6.746638298034668,6.772141456604004,6.247314453125,6.332786560058594,6.7579450607299805,6.38361930847168,6.634665489196777,6.584917068481445,6.887727737426758,6.343280792236328,6.899263381958008,6.330324172973633,6.626551628112793,6.960097312927246,6.3871002197265625,6.660849571228027,6.618146896362305,5.781418800354004,6.620124816894531,7.311498641967773,7.206951141357422,6.33872127532959,5.888164520263672,6.724948883056641,6.355386734008789,6.677863121032715
20
+ openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,0.12,5.251814126968384,5.051405191421509,5.824950456619263,6.114632844924927,5.42397141456604,5.6496946811676025,6.120051622390747,5.661573648452759,5.406200647354126,5.948969125747681,5.6591761112213135,5.098428964614868,6.186973810195923,6.260297060012817,4.845706224441528,4.884040117263794,5.762598276138306,5.379200220108032,6.526299715042114,6.481151819229126,5.83853554725647,6.3441002368927,5.409104585647583,5.687176942825317,5.74416184425354,5.246052980422974,6.2644360065460205,5.886080026626587,5.6273205280303955,6.853506326675415,5.799962282180786,5.687295198440552,5.733844995498657,6.048799753189087,6.1102516651153564,5.856144189834595,5.675837755203247,5.673197031021118,5.446083307266235,5.110049486160278,5.814363718032837,6.427519083023071,6.298197984695435,5.46152138710022,6.262089967727661,6.41163182258606,5.403126955032349,5.201777696609497,6.494671106338501,5.8217175006866455,6.150409936904907,5.3533384799957275,5.698808908462524,6.203578233718872,5.529521226882935,6.30396294593811,4.761819124221802,6.134600877761841,5.632828950881958,5.592910051345825,5.776437997817993,5.856722116470337,5.320695161819458,5.844919443130493,5.342267274856567,5.450078248977661,6.006361246109009,6.269325494766235,6.200906038284302,6.103886842727661,5.622957468032837,6.478121995925903,5.759417772293091,5.9912941455841064,6.174035310745239,6.112441301345825,6.027667284011841,5.732034921646118,5.58466649055481,5.922117471694946,5.360791444778442,5.733863115310669,5.565796136856079,6.27460789680481,6.186408281326294,4.860639810562134,5.562293291091919,5.67807412147522,5.985955476760864,5.920683145523071,4.785085916519165,6.239168405532837,5.6897499561309814,5.570624589920044,5.4800636768341064,5.736876726150513,5.4539244174957275,6.388970613479614,4.825183153152466,6.031313180923462,5.904541254043579
21
+ openlm-research/open_llama_7b vs microsoft/Orca-2-7b,0.85,6.374301195144653,5.607383966445923,5.797471284866333,5.860095262527466,6.8239147663116455,6.35334038734436,6.313669443130493,5.884981393814087,6.407116174697876,5.7687084674835205,6.412864923477173,6.109200716018677,6.0385754108428955,4.771707773208618,5.541899919509888,6.475514650344849,5.574916124343872,6.94892144203186,5.697839975357056,6.276600122451782,5.68853497505188,4.96396279335022,14.194698572158813,5.249680757522583,7.378007173538208,6.040998697280884,5.913078546524048,5.410505533218384,6.165990114212036,7.0394127368927,5.487897157669067,5.749399423599243,6.100099802017212,5.42592453956604,6.552445650100708,5.945457696914673,5.909960985183716,6.225774049758911,5.548428773880005,5.035020112991333,4.851449251174927,5.882075548171997,5.815673112869263,5.656095743179321,5.143633127212524,5.5314247608184814,7.116902589797974,5.827336549758911,5.751171350479126,5.695185899734497,7.8084423542022705,4.529286623001099,6.1526501178741455,6.381459474563599,6.09170937538147,4.367375612258911,5.150029420852661,5.458413362503052,5.681887865066528,6.342478036880493,5.972796678543091,5.907988786697388,5.2621166706085205,5.635527849197388,6.17734169960022,5.835489511489868,5.093684434890747,5.344476938247681,6.326663255691528,6.0996363162994385,5.540684938430786,4.868671655654907,5.586103677749634,5.079303026199341,6.067350625991821,6.356841325759888,5.464815378189087,5.622609376907349,5.330131769180298,6.660903215408325,5.466338396072388,5.834389925003052,5.860307931900024,5.607694864273071,5.605406999588013,5.080667734146118,5.097392320632935,6.129028558731079,6.217902421951294,5.536783456802368,5.054259538650513,4.665910959243774,6.816373109817505,5.676759958267212,5.222558259963989,6.3869335651397705,5.894263505935669,6.307274103164673,4.8874828815460205,5.601005792617798,5.1678307056427
model-tracing/scripts/docs/doc_trace.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
+ import os
4
+ from datasets import load_dataset
5
+ from tqdm import tqdm
6
+ import math
7
+ import matplotlib.pyplot as plt
8
+ import csv
9
+ from utils import interpolate_models
10
+ import time
11
+ import copy
12
+ import argparse
13
+ import glob
14
+
15
+
16
+ block_size = 512
17
+
18
+
19
+ def group_texts(examples):
20
+ # Concatenate all texts.
21
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
22
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
23
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
24
+ # customize this part to your needs.
25
+ total_length = (total_length // block_size) * block_size
26
+ # Split by chunks of max_len.
27
+ result = {
28
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
29
+ for k, t in concatenated_examples.items()
30
+ }
31
+ result["labels"] = result["input_ids"].copy()
32
+ return result
33
+
34
+
35
+ def main(args):
36
+ start_time = time.time()
37
+ # Automatically detect CUDA device
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ print(f"Using device: {device}")
40
+ os.environ["WANDB_MODE"] = "disabled"
41
+
42
+ # Load models and tokenizer
43
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
44
+ model_list = [
45
+ "meta-llama/Llama-2-7b-hf",
46
+ "codellama/CodeLlama-7b-hf",
47
+ "lmsys/vicuna-7b-v1.5",
48
+ "EleutherAI/llemma_7b",
49
+ "LLM360/Amber",
50
+ ]
51
+ model_pairs = [
52
+ (0, 2), # LLama2, Vicuna-1.5
53
+ (0, 1), # LLama2, CodeLlama
54
+ (0, 3), # LLama2, Lemma
55
+ (1, 3), # CodeLlama, Lemma
56
+ (0, 4), # LLama2, Amber
57
+ ]
58
+ models = [
59
+ AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
60
+ for model_name in model_list
61
+ ]
62
+ tokenizer = AutoTokenizer.from_pretrained(models[0].config._name_or_path)
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ # Scan the directory for JSON files based on the test name argument
66
+ columns_ignored = [
67
+ "text",
68
+ "added",
69
+ "id",
70
+ "lang",
71
+ "metadata",
72
+ "source",
73
+ "timestamp",
74
+ "subdomain",
75
+ ]
76
+ json_dir = f"/juice4/scr4/nlp/model-tracing/dolma_program_languages/json_files_{args.test_name}"
77
+ json_files = glob.glob(f"{json_dir}/*.json")
78
+ save_dir = f"/juice4/scr4/nlp/model-tracing/dolma_program_languages/results_{args.test_name}"
79
+ if not os.path.exists(save_dir):
80
+ os.makedirs(save_dir)
81
+
82
+ for json_file in json_files:
83
+ print(f"Processing {json_file}")
84
+
85
+ # Prepare dataset
86
+ eval_dataset = load_dataset("json", data_files=json_file)
87
+
88
+ def tokenize_function(examples):
89
+ return tokenizer(examples["text"])
90
+
91
+ tokenized_datasets = eval_dataset.map(
92
+ tokenize_function, batched=True, num_proc=4, remove_columns=columns_ignored
93
+ )
94
+ lm_datasets = tokenized_datasets.map(
95
+ group_texts,
96
+ batched=True,
97
+ batch_size=1,
98
+ num_proc=1,
99
+ )
100
+
101
+ # Prepare for evaluation. Batch size is optimized for ~7B model
102
+ training_args = TrainingArguments(
103
+ output_dir="./results",
104
+ per_device_eval_batch_size=3,
105
+ do_eval=True,
106
+ report_to=None,
107
+ dataloader_num_workers=4,
108
+ use_cpu=True,
109
+ )
110
+ alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
111
+ model = copy.deepcopy(models[0])
112
+ trainer = Trainer(model=model, args=training_args, eval_dataset=lm_datasets)
113
+ print("create data loader")
114
+ eval_dataloader = trainer.get_test_dataloader(lm_datasets["train"])
115
+
116
+ for idx_a, idx_b in tqdm(model_pairs, desc="Model Interpolation"):
117
+ model_a = models[idx_a]
118
+ model_b = models[idx_b]
119
+ perplexities = []
120
+ model_a_name = model_a.config._name_or_path.split("/")[-1]
121
+ model_b_name = model_b.config._name_or_path.split("/")[-1]
122
+
123
+ for alpha in tqdm(
124
+ alphas, desc=f" \n Alpha Perplexities for {model_a_name} and {model_b_name}"
125
+ ):
126
+ interpolated_model = interpolate_models(model_a, model_b, alpha)
127
+ # cast to bfloat16 before GPU
128
+ interpolated_model = interpolated_model.half().to(device)
129
+
130
+ start_time = time.time()
131
+ losses = []
132
+
133
+ for batch in tqdm(eval_dataloader, desc=f"\n Evaluating {alpha}"):
134
+ # HF Trainer finds GPU by default
135
+ input_ids = batch["input_ids"].to(device)
136
+ attention_mask = batch["attention_mask"].to(device)
137
+ labels = batch["labels"].to(device)
138
+ outputs = interpolated_model(
139
+ input_ids=input_ids,
140
+ attention_mask=attention_mask,
141
+ labels=labels,
142
+ )
143
+ loss = outputs.loss
144
+ losses.append(loss.item())
145
+
146
+ loss_mean = sum(losses) / len(losses)
147
+ print(f"Loss mean: {loss_mean}")
148
+ end_time = time.time()
149
+ execution_time = end_time - start_time
150
+ print(f"Execution time base: {execution_time} seconds")
151
+
152
+ perplexity = math.exp(loss_mean)
153
+ perplexities.append(perplexity)
154
+
155
+ # Move the model back to CPU
156
+ interpolated_model.to("cpu")
157
+
158
+ # Clear the GPU cache
159
+ del interpolated_model, input_ids, attention_mask, labels, outputs, loss
160
+ torch.cuda.empty_cache()
161
+
162
+ # Save perplexities and model names to CSV
163
+ json_filename = os.path.splitext(os.path.basename(json_file))[0]
164
+ csv_filename = f"perplexities_{json_filename}.csv"
165
+ csv_full_path = f"{save_dir}/{csv_filename}"
166
+ csv_header = ["Model Pair"] + [f"Alpha {alpha}" for alpha in alphas]
167
+ if not os.path.exists(csv_full_path):
168
+ with open(csv_full_path, "w", newline="") as csvfile:
169
+ writer = csv.writer(csvfile)
170
+ writer.writerow(csv_header)
171
+
172
+ with open(csv_full_path, "a", newline="") as csvfile:
173
+ writer = csv.writer(csvfile)
174
+ model_pair = f"{model_a_name} vs {model_b_name}"
175
+ row = [model_pair] + perplexities
176
+ writer.writerow(row)
177
+
178
+ # Create the plot
179
+ plt.figure(figsize=(8, 6))
180
+ plt.plot(alphas, perplexities)
181
+ plt.xlabel("Alpha")
182
+ plt.ylabel("Perplexity")
183
+ plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
184
+
185
+ # Save the plot as a PNG file
186
+ plot_filename = (
187
+ f"alpha_vs_perplexity_{model_a_name}_vs_{model_b_name}_{json_filename}.png"
188
+ )
189
+ plot_full_path = f"{save_dir}/{plot_filename}"
190
+ plt.savefig(plot_full_path, dpi=300, bbox_inches="tight")
191
+ plt.close()
192
+
193
+ end_time = time.time()
194
+ execution_time = end_time - start_time
195
+ print(f"Total execution time: {execution_time} seconds")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ parser = argparse.ArgumentParser(description="Model Interpolation")
200
+ parser.add_argument(
201
+ "--test_name", type=str, default="js", help="Test name (e.g., cpp, python, js)"
202
+ )
203
+ args = parser.parse_args()
204
+ main(args)
model-tracing/scripts/docs/launch.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yaml import safe_load
2
+ import subprocess
3
+ import argparse
4
+ import os
5
+
6
+
7
+ def load_yaml(file_path):
8
+ with open(file_path, "r") as file:
9
+ return safe_load(file)
10
+
11
+
12
+ def main(args):
13
+ # Load configurations
14
+ config = load_yaml(args.config)
15
+
16
+ # Create necessary directories
17
+ os.makedirs(config["save_dir"], exist_ok=True)
18
+ os.makedirs(f"{config['save_dir']}/logs", exist_ok=True)
19
+ os.makedirs(f"{config['save_dir']}/results", exist_ok=True)
20
+
21
+ # Prepare base command
22
+ base_cmd = f"{args.slurm} python {args.script}"
23
+
24
+ # Launch jobs
25
+ for dataset in config["datasets"]:
26
+ for model_arch in config["model_architectures"]:
27
+ job_id = f"{dataset}_{model_arch}"
28
+ log_path = f"{config['save_dir']}/logs/{job_id}.out"
29
+ results_path = f"{config['save_dir']}/results/{job_id}"
30
+
31
+ cmd = (
32
+ f"{base_cmd} "
33
+ f"--model_arch {model_arch} "
34
+ f"--test_name {dataset} "
35
+ f"--save_dir {results_path}"
36
+ )
37
+
38
+ if "dolma" in dataset.lower():
39
+ cmd += f" --json_dir {config['dolma_json_dir']}/{dataset}"
40
+ elif "m2d2" in dataset.lower():
41
+ cmd += f" --json_dir {config['m2d2_json_dir']}/{dataset}"
42
+
43
+ full_cmd = f"{args.slurm} -o {log_path} -J {job_id} '{cmd}'"
44
+ print(f"Launching job: {full_cmd}")
45
+ subprocess.run(full_cmd, shell=True)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ parser = argparse.ArgumentParser(description="Launch interpolation experiments on SLURM")
50
+ parser.add_argument("--config", default="config.yaml", help="Path to YAML configuration file")
51
+ parser.add_argument(
52
+ "--slurm",
53
+ default="srun --partition=your-partition --time=24:00:00 --mem=64G --gres=gpu:1",
54
+ help="SLURM command",
55
+ )
56
+ parser.add_argument("--script", default="interpolation_script.py", help="Python script to run")
57
+ args = parser.parse_args()
58
+ main(args)
model-tracing/scripts/docs/m2d_trace.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
+ import itertools
4
+ import os
5
+ from datasets import load_dataset
6
+ from tqdm import tqdm
7
+ import math
8
+ import matplotlib.pyplot as plt
9
+ import csv
10
+ from utils import interpolate_models
11
+ import time
12
+ import argparse
13
+ import glob
14
+ import gc
15
+
16
+ block_size = 2048
17
+ """
18
+ Script for running ablation of tests on m2d2 dataset rather
19
+ than simply wikitext
20
+ """
21
+
22
+
23
+ def group_texts(examples):
24
+ # Concatenate all texts.
25
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
26
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
27
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
28
+ # customize this part to your needs.
29
+ total_length = (total_length // block_size) * block_size
30
+ # Split by chunks of max_len.
31
+ result = {
32
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
33
+ for k, t in concatenated_examples.items()
34
+ }
35
+ result["labels"] = result["input_ids"].copy()
36
+ return result
37
+
38
+
39
+ def load_model(model_name):
40
+ return AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
41
+
42
+
43
+ def main(args):
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ print(f"Using device: {device}")
46
+ os.environ["WANDB_MODE"] = "disabled"
47
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
48
+
49
+ model_arch = args.model_arch
50
+ if model_arch == "llama":
51
+ model_list = [
52
+ "meta-llama/Llama-2-7b-hf",
53
+ "meta-llama/Llama-2-7b-chat-hf",
54
+ "meta-llama/CodeLlama-7b-Python-hf",
55
+ "meta-llama/CodeLlama-7b-Instruct-hf",
56
+ "codellama/CodeLlama-7b-hf",
57
+ "lmsys/vicuna-7b-v1.5",
58
+ "lmsys/vicuna-7b-v1.1",
59
+ "EleutherAI/llemma_7b",
60
+ "LLM360/Amber",
61
+ ]
62
+ elif model_arch == "olmo":
63
+ model_list = [
64
+ "/scr/ahmedah/olmo/step1000_4B_tokens/seed_0_4B",
65
+ "/scr/ahmedah/olmo/step1000_4B_tokens/seed_42_4B",
66
+ ]
67
+
68
+ tokenizer = AutoTokenizer.from_pretrained(model_list[0])
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+
71
+ test_cases = [
72
+ {
73
+ "test_name": folder_name,
74
+ "json_dir": f"/juice4/scr4/nlp/model-tracing/m2d2_s2orc/{folder_name}",
75
+ "save_dir": f"/juice4/scr4/nlp/model-tracing/m2d2_s2orc/results_{folder_name}",
76
+ "columns_ignored": ["text", "added", "id", "source", "timestamp", "subdomain"],
77
+ }
78
+ for folder_name in [
79
+ "AI",
80
+ "CV",
81
+ "ET",
82
+ "IM",
83
+ "mtrl-sci",
84
+ "stat-mech",
85
+ "AR",
86
+ "CY",
87
+ "IR",
88
+ "NA",
89
+ "str-el",
90
+ "art",
91
+ "DB",
92
+ "FL",
93
+ "supr-con",
94
+ "CC",
95
+ "DC",
96
+ "GA",
97
+ "LG",
98
+ "phil",
99
+ "CE",
100
+ "dis-nn",
101
+ "GL",
102
+ "LO",
103
+ "CG",
104
+ "DL",
105
+ "GR",
106
+ "MA",
107
+ "quant-gas",
108
+ "CL",
109
+ "DM",
110
+ "GT",
111
+ "mes-hall",
112
+ "CO",
113
+ "DS",
114
+ "HC",
115
+ "MM",
116
+ "soft",
117
+ "CR",
118
+ "EP",
119
+ "HE",
120
+ "MS",
121
+ "SR",
122
+ ]
123
+ ]
124
+
125
+ for test_case in test_cases:
126
+ test_name = test_case["test_name"]
127
+ json_dir = test_case["json_dir"]
128
+ save_dir = test_case["save_dir"]
129
+ columns_ignored = ["text", "added", "id", "source", "subdomain"]
130
+
131
+ json_files = glob.glob(f"{json_dir}/*.json")
132
+ if not os.path.exists(save_dir):
133
+ os.makedirs(save_dir)
134
+
135
+ for json_file in json_files:
136
+ print(f"Processing {json_file}")
137
+
138
+ eval_dataset = load_dataset("json", data_files=json_file)
139
+
140
+ def tokenize_function(examples):
141
+ return tokenizer(examples["text"])
142
+
143
+ tokenized_datasets = eval_dataset.map(
144
+ tokenize_function, batched=True, num_proc=4, remove_columns=columns_ignored
145
+ )
146
+ lm_datasets = tokenized_datasets.map(
147
+ group_texts,
148
+ batched=True,
149
+ batch_size=1000,
150
+ num_proc=8,
151
+ )
152
+
153
+ training_args = TrainingArguments(
154
+ output_dir="./hf_results",
155
+ per_device_eval_batch_size=15,
156
+ do_eval=True,
157
+ report_to=None,
158
+ dataloader_num_workers=8,
159
+ use_cpu=True,
160
+ )
161
+ alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
162
+ initial_model = load_model(model_list[0])
163
+ trainer = Trainer(model=initial_model, args=training_args, eval_dataset=lm_datasets)
164
+ eval_dataloader = trainer.get_test_dataloader(lm_datasets["train"])
165
+ del initial_model
166
+
167
+ model_pairs = list(itertools.combinations(enumerate(model_list), 2))
168
+
169
+ base_dir = f"{save_dir}/{test_name}"
170
+ os.makedirs(base_dir, exist_ok=True)
171
+ imgs_dir = os.path.join(base_dir, "imgs")
172
+ os.makedirs(imgs_dir, exist_ok=True)
173
+ csv_dir = os.path.join(base_dir, "csv")
174
+ os.makedirs(csv_dir, exist_ok=True)
175
+
176
+ current_model_a, current_model_b = None, None
177
+ current_model_a_name, current_model_b_name = None, None
178
+
179
+ for (idx_a, model_a_name), (idx_b, model_b_name) in tqdm(
180
+ model_pairs, desc="Model Interpolation"
181
+ ):
182
+ if idx_a < idx_b:
183
+ perplexities = []
184
+
185
+ if current_model_a is None or current_model_a_name != model_a_name:
186
+ if current_model_a is not None:
187
+ del current_model_a
188
+ torch.cuda.empty_cache()
189
+ current_model_a = load_model(model_a_name).to("cpu")
190
+ current_model_a_name = model_a_name
191
+
192
+ if current_model_b is None or current_model_b_name != model_b_name:
193
+ if current_model_b is not None:
194
+ del current_model_b
195
+ torch.cuda.empty_cache()
196
+ current_model_b = load_model(model_b_name).to("cpu")
197
+ current_model_b_name = model_b_name
198
+
199
+ with torch.no_grad():
200
+ for alpha in tqdm(
201
+ alphas,
202
+ desc=f" \n Alpha Perplexities for {model_a_name} and {model_b_name}",
203
+ ):
204
+ interpolated_model = interpolate_models(
205
+ current_model_a, current_model_b, alpha, model_arch=model_arch
206
+ )
207
+ interpolated_model = interpolated_model.half().to(device)
208
+
209
+ start_time = time.time()
210
+ losses = []
211
+
212
+ for batch in tqdm(eval_dataloader, desc=f"\n Evaluating {alpha}"):
213
+ input_ids = batch["input_ids"].to(device)
214
+ attention_mask = batch["attention_mask"].to(device)
215
+ labels = batch["labels"].to(device)
216
+
217
+ outputs = interpolated_model(
218
+ input_ids=input_ids,
219
+ attention_mask=attention_mask,
220
+ labels=labels,
221
+ )
222
+ loss = outputs.loss
223
+ losses.append(loss.item())
224
+
225
+ loss_mean = sum(losses) / len(losses)
226
+ print(f"Loss mean: {loss_mean}")
227
+ end_time = time.time()
228
+ execution_time = end_time - start_time
229
+ print(f"Execution time base: {execution_time} seconds")
230
+
231
+ perplexity = math.exp(loss_mean)
232
+ perplexities.append(perplexity)
233
+
234
+ interpolated_model.to("cpu")
235
+ del interpolated_model, input_ids, attention_mask, labels, outputs, loss
236
+ torch.cuda.empty_cache()
237
+ gc.collect()
238
+
239
+ model_a_name = model_a_name.split("/")[-1]
240
+ model_b_name = model_b_name.split("/")[-1]
241
+ json_filename = os.path.splitext(os.path.basename(json_file))[0]
242
+ csv_filename = f"{csv_dir}/perplexities_{json_filename}.csv"
243
+ csv_header = ["Model Pair"] + [f"Alpha {alpha}" for alpha in alphas]
244
+
245
+ if not os.path.exists(csv_filename):
246
+ with open(csv_filename, "w", newline="") as csvfile:
247
+ writer = csv.writer(csvfile)
248
+ writer.writerow(csv_header)
249
+
250
+ with open(csv_filename, "a", newline="") as csvfile:
251
+ writer = csv.writer(csvfile)
252
+ model_pair = f"{model_a_name} vs {model_b_name}"
253
+ row = [model_pair] + perplexities
254
+ writer.writerow(row)
255
+
256
+ plt.figure(figsize=(8, 6))
257
+ plt.plot(alphas, perplexities)
258
+ plt.xlabel("Alpha")
259
+ plt.ylabel("Perplexity")
260
+ plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
261
+
262
+ plot_filename = (
263
+ f"alpha_vs_perplexity_{model_a_name}_vs_{model_b_name}_{json_filename}.png"
264
+ )
265
+ plot_path = f"{imgs_dir}/{plot_filename}"
266
+ plt.savefig(plot_path, dpi=300, bbox_inches="tight")
267
+ plt.close()
268
+
269
+
270
+ if __name__ == "__main__":
271
+ parser = argparse.ArgumentParser(description="Model Interpolation")
272
+ parser.add_argument(
273
+ "--model_arch",
274
+ choices=["llama", "olmo"],
275
+ default="llama",
276
+ help="default model architecture to use",
277
+ )
278
+ args = parser.parse_args()
279
+ main(args)
model-tracing/scripts/mode/main.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
4
+ import itertools
5
+ import os
6
+ from datasets import load_dataset
7
+ from tqdm import tqdm
8
+ import math
9
+ import matplotlib.pyplot as plt
10
+ import csv
11
+ from utils import interpolate_models
12
+ import time
13
+ import argparse
14
+
15
+
16
+ block_size = 512
17
+
18
+
19
+ def group_texts(examples):
20
+ # Concatenate all texts.
21
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
22
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
23
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
24
+ # customize this part to your needs.
25
+ total_length = (total_length // block_size) * block_size
26
+ # Split by chunks of max_len.
27
+ result = {
28
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
29
+ for k, t in concatenated_examples.items()
30
+ }
31
+ result["labels"] = result["input_ids"].copy()
32
+ return result
33
+
34
+
35
+ def load_model(model_name):
36
+ return AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
37
+
38
+
39
+ def main(args):
40
+ # Automatically detect CUDA device
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ print(f"Using device: {device}")
43
+ os.environ["WANDB_MODE"] = "disabled"
44
+
45
+ # Load models and tokenizer
46
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
+
48
+ model_arch = args.model_arch
49
+ if model_arch == "llama":
50
+ model_list = [
51
+ "meta-llama/Llama-2-7b-hf",
52
+ "codellama/CodeLlama-7b-hf",
53
+ "openlm-research/open_llama_7b",
54
+ "huggyllama/llama-7b",
55
+ "lmsys/vicuna-7b-v1.5",
56
+ "EleutherAI/llemma_7b",
57
+ "lmsys/vicuna-7b-v1.1",
58
+ "microsoft/Orca-2-7b",
59
+ "LLM360/Amber",
60
+ ]
61
+ elif model_arch == "olmo":
62
+ model_list = [
63
+ "/scr/ahmedah/olmo/step1000_4B_tokens/seed_0_4B",
64
+ "/scr/ahmedah/olmo/step1000_4B_tokens/seed_42_4B",
65
+ ]
66
+
67
+ # Load tokenizer
68
+ tokenizer = AutoTokenizer.from_pretrained(model_list[0])
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+
71
+ # Prepare dataset
72
+ if args.dataset == "wikitext":
73
+ eval_dataset = load_dataset("dlwh/wikitext_103_detokenized", split="test")
74
+ columns_ignored = ["text"]
75
+ else:
76
+ raise ValueError("main.py only supports wikitext.")
77
+
78
+ def tokenize_function(examples):
79
+ return tokenizer(examples["text"])
80
+
81
+ tokenized_datasets = eval_dataset.map(
82
+ tokenize_function, batched=True, num_proc=4, remove_columns=columns_ignored
83
+ )
84
+ lm_datasets = tokenized_datasets.map(
85
+ group_texts,
86
+ batched=True,
87
+ batch_size=1,
88
+ num_proc=1,
89
+ )
90
+
91
+ # Prepare for evaluation. Batch size is optimized for ~7B model
92
+ training_args = TrainingArguments(
93
+ output_dir="./hf_results",
94
+ per_device_eval_batch_size=3,
95
+ do_eval=True,
96
+ report_to=None,
97
+ dataloader_num_workers=4,
98
+ use_cpu=True,
99
+ )
100
+ alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
101
+ # Load an initial model to create the trainer and dataloader
102
+ initial_model = load_model(model_list[0])
103
+ trainer = Trainer(model=initial_model, args=training_args, eval_dataset=lm_datasets)
104
+ eval_dataloader = trainer.get_test_dataloader(lm_datasets)
105
+ del initial_model
106
+
107
+ # Calculate the L2 distance between each pair of models
108
+ model_pairs = list(itertools.combinations(enumerate(model_list), 2))
109
+
110
+ # create directories for results
111
+ base_dir = f"{os.getcwd()}/results"
112
+ os.makedirs(base_dir, exist_ok=True)
113
+ imgs_dir = os.path.join(base_dir, "imgs")
114
+ os.makedirs(imgs_dir, exist_ok=True)
115
+ csv_dir = os.path.join(base_dir, "csv")
116
+ print(csv_dir)
117
+ os.makedirs(csv_dir, exist_ok=True)
118
+
119
+ current_model_a, current_model_b = None, None
120
+ current_model_a_name, current_model_b_name = None, None
121
+
122
+ for (idx_a, model_a_name), (idx_b, model_b_name) in tqdm(
123
+ model_pairs, desc="Model Interpolation"
124
+ ):
125
+ if idx_a < idx_b:
126
+ perplexities = []
127
+
128
+ if current_model_a is None or current_model_a_name != model_a_name:
129
+ if current_model_a is not None:
130
+ del current_model_a
131
+ torch.cuda.empty_cache()
132
+ current_model_a = load_model(model_a_name).to("cpu")
133
+ current_model_a_name = model_a_name
134
+
135
+ if current_model_b is None or current_model_b_name != model_b_name:
136
+ if current_model_b is not None:
137
+ del current_model_b
138
+ torch.cuda.empty_cache()
139
+ current_model_b = load_model(model_b_name).to("cpu")
140
+ current_model_b_name = model_b_name
141
+
142
+ with torch.no_grad():
143
+ for alpha in tqdm(
144
+ alphas, desc=f" \n Alpha Perplexities for {model_a_name} and {model_b_name}"
145
+ ):
146
+
147
+ interpolated_model = interpolate_models(
148
+ current_model_a, current_model_b, alpha, model_arch=model_arch
149
+ )
150
+ interpolated_model = interpolated_model.half().to(device)
151
+
152
+ start_time = time.time()
153
+ losses = []
154
+
155
+ for batch in tqdm(eval_dataloader, desc=f"\n Evaluating {alpha}"):
156
+ input_ids = batch["input_ids"].to(device)
157
+ attention_mask = batch["attention_mask"].to(device)
158
+ labels = batch["labels"].to(device)
159
+
160
+ outputs = interpolated_model(
161
+ input_ids=input_ids,
162
+ attention_mask=attention_mask,
163
+ labels=labels,
164
+ )
165
+ loss = outputs.loss
166
+ losses.append(loss.item())
167
+
168
+ loss_mean = sum(losses) / len(losses)
169
+ print(f"Loss mean: {loss_mean}")
170
+ end_time = time.time()
171
+ execution_time = end_time - start_time
172
+ print(f"Execution time base: {execution_time} seconds")
173
+
174
+ perplexity = math.exp(loss_mean)
175
+ perplexities.append(perplexity)
176
+
177
+ # Move the model back to CPU
178
+ interpolated_model.to("cpu")
179
+
180
+ # Clear the GPU cache & collect free memory
181
+ del interpolated_model, input_ids, attention_mask, labels, outputs, loss
182
+ torch.cuda.empty_cache()
183
+ gc.collect()
184
+
185
+ # split on HF org so we don't get accidental
186
+ # directory error
187
+
188
+ model_a_name = model_a_name.split("/")[-1]
189
+ model_b_name = model_b_name.split("/")[-1]
190
+ # Save perplexities and model names to CSV
191
+ csv_filename = f"{csv_dir}/single_perplexities.csv"
192
+ csv_header = ["Model Pair"] + [f"Alpha {alpha}" for alpha in alphas]
193
+
194
+ if not os.path.exists(csv_filename):
195
+ with open(csv_filename, "w", newline="") as csvfile:
196
+ writer = csv.writer(csvfile)
197
+ writer.writerow(csv_header)
198
+
199
+ with open(csv_filename, "a", newline="") as csvfile:
200
+ writer = csv.writer(csvfile)
201
+ model_pair = f"{model_a_name} vs {model_b_name}"
202
+ row = [model_pair] + perplexities
203
+ writer.writerow(row)
204
+
205
+ # Create the plot
206
+ plt.figure(figsize=(8, 6))
207
+ plt.plot(alphas, perplexities)
208
+ plt.xlabel("Alpha")
209
+ plt.ylabel("Perplexity")
210
+ plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
211
+
212
+ # Save the plot as a PNG file
213
+ plot_filename = f"single_alpha_vs_perplexity_{model_a_name}_vs_{model_b_name}.png"
214
+ plot_path = f"{imgs_dir}/{plot_filename}"
215
+ plt.savefig(plot_path, dpi=300, bbox_inches="tight")
216
+ plt.close()
217
+
218
+
219
+ if __name__ == "__main__":
220
+ parser = argparse.ArgumentParser(description="Model Interpolation")
221
+ parser.add_argument(
222
+ "--dataset", choices=["wikitext", "json"], default="wikitext", help="Dataset to use"
223
+ )
224
+ parser.add_argument(
225
+ "--model_arch",
226
+ choices=["llama", "olmo"],
227
+ default="llama",
228
+ help="default model architecture to use",
229
+ )
230
+ args = parser.parse_args()
231
+ main(args)
model-tracing/scripts/mode/mode_connectivity_metrics.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import datetime
4
+
5
+
6
+ def plot_traces(
7
+ results_path,
8
+ metric,
9
+ plot_path,
10
+ model_a_name,
11
+ model_b_name,
12
+ unpermuted_res=False,
13
+ normalize=True,
14
+ alpha_step=0.1,
15
+ end_points=True,
16
+ ):
17
+
18
+ df = pd.read_csv(results_path)
19
+
20
+ alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
21
+ if end_points is False:
22
+ alphas = alphas[1:-1]
23
+
24
+ if metric == "loss":
25
+
26
+ plt.figure(figsize=(8, 6))
27
+ for index, row in df.iterrows():
28
+ row = row[int(len(row) - (len(row) - 2) / 2) :]
29
+ if normalize:
30
+ row = normalize_trace(row, alpha_step)
31
+ plt.plot(alphas, row, "o-")
32
+
33
+ plt.xlabel("Alpha")
34
+ plt.ylabel("Loss")
35
+ plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
36
+ plot_filename = f"{plot_path}_{datetime.datetime.now().timestamp()}.png"
37
+
38
+ if metric == "perplexity":
39
+
40
+ plt.figure(figsize=(8, 6))
41
+ for index, row in df.iterrows():
42
+ row = row[2 : int(2 + (len(row) - 2) / 2)]
43
+ if normalize:
44
+ row = normalize_trace(row, alpha_step)
45
+ plt.plot(alphas, row, "o-")
46
+
47
+ plt.xlabel("Alpha")
48
+ plt.ylabel("Perplexity")
49
+ plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
50
+ plot_filename = f"{plot_path}_{datetime.datetime.now().timestamp()}.png"
51
+
52
+ if unpermuted_res is not False:
53
+ plt.plot(alphas, normalize_trace(unpermuted_res, alpha_step))
54
+
55
+ # Save the plot as a PNG file
56
+
57
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
58
+ plt.close()
59
+
60
+
61
+ def plot_trace(losses, alpha_step, normalize, model_a_name, model_b_name, plot_path):
62
+
63
+ plt.figure(figsize=(8, 6))
64
+ if normalize:
65
+ losses = normalize_trace(losses, alpha_step)
66
+ alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
67
+ plt.plot(alphas, losses, "o-")
68
+
69
+ plt.xlabel("Alpha")
70
+ plt.ylabel("Loss")
71
+ plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
72
+ plot_filename = f"{plot_path}_{datetime.datetime.now().timestamp()}.png"
73
+
74
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
75
+ plt.close()
76
+
77
+
78
+ def normalize_trace(trace, alpha_step):
79
+ slope = trace[-1] - trace[0]
80
+ start = trace[0]
81
+ for i in range(len(trace)):
82
+ trace[i] -= slope * alpha_step * i
83
+ trace[i] -= start
84
+ return trace
85
+
86
+
87
+ def normalize_trace_2(trace, alphas):
88
+ slope = trace[-1] - trace[0]
89
+ start = trace[0]
90
+ for i in range(len(trace)):
91
+ trace[i] -= slope * alphas[i]
92
+ trace[i] -= start
93
+ return trace
94
+
95
+
96
+ def max_loss_ahmed(results_path, num_points=5, normalize=True, alphas=[0.0, 0.3, 0.5, 0.7, 1.0]):
97
+ df = pd.read_csv(results_path)
98
+
99
+ max_losses = []
100
+
101
+ for index, row in df.iterrows():
102
+ row = row[-num_points:]
103
+ if normalize:
104
+ row = normalize_trace_2(row, alphas)
105
+ max_losses.append(max(row))
106
+
107
+ return max_losses
108
+
109
+
110
+ def max_loss_compare(results_path, unpermuted_loss, num_points, normalize=True, alpha_step=0.1):
111
+ df = pd.read_csv(results_path)
112
+ alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
113
+
114
+ permuted_max_losses = []
115
+
116
+ for index, row in df.iterrows():
117
+ row = row[-num_points:]
118
+ if normalize:
119
+ row = normalize_trace(row, alpha_step)
120
+ permuted_max_losses.append(max(row))
121
+
122
+ if normalize:
123
+ unpermuted_loss = normalize_trace(unpermuted_loss, alpha_step)
124
+ unpermuted_max_loss = max(unpermuted_loss)
125
+
126
+ counter = 0
127
+ for m in permuted_max_losses:
128
+ if m > unpermuted_max_loss:
129
+ counter += 1
130
+
131
+ return counter, len(permuted_max_losses)
132
+
133
+
134
+ def avg_loss_compare(results_path, unpermuted_loss, num_points, normalize=True, alpha_step=0.1):
135
+ df = pd.read_csv(results_path)
136
+ alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
137
+
138
+ permuted_avg_losses = []
139
+
140
+ for index, row in df.iterrows():
141
+ row = row[-num_points:]
142
+ if normalize:
143
+ row = normalize_trace(row, alpha_step)
144
+ permuted_avg_losses.append(sum(row) / len(row))
145
+
146
+ if normalize:
147
+ unpermuted_loss = normalize_trace(unpermuted_loss, alpha_step)
148
+ unpermuted_avg_loss = sum(unpermuted_loss) / len(unpermuted_loss)
149
+
150
+ counter = 0
151
+ for m in permuted_avg_losses:
152
+ if m > unpermuted_avg_loss:
153
+ counter += 1
154
+
155
+ return counter, len(permuted_avg_losses)
156
+
157
+
158
+ def avg_loss_ahmed(results_path, num_points=5, normalize=True, alphas=[0.0, 0.3, 0.5, 0.7, 1.0]):
159
+ df = pd.read_csv(results_path)
160
+
161
+ avg_losses = []
162
+
163
+ for index, row in df.iterrows():
164
+ row = row[-num_points:]
165
+ if normalize:
166
+ row = normalize_trace_2(row, alphas)
167
+ avg_losses.append(sum(row) / len(row))
168
+
169
+ return avg_losses
170
+
171
+
172
+ def compute_p_value(counter, total):
173
+ return (total - counter - 1) / total
model-tracing/scripts/perm/main.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from tracing.perm.permute import permute_model
5
+
6
+
7
+ def main(base_model, ft_model, test_stat, num_perm, emb_dim=4096, mlp_dim=11008):
8
+
9
+ unperm_stat = test_stat(base_model, ft_model)
10
+ print(unperm_stat)
11
+
12
+ perm_stats = []
13
+
14
+ for i in range(num_perm):
15
+
16
+ mlp_permutation = torch.randperm(mlp_dim)
17
+ emb_permutation = torch.randperm(emb_dim)
18
+
19
+ permute_model(ft_model, mlp_permutation, emb_permutation)
20
+
21
+ perm_stat = test_stat(base_model, ft_model)
22
+
23
+ perm_stats.append(perm_stat)
24
+ print(i, perm_stat)
25
+
26
+ print(perm_stats)
27
+ exact = p_value_exact(unperm_stat, perm_stats.copy())
28
+ approx = p_value_approx(unperm_stat, perm_stats.copy())
29
+
30
+ print(exact, approx)
31
+
32
+ return exact, approx, unperm_stat, perm_stats
33
+
34
+
35
+ def p_value_exact(unpermuted, permuted):
36
+ count = 0
37
+ for a in permuted:
38
+ if a < unpermuted:
39
+ count += 1
40
+ return round((count + 1) / (len(permuted) + 1), 2)
41
+
42
+
43
+ def p_value_approx(unpermuted, permuted):
44
+ mean = sum(permuted) / len(permuted)
45
+ std = np.std(permuted)
46
+ zscore = (unpermuted - mean) / std
47
+ return zscore
model-tracing/scripts/robust/pythia.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import GPTNeoXForCausalLM, AutoTokenizer
3
+
4
+ import argparse
5
+ import pickle
6
+ import timeit
7
+ import subprocess
8
+
9
+ from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate
10
+ from tracing.utils.utils import output_hook, get_submodule
11
+
12
+ parser = argparse.ArgumentParser(description="Experiment Settings")
13
+
14
+ parser.add_argument("--model_id", default="EleutherAI/pythia-1.4b-deduped", type=str)
15
+ parser.add_argument("--step", default=0, type=int)
16
+ parser.add_argument("--layer", default=10, type=int)
17
+
18
+ parser.add_argument("--dataset_id", default="dlwh/wikitext_103_detokenized", type=str)
19
+ parser.add_argument("--block_size", default=512, type=int)
20
+ parser.add_argument("--batch_size", default=6, type=int)
21
+
22
+ parser.add_argument("--save", default="results.p", type=str)
23
+ parser.add_argument("--seed", default=0, type=int)
24
+ parser.add_argument("--token", default="", type=str)
25
+
26
+ args = parser.parse_args()
27
+
28
+ from huggingface_hub import login
29
+
30
+ login(token=args.token)
31
+
32
+ start = timeit.default_timer()
33
+
34
+ results = {}
35
+ results["args"] = args
36
+ results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
37
+
38
+ torch.manual_seed(args.seed)
39
+
40
+ model = GPTNeoXForCausalLM.from_pretrained(
41
+ args.model_id,
42
+ revision=f"step{args.step}",
43
+ )
44
+ tokenizer = AutoTokenizer.from_pretrained(
45
+ args.model_id,
46
+ revision=f"step{args.step}",
47
+ )
48
+
49
+ print("model loaded")
50
+
51
+ dataset = prepare_hf_dataset(args.dataset_id, args.block_size, tokenizer)
52
+ dataloader = prepare_hf_dataloader(dataset, args.batch_size)
53
+
54
+ print("dataset loaded")
55
+
56
+ block = get_submodule(model, f"gpt_neox.layers.{args.layer}")
57
+
58
+ feats, hooks = {}, {}
59
+ for layer in [
60
+ "input_layernorm",
61
+ "post_attention_layernorm",
62
+ "mlp.dense_h_to_4h",
63
+ "mlp.dense_4h_to_h",
64
+ ]:
65
+ hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: output_hook(
66
+ m, inp, op, layer, feats
67
+ )
68
+ get_submodule(block, layer).register_forward_hook(hooks[layer])
69
+
70
+ print("hooks created")
71
+
72
+ evaluate(model, dataloader)
73
+
74
+ print("models evaluated")
75
+
76
+ end = timeit.default_timer()
77
+ results["time"] = end - start
78
+
79
+ results["weights"] = block.state_dict()
80
+ results["feats"] = feats
81
+
82
+ print(results)
83
+ pickle.dump(results, open(args.save, "wb"))
model-tracing/tracing/__init__.py ADDED
File without changes
model-tracing/tracing/perm/permute.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for permuting weights in a Llama model architecture.
3
+ This enables exploring model connectivity and representation properties
4
+ by applying consistent neuron permutations throughout the network.
5
+ """
6
+
7
+
8
+ def permute_model(model, mlp_permutation, emb_permutation, n_blocks=32):
9
+ """
10
+ Apply permutations to a Llama model's weights to maintain functional equivalence.
11
+
12
+ Args:
13
+ model: The Llama model to permute
14
+ mlp_permutation: Permutation indices for MLP hidden dimensions
15
+ emb_permutation: Permutation indices for embedding dimensions
16
+ n_blocks: Number of transformer blocks in the model (default: 32)
17
+
18
+ Returns:
19
+ None: Modifies the model in-place
20
+ """
21
+ permute_embedding_layer(model, emb_permutation)
22
+ permute_transformer_blocks(model, mlp_permutation, emb_permutation)
23
+ permute_output_layer(model, emb_permutation)
24
+
25
+
26
+ def permute_transformer_blocks(model, mlp_permutation, emb_permutation):
27
+ """
28
+ Apply permutations to transformer block weights in a Llama model.
29
+
30
+ Permutes attention layers, MLP layers, and normalization layers according to
31
+ the provided permutation indices to maintain functional equivalence.
32
+
33
+ Args:
34
+ model: The Llama model to permute
35
+ mlp_permutation: Permutation indices for MLP hidden dimensions
36
+ emb_permutation: Permutation indices for embedding dimensions
37
+
38
+ Returns:
39
+ None: Modifies the model in-place
40
+ """
41
+ weights = model.state_dict()
42
+
43
+ # Permuting the Self attention layers
44
+ for key in weights:
45
+ if "self_attn" not in key:
46
+ continue
47
+
48
+ if "o_proj" in key:
49
+ weights[key] = weights[key][emb_permutation]
50
+ else:
51
+ weights[key] = weights[key][:, emb_permutation]
52
+
53
+ # Permuting the mlp projection layers
54
+ for key in weights:
55
+ if "mlp" not in key:
56
+ continue
57
+ if len(weights[key].shape) != 2:
58
+ continue
59
+
60
+ dim_0 = weights[key].size(0)
61
+ dim_1 = weights[key].size(1)
62
+
63
+ if dim_0 == len(mlp_permutation):
64
+ weights[key] = weights[key][mlp_permutation]
65
+ elif dim_1 == len(mlp_permutation):
66
+ weights[key] = weights[key][:, mlp_permutation]
67
+
68
+ if dim_0 == len(emb_permutation):
69
+ weights[key] = weights[key][emb_permutation]
70
+ elif dim_1 == len(emb_permutation):
71
+ weights[key] = weights[key][:, emb_permutation]
72
+
73
+ # input_layernorm, post_attention_layernorm
74
+ for key in weights:
75
+ if "model.layers" not in key:
76
+ continue
77
+ if len(weights[key].shape) != 1 or len(weights[key]) != len(emb_permutation):
78
+ continue
79
+
80
+ weights[key] = weights[key][emb_permutation]
81
+
82
+ model.load_state_dict(weights)
83
+
84
+
85
+ def permute_embedding_layer(model, emb_permutation):
86
+ """
87
+ Apply permutation to embedding layer weights in a Llama model.
88
+
89
+ Args:
90
+ model: The Llama model to permute
91
+ emb_permutation: Permutation indices for embedding dimensions
92
+
93
+ Returns:
94
+ None: Modifies the model in-place
95
+ """
96
+ weights = model.state_dict()
97
+
98
+ weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
99
+ model.load_state_dict(weights)
100
+
101
+
102
+ def permute_output_layer(model, emb_permutation):
103
+ """
104
+ Apply permutation to output layer weights in a Llama model.
105
+
106
+ Permutes the language model head and final normalization layer.
107
+
108
+ Args:
109
+ model: The Llama model to permute
110
+ emb_permutation: Permutation indices for embedding dimensions
111
+
112
+ Returns:
113
+ None: Modifies the model in-place
114
+ """
115
+ weights = model.state_dict()
116
+
117
+ weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
118
+ weights["model.norm.weight"] = weights["model.norm.weight"][emb_permutation]
119
+ model.load_state_dict(weights)
model-tracing/tracing/statistics/__init__.py ADDED
File without changes
model-tracing/tracing/statistics/csh.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Chi-Squared Hypothesis (CSH) test for comparing neural network models.
3
+
4
+ This module provides functions to test whether two models have similar activation patterns
5
+ across different layers using Chi-Squared statistical tests.
6
+ """
7
+
8
+ import torch
9
+ from collections import defaultdict
10
+ import scipy
11
+ import numpy as np
12
+ from scipy.stats import chi2
13
+
14
+ from scipy.optimize import linear_sum_assignment as LAP
15
+
16
+ from tracing.utils.utils import cossim
17
+ from tracing.utils.evaluate import evaluate
18
+
19
+
20
+ def statistic(base_model, ft_model, dataloader):
21
+ """
22
+ Compute Chi-Squared Hypothesis test statistic between two models.
23
+
24
+ Args:
25
+ base_model: Base model to compare
26
+ ft_model: Fine-tuned or target model to compare against the base model
27
+ dataloader: DataLoader providing input data for activation collection
28
+
29
+ Returns:
30
+ tuple: (p_value, p_values_per_layer) from the CSH test
31
+ """
32
+ return csh_sp_dataloader(base_model, ft_model, dataloader)
33
+
34
+
35
+ def hook(m, inp, op, feats, name):
36
+ """
37
+ Forward hook to capture output activations from model layers.
38
+
39
+ Args:
40
+ m: Module being hooked
41
+ inp: Input to the module
42
+ op: Output from the module
43
+ feats: Dictionary to store activations
44
+ name: Key to store the activations under
45
+ """
46
+ feats[name].append(op.detach().cpu())
47
+
48
+
49
+ def hook_in(m, inp, op, feats, name):
50
+ """
51
+ Forward hook to capture input activations to model layers.
52
+
53
+ Args:
54
+ m: Module being hooked
55
+ inp: Input to the module (tuple)
56
+ op: Output from the module
57
+ feats: Dictionary to store activations
58
+ name: Key to store the activations under
59
+ """
60
+ feats[name].append(inp[0].detach().cpu())
61
+
62
+
63
+ def csh_sp_dataloader_block(base_model, ft_model, dataloader, i):
64
+ """
65
+ Apply CSH test to a specific block in the model.
66
+
67
+ Args:
68
+ base_model: Base model to compare
69
+ ft_model: Fine-tuned or target model to compare against the base model
70
+ dataloader: DataLoader providing input data for activation collection
71
+ i: Block index to analyze
72
+
73
+ Returns:
74
+ float: p-value indicating the statistical similarity between models at block i
75
+ """
76
+ feats = defaultdict(list)
77
+
78
+ base_hook = lambda *args: hook(*args, feats, "base")
79
+ base_model.model.layers[i].mlp.down_proj.register_forward_hook(base_hook)
80
+
81
+ ft_hook = lambda *args: hook(*args, feats, "ft")
82
+ ft_model.model.layers[i].mlp.down_proj.register_forward_hook(ft_hook)
83
+
84
+ evaluate(base_model, dataloader)
85
+ evaluate(ft_model, dataloader)
86
+
87
+ base_mat = torch.vstack(feats["base"])
88
+ ft_mat = torch.vstack(feats["ft"])
89
+
90
+ base_mat = base_mat.view(-1, base_mat.shape[-1]).T
91
+ ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
92
+
93
+ matched = torch.argmax(cossim(base_mat, ft_mat), axis=-1)
94
+ orig = torch.arange(len(matched))
95
+
96
+ cor, pvalue = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
97
+ return pvalue
98
+
99
+
100
+ def csh_sp_dataloader(base_model, ft_model, dataloader, n_blocks=32):
101
+ """
102
+ Apply CSH test across all model blocks using activations from a dataloader.
103
+
104
+ Performs Chi-Squared Hypothesis test by:
105
+ 1. Collecting activations from both models using the same input data
106
+ 2. Computing optimal matching between neurons in corresponding layers
107
+ 3. Calculating Spearman correlation between matched indices and original indices
108
+ 4. Computing combined p-value using Fisher's method
109
+
110
+ Args:
111
+ base_model: Base model to compare
112
+ ft_model: Fine-tuned or target model to compare against the base model
113
+ dataloader: DataLoader providing input data for activation collection
114
+ n_blocks: Number of transformer blocks to analyze (default: 32)
115
+
116
+ Returns:
117
+ tuple: (combined_p_value, list_of_p_values_per_layer)
118
+ """
119
+ chi_squared = 0
120
+ feats = defaultdict(list)
121
+
122
+ base_hooks = {}
123
+ ft_hooks = {}
124
+
125
+ for i in range(n_blocks):
126
+ layer = str(i)
127
+
128
+ base_hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: hook(
129
+ m, inp, op, feats, "base_" + layer
130
+ )
131
+ base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hooks[layer])
132
+
133
+ ft_hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: hook(
134
+ m, inp, op, feats, "ft_" + layer
135
+ )
136
+ ft_model.model.layers[i].mlp.up_proj.register_forward_hook(ft_hooks[layer])
137
+
138
+ evaluate(base_model, dataloader)
139
+ evaluate(ft_model, dataloader)
140
+
141
+ p_values = []
142
+ count = 0
143
+
144
+ for i in range(n_blocks):
145
+ base_mat = torch.vstack(feats["base_" + str(i)])
146
+ ft_mat = torch.vstack(feats["ft_" + str(i)])
147
+
148
+ base_mat = torch.reshape(
149
+ base_mat, (base_mat.shape[0] * base_mat.shape[1], base_mat.shape[2])
150
+ )
151
+ ft_mat = torch.reshape(ft_mat, (ft_mat.shape[0] * ft_mat.shape[1], ft_mat.shape[2]))
152
+
153
+ base_mat = base_mat.T
154
+ ft_mat = ft_mat.T
155
+
156
+ matched = LAP(
157
+ cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True
158
+ )
159
+ matched = matched[1]
160
+ orig = torch.arange(len(matched))
161
+
162
+ cor, temp = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
163
+
164
+ if not np.isnan(temp):
165
+ chi_squared -= 2 * np.log(temp)
166
+ count += 1
167
+ print(i, temp)
168
+ p_values.append(temp)
169
+
170
+ p_value = chi2.sf(chi_squared, df=2 * count)
171
+
172
+ return p_value, p_values
model-tracing/tracing/statistics/csu.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Cosine Similarity of Weights (CSW) test for comparing neural network models.
3
+
4
+ This module provides functions to test whether two models have similar weight matrices
5
+ using cosine similarity and statistical tests to quantify the similarity.
6
+ """
7
+
8
+ import torch
9
+
10
+ from tracing.utils.utils import cossim, fisher
11
+ import scipy
12
+ import numpy as np
13
+ from scipy.stats import chi2
14
+
15
+ from scipy.optimize import linear_sum_assignment as LAP
16
+
17
+
18
+ def statistic(base_model, ft_model):
19
+ """
20
+ Compute Cosine Similarity of Weights statistic between two models.
21
+
22
+ Args:
23
+ base_model: Base model to compare
24
+ ft_model: Fine-tuned or target model to compare against the base model
25
+
26
+ Returns:
27
+ tuple: (aggregate_p_value, p_values_per_layer) from the CSW test
28
+ """
29
+ return csw_sp(base_model, ft_model)
30
+
31
+
32
+ def csw_sp_layer(base_model, ft_model, layer_name):
33
+ """
34
+ Calculate Cosine Similarity of Weights for a specific layer.
35
+
36
+ Uses linear assignment to find optimal matching between neurons in the layer
37
+ and calculates Spearman correlation to quantify similarity.
38
+
39
+ Args:
40
+ base_model: Base model to compare
41
+ ft_model: Fine-tuned or target model to compare against the base model
42
+ layer_name: Name of the layer in the model's state dict to analyze
43
+
44
+ Returns:
45
+ float: p-value indicating the statistical similarity of weight matrices
46
+ """
47
+ base_mat = base_model.state_dict()[layer_name]
48
+ ft_mat = ft_model.state_dict()[layer_name]
49
+
50
+ matched = LAP(cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True)
51
+ matched = matched[1]
52
+ orig = torch.arange(len(matched))
53
+
54
+ cor, pvalue = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
55
+ return pvalue
56
+
57
+
58
+ def csw_sp(model1, model2):
59
+ """
60
+ Apply CSW test across all MLP up-projection layers in the models.
61
+
62
+ Performs Fisher's method to combine p-values from individual layer tests
63
+ into an aggregate statistic.
64
+
65
+ Args:
66
+ model1: First model to compare
67
+ model2: Second model to compare
68
+
69
+ Returns:
70
+ tuple: (aggregate_p_value, list_of_p_values_per_layer)
71
+ """
72
+ chi_squared = 0
73
+ num_layers = 0
74
+
75
+ p_values = []
76
+
77
+ for name1, name2 in zip(list(model1.state_dict().keys()), list(model2.state_dict().keys())):
78
+ if name1 != name2:
79
+ raise ValueError(f"Model parameter names do not match: {name1} != {name2}")
80
+ elif "mlp.up_proj" not in name1:
81
+ continue
82
+
83
+ pvalue = csw_sp_layer(model1, model2, name1)
84
+ if not np.isnan(pvalue):
85
+ chi_squared -= 2 * np.log(pvalue)
86
+ num_layers += 1
87
+ p_values.append(pvalue)
88
+
89
+ print(name1, pvalue)
90
+
91
+ aggregate_pvalue = chi2.sf(chi_squared, df=2 * num_layers)
92
+ return aggregate_pvalue, p_values
93
+
94
+
95
+ def csw_sp_pair(base_model, ft_model, layer_name_base, layer_name_ft):
96
+ """
97
+ Calculate Cosine Similarity of Weights between two specific layers.
98
+
99
+ Similar to csw_sp_layer but allows comparing layers with different names.
100
+
101
+ Args:
102
+ base_model: Base model to compare
103
+ ft_model: Fine-tuned or target model to compare against the base model
104
+ layer_name_base: Name of the layer in the base model's state dict
105
+ layer_name_ft: Name of the layer in the fine-tuned model's state dict
106
+
107
+ Returns:
108
+ float: p-value indicating the statistical similarity of weight matrices
109
+ """
110
+ base_mat = base_model.state_dict()[layer_name_base]
111
+ ft_mat = ft_model.state_dict()[layer_name_ft]
112
+
113
+ matched = LAP(cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True)
114
+ matched = matched[1]
115
+ orig = torch.arange(len(matched))
116
+
117
+ cor, pvalue = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
118
+ return pvalue
119
+
120
+
121
+ def statistic_all(base_model, ft_model):
122
+ """
123
+ Compute comprehensive pairwise comparisons between all compatible layers.
124
+
125
+ Tests every possible layer pairing between models that have compatible shapes,
126
+ useful for exploring model structure similarities without assumptions.
127
+
128
+ Args:
129
+ base_model: Base model to compare
130
+ ft_model: Fine-tuned or target model to compare against the base model
131
+
132
+ Returns:
133
+ float: Aggregate p-value from Fisher's method combining all layer comparisons
134
+ """
135
+ base_model.to("cpu")
136
+ ft_model.to("cpu")
137
+
138
+ weights_base = base_model.state_dict()
139
+ weights_ft = ft_model.state_dict()
140
+
141
+ shapes_base = {}
142
+ shapes_ft = {}
143
+
144
+ for name1 in list(weights_base.keys()):
145
+ shapes_base[name1] = weights_base[name1].shape
146
+ for name2 in list(weights_ft.keys()):
147
+ shapes_ft[name2] = weights_ft[name2].shape
148
+
149
+ pvalues = []
150
+
151
+ for name1 in list(weights_base.keys()):
152
+ for name2 in list(weights_ft.keys()):
153
+ if shapes_base[name1] == shapes_ft[name2] and len(shapes_base[name1]) != 1:
154
+ pval = csw_sp_pair(base_model, ft_model, name1, name2)
155
+ print(name1, name2, pval)
156
+ pvalues.append(pval)
157
+
158
+ print(pvalues)
159
+
160
+ res = 0
161
+
162
+ if len(pvalues) == 0:
163
+ res = 999
164
+ else:
165
+ res = fisher(pvalues)
166
+
167
+ print(res)
168
+ return res
model-tracing/tracing/statistics/jsd.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Jensen-Shannon Divergence (JSD) for comparing language model outputs.
3
+
4
+ This module provides functions to compute the Jensen-Shannon Divergence between
5
+ probability distributions output by two language models, measuring their similarity
6
+ in output space rather than parameter space.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ from tracing.utils.evaluate import (
14
+ prepare_hf_dataset,
15
+ prepare_hf_dataloader,
16
+ )
17
+
18
+
19
+ def statistic(base_model, ft_model, dataloader, device="cuda"):
20
+ """
21
+ Compute Jensen-Shannon Divergence between outputs of two language models.
22
+
23
+ Args:
24
+ base_model: Base model to compare
25
+ ft_model: Fine-tuned or target model to compare against the base model
26
+ dataloader: DataLoader providing input data for model evaluation
27
+ device: Device to run the computation on (default: "cuda")
28
+
29
+ Returns:
30
+ float: Sum of Jensen-Shannon Divergence values across all batches
31
+ """
32
+ return compute_jsd(base_model, ft_model, dataloader, device)
33
+
34
+
35
+ def statistic_stable(base_model, ft_model, dataloader, device="cuda"):
36
+ """
37
+ Compute numerically stable Jensen-Shannon Divergence between outputs of two models.
38
+
39
+ This version handles potential numerical issues better than the standard version.
40
+
41
+ Args:
42
+ base_model: Base model to compare
43
+ ft_model: Fine-tuned or target model to compare against the base model
44
+ dataloader: DataLoader providing input data for model evaluation
45
+ device: Device to run the computation on (default: "cuda")
46
+
47
+ Returns:
48
+ float: Sum of Jensen-Shannon Divergence values across all batches
49
+ """
50
+ return compute_jsd_stable(base_model, ft_model, dataloader, device)
51
+
52
+
53
+ def compute_jsd(base_model, ft_model, dataloader, device="cuda"):
54
+ """
55
+ Compute Jensen-Shannon Divergence between two models using softmax outputs.
56
+
57
+ Processes each batch in the dataloader and computes JSD between the models'
58
+ probability distributions over vocabulary tokens. Handles potential vocabulary
59
+ size differences by truncating to a common size (32000 tokens).
60
+
61
+ Args:
62
+ base_model: Base model to compare
63
+ ft_model: Fine-tuned or target model to compare against the base model
64
+ dataloader: DataLoader providing input data for model evaluation
65
+ device: Device to run the computation on (default: "cuda")
66
+
67
+ Returns:
68
+ float: Sum of Jensen-Shannon Divergence values across all batches
69
+ """
70
+ jsds = []
71
+
72
+ base_model.to(device)
73
+ ft_model.to(device)
74
+
75
+ with torch.no_grad():
76
+ for batch in dataloader:
77
+ input_ids = batch["input_ids"].to(device)
78
+ attention_mask = batch["attention_mask"].to(device)
79
+ labels = batch["labels"].to(device)
80
+
81
+ outputs_base = base_model(
82
+ input_ids=input_ids,
83
+ attention_mask=attention_mask,
84
+ labels=labels,
85
+ )
86
+ outputs_ft = ft_model(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ labels=labels,
90
+ )
91
+
92
+ logits_base = outputs_base.logits.squeeze()
93
+ logits_ft = outputs_ft.logits.squeeze()
94
+
95
+ softmax_base = torch.softmax(logits_base, dim=-1)
96
+ softmax_ft = torch.softmax(logits_ft, dim=-1)
97
+
98
+ # Truncate the softmax outputs to the first 32000 dimensions
99
+ softmax_base = softmax_base[:, :32000]
100
+ softmax_ft = softmax_ft[:, :32000]
101
+
102
+ m = 0.5 * (softmax_base + softmax_ft)
103
+ jsd = 0.5 * (F.kl_div(m.log(), softmax_base) + F.kl_div(m.log(), softmax_ft))
104
+
105
+ jsds.append(jsd.item())
106
+
107
+ base_model.to("cpu")
108
+ ft_model.to("cpu")
109
+ return sum(jsds)
110
+
111
+
112
+ def compute_jsd_stable(base_model, ft_model, dataloader, device="cuda"):
113
+ """
114
+ Compute numerically stable Jensen-Shannon Divergence between two models.
115
+
116
+ A more robust implementation that:
117
+ 1. Handles vocabulary size mismatches by truncating to the minimum size
118
+ 2. Uses log-space calculations to avoid numerical underflow
119
+ 3. Computes JSD directly from log probabilities for better stability
120
+
121
+ Args:
122
+ base_model: Base model to compare
123
+ ft_model: Fine-tuned or target model to compare against the base model
124
+ dataloader: DataLoader providing input data for model evaluation
125
+ device: Device to run the computation on (default: "cuda")
126
+
127
+ Returns:
128
+ float: Sum of Jensen-Shannon Divergence values across all batches
129
+ """
130
+ jsds = []
131
+
132
+ base_model.to(device)
133
+ ft_model.to(device)
134
+
135
+ with torch.no_grad():
136
+ for batch in dataloader:
137
+ input_ids = batch["input_ids"].to(device)
138
+ attention_mask = batch["attention_mask"].to(device)
139
+ labels = batch["labels"].to(device)
140
+
141
+ outputs_base = base_model(
142
+ input_ids=input_ids,
143
+ attention_mask=attention_mask,
144
+ labels=labels,
145
+ )
146
+ outputs_ft = ft_model(
147
+ input_ids=input_ids,
148
+ attention_mask=attention_mask,
149
+ labels=labels,
150
+ )
151
+
152
+ logits_base = outputs_base.logits.squeeze()
153
+ logits_ft = outputs_ft.logits.squeeze()
154
+
155
+ # Determine the minimum vocabulary size between the two models
156
+ min_vocab_size = min(logits_base.size(-1), logits_ft.size(-1))
157
+
158
+ # Truncate the logits to the minimum vocabulary size
159
+ logits_base = logits_base[..., :min_vocab_size]
160
+ logits_ft = logits_ft[..., :min_vocab_size]
161
+
162
+ log_probs_base = F.log_softmax(logits_base, dim=-1)
163
+ log_probs_ft = F.log_softmax(logits_ft, dim=-1)
164
+
165
+ m = 0.5 * (log_probs_base.exp() + log_probs_ft.exp())
166
+ log_m = m.log()
167
+
168
+ kl_div_base_m = (log_probs_base - log_m).sum(dim=-1)
169
+ kl_div_ft_m = (log_probs_ft - log_m).sum(dim=-1)
170
+
171
+ jsd = 0.5 * (kl_div_base_m + kl_div_ft_m).mean()
172
+ jsds.append(jsd.item())
173
+
174
+ base_model.to("cpu")
175
+ ft_model.to("cpu")
176
+
177
+ return sum(jsds)
178
+
179
+
180
+ if __name__ == "__main__":
181
+
182
+ base_model_name = "LLM360/Amber" # 'openlm-research/open_llama_7b' # 'lmsys/vicuna-7b-v1.5'
183
+ ft_model_name = "LLM360/AmberChat" # 'openlm-research/open_llama_7b_v2' # 'LLM360/Amber' # "lmsys/vicuna-7b-v1.1"
184
+
185
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
186
+ ft_model = AutoModelForCausalLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16)
187
+ base_tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
188
+
189
+ # dataset = load_generated_datasets(base_model_name, ft_model_name, 512, base_tokenizer, ["text"])
190
+ # dataloader = prepare_hf_dataloader(dataset, 1)
191
+
192
+ dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer)
193
+ dataloader = prepare_hf_dataloader(dataset, 1)
194
+
195
+ print(statistic(base_model, ft_model, dataloader))
model-tracing/tracing/statistics/l2.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of L2 distance metrics for comparing neural network model weights.
3
+
4
+ This module provides functions to calculate the L2 (Euclidean) distance between
5
+ corresponding weight tensors of two models, providing a measure of parameter space
6
+ similarity.
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ def statistic(base_model, ft_model):
13
+ """
14
+ Compute the average L2 distance between weights of two models.
15
+
16
+ Args:
17
+ base_model: Base model to compare
18
+ ft_model: Fine-tuned or target model to compare against the base model
19
+
20
+ Returns:
21
+ float: Average L2 distance across all comparable parameters
22
+ """
23
+ return calculate_l2_distance(base_model, ft_model)
24
+
25
+
26
+ def calculate_l2_distance(model1, model2):
27
+ """
28
+ Calculate the average L2 distance between corresponding parameters of two models.
29
+
30
+ For each parameter tensor in the models, computes the Euclidean distance between
31
+ them and returns the average across all parameters. Handles potential shape
32
+ mismatches in embedding or output layers.
33
+
34
+ Args:
35
+ model1: First model to compare
36
+ model2: Second model to compare
37
+
38
+ Returns:
39
+ float: Average L2 distance across all comparable parameters
40
+
41
+ Raises:
42
+ ValueError: If parameter names don't match or if there are shape mismatches
43
+ in parameters other than embedding or output layers
44
+ """
45
+ total_squared_diff = 0
46
+ num_layers = 0
47
+
48
+ all_layers = []
49
+
50
+ for (name1, param1), (name2, param2) in zip(
51
+ model1.named_parameters(), model2.named_parameters()
52
+ ):
53
+ if name1 != name2:
54
+ raise ValueError(f"Model parameter names do not match: {name1} != {name2}")
55
+ elif param1.shape != param2.shape:
56
+ if name1 == "model.embed_tokens.weight" or name1 == "lm_head.weight":
57
+ print(
58
+ f"Skipping {name1} because of shape mismatch: {param1.shape} != {param2.shape}"
59
+ )
60
+ continue
61
+ raise ValueError(
62
+ f"Model parameter shapes do not match for {name1}: {param1.shape} != {param2.shape}"
63
+ )
64
+
65
+ l2_diff = torch.sum((param1 - param2) ** 2) ** 0.5
66
+ total_squared_diff += l2_diff.item()
67
+ all_layers.append(l2_diff.item())
68
+ num_layers += 1
69
+
70
+ avg_l2_distance = total_squared_diff / num_layers
71
+ return avg_l2_distance
model-tracing/tracing/statistics/match.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of activation matching algorithms for comparing neural network models.
3
+
4
+ This module provides functions for matching neurons between two models based on
5
+ their activation patterns, helping identify corresponding functional units despite
6
+ permutation differences.
7
+ """
8
+
9
+ import torch
10
+ from collections import defaultdict
11
+ import scipy
12
+ import numpy as np
13
+
14
+ from tracing.utils.evaluate import evaluate
15
+ from tracing.utils.llama.matching import match_wmats
16
+
17
+
18
+ def hook_in(m, inp, op, feats, name):
19
+ """
20
+ Forward hook to capture input activations to model layers.
21
+
22
+ Args:
23
+ m: Module being hooked
24
+ inp: Input to the module (tuple)
25
+ op: Output from the module
26
+ feats: Dictionary to store activations
27
+ name: Key to store the activations under
28
+ """
29
+ feats[name].append(inp[0].detach().cpu())
30
+
31
+
32
+ def hook_out(m, inp, op, feats, name):
33
+ """
34
+ Forward hook to capture output activations from model layers.
35
+
36
+ Args:
37
+ m: Module being hooked
38
+ inp: Input to the module
39
+ op: Output from the module
40
+ feats: Dictionary to store activations
41
+ name: Key to store the activations under
42
+ """
43
+ feats[name].append(op.detach().cpu())
44
+
45
+
46
+ def statistic(base_model, ft_model, dataloader, n_blocks=32):
47
+ """
48
+ Compute neuron matching statistics across all transformer blocks.
49
+
50
+ For each block, compares the gate and up projections to determine if
51
+ the permutation patterns are consistent, which would indicate functionally
52
+ corresponding neurons.
53
+
54
+ Args:
55
+ base_model: Base model to compare
56
+ ft_model: Fine-tuned or target model to compare against the base model
57
+ dataloader: DataLoader providing input data for activation collection
58
+ n_blocks: Number of transformer blocks to analyze (default: 32)
59
+
60
+ Returns:
61
+ list: Spearman correlation p-values for each block
62
+ """
63
+ stats = []
64
+
65
+ for i in range(n_blocks):
66
+ gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i=i)
67
+ up_match = mlp_matching_up(base_model, ft_model, dataloader, i=i)
68
+
69
+ cor, pvalue = scipy.stats.spearmanr(gate_match.tolist(), up_match.tolist())
70
+ print(i, pvalue, len(gate_match))
71
+ stats.append(pvalue)
72
+
73
+ return stats
74
+
75
+
76
+ def statistic_layer(base_model, ft_model, dataloader, i=0):
77
+ """
78
+ Compute neuron matching statistics for a specific transformer block.
79
+
80
+ Args:
81
+ base_model: Base model to compare
82
+ ft_model: Fine-tuned or target model to compare against the base model
83
+ dataloader: DataLoader providing input data for activation collection
84
+ i: Block index to analyze (default: 0)
85
+
86
+ Returns:
87
+ float: Spearman correlation p-value for the specified block
88
+ """
89
+ gate_perm = mlp_matching_gate(base_model, ft_model, dataloader, i=i)
90
+ up_perm = mlp_matching_up(base_model, ft_model, dataloader, i=i)
91
+ cor, pvalue = scipy.stats.spearmanr(gate_perm.tolist(), up_perm.tolist())
92
+ return pvalue
93
+
94
+
95
+ def mlp_matching_gate(base_model, ft_model, dataloader, i=0):
96
+ """
97
+ Match neurons between models by comparing gate projection activations.
98
+
99
+ Collects activations from the gate projection layer for both models
100
+ and computes a permutation that would align corresponding neurons.
101
+
102
+ Args:
103
+ base_model: Base model to compare
104
+ ft_model: Fine-tuned or target model to compare against the base model
105
+ dataloader: DataLoader providing input data for activation collection
106
+ i: Block index to analyze (default: 0)
107
+
108
+ Returns:
109
+ torch.Tensor: Permutation indices that match neurons between models
110
+ """
111
+ feats = defaultdict(list)
112
+
113
+ base_hook = lambda *args: hook_out(*args, feats, "base")
114
+ base_handle = base_model.model.layers[i].mlp.gate_proj.register_forward_hook(base_hook)
115
+
116
+ ft_hook = lambda *args: hook_out(*args, feats, "ft")
117
+ ft_handle = ft_model.model.layers[i].mlp.gate_proj.register_forward_hook(ft_hook)
118
+
119
+ evaluate(base_model, dataloader)
120
+ evaluate(ft_model, dataloader)
121
+
122
+ base_mat = torch.vstack(feats["base"])
123
+ ft_mat = torch.vstack(feats["ft"])
124
+
125
+ base_mat.to("cuda")
126
+ ft_mat.to("cuda")
127
+
128
+ base_mat = base_mat.view(-1, base_mat.shape[-1]).T
129
+ ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
130
+
131
+ base_handle.remove()
132
+ ft_handle.remove()
133
+
134
+ perm = match_wmats(base_mat, ft_mat)
135
+
136
+ return perm
137
+
138
+
139
+ def mlp_matching_up(base_model, ft_model, dataloader, i=0):
140
+ """
141
+ Match neurons between models by comparing up projection activations.
142
+
143
+ Collects activations from the up projection layer for both models
144
+ and computes a permutation that would align corresponding neurons.
145
+
146
+ Args:
147
+ base_model: Base model to compare
148
+ ft_model: Fine-tuned or target model to compare against the base model
149
+ dataloader: DataLoader providing input data for activation collection
150
+ i: Block index to analyze (default: 0)
151
+
152
+ Returns:
153
+ torch.Tensor: Permutation indices that match neurons between models
154
+ """
155
+ feats = defaultdict(list)
156
+
157
+ base_hook = lambda *args: hook_out(*args, feats, "base")
158
+ base_handle = base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hook)
159
+
160
+ ft_hook = lambda *args: hook_out(*args, feats, "ft")
161
+ ft_handle = ft_model.model.layers[i].mlp.up_proj.register_forward_hook(ft_hook)
162
+
163
+ evaluate(base_model, dataloader)
164
+ evaluate(ft_model, dataloader)
165
+
166
+ base_mat = torch.vstack(feats["base"])
167
+ ft_mat = torch.vstack(feats["ft"])
168
+
169
+ base_mat.to("cuda")
170
+ ft_mat.to("cuda")
171
+
172
+ base_mat = base_mat.view(-1, base_mat.shape[-1]).T
173
+ ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
174
+
175
+ base_handle.remove()
176
+ ft_handle.remove()
177
+
178
+ perm = match_wmats(base_mat, ft_mat)
179
+
180
+ return perm
181
+
182
+
183
+ def mlp_layers(base_model, ft_model, dataloader, i, j):
184
+ """
185
+ Compare gate and up projections between specific layers of two models.
186
+
187
+ Useful for comparing non-corresponding layers to find functional similarities.
188
+
189
+ Args:
190
+ base_model: Base model to compare
191
+ ft_model: Fine-tuned or target model to compare against the base model
192
+ dataloader: DataLoader providing input data for activation collection
193
+ i: Layer index in the base model
194
+ j: Layer index in the fine-tuned model
195
+
196
+ Returns:
197
+ float: Spearman correlation p-value between gate and up projections
198
+ """
199
+ gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i, j)
200
+ up_match = mlp_matching_up(base_model, ft_model, dataloader, i, j)
201
+
202
+ cor, pvalue = scipy.stats.spearmanr(gate_match.tolist(), up_match.tolist())
203
+
204
+ return pvalue
205
+
206
+
207
+ def statistic_all(model_1, model_2, dataloader):
208
+ """
209
+ Perform comprehensive layer matching between two models.
210
+
211
+ Tests all combinations of layers between the models to identify corresponding
212
+ functional units, regardless of their position in the network architecture.
213
+
214
+ Args:
215
+ model_1: First model to compare
216
+ model_2: Second model to compare
217
+ dataloader: DataLoader providing input data for activation collection
218
+
219
+ Returns:
220
+ None: Prints matching results during execution
221
+ """
222
+ model_1_matched = np.zeros(model_1.config.num_hidden_layers)
223
+ model_2_matched = np.zeros(model_2.config.num_hidden_layers)
224
+
225
+ for i in range(model_1.config.num_hidden_layers):
226
+ for j in range(model_2.config.num_hidden_layers):
227
+ if model_1_matched[i] == 1 or model_2_matched[j] == 1:
228
+ continue
229
+ stat = mlp_layers(model_1, model_2, dataloader, i, j)
230
+ print(i, j, stat)
231
+ if stat < 0.000001:
232
+ model_1_matched[i] = 1
233
+ model_2_matched[j] = 1
234
+ break
model-tracing/tracing/statistics/mc.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils.llama.model import avg_model
2
+ from ..utils.olmo.model import avg_model as avg_model_olmo
3
+ from ..utils.evaluate import evaluate
4
+
5
+
6
+ def statistic(base_model, ft_model, tmp_model, dataloader, attn=False, emb=False, alpha=0.5):
7
+ if "olmo" in base_model._get_name().lower():
8
+ avg_model_olmo(base_model, ft_model, tmp_model, attn=attn, emb=emb)
9
+ else:
10
+ avg_model(base_model, ft_model, tmp_model, attn=attn, emb=emb)
11
+
12
+ return sum(evaluate(tmp_model, dataloader))
model-tracing/tracing/statistics/perm_mc_l2.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tracing.perm.permute import permute_model
3
+ from scripts.perm.main import p_value_exact, p_value_approx
4
+
5
+
6
+ def statistic(base_model, ft_model, mc_stat, l2_stat, num_perm, emb_dim=4096, mlp_dim=11008):
7
+
8
+ unperm_stat_mc = mc_stat(base_model, ft_model)
9
+ unperm_stat_l2 = l2_stat(base_model, ft_model)
10
+
11
+ print(unperm_stat_mc, unperm_stat_l2)
12
+
13
+ perm_stats_mc = []
14
+ perm_stats_l2 = []
15
+
16
+ for i in range(num_perm):
17
+ mlp_permutation = torch.randperm(mlp_dim)
18
+ emb_permutation = torch.randperm(emb_dim)
19
+
20
+ permute_model(ft_model, mlp_permutation, emb_permutation)
21
+
22
+ perm_stat_mc = mc_stat(base_model, ft_model)
23
+ perm_stat_l2 = l2_stat(base_model, ft_model)
24
+
25
+ perm_stats_mc.append(perm_stat_mc)
26
+ perm_stats_l2.append(perm_stat_l2)
27
+
28
+ print(i, perm_stat_mc, perm_stat_l2)
29
+
30
+ exact_mc = p_value_exact(unperm_stat_mc, perm_stats_mc.copy())
31
+ approx_mc = p_value_approx(unperm_stat_mc, perm_stats_mc.copy())
32
+
33
+ exact_l2 = p_value_exact(unperm_stat_l2, perm_stats_l2.copy())
34
+ approx_l2 = p_value_approx(unperm_stat_l2, perm_stats_l2.copy())
35
+
36
+ print(exact_mc, approx_mc)
37
+ print(exact_l2, approx_l2)
38
+
39
+ return (
40
+ exact_mc,
41
+ approx_mc,
42
+ exact_l2,
43
+ approx_l2,
44
+ unperm_stat_mc,
45
+ unperm_stat_l2,
46
+ perm_stats_mc,
47
+ perm_stats_l2,
48
+ )
model-tracing/tracing/utils/__init__.py ADDED
File without changes
model-tracing/tracing/utils/evaluate.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import glob
3
+ from typing import List
4
+ from datasets import load_dataset, concatenate_datasets, Dataset
5
+ from accelerate.data_loader import DataLoaderShard
6
+ from transformers import AutoTokenizer
7
+
8
+
9
+ def prepare_hf_dataset(hf_path, block_size, tokenizer, split="test"):
10
+ raw_dataset = load_dataset(hf_path, split=split)
11
+ dataset = raw_dataset.map(
12
+ lambda examples: tokenize_function(examples, tokenizer),
13
+ batched=True,
14
+ remove_columns=["text"],
15
+ ).map(lambda examples: group_texts(examples, block_size), batched=True, batch_size=1)
16
+ dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
17
+ return dataset
18
+
19
+
20
+ def prepare_programming_dataset(
21
+ json_path: str, block_size: int, tokenizer: AutoTokenizer, columns_ignored: List[str]
22
+ ):
23
+ raw_dataset = load_dataset("json", data_files=json_path)
24
+
25
+ dataset = (
26
+ raw_dataset["train"]
27
+ .map(
28
+ lambda examples: tokenize_function(examples, tokenizer),
29
+ batched=True,
30
+ num_proc=4,
31
+ remove_columns=columns_ignored,
32
+ )
33
+ .map(
34
+ lambda examples: group_texts(examples, block_size),
35
+ batched=True,
36
+ batch_size=1,
37
+ num_proc=1,
38
+ )
39
+ )
40
+ dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
41
+ return dataset
42
+
43
+
44
+ def prepare_random_sample_dataset(num_samples, block_size, vocab_size=32000):
45
+ tokens = torch.randint(low=0, high=vocab_size, size=(num_samples, block_size))
46
+ dictionary = {"input_ids": tokens, "attention_mask": torch.ones(tokens.shape), "labels": tokens}
47
+
48
+ dataset = Dataset.from_dict(dictionary)
49
+ dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
50
+ return dataset
51
+
52
+
53
+ def load_m2d2_datasets(
54
+ test_name: str,
55
+ block_size: int,
56
+ tokenizer: AutoTokenizer,
57
+ columns_ignored: List[str],
58
+ ):
59
+ base_path = "/juice4/scr4/nlp/model-tracing/m2d2_s2orc"
60
+ json_dir = f"{base_path}/{test_name}"
61
+ json_files = glob.glob(f"{json_dir}/*.json")
62
+
63
+ if not json_files:
64
+ raise ValueError(f"No JSON files found for test case: {test_name}")
65
+
66
+ datasets = []
67
+ for json_file in json_files:
68
+ dataset = prepare_programming_dataset(json_file, block_size, tokenizer, columns_ignored)
69
+ datasets.append(dataset)
70
+
71
+ combined_dataset = concatenate_datasets(datasets)
72
+ return combined_dataset
73
+
74
+
75
+ def load_dolma_programming_datasets(
76
+ test_name: str,
77
+ block_size: int,
78
+ tokenizer: AutoTokenizer,
79
+ columns_ignored: List[str],
80
+ ):
81
+ base_path = "/juice4/scr4/nlp/model-tracing/dolma_program_languages"
82
+
83
+ json_dir = f"{base_path}/json_files_{test_name}"
84
+ json_files = glob.glob(f"{json_dir}/*.json")
85
+
86
+ datasets = []
87
+ for json_file in json_files:
88
+ dataset = prepare_programming_dataset(json_file, block_size, tokenizer, columns_ignored)
89
+ datasets.append(dataset)
90
+
91
+ combined_dataset = concatenate_datasets(datasets)
92
+ return combined_dataset
93
+
94
+
95
+ def load_generated_datasets(base_model_name, ft_model_name, block_size, tokenizer, columns_ignored):
96
+
97
+ json_file_base = (
98
+ "/juice4/scr4/nlp/model-tracing/generations/"
99
+ + base_model_name.replace("/", "-")
100
+ + "_gentext.json"
101
+ )
102
+ json_file_ft = (
103
+ "/juice4/scr4/nlp/model-tracing/generations/"
104
+ + ft_model_name.replace("/", "-")
105
+ + "_gentext.json"
106
+ )
107
+ dataset_base = prepare_programming_dataset(
108
+ json_file_base, block_size, tokenizer, columns_ignored
109
+ )
110
+ dataset_ft = prepare_programming_dataset(json_file_ft, block_size, tokenizer, columns_ignored)
111
+
112
+ datasets = []
113
+ datasets.append(dataset_base)
114
+ datasets.append(dataset_ft)
115
+
116
+ combined_dataset = concatenate_datasets(datasets)
117
+
118
+ return combined_dataset
119
+
120
+
121
+ def prepare_hf_dataloader(dataset, batch_size: int):
122
+ return DataLoaderShard(dataset, batch_size=batch_size)
123
+
124
+
125
+ def evaluate_70b(model, dataloader, device: str = "cuda:0"):
126
+ losses = []
127
+ with torch.no_grad():
128
+ for batch in dataloader:
129
+ input_ids = batch["input_ids"].to(device)
130
+ attention_mask = batch["attention_mask"].to(device)
131
+ labels = batch["labels"].to(device)
132
+
133
+ outputs = model(
134
+ input_ids=input_ids,
135
+ attention_mask=attention_mask,
136
+ labels=labels,
137
+ )
138
+ loss = outputs.loss
139
+ losses.append(loss.item())
140
+
141
+ return losses
142
+
143
+
144
+ def evaluate(model, dataloader, device: str = "cuda"):
145
+ losses = []
146
+ model.to(device)
147
+ with torch.no_grad():
148
+ for batch in dataloader:
149
+ input_ids = batch["input_ids"].to(device)
150
+ attention_mask = batch["attention_mask"].to(device)
151
+ labels = batch["labels"].to(device)
152
+
153
+ outputs = model(
154
+ input_ids=input_ids,
155
+ attention_mask=attention_mask,
156
+ labels=labels,
157
+ )
158
+ loss = outputs.loss
159
+ losses.append(loss.item())
160
+
161
+ model.to("cpu")
162
+ return losses
163
+
164
+
165
+ def prepare_aya_dataset(subset: str, language: str, block_size: int, tokenizer: AutoTokenizer):
166
+ """
167
+ Prepare the Aya dataset for a specific subset and language.
168
+ """
169
+ raw_dataset = load_dataset("CohereForAI/aya_evaluation_suite", subset)
170
+ filtered_dataset = raw_dataset.filter(lambda example: example["language"] == language)
171
+
172
+ dataset = filtered_dataset.map(
173
+ lambda examples: tokenize_function(examples, tokenizer),
174
+ batched=True,
175
+ remove_columns=filtered_dataset.column_names,
176
+ ).map(lambda examples: group_texts(examples, block_size), batched=True, batch_size=1)
177
+
178
+ dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
179
+ return dataset
180
+
181
+
182
+ def tokenize_aya_function(examples, tokenizer: AutoTokenizer):
183
+ """
184
+ Tokenize Aya dataset examples.
185
+ """
186
+ return tokenizer(examples["inputs"])
187
+
188
+
189
+ def tokenize_function(examples, tokenizer):
190
+ if "text" in examples:
191
+ return tokenizer(examples["text"])
192
+ elif "inputs" in examples:
193
+ return tokenizer(examples["inputs"])
194
+ else:
195
+ raise ValueError("Neither 'text' nor 'inputs' found in examples")
196
+
197
+
198
+ def group_texts(examples, block_size):
199
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
200
+ total_length = len(concatenated_examples["input_ids"])
201
+
202
+ total_length = (total_length // block_size) * block_size
203
+ # Split by chunks of max_len.
204
+ result = {
205
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
206
+ for k, t in concatenated_examples.items()
207
+ }
208
+ result["labels"] = result["input_ids"].copy()
209
+ return result
model-tracing/tracing/utils/llama/matching.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VOCAB_SIZE = 32000
2
+
3
+ import torch
4
+ from scipy.optimize import linear_sum_assignment as LAP
5
+
6
+ from ..utils import pdists
7
+ from .model import permute_model, permute_transformer_block
8
+
9
+
10
+ def match_wmats(wmat0, wmat1):
11
+ dists = pdists(wmat0, wmat1).type(torch.float64)
12
+ perm = LAP(dists)[1]
13
+
14
+ return perm # wmat1[perm] should match wmat0
15
+
16
+
17
+ def match_mlp(base_model, ft_model, i=0):
18
+ base_wmat = base_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
19
+ ft_wmat = ft_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
20
+
21
+ perm = match_wmats(base_wmat, ft_wmat)
22
+
23
+ return perm
24
+
25
+
26
+ def match_emb(base_model, ft_model, i="inp"):
27
+ if i == "inp":
28
+ weight_id = "model.embed_tokens.weight"
29
+ if i == "out":
30
+ weight_id = "lm_head.weight"
31
+
32
+ base_wmat = base_model.state_dict()[weight_id][:VOCAB_SIZE].T
33
+ ft_wmat = ft_model.state_dict()[weight_id][:VOCAB_SIZE].T
34
+
35
+ perm = match_wmats(base_wmat, ft_wmat)
36
+ return perm
37
+
38
+
39
+ def align_model(base_model, ft_model, tmp_model, n_blocks=32):
40
+ emb_perm = match_emb(base_model, ft_model)
41
+ permute_model(ft_model, tmp_model, torch.arange(11008), emb_perm)
42
+
43
+ for i in range(n_blocks):
44
+ mlp_perm = match_mlp(base_model, tmp_model, i=i)
45
+ permute_transformer_block(tmp_model, i, tmp_model, mlp_perm, torch.arange(4096))
model-tracing/tracing/utils/llama/model.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from scipy.stats import ortho_group
4
+
5
+
6
+ def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32):
7
+ permute_embedding_layer(model, tmp_model, emb_permutation)
8
+ for i in range(n_blocks):
9
+ permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation)
10
+ permute_output_layer(tmp_model, tmp_model, emb_permutation)
11
+
12
+
13
+ def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation):
14
+ weights = model.state_dict()
15
+
16
+ weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[
17
+ "model.layers." + str(i) + ".self_attn.q_proj.weight"
18
+ ][:, emb_permutation]
19
+ weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[
20
+ "model.layers." + str(i) + ".self_attn.k_proj.weight"
21
+ ][:, emb_permutation]
22
+ weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[
23
+ "model.layers." + str(i) + ".self_attn.v_proj.weight"
24
+ ][:, emb_permutation]
25
+ weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[
26
+ "model.layers." + str(i) + ".self_attn.o_proj.weight"
27
+ ][emb_permutation]
28
+
29
+ weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
30
+ "model.layers." + str(i) + ".mlp.gate_proj.weight"
31
+ ][mlp_permutation]
32
+ weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
33
+ "model.layers." + str(i) + ".mlp.up_proj.weight"
34
+ ][mlp_permutation]
35
+ weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
36
+ "model.layers." + str(i) + ".mlp.down_proj.weight"
37
+ ][:, mlp_permutation]
38
+
39
+ weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
40
+ "model.layers." + str(i) + ".mlp.gate_proj.weight"
41
+ ][:, emb_permutation]
42
+ weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
43
+ "model.layers." + str(i) + ".mlp.up_proj.weight"
44
+ ][:, emb_permutation]
45
+ weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
46
+ "model.layers." + str(i) + ".mlp.down_proj.weight"
47
+ ][emb_permutation]
48
+
49
+ weights["model.layers." + str(i) + ".input_layernorm.weight"] = weights[
50
+ "model.layers." + str(i) + ".input_layernorm.weight"
51
+ ][
52
+ emb_permutation
53
+ ] # 1d
54
+ weights["model.layers." + str(i) + ".post_attention_layernorm.weight"] = weights[
55
+ "model.layers." + str(i) + ".post_attention_layernorm.weight"
56
+ ][emb_permutation]
57
+
58
+ tmp_model.load_state_dict(weights)
59
+
60
+
61
+ def permute_embedding_layer(model, tmp_model, emb_permutation):
62
+ weights = model.state_dict()
63
+
64
+ weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
65
+ tmp_model.load_state_dict(weights)
66
+
67
+
68
+ def permute_output_layer(model, tmp_model, emb_permutation):
69
+ weights = model.state_dict()
70
+
71
+ weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
72
+ weights["model.norm.weight"] = weights["model.norm.weight"][emb_permutation]
73
+ tmp_model.load_state_dict(weights)
74
+
75
+
76
+ def permute_mlp_block(model, i, tmp_model, mlp_permutation):
77
+ weights = model.state_dict()
78
+
79
+ weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
80
+ "model.layers." + str(i) + ".mlp.gate_proj.weight"
81
+ ][mlp_permutation]
82
+ weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
83
+ "model.layers." + str(i) + ".mlp.up_proj.weight"
84
+ ][mlp_permutation]
85
+ weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
86
+ "model.layers." + str(i) + ".mlp.down_proj.weight"
87
+ ][:, mlp_permutation]
88
+
89
+ tmp_model.load_state_dict(weights)
90
+
91
+
92
+ def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5):
93
+ weights0 = model0.state_dict()
94
+ weights1 = model1.state_dict()
95
+
96
+ weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
97
+ alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
98
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
99
+ )
100
+ weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
101
+ alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
102
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
103
+ )
104
+ weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
105
+ alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
106
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
107
+ )
108
+
109
+ tmp_model.load_state_dict(weights0)
110
+
111
+
112
+ def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True):
113
+ weights0 = model0.state_dict()
114
+ weights1 = model1.state_dict()
115
+
116
+ if attn is True:
117
+ weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = (
118
+ alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"]
119
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"]
120
+ )
121
+ weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = (
122
+ alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"]
123
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"]
124
+ )
125
+ weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = (
126
+ alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"]
127
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"]
128
+ )
129
+ weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = (
130
+ alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"]
131
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"]
132
+ )
133
+
134
+ weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
135
+ alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
136
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
137
+ )
138
+ weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
139
+ alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
140
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
141
+ )
142
+ weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
143
+ alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
144
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
145
+ )
146
+
147
+ weights0["model.layers." + str(i) + ".input_layernorm.weight"] = (
148
+ alpha * weights0["model.layers." + str(i) + ".input_layernorm.weight"]
149
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".input_layernorm.weight"]
150
+ )
151
+ weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"] = (
152
+ alpha * weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"]
153
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".post_attention_layernorm.weight"]
154
+ )
155
+
156
+ tmp_model.load_state_dict(weights0)
157
+
158
+
159
+ def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5):
160
+ weights0 = model0.state_dict()
161
+ weights1 = model1.state_dict()
162
+
163
+ weights0["model.embed_tokens.weight"] = (
164
+ alpha * weights0["model.embed_tokens.weight"]
165
+ + (1 - alpha) * weights1["model.embed_tokens.weight"]
166
+ )
167
+
168
+ tmp_model.load_state_dict(weights0)
169
+
170
+
171
+ def avg_output_layer(model0, model1, tmp_model, alpha=0.5):
172
+ weights0 = model0.state_dict()
173
+ weights1 = model1.state_dict()
174
+
175
+ weights0["lm_head.weight"] = (
176
+ alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"]
177
+ )
178
+ weights0["model.norm.weight"] = (
179
+ alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"]
180
+ )
181
+
182
+ tmp_model.load_state_dict(weights0)
183
+
184
+
185
+ def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True):
186
+ model1 = copy.deepcopy(model1)
187
+
188
+ if emb is True:
189
+ avg_embedding_layer(model0, model1, tmp_model, alpha=alpha)
190
+ else:
191
+ tmp_model.load_state_dict(model0.state_dict())
192
+ for i in range(n_blocks):
193
+ avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn)
194
+ if emb is True:
195
+ avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha)
196
+
197
+
198
+ def get_mlp_weights(model, i):
199
+ return model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
200
+
201
+
202
+ def get_emb_weights(model):
203
+ return model.state_dict()["model.embed_tokens.weight"]
204
+
205
+
206
+ def rotate_model(model, num_layers=32, hidden_dim=4096):
207
+
208
+ model.to("cuda")
209
+
210
+ rotation = ortho_group.rvs(dim=hidden_dim)
211
+ rotation = torch.tensor(rotation, dtype=torch.bfloat16).to("cuda")
212
+
213
+ weights = model.state_dict()
214
+ weights_rotated = model.state_dict()
215
+
216
+ weights_rotated["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"] @ rotation
217
+
218
+ for i in range(num_layers):
219
+
220
+ weights_rotated[f"model.layers.{i}.input_layernorm.weight"] = torch.ones(hidden_dim)
221
+ weights_rotated[f"model.layers.{i}.post_attention_layernorm.weight"] = torch.ones(
222
+ hidden_dim
223
+ )
224
+
225
+ weights_rotated[f"model.layers.{i}.self_attn.q_proj.weight"] = (
226
+ weights[f"model.layers.{i}.self_attn.q_proj.weight"]
227
+ @ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
228
+ @ rotation
229
+ )
230
+ weights_rotated[f"model.layers.{i}.self_attn.k_proj.weight"] = (
231
+ weights[f"model.layers.{i}.self_attn.k_proj.weight"]
232
+ @ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
233
+ @ rotation
234
+ )
235
+ weights_rotated[f"model.layers.{i}.self_attn.v_proj.weight"] = (
236
+ weights[f"model.layers.{i}.self_attn.v_proj.weight"]
237
+ @ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
238
+ @ rotation
239
+ )
240
+ weights_rotated[f"model.layers.{i}.self_attn.o_proj.weight"] = (
241
+ rotation.T @ weights[f"model.layers.{i}.self_attn.o_proj.weight"]
242
+ )
243
+
244
+ weights_rotated[f"model.layers.{i}.mlp.gate_proj.weight"] = (
245
+ weights[f"model.layers.{i}.mlp.gate_proj.weight"]
246
+ @ torch.diag(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
247
+ @ rotation
248
+ )
249
+ weights_rotated[f"model.layers.{i}.mlp.up_proj.weight"] = (
250
+ weights[f"model.layers.{i}.mlp.up_proj.weight"]
251
+ @ torch.diag(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
252
+ @ rotation
253
+ )
254
+ weights_rotated[f"model.layers.{i}.mlp.down_proj.weight"] = (
255
+ rotation.T @ weights[f"model.layers.{i}.mlp.down_proj.weight"]
256
+ )
257
+
258
+ weights_rotated["model.norm.weight"] = torch.ones(hidden_dim)
259
+ weights_rotated["lm_head.weight"] = (
260
+ weights["lm_head.weight"] @ torch.diag(weights["model.norm.weight"]) @ rotation
261
+ )
262
+
263
+ model.load_state_dict(weights_rotated)
model-tracing/tracing/utils/olmo/model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+
4
+ def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32):
5
+ permute_embedding_layer(model, tmp_model, emb_permutation)
6
+ for i in range(n_blocks):
7
+ permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation)
8
+ permute_output_layer(tmp_model, tmp_model, emb_permutation)
9
+
10
+
11
+ def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation):
12
+ weights = model.state_dict()
13
+
14
+ weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[
15
+ "model.layers." + str(i) + ".self_attn.q_proj.weight"
16
+ ][:, emb_permutation]
17
+ weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[
18
+ "model.layers." + str(i) + ".self_attn.k_proj.weight"
19
+ ][:, emb_permutation]
20
+ weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[
21
+ "model.layers." + str(i) + ".self_attn.v_proj.weight"
22
+ ][:, emb_permutation]
23
+ weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[
24
+ "model.layers." + str(i) + ".self_attn.o_proj.weight"
25
+ ][emb_permutation]
26
+
27
+ weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
28
+ "model.layers." + str(i) + ".mlp.gate_proj.weight"
29
+ ][mlp_permutation]
30
+ weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
31
+ "model.layers." + str(i) + ".mlp.up_proj.weight"
32
+ ][mlp_permutation]
33
+ weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
34
+ "model.layers." + str(i) + ".mlp.down_proj.weight"
35
+ ][:, mlp_permutation]
36
+
37
+ weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
38
+ "model.layers." + str(i) + ".mlp.gate_proj.weight"
39
+ ][:, emb_permutation]
40
+ weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
41
+ "model.layers." + str(i) + ".mlp.up_proj.weight"
42
+ ][:, emb_permutation]
43
+ weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
44
+ "model.layers." + str(i) + ".mlp.down_proj.weight"
45
+ ][emb_permutation]
46
+
47
+ tmp_model.load_state_dict(weights)
48
+
49
+
50
+ def permute_embedding_layer(model, tmp_model, emb_permutation):
51
+ weights = model.state_dict()
52
+
53
+ weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
54
+ tmp_model.load_state_dict(weights)
55
+
56
+
57
+ def permute_output_layer(model, tmp_model, emb_permutation):
58
+ weights = model.state_dict()
59
+
60
+ weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
61
+ tmp_model.load_state_dict(weights)
62
+
63
+
64
+ def permute_mlp_block(model, i, tmp_model, mlp_permutation):
65
+ weights = model.state_dict()
66
+
67
+ weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
68
+ "model.layers." + str(i) + ".mlp.gate_proj.weight"
69
+ ][mlp_permutation]
70
+ weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
71
+ "model.layers." + str(i) + ".mlp.up_proj.weight"
72
+ ][mlp_permutation]
73
+ weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
74
+ "model.layers." + str(i) + ".mlp.down_proj.weight"
75
+ ][:, mlp_permutation]
76
+
77
+ tmp_model.load_state_dict(weights)
78
+
79
+
80
+ def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5):
81
+ weights0 = model0.state_dict()
82
+ weights1 = model1.state_dict()
83
+
84
+ weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
85
+ alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
86
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
87
+ )
88
+ weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
89
+ alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
90
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
91
+ )
92
+ weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
93
+ alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
94
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
95
+ )
96
+
97
+ tmp_model.load_state_dict(weights0)
98
+
99
+
100
+ def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True):
101
+ weights0 = model0.state_dict()
102
+ weights1 = model1.state_dict()
103
+
104
+ if attn is True:
105
+ weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = (
106
+ alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"]
107
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"]
108
+ )
109
+ weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = (
110
+ alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"]
111
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"]
112
+ )
113
+ weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = (
114
+ alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"]
115
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"]
116
+ )
117
+ weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = (
118
+ alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"]
119
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"]
120
+ )
121
+
122
+ weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
123
+ alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
124
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
125
+ )
126
+ weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
127
+ alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
128
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
129
+ )
130
+ weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
131
+ alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
132
+ + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
133
+ )
134
+
135
+ tmp_model.load_state_dict(weights0)
136
+
137
+
138
+ def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5):
139
+ weights0 = model0.state_dict()
140
+ weights1 = model1.state_dict()
141
+
142
+ weights0["model.embed_tokens.weight"] = (
143
+ alpha * weights0["model.embed_tokens.weight"]
144
+ + (1 - alpha) * weights1["model.embed_tokens.weight"]
145
+ )
146
+
147
+ tmp_model.load_state_dict(weights0)
148
+
149
+
150
+ def avg_output_layer(model0, model1, tmp_model, alpha=0.5):
151
+ weights0 = model0.state_dict()
152
+ weights1 = model1.state_dict()
153
+
154
+ weights0["lm_head.weight"] = (
155
+ alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"]
156
+ )
157
+ weights0["model.norm.weight"] = (
158
+ alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"]
159
+ )
160
+
161
+ tmp_model.load_state_dict(weights0)
162
+
163
+
164
+ def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True):
165
+ model1 = copy.deepcopy(model1)
166
+
167
+ if emb is True:
168
+ avg_embedding_layer(model0, model1, tmp_model, alpha=alpha)
169
+ else:
170
+ tmp_model.load_state_dict(model0.state_dict())
171
+ for i in range(n_blocks):
172
+ avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn)
173
+ if emb is True:
174
+ avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha)
175
+
176
+
177
+ def get_mlp_weights(model, i):
178
+ return model.state_dict()["model.layers." + str(i) + ".intermediate.dense.weight"]
179
+
180
+
181
+ def get_emb_weights(model):
182
+ return model.state_dict()["model.embed_tokens.weight"]
model-tracing/tracing/utils/plot_metrics.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from yaml import Loader
3
+
4
+ import matplotlib.pyplot as plt
5
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
6
+ import numpy as np
7
+ import os
8
+ from scipy.stats import chi2
9
+
10
+ base = {
11
+ "lmsys/vicuna-7b-v1.5": 1,
12
+ "codellama/CodeLlama-7b-hf": 1,
13
+ "codellama/CodeLlama-7b-Python-hf": 1,
14
+ "codellama/CodeLlama-7b-Instruct-hf": 1,
15
+ "EleutherAI/llemma_7b": 1,
16
+ "microsoft/Orca-2-7b": 1,
17
+ "oh-yeontaek/llama-2-7B-LoRA-assemble": 1,
18
+ "lvkaokao/llama2-7b-hf-instruction-lora": 1,
19
+ "NousResearch/Nous-Hermes-llama-2-7b": 1,
20
+ "lmsys/vicuna-7b-v1.1": 0,
21
+ "yahma/llama-7b-hf": 0,
22
+ "Salesforce/xgen-7b-4k-base": 2,
23
+ "EleutherAI/llemma_7b_muinstruct_camelmath": 1,
24
+ "AlfredPros/CodeLlama-7b-Instruct-Solidity": 1,
25
+ "meta-llama/Llama-2-7b-hf": 1,
26
+ "LLM360/Amber": 3,
27
+ "LLM360/AmberChat": 3,
28
+ "openlm-research/open_llama_7b": 4,
29
+ "openlm-research/open_llama_7b_v2": 5,
30
+ "ibm-granite/granite-7b-base": 6,
31
+ "ibm-granite/granite-7b-instruct": 6,
32
+ }
33
+
34
+ base_ordered = {
35
+ "yahma/llama-7b-hf": 0,
36
+ "lmsys/vicuna-7b-v1.1": 0,
37
+ "meta-llama/Llama-2-7b-hf": 1,
38
+ "lmsys/vicuna-7b-v1.5": 1,
39
+ "codellama/CodeLlama-7b-hf": 1,
40
+ "codellama/CodeLlama-7b-Python-hf": 1,
41
+ "codellama/CodeLlama-7b-Instruct-hf": 1,
42
+ "AlfredPros/CodeLlama-7b-Instruct-Solidity": 1,
43
+ "EleutherAI/llemma_7b": 1,
44
+ "EleutherAI/llemma_7b_muinstruct_camelmath": 1,
45
+ "microsoft/Orca-2-7b": 1,
46
+ "oh-yeontaek/llama-2-7B-LoRA-assemble": 1,
47
+ "lvkaokao/llama2-7b-hf-instruction-lora": 1,
48
+ "NousResearch/Nous-Hermes-llama-2-7b": 1,
49
+ "Salesforce/xgen-7b-4k-base": 2,
50
+ "LLM360/Amber": 3,
51
+ "LLM360/AmberChat": 3,
52
+ "openlm-research/open_llama_7b": 4,
53
+ "openlm-research/open_llama_7b_v2": 5,
54
+ "ibm-granite/granite-7b-base": 6,
55
+ "ibm-granite/granite-7b-instruct": 6,
56
+ }
57
+
58
+ tree = {
59
+ "yahma/llama-7b-hf": "A---",
60
+ "lmsys/vicuna-7b-v1.1": "AA--",
61
+ "meta-llama/Llama-2-7b-hf": "B---",
62
+ "lmsys/vicuna-7b-v1.5": "BA--",
63
+ "codellama/CodeLlama-7b-hf": "BB--",
64
+ "codellama/CodeLlama-7b-Python-hf": "BBA-",
65
+ "codellama/CodeLlama-7b-Instruct-hf": "BBB-",
66
+ "AlfredPros/CodeLlama-7b-Instruct-Solidity": "BBBA",
67
+ "EleutherAI/llemma_7b": "BBC-",
68
+ "EleutherAI/llemma_7b_muinstruct_camelmath": "BBCA",
69
+ "microsoft/Orca-2-7b": "BC--",
70
+ "oh-yeontaek/llama-2-7B-LoRA-assemble": "BD--",
71
+ "lvkaokao/llama2-7b-hf-instruction-lora": "BE--",
72
+ "NousResearch/Nous-Hermes-llama-2-7b": "BF--",
73
+ "Salesforce/xgen-7b-4k-base": "C---",
74
+ "LLM360/Amber": "D---",
75
+ "LLM360/AmberChat": "DA--",
76
+ "openlm-research/open_llama_7b": "E---",
77
+ "openlm-research/open_llama_7b_v2": "F---",
78
+ "ibm-granite/granite-7b-base": "G---",
79
+ "ibm-granite/granite-7b-instruct": "GA--",
80
+ }
81
+
82
+
83
+ def get_dict_ft(flat_model_path):
84
+ dict_ft = {}
85
+
86
+ model_paths = yaml.load(open(flat_model_path, "r"), Loader=Loader)
87
+
88
+ for i in range(len(model_paths)):
89
+ for j in range(i + 1, len(model_paths)):
90
+ model_a = model_paths[i]
91
+ model_b = model_paths[j]
92
+
93
+ job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-")
94
+
95
+ dict_ft[job_id] = base[model_a] == base[model_b]
96
+
97
+ return dict_ft
98
+
99
+
100
+ def get_statistic_from_file(filename):
101
+
102
+ file = open(filename, "r")
103
+
104
+ lines = file.readlines()
105
+ stat = np.nan
106
+
107
+ for line in lines:
108
+ if "Namespace" in line and "non-aligned test stat" in line:
109
+ # dict = json.loads(line)
110
+ # print(dict)
111
+ # print(dict['non-aligned test stat'])
112
+
113
+ start1 = line.find("non-aligned test stat")
114
+ stat = line[line.find(":", start1) : line.find(",", start1)]
115
+ stat = stat.replace(" ", "")
116
+ stat = stat.replace("(", "")
117
+ stat = stat.replace(":", "")
118
+ stat = stat.replace("tensor", "")
119
+ stat = float(stat)
120
+
121
+ return stat
122
+
123
+
124
+ def get_l2_stat_from_file(filename):
125
+ file = open(filename, "r")
126
+
127
+ lines = file.readlines()
128
+ stats = []
129
+
130
+ for line in lines:
131
+ if len(line) >= 4 and line[4] == " ":
132
+ stats.append(line[:4])
133
+
134
+ return float(stats[-1])
135
+
136
+
137
+ def get_layer_statistic_from_file(filename, layer):
138
+ file = open(filename, "r")
139
+
140
+ lines = file.readlines()
141
+ stat = np.nan
142
+
143
+ for line in lines:
144
+ temp = str(layer) + " "
145
+ if line[: len(temp)] == temp:
146
+ stat = line[line.find("0.") :]
147
+ if "e" in line:
148
+ stat = 0
149
+ # print(layer, stat)
150
+ stat = float(stat)
151
+
152
+ return stat
153
+
154
+
155
+ def plot_statistic_scatter(results_path, dict_ft, plot_path):
156
+
157
+ x = []
158
+ y = []
159
+
160
+ dir_list = os.listdir(results_path)
161
+ for file in dir_list:
162
+ models = file[: file.find(".out")]
163
+ if "huggyllama" in models:
164
+ continue
165
+ print(models)
166
+ ft = int(dict_ft[models])
167
+ stat = get_l2_stat_from_file(results_path + "/" + file)
168
+ # stat = get_statistic_from_file(results_path + '/' + file)
169
+ if not np.isnan(stat):
170
+ y.append(ft)
171
+ # x.append(get_statistic_from_file(results_path + '/' + file))
172
+ x.append(get_l2_stat_from_file(results_path + "/" + file))
173
+
174
+ plt.figure(figsize=(10, 1))
175
+
176
+ plt.scatter(x, y, s=8)
177
+
178
+ plt.xlabel("$p$-value")
179
+ plt.ylabel("Fine-tuned")
180
+ plt.ylim(-0.5, 1.5)
181
+ # plt.title(f"{}")
182
+ plot_filename = f"{plot_path}.png"
183
+
184
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
185
+ plt.close()
186
+
187
+
188
+ def plot_statistic_grid(results_path, dict_base, title, plot_path, decimals, log=False):
189
+ models = list(dict_base.keys())
190
+ print(models)
191
+
192
+ data = np.full((len(models), len(models)), np.nan)
193
+
194
+ for i in range(len(models)):
195
+ for j in range(len(models)):
196
+ model_a = models[i]
197
+ model_b = models[j]
198
+
199
+ job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out"
200
+
201
+ if not os.path.exists(results_path + "/" + job_id):
202
+ continue
203
+ print(job_id)
204
+
205
+ stat = get_statistic_from_file(results_path + "/" + job_id)
206
+ # stat = get_l2_stat_from_file(results_path + '/' + job_id)
207
+
208
+ if log:
209
+ stat = np.log(stat)
210
+
211
+ data[i][j] = np.round(stat, decimals=decimals)
212
+ data[j][i] = data[i][j]
213
+
214
+ fig, ax = plt.subplots()
215
+ fig.set_size_inches(20, 20)
216
+ im = ax.imshow(data, cmap="viridis")
217
+
218
+ _ = make_axes_locatable(ax)
219
+ _ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
220
+
221
+ # Show all ticks and label them with the respective list entries
222
+ ax.set_xticks(np.arange(len(models)), labels=models)
223
+ ax.set_yticks(np.arange(len(models)), labels=models)
224
+
225
+ # Rotate the tick labels and set their alignment.
226
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
227
+
228
+ texts = []
229
+ for i in range(len(models)):
230
+ text1 = []
231
+ for j in range(len(models)):
232
+ text1.append("")
233
+ texts.append(text1)
234
+
235
+ # Loop over data dimensions and create text annotations.
236
+ for i in range(len(models)):
237
+ for j in range(len(models)):
238
+ texts[i][j] = str(data[i][j])
239
+ if data[i][j] == 0.0:
240
+ texts[i][j] = "$\\varepsilon$"
241
+ _ = ax.text(j, i, texts[i][j], ha="center", va="center", color="w")
242
+
243
+ ax.set_title(title)
244
+ fig.tight_layout()
245
+ plot_filename = f"{plot_path}.png"
246
+
247
+ plt.savefig(plot_filename, dpi=500, bbox_inches="tight")
248
+ plt.close()
249
+
250
+
251
+ def plot_statistic_scatter_layer(results_path, dict_ft, plot_path, layer):
252
+
253
+ x = []
254
+ y = []
255
+
256
+ dir_list = os.listdir(results_path)
257
+ for file in dir_list:
258
+ models = file[: file.find(".out")]
259
+ if "huggyllama" in models:
260
+ continue
261
+ print(models)
262
+ ft = int(dict_ft[models])
263
+ stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
264
+ if not np.isnan(stat):
265
+ x.append(ft)
266
+ y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
267
+
268
+ plt.figure(figsize=(8, 6))
269
+
270
+ plt.scatter(x, y, s=2)
271
+
272
+ plt.xlabel("Fine tuned")
273
+ plt.ylabel("Test statistic")
274
+ # plt.title(f"{}")
275
+ plot_filename = f"{plot_path}.png"
276
+
277
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
278
+ plt.close()
279
+
280
+
281
+ def plot_statistic_grid_layer(
282
+ results_path, dict_base, title, plot_path, decimals, layer, log=False
283
+ ):
284
+ models = list(dict_base.keys())
285
+ print(models)
286
+
287
+ data = np.full((len(models), len(models)), np.nan)
288
+
289
+ for i in range(len(models)):
290
+ for j in range(len(models)):
291
+ model_a = models[i]
292
+ model_b = models[j]
293
+
294
+ job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out"
295
+
296
+ if not os.path.exists(results_path + "/" + job_id):
297
+ continue
298
+
299
+ stat = get_layer_statistic_from_file(results_path + "/" + job_id, layer)
300
+
301
+ if log:
302
+ stat = np.log(stat)
303
+
304
+ data[i][j] = np.round(stat, decimals=decimals)
305
+ data[j][i] = data[i][j]
306
+
307
+ fig, ax = plt.subplots()
308
+ fig.set_size_inches(20, 20)
309
+ im = ax.imshow(data)
310
+
311
+ _ = make_axes_locatable(ax)
312
+ _ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
313
+
314
+ # Show all ticks and label them with the respective list entries
315
+ ax.set_xticks(np.arange(len(models)), labels=models)
316
+ ax.set_yticks(np.arange(len(models)), labels=models)
317
+
318
+ # Rotate the tick labels and set their alignment.
319
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
320
+
321
+ # Loop over data dimensions and create text annotations.
322
+ for i in range(len(models)):
323
+ for j in range(len(models)):
324
+ _ = ax.text(j, i, data[i, j], ha="center", va="center", color="w")
325
+
326
+ ax.set_title(title)
327
+ fig.tight_layout()
328
+ plot_filename = f"{plot_path}.png"
329
+
330
+ plt.savefig(plot_filename, dpi=500, bbox_inches="tight")
331
+ plt.close()
332
+
333
+
334
+ def plot_histogram(results_path, dict_ft, plot_path):
335
+ indp = []
336
+ not_indp = []
337
+
338
+ dir_list = os.listdir(results_path)
339
+ for file in dir_list:
340
+ models = file[: file.find(".out")]
341
+ print(models)
342
+ ft = int(dict_ft[models])
343
+ stat = get_statistic_from_file(results_path + "/" + file)
344
+ if not np.isnan(stat):
345
+ if ft:
346
+ not_indp.append(stat)
347
+ else:
348
+ indp.append(stat)
349
+
350
+ plt.figure(figsize=(8, 6))
351
+
352
+ plt.hist(indp, bins=20, range=(0, 1), color="blue")
353
+ plt.hist(not_indp, bins=20, range=(0, 1), color="green")
354
+
355
+ plt.xlabel("Test statistic value")
356
+ plt.ylabel("Count")
357
+ # plt.title(f"{}")
358
+ plot_filename = f"{plot_path}.png"
359
+
360
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
361
+ plt.close()
362
+
363
+
364
+ def fisher(pvalues):
365
+ chi_squared = 0
366
+ num_layers = 0
367
+ for pvalue in pvalues:
368
+ if not np.isnan(pvalue):
369
+ chi_squared -= 2 * np.log(pvalue)
370
+ num_layers += 1
371
+
372
+ return chi2.sf(chi_squared, df=2 * num_layers)
373
+
374
+
375
+ def plot_statistic_scatter_all_layers(results_path, dict_ft, plot_path):
376
+
377
+ x = []
378
+ y = []
379
+ c = []
380
+
381
+ dir_list = os.listdir(results_path)
382
+
383
+ for layer in range(32):
384
+
385
+ for file in dir_list:
386
+ models = file[: file.find(".out")]
387
+ # if("huggyllama" in models): continue
388
+ print(models)
389
+ ft = int(dict_ft[models])
390
+ stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
391
+ if not np.isnan(stat):
392
+ x.append(layer)
393
+ y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
394
+ if ft:
395
+ c.append("r")
396
+ else:
397
+ c.append("b")
398
+
399
+ for file in dir_list:
400
+ models = file[: file.find(".out")]
401
+ # if("huggyllama" in models): continue
402
+ ft = int(dict_ft[models])
403
+ stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
404
+ if not np.isnan(stat):
405
+ x.append(layer)
406
+ y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
407
+ if ft:
408
+ c.append("r")
409
+ else:
410
+ c.append("b")
411
+
412
+ plt.figure(figsize=(8, 6))
413
+
414
+ plt.scatter(x, y, s=1.5, c=c)
415
+
416
+ plt.xlabel("Layer")
417
+ plt.ylabel("p-value")
418
+ # plt.title(f"{}")
419
+ plot_filename = f"{plot_path}.png"
420
+
421
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
422
+ plt.close()
423
+
424
+
425
+ def plot_pvalue(results_path, dict_ft, plot_path):
426
+
427
+ pvalues = []
428
+
429
+ dir_list = os.listdir(results_path)
430
+
431
+ for layer in range(32):
432
+
433
+ for file in dir_list:
434
+ models = file[: file.find(".out")]
435
+ if "huggyllama" in models:
436
+ continue
437
+ print(models)
438
+ ft = int(dict_ft[models])
439
+ if ft is True:
440
+ continue
441
+ stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
442
+ if not np.isnan(stat):
443
+ pvalues.append(stat)
444
+ x = np.arange(0, 1, step=0.001)
445
+ y = []
446
+ print(pvalues)
447
+ print(len(pvalues))
448
+
449
+ for i in x:
450
+ counter = 0
451
+ for val in pvalues:
452
+ if val < i:
453
+ counter += 1
454
+ y.append(counter / len(pvalues))
455
+
456
+ plt.figure(figsize=(8, 6))
457
+
458
+ plt.plot(x, y, ".-")
459
+
460
+ # plt.xlabel("Fine tuned")
461
+ # plt.ylabel("Test statistic")
462
+ # plt.title(f"{}")
463
+ # plt.xlim(-10,0)
464
+ # plt.ylim(-10,0)
465
+ plot_filename = f"{plot_path}.png"
466
+
467
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
468
+ plt.close()
469
+
470
+
471
+ if __name__ == "__main__":
472
+ dict_ft = get_dict_ft("/nlp/u/salzhu/model-tracing/config/llama_flat.yaml")
473
+
474
+ # plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs",
475
+ # dict_ft, "test_statistic_plots/l2_pvalue_horizontal")
476
+
477
+ # plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
478
+ # base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)",
479
+ # "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap",
480
+ # 3, log=False)
481
+
482
+ # plot_statistic_grid("/juice4/scr4/nlp/model-tracing/csh_0928_reruns/logs",
483
+ # base_ordered, "",
484
+ # "/nlp/u/salzhu/csh_0929_cols",
485
+ # 3, log=False)
486
+
487
+ # plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
488
+ # base_ordered, "",
489
+ # "/nlp/u/salzhu/robust_0929",
490
+ # 3, log=False)
491
+
492
+ # plot_statistic_grid("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs",
493
+ # base_ordered, "",
494
+ # "/nlp/u/salzhu/l2_0927",
495
+ # 3, log=False)
496
+
497
+ # plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft,
498
+ # "/nlp/u/salzhu/test_statistic_plots/mlp_sp_final")
499
+
500
+ # plot_statistic_scatter_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft,
501
+ # "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_layer31", 31)
502
+
503
+ # plot_statistic_grid_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
504
+ # base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)",
505
+ # "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap_layer31",
506
+ # 3, 31, log=False)
507
+
508
+ # plot_statistic_scatter_all_layers("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
509
+ # dict_ft,
510
+ # "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_all_layers")
511
+
512
+ plot_pvalue(
513
+ "/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
514
+ dict_ft,
515
+ "/nlp/u/salzhu/test_statistic_plots/mlp_sp_final",
516
+ )
517
+
518
+ # plot_histogram("/juice4/scr4/nlp/model-tracing/mlp_match_med_max_layer0/logs",
519
+ # dict_ft, "/nlp/u/salzhu/test_statistic_plots/mlp_med_max_histogram")
520
+
521
+ # checkpoints = {
522
+ # "100M": 1e8,
523
+ # "1B": 1e9,
524
+ # "10B": 1e10,
525
+ # "18B": 1.8e10,
526
+ # }
527
+
528
+ # checkpoints = {
529
+ # "100M": 1e8,
530
+ # "1B": 1e9,
531
+ # "4B": 4e9,
532
+ # "8B": 8e9,
533
+ # "16B": 1.6e10
534
+ # }
535
+
536
+ # checkpoints = {
537
+ # "100M": 1e8,
538
+ # "1B": 1e9,
539
+ # "12B": 1.2e10,
540
+ # "25B": 2.5e10
541
+ # }
542
+
543
+ # plot_statistic_olmo_scatter("/juice4/scr4/nlp/model-tracing/olmo_models_runs/final_checkpoint/csw_robust_cols/logs", checkpoints,
544
+ # "final checkpoint vs. additional training seed 42", "CSW robust",
545
+ # "/nlp/u/salzhu/olmo_plots/final_checkpoint/csw_robust_cols_seed42")
model-tracing/tracing/utils/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import random
5
+ import os
6
+ from scipy.stats import chi2
7
+
8
+
9
+ def manual_seed(seed, fix_cudnn=True):
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ torch.cuda.manual_seed_all(seed)
14
+ os.environ["PYTHONHASHSEED"] = str(seed)
15
+ if fix_cudnn:
16
+ torch.backends.cudnn.deterministic = True # noqa
17
+ torch.backends.cudnn.benchmark = False # noqa
18
+
19
+
20
+ def spcor(x, y):
21
+ n = len(x)
22
+ with torch.no_grad():
23
+ r = 1 - torch.sum(6 * torch.square(x - y)) / (n * (n**2 - 1))
24
+
25
+ return r
26
+
27
+
28
+ def pdists(x, y):
29
+ x = x.to("cuda")
30
+ y = y.to("cuda")
31
+
32
+ with torch.no_grad():
33
+ xsum = torch.sum(torch.square(x), axis=-1)
34
+ ysum = torch.sum(torch.square(y), axis=-1)
35
+
36
+ dists = xsum.view(-1, 1) + ysum.view(1, -1) - 2 * x @ y.T
37
+
38
+ return dists.cpu()
39
+
40
+
41
+ def cossim(x, y):
42
+ x = x.to("cuda")
43
+ y = y.to("cuda")
44
+
45
+ with torch.no_grad():
46
+ similarities = (
47
+ x
48
+ @ y.T
49
+ / (
50
+ torch.linalg.norm(x, axis=-1).view(-1, 1)
51
+ * torch.linalg.norm(y, axis=-1).view(1, -1)
52
+ )
53
+ )
54
+
55
+ return similarities.cpu()
56
+
57
+
58
+ def fisher(p):
59
+ count = 0
60
+ chi_2 = 0
61
+ for pvalue in p:
62
+ if not np.isnan(pvalue):
63
+ chi_2 -= 2 * np.log(pvalue)
64
+ count += 1
65
+
66
+ return chi2.sf(chi_2, df=2 * count)
67
+
68
+
69
+ def normalize_mc_midpoint(mid, base, ft):
70
+ slope = ft - base
71
+ mid -= slope * 0.5
72
+ mid -= base
73
+ return mid
74
+
75
+
76
+ def normalize_trace(trace, alphas):
77
+ slope = trace[-1] - trace[0]
78
+ start = trace[0]
79
+ for i in range(len(trace)):
80
+ trace[i] -= slope * alphas[i]
81
+ trace[i] -= start
82
+ return trace
83
+
84
+
85
+ def output_hook(m, inp, op, name, feats):
86
+ feats[name] = op.detach()
87
+
88
+
89
+ def get_submodule(module, submodule_string):
90
+ attributes = submodule_string.split(".")
91
+ for attr in attributes:
92
+ module = getattr(module, attr)
93
+ return module
94
+
95
+
96
+ def plot_trace(losses, alphas, normalize, model_a_name, model_b_name, plot_path):
97
+
98
+ plt.figure(figsize=(8, 6))
99
+ if normalize:
100
+ losses = normalize_trace(losses, alphas)
101
+ plt.plot(alphas, losses, "o-")
102
+
103
+ plt.xlabel("Alpha")
104
+ plt.ylabel("Loss")
105
+ plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
106
+ plot_filename = f"{plot_path}.png"
107
+
108
+ plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
109
+ plt.close()