Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| from pathlib import Path | |
| import librosa | |
| import soundfile as sf | |
| import torch | |
| from safetensors.torch import save_file | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from transformers import HubertModel | |
| class AudioDataset(Dataset): | |
| def __init__(self, file_list, target_sr=16000): | |
| self.paths = file_list | |
| self.target_sr = target_sr | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, idx): | |
| path = self.paths[idx] | |
| wav, sr = sf.read(str(path)) | |
| if sr != self.target_sr: | |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) | |
| wav = torch.tensor(wav).float().unsqueeze(0) # shape: [1, T] | |
| return wav, path | |
| def encode_batch(model, batch, device, out_dir, keep_layers): | |
| wavs, paths = batch | |
| for wav, path in zip(wavs, paths): | |
| wav = wav.to(device) | |
| outputs = model(wav, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states # tuple of 13 tensors: [1, T', D] | |
| selected = { | |
| f"layer_{i}": hs.squeeze(0).cpu() | |
| for i, hs in enumerate(hidden_states) | |
| if i in keep_layers | |
| } | |
| out_path = out_dir / (path.stem + ".st") | |
| save_file(selected, str(out_path)) | |
| def parse_layers(layer_str, max_layers): | |
| if layer_str.strip().lower() == "all": | |
| return set(range(max_layers)) | |
| return set(int(idx) for idx in layer_str.split(",") if idx.strip().isdigit()) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Infer HuBERT hidden states.") | |
| parser.add_argument( | |
| "file_list", type=Path, help="Text file with paths to audio files" | |
| ) | |
| parser.add_argument("output_dir", type=Path, help="Directory to save .st files") | |
| parser.add_argument( | |
| "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| parser.add_argument("--num_workers", type=int, default=2) | |
| parser.add_argument( | |
| "--layers", | |
| type=str, | |
| default="all", | |
| help="Comma-separated layer indices or 'all'", | |
| ) | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| model = HubertModel.from_pretrained("facebook/hubert-base-ls960").to(device).eval() | |
| num_layers = ( | |
| len(model.config.hidden_layers) | |
| if hasattr(model.config, "hidden_layers") | |
| else 13 | |
| ) | |
| keep_layers = parse_layers(args.layers, num_layers) | |
| with open(args.file_list, "r") as f: | |
| paths = [Path(line.strip()) for line in f if line.strip()] | |
| dataset = AudioDataset(paths) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| num_workers=args.num_workers, | |
| collate_fn=lambda x: list(zip(*x)), | |
| ) | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| for batch in tqdm(dataloader): | |
| try: | |
| encode_batch(model, batch, device, args.output_dir, keep_layers) | |
| except Exception as e: | |
| print(f"β Failed on batch: {e}") | |
| if __name__ == "__main__": | |
| main() | |