|
|
|
|
|
|
|
""" |
|
Fine-tune openai/whisper-large-v3 on mixed datasets from different languages: |
|
- FLEURS Cebuano (ceb_ph) |
|
- FLEURS Khmer (km_kh) |
|
- Switchboard1 English |
|
- WenetSpeech Chinese |
|
- Eng-Indon-CS |
|
- Eng-Malay-CS |
|
Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper |
|
|
|
To run this script on multiple GPUs, you have several options: |
|
|
|
1. **Automatic Multi-GPU (DataParallel-style):** |
|
python finetune_whisper_mix_datasets.py |
|
|
|
The script will automatically detect and use all available GPUs. |
|
|
|
2. **Distributed Training with torchrun (Recommended for 2+ GPUs):** |
|
torchrun --nproc_per_node=2 finetune_whisper_mix_datasets.py |
|
|
|
This uses DistributedDataParallel which is more efficient. |
|
|
|
3. **Distributed Training with accelerate (Alternative):** |
|
accelerate launch --num_processes=2 finetune_whisper_mix_datasets.py |
|
|
|
Requires: pip install accelerate |
|
|
|
Note: With 2 GPUs, the effective batch size becomes: |
|
per_device_batch_size * num_gpus * gradient_accumulation_steps |
|
= 24 * 2 * 1 = 48 (compared to 32 with single GPU) |
|
|
|
CPU Core Limiting: |
|
The script automatically limits CPU usage to 20 cores using environment variables. |
|
You can also set these manually before running: |
|
export OMP_NUM_THREADS=20 |
|
export MKL_NUM_THREADS=20 |
|
export NUMEXPR_NUM_THREADS=20 |
|
python finetune_whisper_mix_datasets.py |
|
""" |
|
|
|
import os |
|
import random |
|
import io |
|
import yaml |
|
import argparse |
|
from itertools import chain |
|
|
|
|
|
def load_config(config_path): |
|
with open(config_path, 'r') as file: |
|
return yaml.safe_load(file) |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Fine-tune Whisper on mixed datasets') |
|
parser.add_argument('--config', type=str, default='config.yaml', |
|
help='Path to configuration YAML file') |
|
args = parser.parse_args() |
|
|
|
|
|
config = load_config(args.config) |
|
|
|
|
|
env_config = config['environment'] |
|
os.environ["OMP_NUM_THREADS"] = env_config['omp_num_threads'] |
|
os.environ["MKL_NUM_THREADS"] = env_config['mkl_num_threads'] |
|
os.environ["OPENBLAS_NUM_THREADS"] = env_config['openblas_num_threads'] |
|
os.environ["VECLIB_MAXIMUM_THREADS"] = env_config['veclib_maximum_threads'] |
|
os.environ["NUMEXPR_NUM_THREADS"] = env_config['numexpr_num_threads'] |
|
os.environ["TOKENIZERS_PARALLELISM"] = env_config['tokenizers_parallelism'] |
|
os.environ["TRANSFORMERS_NO_TF"] = env_config['transformers_no_tf'] |
|
|
|
import torch |
|
from datasets import load_dataset, Audio, concatenate_datasets, Dataset |
|
from torch.utils.data import Dataset as TorchDataset |
|
from transformers import ( |
|
WhisperProcessor, |
|
WhisperForConditionalGeneration, |
|
Seq2SeqTrainingArguments, |
|
Seq2SeqTrainer, |
|
) |
|
import ipdb |
|
import evaluate |
|
import numpy as np |
|
import ipdb |
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
print(f"Setting up for {torch.cuda.device_count()} GPUs") |
|
|
|
if "LOCAL_RANK" not in os.environ: |
|
os.environ["LOCAL_RANK"] = "0" |
|
if "WORLD_SIZE" not in os.environ: |
|
os.environ["WORLD_SIZE"] = str(torch.cuda.device_count()) |
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Union |
|
|
|
class WhisperOnTheFlyDataset(TorchDataset): |
|
"""Custom dataset that preprocesses audio on-the-fly during training""" |
|
|
|
def __init__(self, dataset, processors, main_processor, max_target_length, audio_config): |
|
self.dataset = dataset |
|
self.processors = processors |
|
self.main_processor = main_processor |
|
self.max_target_length = max_target_length |
|
self.sampling_rate = audio_config['sampling_rate'] |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
item = self.dataset[idx] |
|
|
|
audio_sample = item["audio"] |
|
audio_data = self._process_audio(audio_sample) |
|
|
|
|
|
inputs = self.main_processor.feature_extractor( |
|
audio_data, |
|
sampling_rate=self.sampling_rate, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
lang = item["language"] |
|
if lang in ["cebuano", "khmer"]: |
|
text = item["transcription"] |
|
else: |
|
text = item["text"] |
|
|
|
|
|
if lang == "cebuano": |
|
labels = self.processors["cebuano"].tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=False, |
|
truncation=True, |
|
max_length=self.max_target_length |
|
) |
|
elif lang == "khmer": |
|
labels = self.processors["khmer"].tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=False, |
|
truncation=True, |
|
max_length=self.max_target_length |
|
) |
|
elif lang == "english": |
|
labels = self.processors["english"].tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=False |
|
) |
|
elif lang == "chinese": |
|
labels = self.processors["chinese"].tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=False |
|
) |
|
elif lang == "indonesian": |
|
labels = self.processors["indonesian"].tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=False |
|
) |
|
else: |
|
labels = self.processors["malay"].tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=False |
|
) |
|
|
|
return { |
|
"input_features": inputs.input_features.squeeze(0), |
|
"labels": labels.input_ids.squeeze(0), |
|
"language": lang |
|
} |
|
|
|
def _process_audio(self, audio_sample): |
|
"""Process audio sample into numpy array""" |
|
import librosa |
|
|
|
if isinstance(audio_sample, dict): |
|
if "array" in audio_sample: |
|
return audio_sample["array"] |
|
elif "bytes" in audio_sample and audio_sample["bytes"] is not None: |
|
audio_array, _ = librosa.load(io.BytesIO(audio_sample["bytes"]), sr=self.sampling_rate) |
|
return audio_array |
|
elif "path" in audio_sample: |
|
audio_array, _ = librosa.load(audio_sample["path"], sr=self.sampling_rate) |
|
return audio_array |
|
else: |
|
raise ValueError(f"Unknown audio dict format: {audio_sample.keys()}") |
|
elif isinstance(audio_sample, str): |
|
audio_array, _ = librosa.load(audio_sample, sr=self.sampling_rate) |
|
return audio_array |
|
else: |
|
return audio_sample |
|
|
|
@dataclass |
|
class DataCollatorSpeechSeq2SeqWithPadding: |
|
processor: Any |
|
decoder_start_token_id: int |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_features": feature["input_features"]} for feature in features] |
|
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") |
|
|
|
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
|
|
|
|
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): |
|
labels = labels[:, 1:] |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
MODEL_CHECKPOINT = config['model']['checkpoint'] |
|
OUTPUT_DIR = config['output']['output_dir'] |
|
MAX_TARGET_LENGTH = config['model']['max_target_length'] |
|
|
|
|
|
MAX_CPU_CORES = config['environment']['max_cpu_cores'] |
|
TEST_CPU_CORES = config['environment']['test_cpu_cores'] |
|
|
|
|
|
DATASET_CONFIGS = config['languages'] |
|
|
|
print("Loading datasets...") |
|
|
|
|
|
datasets = {} |
|
dataset_configs = config['datasets'] |
|
audio_config = config['audio'] |
|
|
|
|
|
enabled_languages = set(config['languages'].keys()) & set(config['datasets'].keys()) |
|
print(f"Enabled languages: {list(enabled_languages)}") |
|
|
|
def load_fleurs_dataset(lang_name, lang_config, dataset_config): |
|
"""Load FLEURS dataset for a language""" |
|
print(f"Loading FLEURS {lang_name.title()}...") |
|
lang_datasets = load_dataset( |
|
dataset_config['source'], |
|
dataset_config['language_code'], |
|
split={k: v for k, v in dataset_config['splits'].items()}, |
|
trust_remote_code=dataset_config['trust_remote_code'] |
|
) |
|
|
|
for split in dataset_config['splits'].keys(): |
|
lang_datasets[split] = lang_datasets[split].cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False)) |
|
|
|
|
|
if 'train_subset_ratio' in lang_config: |
|
train_subset_ratio = lang_config['train_subset_ratio'] |
|
lang_datasets["train"] = lang_datasets["train"].train_test_split(test_size=1-train_subset_ratio, seed=config['data_processing']['seed'])["train"] |
|
|
|
return lang_datasets |
|
|
|
def load_simple_dataset(lang_name, dataset_config): |
|
"""Load simple dataset with train/validation/test splits""" |
|
print(f"Loading {lang_name.title()}...") |
|
lang_dataset = load_dataset(dataset_config['source'], split={k: v for k, v in dataset_config['splits'].items()}) |
|
return lang_dataset |
|
|
|
def load_english_dataset(lang_config, dataset_config): |
|
"""Load English dataset with custom train/validation split""" |
|
print("Loading English...") |
|
swb_train = load_dataset(dataset_config['train_dataset'], split=dataset_config['train_split'], streaming=dataset_config['streaming']) |
|
swb_test = load_dataset(dataset_config['test_dataset'], split=dataset_config['test_split'], streaming=dataset_config['streaming']) |
|
|
|
validation_size = lang_config['validation_size'] |
|
swb_val = swb_train.take(validation_size) |
|
swb_train = swb_train.skip(validation_size) |
|
return { |
|
"train": swb_train, |
|
"validation": swb_val, |
|
"test": swb_test |
|
} |
|
|
|
def load_chinese_dataset(dataset_config): |
|
"""Load Chinese dataset with multiple test splits""" |
|
print("Loading Chinese...") |
|
wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming']) |
|
wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming']) |
|
wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming']) |
|
wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming']) |
|
return { |
|
"train": wenet_train["train"], |
|
"validation": wenet_valid, |
|
"test_net": wenet_testnet, |
|
"test_meeting": wenet_testmeeting |
|
} |
|
|
|
|
|
for lang in enabled_languages: |
|
lang_config = config['languages'][lang] |
|
dataset_config = dataset_configs[lang] |
|
|
|
if lang in ['cebuano', 'khmer']: |
|
|
|
datasets[lang] = load_fleurs_dataset(lang, lang_config, dataset_config) |
|
elif lang == 'english': |
|
|
|
datasets[lang] = load_english_dataset(lang_config, dataset_config) |
|
elif lang == 'chinese': |
|
|
|
datasets[lang] = load_chinese_dataset(dataset_config) |
|
elif lang in ['indonesian', 'malay']: |
|
|
|
datasets[lang] = load_simple_dataset(lang, dataset_config) |
|
else: |
|
print(f"Warning: Unknown language {lang}, treating as simple dataset") |
|
datasets[lang] = load_simple_dataset(lang, dataset_config) |
|
|
|
print("Setting up processors...") |
|
|
|
|
|
processors = {} |
|
for lang in enabled_languages: |
|
lang_config = config['languages'][lang] |
|
processors[lang] = WhisperProcessor.from_pretrained( |
|
MODEL_CHECKPOINT, |
|
language=lang_config["whisper_language"] |
|
) |
|
|
|
|
|
if "english" in processors: |
|
main_processor = processors["english"] |
|
elif processors: |
|
main_processor = processors[list(processors.keys())[0]] |
|
else: |
|
raise ValueError("No processors created. Check your language configuration.") |
|
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT) |
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
print(f"Using {torch.cuda.device_count()} GPUs for training") |
|
|
|
model.to(device) |
|
else: |
|
model.to(device) |
|
|
|
|
|
|
|
print("Adding language labels to raw datasets...") |
|
|
|
|
|
for lang in enabled_languages: |
|
lang_datasets = datasets[lang] |
|
|
|
|
|
if isinstance(lang_datasets, dict): |
|
|
|
for split_name, split_dataset in lang_datasets.items(): |
|
if split_dataset is not None: |
|
|
|
columns_to_remove = [col for col in split_dataset.column_names if col.lower() in ["language", "lang"]] |
|
if columns_to_remove: |
|
print(f"Removing existing language column(s) {columns_to_remove} from {lang} {split_name}") |
|
datasets[lang][split_name] = split_dataset.remove_columns(columns_to_remove) |
|
|
|
|
|
datasets[lang][split_name] = datasets[lang][split_name].add_column("language", [lang] * len(datasets[lang][split_name])) |
|
else: |
|
|
|
print(f"Warning: Unexpected dataset structure for {lang}") |
|
continue |
|
|
|
|
|
print("Combining raw datasets before preprocessing...") |
|
|
|
|
|
def standardize_dataset_schema(dataset, dataset_name): |
|
"""Standardize dataset schema to ensure compatibility for concatenation""" |
|
print(f"Standardizing schema for {dataset_name}...") |
|
|
|
|
|
if "audio" in dataset.column_names: |
|
print(f" Setting audio feature type to {audio_config['sampling_rate']}Hz (compressed) for {dataset_name}") |
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False)) |
|
|
|
|
|
columns_to_remove = [] |
|
for col in dataset.column_names: |
|
if col in config['data_processing']['columns_to_remove']: |
|
columns_to_remove.append(col) |
|
|
|
if columns_to_remove: |
|
print(f" Removing incompatible columns: {columns_to_remove}") |
|
dataset = dataset.remove_columns(columns_to_remove) |
|
|
|
return dataset |
|
|
|
|
|
print("Standardizing training datasets...") |
|
raw_train_datasets = [] |
|
for lang in enabled_languages: |
|
if "train" in datasets[lang]: |
|
std_dataset = standardize_dataset_schema(datasets[lang]["train"], f"{lang}_train") |
|
raw_train_datasets.append(std_dataset) |
|
|
|
|
|
print("Standardizing validation datasets...") |
|
raw_val_datasets = [] |
|
for lang in enabled_languages: |
|
if "validation" in datasets[lang]: |
|
std_dataset = standardize_dataset_schema(datasets[lang]["validation"], f"{lang}_val") |
|
raw_val_datasets.append(std_dataset) |
|
|
|
|
|
if raw_train_datasets: |
|
print("Combining training datasets...") |
|
combined_raw_train = concatenate_datasets(raw_train_datasets) |
|
combined_raw_train = combined_raw_train.shuffle(seed=config['data_processing']['seed']) |
|
else: |
|
raise ValueError("No training datasets found. Check your configuration.") |
|
|
|
if raw_val_datasets: |
|
print("Combining validation datasets...") |
|
combined_raw_val = concatenate_datasets(raw_val_datasets) |
|
combined_raw_val = combined_raw_val.shuffle(seed=config['data_processing']['seed']) |
|
else: |
|
print("Warning: No validation datasets found. Training without validation.") |
|
combined_raw_val = None |
|
|
|
print("Creating on-the-fly datasets (no preprocessing stored to disk)...") |
|
|
|
|
|
|
|
combined_train_dataset = WhisperOnTheFlyDataset( |
|
combined_raw_train, |
|
processors, |
|
main_processor, |
|
MAX_TARGET_LENGTH, |
|
audio_config |
|
) |
|
|
|
|
|
if combined_raw_val is not None: |
|
combined_val_dataset = WhisperOnTheFlyDataset( |
|
combined_raw_val, |
|
processors, |
|
main_processor, |
|
MAX_TARGET_LENGTH, |
|
audio_config |
|
) |
|
else: |
|
combined_val_dataset = None |
|
|
|
print("Creating on-the-fly test datasets...") |
|
|
|
|
|
processed_datasets = {} |
|
|
|
for lang in enabled_languages: |
|
processed_datasets[lang] = {} |
|
|
|
|
|
if lang == "chinese": |
|
|
|
if "test_net" in datasets[lang]: |
|
processed_datasets[lang]["test_net"] = WhisperOnTheFlyDataset( |
|
datasets[lang]["test_net"], |
|
processors, |
|
main_processor, |
|
MAX_TARGET_LENGTH, |
|
audio_config |
|
) |
|
if "test_meeting" in datasets[lang]: |
|
processed_datasets[lang]["test_meeting"] = WhisperOnTheFlyDataset( |
|
datasets[lang]["test_meeting"], |
|
processors, |
|
main_processor, |
|
MAX_TARGET_LENGTH, |
|
audio_config |
|
) |
|
else: |
|
|
|
if "test" in datasets[lang]: |
|
processed_datasets[lang]["test"] = WhisperOnTheFlyDataset( |
|
datasets[lang]["test"], |
|
processors, |
|
main_processor, |
|
MAX_TARGET_LENGTH, |
|
audio_config |
|
) |
|
|
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding( |
|
processor=main_processor, |
|
decoder_start_token_id=model.config.decoder_start_token_id, |
|
) |
|
|
|
|
|
wer_metric = evaluate.load("wer") |
|
cer_metric = evaluate.load("cer") |
|
|
|
def compute_metrics(pred): |
|
""" |
|
Compute WER and CER metrics for predictions |
|
""" |
|
pred_ids = pred.predictions |
|
pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
|
|
label_ids = pred.label_ids |
|
label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id |
|
ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True) |
|
|
|
|
|
pred_str = [s.lower().strip() for s in pred_str] |
|
ref_str = [s.lower().strip() for s in ref_str] |
|
|
|
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str) |
|
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str) |
|
return {"wer": wer_score, "cer": cer_score} |
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
print(f"Number of available GPUs: {num_gpus}") |
|
|
|
|
|
training_config = config['training'] |
|
|
|
|
|
if num_gpus > 1: |
|
|
|
gpu_config = training_config['multi_gpu'] |
|
per_device_batch_size = gpu_config['per_device_train_batch_size'] |
|
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size'] |
|
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps'] |
|
print(f"Multi-GPU training detected. Using {num_gpus} GPUs.") |
|
else: |
|
|
|
gpu_config = training_config['single_gpu'] |
|
per_device_batch_size = gpu_config['per_device_train_batch_size'] |
|
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size'] |
|
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps'] |
|
print("Single GPU training.") |
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir=OUTPUT_DIR, |
|
per_device_train_batch_size=per_device_batch_size, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
learning_rate=training_config['learning_rate'], |
|
warmup_steps=training_config['warmup_steps'], |
|
max_steps=training_config['max_steps'], |
|
gradient_checkpointing=training_config['gradient_checkpointing'], |
|
fp16=training_config['fp16'], |
|
eval_strategy=training_config['eval_strategy'], |
|
per_device_eval_batch_size=per_device_eval_batch_size, |
|
predict_with_generate=training_config['predict_with_generate'], |
|
generation_max_length=training_config['generation_max_length'], |
|
save_steps=training_config['save_steps'], |
|
eval_steps=training_config['eval_steps'], |
|
logging_steps=training_config['logging_steps'], |
|
report_to=training_config['report_to'], |
|
load_best_model_at_end=training_config['load_best_model_at_end'], |
|
metric_for_best_model=training_config['metric_for_best_model'], |
|
greater_is_better=training_config['greater_is_better'], |
|
push_to_hub=training_config['push_to_hub'], |
|
save_total_limit=training_config['save_total_limit'], |
|
|
|
dataloader_drop_last=training_config['dataloader_drop_last'], |
|
ddp_find_unused_parameters=training_config['ddp_find_unused_parameters'], |
|
) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=combined_train_dataset, |
|
eval_dataset=combined_val_dataset, |
|
data_collator=data_collator, |
|
tokenizer=main_processor.feature_extractor, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
def evaluate_on_test_sets(): |
|
"""Evaluate the model on all test sets from enabled languages""" |
|
print("\n" + "="*60) |
|
print("EVALUATING ON ALL TEST SETS") |
|
print("="*60) |
|
|
|
results = {} |
|
|
|
for lang in enabled_languages: |
|
if lang in processed_datasets: |
|
lang_results = {} |
|
|
|
if lang == "chinese": |
|
|
|
if "test_net" in processed_datasets[lang]: |
|
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****") |
|
chi_testnet_metrics = trainer.predict(processed_datasets[lang]["test_net"], metric_key_prefix=f"test_{lang}_net") |
|
print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%") |
|
print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%") |
|
lang_results["test_net"] = chi_testnet_metrics.metrics |
|
|
|
if "test_meeting" in processed_datasets[lang]: |
|
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****") |
|
chi_testmeet_metrics = trainer.predict(processed_datasets[lang]["test_meeting"], metric_key_prefix=f"test_{lang}_meeting") |
|
print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%") |
|
print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%") |
|
lang_results["test_meeting"] = chi_testmeet_metrics.metrics |
|
else: |
|
|
|
if "test" in processed_datasets[lang]: |
|
print(f"\n***** Evaluating on {lang.title()} test set *****") |
|
test_metrics = trainer.predict(processed_datasets[lang]["test"], metric_key_prefix=f"test_{lang}") |
|
print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%") |
|
print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%") |
|
lang_results["test"] = test_metrics.metrics |
|
|
|
results[lang] = lang_results |
|
|
|
|
|
print("\n" + "="*60) |
|
print("SUMMARY OF ALL TEST RESULTS") |
|
print("="*60) |
|
|
|
for lang in enabled_languages: |
|
if lang in results: |
|
if lang == "chinese": |
|
if "test_net" in results[lang]: |
|
wer = results[lang]["test_net"][f"test_{lang}_net_wer"] * 100 |
|
cer = results[lang]["test_net"][f"test_{lang}_net_cer"] * 100 |
|
print(f"Chinese-NET: WER={wer:.2f}% | CER={cer:.2f}%") |
|
if "test_meeting" in results[lang]: |
|
wer = results[lang]["test_meeting"][f"test_{lang}_meeting_wer"] * 100 |
|
cer = results[lang]["test_meeting"][f"test_{lang}_meeting_cer"] * 100 |
|
print(f"Chinese-MTG: WER={wer:.2f}% | CER={cer:.2f}%") |
|
else: |
|
if "test" in results[lang]: |
|
wer = results[lang]["test"][f"test_{lang}_wer"] * 100 |
|
cer = results[lang]["test"][f"test_{lang}_cer"] * 100 |
|
print(f"{lang.title():12}: WER={wer:.2f}% | CER={cer:.2f}%") |
|
|
|
return results |
|
|
|
if __name__ == "__main__": |
|
print(f"Total training samples: {len(combined_train_dataset)}") |
|
print(f"Total validation samples: {len(combined_val_dataset)}") |
|
print("Starting training...") |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
evaluate_on_test_sets() |
|
|
|
|
|
|
|
|
|
|