Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Optional, Set, Type, Union | |
import torch | |
from torch import nn | |
class LoraInjectedLinear(nn.Module): | |
""" | |
Linear layer with LoRA injection. | |
Taken from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py | |
""" | |
def __init__( | |
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 | |
): | |
super().__init__() | |
if r > min(in_features, out_features): | |
raise ValueError( | |
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" | |
) | |
self.r = r | |
self.linear = nn.Linear(in_features, out_features, bias) | |
self.lora_down = nn.Linear(in_features, r, bias=False) | |
self.dropout = nn.Dropout(dropout_p) | |
self.lora_up = nn.Linear(r, out_features, bias=False) | |
self.scale = scale | |
self.selector = nn.Identity() | |
nn.init.normal_(self.lora_down.weight, std=1 / r) | |
nn.init.zeros_(self.lora_up.weight) | |
def forward(self, input): | |
return ( | |
self.linear(input.float()) | |
+ self.dropout(self.lora_up(self.selector(self.lora_down(input.float())))) | |
* self.scale | |
).half() | |
def realize_as_lora(self): | |
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data | |
def set_selector_from_diag(self, diag: torch.Tensor): | |
# diag is a 1D tensor of size (r,) | |
assert diag.shape == (self.r,) | |
self.selector = nn.Linear(self.r, self.r, bias=False) | |
self.selector.weight.data = torch.diag(diag) | |
self.selector.weight.data = self.selector.weight.data.to( | |
self.lora_up.weight.device | |
).to(self.lora_up.weight.dtype) | |
class LoraInjectedConv2d(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups: int = 1, | |
bias: bool = True, | |
r: int = 4, | |
dropout_p: float = 0.1, | |
scale: float = 1.0, | |
): | |
super().__init__() | |
if r > min(in_channels, out_channels): | |
raise ValueError( | |
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}" | |
) | |
self.r = r | |
self.conv = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
) | |
self.lora_down = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=r, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=False, | |
) | |
self.dropout = nn.Dropout(dropout_p) | |
self.lora_up = nn.Conv2d( | |
in_channels=r, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
self.selector = nn.Identity() | |
self.scale = scale | |
nn.init.normal_(self.lora_down.weight, std=1 / r) | |
nn.init.zeros_(self.lora_up.weight) | |
def forward(self, input): | |
return ( | |
self.conv(input) | |
+ self.dropout(self.lora_up(self.selector(self.lora_down(input)))) | |
* self.scale | |
) | |
def realize_as_lora(self): | |
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data | |
def set_selector_from_diag(self, diag: torch.Tensor): | |
# diag is a 1D tensor of size (r,) | |
assert diag.shape == (self.r,) | |
self.selector = nn.Conv2d( | |
in_channels=self.r, | |
out_channels=self.r, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
self.selector.weight.data = torch.diag(diag) | |
# same device + dtype as lora_up | |
self.selector.weight.data = self.selector.weight.data.to( | |
self.lora_up.weight.device | |
).to(self.lora_up.weight.dtype) | |