zhongfang-zhuang's picture
Upload folder using huggingface_hub
3388b80 verified
import torch
import torch.nn as nn
import torch.optim as optim
class NdLinear(nn.Module):
def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True, bias: bool = True):
"""
NdLinear: A PyTorch layer for projecting tensors into multi-space representations.
Unlike conventional embedding layers that map into a single vector space, NdLinear
transforms tensors across a collection of vector spaces, capturing multivariate structure
and topical information that standard deep learning architectures typically lose.
Args:
input_dims (tuple): Shape of input tensor (excluding batch dimension).
hidden_size (tuple): Target hidden dimensions after transformation.
"""
super(NdLinear, self).__init__()
if len(input_dims) != len(hidden_size):
raise Exception("Input shape and hidden shape do not match.")
self.input_dims = input_dims
self.hidden_size = hidden_size
self.num_layers = len(input_dims) # Must match since dims are equal
# self.relu = nn.ReLU() # self.relu is not being used.
self.transform_outer = transform_outer
self.bias = bias
# Define transformation layers per dimension
self.align_layers = nn.ModuleList([
nn.Linear(input_dims[i], hidden_size[i], bias=self.bias) for i in range(self.num_layers)
])
def forward(self, X):
"""
Forward pass to project input tensor into a new multi-space representation.
- Incrementally transposes, flattens, applies linear layers, and restores shape.
Expected Input Shape: [batch_size, *input_dims]
Output Shape: [batch_size, *hidden_size]
Args:
X (torch.Tensor): Input tensor with shape [batch_size, *input_dims]
Returns:
torch.Tensor: Output tensor with shape [batch_size, *hidden_size]
"""
num_transforms = self.num_layers # Number of transformations
# Define iteration order
# transform_indices = range(num_transforms) if transform_outer else reversed(range(num_transforms))
for i in range(num_transforms):
if self.transform_outer:
layer = self.align_layers[i]
transpose_dim = i + 1
else:
layer = self.align_layers[num_transforms - (i+1)]
transpose_dim = num_transforms - i
# Transpose the selected dimension to the last position
X = torch.transpose(X, transpose_dim, num_transforms).contiguous()
# Store original shape before transformation
X_size = X.shape[:-1]
# Flatten everything except the last dimension
X = X.view(-1, X.shape[-1])
# Apply transformation
X = layer(X)
# Reshape back to the original spatial structure (with new embedding dim)
X = X.view(*X_size, X.shape[-1])
# Transpose the dimension back to its original position
X = torch.transpose(X, transpose_dim, num_transforms).contiguous()
return X