whisper-fleurs-km_kh-small / finetune-backup.py
pengyizhou's picture
update
4b76203
#!/usr/bin/env python
# finetune_whisper.py
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
import torch
from datasets import load_dataset, Audio
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
import ipdb
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
# → Choose device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Configuration
LANGUAGE = "km_kh" # FLEURS config for Khmer
LANGUAGE_WHISPER = "khmer" # Whisper config for Khmer
MODEL_CHECKPOINT = "openai/whisper-large-v3"
OUTPUT_DIR = f"./whisper-fleurs-{LANGUAGE}-small"
TRAIN_SPLIT = "train"
VALID_SPLIT = "validation"
TEST_SPLIT = "test"
MAX_TARGET_LENGTH= 448
# 2. Load FLEURS Dataset (audio at 16 kHz)
raw_datasets = load_dataset("google/fleurs", LANGUAGE,
split={ "train": TRAIN_SPLIT,
"validation": VALID_SPLIT,
"test": TEST_SPLIT })
# Cast “audio” column to 16 kHz
for split in ["train", "validation", "test"]:
raw_datasets[split] = raw_datasets[split].cast_column("audio", Audio(sampling_rate=16_000))
raw_datasets["train"] = raw_datasets["train"].train_test_split(test_size=0.75, seed=42)["test"]
# 3. Load Whisper Processor & Model
processor = WhisperProcessor.from_pretrained(MODEL_CHECKPOINT, language=LANGUAGE_WHISPER)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
model.to(device)
# 4. Preprocessing Function
# - Extract log‐Mel features from audio
# - Tokenize the target transcription
def preprocess_batch(batch):
# batch["audio"]["array"] is a list of NumPy arrays @ 16 kHz
audio_arrays = [example["array"] for example in batch["audio"]]
# 4a. Feature extraction (log‐Mel + normalization)
inputs = processor.feature_extractor(
audio_arrays,
sampling_rate=16_000,
return_tensors="pt"
)
# 4b. Tokenize (labels) using the Whisper tokenizer
# We prefix with target language ID (e.g. "<|my_mm|>") if necessary;
# but for FLEURS, the default Whisper language‐ID tokens should suffice.
labels = processor.tokenizer(
batch["transcription"],
return_tensors="pt",
padding="longest",
truncation=True,
max_length=MAX_TARGET_LENGTH
)
# ipdb.set_trace()
# rename for trainer:
inputs["input_features"] = inputs.pop("input_features")
inputs["labels"] = labels.input_ids
return inputs
# 5. Apply preprocessing to train/validation/test
# - Remove all non‐audio columns after mapping
train_dataset = raw_datasets["train"].map(
preprocess_batch,
remove_columns=raw_datasets["train"].column_names,
batched=True,
batch_size=16, # adjust batch_size to your memory
)
# ipdb.set_trace()
eval_dataset = raw_datasets["validation"].map(
preprocess_batch,
remove_columns=raw_datasets["validation"].column_names,
batched=True,
batch_size=8,
)
test_dataset = raw_datasets["test"].map(
preprocess_batch,
remove_columns=raw_datasets["test"].column_names,
batched=True,
batch_size=8,
)
# 6. Data Collator
# This will pad input_features and labels to the maximum length in the batch,
# and replace padding token ID in labels by -100 to ignore them in loss computation.
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
# 7. Metrics: WER & CER (using Hugging Face Evaluate)
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
def compute_metrics(pred):
"""
pred.predictions: raw token IDs from generate()
pred.label_ids: token IDs used as labels
"""
# 7a. decode predictions → strings
pred_ids = pred.predictions
# ensure we skip special tokens
pred_str = processor.batch_decode(pred_ids,
skip_special_tokens=True)
# 7b. decode references → strings, replacing -100 with padding_token_id
label_ids = pred.label_ids
# replace -100 with pad_token_id so that the tokenizer does not crash
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
ref_str = processor.batch_decode(label_ids, skip_special_tokens=True)
# lowercase & strip
pred_str = [s.lower().strip() for s in pred_str]
ref_str = [s.lower().strip() for s in ref_str]
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
return { "wer": wer_score, "cer": cer_score }
"""
# 8. Training Arguments
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=4, # reduce if you OOM; or increase if large GPU
per_device_eval_batch_size=4,
gradient_accumulation_steps=2, # to simulate a larger batch
evaluation_strategy="steps",
eval_steps=500, # evaluate every 500 steps
logging_steps=250,
save_steps=1000,
num_train_epochs=3,
learning_rate=1e-5,
warmup_steps=500,
fp16=True, # use mixed precision if supported
predict_with_generate=True, # for computing WER/CER we need generate()
save_total_limit=2,
push_to_hub=False,
)
"""
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=16,
gradient_accumulation_steps=1,
learning_rate=1e-5,
warmup_steps=100,
max_steps=800,
gradient_checkpointing=True,
fp16=True,
eval_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=448,
save_steps=100,
eval_steps=100,
logging_steps=10,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="cer",
greater_is_better=False,
push_to_hub=True
)
# 9. Initialize Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=processor.feature_extractor, # feature_extractor + tokenizer packed in processor
compute_metrics=compute_metrics,
)
# 10. Fine-tune
if __name__ == "__main__":
# 10a. Train
trainer.train()
# 10b. Evaluate on TEST split
print("\n***** Evaluating on TEST split *****")
test_metrics = trainer.predict(test_dataset, metric_key_prefix="test")
print(f"Test WER: {test_metrics.metrics['test_wer']*100:.2f}%")
print(f"Test CER: {test_metrics.metrics['test_cer']*100:.2f}%")