import argparse import json from pathlib import Path import nemo.collections.asr as nemo_asr import torch import yaml from jiwer import wer from torchaudio import load from torchaudio.functional import resample from tqdm import tqdm from zcodec.models import WavVAE, ZFlowAutoEncoder def load_config(config_path): with open(config_path, "r") as f: return yaml.safe_load(f) def transcribe(audio: torch.Tensor, asr_model) -> str: audio = audio.cpu().numpy(force=True) with torch.inference_mode(): return asr_model.transcribe([audio[0]])[0].text def main(args): config = load_config(args.config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load models wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval() zflowae = ( ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval() ) # Load ASR model asr_model = nemo_asr.models.ASRModel.from_pretrained( model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2") ) # Read file list with open(config["file_list"], "r") as f: wav_files = [line.strip() for line in f if line.strip()] results = [] for wav_path in tqdm(wav_files, desc="Processing files"): wav, sr = load(wav_path) wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device) with torch.inference_mode(): # Transcribe original original_text = transcribe(wav, asr_model) # Compress and decompress z = wavvae.encode(wav) zz, _ = zflowae.encode(z) z_hat = zflowae.decode( zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0) ) wav_hat = wavvae.decode(z_hat) # Transcribe reconstructed reconstructed_text = transcribe(wav_hat, asr_model) results.append( { "file": wav_path, "original_text": original_text, "reconstructed_text": reconstructed_text, } ) # Save output out_path = Path(config.get("output_jsonl", "transcripts.jsonl")) with out_path.open("w") as f: for entry in results: f.write(json.dumps(entry, ensure_ascii=False) + "\n") print(f"\nSaved {len(results)} transcript pairs to {out_path}") # Optionally compute WER if args.compute_wer: original_texts = [r["original_text"] for r in results] reconstructed_texts = [r["reconstructed_text"] for r in results] score = wer(original_texts, reconstructed_texts) print(f"WER: {score:.3%}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to YAML config") parser.add_argument( "--compute_wer", action="store_true", help="Compute WER after decoding" ) args = parser.parse_args() main(args)