|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Sequence |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from .components.init import small_init_init_ |
|
from .utils import WeightDecayOptimGroupMixin |
|
from .xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig |
|
|
|
|
|
@dataclass |
|
class xLSTMLMModelConfig(xLSTMBlockStackConfig): |
|
vocab_size: int = -1 |
|
tie_weights: bool = False |
|
weight_decay_on_embedding: bool = False |
|
add_embedding_dropout: bool = False |
|
|
|
|
|
class xLSTMLMModel(WeightDecayOptimGroupMixin, nn.Module): |
|
config_class = xLSTMLMModelConfig |
|
|
|
def __init__(self, config: xLSTMLMModelConfig, **kwargs): |
|
super().__init__() |
|
self.config = config |
|
|
|
self.xlstm_block_stack = xLSTMBlockStack(config=config) |
|
self.token_embedding = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim) |
|
self.emb_dropout = nn.Dropout(config.dropout) if config.add_embedding_dropout else nn.Identity() |
|
|
|
self.lm_head = nn.Linear( |
|
in_features=config.embedding_dim, |
|
out_features=config.vocab_size, |
|
bias=False, |
|
) |
|
if config.tie_weights: |
|
self.lm_head.weight = self.token_embedding.weight |
|
|
|
def reset_parameters(self): |
|
self.xlstm_block_stack.reset_parameters() |
|
|
|
small_init_init_(self.token_embedding.weight, dim=self.config.embedding_dim) |
|
|
|
if not self.config.tie_weights: |
|
small_init_init_(self.lm_head.weight, dim=self.config.embedding_dim) |
|
|
|
def forward(self, idx: torch.Tensor) -> torch.Tensor: |
|
x = self.token_embedding(idx) |
|
x = self.emb_dropout(x) |
|
x = self.xlstm_block_stack(x) |
|
logits = self.lm_head(x) |
|
return logits |
|
|
|
def step( |
|
self, idx: torch.Tensor, state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None, **kwargs |
|
) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]: |
|
x = self.token_embedding(idx) |
|
x = self.emb_dropout(x) |
|
x, state = self.xlstm_block_stack.step(x, state=state, **kwargs) |
|
logits = self.lm_head(x) |
|
return logits, state |
|
|
|
def _create_weight_decay_optim_groups(self, **kwargs) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]: |
|
weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups(**kwargs) |
|
|
|
weight_decay = list(weight_decay) |
|
removed = 0 |
|
for idx in range(len(weight_decay)): |
|
if weight_decay[idx - removed] is self.token_embedding.weight: |
|
weight_decay.pop(idx - removed) |
|
removed += 1 |
|
weight_decay = tuple(weight_decay) |
|
if self.config.weight_decay_on_embedding: |
|
weight_decay += (self.token_embedding.weight,) |
|
else: |
|
no_weight_decay += (self.token_embedding.weight,) |
|
|
|
return weight_decay, no_weight_decay |
|
|