Qwen3-1.7B-CSQA-NdLinearLoRA / modeling_ndlinear.py
zhongfang-zhuang's picture
Upload folder using huggingface_hub
3388b80 verified
import torch
import torch.nn as nn
from ndlinear import NdLinear
# This file contains the custom building blocks for the NdLinear-LoRA architecture.
# It should be in the same directory as the model when loading it.
def find_factor(n):
"""Finds the most balanced integer factors for n."""
for i in range(int(n ** 0.5), 0, -1):
if n % i == 0:
return (i, n // i)
return (1, n)
class NdLinearLoRA(nn.Module):
"""The NdLinear-LoRA adapter layer."""
def __init__(self, d_in, d_out, alpha=1.0, dropout=0.0):
super().__init__()
self.d_in = d_in
self.d_out = d_out
self.in_factors = find_factor(d_in)
self.out_factors = find_factor(d_out)
self.adapter = NdLinear(
input_dims=self.in_factors,
hidden_size=self.out_factors,
transform_outer=False,
bias=False
)
self.scaling = alpha
self.drop = nn.Dropout(dropout)
def forward(self, x):
orig_shape = x.shape
x = self.drop(x).view(-1, *self.in_factors)
y = self.adapter(x).view(*orig_shape[:-1], self.d_out)
return y * self.scaling
class LinearWithNdLinearLoRA(nn.Module):
"""A nn.Linear layer wrapped with the NdLinear-LoRA adapter."""
def __init__(self, base_layer, alpha=1.0, dropout=0.0):
super().__init__()
self.base_layer = base_layer
for param in self.base_layer.parameters():
param.requires_grad = False
self.adapter = NdLinearLoRA(
d_in=self.base_layer.in_features,
d_out=self.base_layer.out_features,
alpha=alpha,
dropout=dropout
)
def forward(self, x):
return self.base_layer(x) + self.adapter(x)