|
|
|
|
|
|
|
|
|
import os |
|
|
|
os.environ["TRANSFORMERS_NO_TF"] = "1" |
|
|
|
import torch |
|
from datasets import load_dataset, Audio |
|
from transformers import ( |
|
WhisperProcessor, |
|
WhisperForConditionalGeneration, |
|
Seq2SeqTrainingArguments, |
|
Seq2SeqTrainer, |
|
) |
|
import ipdb |
|
import evaluate |
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Union |
|
|
|
@dataclass |
|
class DataCollatorSpeechSeq2SeqWithPadding: |
|
processor: Any |
|
decoder_start_token_id: int |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_features": feature["input_features"]} for feature in features] |
|
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") |
|
|
|
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
|
|
|
|
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): |
|
labels = labels[:, 1:] |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
LANGUAGE = "km_kh" |
|
LANGUAGE_WHISPER = "khmer" |
|
MODEL_CHECKPOINT = "openai/whisper-large-v3" |
|
OUTPUT_DIR = f"./whisper-fleurs-{LANGUAGE}-small" |
|
TRAIN_SPLIT = "train" |
|
VALID_SPLIT = "validation" |
|
TEST_SPLIT = "test" |
|
MAX_TARGET_LENGTH= 448 |
|
|
|
raw_datasets = load_dataset("google/fleurs", LANGUAGE, |
|
split={ "train": TRAIN_SPLIT, |
|
"validation": VALID_SPLIT, |
|
"test": TEST_SPLIT }) |
|
|
|
|
|
|
|
|
|
for split in ["train", "validation", "test"]: |
|
raw_datasets[split] = raw_datasets[split].cast_column("audio", Audio(sampling_rate=16_000)) |
|
|
|
raw_datasets["train"] = raw_datasets["train"].train_test_split(test_size=0.75, seed=42)["test"] |
|
|
|
|
|
processor = WhisperProcessor.from_pretrained(MODEL_CHECKPOINT, language=LANGUAGE_WHISPER) |
|
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT) |
|
model.to(device) |
|
|
|
|
|
|
|
|
|
def preprocess_batch(batch): |
|
|
|
audio_arrays = [example["array"] for example in batch["audio"]] |
|
|
|
inputs = processor.feature_extractor( |
|
audio_arrays, |
|
sampling_rate=16_000, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
|
|
labels = processor.tokenizer( |
|
batch["transcription"], |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
max_length=MAX_TARGET_LENGTH |
|
) |
|
|
|
|
|
inputs["input_features"] = inputs.pop("input_features") |
|
inputs["labels"] = labels.input_ids |
|
return inputs |
|
|
|
|
|
|
|
train_dataset = raw_datasets["train"].map( |
|
preprocess_batch, |
|
remove_columns=raw_datasets["train"].column_names, |
|
batched=True, |
|
batch_size=16, |
|
) |
|
|
|
|
|
eval_dataset = raw_datasets["validation"].map( |
|
preprocess_batch, |
|
remove_columns=raw_datasets["validation"].column_names, |
|
batched=True, |
|
batch_size=8, |
|
) |
|
|
|
test_dataset = raw_datasets["test"].map( |
|
preprocess_batch, |
|
remove_columns=raw_datasets["test"].column_names, |
|
batched=True, |
|
batch_size=8, |
|
) |
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding( |
|
processor=processor, |
|
decoder_start_token_id=model.config.decoder_start_token_id, |
|
) |
|
|
|
|
|
wer_metric = evaluate.load("wer") |
|
cer_metric = evaluate.load("cer") |
|
|
|
def compute_metrics(pred): |
|
""" |
|
pred.predictions: raw token IDs from generate() |
|
pred.label_ids: token IDs used as labels |
|
""" |
|
|
|
pred_ids = pred.predictions |
|
|
|
pred_str = processor.batch_decode(pred_ids, |
|
skip_special_tokens=True) |
|
|
|
label_ids = pred.label_ids |
|
|
|
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id |
|
ref_str = processor.batch_decode(label_ids, skip_special_tokens=True) |
|
|
|
|
|
pred_str = [s.lower().strip() for s in pred_str] |
|
ref_str = [s.lower().strip() for s in ref_str] |
|
|
|
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str) |
|
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str) |
|
return { "wer": wer_score, "cer": cer_score } |
|
|
|
|
|
""" |
|
# 8. Training Arguments |
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir=OUTPUT_DIR, |
|
per_device_train_batch_size=4, # reduce if you OOM; or increase if large GPU |
|
per_device_eval_batch_size=4, |
|
gradient_accumulation_steps=2, # to simulate a larger batch |
|
evaluation_strategy="steps", |
|
eval_steps=500, # evaluate every 500 steps |
|
logging_steps=250, |
|
save_steps=1000, |
|
num_train_epochs=3, |
|
learning_rate=1e-5, |
|
warmup_steps=500, |
|
fp16=True, # use mixed precision if supported |
|
predict_with_generate=True, # for computing WER/CER we need generate() |
|
save_total_limit=2, |
|
push_to_hub=False, |
|
) |
|
""" |
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir=OUTPUT_DIR, |
|
per_device_train_batch_size=16, |
|
gradient_accumulation_steps=1, |
|
learning_rate=1e-5, |
|
warmup_steps=100, |
|
max_steps=800, |
|
gradient_checkpointing=True, |
|
fp16=True, |
|
eval_strategy="steps", |
|
per_device_eval_batch_size=8, |
|
predict_with_generate=True, |
|
generation_max_length=448, |
|
save_steps=100, |
|
eval_steps=100, |
|
logging_steps=10, |
|
report_to=["tensorboard"], |
|
load_best_model_at_end=True, |
|
metric_for_best_model="cer", |
|
greater_is_better=False, |
|
push_to_hub=True |
|
) |
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=data_collator, |
|
tokenizer=processor.feature_extractor, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
trainer.train() |
|
|
|
|
|
print("\n***** Evaluating on TEST split *****") |
|
test_metrics = trainer.predict(test_dataset, metric_key_prefix="test") |
|
print(f"Test WER: {test_metrics.metrics['test_wer']*100:.2f}%") |
|
print(f"Test CER: {test_metrics.metrics['test_cer']*100:.2f}%") |
|
|