import argparse import json import os from pathlib import Path import numpy as np import torch import sacrebleu from datasets import load_dataset from torch.utils.data import Dataset, ConcatDataset from tqdm import tqdm from transformers import ( AutoProcessor, AutoModel, BatchFeature, Trainer, TrainingArguments, StoppingCriteria, StoppingCriteriaList, ) from collections import defaultdict import soundfile as sf from datasets import Audio import random class BaseAudioDataset(Dataset): def __init__(self, processor, split, sampling_rate=16000, debug=False): self.processor = processor self.training = "train" in split self.debug = debug self.sampling_rate = sampling_rate self.name = "" def set_dataset_name(self, name): self.name = name @staticmethod def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True): original_size = len(data) data = data.cast_column(audio_field, Audio(decode=False)) def identify_corrupted_files(example): try: sf.read(example[audio_field]["path"]) for field in text_fields: if field in example and example[field].replace('"', '') == "": return False return True except Exception: return False data = data.filter(identify_corrupted_files, num_proc=16) validated_size = len(data) # Audio Decoding data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True)) if debug: print(f"Dataset: {dataset_name}") print(f"Original data nums: {original_size}") print(f"After filtering data nums: {validated_size}") print(f"Filtering ratio: {validated_size/original_size:.2%}") return data @staticmethod def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True): original_size = len(data) def filter_audio_by_length(example): try: audio = example[audio_field]['array'] channel = 1 if hasattr(audio, 'ndim') and audio.ndim > 1: channel = audio.ndim audio = audio.squeeze() audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel return min_sec <= audio_length <= max_sec except Exception as e: if debug: print(f"Error : {str(e)[:100]}... - sample excluded") return False data = data.filter(filter_audio_by_length, num_proc=16) filtered_size = len(data) if debug: print(f"Before Length Filtering data nums: {original_size}") print(f"After Length Filtering data nums: {filtered_size}") print(f"Filtering ratio: {filtered_size/original_size:.2%}") return data def prepare_model_inputs(self, audio_array, instruction, answer_text): user_message = { 'role': 'user', 'content': '' + instruction, } prompt = self.processor.tokenizer.apply_chat_template( [user_message], tokenize=False, add_generation_prompt=True, add_bos=True ) inputs = self.processor( text=prompt, audio=[audio_array], add_special_tokens=False, return_tensors='pt' ) answer = f"{answer_text}{ANSWER_SUFFIX}" answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids if self.debug: self.debug = False task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR' lang_info = f" - {self.lang}" if hasattr(self, 'lang') else "" print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n") print(f"INPUT_MODE: {inputs.input_modes[0].item()}") if self.training: input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) labels = torch.full_like(input_ids, _IGNORE_INDEX) labels[:, -answer_ids.shape[1]:] = answer_ids padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1])) token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1) else: input_ids = inputs.input_ids labels = answer_ids token_type_ids = inputs.token_type_ids return { 'input_ids': input_ids, 'labels': labels, 'token_type_ids': token_type_ids, 'input_audio_embeds': inputs.input_audio_embeds, 'audio_embed_sizes': inputs.audio_embed_sizes, 'input_modes': inputs.input_modes, } # CoVoST2 Dataset Class class CoVoSTDataset(BaseAudioDataset): def __init__(self, processor, data_dir, split, ast=False, lang=("en_ko", "Korean"), sampling_rate=16000, debug=False): super().__init__(processor, split, sampling_rate, debug) self.set_dataset_name("CoVoST") self.ast = ast self.lang = lang[0] self.data = load_dataset("junnei/covost2", lang[0], data_dir=data_dir, split=split, trust_remote_code=True ) text_fields = ["sentence", "translation"] if ast else ["sentence"] self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST") # (Optional) Audio length Filtering self.data = self.filter_by_audio_length(self.data, "audio") # Instruction Setting self.instruction = random.choice(INSTRUCTION["ast"]).format(lang[1]) if ast else random.choice(INSTRUCTION["asr"]) def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] if self.ast: answer_text = data["translation"] else: answer_text = data["sentence"].replace('"', '') return self.prepare_model_inputs( data["audio"]["array"], self.instruction, answer_text ) # Zeroth Korean Dataset Class class ZerothKoreanDataset(BaseAudioDataset): def __init__(self, processor, split, sampling_rate=16000, debug=False): super().__init__(processor, split, sampling_rate, debug) self.set_dataset_name("Zeroth") # only ASR self.ast = False self.lang = "ko" # load dataset self.data = load_dataset("Bingsu/zeroth-korean", split=split, trust_remote_code=True ) # (Optional) Audio length Filtering self.data = self.filter_by_audio_length(self.data, "audio") # Instruction Setting self.instruction = random.choice(INSTRUCTION["asr"]) def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] # Zeroth Korean is only for ASR answer_text = data["text"].replace('"', '') return self.prepare_model_inputs( data["audio"]["array"], self.instruction, answer_text ) # Libri Speech Dataset Class class LibriSpeechDataset(BaseAudioDataset): def __init__(self, processor, subset, split, sampling_rate=16000, debug=False): super().__init__(processor, split, sampling_rate, debug) self.set_dataset_name(f"LibriSpeech_{subset}") # only ASR self.ast = False self.lang = "en" # load dataset self.data = load_dataset("fixie-ai/librispeech_asr", subset, split=split, trust_remote_code=True ) # (Optional) Audio length Filtering self.data = self.filter_by_audio_length(self.data, "audio") # Instruction Setting self.instruction = random.choice(INSTRUCTION["asr"]) def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] # Libri Speech is only for ASR answer_text = data["text"].replace('"', '') return self.prepare_model_inputs( data["audio"]["array"], self.instruction, answer_text ) # Fleurs Dataset Class class FleursDataset(BaseAudioDataset): def __init__(self, processor, split, source_lang, target_lang=None, mode="asr", sampling_rate=16000, debug=False): super().__init__(processor, split, sampling_rate, debug) self.set_dataset_name("Fleurs") # Mode Setting (ASR or AST) if mode not in ["asr", "ast"]: raise ValueError("mode must be 'asr' or 'ast'.") self.mode = mode self.ast = (mode == "ast") self.source_lang = source_lang # Language name mapping (expand if needed) self.lang_names = { 'en_us': 'English', 'ko_kr': 'Korean' } # load dataset - source language dataset self.data = load_dataset("google/fleurs", source_lang, split=split, trust_remote_code=True ) # (Optional) Audio length Filtering self.data = self.filter_by_audio_length(self.data, "audio") # When AST mode, load target language dataset. if self.ast: if target_lang is None: raise ValueError("AST mode requires target_lang.") self.target_lang = target_lang self.lang = f"{source_lang}_{target_lang}" # load dataset - target language dataset (for translation) target_data = load_dataset("google/fleurs", target_lang, split=split, trust_remote_code=True ) source_dict = {item['id']: item for item in self.data} target_dict = {item['id']: item for item in target_data} # only Common ID, add translation fields common_ids = set(source_dict.keys()) & set(target_dict.keys()) print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}") self.data = [ {**source_dict[id], 'translation': target_dict[id]['transcription']} for id in common_ids ] # Instruction Setting - use target language name target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize()) self.instruction = random.choice(INSTRUCTION["ast"]).format(target_lang_name) else: # ASR mode self.lang = source_lang self.instruction = random.choice(INSTRUCTION["asr"]) if self.debug: print(f"FLEURS dataset loaded: {self.mode.upper()} mode") print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})") if self.ast: print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})") print(f"dataset size: {len(self.data)}") def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] audio_array = data["audio"]["array"] if self.ast: answer_text = data["translation"] else: answer_text = data["transcription"] return self.prepare_model_inputs( audio_array, self.instruction, answer_text ) def covost_collate_fn(batch): input_ids_list = [] labels_list = [] token_type_ids_list = [] input_audio_embeds_list = [] audio_embed_sizes_list = [] audio_attention_mask_list = [] input_modes_list = [] for inputs in batch: input_ids_list.append(inputs['input_ids'][0]) labels_list.append(inputs['labels'][0]) token_type_ids_list.append(inputs['token_type_ids'][0]) input_audio_embeds_list.append(inputs['input_audio_embeds']) audio_embed_sizes_list.append(inputs['audio_embed_sizes']) audio_attention_mask_list.append( inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) ) input_modes_list.append(inputs['input_modes']) try: token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0) input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) labels = pad_sequence(labels_list, padding_side='left', padding_value=0) audio_attention_mask = ( pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False) if len(audio_attention_mask_list) > 1 else None ) except Exception as e: print(e) print(input_ids_list) print(labels_list) raise attention_mask = (input_ids != 0).long() input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) audio_embed_sizes = torch.cat(audio_embed_sizes_list) input_modes = torch.cat(input_modes_list) return BatchFeature( { 'input_ids': input_ids, 'labels': labels, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask, 'input_audio_embeds': input_audio_embeds, 'audio_embed_sizes': audio_embed_sizes, 'audio_attention_mask': audio_attention_mask, 'input_modes': input_modes, } ) def pad_sequence(sequences, padding_side='left', padding_value=0): """ Pad a list of sequences to the same length. sequences: list of tensors in [seq_len, *] shape """ assert padding_side in ['right', 'left'] max_size = sequences[0].size() trailing_dims = max_size[1:] max_len = max(len(seq) for seq in sequences) batch_size = len(sequences) output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) for i, seq in enumerate(sequences): length = seq.size(0) if padding_side == 'right': output.data[i, :length] = seq else: output.data[i, -length:] = seq return output def cat_with_pad(tensors, dim, padding_value=0): """ cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() assert all( t.dim() == ndim for t in tensors[1:] ), 'All tensors must have the same number of dimensions' out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) output = tensors[0].new_full(out_size, padding_value) index = 0 for t in tensors: # Create a slice list where every dimension except dim is full slice slices = [slice(0, t.shape[d]) for d in range(ndim)] # Update only the concat dimension slice slices[dim] = slice(index, index + t.shape[dim]) output[slices] = t index += t.shape[dim] return output def count_parameters_by_module(model): # dictionary for parameters number by modules module_params = defaultdict(lambda: {"total": 0, "trainable": 0}) # all params total_params = 0 total_trainable_params = 0 # Check Embedding Token masks embedding_masks = {} for name, param in model.named_parameters(): if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks: # check if params has embedding_grad_mask_hook for hook_id, hook_fn in param._backward_hooks.items(): if hook_fn.__code__.co_name == 'embedding_grad_mask_hook': # Accessing mask variables in the closure of hook functions for cell in hook_fn.__closure__ or []: if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool: # check mask tensor embedding_masks[name] = ~cell.cell_contents # True : Trainable # Count params by modules for name, param in model.named_parameters(): # extracts top module_name module_name = name.split('.')[0] param_count = param.numel() module_params[module_name]["total"] += param_count total_params += param_count if param.requires_grad: # Only count for real trainable params. (with masks) if name in embedding_masks: trainable_count = embedding_masks[name].sum().item() module_params[module_name]["trainable"] += trainable_count total_trainable_params += trainable_count else: module_params[module_name]["trainable"] += param_count total_trainable_params += param_count print(f"All Params: {total_params:,}") print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)") print("\nParams by Module:") for module_name, counts in sorted(module_params.items()): trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0 total_percentage = counts["total"] / total_params * 100 print(f"- {module_name}:") print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)") print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)") return module_params def create_model(model_name_or_path, revision="main", use_flash_attention = False): model = AutoModel.from_pretrained( model_name_or_path, revision=revision, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" if use_flash_attention else "eager", trust_remote_code=True, ) # Set use_cache to False after model loaded model.config.use_cache = False # Freeze all parameters for param in model.parameters(): param.requires_grad = False model.set_lora_adapter('speech') model.to(torch.bfloat16) # (Optional) unfreeze audio_tower parameters #for param in model.audio_tower.parameters(): # param.requires_grad = True # Only unfreeze audio_projector parameters for param in model.audio_projector.parameters(): param.requires_grad = True # (Optional) unfreeze audio embed_tokens train_embed = True if train_embed: embed_tokens = model.language_model.model.model.embed_tokens embed_tokens.weight.requires_grad = False # Added Speech token IDs (only this tokens be trainable) trainable_token_ids = [256001, 256002] embed_tokens.weight.requires_grad = True mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool) mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze) # backward hook, with gradient masking def embedding_grad_mask_hook(grad): return grad.masked_fill(mask, 0) embed_tokens.weight.register_hook(embedding_grad_mask_hook) model.language_model.model.model.embed_tokens = embed_tokens count_parameters_by_module(model) return model os.environ["TOKENIZERS_PARALLELISM"] = "false" INSTRUCTION = { "ast": [ "Translate the audio to {0}.", "Translate the audio clip into {0}.", "Based on the attached audio, generate a comprehensive {0} translation of the spoken content.", "Translate the provided audio file into {0}.", "Convert the audio speech to {0} text.", "Write an {0} translation of the audio file.", "Translate spoken words from the audio into {0}.", "Create an {0} version of the audio content.", "Produce an accurate {0} translation of the audio.", "Extract speech from the audio and translate it to {0}.", "Turn the audio into readable {0} text.", "Write all spoken content from the audio in {0}.", "Generate an {0} translation of the speech in the file.", "Convert the recording into {0} text.", "Accurately translate the audio recording to {0}.", "Write down dialogue from the given audio in {0}.", "Translate all speech in this audio file to {0}.", "Create an accurate {0} version of the speech.", "Perform a complete {0} translation of the audio." ], "asr": [ "Transcribe the audio clip into text.", "Based on the attached audio, generate a comprehensive text transcription of the spoken content.", "Transcribe the provided audio file into text.", "Convert the audio speech to text.", "Write a transcript of the audio file.", "Transcribe spoken words from the audio.", "Create a text version of the audio content.", "Produce a verbatim transcript of the audio.", "Extract and transcribe speech from the audio.", "Turn the audio into readable text.", "Write all spoken words from the audio.", "Generate a transcript of the speech in the file.", "Convert the recording into a text transcript.", "Accurately transcribe the audio recording.", "Write down dialogue from the given audio.", "Transcribe all speech in this audio file.", "Create an accurate text version of the speech.", "Perform a complete transcription of the audio." ], } ANSWER_SUFFIX = "" _IGNORE_INDEX = -100 model_name_or_path = 'junnei/gemma-3-4b-it-speech' use_flash_attention = True output_dir = '/workspace/output' batch_size = 128 batch_size_per_gpu = 32 learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning wd = 0.01 num_train_epochs = 5 revision = "main" #"v1.0" processor = AutoProcessor.from_pretrained( model_name_or_path, revision=revision, trust_remote_code=True, ) model = create_model( model_name_or_path, revision=revision, use_flash_attention=use_flash_attention, ) train_datasets = [] # Covost ASR mode (English -> English text) covost_asr_dataset = CoVoSTDataset( processor=processor, data_dir="/workspace/CommonVoice/EN", split="train", ast=False, lang=("en_ko", "Korean") ) train_datasets.append(covost_asr_dataset) # Covost AST mode (English -> Korean text) covost_dataset = CoVoSTDataset( processor=processor, data_dir="/workspace/CommonVoice/EN", split="train", ast=True, lang=("en_ko", "Korean") ) train_datasets.append(covost_dataset) # Libri Speech Clean ASR mode (English -> English text) libri_speech_clean = LibriSpeechDataset( processor=processor, subset="clean", split="train.360" ) train_datasets.append(libri_speech_clean) # Libri Speech Other ASR mode (English -> English text) libri_speech_other = LibriSpeechDataset( processor=processor, subset="other", split="train.500" ) train_datasets.append(libri_speech_other) # Fleurs ASR mode (English -> English text) en_asr_fleurs = FleursDataset( processor=processor, split="train", source_lang="en_us", # English mode="asr" ) train_datasets.append(en_asr_fleurs) # Fleurs AST mode (English -> Korean text) en_ko_ast_fleurs = FleursDataset( processor=processor, split="train", source_lang="en_us", # English target_lang="ko_kr", # Korean mode="ast" ) train_datasets.append(en_ko_ast_fleurs) # Covost ASR mode (Korean -> Korean text) covost_ko_asr_dataset = CoVoSTDataset( processor=processor, data_dir="/workspace/CommonVoice/ko", split="train", ast=False, lang=("ko_en", "English") ) train_datasets.append(covost_ko_asr_dataset) # Covost AST mode (Korean -> English text) covost_ko_dataset = CoVoSTDataset( processor=processor, data_dir="/workspace/CommonVoice/ko", split="train", ast=True, lang=("ko_en", "English") ) train_datasets.append(covost_ko_dataset) # Zeroth ASR mode (Korean -> Korean text) ko_asr_zeroth = ZerothKoreanDataset( processor=processor, split="train" ) train_datasets.append(ko_asr_zeroth) # Fleurs ASR mode (Korean -> Korean text) ko_asr_fleurs = FleursDataset( processor=processor, split="train", source_lang="ko_kr", # Korean mode="asr" ) train_datasets.append(ko_asr_fleurs) # Fleurs AST mode (Korean -> English text) ko_en_ast_fleurs = FleursDataset( processor=processor, split="train", source_lang="ko_kr", # Korean target_lang="en_us", # English mode="ast" ) train_datasets.append(ko_en_ast_fleurs) print("Count Num of Datasets", len(train_datasets)) print([len(dataset) for dataset in train_datasets]) # ConcatDataset train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] print("Count Length of Datas", len(train_dataset)) # Check GPUs num_gpus = torch.cuda.device_count() print(f'training on {num_gpus} GPUs') assert ( batch_size % (num_gpus * batch_size_per_gpu) == 0 ), 'Batch size must be divisible by the number of GPUs' gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu) # hard coded training args training_args = TrainingArguments( num_train_epochs=num_train_epochs, per_device_train_batch_size=batch_size_per_gpu, gradient_checkpointing=True, gradient_checkpointing_kwargs={'use_reentrant': False}, gradient_accumulation_steps=gradient_accumulation_steps, optim='adamw_torch', adam_beta1=0.9, adam_beta2=0.95, adam_epsilon=1e-7, learning_rate=learning_rate, weight_decay=wd, max_grad_norm=1.0, lr_scheduler_type='cosine', warmup_steps=50, logging_steps=50, output_dir=output_dir, save_strategy='no', save_total_limit=10, save_only_model=True, bf16=True, fp16=False, remove_unused_columns=False, report_to='none', deepspeed=None, disable_tqdm=False, dataloader_num_workers=4, ddp_find_unused_parameters=True, ) out_path = Path(training_args.output_dir) out_path.mkdir(parents=True, exist_ok=True) # create optimizer only for trainable params optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=wd, betas=(0.9, 0.95), eps=1e-7, ) # Trainer Setting trainer = Trainer( model=model, args=training_args, data_collator=covost_collate_fn, train_dataset=train_dataset, optimizers=(optimizer, None), ) trainer.train() import shutil # 1. Save LoRA Adapter model.language_model.model.save_pretrained(output_dir) # 1-1. Delete Markdown file markdown_file = os.path.join(output_dir, "README.md") if os.path.exists(markdown_file): os.remove(markdown_file) # 2. Save entire model model.save_pretrained(output_dir) # 3. Cleanup Memory del model del trainer __import__('gc').collect() torch.cuda.empty_cache() from huggingface_hub import HfApi, login, create_repo, Repository, upload_folder upload_dir = "/workspace/upload" # 4. Clone Repo repo_id = "junnei/gemma-3-4b-it-speech" branch_name = "main" # 새 브랜치 이름 repo = Repository(local_dir=upload_dir, clone_from = repo_id) repo.git_checkout(branch_name, create_branch_ok=True) # 4-1. Move Trained model to Repo for item in os.listdir(output_dir): s = os.path.join(output_dir, item) d = os.path.join(upload_dir, item) if os.path.isdir(s): shutil.copytree(s, d, dirs_exist_ok=True) else: shutil.copy2(s, d)