import torch import torch.nn as nn from ndlinear import NdLinear # This file contains the custom building blocks for the NdLinear-LoRA architecture. # It should be in the same directory as the model when loading it. def find_factor(n): """Finds the most balanced integer factors for n.""" for i in range(int(n ** 0.5), 0, -1): if n % i == 0: return (i, n // i) return (1, n) class NdLinearLoRA(nn.Module): """The NdLinear-LoRA adapter layer.""" def __init__(self, d_in, d_out, alpha=1.0, dropout=0.0): super().__init__() self.d_in = d_in self.d_out = d_out self.in_factors = find_factor(d_in) self.out_factors = find_factor(d_out) self.adapter = NdLinear( input_dims=self.in_factors, hidden_size=self.out_factors, transform_outer=False, bias=False ) self.scaling = alpha self.drop = nn.Dropout(dropout) def forward(self, x): orig_shape = x.shape x = self.drop(x).view(-1, *self.in_factors) y = self.adapter(x).view(*orig_shape[:-1], self.d_out) return y * self.scaling class LinearWithNdLinearLoRA(nn.Module): """A nn.Linear layer wrapped with the NdLinear-LoRA adapter.""" def __init__(self, base_layer, alpha=1.0, dropout=0.0): super().__init__() self.base_layer = base_layer for param in self.base_layer.parameters(): param.requires_grad = False self.adapter = NdLinearLoRA( d_in=self.base_layer.in_features, d_out=self.base_layer.out_features, alpha=alpha, dropout=dropout ) def forward(self, x): return self.base_layer(x) + self.adapter(x)