File size: 2,557 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import json
from pathlib import Path

import nemo.collections.asr as nemo_asr
import torch
import yaml
from torchaudio import load
from torchaudio.functional import resample
from tqdm import tqdm

from zcodec.models import WavVAE, ZFlowAutoEncoder


def load_config(config_path):
    with open(config_path, "r") as f:
        return yaml.safe_load(f)


def transcribe(audio: torch.Tensor, asr_model) -> str:
    audio = audio.cpu().numpy(force=True)
    with torch.inference_mode():
        return asr_model.transcribe([audio[0]])[0].text


def main(args):
    config = load_config(args.config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load models
    wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval()
    zflowae = (
        ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval()
    )

    # Load ASR model
    asr_model = nemo_asr.models.ASRModel.from_pretrained(
        model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
    )

    # Read file list
    with open(config["file_list"], "r") as f:
        wav_files = [line.strip() for line in f if line.strip()]

    results = []

    for wav_path in tqdm(wav_files, desc="ASR on reconstructed audio"):
        wav, sr = load(wav_path)
        wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device)

        with torch.inference_mode():
            # Compress and decompress
            z = wavvae.encode(wav)
            zz, _ = zflowae.encode(z)
            z_hat = zflowae.decode(
                zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0)
            )
            wav_hat = wavvae.decode(z_hat)

            # Transcribe
            wav_hat = resample(wav_hat, orig_freq=wavvae.sampling_rate, new_freq=16000)
            reconstructed_text = transcribe(wav_hat, asr_model)

        results.append(
            {
                "file": wav_path,
                "transcript": reconstructed_text,
            }
        )

    # Save output
    out_path = Path(config.get("output_jsonl", "asr_reconstructed.jsonl"))
    with out_path.open("w") as f:
        for entry in results:
            f.write(json.dumps(entry, ensure_ascii=False) + "\n")

    print(f"\nSaved {len(results)} reconstructed ASR results to {out_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
    args = parser.parse_args()
    main(args)