#!/usr/bin/env python # pip install transformers datasets torch soundfile jiwer from datasets import load_dataset, Audio from transformers import pipeline, WhisperProcessor from torch.utils.data import DataLoader import torch from jiwer import wer as jiwer_wer from jiwer import cer as jiwer_cer import ipdb # 1. Load FLEURS Burmese test set, cast to 16 kHz audio ds = load_dataset("google/fleurs", "km_kh", split="test") ds = ds.cast_column("audio", Audio(sampling_rate=16_000)) from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # model_id = "openai/whisper-large-v3" model_id = "pengyizhou/whisper-fleurs-km_kh" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) whisper_model = "openai/whisper-large-v3" processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer") asr = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, chunk_length_s=30, batch_size=64, max_new_tokens=440, device=device, no_repeat_ngram_size=3, # Prevent repeating 3-grams repetition_penalty=1.0, # Penalize repetitions (>1.0 reduces repetition) length_penalty=1.0, # Control length preference num_beams=1, # Use beam search for better quality do_sample=False, # Disable sampling for deterministic output early_stopping=False, # Stop when sufficient beams are complete suppress_tokens=[], ) # 3. Batch‐wise transcription function def transcribe_batch(batch): # `batch["audio"]` is a list of {"array": np.ndarray, ...} inputs = [ ex["array"] for ex in batch["audio"] ] outputs = asr(inputs) # returns a list of dicts with "text" # lower-case and strip to normalize for CER preds = [ out["text"].lower().strip() for out in outputs ] return {"prediction": preds} # 4. Map over the dataset in chunks of, say, 32 examples at a time result = ds.map( transcribe_batch, batched=True, batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s remove_columns=ds.column_names ) # ipdb.set_trace() # 5. Compute corpus-level CER with jiwer # refs = "\n".join(t.lower().strip() for t in ds["transcription"]) # preds = "\n".join(t for t in result["prediction"]) # score = jiwer_cer(refs, preds) refs = [t.lower().strip() for t in ds["transcription"]] preds = [t for t in result["prediction"]] score_cer = jiwer_cer(refs, preds) score_wer = jiwer_wer(refs, preds) print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%") print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%") with open("./km_kh_finetune.pred", "w") as pred_results: for pred in preds: pred_results.write("{}\n".format(pred)) with open("./km_kh.ref", "w") as ref_results: for ref in refs: ref_results.write("{}\n".format(ref))