Ahmed Ahmed
Add model-tracing code for p-value computation (without binary files)
de071e9
import torch
import matplotlib.pyplot as plt
import numpy as np
import random
import os
from scipy.stats import chi2
def manual_seed(seed, fix_cudnn=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
if fix_cudnn:
torch.backends.cudnn.deterministic = True # noqa
torch.backends.cudnn.benchmark = False # noqa
def spcor(x, y):
n = len(x)
with torch.no_grad():
r = 1 - torch.sum(6 * torch.square(x - y)) / (n * (n**2 - 1))
return r
def pdists(x, y):
x = x.to("cuda")
y = y.to("cuda")
with torch.no_grad():
xsum = torch.sum(torch.square(x), axis=-1)
ysum = torch.sum(torch.square(y), axis=-1)
dists = xsum.view(-1, 1) + ysum.view(1, -1) - 2 * x @ y.T
return dists.cpu()
def cossim(x, y):
x = x.to("cuda")
y = y.to("cuda")
with torch.no_grad():
similarities = (
x
@ y.T
/ (
torch.linalg.norm(x, axis=-1).view(-1, 1)
* torch.linalg.norm(y, axis=-1).view(1, -1)
)
)
return similarities.cpu()
def fisher(p):
count = 0
chi_2 = 0
for pvalue in p:
if not np.isnan(pvalue):
chi_2 -= 2 * np.log(pvalue)
count += 1
return chi2.sf(chi_2, df=2 * count)
def normalize_mc_midpoint(mid, base, ft):
slope = ft - base
mid -= slope * 0.5
mid -= base
return mid
def normalize_trace(trace, alphas):
slope = trace[-1] - trace[0]
start = trace[0]
for i in range(len(trace)):
trace[i] -= slope * alphas[i]
trace[i] -= start
return trace
def output_hook(m, inp, op, name, feats):
feats[name] = op.detach()
def get_submodule(module, submodule_string):
attributes = submodule_string.split(".")
for attr in attributes:
module = getattr(module, attr)
return module
def plot_trace(losses, alphas, normalize, model_a_name, model_b_name, plot_path):
plt.figure(figsize=(8, 6))
if normalize:
losses = normalize_trace(losses, alphas)
plt.plot(alphas, losses, "o-")
plt.xlabel("Alpha")
plt.ylabel("Loss")
plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
plot_filename = f"{plot_path}.png"
plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
plt.close()