Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn | |
| from model.cache_utils import FLACache | |
| from model.config import PlayHeadConfig | |
| from model.prediction_head import CircularHead, LogitsHead | |
| from model.simple_gla import SimpleGLABlock | |
| class PlayHead(nn.Module): | |
| def __init__(self, cfg: PlayHeadConfig): | |
| super().__init__() | |
| self.cycle_len = cfg.cycle_len | |
| self.num_sink_tokens = cfg.num_sink_tokens | |
| self.pos_embedding = nn.Embedding(cfg.num_sink_tokens + cfg.cycle_len, cfg.dim) | |
| self.avg_pool_stride = cfg.avg_pool_stride | |
| self.net = nn.ModuleList( | |
| [ | |
| SimpleGLABlock( | |
| dim=cfg.dim, | |
| num_heads=cfg.dim // 128, | |
| layer_idx=i, | |
| expand_k=0.5, | |
| expand_v=1.0, | |
| use_short_conv=True, | |
| ffn_expansion_factor=4, | |
| ) | |
| for i in range(cfg.num_layers) | |
| ] | |
| ) | |
| self.logits_head = ( | |
| LogitsHead(cfg.dim, cfg.cycle_len) if cfg.logits_head else None | |
| ) | |
| self.circular_head = CircularHead(cfg.dim) if cfg.circular_head else None | |
| def forward( | |
| self, | |
| cross_attention_weights: torch.Tensor, | |
| target: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| ): | |
| B, T, A = cross_attention_weights.shape | |
| # if self.cross_attention_reduction == "sum": | |
| # cross_attention_weights = cross_attention_weights.sum(1) | |
| device = cross_attention_weights.device | |
| pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len | |
| sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len | |
| sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None]) | |
| x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd | |
| for block in self.net: | |
| x = block(x) | |
| losses = dict() | |
| if self.logits_head is not None: | |
| losses |= self.logits_head.compute_loss(x, target.long(), mask=mask) | |
| if self.circular_head is not None: | |
| losses |= self.circular_head.compute_loss(x, target, mask=mask) | |
| return losses | |
| def init_cache(self): | |
| return FLACache(num_states=len(self.net)) | |
| def predict( | |
| self, | |
| cross_attention_weights: torch.Tensor, | |
| previous_position: torch.Tensor | None = None, | |
| cache: FLACache | None = None, | |
| ): | |
| avg_pool_ca = torch.nn.functional.avg_pool1d( | |
| cross_attention_weights[:, self.num_sink_tokens :].transpose(-1, -2), | |
| self.avg_pool_stride, | |
| stride=self.avg_pool_stride, | |
| ceil_mode=True, | |
| ).transpose(-1, -2) | |
| sink_ca = cross_attention_weights[:, : self.num_sink_tokens] | |
| cross_attention_weights = torch.cat((sink_ca, avg_pool_ca), dim=1) | |
| B, T, A = cross_attention_weights.shape | |
| device = cross_attention_weights.device | |
| pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len | |
| sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len | |
| sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None]) | |
| x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd | |
| for block in self.net: | |
| x = block(x, cache=cache) | |
| if self.logits_head is not None: | |
| logits = self.logits_head(x) | |
| pred_position = torch.argmax(logits, -1) | |
| if previous_position is not None: | |
| current_angle, previous_angle = map( | |
| lambda x: torch.exp(1j * 2 * torch.pi * x / self.cycle_len), | |
| (pred_position, previous_position), | |
| ) | |
| diff = current_angle / previous_angle | |
| step = (diff.angle() / (2 * torch.pi / self.cycle_len)).round().long() | |
| return pred_position, step | |
| return pred_position | |