cosrigel commited on
Commit
9b33e27
·
verified ·
1 Parent(s): f35a87c

Delete dia/finetune.py

Browse files
Files changed (1) hide show
  1. dia/finetune.py +0 -787
dia/finetune.py DELETED
@@ -1,787 +0,0 @@
1
- import argparse
2
- import logging
3
- import os
4
- import random
5
- import tempfile
6
- from dataclasses import dataclass
7
- from pathlib import Path
8
-
9
- import torch
10
- import torchaudio
11
- import pandas as pd
12
- from torch.utils.data import Dataset, DataLoader, random_split
13
- from torch.cuda.amp import autocast
14
- from torch.utils.tensorboard import SummaryWriter
15
- from torch.nn.utils import clip_grad_norm_
16
- from transformers import get_scheduler
17
- import torch.nn.functional as F
18
- import bitsandbytes as bnb
19
- from tqdm import tqdm
20
- from datasets import load_dataset, interleave_datasets, get_dataset_config_names, DatasetDict
21
- from huggingface_hub import hf_hub_download
22
- import math
23
- import gc
24
- from torch.cuda.amp import GradScaler
25
-
26
- import dac
27
- from .config import DiaConfig
28
- from .layers import DiaModel
29
- from .model import Dia
30
- from .audio import build_delay_indices, apply_audio_delay
31
- from .dataset import *
32
- from .interleaved_datasets import load_cml_tts_streamed, load_common_voice17_streamed
33
- from datasets import load_from_disk
34
- from .dataset import HFDiaDataset
35
- from tqdm import tqdm
36
-
37
- # Configure logging
38
- logging.basicConfig(
39
- level=logging.INFO,
40
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
41
- )
42
- logger = logging.getLogger(__name__)
43
-
44
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
- torch.backends.cudnn.benchmark = True
46
-
47
- #bytes for language tag replacement
48
- LANG2BYTE = {
49
- "en": 3,
50
- "vi": 19,
51
- }
52
-
53
- CHANNELS = [
54
- "5phutcrypto",
55
- "anhbanthan",
56
- "anhthamtu",
57
- "animerewind.official",
58
- "bibitv8888",
59
- "btvgo",
60
- "baclieutv",
61
- "bachhoaxanhcom",
62
- "baodientuvov",
63
- "blvckvines",
64
- "boringppl",
65
- "bronub",
66
- "cdteam-why",
67
- "cobabinhduong",
68
- "cosmicwriter",
69
- "cuthongthai",
70
- "daiphatthanhtruyenhinhsonla",
71
- "day-be-thong-minh-tv",
72
- "danangtv",
73
- "daihanoi-htv",
74
- "daiptththainguyentntv",
75
- "dongmauviet",
76
- "dongthaptv",
77
- "fptbongdaofficial",
78
- "fonosvietnam",
79
- "hieurotrong5phut-ntkt",
80
- "htvtintuc",
81
- "happyhidari",
82
- "hoabinhtvgo",
83
- "hocenglishonline",
84
- "hocvienbovagau",
85
- "hungyentvvngo",
86
- "huynhduykhuongofficial",
87
- "huynhlapofficial",
88
- "jvevermind",
89
- "kenhvtc16",
90
- "kiengiangtv",
91
- "khanhvyofficial",
92
- "kienthucquansu",
93
- "lamdongtv1",
94
- "lamvlog",
95
- "longantv-la34",
96
- "mangovid",
97
- "mensbay",
98
- "meovatcuocsonglnv",
99
- "meuchannel",
100
- "ntnvlogsnguyenthanhnam",
101
- "ngamradio",
102
- "nhanhac555",
103
- "nhantaidaiviet",
104
- "ptth-trt",
105
- "ptvtruyenhinhphutho",
106
- "phantichgame",
107
- "phephim",
108
- "phimhottk-l",
109
- "riwaylegal",
110
- "ruangao",
111
- "suckhoetamsinh",
112
- "sachbiquyethanhcong",
113
- "soisangbrightsidevietnamese",
114
- "spiderum",
115
- "spiderumbooks",
116
- "sukieskitchen",
117
- "tin3phut",
118
- "tranthanhtown",
119
- "tulemientay",
120
- "tayninhtv",
121
- "thainhitv",
122
- "thanhpahm",
123
- "thegioilaptop",
124
- "thepresentwriter",
125
- "tiengiangtivi",
126
- "tieubaobaothom",
127
- "tintucbitcoin247",
128
- "truyenhinhbinhphuoc-bptv",
129
- "truyenhinhyenbaiytv",
130
- "truyenhinhcaobang",
131
- "truyenhinhdaklakdrt",
132
- "truyenhinhdaknong1",
133
- "truyenhinhdienbien23.9",
134
- "truyenhinhkhanhhoa",
135
- "truyenhinhkontumkrt",
136
- "truyenhinhnaminhntv",
137
- "truyenhinhninhthuan",
138
- "truyenhinhquangngai",
139
- "tuantienti2911",
140
- "tuyenquangttv",
141
- "vovlivedoctruyen",
142
- "vietcetera",
143
- "vinhlongtv",
144
- "voizfm",
145
- "vutrunguyenthuy",
146
- "vuive",
147
- "w2wanime",
148
- "w2wcartoon",
149
- "w2whorror",
150
- "w2wmovie",
151
- "web5ngay",
152
- "xanh24h",
153
- "aiphatthanhtruyenhinhquangtri",
154
- "aiphatthanhvatruyenhinhhai1908",
155
- "altonghop",
156
- "antvtruyenhinhcongannhandan",
157
- "baihoc10phut",
158
- "battlecry.khampha",
159
- "betterversionvn",
160
- "blogkhoinghiep",
161
- "bumcn",
162
- "caikinhdi_vn",
163
- "canthitg",
164
- "chanthienmybachnien",
165
- "chauanhchao",
166
- "cosu",
167
- "cungmaivaobep-monan-amthuc",
168
- "daiptthphuyen",
169
- "daiptthtv",
170
- "daitruyenhinhangiang",
171
- "daitruyenhinhbacgiang",
172
- "dannytran2375",
173
- "daybehoc5489",
174
- "daylaphegame",
175
- "dienmay",
176
- "ducisreal",
177
- "duongfg",
178
- "duyluandethuong",
179
- "duythanhish",
180
- "elroydevops",
181
- "gc.gamelab",
182
- "hacthaybachthay",
183
- "hagiangtv475",
184
- "haiduongtv247",
185
- "hanamtv8831",
186
- "hangphimtailieudienanhnd",
187
- "haugiangtv",
188
- "haunauday",
189
- "hieu-tv",
190
- "hoshiphan",
191
- "jakinatsumi2915",
192
- "kechuyentieuhoc1719",
193
- "kenhcovan",
194
- "khalid_dinh",
195
- "kiaralah",
196
- "laichautv",
197
- "langsontvtube",
198
- "megame_official",
199
- "minvestvn",
200
- "nguoithanhcong1991",
201
- "nhatkycuocsong.",
202
- "ntcanima",
203
- "ptthbentre",
204
- "ptthquangbinh",
205
- "qrt",
206
- "quangninhtv",
207
- "snewsvn",
208
- "soctrangtv",
209
- "sunhuynpodcast",
210
- "tamhonanuong",
211
- "tgddreview",
212
- "thaibinhtv",
213
- "thanhnamedu",
214
- "thanhnientvnews",
215
- "thbrt",
216
- "thieunhitv3630",
217
- "thtpct",
218
- "tinnhanh3phut868",
219
- "toansam",
220
- "toidicodedaoblog",
221
- "tranquochuywecommit",
222
- "tranvyvy",
223
- "truyenhinh4k",
224
- "truyenhinhbinhthuan",
225
- "truyenhinhcamau69",
226
- "truyenhinhdongnai_dnrtv",
227
- "truyenhinhgialai",
228
- "truyenhinhlaocai",
229
- "truyenhinhnghean",
230
- "truyenhinhvinhphuc",
231
- "txtofficial8798",
232
- "vanhkhuyenle",
233
- "vietnh1009",
234
- "visaothenhipodcast",
235
- "vtc14",
236
- "vtcnow",
237
- "vtv24",
238
- "vuive123",
239
- "zombiev4",
240
- ]
241
-
242
- # Tự động ánh xạ channel → token (bắt đầu từ 30)
243
- for i, ch in enumerate(CHANNELS):
244
- LANG2BYTE[ch] = 30 + i
245
-
246
- test_sentences = {
247
- "en": "In order to fully assess performance and the accuracy of language tags, this test sentence contains multiple subordinate clauses, varied punctuation, and a sufficient word count.",
248
- "vi": "Để đánh giá toàn diện hiệu suất và độ chính xác của các thẻ ngôn ngữ, câu kiểm tra này chứa nhiều mệnh đề phụ, dấu câu đa dạng và số lượng từ đầy đủ."
249
- }
250
-
251
- @dataclass
252
- class TrainConfig:
253
- epochs: int = 1
254
- batch_size: int = 2
255
- grad_accum_steps: int = 2
256
- learning_rate: float = 1e-5
257
- warmup_steps: int = 500
258
- unconditional_frac: float = 0.15
259
- eval_step: int = 200
260
- save_step: int = 2000
261
- split_ratio: float = 0.997
262
- shuffle_buffer_size: int = None # for streaming shuffle
263
- seed: int = 42 # seed for reproducibility
264
- runs_dir: Path = Path("runs")
265
- run_name: str = "dia_finetune_cv"
266
- output_dir: Path = Path(".cpkts/dia_finetune_cv ")
267
- resume_from: Path = None
268
- total_steps: int = 290007
269
-
270
-
271
-
272
- def get_args() -> argparse.Namespace:
273
- parser = argparse.ArgumentParser(description="Train the Dia audio model")
274
- parser.add_argument("--config", type=Path, default=Path("dia/config.json"))
275
- parser.add_argument("--dataset", type=str, default="Paradoxia/opendata-iisys-hui",
276
- help="HuggingFace dataset name (if not using --csv_path).")
277
- parser.add_argument("--dataset2", type=str, default=None,
278
- help="(Optional) second HF dataset to interleave (streaming)")
279
- parser.add_argument("--streaming",action="store_true",
280
- help="Enable HuggingFace dataset streaming")
281
- parser.add_argument("--hub_model", type=str, default="nari-labs/Dia-1.6B")
282
- parser.add_argument("--local_ckpt", type=str, default=None)
283
- parser.add_argument("--csv_path", type=Path, default=None,
284
- help="Path to local CSV/TSV file with `audio|text` (if you want to train locally).")
285
- parser.add_argument("--audio_root",type=Path, default=None,
286
- help="Root directory for local audio files (required if --csv_path is set).")
287
- parser.add_argument("--run_name", type=str, default=None)
288
- parser.add_argument("--output_dir",type=Path, default=None)
289
- parser.add_argument("--shuffle_buffer_size", type=int, default=None,
290
- help="Buffer size for streaming dataset shuffle.")
291
- parser.add_argument("--seed", type=int, default=42,
292
- help="Random seed for reproducibility.")
293
- parser.add_argument("--half", action="store_true", help="load model in fp16")
294
- parser.add_argument("--compile", action="store_true", help="torch compile model")
295
- parser.add_argument('--use_amp', action='store_true', help='Enable mixed precision')
296
- parser.add_argument("--resume_from", type=str, default=None)
297
- return parser.parse_args()
298
-
299
-
300
-
301
- def collate_fn(batch, config: DiaConfig, device: torch.device):
302
- from torch.nn.functional import pad
303
-
304
- texts, encodings, waveforms = zip(*batch)
305
-
306
- # -- Text inputs ---------------------------------------------------------
307
-
308
- max_text = config.data.text_length
309
- pad_tok = config.data.text_pad_value
310
- text_ids = []
311
- for txt in texts:
312
- b_full = txt.encode('utf-8')
313
- # replace leading "[lang]" prefix
314
- for code, val in LANG2BYTE.items():
315
- prefix = f"[{code}]".encode('utf-8')
316
- if b_full.startswith(prefix):
317
- b_full = bytes([val]) + b_full[len(prefix):]
318
- break
319
- bts = b_full[:max_text]
320
- arr = list(bts) + [pad_tok] * (max_text - len(bts))
321
- text_ids.append(torch.tensor(arr, dtype=torch.long))
322
- src = torch.stack(text_ids).to(device)
323
- src_pos = torch.arange(max_text, device=device).unsqueeze(0).expand(src.size(0), -1)
324
- src_pad = src.ne(pad_tok)
325
- enc_self_attn_mask = (src_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
326
-
327
- # -- Audio codes --------------------------------------------------------
328
-
329
- max_audio = config.data.audio_length
330
- # per-sample lengths (clipped to max_audio)
331
- seq_lens = [min(e.size(0), max_audio) for e in encodings]
332
- batch_max = max(seq_lens)
333
-
334
- # pad or trim each encoding to the batch max length
335
- padded = [pad(e, (0, 0, 0, batch_max - e.size(0))) if e.size(0) < batch_max else e[:batch_max]
336
- for e in encodings]
337
- codes = torch.stack(padded).to(device) # (B, T=batch_max, C)
338
-
339
- B, T, C = codes.shape
340
- t_idx, idxs = build_delay_indices(B, T, C, config.data.delay_pattern)
341
- delayed = apply_audio_delay(
342
- codes,
343
- config.data.audio_pad_value,
344
- config.data.audio_bos_value,
345
- (t_idx, idxs)
346
- )
347
- # ensure no longer than max_audio
348
- delayed = delayed[:, :max_audio, :]
349
-
350
- # -- Targets with per-sample EOS ----------------------------------------
351
-
352
- max_tgt_len = max_audio + 2
353
- pad_val = config.data.audio_pad_value
354
- bos_val = config.data.audio_bos_value
355
- eos_val = config.data.audio_eos_value
356
-
357
- tgt = torch.full((B, max_tgt_len, C), pad_val, dtype=torch.long, device=device)
358
- tgt[:, 0, :] = bos_val
359
- tgt_lens = []
360
- for i, L in enumerate(seq_lens):
361
- tgt[i, 1:1 + L, :] = delayed[i, :L, :]
362
- tgt[i, 1 + L, :] = eos_val
363
- tgt_lens.append(1 + L + 1)
364
-
365
- tgt_pos = torch.arange(max_tgt_len, device=device).unsqueeze(0).expand(B, -1)
366
- tgt_pad = tgt.ne(pad_val).any(-1)
367
-
368
- causal = torch.tril(torch.ones((max_tgt_len, max_tgt_len),
369
- dtype=torch.bool,
370
- device=device))
371
- dec_self_attn_mask = (tgt_pad.unsqueeze(2) & tgt_pad.unsqueeze(1) & causal).unsqueeze(1)
372
- dec_cross_attn_mask = (tgt_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
373
-
374
- return {
375
- 'src_tokens': src,
376
- 'src_positions': src_pos,
377
- 'enc_self_attn_mask': enc_self_attn_mask,
378
- 'tgt_tokens': tgt,
379
- 'tgt_positions': tgt_pos,
380
- 'dec_self_attn_mask': dec_self_attn_mask,
381
- 'dec_cross_attn_mask': dec_cross_attn_mask,
382
- 'waveforms': waveforms,
383
- 'raw_text': texts[0],
384
- 'tgt_lens': torch.tensor(tgt_lens, dtype=torch.long, device=device),
385
- }
386
-
387
- def setup_loaders(dataset, dia_cfg: DiaConfig, train_cfg: TrainConfig, device):
388
- collate = lambda b: collate_fn(b, dia_cfg, device)
389
- if isinstance(dataset, HFDiaIterDataset):
390
- total = getattr(dataset, "total_examples", None)
391
- if total is None:
392
- total = dataset.dataset.info.splits["train"].num_examples
393
- n_train = int(train_cfg.split_ratio * total)
394
- n_val = total - n_train
395
- if n_val <= 0:
396
- raise RuntimeError(f"No validation samples: total={total}, split_ratio={train_cfg.split_ratio}")
397
- base = dataset.dataset.shuffle(buffer_size=train_cfg.shuffle_buffer_size, seed=train_cfg.seed) if train_cfg.shuffle_buffer_size else dataset.dataset
398
- val_stream = base.take(n_val)
399
- train_stream = base.skip(n_val)
400
- train_ds = HFDiaIterDataset(train_stream, dia_cfg, dataset.dac_model)
401
- val_ds = HFDiaIterDataset(val_stream, dia_cfg, dataset.dac_model)
402
- train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=False, collate_fn=collate)
403
- train_loader.steps_per_epoch = math.ceil(n_train / train_cfg.batch_size)
404
- val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
405
- return train_loader, val_loader
406
- ds_len = len(dataset)
407
- n_train = int(train_cfg.split_ratio * ds_len)
408
- train_ds, val_ds = random_split(dataset, [n_train, ds_len - n_train])
409
- train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=True, collate_fn=collate)
410
- val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
411
- return train_loader, val_loader
412
-
413
-
414
-
415
- def setup_optimizer_and_scheduler(model, train_loader, train_cfg):
416
- opt = bnb.optim.AdamW8bit(model.parameters(), lr=train_cfg.learning_rate)
417
- # Determine steps per epoch: prefer len(), else use attached attribute
418
- try:
419
- steps_per_epoch = len(train_loader)
420
- except TypeError:
421
- if hasattr(train_loader, 'steps_per_epoch'):
422
- steps_per_epoch = train_loader.steps_per_epoch
423
- else:
424
- raise RuntimeError("Cannot determine steps_per_epoch for streaming loader")
425
- total_training_steps = steps_per_epoch * train_cfg.epochs
426
- sched = get_scheduler(
427
- 'cosine', opt,
428
- num_warmup_steps=train_cfg.warmup_steps / train_cfg.grad_accum_steps,
429
- num_training_steps=total_training_steps / train_cfg.grad_accum_steps
430
- )
431
- return opt, sched
432
-
433
-
434
-
435
- def train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step,scaler):
436
- """
437
- Perform a single training step: forward, loss, backward, update, log.
438
- Now uses per‑sample tgt_lens to mask out padding after each EOS,
439
- and applies 4× loss weight on the first channel.
440
- """
441
- # (optional) unconditional conditioning
442
- if random.random() < train_cfg.unconditional_frac:
443
- pad_tok = dia_cfg.data.text_pad_value
444
- batch['src_tokens'] = torch.zeros_like(batch['src_tokens'])
445
- batch['enc_self_attn_mask'] = torch.zeros_like(batch['enc_self_attn_mask'])
446
- batch['dec_cross_attn_mask'] = torch.zeros_like(batch['dec_cross_attn_mask'])
447
-
448
- with autocast(dtype=torch.float16):
449
- # forward pass
450
- logits = model(
451
- src_BxS=batch['src_tokens'],
452
- tgt_BxTxC=batch['tgt_tokens'],
453
- src_positions=batch['src_positions'],
454
- tgt_positions=batch['tgt_positions'],
455
- enc_self_attn_mask=batch['enc_self_attn_mask'],
456
- dec_self_attn_mask=batch['dec_self_attn_mask'],
457
- dec_cross_attn_mask=batch['dec_cross_attn_mask'],
458
- enable_dropout=True,
459
- )
460
- # fetch per-sample target‑lengths (including BOS+frames+EOS)
461
- lens = batch['tgt_lens'] # shape: (B,)
462
- max_L = int(lens.max().item()) # maximum over batch
463
-
464
- # keep only up through the last possible EOS slot
465
- # logits: (B, T, C, V) -> (B, max_L-1, C, V)
466
- logits = logits[:, : max_L - 1]
467
-
468
- # targets: shift off the BOS so 0..<max_L-1> align with logits
469
- # target: (B, T, C) -> (B, max_L-1, C)
470
- target = batch['tgt_tokens'][:, 1:max_L, :]
471
-
472
- B, Tm1, C = target.shape
473
- pad_val = dia_cfg.data.audio_pad_value
474
-
475
- # build a mask [B x (max_L-1)] that is True for t < (lens[i]-1)
476
- time_idx = torch.arange(Tm1, device=lens.device).unsqueeze(0) # (1, Tm1)
477
- valid_time = time_idx < (lens.unsqueeze(1) - 1) # (B, Tm1)
478
- mask = valid_time.unsqueeze(-1).expand(-1, -1, C) # (B, Tm1, C)
479
-
480
- # apply 4× weight on first channel, 1× on others
481
- channel_weights = [4.0] + [1.0] * (C - 1)
482
- loss_c = 0.0
483
- _, _, _, V = logits.size()
484
-
485
- for c, w in enumerate(channel_weights):
486
- # flatten this channel
487
- lc = logits[:, :, c, :].reshape(-1, V) # (B*Tm1, V)
488
- tc = target[:, :, c].reshape(-1) # (B*Tm1,)
489
- mc = mask[:, :, c].reshape(-1) # (B*Tm1,)
490
-
491
- # mask out padding and compute cross-entropy
492
- lc_valid = lc[mc]
493
- tc_valid = tc[mc]
494
- loss_c += w * F.cross_entropy(
495
- lc_valid, tc_valid,
496
- ignore_index=pad_val
497
- )
498
-
499
- # normalize by sum of weights
500
- loss = loss_c / sum(channel_weights)
501
-
502
- # scale + backward
503
- loss = loss / train_cfg.grad_accum_steps
504
- scaler.scale(loss).backward()
505
-
506
-
507
- # step & log
508
-
509
- if (step_in_epoch + 1) % train_cfg.grad_accum_steps == 0:
510
- # Unscale before clipping
511
- scaler.unscale_(opt)
512
- grad_norm = clip_grad_norm_(model.parameters(), max_norm=1e9)
513
-
514
- scaler.step(opt)
515
- scaler.update()
516
- opt.zero_grad()
517
- sched.step()
518
-
519
- true_loss = loss.item() * train_cfg.grad_accum_steps
520
- current_lr = sched.get_last_lr()[0]
521
-
522
- writer.add_scalar('GradNorm/global', grad_norm, global_step)
523
- writer.add_scalar('LR', current_lr, global_step)
524
- writer.add_scalar('Loss/train', true_loss, global_step)
525
-
526
- return true_loss
527
- else:
528
- return loss.item() * train_cfg.grad_accum_steps
529
-
530
-
531
-
532
-
533
- def eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step):
534
- """
535
- Run evaluation: compute average loss on validation set and log audio samples.
536
- """
537
- import gc
538
- eval_losses = []
539
- last_batch = None
540
- with torch.inference_mode():
541
- for eb in tqdm(val_loader, desc="eval"):
542
- last_batch = eb
543
-
544
- with autocast(dtype=torch.float16):
545
- logits16 = model(
546
- src_BxS=eb['src_tokens'],
547
- tgt_BxTxC=eb['tgt_tokens'],
548
- src_positions=eb['src_positions'],
549
- tgt_positions=eb['tgt_positions'],
550
- enc_self_attn_mask=eb['enc_self_attn_mask'],
551
- dec_self_attn_mask=eb['dec_self_attn_mask'],
552
- dec_cross_attn_mask=eb['dec_cross_attn_mask'],
553
- enable_dropout=False,
554
- )[:, :-1]
555
-
556
- logits = logits16.float()
557
- target = eb['tgt_tokens'][:, 1:]
558
- B_e, T_e, C_e = target.shape
559
- V_e = logits.size(-1)
560
-
561
- loss_e = 0.0
562
- weights_e = [4.0] + [1.0] * (C_e - 1)
563
- for c, w in enumerate(weights_e):
564
- lc = logits[:, :, c, :].reshape(-1, V_e)
565
- tc = target[:, :, c].reshape(-1)
566
- loss_e += w * F.cross_entropy(
567
- lc, tc, ignore_index=dia_cfg.data.audio_pad_value
568
- )
569
- loss_e = loss_e / sum(weights_e)
570
-
571
- eval_losses.append(loss_e)
572
-
573
- avg_eval_loss = sum(eval_losses) / len(eval_losses)
574
- writer.add_scalar('Loss/eval', avg_eval_loss.item(), global_step)
575
-
576
- # --- Inference test sentence ---
577
- try:
578
- orig_dtype = next(model.parameters()).dtype
579
- model = model.float()
580
- dia_gen = Dia(dia_cfg, device)
581
- dia_gen.model, dia_gen.dac_model = model, dac_model
582
-
583
- # ✅ Test câu hội thoại đa giọng
584
- test_dialogue = "[vtv24] Em vừa đi học về, anh ạ. [duongfg] Ừ, em ăn cơm chưa? [vtv24] Em ăn rồi!"
585
-
586
- if len(test_dialogue) > 10:
587
- try:
588
- audio = dia_gen.generate(text=test_dialogue)
589
- writer.add_audio("Eval/test_dialogue", audio, global_step, 44100)
590
- except Exception:
591
- logger.exception("Eval error during test_dialogue")
592
- finally:
593
- if 'audio' in locals():
594
- del audio
595
-
596
-
597
- except Exception:
598
- logger.exception("Eval error")
599
-
600
- finally:
601
- if 'audio' in locals():
602
- del audio
603
- gc.collect()
604
- torch.cuda.empty_cache()
605
- if orig_dtype == torch.float16:
606
- model = model.half()
607
-
608
- def train(model, dia_cfg: DiaConfig, dac_model: dac.DAC, dataset, train_cfg: TrainConfig):
609
- """
610
- Run the full training loop over epochs.
611
- """
612
- # prepare directories
613
- train_cfg.output_dir.mkdir(parents=True, exist_ok=True)
614
- (train_cfg.runs_dir / train_cfg.run_name).mkdir(parents=True, exist_ok=True)
615
- model = model.to(device)
616
-
617
- train_loader, val_loader = setup_loaders(dataset, dia_cfg, train_cfg, device)
618
- opt, sched = setup_optimizer_and_scheduler(model, train_loader, train_cfg)
619
-
620
- writer = SummaryWriter(train_cfg.runs_dir / train_cfg.run_name)
621
- model.train()
622
- scaler = GradScaler()
623
- start_epoch = 0
624
- global_step = 0
625
- resume_ckpt = getattr(train_cfg, "resume_from", None)
626
- if resume_ckpt and resume_ckpt.exists():
627
- logger.info(f"Resuming from checkpoint: {resume_ckpt}")
628
- checkpoint = torch.load(resume_ckpt, map_location=device)
629
- model.load_state_dict(checkpoint["model"])
630
- opt.load_state_dict(checkpoint["optimizer"])
631
- sched.load_state_dict(checkpoint["scheduler"])
632
- scaler.load_state_dict(checkpoint["scaler"])
633
- start_epoch = checkpoint.get("epoch", 0)
634
- global_step = checkpoint.get("global_step", 0)
635
-
636
-
637
- steps_per_epoch = getattr(train_loader, 'steps_per_epoch', None)
638
- if steps_per_epoch is None:
639
- try:
640
- steps_per_epoch = len(train_loader)
641
- except Exception:
642
- steps_per_epoch = None
643
-
644
- for epoch in range(start_epoch, train_cfg.epochs):
645
- # iterate with progress bar, using total if known
646
- loader_iter = tqdm(
647
- train_loader,
648
- desc=f"E{epoch+1}",
649
- total=steps_per_epoch
650
- )
651
- pbar = tqdm(loader_iter, total=train_cfg.total_steps, initial=global_step, desc=f"E{epoch}")
652
- for step_in_epoch, batch in enumerate(pbar):
653
- global_step += 1
654
- # training step
655
- loss = train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step, scaler)
656
-
657
- cur_alloc = torch.cuda.memory_allocated() # bytes currently allocated by tensors
658
- peak_alloc = torch.cuda.max_memory_allocated() # bytes peak during program
659
- # optionally convert to GB
660
- cur_gb = cur_alloc / 1024**3
661
- peak_gb = peak_alloc / 1024**3
662
-
663
- # update the tqdm postfix
664
- loader_iter.set_postfix({
665
- 'loss': f"{loss:.4f}",
666
- 'VRAM (GB)': f"{cur_gb:.2f}/{peak_gb:.2f}"
667
- })
668
-
669
- # remember to zero the peak if you want rolling peaks per step
670
- if torch.cuda.is_available():
671
- torch.cuda.reset_peak_memory_stats()
672
-
673
-
674
- # evaluation
675
- if step_in_epoch % train_cfg.eval_step == 0:
676
- model.eval()
677
- with torch.no_grad():
678
- eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step)
679
- model.train()
680
- scaler = GradScaler()
681
-
682
- # checkpoint
683
- if step_in_epoch and step_in_epoch % train_cfg.save_step == 0:
684
- ckpt = train_cfg.output_dir / f"ckpt_step{global_step}.pth"
685
- torch.save({
686
- "model": model.state_dict(),
687
- "optimizer": opt.state_dict(),
688
- "scheduler": sched.state_dict(),
689
- "scaler": scaler.state_dict(),
690
- "epoch": epoch,
691
- "global_step": global_step
692
- }, ckpt)
693
- logger.info(f"Saved checkpoint: {ckpt}")
694
-
695
- # end of epoch checkpoint
696
- ckpt_e = train_cfg.output_dir / f"ckpt_epoch{epoch+1}.pth"
697
- torch.save({
698
- "model": model.state_dict(),
699
- "optimizer": opt.state_dict(),
700
- "scheduler": sched.state_dict(),
701
- "scaler": scaler.state_dict(),
702
- "epoch": epoch + 1,
703
- "global_step": global_step
704
- }, ckpt_e)
705
- logger.info(f"Saved end-of-epoch checkpoint: {ckpt_e}")
706
-
707
- from datasets import disable_caching
708
-
709
- def main():
710
- args = get_args()
711
- import os
712
- os.environ["HF_DATASETS_CACHE"] = "/tmp/force_streaming" # ép cache mới
713
- disable_caching()
714
- # tắt toàn bộ cache local HuggingFace
715
- import json
716
- with open(args.config, "r", encoding="utf-8") as f:
717
- config_dict = json.load(f)
718
-
719
- dia_cfg = DiaConfig(**config_dict)
720
- dac_model = dac.DAC.load(dac.utils.download()).to(device)
721
- dataset = None
722
-
723
- if not dataset:
724
- if args.csv_path:
725
- if not args.audio_root:
726
- raise ValueError("`--audio_root` must be set when using `--csv_path`")
727
- dataset = LocalDiaDataset(args.csv_path, args.audio_root, dia_cfg, dac_model)
728
- else:
729
- # ✅ Check nếu dataset là đường dẫn local
730
- if Path(args.dataset).exists():
731
- print(f"Loading dataset from local path: {args.dataset}")
732
- ds1 = load_from_disk(args.dataset)
733
- if isinstance(ds1, DatasetDict):
734
- ds1 = ds1["train"]
735
- dataset = HFDiaDataset(ds1, dia_cfg, dac_model)
736
- else:
737
- print(f"Loading HuggingFace dataset: {args.dataset} (streaming)")
738
- ds1 = load_dataset(args.dataset, split="train", streaming=True)
739
-
740
- if args.dataset2:
741
- ds2 = load_dataset(args.dataset2, split="train", streaming=True)
742
- hf_ds = interleave_datasets([ds1, ds2])
743
- dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
744
- else:
745
- hf_ds = ds1
746
- dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
747
-
748
-
749
-
750
- train_cfg = TrainConfig(
751
- run_name = args.run_name or TrainConfig.run_name,
752
- output_dir = args.output_dir or TrainConfig.output_dir,
753
- shuffle_buffer_size = args.shuffle_buffer_size,
754
- seed = args.seed,
755
- )
756
- if args.resume_from:
757
- train_cfg.resume_from = Path(args.resume_from)
758
- # load model checkpoint
759
- if args.local_ckpt:
760
- ckpt_file = args.local_ckpt
761
- else:
762
- ckpt_file = hf_hub_download(args.hub_model, filename="dia-v0_1.pth")
763
- model = DiaModel(dia_cfg)
764
- if args.half:
765
- model=model.half()
766
- if args.compile:
767
- model = torch.compile(model, backend="inductor")
768
- ckpt = torch.load(ckpt_file, map_location="cpu")
769
- state_dict = ckpt["model"] if "model" in ckpt else ckpt
770
- new_state_dict = {}
771
-
772
- for k, v in state_dict.items():
773
- if "encoder.embedding.weight" in k:
774
- if v.shape != model.state_dict()[k].shape:
775
- print(f"⚠️ Bỏ qua {k} do shape không khớp: {v.shape} vs {model.state_dict()[k].shape}")
776
- continue
777
- new_state_dict[k] = v
778
-
779
- model.load_state_dict(new_state_dict, strict=False)
780
-
781
-
782
- # start training
783
- train(model, dia_cfg, dac_model, dataset, train_cfg)
784
-
785
-
786
- if __name__ == "__main__":
787
- main()