Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import random | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from functools import reduce | |
| from itertools import accumulate | |
| from random import choices | |
| from typing import List, Optional, Sequence, Tuple | |
| import pytorch_lightning as ptl | |
| import torch | |
| from datasets import (DatasetDict, concatenate_datasets, load_dataset, | |
| load_from_disk) | |
| from einops import rearrange | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data import (BatchSampler, DataLoader, Sampler, | |
| SubsetRandomSampler) | |
| from transformers import PreTrainedTokenizerFast | |
| from tts.tools import (audio_to_text_partial_neighbor_mask, packmask_2d, | |
| pad_2d_sequence, sequence_mask) | |
| class BucketSampler(Sampler[List[int]]): | |
| def __init__( | |
| self, | |
| buckets: List[List[int]], | |
| batch_sizes: List[int] | int, | |
| bucket_sampling_weights: List[Tuple[float]] = None, | |
| drop_last: bool = True, | |
| distributed: bool = True, # TODO - implement not distributed as well | |
| sample_bucket: Optional[int] = None, | |
| seed: int = 123, | |
| epoch_seed: bool = True, | |
| ): | |
| if type(batch_sizes) is int: | |
| batch_sizes = [batch_sizes] * len(buckets) | |
| else: | |
| assert len(buckets) == len(batch_sizes) | |
| if bucket_sampling_weights is not None: | |
| assert len(bucket_sampling_weights) == len(batch_sizes) | |
| self.bucket_sampling_weights = bucket_sampling_weights | |
| self.num_replicas = torch.distributed.get_world_size() | |
| self.rank = torch.distributed.get_rank() | |
| self.buckets = [ | |
| b[self.rank : len(b) - len(b) % self.num_replicas : self.num_replicas] | |
| for b in buckets | |
| ] | |
| self.num_samples = [len(b) // self.num_replicas for b in buckets] | |
| self.batch_sizes = batch_sizes | |
| self.total_sizes = [ | |
| ns // bs for ns, bs in zip(self.num_samples, self.batch_sizes) | |
| ] | |
| self.drop_last = drop_last | |
| self.seed = seed | |
| self.epoch = 0 | |
| self.sample_bucket = sample_bucket | |
| self.epoch_seed = epoch_seed | |
| self.batch_size = batch_sizes[0] | |
| def set_epoch(self, epoch: int): | |
| self.epoch = epoch | |
| def __len__(self): | |
| return sum(self.total_sizes) | |
| def __iter__(self): | |
| generator = torch.Generator() | |
| generator.manual_seed(self.seed + self.epoch * self.epoch_seed + self.rank) | |
| pool = [ | |
| BatchSampler( | |
| SubsetRandomSampler(b, generator=generator), | |
| bs, | |
| drop_last=self.drop_last, | |
| ) | |
| for b, bs in zip(self.buckets, self.batch_sizes) | |
| ] | |
| pool = [iter(b) for b in pool] | |
| weights = ( | |
| [w for w in self.bucket_sampling_weights] | |
| if self.bucket_sampling_weights is not None | |
| else None | |
| ) | |
| while pool: # sample until all buckets are done | |
| idx, bucket = choices(list(enumerate(pool)), weights=weights)[0] | |
| try: | |
| batch = next(bucket) | |
| yield batch | |
| except StopIteration: | |
| pool.pop(idx) # if bucket is done, throw it | |
| if weights is not None: | |
| weights.pop(idx) | |
| class DatasetFactory(ABC): | |
| def build(self): | |
| pass | |
| class HiFiTTS2_AudioLatent(DatasetFactory): | |
| def __init__( | |
| self, | |
| path: str | list[str] = "hifitts2_vae8_dataset", | |
| duration_column: str = "audio_duration", | |
| duration_path: str | None = None, | |
| expresso_path: str | None = None, | |
| min_dur: float = 3.0, | |
| max_dur: float = 20.1, | |
| framerate: float = 25.0, | |
| ): | |
| self.min_dur = min_dur | |
| self.max_dur = max_dur | |
| self.path = path | |
| self.duration_column = duration_column | |
| self.duration_path = duration_path | |
| self.framerate = framerate | |
| self.expresso_path = expresso_path | |
| def build(self): | |
| if type(self.path) is str: | |
| self.path = [self.path] | |
| datasets = [load_from_disk(x) for x in self.path] | |
| dataset = concatenate_datasets(datasets).with_format("torch") | |
| if self.duration_path is not None: | |
| duration_dataset = load_from_disk(self.duration_path) | |
| dataset = concatenate_datasets( | |
| [dataset, duration_dataset], axis=1 | |
| ).with_format("torch") | |
| dataset = dataset.filter( | |
| lambda dur: dur > self.min_dur and dur < self.max_dur, | |
| input_columns=self.duration_column, | |
| ) | |
| dataset = dataset.rename_column(self.duration_column, "audio_duration") | |
| # dataset = dataset.map( | |
| # lambda x: {"audio_duration": x.shape[1] / self.framerate}, | |
| # input_columns="audio_latent", | |
| # ).filter( | |
| # lambda dur: dur > self.min_dur and dur < self.max_dur, | |
| # input_columns="audio_duration", | |
| # ) | |
| if self.expresso_path is not None: | |
| expresso_dataset = load_from_disk(self.expresso_path).with_format("torch") | |
| dataset = dataset.sort("audio_duration") | |
| return DatasetDict({"train": dataset}) | |
| class SegmentsCollateArgs: | |
| abs_style_intensity: bool = False | |
| merge_endpoints: bool = True | |
| block_crossatt_mask: bool = True | |
| alternate_crossatt_pos: bool = False | |
| block_crossatt_past_tokens: int = 0 | |
| block_crossatt_future_tokens: int = 0 | |
| eos: bool = True | |
| bos: bool = True | |
| class CollateArgs: | |
| abs_style_intensity: bool = False | |
| random_text_segment: bool = False | |
| eos: bool = True | |
| bos: bool = True | |
| num_stop_tokens: int = 1 | |
| def random_log_breakpoints( | |
| seq: Sequence, a: int, b: int, gap: bool = False | |
| ) -> List[int]: | |
| """ | |
| Generate random breakpoints in a sequence where the gap X between | |
| successive breakpoints satisfies log2(X) ~ Uniform[log2(a), log2(b)]. | |
| Gaps are then rounded to the nearest integer in [a, b]. | |
| Parameters | |
| ---------- | |
| seq : Sequence | |
| The input sequence in which to place breakpoints. | |
| a : int | |
| Minimum gap (>= 1). | |
| b : int | |
| Maximum gap (>= a). | |
| Returns | |
| ------- | |
| List[int] | |
| Sorted list of breakpoint indices (0 < idx < len(seq)). | |
| """ | |
| if a < 1 or b < a: | |
| raise ValueError("Require 1 <= a <= b") | |
| n = len(seq) | |
| breakpoints: List[int] = [] | |
| pos = 0 | |
| while True: | |
| # sample U ~ Uniform(log2(a), log2(b)) | |
| u = random.uniform(math.log2(a), math.log2(b)) | |
| # map back to X = 2^U, then round to nearest integer | |
| x = 2**u | |
| gap = int(math.floor(x + 0.5)) | |
| # enforce integer bounds exactly | |
| gap = max(a, min(b, gap)) | |
| pos += gap | |
| if pos >= n: | |
| if gap: | |
| breakpoints.append(n - sum(breakpoints)) | |
| break | |
| if gap: | |
| breakpoints.append(gap) | |
| else: | |
| breakpoints.append(pos) | |
| return breakpoints | |
| def standalone_collate_latent( | |
| batch, | |
| tokenizer, | |
| abs_style_intensity: bool = False, | |
| random_text_segment: bool = False, | |
| bos: bool = True, | |
| eos: bool = True, | |
| num_stop_tokens: int = 1, | |
| ): | |
| audio_latent, text = zip(*[(x["audio_latent"], x["text"]) for x in batch]) | |
| audio_latent = [x.squeeze() for x in audio_latent] | |
| text_pp = [] | |
| for t in text: | |
| if bos: | |
| t = "[BOS]" + t | |
| if eos: | |
| t = t + "[EOS]" | |
| text_pp.append(t) | |
| text_token = [torch.LongTensor(tokenizer.encode(x)) for x in text_pp] | |
| xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent)) | |
| stop_token = [] | |
| text_stop_token = [] | |
| for x, y in zip(xlen, ylen): | |
| tst = torch.zeros(x) | |
| st = torch.zeros(y) | |
| st_idx = random.randint(1, num_stop_tokens) | |
| st[-1] = st_idx | |
| tst[-1] = st_idx | |
| stop_token.append(st) | |
| text_stop_token.append(tst) | |
| stop_token = pad_sequence(stop_token, batch_first=True).long() | |
| text_stop_token = pad_sequence(text_stop_token, batch_first=True).long() | |
| x_mask, y_mask = map( | |
| lambda x: sequence_mask(x, device="cpu"), | |
| (torch.tensor(xlen), torch.tensor(ylen)), | |
| ) | |
| text_rel_pos = None | |
| if random_text_segment: | |
| breakpoints = [random_log_breakpoints(t, 32, 256, gap=True) for t in text_token] | |
| encoder_mask = pad_2d_sequence([packmask_2d(b, b) for b in breakpoints]) | |
| text_rel_pos = [torch.cat([torch.arange(bb) for bb in b]) for b in breakpoints] | |
| text_rel_pos = pad_sequence(text_rel_pos, batch_first=True) | |
| else: | |
| encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) | |
| crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) | |
| audio_latent, text_token = map( | |
| lambda x: pad_sequence(x, batch_first=True, padding_value=0.0), | |
| (audio_latent, text_token), | |
| ) | |
| if abs_style_intensity: | |
| abs_style_intensity = [x["abs_style_intensity"] for x in batch] | |
| abs_style_intensity = [ | |
| torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity | |
| ] | |
| abs_style_intensity = torch.stack(abs_style_intensity) | |
| else: | |
| abs_style_intensity = None | |
| return { | |
| "text_token": text_token, | |
| "audio_token": audio_latent, | |
| "crossatt_mask": crossatt_mask, | |
| "encoder_mask": encoder_mask, | |
| "y_mask": y_mask, | |
| "stop_token": stop_token, | |
| "text_stop_token": text_stop_token, | |
| "x_len": xlen, | |
| "y_len": ylen, | |
| "abs_style_intensity": abs_style_intensity, | |
| "text_rel_pos": text_rel_pos, | |
| } | |
| def standalone_collate_latent_segments( | |
| batch, | |
| tokenizer, | |
| abs_style_intensity: bool = False, | |
| merge_endpoints: bool = True, | |
| block_crossatt_mask: bool = True, | |
| block_crossatt_past_tokens: int = 0, | |
| block_crossatt_future_tokens: int = 0, | |
| alternate_crossatt_pos: bool = False, | |
| alternate_crossatt_shift: int = 1000, | |
| eos: bool = True, | |
| bos: bool = True, | |
| ): | |
| audio_latent, text, token_duration = zip( | |
| *[(x["audio_latent"], x["text"], x["token_duration"]) for x in batch] | |
| ) | |
| text_pp = [] | |
| for t in text: | |
| if bos: | |
| t = "[BOS]" + t | |
| if eos: | |
| t = t + "[EOS]" | |
| text_pp.append(t) | |
| if merge_endpoints: | |
| tokens = [tokenizer.encode(x) for x in text] | |
| new_td = [] | |
| for td in token_duration: | |
| begin, end = td[0], td[-1] | |
| tdd = td[1:-1] | |
| tdd[0] += begin | |
| tdd[-1] += end | |
| new_td.append(tdd) | |
| token_duration = new_td | |
| else: | |
| tokens = [tokenizer.encode(x) for x in text_pp] | |
| segments = [ | |
| random_segments_from_text_and_durations(t, td.tolist()) | |
| for t, td in zip(tokens, token_duration) | |
| ] | |
| bos, eos = map(tokenizer.encode, ("[BOS]", "[EOS]")) | |
| audio_segments = [] | |
| text_segments = [] | |
| audio_segments_len = [] | |
| text_segments_len = [] | |
| for aud, seg in zip(audio_latent, segments): | |
| tt, at, tt_l, at_l = [], [], [], [] | |
| for i, s in enumerate(seg): | |
| ttoken = s["text_token"] | |
| if bos: | |
| ttoken = bos + ttoken | |
| if eos: | |
| ttoken = ttoken + eos | |
| tt.append(ttoken) | |
| a_s = aud[:, s["start"] : s["end"]] | |
| at.append(a_s) | |
| at_l.append(a_s.shape[1]) | |
| tt_l.append(len(ttoken)) | |
| audio_segments.append(at) | |
| text_segments.append(tt) | |
| audio_segments_len.append(at_l) | |
| text_segments_len.append(tt_l) | |
| text_token = [torch.LongTensor(reduce(list.__add__, x)) for x in text_segments] | |
| audio_latent = [torch.cat(a_ss, dim=1).squeeze(0) for a_ss in audio_segments] | |
| xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent)) | |
| x_mask, y_mask = map( | |
| lambda x: sequence_mask(x, device="cpu"), | |
| (torch.tensor(xlen), torch.tensor(ylen)), | |
| ) | |
| audio_latent, text_token = map( | |
| lambda x: pad_sequence(x, batch_first=True, padding_value=0), | |
| (audio_latent, text_token), | |
| ) | |
| encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) | |
| if block_crossatt_mask: | |
| crossatt_mask = [ | |
| audio_to_text_partial_neighbor_mask( | |
| x, | |
| y, | |
| past_tokens=block_crossatt_past_tokens, | |
| future_tokens=block_crossatt_future_tokens, | |
| ) | |
| for x, y in zip(text_segments_len, audio_segments_len) | |
| ] | |
| crossatt_mask = pad_2d_sequence(crossatt_mask) | |
| pad_mask = rearrange(torch.arange(max(ylen)), "n -> 1 n 1") >= rearrange( | |
| torch.tensor(ylen), "n -> n 1 1" | |
| ) | |
| else: | |
| crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) | |
| text_rel_pos = pad_sequence( | |
| [torch.cat([torch.arange(x) for x in tsl]) for tsl in text_segments_len], | |
| batch_first=True, | |
| ) | |
| crossatt_rel_pos = None | |
| if alternate_crossatt_pos: | |
| crossatt_rel_pos = [] | |
| for tsl in text_segments_len: | |
| rel_pos = [] | |
| random_shift = int(random.random() < 0.5) | |
| for i, x in enumerate(tsl): | |
| rel_pos.append( | |
| torch.arange(x) | |
| + ((random_shift + i) % 2) * alternate_crossatt_shift | |
| ) | |
| crossatt_rel_pos.append(torch.cat(rel_pos)) | |
| crossatt_rel_pos = pad_sequence(crossatt_rel_pos, batch_first=True) | |
| audio_rel_pos = pad_sequence( | |
| [torch.cat([torch.arange(x) for x in asl]) for asl in audio_segments_len], | |
| batch_first=True, | |
| ) | |
| stop_token = [] | |
| for asl in audio_segments_len: | |
| sts = [] | |
| for x in asl: | |
| st = torch.zeros(x) | |
| st[-1] = 1 | |
| sts.append(st) | |
| stop_token.append(torch.cat(sts)) | |
| stop_token = pad_sequence(stop_token, batch_first=True).int() | |
| text_stop_token = [] | |
| for asl in text_segments_len: | |
| sts = [] | |
| for x in asl: | |
| st = torch.zeros(x) | |
| st[-1] = 1 | |
| sts.append(st) | |
| text_stop_token.append(torch.cat(sts)) | |
| text_stop_token = pad_sequence(text_stop_token, batch_first=True).int() | |
| if abs_style_intensity: | |
| abs_style_intensity = [x["abs_style_intensity"] for x in batch] | |
| abs_style_intensity = [ | |
| torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity | |
| ] | |
| abs_style_intensity = torch.stack(abs_style_intensity) | |
| else: | |
| abs_style_intensity = None | |
| return { | |
| "text_token": text_token, | |
| "audio_token": audio_latent, | |
| "crossatt_mask": crossatt_mask, | |
| "encoder_mask": encoder_mask, | |
| "y_mask": y_mask, | |
| "stop_token": stop_token, | |
| "text_stop_token": text_stop_token, | |
| "x_mask": x_mask, | |
| "x_len": xlen, | |
| "y_len": ylen, | |
| "abs_style_intensity": abs_style_intensity, | |
| "text_rel_pos": text_rel_pos, | |
| "crossatt_rel_pos": crossatt_rel_pos, | |
| "audio_rel_pos": audio_rel_pos, | |
| "segments": segments, | |
| } | |
| def random_segments_from_text_and_durations( | |
| text, | |
| dur, | |
| low_bnd: int = 8, | |
| up_bnd: int = 384, | |
| ): | |
| b = random_log_breakpoints(text, low_bnd, up_bnd) | |
| bounds = [0] + b + [len(text)] | |
| segs, durs = [], [] | |
| for a, b in zip(bounds[:-1], bounds[1:]): | |
| segs.append(text[a:b]) | |
| durs.append(sum(dur[a:b])) | |
| bounds = [0] + list(accumulate(durs, int.__add__)) | |
| segs_dicts = [] | |
| for t, s, e in zip(segs, bounds[:-1], bounds[1:]): | |
| segs_dicts.append( | |
| { | |
| "start": s, | |
| "end": e, | |
| "text_token": t, | |
| } | |
| ) | |
| segs_dicts[-1]["end"] += 1 | |
| return segs_dicts | |
| class LinaDataModule(ptl.LightningDataModule): | |
| def __init__( | |
| self, | |
| path: str | DatasetFactory, | |
| quant_layer: list[int], | |
| train_batch_size: int = 8, | |
| token_by_batch: int | None = None, | |
| n_buckets=5, | |
| codec_rate_hz: int = 75, | |
| num_workers: int = 8, | |
| test_size: int = 2000, | |
| val_batch_size: int = 8, | |
| seed: int = 123, | |
| train_test_seed: int = 123, | |
| segments: bool = False, | |
| segments_args: SegmentsCollateArgs = field( | |
| default_factory=lambda: SegmentsCollateArgs() | |
| ), | |
| collate_args: CollateArgs = field(default_factory=lambda: CollateArgs()), | |
| block_mask_segments: bool = False, | |
| tokenizer_file=None, | |
| trail_end_frame: int | None = None, | |
| split="train", | |
| add_columns: str | list[str] | None = None, | |
| add_text_tokens: list[str] | None = None, | |
| type: str = "latent", | |
| ): | |
| super().__init__() | |
| self.path = path | |
| self.codec_rate_hz = codec_rate_hz | |
| self.num_workers = num_workers | |
| self.quant_layer = quant_layer | |
| self.seed = seed | |
| self.segments = segments | |
| self.segments_args = segments_args | |
| self.collate_args = collate_args | |
| self.train_test_seed = train_test_seed | |
| self.test_size = test_size | |
| self.val_batch_size = val_batch_size | |
| self.train_batch_size = train_batch_size | |
| self.split = split | |
| self.trail_end_frame = trail_end_frame | |
| self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) | |
| if add_text_tokens: | |
| self.tokenizer.add_tokens(add_text_tokens) | |
| self.add_columns = add_columns | |
| self.n_buckets = n_buckets | |
| self.token_by_batch = token_by_batch | |
| self.type = type | |
| def setup(self, stage): | |
| if isinstance(self.path, DatasetFactory): | |
| self.dataset = self.path.build() | |
| else: | |
| self.dataset = load_dataset(self.path) | |
| split = self.split | |
| columns = [ | |
| "audio_latent" if self.type == "latent" else "audio_token", | |
| "text", | |
| "audio_duration", | |
| ] | |
| if self.add_columns is not None: | |
| if type(self.add_columns) is str: | |
| self.add_columns = [self.add_columns] | |
| columns += self.add_columns | |
| if self.segments: | |
| columns += ["token_duration"] | |
| self.collate_fn = lambda x: segments_collate(x, self.tokenizer) | |
| else: | |
| self.collate_fn = lambda x: standalone_collate( | |
| x, self.tokenizer, abs_style_intensity="abs_style_intensity" in columns | |
| ) | |
| self.dataset = ( | |
| self.dataset[split] | |
| .train_test_split(test_size=self.test_size, seed=self.train_test_seed) | |
| .select_columns(columns) | |
| ) | |
| if self.type == "latent": | |
| if self.segments: | |
| self.collate_fn = lambda x: standalone_collate_latent_segments( | |
| x, | |
| self.tokenizer, | |
| **self.segments_args, | |
| ) | |
| else: | |
| self.collate_fn = lambda x: standalone_collate_latent( | |
| x, | |
| self.tokenizer, | |
| **self.collate_args, | |
| ) | |
| def get_buckets_by_quantile(duration, n_quantile, is_sorted=False): | |
| if is_sorted: | |
| size = len(duration) | |
| bucket_size = size // n_quantile | |
| buckets = [ | |
| list(range(i, min(i + bucket_size, size))) | |
| for i in range(0, size, bucket_size) | |
| ] | |
| else: | |
| idxdur = list(enumerate(duration)) | |
| idxdur.sort(key=lambda x: x[1]) | |
| idx, dur = zip(*idxdur) | |
| bucket_size = len(idx) // n_quantile | |
| buckets = [list(x) for x in zip(*[iter(idx)] * bucket_size)] | |
| return buckets | |
| if self.token_by_batch is not None: | |
| train_buckets = get_buckets_by_quantile( | |
| self.dataset["train"]["audio_duration"], self.n_buckets | |
| ) | |
| max_audio_durations = [ | |
| self.dataset["train"]["audio_duration"][x[-1]] for x in train_buckets | |
| ] | |
| batch_sizes = [ | |
| int(self.token_by_batch // (self.codec_rate_hz * ad)) | |
| for ad in max_audio_durations | |
| ] | |
| self.train_batch_sampler = BucketSampler(train_buckets, batch_sizes) | |
| def train_dataloader(self): | |
| if self.token_by_batch is not None: | |
| return DataLoader( | |
| self.dataset["train"].with_format("torch"), | |
| num_workers=self.num_workers, | |
| collate_fn=self.collate_fn, | |
| batch_sampler=self.train_batch_sampler, | |
| ) | |
| else: | |
| return DataLoader( | |
| self.dataset["train"].with_format("torch"), | |
| num_workers=self.num_workers, | |
| batch_size=self.train_batch_size, | |
| collate_fn=self.collate_fn, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.dataset["test"].with_format("torch"), | |
| batch_size=self.val_batch_size, | |
| num_workers=0, | |
| collate_fn=self.collate_fn, | |
| ) | |