pardi-speech / tts /groupdataset.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
raw
history blame
21.3 kB
import math
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import reduce
from itertools import accumulate
from random import choices
from typing import List, Optional, Sequence, Tuple
import pytorch_lightning as ptl
import torch
from datasets import (DatasetDict, concatenate_datasets, load_dataset,
load_from_disk)
from einops import rearrange
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import (BatchSampler, DataLoader, Sampler,
SubsetRandomSampler)
from transformers import PreTrainedTokenizerFast
from tts.tools import (audio_to_text_partial_neighbor_mask, packmask_2d,
pad_2d_sequence, sequence_mask)
class BucketSampler(Sampler[List[int]]):
def __init__(
self,
buckets: List[List[int]],
batch_sizes: List[int] | int,
bucket_sampling_weights: List[Tuple[float]] = None,
drop_last: bool = True,
distributed: bool = True, # TODO - implement not distributed as well
sample_bucket: Optional[int] = None,
seed: int = 123,
epoch_seed: bool = True,
):
if type(batch_sizes) is int:
batch_sizes = [batch_sizes] * len(buckets)
else:
assert len(buckets) == len(batch_sizes)
if bucket_sampling_weights is not None:
assert len(bucket_sampling_weights) == len(batch_sizes)
self.bucket_sampling_weights = bucket_sampling_weights
self.num_replicas = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()
self.buckets = [
b[self.rank : len(b) - len(b) % self.num_replicas : self.num_replicas]
for b in buckets
]
self.num_samples = [len(b) // self.num_replicas for b in buckets]
self.batch_sizes = batch_sizes
self.total_sizes = [
ns // bs for ns, bs in zip(self.num_samples, self.batch_sizes)
]
self.drop_last = drop_last
self.seed = seed
self.epoch = 0
self.sample_bucket = sample_bucket
self.epoch_seed = epoch_seed
self.batch_size = batch_sizes[0]
def set_epoch(self, epoch: int):
self.epoch = epoch
def __len__(self):
return sum(self.total_sizes)
def __iter__(self):
generator = torch.Generator()
generator.manual_seed(self.seed + self.epoch * self.epoch_seed + self.rank)
pool = [
BatchSampler(
SubsetRandomSampler(b, generator=generator),
bs,
drop_last=self.drop_last,
)
for b, bs in zip(self.buckets, self.batch_sizes)
]
pool = [iter(b) for b in pool]
weights = (
[w for w in self.bucket_sampling_weights]
if self.bucket_sampling_weights is not None
else None
)
while pool: # sample until all buckets are done
idx, bucket = choices(list(enumerate(pool)), weights=weights)[0]
try:
batch = next(bucket)
yield batch
except StopIteration:
pool.pop(idx) # if bucket is done, throw it
if weights is not None:
weights.pop(idx)
class DatasetFactory(ABC):
@abstractmethod
def build(self):
pass
class HiFiTTS2_AudioLatent(DatasetFactory):
def __init__(
self,
path: str | list[str] = "hifitts2_vae8_dataset",
duration_column: str = "audio_duration",
duration_path: str | None = None,
expresso_path: str | None = None,
min_dur: float = 3.0,
max_dur: float = 20.1,
framerate: float = 25.0,
):
self.min_dur = min_dur
self.max_dur = max_dur
self.path = path
self.duration_column = duration_column
self.duration_path = duration_path
self.framerate = framerate
self.expresso_path = expresso_path
def build(self):
if type(self.path) is str:
self.path = [self.path]
datasets = [load_from_disk(x) for x in self.path]
dataset = concatenate_datasets(datasets).with_format("torch")
if self.duration_path is not None:
duration_dataset = load_from_disk(self.duration_path)
dataset = concatenate_datasets(
[dataset, duration_dataset], axis=1
).with_format("torch")
dataset = dataset.filter(
lambda dur: dur > self.min_dur and dur < self.max_dur,
input_columns=self.duration_column,
)
dataset = dataset.rename_column(self.duration_column, "audio_duration")
# dataset = dataset.map(
# lambda x: {"audio_duration": x.shape[1] / self.framerate},
# input_columns="audio_latent",
# ).filter(
# lambda dur: dur > self.min_dur and dur < self.max_dur,
# input_columns="audio_duration",
# )
if self.expresso_path is not None:
expresso_dataset = load_from_disk(self.expresso_path).with_format("torch")
dataset = dataset.sort("audio_duration")
return DatasetDict({"train": dataset})
@dataclass
class SegmentsCollateArgs:
abs_style_intensity: bool = False
merge_endpoints: bool = True
block_crossatt_mask: bool = True
alternate_crossatt_pos: bool = False
block_crossatt_past_tokens: int = 0
block_crossatt_future_tokens: int = 0
eos: bool = True
bos: bool = True
@dataclass
class CollateArgs:
abs_style_intensity: bool = False
random_text_segment: bool = False
eos: bool = True
bos: bool = True
num_stop_tokens: int = 1
def random_log_breakpoints(
seq: Sequence, a: int, b: int, gap: bool = False
) -> List[int]:
"""
Generate random breakpoints in a sequence where the gap X between
successive breakpoints satisfies log2(X) ~ Uniform[log2(a), log2(b)].
Gaps are then rounded to the nearest integer in [a, b].
Parameters
----------
seq : Sequence
The input sequence in which to place breakpoints.
a : int
Minimum gap (>= 1).
b : int
Maximum gap (>= a).
Returns
-------
List[int]
Sorted list of breakpoint indices (0 < idx < len(seq)).
"""
if a < 1 or b < a:
raise ValueError("Require 1 <= a <= b")
n = len(seq)
breakpoints: List[int] = []
pos = 0
while True:
# sample U ~ Uniform(log2(a), log2(b))
u = random.uniform(math.log2(a), math.log2(b))
# map back to X = 2^U, then round to nearest integer
x = 2**u
gap = int(math.floor(x + 0.5))
# enforce integer bounds exactly
gap = max(a, min(b, gap))
pos += gap
if pos >= n:
if gap:
breakpoints.append(n - sum(breakpoints))
break
if gap:
breakpoints.append(gap)
else:
breakpoints.append(pos)
return breakpoints
def standalone_collate_latent(
batch,
tokenizer,
abs_style_intensity: bool = False,
random_text_segment: bool = False,
bos: bool = True,
eos: bool = True,
num_stop_tokens: int = 1,
):
audio_latent, text = zip(*[(x["audio_latent"], x["text"]) for x in batch])
audio_latent = [x.squeeze() for x in audio_latent]
text_pp = []
for t in text:
if bos:
t = "[BOS]" + t
if eos:
t = t + "[EOS]"
text_pp.append(t)
text_token = [torch.LongTensor(tokenizer.encode(x)) for x in text_pp]
xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent))
stop_token = []
text_stop_token = []
for x, y in zip(xlen, ylen):
tst = torch.zeros(x)
st = torch.zeros(y)
st_idx = random.randint(1, num_stop_tokens)
st[-1] = st_idx
tst[-1] = st_idx
stop_token.append(st)
text_stop_token.append(tst)
stop_token = pad_sequence(stop_token, batch_first=True).long()
text_stop_token = pad_sequence(text_stop_token, batch_first=True).long()
x_mask, y_mask = map(
lambda x: sequence_mask(x, device="cpu"),
(torch.tensor(xlen), torch.tensor(ylen)),
)
text_rel_pos = None
if random_text_segment:
breakpoints = [random_log_breakpoints(t, 32, 256, gap=True) for t in text_token]
encoder_mask = pad_2d_sequence([packmask_2d(b, b) for b in breakpoints])
text_rel_pos = [torch.cat([torch.arange(bb) for bb in b]) for b in breakpoints]
text_rel_pos = pad_sequence(text_rel_pos, batch_first=True)
else:
encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2)
crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2)
audio_latent, text_token = map(
lambda x: pad_sequence(x, batch_first=True, padding_value=0.0),
(audio_latent, text_token),
)
if abs_style_intensity:
abs_style_intensity = [x["abs_style_intensity"] for x in batch]
abs_style_intensity = [
torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity
]
abs_style_intensity = torch.stack(abs_style_intensity)
else:
abs_style_intensity = None
return {
"text_token": text_token,
"audio_token": audio_latent,
"crossatt_mask": crossatt_mask,
"encoder_mask": encoder_mask,
"y_mask": y_mask,
"stop_token": stop_token,
"text_stop_token": text_stop_token,
"x_len": xlen,
"y_len": ylen,
"abs_style_intensity": abs_style_intensity,
"text_rel_pos": text_rel_pos,
}
def standalone_collate_latent_segments(
batch,
tokenizer,
abs_style_intensity: bool = False,
merge_endpoints: bool = True,
block_crossatt_mask: bool = True,
block_crossatt_past_tokens: int = 0,
block_crossatt_future_tokens: int = 0,
alternate_crossatt_pos: bool = False,
alternate_crossatt_shift: int = 1000,
eos: bool = True,
bos: bool = True,
):
audio_latent, text, token_duration = zip(
*[(x["audio_latent"], x["text"], x["token_duration"]) for x in batch]
)
text_pp = []
for t in text:
if bos:
t = "[BOS]" + t
if eos:
t = t + "[EOS]"
text_pp.append(t)
if merge_endpoints:
tokens = [tokenizer.encode(x) for x in text]
new_td = []
for td in token_duration:
begin, end = td[0], td[-1]
tdd = td[1:-1]
tdd[0] += begin
tdd[-1] += end
new_td.append(tdd)
token_duration = new_td
else:
tokens = [tokenizer.encode(x) for x in text_pp]
segments = [
random_segments_from_text_and_durations(t, td.tolist())
for t, td in zip(tokens, token_duration)
]
bos, eos = map(tokenizer.encode, ("[BOS]", "[EOS]"))
audio_segments = []
text_segments = []
audio_segments_len = []
text_segments_len = []
for aud, seg in zip(audio_latent, segments):
tt, at, tt_l, at_l = [], [], [], []
for i, s in enumerate(seg):
ttoken = s["text_token"]
if bos:
ttoken = bos + ttoken
if eos:
ttoken = ttoken + eos
tt.append(ttoken)
a_s = aud[:, s["start"] : s["end"]]
at.append(a_s)
at_l.append(a_s.shape[1])
tt_l.append(len(ttoken))
audio_segments.append(at)
text_segments.append(tt)
audio_segments_len.append(at_l)
text_segments_len.append(tt_l)
text_token = [torch.LongTensor(reduce(list.__add__, x)) for x in text_segments]
audio_latent = [torch.cat(a_ss, dim=1).squeeze(0) for a_ss in audio_segments]
xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent))
x_mask, y_mask = map(
lambda x: sequence_mask(x, device="cpu"),
(torch.tensor(xlen), torch.tensor(ylen)),
)
audio_latent, text_token = map(
lambda x: pad_sequence(x, batch_first=True, padding_value=0),
(audio_latent, text_token),
)
encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2)
if block_crossatt_mask:
crossatt_mask = [
audio_to_text_partial_neighbor_mask(
x,
y,
past_tokens=block_crossatt_past_tokens,
future_tokens=block_crossatt_future_tokens,
)
for x, y in zip(text_segments_len, audio_segments_len)
]
crossatt_mask = pad_2d_sequence(crossatt_mask)
pad_mask = rearrange(torch.arange(max(ylen)), "n -> 1 n 1") >= rearrange(
torch.tensor(ylen), "n -> n 1 1"
)
else:
crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2)
text_rel_pos = pad_sequence(
[torch.cat([torch.arange(x) for x in tsl]) for tsl in text_segments_len],
batch_first=True,
)
crossatt_rel_pos = None
if alternate_crossatt_pos:
crossatt_rel_pos = []
for tsl in text_segments_len:
rel_pos = []
random_shift = int(random.random() < 0.5)
for i, x in enumerate(tsl):
rel_pos.append(
torch.arange(x)
+ ((random_shift + i) % 2) * alternate_crossatt_shift
)
crossatt_rel_pos.append(torch.cat(rel_pos))
crossatt_rel_pos = pad_sequence(crossatt_rel_pos, batch_first=True)
audio_rel_pos = pad_sequence(
[torch.cat([torch.arange(x) for x in asl]) for asl in audio_segments_len],
batch_first=True,
)
stop_token = []
for asl in audio_segments_len:
sts = []
for x in asl:
st = torch.zeros(x)
st[-1] = 1
sts.append(st)
stop_token.append(torch.cat(sts))
stop_token = pad_sequence(stop_token, batch_first=True).int()
text_stop_token = []
for asl in text_segments_len:
sts = []
for x in asl:
st = torch.zeros(x)
st[-1] = 1
sts.append(st)
text_stop_token.append(torch.cat(sts))
text_stop_token = pad_sequence(text_stop_token, batch_first=True).int()
if abs_style_intensity:
abs_style_intensity = [x["abs_style_intensity"] for x in batch]
abs_style_intensity = [
torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity
]
abs_style_intensity = torch.stack(abs_style_intensity)
else:
abs_style_intensity = None
return {
"text_token": text_token,
"audio_token": audio_latent,
"crossatt_mask": crossatt_mask,
"encoder_mask": encoder_mask,
"y_mask": y_mask,
"stop_token": stop_token,
"text_stop_token": text_stop_token,
"x_mask": x_mask,
"x_len": xlen,
"y_len": ylen,
"abs_style_intensity": abs_style_intensity,
"text_rel_pos": text_rel_pos,
"crossatt_rel_pos": crossatt_rel_pos,
"audio_rel_pos": audio_rel_pos,
"segments": segments,
}
def random_segments_from_text_and_durations(
text,
dur,
low_bnd: int = 8,
up_bnd: int = 384,
):
b = random_log_breakpoints(text, low_bnd, up_bnd)
bounds = [0] + b + [len(text)]
segs, durs = [], []
for a, b in zip(bounds[:-1], bounds[1:]):
segs.append(text[a:b])
durs.append(sum(dur[a:b]))
bounds = [0] + list(accumulate(durs, int.__add__))
segs_dicts = []
for t, s, e in zip(segs, bounds[:-1], bounds[1:]):
segs_dicts.append(
{
"start": s,
"end": e,
"text_token": t,
}
)
segs_dicts[-1]["end"] += 1
return segs_dicts
class LinaDataModule(ptl.LightningDataModule):
def __init__(
self,
path: str | DatasetFactory,
quant_layer: list[int],
train_batch_size: int = 8,
token_by_batch: int | None = None,
n_buckets=5,
codec_rate_hz: int = 75,
num_workers: int = 8,
test_size: int = 2000,
val_batch_size: int = 8,
seed: int = 123,
train_test_seed: int = 123,
segments: bool = False,
segments_args: SegmentsCollateArgs = field(
default_factory=lambda: SegmentsCollateArgs()
),
collate_args: CollateArgs = field(default_factory=lambda: CollateArgs()),
block_mask_segments: bool = False,
tokenizer_file=None,
trail_end_frame: int | None = None,
split="train",
add_columns: str | list[str] | None = None,
add_text_tokens: list[str] | None = None,
type: str = "latent",
):
super().__init__()
self.path = path
self.codec_rate_hz = codec_rate_hz
self.num_workers = num_workers
self.quant_layer = quant_layer
self.seed = seed
self.segments = segments
self.segments_args = segments_args
self.collate_args = collate_args
self.train_test_seed = train_test_seed
self.test_size = test_size
self.val_batch_size = val_batch_size
self.train_batch_size = train_batch_size
self.split = split
self.trail_end_frame = trail_end_frame
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
if add_text_tokens:
self.tokenizer.add_tokens(add_text_tokens)
self.add_columns = add_columns
self.n_buckets = n_buckets
self.token_by_batch = token_by_batch
self.type = type
def setup(self, stage):
if isinstance(self.path, DatasetFactory):
self.dataset = self.path.build()
else:
self.dataset = load_dataset(self.path)
split = self.split
columns = [
"audio_latent" if self.type == "latent" else "audio_token",
"text",
"audio_duration",
]
if self.add_columns is not None:
if type(self.add_columns) is str:
self.add_columns = [self.add_columns]
columns += self.add_columns
if self.segments:
columns += ["token_duration"]
self.collate_fn = lambda x: segments_collate(x, self.tokenizer)
else:
self.collate_fn = lambda x: standalone_collate(
x, self.tokenizer, abs_style_intensity="abs_style_intensity" in columns
)
self.dataset = (
self.dataset[split]
.train_test_split(test_size=self.test_size, seed=self.train_test_seed)
.select_columns(columns)
)
if self.type == "latent":
if self.segments:
self.collate_fn = lambda x: standalone_collate_latent_segments(
x,
self.tokenizer,
**self.segments_args,
)
else:
self.collate_fn = lambda x: standalone_collate_latent(
x,
self.tokenizer,
**self.collate_args,
)
def get_buckets_by_quantile(duration, n_quantile, is_sorted=False):
if is_sorted:
size = len(duration)
bucket_size = size // n_quantile
buckets = [
list(range(i, min(i + bucket_size, size)))
for i in range(0, size, bucket_size)
]
else:
idxdur = list(enumerate(duration))
idxdur.sort(key=lambda x: x[1])
idx, dur = zip(*idxdur)
bucket_size = len(idx) // n_quantile
buckets = [list(x) for x in zip(*[iter(idx)] * bucket_size)]
return buckets
if self.token_by_batch is not None:
train_buckets = get_buckets_by_quantile(
self.dataset["train"]["audio_duration"], self.n_buckets
)
max_audio_durations = [
self.dataset["train"]["audio_duration"][x[-1]] for x in train_buckets
]
batch_sizes = [
int(self.token_by_batch // (self.codec_rate_hz * ad))
for ad in max_audio_durations
]
self.train_batch_sampler = BucketSampler(train_buckets, batch_sizes)
def train_dataloader(self):
if self.token_by_batch is not None:
return DataLoader(
self.dataset["train"].with_format("torch"),
num_workers=self.num_workers,
collate_fn=self.collate_fn,
batch_sampler=self.train_batch_sampler,
)
else:
return DataLoader(
self.dataset["train"].with_format("torch"),
num_workers=self.num_workers,
batch_size=self.train_batch_size,
collate_fn=self.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.dataset["test"].with_format("torch"),
batch_size=self.val_batch_size,
num_workers=0,
collate_fn=self.collate_fn,
)