|  |  | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch as t | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | import modules.dist as dist | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BottleneckBlock(nn.Module): | 
					
						
						|  | def __init__(self, k_bins, emb_width, mu): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.k_bins = k_bins | 
					
						
						|  | self.emb_width = emb_width | 
					
						
						|  | self.mu = mu | 
					
						
						|  | self.reset_k() | 
					
						
						|  | self.threshold = 1.0 | 
					
						
						|  |  | 
					
						
						|  | def reset_k(self): | 
					
						
						|  | self.init = False | 
					
						
						|  | self.k_sum = None | 
					
						
						|  | self.k_elem = None | 
					
						
						|  | self.register_buffer('k', t.zeros(self.k_bins, self.emb_width).cuda()) | 
					
						
						|  |  | 
					
						
						|  | def _tile(self, x): | 
					
						
						|  | d, ew = x.shape | 
					
						
						|  | if d < self.k_bins: | 
					
						
						|  | n_repeats = (self.k_bins + d - 1) // d | 
					
						
						|  | std = 0.01 / np.sqrt(ew) | 
					
						
						|  | x = x.repeat(n_repeats, 1) | 
					
						
						|  | x = x + t.randn_like(x) * std | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def init_k(self, x): | 
					
						
						|  | mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins | 
					
						
						|  | self.init = True | 
					
						
						|  |  | 
					
						
						|  | y = self._tile(x) | 
					
						
						|  | _k_rand = y[t.randperm(y.shape[0])][:k_bins] | 
					
						
						|  | dist.broadcast(_k_rand, 0) | 
					
						
						|  | self.k = _k_rand | 
					
						
						|  | assert self.k.shape == (k_bins, emb_width) | 
					
						
						|  | self.k_sum = self.k | 
					
						
						|  | self.k_elem = t.ones(k_bins, device=self.k.device) | 
					
						
						|  |  | 
					
						
						|  | def restore_k(self, num_tokens=None, threshold=1.0): | 
					
						
						|  | mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins | 
					
						
						|  | self.init = True | 
					
						
						|  | assert self.k.shape == (k_bins, emb_width) | 
					
						
						|  | self.k_sum = self.k.clone() | 
					
						
						|  | self.k_elem = t.ones(k_bins, device=self.k.device) | 
					
						
						|  | if num_tokens is not None: | 
					
						
						|  | expected_usage = num_tokens / k_bins | 
					
						
						|  | self.k_elem.data.mul_(expected_usage) | 
					
						
						|  | self.k_sum.data.mul_(expected_usage) | 
					
						
						|  | self.threshold = threshold | 
					
						
						|  |  | 
					
						
						|  | def update_k(self, x, x_l): | 
					
						
						|  | mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins | 
					
						
						|  | with t.no_grad(): | 
					
						
						|  |  | 
					
						
						|  | x_l_onehot = t.zeros(k_bins, x.shape[0], device=x.device) | 
					
						
						|  | x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1) | 
					
						
						|  |  | 
					
						
						|  | _k_sum = t.matmul(x_l_onehot, x) | 
					
						
						|  | _k_elem = x_l_onehot.sum(dim=-1) | 
					
						
						|  | y = self._tile(x) | 
					
						
						|  | _k_rand = y[t.randperm(y.shape[0])][:k_bins] | 
					
						
						|  |  | 
					
						
						|  | dist.broadcast(_k_rand, 0) | 
					
						
						|  | dist.all_reduce(_k_sum) | 
					
						
						|  | dist.all_reduce(_k_elem) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | old_k = self.k | 
					
						
						|  | self.k_sum = mu * self.k_sum + (1. - mu) * _k_sum | 
					
						
						|  | self.k_elem = mu * self.k_elem + (1. - mu) * _k_elem | 
					
						
						|  | usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float() | 
					
						
						|  | self.k = usage * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) \ | 
					
						
						|  | + (1 - usage) * _k_rand | 
					
						
						|  | _k_prob = _k_elem / t.sum(_k_elem) | 
					
						
						|  | entropy = -t.sum(_k_prob * t.log(_k_prob + 1e-8)) | 
					
						
						|  | used_curr = (_k_elem >= self.threshold).sum() | 
					
						
						|  | usage = t.sum(usage) | 
					
						
						|  | dk = t.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape)) | 
					
						
						|  | return dict(entropy=entropy, | 
					
						
						|  | used_curr=used_curr, | 
					
						
						|  | usage=usage, | 
					
						
						|  | dk=dk) | 
					
						
						|  |  | 
					
						
						|  | def preprocess(self, x): | 
					
						
						|  |  | 
					
						
						|  | x = x.permute(0, 2, 1).contiguous() | 
					
						
						|  | x = x.view(-1, x.shape[-1]) | 
					
						
						|  |  | 
					
						
						|  | if x.shape[-1] == self.emb_width: | 
					
						
						|  | prenorm = t.norm(x - t.mean(x)) / np.sqrt(np.prod(x.shape)) | 
					
						
						|  | elif x.shape[-1] == 2 * self.emb_width: | 
					
						
						|  | x1, x2 = x[..., :self.emb_width], x[..., self.emb_width:] | 
					
						
						|  | prenorm = (t.norm(x1 - t.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( | 
					
						
						|  | t.norm(x2 - t.mean(x2)) / np.sqrt(np.prod(x2.shape))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = x1 + x2 | 
					
						
						|  | else: | 
					
						
						|  | assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}" | 
					
						
						|  | return x, prenorm | 
					
						
						|  |  | 
					
						
						|  | def postprocess(self, x_l, x_d, x_shape): | 
					
						
						|  |  | 
					
						
						|  | N, T = x_shape | 
					
						
						|  | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() | 
					
						
						|  | x_l = x_l.view(N, T) | 
					
						
						|  | return x_l, x_d | 
					
						
						|  |  | 
					
						
						|  | def quantise(self, x): | 
					
						
						|  |  | 
					
						
						|  | k_w = self.k.t() | 
					
						
						|  | distance = t.sum(x ** 2, dim=-1, keepdim=True) - 2 * t.matmul(x, k_w) + t.sum(k_w ** 2, dim=0, | 
					
						
						|  | keepdim=True) | 
					
						
						|  | min_distance, x_l = t.min(distance, dim=-1) | 
					
						
						|  | fit = t.mean(min_distance) | 
					
						
						|  | return x_l, fit | 
					
						
						|  |  | 
					
						
						|  | def dequantise(self, x_l): | 
					
						
						|  | x = F.embedding(x_l, self.k) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def encode(self, x): | 
					
						
						|  | N, width, T = x.shape | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x, prenorm = self.preprocess(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_l, fit = self.quantise(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_l = x_l.view(N, T) | 
					
						
						|  | return x_l | 
					
						
						|  |  | 
					
						
						|  | def decode(self, x_l): | 
					
						
						|  | N, T = x_l.shape | 
					
						
						|  | width = self.emb_width | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_d = self.dequantise(x_l) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous() | 
					
						
						|  | return x_d | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, update_k=True): | 
					
						
						|  | N, width, T = x.shape | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x, prenorm = self.preprocess(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if update_k and not self.init: | 
					
						
						|  | self.init_k(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_l, fit = self.quantise(x) | 
					
						
						|  | x_d = self.dequantise(x_l) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if update_k and self.training: | 
					
						
						|  | update_metrics = self.update_k(x, x_l) | 
					
						
						|  | else: | 
					
						
						|  | update_metrics = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | commit_loss = t.norm(x_d.detach() - x) ** 2 / np.prod(x.shape) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_d = x + (x_d - x).detach() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_l, x_d = self.postprocess(x_l, x_d, (N, T)) | 
					
						
						|  | return x_l, x_d, commit_loss, dict(fit=fit, | 
					
						
						|  | pn=prenorm, | 
					
						
						|  | **update_metrics) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Bottleneck(nn.Module): | 
					
						
						|  | def __init__(self, l_bins, emb_width, mu, levels): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.levels = levels | 
					
						
						|  | level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu) | 
					
						
						|  | self.level_blocks = nn.ModuleList() | 
					
						
						|  | for level in range(self.levels): | 
					
						
						|  | self.level_blocks.append(level_block(level)) | 
					
						
						|  |  | 
					
						
						|  | def encode(self, xs): | 
					
						
						|  | zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)] | 
					
						
						|  | return zs | 
					
						
						|  |  | 
					
						
						|  | def decode(self, zs, start_level=0, end_level=None): | 
					
						
						|  | if end_level is None: | 
					
						
						|  | end_level = self.levels | 
					
						
						|  | xs_quantised = [level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs)] | 
					
						
						|  | return xs_quantised | 
					
						
						|  |  | 
					
						
						|  | def forward(self, xs): | 
					
						
						|  | zs, xs_quantised, commit_losses, metrics = [], [], [], [] | 
					
						
						|  | for level in range(self.levels): | 
					
						
						|  | level_block = self.level_blocks[level] | 
					
						
						|  | x = xs[level] | 
					
						
						|  | z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training) | 
					
						
						|  | zs.append(z) | 
					
						
						|  | if not self.training: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_quantised = x_quantised.detach() | 
					
						
						|  | xs_quantised.append(x_quantised) | 
					
						
						|  | commit_losses.append(commit_loss) | 
					
						
						|  | if self.training: | 
					
						
						|  | metrics.append(metric) | 
					
						
						|  | return zs, xs_quantised, commit_losses, metrics | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class NoBottleneckBlock(nn.Module): | 
					
						
						|  | def restore_k(self): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class NoBottleneck(nn.Module): | 
					
						
						|  | def __init__(self, levels): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.level_blocks = nn.ModuleList() | 
					
						
						|  | self.levels = levels | 
					
						
						|  | for level in range(levels): | 
					
						
						|  | self.level_blocks.append(NoBottleneckBlock()) | 
					
						
						|  |  | 
					
						
						|  | def encode(self, xs): | 
					
						
						|  | return xs | 
					
						
						|  |  | 
					
						
						|  | def decode(self, zs, start_level=0, end_level=None): | 
					
						
						|  | if end_level is None: | 
					
						
						|  | end_level = self.levels | 
					
						
						|  | return zs | 
					
						
						|  |  | 
					
						
						|  | def forward(self, xs): | 
					
						
						|  | zero = t.zeros(()).cuda() | 
					
						
						|  | commit_losses = [zero for _ in range(self.levels)] | 
					
						
						|  | metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)] | 
					
						
						|  | return xs, xs, commit_losses, metrics | 
					
						
						|  |  |