Ahmed Ahmed
Add model-tracing code for p-value computation (without binary files)
de071e9
import yaml
from yaml import Loader
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import os
from scipy.stats import chi2
base = {
"lmsys/vicuna-7b-v1.5": 1,
"codellama/CodeLlama-7b-hf": 1,
"codellama/CodeLlama-7b-Python-hf": 1,
"codellama/CodeLlama-7b-Instruct-hf": 1,
"EleutherAI/llemma_7b": 1,
"microsoft/Orca-2-7b": 1,
"oh-yeontaek/llama-2-7B-LoRA-assemble": 1,
"lvkaokao/llama2-7b-hf-instruction-lora": 1,
"NousResearch/Nous-Hermes-llama-2-7b": 1,
"lmsys/vicuna-7b-v1.1": 0,
"yahma/llama-7b-hf": 0,
"Salesforce/xgen-7b-4k-base": 2,
"EleutherAI/llemma_7b_muinstruct_camelmath": 1,
"AlfredPros/CodeLlama-7b-Instruct-Solidity": 1,
"meta-llama/Llama-2-7b-hf": 1,
"LLM360/Amber": 3,
"LLM360/AmberChat": 3,
"openlm-research/open_llama_7b": 4,
"openlm-research/open_llama_7b_v2": 5,
"ibm-granite/granite-7b-base": 6,
"ibm-granite/granite-7b-instruct": 6,
}
base_ordered = {
"yahma/llama-7b-hf": 0,
"lmsys/vicuna-7b-v1.1": 0,
"meta-llama/Llama-2-7b-hf": 1,
"lmsys/vicuna-7b-v1.5": 1,
"codellama/CodeLlama-7b-hf": 1,
"codellama/CodeLlama-7b-Python-hf": 1,
"codellama/CodeLlama-7b-Instruct-hf": 1,
"AlfredPros/CodeLlama-7b-Instruct-Solidity": 1,
"EleutherAI/llemma_7b": 1,
"EleutherAI/llemma_7b_muinstruct_camelmath": 1,
"microsoft/Orca-2-7b": 1,
"oh-yeontaek/llama-2-7B-LoRA-assemble": 1,
"lvkaokao/llama2-7b-hf-instruction-lora": 1,
"NousResearch/Nous-Hermes-llama-2-7b": 1,
"Salesforce/xgen-7b-4k-base": 2,
"LLM360/Amber": 3,
"LLM360/AmberChat": 3,
"openlm-research/open_llama_7b": 4,
"openlm-research/open_llama_7b_v2": 5,
"ibm-granite/granite-7b-base": 6,
"ibm-granite/granite-7b-instruct": 6,
}
tree = {
"yahma/llama-7b-hf": "A---",
"lmsys/vicuna-7b-v1.1": "AA--",
"meta-llama/Llama-2-7b-hf": "B---",
"lmsys/vicuna-7b-v1.5": "BA--",
"codellama/CodeLlama-7b-hf": "BB--",
"codellama/CodeLlama-7b-Python-hf": "BBA-",
"codellama/CodeLlama-7b-Instruct-hf": "BBB-",
"AlfredPros/CodeLlama-7b-Instruct-Solidity": "BBBA",
"EleutherAI/llemma_7b": "BBC-",
"EleutherAI/llemma_7b_muinstruct_camelmath": "BBCA",
"microsoft/Orca-2-7b": "BC--",
"oh-yeontaek/llama-2-7B-LoRA-assemble": "BD--",
"lvkaokao/llama2-7b-hf-instruction-lora": "BE--",
"NousResearch/Nous-Hermes-llama-2-7b": "BF--",
"Salesforce/xgen-7b-4k-base": "C---",
"LLM360/Amber": "D---",
"LLM360/AmberChat": "DA--",
"openlm-research/open_llama_7b": "E---",
"openlm-research/open_llama_7b_v2": "F---",
"ibm-granite/granite-7b-base": "G---",
"ibm-granite/granite-7b-instruct": "GA--",
}
def get_dict_ft(flat_model_path):
dict_ft = {}
model_paths = yaml.load(open(flat_model_path, "r"), Loader=Loader)
for i in range(len(model_paths)):
for j in range(i + 1, len(model_paths)):
model_a = model_paths[i]
model_b = model_paths[j]
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-")
dict_ft[job_id] = base[model_a] == base[model_b]
return dict_ft
def get_statistic_from_file(filename):
file = open(filename, "r")
lines = file.readlines()
stat = np.nan
for line in lines:
if "Namespace" in line and "non-aligned test stat" in line:
# dict = json.loads(line)
# print(dict)
# print(dict['non-aligned test stat'])
start1 = line.find("non-aligned test stat")
stat = line[line.find(":", start1) : line.find(",", start1)]
stat = stat.replace(" ", "")
stat = stat.replace("(", "")
stat = stat.replace(":", "")
stat = stat.replace("tensor", "")
stat = float(stat)
return stat
def get_l2_stat_from_file(filename):
file = open(filename, "r")
lines = file.readlines()
stats = []
for line in lines:
if len(line) >= 4 and line[4] == " ":
stats.append(line[:4])
return float(stats[-1])
def get_layer_statistic_from_file(filename, layer):
file = open(filename, "r")
lines = file.readlines()
stat = np.nan
for line in lines:
temp = str(layer) + " "
if line[: len(temp)] == temp:
stat = line[line.find("0.") :]
if "e" in line:
stat = 0
# print(layer, stat)
stat = float(stat)
return stat
def plot_statistic_scatter(results_path, dict_ft, plot_path):
x = []
y = []
dir_list = os.listdir(results_path)
for file in dir_list:
models = file[: file.find(".out")]
if "huggyllama" in models:
continue
print(models)
ft = int(dict_ft[models])
stat = get_l2_stat_from_file(results_path + "/" + file)
# stat = get_statistic_from_file(results_path + '/' + file)
if not np.isnan(stat):
y.append(ft)
# x.append(get_statistic_from_file(results_path + '/' + file))
x.append(get_l2_stat_from_file(results_path + "/" + file))
plt.figure(figsize=(10, 1))
plt.scatter(x, y, s=8)
plt.xlabel("$p$-value")
plt.ylabel("Fine-tuned")
plt.ylim(-0.5, 1.5)
# plt.title(f"{}")
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
plt.close()
def plot_statistic_grid(results_path, dict_base, title, plot_path, decimals, log=False):
models = list(dict_base.keys())
print(models)
data = np.full((len(models), len(models)), np.nan)
for i in range(len(models)):
for j in range(len(models)):
model_a = models[i]
model_b = models[j]
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out"
if not os.path.exists(results_path + "/" + job_id):
continue
print(job_id)
stat = get_statistic_from_file(results_path + "/" + job_id)
# stat = get_l2_stat_from_file(results_path + '/' + job_id)
if log:
stat = np.log(stat)
data[i][j] = np.round(stat, decimals=decimals)
data[j][i] = data[i][j]
fig, ax = plt.subplots()
fig.set_size_inches(20, 20)
im = ax.imshow(data, cmap="viridis")
_ = make_axes_locatable(ax)
_ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# Show all ticks and label them with the respective list entries
ax.set_xticks(np.arange(len(models)), labels=models)
ax.set_yticks(np.arange(len(models)), labels=models)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
texts = []
for i in range(len(models)):
text1 = []
for j in range(len(models)):
text1.append("")
texts.append(text1)
# Loop over data dimensions and create text annotations.
for i in range(len(models)):
for j in range(len(models)):
texts[i][j] = str(data[i][j])
if data[i][j] == 0.0:
texts[i][j] = "$\\varepsilon$"
_ = ax.text(j, i, texts[i][j], ha="center", va="center", color="w")
ax.set_title(title)
fig.tight_layout()
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=500, bbox_inches="tight")
plt.close()
def plot_statistic_scatter_layer(results_path, dict_ft, plot_path, layer):
x = []
y = []
dir_list = os.listdir(results_path)
for file in dir_list:
models = file[: file.find(".out")]
if "huggyllama" in models:
continue
print(models)
ft = int(dict_ft[models])
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
if not np.isnan(stat):
x.append(ft)
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
plt.figure(figsize=(8, 6))
plt.scatter(x, y, s=2)
plt.xlabel("Fine tuned")
plt.ylabel("Test statistic")
# plt.title(f"{}")
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
plt.close()
def plot_statistic_grid_layer(
results_path, dict_base, title, plot_path, decimals, layer, log=False
):
models = list(dict_base.keys())
print(models)
data = np.full((len(models), len(models)), np.nan)
for i in range(len(models)):
for j in range(len(models)):
model_a = models[i]
model_b = models[j]
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out"
if not os.path.exists(results_path + "/" + job_id):
continue
stat = get_layer_statistic_from_file(results_path + "/" + job_id, layer)
if log:
stat = np.log(stat)
data[i][j] = np.round(stat, decimals=decimals)
data[j][i] = data[i][j]
fig, ax = plt.subplots()
fig.set_size_inches(20, 20)
im = ax.imshow(data)
_ = make_axes_locatable(ax)
_ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# Show all ticks and label them with the respective list entries
ax.set_xticks(np.arange(len(models)), labels=models)
ax.set_yticks(np.arange(len(models)), labels=models)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(models)):
for j in range(len(models)):
_ = ax.text(j, i, data[i, j], ha="center", va="center", color="w")
ax.set_title(title)
fig.tight_layout()
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=500, bbox_inches="tight")
plt.close()
def plot_histogram(results_path, dict_ft, plot_path):
indp = []
not_indp = []
dir_list = os.listdir(results_path)
for file in dir_list:
models = file[: file.find(".out")]
print(models)
ft = int(dict_ft[models])
stat = get_statistic_from_file(results_path + "/" + file)
if not np.isnan(stat):
if ft:
not_indp.append(stat)
else:
indp.append(stat)
plt.figure(figsize=(8, 6))
plt.hist(indp, bins=20, range=(0, 1), color="blue")
plt.hist(not_indp, bins=20, range=(0, 1), color="green")
plt.xlabel("Test statistic value")
plt.ylabel("Count")
# plt.title(f"{}")
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
plt.close()
def fisher(pvalues):
chi_squared = 0
num_layers = 0
for pvalue in pvalues:
if not np.isnan(pvalue):
chi_squared -= 2 * np.log(pvalue)
num_layers += 1
return chi2.sf(chi_squared, df=2 * num_layers)
def plot_statistic_scatter_all_layers(results_path, dict_ft, plot_path):
x = []
y = []
c = []
dir_list = os.listdir(results_path)
for layer in range(32):
for file in dir_list:
models = file[: file.find(".out")]
# if("huggyllama" in models): continue
print(models)
ft = int(dict_ft[models])
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
if not np.isnan(stat):
x.append(layer)
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
if ft:
c.append("r")
else:
c.append("b")
for file in dir_list:
models = file[: file.find(".out")]
# if("huggyllama" in models): continue
ft = int(dict_ft[models])
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
if not np.isnan(stat):
x.append(layer)
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer))
if ft:
c.append("r")
else:
c.append("b")
plt.figure(figsize=(8, 6))
plt.scatter(x, y, s=1.5, c=c)
plt.xlabel("Layer")
plt.ylabel("p-value")
# plt.title(f"{}")
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
plt.close()
def plot_pvalue(results_path, dict_ft, plot_path):
pvalues = []
dir_list = os.listdir(results_path)
for layer in range(32):
for file in dir_list:
models = file[: file.find(".out")]
if "huggyllama" in models:
continue
print(models)
ft = int(dict_ft[models])
if ft is True:
continue
stat = get_layer_statistic_from_file(results_path + "/" + file, layer)
if not np.isnan(stat):
pvalues.append(stat)
x = np.arange(0, 1, step=0.001)
y = []
print(pvalues)
print(len(pvalues))
for i in x:
counter = 0
for val in pvalues:
if val < i:
counter += 1
y.append(counter / len(pvalues))
plt.figure(figsize=(8, 6))
plt.plot(x, y, ".-")
# plt.xlabel("Fine tuned")
# plt.ylabel("Test statistic")
# plt.title(f"{}")
# plt.xlim(-10,0)
# plt.ylim(-10,0)
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
plt.close()
if __name__ == "__main__":
dict_ft = get_dict_ft("/nlp/u/salzhu/model-tracing/config/llama_flat.yaml")
# plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs",
# dict_ft, "test_statistic_plots/l2_pvalue_horizontal")
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
# base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)",
# "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap",
# 3, log=False)
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/csh_0928_reruns/logs",
# base_ordered, "",
# "/nlp/u/salzhu/csh_0929_cols",
# 3, log=False)
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
# base_ordered, "",
# "/nlp/u/salzhu/robust_0929",
# 3, log=False)
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs",
# base_ordered, "",
# "/nlp/u/salzhu/l2_0927",
# 3, log=False)
# plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft,
# "/nlp/u/salzhu/test_statistic_plots/mlp_sp_final")
# plot_statistic_scatter_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft,
# "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_layer31", 31)
# plot_statistic_grid_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
# base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)",
# "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap_layer31",
# 3, 31, log=False)
# plot_statistic_scatter_all_layers("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
# dict_ft,
# "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_all_layers")
plot_pvalue(
"/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs",
dict_ft,
"/nlp/u/salzhu/test_statistic_plots/mlp_sp_final",
)
# plot_histogram("/juice4/scr4/nlp/model-tracing/mlp_match_med_max_layer0/logs",
# dict_ft, "/nlp/u/salzhu/test_statistic_plots/mlp_med_max_histogram")
# checkpoints = {
# "100M": 1e8,
# "1B": 1e9,
# "10B": 1e10,
# "18B": 1.8e10,
# }
# checkpoints = {
# "100M": 1e8,
# "1B": 1e9,
# "4B": 4e9,
# "8B": 8e9,
# "16B": 1.6e10
# }
# checkpoints = {
# "100M": 1e8,
# "1B": 1e9,
# "12B": 1.2e10,
# "25B": 2.5e10
# }
# plot_statistic_olmo_scatter("/juice4/scr4/nlp/model-tracing/olmo_models_runs/final_checkpoint/csw_robust_cols/logs", checkpoints,
# "final checkpoint vs. additional training seed 42", "CSW robust",
# "/nlp/u/salzhu/olmo_plots/final_checkpoint/csw_robust_cols_seed42")