Spaces:
Runtime error
Runtime error
File size: 6,439 Bytes
412c852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LoRALayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class Linear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
class Conv2d(nn.Conv2d, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
merge_weights: bool = True,
**kwargs
):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
# assert type(kernel_size) is int
if type(kernel_size) is tuple:
temp_ks = kernel_size[0]
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(
self.weight.new_zeros((r*temp_ks, in_channels*temp_ks))
)
self.lora_B = nn.Parameter(
self.weight.new_zeros((out_channels*temp_ks, r*temp_ks))
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Conv2d.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
nn.Conv2d.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
return F.conv2d(
x,
self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
self.bias, self.stride, self.padding, self.dilation, self.groups
)
return nn.Conv2d.forward(self, x)
def wrap_model_with_lora(module, rank=4):
for name, child in module.named_children():
if isinstance(child, (Linear, Conv2d)):
continue
if 'stitch' in name:
pass
if isinstance(child, nn.Linear):
setattr(module, name, Linear(in_features=child.in_features, out_features=child.out_features, bias=child.bias is not None, r=rank))
elif isinstance(child, nn.Conv2d):
setattr(module, name, Conv2d(in_channels=child.in_channels, out_channels=child.out_channels, kernel_size=child.kernel_size, stride=child.stride, padding=child.padding, dilation=child.dilation, groups=child.groups, bias=child.bias is not None, r=rank))
else:
wrap_model_with_lora(child, rank)
|