import torch from torch import nn from model.cache_utils import FLACache from model.config import PlayHeadConfig from model.prediction_head import CircularHead, LogitsHead from model.simple_gla import SimpleGLABlock class PlayHead(nn.Module): def __init__(self, cfg: PlayHeadConfig): super().__init__() self.cycle_len = cfg.cycle_len self.num_sink_tokens = cfg.num_sink_tokens self.pos_embedding = nn.Embedding(cfg.num_sink_tokens + cfg.cycle_len, cfg.dim) self.avg_pool_stride = cfg.avg_pool_stride self.net = nn.ModuleList( [ SimpleGLABlock( dim=cfg.dim, num_heads=cfg.dim // 128, layer_idx=i, expand_k=0.5, expand_v=1.0, use_short_conv=True, ffn_expansion_factor=4, ) for i in range(cfg.num_layers) ] ) self.logits_head = ( LogitsHead(cfg.dim, cfg.cycle_len) if cfg.logits_head else None ) self.circular_head = CircularHead(cfg.dim) if cfg.circular_head else None def forward( self, cross_attention_weights: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None, ): B, T, A = cross_attention_weights.shape # if self.cross_attention_reduction == "sum": # cross_attention_weights = cross_attention_weights.sum(1) device = cross_attention_weights.device pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None]) x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd for block in self.net: x = block(x) losses = dict() if self.logits_head is not None: losses |= self.logits_head.compute_loss(x, target.long(), mask=mask) if self.circular_head is not None: losses |= self.circular_head.compute_loss(x, target, mask=mask) return losses def init_cache(self): return FLACache(num_states=len(self.net)) def predict( self, cross_attention_weights: torch.Tensor, previous_position: torch.Tensor | None = None, cache: FLACache | None = None, ): avg_pool_ca = torch.nn.functional.avg_pool1d( cross_attention_weights[:, self.num_sink_tokens :].transpose(-1, -2), self.avg_pool_stride, stride=self.avg_pool_stride, ceil_mode=True, ).transpose(-1, -2) sink_ca = cross_attention_weights[:, : self.num_sink_tokens] cross_attention_weights = torch.cat((sink_ca, avg_pool_ca), dim=1) B, T, A = cross_attention_weights.shape device = cross_attention_weights.device pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None]) x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd for block in self.net: x = block(x, cache=cache) if self.logits_head is not None: logits = self.logits_head(x) pred_position = torch.argmax(logits, -1) if previous_position is not None: current_angle, previous_angle = map( lambda x: torch.exp(1j * 2 * torch.pi * x / self.cycle_len), (pred_position, previous_position), ) diff = current_angle / previous_angle step = (diff.angle() / (2 * torch.pi / self.cycle_len)).round().long() return pred_position, step return pred_position