Spaces:
Running
Running
Delete dia/finetune.py
Browse files- dia/finetune.py +0 -787
dia/finetune.py
DELETED
|
@@ -1,787 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import logging
|
| 3 |
-
import os
|
| 4 |
-
import random
|
| 5 |
-
import tempfile
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torchaudio
|
| 11 |
-
import pandas as pd
|
| 12 |
-
from torch.utils.data import Dataset, DataLoader, random_split
|
| 13 |
-
from torch.cuda.amp import autocast
|
| 14 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 15 |
-
from torch.nn.utils import clip_grad_norm_
|
| 16 |
-
from transformers import get_scheduler
|
| 17 |
-
import torch.nn.functional as F
|
| 18 |
-
import bitsandbytes as bnb
|
| 19 |
-
from tqdm import tqdm
|
| 20 |
-
from datasets import load_dataset, interleave_datasets, get_dataset_config_names, DatasetDict
|
| 21 |
-
from huggingface_hub import hf_hub_download
|
| 22 |
-
import math
|
| 23 |
-
import gc
|
| 24 |
-
from torch.cuda.amp import GradScaler
|
| 25 |
-
|
| 26 |
-
import dac
|
| 27 |
-
from .config import DiaConfig
|
| 28 |
-
from .layers import DiaModel
|
| 29 |
-
from .model import Dia
|
| 30 |
-
from .audio import build_delay_indices, apply_audio_delay
|
| 31 |
-
from .dataset import *
|
| 32 |
-
from .interleaved_datasets import load_cml_tts_streamed, load_common_voice17_streamed
|
| 33 |
-
from datasets import load_from_disk
|
| 34 |
-
from .dataset import HFDiaDataset
|
| 35 |
-
from tqdm import tqdm
|
| 36 |
-
|
| 37 |
-
# Configure logging
|
| 38 |
-
logging.basicConfig(
|
| 39 |
-
level=logging.INFO,
|
| 40 |
-
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 41 |
-
)
|
| 42 |
-
logger = logging.getLogger(__name__)
|
| 43 |
-
|
| 44 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
-
torch.backends.cudnn.benchmark = True
|
| 46 |
-
|
| 47 |
-
#bytes for language tag replacement
|
| 48 |
-
LANG2BYTE = {
|
| 49 |
-
"en": 3,
|
| 50 |
-
"vi": 19,
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
CHANNELS = [
|
| 54 |
-
"5phutcrypto",
|
| 55 |
-
"anhbanthan",
|
| 56 |
-
"anhthamtu",
|
| 57 |
-
"animerewind.official",
|
| 58 |
-
"bibitv8888",
|
| 59 |
-
"btvgo",
|
| 60 |
-
"baclieutv",
|
| 61 |
-
"bachhoaxanhcom",
|
| 62 |
-
"baodientuvov",
|
| 63 |
-
"blvckvines",
|
| 64 |
-
"boringppl",
|
| 65 |
-
"bronub",
|
| 66 |
-
"cdteam-why",
|
| 67 |
-
"cobabinhduong",
|
| 68 |
-
"cosmicwriter",
|
| 69 |
-
"cuthongthai",
|
| 70 |
-
"daiphatthanhtruyenhinhsonla",
|
| 71 |
-
"day-be-thong-minh-tv",
|
| 72 |
-
"danangtv",
|
| 73 |
-
"daihanoi-htv",
|
| 74 |
-
"daiptththainguyentntv",
|
| 75 |
-
"dongmauviet",
|
| 76 |
-
"dongthaptv",
|
| 77 |
-
"fptbongdaofficial",
|
| 78 |
-
"fonosvietnam",
|
| 79 |
-
"hieurotrong5phut-ntkt",
|
| 80 |
-
"htvtintuc",
|
| 81 |
-
"happyhidari",
|
| 82 |
-
"hoabinhtvgo",
|
| 83 |
-
"hocenglishonline",
|
| 84 |
-
"hocvienbovagau",
|
| 85 |
-
"hungyentvvngo",
|
| 86 |
-
"huynhduykhuongofficial",
|
| 87 |
-
"huynhlapofficial",
|
| 88 |
-
"jvevermind",
|
| 89 |
-
"kenhvtc16",
|
| 90 |
-
"kiengiangtv",
|
| 91 |
-
"khanhvyofficial",
|
| 92 |
-
"kienthucquansu",
|
| 93 |
-
"lamdongtv1",
|
| 94 |
-
"lamvlog",
|
| 95 |
-
"longantv-la34",
|
| 96 |
-
"mangovid",
|
| 97 |
-
"mensbay",
|
| 98 |
-
"meovatcuocsonglnv",
|
| 99 |
-
"meuchannel",
|
| 100 |
-
"ntnvlogsnguyenthanhnam",
|
| 101 |
-
"ngamradio",
|
| 102 |
-
"nhanhac555",
|
| 103 |
-
"nhantaidaiviet",
|
| 104 |
-
"ptth-trt",
|
| 105 |
-
"ptvtruyenhinhphutho",
|
| 106 |
-
"phantichgame",
|
| 107 |
-
"phephim",
|
| 108 |
-
"phimhottk-l",
|
| 109 |
-
"riwaylegal",
|
| 110 |
-
"ruangao",
|
| 111 |
-
"suckhoetamsinh",
|
| 112 |
-
"sachbiquyethanhcong",
|
| 113 |
-
"soisangbrightsidevietnamese",
|
| 114 |
-
"spiderum",
|
| 115 |
-
"spiderumbooks",
|
| 116 |
-
"sukieskitchen",
|
| 117 |
-
"tin3phut",
|
| 118 |
-
"tranthanhtown",
|
| 119 |
-
"tulemientay",
|
| 120 |
-
"tayninhtv",
|
| 121 |
-
"thainhitv",
|
| 122 |
-
"thanhpahm",
|
| 123 |
-
"thegioilaptop",
|
| 124 |
-
"thepresentwriter",
|
| 125 |
-
"tiengiangtivi",
|
| 126 |
-
"tieubaobaothom",
|
| 127 |
-
"tintucbitcoin247",
|
| 128 |
-
"truyenhinhbinhphuoc-bptv",
|
| 129 |
-
"truyenhinhyenbaiytv",
|
| 130 |
-
"truyenhinhcaobang",
|
| 131 |
-
"truyenhinhdaklakdrt",
|
| 132 |
-
"truyenhinhdaknong1",
|
| 133 |
-
"truyenhinhdienbien23.9",
|
| 134 |
-
"truyenhinhkhanhhoa",
|
| 135 |
-
"truyenhinhkontumkrt",
|
| 136 |
-
"truyenhinhnaminhntv",
|
| 137 |
-
"truyenhinhninhthuan",
|
| 138 |
-
"truyenhinhquangngai",
|
| 139 |
-
"tuantienti2911",
|
| 140 |
-
"tuyenquangttv",
|
| 141 |
-
"vovlivedoctruyen",
|
| 142 |
-
"vietcetera",
|
| 143 |
-
"vinhlongtv",
|
| 144 |
-
"voizfm",
|
| 145 |
-
"vutrunguyenthuy",
|
| 146 |
-
"vuive",
|
| 147 |
-
"w2wanime",
|
| 148 |
-
"w2wcartoon",
|
| 149 |
-
"w2whorror",
|
| 150 |
-
"w2wmovie",
|
| 151 |
-
"web5ngay",
|
| 152 |
-
"xanh24h",
|
| 153 |
-
"aiphatthanhtruyenhinhquangtri",
|
| 154 |
-
"aiphatthanhvatruyenhinhhai1908",
|
| 155 |
-
"altonghop",
|
| 156 |
-
"antvtruyenhinhcongannhandan",
|
| 157 |
-
"baihoc10phut",
|
| 158 |
-
"battlecry.khampha",
|
| 159 |
-
"betterversionvn",
|
| 160 |
-
"blogkhoinghiep",
|
| 161 |
-
"bumcn",
|
| 162 |
-
"caikinhdi_vn",
|
| 163 |
-
"canthitg",
|
| 164 |
-
"chanthienmybachnien",
|
| 165 |
-
"chauanhchao",
|
| 166 |
-
"cosu",
|
| 167 |
-
"cungmaivaobep-monan-amthuc",
|
| 168 |
-
"daiptthphuyen",
|
| 169 |
-
"daiptthtv",
|
| 170 |
-
"daitruyenhinhangiang",
|
| 171 |
-
"daitruyenhinhbacgiang",
|
| 172 |
-
"dannytran2375",
|
| 173 |
-
"daybehoc5489",
|
| 174 |
-
"daylaphegame",
|
| 175 |
-
"dienmay",
|
| 176 |
-
"ducisreal",
|
| 177 |
-
"duongfg",
|
| 178 |
-
"duyluandethuong",
|
| 179 |
-
"duythanhish",
|
| 180 |
-
"elroydevops",
|
| 181 |
-
"gc.gamelab",
|
| 182 |
-
"hacthaybachthay",
|
| 183 |
-
"hagiangtv475",
|
| 184 |
-
"haiduongtv247",
|
| 185 |
-
"hanamtv8831",
|
| 186 |
-
"hangphimtailieudienanhnd",
|
| 187 |
-
"haugiangtv",
|
| 188 |
-
"haunauday",
|
| 189 |
-
"hieu-tv",
|
| 190 |
-
"hoshiphan",
|
| 191 |
-
"jakinatsumi2915",
|
| 192 |
-
"kechuyentieuhoc1719",
|
| 193 |
-
"kenhcovan",
|
| 194 |
-
"khalid_dinh",
|
| 195 |
-
"kiaralah",
|
| 196 |
-
"laichautv",
|
| 197 |
-
"langsontvtube",
|
| 198 |
-
"megame_official",
|
| 199 |
-
"minvestvn",
|
| 200 |
-
"nguoithanhcong1991",
|
| 201 |
-
"nhatkycuocsong.",
|
| 202 |
-
"ntcanima",
|
| 203 |
-
"ptthbentre",
|
| 204 |
-
"ptthquangbinh",
|
| 205 |
-
"qrt",
|
| 206 |
-
"quangninhtv",
|
| 207 |
-
"snewsvn",
|
| 208 |
-
"soctrangtv",
|
| 209 |
-
"sunhuynpodcast",
|
| 210 |
-
"tamhonanuong",
|
| 211 |
-
"tgddreview",
|
| 212 |
-
"thaibinhtv",
|
| 213 |
-
"thanhnamedu",
|
| 214 |
-
"thanhnientvnews",
|
| 215 |
-
"thbrt",
|
| 216 |
-
"thieunhitv3630",
|
| 217 |
-
"thtpct",
|
| 218 |
-
"tinnhanh3phut868",
|
| 219 |
-
"toansam",
|
| 220 |
-
"toidicodedaoblog",
|
| 221 |
-
"tranquochuywecommit",
|
| 222 |
-
"tranvyvy",
|
| 223 |
-
"truyenhinh4k",
|
| 224 |
-
"truyenhinhbinhthuan",
|
| 225 |
-
"truyenhinhcamau69",
|
| 226 |
-
"truyenhinhdongnai_dnrtv",
|
| 227 |
-
"truyenhinhgialai",
|
| 228 |
-
"truyenhinhlaocai",
|
| 229 |
-
"truyenhinhnghean",
|
| 230 |
-
"truyenhinhvinhphuc",
|
| 231 |
-
"txtofficial8798",
|
| 232 |
-
"vanhkhuyenle",
|
| 233 |
-
"vietnh1009",
|
| 234 |
-
"visaothenhipodcast",
|
| 235 |
-
"vtc14",
|
| 236 |
-
"vtcnow",
|
| 237 |
-
"vtv24",
|
| 238 |
-
"vuive123",
|
| 239 |
-
"zombiev4",
|
| 240 |
-
]
|
| 241 |
-
|
| 242 |
-
# Tự động ánh xạ channel → token (bắt đầu từ 30)
|
| 243 |
-
for i, ch in enumerate(CHANNELS):
|
| 244 |
-
LANG2BYTE[ch] = 30 + i
|
| 245 |
-
|
| 246 |
-
test_sentences = {
|
| 247 |
-
"en": "In order to fully assess performance and the accuracy of language tags, this test sentence contains multiple subordinate clauses, varied punctuation, and a sufficient word count.",
|
| 248 |
-
"vi": "Để đánh giá toàn diện hiệu suất và độ chính xác của các thẻ ngôn ngữ, câu kiểm tra này chứa nhiều mệnh đề phụ, dấu câu đa dạng và số lượng từ đầy đủ."
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
@dataclass
|
| 252 |
-
class TrainConfig:
|
| 253 |
-
epochs: int = 1
|
| 254 |
-
batch_size: int = 2
|
| 255 |
-
grad_accum_steps: int = 2
|
| 256 |
-
learning_rate: float = 1e-5
|
| 257 |
-
warmup_steps: int = 500
|
| 258 |
-
unconditional_frac: float = 0.15
|
| 259 |
-
eval_step: int = 200
|
| 260 |
-
save_step: int = 2000
|
| 261 |
-
split_ratio: float = 0.997
|
| 262 |
-
shuffle_buffer_size: int = None # for streaming shuffle
|
| 263 |
-
seed: int = 42 # seed for reproducibility
|
| 264 |
-
runs_dir: Path = Path("runs")
|
| 265 |
-
run_name: str = "dia_finetune_cv"
|
| 266 |
-
output_dir: Path = Path(".cpkts/dia_finetune_cv ")
|
| 267 |
-
resume_from: Path = None
|
| 268 |
-
total_steps: int = 290007
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
def get_args() -> argparse.Namespace:
|
| 273 |
-
parser = argparse.ArgumentParser(description="Train the Dia audio model")
|
| 274 |
-
parser.add_argument("--config", type=Path, default=Path("dia/config.json"))
|
| 275 |
-
parser.add_argument("--dataset", type=str, default="Paradoxia/opendata-iisys-hui",
|
| 276 |
-
help="HuggingFace dataset name (if not using --csv_path).")
|
| 277 |
-
parser.add_argument("--dataset2", type=str, default=None,
|
| 278 |
-
help="(Optional) second HF dataset to interleave (streaming)")
|
| 279 |
-
parser.add_argument("--streaming",action="store_true",
|
| 280 |
-
help="Enable HuggingFace dataset streaming")
|
| 281 |
-
parser.add_argument("--hub_model", type=str, default="nari-labs/Dia-1.6B")
|
| 282 |
-
parser.add_argument("--local_ckpt", type=str, default=None)
|
| 283 |
-
parser.add_argument("--csv_path", type=Path, default=None,
|
| 284 |
-
help="Path to local CSV/TSV file with `audio|text` (if you want to train locally).")
|
| 285 |
-
parser.add_argument("--audio_root",type=Path, default=None,
|
| 286 |
-
help="Root directory for local audio files (required if --csv_path is set).")
|
| 287 |
-
parser.add_argument("--run_name", type=str, default=None)
|
| 288 |
-
parser.add_argument("--output_dir",type=Path, default=None)
|
| 289 |
-
parser.add_argument("--shuffle_buffer_size", type=int, default=None,
|
| 290 |
-
help="Buffer size for streaming dataset shuffle.")
|
| 291 |
-
parser.add_argument("--seed", type=int, default=42,
|
| 292 |
-
help="Random seed for reproducibility.")
|
| 293 |
-
parser.add_argument("--half", action="store_true", help="load model in fp16")
|
| 294 |
-
parser.add_argument("--compile", action="store_true", help="torch compile model")
|
| 295 |
-
parser.add_argument('--use_amp', action='store_true', help='Enable mixed precision')
|
| 296 |
-
parser.add_argument("--resume_from", type=str, default=None)
|
| 297 |
-
return parser.parse_args()
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
def collate_fn(batch, config: DiaConfig, device: torch.device):
|
| 302 |
-
from torch.nn.functional import pad
|
| 303 |
-
|
| 304 |
-
texts, encodings, waveforms = zip(*batch)
|
| 305 |
-
|
| 306 |
-
# -- Text inputs ---------------------------------------------------------
|
| 307 |
-
|
| 308 |
-
max_text = config.data.text_length
|
| 309 |
-
pad_tok = config.data.text_pad_value
|
| 310 |
-
text_ids = []
|
| 311 |
-
for txt in texts:
|
| 312 |
-
b_full = txt.encode('utf-8')
|
| 313 |
-
# replace leading "[lang]" prefix
|
| 314 |
-
for code, val in LANG2BYTE.items():
|
| 315 |
-
prefix = f"[{code}]".encode('utf-8')
|
| 316 |
-
if b_full.startswith(prefix):
|
| 317 |
-
b_full = bytes([val]) + b_full[len(prefix):]
|
| 318 |
-
break
|
| 319 |
-
bts = b_full[:max_text]
|
| 320 |
-
arr = list(bts) + [pad_tok] * (max_text - len(bts))
|
| 321 |
-
text_ids.append(torch.tensor(arr, dtype=torch.long))
|
| 322 |
-
src = torch.stack(text_ids).to(device)
|
| 323 |
-
src_pos = torch.arange(max_text, device=device).unsqueeze(0).expand(src.size(0), -1)
|
| 324 |
-
src_pad = src.ne(pad_tok)
|
| 325 |
-
enc_self_attn_mask = (src_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
|
| 326 |
-
|
| 327 |
-
# -- Audio codes --------------------------------------------------------
|
| 328 |
-
|
| 329 |
-
max_audio = config.data.audio_length
|
| 330 |
-
# per-sample lengths (clipped to max_audio)
|
| 331 |
-
seq_lens = [min(e.size(0), max_audio) for e in encodings]
|
| 332 |
-
batch_max = max(seq_lens)
|
| 333 |
-
|
| 334 |
-
# pad or trim each encoding to the batch max length
|
| 335 |
-
padded = [pad(e, (0, 0, 0, batch_max - e.size(0))) if e.size(0) < batch_max else e[:batch_max]
|
| 336 |
-
for e in encodings]
|
| 337 |
-
codes = torch.stack(padded).to(device) # (B, T=batch_max, C)
|
| 338 |
-
|
| 339 |
-
B, T, C = codes.shape
|
| 340 |
-
t_idx, idxs = build_delay_indices(B, T, C, config.data.delay_pattern)
|
| 341 |
-
delayed = apply_audio_delay(
|
| 342 |
-
codes,
|
| 343 |
-
config.data.audio_pad_value,
|
| 344 |
-
config.data.audio_bos_value,
|
| 345 |
-
(t_idx, idxs)
|
| 346 |
-
)
|
| 347 |
-
# ensure no longer than max_audio
|
| 348 |
-
delayed = delayed[:, :max_audio, :]
|
| 349 |
-
|
| 350 |
-
# -- Targets with per-sample EOS ----------------------------------------
|
| 351 |
-
|
| 352 |
-
max_tgt_len = max_audio + 2
|
| 353 |
-
pad_val = config.data.audio_pad_value
|
| 354 |
-
bos_val = config.data.audio_bos_value
|
| 355 |
-
eos_val = config.data.audio_eos_value
|
| 356 |
-
|
| 357 |
-
tgt = torch.full((B, max_tgt_len, C), pad_val, dtype=torch.long, device=device)
|
| 358 |
-
tgt[:, 0, :] = bos_val
|
| 359 |
-
tgt_lens = []
|
| 360 |
-
for i, L in enumerate(seq_lens):
|
| 361 |
-
tgt[i, 1:1 + L, :] = delayed[i, :L, :]
|
| 362 |
-
tgt[i, 1 + L, :] = eos_val
|
| 363 |
-
tgt_lens.append(1 + L + 1)
|
| 364 |
-
|
| 365 |
-
tgt_pos = torch.arange(max_tgt_len, device=device).unsqueeze(0).expand(B, -1)
|
| 366 |
-
tgt_pad = tgt.ne(pad_val).any(-1)
|
| 367 |
-
|
| 368 |
-
causal = torch.tril(torch.ones((max_tgt_len, max_tgt_len),
|
| 369 |
-
dtype=torch.bool,
|
| 370 |
-
device=device))
|
| 371 |
-
dec_self_attn_mask = (tgt_pad.unsqueeze(2) & tgt_pad.unsqueeze(1) & causal).unsqueeze(1)
|
| 372 |
-
dec_cross_attn_mask = (tgt_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
|
| 373 |
-
|
| 374 |
-
return {
|
| 375 |
-
'src_tokens': src,
|
| 376 |
-
'src_positions': src_pos,
|
| 377 |
-
'enc_self_attn_mask': enc_self_attn_mask,
|
| 378 |
-
'tgt_tokens': tgt,
|
| 379 |
-
'tgt_positions': tgt_pos,
|
| 380 |
-
'dec_self_attn_mask': dec_self_attn_mask,
|
| 381 |
-
'dec_cross_attn_mask': dec_cross_attn_mask,
|
| 382 |
-
'waveforms': waveforms,
|
| 383 |
-
'raw_text': texts[0],
|
| 384 |
-
'tgt_lens': torch.tensor(tgt_lens, dtype=torch.long, device=device),
|
| 385 |
-
}
|
| 386 |
-
|
| 387 |
-
def setup_loaders(dataset, dia_cfg: DiaConfig, train_cfg: TrainConfig, device):
|
| 388 |
-
collate = lambda b: collate_fn(b, dia_cfg, device)
|
| 389 |
-
if isinstance(dataset, HFDiaIterDataset):
|
| 390 |
-
total = getattr(dataset, "total_examples", None)
|
| 391 |
-
if total is None:
|
| 392 |
-
total = dataset.dataset.info.splits["train"].num_examples
|
| 393 |
-
n_train = int(train_cfg.split_ratio * total)
|
| 394 |
-
n_val = total - n_train
|
| 395 |
-
if n_val <= 0:
|
| 396 |
-
raise RuntimeError(f"No validation samples: total={total}, split_ratio={train_cfg.split_ratio}")
|
| 397 |
-
base = dataset.dataset.shuffle(buffer_size=train_cfg.shuffle_buffer_size, seed=train_cfg.seed) if train_cfg.shuffle_buffer_size else dataset.dataset
|
| 398 |
-
val_stream = base.take(n_val)
|
| 399 |
-
train_stream = base.skip(n_val)
|
| 400 |
-
train_ds = HFDiaIterDataset(train_stream, dia_cfg, dataset.dac_model)
|
| 401 |
-
val_ds = HFDiaIterDataset(val_stream, dia_cfg, dataset.dac_model)
|
| 402 |
-
train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=False, collate_fn=collate)
|
| 403 |
-
train_loader.steps_per_epoch = math.ceil(n_train / train_cfg.batch_size)
|
| 404 |
-
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
|
| 405 |
-
return train_loader, val_loader
|
| 406 |
-
ds_len = len(dataset)
|
| 407 |
-
n_train = int(train_cfg.split_ratio * ds_len)
|
| 408 |
-
train_ds, val_ds = random_split(dataset, [n_train, ds_len - n_train])
|
| 409 |
-
train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=True, collate_fn=collate)
|
| 410 |
-
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
|
| 411 |
-
return train_loader, val_loader
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
def setup_optimizer_and_scheduler(model, train_loader, train_cfg):
|
| 416 |
-
opt = bnb.optim.AdamW8bit(model.parameters(), lr=train_cfg.learning_rate)
|
| 417 |
-
# Determine steps per epoch: prefer len(), else use attached attribute
|
| 418 |
-
try:
|
| 419 |
-
steps_per_epoch = len(train_loader)
|
| 420 |
-
except TypeError:
|
| 421 |
-
if hasattr(train_loader, 'steps_per_epoch'):
|
| 422 |
-
steps_per_epoch = train_loader.steps_per_epoch
|
| 423 |
-
else:
|
| 424 |
-
raise RuntimeError("Cannot determine steps_per_epoch for streaming loader")
|
| 425 |
-
total_training_steps = steps_per_epoch * train_cfg.epochs
|
| 426 |
-
sched = get_scheduler(
|
| 427 |
-
'cosine', opt,
|
| 428 |
-
num_warmup_steps=train_cfg.warmup_steps / train_cfg.grad_accum_steps,
|
| 429 |
-
num_training_steps=total_training_steps / train_cfg.grad_accum_steps
|
| 430 |
-
)
|
| 431 |
-
return opt, sched
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
def train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step,scaler):
|
| 436 |
-
"""
|
| 437 |
-
Perform a single training step: forward, loss, backward, update, log.
|
| 438 |
-
Now uses per‑sample tgt_lens to mask out padding after each EOS,
|
| 439 |
-
and applies 4× loss weight on the first channel.
|
| 440 |
-
"""
|
| 441 |
-
# (optional) unconditional conditioning
|
| 442 |
-
if random.random() < train_cfg.unconditional_frac:
|
| 443 |
-
pad_tok = dia_cfg.data.text_pad_value
|
| 444 |
-
batch['src_tokens'] = torch.zeros_like(batch['src_tokens'])
|
| 445 |
-
batch['enc_self_attn_mask'] = torch.zeros_like(batch['enc_self_attn_mask'])
|
| 446 |
-
batch['dec_cross_attn_mask'] = torch.zeros_like(batch['dec_cross_attn_mask'])
|
| 447 |
-
|
| 448 |
-
with autocast(dtype=torch.float16):
|
| 449 |
-
# forward pass
|
| 450 |
-
logits = model(
|
| 451 |
-
src_BxS=batch['src_tokens'],
|
| 452 |
-
tgt_BxTxC=batch['tgt_tokens'],
|
| 453 |
-
src_positions=batch['src_positions'],
|
| 454 |
-
tgt_positions=batch['tgt_positions'],
|
| 455 |
-
enc_self_attn_mask=batch['enc_self_attn_mask'],
|
| 456 |
-
dec_self_attn_mask=batch['dec_self_attn_mask'],
|
| 457 |
-
dec_cross_attn_mask=batch['dec_cross_attn_mask'],
|
| 458 |
-
enable_dropout=True,
|
| 459 |
-
)
|
| 460 |
-
# fetch per-sample target‑lengths (including BOS+frames+EOS)
|
| 461 |
-
lens = batch['tgt_lens'] # shape: (B,)
|
| 462 |
-
max_L = int(lens.max().item()) # maximum over batch
|
| 463 |
-
|
| 464 |
-
# keep only up through the last possible EOS slot
|
| 465 |
-
# logits: (B, T, C, V) -> (B, max_L-1, C, V)
|
| 466 |
-
logits = logits[:, : max_L - 1]
|
| 467 |
-
|
| 468 |
-
# targets: shift off the BOS so 0..<max_L-1> align with logits
|
| 469 |
-
# target: (B, T, C) -> (B, max_L-1, C)
|
| 470 |
-
target = batch['tgt_tokens'][:, 1:max_L, :]
|
| 471 |
-
|
| 472 |
-
B, Tm1, C = target.shape
|
| 473 |
-
pad_val = dia_cfg.data.audio_pad_value
|
| 474 |
-
|
| 475 |
-
# build a mask [B x (max_L-1)] that is True for t < (lens[i]-1)
|
| 476 |
-
time_idx = torch.arange(Tm1, device=lens.device).unsqueeze(0) # (1, Tm1)
|
| 477 |
-
valid_time = time_idx < (lens.unsqueeze(1) - 1) # (B, Tm1)
|
| 478 |
-
mask = valid_time.unsqueeze(-1).expand(-1, -1, C) # (B, Tm1, C)
|
| 479 |
-
|
| 480 |
-
# apply 4× weight on first channel, 1× on others
|
| 481 |
-
channel_weights = [4.0] + [1.0] * (C - 1)
|
| 482 |
-
loss_c = 0.0
|
| 483 |
-
_, _, _, V = logits.size()
|
| 484 |
-
|
| 485 |
-
for c, w in enumerate(channel_weights):
|
| 486 |
-
# flatten this channel
|
| 487 |
-
lc = logits[:, :, c, :].reshape(-1, V) # (B*Tm1, V)
|
| 488 |
-
tc = target[:, :, c].reshape(-1) # (B*Tm1,)
|
| 489 |
-
mc = mask[:, :, c].reshape(-1) # (B*Tm1,)
|
| 490 |
-
|
| 491 |
-
# mask out padding and compute cross-entropy
|
| 492 |
-
lc_valid = lc[mc]
|
| 493 |
-
tc_valid = tc[mc]
|
| 494 |
-
loss_c += w * F.cross_entropy(
|
| 495 |
-
lc_valid, tc_valid,
|
| 496 |
-
ignore_index=pad_val
|
| 497 |
-
)
|
| 498 |
-
|
| 499 |
-
# normalize by sum of weights
|
| 500 |
-
loss = loss_c / sum(channel_weights)
|
| 501 |
-
|
| 502 |
-
# scale + backward
|
| 503 |
-
loss = loss / train_cfg.grad_accum_steps
|
| 504 |
-
scaler.scale(loss).backward()
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
# step & log
|
| 508 |
-
|
| 509 |
-
if (step_in_epoch + 1) % train_cfg.grad_accum_steps == 0:
|
| 510 |
-
# Unscale before clipping
|
| 511 |
-
scaler.unscale_(opt)
|
| 512 |
-
grad_norm = clip_grad_norm_(model.parameters(), max_norm=1e9)
|
| 513 |
-
|
| 514 |
-
scaler.step(opt)
|
| 515 |
-
scaler.update()
|
| 516 |
-
opt.zero_grad()
|
| 517 |
-
sched.step()
|
| 518 |
-
|
| 519 |
-
true_loss = loss.item() * train_cfg.grad_accum_steps
|
| 520 |
-
current_lr = sched.get_last_lr()[0]
|
| 521 |
-
|
| 522 |
-
writer.add_scalar('GradNorm/global', grad_norm, global_step)
|
| 523 |
-
writer.add_scalar('LR', current_lr, global_step)
|
| 524 |
-
writer.add_scalar('Loss/train', true_loss, global_step)
|
| 525 |
-
|
| 526 |
-
return true_loss
|
| 527 |
-
else:
|
| 528 |
-
return loss.item() * train_cfg.grad_accum_steps
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
def eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step):
|
| 534 |
-
"""
|
| 535 |
-
Run evaluation: compute average loss on validation set and log audio samples.
|
| 536 |
-
"""
|
| 537 |
-
import gc
|
| 538 |
-
eval_losses = []
|
| 539 |
-
last_batch = None
|
| 540 |
-
with torch.inference_mode():
|
| 541 |
-
for eb in tqdm(val_loader, desc="eval"):
|
| 542 |
-
last_batch = eb
|
| 543 |
-
|
| 544 |
-
with autocast(dtype=torch.float16):
|
| 545 |
-
logits16 = model(
|
| 546 |
-
src_BxS=eb['src_tokens'],
|
| 547 |
-
tgt_BxTxC=eb['tgt_tokens'],
|
| 548 |
-
src_positions=eb['src_positions'],
|
| 549 |
-
tgt_positions=eb['tgt_positions'],
|
| 550 |
-
enc_self_attn_mask=eb['enc_self_attn_mask'],
|
| 551 |
-
dec_self_attn_mask=eb['dec_self_attn_mask'],
|
| 552 |
-
dec_cross_attn_mask=eb['dec_cross_attn_mask'],
|
| 553 |
-
enable_dropout=False,
|
| 554 |
-
)[:, :-1]
|
| 555 |
-
|
| 556 |
-
logits = logits16.float()
|
| 557 |
-
target = eb['tgt_tokens'][:, 1:]
|
| 558 |
-
B_e, T_e, C_e = target.shape
|
| 559 |
-
V_e = logits.size(-1)
|
| 560 |
-
|
| 561 |
-
loss_e = 0.0
|
| 562 |
-
weights_e = [4.0] + [1.0] * (C_e - 1)
|
| 563 |
-
for c, w in enumerate(weights_e):
|
| 564 |
-
lc = logits[:, :, c, :].reshape(-1, V_e)
|
| 565 |
-
tc = target[:, :, c].reshape(-1)
|
| 566 |
-
loss_e += w * F.cross_entropy(
|
| 567 |
-
lc, tc, ignore_index=dia_cfg.data.audio_pad_value
|
| 568 |
-
)
|
| 569 |
-
loss_e = loss_e / sum(weights_e)
|
| 570 |
-
|
| 571 |
-
eval_losses.append(loss_e)
|
| 572 |
-
|
| 573 |
-
avg_eval_loss = sum(eval_losses) / len(eval_losses)
|
| 574 |
-
writer.add_scalar('Loss/eval', avg_eval_loss.item(), global_step)
|
| 575 |
-
|
| 576 |
-
# --- Inference test sentence ---
|
| 577 |
-
try:
|
| 578 |
-
orig_dtype = next(model.parameters()).dtype
|
| 579 |
-
model = model.float()
|
| 580 |
-
dia_gen = Dia(dia_cfg, device)
|
| 581 |
-
dia_gen.model, dia_gen.dac_model = model, dac_model
|
| 582 |
-
|
| 583 |
-
# ✅ Test câu hội thoại đa giọng
|
| 584 |
-
test_dialogue = "[vtv24] Em vừa đi học về, anh ạ. [duongfg] Ừ, em ăn cơm chưa? [vtv24] Em ăn rồi!"
|
| 585 |
-
|
| 586 |
-
if len(test_dialogue) > 10:
|
| 587 |
-
try:
|
| 588 |
-
audio = dia_gen.generate(text=test_dialogue)
|
| 589 |
-
writer.add_audio("Eval/test_dialogue", audio, global_step, 44100)
|
| 590 |
-
except Exception:
|
| 591 |
-
logger.exception("Eval error during test_dialogue")
|
| 592 |
-
finally:
|
| 593 |
-
if 'audio' in locals():
|
| 594 |
-
del audio
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
except Exception:
|
| 598 |
-
logger.exception("Eval error")
|
| 599 |
-
|
| 600 |
-
finally:
|
| 601 |
-
if 'audio' in locals():
|
| 602 |
-
del audio
|
| 603 |
-
gc.collect()
|
| 604 |
-
torch.cuda.empty_cache()
|
| 605 |
-
if orig_dtype == torch.float16:
|
| 606 |
-
model = model.half()
|
| 607 |
-
|
| 608 |
-
def train(model, dia_cfg: DiaConfig, dac_model: dac.DAC, dataset, train_cfg: TrainConfig):
|
| 609 |
-
"""
|
| 610 |
-
Run the full training loop over epochs.
|
| 611 |
-
"""
|
| 612 |
-
# prepare directories
|
| 613 |
-
train_cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
| 614 |
-
(train_cfg.runs_dir / train_cfg.run_name).mkdir(parents=True, exist_ok=True)
|
| 615 |
-
model = model.to(device)
|
| 616 |
-
|
| 617 |
-
train_loader, val_loader = setup_loaders(dataset, dia_cfg, train_cfg, device)
|
| 618 |
-
opt, sched = setup_optimizer_and_scheduler(model, train_loader, train_cfg)
|
| 619 |
-
|
| 620 |
-
writer = SummaryWriter(train_cfg.runs_dir / train_cfg.run_name)
|
| 621 |
-
model.train()
|
| 622 |
-
scaler = GradScaler()
|
| 623 |
-
start_epoch = 0
|
| 624 |
-
global_step = 0
|
| 625 |
-
resume_ckpt = getattr(train_cfg, "resume_from", None)
|
| 626 |
-
if resume_ckpt and resume_ckpt.exists():
|
| 627 |
-
logger.info(f"Resuming from checkpoint: {resume_ckpt}")
|
| 628 |
-
checkpoint = torch.load(resume_ckpt, map_location=device)
|
| 629 |
-
model.load_state_dict(checkpoint["model"])
|
| 630 |
-
opt.load_state_dict(checkpoint["optimizer"])
|
| 631 |
-
sched.load_state_dict(checkpoint["scheduler"])
|
| 632 |
-
scaler.load_state_dict(checkpoint["scaler"])
|
| 633 |
-
start_epoch = checkpoint.get("epoch", 0)
|
| 634 |
-
global_step = checkpoint.get("global_step", 0)
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
steps_per_epoch = getattr(train_loader, 'steps_per_epoch', None)
|
| 638 |
-
if steps_per_epoch is None:
|
| 639 |
-
try:
|
| 640 |
-
steps_per_epoch = len(train_loader)
|
| 641 |
-
except Exception:
|
| 642 |
-
steps_per_epoch = None
|
| 643 |
-
|
| 644 |
-
for epoch in range(start_epoch, train_cfg.epochs):
|
| 645 |
-
# iterate with progress bar, using total if known
|
| 646 |
-
loader_iter = tqdm(
|
| 647 |
-
train_loader,
|
| 648 |
-
desc=f"E{epoch+1}",
|
| 649 |
-
total=steps_per_epoch
|
| 650 |
-
)
|
| 651 |
-
pbar = tqdm(loader_iter, total=train_cfg.total_steps, initial=global_step, desc=f"E{epoch}")
|
| 652 |
-
for step_in_epoch, batch in enumerate(pbar):
|
| 653 |
-
global_step += 1
|
| 654 |
-
# training step
|
| 655 |
-
loss = train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step, scaler)
|
| 656 |
-
|
| 657 |
-
cur_alloc = torch.cuda.memory_allocated() # bytes currently allocated by tensors
|
| 658 |
-
peak_alloc = torch.cuda.max_memory_allocated() # bytes peak during program
|
| 659 |
-
# optionally convert to GB
|
| 660 |
-
cur_gb = cur_alloc / 1024**3
|
| 661 |
-
peak_gb = peak_alloc / 1024**3
|
| 662 |
-
|
| 663 |
-
# update the tqdm postfix
|
| 664 |
-
loader_iter.set_postfix({
|
| 665 |
-
'loss': f"{loss:.4f}",
|
| 666 |
-
'VRAM (GB)': f"{cur_gb:.2f}/{peak_gb:.2f}"
|
| 667 |
-
})
|
| 668 |
-
|
| 669 |
-
# remember to zero the peak if you want rolling peaks per step
|
| 670 |
-
if torch.cuda.is_available():
|
| 671 |
-
torch.cuda.reset_peak_memory_stats()
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
# evaluation
|
| 675 |
-
if step_in_epoch % train_cfg.eval_step == 0:
|
| 676 |
-
model.eval()
|
| 677 |
-
with torch.no_grad():
|
| 678 |
-
eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step)
|
| 679 |
-
model.train()
|
| 680 |
-
scaler = GradScaler()
|
| 681 |
-
|
| 682 |
-
# checkpoint
|
| 683 |
-
if step_in_epoch and step_in_epoch % train_cfg.save_step == 0:
|
| 684 |
-
ckpt = train_cfg.output_dir / f"ckpt_step{global_step}.pth"
|
| 685 |
-
torch.save({
|
| 686 |
-
"model": model.state_dict(),
|
| 687 |
-
"optimizer": opt.state_dict(),
|
| 688 |
-
"scheduler": sched.state_dict(),
|
| 689 |
-
"scaler": scaler.state_dict(),
|
| 690 |
-
"epoch": epoch,
|
| 691 |
-
"global_step": global_step
|
| 692 |
-
}, ckpt)
|
| 693 |
-
logger.info(f"Saved checkpoint: {ckpt}")
|
| 694 |
-
|
| 695 |
-
# end of epoch checkpoint
|
| 696 |
-
ckpt_e = train_cfg.output_dir / f"ckpt_epoch{epoch+1}.pth"
|
| 697 |
-
torch.save({
|
| 698 |
-
"model": model.state_dict(),
|
| 699 |
-
"optimizer": opt.state_dict(),
|
| 700 |
-
"scheduler": sched.state_dict(),
|
| 701 |
-
"scaler": scaler.state_dict(),
|
| 702 |
-
"epoch": epoch + 1,
|
| 703 |
-
"global_step": global_step
|
| 704 |
-
}, ckpt_e)
|
| 705 |
-
logger.info(f"Saved end-of-epoch checkpoint: {ckpt_e}")
|
| 706 |
-
|
| 707 |
-
from datasets import disable_caching
|
| 708 |
-
|
| 709 |
-
def main():
|
| 710 |
-
args = get_args()
|
| 711 |
-
import os
|
| 712 |
-
os.environ["HF_DATASETS_CACHE"] = "/tmp/force_streaming" # ép cache mới
|
| 713 |
-
disable_caching()
|
| 714 |
-
# tắt toàn bộ cache local HuggingFace
|
| 715 |
-
import json
|
| 716 |
-
with open(args.config, "r", encoding="utf-8") as f:
|
| 717 |
-
config_dict = json.load(f)
|
| 718 |
-
|
| 719 |
-
dia_cfg = DiaConfig(**config_dict)
|
| 720 |
-
dac_model = dac.DAC.load(dac.utils.download()).to(device)
|
| 721 |
-
dataset = None
|
| 722 |
-
|
| 723 |
-
if not dataset:
|
| 724 |
-
if args.csv_path:
|
| 725 |
-
if not args.audio_root:
|
| 726 |
-
raise ValueError("`--audio_root` must be set when using `--csv_path`")
|
| 727 |
-
dataset = LocalDiaDataset(args.csv_path, args.audio_root, dia_cfg, dac_model)
|
| 728 |
-
else:
|
| 729 |
-
# ✅ Check nếu dataset là đường dẫn local
|
| 730 |
-
if Path(args.dataset).exists():
|
| 731 |
-
print(f"Loading dataset from local path: {args.dataset}")
|
| 732 |
-
ds1 = load_from_disk(args.dataset)
|
| 733 |
-
if isinstance(ds1, DatasetDict):
|
| 734 |
-
ds1 = ds1["train"]
|
| 735 |
-
dataset = HFDiaDataset(ds1, dia_cfg, dac_model)
|
| 736 |
-
else:
|
| 737 |
-
print(f"Loading HuggingFace dataset: {args.dataset} (streaming)")
|
| 738 |
-
ds1 = load_dataset(args.dataset, split="train", streaming=True)
|
| 739 |
-
|
| 740 |
-
if args.dataset2:
|
| 741 |
-
ds2 = load_dataset(args.dataset2, split="train", streaming=True)
|
| 742 |
-
hf_ds = interleave_datasets([ds1, ds2])
|
| 743 |
-
dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
|
| 744 |
-
else:
|
| 745 |
-
hf_ds = ds1
|
| 746 |
-
dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
train_cfg = TrainConfig(
|
| 751 |
-
run_name = args.run_name or TrainConfig.run_name,
|
| 752 |
-
output_dir = args.output_dir or TrainConfig.output_dir,
|
| 753 |
-
shuffle_buffer_size = args.shuffle_buffer_size,
|
| 754 |
-
seed = args.seed,
|
| 755 |
-
)
|
| 756 |
-
if args.resume_from:
|
| 757 |
-
train_cfg.resume_from = Path(args.resume_from)
|
| 758 |
-
# load model checkpoint
|
| 759 |
-
if args.local_ckpt:
|
| 760 |
-
ckpt_file = args.local_ckpt
|
| 761 |
-
else:
|
| 762 |
-
ckpt_file = hf_hub_download(args.hub_model, filename="dia-v0_1.pth")
|
| 763 |
-
model = DiaModel(dia_cfg)
|
| 764 |
-
if args.half:
|
| 765 |
-
model=model.half()
|
| 766 |
-
if args.compile:
|
| 767 |
-
model = torch.compile(model, backend="inductor")
|
| 768 |
-
ckpt = torch.load(ckpt_file, map_location="cpu")
|
| 769 |
-
state_dict = ckpt["model"] if "model" in ckpt else ckpt
|
| 770 |
-
new_state_dict = {}
|
| 771 |
-
|
| 772 |
-
for k, v in state_dict.items():
|
| 773 |
-
if "encoder.embedding.weight" in k:
|
| 774 |
-
if v.shape != model.state_dict()[k].shape:
|
| 775 |
-
print(f"⚠️ Bỏ qua {k} do shape không khớp: {v.shape} vs {model.state_dict()[k].shape}")
|
| 776 |
-
continue
|
| 777 |
-
new_state_dict[k] = v
|
| 778 |
-
|
| 779 |
-
model.load_state_dict(new_state_dict, strict=False)
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
# start training
|
| 783 |
-
train(model, dia_cfg, dac_model, dataset, train_cfg)
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
if __name__ == "__main__":
|
| 787 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|