pardi-speech / codec /scripts /infer_hubert.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
from pathlib import Path
import librosa
import soundfile as sf
import torch
from safetensors.torch import save_file
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import HubertModel
class AudioDataset(Dataset):
def __init__(self, file_list, target_sr=16000):
self.paths = file_list
self.target_sr = target_sr
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
wav, sr = sf.read(str(path))
if sr != self.target_sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
wav = torch.tensor(wav).float().unsqueeze(0) # shape: [1, T]
return wav, path
@torch.no_grad()
def encode_batch(model, batch, device, out_dir, keep_layers):
wavs, paths = batch
for wav, path in zip(wavs, paths):
wav = wav.to(device)
outputs = model(wav, output_hidden_states=True)
hidden_states = outputs.hidden_states # tuple of 13 tensors: [1, T', D]
selected = {
f"layer_{i}": hs.squeeze(0).cpu()
for i, hs in enumerate(hidden_states)
if i in keep_layers
}
out_path = out_dir / (path.stem + ".st")
save_file(selected, str(out_path))
def parse_layers(layer_str, max_layers):
if layer_str.strip().lower() == "all":
return set(range(max_layers))
return set(int(idx) for idx in layer_str.split(",") if idx.strip().isdigit())
def main():
parser = argparse.ArgumentParser(description="Infer HuBERT hidden states.")
parser.add_argument(
"file_list", type=Path, help="Text file with paths to audio files"
)
parser.add_argument("output_dir", type=Path, help="Directory to save .st files")
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--num_workers", type=int, default=2)
parser.add_argument(
"--layers",
type=str,
default="all",
help="Comma-separated layer indices or 'all'",
)
args = parser.parse_args()
device = torch.device(args.device)
model = HubertModel.from_pretrained("facebook/hubert-base-ls960").to(device).eval()
num_layers = (
len(model.config.hidden_layers)
if hasattr(model.config, "hidden_layers")
else 13
)
keep_layers = parse_layers(args.layers, num_layers)
with open(args.file_list, "r") as f:
paths = [Path(line.strip()) for line in f if line.strip()]
dataset = AudioDataset(paths)
dataloader = DataLoader(
dataset,
batch_size=1,
num_workers=args.num_workers,
collate_fn=lambda x: list(zip(*x)),
)
args.output_dir.mkdir(parents=True, exist_ok=True)
for batch in tqdm(dataloader):
try:
encode_batch(model, batch, device, args.output_dir, keep_layers)
except Exception as e:
print(f"❌ Failed on batch: {e}")
if __name__ == "__main__":
main()