junnei commited on
Commit
800f55d
ยท
verified ยท
1 Parent(s): 30528ba

Upload evaluate_speech.py

Browse files
Files changed (1) hide show
  1. examples/evaluate_speech.py +345 -82
examples/evaluate_speech.py CHANGED
@@ -25,7 +25,7 @@ normalizer = {
25
 
26
  # ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
27
  model_id = "junnei/gemma-3-4b-it-speech"
28
- revision = "v1.0"
29
 
30
  model = AutoModel.from_pretrained(
31
  model_id, device_map="auto", revision = revision, trust_remote_code=True
@@ -45,76 +45,282 @@ INSTRUCTION = {
45
  "asr": "Transcribe the audio clip into text.",
46
  }
47
 
48
- class CoVoSTDataset(Dataset):
49
- def __init__(self, processor, data_dir, ast=False,
50
- lang=("en_ko", "Korean")):
51
- self.data = load_dataset("junnei/covost2",
52
- lang[0],
53
- data_dir=data_dir,
54
- split='test',
55
- trust_remote_code=True
56
- )
57
-
58
- original_size = len(self.data)
59
- self.data = self.data.cast_column("audio", Audio(decode=False))
60
 
 
 
 
 
 
 
 
 
 
61
  def identify_corrupted_files(example):
62
  try:
63
- # ๋””์ฝ”๋”ฉ ์‹œ๋„
64
- sf.read(example["audio"]["path"])
65
- if example['translation'] == "" or example['sentence'] == "":
66
- return False
 
67
  return True
68
  except Exception:
69
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- self.data = self.data.filter(identify_corrupted_files, num_proc=16)
72
- validated_size = len(self.data)
73
- self.data = self.data.cast_column("audio", Audio(sampling_rate = 16000, decode=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- self.lang = lang[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  self.ast = ast
 
 
 
 
 
 
 
 
77
 
78
- print(f"- {self.lang}: {('AST' if self.ast else 'ASR')}")
79
- print(f"์›๋ณธ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜: {original_size}")
80
- print(f"์—๋Ÿฌ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜: {original_size - validated_size}")
81
- print(f"ํ•„ํ„ฐ๋ง ๋น„์œจ: {validated_size/original_size:.2%}")
 
82
 
83
- self.processor = processor
84
  self.instruction = INSTRUCTION["ast"].format(lang[1]) if ast else INSTRUCTION["asr"]
85
-
86
  def __len__(self):
87
  return len(self.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def __getitem__(self, idx):
90
  data = self.data[idx]
91
- user_message = {
92
- 'role': 'user',
93
- 'content': '<start_of_audio>' + self.instruction,
94
- }
95
- prompt = self.processor.tokenizer.apply_chat_template(
96
- [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
 
 
97
  )
98
- inputs = self.processor(text=prompt, audio=[data["audio"]["array"]], add_special_tokens=False, return_tensors='pt')
99
- sentence = data['sentence'].replace('"', '')
100
- answer = f"{data['translation'] if self.ast else sentence}"
101
 
102
- return {
103
- 'input_ids': inputs.input_ids,
104
- 'attention_mask': inputs.attention_mask,
105
- 'token_type_ids': inputs.token_type_ids,
106
- 'input_modes': inputs.input_modes,
107
- 'input_audio_embeds': inputs.input_audio_embeds,
108
- 'audio_embed_sizes': inputs.audio_embed_sizes,
109
- 'sentence': sentence,
110
- 'answer': answer,
 
 
 
 
 
 
 
 
 
 
111
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- def select(self, indices):
114
- self.data = self.data.select(indices)
115
- return self
 
 
 
 
 
 
 
116
 
117
- def pad_sequence(sequences, padding_side='right', padding_value=0):
118
  """
119
  Pad a list of sequences to the same length.
120
  sequences: list of tensors in [seq_len, *] shape
@@ -164,7 +370,6 @@ def covost_collate_fn(batch):
164
  audio_embed_sizes_list = []
165
  audio_attention_mask_list = []
166
  input_modes_list = []
167
- sentence_list = []
168
  answer_list = []
169
  for inputs in batch:
170
  input_ids_list.append(inputs['input_ids'][0])
@@ -174,7 +379,6 @@ def covost_collate_fn(batch):
174
  inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
175
  )
176
  input_modes_list.append(inputs['input_modes'])
177
- sentence_list.append(inputs['sentence'])
178
  answer_list.append(inputs['answer'])
179
 
180
  try:
@@ -202,14 +406,13 @@ def covost_collate_fn(batch):
202
  'audio_embed_sizes': audio_embed_sizes,
203
  'audio_attention_mask': audio_attention_mask,
204
  'input_modes': input_modes,
205
- 'sentence': sentence_list,
206
  'answer': answer_list,
207
  }
208
  )
209
 
210
- def save_results(results, task, source_lang, target_lang=None, sample_idx=None):
211
  """๊ฒฐ๊ณผ๋ฅผ JSON ํŒŒ์ผ๋กœ ์ €์žฅ"""
212
- filename = f"{task}_{source_lang}"
213
  if target_lang:
214
  filename += f"_to_{target_lang}"
215
  if sample_idx is not None:
@@ -244,7 +447,6 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
244
 
245
  # ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ์ฒ˜๋ฆฌ
246
  for batch_idx, batch in enumerate(tqdm(dataloader)):
247
- batch_sentences = batch.pop("sentence")
248
  batch_references = batch.pop("answer")
249
 
250
  # GPU๋กœ ์ด๋™
@@ -253,7 +455,10 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
253
 
254
  # ๋ฐฐ์น˜ ์ถ”๋ก 
255
  with torch.inference_mode():
256
- generate_ids = model.generate(**batch, max_new_tokens=256, do_sample=False)
 
 
 
257
 
258
  input_lengths = batch['input_ids'].shape[1]
259
  generate_ids = generate_ids[:, input_lengths:]
@@ -264,11 +469,10 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
264
  )
265
 
266
  # ๊ฒฐ๊ณผ ์ €์žฅ
267
- for i, (sentence, reference, prediction) in enumerate(zip(batch_sentences, batch_references, batch_predictions)):
268
  idx = batch_idx * batch_size + i
269
  sample_result = {
270
  "id": idx,
271
- "sentence": sentence,
272
  "reference": reference,
273
  "prediction": prediction
274
  }
@@ -329,7 +533,7 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
329
  "num_samples": len(temp_results),
330
  "sample_results": temp_results
331
  }
332
- save_results(partial_results, task_type, source_lang, target_lang)
333
 
334
  for item in sample_results:
335
  ref = eval_normalizer(item["reference"])
@@ -351,6 +555,7 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
351
  avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
352
 
353
  results = {
 
354
  "task": task_type,
355
  "source_lang": source_lang,
356
  "target_lang": target_lang,
@@ -364,60 +569,118 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
364
  }
365
 
366
  # ์ตœ์ข… ๊ฒฐ๊ณผ ์ €์žฅ
367
- save_results(results, task_type, source_lang, target_lang)
368
  return results
369
 
370
  # ๋ฉ”์ธ ์‹คํ–‰ ์ฝ”๋“œ
371
  if __name__ == "__main__":
372
  # ํ‰๊ฐ€ํ•  ์–ธ์–ด ๋ชฉ๋ก (์†Œ์Šค ์–ธ์–ด)
373
  source_languages = [
374
- ("en_us", "English"), # ์˜์–ด (๋ฏธ๊ตญ)
375
  #("ko_kr", "Korean"),
 
376
  ]
377
 
378
  # ๋ฒˆ์—ญ ๋Œ€์ƒ ์–ธ์–ด ๋ชฉ๋ก (์ฝ”๋“œ, ์ด๋ฆ„)
379
  target_languages = [
380
- ("ko_kr", "Korean"),
381
  #("en_us", "English"),
 
382
  ]
383
 
384
  data_dir = {
385
- "en_us" : "/workspace/CommonVoice/EN",
386
  #"ko_kr" : "/workspace/CommonVoice/ko",
 
387
  }
388
 
389
  # ์ƒ˜ํ”Œ ์ˆ˜ ์„ค์ • (-1์€ ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹ ์‚ฌ์šฉ)
390
  num_samples = -1
391
- batch_size = 16
392
 
393
  # ๋ชจ๋“  ์†Œ์Šค ์–ธ์–ด์— ๋Œ€ํ•ด ASR ํ‰๊ฐ€
394
  for source_lang, target_lang in zip(source_languages, target_languages):
395
  print(f"\n===== {source_lang[0]} ASR ํ‰๊ฐ€ ์‹œ์ž‘ =====")
396
 
397
  # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
398
- covost = CoVoSTDataset(processor, data_dir[source_lang[0]], ast=False, lang=(f"{source_lang[0].split('_')[0]}_{target_lang[0].split('_')[0]}", f"{target_lang[1]}"))
399
 
400
- # ASR ํ‰๊ฐ€
401
- asr_results = evaluate_task(covost, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
402
-
403
- print(f"\n=== {source_lang[0]} ASR ๊ฒฐ๊ณผ ===")
404
- print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
405
- print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
406
- print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
 
 
 
 
 
408
  try:
409
  print(f"\n===== {source_lang[0]} -> {target_lang[0]} ๋ฒˆ์—ญ ํ‰๊ฐ€ ์‹œ์ž‘ =====")
410
-
411
- # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
412
- covost = CoVoSTDataset(processor, data_dir[source_lang[0]], ast=True, lang=(f"{source_lang[0].split('_')[0]}_{target_lang[0].split('_')[0]}", f"{target_lang[1]}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
- # ๋ฒˆ์—ญ ํ‰๊ฐ€
415
- translation_results = evaluate_task(covost, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = False)
416
-
417
- print(f"\n=== {source_lang[0]} -> {target_lang[0]} ๋ฒˆ์—ญ ๊ฒฐ๊ณผ ===")
418
- print(f"BLEU: {translation_results.get('metrics', {}).get('bleu', 'N/A')}")
419
- print(f"WER: {translation_results.get('metrics', {}).get('wer', 'N/A')}")
420
- print(f"CER: {translation_results.get('metrics', {}).get('cer', 'N/A')}")
 
421
 
422
  except Exception as e:
423
  error_info = {
 
25
 
26
  # ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
27
  model_id = "junnei/gemma-3-4b-it-speech"
28
+ revision = "main" #"v1.0"
29
 
30
  model = AutoModel.from_pretrained(
31
  model_id, device_map="auto", revision = revision, trust_remote_code=True
 
45
  "asr": "Transcribe the audio clip into text.",
46
  }
47
 
48
+ class BaseAudioDataset(Dataset):
49
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
50
+ self.processor = processor
51
+ self.training = "train" in split
52
+ self.debug = debug
53
+ self.sampling_rate = sampling_rate
54
+ self.name = ""
 
 
 
 
 
55
 
56
+ def set_dataset_name(self, name):
57
+ self.name = name
58
+
59
+ @staticmethod
60
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
61
+ original_size = len(data)
62
+
63
+ data = data.cast_column(audio_field, Audio(decode=False))
64
+
65
  def identify_corrupted_files(example):
66
  try:
67
+ sf.read(example[audio_field]["path"])
68
+
69
+ for field in text_fields:
70
+ if example[field].replace('"', '') == "":
71
+ return False
72
  return True
73
  except Exception:
74
  return False
75
+
76
+ data = data.filter(identify_corrupted_files, num_proc=16)
77
+ validated_size = len(data)
78
+
79
+ # ์˜ค๋””์˜ค ๋””์ฝ”๋”ฉ
80
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
81
+
82
+ if debug:
83
+ print(f"๋ฐ์ดํ„ฐ์…‹: {dataset_name}")
84
+ print(f"์›๋ณธ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜: {original_size}")
85
+ print(f"ํ•„ํ„ฐ๋ง ํ›„ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜: {validated_size}")
86
+ print(f"ํ•„ํ„ฐ๋ง ๋น„์œจ: {validated_size/original_size:.2%}")
87
+
88
+ return data
89
 
90
+ @staticmethod
91
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
92
+ original_size = len(data)
93
+
94
+ def filter_audio_by_length(example):
95
+ try:
96
+ audio = example[audio_field]['array']
97
+ channel = 1
98
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
99
+ channel = audio.ndim
100
+ audio = audio.squeeze()
101
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
102
+ return min_sec <= audio_length <= max_sec
103
+ except Exception as e:
104
+ if debug:
105
+ print(f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)[:100]}... - ์ƒ˜ํ”Œ ์ œ์™ธ๋จ")
106
+ return False
107
+
108
+ data = data.filter(filter_audio_by_length, num_proc=16)
109
+ filtered_size = len(data)
110
+
111
+ if debug:
112
+ print(f"๊ธธ์ด ํ•„ํ„ฐ๋ง ์ „ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜: {original_size}")
113
+ print(f"๊ธธ์ด ํ•„ํ„ฐ๋ง ํ›„ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜: {filtered_size}")
114
+ print(f"ํ•„ํ„ฐ๋ง ๋น„์œจ: {filtered_size/original_size:.2%}")
115
+
116
+ return data
117
 
118
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
119
+ user_message = {
120
+ 'role': 'user',
121
+ 'content': '<start_of_audio>' + instruction,
122
+ }
123
+ prompt = self.processor.tokenizer.apply_chat_template(
124
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
125
+ )
126
+
127
+ inputs = self.processor(
128
+ text=prompt,
129
+ audio=[audio_array],
130
+ add_special_tokens=False,
131
+ return_tensors='pt'
132
+ )
133
+
134
+ input_ids = inputs.input_ids
135
+ token_type_ids = inputs.token_type_ids
136
+
137
+ return {
138
+ 'input_ids': input_ids,
139
+ 'token_type_ids': token_type_ids,
140
+ 'input_audio_embeds': inputs.input_audio_embeds,
141
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
142
+ 'input_modes': inputs.input_modes,
143
+ 'answer': answer_text,
144
+ }
145
+
146
+ # CoVoST2 Dataset Class
147
+ class CoVoSTDataset(BaseAudioDataset):
148
+ def __init__(self, processor, data_dir, split, ast=False,
149
+ lang=("en_ko", "Korean"), sampling_rate=16000, debug=False):
150
+ super().__init__(processor, split, sampling_rate, debug)
151
+
152
+ self.set_dataset_name("CoVoST")
153
+
154
  self.ast = ast
155
+ self.lang = lang[0]
156
+
157
+ self.data = load_dataset("junnei/covost2",
158
+ lang[0],
159
+ data_dir=data_dir,
160
+ split=split,
161
+ trust_remote_code=True
162
+ )
163
 
164
+ text_fields = ["sentence", "translation"] if ast else ["sentence"]
165
+ self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST")
166
+
167
+ # (Optional) Audio length Filtering
168
+ self.data = self.filter_by_audio_length(self.data, "audio")
169
 
170
+ # Instruction Setting
171
  self.instruction = INSTRUCTION["ast"].format(lang[1]) if ast else INSTRUCTION["asr"]
172
+
173
  def __len__(self):
174
  return len(self.data)
175
+
176
+ def __getitem__(self, idx):
177
+ data = self.data[idx]
178
+
179
+ if self.ast:
180
+ answer_text = data["translation"]
181
+ else:
182
+ answer_text = data["sentence"].replace('"', '')
183
+
184
+ return self.prepare_model_inputs(
185
+ data["audio"]["array"],
186
+ self.instruction,
187
+ answer_text
188
+ )
189
 
190
+
191
+ # Libri Speech Dataset Class
192
+ class LibriSpeechDataset(BaseAudioDataset):
193
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
194
+ super().__init__(processor, split, sampling_rate, debug)
195
+
196
+ self.set_dataset_name(f"LibriSpeech_{subset}")
197
+
198
+ # only ASR
199
+ self.ast = False
200
+ self.lang = "en"
201
+
202
+ if split == "train":
203
+ split = "train.360"
204
+
205
+ # load dataset
206
+ self.data = load_dataset("fixie-ai/librispeech_asr",
207
+ subset,
208
+ split=split,
209
+ trust_remote_code=True
210
+ )
211
+
212
+ # (Optional) Audio length Filtering
213
+ self.data = self.filter_by_audio_length(self.data, "audio")
214
+
215
+ # Instruction Setting
216
+ self.instruction = INSTRUCTION["asr"]
217
+
218
+ def __len__(self):
219
+ return len(self.data)
220
+
221
  def __getitem__(self, idx):
222
  data = self.data[idx]
223
+
224
+ # Libri Speech is only for ASR
225
+ answer_text = data["text"].replace('"', '')
226
+
227
+ return self.prepare_model_inputs(
228
+ data["audio"]["array"],
229
+ self.instruction,
230
+ answer_text
231
  )
 
 
 
232
 
233
+ # Fleurs Dataset Class
234
+ class FleursDataset(BaseAudioDataset):
235
+ def __init__(self, processor, split, source_lang, target_lang=None,
236
+ mode="asr", sampling_rate=16000, debug=False):
237
+ super().__init__(processor, split, sampling_rate, debug)
238
+
239
+ self.set_dataset_name("Fleurs")
240
+
241
+ # Mode Setting (ASR or AST)
242
+ if mode not in ["asr", "ast"]:
243
+ raise ValueError("mode must be 'asr' or 'ast'.")
244
+
245
+ self.mode = mode
246
+ self.ast = (mode == "ast")
247
+ self.source_lang = source_lang
248
+
249
+ # Language name mapping (expand if needed)
250
+ self.lang_names = {
251
+ 'en_us': 'English', 'ko_kr': 'Korean'
252
  }
253
+
254
+ # load dataset - source language dataset
255
+ self.data = load_dataset("google/fleurs",
256
+ source_lang,
257
+ split=split,
258
+ trust_remote_code=True
259
+ )
260
+
261
+ # (Optional) Audio length Filtering
262
+ self.data = self.filter_by_audio_length(self.data, "audio")
263
+
264
+ # When AST mode, load target language dataset.
265
+ if self.ast:
266
+ if target_lang is None:
267
+ raise ValueError("AST mode requires target_lang.")
268
+
269
+ self.target_lang = target_lang
270
+ self.lang = f"{source_lang}_{target_lang}"
271
+
272
+ # load dataset - target language dataset (for translation)
273
+ target_data = load_dataset("google/fleurs",
274
+ target_lang,
275
+ split=split,
276
+ trust_remote_code=True
277
+ )
278
+
279
+ source_dict = {item['id']: item for item in self.data}
280
+ target_dict = {item['id']: item for item in target_data}
281
+
282
+ # only Common ID, add translation fields
283
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
284
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
285
+ self.data = [
286
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
287
+ for id in common_ids
288
+ ]
289
+
290
+ # Instruction Setting - use target language name
291
+ target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
292
+ self.instruction = INSTRUCTION["ast"].format(target_lang_name)
293
+ else:
294
+ # ASR mode
295
+ self.lang = source_lang
296
+ self.instruction = INSTRUCTION["asr"]
297
+
298
+ if self.debug:
299
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
300
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
301
+ if self.ast:
302
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
303
+ print(f"dataset size: {len(self.data)}")
304
+
305
+ def __len__(self):
306
+ return len(self.data)
307
+
308
+ def __getitem__(self, idx):
309
+ data = self.data[idx]
310
+ audio_array = data["audio"]["array"]
311
 
312
+ if self.ast:
313
+ answer_text = data["translation"]
314
+ else:
315
+ answer_text = data["transcription"]
316
+
317
+ return self.prepare_model_inputs(
318
+ audio_array,
319
+ self.instruction,
320
+ answer_text
321
+ )
322
 
323
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
324
  """
325
  Pad a list of sequences to the same length.
326
  sequences: list of tensors in [seq_len, *] shape
 
370
  audio_embed_sizes_list = []
371
  audio_attention_mask_list = []
372
  input_modes_list = []
 
373
  answer_list = []
374
  for inputs in batch:
375
  input_ids_list.append(inputs['input_ids'][0])
 
379
  inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
380
  )
381
  input_modes_list.append(inputs['input_modes'])
 
382
  answer_list.append(inputs['answer'])
383
 
384
  try:
 
406
  'audio_embed_sizes': audio_embed_sizes,
407
  'audio_attention_mask': audio_attention_mask,
408
  'input_modes': input_modes,
 
409
  'answer': answer_list,
410
  }
411
  )
412
 
413
+ def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
414
  """๊ฒฐ๊ณผ๋ฅผ JSON ํŒŒ์ผ๋กœ ์ €์žฅ"""
415
+ filename = f"{task}_{dataset_name}_{source_lang}"
416
  if target_lang:
417
  filename += f"_to_{target_lang}"
418
  if sample_idx is not None:
 
447
 
448
  # ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ์ฒ˜๋ฆฌ
449
  for batch_idx, batch in enumerate(tqdm(dataloader)):
 
450
  batch_references = batch.pop("answer")
451
 
452
  # GPU๋กœ ์ด๋™
 
455
 
456
  # ๋ฐฐ์น˜ ์ถ”๋ก 
457
  with torch.inference_mode():
458
+ generate_ids = model.generate(**batch,
459
+ max_new_tokens=256,
460
+ #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
461
+ )
462
 
463
  input_lengths = batch['input_ids'].shape[1]
464
  generate_ids = generate_ids[:, input_lengths:]
 
469
  )
470
 
471
  # ๊ฒฐ๊ณผ ์ €์žฅ
472
+ for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
473
  idx = batch_idx * batch_size + i
474
  sample_result = {
475
  "id": idx,
 
476
  "reference": reference,
477
  "prediction": prediction
478
  }
 
533
  "num_samples": len(temp_results),
534
  "sample_results": temp_results
535
  }
536
+ save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
537
 
538
  for item in sample_results:
539
  ref = eval_normalizer(item["reference"])
 
555
  avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
556
 
557
  results = {
558
+ "dataset": dataset.name,
559
  "task": task_type,
560
  "source_lang": source_lang,
561
  "target_lang": target_lang,
 
569
  }
570
 
571
  # ์ตœ์ข… ๊ฒฐ๊ณผ ์ €์žฅ
572
+ save_results(results, dataset.name, task_type, source_lang, target_lang)
573
  return results
574
 
575
  # ๋ฉ”์ธ ์‹คํ–‰ ์ฝ”๋“œ
576
  if __name__ == "__main__":
577
  # ํ‰๊ฐ€ํ•  ์–ธ์–ด ๋ชฉ๋ก (์†Œ์Šค ์–ธ์–ด)
578
  source_languages = [
 
579
  #("ko_kr", "Korean"),
580
+ ("en_us", "English"), # ์˜์–ด (๋ฏธ๊ตญ)
581
  ]
582
 
583
  # ๋ฒˆ์—ญ ๋Œ€์ƒ ์–ธ์–ด ๋ชฉ๋ก (์ฝ”๋“œ, ์ด๋ฆ„)
584
  target_languages = [
 
585
  #("en_us", "English"),
586
+ ("ko_kr", "Korean"),
587
  ]
588
 
589
  data_dir = {
 
590
  #"ko_kr" : "/workspace/CommonVoice/ko",
591
+ "en_us" : "/workspace/CommonVoice/EN",
592
  }
593
 
594
  # ์ƒ˜ํ”Œ ์ˆ˜ ์„ค์ • (-1์€ ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹ ์‚ฌ์šฉ)
595
  num_samples = -1
596
+ batch_size = 32
597
 
598
  # ๋ชจ๋“  ์†Œ์Šค ์–ธ์–ด์— ๋Œ€ํ•ด ASR ํ‰๊ฐ€
599
  for source_lang, target_lang in zip(source_languages, target_languages):
600
  print(f"\n===== {source_lang[0]} ASR ํ‰๊ฐ€ ์‹œ์ž‘ =====")
601
 
602
  # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
603
+ split = "test"
604
 
605
+ datasets = []
606
+
607
+ # Covost ASR mode (English -> English text)
608
+ covost = CoVoSTDataset(
609
+ processor=processor,
610
+ data_dir="/workspace/CommonVoice/EN",
611
+ split=split,
612
+ ast=False,
613
+ lang=("en_ko", "Korean")
614
+ )
615
+ datasets.append(covost)
616
+
617
+ # Libri Speech Clean ASR mode (English -> English text)
618
+ libri_speech_clean = LibriSpeechDataset(
619
+ processor=processor,
620
+ subset="clean",
621
+ split=split
622
+ )
623
+ datasets.append(libri_speech_clean)
624
+
625
+ # Libri Speech Other ASR mode (English -> English text)
626
+ libri_speech_other = LibriSpeechDataset(
627
+ processor=processor,
628
+ subset="other",
629
+ split=split
630
+ )
631
+ datasets.append(libri_speech_other)
632
+
633
+ # Fleurs ASR mode (English -> English text)
634
+ fleurs = FleursDataset(
635
+ processor=processor,
636
+ split=split,
637
+ source_lang="en_us", # English
638
+ mode="asr"
639
+ )
640
+ datasets.append(fleurs)
641
+
642
+ for dataset in datasets:
643
+ # ASR ํ‰๊ฐ€
644
+ asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
645
 
646
+ print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR ๊ฒฐ๊ณผ ===")
647
+ print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
648
+ print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
649
+ print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
650
+
651
  try:
652
  print(f"\n===== {source_lang[0]} -> {target_lang[0]} ๋ฒˆ์—ญ ํ‰๊ฐ€ ์‹œ์ž‘ =====")
653
+
654
+ datasets = []
655
+
656
+ # Covost AST mode (English -> Korean text)
657
+ covost = CoVoSTDataset(
658
+ processor=processor,
659
+ data_dir="/workspace/CommonVoice/EN",
660
+ split=split,
661
+ ast=True,
662
+ lang=("en_ko", "Korean")
663
+ )
664
+ datasets.append(covost)
665
+
666
+ # Fleurs AST mode (English -> Korean text)
667
+ fleurs = FleursDataset(
668
+ processor=processor,
669
+ split=split,
670
+ source_lang="en_us", # English
671
+ target_lang="ko_kr", # Korean
672
+ mode="ast"
673
+ )
674
+ datasets.append(fleurs)
675
 
676
+ for dataset in datasets:
677
+ # ๋ฒˆ์—ญ ํ‰๊ฐ€
678
+ translation_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = False)
679
+
680
+ print(f"\n=== {translation_results.get('dataset', 'Dataset')} | {source_lang[0]} -> {target_lang[0]} ๋ฒˆ์—ญ ๊ฒฐ๊ณผ ===")
681
+ print(f"BLEU: {translation_results.get('metrics', {}).get('bleu', 'N/A')}")
682
+ print(f"WER: {translation_results.get('metrics', {}).get('wer', 'N/A')}")
683
+ print(f"CER: {translation_results.get('metrics', {}).get('cer', 'N/A')}")
684
 
685
  except Exception as e:
686
  error_info = {