|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
|
|
class NdLinear(nn.Module): |
|
def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True, bias: bool = True): |
|
""" |
|
NdLinear: A PyTorch layer for projecting tensors into multi-space representations. |
|
|
|
Unlike conventional embedding layers that map into a single vector space, NdLinear |
|
transforms tensors across a collection of vector spaces, capturing multivariate structure |
|
and topical information that standard deep learning architectures typically lose. |
|
|
|
Args: |
|
input_dims (tuple): Shape of input tensor (excluding batch dimension). |
|
hidden_size (tuple): Target hidden dimensions after transformation. |
|
""" |
|
super(NdLinear, self).__init__() |
|
|
|
if len(input_dims) != len(hidden_size): |
|
raise Exception("Input shape and hidden shape do not match.") |
|
|
|
self.input_dims = input_dims |
|
self.hidden_size = hidden_size |
|
self.num_layers = len(input_dims) |
|
|
|
self.transform_outer = transform_outer |
|
self.bias = bias |
|
|
|
|
|
self.align_layers = nn.ModuleList([ |
|
nn.Linear(input_dims[i], hidden_size[i], bias=self.bias) for i in range(self.num_layers) |
|
]) |
|
|
|
|
|
def forward(self, X): |
|
""" |
|
Forward pass to project input tensor into a new multi-space representation. |
|
- Incrementally transposes, flattens, applies linear layers, and restores shape. |
|
|
|
Expected Input Shape: [batch_size, *input_dims] |
|
Output Shape: [batch_size, *hidden_size] |
|
|
|
Args: |
|
X (torch.Tensor): Input tensor with shape [batch_size, *input_dims] |
|
|
|
Returns: |
|
torch.Tensor: Output tensor with shape [batch_size, *hidden_size] |
|
""" |
|
num_transforms = self.num_layers |
|
|
|
|
|
|
|
|
|
for i in range(num_transforms): |
|
if self.transform_outer: |
|
layer = self.align_layers[i] |
|
transpose_dim = i + 1 |
|
else: |
|
layer = self.align_layers[num_transforms - (i+1)] |
|
transpose_dim = num_transforms - i |
|
|
|
|
|
X = torch.transpose(X, transpose_dim, num_transforms).contiguous() |
|
|
|
|
|
X_size = X.shape[:-1] |
|
|
|
|
|
X = X.view(-1, X.shape[-1]) |
|
|
|
|
|
X = layer(X) |
|
|
|
|
|
X = X.view(*X_size, X.shape[-1]) |
|
|
|
|
|
X = torch.transpose(X, transpose_dim, num_transforms).contiguous() |
|
|
|
return X |