Daporte commited on
Commit
c199313
·
verified ·
1 Parent(s): 0e82b05

Add files from https://github.com/facebookresearch/speech-resynthesis

Browse files
Files changed (5) hide show
  1. models.py +38 -0
  2. modules/dist.py +108 -0
  3. modules/jukebox.py +178 -0
  4. modules/resnet.py +82 -0
  5. modules/vq.py +249 -0
models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://github.com/jik876/hifi-gan
2
+
3
+ from transformers.modeling_utils import PreTrainedModel
4
+
5
+ from quantizer_config import QuantizerConfig
6
+ from modules.jukebox import Encoder, Decoder
7
+ from modules.vq import Bottleneck
8
+
9
+
10
+
11
+ class Quantizer(PreTrainedModel):
12
+ config_class = QuantizerConfig
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+
17
+ self.config = config
18
+ self.encoder = Encoder(**config.f0_encoder_params)
19
+ self.vq = Bottleneck(**config.f0_vq_params)
20
+ self.decoder = Decoder(**config.f0_decoder_params)
21
+
22
+ def forward(self, **kwargs):
23
+ f0_h = self.encoder(kwargs['features'])
24
+
25
+ zs, f0_h_q, f0_commit_losses, f0_metrics = self.vq(f0_h)
26
+
27
+ f0 = self.decoder(f0_h_q)
28
+
29
+ return {
30
+ 'f0': f0,
31
+ 'commit_losses': f0_commit_losses,
32
+ 'metrics': f0_metrics,
33
+ 'codes': zs,
34
+ 'hidden_states': f0_h_q
35
+ }
36
+
37
+
38
+
modules/dist.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ from enum import Enum
4
+
5
+ import torch.distributed as dist
6
+
7
+
8
+ class ReduceOp(Enum):
9
+ SUM = 0,
10
+ PRODUCT = 1,
11
+ MIN = 2,
12
+ MAX = 3
13
+
14
+ def ToDistOp(self):
15
+ return {
16
+ self.SUM: dist.ReduceOp.SUM,
17
+ self.PRODUCT: dist.ReduceOp.PRODUCT,
18
+ self.MIN: dist.ReduceOp.MIN,
19
+ self.MAX: dist.ReduceOp.MAX
20
+ }[self]
21
+
22
+
23
+ def is_available():
24
+ return dist.is_initialized()
25
+
26
+
27
+ def get_rank():
28
+ if is_available():
29
+ return _get_rank()
30
+ else:
31
+ return 0
32
+
33
+
34
+ def get_world_size():
35
+ if is_available():
36
+ return _get_world_size()
37
+ else:
38
+ return 1
39
+
40
+
41
+ def barrier():
42
+ if is_available():
43
+ return _barrier()
44
+ # else: do nothing
45
+
46
+
47
+ def all_gather(tensor_list, tensor):
48
+ if is_available():
49
+ return _all_gather(tensor_list, tensor)
50
+ else:
51
+ tensor_list[0] = tensor
52
+
53
+
54
+ def all_reduce(tensor, op=ReduceOp.SUM):
55
+ if is_available():
56
+ return _all_reduce(tensor, op)
57
+ # else: do nothing
58
+
59
+
60
+ def reduce(tensor, dst, op=ReduceOp.SUM):
61
+ if is_available():
62
+ return _reduce(tensor, dst, op)
63
+ # else: do nothing
64
+
65
+
66
+ def broadcast(tensor, src):
67
+ if is_available():
68
+ return _broadcast(tensor, src)
69
+ # else: do nothing
70
+
71
+
72
+ def init_process_group(backend, init_method):
73
+ if is_available():
74
+ return _init_process_group(backend, init_method)
75
+ # else: do nothing
76
+
77
+
78
+ def _get_rank():
79
+ return dist.get_rank()
80
+
81
+
82
+ def _barrier():
83
+ return dist.barrier()
84
+
85
+
86
+ def _get_world_size():
87
+ return dist.get_world_size()
88
+
89
+
90
+ def _all_gather(tensor_list, tensor):
91
+ return dist.all_gather(tensor_list, tensor)
92
+
93
+
94
+ def _all_reduce(tensor, op):
95
+ return dist.all_reduce(tensor, op.ToDistOp())
96
+
97
+
98
+ def _reduce(tensor, dst, op):
99
+ return dist.reduce(tensor, dst, op.ToDistOp())
100
+
101
+
102
+ def _broadcast(tensor, src):
103
+ return dist.broadcast(tensor, src)
104
+
105
+
106
+ def _init_process_group(backend, init_method):
107
+ return dist.init_process_group(backend, init_method)
108
+
modules/jukebox.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from modules.resnet import Resnet1D
6
+
7
+
8
+ def assert_shape(x, exp_shape):
9
+ assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}"
10
+
11
+
12
+ class EncoderConvBlock(nn.Module):
13
+ def __init__(self, input_emb_width, output_emb_width, down_t, stride_t, width, depth, m_conv,
14
+ dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False):
15
+ super().__init__()
16
+ blocks = []
17
+ if type(stride_t) is tuple or type(stride_t) is list:
18
+ start = True
19
+ for s_t, d_t in zip(stride_t, down_t):
20
+ if s_t % 2 == 0:
21
+ filter_t, pad_t = s_t * 2, s_t // 2
22
+ else:
23
+ filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1
24
+ if d_t > 0:
25
+ for i in range(d_t):
26
+ block = nn.Sequential(
27
+ nn.Conv1d(input_emb_width if i == 0 and start else width, width, filter_t, s_t, pad_t),
28
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale), )
29
+ blocks.append(block)
30
+ start = False
31
+ block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
32
+ blocks.append(block)
33
+ else:
34
+ filter_t, pad_t = stride_t * 2, stride_t // 2
35
+ if down_t > 0:
36
+ for i in range(down_t):
37
+ block = nn.Sequential(
38
+ nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t),
39
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale), )
40
+ blocks.append(block)
41
+ block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
42
+ blocks.append(block)
43
+ self.model = nn.Sequential(*blocks)
44
+
45
+ def forward(self, x):
46
+ return self.model(x)
47
+
48
+
49
+ class DecoderConvBock(nn.Module):
50
+ def __init__(self, input_emb_width, output_emb_width, down_t, stride_t, width, depth, m_conv,
51
+ dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False,
52
+ reverse_decoder_dilation=False, checkpoint_res=False):
53
+ super().__init__()
54
+ blocks = []
55
+
56
+ if type(stride_t) is tuple or type(stride_t) is list:
57
+ block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
58
+ blocks.append(block)
59
+ for k, (s_t, d_t) in enumerate(zip(stride_t, down_t)):
60
+ if d_t > 0:
61
+ if s_t % 2 == 0:
62
+ filter_t, pad_t = s_t * 2, s_t // 2
63
+ else:
64
+ filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1
65
+ end = k == len(stride_t) - 1
66
+ for i in range(d_t):
67
+ block = nn.Sequential(
68
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out,
69
+ res_scale=res_scale, reverse_dilation=reverse_decoder_dilation,
70
+ checkpoint_res=checkpoint_res),
71
+ nn.ConvTranspose1d(width, input_emb_width if i == (d_t - 1) and end else width, filter_t,
72
+ s_t, pad_t))
73
+ blocks.append(block)
74
+ else:
75
+ if down_t > 0:
76
+ filter_t, pad_t = stride_t * 2, stride_t // 2
77
+ block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
78
+ blocks.append(block)
79
+ for i in range(down_t):
80
+ block = nn.Sequential(
81
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out,
82
+ res_scale=res_scale, reverse_dilation=reverse_decoder_dilation,
83
+ checkpoint_res=checkpoint_res),
84
+ nn.ConvTranspose1d(width, input_emb_width if i == (down_t - 1) else width, filter_t, stride_t,
85
+ pad_t))
86
+ blocks.append(block)
87
+ self.model = nn.Sequential(*blocks)
88
+
89
+ def forward(self, x):
90
+ return self.model(x)
91
+
92
+
93
+ class Encoder(nn.Module):
94
+ def __init__(self, input_emb_width, output_emb_width, levels, downs_t, strides_t, **block_kwargs):
95
+ super().__init__()
96
+ self.input_emb_width = input_emb_width
97
+ self.output_emb_width = output_emb_width
98
+ self.levels = levels
99
+ self.downs_t = downs_t
100
+ self.strides_t = strides_t
101
+
102
+ block_kwargs_copy = dict(**block_kwargs)
103
+ if 'reverse_decoder_dilation' in block_kwargs_copy:
104
+ del block_kwargs_copy['reverse_decoder_dilation']
105
+ level_block = lambda level, down_t, stride_t: EncoderConvBlock(
106
+ input_emb_width if level == 0 else output_emb_width, output_emb_width, down_t, stride_t,
107
+ **block_kwargs_copy)
108
+ self.level_blocks = nn.ModuleList()
109
+ iterator = zip(list(range(self.levels)), downs_t, strides_t)
110
+ for level, down_t, stride_t in iterator:
111
+ self.level_blocks.append(level_block(level, down_t, stride_t))
112
+
113
+ def forward(self, x):
114
+ N, T = x.shape[0], x.shape[-1]
115
+ emb = self.input_emb_width
116
+ assert_shape(x, (N, emb, T))
117
+ xs = []
118
+
119
+ # 64, 32, ...
120
+ iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t)
121
+ for level, down_t, stride_t in iterator:
122
+ level_block = self.level_blocks[level]
123
+ x = level_block(x)
124
+ if type(stride_t) is tuple or type(stride_t) is list:
125
+ emb, T = self.output_emb_width, T // np.prod([s ** d for s, d in zip(stride_t, down_t)])
126
+ else:
127
+ emb, T = self.output_emb_width, T // (stride_t ** down_t)
128
+ assert_shape(x, (N, emb, T))
129
+ xs.append(x)
130
+
131
+ return xs
132
+
133
+
134
+ class Decoder(nn.Module):
135
+ def __init__(self, input_emb_width, output_emb_width, levels, downs_t, strides_t, **block_kwargs):
136
+ super().__init__()
137
+ self.input_emb_width = input_emb_width
138
+ self.output_emb_width = output_emb_width
139
+ self.levels = levels
140
+
141
+ self.downs_t = downs_t
142
+
143
+ self.strides_t = strides_t
144
+
145
+ level_block = lambda level, down_t, stride_t: DecoderConvBock(output_emb_width, output_emb_width, down_t,
146
+ stride_t, **block_kwargs)
147
+ self.level_blocks = nn.ModuleList()
148
+ iterator = zip(list(range(self.levels)), downs_t, strides_t)
149
+ for level, down_t, stride_t in iterator:
150
+ self.level_blocks.append(level_block(level, down_t, stride_t))
151
+
152
+ self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1)
153
+
154
+ def forward(self, xs, all_levels=True):
155
+ if all_levels:
156
+ assert len(xs) == self.levels
157
+ else:
158
+ assert len(xs) == 1
159
+ x = xs[-1]
160
+ N, T = x.shape[0], x.shape[-1]
161
+ emb = self.output_emb_width
162
+ assert_shape(x, (N, emb, T))
163
+
164
+ # 32, 64 ...
165
+ iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t)))
166
+ for level, down_t, stride_t in iterator:
167
+ level_block = self.level_blocks[level]
168
+ x = level_block(x)
169
+ if type(stride_t) is tuple or type(stride_t) is list:
170
+ emb, T = self.output_emb_width, T * np.prod([s ** d for s, d in zip(stride_t, down_t)])
171
+ else:
172
+ emb, T = self.output_emb_width, T * (stride_t ** down_t)
173
+ assert_shape(x, (N, emb, T))
174
+ if level != 0 and all_levels:
175
+ x = x + xs[level - 1]
176
+
177
+ x = self.out(x)
178
+ return x
modules/resnet.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ import math
4
+ import torch.nn as nn
5
+
6
+ import modules.dist as dist
7
+
8
+
9
+ class ResConvBlock(nn.Module):
10
+ def __init__(self, n_in, n_state):
11
+ super().__init__()
12
+ self.model = nn.Sequential(
13
+ nn.ReLU(),
14
+ nn.Conv2d(n_in, n_state, 3, 1, 1),
15
+ nn.ReLU(),
16
+ nn.Conv2d(n_state, n_in, 1, 1, 0),
17
+ )
18
+
19
+ def forward(self, x):
20
+ return x + self.model(x)
21
+
22
+
23
+ class Resnet(nn.Module):
24
+ def __init__(self, n_in, n_depth, m_conv=1.0):
25
+ super().__init__()
26
+ self.model = nn.Sequential(*[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)])
27
+
28
+ def forward(self, x):
29
+ return self.model(x)
30
+
31
+
32
+ class ResConv1DBlock(nn.Module):
33
+ def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0):
34
+ super().__init__()
35
+ padding = dilation
36
+ self.model = nn.Sequential(
37
+ nn.ReLU(),
38
+ nn.Conv1d(n_in, n_state, 3, 1, padding, dilation),
39
+ nn.ReLU(),
40
+ nn.Conv1d(n_state, n_in, 1, 1, 0),
41
+ )
42
+ if zero_out:
43
+ out = self.model[-1]
44
+ nn.init.zeros_(out.weight)
45
+ nn.init.zeros_(out.bias)
46
+ self.res_scale = res_scale
47
+
48
+ def forward(self, x):
49
+ return x + self.res_scale * self.model(x)
50
+
51
+
52
+ class Resnet1D(nn.Module):
53
+ def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False,
54
+ res_scale=False, reverse_dilation=False, checkpoint_res=False):
55
+ super().__init__()
56
+
57
+ def _get_depth(depth):
58
+ if dilation_cycle is None:
59
+ return depth
60
+ else:
61
+ return depth % dilation_cycle
62
+
63
+ blocks = [ResConv1DBlock(n_in, int(m_conv * n_in),
64
+ dilation=dilation_growth_rate ** _get_depth(depth),
65
+ zero_out=zero_out,
66
+ res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth))
67
+ for depth in range(n_depth)]
68
+ if reverse_dilation:
69
+ blocks = blocks[::-1]
70
+ self.checkpoint_res = checkpoint_res
71
+ if self.checkpoint_res == 1:
72
+ if dist.get_rank() == 0:
73
+ print("Checkpointing convs")
74
+ self.blocks = nn.ModuleList(blocks)
75
+ else:
76
+ self.model = nn.Sequential(*blocks)
77
+
78
+ def forward(self, x):
79
+ if self.checkpoint_res == 1:
80
+ raise NotImplementedError("Checkpoint not implemented")
81
+ else:
82
+ return self.model(x)
modules/vq.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ import numpy as np
4
+ import torch as t
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import modules.dist as dist
9
+
10
+
11
+ class BottleneckBlock(nn.Module):
12
+ def __init__(self, k_bins, emb_width, mu):
13
+ super().__init__()
14
+ self.k_bins = k_bins
15
+ self.emb_width = emb_width
16
+ self.mu = mu
17
+ self.reset_k()
18
+ self.threshold = 1.0
19
+
20
+ def reset_k(self):
21
+ self.init = False
22
+ self.k_sum = None
23
+ self.k_elem = None
24
+ self.register_buffer('k', t.zeros(self.k_bins, self.emb_width).cuda())
25
+
26
+ def _tile(self, x):
27
+ d, ew = x.shape
28
+ if d < self.k_bins:
29
+ n_repeats = (self.k_bins + d - 1) // d
30
+ std = 0.01 / np.sqrt(ew)
31
+ x = x.repeat(n_repeats, 1)
32
+ x = x + t.randn_like(x) * std
33
+ return x
34
+
35
+ def init_k(self, x):
36
+ mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
37
+ self.init = True
38
+ # init k_w using random vectors from x
39
+ y = self._tile(x)
40
+ _k_rand = y[t.randperm(y.shape[0])][:k_bins]
41
+ dist.broadcast(_k_rand, 0)
42
+ self.k = _k_rand
43
+ assert self.k.shape == (k_bins, emb_width)
44
+ self.k_sum = self.k
45
+ self.k_elem = t.ones(k_bins, device=self.k.device)
46
+
47
+ def restore_k(self, num_tokens=None, threshold=1.0):
48
+ mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
49
+ self.init = True
50
+ assert self.k.shape == (k_bins, emb_width)
51
+ self.k_sum = self.k.clone()
52
+ self.k_elem = t.ones(k_bins, device=self.k.device)
53
+ if num_tokens is not None:
54
+ expected_usage = num_tokens / k_bins
55
+ self.k_elem.data.mul_(expected_usage)
56
+ self.k_sum.data.mul_(expected_usage)
57
+ self.threshold = threshold
58
+
59
+ def update_k(self, x, x_l):
60
+ mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
61
+ with t.no_grad():
62
+ # Calculate new centres
63
+ x_l_onehot = t.zeros(k_bins, x.shape[0], device=x.device) # k_bins, N * L
64
+ x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1)
65
+
66
+ _k_sum = t.matmul(x_l_onehot, x) # k_bins, w
67
+ _k_elem = x_l_onehot.sum(dim=-1) # k_bins
68
+ y = self._tile(x)
69
+ _k_rand = y[t.randperm(y.shape[0])][:k_bins]
70
+
71
+ dist.broadcast(_k_rand, 0)
72
+ dist.all_reduce(_k_sum)
73
+ dist.all_reduce(_k_elem)
74
+
75
+ # Update centres
76
+ old_k = self.k
77
+ self.k_sum = mu * self.k_sum + (1. - mu) * _k_sum # w, k_bins
78
+ self.k_elem = mu * self.k_elem + (1. - mu) * _k_elem # k_bins
79
+ usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float()
80
+ self.k = usage * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) \
81
+ + (1 - usage) * _k_rand
82
+ _k_prob = _k_elem / t.sum(_k_elem) # x_l_onehot.mean(dim=-1) # prob of each bin
83
+ entropy = -t.sum(_k_prob * t.log(_k_prob + 1e-8)) # entropy ie how diverse
84
+ used_curr = (_k_elem >= self.threshold).sum()
85
+ usage = t.sum(usage)
86
+ dk = t.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape))
87
+ return dict(entropy=entropy,
88
+ used_curr=used_curr,
89
+ usage=usage,
90
+ dk=dk)
91
+
92
+ def preprocess(self, x):
93
+ # NCT -> NTC -> [NT, C]
94
+ x = x.permute(0, 2, 1).contiguous()
95
+ x = x.view(-1, x.shape[-1]) # x_en = (N * L, w), k_j = (w, k_bins)
96
+
97
+ if x.shape[-1] == self.emb_width:
98
+ prenorm = t.norm(x - t.mean(x)) / np.sqrt(np.prod(x.shape))
99
+ elif x.shape[-1] == 2 * self.emb_width:
100
+ x1, x2 = x[..., :self.emb_width], x[..., self.emb_width:]
101
+ prenorm = (t.norm(x1 - t.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
102
+ t.norm(x2 - t.mean(x2)) / np.sqrt(np.prod(x2.shape)))
103
+
104
+ # Normalise
105
+ x = x1 + x2
106
+ else:
107
+ assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}"
108
+ return x, prenorm
109
+
110
+ def postprocess(self, x_l, x_d, x_shape):
111
+ # [NT, C] -> NTC -> NCT
112
+ N, T = x_shape
113
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
114
+ x_l = x_l.view(N, T)
115
+ return x_l, x_d
116
+
117
+ def quantise(self, x):
118
+ # Calculate latent code x_l
119
+ k_w = self.k.t()
120
+ distance = t.sum(x ** 2, dim=-1, keepdim=True) - 2 * t.matmul(x, k_w) + t.sum(k_w ** 2, dim=0,
121
+ keepdim=True) # (N * L, b)
122
+ min_distance, x_l = t.min(distance, dim=-1)
123
+ fit = t.mean(min_distance)
124
+ return x_l, fit
125
+
126
+ def dequantise(self, x_l):
127
+ x = F.embedding(x_l, self.k)
128
+ return x
129
+
130
+ def encode(self, x):
131
+ N, width, T = x.shape
132
+
133
+ # Preprocess.
134
+ x, prenorm = self.preprocess(x)
135
+
136
+ # Quantise
137
+ x_l, fit = self.quantise(x)
138
+
139
+ # Postprocess.
140
+ x_l = x_l.view(N, T)
141
+ return x_l
142
+
143
+ def decode(self, x_l):
144
+ N, T = x_l.shape
145
+ width = self.emb_width
146
+
147
+ # Dequantise
148
+ x_d = self.dequantise(x_l)
149
+
150
+ # Postprocess
151
+ x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous()
152
+ return x_d
153
+
154
+ def forward(self, x, update_k=True):
155
+ N, width, T = x.shape
156
+
157
+ # Preprocess
158
+ x, prenorm = self.preprocess(x)
159
+
160
+ # Init k if not inited
161
+ if update_k and not self.init:
162
+ self.init_k(x)
163
+
164
+ # Quantise and dequantise through bottleneck
165
+ x_l, fit = self.quantise(x)
166
+ x_d = self.dequantise(x_l)
167
+
168
+ # Update embeddings
169
+ if update_k and self.training:
170
+ update_metrics = self.update_k(x, x_l)
171
+ else:
172
+ update_metrics = {}
173
+
174
+ # Loss
175
+ commit_loss = t.norm(x_d.detach() - x) ** 2 / np.prod(x.shape)
176
+
177
+ # Passthrough
178
+ x_d = x + (x_d - x).detach()
179
+
180
+ # Postprocess
181
+ x_l, x_d = self.postprocess(x_l, x_d, (N, T))
182
+ return x_l, x_d, commit_loss, dict(fit=fit,
183
+ pn=prenorm,
184
+ **update_metrics)
185
+
186
+
187
+ class Bottleneck(nn.Module):
188
+ def __init__(self, l_bins, emb_width, mu, levels):
189
+ super().__init__()
190
+ self.levels = levels
191
+ level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu)
192
+ self.level_blocks = nn.ModuleList()
193
+ for level in range(self.levels):
194
+ self.level_blocks.append(level_block(level))
195
+
196
+ def encode(self, xs):
197
+ zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)]
198
+ return zs
199
+
200
+ def decode(self, zs, start_level=0, end_level=None):
201
+ if end_level is None:
202
+ end_level = self.levels
203
+ xs_quantised = [level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs)]
204
+ return xs_quantised
205
+
206
+ def forward(self, xs):
207
+ zs, xs_quantised, commit_losses, metrics = [], [], [], []
208
+ for level in range(self.levels):
209
+ level_block = self.level_blocks[level]
210
+ x = xs[level]
211
+ z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training)
212
+ zs.append(z)
213
+ if not self.training:
214
+ # Be extra paranoid and make sure the encoder weights can't
215
+ # change from straight-through estimator
216
+ x_quantised = x_quantised.detach()
217
+ xs_quantised.append(x_quantised)
218
+ commit_losses.append(commit_loss)
219
+ if self.training:
220
+ metrics.append(metric)
221
+ return zs, xs_quantised, commit_losses, metrics
222
+
223
+
224
+ class NoBottleneckBlock(nn.Module):
225
+ def restore_k(self):
226
+ pass
227
+
228
+
229
+ class NoBottleneck(nn.Module):
230
+ def __init__(self, levels):
231
+ super().__init__()
232
+ self.level_blocks = nn.ModuleList()
233
+ self.levels = levels
234
+ for level in range(levels):
235
+ self.level_blocks.append(NoBottleneckBlock())
236
+
237
+ def encode(self, xs):
238
+ return xs
239
+
240
+ def decode(self, zs, start_level=0, end_level=None):
241
+ if end_level is None:
242
+ end_level = self.levels
243
+ return zs
244
+
245
+ def forward(self, xs):
246
+ zero = t.zeros(()).cuda()
247
+ commit_losses = [zero for _ in range(self.levels)]
248
+ metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)]
249
+ return xs, xs, commit_losses, metrics