# Copyright (c) NXAI GmbH and its affiliates 2024 | |
# Maximilian Beck | |
from dataclasses import dataclass, field | |
from ..xlstm_block import xLSTMBlock, xLSTMBlockConfig | |
from .layer import mLSTMLayerConfig | |
class mLSTMBlockConfig: | |
mlstm: mLSTMLayerConfig = field(default_factory=mLSTMLayerConfig) | |
# we initialize these with None to catch the case where they are not set | |
_num_blocks: int = None | |
_block_idx: int = None | |
def __post_init__(self): | |
self.mlstm._num_blocks = self._num_blocks | |
self.mlstm.__post_init__() | |
class mLSTMBlock(xLSTMBlock): | |
config_class = mLSTMBlockConfig | |
def __init__(self, config: mLSTMBlockConfig) -> None: | |
super().__init__( | |
config=xLSTMBlockConfig( | |
mlstm=config.mlstm, | |
slstm=None, | |
feedforward=None, | |
_num_blocks=config._num_blocks, | |
_block_idx=config._block_idx, | |
) | |
) | |