mehdi999 commited on
Commit
d09d63c
·
1 Parent(s): 78e8994

another test

Browse files
codec/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- # codec/__init__.py minimal inference exports
2
- from .models import PatchVAE, PatchVAEConfig, WavVAE # noqa: F401
 
1
+ from .train_patchvae import TrainPatchVAE
2
+ from .train_wavvae import TrainWavVAE
codec/models/__init__.py CHANGED
@@ -1,7 +1,2 @@
1
- # codec/models/__init__.py keep inference-only imports
2
- from .patchvae.model import PatchVAE, PatchVAEConfig # noqa: F401
3
- from .wavvae.model import WavVAE # noqa: F401
4
-
5
- # IMPORTANT:
6
- # - Do NOT import pardi_tokenizer here (it references zcodec in your tree)
7
- # - Do NOT import training utilities here
 
1
+ from .patchvae.model import PatchVAE, PatchVAEConfig
2
+ from .wavvae.model import WavVAE, WavVAEConfig
 
 
 
 
 
codec/models/patchvae/modules.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
  from torch import nn
7
  from vector_quantize_pytorch import FSQ
8
 
9
- from ..components.transformer import TransformerBlock
10
 
11
 
12
  class AdaLayerNormScale(nn.Module):
 
6
  from torch import nn
7
  from vector_quantize_pytorch import FSQ
8
 
9
+ from zcodec.models.components.transformer import TransformerBlock
10
 
11
 
12
  class AdaLayerNormScale(nn.Module):
tts/model/simple_gla.py.bak DELETED
@@ -1,295 +0,0 @@
1
- import os
2
- #simple-gla
3
- import torch
4
- import torch.nn.functional as F
5
- from einops import rearrange
6
- from fla.layers.simple_gla import SimpleGatedLinearAttention
7
- from fla.models.utils import Cache
8
- from sympy import num_digits
9
- from torch import nn
10
-
11
- from tts.layers.attention import CrossAttention
12
- from tts.layers.ffn import SwiGLU
13
-
14
- from .cache_utils import FLACache
15
- from .config import SimpleGLADecoderConfig
16
- from .registry import register_decoder
17
- from .shortconv import ShortConvBlock
18
-
19
- if "GRAD_CKPT" in os.environ:
20
-
21
- def maybe_grad_ckpt(f):
22
- def grad_ckpt_f(*args, **kwargs):
23
- return torch.utils.checkpoint.checkpoint(
24
- f, *args, **kwargs, use_reentrant=False
25
- )
26
-
27
- return grad_ckpt_f
28
- else:
29
-
30
- def maybe_grad_ckpt(f):
31
- return f
32
-
33
-
34
- class SimpleGLABlock(nn.Module):
35
- def __init__(
36
- self,
37
- dim: int,
38
- num_heads: int,
39
- layer_idx: int,
40
- expand_k: float,
41
- expand_v: float,
42
- use_short_conv: bool,
43
- ffn_expansion_factor: int,
44
- ):
45
- super().__init__()
46
- self.tmix = SimpleGatedLinearAttention(
47
- hidden_size=dim,
48
- num_heads=num_heads,
49
- layer_idx=layer_idx,
50
- )
51
- self.cmix = SwiGLU(dim, ffn_expansion_factor)
52
- self.norm1 = nn.LayerNorm(dim)
53
- self.norm2 = nn.LayerNorm(dim)
54
-
55
- def forward(
56
- self,
57
- x,
58
- freqs: torch.Tensor | None = None,
59
- text_freqs: torch.Tensor | None = None,
60
- cache: Cache | None = None,
61
- ):
62
- # N’active le cache QUE s’il est utilisable (conv_state non nul)
63
- use_cache_flag = isinstance(cache, dict) and cache.get("conv_state", None) not in (None, [])
64
- pkv = cache if use_cache_flag else None
65
-
66
- x = (
67
- self.tmix(
68
- self.norm1(x),
69
- past_key_values=pkv,
70
- use_cache=use_cache_flag,
71
- )[0]
72
- + x
73
- )
74
- x = self.cmix(self.norm2(x)) + x
75
- return x
76
-
77
-
78
- class DecoderBlockWithOptionalCrossAttention(nn.Module):
79
- def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None):
80
- super().__init__()
81
-
82
- self.decoder_block = decoder_block
83
- self.crossatt = crossatt
84
-
85
- def forward(
86
- self,
87
- x: torch.Tensor,
88
- encoder_output: torch.Tensor | None = None,
89
- freqs: torch.Tensor | None = None,
90
- text_freqs: torch.Tensor | None = None,
91
- cache: Cache | None = None,
92
- selfatt_mask: torch.Tensor | None = None,
93
- crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
94
- ) -> torch.Tensor:
95
- x = self.decoder_block(
96
- x,
97
- freqs=freqs,
98
- cache=cache,
99
- )
100
- if type(crossatt_mask) is list:
101
- crossatt_mask = crossatt_mask[self.decoder_block.tmix.layer_idx]
102
- if self.crossatt is not None:
103
- x = x + self.crossatt(
104
- x,
105
- k=encoder_output,
106
- text_freqs=text_freqs,
107
- mask=crossatt_mask,
108
- cache=cache,
109
- )
110
-
111
- return x
112
-
113
-
114
- @register_decoder("simple_gla")
115
- class SimpleGLADecoder(nn.Module):
116
- config = SimpleGLADecoderConfig
117
-
118
- def __init__(self, cfg: SimpleGLADecoderConfig):
119
- super().__init__()
120
-
121
- assert cfg.dim % cfg.num_heads == 0, "num_heads should divide dim"
122
- assert cfg.blind_crossatt + (cfg.listen_read_crossatt is not None) < 2, (
123
- "at most one specialized cross-attention"
124
- )
125
-
126
- self.head_dim = cfg.dim // cfg.num_heads
127
- self.num_heads = cfg.num_heads
128
-
129
- def simple_gla_block(i):
130
- conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers
131
- if i in conv_layers:
132
- return ShortConvBlock(
133
- dim=cfg.dim,
134
- kernel_size=4,
135
- ffn_expansion_factor=cfg.ffn_expansion_factor,
136
- layer_idx=i,
137
- use_fast_conv1d=True,
138
- )
139
-
140
- else:
141
- return SimpleGLABlock(
142
- dim=cfg.dim,
143
- num_heads=cfg.num_heads,
144
- layer_idx=i,
145
- expand_k=cfg.expand_k,
146
- expand_v=cfg.expand_v,
147
- use_short_conv=cfg.use_short_conv,
148
- ffn_expansion_factor=cfg.ffn_expansion_factor,
149
- )
150
-
151
- def crossatt_block(i):
152
- if i in cfg.crossatt_layer_idx:
153
- return CrossAttention(
154
- dim=cfg.dim,
155
- num_heads=cfg.crossatt_num_heads,
156
- dropout=cfg.crossatt_dropout,
157
- layer_idx=i,
158
- )
159
- else:
160
- return None
161
-
162
- self.decoder_layers = nn.ModuleList(
163
- [
164
- DecoderBlockWithOptionalCrossAttention(
165
- simple_gla_block(i),
166
- crossatt_block(i),
167
- )
168
- for i in range(cfg.num_layers)
169
- ]
170
- )
171
-
172
- def forward(
173
- self,
174
- encoder_output: torch.Tensor,
175
- decoder_input: torch.Tensor,
176
- crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
177
- text_ids: torch.Tensor | None = None,
178
- cache: FLACache | None = None,
179
- ):
180
- x = decoder_input
181
- text_freqs = None
182
-
183
- for layer in self.decoder_layers:
184
- x = maybe_grad_ckpt(layer)(
185
- x,
186
- encoder_output,
187
- text_freqs=text_freqs,
188
- cache=cache,
189
- crossatt_mask=crossatt_mask,
190
- )
191
- return x
192
-
193
- def init_cache(self, max_seq_len, device):
194
- return FLACache(num_states=len(self.decoder_layers) + 1)
195
-
196
- def init_initial_state(self, batch_size=1, scale=1e-2, device="cpu"):
197
- return tuple(
198
- nn.Parameter(
199
- torch.randn(
200
- batch_size,
201
- self.num_heads,
202
- self.head_dim,
203
- self.head_dim,
204
- device=device,
205
- )
206
- * scale
207
- )
208
- for _ in range(len(self.decoder_layers))
209
- )
210
- def init_initial_state_lora(self, lora:int=1, batch_size: int = 1, scale: float=1e-2, device: str="cpu"):
211
- return tuple(
212
- (
213
- nn.Parameter(
214
- torch.randn(
215
- batch_size,
216
- self.num_heads,
217
- self.head_dim,
218
- lora,
219
- device=device,
220
- )
221
- * scale
222
- ),
223
- nn.Parameter(
224
- torch.randn(
225
- batch_size,
226
- self.num_heads,
227
- lora,
228
- self.head_dim,
229
- device=device,
230
- )
231
- * scale
232
- )
233
- )
234
- for _ in range(len(self.decoder_layers))
235
- )
236
-
237
- def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int):
238
- assert self.decoder_layers[layer_idx].crossatt is not None
239
- x = audio_inputs
240
- for _, layer in zip(range(layer_idx - 1), self.decoder_layers):
241
- x = layer(x, None)
242
- return self.decoder_layers[layer_idx].crossatt._query(x)
243
-
244
- def forward_first_n_layers(
245
- self,
246
- encoder_output: torch.Tensor,
247
- decoder_input: torch.Tensor,
248
- n_first_layers: int,
249
- crossatt_mask: torch.Tensor | None = None,
250
- cache: FLACache | None = None,
251
- ):
252
- x = decoder_input
253
- if self.text_freqs_embd is not None:
254
- text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :]
255
- text_freqs = self.text_freqs_embd(text_freqs)
256
- else:
257
- text_freqs = None
258
-
259
- for layer in self.decoder_layers[:n_first_layers]:
260
- x = maybe_grad_ckpt(layer)(
261
- x,
262
- encoder_output,
263
- text_freqs=text_freqs,
264
- cache=cache,
265
- crossatt_mask=crossatt_mask,
266
- )
267
- return x
268
-
269
- def prefill(
270
- self,
271
- encoder_output: torch.Tensor,
272
- decoder_input: torch.Tensor,
273
- crossatt_mask: torch.Tensor | None = None,
274
- cache: FLACache | None = None,
275
- ):
276
- return self(encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask)
277
-
278
- def decode_one(
279
- self,
280
- encoder_output: torch.Tensor,
281
- decoder_input: torch.Tensor,
282
- cache: Cache,
283
- text_freqs: torch.Tensor | None = None,
284
- crossatt_mask: torch.Tensor | None = None,
285
- ):
286
- x = decoder_input
287
- for layer in self.decoder_layers:
288
- x = layer(
289
- x,
290
- encoder_output,
291
- text_freqs=text_freqs,
292
- cache=cache,
293
- crossatt_mask=crossatt_mask,
294
- )
295
- return x