Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from tracing.perm.permute import permute_model | |
def main(base_model, ft_model, test_stat, num_perm, emb_dim=4096, mlp_dim=11008): | |
unperm_stat = test_stat(base_model, ft_model) | |
print(unperm_stat) | |
perm_stats = [] | |
for i in range(num_perm): | |
mlp_permutation = torch.randperm(mlp_dim) | |
emb_permutation = torch.randperm(emb_dim) | |
permute_model(ft_model, mlp_permutation, emb_permutation) | |
perm_stat = test_stat(base_model, ft_model) | |
perm_stats.append(perm_stat) | |
print(i, perm_stat) | |
print(perm_stats) | |
exact = p_value_exact(unperm_stat, perm_stats.copy()) | |
approx = p_value_approx(unperm_stat, perm_stats.copy()) | |
print(exact, approx) | |
return exact, approx, unperm_stat, perm_stats | |
def p_value_exact(unpermuted, permuted): | |
count = 0 | |
for a in permuted: | |
if a < unpermuted: | |
count += 1 | |
return round((count + 1) / (len(permuted) + 1), 2) | |
def p_value_approx(unpermuted, permuted): | |
mean = sum(permuted) / len(permuted) | |
std = np.std(permuted) | |
zscore = (unpermuted - mean) / std | |
return zscore | |