File size: 2,875 Bytes
46b0a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from __future__ import annotations

from pathlib import Path
from random import Random
from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from .hparams import HParams


class TextAudioDataset(Dataset):
    def __init__(self, hps: HParams, is_validation: bool = False):
        self.datapaths = [
            Path(x).parent / (Path(x).name + ".data.pt")
            for x in Path(
                hps.data.validation_files if is_validation else hps.data.training_files
            )
            .read_text("utf-8")
            .splitlines()
        ]
        self.hps = hps
        self.random = Random(hps.train.seed)
        self.random.shuffle(self.datapaths)
        self.max_spec_len = 800

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        with Path(self.datapaths[index]).open("rb") as f:
            data = torch.load(f, weights_only=True, map_location="cpu")

        # cut long data randomly
        spec_len = data["mel_spec"].shape[1]
        hop_len = self.hps.data.hop_length
        if spec_len > self.max_spec_len:
            start = self.random.randint(0, spec_len - self.max_spec_len)
            end = start + self.max_spec_len - 10
            for key in data.keys():
                if key == "audio":
                    data[key] = data[key][:, start * hop_len : end * hop_len]
                elif key == "spk":
                    continue
                else:
                    data[key] = data[key][..., start:end]
        torch.cuda.empty_cache()
        return data

    def __len__(self) -> int:
        return len(self.datapaths)


def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor:
    max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array]))
    max_x = array[max_idx]
    x_padded = [
        F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0)
        for x_ in array
    ]
    return torch.stack(x_padded)


class TextAudioCollate(nn.Module):
    def forward(
        self, batch: Sequence[dict[str, torch.Tensor]]
    ) -> tuple[torch.Tensor, ...]:
        batch = [b for b in batch if b is not None]
        batch = list(sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True))
        lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long()
        results = {}
        for key in batch[0].keys():
            if key not in ["spk"]:
                results[key] = _pad_stack([b[key] for b in batch]).cpu()
            else:
                results[key] = torch.tensor([[b[key]] for b in batch]).cpu()

        return (
            results["content"],
            results["f0"],
            results["spec"],
            results["mel_spec"],
            results["audio"],
            results["spk"],
            lengths,
            results["uv"],
        )