Spaces:
Runtime error
Runtime error
Ahmed Ahmed
commited on
Commit
·
de071e9
1
Parent(s):
1dd4b6a
Add model-tracing code for p-value computation (without binary files)
Browse files- model-tracing +0 -1
- model-tracing/.pre-commit-config.yaml +29 -0
- model-tracing/README.md +138 -0
- model-tracing/config/dolma_pl.yaml +9 -0
- model-tracing/config/llama3.yaml +5 -0
- model-tracing/config/llama70b.yaml +5 -0
- model-tracing/config/llama7b.yaml +22 -0
- model-tracing/config/llama7b_split.yaml +24 -0
- model-tracing/config/llama7b_tree.yaml +31 -0
- model-tracing/config/m2d2.yaml +34 -0
- model-tracing/experiments/csw_full.py +111 -0
- model-tracing/experiments/faiss/csw_faiss.py +280 -0
- model-tracing/experiments/generalized_match.py +343 -0
- model-tracing/experiments/huref.py +140 -0
- model-tracing/experiments/localized_testing.py +144 -0
- model-tracing/launch.py +87 -0
- model-tracing/main.py +246 -0
- model-tracing/requirements-dev.txt +5 -0
- model-tracing/requirements.txt +17 -0
- model-tracing/results/jsd/model_pairs_jsd.csv +37 -0
- model-tracing/results/l2/model_pairs_l2.csv +37 -0
- model-tracing/results/perm/permutation_l2_updated_midpoint_wikitext_single.csv +21 -0
- model-tracing/results/perm/permutation_loss_midpoint_wikitext_single.csv +21 -0
- model-tracing/results/perm/permutation_norm_loss_midpoint_wikitext_single.csv +21 -0
- model-tracing/scripts/docs/doc_trace.py +204 -0
- model-tracing/scripts/docs/launch.py +58 -0
- model-tracing/scripts/docs/m2d_trace.py +279 -0
- model-tracing/scripts/mode/main.py +231 -0
- model-tracing/scripts/mode/mode_connectivity_metrics.py +173 -0
- model-tracing/scripts/perm/main.py +47 -0
- model-tracing/scripts/robust/pythia.py +83 -0
- model-tracing/tracing/__init__.py +0 -0
- model-tracing/tracing/perm/permute.py +119 -0
- model-tracing/tracing/statistics/__init__.py +0 -0
- model-tracing/tracing/statistics/csh.py +172 -0
- model-tracing/tracing/statistics/csu.py +168 -0
- model-tracing/tracing/statistics/jsd.py +195 -0
- model-tracing/tracing/statistics/l2.py +71 -0
- model-tracing/tracing/statistics/match.py +234 -0
- model-tracing/tracing/statistics/mc.py +12 -0
- model-tracing/tracing/statistics/perm_mc_l2.py +48 -0
- model-tracing/tracing/utils/__init__.py +0 -0
- model-tracing/tracing/utils/evaluate.py +209 -0
- model-tracing/tracing/utils/llama/matching.py +45 -0
- model-tracing/tracing/utils/llama/model.py +263 -0
- model-tracing/tracing/utils/olmo/model.py +182 -0
- model-tracing/tracing/utils/plot_metrics.py +545 -0
- model-tracing/tracing/utils/utils.py +109 -0
model-tracing
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit 9eb3b67655be2a3576348a6d482e69c62f72fc3e
|
|
|
|
model-tracing/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v4.5.0
|
4 |
+
hooks:
|
5 |
+
- id: trailing-whitespace
|
6 |
+
- id: end-of-file-fixer
|
7 |
+
- id: check-yaml
|
8 |
+
- id: check-added-large-files
|
9 |
+
- id: check-merge-conflict
|
10 |
+
|
11 |
+
- repo: https://github.com/psf/black
|
12 |
+
rev: 24.1.1
|
13 |
+
hooks:
|
14 |
+
- id: black
|
15 |
+
language_version: python3
|
16 |
+
args: [--line-length=100]
|
17 |
+
|
18 |
+
#- repo: https://github.com/charliermarsh/ruff-pre-commit
|
19 |
+
# rev: 'v0.1.8'
|
20 |
+
# hooks:
|
21 |
+
# - id: ruff
|
22 |
+
# args: [--fix, --exit-non-zero-on-fix, --line-length=100, --ignore=E402,E731,F841,F811,F821]
|
23 |
+
|
24 |
+
- repo: https://github.com/nbQA-dev/nbQA
|
25 |
+
rev: 1.7.1
|
26 |
+
hooks:
|
27 |
+
- id: nbqa-black
|
28 |
+
additional_dependencies: [black==24.1.1]
|
29 |
+
args: [--line-length=100]
|
model-tracing/README.md
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# LLM Model Tracing
|
3 |
+
This repository investigates model tracing in large language models (LLMs).
|
4 |
+
|
5 |
+
Specifically, given a base LLM and a fine-tuned LLM, this code provides functionality to:
|
6 |
+
|
7 |
+
- Permute the weights of one model (either MLP or embedding weights).
|
8 |
+
- Align the weights of the fine-tuned model to the base model using the Hungarian algorithm.
|
9 |
+
- Evaluate the effect of weight permutation and alignment on different statistics:
|
10 |
+
- Mode connectivity
|
11 |
+
- Cosine similarity
|
12 |
+
- Embedding similarity
|
13 |
+
- Evaluate the perplexity of the base and fine-tuned models on a given dataset.
|
14 |
+
|
15 |
+
## Requirements
|
16 |
+
|
17 |
+
Install the necessary packages using:
|
18 |
+
|
19 |
+
```bash
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
For development, install the development dependencies:
|
24 |
+
|
25 |
+
```bash
|
26 |
+
pip install -r requirements-dev.txt
|
27 |
+
```
|
28 |
+
|
29 |
+
### Code Formatting with pre-commit
|
30 |
+
|
31 |
+
This repository uses pre-commit hooks to ensure code quality and consistency.
|
32 |
+
|
33 |
+
1. Install pre-commit:
|
34 |
+
|
35 |
+
```bash
|
36 |
+
pip install pre-commit
|
37 |
+
```
|
38 |
+
|
39 |
+
2. Set up the pre-commit hooks:
|
40 |
+
|
41 |
+
```bash
|
42 |
+
pre-commit install
|
43 |
+
```
|
44 |
+
|
45 |
+
3. (Optional) Run pre-commit on all files:
|
46 |
+
|
47 |
+
```bash
|
48 |
+
pre-commit run --all-files
|
49 |
+
```
|
50 |
+
|
51 |
+
Pre-commit will automatically run on staged files when you commit changes, applying:
|
52 |
+
- Black for code formatting
|
53 |
+
- Ruff for linting and fixing common issues
|
54 |
+
- nbQA for notebook formatting
|
55 |
+
- Various file checks (trailing whitespace, YAML validity, etc.)
|
56 |
+
|
57 |
+
## Usage
|
58 |
+
|
59 |
+
The repository provides three main scripts:
|
60 |
+
|
61 |
+
- `main.py`: Executes the main experiment pipeline for model tracing.
|
62 |
+
- `launch.py`: Launches multiple experiments in parallel using slurm.
|
63 |
+
|
64 |
+
### `main.py`
|
65 |
+
|
66 |
+
This script performs the following steps:
|
67 |
+
|
68 |
+
1. Loads the base and fine-tuned LLMs.
|
69 |
+
2. Optionally permutes the weights of the fine-tuned model.
|
70 |
+
3. Calculates the selected statistic for the non-aligned models.
|
71 |
+
4. Optionally aligns the weights of the fine-tuned model to the base model.
|
72 |
+
5. Calculates the selected statistic for the aligned models.
|
73 |
+
6. Optionally evaluates the perplexity of the base and fine-tuned models.
|
74 |
+
7. Saves the results to a pickle file.
|
75 |
+
|
76 |
+
The script accepts various command-line arguments:
|
77 |
+
|
78 |
+
- `--base_model_id`: HuggingFace model ID for the base model.
|
79 |
+
- `--ft_model_id`: HuggingFace model ID for the fine-tuned model.
|
80 |
+
- `--permute`: Whether to permute the weights of the fine-tuned model.
|
81 |
+
- `--align`: Whether to align the weights of the fine-tuned model to the base model.
|
82 |
+
- `--dataset_id`: HuggingFace dataset ID for perplexity evaluation.
|
83 |
+
- `--stat`: Statistic to calculate (options: "mode", "cos", "emb").
|
84 |
+
- csu: cosine similarity of weights statistic (on MLP up projection matrices) w/ Spearman correlation
|
85 |
+
- csu_all: csu on all pairs of parameters with equal shape
|
86 |
+
- csh: cosine similarity of MLP activations statistic w/ Spearman correlation
|
87 |
+
- match: unconstrained statistic (match) with permutation matching of MLP activations
|
88 |
+
- match_all: unconstrained statistic (match) on all pairs of MLP block activations
|
89 |
+
- `--attn`: Whether to consider attention weights in the "mode" statistic.
|
90 |
+
- `--emb`: Whether to consider embedding weights in the "mode" statistic.
|
91 |
+
- `--eval`: Whether to evaluate perplexity.
|
92 |
+
- `--save`: Path to save the results pickle file.
|
93 |
+
|
94 |
+
Example usage:
|
95 |
+
|
96 |
+
```bash
|
97 |
+
python main.py --base_model_id meta-llama/Llama-2-7b-hf --ft_model_id lmsys/vicuna-7b-v1.5 --stat csu --save results.p
|
98 |
+
```
|
99 |
+
|
100 |
+
```bash
|
101 |
+
python main.py --base_model_id meta-llama/Llama-2-7b-hf --ft_model_id lmsys/vicuna-7b-v1.5 --permute --align --dataset wikitext --stat match --attn --save results.p
|
102 |
+
```
|
103 |
+
|
104 |
+
### `launch.py`
|
105 |
+
|
106 |
+
This script launches multiple experiments in parallel using slurm. It reads model IDs from a YAML file and runs `main.py` for each pair of base and fine-tuned models. Use the flag --flat all (defaulted) to run on all pairs of models from a YAML (see config/llama7b.yaml); or, --flat split to run on all pairs of a 'base' model with a 'finetuned' model (see config/llama7b_split.yaml); or --flat specified to run on a specified list of pairs of models.
|
107 |
+
|
108 |
+
## Configuration
|
109 |
+
|
110 |
+
The `model-tracing/config/model_list.yaml` file defines the base and fine-tuned models for the experiments.
|
111 |
+
## Data
|
112 |
+
|
113 |
+
The code downloads and uses the Wikitext 103 dataset for perplexity evaluation.
|
114 |
+
|
115 |
+
## Results
|
116 |
+
|
117 |
+
The results of the experiments are saved as pickle files. The files contain dictionaries with the following keys:
|
118 |
+
|
119 |
+
- `args`: Command-line arguments used for the experiment.
|
120 |
+
- `commit`: Git commit hash of the code used for the experiment.
|
121 |
+
- `non-aligned test stat`: Value of the selected statistic for the non-aligned models.
|
122 |
+
- `aligned test stat`: Value of the selected statistic for the aligned models (if `--align` is True).
|
123 |
+
- `base loss`: Perplexity of the base model on the evaluation dataset (if `--eval` is True).
|
124 |
+
- `ft loss`: Perplexity of the fine-tuned model on the evaluation dataset (if `--eval` is True).
|
125 |
+
- `time`: Total execution time of the experiment.
|
126 |
+
|
127 |
+
## Sample commands
|
128 |
+
|
129 |
+
### 70B runs
|
130 |
+
```
|
131 |
+
python main.py --base_model_id meta-llama/Llama-2-70b-hf --ft_model_id meta-llama/Meta-Llama-3-70B --stat csu
|
132 |
+
```
|
133 |
+
|
134 |
+
# Experiments
|
135 |
+
|
136 |
+
Relevant scripts for running additional experiments described in our paper are in this folder. For example, there are experiments on retraining MLP blocks and evaluating our statistics.
|
137 |
+
|
138 |
+
These include `experiments/localized_testing.py` (Section 3.2.1) for fine-grained forensics and layer-matching between two models; `experiments/csu_full.py` (Section 3.2.1) for full parameter-matching between any two model architectures for hybrid models; `experiments/generalized_match.py` (Section 2.3.2, 3.2.3, 3.2.4) for the generalized robust test that involes retraining or distilling GLU MLPs; and `experiments/huref.py` (Appendix F) where we reproduce and break the invariants from a related work (Zeng et al. 2024).
|
model-tracing/config/dolma_pl.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Dolma Programming Languages
|
2 |
+
save_dir: "/nlp/scr/ahmedah/trace_results"
|
3 |
+
dolma_json_dir: "/juice4/scr4/nlp/model-tracing/dolma_program_languages/json_files"
|
4 |
+
datasets:
|
5 |
+
- "cpp"
|
6 |
+
- "python"
|
7 |
+
- "js"
|
8 |
+
model_architectures:
|
9 |
+
- "llama"
|
model-tracing/config/llama3.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
- meta-llama/Meta-Llama-3-8B
|
3 |
+
- meta-llama/Meta-Llama-3-8B-Instruct
|
4 |
+
- meta-llama/Meta-Llama-3.1-8B
|
5 |
+
- meta-llama/Meta-Llama-3.1-8B-Instruct
|
model-tracing/config/llama70b.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
- meta-llama/Llama-2-70b-hf
|
3 |
+
- meta-llama/Llama-3.1-70B
|
4 |
+
- 152334H/miqu-1-70b-sf
|
5 |
+
- Writer/Palmyra-Fin-70B-32K
|
model-tracing/config/llama7b.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
- lmsys/vicuna-7b-v1.5
|
3 |
+
- codellama/CodeLlama-7b-hf
|
4 |
+
- codellama/CodeLlama-7b-Python-hf
|
5 |
+
- codellama/CodeLlama-7b-Instruct-hf
|
6 |
+
- EleutherAI/llemma_7b
|
7 |
+
- microsoft/Orca-2-7b
|
8 |
+
- oh-yeontaek/llama-2-7B-LoRA-assemble
|
9 |
+
- lvkaokao/llama2-7b-hf-instruction-lora
|
10 |
+
- NousResearch/Nous-Hermes-llama-2-7b
|
11 |
+
- lmsys/vicuna-7b-v1.1
|
12 |
+
- yahma/llama-7b-hf
|
13 |
+
- Salesforce/xgen-7b-4k-base
|
14 |
+
- EleutherAI/llemma_7b_muinstruct_camelmath
|
15 |
+
- AlfredPros/CodeLlama-7b-Instruct-Solidity
|
16 |
+
- meta-llama/Llama-2-7b-hf
|
17 |
+
- LLM360/Amber
|
18 |
+
- LLM360/AmberChat
|
19 |
+
- openlm-research/open_llama_7b
|
20 |
+
- openlm-research/open_llama_7b_v2
|
21 |
+
- ibm-granite/granite-7b-base
|
22 |
+
- ibm-granite/granite-7b-instruct
|
model-tracing/config/llama7b_split.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
ft_models:
|
3 |
+
- lmsys/vicuna-7b-v1.5
|
4 |
+
- codellama/CodeLlama-7b-hf
|
5 |
+
- codellama/CodeLlama-7b-Python-hf
|
6 |
+
- codellama/CodeLlama-7b-Instruct-hf
|
7 |
+
- EleutherAI/llemma_7b
|
8 |
+
- microsoft/Orca-2-7b
|
9 |
+
- oh-yeontaek/llama-2-7B-LoRA-assemble
|
10 |
+
- lvkaokao/llama2-7b-hf-instruction-lora
|
11 |
+
- NousResearch/Nous-Hermes-llama-2-7b
|
12 |
+
- lmsys/vicuna-7b-v1.1
|
13 |
+
- EleutherAI/llemma_7b_muinstruct_camelmath
|
14 |
+
- AlfredPros/CodeLlama-7b-Instruct-Solidity
|
15 |
+
- LLM360/AmberChat
|
16 |
+
- ibm-granite/granite-7b-instruct
|
17 |
+
base_models:
|
18 |
+
- meta-llama/Llama-2-7b-hf
|
19 |
+
- yahma/llama-7b-hf
|
20 |
+
- Salesforce/xgen-7b-4k-base
|
21 |
+
- LLM360/Amber
|
22 |
+
- openlm-research/open_llama_7b
|
23 |
+
- openlm-research/open_llama_7b_v2
|
24 |
+
- ibm-granite/granite-7b-base
|
model-tracing/config/llama7b_tree.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# base llama1
|
3 |
+
- yahma/llama-7b-hf:
|
4 |
+
- lmsys/vicuna-7b-v1.1
|
5 |
+
# base llama2
|
6 |
+
- meta-llama/Llama-2-7b-hf:
|
7 |
+
- lmsys/vicuna-7b-v1.5
|
8 |
+
- codellama/CodeLlama-7b-hf:
|
9 |
+
# code llama family
|
10 |
+
- codellama/CodeLlama-7b-Python-hf
|
11 |
+
- codellama/CodeLlama-7b-Instruct-hf:
|
12 |
+
# code llama instruct fine tune
|
13 |
+
- AlfredPros/CodeLlama-7b-Instruct-Solidity
|
14 |
+
- EleutherAI/llemma_7b:
|
15 |
+
# llemma fine tune
|
16 |
+
- EleutherAI/llemma_7b_muinstruct_camelmath
|
17 |
+
- microsoft/Orca-2-7b
|
18 |
+
- oh-yeontaek/llama-2-7B-LoRA-assemble
|
19 |
+
- lvkaokao/llama2-7b-hf-instruction-lora
|
20 |
+
- NousResearch/Nous-Hermes-llama-2-7b
|
21 |
+
|
22 |
+
# independent llama models
|
23 |
+
- Salesforce/xgen-7b-4k-base
|
24 |
+
- LLM360/Amber:
|
25 |
+
- LLM360/AmberChat
|
26 |
+
# two open llama models trained independently
|
27 |
+
- openlm-research/open_llama_7b
|
28 |
+
- openlm-research/open_llama_7b_v2
|
29 |
+
- ibm-granite/granite-7b-base:
|
30 |
+
# granite instruct variant
|
31 |
+
- ibm-granite/granite-7b-instruct
|
model-tracing/config/m2d2.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for M2D2 datasets
|
2 |
+
save_dir: "/nlp/scr/ahmedah/trace_results"
|
3 |
+
m2d2_json_dir: "/juice4/scr4/nlp/model-tracing/m2d2_s2orc"
|
4 |
+
datasets:
|
5 |
+
# These categories are derived from the ArXiv ontology
|
6 |
+
- "AI" # Artificial Intelligence
|
7 |
+
- "CV" # Computer Vision
|
8 |
+
- "ET" # Emerging Technologies
|
9 |
+
- "IM" # Information Management
|
10 |
+
- "mtrl-sci" # Materials Science
|
11 |
+
- "stat-mech" # Statistical Mechanics
|
12 |
+
- "AR" # Architecture
|
13 |
+
- "CY" # Cryptography and Security
|
14 |
+
- "IR" # Information Retrieval
|
15 |
+
- "NA" # Numerical Analysis
|
16 |
+
- "str-el" # Strongly Correlated Electrons
|
17 |
+
# Additional ArXiv categories can be added here
|
18 |
+
|
19 |
+
# These categories are derived from the Wikipedia ontology
|
20 |
+
# - "HEAL" # Health and Fitness
|
21 |
+
# - "HIST" # History and Events
|
22 |
+
# - "SOCI" # Society and Social Sciences
|
23 |
+
# - "TECH" # Technology and Applied Sciences
|
24 |
+
# - "CULT" # Culture and the Arts
|
25 |
+
# - "NATU" # Natural and Physical Sciences
|
26 |
+
# - "HUMA" # Human Activities
|
27 |
+
# - "MATH" # Mathematics and Logic
|
28 |
+
# - "GENE" # General Reference
|
29 |
+
# - "RELI" # Religion and Belief Systems
|
30 |
+
# - "PHIL" # Philosophy and Thinking
|
31 |
+
|
32 |
+
model_architectures:
|
33 |
+
- "llama"
|
34 |
+
- "olmo"
|
model-tracing/experiments/csw_full.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Runs statistic Cosine Similarity of Weights on all tensors of two given models, if they match in size.
|
3 |
+
Part of the "Unconstrained Setting" experiments (see StripedHyena experiments).
|
4 |
+
Relevant for hybrid models where only some parameters are shared.
|
5 |
+
|
6 |
+
To run: Use the HuggingFace Ids for the two models in Line 65-66.
|
7 |
+
Prints p-values between tensors that align in dimension.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import scipy
|
12 |
+
from scipy.optimize import linear_sum_assignment as LAP
|
13 |
+
from transformers import AutoModelForCausalLM
|
14 |
+
|
15 |
+
from tracing.utils.utils import cossim, fisher
|
16 |
+
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
|
22 |
+
def csw_sp_pair(base_model, ft_model, layer_name_base, layer_name_ft):
|
23 |
+
"""
|
24 |
+
Calculate Cosine Similarity of Weights between two specific layers.
|
25 |
+
|
26 |
+
Uses linear assignment to find optimal matching between neurons and
|
27 |
+
calculates Pearson correlation to quantify similarity.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
base_model: First model to compare
|
31 |
+
ft_model: Second model to compare
|
32 |
+
layer_name_base: Name of the layer in the first model's state dict
|
33 |
+
layer_name_ft: Name of the layer in the second model's state dict
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
float: p-value indicating the statistical similarity of weight matrices
|
37 |
+
"""
|
38 |
+
base_mat = base_model.state_dict()[layer_name_base]
|
39 |
+
ft_mat = ft_model.state_dict()[layer_name_ft]
|
40 |
+
|
41 |
+
matched = LAP(cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True)
|
42 |
+
matched = matched[1]
|
43 |
+
|
44 |
+
orig = torch.arange(len(matched))
|
45 |
+
cor, pvalue = scipy.stats.pearsonr(matched.tolist(), orig.tolist())
|
46 |
+
return pvalue
|
47 |
+
|
48 |
+
|
49 |
+
def csw_models(base_model, ft_model):
|
50 |
+
"""
|
51 |
+
Perform comprehensive pairwise comparisons between all compatible layers of two models.
|
52 |
+
|
53 |
+
Tests all possible layer pairings between models that have compatible shapes,
|
54 |
+
useful for exploring model structure similarities without assuming corresponding positions.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
base_model: First model to compare
|
58 |
+
ft_model: Second model to compare
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
float: Aggregate p-value from Fisher's method combining all layer comparisons,
|
62 |
+
or 999 if no compatible layers were found
|
63 |
+
"""
|
64 |
+
base_model.to("cpu")
|
65 |
+
ft_model.to("cpu")
|
66 |
+
|
67 |
+
weights_base = base_model.state_dict()
|
68 |
+
weights_ft = ft_model.state_dict()
|
69 |
+
|
70 |
+
shapes_base = {}
|
71 |
+
shapes_ft = {}
|
72 |
+
|
73 |
+
for name1 in list(weights_base.keys()):
|
74 |
+
shapes_base[name1] = weights_base[name1].shape
|
75 |
+
for name2 in list(weights_ft.keys()):
|
76 |
+
shapes_ft[name2] = weights_ft[name2].shape
|
77 |
+
|
78 |
+
pvalues = []
|
79 |
+
|
80 |
+
for name1 in list(weights_base.keys()):
|
81 |
+
for name2 in list(weights_ft.keys()):
|
82 |
+
# print(name1,name2)
|
83 |
+
if shapes_base[name1] == shapes_ft[name2] and len(shapes_base[name1]) != 1:
|
84 |
+
pval = csw_sp_pair(base_model, ft_model, name1, name2)
|
85 |
+
print(name1, name2, pval)
|
86 |
+
pvalues.append(pval)
|
87 |
+
|
88 |
+
res = 0
|
89 |
+
|
90 |
+
if len(pvalues) == 0:
|
91 |
+
res = 999
|
92 |
+
else:
|
93 |
+
res = fisher(pvalues)
|
94 |
+
|
95 |
+
return res
|
96 |
+
|
97 |
+
|
98 |
+
def main():
|
99 |
+
model_1_id = "openai-community/gpt2"
|
100 |
+
model_2_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
101 |
+
|
102 |
+
print(model_1_id, model_2_id)
|
103 |
+
|
104 |
+
model_1 = AutoModelForCausalLM.from_pretrained(model_1_id, torch_dtype=torch.bfloat16)
|
105 |
+
model_2 = AutoModelForCausalLM.from_pretrained(model_2_id, torch_dtype=torch.bfloat16)
|
106 |
+
|
107 |
+
print(csw_models(model_1, model_2))
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
main()
|
model-tracing/experiments/faiss/csw_faiss.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from typing import Dict, Tuple, List, NamedTuple
|
5 |
+
import os
|
6 |
+
import pickle
|
7 |
+
import yaml
|
8 |
+
from transformers import AutoModelForCausalLM
|
9 |
+
|
10 |
+
|
11 |
+
class WeightInfo(NamedTuple):
|
12 |
+
"""
|
13 |
+
A named tuple containing metadata about a weight matrix.
|
14 |
+
|
15 |
+
Attributes:
|
16 |
+
model_name: Name or identifier of the model
|
17 |
+
param_name: Name of the parameter in the model's state dict
|
18 |
+
dimensions: Tuple containing the shape of the weight matrix (d1, d2)
|
19 |
+
"""
|
20 |
+
|
21 |
+
model_name: str
|
22 |
+
param_name: str
|
23 |
+
dimensions: Tuple[int, int]
|
24 |
+
|
25 |
+
|
26 |
+
class CSWSearch:
|
27 |
+
"""
|
28 |
+
CSWSearch (Cosine Similarity of Weights Search) using FAISS for efficient similarity search.
|
29 |
+
|
30 |
+
This class enables fast indexing and retrieval of similar weight matrices across models,
|
31 |
+
organizing weight matrices by their dimensions to ensure comparable searches.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
# Keep track of what each index position corresponds to
|
36 |
+
self.metadata: Dict[Tuple[int, int], List[WeightInfo]] = {}
|
37 |
+
# Track dimensions and index file locations
|
38 |
+
self.index_files: Dict[Tuple[int, int], str] = {}
|
39 |
+
# Directory where indices are stored
|
40 |
+
self.index_dir: str = "indexes"
|
41 |
+
# Currently loaded index
|
42 |
+
self.current_index: Tuple[Tuple[int, int], faiss.Index] = None
|
43 |
+
|
44 |
+
def add_weight_matrix(
|
45 |
+
self, model_name: str, param_name: str, weight_matrix: np.ndarray
|
46 |
+
) -> None:
|
47 |
+
"""
|
48 |
+
Add a weight matrix to the appropriate index based on its dimensions.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
model_name: Name or identifier of the model
|
52 |
+
param_name: Name of the parameter in the model's state dict
|
53 |
+
weight_matrix: The weight matrix tensor to index
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
None
|
57 |
+
"""
|
58 |
+
print(f"Adding {model_name} {param_name}")
|
59 |
+
d1, d2 = weight_matrix.shape
|
60 |
+
dim_key = (d1, d2)
|
61 |
+
|
62 |
+
# First time seeing this dimension combination
|
63 |
+
if dim_key not in self.index_files:
|
64 |
+
self.metadata[dim_key] = []
|
65 |
+
self.index_files[dim_key] = f"index_{d1}x{d2}.index"
|
66 |
+
|
67 |
+
# Load the appropriate index
|
68 |
+
index = self._load_index(dim_key)
|
69 |
+
|
70 |
+
# Flatten matrix in row-major order and normalize
|
71 |
+
flat_weights = np.array(weight_matrix.to(dtype=torch.float32).reshape(1, -1).numpy())
|
72 |
+
faiss.normalize_L2(flat_weights) # for cosine similarity
|
73 |
+
|
74 |
+
# Add to appropriate index
|
75 |
+
index.add(flat_weights)
|
76 |
+
|
77 |
+
# Store metadata
|
78 |
+
self.metadata[dim_key].append(WeightInfo(model_name, param_name, (d1, d2)))
|
79 |
+
|
80 |
+
# Save the updated index
|
81 |
+
self._save_index(dim_key, index)
|
82 |
+
|
83 |
+
def find_similar_weights(
|
84 |
+
self, model_name: str, weight_matrix: np.ndarray, k: int = 5
|
85 |
+
) -> List[Tuple[WeightInfo, float]]:
|
86 |
+
"""
|
87 |
+
Find similar weight matrices with matching dimensions.
|
88 |
+
|
89 |
+
Searches for weight matrices most similar to the provided one,
|
90 |
+
but only among those with the same dimensions.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
model_name: Name or identifier of the model (used to exclude self-matches)
|
94 |
+
weight_matrix: The weight matrix tensor to search for
|
95 |
+
k: Number of similar matrices to return (default: 5)
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
List of tuples containing (WeightInfo, similarity_score)
|
99 |
+
|
100 |
+
Raises:
|
101 |
+
ValueError: If no weight matrices with matching dimensions are found
|
102 |
+
"""
|
103 |
+
d1, d2 = weight_matrix.shape
|
104 |
+
dim_key = (d1, d2)
|
105 |
+
|
106 |
+
if dim_key not in self.index_files:
|
107 |
+
raise ValueError(f"No weight matrices found with dimensions {dim_key}")
|
108 |
+
|
109 |
+
# Load the appropriate index
|
110 |
+
index = self._load_index(dim_key)
|
111 |
+
|
112 |
+
# Prepare query in same way as stored matrices
|
113 |
+
query = np.array(weight_matrix.to(dtype=torch.float32).reshape(1, -1).numpy())
|
114 |
+
faiss.normalize_L2(query)
|
115 |
+
|
116 |
+
# Search
|
117 |
+
distances, indices = index.search(query, k + 1) # +1 for self-match
|
118 |
+
|
119 |
+
# Format results (excluding self-match)
|
120 |
+
results = []
|
121 |
+
for idx, sim in zip(indices[0], distances[0]):
|
122 |
+
info = self.metadata[dim_key][idx]
|
123 |
+
if info.model_name != model_name: # Skip self-match
|
124 |
+
results.append((info, float(sim)))
|
125 |
+
|
126 |
+
return results[:k]
|
127 |
+
|
128 |
+
def _load_index(self, dim_key: Tuple[int, int]) -> faiss.Index:
|
129 |
+
"""
|
130 |
+
Load or create the FAISS index for a specific dimension.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
dim_key: Tuple of dimensions (d1, d2)
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
faiss.Index: The loaded or newly created index
|
137 |
+
"""
|
138 |
+
if self.current_index and self.current_index[0] == dim_key:
|
139 |
+
return self.current_index[1]
|
140 |
+
|
141 |
+
d1, d2 = dim_key
|
142 |
+
index_path = os.path.join(self.index_dir, self.index_files[dim_key])
|
143 |
+
|
144 |
+
if os.path.exists(index_path):
|
145 |
+
try:
|
146 |
+
index = faiss.read_index(index_path)
|
147 |
+
except RuntimeError:
|
148 |
+
print(f"Error reading index file {index_path}. Creating a new index.")
|
149 |
+
index = faiss.IndexFlatIP(d1 * d2)
|
150 |
+
else:
|
151 |
+
print(f"Index file {index_path} not found. Creating a new index.")
|
152 |
+
index = faiss.IndexFlatIP(d1 * d2)
|
153 |
+
|
154 |
+
self.current_index = (dim_key, index)
|
155 |
+
return index
|
156 |
+
|
157 |
+
def _save_index(self, dim_key: Tuple[int, int], index: faiss.Index):
|
158 |
+
"""
|
159 |
+
Save the index for the given dimensions to disk.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
dim_key: Tuple of dimensions (d1, d2)
|
163 |
+
index: The FAISS index to save
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
None
|
167 |
+
"""
|
168 |
+
index_path = os.path.join(self.index_dir, self.index_files[dim_key])
|
169 |
+
faiss.write_index(index, index_path)
|
170 |
+
|
171 |
+
def save(self, directory: str):
|
172 |
+
"""
|
173 |
+
Save the entire search system (metadata and indexes) to disk.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
directory: Directory where indices and metadata will be stored
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
None
|
180 |
+
"""
|
181 |
+
self.index_dir = directory
|
182 |
+
os.makedirs(directory, exist_ok=True)
|
183 |
+
|
184 |
+
if self.current_index:
|
185 |
+
self._save_index(self.current_index[0], self.current_index[1])
|
186 |
+
|
187 |
+
metadata_path = os.path.join(directory, "metadata.pkl")
|
188 |
+
with open(metadata_path, "wb") as f:
|
189 |
+
pickle.dump(self.metadata, f)
|
190 |
+
|
191 |
+
index_files_path = os.path.join(directory, "index_files.pkl")
|
192 |
+
with open(index_files_path, "wb") as f:
|
193 |
+
pickle.dump(self.index_files, f)
|
194 |
+
|
195 |
+
@classmethod
|
196 |
+
def load(cls, directory: str):
|
197 |
+
"""
|
198 |
+
Load a previously saved search system from disk.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
directory: Directory where indices and metadata are stored
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
CSWSearch: The loaded search system
|
205 |
+
"""
|
206 |
+
csw_search = cls()
|
207 |
+
csw_search.index_dir = directory
|
208 |
+
|
209 |
+
metadata_path = os.path.join(directory, "metadata.pkl")
|
210 |
+
with open(metadata_path, "rb") as f:
|
211 |
+
csw_search.metadata = pickle.load(f)
|
212 |
+
|
213 |
+
index_files_path = os.path.join(directory, "index_files.pkl")
|
214 |
+
with open(index_files_path, "rb") as f:
|
215 |
+
csw_search.index_files = pickle.load(f)
|
216 |
+
|
217 |
+
return csw_search
|
218 |
+
|
219 |
+
|
220 |
+
csw = CSWSearch()
|
221 |
+
|
222 |
+
|
223 |
+
def add_params(model_list):
|
224 |
+
"""
|
225 |
+
Index weight matrices from a list of HuggingFace model IDs.
|
226 |
+
|
227 |
+
Loads each model, extracts its parameters, and adds all 2D weight matrices
|
228 |
+
to the CSWSearch index for later similarity search.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
model_list: List of HuggingFace model IDs to index
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
None: Updates the global csw search index
|
235 |
+
"""
|
236 |
+
for model_id in model_list:
|
237 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
238 |
+
weights = model.state_dict()
|
239 |
+
params = list(weights.keys())
|
240 |
+
for param in params:
|
241 |
+
# Skip 1D tensors (like bias terms or layer norms)
|
242 |
+
if len(weights[param].shape) == 1:
|
243 |
+
continue
|
244 |
+
csw.add_weight_matrix(model_id, param_name=param, weight_matrix=weights[param])
|
245 |
+
|
246 |
+
|
247 |
+
def get_similar_param(param, k=5):
|
248 |
+
"""
|
249 |
+
Find similar parameters to the given weight matrix across indexed models.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
param: Weight matrix tensor to search for
|
253 |
+
k: Number of similar matrices to return (default: 5)
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
List of tuples containing (WeightInfo, similarity_score)
|
257 |
+
"""
|
258 |
+
return csw.find_similar_weights("--", param, k=k)
|
259 |
+
|
260 |
+
|
261 |
+
def main():
|
262 |
+
# Model list to add from yaml
|
263 |
+
model_list = yaml.safe_load(open("config/llama7b.yaml", "r"))
|
264 |
+
add_params(model_list)
|
265 |
+
csw.save("indexes")
|
266 |
+
|
267 |
+
# Weight matrix to search for
|
268 |
+
model = AutoModelForCausalLM.from_pretrained(
|
269 |
+
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16
|
270 |
+
)
|
271 |
+
weights = model.state_dict()
|
272 |
+
attn_name = "model.layers.0.self_attn.o_proj.weight"
|
273 |
+
|
274 |
+
print(get_similar_param(weights[attn_name]))
|
275 |
+
|
276 |
+
return
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == "__main__":
|
280 |
+
main()
|
model-tracing/experiments/generalized_match.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains the code to do generalize \phi_MATCH.
|
3 |
+
Given two models, it first retrains the MLPs or other FFN of the models, then runs \phi_MATCH on the distilled MLPs.
|
4 |
+
|
5 |
+
May need to be modified depending on model architecture (this code was used for GPT-architecture).
|
6 |
+
"""
|
7 |
+
|
8 |
+
MLP_SIZE = 3072
|
9 |
+
MLP_SIZE_2 = 3072
|
10 |
+
EMB_SIZE = 768
|
11 |
+
EMB_SIZE_2 = 768
|
12 |
+
N_BLOCKS = 12
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from transformers import (
|
17 |
+
AutoTokenizer,
|
18 |
+
AutoModelForCausalLM,
|
19 |
+
AutoConfig,
|
20 |
+
GPT2LMHeadModel,
|
21 |
+
OpenAIGPTLMHeadModel,
|
22 |
+
)
|
23 |
+
import scipy
|
24 |
+
from collections import defaultdict
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate
|
29 |
+
from tracing.utils.evaluate import (
|
30 |
+
prepare_hf_dataset,
|
31 |
+
prepare_hf_dataloader,
|
32 |
+
)
|
33 |
+
|
34 |
+
from tracing.utils.utils import manual_seed
|
35 |
+
from tracing.utils.llama.matching import match_wmats
|
36 |
+
|
37 |
+
manual_seed(0)
|
38 |
+
|
39 |
+
|
40 |
+
# architecture of MLP trained from scratch can be different from original
|
41 |
+
# eg, uncomment to get a 2-hidden layer MLP (original has just 1 hidden layer)
|
42 |
+
class CustomLlamaMLP(nn.Module):
|
43 |
+
"""
|
44 |
+
Custom MLP module for Llama-style architecture with SwiGLU activation.
|
45 |
+
|
46 |
+
This implementation allows for flexible architecture changes when training
|
47 |
+
replacement MLPs for model distillation and analysis.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
config: Model configuration containing embedding dimensions
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, config):
|
54 |
+
super().__init__()
|
55 |
+
self.config = config
|
56 |
+
self.hidden_size = config.n_embd
|
57 |
+
self.intermediate_size = 4 * config.n_embd
|
58 |
+
|
59 |
+
self.gate_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
60 |
+
self.up_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
61 |
+
self.down_proj1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
62 |
+
|
63 |
+
self.act_fn = nn.SiLU()
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
"""
|
67 |
+
Forward pass implementing SwiGLU activation function.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x: Input tensor of shape [batch_size, seq_len, hidden_size]
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
torch.Tensor: Output tensor after MLP transformation
|
74 |
+
"""
|
75 |
+
down_proj = self.down_proj1(self.act_fn(self.gate_proj1(x)) * self.up_proj1(x))
|
76 |
+
return down_proj
|
77 |
+
|
78 |
+
|
79 |
+
def hook_out(m, inp, op, feats, name):
|
80 |
+
"""
|
81 |
+
Forward hook to capture output activations from model layers.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
m: Module being hooked
|
85 |
+
inp: Input to the module
|
86 |
+
op: Output from the module
|
87 |
+
feats: Dictionary to store activations
|
88 |
+
name: Key to store the activations under
|
89 |
+
"""
|
90 |
+
feats[name].append(op.detach().cpu())
|
91 |
+
|
92 |
+
|
93 |
+
def hook_in(m, inp, op, feats, name):
|
94 |
+
"""
|
95 |
+
Forward hook to capture input activations to model layers.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
m: Module being hooked
|
99 |
+
inp: Input to the module (tuple)
|
100 |
+
op: Output from the module
|
101 |
+
feats: Dictionary to store activations
|
102 |
+
name: Key to store the activations under
|
103 |
+
"""
|
104 |
+
feats[name].append(inp[0].detach().cpu())
|
105 |
+
|
106 |
+
|
107 |
+
def mlp_layers(base_model_gate, base_model_up, ft_model_gate, ft_model_up, dataloader, i, j):
|
108 |
+
"""
|
109 |
+
Compare gate and up projections between separate model components.
|
110 |
+
|
111 |
+
Tests whether separately trained gate and up projection models have
|
112 |
+
consistent permutation patterns, which would indicate functionally
|
113 |
+
corresponding neurons.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
base_model_gate: First model with gate projection weights
|
117 |
+
base_model_up: First model with up projection weights
|
118 |
+
ft_model_gate: Second model with gate projection weights
|
119 |
+
ft_model_up: Second model with up projection weights
|
120 |
+
dataloader: DataLoader providing input data for activation collection
|
121 |
+
i: Layer index in the first model
|
122 |
+
j: Layer index in the second model
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
float: Pearson correlation p-value between gate and up projection permutations
|
126 |
+
"""
|
127 |
+
gate_match = mlp_matching(base_model_gate, ft_model_gate, dataloader, i, j)
|
128 |
+
up_match = mlp_matching(base_model_up, ft_model_up, dataloader, i, j)
|
129 |
+
|
130 |
+
print(gate_match, up_match, i, j)
|
131 |
+
|
132 |
+
cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist())
|
133 |
+
return pvalue
|
134 |
+
|
135 |
+
|
136 |
+
def mlp_matching(base_model, ft_model, dataloader, i, j):
|
137 |
+
"""
|
138 |
+
Match neurons between models by comparing activations.
|
139 |
+
|
140 |
+
Collects activations from the feed-forward layer for both models
|
141 |
+
and computes a permutation that would align corresponding neurons.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
base_model: Base model to compare
|
145 |
+
ft_model: Target model to compare against the base model
|
146 |
+
dataloader: DataLoader providing input data for activation collection
|
147 |
+
i: Layer index in the base model
|
148 |
+
j: Layer index in the target model
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
torch.Tensor: Permutation indices that match neurons between models
|
152 |
+
"""
|
153 |
+
feats = defaultdict(list)
|
154 |
+
|
155 |
+
base_hook = lambda *args: hook_out(*args, feats, "base")
|
156 |
+
base_handle = base_model.transformer.h[i].mlp.c_fc.register_forward_hook(base_hook)
|
157 |
+
|
158 |
+
ft_hook = lambda *args: hook_out(*args, feats, "ft")
|
159 |
+
ft_handle = ft_model.transformer.h[i].mlp.c_fc.register_forward_hook(ft_hook)
|
160 |
+
|
161 |
+
evaluate(base_model, dataloader)
|
162 |
+
evaluate(ft_model, dataloader)
|
163 |
+
|
164 |
+
base_mat = torch.vstack(feats["base"])
|
165 |
+
ft_mat = torch.vstack(feats["ft"])
|
166 |
+
|
167 |
+
base_mat.to("cuda")
|
168 |
+
ft_mat.to("cuda")
|
169 |
+
|
170 |
+
base_mat = base_mat.view(-1, base_mat.shape[-1]).T
|
171 |
+
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
|
172 |
+
|
173 |
+
base_handle.remove()
|
174 |
+
ft_handle.remove()
|
175 |
+
|
176 |
+
perm = match_wmats(base_mat, ft_mat)
|
177 |
+
return perm
|
178 |
+
|
179 |
+
|
180 |
+
def run(i):
|
181 |
+
"""
|
182 |
+
Run the generalized MATCH algorithm to compare models with distilled components.
|
183 |
+
|
184 |
+
This function:
|
185 |
+
1. Loads two different GPT-2 models
|
186 |
+
2. Trains custom MLPs to replicate the behavior of the original model MLPs
|
187 |
+
3. Optionally applies random rotations to test invariance to representation changes
|
188 |
+
4. Creates separate models for gate and up projections
|
189 |
+
5. Compares the models using the MATCH algorithm
|
190 |
+
|
191 |
+
The process demonstrates that functionally equivalent components can be identified
|
192 |
+
even after representation changes, by examining activation patterns.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
i: Layer index to focus on for the analysis
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
None: Prints p-value results from the neuron matching
|
199 |
+
"""
|
200 |
+
train_losses = []
|
201 |
+
|
202 |
+
model_id_2 = "manupande21/GPT2_PMC"
|
203 |
+
model_id_1 = "openai-community/gpt2"
|
204 |
+
|
205 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id_1, use_fast=False)
|
206 |
+
model = AutoModelForCausalLM.from_pretrained(model_id_1, torch_dtype=torch.bfloat16)
|
207 |
+
|
208 |
+
base_tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
209 |
+
dataset_wikitext = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer)
|
210 |
+
dataloader_wikitext = prepare_hf_dataloader(dataset_wikitext, 1)
|
211 |
+
|
212 |
+
config = AutoConfig.from_pretrained(model_id_1)
|
213 |
+
|
214 |
+
i = 0 # layer to retrain
|
215 |
+
bsz = 5000 # batch size
|
216 |
+
T = 10000 # gradient steps
|
217 |
+
width_fac = 1.0 # easier to get loss down for wider MLPs when retraining
|
218 |
+
|
219 |
+
# Train the first custom MLP
|
220 |
+
mlp = CustomLlamaMLP(config).bfloat16()
|
221 |
+
|
222 |
+
mlp.to("cuda")
|
223 |
+
model.transformer.h[i].mlp.to("cuda")
|
224 |
+
|
225 |
+
criterion = torch.nn.MSELoss()
|
226 |
+
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
|
227 |
+
print(f"Training MLP {model_id_1}")
|
228 |
+
|
229 |
+
# Random rotation matrix to test invariance to representation changes
|
230 |
+
A = torch.randn(size=(EMB_SIZE, EMB_SIZE), device="cuda").bfloat16() / np.sqrt(
|
231 |
+
EMB_SIZE
|
232 |
+
) # rotate outputs (just for kicks / sanity check)
|
233 |
+
|
234 |
+
# Distillation training loop for first model
|
235 |
+
for t in range(T):
|
236 |
+
X_batch = torch.randn(size=(bsz, EMB_SIZE), dtype=torch.bfloat16, device="cuda")
|
237 |
+
with torch.no_grad():
|
238 |
+
Y_batch = model.transformer.h[i].mlp(X_batch)
|
239 |
+
Y_batch = Y_batch @ A.T # Apply rotation to outputs
|
240 |
+
|
241 |
+
Y_h = mlp(X_batch)
|
242 |
+
|
243 |
+
optimizer.zero_grad()
|
244 |
+
loss = criterion(Y_h, Y_batch)
|
245 |
+
|
246 |
+
loss.backward()
|
247 |
+
optimizer.step()
|
248 |
+
|
249 |
+
if t % 1000 == 0:
|
250 |
+
print(f"train loss: {loss.item()}")
|
251 |
+
train_losses.append(loss.item())
|
252 |
+
|
253 |
+
# Create separate models for gate and up projections
|
254 |
+
config = AutoConfig.from_pretrained(model_id_1)
|
255 |
+
config.intermediate_size = int(width_fac * MLP_SIZE)
|
256 |
+
|
257 |
+
model_retrained_1_gate = OpenAIGPTLMHeadModel(config).bfloat16()
|
258 |
+
model_retrained_1_up = OpenAIGPTLMHeadModel(config).bfloat16()
|
259 |
+
model.to("cpu")
|
260 |
+
mlp.to("cpu")
|
261 |
+
|
262 |
+
# Loading retrained weights to model_retrained
|
263 |
+
weights_1_gate = model.state_dict()
|
264 |
+
weights_1_up = model.state_dict()
|
265 |
+
|
266 |
+
weights_1_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T
|
267 |
+
weights_1_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T
|
268 |
+
model_retrained_1_gate.load_state_dict(weights_1_gate)
|
269 |
+
model_retrained_1_up.load_state_dict(weights_1_up)
|
270 |
+
|
271 |
+
# Retraining / distilling second model
|
272 |
+
model = AutoModelForCausalLM.from_pretrained(model_id_2, torch_dtype=torch.bfloat16)
|
273 |
+
|
274 |
+
config = AutoConfig.from_pretrained(model_id_2)
|
275 |
+
mlp = CustomLlamaMLP(config).bfloat16()
|
276 |
+
|
277 |
+
mlp.to("cuda")
|
278 |
+
model.transformer.h[i].mlp.to("cuda")
|
279 |
+
|
280 |
+
criterion = torch.nn.MSELoss()
|
281 |
+
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
|
282 |
+
|
283 |
+
print(f"Training MLP {model_id_2}")
|
284 |
+
|
285 |
+
# Different random rotation matrix for second model
|
286 |
+
A = torch.randn(size=(EMB_SIZE_2, EMB_SIZE_2), device="cuda").bfloat16() / np.sqrt(
|
287 |
+
EMB_SIZE_2
|
288 |
+
) # rotate outputs (just for kicks / sanity check)
|
289 |
+
|
290 |
+
# Distillation training loop for second model
|
291 |
+
for t in range(T):
|
292 |
+
X_batch = torch.randn(size=(bsz, EMB_SIZE_2), dtype=torch.bfloat16, device="cuda")
|
293 |
+
with torch.no_grad():
|
294 |
+
Y_batch = model.transformer.h[i].mlp(X_batch)
|
295 |
+
Y_batch = Y_batch @ A.T # Apply rotation to outputs
|
296 |
+
|
297 |
+
Y_h = mlp(X_batch)
|
298 |
+
|
299 |
+
optimizer.zero_grad()
|
300 |
+
loss = criterion(Y_h, Y_batch)
|
301 |
+
|
302 |
+
loss.backward()
|
303 |
+
optimizer.step()
|
304 |
+
|
305 |
+
if t % 1000 == 0:
|
306 |
+
print(f"train loss: {loss.item()}")
|
307 |
+
train_losses.append(loss.item())
|
308 |
+
|
309 |
+
# Create separate models for gate and up projections
|
310 |
+
config = AutoConfig.from_pretrained(model_id_2)
|
311 |
+
config.intermediate_size = int(width_fac * MLP_SIZE_2)
|
312 |
+
|
313 |
+
model_retrained_2_gate = GPT2LMHeadModel(config).bfloat16()
|
314 |
+
model_retrained_2_up = GPT2LMHeadModel(config).bfloat16()
|
315 |
+
model.to("cpu")
|
316 |
+
mlp.to("cpu")
|
317 |
+
|
318 |
+
weights_2_gate = model.state_dict()
|
319 |
+
weights_2_up = model.state_dict()
|
320 |
+
|
321 |
+
weights_2_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T
|
322 |
+
weights_2_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T
|
323 |
+
|
324 |
+
model_retrained_2_gate.load_state_dict(weights_2_gate)
|
325 |
+
model_retrained_2_up.load_state_dict(weights_2_up)
|
326 |
+
|
327 |
+
# Run MATCH algorithm to compare the models
|
328 |
+
print(
|
329 |
+
mlp_layers(
|
330 |
+
model_retrained_1_gate,
|
331 |
+
model_retrained_1_up,
|
332 |
+
model_retrained_2_gate,
|
333 |
+
model_retrained_2_up,
|
334 |
+
dataloader,
|
335 |
+
0,
|
336 |
+
0,
|
337 |
+
)
|
338 |
+
)
|
339 |
+
|
340 |
+
|
341 |
+
if __name__ == "__main__":
|
342 |
+
for i in range(0, 10):
|
343 |
+
run(i)
|
model-tracing/experiments/huref.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM
|
3 |
+
|
4 |
+
from tracing.utils.llama.model import rotate_model_t5
|
5 |
+
|
6 |
+
torch.set_default_dtype(torch.bfloat16)
|
7 |
+
|
8 |
+
model_name = "meta-llama/Llama-2-7b-hf"
|
9 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
|
10 |
+
|
11 |
+
model_rot_name = "yahma/llama-7b-hf"
|
12 |
+
model_rotated = AutoModelForCausalLM.from_pretrained(model_rot_name, torch_dtype=torch.bfloat16).to(
|
13 |
+
"cuda"
|
14 |
+
)
|
15 |
+
rotate_model_t5(model_rotated)
|
16 |
+
# rotate_model(model_rotated)
|
17 |
+
|
18 |
+
# Fixing the layer norms to 1's (HUReF works)
|
19 |
+
# """
|
20 |
+
# fix_layer_norm(model)
|
21 |
+
# fix_layer_norm(model_rotated)
|
22 |
+
# """
|
23 |
+
|
24 |
+
# base_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
25 |
+
# dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized",512,base_tokenizer)
|
26 |
+
# dataloader = prepare_hf_dataloader(dataset,1)
|
27 |
+
|
28 |
+
# evaluate_model = evaluate(model, dataloader)
|
29 |
+
# evaluate_rotated = evaluate(model_rotated, dataloader)
|
30 |
+
|
31 |
+
# print("outputs are aligned: ")
|
32 |
+
# print([abs(evaluate_model[i] - evaluate_rotated[i]) <= 0.01 for i in range(len(evaluate_model))])
|
33 |
+
|
34 |
+
weights = model.state_dict()
|
35 |
+
weights_rotated = model_rotated.state_dict()
|
36 |
+
|
37 |
+
# model.to('cuda')
|
38 |
+
# print("invariant 1")
|
39 |
+
# print(weights['model.embed_tokens.weight']@weights['model.layers.0.self_attn.q_proj.weight'].T@weights['model.layers.0.self_attn.k_proj.weight']@weights['model.embed_tokens.weight'].T)
|
40 |
+
# print("invariant 2")
|
41 |
+
# print(weights['model.embed_tokens.weight']@weights['model.layers.0.self_attn.v_proj.weight'].T@weights['model.layers.0.self_attn.o_proj.weight'].T@weights['model.embed_tokens.weight'].T)
|
42 |
+
# print("invariant 3")
|
43 |
+
# print(weights['model.embed_tokens.weight']@weights[f'model.layers.0.mlp.up_proj.weight'].T@weights[f'model.layers.0.mlp.down_proj.weight'].T@weights['model.embed_tokens.weight'].T)
|
44 |
+
# print()
|
45 |
+
# model.to('cpu')
|
46 |
+
|
47 |
+
# model_rotated.to('cuda')
|
48 |
+
# print("rotated")
|
49 |
+
# print("invariant 1")
|
50 |
+
# print(weights_rotated['model.embed_tokens.weight']@weights_rotated['model.layers.0.self_attn.q_proj.weight'].T@weights_rotated['model.layers.0.self_attn.k_proj.weight']@weights_rotated['model.embed_tokens.weight'].T)
|
51 |
+
|
52 |
+
|
53 |
+
# print("invariant 2")
|
54 |
+
# print(weights_rotated['model.embed_tokens.weight']@weights_rotated['model.layers.0.self_attn.v_proj.weight'].T@weights_rotated['model.layers.0.self_attn.o_proj.weight'].T@weights_rotated['model.embed_tokens.weight'].T)
|
55 |
+
# print("invariant 3")
|
56 |
+
# print(weights_rotated['model.embed_tokens.weight']@weights_rotated[f'model.layers.0.mlp.up_proj.weight'].T@weights_rotated[f'model.layers.0.mlp.down_proj.weight'].T@weights_rotated['model.embed_tokens.weight'].T)
|
57 |
+
# print()
|
58 |
+
# model_rotated.to('cpu')
|
59 |
+
|
60 |
+
# Cosine similarity
|
61 |
+
|
62 |
+
print("cosine similarity")
|
63 |
+
|
64 |
+
model.to("cuda")
|
65 |
+
print("invariant 1")
|
66 |
+
invariant = (
|
67 |
+
weights["model.embed_tokens.weight"]
|
68 |
+
@ weights["model.layers.0.self_attn.q_proj.weight"].T
|
69 |
+
@ weights["model.layers.0.self_attn.k_proj.weight"]
|
70 |
+
@ (weights["model.embed_tokens.weight"]).T
|
71 |
+
)
|
72 |
+
model.to("cpu")
|
73 |
+
model_rotated.to("cuda")
|
74 |
+
invariant_rotated = (
|
75 |
+
weights_rotated["model.embed_tokens.weight"]
|
76 |
+
@ weights_rotated["model.layers.0.self_attn.q_proj.weight"].T
|
77 |
+
@ weights_rotated["model.layers.0.self_attn.k_proj.weight"]
|
78 |
+
@ (weights_rotated["model.embed_tokens.weight"]).T
|
79 |
+
)
|
80 |
+
model_rotated.to("cpu")
|
81 |
+
invariant.to("cuda")
|
82 |
+
invariant_rotated.to("cuda")
|
83 |
+
invariant = torch.flatten(invariant)
|
84 |
+
invariant_rotated = torch.flatten(invariant_rotated)
|
85 |
+
print(
|
86 |
+
torch.dot(invariant, invariant_rotated)
|
87 |
+
/ (torch.norm(invariant) * torch.norm(invariant_rotated))
|
88 |
+
)
|
89 |
+
|
90 |
+
model.to("cuda")
|
91 |
+
print("invariant 2")
|
92 |
+
invariant = (
|
93 |
+
weights["model.embed_tokens.weight"]
|
94 |
+
@ weights["model.layers.0.self_attn.v_proj.weight"].T
|
95 |
+
@ weights["model.layers.0.self_attn.o_proj.weight"].T
|
96 |
+
@ weights["model.embed_tokens.weight"].T
|
97 |
+
)
|
98 |
+
model.to("cpu")
|
99 |
+
model_rotated.to("cuda")
|
100 |
+
invariant_rotated = (
|
101 |
+
weights_rotated["model.embed_tokens.weight"]
|
102 |
+
@ weights_rotated["model.layers.0.self_attn.v_proj.weight"].T
|
103 |
+
@ weights_rotated["model.layers.0.self_attn.o_proj.weight"].T
|
104 |
+
@ weights_rotated["model.embed_tokens.weight"].T
|
105 |
+
)
|
106 |
+
model_rotated.to("cpu")
|
107 |
+
invariant.to("cuda")
|
108 |
+
invariant_rotated.to("cuda")
|
109 |
+
invariant = torch.flatten(invariant)
|
110 |
+
invariant_rotated = torch.flatten(invariant_rotated)
|
111 |
+
print(
|
112 |
+
torch.dot(invariant, invariant_rotated)
|
113 |
+
/ (torch.norm(invariant) * torch.norm(invariant_rotated))
|
114 |
+
)
|
115 |
+
|
116 |
+
model.to("cuda")
|
117 |
+
print("invariant 3")
|
118 |
+
invariant = (
|
119 |
+
weights["model.embed_tokens.weight"]
|
120 |
+
@ weights["model.layers.0.mlp.up_proj.weight"].T
|
121 |
+
@ weights["model.layers.0.mlp.down_proj.weight"].T
|
122 |
+
@ weights["model.embed_tokens.weight"].T
|
123 |
+
)
|
124 |
+
model.to("cpu")
|
125 |
+
model_rotated.to("cuda")
|
126 |
+
invariant_rotated = (
|
127 |
+
weights_rotated["model.embed_tokens.weight"]
|
128 |
+
@ weights_rotated["model.layers.0.mlp.up_proj.weight"].T
|
129 |
+
@ weights_rotated["model.layers.0.mlp.down_proj.weight"].T
|
130 |
+
@ weights_rotated["model.embed_tokens.weight"].T
|
131 |
+
)
|
132 |
+
model_rotated.to("cpu")
|
133 |
+
invariant.to("cuda")
|
134 |
+
invariant_rotated.to("cuda")
|
135 |
+
invariant = torch.flatten(invariant)
|
136 |
+
invariant_rotated = torch.flatten(invariant_rotated)
|
137 |
+
print(
|
138 |
+
torch.dot(invariant, invariant_rotated)
|
139 |
+
/ (torch.norm(invariant) * torch.norm(invariant_rotated))
|
140 |
+
)
|
model-tracing/experiments/localized_testing.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Localized testing experiments for two models.
|
3 |
+
Runs \phi_MATCH on all pairs of GLU MLPs between two models and identifies a match.
|
4 |
+
Also can uncomment code to print the aligned activations.
|
5 |
+
|
6 |
+
To run: Use HuggingFace model Ids in Lines 104-05.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
|
12 |
+
from tracing.utils.evaluate import evaluate
|
13 |
+
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader
|
14 |
+
from tracing.statistics.mlp_sp import hook_out
|
15 |
+
from tracing.utils.evaluate import (
|
16 |
+
prepare_hf_dataset,
|
17 |
+
prepare_hf_dataloader,
|
18 |
+
evaluate,
|
19 |
+
)
|
20 |
+
from tracing.utils.llama.matching import match_wmats
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
import scipy
|
24 |
+
import warnings
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
warnings.filterwarnings("ignore")
|
28 |
+
|
29 |
+
|
30 |
+
def mlp_matching_gate(base_model, ft_model, dataloader, i, j):
|
31 |
+
feats = defaultdict(list)
|
32 |
+
|
33 |
+
base_hook = lambda *args: hook_out(*args, feats, "base")
|
34 |
+
base_handle = base_model.model.layers[i].mlp.gate_proj.register_forward_hook(base_hook)
|
35 |
+
|
36 |
+
ft_hook = lambda *args: hook_out(*args, feats, "ft")
|
37 |
+
ft_handle = ft_model.model.layers[j].mlp.gate_proj.register_forward_hook(ft_hook)
|
38 |
+
|
39 |
+
evaluate(base_model, dataloader)
|
40 |
+
evaluate(ft_model, dataloader)
|
41 |
+
|
42 |
+
base_mat = torch.vstack(feats["base"])
|
43 |
+
ft_mat = torch.vstack(feats["ft"])
|
44 |
+
|
45 |
+
base_mat.to("cuda")
|
46 |
+
ft_mat.to("cuda")
|
47 |
+
|
48 |
+
base_mat = base_mat.view(-1, base_mat.shape[-1]).T
|
49 |
+
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
|
50 |
+
|
51 |
+
# If want to print the activations matching for localized testing (See Llama3.2-3B and Llama3.1-8B activation matching experiment)
|
52 |
+
|
53 |
+
"""
|
54 |
+
ft_mat = torch.norm(ft_mat,dim=1)
|
55 |
+
sorted = torch.sort(torch.argsort(ft_mat)[:8192])[0]
|
56 |
+
for i in sorted:
|
57 |
+
print(i.item(),end=" ")
|
58 |
+
"""
|
59 |
+
|
60 |
+
base_handle.remove()
|
61 |
+
ft_handle.remove()
|
62 |
+
|
63 |
+
perm = match_wmats(base_mat, ft_mat)
|
64 |
+
|
65 |
+
return perm
|
66 |
+
|
67 |
+
|
68 |
+
def mlp_matching_up(base_model, ft_model, dataloader, i, j):
|
69 |
+
feats = defaultdict(list)
|
70 |
+
|
71 |
+
base_hook = lambda *args: hook_out(*args, feats, "base")
|
72 |
+
base_handle = base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hook)
|
73 |
+
|
74 |
+
ft_hook = lambda *args: hook_out(*args, feats, "ft")
|
75 |
+
ft_handle = ft_model.model.layers[j].mlp.up_proj.register_forward_hook(ft_hook)
|
76 |
+
|
77 |
+
evaluate(base_model, dataloader)
|
78 |
+
evaluate(ft_model, dataloader)
|
79 |
+
|
80 |
+
base_mat = torch.vstack(feats["base"])
|
81 |
+
ft_mat = torch.vstack(feats["ft"])
|
82 |
+
|
83 |
+
base_mat.to("cuda")
|
84 |
+
ft_mat.to("cuda")
|
85 |
+
|
86 |
+
base_mat = base_mat.view(-1, base_mat.shape[-1]).T
|
87 |
+
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
|
88 |
+
|
89 |
+
base_handle.remove()
|
90 |
+
ft_handle.remove()
|
91 |
+
|
92 |
+
perm = match_wmats(base_mat, ft_mat)
|
93 |
+
|
94 |
+
return perm
|
95 |
+
|
96 |
+
|
97 |
+
def mlp_layers(base_model, ft_model, dataloader, i, j):
|
98 |
+
|
99 |
+
gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i, j)
|
100 |
+
up_match = mlp_matching_up(base_model, ft_model, dataloader, i, j)
|
101 |
+
|
102 |
+
for g in gate_match:
|
103 |
+
print(g.item(), end=" ")
|
104 |
+
|
105 |
+
cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist())
|
106 |
+
|
107 |
+
return pvalue
|
108 |
+
|
109 |
+
|
110 |
+
def main():
|
111 |
+
model_1_id = "meta-llama/Llama-2-7b-hf"
|
112 |
+
model_2_id = "princeton-nlp/Sheared-LLaMA-2.7B"
|
113 |
+
|
114 |
+
print(model_1_id, model_2_id)
|
115 |
+
|
116 |
+
model_1 = AutoModelForCausalLM.from_pretrained(model_1_id, torch_dtype=torch.bfloat16)
|
117 |
+
model_2 = AutoModelForCausalLM.from_pretrained(model_2_id, torch_dtype=torch.bfloat16)
|
118 |
+
|
119 |
+
tokenizer = AutoTokenizer.from_pretrained(model_1_id)
|
120 |
+
|
121 |
+
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, tokenizer)
|
122 |
+
dataloader = prepare_hf_dataloader(dataset, 1)
|
123 |
+
|
124 |
+
print(model_1.config.num_hidden_layers, model_2.config.num_hidden_layers)
|
125 |
+
|
126 |
+
model_1_matched = np.zeros(model_1.config.num_hidden_layers)
|
127 |
+
model_2_matched = np.zeros(model_2.config.num_hidden_layers)
|
128 |
+
|
129 |
+
for i in range(model_1.config.num_hidden_layers):
|
130 |
+
for j in range(model_2.config.num_hidden_layers):
|
131 |
+
if model_1_matched[i] == 1 or model_2_matched[j] == 1:
|
132 |
+
continue
|
133 |
+
stat = mlp_layers(model_1, model_2, dataloader, i, j)
|
134 |
+
print(i, j, stat)
|
135 |
+
if stat < 0.000001:
|
136 |
+
model_1_matched[i] = 1
|
137 |
+
model_2_matched[j] = 1
|
138 |
+
break
|
139 |
+
break
|
140 |
+
break
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
main()
|
model-tracing/launch.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from yaml import Loader
|
3 |
+
|
4 |
+
import subprocess
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser(description="Experiment Settings")
|
9 |
+
|
10 |
+
parser.add_argument("--slurm", default="nlprun -g 1 -d a6000 -r 80G -a model-tracing", type=str)
|
11 |
+
parser.add_argument("--python", default="python experiment.py", type=str)
|
12 |
+
parser.add_argument("--models", default="config/llama_flat.yaml", type=str)
|
13 |
+
parser.add_argument("--save", default="/juice4/scr4/nlp/model-tracing", type=str)
|
14 |
+
parser.add_argument("--flat", default="all", type=str)
|
15 |
+
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
model_paths = yaml.load(open(args.models, "r"), Loader=Loader)
|
19 |
+
|
20 |
+
subprocess.run(f"mkdir -p {args.save}/logs", shell=True)
|
21 |
+
subprocess.run(f"mkdir -p {args.save}/results", shell=True)
|
22 |
+
|
23 |
+
if args.flat == "all":
|
24 |
+
for i in range(len(model_paths)):
|
25 |
+
for j in range(i + 1, len(model_paths)):
|
26 |
+
model_a = model_paths[i]
|
27 |
+
model_b = model_paths[j]
|
28 |
+
|
29 |
+
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-")
|
30 |
+
if "miqu" not in job_id:
|
31 |
+
continue
|
32 |
+
if job_id[0] == "-":
|
33 |
+
job_id = job_id[1:]
|
34 |
+
|
35 |
+
log_path = args.save + "/logs/" + job_id + ".out"
|
36 |
+
results_path = args.save + "/results/" + job_id + ".p"
|
37 |
+
|
38 |
+
job = (
|
39 |
+
args.slurm + f" -o {log_path} -n {job_id}"
|
40 |
+
f" '{args.python}"
|
41 |
+
+ f" --base_model_id {model_a} --ft_model_id {model_b} --save {results_path}'"
|
42 |
+
)
|
43 |
+
subprocess.run(job, shell=True)
|
44 |
+
|
45 |
+
elif args.flat == "split":
|
46 |
+
base_models = model_paths["base_models"]
|
47 |
+
ft_models = model_paths["ft_models"]
|
48 |
+
|
49 |
+
for base_model in base_models:
|
50 |
+
for ft_model in ft_models:
|
51 |
+
job_id = base_model.replace("/", "-") + "_AND_" + ft_model.replace("/", "-")
|
52 |
+
|
53 |
+
log_path = args.save + "/logs/" + job_id + ".out"
|
54 |
+
results_path = args.save + "/results/" + job_id + ".p"
|
55 |
+
|
56 |
+
job = (
|
57 |
+
args.slurm + f" -o {log_path} -n {job_id}"
|
58 |
+
f" '{args.python}"
|
59 |
+
+ f" --base_model_id {base_model} --ft_model_id {ft_model} --save {results_path}'"
|
60 |
+
)
|
61 |
+
subprocess.run(job, shell=True)
|
62 |
+
|
63 |
+
elif args.flat == "specified":
|
64 |
+
models_1 = model_paths["models_1"]
|
65 |
+
models_2 = model_paths["models_2"]
|
66 |
+
|
67 |
+
names_1 = model_paths["names_1"]
|
68 |
+
names_2 = model_paths["names_2"]
|
69 |
+
|
70 |
+
for i in range(len(models_1)):
|
71 |
+
model_a = models_1[i]
|
72 |
+
model_b = models_2[i]
|
73 |
+
|
74 |
+
name_a = names_1[i]
|
75 |
+
name_b = names_2[i]
|
76 |
+
|
77 |
+
job_id = name_a.replace("/", "-") + "_AND_" + name_b.replace("/", "-")
|
78 |
+
|
79 |
+
log_path = args.save + "/logs/" + job_id + ".out"
|
80 |
+
results_path = args.save + "/results/" + job_id + ".p"
|
81 |
+
|
82 |
+
job = (
|
83 |
+
args.slurm + f" -o {log_path} -n {job_id}"
|
84 |
+
f" '{args.python}"
|
85 |
+
+ f" --base_model_id {model_a} --ft_model_id {model_b} --save {results_path}'"
|
86 |
+
)
|
87 |
+
subprocess.run(job, shell=True)
|
model-tracing/main.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MLP_SIZE = 11008
|
2 |
+
EMB_SIZE = 4096
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import (
|
6 |
+
AutoTokenizer,
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
GPTNeoXTokenizerFast,
|
9 |
+
)
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import pickle
|
13 |
+
import timeit
|
14 |
+
import subprocess
|
15 |
+
import os
|
16 |
+
|
17 |
+
from tracing.utils.llama.model import permute_model, rotate_model
|
18 |
+
from tracing.utils.olmo.model import permute_model as permute_model_olmo
|
19 |
+
from tracing.utils.llama.matching import align_model
|
20 |
+
from tracing.utils.evaluate import (
|
21 |
+
prepare_hf_dataset,
|
22 |
+
prepare_aya_dataset,
|
23 |
+
prepare_hf_dataloader,
|
24 |
+
evaluate,
|
25 |
+
load_dolma_programming_datasets,
|
26 |
+
load_m2d2_datasets,
|
27 |
+
load_generated_datasets,
|
28 |
+
prepare_random_sample_dataset,
|
29 |
+
)
|
30 |
+
from tracing.utils.utils import manual_seed
|
31 |
+
|
32 |
+
from tracing.statistics.mc import statistic as mode_stat
|
33 |
+
from tracing.statistics.l2 import statistic as l2_stat
|
34 |
+
from tracing.statistics.jsd import statistic as jsd_stat
|
35 |
+
from tracing.statistics.csu import statistic as csu_stat
|
36 |
+
from tracing.statistics.csu import statistic_all as csu_all_stat
|
37 |
+
from tracing.statistics.csh import statistic as csh_stat
|
38 |
+
from tracing.statistics.match import statistic as match_stat
|
39 |
+
from tracing.statistics.match import statistic_all as match_all_stat
|
40 |
+
from tracing.statistics.perm_mc_l2 import statistic as perm_mc_l2_stat
|
41 |
+
|
42 |
+
parser = argparse.ArgumentParser(description="Experiment Settings")
|
43 |
+
|
44 |
+
parser.add_argument("--base_model_id", default="meta-llama/Llama-2-7b-hf", type=str)
|
45 |
+
parser.add_argument("--ft_model_id", default="lmsys/vicuna-7b-v1.1", type=str)
|
46 |
+
|
47 |
+
parser.add_argument("--permute", action="store_true")
|
48 |
+
parser.add_argument("--rotate", action="store_true")
|
49 |
+
parser.add_argument("--align", action="store_true")
|
50 |
+
|
51 |
+
parser.add_argument("--dataset", default="wikitext", type=str)
|
52 |
+
parser.add_argument("--block_size", default=512, type=int)
|
53 |
+
parser.add_argument("--batch_size", default=1, type=int)
|
54 |
+
|
55 |
+
parser.add_argument("--save", default="results.p", type=str)
|
56 |
+
parser.add_argument("--seed", default=0, type=int)
|
57 |
+
parser.add_argument("--alpha", default=0.5, type=float)
|
58 |
+
parser.add_argument("--token", default="", type=str)
|
59 |
+
|
60 |
+
parser.add_argument("--stat", default="mode", type=str)
|
61 |
+
parser.add_argument("--attn", action="store_true")
|
62 |
+
parser.add_argument("--emb", action="store_true")
|
63 |
+
parser.add_argument("--num_perm", default=99, type=int)
|
64 |
+
|
65 |
+
|
66 |
+
parser.add_argument("--eval", action="store_true")
|
67 |
+
|
68 |
+
parser.add_argument(
|
69 |
+
"--aya_subset", default="aya_human_annotated", type=str, help="Subset of Aya dataset"
|
70 |
+
)
|
71 |
+
parser.add_argument("--aya_language", default="eng", type=str, help="Language code for Aya dataset")
|
72 |
+
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
|
76 |
+
from huggingface_hub import login
|
77 |
+
|
78 |
+
if args.token == "":
|
79 |
+
hf_token = os.environ["HF_TOKEN"]
|
80 |
+
else:
|
81 |
+
hf_token = args.token
|
82 |
+
login(token=hf_token)
|
83 |
+
|
84 |
+
start = timeit.default_timer()
|
85 |
+
|
86 |
+
results = {}
|
87 |
+
results["args"] = args
|
88 |
+
results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
|
89 |
+
|
90 |
+
# fix seed on torch, np and random
|
91 |
+
manual_seed(args.seed)
|
92 |
+
|
93 |
+
dtype = torch.bfloat16
|
94 |
+
low_cpu_mem_usage = (
|
95 |
+
"70b" in args.base_model_id.lower()
|
96 |
+
) # Enable low memory loading if "70b" is in model name
|
97 |
+
|
98 |
+
print(f"Low CPU Mem Usage Flag set to {low_cpu_mem_usage}")
|
99 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
100 |
+
args.base_model_id, torch_dtype=dtype, low_cpu_mem_usage=low_cpu_mem_usage
|
101 |
+
)
|
102 |
+
if "olmo" in args.base_model_id.lower():
|
103 |
+
tokenizer_name = (
|
104 |
+
"allenai/OLMo-1.7-7B-hf" if "olmo" in args.base_model_id.lower() else args.base_model_id
|
105 |
+
)
|
106 |
+
base_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False)
|
107 |
+
elif "Alfred" in args.base_model_id:
|
108 |
+
base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id)
|
109 |
+
elif "Salesforce" in args.base_model_id:
|
110 |
+
base_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, trust_remote_code=True)
|
111 |
+
else:
|
112 |
+
base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, use_fast=False)
|
113 |
+
|
114 |
+
ft_model = AutoModelForCausalLM.from_pretrained(args.ft_model_id, torch_dtype=dtype)
|
115 |
+
if "olmo" in args.ft_model_id.lower():
|
116 |
+
tokenizer_name = (
|
117 |
+
"allenai/OLMo-1.7-7B-hf" if "olmo" in args.ft_model_id.lower() else args.ft_model_id
|
118 |
+
)
|
119 |
+
ft_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False)
|
120 |
+
elif "Alfred" in args.ft_model_id:
|
121 |
+
ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id)
|
122 |
+
elif "Salesforce" in args.ft_model_id:
|
123 |
+
ft_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True)
|
124 |
+
else:
|
125 |
+
ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, use_fast=False)
|
126 |
+
|
127 |
+
print("base and ft models loaded")
|
128 |
+
|
129 |
+
if args.permute is True:
|
130 |
+
mlp_permutation = torch.randperm(MLP_SIZE)
|
131 |
+
emb_permutation = torch.randperm(EMB_SIZE)
|
132 |
+
if "olmo" in args.base_model_id.lower():
|
133 |
+
permute_model_olmo(base_model, ft_model, mlp_permutation, emb_permutation)
|
134 |
+
else:
|
135 |
+
permute_model(base_model, ft_model, mlp_permutation, emb_permutation)
|
136 |
+
print("ft model permuted")
|
137 |
+
|
138 |
+
if args.rotate is True:
|
139 |
+
rotate_model(ft_model)
|
140 |
+
print("ft model rotated")
|
141 |
+
|
142 |
+
if "70b" in args.base_model_id.lower() and "70b" in args.ft_model_id.lower():
|
143 |
+
# skip tmp_model
|
144 |
+
tmp_model = None
|
145 |
+
elif args.stat == "mode":
|
146 |
+
tmp_model = AutoModelForCausalLM.from_pretrained(args.base_model_id, torch_dtype=dtype)
|
147 |
+
# tmp_tokenizer is unused
|
148 |
+
|
149 |
+
if args.dataset == "wikitext":
|
150 |
+
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", args.block_size, base_tokenizer)
|
151 |
+
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
|
152 |
+
elif args.dataset == "aya":
|
153 |
+
dataset = prepare_aya_dataset(
|
154 |
+
args.aya_subset, args.aya_language, args.block_size, base_tokenizer
|
155 |
+
)
|
156 |
+
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
|
157 |
+
elif args.dataset.startswith("dolma_"):
|
158 |
+
language = args.dataset.split("_")[1]
|
159 |
+
if not language and language is not None:
|
160 |
+
raise ValueError("Language is an empty string")
|
161 |
+
columns_ignored = [
|
162 |
+
"text",
|
163 |
+
"added",
|
164 |
+
"id",
|
165 |
+
"lang",
|
166 |
+
"metadata",
|
167 |
+
"source",
|
168 |
+
"timestamp",
|
169 |
+
"subdomain",
|
170 |
+
]
|
171 |
+
dataset = load_dolma_programming_datasets(
|
172 |
+
language, args.block_size, base_tokenizer, columns_ignored
|
173 |
+
)
|
174 |
+
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
|
175 |
+
elif args.dataset.startswith("m2d2_"):
|
176 |
+
test_case = args.dataset.split("_")[1]
|
177 |
+
if not test_case:
|
178 |
+
raise ValueError("Invalid m2d2 dataset format. Use 'm2d2_testcase' (e.g., 'm2d2_AI')")
|
179 |
+
columns_ignored = ["text", "added", "id", "source", "subdomain"]
|
180 |
+
dataset = load_m2d2_datasets(test_case, args.block_size, base_tokenizer, columns_ignored)
|
181 |
+
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
|
182 |
+
elif args.dataset == "generated":
|
183 |
+
columns_ignored = ["text"]
|
184 |
+
dataset = load_generated_datasets(
|
185 |
+
args.base_model_id, args.ft_model_id, args.block_size, base_tokenizer, columns_ignored
|
186 |
+
)
|
187 |
+
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
|
188 |
+
elif args.dataset == "random":
|
189 |
+
dataset = prepare_random_sample_dataset(20, args.block_size)
|
190 |
+
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
|
191 |
+
|
192 |
+
else:
|
193 |
+
raise ValueError(f"Unknown dataset: {args.dataset}")
|
194 |
+
|
195 |
+
print("dataset loaded")
|
196 |
+
|
197 |
+
if args.stat == "mode":
|
198 |
+
test_stat = lambda base_model, ft_model: mode_stat(
|
199 |
+
base_model, ft_model, tmp_model, dataloader, args.attn, args.emb, args.alpha
|
200 |
+
)
|
201 |
+
results["alpha"] = args.alpha
|
202 |
+
if args.stat == "l2":
|
203 |
+
test_stat = lambda base_model, ft_model: l2_stat(base_model, ft_model)
|
204 |
+
if args.stat == "jsd":
|
205 |
+
test_stat = lambda base_model, ft_model: jsd_stat(base_model, ft_model, dataloader)
|
206 |
+
|
207 |
+
if args.stat == "csu":
|
208 |
+
test_stat = lambda base_model, ft_model: csu_stat(base_model, ft_model)
|
209 |
+
if args.stat == "csu_all":
|
210 |
+
test_stat = lambda base_model, ft_model: csu_all_stat(base_model, ft_model)
|
211 |
+
if args.stat == "csh_sp":
|
212 |
+
test_stat = lambda base_model, ft_model: csh_stat(base_model, ft_model, dataloader)
|
213 |
+
|
214 |
+
if args.stat == "match":
|
215 |
+
test_stat = lambda base_model, ft_model: match_stat(base_model, ft_model, dataloader)
|
216 |
+
if args.stat == "match_all":
|
217 |
+
test_stat = lambda base_model, ft_model: match_all_stat(base_model, ft_model, dataloader)
|
218 |
+
|
219 |
+
if args.stat == "perm_mc_l2":
|
220 |
+
mc = lambda base_model, ft_model: mode_stat(
|
221 |
+
base_model, ft_model, tmp_model, dataloader, args.attn, args.emb
|
222 |
+
)
|
223 |
+
l2 = lambda base_model, ft_model: l2_stat(base_model, ft_model)
|
224 |
+
test_stat = lambda base_model, ft_model: perm_mc_l2_stat(
|
225 |
+
base_model, ft_model, mc, l2, args.num_perm
|
226 |
+
)
|
227 |
+
|
228 |
+
if args.eval is True:
|
229 |
+
results["base loss"] = sum(evaluate(base_model, dataloader))
|
230 |
+
results["ft loss"] = sum(evaluate(ft_model, dataloader))
|
231 |
+
print("losses evaluated")
|
232 |
+
|
233 |
+
results["non-aligned test stat"] = test_stat(base_model, ft_model)
|
234 |
+
|
235 |
+
print("non-aligned stat computed")
|
236 |
+
|
237 |
+
if args.align is True:
|
238 |
+
align_model(base_model, ft_model, ft_model)
|
239 |
+
results["aligned test stat"] = test_stat(base_model, ft_model)
|
240 |
+
print("aligned stat computed")
|
241 |
+
|
242 |
+
end = timeit.default_timer()
|
243 |
+
results["time"] = end - start
|
244 |
+
|
245 |
+
print(results)
|
246 |
+
pickle.dump(results, open(args.save, "wb"))
|
model-tracing/requirements-dev.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
black==24.1.1
|
2 |
+
ruff==0.1.8
|
3 |
+
pre-commit==3.5.0
|
4 |
+
nbqa==1.7.1
|
5 |
+
ipykernel==6.29.0
|
model-tracing/requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.31.0
|
2 |
+
data_processing==0.1.4
|
3 |
+
datasets==2.18.0
|
4 |
+
huggingface-hub==0.27.1
|
5 |
+
matplotlib==3.8.4
|
6 |
+
numpy==1.26.4
|
7 |
+
pandas==2.2.2
|
8 |
+
PyYAML==6.0.1
|
9 |
+
PyYAML==6.0.1
|
10 |
+
scipy==1.13.1
|
11 |
+
torch==2.2.2
|
12 |
+
tqdm==4.66.2
|
13 |
+
transformers==4.40.0
|
14 |
+
sentencepiece==0.2.0
|
15 |
+
protobuf==5.27.1
|
16 |
+
zstandard==0.22.0
|
17 |
+
ipdb==0.13.13
|
model-tracing/results/jsd/model_pairs_jsd.csv
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model Pair,jsd_wikitext_batch, jsd_test_input
|
2 |
+
meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,1.9561830413294956e-06,3.245754214731278e-06
|
3 |
+
meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,1.729454925225582e-05,2.0375082385726273e-05
|
4 |
+
meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,2.1560920231422642e-06,1.0297326298314147e-06
|
5 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,1.5276991689461283e-06,4.193487256998196e-06
|
6 |
+
meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,2.434112047922099e-06,4.357912075647619e-06
|
7 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,3.127033551209024e-06,3.4884374144894537e-06
|
8 |
+
meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,2.6865156996791484e-06,4.245463969709817e-06
|
9 |
+
meta-llama/Llama-2-7b-hf vs LLM360/Amber,1.8514610928832553e-06,3.480401119304588e-06
|
10 |
+
codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,1.7027690773829818e-05,1.983477886824403e-05
|
11 |
+
codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,2.795685531964409e-06,3.3317055567749776e-06
|
12 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,3.2359234864998143e-06,4.885899215878453e-06
|
13 |
+
codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,1.8831561874321778e-06,4.028150215162896e-06
|
14 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,3.955687134293839e-06,4.28245175498887e-06
|
15 |
+
codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,3.917118647223106e-06,5.343406428437447e-06
|
16 |
+
codellama/CodeLlama-7b-hf vs LLM360/Amber,2.043052063527284e-06,3.20175354318053e-06
|
17 |
+
openlm-research/open_llama_7b vs huggyllama/llama-7b,1.7465732526034117e-05,2.0448682334972546e-05
|
18 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,1.7609712813282385e-05,2.049213435384445e-05
|
19 |
+
openlm-research/open_llama_7b vs EleutherAI/llemma_7b,1.706841794657521e-05,2.0151202988927253e-05
|
20 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,1.7788930563256145e-05,2.0416322513483465e-05
|
21 |
+
openlm-research/open_llama_7b vs microsoft/Orca-2-7b,1.767633148119785e-05,2.018710074480623e-05
|
22 |
+
openlm-research/open_llama_7b vs LLM360/Amber,1.7064834537450224e-05,2.0156170648988336e-05
|
23 |
+
huggyllama/llama-7b vs lmsys/vicuna-7b-v1.5,3.4697009141382296e-06,4.354528755357023e-06
|
24 |
+
huggyllama/llama-7b vs EleutherAI/llemma_7b,3.1778936318005435e-06,4.261534741090145e-06
|
25 |
+
huggyllama/llama-7b vs lmsys/vicuna-7b-v1.1,1.5405695421577548e-06,2.4467876755807083e-06
|
26 |
+
huggyllama/llama-7b vs microsoft/Orca-2-7b,4.071262083016336e-06,4.621022526407614e-06
|
27 |
+
huggyllama/llama-7b vs LLM360/Amber,2.3568713913846295e-06,3.4077993404935114e-06
|
28 |
+
lmsys/vicuna-7b-v1.5 vs EleutherAI/llemma_7b,3.499165813991567e-06,5.325471647665836e-06
|
29 |
+
lmsys/vicuna-7b-v1.5 vs lmsys/vicuna-7b-v1.1,3.2389764328399906e-06,3.7683282698708354e-06
|
30 |
+
lmsys/vicuna-7b-v1.5 vs microsoft/Orca-2-7b,2.189671249652747e-06,3.5288214803586015e-06
|
31 |
+
lmsys/vicuna-7b-v1.5 vs LLM360/Amber,2.9944339985377155e-06,4.550915946310852e-06
|
32 |
+
EleutherAI/llemma_7b vs lmsys/vicuna-7b-v1.1,4.045330570079386e-06,5.489064733410487e-06
|
33 |
+
EleutherAI/llemma_7b vs microsoft/Orca-2-7b,3.952619408664759e-06,6.314553502306808e-06
|
34 |
+
EleutherAI/llemma_7b vs LLM360/Amber,2.333970769541338e-06,3.391487553017214e-06
|
35 |
+
lmsys/vicuna-7b-v1.1 vs microsoft/Orca-2-7b,3.713232445079484e-06,4.426544364832807e-06
|
36 |
+
lmsys/vicuna-7b-v1.1 vs LLM360/Amber,3.438890416873619e-06,4.203274784231326e-06
|
37 |
+
microsoft/Orca-2-7b vs LLM360/Amber,3.5631983337225392e-06,5.076844900031574e-06
|
model-tracing/results/l2/model_pairs_l2.csv
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model Pair,l2
|
2 |
+
meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,93.18718912197232
|
3 |
+
meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,122.27028296821305
|
4 |
+
meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,122.27950493986255
|
5 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,4.1913064124248285
|
6 |
+
meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,88.74339722102076
|
7 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,122.32382946735395
|
8 |
+
meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,5.711168968966263
|
9 |
+
meta-llama/Llama-2-7b-hf vs LLM360/Amber,148.0270618556701
|
10 |
+
codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,140.50532547577853
|
11 |
+
codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,140.80544442041523
|
12 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,93.25281141868513
|
13 |
+
codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,49.639849790592784
|
14 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,140.8399383650519
|
15 |
+
codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,93.11375432525952
|
16 |
+
codellama/CodeLlama-7b-hf vs LLM360/Amber,163.1440311418685
|
17 |
+
openlm-research/open_llama_7b vs huggyllama/llama-7b,131.92685513316152
|
18 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,122.38108086340206
|
19 |
+
openlm-research/open_llama_7b vs EleutherAI/llemma_7b,133.5973724048443
|
20 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,131.96828017611685
|
21 |
+
openlm-research/open_llama_7b vs microsoft/Orca-2-7b,120.33676200259515
|
22 |
+
openlm-research/open_llama_7b vs LLM360/Amber,156.87263745704468
|
23 |
+
huggyllama/llama-7b vs lmsys/vicuna-7b-v1.5,122.38058419243987
|
24 |
+
huggyllama/llama-7b vs EleutherAI/llemma_7b,134.03754865916954
|
25 |
+
huggyllama/llama-7b vs lmsys/vicuna-7b-v1.1,3.2305828500859106
|
26 |
+
huggyllama/llama-7b vs microsoft/Orca-2-7b,120.63673226643598
|
27 |
+
huggyllama/llama-7b vs LLM360/Amber,156.8477233676976
|
28 |
+
lmsys/vicuna-7b-v1.5 vs EleutherAI/llemma_7b,88.80338316392734
|
29 |
+
lmsys/vicuna-7b-v1.5 vs lmsys/vicuna-7b-v1.1,122.41473367697594
|
30 |
+
lmsys/vicuna-7b-v1.5 vs microsoft/Orca-2-7b,6.786975900194637
|
31 |
+
lmsys/vicuna-7b-v1.5 vs LLM360/Amber,148.06894329896906
|
32 |
+
EleutherAI/llemma_7b vs lmsys/vicuna-7b-v1.1,134.06379757785467
|
33 |
+
EleutherAI/llemma_7b vs microsoft/Orca-2-7b,88.6362254000865
|
34 |
+
EleutherAI/llemma_7b vs LLM360/Amber,156.21647923875432
|
35 |
+
lmsys/vicuna-7b-v1.1 vs microsoft/Orca-2-7b,120.66557634083046
|
36 |
+
lmsys/vicuna-7b-v1.1 vs LLM360/Amber,156.87929553264604
|
37 |
+
microsoft/Orca-2-7b vs LLM360/Amber,145.12435121107268
|
model-tracing/results/perm/permutation_l2_updated_midpoint_wikitext_single.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model Pair,l2 p-value,unpermuted l2,permuted l2s
|
2 |
+
meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,0.01,93.18718912197232,132.05157871972318,132.05082179930795,132.0419009515571,132.04476643598616,132.03995458477507,132.04484753460207,132.04233347750866,132.04203611591694,132.02357266435988,132.0313310986159,132.04133326124568,132.0299794550173,132.041035899654,132.03403438581316,132.04498269896195,132.05074070069205,132.03608888408306,132.02641111591694,132.0261678200692,132.03270977508652,132.0413062283737,132.04349589100346,132.04060337370242,132.0303849480969,132.04200908304497,132.0373053633218,132.04376621972318,132.04279303633217,132.03746756055364,132.048821366782,132.05063256920414,132.04917279411765,132.03687283737025,132.03990051903114,132.03938689446366,132.03616998269896,132.0414143598616,132.04460423875432,132.0300605536332,132.04928092560553,132.0401438148789,132.047848183391,132.03941392733563,132.0504974048443,132.0397653546713,132.0411169982699,132.05222750865053,132.05025410899654,132.04311743079586,132.04152249134947,132.0435499567474,132.03844074394465,132.03768382352942,132.033196366782,132.04184688581316,132.0558499134948,132.0356563581315,132.0525519031142,132.03611591695503,132.03911656574394,132.04692906574394,132.0583910034602,132.04273897058823,132.04636137543253,132.04955125432525,132.0530384948097,132.0344669117647,132.03471020761245,132.033196366782,132.04108996539793,132.04849697231833,132.02995242214533,132.0433607266436,132.03146626297578,132.03838667820068,132.024410683391,132.0287089100346,132.0296820934256,132.03854887543253,132.04598291522493,132.0377919550173,132.04895653114187,132.0298983564014,132.0441176470588,132.02065311418684,132.03654844290656,132.0482266435986,132.02643814878894,132.04836180795849,132.04006271626298,132.03579152249134,132.03660250865053,132.05247080449826,132.05047037197232,132.03965722318338,132.03441284602076,132.05060553633217,132.03714316608998,132.03862997404843,132.04200908304497
|
3 |
+
meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,0.7,122.27028296821305,122.26932989690722,122.27398786512028,122.27390732388317,122.27303479381443,122.27692762027492,122.26056432560138,122.26719555412372,122.26647068298969,122.27483354810997,122.27200118127148,122.26848421391753,122.28572003865979,122.28478039089347,122.27780015034364,122.2675445661512,122.2703903565292,122.26941043814433,122.26728951890034,122.27903511597938,122.2711689218213,122.2724441580756,122.26934332044674,122.27316902920963,122.27233676975945,122.27694104381443,122.2737193943299,122.26997422680412,122.27283344072166,122.26644383591065,122.26551761168385,122.27107495704468,122.27491408934708,122.2752899484536,122.27907538659794,122.28435083762886,122.25655068728523,122.270591709622,122.28267289518901,122.27515571305842,122.2756389604811,122.2705245919244,122.26542364690722,122.27439057130584,122.27637725515464,122.28272658934708,122.26685996563575,122.28712951030928,122.28406894329896,122.2765651847079,122.28468642611683,122.27808204467354,122.26298056271477,122.26336984536083,122.27016215635739,122.2815318943299,122.28082044673539,122.27277974656357,122.27076621563575,122.27816258591065,122.2646316580756,122.27404155927834,122.27562553694158,122.2571010524055,122.26389336340206,122.268470790378,122.26716870704468,122.2803908934708,122.27643094931271,122.28996187714776,122.27609536082474,122.28057882302406,122.2712360395189,122.2636383161512,122.27245758161511,122.2852904853952,122.27440399484536,122.27890088058419,122.26483301116839,122.27869952749141,122.2716655927835,122.28151847079037,122.2673297895189,122.27923646907216,122.27596112542955,122.27347777061856,122.28584085051547,122.27966602233677,122.27594770189003,122.26401417525773,122.27174613402062,122.28414948453609,122.28114261168385,122.25987972508591,122.27652491408935,122.27314218213058,122.27409525343643,122.27867268041237,122.26774591924399,122.26245704467354,122.27077963917526
|
4 |
+
meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,0.54,122.27950493986255,122.27486039518901,122.28463273195877,122.2864314862543,122.27872637457045,122.27867268041237,122.27773303264605,122.27649806701031,122.28712951030928,122.27754510309278,122.27550472508591,122.28747852233677,122.28261920103093,122.270591709622,122.27614905498282,122.28127684707904,122.270591709622,122.27958548109966,122.27929016323024,122.27051116838489,122.27708870274914,122.2786189862543,122.27534364261169,122.27955863402062,122.27829682130584,122.28205541237114,122.27977341065292,122.26401417525773,122.2897068298969,122.28565292096219,122.26490012886597,122.27426975945018,122.29236469072166,122.28436426116839,122.2882033934708,122.28114261168385,122.28251181271477,122.2842300257732,122.2813842353952,122.28860609965636,122.2909149484536,122.2908344072165,122.2858676975945,122.28393470790378,122.27354488831615,122.27963917525773,122.2774108676976,122.29558634020619,122.27974656357388,122.28812285223368,122.29461984536083,122.28704896907216,122.2715850515464,122.26608140034364,122.26197379725086,122.27692762027492,122.28788122852234,122.27958548109966,122.2872100515464,122.27775987972508,122.27550472508591,122.28205541237114,122.28436426116839,122.27990764604812,122.27604166666667,122.27714239690722,122.27389390034364,122.28125,122.27843105670104,122.27241731099656,122.28664626288659,122.26771907216495,122.2760685137457,122.28774699312714,122.28259235395188,122.28484750859107,122.28433741408935,122.27013530927834,122.28479381443299,122.28551868556701,122.27765249140893,122.26758483676976,122.28814969931271,122.27896799828179,122.28149162371135,122.28576030927834,122.29311640893471,122.28299506013745,122.2726589347079,122.27891430412372,122.27445768900344,122.27843105670104,122.27268578178695,122.2838810137457,122.27996134020619,122.2746456185567,122.27453823024055,122.29010953608247,122.2850891323024,122.27300794673539,122.28710266323024
|
5 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,4.1913064124248285,111.40988777920963,111.4162639604811,111.4112972508591,111.4123577104811,111.40202158505154,111.4141296176976,111.38573883161511,111.3972964991409,111.42339185996563,111.4049747637457,111.39443728522336,111.4005718427835,111.40771316580756,111.4324527491409,111.40414250429554,111.43137886597938,111.40233032646049,111.41706937285224,111.44485609965636,111.42356636597938,111.41751234965636,111.38518846649484,111.38039626288659,111.41241140463917,111.41616999570446,111.41276041666667,111.40418277491409,111.42352609536083,111.39218213058419,111.41365979381443,111.40744469501718,111.40552512886597,111.3884235395189,111.4141296176976,111.4035518685567,111.4170425257732,111.4012161726804,111.41414304123711,111.40931056701031,111.43415753865979,111.41148518041237,111.41082742697594,111.41332420532646,111.40336393900344,111.4093776847079,111.41325708762886,111.39212843642612,111.36345575601375,111.41567332474227,111.42532484965636,111.42834514604812,111.41506926546391,111.41302888745705,111.39298754295532,111.41467998281787,111.41370006443299,111.41050526202748,111.39920264175258,111.41560620704468,111.41368664089347,111.39293384879726,111.42366033075601,111.38220844072166,111.42277437714776,111.41124355670104,111.39641054553265,111.42132463487972,111.41400880584193,111.41450547680412,111.39639712199313,111.40961930841924,111.40094770189003,111.40665270618557,111.41680090206185,111.40949849656357,111.426841709622,111.40143094931271,111.42231797680412,111.42061318728523,111.41054553264605,111.41502899484536,111.4170425257732,111.4147068298969,111.42254617697594,111.41698883161511,111.43566097508591,111.41457259450172,111.42980831185567,111.41288122852234,111.41047841494846,111.4120086984536,111.4169754080756,111.40952534364261,111.42842568728523,111.40661243556701,111.39262510738831,111.39547089776632,111.40763262457045,111.42008966924399,111.42800955756013
|
6 |
+
meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,0.01,88.74339722102076,124.33972210207612,124.33282871972318,124.33428849480968,124.32993620242215,124.34894031141869,124.3295847750865,124.3283142301038,124.33209883217994,124.32628676470588,124.31741998269896,124.34288494809688,124.34204692906575,124.32566500865052,124.33596453287197,124.34515570934256,124.35004865916954,124.33196366782006,124.32552984429066,124.34983239619378,124.32985510380622,124.33674848615917,124.33604563148789,124.33090938581314,124.3409115484429,124.34493944636678,124.33155817474048,124.33477508650519,124.34188473183391,124.34058715397924,124.3288008217993,124.33655925605537,124.33631596020761,124.32899005190312,124.35085964532873,124.33961397058823,124.32536764705883,124.32928741349481,124.33980320069205,124.3333964100346,124.31693339100346,124.33461288927336,124.33564013840831,124.3396410034602,124.33450475778547,124.34566933391004,124.33461288927336,124.33599156574394,124.327151816609,124.33820826124567,124.32677335640139,124.34280384948097,124.34139814013841,124.33431552768167,124.32274545847751,124.33431552768167,124.34210099480968,124.32774653979239,124.34764273356402,124.32088019031141,124.33585640138408,124.33215289792388,124.33331531141869,124.329098183391,124.33423442906575,124.33353157439447,124.31585207612457,124.35286007785467,124.32485402249135,124.33682958477509,124.32361051038062,124.33353157439447,124.33666738754326,124.32580017301038,124.33420739619378,124.3295847750865,124.34607482698962,124.33664035467127,124.34556120242215,124.3423713235294,124.34212802768167,124.32174524221453,124.33069312283737,124.34088451557093,124.32277249134948,124.33623486159169,124.33377487024221,124.32963884083046,124.34593966262976,124.34034385813149,124.32777357266436,124.33358564013841,124.32085315743944,124.34069528546713,124.3389651816609,124.32885488754326,124.33047685986159,124.329098183391,124.3323150951557,124.3367214532872,124.32131271626298
|
7 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,0.31,122.32382946735395,122.33115871993127,122.31719823883161,122.31287585910653,122.3221649484536,122.3154800257732,122.32098367697594,122.3173324742268,122.31561426116839,122.32866194158076,122.30825816151203,122.32503758591065,122.31749355670104,122.31824527491409,122.31663445017182,122.31043277491409,122.31840635738831,122.32970897766323,122.31829896907216,122.32447379725086,122.32071520618557,122.3299774484536,122.3195339347079,122.30586877147766,122.31803049828179,122.31980240549828,122.33448775773196,122.32780283505154,122.3230777491409,122.31652706185567,122.32624570446735,122.3242858676976,122.30023088487972,122.32345360824742,122.31486254295532,122.3313198024055,122.32718535223368,122.33652813573883,122.31835266323024,122.32868878865979,122.32178908934708,122.32586984536083,122.30815077319588,122.31905068728523,122.31558741408935,122.31650021477664,122.3141376718213,122.31190936426117,122.32530605670104,122.32718535223368,122.31322487113403,122.32334621993127,122.31284901202748,122.31556056701031,122.30307667525773,122.32050042955326,122.32272873711341,122.3085266323024,122.32079574742268,122.31180197594502,122.30774806701031,122.32423217353951,122.31652706185567,122.3191043814433,122.31539948453609,122.32063466494846,122.30337199312714,122.30900987972508,122.32503758591065,122.31416451890034,122.32079574742268,122.31792310996563,122.33027276632302,122.32468857388317,122.32560137457045,122.30543921821305,122.3055466065292,122.31765463917526,122.32299720790378,122.31762779209622,122.31207044673539,122.32858140034364,122.32624570446735,122.32608462199313,122.32267504295532,122.31671499140893,122.3194533934708,122.32541344501718,122.3225139604811,122.3261383161512,122.3116408934708,122.3131443298969,122.33674291237114,122.32710481099656,122.3195339347079,122.32288981958763,122.3251449742268,122.31478200171821,122.3196681701031,122.30903672680412,122.31709085051547
|
8 |
+
meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,0.01,5.711168968966263,109.73838938148789,109.76823367214533,109.74444474480968,109.75990754757785,109.75594723183391,109.74075475778547,109.76926092128028,109.75613646193771,109.7564473399654,109.75486591695501,109.73783520761246,109.75616349480968,109.73333423442907,109.7333072015571,109.7589884299308,109.74649924307958,109.7417009083045,109.74762110726644,109.7486348399654,109.74178200692042,109.75535250865052,109.74284980536332,109.75802876297578,109.74459342560553,109.7354427984429,109.7502027465398,109.74397166955018,109.75358185553634,109.74764814013841,109.74810769896193,109.73225291955018,109.74786440311419,109.74144409602076,109.72913062283737,109.73580774221453,109.7557580017301,109.745120566609,109.73607807093425,109.74572880622837,109.75629865916954,109.75543360726644,109.74751297577855,109.73726751730104,109.76432742214533,109.74180903979239,109.74395815311419,109.73598345588235,109.75665008650519,109.75124351211073,109.7499053849481,109.76511137543253,109.74045739619378,109.74457990916954,109.74847264273356,109.74220101643598,109.7159115484429,109.75601481401384,109.73921388408304,109.73537521626298,109.75866403546713,109.7580963451557,109.74631001297578,109.74312013408304,109.7440392517301,109.76505730968859,109.74194420415225,109.76772004757785,109.73433445069205,109.73845696366782,109.75387921712803,109.74189013840831,109.75978589965398,109.73451016435986,109.73171226211073,109.71661440311419,109.73563202854672,109.74274167387543,109.74071420847751,109.73142841695501,109.74034926470588,109.73146896626298,109.72857644896193,109.74256596020761,109.73361807958477,109.76313797577855,109.73498323961938,109.74025464965398,109.75832612456747,109.76293522923875,109.76716587370242,109.72335910467127,109.75854238754326,109.73529411764706,109.74763462370242,109.7467695717993,109.76869323096886,109.74939176038062,109.73225291955018,109.74653979238754,109.75794766435986
|
9 |
+
meta-llama/Llama-2-7b-hf vs LLM360/Amber,0.2,148.0270618556701,148.01374570446737,148.0257731958763,148.03135738831614,148.01331615120276,148.01460481099656,148.00644329896906,148.03823024054984,148.02276632302406,148.03221649484536,148.01546391752578,148.04252577319588,148.02405498281786,148.0176116838488,148.02190721649484,148.01202749140893,148.0266323024055,148.0158934707904,148.0189003436426,148.01847079037802,148.0356529209622,148.02061855670104,148.01245704467354,148.01159793814432,148.01675257731958,148.0180412371134,148.03264604810997,148.0287800687285,148.02018900343643,148.0257731958763,148.01460481099656,148.01202749140893,148.02018900343643,148.01374570446737,148.02061855670104,148.02448453608247,148.02405498281786,148.01460481099656,148.0287800687285,148.01503436426117,148.02362542955328,148.01847079037802,148.03049828178695,148.00644329896906,148.0098797250859,148.01417525773195,148.0171821305842,148.00644329896906,148.03049828178695,148.0081615120275,148.0266323024055,148.02061855670104,148.01675257731958,148.02491408934708,148.0270618556701,148.01675257731958,148.03908934707903,148.0274914089347,148.01675257731958,148.0158934707904,148.03307560137458,148.0090206185567,148.01331615120276,148.0171821305842,148.01503436426117,148.01374570446737,148.01503436426117,148.01546391752578,148.02018900343643,148.03393470790377,148.03608247422682,148.0356529209622,148.03092783505156,148.0266323024055,148.02362542955328,148.02061855670104,148.02319587628867,148.0158934707904,148.01159793814432,148.02233676975945,148.01331615120276,148.0270618556701,148.0171821305842,148.02362542955328,148.03178694158075,148.01116838487974,148.0262027491409,148.03049828178695,148.02362542955328,148.0270618556701,148.0257731958763,148.0257731958763,147.99785223367698,148.02190721649484,148.01503436426117,148.0253436426117,148.02233676975945,148.02405498281786,148.0017182130584,148.0004295532646,148.03393470790377
|
10 |
+
codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,0.38,140.50532547577853,140.5032033953287,140.5064608564014,140.51106996107268,140.49125486591694,140.50667711937717,140.50873161764707,140.50923172577853,140.49525573096886,140.5012840614187,140.50564987024222,140.49528276384083,140.4923902465398,140.50493349913495,140.4986483564014,140.50863700259515,140.50363592128028,140.5041089965398,140.49612078287197,140.49152519463667,140.5005001081315,140.517584883218,140.50701503027682,140.49028168252596,140.49904033304497,140.50935337370242,140.50778546712803,140.49187662197232,140.49671550605535,140.5021355968858,140.50373053633217,140.49620188148788,140.49144409602076,140.50054065743944,140.5080017301038,140.50940743944636,140.51417874134947,140.49390408737025,140.50521734429066,140.51169171712803,140.4971750648789,140.5050686634948,140.48786224048442,140.50223021193773,140.5075151384083,140.5000810986159,140.51462478373702,140.50389273356402,140.50124351211073,140.50512272923876,140.5103671064014,140.50636624134947,140.50823150951558,140.50950205449826,140.50141922577853,140.5096642517301,140.50913711072664,140.49066014273356,140.51182688148788,140.5005812067474,140.4827124783737,140.51116457612457,140.49401221885813,140.50786656574394,140.5064473399654,140.49655330882354,140.50082450259515,140.5028384515571,140.49266057525952,140.49933769463667,140.506825800173,140.4949178200692,140.50093263408306,140.5046902032872,140.49213343425606,140.50385218425606,140.49698583477507,140.50163548875432,140.50366295415225,140.5062581098616,140.5017706531142,140.49616133217992,140.51032655709344,140.51282709775086,140.5035007569204,140.50894788062283,140.5027573529412,140.49937824394465,140.50121647923876,140.5116376513841,140.49831044550174,140.50066230536333,140.49010596885813,140.49332288062283,140.50817744377161,140.50729887543253,140.5018517517301,140.48883542387543,140.5067447015571,140.509772383218,140.49279573961937
|
11 |
+
codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,0.69,140.80544442041523,140.79292820069205,140.80676903114187,140.8191500865052,140.80906682525952,140.81752811418684,140.80882352941177,140.80971561418684,140.8163116349481,140.82358347750866,140.81920415224914,140.80328179065745,140.80536332179932,140.81898788927336,140.8016598183391,140.82015030276816,140.7961721453287,140.8296929065744,140.82950367647058,140.8107428633218,140.80736375432525,140.80855320069205,140.81850129757785,140.806633866782,140.81187824394465,140.8059580449827,140.8185553633218,140.81771734429066,140.80606617647058,140.8207179930796,140.82880082179932,140.81477076124568,140.80079476643598,140.78873810553634,140.81606833910035,140.80755298442907,140.81393274221455,140.8020653114187,140.80160575259515,140.81898788927336,140.81304065743944,140.80717452422147,140.80017301038063,140.81909602076124,140.79106293252596,140.8329098183391,140.80252487024222,140.80322772491348,140.81925821799308,140.80417387543253,140.80387651384083,140.80941825259515,140.79379325259515,140.79952422145328,140.81842019896195,140.81255406574394,140.81339208477507,140.80582288062283,140.81477076124568,140.81655493079586,140.7884948096886,140.806660899654,140.82723291522493,140.8205017301038,140.8002000432526,140.81022923875432,140.8186634948097,140.80133542387543,140.80433607266437,140.80890462802768,140.818366133218,140.797848183391,140.81477076124568,140.7946312716263,140.80149762110727,140.80025410899654,140.81441933391002,140.80441717128028,140.81696042387543,140.81031033737025,140.81509515570934,140.80628243944636,140.79625324394465,140.80841803633217,140.7991187283737,140.79036007785467,140.82501621972318,140.8176903114187,140.81168901384083,140.81604130622839,140.81182417820068,140.8177714100346,140.80714749134947,140.81068879757785,140.806660899654,140.81433823529412,140.80912089100346,140.79500973183391,140.81360834775086,140.79879433391002,140.81222967128028
|
12 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,93.25281141868513,132.11570069204151,132.1188365051903,132.12359429065745,132.12824394463667,132.1175929930796,132.10329260380624,132.12218858131487,132.11305147058823,132.12135056228374,132.1322448096886,132.12097210207614,132.11245674740485,132.10715830449826,132.10064338235293,132.12210748269896,132.11791738754326,132.133785683391,132.12700043252596,132.1134839965398,132.1257028546713,132.13619160899654,132.1047794117647,132.11937716262975,132.13305579584775,132.1193230968858,132.10007569204151,132.1053741349481,132.11886353806227,132.1098615916955,132.10640138408306,132.13029844290656,132.10250865051904,132.1098615916955,132.11059147923876,132.1089695069204,132.1121053200692,132.11851211072664,132.1159169550173,132.1171064013841,132.11802551903114,132.1390570934256,132.12026924740485,132.11645761245674,132.1136732266436,132.11813365051904,132.11721453287197,132.13062283737025,132.12235077854672,132.12448637543253,132.1038062283737,132.11167279411765,132.10710423875432,132.12262110726644,132.11440311418684,132.119160899654,132.1130785034602,132.12121539792386,132.12527032871972,132.12278330449826,132.1101589532872,132.12410791522493,132.11275410899654,132.1047794117647,132.10396842560553,132.10315743944636,132.11059147923876,132.111321366782,132.11634948096886,132.12654087370242,132.11042928200692,132.11470047577853,132.1305687716263,132.11169982698962,132.10504974048442,132.1301903114187,132.1169982698962,132.1234320934256,132.12721669550174,132.11607915224914,132.11794442041523,132.11359212802768,132.11872837370242,132.12024221453288,132.12140462802768,132.11440311418684,132.11862024221455,132.10921280276816,132.12153979238755,132.1078070934256,132.10310337370242,132.11770112456747,132.12278330449826,132.10396842560553,132.12472967128028,132.11218641868513,132.12935229238755,132.12318879757785,132.12299956747404,132.12332396193773,132.11051038062283
|
13 |
+
codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,0.01,49.639849790592784,145.56556056701032,145.56373496563575,145.55355992268042,145.55103629725085,145.5629295532646,145.5602448453608,145.55651310137458,145.53490120274915,145.53055197594503,145.56274162371133,145.55941258591065,145.57844716494844,145.5524323453608,145.5569695017182,145.5328071305842,145.53490120274915,145.5290485395189,145.53739798109964,145.53420317869416,145.55670103092783,145.5446198453608,145.5487811426117,145.55345253436425,145.54733140034364,145.53326353092783,145.56346649484536,145.56523840206185,145.56966817010309,145.54853951890036,145.54717031786942,145.56961447594503,145.52821628006873,145.56534579037802,145.54408290378007,145.53423002577318,145.563922895189,145.57444695017182,145.55600300687286,145.54821735395188,145.55133161512026,145.53570661512026,145.5597347508591,145.54706292955328,145.55917096219932,145.52679338487974,145.548995919244,145.52665914948454,145.54008268900344,145.5602985395189,145.55500966494844,145.55087521477662,145.53527706185568,145.53976052405497,145.53527706185568,145.53608247422682,145.55750644329896,145.535518685567,145.55976159793815,145.53127684707903,145.51430949312714,145.54580111683848,145.54937177835052,145.5565399484536,145.5551707474227,145.53307560137458,145.57366838487974,145.55793599656357,145.5544727233677,145.54351911512026,145.54330433848799,145.5429016323024,145.5487274484536,145.54754617697594,145.560379080756,145.55310352233678,145.53726374570448,145.55898303264604,145.55004295532646,145.53570661512026,145.5592246563574,145.53978737113403,145.5659901202749,145.5630637886598,145.55355992268042,145.53261920103094,145.57619201030928,145.55672787800688,145.52061855670104,145.52126288659792,145.55756013745705,145.52998818728523,145.5549022766323,145.53967998281786,145.5607817869416,145.55613724226805,145.5617214347079,145.56502362542955,145.53033719931273,145.55361361683848,145.54206937285224
|
14 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,0.95,140.8399383650519,140.86207828719722,140.84693987889273,140.85461721453288,140.85002162629758,140.86124026816609,140.8376946366782,140.865241133218,140.84069528546712,140.86096993944636,140.84996756055364,140.87175605536333,140.85621215397924,140.8480482266436,140.86261894463667,140.84896734429066,140.84964316608998,140.85499567474048,140.84645328719722,140.86207828719722,140.84885921280278,140.84391219723184,140.8503730536332,140.86086180795849,140.8394517733564,140.87008001730104,140.86613321799308,140.84953503460207,140.8507785467128,140.8595642301038,140.85450908304497,140.848642949827,140.84558823529412,140.85104887543253,140.85534710207614,140.84372296712803,140.84915657439447,140.84596669550174,140.8400464965398,140.83196366782008,140.85656358131487,140.85294117647058,140.85288711072664,140.85337370242215,140.86215938581316,140.8411278114187,140.86178092560553,140.8532115051903,140.84264165224914,140.85207612456747,140.86283520761245,140.84591262975778,140.85461721453288,140.8628892733564,140.86315960207614,140.85018382352942,140.84826448961937,140.85131920415225,140.8591857698962,140.8618079584775,140.85632028546712,140.84634515570934,140.83704584775086,140.86091587370242,140.85480644463667,140.8448313148789,140.8634839965398,140.85759083044982,140.84283088235293,140.85456314878894,140.86691717128028,140.866214316609,140.85172469723184,140.84410142733563,140.84693987889273,140.85467128027682,140.85883434256056,140.8506974480969,140.8499945934256,140.86607915224914,140.8382893598616,140.8586721453287,140.8449394463668,140.84945393598616,140.84575043252596,140.8527519463668,140.8514814013841,140.86291630622839,140.84748053633217,140.85494160899654,140.86507893598616,140.86359212802768,140.8492106401384,140.84166846885813,140.86224048442907,140.86342993079586,140.84710207612457,140.8526438148789,140.8489403114187,140.8474535034602,140.84556120242215
|
15 |
+
codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,0.01,93.11375432525952,131.67065852076124,131.69193339100346,131.69139273356402,131.68917603806227,131.67574070069205,131.68130947231833,131.6948259083045,131.68598615916954,131.68725670415225,131.7043955449827,131.68766219723184,131.6863375865052,131.69371756055364,131.69139273356402,131.6988267733564,131.69517733564012,131.7019625865052,131.70339532871972,131.700205449827,131.69136570069205,131.68625648788927,131.6990160034602,131.70631487889273,131.71239727508652,131.6953935986159,131.69055471453288,131.6905276816609,131.68923010380624,131.67936310553634,131.70290873702422,131.67730860726644,131.68549956747404,131.68901384083046,131.69387975778545,131.67738970588235,131.70188148788927,131.70558499134947,131.68817582179932,131.6808499134948,131.68460748269896,131.68412089100346,131.7046929065744,131.7000973183391,131.6867971453287,131.67525410899654,131.68512110726644,131.6942311851211,131.69809688581316,131.69828611591694,131.70444961072664,131.6944474480969,131.6895544982699,131.68376946366783,131.68122837370242,131.6807688148789,131.6996107266436,131.69442041522493,131.67679498269896,131.68549956747404,131.67936310553634,131.6877973615917,131.70393598615917,131.6775519031142,131.6896355968858,131.69742106401384,131.68279628027682,131.68371539792386,131.68171496539793,131.67941717128028,131.70258434256056,131.7061526816609,131.6964749134948,131.68547253460207,131.69166306228374,131.68774329584775,131.69171712802768,131.68220155709344,131.69133866782008,131.69382569204151,131.68093101211073,131.69166306228374,131.6867971453287,131.67341587370242,131.69182525951558,131.6847426470588,131.68952746539793,131.68403979238755,131.67484861591694,131.6941230536332,131.69590722318338,131.70999134948096,131.69585315743944,131.69598832179932,131.67325367647058,131.70104346885813,131.69874567474048,131.69347426470588,131.66998269896195,131.72056120242215,131.68279628027682
|
16 |
+
codellama/CodeLlama-7b-hf vs LLM360/Amber,0.91,163.1440311418685,163.17171280276816,163.1613321799308,163.14749134948096,163.1613321799308,163.15787197231833,163.14749134948096,163.15787197231833,163.18209342560553,163.17517301038063,163.14057093425606,163.16479238754326,163.15787197231833,163.16479238754326,163.17171280276816,163.17171280276816,163.1682525951557,163.15787197231833,163.1691176470588,163.15095155709344,163.15441176470588,163.17517301038063,163.17171280276816,163.17171280276816,163.17517301038063,163.14057093425606,163.15787197231833,163.15095155709344,163.15787197231833,163.1613321799308,163.15441176470588,163.17171280276816,163.17517301038063,163.15787197231833,163.15787197231833,163.1725778546713,163.17171280276816,163.17517301038063,163.15787197231833,163.1725778546713,163.16479238754326,163.16955017301038,163.1613321799308,163.16479238754326,163.17517301038063,163.15441176470588,163.15787197231833,163.15787197231833,163.14749134948096,163.17171280276816,163.15441176470588,163.14749134948096,163.18209342560553,163.1613321799308,163.1301903114187,163.1652249134948,163.15441176470588,163.1691176470588,163.15441176470588,163.15787197231833,163.1682525951557,163.15787197231833,163.15095155709344,163.17171280276816,163.1440311418685,163.1440311418685,163.15095155709344,163.15441176470588,163.15787197231833,163.15787197231833,163.1682525951557,163.1725778546713,163.1440311418685,163.16479238754326,163.14705882352942,163.1691176470588,163.15787197231833,163.1691176470588,163.16608996539793,163.16479238754326,163.17517301038063,163.17171280276816,163.17517301038063,163.17863321799308,163.17517301038063,163.1682525951557,163.13365051903114,163.15787197231833,163.1371107266436,163.1613321799308,163.15787197231833,163.15095155709344,163.16565743944636,163.17171280276816,163.1371107266436,163.15441176470588,163.15441176470588,163.15787197231833,163.15787197231833,163.16479238754326,163.15181660899654
|
17 |
+
openlm-research/open_llama_7b vs huggyllama/llama-7b,0.19,131.92685513316152,131.92089508161513,131.92509664948454,131.91161941580756,131.92971434707903,131.92175418814432,131.9051089991409,131.91141806271477,131.91376718213058,131.92355294243987,131.90969984965636,131.9154182774914,131.91289465206185,131.93862757731958,131.91175365120276,131.9225864475945,131.90466602233678,131.92443889604812,131.91710964347078,131.91717676116838,131.91156572164948,131.91041129725085,131.92552620274915,131.9163981958763,131.90876020189003,131.93009020618555,131.92227770618555,131.91305573453607,131.9106394974227,131.9084380369416,131.91524377147766,131.93839937714776,131.91627738402062,131.9255664733677,131.90842461340208,131.91662639604812,131.91408934707903,131.93258698453607,131.9095924613402,131.92101589347078,131.91830433848799,131.9099414733677,131.93456024484536,131.93268094931273,131.9296740764605,131.90584729381445,131.9298217353952,131.9085320017182,131.90761920103094,131.92474763745705,131.92486844931273,131.90955219072166,131.92473421391753,131.92931164089347,131.9085320017182,131.91313627577318,131.92668062714776,131.90929714347078,131.92293545962198,131.9352314218213,131.91990173969072,131.90663928264604,131.90173969072166,131.90194104381445,131.93144598367698,131.9231502362543,131.91830433848799,131.92064003436425,131.91923056271477,131.9074581185567,131.92055949312714,131.9156196305842,131.93677512886597,131.92196896477662,131.92904317010309,131.91843857388315,131.92469394329896,131.9081561426117,131.92767396907217,131.9004913015464,131.91869362113403,131.9233113187285,131.92473421391753,131.93711071735396,131.91902920962198,131.92630476804123,131.8961554982818,131.91599548969072,131.9234455541237,131.91461286512026,131.90450493986253,131.92191527061857,131.92100246993127,131.92035814003435,131.9273518041237,131.91313627577318,131.91649216065292,131.92000912800688,131.92155283505156,131.93363402061857,131.8985717353952
|
18 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,0.3,122.38108086340206,122.37894652061856,122.37337575171821,122.39488026202748,122.38780605670104,122.37921499140893,122.38881282216495,122.37757731958763,122.37446305841924,122.37242268041237,122.38428908934708,122.37128167955326,122.37140249140893,122.37792633161511,122.38223528780068,122.37481207044674,122.36418062714776,122.36261007302406,122.38744362113403,122.39500107388317,122.3808929338488,122.37532216494846,122.36401954467354,122.38010094501718,122.37800687285224,122.37418116408935,122.37768470790378,122.36528135738831,122.36638208762886,122.38074527491409,122.37836930841924,122.36595253436425,122.37355025773196,122.38543009020619,122.37176492697594,122.38426224226804,122.37867804982818,122.37700010738831,122.38626234965636,122.36165700171821,122.37559063573883,122.37510738831615,122.38540324312714,122.3731475515464,122.38097347508591,122.38815506872852,122.39007463487972,122.37310728092784,122.3875778565292,122.37246295103093,122.38780605670104,122.38430251288659,122.3854435137457,122.36652974656357,122.37838273195877,122.37607388316151,122.37062392611683,122.37498657646049,122.36797948883161,122.38755100945018,122.3907592353952,122.37144276202748,122.38014121563575,122.36886544243987,122.36836877147766,122.3718723152921,122.36828823024055,122.36466387457045,122.37164411512028,122.39407484965636,122.37930895618557,122.38269168814433,122.37654370704468,122.38055734536083,122.37729542525773,122.38000698024055,122.37559063573883,122.36236844931271,122.38403404209622,122.38416827749141,122.37020779639175,122.37145618556701,122.39795425257732,122.36867751288659,122.38254402920963,122.38759128006873,122.38109428694158,122.38639658505154,122.37050311426117,122.37686587199313,122.37998013316151,122.3759933419244,122.36941580756013,122.37030176116839,122.37702695446735,122.37258376288659,122.36850300687286,122.37001986683849,122.3825977233677,122.38007409793815,122.3877120919244
|
19 |
+
openlm-research/open_llama_7b vs EleutherAI/llemma_7b,0.91,133.5973724048443,133.58549145761245,133.58641057525952,133.58520761245674,133.58861375432525,133.58793793252596,133.59591262975778,133.59347967128028,133.58755947231833,133.58696474913495,133.591817149654,133.59662900086505,133.58185553633217,133.5969669117647,133.59514219290656,133.59375,133.58782980103805,133.5972642733564,133.58754595588235,133.5884785899654,133.58831639273356,133.57458369377161,133.58235564446366,133.5881812283737,133.60386029411765,133.5882758434256,133.58638354238755,133.59469615051904,133.5890733131488,133.59923767301038,133.59812932525952,133.59323637543253,133.59118187716263,133.59565581747404,133.5847480536332,133.59446637110727,133.5928849480969,133.58389651816609,133.58851913927336,133.57748972750866,133.58812716262975,133.5881812283737,133.58270707179932,133.58442365916954,133.58492376730104,133.60133272058823,133.59592614619376,133.60564446366783,133.59327692474048,133.59200637975778,133.58781628460207,133.58978968425606,133.59554768598616,133.58843804065745,133.58768112024222,133.5958991133218,133.58226102941177,133.58912737889273,133.58693771626298,133.58226102941177,133.59546658737025,133.58423442906573,133.5899248486159,133.58745134083046,133.5714749134948,133.58942474048442,133.594723183391,133.58489673442907,133.59077638408306,133.59118187716263,133.58484266868513,133.59153330449826,133.58070663927336,133.58363970588235,133.5919928633218,133.59164143598616,133.59230374134947,133.5856401384083,133.59031682525952,133.59293901384083,133.60775302768167,133.58596453287197,133.58158520761245,133.59299307958477,133.58139597750866,133.58386948529412,133.59235780709344,133.60891544117646,133.59389868079586,133.5977373486159,133.5859375,133.587910899654,133.590816933391,133.58539684256056,133.5919928633218,133.59049253892732,133.58482915224914,133.58393706747404,133.59064121972318,133.59954855103805,133.5914116565744
|
20 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,0.87,131.96828017611685,131.96799828178695,131.95217192869416,131.9627094072165,131.97382409793815,131.95213165807561,131.96915270618555,131.94882946735396,131.94952749140893,131.96253490120276,131.96785062285224,131.96084353522338,131.94980938573883,131.95834675687286,131.95994415807561,131.9536216709622,131.9409766967354,131.962722830756,131.95621241408935,131.95221219931273,131.962722830756,131.9592327104811,131.9690318943299,131.95403780068727,131.9434063573883,131.96559546821305,131.95061479810997,131.9673136812715,131.9615684063574,131.94076192010309,131.96167579467354,131.96227985395188,131.96062875859107,131.9555412371134,131.94344662800688,131.95720575601374,131.9682667525773,131.95826621563575,131.94849387886597,131.95658827319588,131.95464185996565,131.945674935567,131.9667498926117,131.9552593427835,131.9645618556701,131.97374355670104,131.9653404209622,131.97028028350516,131.9652598797251,131.96083011168386,131.9667498926117,131.97054875429552,131.9605213702749,131.94987650343643,131.9624275128866,131.96632033934708,131.96393094931273,131.95391698883162,131.95994415807561,131.969112435567,131.95430627147766,131.95754134450172,131.97320661512026,131.96383698453607,131.96599817439864,131.94372852233678,131.95054768041237,131.96021262886597,131.9595011812715,131.9708038015464,131.9568970146048,131.97076353092783,131.96960910652922,131.96513906786942,131.9573802620275,131.96154155927834,131.95418545962198,131.95394383591065,131.9495677620275,131.94018470790377,131.94351374570448,131.94017128436425,131.95194372852234,131.95993073453607,131.96125966494844,131.95344716494844,131.9577426975945,131.95673593213058,131.9472991838488,131.94257409793815,131.96105831185568,131.96955541237114,131.95891054553263,131.95352770618555,131.96414572594503,131.97213273195877,131.95521907216494,131.94825225515464,131.96821305841925,131.96038713487974,131.9682533290378
|
21 |
+
openlm-research/open_llama_7b vs microsoft/Orca-2-7b,0.07,120.33676200259515,120.35221128892734,120.34593966262976,120.33831639273356,120.33472102076125,120.3409115484429,120.34412846020761,120.34708855968859,120.34992701124567,120.35217073961938,120.34460153546713,120.34233077422145,120.34884569636678,120.34533142301038,120.35580666089966,120.33586991782006,120.34685878027682,120.34514219290658,120.34084396626298,120.34454746972318,120.34618295847751,120.34007352941177,120.35161656574394,120.35206260813149,120.34639922145328,120.35118403979239,120.34798064446366,120.33600508217994,120.34848075259515,120.3397491349481,120.34822394031141,120.34788602941177,120.34557471885813,120.34649383650519,120.34214154411765,120.35111645761246,120.3458044982699,120.346683066609,120.34006001297578,120.33628892733564,120.34562878460207,120.34156033737024,120.35415765570934,120.35277897923875,120.34381758217994,120.34957558391004,120.33919496107266,120.34773734861592,120.34558823529412,120.35090019463668,120.36459234429066,120.34373648356402,120.34089803200692,120.34887272923875,120.34576394896193,120.3403303416955,120.3382758434256,120.33942474048443,120.36328125,120.34573691608996,120.34011407871972,120.34289846453287,120.35185986159169,120.33703233131487,120.35460369809688,120.35095426038062,120.34773734861592,120.3406277032872,120.33530222750865,120.34142517301038,120.34164143598616,120.33789738321799,120.33873540224913,120.34546658737024,120.34662900086505,120.3412089100346,120.35587424307958,120.34106022923875,120.35481996107266,120.34895382785467,120.34106022923875,120.3433580233564,120.34008704584775,120.34222264273356,120.3576178633218,120.34314176038062,120.33776221885813,120.34089803200692,120.3569285250865,120.3411142949827,120.34752108564014,120.34787251297578,120.34964316608996,120.33081477076125,120.3451286764706,120.34139814013841,120.35584721020761,120.35129217128028,120.33981671712803,120.35227887110727,120.33951935553634
|
model-tracing/results/perm/permutation_loss_midpoint_wikitext_single.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model Pair,loss p-value,unpermuted loss,permuted losses
|
2 |
+
meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,0.01,2.3246312141418457,12.414887428283691,10.436723709106445,11.463930130004883,10.289920806884766,10.242432594299316,10.102094650268555,10.40760612487793,11.008111000061035,10.213240623474121,10.67473316192627,10.590666770935059,11.035821914672852,11.31821060180664,10.628593444824219,10.762565612792969,10.666658401489258,10.025702476501465,11.167806625366211,9.551054000854492,10.156463623046875,11.520145416259766,10.491896629333496,11.043285369873047,9.333309173583984,11.062390327453613,12.100906372070312,11.202133178710938,9.852943420410156,11.296211242675781,10.389328956604004,11.376590728759766,10.070525169372559,10.026528358459473,9.897621154785156,11.314379692077637,10.841796875,10.254968643188477,10.329296112060547,9.514544486999512,10.902949333190918,11.021820068359375,10.437745094299316,10.649450302124023,10.367554664611816,10.411920547485352,10.035419464111328,10.890546798706055,10.692540168762207,10.53729248046875,11.218603134155273,9.837448120117188,11.327438354492188,9.788040161132812,10.384784698486328,11.534051895141602,11.14419174194336,10.694536209106445,11.900223731994629,10.65000057220459,12.121938705444336,10.567901611328125,11.527382850646973,9.785840034484863,11.457867622375488,10.247855186462402,10.888710021972656,10.499954223632812,10.134185791015625,10.207050323486328,10.164628028869629,10.550536155700684,11.72325325012207,9.784914016723633,9.676228523254395,9.838054656982422,10.10443115234375,10.668523788452148,11.105070114135742,10.959217071533203,10.044930458068848,10.379700660705566,10.720422744750977,10.373221397399902,11.601597785949707,9.767145156860352,11.08609390258789,11.894039154052734,10.29943561553955,9.70090103149414,10.862321853637695,10.94301700592041,10.51127815246582,10.600202560424805,10.773964881896973,10.515543937683105,10.527175903320312,9.276633262634277,10.818120002746582,10.66445541381836,9.917793273925781
|
3 |
+
meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,0.64,11.575419425964355,11.40134334564209,11.22384262084961,11.46405029296875,11.452263832092285,10.455619812011719,10.975679397583008,11.58855152130127,12.451857566833496,11.924081802368164,12.242836952209473,11.277565002441406,11.563451766967773,11.68265151977539,11.848937034606934,10.44631290435791,12.024474143981934,11.507471084594727,10.45692253112793,11.344988822937012,10.991713523864746,11.589402198791504,11.106851577758789,11.331841468811035,11.409880638122559,11.256386756896973,11.77776050567627,11.199036598205566,11.516610145568848,11.03821849822998,11.872330665588379,11.290637969970703,11.223627090454102,11.23520278930664,11.499801635742188,11.516365051269531,11.601478576660156,10.872085571289062,11.779101371765137,12.180392265319824,11.032282829284668,11.411971092224121,11.81676197052002,11.883353233337402,11.750810623168945,11.176957130432129,10.957067489624023,14.970661163330078,11.218901634216309,12.29248046875,11.31027603149414,11.960100173950195,11.291650772094727,11.711487770080566,12.310693740844727,11.598724365234375,11.308028221130371,11.46600341796875,11.952092170715332,11.185503959655762,11.173068046569824,11.149730682373047,11.459930419921875,11.234939575195312,11.293815612792969,11.118258476257324,12.10870361328125,12.251635551452637,11.076282501220703,11.664109230041504,12.398612976074219,11.496159553527832,11.82258415222168,11.470304489135742,11.340746879577637,12.212625503540039,10.980117797851562,11.57983684539795,11.1240873336792,11.654163360595703,11.260462760925293,11.391934394836426,10.89322566986084,11.501547813415527,11.8900785446167,10.93996524810791,10.885272979736328,11.528347969055176,11.503525733947754,11.81087589263916,11.610458374023438,11.825965881347656,11.274349212646484,10.811087608337402,11.338240623474121,11.545049667358398,11.304219245910645,11.655720710754395,11.276208877563477,10.926770210266113,10.414559364318848
|
4 |
+
meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,0.99,12.2577486038208,11.713994979858398,11.829354286193848,11.17656421661377,10.935490608215332,10.614337921142578,11.733335494995117,10.55212116241455,11.617198944091797,11.02830982208252,11.103273391723633,10.98201847076416,11.652020454406738,11.325891494750977,11.472954750061035,11.740574836730957,11.530477523803711,11.882132530212402,11.853250503540039,11.228290557861328,11.346905708312988,11.387336730957031,11.696052551269531,11.597851753234863,10.276288032531738,10.484336853027344,11.387497901916504,11.903156280517578,10.768474578857422,11.013677597045898,11.241617202758789,11.207913398742676,11.054834365844727,11.4959716796875,10.328938484191895,11.011316299438477,11.860649108886719,10.733698844909668,11.433752059936523,11.57979965209961,10.878405570983887,10.85395336151123,11.535321235656738,11.479925155639648,11.269013404846191,10.643021583557129,10.89708423614502,10.569060325622559,10.77185344696045,11.806313514709473,11.702690124511719,11.08808708190918,10.627337455749512,11.021454811096191,11.430144309997559,10.929702758789062,11.454261779785156,11.20509147644043,10.65364933013916,11.217240333557129,11.659832000732422,10.883869171142578,11.713050842285156,10.736254692077637,11.047320365905762,11.248353004455566,11.70473861694336,10.898530006408691,10.070318222045898,10.314364433288574,10.933834075927734,10.347864151000977,11.01766586303711,10.836737632751465,10.664950370788574,10.768060684204102,10.58050537109375,10.815657615661621,10.808479309082031,11.686751365661621,10.925615310668945,11.23190975189209,11.179738998413086,11.049586296081543,10.853154182434082,11.299239158630371,12.247625350952148,10.916298866271973,11.287036895751953,11.025016784667969,11.59552001953125,11.489020347595215,10.638093948364258,11.174018859863281,10.918281555175781,11.737152099609375,10.530423164367676,10.45882797241211,10.129399299621582,11.413297653198242,11.797016143798828
|
5 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,1.927486538887024,11.821135520935059,11.198407173156738,9.731367111206055,10.24201774597168,10.990260124206543,11.304125785827637,11.099743843078613,12.419852256774902,12.37198543548584,11.15950870513916,11.141377449035645,11.544782638549805,12.145750999450684,12.45425033569336,11.949847221374512,11.726274490356445,10.694775581359863,11.242759704589844,10.818204879760742,12.982400894165039,10.319830894470215,12.521677017211914,12.44443130493164,13.007572174072266,10.896312713623047,12.76135540008545,12.305526733398438,10.95059871673584,11.4083890914917,12.70875358581543,12.232357025146484,12.533140182495117,10.787182807922363,10.945103645324707,11.467927932739258,10.401111602783203,10.44174861907959,12.008749008178711,11.72877311706543,11.764703750610352,9.260933876037598,9.660629272460938,10.801837921142578,10.224106788635254,10.381224632263184,9.673469543457031,11.380040168762207,11.231133460998535,19.4843807220459,11.832645416259766,11.092031478881836,10.962821960449219,9.960295677185059,10.633321762084961,10.486164093017578,10.409814834594727,11.982443809509277,9.964591026306152,12.886614799499512,10.435898780822754,12.111292839050293,10.96882152557373,11.224199295043945,11.8229341506958,12.327183723449707,10.941879272460938,10.337071418762207,10.541845321655273,11.256275177001953,12.46101188659668,11.085774421691895,11.33261489868164,11.308060646057129,11.747783660888672,10.688689231872559,11.49790096282959,10.940205574035645,8.749332427978516,12.804007530212402,11.88497543334961,11.572806358337402,9.896434783935547,11.37403392791748,10.527442932128906,10.58166790008545,10.856118202209473,11.154521942138672,10.329510688781738,11.657892227172852,11.611366271972656,11.373307228088379,11.529753684997559,11.080778121948242,9.911948204040527,11.649473190307617,11.339656829833984,11.394058227539062,10.780144691467285,10.756746292114258,12.14152717590332
|
6 |
+
meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,0.01,2.2617459297180176,12.073643684387207,9.601665496826172,11.438493728637695,11.859432220458984,9.136279106140137,10.389586448669434,10.537793159484863,9.534730911254883,11.314827919006348,11.605558395385742,10.475652694702148,11.510196685791016,11.53796100616455,12.012444496154785,9.623738288879395,9.299081802368164,12.455397605895996,9.597661972045898,12.382262229919434,11.051375389099121,9.97256088256836,9.617574691772461,9.902551651000977,10.14933967590332,10.723257064819336,10.306427001953125,10.507966041564941,10.786173820495605,12.057255744934082,8.83625316619873,10.997522354125977,9.18514633178711,9.580232620239258,9.388571739196777,9.756895065307617,9.360857009887695,10.979372024536133,10.720354080200195,10.257851600646973,9.46591854095459,9.886110305786133,10.586426734924316,10.423218727111816,11.755844116210938,10.80079460144043,10.800479888916016,9.471284866333008,9.391319274902344,12.033135414123535,10.548360824584961,11.953831672668457,10.858678817749023,11.68787670135498,10.111590385437012,11.390238761901855,10.074946403503418,8.820842742919922,10.89544677734375,9.980045318603516,10.665379524230957,9.60218620300293,10.32979965209961,12.539587020874023,8.944647789001465,10.680426597595215,10.794387817382812,12.351435661315918,10.521456718444824,10.917708396911621,10.16922378540039,9.72276496887207,9.371663093566895,10.629088401794434,10.897933959960938,12.550519943237305,8.850178718566895,10.177017211914062,9.906770706176758,12.711698532104492,11.127090454101562,10.410646438598633,10.431379318237305,9.942227363586426,11.459847450256348,15.255708694458008,12.270158767700195,10.828147888183594,10.433170318603516,10.952369689941406,11.425644874572754,9.552434921264648,9.059112548828125,9.621358871459961,12.671890258789062,10.086618423461914,9.930680274963379,12.442190170288086,9.548700332641602,11.214447975158691,10.612310409545898
|
7 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,0.99,12.176732063293457,12.091742515563965,9.97877025604248,11.391814231872559,10.661109924316406,10.817087173461914,11.534388542175293,11.4345064163208,11.032309532165527,11.198960304260254,10.708231925964355,10.624578475952148,10.788887977600098,10.516521453857422,11.011672019958496,11.516682624816895,11.103180885314941,11.559279441833496,11.907734870910645,11.123976707458496,9.81065559387207,10.62118911743164,11.27763843536377,11.523125648498535,10.444096565246582,11.76767349243164,11.019610404968262,11.484188079833984,11.91457748413086,10.497208595275879,11.79175853729248,11.298368453979492,10.316926956176758,10.76264762878418,11.064068794250488,10.277680397033691,10.894306182861328,11.136940002441406,10.948762893676758,10.753281593322754,11.747068405151367,11.244501113891602,11.157686233520508,11.421767234802246,11.376803398132324,10.557845115661621,11.091096878051758,11.671372413635254,11.795584678649902,10.934099197387695,11.12752628326416,11.514236450195312,10.979425430297852,10.027077674865723,11.399852752685547,11.993349075317383,11.904010772705078,10.320328712463379,11.54886531829834,11.317293167114258,10.79334831237793,11.174019813537598,11.312468528747559,12.456009864807129,11.849678039550781,11.423567771911621,11.26967716217041,11.130739212036133,11.871220588684082,11.11894416809082,10.497808456420898,11.341180801391602,11.120772361755371,11.832477569580078,11.945394515991211,10.435187339782715,10.123815536499023,11.923513412475586,11.165726661682129,10.511265754699707,10.816417694091797,11.620299339294434,11.749600410461426,11.621185302734375,10.726395606994629,10.873162269592285,11.366233825683594,11.751174926757812,11.712872505187988,11.940488815307617,11.489510536193848,11.349868774414062,11.479663848876953,11.586658477783203,10.732288360595703,10.799344062805176,10.621288299560547,11.430079460144043,11.587478637695312,11.59907054901123,10.702455520629883
|
8 |
+
meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,0.01,1.9139776229858398,11.634427070617676,11.600348472595215,11.60654067993164,12.126233100891113,12.646196365356445,12.883075714111328,12.256454467773438,12.270318984985352,12.459698677062988,10.608137130737305,12.05358600616455,11.529598236083984,10.475113868713379,11.170097351074219,11.863300323486328,10.970882415771484,10.59147834777832,12.343642234802246,12.012216567993164,11.054078102111816,10.566030502319336,11.191433906555176,11.72900390625,12.421318054199219,11.642358779907227,12.26218032836914,12.066629409790039,11.98330020904541,12.333147048950195,10.645333290100098,11.23974323272705,12.008330345153809,11.807231903076172,10.58747673034668,11.714077949523926,11.820113182067871,11.989992141723633,12.196735382080078,12.603166580200195,10.58354663848877,12.315600395202637,12.89755916595459,12.080608367919922,11.813095092773438,11.074599266052246,10.643448829650879,10.9359769821167,11.14858627319336,12.007444381713867,12.142953872680664,12.374232292175293,10.88144302368164,10.891057014465332,11.451530456542969,12.084410667419434,12.564109802246094,12.367560386657715,10.01160717010498,11.172256469726562,12.028122901916504,11.117118835449219,10.586153030395508,12.69567584991455,11.325706481933594,11.118499755859375,12.011213302612305,12.913552284240723,11.33607006072998,12.125946044921875,10.641644477844238,11.218090057373047,11.023408889770508,11.858912467956543,9.926396369934082,11.108742713928223,10.33670425415039,11.369423866271973,10.87497329711914,10.876391410827637,11.830551147460938,18.367767333984375,11.91503620147705,10.409882545471191,12.978384971618652,11.026498794555664,12.638835906982422,11.090004920959473,13.22242259979248,10.301204681396484,12.243881225585938,11.466031074523926,11.41460132598877,11.086315155029297,11.420076370239258,10.147997856140137,11.232333183288574,11.026019096374512,11.566559791564941,13.21921157836914,11.737604141235352
|
9 |
+
meta-llama/Llama-2-7b-hf vs LLM360/Amber,0.11,9.658585548400879,9.505361557006836,11.505193710327148,9.863297462463379,10.549816131591797,10.601881980895996,9.804191589355469,9.992196083068848,10.386860847473145,10.168055534362793,10.143446922302246,11.450798988342285,10.739060401916504,10.874429702758789,11.022444725036621,11.62452507019043,11.448281288146973,10.981934547424316,10.17294692993164,9.73409652709961,10.862936973571777,10.203863143920898,9.555386543273926,11.31235122680664,9.649775505065918,9.874133110046387,10.310741424560547,10.610264778137207,11.861855506896973,11.109381675720215,10.608086585998535,10.293082237243652,10.586475372314453,9.161870956420898,10.675275802612305,9.941357612609863,10.5984525680542,9.234950065612793,9.805939674377441,10.848678588867188,10.375737190246582,11.113080024719238,10.478139877319336,10.219255447387695,9.84744930267334,9.266841888427734,11.306220054626465,11.001436233520508,9.924046516418457,10.07392406463623,10.432853698730469,11.701499938964844,9.725397109985352,9.526941299438477,11.007448196411133,10.04012680053711,11.11932373046875,11.325552940368652,10.447799682617188,10.715643882751465,9.796828269958496,10.892861366271973,10.44837474822998,10.71353816986084,10.633293151855469,10.90569019317627,10.542643547058105,11.243886947631836,11.541812896728516,11.15402889251709,10.407221794128418,11.446313858032227,10.677804946899414,10.178149223327637,10.385024070739746,11.192070960998535,9.855805397033691,10.154035568237305,11.104714393615723,10.019867897033691,10.577777862548828,9.540035247802734,11.262989044189453,9.886129379272461,9.649831771850586,10.185017585754395,9.09082317352295,10.29433822631836,9.811910629272461,10.254436492919922,11.52804946899414,10.116106986999512,10.657572746276855,9.819936752319336,10.43431282043457,10.864540100097656,10.484803199768066,12.004101753234863,11.187883377075195,10.061746597290039,11.245307922363281
|
10 |
+
codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,0.65,11.237561225891113,11.680026054382324,11.398880004882812,11.86768913269043,11.39999771118164,11.541326522827148,10.19598388671875,11.428864479064941,11.862239837646484,10.481870651245117,10.985757827758789,10.969829559326172,11.367602348327637,11.111282348632812,11.022823333740234,10.561090469360352,10.207883834838867,10.694847106933594,10.99021053314209,10.593587875366211,11.172009468078613,11.665566444396973,11.263614654541016,11.552456855773926,11.507055282592773,11.336711883544922,11.074647903442383,10.79491138458252,10.656916618347168,10.548526763916016,11.022597312927246,10.966035842895508,11.283071517944336,10.308378219604492,10.846853256225586,10.849564552307129,10.968177795410156,11.308629035949707,11.001319885253906,10.848873138427734,10.73318099975586,10.859678268432617,11.509345054626465,10.539803504943848,10.935742378234863,11.4722900390625,11.278949737548828,11.9168701171875,11.39941692352295,10.587061882019043,10.908806800842285,11.912999153137207,13.913973808288574,11.417015075683594,10.803816795349121,11.063304901123047,11.088665962219238,11.479676246643066,10.848278045654297,11.193321228027344,10.839192390441895,10.731671333312988,11.18687629699707,10.452286720275879,9.904109001159668,11.235068321228027,11.758543014526367,10.346553802490234,10.477143287658691,10.51911449432373,11.152212142944336,10.974753379821777,10.7435302734375,11.238974571228027,12.068460464477539,10.960865020751953,11.370020866394043,11.462693214416504,11.207311630249023,10.798894882202148,11.14282512664795,11.745903968811035,11.410775184631348,10.994516372680664,11.088298797607422,11.085737228393555,11.031476020812988,11.000276565551758,9.836112022399902,11.110745429992676,10.450307846069336,10.72428035736084,11.212058067321777,11.340717315673828,11.261013984680176,11.80449104309082,11.24316120147705,10.83240795135498,10.822101593017578,10.707273483276367,10.825920104980469
|
11 |
+
codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,0.32,10.771794319152832,10.501058578491211,10.56005859375,11.231363296508789,10.989171028137207,11.534804344177246,12.136800765991211,11.693639755249023,10.567752838134766,11.91177749633789,11.43637466430664,11.564026832580566,11.265721321105957,11.19032096862793,10.40317153930664,10.623876571655273,10.287189483642578,10.42343521118164,10.866487503051758,12.013888359069824,10.774230003356934,11.138988494873047,11.338595390319824,11.346685409545898,10.583052635192871,10.429323196411133,10.492464065551758,11.78402328491211,10.30309009552002,11.499597549438477,12.184760093688965,11.097319602966309,11.184183120727539,11.21827507019043,10.121459007263184,11.127937316894531,10.251218795776367,11.494721412658691,10.646721839904785,10.4236478805542,11.158930778503418,9.909791946411133,10.339156150817871,11.350455284118652,10.859237670898438,10.905741691589355,12.035804748535156,11.207295417785645,10.787891387939453,11.013189315795898,11.2416410446167,10.195667266845703,11.020384788513184,11.056377410888672,11.036890029907227,10.251171112060547,10.708823204040527,11.038761138916016,10.733219146728516,10.966094017028809,10.910711288452148,10.949577331542969,11.406147003173828,11.047470092773438,9.963555335998535,10.648523330688477,11.16801643371582,11.87838363647461,11.598986625671387,10.867388725280762,12.05274486541748,11.79859733581543,11.681252479553223,10.888261795043945,11.057225227355957,10.08951187133789,11.208805084228516,10.740156173706055,11.282068252563477,11.669149398803711,11.980530738830566,10.39877700805664,11.507721900939941,11.044110298156738,11.520370483398438,10.161965370178223,11.56179428100586,10.723697662353516,10.831975936889648,11.531084060668945,10.410531997680664,11.085836410522461,12.549727439880371,11.945355415344238,10.689722061157227,10.73282241821289,11.869728088378906,12.093964576721191,12.069504737854004,11.068297386169434,11.401714324951172
|
12 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,2.302018404006958,11.172709465026855,12.183158874511719,10.882681846618652,9.767948150634766,24.42896842956543,10.943366050720215,11.75478744506836,15.235170364379883,9.498678207397461,10.326377868652344,10.187347412109375,10.607508659362793,10.732197761535645,9.70353889465332,10.948262214660645,9.788002967834473,10.9060640335083,10.285189628601074,10.656235694885254,11.046913146972656,10.868155479431152,10.477668762207031,9.764632225036621,9.605854034423828,9.753314018249512,10.699603080749512,9.823118209838867,10.523367881774902,9.758964538574219,9.701371192932129,10.879630088806152,11.618856430053711,13.337488174438477,9.767634391784668,11.983025550842285,10.607478141784668,10.280527114868164,10.127568244934082,12.975506782531738,11.579995155334473,11.203207015991211,10.376364707946777,11.423201560974121,11.209802627563477,9.359248161315918,12.826610565185547,10.373741149902344,9.53094482421875,10.389079093933105,9.205557823181152,10.689602851867676,10.414274215698242,11.192754745483398,10.581353187561035,10.557870864868164,12.383353233337402,10.521531105041504,9.70833969116211,11.450300216674805,10.21054458618164,11.530344009399414,10.224842071533203,10.37917709350586,11.495973587036133,10.469186782836914,9.77238941192627,9.884854316711426,10.01948356628418,9.838289260864258,11.130271911621094,10.648575782775879,10.440532684326172,10.789712905883789,11.31021785736084,10.711143493652344,10.09947681427002,11.175568580627441,9.295597076416016,9.529111862182617,11.034727096557617,9.88634204864502,10.545726776123047,10.429464340209961,10.972481727600098,9.342894554138184,11.735199928283691,9.926610946655273,10.638096809387207,11.533238410949707,11.10690975189209,11.147697448730469,10.123966217041016,9.807247161865234,9.888313293457031,10.178050994873047,10.791872024536133,10.62413501739502,9.422971725463867,9.87336540222168,10.158583641052246
|
13 |
+
codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,0.01,1.8881051540374756,11.02577018737793,11.569954872131348,9.650038719177246,12.788570404052734,11.758131980895996,9.877432823181152,10.481515884399414,11.251947402954102,11.191924095153809,11.158823013305664,9.50117015838623,11.658199310302734,10.220966339111328,11.156086921691895,11.29964828491211,11.514382362365723,11.962372779846191,10.005611419677734,10.904683113098145,10.441763877868652,10.319205284118652,11.140127182006836,9.660285949707031,10.85449504852295,10.57632827758789,10.08746337890625,11.422477722167969,11.425626754760742,10.719733238220215,11.198336601257324,9.705710411071777,11.163413047790527,10.196398735046387,12.473783493041992,11.222222328186035,10.624934196472168,11.31132984161377,12.355788230895996,12.014126777648926,10.066976547241211,9.984125137329102,11.545228958129883,10.415767669677734,11.360154151916504,10.375480651855469,11.507826805114746,11.202043533325195,10.294898986816406,9.63949203491211,10.868791580200195,11.535318374633789,10.94536018371582,10.92975902557373,11.955968856811523,10.462796211242676,10.789320945739746,9.42170524597168,11.375829696655273,9.957808494567871,10.590437889099121,11.810711860656738,10.596860885620117,10.35020637512207,10.983830451965332,11.902685165405273,12.088788986206055,10.547697067260742,11.154755592346191,10.342395782470703,9.408880233764648,10.353204727172852,10.178812980651855,11.717917442321777,10.343524932861328,9.107279777526855,11.441140174865723,12.437472343444824,9.229593276977539,9.995097160339355,13.198864936828613,10.668238639831543,9.889089584350586,10.291954040527344,9.115019798278809,11.644171714782715,9.289207458496094,10.126248359680176,11.402484893798828,10.980873107910156,9.706541061401367,10.721711158752441,10.652327537536621,9.227180480957031,11.187313079833984,11.266683578491211,10.421162605285645,10.27970027923584,11.191493034362793,11.59189224243164,10.865525245666504
|
14 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,0.29,10.684311866760254,9.88802433013916,11.366555213928223,10.514373779296875,11.024654388427734,10.283346176147461,11.294032096862793,11.239091873168945,11.904309272766113,11.442981719970703,9.724740028381348,10.486740112304688,10.531262397766113,10.704021453857422,11.958599090576172,11.913610458374023,10.451040267944336,11.539121627807617,11.079904556274414,10.380290985107422,10.464911460876465,10.824151039123535,10.613722801208496,10.818387031555176,10.909002304077148,11.383553504943848,11.057758331298828,10.797572135925293,11.538613319396973,10.077716827392578,11.08383846282959,10.961997985839844,10.51244068145752,10.801470756530762,11.105854034423828,10.700427055358887,10.818270683288574,10.551637649536133,10.831416130065918,11.392118453979492,11.99124813079834,11.624017715454102,11.47986125946045,11.073309898376465,10.576276779174805,11.65388298034668,10.442636489868164,10.892948150634766,10.495064735412598,11.371464729309082,12.704389572143555,10.441764831542969,11.860435485839844,10.724827766418457,10.910114288330078,9.940642356872559,11.6974515914917,10.893054008483887,10.342144012451172,10.850858688354492,11.492060661315918,10.491785049438477,12.155282974243164,11.26374340057373,10.091155052185059,10.971919059753418,10.889839172363281,9.859522819519043,11.616873741149902,11.444681167602539,10.922995567321777,11.157459259033203,10.100933074951172,10.804597854614258,10.976712226867676,11.442182540893555,11.178875923156738,11.096498489379883,11.058975219726562,11.280542373657227,11.896692276000977,11.171963691711426,10.79887580871582,11.429341316223145,11.694151878356934,12.15854549407959,11.003829002380371,11.078950881958008,11.429558753967285,11.451701164245605,11.222574234008789,11.16943073272705,10.59616756439209,9.994879722595215,11.877211570739746,12.13260555267334,10.476194381713867,9.788568496704102,9.951860427856445,12.063776016235352,10.534578323364258
|
15 |
+
codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,0.01,2.320063829421997,9.977145195007324,11.017745018005371,11.732516288757324,10.55106258392334,10.886062622070312,11.940051078796387,10.490015029907227,11.409379959106445,9.593038558959961,9.326103210449219,11.52270221710205,10.925615310668945,11.000499725341797,11.391313552856445,10.632598876953125,10.968502044677734,9.419816970825195,11.293421745300293,10.746554374694824,10.45264720916748,9.733240127563477,9.230606079101562,11.2017240524292,11.673495292663574,9.883476257324219,14.689058303833008,10.65975284576416,11.826485633850098,9.935537338256836,10.098958969116211,11.369180679321289,10.066923141479492,10.007286071777344,10.073885917663574,10.076384544372559,10.496053695678711,9.437036514282227,10.592652320861816,9.364110946655273,10.822487831115723,10.531739234924316,10.315073013305664,9.872851371765137,10.577094078063965,10.572772026062012,10.358898162841797,10.806140899658203,16.713293075561523,10.911810874938965,9.414806365966797,11.51042366027832,10.734992027282715,12.246551513671875,11.260639190673828,10.792762756347656,10.72775936126709,10.830528259277344,11.213608741760254,12.049139976501465,9.823256492614746,9.768250465393066,10.623906135559082,11.0609712600708,10.343968391418457,8.824793815612793,10.778313636779785,10.15536880493164,10.200857162475586,11.061370849609375,10.520877838134766,8.951927185058594,10.345824241638184,11.1963472366333,10.408754348754883,9.833317756652832,9.915206909179688,11.101364135742188,11.677738189697266,11.327123641967773,10.630494117736816,9.27521800994873,10.425193786621094,10.057621955871582,10.960726737976074,11.673463821411133,10.601761817932129,10.41421127319336,10.354918479919434,11.0147705078125,11.991089820861816,9.368815422058105,8.995253562927246,11.107773780822754,10.934799194335938,10.410724639892578,10.493441581726074,10.318268775939941,9.758223533630371,9.74752426147461,10.921327590942383
|
16 |
+
codellama/CodeLlama-7b-hf vs LLM360/Amber,0.01,9.603846549987793,10.922776222229004,10.763627052307129,10.674799919128418,10.944330215454102,10.779414176940918,10.923018455505371,9.871673583984375,10.166775703430176,10.810501098632812,10.085100173950195,10.842877388000488,10.339820861816406,10.418185234069824,9.797157287597656,11.204648971557617,11.264816284179688,11.797182083129883,10.71696662902832,10.645628929138184,11.203916549682617,10.131481170654297,10.013680458068848,10.839400291442871,11.125710487365723,10.923445701599121,11.895150184631348,10.997817993164062,11.030921936035156,10.626928329467773,10.994245529174805,10.47247314453125,10.476496696472168,10.737053871154785,10.940540313720703,10.545219421386719,10.65115737915039,10.538052558898926,10.73913860321045,10.608595848083496,11.778488159179688,10.629700660705566,11.484350204467773,10.728248596191406,11.018856048583984,10.530044555664062,10.73550796508789,11.228963851928711,10.074200630187988,11.778253555297852,10.75644588470459,10.916414260864258,10.685816764831543,11.022653579711914,11.845925331115723,11.053071975708008,10.963454246520996,11.932684898376465,11.024358749389648,10.258270263671875,11.10335636138916,11.186234474182129,11.418861389160156,10.6289644241333,11.077651023864746,10.978907585144043,11.048869132995605,11.65108585357666,10.4456787109375,10.570917129516602,10.814337730407715,10.753689765930176,11.621301651000977,10.818181037902832,11.07013988494873,11.709806442260742,10.042532920837402,11.51170825958252,10.287849426269531,11.025697708129883,10.106474876403809,10.335064888000488,10.63949966430664,11.414520263671875,9.983678817749023,10.194723129272461,10.093085289001465,10.348984718322754,10.452190399169922,10.118818283081055,11.143959999084473,10.194239616394043,10.520627975463867,9.854783058166504,11.553059577941895,10.919979095458984,10.652848243713379,11.446795463562012,11.503812789916992,11.831265449523926,10.060680389404297
|
17 |
+
openlm-research/open_llama_7b vs huggyllama/llama-7b,0.12,10.616796493530273,10.797248840332031,10.766523361206055,10.866601943969727,11.188263893127441,10.954423904418945,11.125080108642578,11.481340408325195,10.507387161254883,10.983829498291016,11.449170112609863,11.820627212524414,10.870183944702148,11.375383377075195,11.897957801818848,10.795915603637695,11.425924301147461,10.779943466186523,10.608522415161133,11.291844367980957,10.9775972366333,11.396665573120117,11.077995300292969,10.915323257446289,10.149968147277832,10.838491439819336,11.376053810119629,12.027775764465332,11.226425170898438,11.149484634399414,11.914155006408691,11.06672191619873,12.116933822631836,11.715675354003906,11.286850929260254,10.582368850708008,10.894996643066406,10.559850692749023,11.138714790344238,11.34919548034668,11.406577110290527,11.709674835205078,10.768847465515137,11.264957427978516,11.060179710388184,10.48735523223877,11.610673904418945,10.382858276367188,11.576861381530762,11.695962905883789,11.838735580444336,11.8810396194458,11.433664321899414,10.812098503112793,11.260259628295898,10.915751457214355,10.764631271362305,11.118894577026367,11.433198928833008,11.013741493225098,11.012408256530762,11.988961219787598,10.893074989318848,11.14090633392334,11.564040184020996,11.033992767333984,11.447514533996582,10.347433090209961,11.568410873413086,11.325068473815918,11.319805145263672,11.36147689819336,10.636981010437012,10.374312400817871,11.18582820892334,11.290904998779297,11.412924766540527,11.148937225341797,11.629731178283691,10.574426651000977,11.250195503234863,11.57218074798584,10.607427597045898,11.080497741699219,11.073441505432129,11.367008209228516,11.016765594482422,11.250604629516602,10.794018745422363,11.461509704589844,12.037816047668457,10.851967811584473,11.706341743469238,10.905144691467285,12.039907455444336,10.87762451171875,11.777667999267578,11.1010160446167,11.086259841918945,10.743986129760742,11.00861930847168
|
18 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,0.68,11.10280704498291,11.494667053222656,10.498247146606445,11.133041381835938,11.101852416992188,11.380467414855957,10.272819519042969,10.971842765808105,10.812454223632812,10.571216583251953,11.184111595153809,10.600080490112305,10.708064079284668,11.08830738067627,11.812434196472168,11.79877758026123,10.77804946899414,10.4550142288208,10.40676498413086,10.906953811645508,10.607868194580078,11.703605651855469,11.5511474609375,9.885051727294922,12.578688621520996,10.805289268493652,10.065228462219238,10.415675163269043,9.463606834411621,10.35334587097168,10.453519821166992,11.14851188659668,11.603013038635254,11.39593505859375,10.219343185424805,10.553191184997559,11.516910552978516,11.12913990020752,10.367042541503906,11.04174518585205,10.621479988098145,10.78134536743164,10.831048011779785,10.864336967468262,10.25527572631836,10.48643970489502,10.815017700195312,10.441289901733398,10.695341110229492,12.055279731750488,11.352632522583008,11.06913948059082,11.00218677520752,10.366230010986328,10.634378433227539,11.0018949508667,11.126782417297363,9.731849670410156,10.59158706665039,11.318184852600098,10.395101547241211,11.940820693969727,10.694389343261719,11.356606483459473,10.950284957885742,11.315900802612305,10.90494155883789,11.158031463623047,11.424098014831543,10.433701515197754,11.031879425048828,10.566242218017578,10.624926567077637,11.704424858093262,9.559603691101074,10.54096508026123,10.466423034667969,10.139713287353516,10.190888404846191,10.859702110290527,11.18995475769043,11.700125694274902,10.920588493347168,10.214404106140137,11.712258338928223,11.302619934082031,11.045059204101562,21.055919647216797,10.661722183227539,10.79277515411377,11.078306198120117,10.720754623413086,11.443395614624023,10.91203498840332,11.552059173583984,9.76099967956543,10.945575714111328,10.874262809753418,11.279154777526855,9.812013626098633,11.022870063781738
|
19 |
+
openlm-research/open_llama_7b vs EleutherAI/llemma_7b,0.74,11.676680564880371,10.700831413269043,10.920976638793945,11.028438568115234,12.167314529418945,11.756142616271973,11.685977935791016,10.584118843078613,11.082442283630371,11.074419975280762,11.608972549438477,12.379051208496094,10.71125316619873,10.557001113891602,11.452447891235352,11.275219917297363,11.134586334228516,11.230876922607422,12.451685905456543,10.93018913269043,11.579869270324707,11.084054946899414,12.037456512451172,11.611255645751953,11.685752868652344,11.016987800598145,11.213319778442383,11.783517837524414,11.0402250289917,10.860733032226562,11.379624366760254,11.247736930847168,10.858418464660645,11.460200309753418,11.808459281921387,11.439626693725586,12.029986381530762,10.962839126586914,11.192753791809082,11.062772750854492,11.25781536102295,11.559438705444336,11.687849044799805,10.75302791595459,12.29854965209961,11.086894989013672,12.160594940185547,11.388923645019531,10.720033645629883,11.752241134643555,11.075825691223145,11.455404281616211,11.496362686157227,10.872966766357422,11.221406936645508,12.037018775939941,11.722471237182617,11.927973747253418,10.33534049987793,10.593981742858887,10.86403751373291,11.836677551269531,11.976253509521484,11.691933631896973,10.841435432434082,10.745210647583008,11.37106990814209,12.00289535522461,11.43269157409668,12.220749855041504,11.051970481872559,11.547199249267578,10.92629337310791,11.815045356750488,11.536211967468262,11.504477500915527,11.529980659484863,11.00515365600586,11.090625762939453,11.51578426361084,11.141458511352539,11.392504692077637,11.342756271362305,11.645566940307617,11.101119995117188,11.657102584838867,11.088163375854492,11.384390830993652,11.717936515808105,11.144939422607422,11.418688774108887,11.375986099243164,10.539258003234863,11.37796401977539,12.069337844848633,11.964790344238281,11.09656047821045,10.646003723144531,11.4827880859375,11.113225936889648,11.435702323913574
|
20 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,0.12,10.621867179870605,10.42145824432373,11.195003509521484,11.484685897827148,10.794024467468262,11.019747734069824,11.490104675292969,11.03162670135498,10.776253700256348,11.319022178649902,11.029229164123535,10.46848201751709,11.557026863098145,11.630350112915039,10.21575927734375,10.254093170166016,11.132651329040527,10.749253273010254,11.896352767944336,11.851204872131348,11.208588600158691,11.714153289794922,10.779157638549805,11.057229995727539,11.114214897155762,10.616106033325195,11.634489059448242,11.256133079528809,10.997373580932617,12.223559379577637,11.170015335083008,11.057348251342773,11.103898048400879,11.418852806091309,11.480304718017578,11.226197242736816,11.045890808105469,11.04325008392334,10.816136360168457,10.4801025390625,11.184416770935059,11.797572135925293,11.668251037597656,10.831574440002441,11.632143020629883,11.781684875488281,10.77318000793457,10.571830749511719,11.864724159240723,11.191770553588867,11.520462989807129,10.72339153289795,11.068861961364746,11.573631286621094,10.899574279785156,11.674015998840332,10.131872177124023,11.504653930664062,11.00288200378418,10.962963104248047,11.146491050720215,11.226775169372559,10.69074821472168,11.214972496032715,10.712320327758789,10.820131301879883,11.37641429901123,11.639378547668457,11.570959091186523,11.473939895629883,10.993010520935059,11.848175048828125,11.129470825195312,11.361347198486328,11.544088363647461,11.482494354248047,11.397720336914062,11.10208797454834,10.954719543457031,11.292170524597168,10.730844497680664,11.10391616821289,10.9358491897583,11.644660949707031,11.556461334228516,10.230692863464355,10.93234634399414,11.048127174377441,11.356008529663086,11.290736198425293,10.155138969421387,11.609221458435059,11.059803009033203,10.940677642822266,10.850116729736328,11.106929779052734,10.82397747039795,11.759023666381836,10.195236206054688,11.401366233825684,11.2745943069458
|
21 |
+
openlm-research/open_llama_7b vs microsoft/Orca-2-7b,0.85,11.351534843444824,10.584617614746094,10.774704933166504,10.837328910827637,11.801148414611816,11.330574035644531,11.290903091430664,10.862215042114258,11.384349822998047,10.745942115783691,11.390098571777344,11.086434364318848,11.015809059143066,9.748941421508789,10.519133567810059,11.45274829864502,10.552149772644043,11.926155090332031,10.675073623657227,11.253833770751953,10.66576862335205,9.94119644165039,19.171932220458984,10.226914405822754,12.355240821838379,11.018232345581055,10.890312194824219,10.387739181518555,11.143223762512207,12.016646385192871,10.465130805969238,10.726633071899414,11.077333450317383,10.403158187866211,11.529679298400879,10.922691345214844,10.887194633483887,11.203007698059082,10.525662422180176,10.012253761291504,9.828682899475098,10.859309196472168,10.792906761169434,10.633329391479492,10.120866775512695,10.508658409118652,12.094136238098145,10.804570198059082,10.728404998779297,10.672419548034668,12.785676002502441,9.50652027130127,11.129883766174316,11.35869312286377,11.06894302368164,9.344609260559082,10.127263069152832,10.435647010803223,10.6591215133667,11.319711685180664,10.950030326843262,10.885222434997559,10.239350318908691,10.612761497497559,11.15457534790039,10.812723159790039,10.070918083190918,10.321710586547852,11.3038969039917,11.07686996459961,10.517918586730957,9.845905303955078,10.563337326049805,10.056536674499512,11.044584274291992,11.334074974060059,10.442049026489258,10.59984302520752,10.307365417480469,11.638136863708496,10.443572044372559,10.811623573303223,10.837541580200195,10.584928512573242,10.582640647888184,10.057901382446289,10.074625968933105,11.10626220703125,11.195136070251465,10.514017105102539,10.031493186950684,9.643144607543945,11.793606758117676,10.653993606567383,10.19979190826416,11.364167213439941,10.87149715423584,11.284507751464844,9.864716529846191,10.578239440917969,10.145064353942871
|
model-tracing/results/perm/permutation_norm_loss_midpoint_wikitext_single.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model Pair,norm loss p-value,unpermuted norm loss,permuted norm losses
|
2 |
+
meta-llama/Llama-2-7b-hf vs codellama/CodeLlama-7b-hf,0.01,0.3834942579269409,10.473750472068787,8.49558675289154,9.522793173789978,8.34878385066986,8.301295638084412,8.16095769405365,8.466469168663025,9.06697404384613,8.272103667259216,8.733596205711365,8.649529814720154,9.094684958457947,9.377073645591736,8.687456488609314,8.821428656578064,8.725521445274353,8.08456552028656,9.226669669151306,7.609917044639587,8.21532666683197,9.57900846004486,8.550759673118591,9.102148413658142,7.39217221736908,9.121253371238708,10.159769415855408,9.260996222496033,7.9118064641952515,9.355074286460876,8.4481920003891,9.43545377254486,8.129388213157654,8.085391402244568,7.9564841985702515,9.373242735862732,8.900659918785095,8.313831686973572,8.388159155845642,7.573407530784607,8.961812376976013,9.08068311214447,8.496608138084412,8.708313345909119,8.426417708396912,8.470783591270447,8.094282507896423,8.94940984249115,8.751403212547302,8.596155524253845,9.277466177940369,7.896311163902283,9.386301398277283,7.846903204917908,8.443647742271423,9.592914938926697,9.203054785728455,8.75339925289154,9.959086775779724,8.708863615989685,10.180801749229431,8.62676465511322,9.586245894432068,7.8447030782699585,9.516730666160583,8.306718230247498,8.947573065757751,8.558817267417908,8.19304883480072,8.265913367271423,8.223491072654724,8.609399199485779,9.782116293907166,7.843777060508728,7.73509156703949,7.896917700767517,8.163294196128845,8.727386832237244,9.163933157920837,9.018080115318298,8.103793501853943,8.438563704490662,8.779285788536072,8.432084441184998,9.660460829734802,7.826008200645447,9.144956946372986,9.95290219783783,8.358298659324646,7.759764075279236,8.92118489742279,9.001880049705505,8.570141196250916,8.6590656042099,8.832827925682068,8.5744069814682,8.586038947105408,7.335496306419373,8.876983046531677,8.723318457603455,7.9766563177108765
|
3 |
+
meta-llama/Llama-2-7b-hf vs openlm-research/open_llama_7b,0.64,6.713594198226929,6.539518117904663,6.362017393112183,6.602225065231323,6.590438604354858,5.593794584274292,6.113854169845581,6.726726293563843,7.590032339096069,7.062256574630737,7.381011724472046,6.4157397747039795,6.701626539230347,6.820826292037964,6.987111806869507,5.584487676620483,7.162648916244507,6.6456458568573,5.595097303390503,6.483163595199585,6.129888296127319,6.727576971054077,6.245026350021362,6.470016241073608,6.548055410385132,6.394561529159546,6.915935277938843,6.33721137046814,6.654784917831421,6.176393270492554,7.010505437850952,6.428812742233276,6.361801862716675,6.373377561569214,6.637976408004761,6.6545398235321045,6.7396533489227295,6.010260343551636,6.91727614402771,7.3185670375823975,6.170457601547241,6.550145864486694,6.954936742782593,7.021528005599976,6.8889853954315186,6.315131902694702,6.095242261886597,10.108835935592651,6.357076406478882,7.430655241012573,6.448450803756714,7.0982749462127686,6.4298255443573,6.84966254234314,7.4488685131073,6.736899137496948,6.446202993392944,6.604178190231323,7.090266942977905,6.323678731918335,6.3112428188323975,6.28790545463562,6.598105192184448,6.373114347457886,6.431990385055542,6.2564332485198975,7.246878385543823,7.38981032371521,6.214457273483276,6.802284002304077,7.536787748336792,6.634334325790405,6.960758924484253,6.608479261398315,6.47892165184021,7.350800275802612,6.118292570114136,6.7180116176605225,6.2622621059417725,6.792338132858276,6.398637533187866,6.530109167098999,6.031400442123413,6.639722585678101,7.0282533168792725,6.078140020370483,6.023447751998901,6.666522741317749,6.641700506210327,6.949050664901733,6.748633146286011,6.9641406536102295,6.412523984909058,5.949262380599976,6.476415395736694,6.683224439620972,6.442394018173218,6.793895483016968,6.41438364982605,6.0649449825286865,5.552734136581421
|
4 |
+
meta-llama/Llama-2-7b-hf vs huggyllama/llama-7b,0.99,10.224812746047974,9.681059122085571,9.79641842842102,9.143628358840942,8.902554750442505,8.581402063369751,9.70039963722229,8.519185304641724,9.58426308631897,8.995373964309692,9.070337533950806,8.949082612991333,9.619084596633911,9.29295563697815,9.440018892288208,9.70763897895813,9.497541666030884,9.849196672439575,9.820314645767212,9.195354700088501,9.313969850540161,9.354400873184204,9.663116693496704,9.564915895462036,8.243352174758911,8.451400995254517,9.354562044143677,9.870220422744751,8.735538721084595,8.980741739273071,9.208681344985962,9.174977540969849,9.0218985080719,9.463035821914673,8.296002626419067,8.97838044166565,9.827713251113892,8.70076298713684,9.400816202163696,9.546863794326782,8.84546971321106,8.821017503738403,9.502385377883911,9.446989297866821,9.236077547073364,8.610085725784302,8.864148378372192,8.536124467849731,8.738917589187622,9.773377656936646,9.669754266738892,9.055151224136353,8.594401597976685,8.988518953323364,9.397208452224731,8.896766901016235,9.421325922012329,9.172155618667603,8.620713472366333,9.184304475784302,9.626896142959595,8.850933313369751,9.680114984512329,8.70331883430481,9.014384508132935,9.21541714668274,9.671802759170532,8.865594148635864,8.037382364273071,8.281428575515747,8.900898218154907,8.31492829322815,8.984730005264282,8.803801774978638,8.632014513015747,8.735124826431274,8.547569513320923,8.782721757888794,8.775543451309204,9.653815507888794,8.892679452896118,9.198973894119263,9.146803140640259,9.016650438308716,8.820218324661255,9.266303300857544,10.214689493179321,8.883363008499146,9.254101037979126,8.992080926895142,9.562584161758423,9.456084489822388,8.60515809059143,9.141083002090454,8.885345697402954,9.704216241836548,8.497487306594849,8.425892114639282,8.096463441848755,9.380361795425415,9.764080286026001
|
5 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,-0.048818111419677734,9.844830870628357,9.222102522850037,7.755062460899353,8.265713095664978,9.013955473899841,9.327821135520935,9.123439192771912,10.4435476064682,10.395680785179138,9.183204054832458,9.165072798728943,9.568477988243103,10.169446349143982,10.477945685386658,9.97354257106781,9.749969840049744,8.718470931053162,9.266455054283142,8.84190022945404,11.006096243858337,8.343526244163513,10.545372366905212,10.468126654624939,11.031267523765564,8.920008063316345,10.785050749778748,10.329222083091736,8.974294066429138,9.432084441184998,10.732448935508728,10.256052374839783,10.556835532188416,8.810878157615662,8.968798995018005,9.491623282432556,8.424806952476501,8.465443968772888,10.03244435787201,9.752468466758728,9.78839910030365,7.284629225730896,7.684324622154236,8.825533270835876,8.247802138328552,8.404919981956482,7.69716489315033,9.403735518455505,9.254828810691833,17.508076071739197,9.856340765953064,9.115726828575134,8.986517310142517,7.983991026878357,8.65701711177826,8.509859442710876,8.433510184288025,10.006139159202576,7.988286375999451,10.91031014919281,8.459594130516052,10.134988188743591,8.992516875267029,9.247894644737244,9.8466295003891,10.350879073143005,8.965574622154236,8.360766768455505,8.565540671348572,9.279970526695251,10.484707236289978,9.109469771385193,9.356310248374939,9.331755995750427,9.77147901058197,8.712384581565857,9.521596312522888,8.963900923728943,6.773027777671814,10.8277028799057,9.908670783042908,9.5965017080307,7.920130133628845,9.397729277610779,8.551138281822205,8.605363249778748,8.879813551902771,9.17821729183197,8.353206038475037,9.68158757686615,9.635061621665955,9.397002577781677,9.553449034690857,9.10447347164154,7.935643553733826,9.673168540000916,9.363352179527283,9.41775357723236,8.803840041160583,8.780441641807556,10.165222525596619
|
6 |
+
meta-llama/Llama-2-7b-hf vs EleutherAI/llemma_7b,0.01,0.44821369647979736,10.260111451148987,7.788133263587952,9.624961495399475,10.045899987220764,7.3227468729019165,8.576054215431213,8.724260926246643,7.721198678016663,9.501295685768127,9.792026162147522,8.662120461463928,9.696664452552795,9.72442877292633,10.198912262916565,7.810206055641174,7.485549569129944,10.641865372657776,7.784129738807678,10.568729996681213,9.2378431558609,8.15902864933014,7.804042458534241,8.089019417762756,8.3358074426651,8.909724831581116,8.492894768714905,8.694433808326721,8.972641587257385,10.243723511695862,7.02272093296051,9.183990120887756,7.371614098548889,7.766700387001038,7.575039505958557,7.943362832069397,7.547324776649475,9.165839791297913,8.906821846961975,8.444319367408752,7.65238630771637,8.072578072547913,8.772894501686096,8.609686493873596,9.942311882972717,8.98726236820221,8.986947655677795,7.657752633094788,7.5777870416641235,10.219603180885315,8.73482859134674,10.140299439430237,9.045146584510803,9.87434446811676,8.298058152198792,9.576706528663635,8.261414170265198,7.007310509681702,9.08191454410553,8.166513085365295,8.851847290992737,7.7886539697647095,8.51626741886139,10.726054787635803,7.131115555763245,8.866894364356995,8.980855584144592,10.537903428077698,8.707924485206604,9.1041761636734,8.35569155216217,7.90923273563385,7.558130860328674,8.815556168556213,9.084401726722717,10.736987709999084,7.036646485328674,8.363484978675842,8.093238472938538,10.898166298866272,9.313558220863342,8.597114205360413,8.617847084999084,8.128695130348206,9.646315217018127,13.442176461219788,10.456626534461975,9.014615654945374,8.619638085365295,9.138837456703186,9.612112641334534,7.738902688026428,7.245580315589905,7.807826638221741,10.858358025550842,8.273086190223694,8.117148041725159,10.628657937049866,7.735168099403381,9.400915741920471,8.798778176307678
|
7 |
+
meta-llama/Llama-2-7b-hf vs lmsys/vicuna-7b-v1.1,0.99,10.070954203605652,9.98596465587616,7.872992396354675,9.286036372184753,8.555332064628601,8.711309313774109,9.428610682487488,9.328728556632996,8.926531672477722,9.093182444572449,8.60245406627655,8.518800616264343,8.683110117912292,8.410743594169617,8.905894160270691,9.41090476512909,8.997403025627136,9.453501582145691,9.80195701122284,9.018198847770691,7.704877734184265,8.515411257743835,9.171860575675964,9.41734778881073,8.338318705558777,9.661895632743835,8.913832545280457,9.37841022014618,9.808799624443054,8.391430735588074,9.685980677604675,9.192590594291687,8.211149096488953,8.656869769096375,8.958290934562683,8.171902537345886,8.788528323173523,9.031162142753601,8.842985033988953,8.647503733634949,9.641290545463562,9.138723254203796,9.051908373832703,9.315989375114441,9.271025538444519,8.452067255973816,8.985319018363953,9.565594553947449,9.689806818962097,8.82832133769989,9.021748423576355,9.408458590507507,8.873647570610046,7.9212998151779175,9.294074892997742,9.887571215629578,9.798232913017273,8.214550852775574,9.443087458610535,9.211515307426453,8.687570452690125,9.068241953849792,9.206690669059753,10.350232005119324,9.743900179862976,9.317789912223816,9.163899302482605,9.024961352348328,9.765442728996277,9.013166308403015,8.392030596733093,9.235402941703796,9.014994502067566,9.726699709892273,9.839616656303406,8.32940948009491,8.018037676811218,9.81773555278778,9.059948801994324,8.405487895011902,8.710639834403992,9.514521479606628,9.64382255077362,9.51540744304657,8.620617747306824,8.76738440990448,9.260455965995789,9.645397067070007,9.607094645500183,9.834710955619812,9.383732676506042,9.244090914726257,9.373885989189148,9.480880618095398,8.626510500907898,8.69356620311737,8.515510439872742,9.324301600456238,9.481700778007507,9.493292689323425,8.596677660942078
|
8 |
+
meta-llama/Llama-2-7b-hf vs microsoft/Orca-2-7b,0.01,-0.10828566551208496,9.612163782119751,9.57808518409729,9.584277391433716,10.103969812393188,10.62393307685852,10.860812425613403,10.234191179275513,10.248055696487427,10.437435388565063,8.58587384223938,10.031322717666626,9.50733494758606,8.452850580215454,9.147834062576294,9.841037034988403,8.94861912727356,8.569215059280396,10.321378946304321,9.98995327949524,9.031814813613892,8.543767213821411,9.169170618057251,9.706740617752075,10.399054765701294,9.620095491409302,10.239917039871216,10.044366121292114,9.961036920547485,10.31088376045227,8.623070001602173,9.217479944229126,9.986067056655884,9.784968614578247,8.565213441848755,9.691814661026001,9.797849893569946,9.967728853225708,10.174472093582153,10.58090329170227,8.561283349990845,10.293337106704712,10.875295877456665,10.058345079421997,9.790831804275513,9.052335977554321,8.621185541152954,8.913713693618774,9.126322984695435,9.985181093215942,10.12069058418274,10.351969003677368,8.859179735183716,8.868793725967407,9.429267168045044,10.062147378921509,10.541846513748169,10.34529709815979,7.989343881607056,9.149993181228638,10.005859613418579,9.094855546951294,8.563889741897583,10.673412561416626,9.303443193435669,9.09623646736145,9.98895001411438,10.891288995742798,9.313806772232056,10.10368275642395,8.619381189346313,9.195826768875122,9.001145601272583,9.836649179458618,7.904133081436157,9.086479425430298,8.314440965652466,9.347160577774048,8.852710008621216,8.854128122329712,9.808287858963013,16.34550404548645,9.892772912979126,8.387619256973267,10.956121683120728,9.00423550605774,10.616572618484497,9.067741632461548,11.200159311294556,8.27894139289856,10.221617937088013,9.443767786026001,9.392338037490845,9.064051866531372,9.397813081741333,8.125734567642212,9.21006989479065,9.003755807876587,9.544296503067017,11.196948289871216,9.715340852737427
|
9 |
+
meta-llama/Llama-2-7b-hf vs LLM360/Amber,0.11,7.894630134105682,7.741406142711639,9.741238296031952,8.099342048168182,8.7858607172966,8.8379265666008,8.040236175060272,8.228240668773651,8.622905433177948,8.404100120067596,8.37949150800705,9.686843574047089,8.975104987621307,9.110474288463593,9.258489310741425,9.860569655895233,9.684325873851776,9.21797913312912,8.408991515636444,7.970141112804413,9.09898155927658,8.439907729625702,7.791431128978729,9.548395812511444,7.885820090770721,8.11017769575119,8.54678601026535,8.84630936384201,10.097900092601776,9.345426261425018,8.844131171703339,8.529126822948456,8.822519958019257,7.397915542125702,8.911320388317108,8.177402198314667,8.834497153759003,7.470994651317596,8.041984260082245,9.084723174571991,8.611781775951385,9.349124610424042,8.71418446302414,8.455300033092499,8.083493888378143,7.502886474132538,9.542264640331268,9.237480819225311,8.16009110212326,8.309968650341034,8.668898284435272,9.937544524669647,7.961441695690155,7.76298588514328,9.243492782115936,8.276171386241913,9.355368316173553,9.561597526073456,8.683844268321991,8.951688468456268,8.0328728556633,9.128905951976776,8.684419333934784,8.949582755565643,8.869337737560272,9.141734778881073,8.778688132762909,9.47993153333664,9.777857482433319,9.390073478221893,8.643266379833221,9.68235844373703,8.913849532604218,8.41419380903244,8.62106865644455,9.428115546703339,8.091849982738495,8.390080153942108,9.340758979320526,8.255912482738495,8.813822448253632,7.776079833507538,9.499033629894257,8.122173964977264,7.885876357555389,8.421062171459198,7.326867759227753,8.530382812023163,8.047955214977264,8.490481078624725,9.764094054698944,8.352151572704315,8.893617331981659,8.05598133802414,8.670357406139374,9.10058468580246,8.72084778547287,10.240146338939667,9.423927962779999,8.297791182994843,9.481352508068085
|
10 |
+
codellama/CodeLlama-7b-hf vs openlm-research/open_llama_7b,0.65,6.28434944152832,6.726814270019531,6.4456682205200195,6.914477348327637,6.446785926818848,6.5881147384643555,5.242772102355957,6.475652694702148,6.909028053283691,5.528658866882324,6.032546043395996,6.016617774963379,6.414390563964844,6.1580705642700195,6.069611549377441,5.607878684997559,5.254672050476074,5.741635322570801,6.036998748779297,5.640376091003418,6.21879768371582,6.71235466003418,6.310402870178223,6.599245071411133,6.5538434982299805,6.383500099182129,6.12143611907959,5.841699600219727,5.703704833984375,5.595314979553223,6.069385528564453,6.012824058532715,6.329859733581543,5.355166435241699,5.893641471862793,5.896352767944336,6.014966011047363,6.355417251586914,6.048108100891113,5.895661354064941,5.779969215393066,5.906466484069824,6.556133270263672,5.586591720581055,5.98253059387207,6.519078254699707,6.325737953186035,6.963658332824707,6.446205139160156,5.63385009765625,5.955595016479492,6.959787368774414,8.960762023925781,6.463803291320801,5.850605010986328,6.110093116760254,6.135454177856445,6.526464462280273,5.895066261291504,6.240109443664551,5.885980606079102,5.778459548950195,6.233664512634277,5.499074935913086,4.950897216796875,6.281856536865234,6.805331230163574,5.393342018127441,5.523931503295898,5.5659027099609375,6.199000358581543,6.021541595458984,5.790318489074707,6.285762786865234,7.115248680114746,6.00765323638916,6.41680908203125,6.509481430053711,6.2540998458862305,5.8456830978393555,6.189613342285156,6.792692184448242,6.457563400268555,6.041304588317871,6.135087013244629,6.132525444030762,6.078264236450195,6.047064781188965,4.882900238037109,6.157533645629883,5.497096061706543,5.771068572998047,6.258846282958984,6.387505531311035,6.307802200317383,6.851279258728027,6.289949417114258,5.8791961669921875,5.868889808654785,5.754061698913574,5.872708320617676
|
11 |
+
codellama/CodeLlama-7b-hf vs huggyllama/llama-7b,0.32,8.647472620010376,8.376736879348755,8.435736894607544,9.107041597366333,8.864849328994751,9.41048264503479,10.012479066848755,9.569318056106567,8.44343113899231,9.787455797195435,9.312052965164185,9.43970513343811,9.141399621963501,9.065999269485474,8.278849840164185,8.499554872512817,8.162867784500122,8.299113512039185,8.742165803909302,9.889566659927368,8.649908304214478,9.01466679573059,9.214273691177368,9.222363710403442,8.458730936050415,8.305001497268677,8.368142366409302,9.659701585769653,8.178768396377563,9.37527585029602,10.060438394546509,8.972997903823853,9.059861421585083,9.093953371047974,7.9971373081207275,9.003615617752075,8.126897096633911,9.370399713516235,8.522400140762329,8.299326181411743,9.034609079360962,7.785470247268677,8.214834451675415,9.226133584976196,8.734915971755981,8.7814199924469,9.9114830493927,9.082973718643188,8.663569688796997,8.888867616653442,9.117319345474243,8.071345567703247,8.896063089370728,8.932055711746216,8.91256833076477,8.12684941291809,8.584501504898071,8.91443943977356,8.60889744758606,8.841772317886353,8.786389589309692,8.825255632400513,9.281825304031372,8.923148393630981,7.839233636856079,8.52420163154602,9.043694734573364,9.754061937332153,9.47466492652893,8.743067026138306,9.928423166275024,9.674275636672974,9.556930780410767,8.76394009590149,8.932903528213501,7.965190172195435,9.08448338508606,8.615834474563599,9.15774655342102,9.544827699661255,9.85620903968811,8.274455308914185,9.383400201797485,8.919788599014282,9.396048784255981,8.037643671035767,9.437472581863403,8.59937596321106,8.707654237747192,9.40676236152649,8.286210298538208,8.961514711380005,10.425405740737915,9.821033716201782,8.56540036201477,8.608500719070435,9.74540638923645,9.969642877578735,9.945183038711548,8.943975687026978,9.277392625808716
|
12 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.5,0.01,0.23432707786560059,9.105018138885498,10.115467548370361,8.814990520477295,7.700256824493408,22.361277103424072,8.875674724578857,9.687096118927002,13.167479038238525,7.4309868812561035,8.258686542510986,8.119656085968018,8.539817333221436,8.664506435394287,7.635847568511963,8.880570888519287,7.720311641693115,8.838372707366943,8.217498302459717,8.588544368743896,8.979221820831299,8.800464153289795,8.409977436065674,7.696940898895264,7.538162708282471,7.685622692108154,8.631911754608154,7.75542688369751,8.455676555633545,7.691273212432861,7.6336798667907715,8.811938762664795,9.551165103912354,11.26979684829712,7.6999430656433105,9.915334224700928,8.53978681564331,8.212835788726807,8.059876918792725,10.90781545639038,9.512303829193115,9.135515689849854,8.30867338180542,9.355510234832764,9.14211130142212,7.2915568351745605,10.75891923904419,8.306049823760986,7.463253498077393,8.321387767791748,7.137866497039795,8.621911525726318,8.346582889556885,9.125063419342041,8.513661861419678,8.490179538726807,10.315661907196045,8.453839778900146,7.640648365020752,9.382608890533447,8.142853260040283,9.462652683258057,8.157150745391846,8.311485767364502,9.428282260894775,8.401495456695557,7.704698085784912,7.817162990570068,7.951792240142822,7.7705979347229,9.062580585479736,8.580884456634521,8.372841358184814,8.722021579742432,9.242526531219482,8.643452167510986,8.031785488128662,9.107877254486084,7.227905750274658,7.46142053604126,8.96703577041626,7.818650722503662,8.47803544998169,8.361773014068604,8.90479040145874,7.275203227996826,9.667508602142334,7.858919620513916,8.57040548324585,9.46554708480835,9.039218425750732,9.080006122589111,8.056274890899658,7.739555835723877,7.820621967315674,8.11035966873169,8.724180698394775,8.556443691253662,7.35528039932251,7.805674076080322,8.090892314910889
|
13 |
+
codellama/CodeLlama-7b-hf vs EleutherAI/llemma_7b,0.01,-0.01681309938430786,9.120851933956146,9.665036618709564,7.745120465755463,10.883652150630951,9.853213727474213,7.972514569759369,8.57659763097763,9.347029149532318,9.287005841732025,9.25390475988388,7.596251904964447,9.753281056880951,8.316048085689545,9.251168668270111,9.394730031490326,9.60946410894394,10.057454526424408,8.100693166255951,8.999764859676361,8.536845624446869,8.414287030696869,9.235208928585052,7.755367696285248,8.949576795101166,8.671410024166107,8.182545125484467,9.517559468746185,9.520708501338959,8.814814984798431,9.29341834783554,7.800792157649994,9.258494794368744,8.291480481624603,10.568865239620209,9.317304074764252,8.720015943050385,9.406411588191986,10.450869977474213,10.109208524227142,8.162058293819427,8.079206883907318,9.6403107047081,8.510849416255951,9.45523589849472,8.470562398433685,9.602908551692963,9.297125279903412,8.389980733394623,7.734573781490326,8.963873326778412,9.630400121212006,9.040441930294037,9.024840772151947,10.05105060338974,8.557877957820892,8.884402692317963,7.516786992549896,9.47091144323349,8.052890241146088,8.685519635677338,9.905793607234955,8.691942632198334,8.445288121700287,9.078912198543549,9.99776691198349,10.183870732784271,8.642778813838959,9.249837338924408,8.43747752904892,7.503961980342865,8.448286473751068,8.273894727230072,9.812999188899994,8.438606679439545,7.202361524105072,9.53622192144394,10.53255409002304,7.324675023555756,8.090178906917572,11.29394668340683,8.76332038640976,7.9841713309288025,8.38703578710556,7.210101544857025,9.739253461360931,7.38428920507431,8.221330106258392,9.497566640377045,9.075954854488373,7.801622807979584,8.816792905330658,8.747409284114838,7.322262227535248,9.282394826412201,9.361765325069427,8.516244351863861,8.374782025814056,9.28657478094101,9.686973989009857,8.96060699224472
|
14 |
+
codellama/CodeLlama-7b-hf vs lmsys/vicuna-7b-v1.1,0.29,8.48714804649353,7.6908605098724365,9.169391393661499,8.317209959030151,8.82749056816101,8.086182355880737,9.09686827659607,9.041928052902222,9.70714545249939,9.24581789970398,7.527576208114624,8.289576292037964,8.33409857749939,8.506857633590698,9.761435270309448,9.7164466381073,8.253876447677612,9.341957807540894,8.88274073600769,8.183127164840698,8.267747640609741,8.626987218856812,8.416558980941772,8.621223211288452,8.711838483810425,9.186389684677124,8.860594511032104,8.60040831565857,9.341449499130249,7.8805530071258545,8.886674642562866,8.76483416557312,8.315276861190796,8.604306936264038,8.908690214157104,8.503263235092163,8.62110686302185,8.35447382926941,8.634252309799194,9.194954633712769,9.794084310531616,9.426853895187378,9.282697439193726,8.876146078109741,8.379112958908081,9.456719160079956,8.24547266960144,8.695784330368042,8.297900915145874,9.174300909042358,10.507225751876831,8.244601011276245,9.66327166557312,8.527663946151733,8.712950468063354,7.743478536605835,9.500287771224976,8.695890188217163,8.144980192184448,8.653694868087769,9.294896841049194,8.294621229171753,9.95811915397644,9.066579580307007,7.893991231918335,8.774755239486694,8.692675352096558,7.662358999252319,9.419709920883179,9.247517347335815,8.725831747055054,8.96029543876648,7.903769254684448,8.607434034347534,8.779548406600952,9.245018720626831,8.981712102890015,8.89933466911316,8.861811399459839,9.083378553390503,9.699528455734253,8.974799871444702,8.601711988449097,9.232177495956421,9.49698805809021,9.961381673812866,8.806665182113647,8.881787061691284,9.232394933700562,9.254537343978882,9.025410413742065,8.972266912460327,8.399003744125366,7.797715902328491,9.680047750473022,9.935441732406616,8.279030561447144,7.591404676437378,7.754696607589722,9.866612195968628,8.337414503097534
|
15 |
+
codellama/CodeLlama-7b-hf vs microsoft/Orca-2-7b,0.01,0.20641469955444336,7.8634960651397705,8.904095888137817,9.61886715888977,8.437413454055786,8.772413492202759,9.826401948928833,8.376365900039673,9.295730829238892,7.479389429092407,7.212454080581665,9.409053087234497,8.811966180801392,8.886850595474243,9.277664422988892,8.518949747085571,8.85485291481018,7.306167840957642,9.17977261543274,8.63290524482727,8.338998079299927,7.619590997695923,7.116956949234009,9.088074922561646,9.55984616279602,7.769827127456665,12.575409173965454,8.546103715896606,9.712836503982544,7.821888208389282,7.985309839248657,9.255531549453735,7.9532740116119385,7.89363694190979,7.9602367877960205,7.962735414505005,8.382404565811157,7.323387384414673,8.479003190994263,7.25046181678772,8.708838701248169,8.418090105056763,8.20142388343811,7.759202241897583,8.463444948196411,8.459122896194458,8.245249032974243,8.69249176979065,14.59964394569397,8.798161745071411,7.301157236099243,9.396774530410767,8.621342897415161,10.132902383804321,9.146990060806274,8.679113626480103,8.614110231399536,8.71687912940979,9.0999596118927,9.935490846633911,7.709607362747192,7.654601335525513,8.510257005691528,8.947322130203247,8.230319261550903,6.711144685745239,8.664664506912231,8.041719675064087,8.087208032608032,8.947721719741821,8.407228708267212,6.83827805519104,8.23217511177063,9.082698106765747,8.295105218887329,7.719668626785278,7.801557779312134,8.987715005874634,9.564089059829712,9.21347451210022,8.516844987869263,7.161568880081177,8.31154465675354,7.943972826004028,8.84707760810852,9.559814691543579,8.488112688064575,8.300562143325806,8.24126935005188,8.901121377944946,9.877440690994263,7.255166292190552,6.881604433059692,8.9941246509552,8.821150064468384,8.297075510025024,8.37979245185852,8.204619646072388,7.644574403762817,7.633875131607056,8.807678461074829
|
16 |
+
codellama/CodeLlama-7b-hf vs LLM360/Amber,0.01,7.748504936695099,9.06743460893631,8.908285439014435,8.819458305835724,9.088988602161407,8.924072563648224,9.067676842212677,8.016331970691681,8.311434090137482,8.955159485340118,8.229758560657501,8.987535774707794,8.484479248523712,8.56284362077713,7.941815674304962,9.349307358264923,9.409474670886993,9.941840469837189,8.861625015735626,8.79028731584549,9.348574936389923,8.276139557361603,8.158338844776154,8.984058678150177,9.270368874073029,9.068104088306427,10.039808571338654,9.142476379871368,9.175580322742462,8.77158671617508,9.13890391588211,8.617131531238556,8.621155083179474,8.881712257862091,9.085198700428009,8.689877808094025,8.795815765857697,8.682710945606232,8.883796989917755,8.753254234790802,9.923146545886993,8.774359047412872,9.62900859117508,8.872906982898712,9.16351443529129,8.674702942371368,8.880166351795197,9.373622238636017,8.218859016895294,9.922911942005157,8.901104271411896,9.061072647571564,8.830475151538849,9.16731196641922,9.990583717823029,9.197730362415314,9.108112633228302,10.07734328508377,9.169017136096954,8.402928650379181,9.248014748096466,9.330892860889435,9.563519775867462,8.773622810840607,9.222309410572052,9.123565971851349,9.193527519702911,9.795744240283966,8.590337097644806,8.715575516223907,8.95899611711502,8.898348152637482,9.765960037708282,8.962839424610138,9.214798271656036,9.854464828968048,8.187191307544708,9.656366646289825,8.432507812976837,9.170356094837189,8.251133263111115,8.479723274707794,8.784158051013947,9.559178650379181,8.12833720445633,8.339381515979767,8.23774367570877,8.49364310503006,8.596848785877228,8.26347666978836,9.288618385791779,8.338898003101349,8.665286362171173,7.99944144487381,9.6977179646492,9.06463748216629,8.797506630420685,9.591453850269318,9.648471176624298,9.975923836231232,8.205338776111603
|
17 |
+
openlm-research/open_llama_7b vs huggyllama/llama-7b,0.12,5.4373085498809814,5.617760896682739,5.587035417556763,5.687114000320435,6.008775949478149,5.774935960769653,5.945592164993286,6.301852464675903,5.327899217605591,5.804341554641724,6.269682168960571,6.641139268875122,5.6906960010528564,6.195895433425903,6.718469858169556,5.616427659988403,6.246436357498169,5.6004555225372314,5.429034471511841,6.112356424331665,5.798109292984009,6.217177629470825,5.898507356643677,5.735835313796997,4.97048020362854,5.659003496170044,6.196565866470337,6.84828782081604,6.0469372272491455,5.969996690750122,6.734667062759399,5.8872339725494385,6.937445878982544,6.536187410354614,6.107362985610962,5.402880907058716,5.715508699417114,5.3803627490997314,5.959226846694946,6.169707536697388,6.227089166641235,6.530186891555786,5.589359521865845,6.085469484329224,5.880691766738892,5.3078672885894775,6.431185960769653,5.2033703327178955,6.39737343788147,6.516474962234497,6.659247636795044,6.701551675796509,6.254176378250122,5.632610559463501,6.0807716846466064,5.7362635135650635,5.585143327713013,5.939406633377075,6.253710985183716,5.834253549575806,5.83292031288147,6.809473276138306,5.713587045669556,5.961418390274048,6.384552240371704,5.854504823684692,6.26802659034729,5.167945146560669,6.388922929763794,6.145580530166626,6.14031720161438,6.181988954544067,5.45749306678772,5.194824457168579,6.006340265274048,6.111417055130005,6.233436822891235,5.969449281692505,6.450243234634399,5.394938707351685,6.070707559585571,6.392692804336548,5.4279396533966064,5.901009798049927,5.893953561782837,6.187520265579224,5.83727765083313,6.07111668586731,5.614530801773071,6.282021760940552,6.858328104019165,5.672479867935181,6.526853799819946,5.725656747817993,6.860419511795044,5.698136568069458,6.598180055618286,5.921528100967407,5.906771898269653,5.56449818611145,5.829131364822388
|
18 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.5,0.68,6.308373689651489,6.700233697891235,5.703813791275024,6.338608026504517,6.307419061660767,6.586034059524536,5.478386163711548,6.177409410476685,6.018020868301392,5.776783227920532,6.389678239822388,5.805647134780884,5.913630723953247,6.293874025344849,7.018000841140747,7.00434422492981,5.98361611366272,5.66058087348938,5.6123316287994385,6.112520456314087,5.813434839248657,6.909172296524048,6.756714105606079,5.090618371963501,7.784255266189575,6.0108559131622314,5.270795106887817,5.621241807937622,4.6691734790802,5.558912515640259,5.659086465835571,6.354078531265259,6.808579683303833,6.601501703262329,5.424909830093384,5.758757829666138,6.722477197647095,6.334706544876099,5.572609186172485,6.24731183052063,5.827046632766724,5.98691201210022,6.036614656448364,6.069903612136841,5.4608423709869385,5.692006349563599,6.020584344863892,5.6468565464019775,5.900907754898071,7.260846376419067,6.558199167251587,6.274706125259399,6.207753419876099,5.571796655654907,5.839945077896118,6.207461595535278,6.332349061965942,4.937416315078735,5.79715371131897,6.523751497268677,5.60066819190979,7.146387338638306,5.899955987930298,6.562173128128052,6.155851602554321,6.521467447280884,6.11050820350647,6.363598108291626,6.629664659500122,5.639268159866333,6.237446069717407,5.771808862686157,5.830493211746216,6.909991502761841,4.765170335769653,5.74653172492981,5.671989679336548,5.345279932022095,5.3964550495147705,6.0652687549591064,6.395521402359009,6.9056923389434814,6.126155138015747,5.419970750808716,6.917824983596802,6.50818657875061,6.250625848770142,16.261486291885376,5.867288827896118,5.998341798782349,6.283872842788696,5.926321268081665,6.6489622592926025,6.117601633071899,6.7576258182525635,4.966566324234009,6.151142358779907,6.079829454421997,6.484721422195435,5.017580270767212,6.228436708450317
|
19 |
+
openlm-research/open_llama_7b vs EleutherAI/llemma_7b,0.74,6.918841361999512,5.942992210388184,6.163137435913086,6.270599365234375,7.409475326538086,6.998303413391113,6.928138732910156,5.826279640197754,6.324603080749512,6.316580772399902,6.851133346557617,7.621212005615234,5.953413963317871,5.799161911010742,6.694608688354492,6.517380714416504,6.376747131347656,6.4730377197265625,7.693846702575684,6.17234992980957,6.822030067443848,6.326215744018555,7.2796173095703125,6.853416442871094,6.927913665771484,6.259148597717285,6.455480575561523,7.025678634643555,6.28238582611084,6.102893829345703,6.6217851638793945,6.489897727966309,6.100579261779785,6.702361106872559,7.050620079040527,6.681787490844727,7.272147178649902,6.204999923706055,6.434914588928223,6.304933547973633,6.49997615814209,6.801599502563477,6.930009841918945,5.9951887130737305,7.54071044921875,6.3290557861328125,7.4027557373046875,6.631084442138672,5.962194442749023,6.994401931762695,6.317986488342285,6.697565078735352,6.738523483276367,6.1151275634765625,6.463567733764648,7.279179573059082,6.964632034301758,7.170134544372559,5.57750129699707,5.836142539978027,6.106198310852051,7.078838348388672,7.218414306640625,6.934094429016113,6.083596229553223,5.987371444702148,6.6132307052612305,7.24505615234375,6.67485237121582,7.4629106521606445,6.294131278991699,6.789360046386719,6.168454170227051,7.057206153869629,6.778372764587402,6.746638298034668,6.772141456604004,6.247314453125,6.332786560058594,6.7579450607299805,6.38361930847168,6.634665489196777,6.584917068481445,6.887727737426758,6.343280792236328,6.899263381958008,6.330324172973633,6.626551628112793,6.960097312927246,6.3871002197265625,6.660849571228027,6.618146896362305,5.781418800354004,6.620124816894531,7.311498641967773,7.206951141357422,6.33872127532959,5.888164520263672,6.724948883056641,6.355386734008789,6.677863121032715
|
20 |
+
openlm-research/open_llama_7b vs lmsys/vicuna-7b-v1.1,0.12,5.251814126968384,5.051405191421509,5.824950456619263,6.114632844924927,5.42397141456604,5.6496946811676025,6.120051622390747,5.661573648452759,5.406200647354126,5.948969125747681,5.6591761112213135,5.098428964614868,6.186973810195923,6.260297060012817,4.845706224441528,4.884040117263794,5.762598276138306,5.379200220108032,6.526299715042114,6.481151819229126,5.83853554725647,6.3441002368927,5.409104585647583,5.687176942825317,5.74416184425354,5.246052980422974,6.2644360065460205,5.886080026626587,5.6273205280303955,6.853506326675415,5.799962282180786,5.687295198440552,5.733844995498657,6.048799753189087,6.1102516651153564,5.856144189834595,5.675837755203247,5.673197031021118,5.446083307266235,5.110049486160278,5.814363718032837,6.427519083023071,6.298197984695435,5.46152138710022,6.262089967727661,6.41163182258606,5.403126955032349,5.201777696609497,6.494671106338501,5.8217175006866455,6.150409936904907,5.3533384799957275,5.698808908462524,6.203578233718872,5.529521226882935,6.30396294593811,4.761819124221802,6.134600877761841,5.632828950881958,5.592910051345825,5.776437997817993,5.856722116470337,5.320695161819458,5.844919443130493,5.342267274856567,5.450078248977661,6.006361246109009,6.269325494766235,6.200906038284302,6.103886842727661,5.622957468032837,6.478121995925903,5.759417772293091,5.9912941455841064,6.174035310745239,6.112441301345825,6.027667284011841,5.732034921646118,5.58466649055481,5.922117471694946,5.360791444778442,5.733863115310669,5.565796136856079,6.27460789680481,6.186408281326294,4.860639810562134,5.562293291091919,5.67807412147522,5.985955476760864,5.920683145523071,4.785085916519165,6.239168405532837,5.6897499561309814,5.570624589920044,5.4800636768341064,5.736876726150513,5.4539244174957275,6.388970613479614,4.825183153152466,6.031313180923462,5.904541254043579
|
21 |
+
openlm-research/open_llama_7b vs microsoft/Orca-2-7b,0.85,6.374301195144653,5.607383966445923,5.797471284866333,5.860095262527466,6.8239147663116455,6.35334038734436,6.313669443130493,5.884981393814087,6.407116174697876,5.7687084674835205,6.412864923477173,6.109200716018677,6.0385754108428955,4.771707773208618,5.541899919509888,6.475514650344849,5.574916124343872,6.94892144203186,5.697839975357056,6.276600122451782,5.68853497505188,4.96396279335022,14.194698572158813,5.249680757522583,7.378007173538208,6.040998697280884,5.913078546524048,5.410505533218384,6.165990114212036,7.0394127368927,5.487897157669067,5.749399423599243,6.100099802017212,5.42592453956604,6.552445650100708,5.945457696914673,5.909960985183716,6.225774049758911,5.548428773880005,5.035020112991333,4.851449251174927,5.882075548171997,5.815673112869263,5.656095743179321,5.143633127212524,5.5314247608184814,7.116902589797974,5.827336549758911,5.751171350479126,5.695185899734497,7.8084423542022705,4.529286623001099,6.1526501178741455,6.381459474563599,6.09170937538147,4.367375612258911,5.150029420852661,5.458413362503052,5.681887865066528,6.342478036880493,5.972796678543091,5.907988786697388,5.2621166706085205,5.635527849197388,6.17734169960022,5.835489511489868,5.093684434890747,5.344476938247681,6.326663255691528,6.0996363162994385,5.540684938430786,4.868671655654907,5.586103677749634,5.079303026199341,6.067350625991821,6.356841325759888,5.464815378189087,5.622609376907349,5.330131769180298,6.660903215408325,5.466338396072388,5.834389925003052,5.860307931900024,5.607694864273071,5.605406999588013,5.080667734146118,5.097392320632935,6.129028558731079,6.217902421951294,5.536783456802368,5.054259538650513,4.665910959243774,6.816373109817505,5.676759958267212,5.222558259963989,6.3869335651397705,5.894263505935669,6.307274103164673,4.8874828815460205,5.601005792617798,5.1678307056427
|
model-tracing/scripts/docs/doc_trace.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
3 |
+
import os
|
4 |
+
from datasets import load_dataset
|
5 |
+
from tqdm import tqdm
|
6 |
+
import math
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import csv
|
9 |
+
from utils import interpolate_models
|
10 |
+
import time
|
11 |
+
import copy
|
12 |
+
import argparse
|
13 |
+
import glob
|
14 |
+
|
15 |
+
|
16 |
+
block_size = 512
|
17 |
+
|
18 |
+
|
19 |
+
def group_texts(examples):
|
20 |
+
# Concatenate all texts.
|
21 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
22 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
23 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
24 |
+
# customize this part to your needs.
|
25 |
+
total_length = (total_length // block_size) * block_size
|
26 |
+
# Split by chunks of max_len.
|
27 |
+
result = {
|
28 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
29 |
+
for k, t in concatenated_examples.items()
|
30 |
+
}
|
31 |
+
result["labels"] = result["input_ids"].copy()
|
32 |
+
return result
|
33 |
+
|
34 |
+
|
35 |
+
def main(args):
|
36 |
+
start_time = time.time()
|
37 |
+
# Automatically detect CUDA device
|
38 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
print(f"Using device: {device}")
|
40 |
+
os.environ["WANDB_MODE"] = "disabled"
|
41 |
+
|
42 |
+
# Load models and tokenizer
|
43 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
44 |
+
model_list = [
|
45 |
+
"meta-llama/Llama-2-7b-hf",
|
46 |
+
"codellama/CodeLlama-7b-hf",
|
47 |
+
"lmsys/vicuna-7b-v1.5",
|
48 |
+
"EleutherAI/llemma_7b",
|
49 |
+
"LLM360/Amber",
|
50 |
+
]
|
51 |
+
model_pairs = [
|
52 |
+
(0, 2), # LLama2, Vicuna-1.5
|
53 |
+
(0, 1), # LLama2, CodeLlama
|
54 |
+
(0, 3), # LLama2, Lemma
|
55 |
+
(1, 3), # CodeLlama, Lemma
|
56 |
+
(0, 4), # LLama2, Amber
|
57 |
+
]
|
58 |
+
models = [
|
59 |
+
AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
|
60 |
+
for model_name in model_list
|
61 |
+
]
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(models[0].config._name_or_path)
|
63 |
+
tokenizer.pad_token = tokenizer.eos_token
|
64 |
+
|
65 |
+
# Scan the directory for JSON files based on the test name argument
|
66 |
+
columns_ignored = [
|
67 |
+
"text",
|
68 |
+
"added",
|
69 |
+
"id",
|
70 |
+
"lang",
|
71 |
+
"metadata",
|
72 |
+
"source",
|
73 |
+
"timestamp",
|
74 |
+
"subdomain",
|
75 |
+
]
|
76 |
+
json_dir = f"/juice4/scr4/nlp/model-tracing/dolma_program_languages/json_files_{args.test_name}"
|
77 |
+
json_files = glob.glob(f"{json_dir}/*.json")
|
78 |
+
save_dir = f"/juice4/scr4/nlp/model-tracing/dolma_program_languages/results_{args.test_name}"
|
79 |
+
if not os.path.exists(save_dir):
|
80 |
+
os.makedirs(save_dir)
|
81 |
+
|
82 |
+
for json_file in json_files:
|
83 |
+
print(f"Processing {json_file}")
|
84 |
+
|
85 |
+
# Prepare dataset
|
86 |
+
eval_dataset = load_dataset("json", data_files=json_file)
|
87 |
+
|
88 |
+
def tokenize_function(examples):
|
89 |
+
return tokenizer(examples["text"])
|
90 |
+
|
91 |
+
tokenized_datasets = eval_dataset.map(
|
92 |
+
tokenize_function, batched=True, num_proc=4, remove_columns=columns_ignored
|
93 |
+
)
|
94 |
+
lm_datasets = tokenized_datasets.map(
|
95 |
+
group_texts,
|
96 |
+
batched=True,
|
97 |
+
batch_size=1,
|
98 |
+
num_proc=1,
|
99 |
+
)
|
100 |
+
|
101 |
+
# Prepare for evaluation. Batch size is optimized for ~7B model
|
102 |
+
training_args = TrainingArguments(
|
103 |
+
output_dir="./results",
|
104 |
+
per_device_eval_batch_size=3,
|
105 |
+
do_eval=True,
|
106 |
+
report_to=None,
|
107 |
+
dataloader_num_workers=4,
|
108 |
+
use_cpu=True,
|
109 |
+
)
|
110 |
+
alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
|
111 |
+
model = copy.deepcopy(models[0])
|
112 |
+
trainer = Trainer(model=model, args=training_args, eval_dataset=lm_datasets)
|
113 |
+
print("create data loader")
|
114 |
+
eval_dataloader = trainer.get_test_dataloader(lm_datasets["train"])
|
115 |
+
|
116 |
+
for idx_a, idx_b in tqdm(model_pairs, desc="Model Interpolation"):
|
117 |
+
model_a = models[idx_a]
|
118 |
+
model_b = models[idx_b]
|
119 |
+
perplexities = []
|
120 |
+
model_a_name = model_a.config._name_or_path.split("/")[-1]
|
121 |
+
model_b_name = model_b.config._name_or_path.split("/")[-1]
|
122 |
+
|
123 |
+
for alpha in tqdm(
|
124 |
+
alphas, desc=f" \n Alpha Perplexities for {model_a_name} and {model_b_name}"
|
125 |
+
):
|
126 |
+
interpolated_model = interpolate_models(model_a, model_b, alpha)
|
127 |
+
# cast to bfloat16 before GPU
|
128 |
+
interpolated_model = interpolated_model.half().to(device)
|
129 |
+
|
130 |
+
start_time = time.time()
|
131 |
+
losses = []
|
132 |
+
|
133 |
+
for batch in tqdm(eval_dataloader, desc=f"\n Evaluating {alpha}"):
|
134 |
+
# HF Trainer finds GPU by default
|
135 |
+
input_ids = batch["input_ids"].to(device)
|
136 |
+
attention_mask = batch["attention_mask"].to(device)
|
137 |
+
labels = batch["labels"].to(device)
|
138 |
+
outputs = interpolated_model(
|
139 |
+
input_ids=input_ids,
|
140 |
+
attention_mask=attention_mask,
|
141 |
+
labels=labels,
|
142 |
+
)
|
143 |
+
loss = outputs.loss
|
144 |
+
losses.append(loss.item())
|
145 |
+
|
146 |
+
loss_mean = sum(losses) / len(losses)
|
147 |
+
print(f"Loss mean: {loss_mean}")
|
148 |
+
end_time = time.time()
|
149 |
+
execution_time = end_time - start_time
|
150 |
+
print(f"Execution time base: {execution_time} seconds")
|
151 |
+
|
152 |
+
perplexity = math.exp(loss_mean)
|
153 |
+
perplexities.append(perplexity)
|
154 |
+
|
155 |
+
# Move the model back to CPU
|
156 |
+
interpolated_model.to("cpu")
|
157 |
+
|
158 |
+
# Clear the GPU cache
|
159 |
+
del interpolated_model, input_ids, attention_mask, labels, outputs, loss
|
160 |
+
torch.cuda.empty_cache()
|
161 |
+
|
162 |
+
# Save perplexities and model names to CSV
|
163 |
+
json_filename = os.path.splitext(os.path.basename(json_file))[0]
|
164 |
+
csv_filename = f"perplexities_{json_filename}.csv"
|
165 |
+
csv_full_path = f"{save_dir}/{csv_filename}"
|
166 |
+
csv_header = ["Model Pair"] + [f"Alpha {alpha}" for alpha in alphas]
|
167 |
+
if not os.path.exists(csv_full_path):
|
168 |
+
with open(csv_full_path, "w", newline="") as csvfile:
|
169 |
+
writer = csv.writer(csvfile)
|
170 |
+
writer.writerow(csv_header)
|
171 |
+
|
172 |
+
with open(csv_full_path, "a", newline="") as csvfile:
|
173 |
+
writer = csv.writer(csvfile)
|
174 |
+
model_pair = f"{model_a_name} vs {model_b_name}"
|
175 |
+
row = [model_pair] + perplexities
|
176 |
+
writer.writerow(row)
|
177 |
+
|
178 |
+
# Create the plot
|
179 |
+
plt.figure(figsize=(8, 6))
|
180 |
+
plt.plot(alphas, perplexities)
|
181 |
+
plt.xlabel("Alpha")
|
182 |
+
plt.ylabel("Perplexity")
|
183 |
+
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
|
184 |
+
|
185 |
+
# Save the plot as a PNG file
|
186 |
+
plot_filename = (
|
187 |
+
f"alpha_vs_perplexity_{model_a_name}_vs_{model_b_name}_{json_filename}.png"
|
188 |
+
)
|
189 |
+
plot_full_path = f"{save_dir}/{plot_filename}"
|
190 |
+
plt.savefig(plot_full_path, dpi=300, bbox_inches="tight")
|
191 |
+
plt.close()
|
192 |
+
|
193 |
+
end_time = time.time()
|
194 |
+
execution_time = end_time - start_time
|
195 |
+
print(f"Total execution time: {execution_time} seconds")
|
196 |
+
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
parser = argparse.ArgumentParser(description="Model Interpolation")
|
200 |
+
parser.add_argument(
|
201 |
+
"--test_name", type=str, default="js", help="Test name (e.g., cpp, python, js)"
|
202 |
+
)
|
203 |
+
args = parser.parse_args()
|
204 |
+
main(args)
|
model-tracing/scripts/docs/launch.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yaml import safe_load
|
2 |
+
import subprocess
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def load_yaml(file_path):
|
8 |
+
with open(file_path, "r") as file:
|
9 |
+
return safe_load(file)
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
# Load configurations
|
14 |
+
config = load_yaml(args.config)
|
15 |
+
|
16 |
+
# Create necessary directories
|
17 |
+
os.makedirs(config["save_dir"], exist_ok=True)
|
18 |
+
os.makedirs(f"{config['save_dir']}/logs", exist_ok=True)
|
19 |
+
os.makedirs(f"{config['save_dir']}/results", exist_ok=True)
|
20 |
+
|
21 |
+
# Prepare base command
|
22 |
+
base_cmd = f"{args.slurm} python {args.script}"
|
23 |
+
|
24 |
+
# Launch jobs
|
25 |
+
for dataset in config["datasets"]:
|
26 |
+
for model_arch in config["model_architectures"]:
|
27 |
+
job_id = f"{dataset}_{model_arch}"
|
28 |
+
log_path = f"{config['save_dir']}/logs/{job_id}.out"
|
29 |
+
results_path = f"{config['save_dir']}/results/{job_id}"
|
30 |
+
|
31 |
+
cmd = (
|
32 |
+
f"{base_cmd} "
|
33 |
+
f"--model_arch {model_arch} "
|
34 |
+
f"--test_name {dataset} "
|
35 |
+
f"--save_dir {results_path}"
|
36 |
+
)
|
37 |
+
|
38 |
+
if "dolma" in dataset.lower():
|
39 |
+
cmd += f" --json_dir {config['dolma_json_dir']}/{dataset}"
|
40 |
+
elif "m2d2" in dataset.lower():
|
41 |
+
cmd += f" --json_dir {config['m2d2_json_dir']}/{dataset}"
|
42 |
+
|
43 |
+
full_cmd = f"{args.slurm} -o {log_path} -J {job_id} '{cmd}'"
|
44 |
+
print(f"Launching job: {full_cmd}")
|
45 |
+
subprocess.run(full_cmd, shell=True)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
parser = argparse.ArgumentParser(description="Launch interpolation experiments on SLURM")
|
50 |
+
parser.add_argument("--config", default="config.yaml", help="Path to YAML configuration file")
|
51 |
+
parser.add_argument(
|
52 |
+
"--slurm",
|
53 |
+
default="srun --partition=your-partition --time=24:00:00 --mem=64G --gres=gpu:1",
|
54 |
+
help="SLURM command",
|
55 |
+
)
|
56 |
+
parser.add_argument("--script", default="interpolation_script.py", help="Python script to run")
|
57 |
+
args = parser.parse_args()
|
58 |
+
main(args)
|
model-tracing/scripts/docs/m2d_trace.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
3 |
+
import itertools
|
4 |
+
import os
|
5 |
+
from datasets import load_dataset
|
6 |
+
from tqdm import tqdm
|
7 |
+
import math
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import csv
|
10 |
+
from utils import interpolate_models
|
11 |
+
import time
|
12 |
+
import argparse
|
13 |
+
import glob
|
14 |
+
import gc
|
15 |
+
|
16 |
+
block_size = 2048
|
17 |
+
"""
|
18 |
+
Script for running ablation of tests on m2d2 dataset rather
|
19 |
+
than simply wikitext
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
def group_texts(examples):
|
24 |
+
# Concatenate all texts.
|
25 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
26 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
27 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
28 |
+
# customize this part to your needs.
|
29 |
+
total_length = (total_length // block_size) * block_size
|
30 |
+
# Split by chunks of max_len.
|
31 |
+
result = {
|
32 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
33 |
+
for k, t in concatenated_examples.items()
|
34 |
+
}
|
35 |
+
result["labels"] = result["input_ids"].copy()
|
36 |
+
return result
|
37 |
+
|
38 |
+
|
39 |
+
def load_model(model_name):
|
40 |
+
return AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
print(f"Using device: {device}")
|
46 |
+
os.environ["WANDB_MODE"] = "disabled"
|
47 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
48 |
+
|
49 |
+
model_arch = args.model_arch
|
50 |
+
if model_arch == "llama":
|
51 |
+
model_list = [
|
52 |
+
"meta-llama/Llama-2-7b-hf",
|
53 |
+
"meta-llama/Llama-2-7b-chat-hf",
|
54 |
+
"meta-llama/CodeLlama-7b-Python-hf",
|
55 |
+
"meta-llama/CodeLlama-7b-Instruct-hf",
|
56 |
+
"codellama/CodeLlama-7b-hf",
|
57 |
+
"lmsys/vicuna-7b-v1.5",
|
58 |
+
"lmsys/vicuna-7b-v1.1",
|
59 |
+
"EleutherAI/llemma_7b",
|
60 |
+
"LLM360/Amber",
|
61 |
+
]
|
62 |
+
elif model_arch == "olmo":
|
63 |
+
model_list = [
|
64 |
+
"/scr/ahmedah/olmo/step1000_4B_tokens/seed_0_4B",
|
65 |
+
"/scr/ahmedah/olmo/step1000_4B_tokens/seed_42_4B",
|
66 |
+
]
|
67 |
+
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_list[0])
|
69 |
+
tokenizer.pad_token = tokenizer.eos_token
|
70 |
+
|
71 |
+
test_cases = [
|
72 |
+
{
|
73 |
+
"test_name": folder_name,
|
74 |
+
"json_dir": f"/juice4/scr4/nlp/model-tracing/m2d2_s2orc/{folder_name}",
|
75 |
+
"save_dir": f"/juice4/scr4/nlp/model-tracing/m2d2_s2orc/results_{folder_name}",
|
76 |
+
"columns_ignored": ["text", "added", "id", "source", "timestamp", "subdomain"],
|
77 |
+
}
|
78 |
+
for folder_name in [
|
79 |
+
"AI",
|
80 |
+
"CV",
|
81 |
+
"ET",
|
82 |
+
"IM",
|
83 |
+
"mtrl-sci",
|
84 |
+
"stat-mech",
|
85 |
+
"AR",
|
86 |
+
"CY",
|
87 |
+
"IR",
|
88 |
+
"NA",
|
89 |
+
"str-el",
|
90 |
+
"art",
|
91 |
+
"DB",
|
92 |
+
"FL",
|
93 |
+
"supr-con",
|
94 |
+
"CC",
|
95 |
+
"DC",
|
96 |
+
"GA",
|
97 |
+
"LG",
|
98 |
+
"phil",
|
99 |
+
"CE",
|
100 |
+
"dis-nn",
|
101 |
+
"GL",
|
102 |
+
"LO",
|
103 |
+
"CG",
|
104 |
+
"DL",
|
105 |
+
"GR",
|
106 |
+
"MA",
|
107 |
+
"quant-gas",
|
108 |
+
"CL",
|
109 |
+
"DM",
|
110 |
+
"GT",
|
111 |
+
"mes-hall",
|
112 |
+
"CO",
|
113 |
+
"DS",
|
114 |
+
"HC",
|
115 |
+
"MM",
|
116 |
+
"soft",
|
117 |
+
"CR",
|
118 |
+
"EP",
|
119 |
+
"HE",
|
120 |
+
"MS",
|
121 |
+
"SR",
|
122 |
+
]
|
123 |
+
]
|
124 |
+
|
125 |
+
for test_case in test_cases:
|
126 |
+
test_name = test_case["test_name"]
|
127 |
+
json_dir = test_case["json_dir"]
|
128 |
+
save_dir = test_case["save_dir"]
|
129 |
+
columns_ignored = ["text", "added", "id", "source", "subdomain"]
|
130 |
+
|
131 |
+
json_files = glob.glob(f"{json_dir}/*.json")
|
132 |
+
if not os.path.exists(save_dir):
|
133 |
+
os.makedirs(save_dir)
|
134 |
+
|
135 |
+
for json_file in json_files:
|
136 |
+
print(f"Processing {json_file}")
|
137 |
+
|
138 |
+
eval_dataset = load_dataset("json", data_files=json_file)
|
139 |
+
|
140 |
+
def tokenize_function(examples):
|
141 |
+
return tokenizer(examples["text"])
|
142 |
+
|
143 |
+
tokenized_datasets = eval_dataset.map(
|
144 |
+
tokenize_function, batched=True, num_proc=4, remove_columns=columns_ignored
|
145 |
+
)
|
146 |
+
lm_datasets = tokenized_datasets.map(
|
147 |
+
group_texts,
|
148 |
+
batched=True,
|
149 |
+
batch_size=1000,
|
150 |
+
num_proc=8,
|
151 |
+
)
|
152 |
+
|
153 |
+
training_args = TrainingArguments(
|
154 |
+
output_dir="./hf_results",
|
155 |
+
per_device_eval_batch_size=15,
|
156 |
+
do_eval=True,
|
157 |
+
report_to=None,
|
158 |
+
dataloader_num_workers=8,
|
159 |
+
use_cpu=True,
|
160 |
+
)
|
161 |
+
alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
|
162 |
+
initial_model = load_model(model_list[0])
|
163 |
+
trainer = Trainer(model=initial_model, args=training_args, eval_dataset=lm_datasets)
|
164 |
+
eval_dataloader = trainer.get_test_dataloader(lm_datasets["train"])
|
165 |
+
del initial_model
|
166 |
+
|
167 |
+
model_pairs = list(itertools.combinations(enumerate(model_list), 2))
|
168 |
+
|
169 |
+
base_dir = f"{save_dir}/{test_name}"
|
170 |
+
os.makedirs(base_dir, exist_ok=True)
|
171 |
+
imgs_dir = os.path.join(base_dir, "imgs")
|
172 |
+
os.makedirs(imgs_dir, exist_ok=True)
|
173 |
+
csv_dir = os.path.join(base_dir, "csv")
|
174 |
+
os.makedirs(csv_dir, exist_ok=True)
|
175 |
+
|
176 |
+
current_model_a, current_model_b = None, None
|
177 |
+
current_model_a_name, current_model_b_name = None, None
|
178 |
+
|
179 |
+
for (idx_a, model_a_name), (idx_b, model_b_name) in tqdm(
|
180 |
+
model_pairs, desc="Model Interpolation"
|
181 |
+
):
|
182 |
+
if idx_a < idx_b:
|
183 |
+
perplexities = []
|
184 |
+
|
185 |
+
if current_model_a is None or current_model_a_name != model_a_name:
|
186 |
+
if current_model_a is not None:
|
187 |
+
del current_model_a
|
188 |
+
torch.cuda.empty_cache()
|
189 |
+
current_model_a = load_model(model_a_name).to("cpu")
|
190 |
+
current_model_a_name = model_a_name
|
191 |
+
|
192 |
+
if current_model_b is None or current_model_b_name != model_b_name:
|
193 |
+
if current_model_b is not None:
|
194 |
+
del current_model_b
|
195 |
+
torch.cuda.empty_cache()
|
196 |
+
current_model_b = load_model(model_b_name).to("cpu")
|
197 |
+
current_model_b_name = model_b_name
|
198 |
+
|
199 |
+
with torch.no_grad():
|
200 |
+
for alpha in tqdm(
|
201 |
+
alphas,
|
202 |
+
desc=f" \n Alpha Perplexities for {model_a_name} and {model_b_name}",
|
203 |
+
):
|
204 |
+
interpolated_model = interpolate_models(
|
205 |
+
current_model_a, current_model_b, alpha, model_arch=model_arch
|
206 |
+
)
|
207 |
+
interpolated_model = interpolated_model.half().to(device)
|
208 |
+
|
209 |
+
start_time = time.time()
|
210 |
+
losses = []
|
211 |
+
|
212 |
+
for batch in tqdm(eval_dataloader, desc=f"\n Evaluating {alpha}"):
|
213 |
+
input_ids = batch["input_ids"].to(device)
|
214 |
+
attention_mask = batch["attention_mask"].to(device)
|
215 |
+
labels = batch["labels"].to(device)
|
216 |
+
|
217 |
+
outputs = interpolated_model(
|
218 |
+
input_ids=input_ids,
|
219 |
+
attention_mask=attention_mask,
|
220 |
+
labels=labels,
|
221 |
+
)
|
222 |
+
loss = outputs.loss
|
223 |
+
losses.append(loss.item())
|
224 |
+
|
225 |
+
loss_mean = sum(losses) / len(losses)
|
226 |
+
print(f"Loss mean: {loss_mean}")
|
227 |
+
end_time = time.time()
|
228 |
+
execution_time = end_time - start_time
|
229 |
+
print(f"Execution time base: {execution_time} seconds")
|
230 |
+
|
231 |
+
perplexity = math.exp(loss_mean)
|
232 |
+
perplexities.append(perplexity)
|
233 |
+
|
234 |
+
interpolated_model.to("cpu")
|
235 |
+
del interpolated_model, input_ids, attention_mask, labels, outputs, loss
|
236 |
+
torch.cuda.empty_cache()
|
237 |
+
gc.collect()
|
238 |
+
|
239 |
+
model_a_name = model_a_name.split("/")[-1]
|
240 |
+
model_b_name = model_b_name.split("/")[-1]
|
241 |
+
json_filename = os.path.splitext(os.path.basename(json_file))[0]
|
242 |
+
csv_filename = f"{csv_dir}/perplexities_{json_filename}.csv"
|
243 |
+
csv_header = ["Model Pair"] + [f"Alpha {alpha}" for alpha in alphas]
|
244 |
+
|
245 |
+
if not os.path.exists(csv_filename):
|
246 |
+
with open(csv_filename, "w", newline="") as csvfile:
|
247 |
+
writer = csv.writer(csvfile)
|
248 |
+
writer.writerow(csv_header)
|
249 |
+
|
250 |
+
with open(csv_filename, "a", newline="") as csvfile:
|
251 |
+
writer = csv.writer(csvfile)
|
252 |
+
model_pair = f"{model_a_name} vs {model_b_name}"
|
253 |
+
row = [model_pair] + perplexities
|
254 |
+
writer.writerow(row)
|
255 |
+
|
256 |
+
plt.figure(figsize=(8, 6))
|
257 |
+
plt.plot(alphas, perplexities)
|
258 |
+
plt.xlabel("Alpha")
|
259 |
+
plt.ylabel("Perplexity")
|
260 |
+
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
|
261 |
+
|
262 |
+
plot_filename = (
|
263 |
+
f"alpha_vs_perplexity_{model_a_name}_vs_{model_b_name}_{json_filename}.png"
|
264 |
+
)
|
265 |
+
plot_path = f"{imgs_dir}/{plot_filename}"
|
266 |
+
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
267 |
+
plt.close()
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
parser = argparse.ArgumentParser(description="Model Interpolation")
|
272 |
+
parser.add_argument(
|
273 |
+
"--model_arch",
|
274 |
+
choices=["llama", "olmo"],
|
275 |
+
default="llama",
|
276 |
+
help="default model architecture to use",
|
277 |
+
)
|
278 |
+
args = parser.parse_args()
|
279 |
+
main(args)
|
model-tracing/scripts/mode/main.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gc
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
4 |
+
import itertools
|
5 |
+
import os
|
6 |
+
from datasets import load_dataset
|
7 |
+
from tqdm import tqdm
|
8 |
+
import math
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import csv
|
11 |
+
from utils import interpolate_models
|
12 |
+
import time
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
|
16 |
+
block_size = 512
|
17 |
+
|
18 |
+
|
19 |
+
def group_texts(examples):
|
20 |
+
# Concatenate all texts.
|
21 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
22 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
23 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
24 |
+
# customize this part to your needs.
|
25 |
+
total_length = (total_length // block_size) * block_size
|
26 |
+
# Split by chunks of max_len.
|
27 |
+
result = {
|
28 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
29 |
+
for k, t in concatenated_examples.items()
|
30 |
+
}
|
31 |
+
result["labels"] = result["input_ids"].copy()
|
32 |
+
return result
|
33 |
+
|
34 |
+
|
35 |
+
def load_model(model_name):
|
36 |
+
return AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
|
37 |
+
|
38 |
+
|
39 |
+
def main(args):
|
40 |
+
# Automatically detect CUDA device
|
41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
print(f"Using device: {device}")
|
43 |
+
os.environ["WANDB_MODE"] = "disabled"
|
44 |
+
|
45 |
+
# Load models and tokenizer
|
46 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
47 |
+
|
48 |
+
model_arch = args.model_arch
|
49 |
+
if model_arch == "llama":
|
50 |
+
model_list = [
|
51 |
+
"meta-llama/Llama-2-7b-hf",
|
52 |
+
"codellama/CodeLlama-7b-hf",
|
53 |
+
"openlm-research/open_llama_7b",
|
54 |
+
"huggyllama/llama-7b",
|
55 |
+
"lmsys/vicuna-7b-v1.5",
|
56 |
+
"EleutherAI/llemma_7b",
|
57 |
+
"lmsys/vicuna-7b-v1.1",
|
58 |
+
"microsoft/Orca-2-7b",
|
59 |
+
"LLM360/Amber",
|
60 |
+
]
|
61 |
+
elif model_arch == "olmo":
|
62 |
+
model_list = [
|
63 |
+
"/scr/ahmedah/olmo/step1000_4B_tokens/seed_0_4B",
|
64 |
+
"/scr/ahmedah/olmo/step1000_4B_tokens/seed_42_4B",
|
65 |
+
]
|
66 |
+
|
67 |
+
# Load tokenizer
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_list[0])
|
69 |
+
tokenizer.pad_token = tokenizer.eos_token
|
70 |
+
|
71 |
+
# Prepare dataset
|
72 |
+
if args.dataset == "wikitext":
|
73 |
+
eval_dataset = load_dataset("dlwh/wikitext_103_detokenized", split="test")
|
74 |
+
columns_ignored = ["text"]
|
75 |
+
else:
|
76 |
+
raise ValueError("main.py only supports wikitext.")
|
77 |
+
|
78 |
+
def tokenize_function(examples):
|
79 |
+
return tokenizer(examples["text"])
|
80 |
+
|
81 |
+
tokenized_datasets = eval_dataset.map(
|
82 |
+
tokenize_function, batched=True, num_proc=4, remove_columns=columns_ignored
|
83 |
+
)
|
84 |
+
lm_datasets = tokenized_datasets.map(
|
85 |
+
group_texts,
|
86 |
+
batched=True,
|
87 |
+
batch_size=1,
|
88 |
+
num_proc=1,
|
89 |
+
)
|
90 |
+
|
91 |
+
# Prepare for evaluation. Batch size is optimized for ~7B model
|
92 |
+
training_args = TrainingArguments(
|
93 |
+
output_dir="./hf_results",
|
94 |
+
per_device_eval_batch_size=3,
|
95 |
+
do_eval=True,
|
96 |
+
report_to=None,
|
97 |
+
dataloader_num_workers=4,
|
98 |
+
use_cpu=True,
|
99 |
+
)
|
100 |
+
alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
|
101 |
+
# Load an initial model to create the trainer and dataloader
|
102 |
+
initial_model = load_model(model_list[0])
|
103 |
+
trainer = Trainer(model=initial_model, args=training_args, eval_dataset=lm_datasets)
|
104 |
+
eval_dataloader = trainer.get_test_dataloader(lm_datasets)
|
105 |
+
del initial_model
|
106 |
+
|
107 |
+
# Calculate the L2 distance between each pair of models
|
108 |
+
model_pairs = list(itertools.combinations(enumerate(model_list), 2))
|
109 |
+
|
110 |
+
# create directories for results
|
111 |
+
base_dir = f"{os.getcwd()}/results"
|
112 |
+
os.makedirs(base_dir, exist_ok=True)
|
113 |
+
imgs_dir = os.path.join(base_dir, "imgs")
|
114 |
+
os.makedirs(imgs_dir, exist_ok=True)
|
115 |
+
csv_dir = os.path.join(base_dir, "csv")
|
116 |
+
print(csv_dir)
|
117 |
+
os.makedirs(csv_dir, exist_ok=True)
|
118 |
+
|
119 |
+
current_model_a, current_model_b = None, None
|
120 |
+
current_model_a_name, current_model_b_name = None, None
|
121 |
+
|
122 |
+
for (idx_a, model_a_name), (idx_b, model_b_name) in tqdm(
|
123 |
+
model_pairs, desc="Model Interpolation"
|
124 |
+
):
|
125 |
+
if idx_a < idx_b:
|
126 |
+
perplexities = []
|
127 |
+
|
128 |
+
if current_model_a is None or current_model_a_name != model_a_name:
|
129 |
+
if current_model_a is not None:
|
130 |
+
del current_model_a
|
131 |
+
torch.cuda.empty_cache()
|
132 |
+
current_model_a = load_model(model_a_name).to("cpu")
|
133 |
+
current_model_a_name = model_a_name
|
134 |
+
|
135 |
+
if current_model_b is None or current_model_b_name != model_b_name:
|
136 |
+
if current_model_b is not None:
|
137 |
+
del current_model_b
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
current_model_b = load_model(model_b_name).to("cpu")
|
140 |
+
current_model_b_name = model_b_name
|
141 |
+
|
142 |
+
with torch.no_grad():
|
143 |
+
for alpha in tqdm(
|
144 |
+
alphas, desc=f" \n Alpha Perplexities for {model_a_name} and {model_b_name}"
|
145 |
+
):
|
146 |
+
|
147 |
+
interpolated_model = interpolate_models(
|
148 |
+
current_model_a, current_model_b, alpha, model_arch=model_arch
|
149 |
+
)
|
150 |
+
interpolated_model = interpolated_model.half().to(device)
|
151 |
+
|
152 |
+
start_time = time.time()
|
153 |
+
losses = []
|
154 |
+
|
155 |
+
for batch in tqdm(eval_dataloader, desc=f"\n Evaluating {alpha}"):
|
156 |
+
input_ids = batch["input_ids"].to(device)
|
157 |
+
attention_mask = batch["attention_mask"].to(device)
|
158 |
+
labels = batch["labels"].to(device)
|
159 |
+
|
160 |
+
outputs = interpolated_model(
|
161 |
+
input_ids=input_ids,
|
162 |
+
attention_mask=attention_mask,
|
163 |
+
labels=labels,
|
164 |
+
)
|
165 |
+
loss = outputs.loss
|
166 |
+
losses.append(loss.item())
|
167 |
+
|
168 |
+
loss_mean = sum(losses) / len(losses)
|
169 |
+
print(f"Loss mean: {loss_mean}")
|
170 |
+
end_time = time.time()
|
171 |
+
execution_time = end_time - start_time
|
172 |
+
print(f"Execution time base: {execution_time} seconds")
|
173 |
+
|
174 |
+
perplexity = math.exp(loss_mean)
|
175 |
+
perplexities.append(perplexity)
|
176 |
+
|
177 |
+
# Move the model back to CPU
|
178 |
+
interpolated_model.to("cpu")
|
179 |
+
|
180 |
+
# Clear the GPU cache & collect free memory
|
181 |
+
del interpolated_model, input_ids, attention_mask, labels, outputs, loss
|
182 |
+
torch.cuda.empty_cache()
|
183 |
+
gc.collect()
|
184 |
+
|
185 |
+
# split on HF org so we don't get accidental
|
186 |
+
# directory error
|
187 |
+
|
188 |
+
model_a_name = model_a_name.split("/")[-1]
|
189 |
+
model_b_name = model_b_name.split("/")[-1]
|
190 |
+
# Save perplexities and model names to CSV
|
191 |
+
csv_filename = f"{csv_dir}/single_perplexities.csv"
|
192 |
+
csv_header = ["Model Pair"] + [f"Alpha {alpha}" for alpha in alphas]
|
193 |
+
|
194 |
+
if not os.path.exists(csv_filename):
|
195 |
+
with open(csv_filename, "w", newline="") as csvfile:
|
196 |
+
writer = csv.writer(csvfile)
|
197 |
+
writer.writerow(csv_header)
|
198 |
+
|
199 |
+
with open(csv_filename, "a", newline="") as csvfile:
|
200 |
+
writer = csv.writer(csvfile)
|
201 |
+
model_pair = f"{model_a_name} vs {model_b_name}"
|
202 |
+
row = [model_pair] + perplexities
|
203 |
+
writer.writerow(row)
|
204 |
+
|
205 |
+
# Create the plot
|
206 |
+
plt.figure(figsize=(8, 6))
|
207 |
+
plt.plot(alphas, perplexities)
|
208 |
+
plt.xlabel("Alpha")
|
209 |
+
plt.ylabel("Perplexity")
|
210 |
+
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
|
211 |
+
|
212 |
+
# Save the plot as a PNG file
|
213 |
+
plot_filename = f"single_alpha_vs_perplexity_{model_a_name}_vs_{model_b_name}.png"
|
214 |
+
plot_path = f"{imgs_dir}/{plot_filename}"
|
215 |
+
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
216 |
+
plt.close()
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
parser = argparse.ArgumentParser(description="Model Interpolation")
|
221 |
+
parser.add_argument(
|
222 |
+
"--dataset", choices=["wikitext", "json"], default="wikitext", help="Dataset to use"
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--model_arch",
|
226 |
+
choices=["llama", "olmo"],
|
227 |
+
default="llama",
|
228 |
+
help="default model architecture to use",
|
229 |
+
)
|
230 |
+
args = parser.parse_args()
|
231 |
+
main(args)
|
model-tracing/scripts/mode/mode_connectivity_metrics.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import datetime
|
4 |
+
|
5 |
+
|
6 |
+
def plot_traces(
|
7 |
+
results_path,
|
8 |
+
metric,
|
9 |
+
plot_path,
|
10 |
+
model_a_name,
|
11 |
+
model_b_name,
|
12 |
+
unpermuted_res=False,
|
13 |
+
normalize=True,
|
14 |
+
alpha_step=0.1,
|
15 |
+
end_points=True,
|
16 |
+
):
|
17 |
+
|
18 |
+
df = pd.read_csv(results_path)
|
19 |
+
|
20 |
+
alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
|
21 |
+
if end_points is False:
|
22 |
+
alphas = alphas[1:-1]
|
23 |
+
|
24 |
+
if metric == "loss":
|
25 |
+
|
26 |
+
plt.figure(figsize=(8, 6))
|
27 |
+
for index, row in df.iterrows():
|
28 |
+
row = row[int(len(row) - (len(row) - 2) / 2) :]
|
29 |
+
if normalize:
|
30 |
+
row = normalize_trace(row, alpha_step)
|
31 |
+
plt.plot(alphas, row, "o-")
|
32 |
+
|
33 |
+
plt.xlabel("Alpha")
|
34 |
+
plt.ylabel("Loss")
|
35 |
+
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
|
36 |
+
plot_filename = f"{plot_path}_{datetime.datetime.now().timestamp()}.png"
|
37 |
+
|
38 |
+
if metric == "perplexity":
|
39 |
+
|
40 |
+
plt.figure(figsize=(8, 6))
|
41 |
+
for index, row in df.iterrows():
|
42 |
+
row = row[2 : int(2 + (len(row) - 2) / 2)]
|
43 |
+
if normalize:
|
44 |
+
row = normalize_trace(row, alpha_step)
|
45 |
+
plt.plot(alphas, row, "o-")
|
46 |
+
|
47 |
+
plt.xlabel("Alpha")
|
48 |
+
plt.ylabel("Perplexity")
|
49 |
+
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
|
50 |
+
plot_filename = f"{plot_path}_{datetime.datetime.now().timestamp()}.png"
|
51 |
+
|
52 |
+
if unpermuted_res is not False:
|
53 |
+
plt.plot(alphas, normalize_trace(unpermuted_res, alpha_step))
|
54 |
+
|
55 |
+
# Save the plot as a PNG file
|
56 |
+
|
57 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
58 |
+
plt.close()
|
59 |
+
|
60 |
+
|
61 |
+
def plot_trace(losses, alpha_step, normalize, model_a_name, model_b_name, plot_path):
|
62 |
+
|
63 |
+
plt.figure(figsize=(8, 6))
|
64 |
+
if normalize:
|
65 |
+
losses = normalize_trace(losses, alpha_step)
|
66 |
+
alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
|
67 |
+
plt.plot(alphas, losses, "o-")
|
68 |
+
|
69 |
+
plt.xlabel("Alpha")
|
70 |
+
plt.ylabel("Loss")
|
71 |
+
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
|
72 |
+
plot_filename = f"{plot_path}_{datetime.datetime.now().timestamp()}.png"
|
73 |
+
|
74 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
75 |
+
plt.close()
|
76 |
+
|
77 |
+
|
78 |
+
def normalize_trace(trace, alpha_step):
|
79 |
+
slope = trace[-1] - trace[0]
|
80 |
+
start = trace[0]
|
81 |
+
for i in range(len(trace)):
|
82 |
+
trace[i] -= slope * alpha_step * i
|
83 |
+
trace[i] -= start
|
84 |
+
return trace
|
85 |
+
|
86 |
+
|
87 |
+
def normalize_trace_2(trace, alphas):
|
88 |
+
slope = trace[-1] - trace[0]
|
89 |
+
start = trace[0]
|
90 |
+
for i in range(len(trace)):
|
91 |
+
trace[i] -= slope * alphas[i]
|
92 |
+
trace[i] -= start
|
93 |
+
return trace
|
94 |
+
|
95 |
+
|
96 |
+
def max_loss_ahmed(results_path, num_points=5, normalize=True, alphas=[0.0, 0.3, 0.5, 0.7, 1.0]):
|
97 |
+
df = pd.read_csv(results_path)
|
98 |
+
|
99 |
+
max_losses = []
|
100 |
+
|
101 |
+
for index, row in df.iterrows():
|
102 |
+
row = row[-num_points:]
|
103 |
+
if normalize:
|
104 |
+
row = normalize_trace_2(row, alphas)
|
105 |
+
max_losses.append(max(row))
|
106 |
+
|
107 |
+
return max_losses
|
108 |
+
|
109 |
+
|
110 |
+
def max_loss_compare(results_path, unpermuted_loss, num_points, normalize=True, alpha_step=0.1):
|
111 |
+
df = pd.read_csv(results_path)
|
112 |
+
alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
|
113 |
+
|
114 |
+
permuted_max_losses = []
|
115 |
+
|
116 |
+
for index, row in df.iterrows():
|
117 |
+
row = row[-num_points:]
|
118 |
+
if normalize:
|
119 |
+
row = normalize_trace(row, alpha_step)
|
120 |
+
permuted_max_losses.append(max(row))
|
121 |
+
|
122 |
+
if normalize:
|
123 |
+
unpermuted_loss = normalize_trace(unpermuted_loss, alpha_step)
|
124 |
+
unpermuted_max_loss = max(unpermuted_loss)
|
125 |
+
|
126 |
+
counter = 0
|
127 |
+
for m in permuted_max_losses:
|
128 |
+
if m > unpermuted_max_loss:
|
129 |
+
counter += 1
|
130 |
+
|
131 |
+
return counter, len(permuted_max_losses)
|
132 |
+
|
133 |
+
|
134 |
+
def avg_loss_compare(results_path, unpermuted_loss, num_points, normalize=True, alpha_step=0.1):
|
135 |
+
df = pd.read_csv(results_path)
|
136 |
+
alphas = [round(alpha * alpha_step, 2) for alpha in range(int(1 / alpha_step + 1))]
|
137 |
+
|
138 |
+
permuted_avg_losses = []
|
139 |
+
|
140 |
+
for index, row in df.iterrows():
|
141 |
+
row = row[-num_points:]
|
142 |
+
if normalize:
|
143 |
+
row = normalize_trace(row, alpha_step)
|
144 |
+
permuted_avg_losses.append(sum(row) / len(row))
|
145 |
+
|
146 |
+
if normalize:
|
147 |
+
unpermuted_loss = normalize_trace(unpermuted_loss, alpha_step)
|
148 |
+
unpermuted_avg_loss = sum(unpermuted_loss) / len(unpermuted_loss)
|
149 |
+
|
150 |
+
counter = 0
|
151 |
+
for m in permuted_avg_losses:
|
152 |
+
if m > unpermuted_avg_loss:
|
153 |
+
counter += 1
|
154 |
+
|
155 |
+
return counter, len(permuted_avg_losses)
|
156 |
+
|
157 |
+
|
158 |
+
def avg_loss_ahmed(results_path, num_points=5, normalize=True, alphas=[0.0, 0.3, 0.5, 0.7, 1.0]):
|
159 |
+
df = pd.read_csv(results_path)
|
160 |
+
|
161 |
+
avg_losses = []
|
162 |
+
|
163 |
+
for index, row in df.iterrows():
|
164 |
+
row = row[-num_points:]
|
165 |
+
if normalize:
|
166 |
+
row = normalize_trace_2(row, alphas)
|
167 |
+
avg_losses.append(sum(row) / len(row))
|
168 |
+
|
169 |
+
return avg_losses
|
170 |
+
|
171 |
+
|
172 |
+
def compute_p_value(counter, total):
|
173 |
+
return (total - counter - 1) / total
|
model-tracing/scripts/perm/main.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from tracing.perm.permute import permute_model
|
5 |
+
|
6 |
+
|
7 |
+
def main(base_model, ft_model, test_stat, num_perm, emb_dim=4096, mlp_dim=11008):
|
8 |
+
|
9 |
+
unperm_stat = test_stat(base_model, ft_model)
|
10 |
+
print(unperm_stat)
|
11 |
+
|
12 |
+
perm_stats = []
|
13 |
+
|
14 |
+
for i in range(num_perm):
|
15 |
+
|
16 |
+
mlp_permutation = torch.randperm(mlp_dim)
|
17 |
+
emb_permutation = torch.randperm(emb_dim)
|
18 |
+
|
19 |
+
permute_model(ft_model, mlp_permutation, emb_permutation)
|
20 |
+
|
21 |
+
perm_stat = test_stat(base_model, ft_model)
|
22 |
+
|
23 |
+
perm_stats.append(perm_stat)
|
24 |
+
print(i, perm_stat)
|
25 |
+
|
26 |
+
print(perm_stats)
|
27 |
+
exact = p_value_exact(unperm_stat, perm_stats.copy())
|
28 |
+
approx = p_value_approx(unperm_stat, perm_stats.copy())
|
29 |
+
|
30 |
+
print(exact, approx)
|
31 |
+
|
32 |
+
return exact, approx, unperm_stat, perm_stats
|
33 |
+
|
34 |
+
|
35 |
+
def p_value_exact(unpermuted, permuted):
|
36 |
+
count = 0
|
37 |
+
for a in permuted:
|
38 |
+
if a < unpermuted:
|
39 |
+
count += 1
|
40 |
+
return round((count + 1) / (len(permuted) + 1), 2)
|
41 |
+
|
42 |
+
|
43 |
+
def p_value_approx(unpermuted, permuted):
|
44 |
+
mean = sum(permuted) / len(permuted)
|
45 |
+
std = np.std(permuted)
|
46 |
+
zscore = (unpermuted - mean) / std
|
47 |
+
return zscore
|
model-tracing/scripts/robust/pythia.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import GPTNeoXForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import pickle
|
6 |
+
import timeit
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate
|
10 |
+
from tracing.utils.utils import output_hook, get_submodule
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser(description="Experiment Settings")
|
13 |
+
|
14 |
+
parser.add_argument("--model_id", default="EleutherAI/pythia-1.4b-deduped", type=str)
|
15 |
+
parser.add_argument("--step", default=0, type=int)
|
16 |
+
parser.add_argument("--layer", default=10, type=int)
|
17 |
+
|
18 |
+
parser.add_argument("--dataset_id", default="dlwh/wikitext_103_detokenized", type=str)
|
19 |
+
parser.add_argument("--block_size", default=512, type=int)
|
20 |
+
parser.add_argument("--batch_size", default=6, type=int)
|
21 |
+
|
22 |
+
parser.add_argument("--save", default="results.p", type=str)
|
23 |
+
parser.add_argument("--seed", default=0, type=int)
|
24 |
+
parser.add_argument("--token", default="", type=str)
|
25 |
+
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
from huggingface_hub import login
|
29 |
+
|
30 |
+
login(token=args.token)
|
31 |
+
|
32 |
+
start = timeit.default_timer()
|
33 |
+
|
34 |
+
results = {}
|
35 |
+
results["args"] = args
|
36 |
+
results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
|
37 |
+
|
38 |
+
torch.manual_seed(args.seed)
|
39 |
+
|
40 |
+
model = GPTNeoXForCausalLM.from_pretrained(
|
41 |
+
args.model_id,
|
42 |
+
revision=f"step{args.step}",
|
43 |
+
)
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
45 |
+
args.model_id,
|
46 |
+
revision=f"step{args.step}",
|
47 |
+
)
|
48 |
+
|
49 |
+
print("model loaded")
|
50 |
+
|
51 |
+
dataset = prepare_hf_dataset(args.dataset_id, args.block_size, tokenizer)
|
52 |
+
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
|
53 |
+
|
54 |
+
print("dataset loaded")
|
55 |
+
|
56 |
+
block = get_submodule(model, f"gpt_neox.layers.{args.layer}")
|
57 |
+
|
58 |
+
feats, hooks = {}, {}
|
59 |
+
for layer in [
|
60 |
+
"input_layernorm",
|
61 |
+
"post_attention_layernorm",
|
62 |
+
"mlp.dense_h_to_4h",
|
63 |
+
"mlp.dense_4h_to_h",
|
64 |
+
]:
|
65 |
+
hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: output_hook(
|
66 |
+
m, inp, op, layer, feats
|
67 |
+
)
|
68 |
+
get_submodule(block, layer).register_forward_hook(hooks[layer])
|
69 |
+
|
70 |
+
print("hooks created")
|
71 |
+
|
72 |
+
evaluate(model, dataloader)
|
73 |
+
|
74 |
+
print("models evaluated")
|
75 |
+
|
76 |
+
end = timeit.default_timer()
|
77 |
+
results["time"] = end - start
|
78 |
+
|
79 |
+
results["weights"] = block.state_dict()
|
80 |
+
results["feats"] = feats
|
81 |
+
|
82 |
+
print(results)
|
83 |
+
pickle.dump(results, open(args.save, "wb"))
|
model-tracing/tracing/__init__.py
ADDED
File without changes
|
model-tracing/tracing/perm/permute.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for permuting weights in a Llama model architecture.
|
3 |
+
This enables exploring model connectivity and representation properties
|
4 |
+
by applying consistent neuron permutations throughout the network.
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
def permute_model(model, mlp_permutation, emb_permutation, n_blocks=32):
|
9 |
+
"""
|
10 |
+
Apply permutations to a Llama model's weights to maintain functional equivalence.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
model: The Llama model to permute
|
14 |
+
mlp_permutation: Permutation indices for MLP hidden dimensions
|
15 |
+
emb_permutation: Permutation indices for embedding dimensions
|
16 |
+
n_blocks: Number of transformer blocks in the model (default: 32)
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
None: Modifies the model in-place
|
20 |
+
"""
|
21 |
+
permute_embedding_layer(model, emb_permutation)
|
22 |
+
permute_transformer_blocks(model, mlp_permutation, emb_permutation)
|
23 |
+
permute_output_layer(model, emb_permutation)
|
24 |
+
|
25 |
+
|
26 |
+
def permute_transformer_blocks(model, mlp_permutation, emb_permutation):
|
27 |
+
"""
|
28 |
+
Apply permutations to transformer block weights in a Llama model.
|
29 |
+
|
30 |
+
Permutes attention layers, MLP layers, and normalization layers according to
|
31 |
+
the provided permutation indices to maintain functional equivalence.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
model: The Llama model to permute
|
35 |
+
mlp_permutation: Permutation indices for MLP hidden dimensions
|
36 |
+
emb_permutation: Permutation indices for embedding dimensions
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
None: Modifies the model in-place
|
40 |
+
"""
|
41 |
+
weights = model.state_dict()
|
42 |
+
|
43 |
+
# Permuting the Self attention layers
|
44 |
+
for key in weights:
|
45 |
+
if "self_attn" not in key:
|
46 |
+
continue
|
47 |
+
|
48 |
+
if "o_proj" in key:
|
49 |
+
weights[key] = weights[key][emb_permutation]
|
50 |
+
else:
|
51 |
+
weights[key] = weights[key][:, emb_permutation]
|
52 |
+
|
53 |
+
# Permuting the mlp projection layers
|
54 |
+
for key in weights:
|
55 |
+
if "mlp" not in key:
|
56 |
+
continue
|
57 |
+
if len(weights[key].shape) != 2:
|
58 |
+
continue
|
59 |
+
|
60 |
+
dim_0 = weights[key].size(0)
|
61 |
+
dim_1 = weights[key].size(1)
|
62 |
+
|
63 |
+
if dim_0 == len(mlp_permutation):
|
64 |
+
weights[key] = weights[key][mlp_permutation]
|
65 |
+
elif dim_1 == len(mlp_permutation):
|
66 |
+
weights[key] = weights[key][:, mlp_permutation]
|
67 |
+
|
68 |
+
if dim_0 == len(emb_permutation):
|
69 |
+
weights[key] = weights[key][emb_permutation]
|
70 |
+
elif dim_1 == len(emb_permutation):
|
71 |
+
weights[key] = weights[key][:, emb_permutation]
|
72 |
+
|
73 |
+
# input_layernorm, post_attention_layernorm
|
74 |
+
for key in weights:
|
75 |
+
if "model.layers" not in key:
|
76 |
+
continue
|
77 |
+
if len(weights[key].shape) != 1 or len(weights[key]) != len(emb_permutation):
|
78 |
+
continue
|
79 |
+
|
80 |
+
weights[key] = weights[key][emb_permutation]
|
81 |
+
|
82 |
+
model.load_state_dict(weights)
|
83 |
+
|
84 |
+
|
85 |
+
def permute_embedding_layer(model, emb_permutation):
|
86 |
+
"""
|
87 |
+
Apply permutation to embedding layer weights in a Llama model.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
model: The Llama model to permute
|
91 |
+
emb_permutation: Permutation indices for embedding dimensions
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
None: Modifies the model in-place
|
95 |
+
"""
|
96 |
+
weights = model.state_dict()
|
97 |
+
|
98 |
+
weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
|
99 |
+
model.load_state_dict(weights)
|
100 |
+
|
101 |
+
|
102 |
+
def permute_output_layer(model, emb_permutation):
|
103 |
+
"""
|
104 |
+
Apply permutation to output layer weights in a Llama model.
|
105 |
+
|
106 |
+
Permutes the language model head and final normalization layer.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
model: The Llama model to permute
|
110 |
+
emb_permutation: Permutation indices for embedding dimensions
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
None: Modifies the model in-place
|
114 |
+
"""
|
115 |
+
weights = model.state_dict()
|
116 |
+
|
117 |
+
weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
|
118 |
+
weights["model.norm.weight"] = weights["model.norm.weight"][emb_permutation]
|
119 |
+
model.load_state_dict(weights)
|
model-tracing/tracing/statistics/__init__.py
ADDED
File without changes
|
model-tracing/tracing/statistics/csh.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of Chi-Squared Hypothesis (CSH) test for comparing neural network models.
|
3 |
+
|
4 |
+
This module provides functions to test whether two models have similar activation patterns
|
5 |
+
across different layers using Chi-Squared statistical tests.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from collections import defaultdict
|
10 |
+
import scipy
|
11 |
+
import numpy as np
|
12 |
+
from scipy.stats import chi2
|
13 |
+
|
14 |
+
from scipy.optimize import linear_sum_assignment as LAP
|
15 |
+
|
16 |
+
from tracing.utils.utils import cossim
|
17 |
+
from tracing.utils.evaluate import evaluate
|
18 |
+
|
19 |
+
|
20 |
+
def statistic(base_model, ft_model, dataloader):
|
21 |
+
"""
|
22 |
+
Compute Chi-Squared Hypothesis test statistic between two models.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
base_model: Base model to compare
|
26 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
27 |
+
dataloader: DataLoader providing input data for activation collection
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
tuple: (p_value, p_values_per_layer) from the CSH test
|
31 |
+
"""
|
32 |
+
return csh_sp_dataloader(base_model, ft_model, dataloader)
|
33 |
+
|
34 |
+
|
35 |
+
def hook(m, inp, op, feats, name):
|
36 |
+
"""
|
37 |
+
Forward hook to capture output activations from model layers.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
m: Module being hooked
|
41 |
+
inp: Input to the module
|
42 |
+
op: Output from the module
|
43 |
+
feats: Dictionary to store activations
|
44 |
+
name: Key to store the activations under
|
45 |
+
"""
|
46 |
+
feats[name].append(op.detach().cpu())
|
47 |
+
|
48 |
+
|
49 |
+
def hook_in(m, inp, op, feats, name):
|
50 |
+
"""
|
51 |
+
Forward hook to capture input activations to model layers.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
m: Module being hooked
|
55 |
+
inp: Input to the module (tuple)
|
56 |
+
op: Output from the module
|
57 |
+
feats: Dictionary to store activations
|
58 |
+
name: Key to store the activations under
|
59 |
+
"""
|
60 |
+
feats[name].append(inp[0].detach().cpu())
|
61 |
+
|
62 |
+
|
63 |
+
def csh_sp_dataloader_block(base_model, ft_model, dataloader, i):
|
64 |
+
"""
|
65 |
+
Apply CSH test to a specific block in the model.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
base_model: Base model to compare
|
69 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
70 |
+
dataloader: DataLoader providing input data for activation collection
|
71 |
+
i: Block index to analyze
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
float: p-value indicating the statistical similarity between models at block i
|
75 |
+
"""
|
76 |
+
feats = defaultdict(list)
|
77 |
+
|
78 |
+
base_hook = lambda *args: hook(*args, feats, "base")
|
79 |
+
base_model.model.layers[i].mlp.down_proj.register_forward_hook(base_hook)
|
80 |
+
|
81 |
+
ft_hook = lambda *args: hook(*args, feats, "ft")
|
82 |
+
ft_model.model.layers[i].mlp.down_proj.register_forward_hook(ft_hook)
|
83 |
+
|
84 |
+
evaluate(base_model, dataloader)
|
85 |
+
evaluate(ft_model, dataloader)
|
86 |
+
|
87 |
+
base_mat = torch.vstack(feats["base"])
|
88 |
+
ft_mat = torch.vstack(feats["ft"])
|
89 |
+
|
90 |
+
base_mat = base_mat.view(-1, base_mat.shape[-1]).T
|
91 |
+
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
|
92 |
+
|
93 |
+
matched = torch.argmax(cossim(base_mat, ft_mat), axis=-1)
|
94 |
+
orig = torch.arange(len(matched))
|
95 |
+
|
96 |
+
cor, pvalue = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
|
97 |
+
return pvalue
|
98 |
+
|
99 |
+
|
100 |
+
def csh_sp_dataloader(base_model, ft_model, dataloader, n_blocks=32):
|
101 |
+
"""
|
102 |
+
Apply CSH test across all model blocks using activations from a dataloader.
|
103 |
+
|
104 |
+
Performs Chi-Squared Hypothesis test by:
|
105 |
+
1. Collecting activations from both models using the same input data
|
106 |
+
2. Computing optimal matching between neurons in corresponding layers
|
107 |
+
3. Calculating Spearman correlation between matched indices and original indices
|
108 |
+
4. Computing combined p-value using Fisher's method
|
109 |
+
|
110 |
+
Args:
|
111 |
+
base_model: Base model to compare
|
112 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
113 |
+
dataloader: DataLoader providing input data for activation collection
|
114 |
+
n_blocks: Number of transformer blocks to analyze (default: 32)
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
tuple: (combined_p_value, list_of_p_values_per_layer)
|
118 |
+
"""
|
119 |
+
chi_squared = 0
|
120 |
+
feats = defaultdict(list)
|
121 |
+
|
122 |
+
base_hooks = {}
|
123 |
+
ft_hooks = {}
|
124 |
+
|
125 |
+
for i in range(n_blocks):
|
126 |
+
layer = str(i)
|
127 |
+
|
128 |
+
base_hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: hook(
|
129 |
+
m, inp, op, feats, "base_" + layer
|
130 |
+
)
|
131 |
+
base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hooks[layer])
|
132 |
+
|
133 |
+
ft_hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: hook(
|
134 |
+
m, inp, op, feats, "ft_" + layer
|
135 |
+
)
|
136 |
+
ft_model.model.layers[i].mlp.up_proj.register_forward_hook(ft_hooks[layer])
|
137 |
+
|
138 |
+
evaluate(base_model, dataloader)
|
139 |
+
evaluate(ft_model, dataloader)
|
140 |
+
|
141 |
+
p_values = []
|
142 |
+
count = 0
|
143 |
+
|
144 |
+
for i in range(n_blocks):
|
145 |
+
base_mat = torch.vstack(feats["base_" + str(i)])
|
146 |
+
ft_mat = torch.vstack(feats["ft_" + str(i)])
|
147 |
+
|
148 |
+
base_mat = torch.reshape(
|
149 |
+
base_mat, (base_mat.shape[0] * base_mat.shape[1], base_mat.shape[2])
|
150 |
+
)
|
151 |
+
ft_mat = torch.reshape(ft_mat, (ft_mat.shape[0] * ft_mat.shape[1], ft_mat.shape[2]))
|
152 |
+
|
153 |
+
base_mat = base_mat.T
|
154 |
+
ft_mat = ft_mat.T
|
155 |
+
|
156 |
+
matched = LAP(
|
157 |
+
cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True
|
158 |
+
)
|
159 |
+
matched = matched[1]
|
160 |
+
orig = torch.arange(len(matched))
|
161 |
+
|
162 |
+
cor, temp = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
|
163 |
+
|
164 |
+
if not np.isnan(temp):
|
165 |
+
chi_squared -= 2 * np.log(temp)
|
166 |
+
count += 1
|
167 |
+
print(i, temp)
|
168 |
+
p_values.append(temp)
|
169 |
+
|
170 |
+
p_value = chi2.sf(chi_squared, df=2 * count)
|
171 |
+
|
172 |
+
return p_value, p_values
|
model-tracing/tracing/statistics/csu.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of Cosine Similarity of Weights (CSW) test for comparing neural network models.
|
3 |
+
|
4 |
+
This module provides functions to test whether two models have similar weight matrices
|
5 |
+
using cosine similarity and statistical tests to quantify the similarity.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from tracing.utils.utils import cossim, fisher
|
11 |
+
import scipy
|
12 |
+
import numpy as np
|
13 |
+
from scipy.stats import chi2
|
14 |
+
|
15 |
+
from scipy.optimize import linear_sum_assignment as LAP
|
16 |
+
|
17 |
+
|
18 |
+
def statistic(base_model, ft_model):
|
19 |
+
"""
|
20 |
+
Compute Cosine Similarity of Weights statistic between two models.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
base_model: Base model to compare
|
24 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
tuple: (aggregate_p_value, p_values_per_layer) from the CSW test
|
28 |
+
"""
|
29 |
+
return csw_sp(base_model, ft_model)
|
30 |
+
|
31 |
+
|
32 |
+
def csw_sp_layer(base_model, ft_model, layer_name):
|
33 |
+
"""
|
34 |
+
Calculate Cosine Similarity of Weights for a specific layer.
|
35 |
+
|
36 |
+
Uses linear assignment to find optimal matching between neurons in the layer
|
37 |
+
and calculates Spearman correlation to quantify similarity.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
base_model: Base model to compare
|
41 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
42 |
+
layer_name: Name of the layer in the model's state dict to analyze
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
float: p-value indicating the statistical similarity of weight matrices
|
46 |
+
"""
|
47 |
+
base_mat = base_model.state_dict()[layer_name]
|
48 |
+
ft_mat = ft_model.state_dict()[layer_name]
|
49 |
+
|
50 |
+
matched = LAP(cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True)
|
51 |
+
matched = matched[1]
|
52 |
+
orig = torch.arange(len(matched))
|
53 |
+
|
54 |
+
cor, pvalue = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
|
55 |
+
return pvalue
|
56 |
+
|
57 |
+
|
58 |
+
def csw_sp(model1, model2):
|
59 |
+
"""
|
60 |
+
Apply CSW test across all MLP up-projection layers in the models.
|
61 |
+
|
62 |
+
Performs Fisher's method to combine p-values from individual layer tests
|
63 |
+
into an aggregate statistic.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
model1: First model to compare
|
67 |
+
model2: Second model to compare
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
tuple: (aggregate_p_value, list_of_p_values_per_layer)
|
71 |
+
"""
|
72 |
+
chi_squared = 0
|
73 |
+
num_layers = 0
|
74 |
+
|
75 |
+
p_values = []
|
76 |
+
|
77 |
+
for name1, name2 in zip(list(model1.state_dict().keys()), list(model2.state_dict().keys())):
|
78 |
+
if name1 != name2:
|
79 |
+
raise ValueError(f"Model parameter names do not match: {name1} != {name2}")
|
80 |
+
elif "mlp.up_proj" not in name1:
|
81 |
+
continue
|
82 |
+
|
83 |
+
pvalue = csw_sp_layer(model1, model2, name1)
|
84 |
+
if not np.isnan(pvalue):
|
85 |
+
chi_squared -= 2 * np.log(pvalue)
|
86 |
+
num_layers += 1
|
87 |
+
p_values.append(pvalue)
|
88 |
+
|
89 |
+
print(name1, pvalue)
|
90 |
+
|
91 |
+
aggregate_pvalue = chi2.sf(chi_squared, df=2 * num_layers)
|
92 |
+
return aggregate_pvalue, p_values
|
93 |
+
|
94 |
+
|
95 |
+
def csw_sp_pair(base_model, ft_model, layer_name_base, layer_name_ft):
|
96 |
+
"""
|
97 |
+
Calculate Cosine Similarity of Weights between two specific layers.
|
98 |
+
|
99 |
+
Similar to csw_sp_layer but allows comparing layers with different names.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
base_model: Base model to compare
|
103 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
104 |
+
layer_name_base: Name of the layer in the base model's state dict
|
105 |
+
layer_name_ft: Name of the layer in the fine-tuned model's state dict
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
float: p-value indicating the statistical similarity of weight matrices
|
109 |
+
"""
|
110 |
+
base_mat = base_model.state_dict()[layer_name_base]
|
111 |
+
ft_mat = ft_model.state_dict()[layer_name_ft]
|
112 |
+
|
113 |
+
matched = LAP(cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True)
|
114 |
+
matched = matched[1]
|
115 |
+
orig = torch.arange(len(matched))
|
116 |
+
|
117 |
+
cor, pvalue = scipy.stats.spearmanr(matched.tolist(), orig.tolist())
|
118 |
+
return pvalue
|
119 |
+
|
120 |
+
|
121 |
+
def statistic_all(base_model, ft_model):
|
122 |
+
"""
|
123 |
+
Compute comprehensive pairwise comparisons between all compatible layers.
|
124 |
+
|
125 |
+
Tests every possible layer pairing between models that have compatible shapes,
|
126 |
+
useful for exploring model structure similarities without assumptions.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
base_model: Base model to compare
|
130 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
float: Aggregate p-value from Fisher's method combining all layer comparisons
|
134 |
+
"""
|
135 |
+
base_model.to("cpu")
|
136 |
+
ft_model.to("cpu")
|
137 |
+
|
138 |
+
weights_base = base_model.state_dict()
|
139 |
+
weights_ft = ft_model.state_dict()
|
140 |
+
|
141 |
+
shapes_base = {}
|
142 |
+
shapes_ft = {}
|
143 |
+
|
144 |
+
for name1 in list(weights_base.keys()):
|
145 |
+
shapes_base[name1] = weights_base[name1].shape
|
146 |
+
for name2 in list(weights_ft.keys()):
|
147 |
+
shapes_ft[name2] = weights_ft[name2].shape
|
148 |
+
|
149 |
+
pvalues = []
|
150 |
+
|
151 |
+
for name1 in list(weights_base.keys()):
|
152 |
+
for name2 in list(weights_ft.keys()):
|
153 |
+
if shapes_base[name1] == shapes_ft[name2] and len(shapes_base[name1]) != 1:
|
154 |
+
pval = csw_sp_pair(base_model, ft_model, name1, name2)
|
155 |
+
print(name1, name2, pval)
|
156 |
+
pvalues.append(pval)
|
157 |
+
|
158 |
+
print(pvalues)
|
159 |
+
|
160 |
+
res = 0
|
161 |
+
|
162 |
+
if len(pvalues) == 0:
|
163 |
+
res = 999
|
164 |
+
else:
|
165 |
+
res = fisher(pvalues)
|
166 |
+
|
167 |
+
print(res)
|
168 |
+
return res
|
model-tracing/tracing/statistics/jsd.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of Jensen-Shannon Divergence (JSD) for comparing language model outputs.
|
3 |
+
|
4 |
+
This module provides functions to compute the Jensen-Shannon Divergence between
|
5 |
+
probability distributions output by two language models, measuring their similarity
|
6 |
+
in output space rather than parameter space.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
+
|
13 |
+
from tracing.utils.evaluate import (
|
14 |
+
prepare_hf_dataset,
|
15 |
+
prepare_hf_dataloader,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def statistic(base_model, ft_model, dataloader, device="cuda"):
|
20 |
+
"""
|
21 |
+
Compute Jensen-Shannon Divergence between outputs of two language models.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
base_model: Base model to compare
|
25 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
26 |
+
dataloader: DataLoader providing input data for model evaluation
|
27 |
+
device: Device to run the computation on (default: "cuda")
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
float: Sum of Jensen-Shannon Divergence values across all batches
|
31 |
+
"""
|
32 |
+
return compute_jsd(base_model, ft_model, dataloader, device)
|
33 |
+
|
34 |
+
|
35 |
+
def statistic_stable(base_model, ft_model, dataloader, device="cuda"):
|
36 |
+
"""
|
37 |
+
Compute numerically stable Jensen-Shannon Divergence between outputs of two models.
|
38 |
+
|
39 |
+
This version handles potential numerical issues better than the standard version.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
base_model: Base model to compare
|
43 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
44 |
+
dataloader: DataLoader providing input data for model evaluation
|
45 |
+
device: Device to run the computation on (default: "cuda")
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
float: Sum of Jensen-Shannon Divergence values across all batches
|
49 |
+
"""
|
50 |
+
return compute_jsd_stable(base_model, ft_model, dataloader, device)
|
51 |
+
|
52 |
+
|
53 |
+
def compute_jsd(base_model, ft_model, dataloader, device="cuda"):
|
54 |
+
"""
|
55 |
+
Compute Jensen-Shannon Divergence between two models using softmax outputs.
|
56 |
+
|
57 |
+
Processes each batch in the dataloader and computes JSD between the models'
|
58 |
+
probability distributions over vocabulary tokens. Handles potential vocabulary
|
59 |
+
size differences by truncating to a common size (32000 tokens).
|
60 |
+
|
61 |
+
Args:
|
62 |
+
base_model: Base model to compare
|
63 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
64 |
+
dataloader: DataLoader providing input data for model evaluation
|
65 |
+
device: Device to run the computation on (default: "cuda")
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
float: Sum of Jensen-Shannon Divergence values across all batches
|
69 |
+
"""
|
70 |
+
jsds = []
|
71 |
+
|
72 |
+
base_model.to(device)
|
73 |
+
ft_model.to(device)
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
for batch in dataloader:
|
77 |
+
input_ids = batch["input_ids"].to(device)
|
78 |
+
attention_mask = batch["attention_mask"].to(device)
|
79 |
+
labels = batch["labels"].to(device)
|
80 |
+
|
81 |
+
outputs_base = base_model(
|
82 |
+
input_ids=input_ids,
|
83 |
+
attention_mask=attention_mask,
|
84 |
+
labels=labels,
|
85 |
+
)
|
86 |
+
outputs_ft = ft_model(
|
87 |
+
input_ids=input_ids,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
labels=labels,
|
90 |
+
)
|
91 |
+
|
92 |
+
logits_base = outputs_base.logits.squeeze()
|
93 |
+
logits_ft = outputs_ft.logits.squeeze()
|
94 |
+
|
95 |
+
softmax_base = torch.softmax(logits_base, dim=-1)
|
96 |
+
softmax_ft = torch.softmax(logits_ft, dim=-1)
|
97 |
+
|
98 |
+
# Truncate the softmax outputs to the first 32000 dimensions
|
99 |
+
softmax_base = softmax_base[:, :32000]
|
100 |
+
softmax_ft = softmax_ft[:, :32000]
|
101 |
+
|
102 |
+
m = 0.5 * (softmax_base + softmax_ft)
|
103 |
+
jsd = 0.5 * (F.kl_div(m.log(), softmax_base) + F.kl_div(m.log(), softmax_ft))
|
104 |
+
|
105 |
+
jsds.append(jsd.item())
|
106 |
+
|
107 |
+
base_model.to("cpu")
|
108 |
+
ft_model.to("cpu")
|
109 |
+
return sum(jsds)
|
110 |
+
|
111 |
+
|
112 |
+
def compute_jsd_stable(base_model, ft_model, dataloader, device="cuda"):
|
113 |
+
"""
|
114 |
+
Compute numerically stable Jensen-Shannon Divergence between two models.
|
115 |
+
|
116 |
+
A more robust implementation that:
|
117 |
+
1. Handles vocabulary size mismatches by truncating to the minimum size
|
118 |
+
2. Uses log-space calculations to avoid numerical underflow
|
119 |
+
3. Computes JSD directly from log probabilities for better stability
|
120 |
+
|
121 |
+
Args:
|
122 |
+
base_model: Base model to compare
|
123 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
124 |
+
dataloader: DataLoader providing input data for model evaluation
|
125 |
+
device: Device to run the computation on (default: "cuda")
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
float: Sum of Jensen-Shannon Divergence values across all batches
|
129 |
+
"""
|
130 |
+
jsds = []
|
131 |
+
|
132 |
+
base_model.to(device)
|
133 |
+
ft_model.to(device)
|
134 |
+
|
135 |
+
with torch.no_grad():
|
136 |
+
for batch in dataloader:
|
137 |
+
input_ids = batch["input_ids"].to(device)
|
138 |
+
attention_mask = batch["attention_mask"].to(device)
|
139 |
+
labels = batch["labels"].to(device)
|
140 |
+
|
141 |
+
outputs_base = base_model(
|
142 |
+
input_ids=input_ids,
|
143 |
+
attention_mask=attention_mask,
|
144 |
+
labels=labels,
|
145 |
+
)
|
146 |
+
outputs_ft = ft_model(
|
147 |
+
input_ids=input_ids,
|
148 |
+
attention_mask=attention_mask,
|
149 |
+
labels=labels,
|
150 |
+
)
|
151 |
+
|
152 |
+
logits_base = outputs_base.logits.squeeze()
|
153 |
+
logits_ft = outputs_ft.logits.squeeze()
|
154 |
+
|
155 |
+
# Determine the minimum vocabulary size between the two models
|
156 |
+
min_vocab_size = min(logits_base.size(-1), logits_ft.size(-1))
|
157 |
+
|
158 |
+
# Truncate the logits to the minimum vocabulary size
|
159 |
+
logits_base = logits_base[..., :min_vocab_size]
|
160 |
+
logits_ft = logits_ft[..., :min_vocab_size]
|
161 |
+
|
162 |
+
log_probs_base = F.log_softmax(logits_base, dim=-1)
|
163 |
+
log_probs_ft = F.log_softmax(logits_ft, dim=-1)
|
164 |
+
|
165 |
+
m = 0.5 * (log_probs_base.exp() + log_probs_ft.exp())
|
166 |
+
log_m = m.log()
|
167 |
+
|
168 |
+
kl_div_base_m = (log_probs_base - log_m).sum(dim=-1)
|
169 |
+
kl_div_ft_m = (log_probs_ft - log_m).sum(dim=-1)
|
170 |
+
|
171 |
+
jsd = 0.5 * (kl_div_base_m + kl_div_ft_m).mean()
|
172 |
+
jsds.append(jsd.item())
|
173 |
+
|
174 |
+
base_model.to("cpu")
|
175 |
+
ft_model.to("cpu")
|
176 |
+
|
177 |
+
return sum(jsds)
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
|
182 |
+
base_model_name = "LLM360/Amber" # 'openlm-research/open_llama_7b' # 'lmsys/vicuna-7b-v1.5'
|
183 |
+
ft_model_name = "LLM360/AmberChat" # 'openlm-research/open_llama_7b_v2' # 'LLM360/Amber' # "lmsys/vicuna-7b-v1.1"
|
184 |
+
|
185 |
+
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
|
186 |
+
ft_model = AutoModelForCausalLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16)
|
187 |
+
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
|
188 |
+
|
189 |
+
# dataset = load_generated_datasets(base_model_name, ft_model_name, 512, base_tokenizer, ["text"])
|
190 |
+
# dataloader = prepare_hf_dataloader(dataset, 1)
|
191 |
+
|
192 |
+
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer)
|
193 |
+
dataloader = prepare_hf_dataloader(dataset, 1)
|
194 |
+
|
195 |
+
print(statistic(base_model, ft_model, dataloader))
|
model-tracing/tracing/statistics/l2.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of L2 distance metrics for comparing neural network model weights.
|
3 |
+
|
4 |
+
This module provides functions to calculate the L2 (Euclidean) distance between
|
5 |
+
corresponding weight tensors of two models, providing a measure of parameter space
|
6 |
+
similarity.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def statistic(base_model, ft_model):
|
13 |
+
"""
|
14 |
+
Compute the average L2 distance between weights of two models.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
base_model: Base model to compare
|
18 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
float: Average L2 distance across all comparable parameters
|
22 |
+
"""
|
23 |
+
return calculate_l2_distance(base_model, ft_model)
|
24 |
+
|
25 |
+
|
26 |
+
def calculate_l2_distance(model1, model2):
|
27 |
+
"""
|
28 |
+
Calculate the average L2 distance between corresponding parameters of two models.
|
29 |
+
|
30 |
+
For each parameter tensor in the models, computes the Euclidean distance between
|
31 |
+
them and returns the average across all parameters. Handles potential shape
|
32 |
+
mismatches in embedding or output layers.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
model1: First model to compare
|
36 |
+
model2: Second model to compare
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
float: Average L2 distance across all comparable parameters
|
40 |
+
|
41 |
+
Raises:
|
42 |
+
ValueError: If parameter names don't match or if there are shape mismatches
|
43 |
+
in parameters other than embedding or output layers
|
44 |
+
"""
|
45 |
+
total_squared_diff = 0
|
46 |
+
num_layers = 0
|
47 |
+
|
48 |
+
all_layers = []
|
49 |
+
|
50 |
+
for (name1, param1), (name2, param2) in zip(
|
51 |
+
model1.named_parameters(), model2.named_parameters()
|
52 |
+
):
|
53 |
+
if name1 != name2:
|
54 |
+
raise ValueError(f"Model parameter names do not match: {name1} != {name2}")
|
55 |
+
elif param1.shape != param2.shape:
|
56 |
+
if name1 == "model.embed_tokens.weight" or name1 == "lm_head.weight":
|
57 |
+
print(
|
58 |
+
f"Skipping {name1} because of shape mismatch: {param1.shape} != {param2.shape}"
|
59 |
+
)
|
60 |
+
continue
|
61 |
+
raise ValueError(
|
62 |
+
f"Model parameter shapes do not match for {name1}: {param1.shape} != {param2.shape}"
|
63 |
+
)
|
64 |
+
|
65 |
+
l2_diff = torch.sum((param1 - param2) ** 2) ** 0.5
|
66 |
+
total_squared_diff += l2_diff.item()
|
67 |
+
all_layers.append(l2_diff.item())
|
68 |
+
num_layers += 1
|
69 |
+
|
70 |
+
avg_l2_distance = total_squared_diff / num_layers
|
71 |
+
return avg_l2_distance
|
model-tracing/tracing/statistics/match.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of activation matching algorithms for comparing neural network models.
|
3 |
+
|
4 |
+
This module provides functions for matching neurons between two models based on
|
5 |
+
their activation patterns, helping identify corresponding functional units despite
|
6 |
+
permutation differences.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from collections import defaultdict
|
11 |
+
import scipy
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from tracing.utils.evaluate import evaluate
|
15 |
+
from tracing.utils.llama.matching import match_wmats
|
16 |
+
|
17 |
+
|
18 |
+
def hook_in(m, inp, op, feats, name):
|
19 |
+
"""
|
20 |
+
Forward hook to capture input activations to model layers.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
m: Module being hooked
|
24 |
+
inp: Input to the module (tuple)
|
25 |
+
op: Output from the module
|
26 |
+
feats: Dictionary to store activations
|
27 |
+
name: Key to store the activations under
|
28 |
+
"""
|
29 |
+
feats[name].append(inp[0].detach().cpu())
|
30 |
+
|
31 |
+
|
32 |
+
def hook_out(m, inp, op, feats, name):
|
33 |
+
"""
|
34 |
+
Forward hook to capture output activations from model layers.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
m: Module being hooked
|
38 |
+
inp: Input to the module
|
39 |
+
op: Output from the module
|
40 |
+
feats: Dictionary to store activations
|
41 |
+
name: Key to store the activations under
|
42 |
+
"""
|
43 |
+
feats[name].append(op.detach().cpu())
|
44 |
+
|
45 |
+
|
46 |
+
def statistic(base_model, ft_model, dataloader, n_blocks=32):
|
47 |
+
"""
|
48 |
+
Compute neuron matching statistics across all transformer blocks.
|
49 |
+
|
50 |
+
For each block, compares the gate and up projections to determine if
|
51 |
+
the permutation patterns are consistent, which would indicate functionally
|
52 |
+
corresponding neurons.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
base_model: Base model to compare
|
56 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
57 |
+
dataloader: DataLoader providing input data for activation collection
|
58 |
+
n_blocks: Number of transformer blocks to analyze (default: 32)
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list: Spearman correlation p-values for each block
|
62 |
+
"""
|
63 |
+
stats = []
|
64 |
+
|
65 |
+
for i in range(n_blocks):
|
66 |
+
gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i=i)
|
67 |
+
up_match = mlp_matching_up(base_model, ft_model, dataloader, i=i)
|
68 |
+
|
69 |
+
cor, pvalue = scipy.stats.spearmanr(gate_match.tolist(), up_match.tolist())
|
70 |
+
print(i, pvalue, len(gate_match))
|
71 |
+
stats.append(pvalue)
|
72 |
+
|
73 |
+
return stats
|
74 |
+
|
75 |
+
|
76 |
+
def statistic_layer(base_model, ft_model, dataloader, i=0):
|
77 |
+
"""
|
78 |
+
Compute neuron matching statistics for a specific transformer block.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
base_model: Base model to compare
|
82 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
83 |
+
dataloader: DataLoader providing input data for activation collection
|
84 |
+
i: Block index to analyze (default: 0)
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
float: Spearman correlation p-value for the specified block
|
88 |
+
"""
|
89 |
+
gate_perm = mlp_matching_gate(base_model, ft_model, dataloader, i=i)
|
90 |
+
up_perm = mlp_matching_up(base_model, ft_model, dataloader, i=i)
|
91 |
+
cor, pvalue = scipy.stats.spearmanr(gate_perm.tolist(), up_perm.tolist())
|
92 |
+
return pvalue
|
93 |
+
|
94 |
+
|
95 |
+
def mlp_matching_gate(base_model, ft_model, dataloader, i=0):
|
96 |
+
"""
|
97 |
+
Match neurons between models by comparing gate projection activations.
|
98 |
+
|
99 |
+
Collects activations from the gate projection layer for both models
|
100 |
+
and computes a permutation that would align corresponding neurons.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
base_model: Base model to compare
|
104 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
105 |
+
dataloader: DataLoader providing input data for activation collection
|
106 |
+
i: Block index to analyze (default: 0)
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.Tensor: Permutation indices that match neurons between models
|
110 |
+
"""
|
111 |
+
feats = defaultdict(list)
|
112 |
+
|
113 |
+
base_hook = lambda *args: hook_out(*args, feats, "base")
|
114 |
+
base_handle = base_model.model.layers[i].mlp.gate_proj.register_forward_hook(base_hook)
|
115 |
+
|
116 |
+
ft_hook = lambda *args: hook_out(*args, feats, "ft")
|
117 |
+
ft_handle = ft_model.model.layers[i].mlp.gate_proj.register_forward_hook(ft_hook)
|
118 |
+
|
119 |
+
evaluate(base_model, dataloader)
|
120 |
+
evaluate(ft_model, dataloader)
|
121 |
+
|
122 |
+
base_mat = torch.vstack(feats["base"])
|
123 |
+
ft_mat = torch.vstack(feats["ft"])
|
124 |
+
|
125 |
+
base_mat.to("cuda")
|
126 |
+
ft_mat.to("cuda")
|
127 |
+
|
128 |
+
base_mat = base_mat.view(-1, base_mat.shape[-1]).T
|
129 |
+
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
|
130 |
+
|
131 |
+
base_handle.remove()
|
132 |
+
ft_handle.remove()
|
133 |
+
|
134 |
+
perm = match_wmats(base_mat, ft_mat)
|
135 |
+
|
136 |
+
return perm
|
137 |
+
|
138 |
+
|
139 |
+
def mlp_matching_up(base_model, ft_model, dataloader, i=0):
|
140 |
+
"""
|
141 |
+
Match neurons between models by comparing up projection activations.
|
142 |
+
|
143 |
+
Collects activations from the up projection layer for both models
|
144 |
+
and computes a permutation that would align corresponding neurons.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
base_model: Base model to compare
|
148 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
149 |
+
dataloader: DataLoader providing input data for activation collection
|
150 |
+
i: Block index to analyze (default: 0)
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
torch.Tensor: Permutation indices that match neurons between models
|
154 |
+
"""
|
155 |
+
feats = defaultdict(list)
|
156 |
+
|
157 |
+
base_hook = lambda *args: hook_out(*args, feats, "base")
|
158 |
+
base_handle = base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hook)
|
159 |
+
|
160 |
+
ft_hook = lambda *args: hook_out(*args, feats, "ft")
|
161 |
+
ft_handle = ft_model.model.layers[i].mlp.up_proj.register_forward_hook(ft_hook)
|
162 |
+
|
163 |
+
evaluate(base_model, dataloader)
|
164 |
+
evaluate(ft_model, dataloader)
|
165 |
+
|
166 |
+
base_mat = torch.vstack(feats["base"])
|
167 |
+
ft_mat = torch.vstack(feats["ft"])
|
168 |
+
|
169 |
+
base_mat.to("cuda")
|
170 |
+
ft_mat.to("cuda")
|
171 |
+
|
172 |
+
base_mat = base_mat.view(-1, base_mat.shape[-1]).T
|
173 |
+
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T
|
174 |
+
|
175 |
+
base_handle.remove()
|
176 |
+
ft_handle.remove()
|
177 |
+
|
178 |
+
perm = match_wmats(base_mat, ft_mat)
|
179 |
+
|
180 |
+
return perm
|
181 |
+
|
182 |
+
|
183 |
+
def mlp_layers(base_model, ft_model, dataloader, i, j):
|
184 |
+
"""
|
185 |
+
Compare gate and up projections between specific layers of two models.
|
186 |
+
|
187 |
+
Useful for comparing non-corresponding layers to find functional similarities.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
base_model: Base model to compare
|
191 |
+
ft_model: Fine-tuned or target model to compare against the base model
|
192 |
+
dataloader: DataLoader providing input data for activation collection
|
193 |
+
i: Layer index in the base model
|
194 |
+
j: Layer index in the fine-tuned model
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
float: Spearman correlation p-value between gate and up projections
|
198 |
+
"""
|
199 |
+
gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i, j)
|
200 |
+
up_match = mlp_matching_up(base_model, ft_model, dataloader, i, j)
|
201 |
+
|
202 |
+
cor, pvalue = scipy.stats.spearmanr(gate_match.tolist(), up_match.tolist())
|
203 |
+
|
204 |
+
return pvalue
|
205 |
+
|
206 |
+
|
207 |
+
def statistic_all(model_1, model_2, dataloader):
|
208 |
+
"""
|
209 |
+
Perform comprehensive layer matching between two models.
|
210 |
+
|
211 |
+
Tests all combinations of layers between the models to identify corresponding
|
212 |
+
functional units, regardless of their position in the network architecture.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
model_1: First model to compare
|
216 |
+
model_2: Second model to compare
|
217 |
+
dataloader: DataLoader providing input data for activation collection
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
None: Prints matching results during execution
|
221 |
+
"""
|
222 |
+
model_1_matched = np.zeros(model_1.config.num_hidden_layers)
|
223 |
+
model_2_matched = np.zeros(model_2.config.num_hidden_layers)
|
224 |
+
|
225 |
+
for i in range(model_1.config.num_hidden_layers):
|
226 |
+
for j in range(model_2.config.num_hidden_layers):
|
227 |
+
if model_1_matched[i] == 1 or model_2_matched[j] == 1:
|
228 |
+
continue
|
229 |
+
stat = mlp_layers(model_1, model_2, dataloader, i, j)
|
230 |
+
print(i, j, stat)
|
231 |
+
if stat < 0.000001:
|
232 |
+
model_1_matched[i] = 1
|
233 |
+
model_2_matched[j] = 1
|
234 |
+
break
|
model-tracing/tracing/statistics/mc.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils.llama.model import avg_model
|
2 |
+
from ..utils.olmo.model import avg_model as avg_model_olmo
|
3 |
+
from ..utils.evaluate import evaluate
|
4 |
+
|
5 |
+
|
6 |
+
def statistic(base_model, ft_model, tmp_model, dataloader, attn=False, emb=False, alpha=0.5):
|
7 |
+
if "olmo" in base_model._get_name().lower():
|
8 |
+
avg_model_olmo(base_model, ft_model, tmp_model, attn=attn, emb=emb)
|
9 |
+
else:
|
10 |
+
avg_model(base_model, ft_model, tmp_model, attn=attn, emb=emb)
|
11 |
+
|
12 |
+
return sum(evaluate(tmp_model, dataloader))
|
model-tracing/tracing/statistics/perm_mc_l2.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tracing.perm.permute import permute_model
|
3 |
+
from scripts.perm.main import p_value_exact, p_value_approx
|
4 |
+
|
5 |
+
|
6 |
+
def statistic(base_model, ft_model, mc_stat, l2_stat, num_perm, emb_dim=4096, mlp_dim=11008):
|
7 |
+
|
8 |
+
unperm_stat_mc = mc_stat(base_model, ft_model)
|
9 |
+
unperm_stat_l2 = l2_stat(base_model, ft_model)
|
10 |
+
|
11 |
+
print(unperm_stat_mc, unperm_stat_l2)
|
12 |
+
|
13 |
+
perm_stats_mc = []
|
14 |
+
perm_stats_l2 = []
|
15 |
+
|
16 |
+
for i in range(num_perm):
|
17 |
+
mlp_permutation = torch.randperm(mlp_dim)
|
18 |
+
emb_permutation = torch.randperm(emb_dim)
|
19 |
+
|
20 |
+
permute_model(ft_model, mlp_permutation, emb_permutation)
|
21 |
+
|
22 |
+
perm_stat_mc = mc_stat(base_model, ft_model)
|
23 |
+
perm_stat_l2 = l2_stat(base_model, ft_model)
|
24 |
+
|
25 |
+
perm_stats_mc.append(perm_stat_mc)
|
26 |
+
perm_stats_l2.append(perm_stat_l2)
|
27 |
+
|
28 |
+
print(i, perm_stat_mc, perm_stat_l2)
|
29 |
+
|
30 |
+
exact_mc = p_value_exact(unperm_stat_mc, perm_stats_mc.copy())
|
31 |
+
approx_mc = p_value_approx(unperm_stat_mc, perm_stats_mc.copy())
|
32 |
+
|
33 |
+
exact_l2 = p_value_exact(unperm_stat_l2, perm_stats_l2.copy())
|
34 |
+
approx_l2 = p_value_approx(unperm_stat_l2, perm_stats_l2.copy())
|
35 |
+
|
36 |
+
print(exact_mc, approx_mc)
|
37 |
+
print(exact_l2, approx_l2)
|
38 |
+
|
39 |
+
return (
|
40 |
+
exact_mc,
|
41 |
+
approx_mc,
|
42 |
+
exact_l2,
|
43 |
+
approx_l2,
|
44 |
+
unperm_stat_mc,
|
45 |
+
unperm_stat_l2,
|
46 |
+
perm_stats_mc,
|
47 |
+
perm_stats_l2,
|
48 |
+
)
|
model-tracing/tracing/utils/__init__.py
ADDED
File without changes
|
model-tracing/tracing/utils/evaluate.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import glob
|
3 |
+
from typing import List
|
4 |
+
from datasets import load_dataset, concatenate_datasets, Dataset
|
5 |
+
from accelerate.data_loader import DataLoaderShard
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
|
8 |
+
|
9 |
+
def prepare_hf_dataset(hf_path, block_size, tokenizer, split="test"):
|
10 |
+
raw_dataset = load_dataset(hf_path, split=split)
|
11 |
+
dataset = raw_dataset.map(
|
12 |
+
lambda examples: tokenize_function(examples, tokenizer),
|
13 |
+
batched=True,
|
14 |
+
remove_columns=["text"],
|
15 |
+
).map(lambda examples: group_texts(examples, block_size), batched=True, batch_size=1)
|
16 |
+
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
17 |
+
return dataset
|
18 |
+
|
19 |
+
|
20 |
+
def prepare_programming_dataset(
|
21 |
+
json_path: str, block_size: int, tokenizer: AutoTokenizer, columns_ignored: List[str]
|
22 |
+
):
|
23 |
+
raw_dataset = load_dataset("json", data_files=json_path)
|
24 |
+
|
25 |
+
dataset = (
|
26 |
+
raw_dataset["train"]
|
27 |
+
.map(
|
28 |
+
lambda examples: tokenize_function(examples, tokenizer),
|
29 |
+
batched=True,
|
30 |
+
num_proc=4,
|
31 |
+
remove_columns=columns_ignored,
|
32 |
+
)
|
33 |
+
.map(
|
34 |
+
lambda examples: group_texts(examples, block_size),
|
35 |
+
batched=True,
|
36 |
+
batch_size=1,
|
37 |
+
num_proc=1,
|
38 |
+
)
|
39 |
+
)
|
40 |
+
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
41 |
+
return dataset
|
42 |
+
|
43 |
+
|
44 |
+
def prepare_random_sample_dataset(num_samples, block_size, vocab_size=32000):
|
45 |
+
tokens = torch.randint(low=0, high=vocab_size, size=(num_samples, block_size))
|
46 |
+
dictionary = {"input_ids": tokens, "attention_mask": torch.ones(tokens.shape), "labels": tokens}
|
47 |
+
|
48 |
+
dataset = Dataset.from_dict(dictionary)
|
49 |
+
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
50 |
+
return dataset
|
51 |
+
|
52 |
+
|
53 |
+
def load_m2d2_datasets(
|
54 |
+
test_name: str,
|
55 |
+
block_size: int,
|
56 |
+
tokenizer: AutoTokenizer,
|
57 |
+
columns_ignored: List[str],
|
58 |
+
):
|
59 |
+
base_path = "/juice4/scr4/nlp/model-tracing/m2d2_s2orc"
|
60 |
+
json_dir = f"{base_path}/{test_name}"
|
61 |
+
json_files = glob.glob(f"{json_dir}/*.json")
|
62 |
+
|
63 |
+
if not json_files:
|
64 |
+
raise ValueError(f"No JSON files found for test case: {test_name}")
|
65 |
+
|
66 |
+
datasets = []
|
67 |
+
for json_file in json_files:
|
68 |
+
dataset = prepare_programming_dataset(json_file, block_size, tokenizer, columns_ignored)
|
69 |
+
datasets.append(dataset)
|
70 |
+
|
71 |
+
combined_dataset = concatenate_datasets(datasets)
|
72 |
+
return combined_dataset
|
73 |
+
|
74 |
+
|
75 |
+
def load_dolma_programming_datasets(
|
76 |
+
test_name: str,
|
77 |
+
block_size: int,
|
78 |
+
tokenizer: AutoTokenizer,
|
79 |
+
columns_ignored: List[str],
|
80 |
+
):
|
81 |
+
base_path = "/juice4/scr4/nlp/model-tracing/dolma_program_languages"
|
82 |
+
|
83 |
+
json_dir = f"{base_path}/json_files_{test_name}"
|
84 |
+
json_files = glob.glob(f"{json_dir}/*.json")
|
85 |
+
|
86 |
+
datasets = []
|
87 |
+
for json_file in json_files:
|
88 |
+
dataset = prepare_programming_dataset(json_file, block_size, tokenizer, columns_ignored)
|
89 |
+
datasets.append(dataset)
|
90 |
+
|
91 |
+
combined_dataset = concatenate_datasets(datasets)
|
92 |
+
return combined_dataset
|
93 |
+
|
94 |
+
|
95 |
+
def load_generated_datasets(base_model_name, ft_model_name, block_size, tokenizer, columns_ignored):
|
96 |
+
|
97 |
+
json_file_base = (
|
98 |
+
"/juice4/scr4/nlp/model-tracing/generations/"
|
99 |
+
+ base_model_name.replace("/", "-")
|
100 |
+
+ "_gentext.json"
|
101 |
+
)
|
102 |
+
json_file_ft = (
|
103 |
+
"/juice4/scr4/nlp/model-tracing/generations/"
|
104 |
+
+ ft_model_name.replace("/", "-")
|
105 |
+
+ "_gentext.json"
|
106 |
+
)
|
107 |
+
dataset_base = prepare_programming_dataset(
|
108 |
+
json_file_base, block_size, tokenizer, columns_ignored
|
109 |
+
)
|
110 |
+
dataset_ft = prepare_programming_dataset(json_file_ft, block_size, tokenizer, columns_ignored)
|
111 |
+
|
112 |
+
datasets = []
|
113 |
+
datasets.append(dataset_base)
|
114 |
+
datasets.append(dataset_ft)
|
115 |
+
|
116 |
+
combined_dataset = concatenate_datasets(datasets)
|
117 |
+
|
118 |
+
return combined_dataset
|
119 |
+
|
120 |
+
|
121 |
+
def prepare_hf_dataloader(dataset, batch_size: int):
|
122 |
+
return DataLoaderShard(dataset, batch_size=batch_size)
|
123 |
+
|
124 |
+
|
125 |
+
def evaluate_70b(model, dataloader, device: str = "cuda:0"):
|
126 |
+
losses = []
|
127 |
+
with torch.no_grad():
|
128 |
+
for batch in dataloader:
|
129 |
+
input_ids = batch["input_ids"].to(device)
|
130 |
+
attention_mask = batch["attention_mask"].to(device)
|
131 |
+
labels = batch["labels"].to(device)
|
132 |
+
|
133 |
+
outputs = model(
|
134 |
+
input_ids=input_ids,
|
135 |
+
attention_mask=attention_mask,
|
136 |
+
labels=labels,
|
137 |
+
)
|
138 |
+
loss = outputs.loss
|
139 |
+
losses.append(loss.item())
|
140 |
+
|
141 |
+
return losses
|
142 |
+
|
143 |
+
|
144 |
+
def evaluate(model, dataloader, device: str = "cuda"):
|
145 |
+
losses = []
|
146 |
+
model.to(device)
|
147 |
+
with torch.no_grad():
|
148 |
+
for batch in dataloader:
|
149 |
+
input_ids = batch["input_ids"].to(device)
|
150 |
+
attention_mask = batch["attention_mask"].to(device)
|
151 |
+
labels = batch["labels"].to(device)
|
152 |
+
|
153 |
+
outputs = model(
|
154 |
+
input_ids=input_ids,
|
155 |
+
attention_mask=attention_mask,
|
156 |
+
labels=labels,
|
157 |
+
)
|
158 |
+
loss = outputs.loss
|
159 |
+
losses.append(loss.item())
|
160 |
+
|
161 |
+
model.to("cpu")
|
162 |
+
return losses
|
163 |
+
|
164 |
+
|
165 |
+
def prepare_aya_dataset(subset: str, language: str, block_size: int, tokenizer: AutoTokenizer):
|
166 |
+
"""
|
167 |
+
Prepare the Aya dataset for a specific subset and language.
|
168 |
+
"""
|
169 |
+
raw_dataset = load_dataset("CohereForAI/aya_evaluation_suite", subset)
|
170 |
+
filtered_dataset = raw_dataset.filter(lambda example: example["language"] == language)
|
171 |
+
|
172 |
+
dataset = filtered_dataset.map(
|
173 |
+
lambda examples: tokenize_function(examples, tokenizer),
|
174 |
+
batched=True,
|
175 |
+
remove_columns=filtered_dataset.column_names,
|
176 |
+
).map(lambda examples: group_texts(examples, block_size), batched=True, batch_size=1)
|
177 |
+
|
178 |
+
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
179 |
+
return dataset
|
180 |
+
|
181 |
+
|
182 |
+
def tokenize_aya_function(examples, tokenizer: AutoTokenizer):
|
183 |
+
"""
|
184 |
+
Tokenize Aya dataset examples.
|
185 |
+
"""
|
186 |
+
return tokenizer(examples["inputs"])
|
187 |
+
|
188 |
+
|
189 |
+
def tokenize_function(examples, tokenizer):
|
190 |
+
if "text" in examples:
|
191 |
+
return tokenizer(examples["text"])
|
192 |
+
elif "inputs" in examples:
|
193 |
+
return tokenizer(examples["inputs"])
|
194 |
+
else:
|
195 |
+
raise ValueError("Neither 'text' nor 'inputs' found in examples")
|
196 |
+
|
197 |
+
|
198 |
+
def group_texts(examples, block_size):
|
199 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
200 |
+
total_length = len(concatenated_examples["input_ids"])
|
201 |
+
|
202 |
+
total_length = (total_length // block_size) * block_size
|
203 |
+
# Split by chunks of max_len.
|
204 |
+
result = {
|
205 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
206 |
+
for k, t in concatenated_examples.items()
|
207 |
+
}
|
208 |
+
result["labels"] = result["input_ids"].copy()
|
209 |
+
return result
|
model-tracing/tracing/utils/llama/matching.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
VOCAB_SIZE = 32000
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from scipy.optimize import linear_sum_assignment as LAP
|
5 |
+
|
6 |
+
from ..utils import pdists
|
7 |
+
from .model import permute_model, permute_transformer_block
|
8 |
+
|
9 |
+
|
10 |
+
def match_wmats(wmat0, wmat1):
|
11 |
+
dists = pdists(wmat0, wmat1).type(torch.float64)
|
12 |
+
perm = LAP(dists)[1]
|
13 |
+
|
14 |
+
return perm # wmat1[perm] should match wmat0
|
15 |
+
|
16 |
+
|
17 |
+
def match_mlp(base_model, ft_model, i=0):
|
18 |
+
base_wmat = base_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
19 |
+
ft_wmat = ft_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
20 |
+
|
21 |
+
perm = match_wmats(base_wmat, ft_wmat)
|
22 |
+
|
23 |
+
return perm
|
24 |
+
|
25 |
+
|
26 |
+
def match_emb(base_model, ft_model, i="inp"):
|
27 |
+
if i == "inp":
|
28 |
+
weight_id = "model.embed_tokens.weight"
|
29 |
+
if i == "out":
|
30 |
+
weight_id = "lm_head.weight"
|
31 |
+
|
32 |
+
base_wmat = base_model.state_dict()[weight_id][:VOCAB_SIZE].T
|
33 |
+
ft_wmat = ft_model.state_dict()[weight_id][:VOCAB_SIZE].T
|
34 |
+
|
35 |
+
perm = match_wmats(base_wmat, ft_wmat)
|
36 |
+
return perm
|
37 |
+
|
38 |
+
|
39 |
+
def align_model(base_model, ft_model, tmp_model, n_blocks=32):
|
40 |
+
emb_perm = match_emb(base_model, ft_model)
|
41 |
+
permute_model(ft_model, tmp_model, torch.arange(11008), emb_perm)
|
42 |
+
|
43 |
+
for i in range(n_blocks):
|
44 |
+
mlp_perm = match_mlp(base_model, tmp_model, i=i)
|
45 |
+
permute_transformer_block(tmp_model, i, tmp_model, mlp_perm, torch.arange(4096))
|
model-tracing/tracing/utils/llama/model.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
from scipy.stats import ortho_group
|
4 |
+
|
5 |
+
|
6 |
+
def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32):
|
7 |
+
permute_embedding_layer(model, tmp_model, emb_permutation)
|
8 |
+
for i in range(n_blocks):
|
9 |
+
permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation)
|
10 |
+
permute_output_layer(tmp_model, tmp_model, emb_permutation)
|
11 |
+
|
12 |
+
|
13 |
+
def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation):
|
14 |
+
weights = model.state_dict()
|
15 |
+
|
16 |
+
weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[
|
17 |
+
"model.layers." + str(i) + ".self_attn.q_proj.weight"
|
18 |
+
][:, emb_permutation]
|
19 |
+
weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[
|
20 |
+
"model.layers." + str(i) + ".self_attn.k_proj.weight"
|
21 |
+
][:, emb_permutation]
|
22 |
+
weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[
|
23 |
+
"model.layers." + str(i) + ".self_attn.v_proj.weight"
|
24 |
+
][:, emb_permutation]
|
25 |
+
weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[
|
26 |
+
"model.layers." + str(i) + ".self_attn.o_proj.weight"
|
27 |
+
][emb_permutation]
|
28 |
+
|
29 |
+
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
|
30 |
+
"model.layers." + str(i) + ".mlp.gate_proj.weight"
|
31 |
+
][mlp_permutation]
|
32 |
+
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
|
33 |
+
"model.layers." + str(i) + ".mlp.up_proj.weight"
|
34 |
+
][mlp_permutation]
|
35 |
+
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
|
36 |
+
"model.layers." + str(i) + ".mlp.down_proj.weight"
|
37 |
+
][:, mlp_permutation]
|
38 |
+
|
39 |
+
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
|
40 |
+
"model.layers." + str(i) + ".mlp.gate_proj.weight"
|
41 |
+
][:, emb_permutation]
|
42 |
+
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
|
43 |
+
"model.layers." + str(i) + ".mlp.up_proj.weight"
|
44 |
+
][:, emb_permutation]
|
45 |
+
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
|
46 |
+
"model.layers." + str(i) + ".mlp.down_proj.weight"
|
47 |
+
][emb_permutation]
|
48 |
+
|
49 |
+
weights["model.layers." + str(i) + ".input_layernorm.weight"] = weights[
|
50 |
+
"model.layers." + str(i) + ".input_layernorm.weight"
|
51 |
+
][
|
52 |
+
emb_permutation
|
53 |
+
] # 1d
|
54 |
+
weights["model.layers." + str(i) + ".post_attention_layernorm.weight"] = weights[
|
55 |
+
"model.layers." + str(i) + ".post_attention_layernorm.weight"
|
56 |
+
][emb_permutation]
|
57 |
+
|
58 |
+
tmp_model.load_state_dict(weights)
|
59 |
+
|
60 |
+
|
61 |
+
def permute_embedding_layer(model, tmp_model, emb_permutation):
|
62 |
+
weights = model.state_dict()
|
63 |
+
|
64 |
+
weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
|
65 |
+
tmp_model.load_state_dict(weights)
|
66 |
+
|
67 |
+
|
68 |
+
def permute_output_layer(model, tmp_model, emb_permutation):
|
69 |
+
weights = model.state_dict()
|
70 |
+
|
71 |
+
weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
|
72 |
+
weights["model.norm.weight"] = weights["model.norm.weight"][emb_permutation]
|
73 |
+
tmp_model.load_state_dict(weights)
|
74 |
+
|
75 |
+
|
76 |
+
def permute_mlp_block(model, i, tmp_model, mlp_permutation):
|
77 |
+
weights = model.state_dict()
|
78 |
+
|
79 |
+
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
|
80 |
+
"model.layers." + str(i) + ".mlp.gate_proj.weight"
|
81 |
+
][mlp_permutation]
|
82 |
+
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
|
83 |
+
"model.layers." + str(i) + ".mlp.up_proj.weight"
|
84 |
+
][mlp_permutation]
|
85 |
+
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
|
86 |
+
"model.layers." + str(i) + ".mlp.down_proj.weight"
|
87 |
+
][:, mlp_permutation]
|
88 |
+
|
89 |
+
tmp_model.load_state_dict(weights)
|
90 |
+
|
91 |
+
|
92 |
+
def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5):
|
93 |
+
weights0 = model0.state_dict()
|
94 |
+
weights1 = model1.state_dict()
|
95 |
+
|
96 |
+
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
|
97 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
98 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
99 |
+
)
|
100 |
+
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
|
101 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
102 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
103 |
+
)
|
104 |
+
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
|
105 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
106 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
107 |
+
)
|
108 |
+
|
109 |
+
tmp_model.load_state_dict(weights0)
|
110 |
+
|
111 |
+
|
112 |
+
def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True):
|
113 |
+
weights0 = model0.state_dict()
|
114 |
+
weights1 = model1.state_dict()
|
115 |
+
|
116 |
+
if attn is True:
|
117 |
+
weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = (
|
118 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"]
|
119 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"]
|
120 |
+
)
|
121 |
+
weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = (
|
122 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"]
|
123 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"]
|
124 |
+
)
|
125 |
+
weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = (
|
126 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"]
|
127 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"]
|
128 |
+
)
|
129 |
+
weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = (
|
130 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"]
|
131 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"]
|
132 |
+
)
|
133 |
+
|
134 |
+
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
|
135 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
136 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
137 |
+
)
|
138 |
+
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
|
139 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
140 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
141 |
+
)
|
142 |
+
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
|
143 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
144 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
145 |
+
)
|
146 |
+
|
147 |
+
weights0["model.layers." + str(i) + ".input_layernorm.weight"] = (
|
148 |
+
alpha * weights0["model.layers." + str(i) + ".input_layernorm.weight"]
|
149 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".input_layernorm.weight"]
|
150 |
+
)
|
151 |
+
weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"] = (
|
152 |
+
alpha * weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"]
|
153 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".post_attention_layernorm.weight"]
|
154 |
+
)
|
155 |
+
|
156 |
+
tmp_model.load_state_dict(weights0)
|
157 |
+
|
158 |
+
|
159 |
+
def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5):
|
160 |
+
weights0 = model0.state_dict()
|
161 |
+
weights1 = model1.state_dict()
|
162 |
+
|
163 |
+
weights0["model.embed_tokens.weight"] = (
|
164 |
+
alpha * weights0["model.embed_tokens.weight"]
|
165 |
+
+ (1 - alpha) * weights1["model.embed_tokens.weight"]
|
166 |
+
)
|
167 |
+
|
168 |
+
tmp_model.load_state_dict(weights0)
|
169 |
+
|
170 |
+
|
171 |
+
def avg_output_layer(model0, model1, tmp_model, alpha=0.5):
|
172 |
+
weights0 = model0.state_dict()
|
173 |
+
weights1 = model1.state_dict()
|
174 |
+
|
175 |
+
weights0["lm_head.weight"] = (
|
176 |
+
alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"]
|
177 |
+
)
|
178 |
+
weights0["model.norm.weight"] = (
|
179 |
+
alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"]
|
180 |
+
)
|
181 |
+
|
182 |
+
tmp_model.load_state_dict(weights0)
|
183 |
+
|
184 |
+
|
185 |
+
def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True):
|
186 |
+
model1 = copy.deepcopy(model1)
|
187 |
+
|
188 |
+
if emb is True:
|
189 |
+
avg_embedding_layer(model0, model1, tmp_model, alpha=alpha)
|
190 |
+
else:
|
191 |
+
tmp_model.load_state_dict(model0.state_dict())
|
192 |
+
for i in range(n_blocks):
|
193 |
+
avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn)
|
194 |
+
if emb is True:
|
195 |
+
avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha)
|
196 |
+
|
197 |
+
|
198 |
+
def get_mlp_weights(model, i):
|
199 |
+
return model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
200 |
+
|
201 |
+
|
202 |
+
def get_emb_weights(model):
|
203 |
+
return model.state_dict()["model.embed_tokens.weight"]
|
204 |
+
|
205 |
+
|
206 |
+
def rotate_model(model, num_layers=32, hidden_dim=4096):
|
207 |
+
|
208 |
+
model.to("cuda")
|
209 |
+
|
210 |
+
rotation = ortho_group.rvs(dim=hidden_dim)
|
211 |
+
rotation = torch.tensor(rotation, dtype=torch.bfloat16).to("cuda")
|
212 |
+
|
213 |
+
weights = model.state_dict()
|
214 |
+
weights_rotated = model.state_dict()
|
215 |
+
|
216 |
+
weights_rotated["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"] @ rotation
|
217 |
+
|
218 |
+
for i in range(num_layers):
|
219 |
+
|
220 |
+
weights_rotated[f"model.layers.{i}.input_layernorm.weight"] = torch.ones(hidden_dim)
|
221 |
+
weights_rotated[f"model.layers.{i}.post_attention_layernorm.weight"] = torch.ones(
|
222 |
+
hidden_dim
|
223 |
+
)
|
224 |
+
|
225 |
+
weights_rotated[f"model.layers.{i}.self_attn.q_proj.weight"] = (
|
226 |
+
weights[f"model.layers.{i}.self_attn.q_proj.weight"]
|
227 |
+
@ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
|
228 |
+
@ rotation
|
229 |
+
)
|
230 |
+
weights_rotated[f"model.layers.{i}.self_attn.k_proj.weight"] = (
|
231 |
+
weights[f"model.layers.{i}.self_attn.k_proj.weight"]
|
232 |
+
@ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
|
233 |
+
@ rotation
|
234 |
+
)
|
235 |
+
weights_rotated[f"model.layers.{i}.self_attn.v_proj.weight"] = (
|
236 |
+
weights[f"model.layers.{i}.self_attn.v_proj.weight"]
|
237 |
+
@ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
|
238 |
+
@ rotation
|
239 |
+
)
|
240 |
+
weights_rotated[f"model.layers.{i}.self_attn.o_proj.weight"] = (
|
241 |
+
rotation.T @ weights[f"model.layers.{i}.self_attn.o_proj.weight"]
|
242 |
+
)
|
243 |
+
|
244 |
+
weights_rotated[f"model.layers.{i}.mlp.gate_proj.weight"] = (
|
245 |
+
weights[f"model.layers.{i}.mlp.gate_proj.weight"]
|
246 |
+
@ torch.diag(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
|
247 |
+
@ rotation
|
248 |
+
)
|
249 |
+
weights_rotated[f"model.layers.{i}.mlp.up_proj.weight"] = (
|
250 |
+
weights[f"model.layers.{i}.mlp.up_proj.weight"]
|
251 |
+
@ torch.diag(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
|
252 |
+
@ rotation
|
253 |
+
)
|
254 |
+
weights_rotated[f"model.layers.{i}.mlp.down_proj.weight"] = (
|
255 |
+
rotation.T @ weights[f"model.layers.{i}.mlp.down_proj.weight"]
|
256 |
+
)
|
257 |
+
|
258 |
+
weights_rotated["model.norm.weight"] = torch.ones(hidden_dim)
|
259 |
+
weights_rotated["lm_head.weight"] = (
|
260 |
+
weights["lm_head.weight"] @ torch.diag(weights["model.norm.weight"]) @ rotation
|
261 |
+
)
|
262 |
+
|
263 |
+
model.load_state_dict(weights_rotated)
|
model-tracing/tracing/utils/olmo/model.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
|
4 |
+
def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32):
|
5 |
+
permute_embedding_layer(model, tmp_model, emb_permutation)
|
6 |
+
for i in range(n_blocks):
|
7 |
+
permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation)
|
8 |
+
permute_output_layer(tmp_model, tmp_model, emb_permutation)
|
9 |
+
|
10 |
+
|
11 |
+
def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation):
|
12 |
+
weights = model.state_dict()
|
13 |
+
|
14 |
+
weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[
|
15 |
+
"model.layers." + str(i) + ".self_attn.q_proj.weight"
|
16 |
+
][:, emb_permutation]
|
17 |
+
weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[
|
18 |
+
"model.layers." + str(i) + ".self_attn.k_proj.weight"
|
19 |
+
][:, emb_permutation]
|
20 |
+
weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[
|
21 |
+
"model.layers." + str(i) + ".self_attn.v_proj.weight"
|
22 |
+
][:, emb_permutation]
|
23 |
+
weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[
|
24 |
+
"model.layers." + str(i) + ".self_attn.o_proj.weight"
|
25 |
+
][emb_permutation]
|
26 |
+
|
27 |
+
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
|
28 |
+
"model.layers." + str(i) + ".mlp.gate_proj.weight"
|
29 |
+
][mlp_permutation]
|
30 |
+
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
|
31 |
+
"model.layers." + str(i) + ".mlp.up_proj.weight"
|
32 |
+
][mlp_permutation]
|
33 |
+
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
|
34 |
+
"model.layers." + str(i) + ".mlp.down_proj.weight"
|
35 |
+
][:, mlp_permutation]
|
36 |
+
|
37 |
+
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
|
38 |
+
"model.layers." + str(i) + ".mlp.gate_proj.weight"
|
39 |
+
][:, emb_permutation]
|
40 |
+
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
|
41 |
+
"model.layers." + str(i) + ".mlp.up_proj.weight"
|
42 |
+
][:, emb_permutation]
|
43 |
+
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
|
44 |
+
"model.layers." + str(i) + ".mlp.down_proj.weight"
|
45 |
+
][emb_permutation]
|
46 |
+
|
47 |
+
tmp_model.load_state_dict(weights)
|
48 |
+
|
49 |
+
|
50 |
+
def permute_embedding_layer(model, tmp_model, emb_permutation):
|
51 |
+
weights = model.state_dict()
|
52 |
+
|
53 |
+
weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
|
54 |
+
tmp_model.load_state_dict(weights)
|
55 |
+
|
56 |
+
|
57 |
+
def permute_output_layer(model, tmp_model, emb_permutation):
|
58 |
+
weights = model.state_dict()
|
59 |
+
|
60 |
+
weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
|
61 |
+
tmp_model.load_state_dict(weights)
|
62 |
+
|
63 |
+
|
64 |
+
def permute_mlp_block(model, i, tmp_model, mlp_permutation):
|
65 |
+
weights = model.state_dict()
|
66 |
+
|
67 |
+
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
|
68 |
+
"model.layers." + str(i) + ".mlp.gate_proj.weight"
|
69 |
+
][mlp_permutation]
|
70 |
+
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
|
71 |
+
"model.layers." + str(i) + ".mlp.up_proj.weight"
|
72 |
+
][mlp_permutation]
|
73 |
+
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
|
74 |
+
"model.layers." + str(i) + ".mlp.down_proj.weight"
|
75 |
+
][:, mlp_permutation]
|
76 |
+
|
77 |
+
tmp_model.load_state_dict(weights)
|
78 |
+
|
79 |
+
|
80 |
+
def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5):
|
81 |
+
weights0 = model0.state_dict()
|
82 |
+
weights1 = model1.state_dict()
|
83 |
+
|
84 |
+
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
|
85 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
86 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
87 |
+
)
|
88 |
+
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
|
89 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
90 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
91 |
+
)
|
92 |
+
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
|
93 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
94 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
95 |
+
)
|
96 |
+
|
97 |
+
tmp_model.load_state_dict(weights0)
|
98 |
+
|
99 |
+
|
100 |
+
def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True):
|
101 |
+
weights0 = model0.state_dict()
|
102 |
+
weights1 = model1.state_dict()
|
103 |
+
|
104 |
+
if attn is True:
|
105 |
+
weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = (
|
106 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"]
|
107 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"]
|
108 |
+
)
|
109 |
+
weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = (
|
110 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"]
|
111 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"]
|
112 |
+
)
|
113 |
+
weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = (
|
114 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"]
|
115 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"]
|
116 |
+
)
|
117 |
+
weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = (
|
118 |
+
alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"]
|
119 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"]
|
120 |
+
)
|
121 |
+
|
122 |
+
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
|
123 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
124 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
|
125 |
+
)
|
126 |
+
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
|
127 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
128 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
|
129 |
+
)
|
130 |
+
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
|
131 |
+
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
132 |
+
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
|
133 |
+
)
|
134 |
+
|
135 |
+
tmp_model.load_state_dict(weights0)
|
136 |
+
|
137 |
+
|
138 |
+
def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5):
|
139 |
+
weights0 = model0.state_dict()
|
140 |
+
weights1 = model1.state_dict()
|
141 |
+
|
142 |
+
weights0["model.embed_tokens.weight"] = (
|
143 |
+
alpha * weights0["model.embed_tokens.weight"]
|
144 |
+
+ (1 - alpha) * weights1["model.embed_tokens.weight"]
|
145 |
+
)
|
146 |
+
|
147 |
+
tmp_model.load_state_dict(weights0)
|
148 |
+
|
149 |
+
|
150 |
+
def avg_output_layer(model0, model1, tmp_model, alpha=0.5):
|
151 |
+
weights0 = model0.state_dict()
|
152 |
+
weights1 = model1.state_dict()
|
153 |
+
|
154 |
+
weights0["lm_head.weight"] = (
|
155 |
+
alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"]
|
156 |
+
)
|
157 |
+
weights0["model.norm.weight"] = (
|
158 |
+
alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"]
|
159 |
+
)
|
160 |
+
|
161 |
+
tmp_model.load_state_dict(weights0)
|
162 |
+
|
163 |
+
|
164 |
+
def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True):
|
165 |
+
model1 = copy.deepcopy(model1)
|
166 |
+
|
167 |
+
if emb is True:
|
168 |
+
avg_embedding_layer(model0, model1, tmp_model, alpha=alpha)
|
169 |
+
else:
|
170 |
+
tmp_model.load_state_dict(model0.state_dict())
|
171 |
+
for i in range(n_blocks):
|
172 |
+
avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn)
|
173 |
+
if emb is True:
|
174 |
+
avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha)
|
175 |
+
|
176 |
+
|
177 |
+
def get_mlp_weights(model, i):
|
178 |
+
return model.state_dict()["model.layers." + str(i) + ".intermediate.dense.weight"]
|
179 |
+
|
180 |
+
|
181 |
+
def get_emb_weights(model):
|
182 |
+
return model.state_dict()["model.embed_tokens.weight"]
|
model-tracing/tracing/utils/plot_metrics.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from yaml import Loader
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
from scipy.stats import chi2
|
9 |
+
|
10 |
+
base = {
|
11 |
+
"lmsys/vicuna-7b-v1.5": 1,
|
12 |
+
"codellama/CodeLlama-7b-hf": 1,
|
13 |
+
"codellama/CodeLlama-7b-Python-hf": 1,
|
14 |
+
"codellama/CodeLlama-7b-Instruct-hf": 1,
|
15 |
+
"EleutherAI/llemma_7b": 1,
|
16 |
+
"microsoft/Orca-2-7b": 1,
|
17 |
+
"oh-yeontaek/llama-2-7B-LoRA-assemble": 1,
|
18 |
+
"lvkaokao/llama2-7b-hf-instruction-lora": 1,
|
19 |
+
"NousResearch/Nous-Hermes-llama-2-7b": 1,
|
20 |
+
"lmsys/vicuna-7b-v1.1": 0,
|
21 |
+
"yahma/llama-7b-hf": 0,
|
22 |
+
"Salesforce/xgen-7b-4k-base": 2,
|
23 |
+
"EleutherAI/llemma_7b_muinstruct_camelmath": 1,
|
24 |
+
"AlfredPros/CodeLlama-7b-Instruct-Solidity": 1,
|
25 |
+
"meta-llama/Llama-2-7b-hf": 1,
|
26 |
+
"LLM360/Amber": 3,
|
27 |
+
"LLM360/AmberChat": 3,
|
28 |
+
"openlm-research/open_llama_7b": 4,
|
29 |
+
"openlm-research/open_llama_7b_v2": 5,
|
30 |
+
"ibm-granite/granite-7b-base": 6,
|
31 |
+
"ibm-granite/granite-7b-instruct": 6,
|
32 |
+
}
|
33 |
+
|
34 |
+
base_ordered = {
|
35 |
+
"yahma/llama-7b-hf": 0,
|
36 |
+
"lmsys/vicuna-7b-v1.1": 0,
|
37 |
+
"meta-llama/Llama-2-7b-hf": 1,
|
38 |
+
"lmsys/vicuna-7b-v1.5": 1,
|
39 |
+
"codellama/CodeLlama-7b-hf": 1,
|
40 |
+
"codellama/CodeLlama-7b-Python-hf": 1,
|
41 |
+
"codellama/CodeLlama-7b-Instruct-hf": 1,
|
42 |
+
"AlfredPros/CodeLlama-7b-Instruct-Solidity": 1,
|
43 |
+
"EleutherAI/llemma_7b": 1,
|
44 |
+
"EleutherAI/llemma_7b_muinstruct_camelmath": 1,
|
45 |
+
"microsoft/Orca-2-7b": 1,
|
46 |
+
"oh-yeontaek/llama-2-7B-LoRA-assemble": 1,
|
47 |
+
"lvkaokao/llama2-7b-hf-instruction-lora": 1,
|
48 |
+
"NousResearch/Nous-Hermes-llama-2-7b": 1,
|
49 |
+
"Salesforce/xgen-7b-4k-base": 2,
|
50 |
+
"LLM360/Amber": 3,
|
51 |
+
"LLM360/AmberChat": 3,
|
52 |
+
"openlm-research/open_llama_7b": 4,
|
53 |
+
"openlm-research/open_llama_7b_v2": 5,
|
54 |
+
"ibm-granite/granite-7b-base": 6,
|
55 |
+
"ibm-granite/granite-7b-instruct": 6,
|
56 |
+
}
|
57 |
+
|
58 |
+
tree = {
|
59 |
+
"yahma/llama-7b-hf": "A---",
|
60 |
+
"lmsys/vicuna-7b-v1.1": "AA--",
|
61 |
+
"meta-llama/Llama-2-7b-hf": "B---",
|
62 |
+
"lmsys/vicuna-7b-v1.5": "BA--",
|
63 |
+
"codellama/CodeLlama-7b-hf": "BB--",
|
64 |
+
"codellama/CodeLlama-7b-Python-hf": "BBA-",
|
65 |
+
"codellama/CodeLlama-7b-Instruct-hf": "BBB-",
|
66 |
+
"AlfredPros/CodeLlama-7b-Instruct-Solidity": "BBBA",
|
67 |
+
"EleutherAI/llemma_7b": "BBC-",
|
68 |
+
"EleutherAI/llemma_7b_muinstruct_camelmath": "BBCA",
|
69 |
+
"microsoft/Orca-2-7b": "BC--",
|
70 |
+
"oh-yeontaek/llama-2-7B-LoRA-assemble": "BD--",
|
71 |
+
"lvkaokao/llama2-7b-hf-instruction-lora": "BE--",
|
72 |
+
"NousResearch/Nous-Hermes-llama-2-7b": "BF--",
|
73 |
+
"Salesforce/xgen-7b-4k-base": "C---",
|
74 |
+
"LLM360/Amber": "D---",
|
75 |
+
"LLM360/AmberChat": "DA--",
|
76 |
+
"openlm-research/open_llama_7b": "E---",
|
77 |
+
"openlm-research/open_llama_7b_v2": "F---",
|
78 |
+
"ibm-granite/granite-7b-base": "G---",
|
79 |
+
"ibm-granite/granite-7b-instruct": "GA--",
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
def get_dict_ft(flat_model_path):
|
84 |
+
dict_ft = {}
|
85 |
+
|
86 |
+
model_paths = yaml.load(open(flat_model_path, "r"), Loader=Loader)
|
87 |
+
|
88 |
+
for i in range(len(model_paths)):
|
89 |
+
for j in range(i + 1, len(model_paths)):
|
90 |
+
model_a = model_paths[i]
|
91 |
+
model_b = model_paths[j]
|
92 |
+
|
93 |
+
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-")
|
94 |
+
|
95 |
+
dict_ft[job_id] = base[model_a] == base[model_b]
|
96 |
+
|
97 |
+
return dict_ft
|
98 |
+
|
99 |
+
|
100 |
+
def get_statistic_from_file(filename):
|
101 |
+
|
102 |
+
file = open(filename, "r")
|
103 |
+
|
104 |
+
lines = file.readlines()
|
105 |
+
stat = np.nan
|
106 |
+
|
107 |
+
for line in lines:
|
108 |
+
if "Namespace" in line and "non-aligned test stat" in line:
|
109 |
+
# dict = json.loads(line)
|
110 |
+
# print(dict)
|
111 |
+
# print(dict['non-aligned test stat'])
|
112 |
+
|
113 |
+
start1 = line.find("non-aligned test stat")
|
114 |
+
stat = line[line.find(":", start1) : line.find(",", start1)]
|
115 |
+
stat = stat.replace(" ", "")
|
116 |
+
stat = stat.replace("(", "")
|
117 |
+
stat = stat.replace(":", "")
|
118 |
+
stat = stat.replace("tensor", "")
|
119 |
+
stat = float(stat)
|
120 |
+
|
121 |
+
return stat
|
122 |
+
|
123 |
+
|
124 |
+
def get_l2_stat_from_file(filename):
|
125 |
+
file = open(filename, "r")
|
126 |
+
|
127 |
+
lines = file.readlines()
|
128 |
+
stats = []
|
129 |
+
|
130 |
+
for line in lines:
|
131 |
+
if len(line) >= 4 and line[4] == " ":
|
132 |
+
stats.append(line[:4])
|
133 |
+
|
134 |
+
return float(stats[-1])
|
135 |
+
|
136 |
+
|
137 |
+
def get_layer_statistic_from_file(filename, layer):
|
138 |
+
file = open(filename, "r")
|
139 |
+
|
140 |
+
lines = file.readlines()
|
141 |
+
stat = np.nan
|
142 |
+
|
143 |
+
for line in lines:
|
144 |
+
temp = str(layer) + " "
|
145 |
+
if line[: len(temp)] == temp:
|
146 |
+
stat = line[line.find("0.") :]
|
147 |
+
if "e" in line:
|
148 |
+
stat = 0
|
149 |
+
# print(layer, stat)
|
150 |
+
stat = float(stat)
|
151 |
+
|
152 |
+
return stat
|
153 |
+
|
154 |
+
|
155 |
+
def plot_statistic_scatter(results_path, dict_ft, plot_path):
|
156 |
+
|
157 |
+
x = []
|
158 |
+
y = []
|
159 |
+
|
160 |
+
dir_list = os.listdir(results_path)
|
161 |
+
for file in dir_list:
|
162 |
+
models = file[: file.find(".out")]
|
163 |
+
if "huggyllama" in models:
|
164 |
+
continue
|
165 |
+
print(models)
|
166 |
+
ft = int(dict_ft[models])
|
167 |
+
stat = get_l2_stat_from_file(results_path + "/" + file)
|
168 |
+
# stat = get_statistic_from_file(results_path + '/' + file)
|
169 |
+
if not np.isnan(stat):
|
170 |
+
y.append(ft)
|
171 |
+
# x.append(get_statistic_from_file(results_path + '/' + file))
|
172 |
+
x.append(get_l2_stat_from_file(results_path + "/" + file))
|
173 |
+
|
174 |
+
plt.figure(figsize=(10, 1))
|
175 |
+
|
176 |
+
plt.scatter(x, y, s=8)
|
177 |
+
|
178 |
+
plt.xlabel("$p$-value")
|
179 |
+
plt.ylabel("Fine-tuned")
|
180 |
+
plt.ylim(-0.5, 1.5)
|
181 |
+
# plt.title(f"{}")
|
182 |
+
plot_filename = f"{plot_path}.png"
|
183 |
+
|
184 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
185 |
+
plt.close()
|
186 |
+
|
187 |
+
|
188 |
+
def plot_statistic_grid(results_path, dict_base, title, plot_path, decimals, log=False):
|
189 |
+
models = list(dict_base.keys())
|
190 |
+
print(models)
|
191 |
+
|
192 |
+
data = np.full((len(models), len(models)), np.nan)
|
193 |
+
|
194 |
+
for i in range(len(models)):
|
195 |
+
for j in range(len(models)):
|
196 |
+
model_a = models[i]
|
197 |
+
model_b = models[j]
|
198 |
+
|
199 |
+
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out"
|
200 |
+
|
201 |
+
if not os.path.exists(results_path + "/" + job_id):
|
202 |
+
continue
|
203 |
+
print(job_id)
|
204 |
+
|
205 |
+
stat = get_statistic_from_file(results_path + "/" + job_id)
|
206 |
+
# stat = get_l2_stat_from_file(results_path + '/' + job_id)
|
207 |
+
|
208 |
+
if log:
|
209 |
+
stat = np.log(stat)
|
210 |
+
|
211 |
+
data[i][j] = np.round(stat, decimals=decimals)
|
212 |
+
data[j][i] = data[i][j]
|
213 |
+
|
214 |
+
fig, ax = plt.subplots()
|
215 |
+
fig.set_size_inches(20, 20)
|
216 |
+
im = ax.imshow(data, cmap="viridis")
|
217 |
+
|
218 |
+
_ = make_axes_locatable(ax)
|
219 |
+
_ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
220 |
+
|
221 |
+
# Show all ticks and label them with the respective list entries
|
222 |
+
ax.set_xticks(np.arange(len(models)), labels=models)
|
223 |
+
ax.set_yticks(np.arange(len(models)), labels=models)
|
224 |
+
|
225 |
+
# Rotate the tick labels and set their alignment.
|
226 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
227 |
+
|
228 |
+
texts = []
|
229 |
+
for i in range(len(models)):
|
230 |
+
text1 = []
|
231 |
+
for j in range(len(models)):
|
232 |
+
text1.append("")
|
233 |
+
texts.append(text1)
|
234 |
+
|
235 |
+
# Loop over data dimensions and create text annotations.
|
236 |
+
for i in range(len(models)):
|
237 |
+
for j in range(len(models)):
|
238 |
+
texts[i][j] = str(data[i][j])
|
239 |
+
if data[i][j] == 0.0:
|
240 |
+
texts[i][j] = "$\\varepsilon$"
|
241 |
+
_ = ax.text(j, i, texts[i][j], ha="center", va="center", color="w")
|
242 |
+
|
243 |
+
ax.set_title(title)
|
244 |
+
fig.tight_layout()
|
245 |
+
plot_filename = f"{plot_path}.png"
|
246 |
+
|
247 |
+
plt.savefig(plot_filename, dpi=500, bbox_inches="tight")
|
248 |
+
plt.close()
|
249 |
+
|
250 |
+
|
251 |
+
def plot_statistic_scatter_layer(results_path, dict_ft, plot_path, layer):
|
252 |
+
|
253 |
+
x = []
|
254 |
+
y = []
|
255 |
+
|
256 |
+
dir_list = os.listdir(results_path)
|
257 |
+
for file in dir_list:
|
258 |
+
models = file[: file.find(".out")]
|
259 |
+
if "huggyllama" in models:
|
260 |
+
continue
|
261 |
+
print(models)
|
262 |
+
ft = int(dict_ft[models])
|
263 |
+
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
|
264 |
+
if not np.isnan(stat):
|
265 |
+
x.append(ft)
|
266 |
+
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
|
267 |
+
|
268 |
+
plt.figure(figsize=(8, 6))
|
269 |
+
|
270 |
+
plt.scatter(x, y, s=2)
|
271 |
+
|
272 |
+
plt.xlabel("Fine tuned")
|
273 |
+
plt.ylabel("Test statistic")
|
274 |
+
# plt.title(f"{}")
|
275 |
+
plot_filename = f"{plot_path}.png"
|
276 |
+
|
277 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
278 |
+
plt.close()
|
279 |
+
|
280 |
+
|
281 |
+
def plot_statistic_grid_layer(
|
282 |
+
results_path, dict_base, title, plot_path, decimals, layer, log=False
|
283 |
+
):
|
284 |
+
models = list(dict_base.keys())
|
285 |
+
print(models)
|
286 |
+
|
287 |
+
data = np.full((len(models), len(models)), np.nan)
|
288 |
+
|
289 |
+
for i in range(len(models)):
|
290 |
+
for j in range(len(models)):
|
291 |
+
model_a = models[i]
|
292 |
+
model_b = models[j]
|
293 |
+
|
294 |
+
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out"
|
295 |
+
|
296 |
+
if not os.path.exists(results_path + "/" + job_id):
|
297 |
+
continue
|
298 |
+
|
299 |
+
stat = get_layer_statistic_from_file(results_path + "/" + job_id, layer)
|
300 |
+
|
301 |
+
if log:
|
302 |
+
stat = np.log(stat)
|
303 |
+
|
304 |
+
data[i][j] = np.round(stat, decimals=decimals)
|
305 |
+
data[j][i] = data[i][j]
|
306 |
+
|
307 |
+
fig, ax = plt.subplots()
|
308 |
+
fig.set_size_inches(20, 20)
|
309 |
+
im = ax.imshow(data)
|
310 |
+
|
311 |
+
_ = make_axes_locatable(ax)
|
312 |
+
_ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
313 |
+
|
314 |
+
# Show all ticks and label them with the respective list entries
|
315 |
+
ax.set_xticks(np.arange(len(models)), labels=models)
|
316 |
+
ax.set_yticks(np.arange(len(models)), labels=models)
|
317 |
+
|
318 |
+
# Rotate the tick labels and set their alignment.
|
319 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
320 |
+
|
321 |
+
# Loop over data dimensions and create text annotations.
|
322 |
+
for i in range(len(models)):
|
323 |
+
for j in range(len(models)):
|
324 |
+
_ = ax.text(j, i, data[i, j], ha="center", va="center", color="w")
|
325 |
+
|
326 |
+
ax.set_title(title)
|
327 |
+
fig.tight_layout()
|
328 |
+
plot_filename = f"{plot_path}.png"
|
329 |
+
|
330 |
+
plt.savefig(plot_filename, dpi=500, bbox_inches="tight")
|
331 |
+
plt.close()
|
332 |
+
|
333 |
+
|
334 |
+
def plot_histogram(results_path, dict_ft, plot_path):
|
335 |
+
indp = []
|
336 |
+
not_indp = []
|
337 |
+
|
338 |
+
dir_list = os.listdir(results_path)
|
339 |
+
for file in dir_list:
|
340 |
+
models = file[: file.find(".out")]
|
341 |
+
print(models)
|
342 |
+
ft = int(dict_ft[models])
|
343 |
+
stat = get_statistic_from_file(results_path + "/" + file)
|
344 |
+
if not np.isnan(stat):
|
345 |
+
if ft:
|
346 |
+
not_indp.append(stat)
|
347 |
+
else:
|
348 |
+
indp.append(stat)
|
349 |
+
|
350 |
+
plt.figure(figsize=(8, 6))
|
351 |
+
|
352 |
+
plt.hist(indp, bins=20, range=(0, 1), color="blue")
|
353 |
+
plt.hist(not_indp, bins=20, range=(0, 1), color="green")
|
354 |
+
|
355 |
+
plt.xlabel("Test statistic value")
|
356 |
+
plt.ylabel("Count")
|
357 |
+
# plt.title(f"{}")
|
358 |
+
plot_filename = f"{plot_path}.png"
|
359 |
+
|
360 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
361 |
+
plt.close()
|
362 |
+
|
363 |
+
|
364 |
+
def fisher(pvalues):
|
365 |
+
chi_squared = 0
|
366 |
+
num_layers = 0
|
367 |
+
for pvalue in pvalues:
|
368 |
+
if not np.isnan(pvalue):
|
369 |
+
chi_squared -= 2 * np.log(pvalue)
|
370 |
+
num_layers += 1
|
371 |
+
|
372 |
+
return chi2.sf(chi_squared, df=2 * num_layers)
|
373 |
+
|
374 |
+
|
375 |
+
def plot_statistic_scatter_all_layers(results_path, dict_ft, plot_path):
|
376 |
+
|
377 |
+
x = []
|
378 |
+
y = []
|
379 |
+
c = []
|
380 |
+
|
381 |
+
dir_list = os.listdir(results_path)
|
382 |
+
|
383 |
+
for layer in range(32):
|
384 |
+
|
385 |
+
for file in dir_list:
|
386 |
+
models = file[: file.find(".out")]
|
387 |
+
# if("huggyllama" in models): continue
|
388 |
+
print(models)
|
389 |
+
ft = int(dict_ft[models])
|
390 |
+
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
|
391 |
+
if not np.isnan(stat):
|
392 |
+
x.append(layer)
|
393 |
+
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
|
394 |
+
if ft:
|
395 |
+
c.append("r")
|
396 |
+
else:
|
397 |
+
c.append("b")
|
398 |
+
|
399 |
+
for file in dir_list:
|
400 |
+
models = file[: file.find(".out")]
|
401 |
+
# if("huggyllama" in models): continue
|
402 |
+
ft = int(dict_ft[models])
|
403 |
+
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
|
404 |
+
if not np.isnan(stat):
|
405 |
+
x.append(layer)
|
406 |
+
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
|
407 |
+
if ft:
|
408 |
+
c.append("r")
|
409 |
+
else:
|
410 |
+
c.append("b")
|
411 |
+
|
412 |
+
plt.figure(figsize=(8, 6))
|
413 |
+
|
414 |
+
plt.scatter(x, y, s=1.5, c=c)
|
415 |
+
|
416 |
+
plt.xlabel("Layer")
|
417 |
+
plt.ylabel("p-value")
|
418 |
+
# plt.title(f"{}")
|
419 |
+
plot_filename = f"{plot_path}.png"
|
420 |
+
|
421 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
422 |
+
plt.close()
|
423 |
+
|
424 |
+
|
425 |
+
def plot_pvalue(results_path, dict_ft, plot_path):
|
426 |
+
|
427 |
+
pvalues = []
|
428 |
+
|
429 |
+
dir_list = os.listdir(results_path)
|
430 |
+
|
431 |
+
for layer in range(32):
|
432 |
+
|
433 |
+
for file in dir_list:
|
434 |
+
models = file[: file.find(".out")]
|
435 |
+
if "huggyllama" in models:
|
436 |
+
continue
|
437 |
+
print(models)
|
438 |
+
ft = int(dict_ft[models])
|
439 |
+
if ft is True:
|
440 |
+
continue
|
441 |
+
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
|
442 |
+
if not np.isnan(stat):
|
443 |
+
pvalues.append(stat)
|
444 |
+
x = np.arange(0, 1, step=0.001)
|
445 |
+
y = []
|
446 |
+
print(pvalues)
|
447 |
+
print(len(pvalues))
|
448 |
+
|
449 |
+
for i in x:
|
450 |
+
counter = 0
|
451 |
+
for val in pvalues:
|
452 |
+
if val < i:
|
453 |
+
counter += 1
|
454 |
+
y.append(counter / len(pvalues))
|
455 |
+
|
456 |
+
plt.figure(figsize=(8, 6))
|
457 |
+
|
458 |
+
plt.plot(x, y, ".-")
|
459 |
+
|
460 |
+
# plt.xlabel("Fine tuned")
|
461 |
+
# plt.ylabel("Test statistic")
|
462 |
+
# plt.title(f"{}")
|
463 |
+
# plt.xlim(-10,0)
|
464 |
+
# plt.ylim(-10,0)
|
465 |
+
plot_filename = f"{plot_path}.png"
|
466 |
+
|
467 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
468 |
+
plt.close()
|
469 |
+
|
470 |
+
|
471 |
+
if __name__ == "__main__":
|
472 |
+
dict_ft = get_dict_ft("/nlp/u/salzhu/model-tracing/config/llama_flat.yaml")
|
473 |
+
|
474 |
+
# plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs",
|
475 |
+
# dict_ft, "test_statistic_plots/l2_pvalue_horizontal")
|
476 |
+
|
477 |
+
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
|
478 |
+
# base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)",
|
479 |
+
# "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap",
|
480 |
+
# 3, log=False)
|
481 |
+
|
482 |
+
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/csh_0928_reruns/logs",
|
483 |
+
# base_ordered, "",
|
484 |
+
# "/nlp/u/salzhu/csh_0929_cols",
|
485 |
+
# 3, log=False)
|
486 |
+
|
487 |
+
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
|
488 |
+
# base_ordered, "",
|
489 |
+
# "/nlp/u/salzhu/robust_0929",
|
490 |
+
# 3, log=False)
|
491 |
+
|
492 |
+
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs",
|
493 |
+
# base_ordered, "",
|
494 |
+
# "/nlp/u/salzhu/l2_0927",
|
495 |
+
# 3, log=False)
|
496 |
+
|
497 |
+
# plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft,
|
498 |
+
# "/nlp/u/salzhu/test_statistic_plots/mlp_sp_final")
|
499 |
+
|
500 |
+
# plot_statistic_scatter_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft,
|
501 |
+
# "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_layer31", 31)
|
502 |
+
|
503 |
+
# plot_statistic_grid_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
|
504 |
+
# base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)",
|
505 |
+
# "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap_layer31",
|
506 |
+
# 3, 31, log=False)
|
507 |
+
|
508 |
+
# plot_statistic_scatter_all_layers("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
|
509 |
+
# dict_ft,
|
510 |
+
# "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_all_layers")
|
511 |
+
|
512 |
+
plot_pvalue(
|
513 |
+
"/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
|
514 |
+
dict_ft,
|
515 |
+
"/nlp/u/salzhu/test_statistic_plots/mlp_sp_final",
|
516 |
+
)
|
517 |
+
|
518 |
+
# plot_histogram("/juice4/scr4/nlp/model-tracing/mlp_match_med_max_layer0/logs",
|
519 |
+
# dict_ft, "/nlp/u/salzhu/test_statistic_plots/mlp_med_max_histogram")
|
520 |
+
|
521 |
+
# checkpoints = {
|
522 |
+
# "100M": 1e8,
|
523 |
+
# "1B": 1e9,
|
524 |
+
# "10B": 1e10,
|
525 |
+
# "18B": 1.8e10,
|
526 |
+
# }
|
527 |
+
|
528 |
+
# checkpoints = {
|
529 |
+
# "100M": 1e8,
|
530 |
+
# "1B": 1e9,
|
531 |
+
# "4B": 4e9,
|
532 |
+
# "8B": 8e9,
|
533 |
+
# "16B": 1.6e10
|
534 |
+
# }
|
535 |
+
|
536 |
+
# checkpoints = {
|
537 |
+
# "100M": 1e8,
|
538 |
+
# "1B": 1e9,
|
539 |
+
# "12B": 1.2e10,
|
540 |
+
# "25B": 2.5e10
|
541 |
+
# }
|
542 |
+
|
543 |
+
# plot_statistic_olmo_scatter("/juice4/scr4/nlp/model-tracing/olmo_models_runs/final_checkpoint/csw_robust_cols/logs", checkpoints,
|
544 |
+
# "final checkpoint vs. additional training seed 42", "CSW robust",
|
545 |
+
# "/nlp/u/salzhu/olmo_plots/final_checkpoint/csw_robust_cols_seed42")
|
model-tracing/tracing/utils/utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
from scipy.stats import chi2
|
7 |
+
|
8 |
+
|
9 |
+
def manual_seed(seed, fix_cudnn=True):
|
10 |
+
random.seed(seed)
|
11 |
+
np.random.seed(seed)
|
12 |
+
torch.manual_seed(seed)
|
13 |
+
torch.cuda.manual_seed_all(seed)
|
14 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
15 |
+
if fix_cudnn:
|
16 |
+
torch.backends.cudnn.deterministic = True # noqa
|
17 |
+
torch.backends.cudnn.benchmark = False # noqa
|
18 |
+
|
19 |
+
|
20 |
+
def spcor(x, y):
|
21 |
+
n = len(x)
|
22 |
+
with torch.no_grad():
|
23 |
+
r = 1 - torch.sum(6 * torch.square(x - y)) / (n * (n**2 - 1))
|
24 |
+
|
25 |
+
return r
|
26 |
+
|
27 |
+
|
28 |
+
def pdists(x, y):
|
29 |
+
x = x.to("cuda")
|
30 |
+
y = y.to("cuda")
|
31 |
+
|
32 |
+
with torch.no_grad():
|
33 |
+
xsum = torch.sum(torch.square(x), axis=-1)
|
34 |
+
ysum = torch.sum(torch.square(y), axis=-1)
|
35 |
+
|
36 |
+
dists = xsum.view(-1, 1) + ysum.view(1, -1) - 2 * x @ y.T
|
37 |
+
|
38 |
+
return dists.cpu()
|
39 |
+
|
40 |
+
|
41 |
+
def cossim(x, y):
|
42 |
+
x = x.to("cuda")
|
43 |
+
y = y.to("cuda")
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
similarities = (
|
47 |
+
x
|
48 |
+
@ y.T
|
49 |
+
/ (
|
50 |
+
torch.linalg.norm(x, axis=-1).view(-1, 1)
|
51 |
+
* torch.linalg.norm(y, axis=-1).view(1, -1)
|
52 |
+
)
|
53 |
+
)
|
54 |
+
|
55 |
+
return similarities.cpu()
|
56 |
+
|
57 |
+
|
58 |
+
def fisher(p):
|
59 |
+
count = 0
|
60 |
+
chi_2 = 0
|
61 |
+
for pvalue in p:
|
62 |
+
if not np.isnan(pvalue):
|
63 |
+
chi_2 -= 2 * np.log(pvalue)
|
64 |
+
count += 1
|
65 |
+
|
66 |
+
return chi2.sf(chi_2, df=2 * count)
|
67 |
+
|
68 |
+
|
69 |
+
def normalize_mc_midpoint(mid, base, ft):
|
70 |
+
slope = ft - base
|
71 |
+
mid -= slope * 0.5
|
72 |
+
mid -= base
|
73 |
+
return mid
|
74 |
+
|
75 |
+
|
76 |
+
def normalize_trace(trace, alphas):
|
77 |
+
slope = trace[-1] - trace[0]
|
78 |
+
start = trace[0]
|
79 |
+
for i in range(len(trace)):
|
80 |
+
trace[i] -= slope * alphas[i]
|
81 |
+
trace[i] -= start
|
82 |
+
return trace
|
83 |
+
|
84 |
+
|
85 |
+
def output_hook(m, inp, op, name, feats):
|
86 |
+
feats[name] = op.detach()
|
87 |
+
|
88 |
+
|
89 |
+
def get_submodule(module, submodule_string):
|
90 |
+
attributes = submodule_string.split(".")
|
91 |
+
for attr in attributes:
|
92 |
+
module = getattr(module, attr)
|
93 |
+
return module
|
94 |
+
|
95 |
+
|
96 |
+
def plot_trace(losses, alphas, normalize, model_a_name, model_b_name, plot_path):
|
97 |
+
|
98 |
+
plt.figure(figsize=(8, 6))
|
99 |
+
if normalize:
|
100 |
+
losses = normalize_trace(losses, alphas)
|
101 |
+
plt.plot(alphas, losses, "o-")
|
102 |
+
|
103 |
+
plt.xlabel("Alpha")
|
104 |
+
plt.ylabel("Loss")
|
105 |
+
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
|
106 |
+
plot_filename = f"{plot_path}.png"
|
107 |
+
|
108 |
+
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
|
109 |
+
plt.close()
|