pengyizhou's picture
update README
b1b221d
#!/usr/bin/env python
# finetune_whisper_mix_datasets.py
"""
Fine-tune openai/whisper-large-v3 on mixed datasets from different languages:
- FLEURS Cebuano (ceb_ph)
- FLEURS Khmer (km_kh)
- Switchboard1 English
- WenetSpeech Chinese
- Eng-Indon-CS
- Eng-Malay-CS
Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper
To run this script on multiple GPUs, you have several options:
1. **Automatic Multi-GPU (DataParallel-style):**
python finetune_whisper_mix_datasets.py
The script will automatically detect and use all available GPUs.
2. **Distributed Training with torchrun (Recommended for 2+ GPUs):**
torchrun --nproc_per_node=2 finetune_whisper_mix_datasets.py
This uses DistributedDataParallel which is more efficient.
3. **Distributed Training with accelerate (Alternative):**
accelerate launch --num_processes=2 finetune_whisper_mix_datasets.py
Requires: pip install accelerate
Note: With 2 GPUs, the effective batch size becomes:
per_device_batch_size * num_gpus * gradient_accumulation_steps
= 24 * 2 * 1 = 48 (compared to 32 with single GPU)
CPU Core Limiting:
The script automatically limits CPU usage to 20 cores using environment variables.
You can also set these manually before running:
export OMP_NUM_THREADS=20
export MKL_NUM_THREADS=20
export NUMEXPR_NUM_THREADS=20
python finetune_whisper_mix_datasets.py
"""
import os
import random
import io
import yaml
import argparse
from itertools import chain
# Load configuration from YAML file
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
# Parse command line arguments
parser = argparse.ArgumentParser(description='Fine-tune Whisper on mixed datasets')
parser.add_argument('--config', type=str, default='config.yaml',
help='Path to configuration YAML file')
args = parser.parse_args()
# Load configuration
config = load_config(args.config)
# Set environment variables from config
env_config = config['environment']
os.environ["OMP_NUM_THREADS"] = env_config['omp_num_threads']
os.environ["MKL_NUM_THREADS"] = env_config['mkl_num_threads']
os.environ["OPENBLAS_NUM_THREADS"] = env_config['openblas_num_threads']
os.environ["VECLIB_MAXIMUM_THREADS"] = env_config['veclib_maximum_threads']
os.environ["NUMEXPR_NUM_THREADS"] = env_config['numexpr_num_threads']
os.environ["TOKENIZERS_PARALLELISM"] = env_config['tokenizers_parallelism']
os.environ["TRANSFORMERS_NO_TF"] = env_config['transformers_no_tf']
import torch
from datasets import load_dataset, Audio, concatenate_datasets, Dataset
from torch.utils.data import Dataset as TorchDataset
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
import ipdb
import evaluate
import numpy as np
import ipdb
# Multi-GPU setup
if torch.cuda.device_count() > 1:
print(f"Setting up for {torch.cuda.device_count()} GPUs")
# Enable distributed training environment variables if not already set
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = "0"
if "WORLD_SIZE" not in os.environ:
os.environ["WORLD_SIZE"] = str(torch.cuda.device_count())
from dataclasses import dataclass
from typing import Any, Dict, List, Union
class WhisperOnTheFlyDataset(TorchDataset):
"""Custom dataset that preprocesses audio on-the-fly during training"""
def __init__(self, dataset, processors, main_processor, max_target_length, audio_config):
self.dataset = dataset
self.processors = processors
self.main_processor = main_processor
self.max_target_length = max_target_length
self.sampling_rate = audio_config['sampling_rate']
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# Process audio
audio_sample = item["audio"]
audio_data = self._process_audio(audio_sample)
# Extract with main processor
inputs = self.main_processor.feature_extractor(
audio_data,
sampling_rate=self.sampling_rate,
return_tensors="pt"
)
# Process text with appropriate processor
lang = item["language"]
if lang in ["cebuano", "khmer"]:
text = item["transcription"]
else: # english, chinese
text = item["text"]
# Tokenize with appropriate processor
if lang == "cebuano":
labels = self.processors["cebuano"].tokenizer(
text,
return_tensors="pt",
padding=False,
truncation=True,
max_length=self.max_target_length
)
elif lang == "khmer":
labels = self.processors["khmer"].tokenizer(
text,
return_tensors="pt",
padding=False,
truncation=True,
max_length=self.max_target_length
)
elif lang == "english":
labels = self.processors["english"].tokenizer(
text,
return_tensors="pt",
padding=False
)
elif lang == "chinese":
labels = self.processors["chinese"].tokenizer(
text,
return_tensors="pt",
padding=False
)
elif lang == "indonesian":
labels = self.processors["indonesian"].tokenizer(
text,
return_tensors="pt",
padding=False
)
else: # Malay
labels = self.processors["malay"].tokenizer(
text,
return_tensors="pt",
padding=False
)
return {
"input_features": inputs.input_features.squeeze(0),
"labels": labels.input_ids.squeeze(0),
"language": lang
}
def _process_audio(self, audio_sample):
"""Process audio sample into numpy array"""
import librosa
if isinstance(audio_sample, dict):
if "array" in audio_sample:
return audio_sample["array"]
elif "bytes" in audio_sample and audio_sample["bytes"] is not None:
audio_array, _ = librosa.load(io.BytesIO(audio_sample["bytes"]), sr=self.sampling_rate)
return audio_array
elif "path" in audio_sample:
audio_array, _ = librosa.load(audio_sample["path"], sr=self.sampling_rate)
return audio_array
else:
raise ValueError(f"Unknown audio dict format: {audio_sample.keys()}")
elif isinstance(audio_sample, str):
audio_array, _ = librosa.load(audio_sample, sr=self.sampling_rate)
return audio_array
else:
return audio_sample
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
# → Choose device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Extract configuration values
MODEL_CHECKPOINT = config['model']['checkpoint']
OUTPUT_DIR = config['output']['output_dir']
MAX_TARGET_LENGTH = config['model']['max_target_length']
# CPU usage configuration for dataset preprocessing
MAX_CPU_CORES = config['environment']['max_cpu_cores']
TEST_CPU_CORES = config['environment']['test_cpu_cores']
# Language configurations for each dataset
DATASET_CONFIGS = config['languages']
print("Loading datasets...")
# Load datasets for each language dynamically based on configuration
datasets = {}
dataset_configs = config['datasets']
audio_config = config['audio']
# Get list of enabled languages from both languages and datasets config
enabled_languages = set(config['languages'].keys()) & set(config['datasets'].keys())
print(f"Enabled languages: {list(enabled_languages)}")
def load_fleurs_dataset(lang_name, lang_config, dataset_config):
"""Load FLEURS dataset for a language"""
print(f"Loading FLEURS {lang_name.title()}...")
lang_datasets = load_dataset(
dataset_config['source'],
dataset_config['language_code'],
split={k: v for k, v in dataset_config['splits'].items()},
trust_remote_code=dataset_config['trust_remote_code']
)
# DON'T decode audio yet - keep it compressed until preprocessing
for split in dataset_config['splits'].keys():
lang_datasets[split] = lang_datasets[split].cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
# Use subset of training data if specified
if 'train_subset_ratio' in lang_config:
train_subset_ratio = lang_config['train_subset_ratio']
lang_datasets["train"] = lang_datasets["train"].train_test_split(test_size=1-train_subset_ratio, seed=config['data_processing']['seed'])["train"]
return lang_datasets
def load_simple_dataset(lang_name, dataset_config):
"""Load simple dataset with train/validation/test splits"""
print(f"Loading {lang_name.title()}...")
lang_dataset = load_dataset(dataset_config['source'], split={k: v for k, v in dataset_config['splits'].items()})
return lang_dataset
def load_english_dataset(lang_config, dataset_config):
"""Load English dataset with custom train/validation split"""
print("Loading English...")
swb_train = load_dataset(dataset_config['train_dataset'], split=dataset_config['train_split'], streaming=dataset_config['streaming'])
swb_test = load_dataset(dataset_config['test_dataset'], split=dataset_config['test_split'], streaming=dataset_config['streaming'])
# Split into train/validation
validation_size = lang_config['validation_size']
swb_val = swb_train.take(validation_size)
swb_train = swb_train.skip(validation_size)
return {
"train": swb_train,
"validation": swb_val,
"test": swb_test
}
def load_chinese_dataset(dataset_config):
"""Load Chinese dataset with multiple test splits"""
print("Loading Chinese...")
wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming'])
wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming'])
wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming'])
wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming'])
return {
"train": wenet_train["train"],
"validation": wenet_valid,
"test_net": wenet_testnet,
"test_meeting": wenet_testmeeting
}
# Load datasets for each enabled language
for lang in enabled_languages:
lang_config = config['languages'][lang]
dataset_config = dataset_configs[lang]
if lang in ['cebuano', 'khmer']:
# FLEURS datasets
datasets[lang] = load_fleurs_dataset(lang, lang_config, dataset_config)
elif lang == 'english':
# English with custom validation split
datasets[lang] = load_english_dataset(lang_config, dataset_config)
elif lang == 'chinese':
# Chinese with multiple test splits
datasets[lang] = load_chinese_dataset(dataset_config)
elif lang in ['indonesian', 'malay']:
# Simple datasets with standard splits
datasets[lang] = load_simple_dataset(lang, dataset_config)
else:
print(f"Warning: Unknown language {lang}, treating as simple dataset")
datasets[lang] = load_simple_dataset(lang, dataset_config)
print("Setting up processors...")
# Create processors for each enabled language
processors = {}
for lang in enabled_languages:
lang_config = config['languages'][lang]
processors[lang] = WhisperProcessor.from_pretrained(
MODEL_CHECKPOINT,
language=lang_config["whisper_language"]
)
# Use the first available processor as the main one, preferring English if available
if "english" in processors:
main_processor = processors["english"]
elif processors:
main_processor = processors[list(processors.keys())[0]]
else:
raise ValueError("No processors created. Check your language configuration.")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
# Multi-GPU handling
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs for training")
# The model will be automatically distributed by the Trainer
model.to(device)
else:
model.to(device)
print("Adding language labels to raw datasets...")
# Remove existing language columns and add our own consistent language labels for each enabled language
for lang in enabled_languages:
lang_datasets = datasets[lang]
# Handle different dataset structures
if isinstance(lang_datasets, dict):
# Dataset with explicit splits (train/validation/test)
for split_name, split_dataset in lang_datasets.items():
if split_dataset is not None:
# Remove existing language column if it exists
columns_to_remove = [col for col in split_dataset.column_names if col.lower() in ["language", "lang"]]
if columns_to_remove:
print(f"Removing existing language column(s) {columns_to_remove} from {lang} {split_name}")
datasets[lang][split_name] = split_dataset.remove_columns(columns_to_remove)
# Add our consistent language label
datasets[lang][split_name] = datasets[lang][split_name].add_column("language", [lang] * len(datasets[lang][split_name]))
else:
# Single dataset object - this shouldn't happen with current structure but handle gracefully
print(f"Warning: Unexpected dataset structure for {lang}")
continue
print("Combining raw datasets before preprocessing...")
# Ensure all datasets have compatible schemas before concatenation
def standardize_dataset_schema(dataset, dataset_name):
"""Standardize dataset schema to ensure compatibility for concatenation"""
print(f"Standardizing schema for {dataset_name}...")
# Keep audio compressed until preprocessing - only set sampling rate
if "audio" in dataset.column_names:
print(f" Setting audio feature type to {audio_config['sampling_rate']}Hz (compressed) for {dataset_name}")
dataset = dataset.cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
# Remove problematic columns that might have different types
columns_to_remove = []
for col in dataset.column_names:
if col in config['data_processing']['columns_to_remove']:
columns_to_remove.append(col)
if columns_to_remove:
print(f" Removing incompatible columns: {columns_to_remove}")
dataset = dataset.remove_columns(columns_to_remove)
return dataset
# Standardize all training datasets dynamically
print("Standardizing training datasets...")
raw_train_datasets = []
for lang in enabled_languages:
if "train" in datasets[lang]:
std_dataset = standardize_dataset_schema(datasets[lang]["train"], f"{lang}_train")
raw_train_datasets.append(std_dataset)
# Standardize validation datasets dynamically
print("Standardizing validation datasets...")
raw_val_datasets = []
for lang in enabled_languages:
if "validation" in datasets[lang]:
std_dataset = standardize_dataset_schema(datasets[lang]["validation"], f"{lang}_val")
raw_val_datasets.append(std_dataset)
# Combine datasets if we have any
if raw_train_datasets:
print("Combining training datasets...")
combined_raw_train = concatenate_datasets(raw_train_datasets)
combined_raw_train = combined_raw_train.shuffle(seed=config['data_processing']['seed'])
else:
raise ValueError("No training datasets found. Check your configuration.")
if raw_val_datasets:
print("Combining validation datasets...")
combined_raw_val = concatenate_datasets(raw_val_datasets)
combined_raw_val = combined_raw_val.shuffle(seed=config['data_processing']['seed'])
else:
print("Warning: No validation datasets found. Training without validation.")
combined_raw_val = None
print("Creating on-the-fly datasets (no preprocessing stored to disk)...")
# Create on-the-fly datasets instead of preprocessing and storing
# Create on-the-fly datasets instead of preprocessing and storing
combined_train_dataset = WhisperOnTheFlyDataset(
combined_raw_train,
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
# Only create validation dataset if we have validation data
if combined_raw_val is not None:
combined_val_dataset = WhisperOnTheFlyDataset(
combined_raw_val,
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
else:
combined_val_dataset = None
print("Creating on-the-fly test datasets...")
# Create on-the-fly test datasets dynamically
processed_datasets = {}
for lang in enabled_languages:
processed_datasets[lang] = {}
# Handle different test split structures for different languages
if lang == "chinese":
# Chinese has multiple test splits
if "test_net" in datasets[lang]:
processed_datasets[lang]["test_net"] = WhisperOnTheFlyDataset(
datasets[lang]["test_net"],
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
if "test_meeting" in datasets[lang]:
processed_datasets[lang]["test_meeting"] = WhisperOnTheFlyDataset(
datasets[lang]["test_meeting"],
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
else:
# Standard test split
if "test" in datasets[lang]:
processed_datasets[lang]["test"] = WhisperOnTheFlyDataset(
datasets[lang]["test"],
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
# Data Collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=main_processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
# Metrics: WER & CER (using Hugging Face Evaluate)
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
def compute_metrics(pred):
"""
Compute WER and CER metrics for predictions
"""
pred_ids = pred.predictions
pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True)
label_ids = pred.label_ids
label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id
ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True)
# lowercase & strip
pred_str = [s.lower().strip() for s in pred_str]
ref_str = [s.lower().strip() for s in ref_str]
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
return {"wer": wer_score, "cer": cer_score}
# Check for multi-GPU setup
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")
# Get training configuration
training_config = config['training']
# Adjust batch size and gradient accumulation for multi-GPU
if num_gpus > 1:
# With multiple GPUs, use multi-GPU configuration
gpu_config = training_config['multi_gpu']
per_device_batch_size = gpu_config['per_device_train_batch_size']
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
print(f"Multi-GPU training detected. Using {num_gpus} GPUs.")
else:
# Single GPU configuration
gpu_config = training_config['single_gpu']
per_device_batch_size = gpu_config['per_device_train_batch_size']
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
print("Single GPU training.")
# Training Arguments
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=training_config['learning_rate'],
warmup_steps=training_config['warmup_steps'],
max_steps=training_config['max_steps'],
gradient_checkpointing=training_config['gradient_checkpointing'],
fp16=training_config['fp16'],
eval_strategy=training_config['eval_strategy'],
per_device_eval_batch_size=per_device_eval_batch_size,
predict_with_generate=training_config['predict_with_generate'],
generation_max_length=training_config['generation_max_length'],
save_steps=training_config['save_steps'],
eval_steps=training_config['eval_steps'],
logging_steps=training_config['logging_steps'],
report_to=training_config['report_to'],
load_best_model_at_end=training_config['load_best_model_at_end'],
metric_for_best_model=training_config['metric_for_best_model'],
greater_is_better=training_config['greater_is_better'],
push_to_hub=training_config['push_to_hub'],
save_total_limit=training_config['save_total_limit'],
# Multi-GPU specific settings
dataloader_drop_last=training_config['dataloader_drop_last'],
ddp_find_unused_parameters=training_config['ddp_find_unused_parameters'],
)
# Initialize Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=combined_train_dataset,
eval_dataset=combined_val_dataset,
data_collator=data_collator,
tokenizer=main_processor.feature_extractor,
compute_metrics=compute_metrics,
)
def evaluate_on_test_sets():
"""Evaluate the model on all test sets from enabled languages"""
print("\n" + "="*60)
print("EVALUATING ON ALL TEST SETS")
print("="*60)
results = {}
for lang in enabled_languages:
if lang in processed_datasets:
lang_results = {}
if lang == "chinese":
# Chinese has multiple test splits
if "test_net" in processed_datasets[lang]:
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****")
chi_testnet_metrics = trainer.predict(processed_datasets[lang]["test_net"], metric_key_prefix=f"test_{lang}_net")
print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%")
print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%")
lang_results["test_net"] = chi_testnet_metrics.metrics
if "test_meeting" in processed_datasets[lang]:
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****")
chi_testmeet_metrics = trainer.predict(processed_datasets[lang]["test_meeting"], metric_key_prefix=f"test_{lang}_meeting")
print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%")
print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%")
lang_results["test_meeting"] = chi_testmeet_metrics.metrics
else:
# Standard test split
if "test" in processed_datasets[lang]:
print(f"\n***** Evaluating on {lang.title()} test set *****")
test_metrics = trainer.predict(processed_datasets[lang]["test"], metric_key_prefix=f"test_{lang}")
print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%")
print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%")
lang_results["test"] = test_metrics.metrics
results[lang] = lang_results
# Summary
print("\n" + "="*60)
print("SUMMARY OF ALL TEST RESULTS")
print("="*60)
for lang in enabled_languages:
if lang in results:
if lang == "chinese":
if "test_net" in results[lang]:
wer = results[lang]["test_net"][f"test_{lang}_net_wer"] * 100
cer = results[lang]["test_net"][f"test_{lang}_net_cer"] * 100
print(f"Chinese-NET: WER={wer:.2f}% | CER={cer:.2f}%")
if "test_meeting" in results[lang]:
wer = results[lang]["test_meeting"][f"test_{lang}_meeting_wer"] * 100
cer = results[lang]["test_meeting"][f"test_{lang}_meeting_cer"] * 100
print(f"Chinese-MTG: WER={wer:.2f}% | CER={cer:.2f}%")
else:
if "test" in results[lang]:
wer = results[lang]["test"][f"test_{lang}_wer"] * 100
cer = results[lang]["test"][f"test_{lang}_cer"] * 100
print(f"{lang.title():12}: WER={wer:.2f}% | CER={cer:.2f}%")
return results
if __name__ == "__main__":
print(f"Total training samples: {len(combined_train_dataset)}")
print(f"Total validation samples: {len(combined_val_dataset)}")
print("Starting training...")
# Fine-tune the model
trainer.train()
# Evaluate on all test sets
evaluate_on_test_sets()