Spaces:
Runtime error
Runtime error
import yaml | |
from yaml import Loader | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
import numpy as np | |
import os | |
from scipy.stats import chi2 | |
base = { | |
"lmsys/vicuna-7b-v1.5": 1, | |
"codellama/CodeLlama-7b-hf": 1, | |
"codellama/CodeLlama-7b-Python-hf": 1, | |
"codellama/CodeLlama-7b-Instruct-hf": 1, | |
"EleutherAI/llemma_7b": 1, | |
"microsoft/Orca-2-7b": 1, | |
"oh-yeontaek/llama-2-7B-LoRA-assemble": 1, | |
"lvkaokao/llama2-7b-hf-instruction-lora": 1, | |
"NousResearch/Nous-Hermes-llama-2-7b": 1, | |
"lmsys/vicuna-7b-v1.1": 0, | |
"yahma/llama-7b-hf": 0, | |
"Salesforce/xgen-7b-4k-base": 2, | |
"EleutherAI/llemma_7b_muinstruct_camelmath": 1, | |
"AlfredPros/CodeLlama-7b-Instruct-Solidity": 1, | |
"meta-llama/Llama-2-7b-hf": 1, | |
"LLM360/Amber": 3, | |
"LLM360/AmberChat": 3, | |
"openlm-research/open_llama_7b": 4, | |
"openlm-research/open_llama_7b_v2": 5, | |
"ibm-granite/granite-7b-base": 6, | |
"ibm-granite/granite-7b-instruct": 6, | |
} | |
base_ordered = { | |
"yahma/llama-7b-hf": 0, | |
"lmsys/vicuna-7b-v1.1": 0, | |
"meta-llama/Llama-2-7b-hf": 1, | |
"lmsys/vicuna-7b-v1.5": 1, | |
"codellama/CodeLlama-7b-hf": 1, | |
"codellama/CodeLlama-7b-Python-hf": 1, | |
"codellama/CodeLlama-7b-Instruct-hf": 1, | |
"AlfredPros/CodeLlama-7b-Instruct-Solidity": 1, | |
"EleutherAI/llemma_7b": 1, | |
"EleutherAI/llemma_7b_muinstruct_camelmath": 1, | |
"microsoft/Orca-2-7b": 1, | |
"oh-yeontaek/llama-2-7B-LoRA-assemble": 1, | |
"lvkaokao/llama2-7b-hf-instruction-lora": 1, | |
"NousResearch/Nous-Hermes-llama-2-7b": 1, | |
"Salesforce/xgen-7b-4k-base": 2, | |
"LLM360/Amber": 3, | |
"LLM360/AmberChat": 3, | |
"openlm-research/open_llama_7b": 4, | |
"openlm-research/open_llama_7b_v2": 5, | |
"ibm-granite/granite-7b-base": 6, | |
"ibm-granite/granite-7b-instruct": 6, | |
} | |
tree = { | |
"yahma/llama-7b-hf": "A---", | |
"lmsys/vicuna-7b-v1.1": "AA--", | |
"meta-llama/Llama-2-7b-hf": "B---", | |
"lmsys/vicuna-7b-v1.5": "BA--", | |
"codellama/CodeLlama-7b-hf": "BB--", | |
"codellama/CodeLlama-7b-Python-hf": "BBA-", | |
"codellama/CodeLlama-7b-Instruct-hf": "BBB-", | |
"AlfredPros/CodeLlama-7b-Instruct-Solidity": "BBBA", | |
"EleutherAI/llemma_7b": "BBC-", | |
"EleutherAI/llemma_7b_muinstruct_camelmath": "BBCA", | |
"microsoft/Orca-2-7b": "BC--", | |
"oh-yeontaek/llama-2-7B-LoRA-assemble": "BD--", | |
"lvkaokao/llama2-7b-hf-instruction-lora": "BE--", | |
"NousResearch/Nous-Hermes-llama-2-7b": "BF--", | |
"Salesforce/xgen-7b-4k-base": "C---", | |
"LLM360/Amber": "D---", | |
"LLM360/AmberChat": "DA--", | |
"openlm-research/open_llama_7b": "E---", | |
"openlm-research/open_llama_7b_v2": "F---", | |
"ibm-granite/granite-7b-base": "G---", | |
"ibm-granite/granite-7b-instruct": "GA--", | |
} | |
def get_dict_ft(flat_model_path): | |
dict_ft = {} | |
model_paths = yaml.load(open(flat_model_path, "r"), Loader=Loader) | |
for i in range(len(model_paths)): | |
for j in range(i + 1, len(model_paths)): | |
model_a = model_paths[i] | |
model_b = model_paths[j] | |
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") | |
dict_ft[job_id] = base[model_a] == base[model_b] | |
return dict_ft | |
def get_statistic_from_file(filename): | |
file = open(filename, "r") | |
lines = file.readlines() | |
stat = np.nan | |
for line in lines: | |
if "Namespace" in line and "non-aligned test stat" in line: | |
# dict = json.loads(line) | |
# print(dict) | |
# print(dict['non-aligned test stat']) | |
start1 = line.find("non-aligned test stat") | |
stat = line[line.find(":", start1) : line.find(",", start1)] | |
stat = stat.replace(" ", "") | |
stat = stat.replace("(", "") | |
stat = stat.replace(":", "") | |
stat = stat.replace("tensor", "") | |
stat = float(stat) | |
return stat | |
def get_l2_stat_from_file(filename): | |
file = open(filename, "r") | |
lines = file.readlines() | |
stats = [] | |
for line in lines: | |
if len(line) >= 4 and line[4] == " ": | |
stats.append(line[:4]) | |
return float(stats[-1]) | |
def get_layer_statistic_from_file(filename, layer): | |
file = open(filename, "r") | |
lines = file.readlines() | |
stat = np.nan | |
for line in lines: | |
temp = str(layer) + " " | |
if line[: len(temp)] == temp: | |
stat = line[line.find("0.") :] | |
if "e" in line: | |
stat = 0 | |
# print(layer, stat) | |
stat = float(stat) | |
return stat | |
def plot_statistic_scatter(results_path, dict_ft, plot_path): | |
x = [] | |
y = [] | |
dir_list = os.listdir(results_path) | |
for file in dir_list: | |
models = file[: file.find(".out")] | |
if "huggyllama" in models: | |
continue | |
print(models) | |
ft = int(dict_ft[models]) | |
stat = get_l2_stat_from_file(results_path + "/" + file) | |
# stat = get_statistic_from_file(results_path + '/' + file) | |
if not np.isnan(stat): | |
y.append(ft) | |
# x.append(get_statistic_from_file(results_path + '/' + file)) | |
x.append(get_l2_stat_from_file(results_path + "/" + file)) | |
plt.figure(figsize=(10, 1)) | |
plt.scatter(x, y, s=8) | |
plt.xlabel("$p$-value") | |
plt.ylabel("Fine-tuned") | |
plt.ylim(-0.5, 1.5) | |
# plt.title(f"{}") | |
plot_filename = f"{plot_path}.png" | |
plt.savefig(plot_filename, dpi=300, bbox_inches="tight") | |
plt.close() | |
def plot_statistic_grid(results_path, dict_base, title, plot_path, decimals, log=False): | |
models = list(dict_base.keys()) | |
print(models) | |
data = np.full((len(models), len(models)), np.nan) | |
for i in range(len(models)): | |
for j in range(len(models)): | |
model_a = models[i] | |
model_b = models[j] | |
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out" | |
if not os.path.exists(results_path + "/" + job_id): | |
continue | |
print(job_id) | |
stat = get_statistic_from_file(results_path + "/" + job_id) | |
# stat = get_l2_stat_from_file(results_path + '/' + job_id) | |
if log: | |
stat = np.log(stat) | |
data[i][j] = np.round(stat, decimals=decimals) | |
data[j][i] = data[i][j] | |
fig, ax = plt.subplots() | |
fig.set_size_inches(20, 20) | |
im = ax.imshow(data, cmap="viridis") | |
_ = make_axes_locatable(ax) | |
_ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
# Show all ticks and label them with the respective list entries | |
ax.set_xticks(np.arange(len(models)), labels=models) | |
ax.set_yticks(np.arange(len(models)), labels=models) | |
# Rotate the tick labels and set their alignment. | |
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") | |
texts = [] | |
for i in range(len(models)): | |
text1 = [] | |
for j in range(len(models)): | |
text1.append("") | |
texts.append(text1) | |
# Loop over data dimensions and create text annotations. | |
for i in range(len(models)): | |
for j in range(len(models)): | |
texts[i][j] = str(data[i][j]) | |
if data[i][j] == 0.0: | |
texts[i][j] = "$\\varepsilon$" | |
_ = ax.text(j, i, texts[i][j], ha="center", va="center", color="w") | |
ax.set_title(title) | |
fig.tight_layout() | |
plot_filename = f"{plot_path}.png" | |
plt.savefig(plot_filename, dpi=500, bbox_inches="tight") | |
plt.close() | |
def plot_statistic_scatter_layer(results_path, dict_ft, plot_path, layer): | |
x = [] | |
y = [] | |
dir_list = os.listdir(results_path) | |
for file in dir_list: | |
models = file[: file.find(".out")] | |
if "huggyllama" in models: | |
continue | |
print(models) | |
ft = int(dict_ft[models]) | |
stat = get_layer_statistic_from_file(results_path + "/" + file, layer) | |
if not np.isnan(stat): | |
x.append(ft) | |
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer)) | |
plt.figure(figsize=(8, 6)) | |
plt.scatter(x, y, s=2) | |
plt.xlabel("Fine tuned") | |
plt.ylabel("Test statistic") | |
# plt.title(f"{}") | |
plot_filename = f"{plot_path}.png" | |
plt.savefig(plot_filename, dpi=300, bbox_inches="tight") | |
plt.close() | |
def plot_statistic_grid_layer( | |
results_path, dict_base, title, plot_path, decimals, layer, log=False | |
): | |
models = list(dict_base.keys()) | |
print(models) | |
data = np.full((len(models), len(models)), np.nan) | |
for i in range(len(models)): | |
for j in range(len(models)): | |
model_a = models[i] | |
model_b = models[j] | |
job_id = model_a.replace("/", "-") + "_AND_" + model_b.replace("/", "-") + ".out" | |
if not os.path.exists(results_path + "/" + job_id): | |
continue | |
stat = get_layer_statistic_from_file(results_path + "/" + job_id, layer) | |
if log: | |
stat = np.log(stat) | |
data[i][j] = np.round(stat, decimals=decimals) | |
data[j][i] = data[i][j] | |
fig, ax = plt.subplots() | |
fig.set_size_inches(20, 20) | |
im = ax.imshow(data) | |
_ = make_axes_locatable(ax) | |
_ = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
# Show all ticks and label them with the respective list entries | |
ax.set_xticks(np.arange(len(models)), labels=models) | |
ax.set_yticks(np.arange(len(models)), labels=models) | |
# Rotate the tick labels and set their alignment. | |
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") | |
# Loop over data dimensions and create text annotations. | |
for i in range(len(models)): | |
for j in range(len(models)): | |
_ = ax.text(j, i, data[i, j], ha="center", va="center", color="w") | |
ax.set_title(title) | |
fig.tight_layout() | |
plot_filename = f"{plot_path}.png" | |
plt.savefig(plot_filename, dpi=500, bbox_inches="tight") | |
plt.close() | |
def plot_histogram(results_path, dict_ft, plot_path): | |
indp = [] | |
not_indp = [] | |
dir_list = os.listdir(results_path) | |
for file in dir_list: | |
models = file[: file.find(".out")] | |
print(models) | |
ft = int(dict_ft[models]) | |
stat = get_statistic_from_file(results_path + "/" + file) | |
if not np.isnan(stat): | |
if ft: | |
not_indp.append(stat) | |
else: | |
indp.append(stat) | |
plt.figure(figsize=(8, 6)) | |
plt.hist(indp, bins=20, range=(0, 1), color="blue") | |
plt.hist(not_indp, bins=20, range=(0, 1), color="green") | |
plt.xlabel("Test statistic value") | |
plt.ylabel("Count") | |
# plt.title(f"{}") | |
plot_filename = f"{plot_path}.png" | |
plt.savefig(plot_filename, dpi=300, bbox_inches="tight") | |
plt.close() | |
def fisher(pvalues): | |
chi_squared = 0 | |
num_layers = 0 | |
for pvalue in pvalues: | |
if not np.isnan(pvalue): | |
chi_squared -= 2 * np.log(pvalue) | |
num_layers += 1 | |
return chi2.sf(chi_squared, df=2 * num_layers) | |
def plot_statistic_scatter_all_layers(results_path, dict_ft, plot_path): | |
x = [] | |
y = [] | |
c = [] | |
dir_list = os.listdir(results_path) | |
for layer in range(32): | |
for file in dir_list: | |
models = file[: file.find(".out")] | |
# if("huggyllama" in models): continue | |
print(models) | |
ft = int(dict_ft[models]) | |
stat = get_layer_statistic_from_file(results_path + "/" + file, layer) | |
if not np.isnan(stat): | |
x.append(layer) | |
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer)) | |
if ft: | |
c.append("r") | |
else: | |
c.append("b") | |
for file in dir_list: | |
models = file[: file.find(".out")] | |
# if("huggyllama" in models): continue | |
ft = int(dict_ft[models]) | |
stat = get_layer_statistic_from_file(results_path + "/" + file, layer) | |
if not np.isnan(stat): | |
x.append(layer) | |
y.append(get_layer_statistic_from_file(results_path + "/" + file, layer)) | |
if ft: | |
c.append("r") | |
else: | |
c.append("b") | |
plt.figure(figsize=(8, 6)) | |
plt.scatter(x, y, s=1.5, c=c) | |
plt.xlabel("Layer") | |
plt.ylabel("p-value") | |
# plt.title(f"{}") | |
plot_filename = f"{plot_path}.png" | |
plt.savefig(plot_filename, dpi=300, bbox_inches="tight") | |
plt.close() | |
def plot_pvalue(results_path, dict_ft, plot_path): | |
pvalues = [] | |
dir_list = os.listdir(results_path) | |
for layer in range(32): | |
for file in dir_list: | |
models = file[: file.find(".out")] | |
if "huggyllama" in models: | |
continue | |
print(models) | |
ft = int(dict_ft[models]) | |
if ft is True: | |
continue | |
stat = get_layer_statistic_from_file(results_path + "/" + file, layer) | |
if not np.isnan(stat): | |
pvalues.append(stat) | |
x = np.arange(0, 1, step=0.001) | |
y = [] | |
print(pvalues) | |
print(len(pvalues)) | |
for i in x: | |
counter = 0 | |
for val in pvalues: | |
if val < i: | |
counter += 1 | |
y.append(counter / len(pvalues)) | |
plt.figure(figsize=(8, 6)) | |
plt.plot(x, y, ".-") | |
# plt.xlabel("Fine tuned") | |
# plt.ylabel("Test statistic") | |
# plt.title(f"{}") | |
# plt.xlim(-10,0) | |
# plt.ylim(-10,0) | |
plot_filename = f"{plot_path}.png" | |
plt.savefig(plot_filename, dpi=300, bbox_inches="tight") | |
plt.close() | |
if __name__ == "__main__": | |
dict_ft = get_dict_ft("/nlp/u/salzhu/model-tracing/config/llama_flat.yaml") | |
# plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs", | |
# dict_ft, "test_statistic_plots/l2_pvalue_horizontal") | |
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", | |
# base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)", | |
# "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap", | |
# 3, log=False) | |
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/csh_0928_reruns/logs", | |
# base_ordered, "", | |
# "/nlp/u/salzhu/csh_0929_cols", | |
# 3, log=False) | |
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", | |
# base_ordered, "", | |
# "/nlp/u/salzhu/robust_0929", | |
# 3, log=False) | |
# plot_statistic_grid("/juice4/scr4/nlp/model-tracing/llama_models_runs/perm_mc_l2_wikitext/logs", | |
# base_ordered, "", | |
# "/nlp/u/salzhu/l2_0927", | |
# 3, log=False) | |
# plot_statistic_scatter("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft, | |
# "/nlp/u/salzhu/test_statistic_plots/mlp_sp_final") | |
# plot_statistic_scatter_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", dict_ft, | |
# "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_layer31", 31) | |
# plot_statistic_grid_layer("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", | |
# base_ordered, "MLP up/gate matching p-value on permuted model pairs (random inputs for matching)", | |
# "/nlp/u/salzhu/test_statistic_tables/mlp_match_rand_rot_perm_lap_layer31", | |
# 3, 31, log=False) | |
# plot_statistic_scatter_all_layers("/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", | |
# dict_ft, | |
# "/nlp/u/salzhu/test_statistic_plots/mlp_match_rand_rot_perm_lap_all_layers") | |
plot_pvalue( | |
"/juice4/scr4/nlp/model-tracing/mlp_match_rand_rot_perm_lap/logs", | |
dict_ft, | |
"/nlp/u/salzhu/test_statistic_plots/mlp_sp_final", | |
) | |
# plot_histogram("/juice4/scr4/nlp/model-tracing/mlp_match_med_max_layer0/logs", | |
# dict_ft, "/nlp/u/salzhu/test_statistic_plots/mlp_med_max_histogram") | |
# checkpoints = { | |
# "100M": 1e8, | |
# "1B": 1e9, | |
# "10B": 1e10, | |
# "18B": 1.8e10, | |
# } | |
# checkpoints = { | |
# "100M": 1e8, | |
# "1B": 1e9, | |
# "4B": 4e9, | |
# "8B": 8e9, | |
# "16B": 1.6e10 | |
# } | |
# checkpoints = { | |
# "100M": 1e8, | |
# "1B": 1e9, | |
# "12B": 1.2e10, | |
# "25B": 2.5e10 | |
# } | |
# plot_statistic_olmo_scatter("/juice4/scr4/nlp/model-tracing/olmo_models_runs/final_checkpoint/csw_robust_cols/logs", checkpoints, | |
# "final checkpoint vs. additional training seed 42", "CSW robust", | |
# "/nlp/u/salzhu/olmo_plots/final_checkpoint/csw_robust_cols_seed42") | |