junnei commited on
Commit
c923988
·
verified ·
1 Parent(s): ec7ba45

Upload finetune_speech.py

Browse files
Files changed (1) hide show
  1. examples/finetune_speech.py +929 -0
examples/finetune_speech.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import sacrebleu
9
+
10
+ from datasets import load_dataset
11
+ from torch.utils.data import Dataset, ConcatDataset
12
+ from tqdm import tqdm
13
+ from transformers import (
14
+ AutoProcessor,
15
+ AutoModel,
16
+ BatchFeature,
17
+ Trainer,
18
+ TrainingArguments,
19
+ StoppingCriteria,
20
+ StoppingCriteriaList,
21
+ )
22
+ from collections import defaultdict
23
+
24
+ import soundfile as sf
25
+ from datasets import Audio
26
+ import random
27
+
28
+ class MultipleTokenBatchStoppingCriteria(StoppingCriteria):
29
+ """Stopping criteria capable of receiving multiple stop-tokens and handling batched inputs."""
30
+
31
+ def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None:
32
+ """Initialize the multiple token batch stopping criteria.
33
+
34
+ Args:
35
+ stop_tokens: Stop-tokens.
36
+ batch_size: Batch size.
37
+
38
+ """
39
+
40
+ self.stop_tokens = stop_tokens
41
+ self.max_stop_tokens = stop_tokens.shape[-1]
42
+ self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device)
43
+
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
45
+ # Only gather the maximum number of inputs compatible with stop tokens
46
+ # and checks whether generated inputs are equal to `stop_tokens`
47
+ generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens)
48
+ equal_generated_inputs = torch.all(generated_inputs, dim=2)
49
+
50
+ # Mark the position where a stop token has been produced for each input in the batch,
51
+ # but only if the corresponding entry is not already set
52
+ sequence_idx = torch.any(equal_generated_inputs, dim=1)
53
+ sequence_set_mask = self.stop_tokens_idx == 0
54
+ self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1]
55
+
56
+ return torch.all(self.stop_tokens_idx)
57
+
58
+ class BaseAudioDataset(Dataset):
59
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
60
+ self.processor = processor
61
+ self.training = "train" in split
62
+ self.debug = debug
63
+ self.sampling_rate = sampling_rate
64
+ self.name = ""
65
+
66
+ def set_dataset_name(self, name):
67
+ self.name = name
68
+
69
+ @staticmethod
70
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
71
+ original_size = len(data)
72
+
73
+ data = data.cast_column(audio_field, Audio(decode=False))
74
+
75
+ def identify_corrupted_files(example):
76
+ try:
77
+ sf.read(example[audio_field]["path"])
78
+
79
+ for field in text_fields:
80
+ if field in example and example[field].replace('"', '') == "":
81
+ return False
82
+ return True
83
+ except Exception:
84
+ return False
85
+
86
+ data = data.filter(identify_corrupted_files, num_proc=16)
87
+ validated_size = len(data)
88
+
89
+ # Audio Decoding
90
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
91
+
92
+ if debug:
93
+ print(f"Dataset: {dataset_name}")
94
+ print(f"Original data nums: {original_size}")
95
+ print(f"After filtering data nums: {validated_size}")
96
+ print(f"Filtering ratio: {validated_size/original_size:.2%}")
97
+
98
+ return data
99
+
100
+ @staticmethod
101
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
102
+ original_size = len(data)
103
+
104
+ def filter_audio_by_length(example):
105
+ try:
106
+ audio = example[audio_field]['array']
107
+ channel = 1
108
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
109
+ channel = audio.ndim
110
+ audio = audio.squeeze()
111
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
112
+ return min_sec <= audio_length <= max_sec
113
+ except Exception as e:
114
+ if debug:
115
+ print(f"Error : {str(e)[:100]}... - sample excluded")
116
+ return False
117
+
118
+ data = data.filter(filter_audio_by_length, num_proc=16)
119
+ filtered_size = len(data)
120
+
121
+ if debug:
122
+ print(f"Before Length Filtering data nums: {original_size}")
123
+ print(f"After Length Filtering data nums: {filtered_size}")
124
+ print(f"Filtering ratio: {filtered_size/original_size:.2%}")
125
+
126
+ return data
127
+
128
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
129
+ user_message = {
130
+ 'role': 'user',
131
+ 'content': '<start_of_audio>' + instruction,
132
+ }
133
+ prompt = self.processor.tokenizer.apply_chat_template(
134
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
135
+ )
136
+
137
+ inputs = self.processor(
138
+ text=prompt,
139
+ audio=[audio_array],
140
+ add_special_tokens=False,
141
+ return_tensors='pt'
142
+ )
143
+
144
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
145
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
146
+
147
+ if self.debug:
148
+ self.debug = False
149
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
150
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
151
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
152
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
153
+
154
+ if self.training:
155
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
156
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
157
+ labels[:, -answer_ids.shape[1]:] = answer_ids
158
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
159
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
160
+ else:
161
+ input_ids = inputs.input_ids
162
+ labels = answer_ids
163
+ token_type_ids = inputs.token_type_ids
164
+
165
+ return {
166
+ 'input_ids': input_ids,
167
+ 'labels': labels,
168
+ 'token_type_ids': token_type_ids,
169
+ 'input_audio_embeds': inputs.input_audio_embeds,
170
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
171
+ 'input_modes': inputs.input_modes,
172
+ }
173
+
174
+ # CoVoST2 Dataset Class
175
+ class CoVoSTDataset(BaseAudioDataset):
176
+ def __init__(self, processor, data_dir, split, ast=False,
177
+ lang=("en_ko", "Korean"), sampling_rate=16000, debug=False):
178
+ super().__init__(processor, split, sampling_rate, debug)
179
+
180
+ self.set_dataset_name("CoVoST")
181
+ self.ast = ast
182
+ self.lang = lang[0]
183
+
184
+ self.data = load_dataset("junnei/covost2",
185
+ lang[0],
186
+ data_dir=data_dir,
187
+ split=split,
188
+ trust_remote_code=True
189
+ )
190
+
191
+ text_fields = ["sentence", "translation"] if ast else ["sentence"]
192
+ self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST")
193
+
194
+ # (Optional) Audio length Filtering
195
+ self.data = self.filter_by_audio_length(self.data, "audio")
196
+
197
+ # Instruction Setting
198
+ self.instruction = random.choice(INSTRUCTION["ast"]).format(lang[1]) if ast else random.choice(INSTRUCTION["asr"])
199
+
200
+ def __len__(self):
201
+ return len(self.data)
202
+
203
+ def __getitem__(self, idx):
204
+ data = self.data[idx]
205
+
206
+ if self.ast:
207
+ answer_text = data["translation"]
208
+ else:
209
+ answer_text = data["sentence"].replace('"', '')
210
+
211
+ return self.prepare_model_inputs(
212
+ data["audio"]["array"],
213
+ self.instruction,
214
+ answer_text
215
+ )
216
+
217
+ # Zeroth Korean Dataset Class
218
+ class ZerothKoreanDataset(BaseAudioDataset):
219
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
220
+ super().__init__(processor, split, sampling_rate, debug)
221
+
222
+ self.set_dataset_name("Zeroth")
223
+ # only ASR
224
+ self.ast = False
225
+ self.lang = "ko"
226
+
227
+ # load dataset
228
+ self.data = load_dataset("Bingsu/zeroth-korean",
229
+ split=split,
230
+ trust_remote_code=True
231
+ )
232
+
233
+ # (Optional) Audio length Filtering
234
+ self.data = self.filter_by_audio_length(self.data, "audio")
235
+
236
+ # Instruction Setting
237
+ self.instruction = random.choice(INSTRUCTION["asr"])
238
+
239
+ def __len__(self):
240
+ return len(self.data)
241
+
242
+ def __getitem__(self, idx):
243
+ data = self.data[idx]
244
+
245
+ # Zeroth Korean is only for ASR
246
+ answer_text = data["text"].replace('"', '')
247
+
248
+ return self.prepare_model_inputs(
249
+ data["audio"]["array"],
250
+ self.instruction,
251
+ answer_text
252
+ )
253
+
254
+ # Libri Speech Dataset Class
255
+ class LibriSpeechDataset(BaseAudioDataset):
256
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
257
+ super().__init__(processor, split, sampling_rate, debug)
258
+
259
+ self.set_dataset_name(f"LibriSpeech_{subset}")
260
+ # only ASR
261
+ self.ast = False
262
+ self.lang = "en"
263
+
264
+ # load dataset
265
+ self.data = load_dataset("fixie-ai/librispeech_asr",
266
+ subset,
267
+ split=split,
268
+ trust_remote_code=True
269
+ )
270
+
271
+ # (Optional) Audio length Filtering
272
+ self.data = self.filter_by_audio_length(self.data, "audio")
273
+
274
+ # Instruction Setting
275
+ self.instruction = random.choice(INSTRUCTION["asr"])
276
+
277
+ def __len__(self):
278
+ return len(self.data)
279
+
280
+ def __getitem__(self, idx):
281
+ data = self.data[idx]
282
+
283
+ # Libri Speech is only for ASR
284
+ answer_text = data["text"].replace('"', '')
285
+
286
+ return self.prepare_model_inputs(
287
+ data["audio"]["array"],
288
+ self.instruction,
289
+ answer_text
290
+ )
291
+
292
+ # Fleurs Dataset Class
293
+ class FleursDataset(BaseAudioDataset):
294
+ def __init__(self, processor, split, source_lang, target_lang=None,
295
+ mode="asr", sampling_rate=16000, debug=False):
296
+ super().__init__(processor, split, sampling_rate, debug)
297
+
298
+ self.set_dataset_name("Fleurs")
299
+ # Mode Setting (ASR or AST)
300
+ if mode not in ["asr", "ast"]:
301
+ raise ValueError("mode must be 'asr' or 'ast'.")
302
+
303
+ self.mode = mode
304
+ self.ast = (mode == "ast")
305
+ self.source_lang = source_lang
306
+
307
+ # Language name mapping (expand if needed)
308
+ self.lang_names = {
309
+ 'en_us': 'English', 'ko_kr': 'Korean'
310
+ }
311
+
312
+ # load dataset - source language dataset
313
+ self.data = load_dataset("google/fleurs",
314
+ source_lang,
315
+ split=split,
316
+ trust_remote_code=True
317
+ )
318
+
319
+ # (Optional) Audio length Filtering
320
+ self.data = self.filter_by_audio_length(self.data, "audio")
321
+
322
+ # When AST mode, load target language dataset.
323
+ if self.ast:
324
+ if target_lang is None:
325
+ raise ValueError("AST mode requires target_lang.")
326
+
327
+ self.target_lang = target_lang
328
+ self.lang = f"{source_lang}_{target_lang}"
329
+
330
+ # load dataset - target language dataset (for translation)
331
+ target_data = load_dataset("google/fleurs",
332
+ target_lang,
333
+ split=split,
334
+ trust_remote_code=True
335
+ )
336
+
337
+ source_dict = {item['id']: item for item in self.data}
338
+ target_dict = {item['id']: item for item in target_data}
339
+
340
+ # only Common ID, add translation fields
341
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
342
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
343
+ self.data = [
344
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
345
+ for id in common_ids
346
+ ]
347
+
348
+ # Instruction Setting - use target language name
349
+ target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
350
+ self.instruction = random.choice(INSTRUCTION["ast"]).format(target_lang_name)
351
+ else:
352
+ # ASR mode
353
+ self.lang = source_lang
354
+ self.instruction = random.choice(INSTRUCTION["asr"])
355
+
356
+ if self.debug:
357
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
358
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
359
+ if self.ast:
360
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
361
+ print(f"dataset size: {len(self.data)}")
362
+
363
+ def __len__(self):
364
+ return len(self.data)
365
+
366
+ def __getitem__(self, idx):
367
+ data = self.data[idx]
368
+ audio_array = data["audio"]["array"]
369
+
370
+ if self.ast:
371
+ answer_text = data["translation"]
372
+ else:
373
+ answer_text = data["transcription"]
374
+
375
+ return self.prepare_model_inputs(
376
+ audio_array,
377
+ self.instruction,
378
+ answer_text
379
+ )
380
+
381
+ def covost_collate_fn(batch):
382
+ input_ids_list = []
383
+ labels_list = []
384
+ token_type_ids_list = []
385
+ input_audio_embeds_list = []
386
+ audio_embed_sizes_list = []
387
+ audio_attention_mask_list = []
388
+ input_modes_list = []
389
+ for inputs in batch:
390
+ input_ids_list.append(inputs['input_ids'][0])
391
+ labels_list.append(inputs['labels'][0])
392
+ token_type_ids_list.append(inputs['token_type_ids'][0])
393
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
394
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
395
+ audio_attention_mask_list.append(
396
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
397
+ )
398
+ input_modes_list.append(inputs['input_modes'])
399
+
400
+ try:
401
+ token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
402
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
403
+ labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
404
+ audio_attention_mask = (
405
+ pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
406
+ if len(audio_attention_mask_list) > 1
407
+ else None
408
+ )
409
+ except Exception as e:
410
+ print(e)
411
+ print(input_ids_list)
412
+ print(labels_list)
413
+ raise
414
+ attention_mask = (input_ids != 0).long()
415
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
416
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list)
417
+ input_modes = torch.cat(input_modes_list)
418
+
419
+ return BatchFeature(
420
+ {
421
+ 'input_ids': input_ids,
422
+ 'labels': labels,
423
+ 'token_type_ids': token_type_ids,
424
+ 'attention_mask': attention_mask,
425
+ 'input_audio_embeds': input_audio_embeds,
426
+ 'audio_embed_sizes': audio_embed_sizes,
427
+ 'audio_attention_mask': audio_attention_mask,
428
+ 'input_modes': input_modes,
429
+ }
430
+ )
431
+
432
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
433
+ """
434
+ Pad a list of sequences to the same length.
435
+ sequences: list of tensors in [seq_len, *] shape
436
+ """
437
+ assert padding_side in ['right', 'left']
438
+ max_size = sequences[0].size()
439
+ trailing_dims = max_size[1:]
440
+ max_len = max(len(seq) for seq in sequences)
441
+ batch_size = len(sequences)
442
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
443
+ for i, seq in enumerate(sequences):
444
+ length = seq.size(0)
445
+ if padding_side == 'right':
446
+ output.data[i, :length] = seq
447
+ else:
448
+ output.data[i, -length:] = seq
449
+ return output
450
+
451
+ def cat_with_pad(tensors, dim, padding_value=0):
452
+ """
453
+ cat along dim, while pad to max for all other dims
454
+ """
455
+ ndim = tensors[0].dim()
456
+ assert all(
457
+ t.dim() == ndim for t in tensors[1:]
458
+ ), 'All tensors must have the same number of dimensions'
459
+
460
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
461
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
462
+ output = tensors[0].new_full(out_size, padding_value)
463
+
464
+ index = 0
465
+ for t in tensors:
466
+ # Create a slice list where every dimension except dim is full slice
467
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
468
+ # Update only the concat dimension slice
469
+ slices[dim] = slice(index, index + t.shape[dim])
470
+
471
+ output[slices] = t
472
+ index += t.shape[dim]
473
+
474
+ return output
475
+
476
+ def count_parameters_by_module(model):
477
+ # dictionary for parameters number by modules
478
+ module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
479
+
480
+ # all params
481
+ total_params = 0
482
+ total_trainable_params = 0
483
+
484
+ # Check Embedding Token masks
485
+ embedding_masks = {}
486
+ for name, param in model.named_parameters():
487
+ if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
488
+ # check if params has embedding_grad_mask_hook
489
+ for hook_id, hook_fn in param._backward_hooks.items():
490
+ if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
491
+ # Accessing mask variables in the closure of hook functions
492
+ for cell in hook_fn.__closure__ or []:
493
+ if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
494
+ # check mask tensor
495
+ embedding_masks[name] = ~cell.cell_contents # True : Trainable
496
+
497
+ # Count params by modules
498
+ for name, param in model.named_parameters():
499
+ # extracts top module_name
500
+ module_name = name.split('.')[0]
501
+ param_count = param.numel()
502
+
503
+ module_params[module_name]["total"] += param_count
504
+ total_params += param_count
505
+
506
+ if param.requires_grad:
507
+ # Only count for real trainable params. (with masks)
508
+ if name in embedding_masks:
509
+ trainable_count = embedding_masks[name].sum().item()
510
+ module_params[module_name]["trainable"] += trainable_count
511
+ total_trainable_params += trainable_count
512
+ else:
513
+ module_params[module_name]["trainable"] += param_count
514
+ total_trainable_params += param_count
515
+
516
+ print(f"All Params: {total_params:,}")
517
+ print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
518
+ print("\nParams by Module:")
519
+
520
+ for module_name, counts in sorted(module_params.items()):
521
+ trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
522
+ total_percentage = counts["total"] / total_params * 100
523
+
524
+ print(f"- {module_name}:")
525
+ print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
526
+ print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
527
+
528
+ return module_params
529
+
530
+ def create_model(model_name_or_path, revision="main", use_flash_attention = False):
531
+ model = AutoModel.from_pretrained(
532
+ model_name_or_path,
533
+ revision=revision,
534
+ torch_dtype=torch.bfloat16,
535
+ device_map="auto",
536
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
537
+ trust_remote_code=True,
538
+ )
539
+
540
+ # Set use_cache to False after model loaded
541
+ model.config.use_cache = False
542
+
543
+ # Freeze all parameters
544
+ for param in model.parameters():
545
+ param.requires_grad = False
546
+
547
+ model.set_lora_adapter('speech')
548
+ model.to(torch.bfloat16)
549
+
550
+ # (Optional) unfreeze audio_tower parameters
551
+ #for param in model.audio_tower.parameters():
552
+ # param.requires_grad = True
553
+
554
+ # Only unfreeze audio_projector parameters
555
+ for param in model.audio_projector.parameters():
556
+ param.requires_grad = True
557
+
558
+ # (Optional) unfreeze audio embed_tokens
559
+ train_embed = True
560
+ if train_embed:
561
+ embed_tokens = model.language_model.model.model.embed_tokens
562
+
563
+ embed_tokens.weight.requires_grad = False
564
+
565
+ # Added Speech token IDs (only this tokens be trainable)
566
+ trainable_token_ids = [256001, 256002]
567
+
568
+ embed_tokens.weight.requires_grad = True
569
+ mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
570
+ mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
571
+
572
+ # backward hook, with gradient masking
573
+ def embedding_grad_mask_hook(grad):
574
+ return grad.masked_fill(mask, 0)
575
+
576
+ embed_tokens.weight.register_hook(embedding_grad_mask_hook)
577
+
578
+ model.language_model.model.model.embed_tokens = embed_tokens
579
+
580
+ count_parameters_by_module(model)
581
+
582
+ return model
583
+
584
+ @torch.no_grad()
585
+ def evaluate(model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1):
586
+ model.eval()
587
+ all_generated_texts = []
588
+ all_labels = []
589
+
590
+ eval_dataloader = torch.utils.data.DataLoader(
591
+ eval_dataset,
592
+ batch_size=eval_batch_size,
593
+ collate_fn=covost_collate_fn,
594
+ shuffle=False,
595
+ drop_last=False,
596
+ num_workers=8,
597
+ prefetch_factor=2,
598
+ pin_memory=True,
599
+ )
600
+ stop_tokens = [processor.tokenizer.eos_token]
601
+ stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"]
602
+ stop_tokens_ids = stop_tokens_ids.to('cuda')
603
+
604
+ for inputs in tqdm(
605
+ eval_dataloader, disable= disable_tqdm, desc='running eval'
606
+ ):
607
+ stopping_criteria=StoppingCriteriaList([MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=inputs.input_ids.size(0))])
608
+ inputs = inputs.to('cuda').to(model.dtype)
609
+ generated_ids = model.generate(
610
+ **inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64,
611
+ stopping_criteria=stopping_criteria,
612
+ )
613
+
614
+ stop_tokens_idx = stopping_criteria[0].stop_tokens_idx.reshape(inputs.input_ids.size(0), -1)[:, 0]
615
+
616
+ stop_tokens_idx = torch.where(
617
+ stop_tokens_idx > 0,
618
+ stop_tokens_idx - stop_tokens_ids.shape[-1],
619
+ generated_ids.shape[-1],
620
+ )
621
+ generated_text = [
622
+ processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx], skip_special_tokens=True, clean_up_tokenization_spaces=False)
623
+ for _pred_ids, _stop_tokens_idx in zip(generated_ids, stop_tokens_idx)
624
+ ]
625
+ all_generated_texts.extend(generated_text)
626
+ labels = [processor.decode(_label_ids[_label_ids != 0]).removesuffix(ANSWER_SUFFIX) for _label_ids in inputs["labels"]]
627
+ all_labels.extend(labels)
628
+
629
+ assert len(all_generated_texts) == len(all_labels)
630
+ bleu = sacrebleu.corpus_bleu(all_generated_texts, [all_labels])
631
+ print(bleu)
632
+ if save_path:
633
+ with open(save_path, 'w') as f:
634
+ save_dict = {
635
+ 'all_generated_texts': all_generated_texts,
636
+ 'all_labels': all_labels,
637
+ 'score': bleu.score,
638
+ }
639
+ json.dump(save_dict, f)
640
+
641
+ return bleu.score
642
+
643
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
644
+
645
+ INSTRUCTION = {
646
+ "ast": [
647
+ "Translate the audio to {0}.",
648
+ "Translate the audio clip into {0}.",
649
+ "Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
650
+ "Translate the provided audio file into {0}.",
651
+ "Convert the audio speech to {0} text.",
652
+ "Write an {0} translation of the audio file.",
653
+ "Translate spoken words from the audio into {0}.",
654
+ "Create an {0} version of the audio content.",
655
+ "Produce an accurate {0} translation of the audio.",
656
+ "Extract speech from the audio and translate it to {0}.",
657
+ "Turn the audio into readable {0} text.",
658
+ "Write all spoken content from the audio in {0}.",
659
+ "Generate an {0} translation of the speech in the file.",
660
+ "Convert the recording into {0} text.",
661
+ "Accurately translate the audio recording to {0}.",
662
+ "Write down dialogue from the given audio in {0}.",
663
+ "Translate all speech in this audio file to {0}.",
664
+ "Create an accurate {0} version of the speech.",
665
+ "Perform a complete {0} translation of the audio."
666
+ ],
667
+ "asr": [
668
+ "Transcribe the audio clip into text.",
669
+ "Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
670
+ "Transcribe the provided audio file into text.",
671
+ "Convert the audio speech to text.",
672
+ "Write a transcript of the audio file.",
673
+ "Transcribe spoken words from the audio.",
674
+ "Create a text version of the audio content.",
675
+ "Produce a verbatim transcript of the audio.",
676
+ "Extract and transcribe speech from the audio.",
677
+ "Turn the audio into readable text.",
678
+ "Write all spoken words from the audio.",
679
+ "Generate a transcript of the speech in the file.",
680
+ "Convert the recording into a text transcript.",
681
+ "Accurately transcribe the audio recording.",
682
+ "Write down dialogue from the given audio.",
683
+ "Transcribe all speech in this audio file.",
684
+ "Create an accurate text version of the speech.",
685
+ "Perform a complete transcription of the audio."
686
+ ],
687
+ }
688
+
689
+ ANSWER_SUFFIX = "<end_of_turn>"
690
+ _IGNORE_INDEX = -100
691
+
692
+ model_name_or_path = 'junnei/gemma-3-4b-it-speech'
693
+ use_flash_attention = True
694
+
695
+ output_dir = '/workspace/output'
696
+ batch_size = 128
697
+ batch_size_per_gpu = 32
698
+ learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
699
+ wd = 0.01
700
+ num_train_epochs = 5
701
+
702
+ revision = "main" #"v1.0"
703
+
704
+ processor = AutoProcessor.from_pretrained(
705
+ model_name_or_path,
706
+ revision=revision,
707
+ trust_remote_code=True,
708
+ )
709
+
710
+ model = create_model(
711
+ model_name_or_path,
712
+ revision=revision,
713
+ use_flash_attention=use_flash_attention,
714
+ )
715
+
716
+ train_datasets = []
717
+
718
+ # Covost ASR mode (English -> English text)
719
+ covost_asr_dataset = CoVoSTDataset(
720
+ processor=processor,
721
+ data_dir="/workspace/CommonVoice/EN",
722
+ split="train",
723
+ ast=False,
724
+ lang=("en_ko", "Korean")
725
+ )
726
+ train_datasets.append(covost_asr_dataset)
727
+
728
+ # Covost AST mode (English -> Korean text)
729
+ covost_dataset = CoVoSTDataset(
730
+ processor=processor,
731
+ data_dir="/workspace/CommonVoice/EN",
732
+ split="train",
733
+ ast=True,
734
+ lang=("en_ko", "Korean")
735
+ )
736
+ train_datasets.append(covost_dataset)
737
+
738
+ # Libri Speech Clean ASR mode (English -> English text)
739
+ libri_speech_clean = LibriSpeechDataset(
740
+ processor=processor,
741
+ subset="clean",
742
+ split="train.360"
743
+ )
744
+ train_datasets.append(libri_speech_clean)
745
+
746
+ # Libri Speech Other ASR mode (English -> English text)
747
+ libri_speech_other = LibriSpeechDataset(
748
+ processor=processor,
749
+ subset="other",
750
+ split="train.500"
751
+ )
752
+ train_datasets.append(libri_speech_other)
753
+
754
+ # Fleurs ASR mode (English -> English text)
755
+ en_asr_fleurs = FleursDataset(
756
+ processor=processor,
757
+ split="train",
758
+ source_lang="en_us", # English
759
+ mode="asr"
760
+ )
761
+ train_datasets.append(en_asr_fleurs)
762
+
763
+ # Fleurs AST mode (English -> Korean text)
764
+ en_ko_ast_fleurs = FleursDataset(
765
+ processor=processor,
766
+ split="train",
767
+ source_lang="en_us", # English
768
+ target_lang="ko_kr", # Korean
769
+ mode="ast"
770
+ )
771
+ train_datasets.append(en_ko_ast_fleurs)
772
+
773
+ # Covost ASR mode (Korean -> Korean text)
774
+ covost_ko_asr_dataset = CoVoSTDataset(
775
+ processor=processor,
776
+ data_dir="/workspace/CommonVoice/ko",
777
+ split="train",
778
+ ast=False,
779
+ lang=("ko_en", "English")
780
+ )
781
+ train_datasets.append(covost_ko_asr_dataset)
782
+
783
+ # Covost AST mode (Korean -> English text)
784
+ covost_ko_dataset = CoVoSTDataset(
785
+ processor=processor,
786
+ data_dir="/workspace/CommonVoice/ko",
787
+ split="train",
788
+ ast=True,
789
+ lang=("ko_en", "English")
790
+ )
791
+ train_datasets.append(covost_ko_dataset)
792
+
793
+ # Zeroth ASR mode (Korean -> Korean text)
794
+ ko_asr_zeroth = ZerothKoreanDataset(
795
+ processor=processor,
796
+ split="train"
797
+ )
798
+ train_datasets.append(ko_asr_zeroth)
799
+
800
+ # Fleurs ASR mode (Korean -> Korean text)
801
+ ko_asr_fleurs = FleursDataset(
802
+ processor=processor,
803
+ split="train",
804
+ source_lang="ko_kr", # Korean
805
+ mode="asr"
806
+ )
807
+ train_datasets.append(ko_asr_fleurs)
808
+
809
+ # Fleurs AST mode (Korean -> English text)
810
+ ko_en_ast_fleurs = FleursDataset(
811
+ processor=processor,
812
+ split="train",
813
+ source_lang="ko_kr", # Korean
814
+ target_lang="en_us", # English
815
+ mode="ast"
816
+ )
817
+ train_datasets.append(ko_en_ast_fleurs)
818
+
819
+ print("Count Num of Datasets", len(train_datasets))
820
+ print([len(dataset) for dataset in train_datasets])
821
+
822
+ # ConcatDataset
823
+ train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
824
+ print("Count Length of Datas", len(train_dataset))
825
+
826
+ # Check GPUs
827
+ num_gpus = torch.cuda.device_count()
828
+ print(f'training on {num_gpus} GPUs')
829
+
830
+ assert (
831
+ batch_size % (num_gpus * batch_size_per_gpu) == 0
832
+ ), 'Batch size must be divisible by the number of GPUs'
833
+ gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
834
+
835
+ # hard coded training args
836
+ training_args = TrainingArguments(
837
+ num_train_epochs=num_train_epochs,
838
+ per_device_train_batch_size=batch_size_per_gpu,
839
+ gradient_checkpointing=True,
840
+ gradient_checkpointing_kwargs={'use_reentrant': False},
841
+ gradient_accumulation_steps=gradient_accumulation_steps,
842
+ optim='adamw_torch',
843
+ adam_beta1=0.9,
844
+ adam_beta2=0.95,
845
+ adam_epsilon=1e-7,
846
+ learning_rate=learning_rate,
847
+ weight_decay=wd,
848
+ max_grad_norm=1.0,
849
+ lr_scheduler_type='cosine',
850
+ warmup_steps=50,
851
+ logging_steps=50,
852
+ output_dir=output_dir,
853
+ save_strategy='no',
854
+ save_total_limit=10,
855
+ save_only_model=True,
856
+ bf16=True,
857
+ fp16=False,
858
+ remove_unused_columns=False,
859
+ report_to='none',
860
+ deepspeed=None,
861
+ disable_tqdm=False,
862
+ dataloader_num_workers=4,
863
+ ddp_find_unused_parameters=True,
864
+ )
865
+
866
+ out_path = Path(training_args.output_dir)
867
+ out_path.mkdir(parents=True, exist_ok=True)
868
+
869
+ # create optimizer only for trainable params
870
+ optimizer = torch.optim.AdamW(
871
+ filter(lambda p: p.requires_grad, model.parameters()),
872
+ lr=learning_rate,
873
+ weight_decay=wd,
874
+ betas=(0.9, 0.95),
875
+ eps=1e-7,
876
+ )
877
+
878
+ # Trainer Setting
879
+ trainer = Trainer(
880
+ model=model,
881
+ args=training_args,
882
+ data_collator=covost_collate_fn,
883
+ train_dataset=train_dataset,
884
+ optimizers=(optimizer, None),
885
+ )
886
+
887
+ trainer.train()
888
+
889
+ import shutil
890
+
891
+ # setting output dir
892
+ output_dir = "/workspace/output"
893
+
894
+ # 1. Save LoRA Adapter
895
+ model.language_model.model.save_pretrained(output_dir)
896
+
897
+ # 1-1. Delete Markdown file
898
+ markdown_file = os.path.join(output_dir, "README.md")
899
+ if os.path.exists(markdown_file):
900
+ os.remove(markdown_file)
901
+
902
+ # 2. Save entire model
903
+ model.save_pretrained(output_dir)
904
+
905
+ # 3. Cleanup Memory
906
+ del model
907
+ del trainer
908
+ __import__('gc').collect()
909
+ torch.cuda.empty_cache()
910
+
911
+ from huggingface_hub import HfApi, login, create_repo, Repository, upload_folder
912
+
913
+ upload_dir = "/workspace/upload"
914
+
915
+ # 4. Clone Repo
916
+ repo_id = "junnei/gemma-3-4b-it-speech"
917
+ branch_name = "main" # 새 브랜치 이름
918
+
919
+ repo = Repository(local_dir=upload_dir, clone_from = repo_id)
920
+ repo.git_checkout(branch_name, create_branch_ok=True)
921
+
922
+ # 4-1. Move Trained model to Repo
923
+ for item in os.listdir(output_dir):
924
+ s = os.path.join(output_dir, item)
925
+ d = os.path.join(upload_dir, item)
926
+ if os.path.isdir(s):
927
+ shutil.copytree(s, d, dirs_exist_ok=True)
928
+ else:
929
+ shutil.copy2(s, d)