import argparse import random import torch from safetensors.torch import safe_open, save_file from tqdm import tqdm def load_tensor(path: str, key: str = "embedding") -> torch.Tensor: with safe_open(path, framework="pt", device="cpu") as f: return f.get_tensor(key) def compute_global_stats(file_list, key="embedding", length_weighted=True): sum_all = None sum_sq_all = None count_all = 0 for path in tqdm(file_list, desc="Computing stats"): tensor = load_tensor(path, key) # shape: [B, T, D] flat = tensor.reshape(-1, tensor.shape[-1]) # [B*T, D] sum_ = flat.sum(dim=0) # [D] sum_sq = (flat**2).sum(dim=0) # [D] count = flat.shape[0] # B*T if sum_all is None: sum_all = sum_ sum_sq_all = sum_sq else: sum_all += sum_ sum_sq_all += sum_sq count_all += count mean = sum_all / count_all var = sum_sq_all / count_all - mean**2 std = torch.sqrt(torch.clamp(var, min=1e-8)) return mean, std def main(): parser = argparse.ArgumentParser() parser.add_argument( "filelist", type=str, help="Text file with list of safetensors paths" ) parser.add_argument("output", type=str, help="Path to output stats.safetensors") parser.add_argument( "--key", type=str, default="audio_z", help="Key of tensor in safetensors file" ) parser.add_argument( "--max-files", type=int, default=None, help="Max number of files to process" ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for shuffling" ) args = parser.parse_args() with open(args.filelist) as f: files = [line.strip() for line in f if line.strip()] if args.max_files: random.seed(args.seed) files = random.sample(files, k=min(args.max_files, len(files))) mean, std = compute_global_stats(files, key=args.key) save_file({"mean": mean, "std": std}, args.output) print(f"✅ Saved to {args.output}") print("Example mean/std:", mean[:5], std[:5]) if __name__ == "__main__": main()