|
import torch |
|
import torch.nn as nn |
|
from ndlinear import NdLinear |
|
|
|
|
|
|
|
|
|
def find_factor(n): |
|
"""Finds the most balanced integer factors for n.""" |
|
for i in range(int(n ** 0.5), 0, -1): |
|
if n % i == 0: |
|
return (i, n // i) |
|
return (1, n) |
|
|
|
class NdLinearLoRA(nn.Module): |
|
"""The NdLinear-LoRA adapter layer.""" |
|
def __init__(self, d_in, d_out, alpha=1.0, dropout=0.0): |
|
super().__init__() |
|
self.d_in = d_in |
|
self.d_out = d_out |
|
self.in_factors = find_factor(d_in) |
|
self.out_factors = find_factor(d_out) |
|
self.adapter = NdLinear( |
|
input_dims=self.in_factors, |
|
hidden_size=self.out_factors, |
|
transform_outer=False, |
|
bias=False |
|
) |
|
self.scaling = alpha |
|
self.drop = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
orig_shape = x.shape |
|
x = self.drop(x).view(-1, *self.in_factors) |
|
y = self.adapter(x).view(*orig_shape[:-1], self.d_out) |
|
return y * self.scaling |
|
|
|
class LinearWithNdLinearLoRA(nn.Module): |
|
"""A nn.Linear layer wrapped with the NdLinear-LoRA adapter.""" |
|
def __init__(self, base_layer, alpha=1.0, dropout=0.0): |
|
super().__init__() |
|
self.base_layer = base_layer |
|
for param in self.base_layer.parameters(): |
|
param.requires_grad = False |
|
self.adapter = NdLinearLoRA( |
|
d_in=self.base_layer.in_features, |
|
d_out=self.base_layer.out_features, |
|
alpha=alpha, |
|
dropout=dropout |
|
) |
|
|
|
def forward(self, x): |
|
return self.base_layer(x) + self.adapter(x) |