File size: 3,190 Bytes
03941e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import torch.optim as optim


class NdLinear(nn.Module):
    def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True, bias: bool = True):
        """
        NdLinear: A PyTorch layer for projecting tensors into multi-space representations.
        
        Unlike conventional embedding layers that map into a single vector space, NdLinear 
        transforms tensors across a collection of vector spaces, capturing multivariate structure 
        and topical information that standard deep learning architectures typically lose.

        Args:
            input_dims (tuple): Shape of input tensor (excluding batch dimension).
            hidden_size (tuple): Target hidden dimensions after transformation.
        """
        super(NdLinear, self).__init__()

        if len(input_dims) != len(hidden_size):
            raise Exception("Input shape and hidden shape do not match.")

        self.input_dims = input_dims
        self.hidden_size = hidden_size
        self.num_layers = len(input_dims)  # Must match since dims are equal
        # self.relu = nn.ReLU() # self.relu is not being used.
        self.transform_outer = transform_outer
        self.bias = bias

        # Define transformation layers per dimension
        self.align_layers = nn.ModuleList([
            nn.Linear(input_dims[i], hidden_size[i], bias=self.bias) for i in range(self.num_layers)
        ])


    def forward(self, X):
        """
        Forward pass to project input tensor into a new multi-space representation.
        - Incrementally transposes, flattens, applies linear layers, and restores shape.

        Expected Input Shape: [batch_size, *input_dims]
        Output Shape: [batch_size, *hidden_size]

        Args:
            X (torch.Tensor): Input tensor with shape [batch_size, *input_dims]

        Returns:
            torch.Tensor: Output tensor with shape [batch_size, *hidden_size]
        """
        num_transforms = self.num_layers  # Number of transformations
        
        # Define iteration order
        # transform_indices = range(num_transforms) if transform_outer else reversed(range(num_transforms))

        for i in range(num_transforms):
            if self.transform_outer:
                layer = self.align_layers[i]
                transpose_dim = i + 1
            else:
                layer = self.align_layers[num_transforms - (i+1)]
                transpose_dim = num_transforms - i

            # Transpose the selected dimension to the last position
            X = torch.transpose(X, transpose_dim, num_transforms).contiguous()

            # Store original shape before transformation
            X_size = X.shape[:-1]

            # Flatten everything except the last dimension
            X = X.view(-1, X.shape[-1])

            # Apply transformation
            X = layer(X)
            
            # Reshape back to the original spatial structure (with new embedding dim)
            X = X.view(*X_size, X.shape[-1])

            # Transpose the dimension back to its original position
            X = torch.transpose(X, transpose_dim, num_transforms).contiguous()

        return X