jiamengjiameng commited on
Commit
0389824
·
verified ·
1 Parent(s): cc9e380

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Emacs
2
+ *~
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+ /runs
32
+ /checkpoints
33
+ /base
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ /runs
171
+ /.cache
172
+ /__pycache__
173
+
174
+ *.wav
175
+ *.pth
176
+ *.pt
177
+ *.pt.gz
178
+ wandb/
179
+ sven_latest_checkpoint/
180
+ sven_qwen/
181
+ pretrained_models/
182
+ xcodec/
183
+ small_speaker_shards_all/
184
+ sven_all_shards/
185
+ qwen_380k/
186
+ evals/
187
+ *.safetensors
188
+ *.pt
189
+ .ruff_cache
neucodec/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .codec_encoder import CodecEncoder
2
+ from .codec_decoder_vocos import CodecDecoderVocos
3
+ from .model import NeuCodec
neucodec/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
neucodec/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
neucodec/alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12):
15
+ super().__init__()
16
+ self.up_ratio = up_ratio
17
+ self.down_ratio = down_ratio
18
+ self.act = activation
19
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
+
22
+ # x: [B,C,T]
23
+ def forward(self, x):
24
+ x = self.upsample(x)
25
+ x = self.act(x)
26
+ x = self.downsample(x)
27
+
28
+ return x
neucodec/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
neucodec/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
neucodec/bs_roformer5.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ import numpy as np
6
+
7
+ from torch.nn import Module, ModuleList
8
+ from einops import rearrange
9
+ from torchtune.modules import RotaryPositionalEmbeddings
10
+
11
+
12
+ class RMSNorm(torch.nn.Module):
13
+ def __init__(self, dim: int, eps: float = 1e-6):
14
+ r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
15
+ super().__init__()
16
+ self.eps = eps
17
+ self.weight = nn.Parameter(torch.ones(dim))
18
+
19
+ def forward(self, x):
20
+ norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
21
+ output = x * torch.rsqrt(norm_x + self.eps) * self.weight
22
+ return output
23
+
24
+
25
+ class MLP(nn.Module):
26
+ def __init__(self, dim: int) -> None:
27
+ super().__init__()
28
+
29
+ self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
30
+ self.silu = nn.SiLU()
31
+ self.fc2 = nn.Linear(4 * dim, dim, bias=False)
32
+
33
+ def forward(self, x):
34
+ x = self.fc1(x)
35
+ x = self.silu(x)
36
+ x = self.fc2(x)
37
+ return x
38
+
39
+
40
+ class Attention(nn.Module):
41
+
42
+ def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
43
+ super().__init__()
44
+
45
+ assert dim % n_heads == 0
46
+
47
+ self.n_heads = n_heads
48
+ self.dim = dim
49
+ self.rotary_embed = rotary_embed
50
+
51
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
52
+ assert self.flash, "Must have flash attention."
53
+
54
+ self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
55
+ self.c_proj = nn.Linear(dim, dim, bias=False)
56
+
57
+ def forward(self, x):
58
+ r"""
59
+ Args:
60
+ x: (b, t, h*d)
61
+
62
+ Constants:
63
+ b: batch_size
64
+ t: time steps
65
+ r: 3
66
+ h: heads_num
67
+ d: heads_dim
68
+ """
69
+ B, T, C = x.size()
70
+
71
+ q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads)
72
+ # q, k, v: (b, h, t, d)
73
+
74
+ q = self.rotary_embed(q)
75
+ k = self.rotary_embed(k)
76
+
77
+ if self.flash:
78
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False)
79
+
80
+ y = rearrange(y, 'b h t d -> b t (h d)')
81
+
82
+ y = self.c_proj(y)
83
+ # shape: (b, t, h*d)
84
+
85
+ return y
86
+
87
+
88
+ class TransformerBlock(nn.Module):
89
+ def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
90
+
91
+ super().__init__()
92
+ self.dim = dim
93
+ self.n_heads = n_heads
94
+
95
+ self.att_norm = RMSNorm(dim)
96
+ self.ffn_norm = RMSNorm(dim)
97
+ self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
98
+ self.mlp = MLP(dim=dim)
99
+
100
+
101
+ def forward(
102
+ self,
103
+ x: torch.Tensor,
104
+ ):
105
+ x = x + self.att(self.att_norm(x))
106
+ x = x + self.mlp(self.ffn_norm(x))
107
+ return x
108
+
109
+
110
+ if __name__ == '__main__':
111
+ rotary_embed_128 = RotaryPositionalEmbeddings(dim=128)
112
+ transformer_block = TransformerBlock(
113
+ dim=1024,
114
+ n_heads=8,
115
+ rotary_embed=rotary_embed_128
116
+ )
117
+ x = torch.randn(2, 128, 1024)
118
+ y = transformer_block(x)
119
+ print(y.shape)
120
+ c=1
neucodec/codec_decoder_vocos.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import List
5
+ from torchtune.modules import RotaryPositionalEmbeddings
6
+ from vector_quantize_pytorch import ResidualFSQ
7
+
8
+ from .bs_roformer5 import TransformerBlock
9
+
10
+
11
+ class ISTFT(nn.Module):
12
+ """
13
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
14
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
15
+ See issue: https://github.com/pytorch/pytorch/issues/62323
16
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
17
+ The NOLA constraint is met as we trim padded samples anyway.
18
+
19
+ Args:
20
+ n_fft (int): Size of Fourier transform.
21
+ hop_length (int): The distance between neighboring sliding window frames.
22
+ win_length (int): The size of window frame and STFT filter.
23
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
24
+ """
25
+
26
+ def __init__(
27
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
28
+ ):
29
+ super().__init__()
30
+ if padding not in ["center", "same"]:
31
+ raise ValueError("Padding must be 'center' or 'same'.")
32
+ self.padding = padding
33
+ self.n_fft = n_fft
34
+ self.hop_length = hop_length
35
+ self.win_length = win_length
36
+ window = torch.hann_window(win_length)
37
+ self.register_buffer("window", window)
38
+
39
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
42
+
43
+ Args:
44
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
45
+ N is the number of frequency bins, and T is the number of time frames.
46
+
47
+ Returns:
48
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
49
+ """
50
+ if self.padding == "center":
51
+ # Fallback to pytorch native implementation
52
+ return torch.istft(
53
+ spec,
54
+ self.n_fft,
55
+ self.hop_length,
56
+ self.win_length,
57
+ self.window,
58
+ center=True,
59
+ )
60
+ elif self.padding == "same":
61
+ pad = (self.win_length - self.hop_length) // 2
62
+ else:
63
+ raise ValueError("Padding must be 'center' or 'same'.")
64
+
65
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
66
+ B, N, T = spec.shape
67
+
68
+ # Inverse FFT
69
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
70
+ ifft = ifft * self.window[None, :, None]
71
+
72
+ # Overlap and Add
73
+ output_size = (T - 1) * self.hop_length + self.win_length
74
+ y = torch.nn.functional.fold(
75
+ ifft,
76
+ output_size=(1, output_size),
77
+ kernel_size=(1, self.win_length),
78
+ stride=(1, self.hop_length),
79
+ )[:, 0, 0, pad:-pad]
80
+
81
+ # Window envelope
82
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
83
+ window_envelope = torch.nn.functional.fold(
84
+ window_sq,
85
+ output_size=(1, output_size),
86
+ kernel_size=(1, self.win_length),
87
+ stride=(1, self.hop_length),
88
+ ).squeeze()[pad:-pad]
89
+
90
+ # Normalize
91
+ assert (window_envelope > 1e-11).all()
92
+ y = y / window_envelope
93
+
94
+ return y
95
+
96
+
97
+ class FourierHead(nn.Module):
98
+ """Base class for inverse fourier modules."""
99
+
100
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Args:
103
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
104
+ L is the sequence length, and H denotes the model dimension.
105
+
106
+ Returns:
107
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
108
+ """
109
+ raise NotImplementedError("Subclasses must implement the forward method.")
110
+
111
+
112
+ class ISTFTHead(FourierHead):
113
+ """
114
+ ISTFT Head module for predicting STFT complex coefficients.
115
+
116
+ Args:
117
+ dim (int): Hidden dimension of the model.
118
+ n_fft (int): Size of Fourier transform.
119
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
120
+ the resolution of the input features.
121
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
122
+ """
123
+
124
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
125
+ super().__init__()
126
+ out_dim = n_fft + 2
127
+ self.out = torch.nn.Linear(dim, out_dim)
128
+ self.istft = ISTFT(
129
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
130
+ )
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ """
134
+ Forward pass of the ISTFTHead module.
135
+
136
+ Args:
137
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
138
+ L is the sequence length, and H denotes the model dimension.
139
+
140
+ Returns:
141
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
142
+ """
143
+ x_pred = self.out(x)
144
+ # x_pred = x
145
+ x_pred = x_pred.transpose(1, 2)
146
+ mag, p = x_pred.chunk(2, dim=1)
147
+ mag = torch.exp(mag)
148
+ mag = torch.clip(
149
+ mag, max=1e2
150
+ ) # safeguard to prevent excessively large magnitudes
151
+ # wrapping happens here. These two lines produce real and imaginary value
152
+ x = torch.cos(p)
153
+ y = torch.sin(p)
154
+ # recalculating phase here does not produce anything new
155
+ # only costs time
156
+ # phase = torch.atan2(y, x)
157
+ # S = mag * torch.exp(phase * 1j)
158
+ # better directly produce the complex value
159
+ S = mag * (x + 1j * y)
160
+ audio = self.istft(S)
161
+ return audio.unsqueeze(1), x_pred
162
+
163
+
164
+ def nonlinearity(x):
165
+ # swish
166
+ return x * torch.sigmoid(x)
167
+
168
+
169
+ def Normalize(in_channels, num_groups=32):
170
+ return torch.nn.GroupNorm(
171
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
172
+ )
173
+
174
+
175
+ class ResnetBlock(nn.Module):
176
+ def __init__(
177
+ self,
178
+ *,
179
+ in_channels,
180
+ out_channels=None,
181
+ conv_shortcut=False,
182
+ dropout,
183
+ temb_channels=512,
184
+ ):
185
+ super().__init__()
186
+ self.in_channels = in_channels
187
+ out_channels = in_channels if out_channels is None else out_channels
188
+ self.out_channels = out_channels
189
+ self.use_conv_shortcut = conv_shortcut
190
+
191
+ self.norm1 = Normalize(in_channels)
192
+ self.conv1 = torch.nn.Conv1d(
193
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
194
+ )
195
+ if temb_channels > 0:
196
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
197
+ self.norm2 = Normalize(out_channels)
198
+ self.dropout = torch.nn.Dropout(dropout)
199
+ self.conv2 = torch.nn.Conv1d(
200
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
201
+ )
202
+ if self.in_channels != self.out_channels:
203
+ if self.use_conv_shortcut:
204
+ self.conv_shortcut = torch.nn.Conv1d(
205
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
206
+ )
207
+ else:
208
+ self.nin_shortcut = torch.nn.Conv1d(
209
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
210
+ )
211
+
212
+ def forward(self, x, temb=None):
213
+ h = x
214
+ h = self.norm1(h)
215
+ h = nonlinearity(h)
216
+ h = self.conv1(h)
217
+
218
+ if temb is not None:
219
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
220
+
221
+ h = self.norm2(h)
222
+ h = nonlinearity(h)
223
+ h = self.dropout(h)
224
+ h = self.conv2(h)
225
+
226
+ if self.in_channels != self.out_channels:
227
+ if self.use_conv_shortcut:
228
+ x = self.conv_shortcut(x)
229
+ else:
230
+ x = self.nin_shortcut(x)
231
+
232
+ return x + h
233
+
234
+
235
+ class Backbone(nn.Module):
236
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
237
+
238
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
239
+ """
240
+ Args:
241
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
242
+ C denotes output features, and L is the sequence length.
243
+
244
+ Returns:
245
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
246
+ and H denotes the model dimension.
247
+ """
248
+ raise NotImplementedError("Subclasses must implement the forward method.")
249
+
250
+
251
+ class VocosBackbone(Backbone):
252
+ """
253
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
254
+
255
+ Args:
256
+ input_channels (int): Number of input features channels.
257
+ dim (int): Hidden dimension of the model.
258
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
259
+ num_layers (int): Number of ConvNeXtBlock layers.
260
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
261
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
262
+ None means non-conditional model. Defaults to None.
263
+ """
264
+
265
+ def __init__(self, hidden_dim=1024, depth=12, heads=16, pos_meb_dim=64):
266
+ super().__init__()
267
+
268
+ self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3)
269
+
270
+ self.temb_ch = 0
271
+ block_in = hidden_dim
272
+ dropout = 0.1
273
+
274
+ prior_net: List[nn.Module] = [
275
+ ResnetBlock(
276
+ in_channels=block_in,
277
+ out_channels=block_in,
278
+ temb_channels=self.temb_ch,
279
+ dropout=dropout,
280
+ ),
281
+ ResnetBlock(
282
+ in_channels=block_in,
283
+ out_channels=block_in,
284
+ temb_channels=self.temb_ch,
285
+ dropout=dropout,
286
+ ),
287
+ ]
288
+ self.prior_net = nn.Sequential(*prior_net)
289
+
290
+ depth = depth
291
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
292
+
293
+ transformer_blocks = [
294
+ TransformerBlock(
295
+ dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed
296
+ )
297
+ for _ in range(depth)
298
+ ]
299
+
300
+ self.transformers = nn.Sequential(*transformer_blocks)
301
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
302
+ post_net: List[nn.Module] = [
303
+ ResnetBlock(
304
+ in_channels=block_in,
305
+ out_channels=block_in,
306
+ temb_channels=self.temb_ch,
307
+ dropout=dropout,
308
+ ),
309
+ ResnetBlock(
310
+ in_channels=block_in,
311
+ out_channels=block_in,
312
+ temb_channels=self.temb_ch,
313
+ dropout=dropout,
314
+ ),
315
+ ]
316
+ self.post_net = nn.Sequential(*post_net)
317
+
318
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
319
+ x = x.transpose(1, 2)
320
+ x = self.embed(x)
321
+ x = self.prior_net(x)
322
+ x = x.transpose(1, 2)
323
+ x = self.transformers(x)
324
+ x = x.transpose(1, 2)
325
+ x = self.post_net(x)
326
+ x = x.transpose(1, 2)
327
+ x = self.final_layer_norm(x)
328
+ return x
329
+
330
+
331
+ def init_weights(m):
332
+ if isinstance(m, nn.Conv1d):
333
+ nn.init.trunc_normal_(m.weight, std=0.02)
334
+ nn.init.constant_(m.bias, 0)
335
+
336
+
337
+ class CodecDecoderVocos(nn.Module):
338
+ def __init__(
339
+ self,
340
+ hidden_dim=1024,
341
+ depth=12,
342
+ heads=16,
343
+ pos_meb_dim=64,
344
+ hop_length=320,
345
+ vq_num_quantizers=1,
346
+ vq_dim=2048, # 1024 2048
347
+ vq_commit_weight=0.25,
348
+ vq_weight_init=False,
349
+ vq_full_commit_loss=False,
350
+ codebook_size=16384,
351
+ codebook_dim=16,
352
+ ):
353
+ super().__init__()
354
+ self.hop_length = hop_length
355
+
356
+ self.quantizer = ResidualFSQ(
357
+ dim=vq_dim, levels=[4, 4, 4, 4, 4, 4, 4, 4], num_quantizers=1
358
+ )
359
+
360
+ self.backbone = VocosBackbone(
361
+ hidden_dim=hidden_dim, depth=depth, heads=heads, pos_meb_dim=pos_meb_dim
362
+ )
363
+
364
+ self.head = ISTFTHead(
365
+ dim=hidden_dim,
366
+ n_fft=self.hop_length * 4,
367
+ hop_length=self.hop_length,
368
+ padding="same",
369
+ )
370
+
371
+ self.reset_parameters()
372
+
373
+ def forward(self, x, vq=True):
374
+ if vq is True:
375
+ # x, q, commit_loss = self.quantizer(x)
376
+ x = x.permute(0, 2, 1)
377
+ x, q = self.quantizer(x)
378
+ x = x.permute(0, 2, 1)
379
+ q = q.permute(0, 2, 1)
380
+ return x, q, None
381
+ x = self.backbone(x)
382
+ x, _ = self.head(x)
383
+
384
+ return x, _
385
+
386
+ def vq2emb(self, vq):
387
+ self.quantizer = self.quantizer.eval()
388
+ x = self.quantizer.vq2emb(vq)
389
+ return x
390
+
391
+ def get_emb(self):
392
+ self.quantizer = self.quantizer.eval()
393
+ embs = self.quantizer.get_emb()
394
+ return embs
395
+
396
+ def inference_vq(self, vq):
397
+ x = vq[None, :, :]
398
+ x = self.model(x)
399
+ return x
400
+
401
+ def inference_0(self, x):
402
+ x, q, loss, perp = self.quantizer(x)
403
+ x = self.model(x)
404
+ return x, None
405
+
406
+ def inference(self, x):
407
+ x = self.model(x)
408
+ return x, None
409
+
410
+ def remove_weight_norm(self):
411
+ """Remove weight normalization module from all of the layers."""
412
+
413
+ def _remove_weight_norm(m):
414
+ try:
415
+ torch.nn.utils.remove_weight_norm(m)
416
+ except ValueError: # this module didn't have weight norm
417
+ return
418
+
419
+ self.apply(_remove_weight_norm)
420
+
421
+ def apply_weight_norm(self):
422
+ """Apply weight normalization module from all of the layers."""
423
+
424
+ def _apply_weight_norm(m):
425
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
426
+ torch.nn.utils.weight_norm(m)
427
+
428
+ self.apply(_apply_weight_norm)
429
+
430
+ def reset_parameters(self):
431
+ self.apply(init_weights)
neucodec/codec_encoder.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from torch import nn
5
+
6
+ from .module import WNConv1d, EncoderBlock
7
+ from .alias_free_torch import Activation1d
8
+ from . import activations
9
+
10
+
11
+ def init_weights(m):
12
+ if isinstance(m, nn.Conv1d):
13
+ nn.init.trunc_normal_(m.weight, std=0.02)
14
+ nn.init.constant_(m.bias, 0)
15
+
16
+
17
+ class CodecEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ ngf=48,
21
+ up_ratios=[2, 2, 4, 4, 5],
22
+ dilations=(1, 3, 9),
23
+ hidden_dim=1024,
24
+ depth=12,
25
+ heads=12,
26
+ pos_meb_dim=64,
27
+ ):
28
+ super().__init__()
29
+ self.hop_length = np.prod(up_ratios)
30
+ self.ngf = ngf
31
+ self.up_ratios = up_ratios
32
+
33
+ d_model = ngf
34
+ self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
35
+
36
+ for i, stride in enumerate(up_ratios):
37
+ d_model *= 2
38
+ self.conv_blocks += [
39
+ EncoderBlock(d_model, stride=stride, dilations=dilations)
40
+ ]
41
+
42
+ self.conv_blocks = nn.Sequential(*self.conv_blocks)
43
+
44
+ self.conv_final_block = [
45
+ Activation1d(
46
+ activation=activations.SnakeBeta(d_model, alpha_logscale=True)
47
+ ),
48
+ WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1),
49
+ ]
50
+ self.conv_final_block = nn.Sequential(*self.conv_final_block)
51
+
52
+ self.reset_parameters()
53
+
54
+ def forward(self, x):
55
+ x = self.conv_blocks(x)
56
+ x = self.conv_final_block(x)
57
+ x = x.permute(0, 2, 1)
58
+ return x
59
+
60
+ def inference(self, x):
61
+ return self.block(x)
62
+
63
+ def remove_weight_norm(self):
64
+ """Remove weight normalization module from all of the layers."""
65
+
66
+ def _remove_weight_norm(m):
67
+ try:
68
+ torch.nn.utils.remove_weight_norm(m)
69
+ except ValueError: # this module didn't have weight norm
70
+ return
71
+
72
+ self.apply(_remove_weight_norm)
73
+
74
+ def apply_weight_norm(self):
75
+ """Apply weight normalization module from all of the layers."""
76
+
77
+ def _apply_weight_norm(m):
78
+ if isinstance(m, nn.Conv1d):
79
+ torch.nn.utils.weight_norm(m)
80
+
81
+ self.apply(_apply_weight_norm)
82
+
83
+ def reset_parameters(self):
84
+ self.apply(init_weights)
neucodec/model.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+
9
+ from typing import Optional
10
+ from torchaudio import transforms as T
11
+ from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
12
+
13
+ from .codec_encoder import CodecEncoder
14
+ from .codec_decoder_vocos import CodecDecoderVocos
15
+ from .module import SemanticEncoder
16
+
17
+
18
+ class NeuCodec(nn.Module):
19
+ def __init__(self, ckpt_path: str, sample_rate: int, hop_length: int):
20
+ super().__init__()
21
+
22
+ # load ckpt
23
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
24
+ self.sample_rate = sample_rate
25
+ self.hop_length = hop_length
26
+
27
+ # load modules
28
+ self.semantic_model = Wav2Vec2BertModel.from_pretrained(
29
+ "facebook/w2v-bert-2.0", output_hidden_states=True
30
+ )
31
+ self.semantic_model.eval()
32
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
33
+ "facebook/w2v-bert-2.0"
34
+ )
35
+ self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024)
36
+ self.CodecEnc = CodecEncoder()
37
+ self.generator = CodecDecoderVocos(hop_length=hop_length)
38
+ self.fc_prior = nn.Linear(2048, 2048)
39
+ self.fc_post_a = nn.Linear(2048, 1024)
40
+
41
+ # load checkpoint
42
+ self._load_ckpt(ckpt)
43
+
44
+ def _load_ckpt(self, ckpt):
45
+ # differentiate between `.ckpt` and `.bin`
46
+ if ckpt.get("state_dict"):
47
+ state_dicts = ckpt.get("state_dict")
48
+ else:
49
+ state_dicts = ckpt
50
+
51
+ # assign keys to correct model components
52
+ filtered_enc = {}
53
+ filtered_gen = {}
54
+ filtered_post = {}
55
+ filtered_prior = {}
56
+ filtered_semantic = {}
57
+ for key, value in state_dicts.items():
58
+ if key.startswith("CodecEnc."):
59
+ new_key = key[len("CodecEnc."):]
60
+ filtered_enc[new_key] = value
61
+ elif key.startswith("generator."):
62
+ new_key = key[len("generator."):]
63
+ filtered_gen[new_key] = value
64
+ elif key.startswith("fc_post_a."):
65
+ new_key = key[len("fc_post_a."):]
66
+ filtered_post[new_key] = value
67
+ elif key.startswith("SemanticEncoder_module."):
68
+ new_key = key[len("SemanticEncoder_module."):]
69
+ filtered_semantic[new_key] = value
70
+ elif key.startswith("fc_prior."):
71
+ new_key = key[len("fc_prior."):]
72
+ filtered_prior[new_key] = value
73
+
74
+ # load
75
+ self.CodecEnc.load_state_dict(filtered_enc)
76
+ self.CodecEnc.eval()
77
+ self.generator.load_state_dict(filtered_gen, strict=False)
78
+ self.generator.eval()
79
+ self.fc_post_a.load_state_dict(filtered_post)
80
+ self.fc_post_a.eval()
81
+ self.fc_prior.load_state_dict(filtered_prior)
82
+ self.SemanticEncoder_module.load_state_dict(filtered_semantic)
83
+ self.SemanticEncoder_module.eval()
84
+
85
+ @torch.inference_mode()
86
+ def encode_code(
87
+ self,
88
+ input_waveform: torch.Tensor,
89
+ semantic_features: torch.Tensor = None,
90
+ sample_rate: int = 16_000,
91
+ ) -> torch.Tensor:
92
+ pad_for_wav = 320 - (input_waveform.shape[1] % 320)
93
+ input_waveform = torch.nn.functional.pad(input_waveform, (0, pad_for_wav))
94
+
95
+ if semantic_features is None:
96
+ semantic_features = self.feature_extractor(
97
+ input_waveform, sampling_rate=sample_rate, return_tensors="pt"
98
+ ).input_features.to(self.device) # [batch, frames, feat_dim]
99
+ else:
100
+ semantic_features = semantic_features[:, 0, :, :]
101
+
102
+ semantic_output = self.semantic_model(semantic_features)
103
+ semantic_hidden_16 = semantic_output.hidden_states[16]
104
+ semantic_hidden_16 = semantic_hidden_16.transpose(
105
+ 1, 2
106
+ ) # [batch, hidden_dim, frames]
107
+ semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
108
+ if len(input_waveform.shape) == 2:
109
+ wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
110
+ else:
111
+ wav = input_waveform.to(self.device)
112
+
113
+ vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024]
114
+ vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
115
+
116
+ if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
117
+ min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
118
+ vq_emb = vq_emb[:, :, :min_len]
119
+ semantic_encoded = semantic_encoded[:, :, :min_len]
120
+ concat_emb = torch.cat(
121
+ [semantic_encoded, vq_emb], dim=1
122
+ ) # [batch, 2048, frames]
123
+ concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
124
+ _, vq_code, _ = self.generator(concat_emb, vq=True)
125
+ return vq_code
126
+
127
+ @torch.inference_mode()
128
+ def decode_code(self, vq_code: torch.Tensor) -> torch.Tensor:
129
+ vq_post_emb = self.generator.quantizer.get_output_from_indices(
130
+ vq_code.transpose(1, 2)
131
+ )
132
+ vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames]
133
+ vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(
134
+ 1, 2
135
+ ) # [batch, 1024, frames]
136
+ recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[
137
+ 0
138
+ ] # [batch, time]
139
+ return recon_audio
140
+
141
+ @torch.inference_mode()
142
+ def autoencode(self, fpath: str, output_fpath: Optional[str] = None):
143
+ y, sr = torchaudio.load(fpath)
144
+ if sr != 16_000:
145
+ y = T.Resample(sr, 16_000)(y)
146
+ vq_codes = self.encode_code(y)
147
+ recon = self.decode_code(vq_codes)
148
+
149
+ if output_fpath is None:
150
+ name, fext = os.path.splitext(fpath)
151
+ output_fpath = f"{name}_recon{fext}"
152
+
153
+ sf.write(output_fpath, recon[0, 0, :].cpu(), self.sample_rate)
154
+
155
+ @torch.inference_mode()
156
+ def batch_encode(
157
+ self, fpaths: list[str], return_tensor: bool = False
158
+ ) -> tuple[list[torch.Tensor], list[int]] | tuple[torch.Tensor, list[int]]:
159
+ # prepare batch
160
+ wavs_batch, semantic_batch, token_durations = self._pad_batch(
161
+ [self._preprocess_file(fpath) for fpath in fpaths]
162
+ )
163
+ vq_codes = self.encode_code(wavs_batch, semantic_batch)
164
+
165
+ # return, unpad if we want to
166
+ if return_tensor:
167
+ return vq_codes, list(token_durations)
168
+
169
+ unpadded_vq_codes = []
170
+ for idx, token_dur in enumerate(token_durations):
171
+ curr_codes = vq_codes[idx, :, :token_dur]
172
+ unpadded_vq_codes.append(curr_codes)
173
+
174
+ return unpadded_vq_codes, None
175
+
176
+ @torch.inference_mode()
177
+ def batch_decode(
178
+ self,
179
+ vq_codes: list[torch.Tensor] | torch.Tensor,
180
+ token_durations: Optional[list[int]] = None,
181
+ ):
182
+ # pad tensor if need be
183
+ if isinstance(vq_codes, list):
184
+ vq_codes, token_durations = self._pad_codes(vq_codes)
185
+ else:
186
+ assert token_durations is not None
187
+
188
+ # decode
189
+ recons = self.decode_code(vq_codes)
190
+
191
+ # unpad
192
+ cut_recons = []
193
+ for idx, token_dur in enumerate(token_durations):
194
+ curr_recon = recons[idx, :, : int(token_dur * self.hop_length)]
195
+ cut_recons.append(curr_recon)
196
+
197
+ return cut_recons
198
+
199
+ @torch.inference_mode()
200
+ def batch_autoencode(
201
+ self, fpaths: list[str], output_fpaths: Optional[list[str]] = None
202
+ ) -> list[torch.Tensor]:
203
+ vq_codes, token_durations = self.batch_encode(fpaths, return_tensor=True)
204
+ cut_recons = self.batch_decode(vq_codes, token_durations)
205
+
206
+ if output_fpaths:
207
+ for recon, output_fpath in zip(cut_recons, output_fpaths):
208
+ sf.write(output_fpath, recon.cpu().numpy()[0, :], self.sample_rate)
209
+
210
+ return cut_recons
211
+
212
+ def _preprocess_file(self, fpath: str):
213
+ # load and resample
214
+ y, sr = torchaudio.load(fpath)
215
+ if sr != 16_000:
216
+ y = T.Resample(sr, 16_000)(y)
217
+
218
+ # compute duration for any cutting we might need to do, in terms of n_tokens
219
+ token_duration = int((y.shape[-1] / 16_000) * 50)
220
+
221
+ # get semantic model features: [harry] note i don't think this can be batched
222
+ semantic_model_input = self.feature_extractor(
223
+ y, sampling_rate=16_000, return_tensors="pt"
224
+ ).input_features
225
+
226
+ return y.to(self.device), semantic_model_input.to(self.device), token_duration
227
+
228
+ def _pad_batch(self, batch: list[tuple[torch.Tensor, torch.Tensor, int]]):
229
+ # unpack batch
230
+ wavs, semantic_features, token_durations = zip(*batch)
231
+ max_length_semantic = max([f.shape[1] for f in semantic_features])
232
+ max_length = max_length_semantic * 320
233
+
234
+ # pad wavs
235
+ wavs_padded = []
236
+ for audio in wavs:
237
+ padding = max_length - audio.shape[1]
238
+ if padding > 0:
239
+ padded_audio = F.pad(audio, (0, padding), mode="constant", value=0)
240
+ else:
241
+ padded_audio = audio[:, :max_length]
242
+ wavs_padded.append(padded_audio)
243
+ wavs_tensor = torch.stack(wavs_padded)
244
+
245
+ # pad semantic features
246
+ semantic_features_padded = []
247
+ for feat in semantic_features:
248
+ padding = max_length_semantic - feat.shape[1]
249
+ padded_feat = F.pad(feat, (0, 0, 0, padding), mode="constant", value=0)
250
+ semantic_features_padded.append(padded_feat)
251
+ semantic_feature_tensor = torch.stack(semantic_features_padded)
252
+
253
+ return wavs_tensor, semantic_feature_tensor, token_durations
254
+
255
+ def _pad_codes(self, vq_codes: list[torch.Tensor]):
256
+ max_len = max([i.shape[-1] for i in vq_codes])
257
+ token_durations = []
258
+ padded_codes = []
259
+ for curr_codes in vq_codes:
260
+ curr_len = curr_codes.shape[-1]
261
+ token_durations.append(curr_len)
262
+ padding = max_len - curr_len
263
+ curr_codes = F.pad(curr_codes, (0, padding), mode="constant", value=0)
264
+ padded_codes.append(curr_codes)
265
+ return torch.stack(padded_codes), token_durations
266
+
267
+ @property
268
+ def device(self):
269
+ return next(self.parameters()).device
neucodec/module.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from torch.nn.utils import weight_norm
4
+
5
+ from .activations import SnakeBeta
6
+ from .alias_free_torch import Activation1d
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ class ResidualUnit(nn.Module):
14
+ def __init__(self, dim: int = 16, dilation: int = 1):
15
+ super().__init__()
16
+ pad = ((7 - 1) * dilation) // 2
17
+ self.block = nn.Sequential(
18
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
19
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
20
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
21
+ WNConv1d(dim, dim, kernel_size=1),
22
+ )
23
+
24
+ def forward(self, x):
25
+ return x + self.block(x)
26
+
27
+
28
+ class EncoderBlock(nn.Module):
29
+ def __init__(self, dim: int = 16, stride: int = 1, dilations=(1, 3, 9)):
30
+ super().__init__()
31
+ runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations]
32
+ self.block = nn.Sequential(
33
+ *runits,
34
+ Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
35
+ WNConv1d(
36
+ dim // 2,
37
+ dim,
38
+ kernel_size=2 * stride,
39
+ stride=stride,
40
+ padding=stride // 2 + stride % 2,
41
+ ),
42
+ )
43
+
44
+ def forward(self, x):
45
+ return self.block(x)
46
+
47
+
48
+ class SemanticEncoder(nn.Module):
49
+ def __init__(
50
+ self,
51
+ input_channels: int,
52
+ code_dim: int,
53
+ encode_channels: int,
54
+ kernel_size: int = 3,
55
+ bias: bool = True,
56
+ ):
57
+ super(SemanticEncoder, self).__init__()
58
+
59
+ # 初始卷积,将 input_channels 映射到 encode_channels
60
+ self.initial_conv = nn.Conv1d(
61
+ in_channels=input_channels,
62
+ out_channels=encode_channels,
63
+ kernel_size=kernel_size,
64
+ stride=1,
65
+ padding=(kernel_size - 1) // 2,
66
+ bias=False,
67
+ )
68
+
69
+ # 残差块
70
+ self.residual_blocks = nn.Sequential(
71
+ nn.ReLU(inplace=True),
72
+ nn.Conv1d(
73
+ encode_channels,
74
+ encode_channels,
75
+ kernel_size=kernel_size,
76
+ stride=1,
77
+ padding=(kernel_size - 1) // 2,
78
+ bias=bias,
79
+ ),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv1d(
82
+ encode_channels,
83
+ encode_channels,
84
+ kernel_size=kernel_size,
85
+ stride=1,
86
+ padding=(kernel_size - 1) // 2,
87
+ bias=bias,
88
+ ),
89
+ )
90
+
91
+ # 最终卷积,将 encode_channels 映射到 code_dim
92
+ self.final_conv = nn.Conv1d(
93
+ in_channels=encode_channels,
94
+ out_channels=code_dim,
95
+ kernel_size=kernel_size,
96
+ stride=1,
97
+ padding=(kernel_size - 1) // 2,
98
+ bias=False,
99
+ )
100
+
101
+ def forward(self, x):
102
+ """
103
+ 前向传播方法。
104
+
105
+ Args:
106
+ x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length)
107
+
108
+ Returns:
109
+ Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length)
110
+ """
111
+ x = self.initial_conv(x) # (Batch, Encode_channels, Length)
112
+ x = self.residual_blocks(x) + x # 残差连接
113
+ x = self.final_conv(x) # (Batch, Code_dim, Length)
114
+ return x
setup.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+
4
+ setup(
5
+ name='neucodec',
6
+ version='0.0.1',
7
+ description='A package for neucodec, based on xcodec2.',
8
+ long_description_content_type='text/markdown',
9
+ author='Harry Julian',
10
+ author_email='[email protected]',
11
+ packages=find_packages(),
12
+ install_requires=[
13
+ 'librosa',
14
+ 'soundfile',
15
+ 'numpy>=2.0.2',
16
+ 'omegaconf>=2.3.0',
17
+ 'torch>=2.5.1',
18
+ 'torchaudio>=2.5.1',
19
+ 'torchao>=0.5.0',
20
+ 'torchtune>=0.3.1',
21
+ 'vector-quantize-pytorch>=1.17.8',
22
+ 'rotary-embedding-torch>=0.8.4',
23
+ 'transformers>=4.44.2',
24
+ 'boto3>1.0',
25
+ 'tqdm',
26
+ ],
27
+ classifiers=[
28
+ 'Programming Language :: Python',
29
+ 'Programming Language :: Python :: 3',
30
+ 'Programming Language :: Python :: 3.10',
31
+ ],
32
+ )
tests/__init__.py ADDED
File without changes
tests/test_neucodec.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+ import torchaudio
4
+ import librosa
5
+ from xcodec2 import XCodec2, MiniXCodec2Encoder
6
+
7
+
8
+ @pytest.fixture
9
+ def model_16khz():
10
+ return XCodec2.from_cache("16khz")
11
+
12
+
13
+ @pytest.fixture
14
+ def model_24khz():
15
+ return XCodec2.from_cache("24khz")
16
+
17
+
18
+ @pytest.fixture
19
+ def model_asr_encoder():
20
+ return MiniXCodec2Encoder.from_cache()
21
+
22
+
23
+ @pytest.fixture
24
+ def example_audio():
25
+ y, sr = torchaudio.load(librosa.ex("libri1"))
26
+ return y, sr
27
+
28
+
29
+ @pytest.fixture
30
+ def example_fpath():
31
+ return librosa.ex("libri1")
32
+
33
+
34
+ @pytest.fixture
35
+ def batch_fpaths():
36
+ return [librosa.ex("libri1"), librosa.ex("libri2")]
37
+
38
+
39
+ def load_and_validate_audio(save_path, sample_rate):
40
+ _, sr = torchaudio.load(save_path)
41
+ assert sr == sample_rate
42
+
43
+
44
+ def test_16khz_autoencode(example_fpath, tmp_path, model_16khz):
45
+ save_path = str(tmp_path / "0.wav")
46
+ model_16khz.autoencode(example_fpath, save_path)
47
+ load_and_validate_audio(save_path, 16_000)
48
+
49
+
50
+ def test_24khz_autoencode(example_fpath, tmp_path, model_24khz):
51
+ save_path = str(tmp_path / "0.wav")
52
+ model_24khz.autoencode(example_fpath, save_path)
53
+ load_and_validate_audio(save_path, 24_000)
54
+
55
+
56
+ def test_24khz_encode_decode_single(example_audio, model_24khz):
57
+ y, sr = example_audio
58
+ if sr != 16_000:
59
+ y = torchaudio.transforms.Resample(sr, 16_000)(y)
60
+ sr = 16_000
61
+
62
+ # encode
63
+ vq_codes = model_24khz.encode_code(y, sample_rate=sr)
64
+ assert isinstance(vq_codes, torch.Tensor)
65
+ assert vq_codes.dim() == 3 # [batch, channels, time]
66
+
67
+ # decode
68
+ reconstructed = model_24khz.decode_code(vq_codes)
69
+ assert isinstance(reconstructed, torch.Tensor)
70
+ assert reconstructed.dim() == 3 # [batch, channels, time]
71
+
72
+
73
+ def test_24khz_batch_encode(batch_fpaths, model_24khz):
74
+ vq_codes_list, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=False)
75
+ assert isinstance(vq_codes_list, list)
76
+ assert token_durations is None
77
+ assert len(vq_codes_list) == 2
78
+
79
+ for codes in vq_codes_list:
80
+ assert isinstance(codes, torch.Tensor)
81
+ assert codes.dim() == 2 # [channels, time]
82
+
83
+
84
+ def test_24khz_batch_encode_tensor(batch_fpaths, model_24khz):
85
+ vq_codes_tensor, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=True)
86
+ assert isinstance(vq_codes_tensor, torch.Tensor)
87
+ assert isinstance(token_durations, list)
88
+ assert vq_codes_tensor.dim() == 3 # [batch, channels, time]
89
+ assert len(token_durations) == 2
90
+ assert len(set(token_durations)) == 2 # ensure we get two different durations back
91
+
92
+
93
+ def test_24khz_batch_decode(batch_fpaths, model_24khz):
94
+ vq_codes_tensor, token_durations = model_24khz.batch_encode(batch_fpaths, return_tensor=True)
95
+ reconstructed_list = model_24khz.batch_decode(vq_codes_tensor, token_durations)
96
+ assert isinstance(reconstructed_list, list)
97
+ assert len(reconstructed_list) == 2
98
+ for recon in reconstructed_list:
99
+ assert isinstance(recon, torch.Tensor)
100
+ assert recon.dim() == 2 # [channels, time]
101
+
102
+
103
+ def test_24khz_batch_decode_list_input(batch_fpaths, model_24khz):
104
+ vq_codes_list, _ = model_24khz.batch_encode(batch_fpaths, return_tensor=False)
105
+ reconstructed_list = model_24khz.batch_decode(vq_codes_list)
106
+ assert isinstance(reconstructed_list, list)
107
+ assert len(reconstructed_list) == 2
108
+ for recon in reconstructed_list:
109
+ assert isinstance(recon, torch.Tensor)
110
+ assert recon.dim() == 2 # [channels, time]
111
+
112
+
113
+ def test_24khz_batch_autoencode(batch_fpaths, tmp_path, model_24khz):
114
+ output_paths = [str(tmp_path / f"{i}.wav") for i in range(len(batch_fpaths))]
115
+ reconstructed_list = model_24khz.batch_autoencode(batch_fpaths, output_paths)
116
+ assert isinstance(reconstructed_list, list)
117
+ assert len(reconstructed_list) == 2
118
+ for i, output_path in enumerate(output_paths):
119
+ load_and_validate_audio(output_path, 24_000)
120
+
121
+
122
+ def test_asr_encoder_encode(example_audio, model_asr_encoder):
123
+ y, sr = example_audio
124
+ if sr != model_asr_encoder.sample_rate:
125
+ y = torchaudio.transforms.Resample(sr, model_asr_encoder.sample_rate)(y)
126
+ vq_codes = model_asr_encoder.encode_code(y)
127
+ assert isinstance(vq_codes, torch.Tensor)
128
+ assert vq_codes.dim() == 3