Spaces:
Runtime error
Runtime error
| import fire | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import torch.nn as nn | |
| from .lora import ( | |
| save_all, | |
| _find_modules, | |
| LoraInjectedConv2d, | |
| LoraInjectedLinear, | |
| inject_trainable_lora, | |
| inject_trainable_lora_extended, | |
| ) | |
| def _iter_lora(model): | |
| for module in model.modules(): | |
| if isinstance(module, LoraInjectedConv2d) or isinstance( | |
| module, LoraInjectedLinear | |
| ): | |
| yield module | |
| def overwrite_base(base_model, tuned_model, rank, clamp_quantile): | |
| device = base_model.device | |
| dtype = base_model.dtype | |
| for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)): | |
| if isinstance(lor_base, LoraInjectedLinear): | |
| residual = lor_tune.linear.weight.data - lor_base.linear.weight.data | |
| # SVD on residual | |
| print("Distill Linear shape ", residual.shape) | |
| residual = residual.float() | |
| U, S, Vh = torch.linalg.svd(residual) | |
| U = U[:, :rank] | |
| S = S[:rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:rank, :] | |
| dist = torch.cat([U.flatten(), Vh.flatten()]) | |
| hi_val = torch.quantile(dist, clamp_quantile) | |
| low_val = -hi_val | |
| U = U.clamp(low_val, hi_val) | |
| Vh = Vh.clamp(low_val, hi_val) | |
| assert lor_base.lora_up.weight.shape == U.shape | |
| assert lor_base.lora_down.weight.shape == Vh.shape | |
| lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) | |
| lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) | |
| if isinstance(lor_base, LoraInjectedConv2d): | |
| residual = lor_tune.conv.weight.data - lor_base.conv.weight.data | |
| print("Distill Conv shape ", residual.shape) | |
| residual = residual.float() | |
| residual = residual.flatten(start_dim=1) | |
| # SVD on residual | |
| U, S, Vh = torch.linalg.svd(residual) | |
| U = U[:, :rank] | |
| S = S[:rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:rank, :] | |
| dist = torch.cat([U.flatten(), Vh.flatten()]) | |
| hi_val = torch.quantile(dist, clamp_quantile) | |
| low_val = -hi_val | |
| U = U.clamp(low_val, hi_val) | |
| Vh = Vh.clamp(low_val, hi_val) | |
| # U is (out_channels, rank) with 1x1 conv. So, | |
| U = U.reshape(U.shape[0], U.shape[1], 1, 1) | |
| # V is (rank, in_channels * kernel_size1 * kernel_size2) | |
| # now reshape: | |
| Vh = Vh.reshape( | |
| Vh.shape[0], | |
| lor_base.conv.in_channels, | |
| lor_base.conv.kernel_size[0], | |
| lor_base.conv.kernel_size[1], | |
| ) | |
| assert lor_base.lora_up.weight.shape == U.shape | |
| assert lor_base.lora_down.weight.shape == Vh.shape | |
| lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) | |
| lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) | |
| def svd_distill( | |
| target_model: str, | |
| base_model: str, | |
| rank: int = 4, | |
| clamp_quantile: float = 0.99, | |
| device: str = "cuda:0", | |
| save_path: str = "svd_distill.safetensors", | |
| ): | |
| pipe_base = StableDiffusionPipeline.from_pretrained( | |
| base_model, torch_dtype=torch.float16 | |
| ).to(device) | |
| pipe_tuned = StableDiffusionPipeline.from_pretrained( | |
| target_model, torch_dtype=torch.float16 | |
| ).to(device) | |
| # Inject unet | |
| _ = inject_trainable_lora_extended(pipe_base.unet, r=rank) | |
| _ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank) | |
| overwrite_base( | |
| pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile | |
| ) | |
| # Inject text encoder | |
| _ = inject_trainable_lora( | |
| pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"} | |
| ) | |
| _ = inject_trainable_lora( | |
| pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"} | |
| ) | |
| overwrite_base( | |
| pipe_base.text_encoder, | |
| pipe_tuned.text_encoder, | |
| rank=rank, | |
| clamp_quantile=clamp_quantile, | |
| ) | |
| save_all( | |
| unet=pipe_base.unet, | |
| text_encoder=pipe_base.text_encoder, | |
| placeholder_token_ids=None, | |
| placeholder_tokens=None, | |
| save_path=save_path, | |
| save_lora=True, | |
| save_ti=False, | |
| ) | |
| def main(): | |
| fire.Fire(svd_distill) | |