Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
import json
from pathlib import Path
import nemo.collections.asr as nemo_asr
import torch
import yaml
from jiwer import wer
from torchaudio import load
from torchaudio.functional import resample
from tqdm import tqdm
from zcodec.models import WavVAE, ZFlowAutoEncoder
def load_config(config_path):
with open(config_path, "r") as f:
return yaml.safe_load(f)
def transcribe(audio: torch.Tensor, asr_model) -> str:
audio = audio.cpu().numpy(force=True)
with torch.inference_mode():
return asr_model.transcribe([audio[0]])[0].text
def main(args):
config = load_config(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models
wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval()
zflowae = (
ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval()
)
# Load ASR model
asr_model = nemo_asr.models.ASRModel.from_pretrained(
model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
)
# Read file list
with open(config["file_list"], "r") as f:
wav_files = [line.strip() for line in f if line.strip()]
results = []
for wav_path in tqdm(wav_files, desc="Processing files"):
wav, sr = load(wav_path)
wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device)
with torch.inference_mode():
# Transcribe original
original_text = transcribe(wav, asr_model)
# Compress and decompress
z = wavvae.encode(wav)
zz, _ = zflowae.encode(z)
z_hat = zflowae.decode(
zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0)
)
wav_hat = wavvae.decode(z_hat)
# Transcribe reconstructed
reconstructed_text = transcribe(wav_hat, asr_model)
results.append(
{
"file": wav_path,
"original_text": original_text,
"reconstructed_text": reconstructed_text,
}
)
# Save output
out_path = Path(config.get("output_jsonl", "transcripts.jsonl"))
with out_path.open("w") as f:
for entry in results:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"\nSaved {len(results)} transcript pairs to {out_path}")
# Optionally compute WER
if args.compute_wer:
original_texts = [r["original_text"] for r in results]
reconstructed_texts = [r["reconstructed_text"] for r in results]
score = wer(original_texts, reconstructed_texts)
print(f"WER: {score:.3%}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
parser.add_argument(
"--compute_wer", action="store_true", help="Compute WER after decoding"
)
args = parser.parse_args()
main(args)