File size: 4,914 Bytes
46b0a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import math
from logging import getLogger
from pathlib import Path
from typing import Any

import numpy as np
import torch
from cm_time import timer
from joblib import Parallel, delayed
from sklearn.cluster import KMeans, MiniBatchKMeans
from tqdm_joblib import tqdm_joblib

LOG = getLogger(__name__)


def train_cluster(
    input_dir: Path | str,
    n_clusters: int,
    use_minibatch: bool = True,
    batch_size: int = 4096,
    partial_fit: bool = False,
    verbose: bool = False,
) -> dict:
    input_dir = Path(input_dir)
    if not partial_fit:
        LOG.info(f"Loading features from {input_dir}")
        features = []
        for path in input_dir.rglob("*.data.pt"):
            with path.open("rb") as f:
                features.append(
                    torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T
                )
        if not features:
            raise ValueError(f"No features found in {input_dir}")
        features = np.concatenate(features, axis=0).astype(np.float32)
        if features.shape[0] < n_clusters:
            raise ValueError(
                "Too few HuBERT features to cluster. Consider using a smaller number of clusters."
            )
        LOG.info(
            f"shape: {features.shape}, size: {features.nbytes/1024**2:.2f} MB, dtype: {features.dtype}"
        )
        with timer() as t:
            if use_minibatch:
                kmeans = MiniBatchKMeans(
                    n_clusters=n_clusters,
                    verbose=verbose,
                    batch_size=batch_size,
                    max_iter=80,
                    n_init="auto",
                ).fit(features)
            else:
                kmeans = KMeans(
                    n_clusters=n_clusters, verbose=verbose, n_init="auto"
                ).fit(features)
        LOG.info(f"Clustering took {t.elapsed:.2f} seconds")

        x = {
            "n_features_in_": kmeans.n_features_in_,
            "_n_threads": kmeans._n_threads,
            "cluster_centers_": kmeans.cluster_centers_,
        }
        return x
    else:
        # minibatch partial fit
        paths = list(input_dir.rglob("*.data.pt"))
        if len(paths) == 0:
            raise ValueError(f"No features found in {input_dir}")
        LOG.info(f"Found {len(paths)} features in {input_dir}")
        n_batches = math.ceil(len(paths) / batch_size)
        LOG.info(f"Splitting into {n_batches} batches")
        with timer() as t:
            kmeans = MiniBatchKMeans(
                n_clusters=n_clusters,
                verbose=verbose,
                batch_size=batch_size,
                max_iter=80,
                n_init="auto",
            )
            for i in range(0, len(paths), batch_size):
                LOG.info(
                    f"Processing batch {i//batch_size+1}/{n_batches} for speaker {input_dir.stem}"
                )
                features = []
                for path in paths[i : i + batch_size]:
                    with path.open("rb") as f:
                        features.append(
                            torch.load(f, weights_only=True)["content"]
                            .squeeze(0)
                            .numpy()
                            .T
                        )
                features = np.concatenate(features, axis=0).astype(np.float32)
                kmeans.partial_fit(features)
        LOG.info(f"Clustering took {t.elapsed:.2f} seconds")

        x = {
            "n_features_in_": kmeans.n_features_in_,
            "_n_threads": kmeans._n_threads,
            "cluster_centers_": kmeans.cluster_centers_,
        }
        return x


def main(
    input_dir: Path | str,
    output_path: Path | str,
    n_clusters: int = 10000,
    use_minibatch: bool = True,
    batch_size: int = 4096,
    partial_fit: bool = False,
    verbose: bool = False,
) -> None:
    input_dir = Path(input_dir)
    output_path = Path(output_path)

    if not (use_minibatch or not partial_fit):
        raise ValueError("partial_fit requires use_minibatch")

    def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]:
        return input_path.stem, train_cluster(input_path, **kwargs)

    with tqdm_joblib(desc="Training clusters", total=len(list(input_dir.iterdir()))):
        parallel_result = Parallel(n_jobs=-1)(
            delayed(train_cluster_)(
                speaker_name,
                n_clusters=n_clusters,
                use_minibatch=use_minibatch,
                batch_size=batch_size,
                partial_fit=partial_fit,
                verbose=verbose,
            )
            for speaker_name in input_dir.iterdir()
        )
    assert parallel_result is not None
    checkpoint = dict(parallel_result)
    output_path.parent.mkdir(exist_ok=True, parents=True)
    with output_path.open("wb") as f:
        torch.save(checkpoint, f)