Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import json | |
| from pathlib import Path | |
| import nemo.collections.asr as nemo_asr | |
| import torch | |
| import yaml | |
| from torchaudio import load | |
| from torchaudio.functional import resample | |
| from tqdm import tqdm | |
| from zcodec.models import WavVAE, ZFlowAutoEncoder | |
| def load_config(config_path): | |
| with open(config_path, "r") as f: | |
| return yaml.safe_load(f) | |
| def transcribe(audio: torch.Tensor, asr_model) -> str: | |
| audio = audio.cpu().numpy(force=True) | |
| with torch.inference_mode(): | |
| return asr_model.transcribe([audio[0]])[0].text | |
| def main(args): | |
| config = load_config(args.config) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load models | |
| wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval() | |
| zflowae = ( | |
| ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval() | |
| ) | |
| # Load ASR model | |
| asr_model = nemo_asr.models.ASRModel.from_pretrained( | |
| model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2") | |
| ) | |
| # Read file list | |
| with open(config["file_list"], "r") as f: | |
| wav_files = [line.strip() for line in f if line.strip()] | |
| results = [] | |
| for wav_path in tqdm(wav_files, desc="ASR on reconstructed audio"): | |
| wav, sr = load(wav_path) | |
| wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device) | |
| with torch.inference_mode(): | |
| # Compress and decompress | |
| z = wavvae.encode(wav) | |
| zz, _ = zflowae.encode(z) | |
| z_hat = zflowae.decode( | |
| zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0) | |
| ) | |
| wav_hat = wavvae.decode(z_hat) | |
| # Transcribe | |
| wav_hat = resample(wav_hat, orig_freq=wavvae.sampling_rate, new_freq=16000) | |
| reconstructed_text = transcribe(wav_hat, asr_model) | |
| results.append( | |
| { | |
| "file": wav_path, | |
| "transcript": reconstructed_text, | |
| } | |
| ) | |
| # Save output | |
| out_path = Path(config.get("output_jsonl", "asr_reconstructed.jsonl")) | |
| with out_path.open("w") as f: | |
| for entry in results: | |
| f.write(json.dumps(entry, ensure_ascii=False) + "\n") | |
| print(f"\nSaved {len(results)} reconstructed ASR results to {out_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, required=True, help="Path to YAML config") | |
| args = parser.parse_args() | |
| main(args) | |