import torch import torch.nn as nn import torch.nn.functional as F import math class LoRALayer(): def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights class Linear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) merge_weights: bool = True, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w nn.Linear.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.merged = True def forward(self, x: torch.Tensor): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w if self.r > 0 and not self.merged: result = F.linear(x, T(self.weight), bias=self.bias) if self.r > 0: result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return F.linear(x, T(self.weight), bias=self.bias) class Conv2d(nn.Conv2d, LoRALayer): # LoRA implemented in a dense layer def __init__( self, in_channels: int, out_channels: int, kernel_size: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., merge_weights: bool = True, **kwargs ): nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) # assert type(kernel_size) is int if type(kernel_size) is tuple: temp_ks = kernel_size[0] # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.weight.new_zeros((r*temp_ks, in_channels*temp_ks)) ) self.lora_B = nn.Parameter( self.weight.new_zeros((out_channels*temp_ks, r*temp_ks)) ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() def reset_parameters(self): nn.Conv2d.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): nn.Conv2d.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling self.merged = True def forward(self, x: torch.Tensor): if self.r > 0 and not self.merged: return F.conv2d( x, self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, self.bias, self.stride, self.padding, self.dilation, self.groups ) return nn.Conv2d.forward(self, x) def wrap_model_with_lora(module, rank=4): for name, child in module.named_children(): if isinstance(child, (Linear, Conv2d)): continue if 'stitch' in name: pass if isinstance(child, nn.Linear): setattr(module, name, Linear(in_features=child.in_features, out_features=child.out_features, bias=child.bias is not None, r=rank)) elif isinstance(child, nn.Conv2d): setattr(module, name, Conv2d(in_channels=child.in_channels, out_channels=child.out_channels, kernel_size=child.kernel_size, stride=child.stride, padding=child.padding, dilation=child.dilation, groups=child.groups, bias=child.bias is not None, r=rank)) else: wrap_model_with_lora(child, rank)