Chain-of-Zoom / osediff_sd3.py
alexnasa's picture
Update osediff_sd3.py
e1a5218 verified
import os
import sys
sys.path.append(os.getcwd())
import yaml
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import numpy as np
import lpips
from torchvision import transforms
from PIL import Image
from peft import LoraConfig, get_peft_model
from copy import deepcopy
from tqdm import tqdm
from diffusers import StableDiffusion3Pipeline, FluxPipeline
from lora.lora_layers import LoraInjectedLinear, LoraInjectedConv2d
def inject_lora_vae(vae, lora_rank=4, init_lora_weights="gaussian", verbose=False):
"""
Inject LoRA into the VAE's encoder
"""
vae.requires_grad_(False)
vae.train()
# Identify modules to LoRA-ify in the encoder
l_grep = ["conv1", "conv2", "conv_in", "conv_shortcut",
"conv", "conv_out", "to_k", "to_q", "to_v", "to_out.0"]
l_target_modules_encoder = []
for n, p in vae.named_parameters():
if "bias" in n or "norm" in n:
continue
for pattern in l_grep:
if (pattern in n) and ("encoder" in n):
l_target_modules_encoder.append(n.replace(".weight", ""))
elif ("quant_conv" in n) and ("post_quant_conv" not in n):
l_target_modules_encoder.append(n.replace(".weight", ""))
if verbose:
print("The following VAE parameters will get LoRA:")
print(l_target_modules_encoder)
# Create and add a LoRA adapter
lora_conf_encoder = LoraConfig(
r=lora_rank,
init_lora_weights=init_lora_weights,
target_modules=l_target_modules_encoder
)
adapter_name = "default_encoder"
try:
vae.add_adapter(lora_conf_encoder, adapter_name=adapter_name)
vae.set_adapter(adapter_name)
except ValueError as e:
if "already exists" in str(e):
print(f"Adapter with name {adapter_name} already exists. Skipping injection.")
else:
raise e
return vae, l_target_modules_encoder
def _find_modules(model, ancestor_class=None, search_class=[nn.Linear], exclude_children_of=[LoraInjectedLinear]):
# Get the targets we should replace all linears under
if ancestor_class is not None:
ancestors = (
module
for module in model.modules()
if module.__class__.__name__ in ancestor_class
)
else:
# this, in case you want to naively iterate over all modules.
ancestors = [module for module in model.modules()]
for ancestor in ancestors:
for fullname, module in ancestor.named_modules():
if any([isinstance(module, _class) for _class in search_class]):
*path, name = fullname.split(".")
parent = ancestor
while path:
parent = parent.get_submodule(path.pop(0))
if exclude_children_of and any(
[isinstance(parent, _class) for _class in exclude_children_of]
):
continue
yield parent, name, module
def inject_lora(model, ancestor_class, loras=None, r:int=4, dropout_p:float=0.0, scale:float=1.0, verbose:bool=False):
model.requires_grad_(False)
model.train()
names = []
require_grad_params = [] # to be updated
total_lora_params = 0
if loras is not None:
loras = torch.load(loras, map_location=model.device, weights_only=True)
loras = [lora.float() for lora in loras]
for _module, name, _child_module in _find_modules(model, ancestor_class): # SiLU + Linear Block
weight = _child_module.weight
bias = _child_module.bias
if verbose:
print(f'LoRA Injection : injecting lora into {name}')
_tmp = LoraInjectedLinear(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r=r,
dropout_p=dropout_p,
scale=scale,
)
_tmp.linear.weight = nn.Parameter(weight.float())
if bias is not None:
_tmp.linear.bias = nn.Parameter(bias.float())
# switch the module
_tmp.to(device=_child_module.weight.device, dtype=torch.float) # keep as float / mixed precision
_module._modules[name] = _tmp
require_grad_params.append(_module._modules[name].lora_up.parameters())
require_grad_params.append(_module._modules[name].lora_down.parameters())
if loras != None:
_module._modules[name].lora_up.weight = nn.Parameter(loras.pop(0))
_module._modules[name].lora_down.weight = nn.Parameter(loras.pop(0))
_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)
if verbose:
# -------- Count LoRA parameters just added --------
lora_up_count = sum(p.numel() for p in _tmp.lora_up.parameters())
lora_down_count = sum(p.numel() for p in _tmp.lora_down.parameters())
lora_total_for_this_layer = lora_up_count + lora_down_count
total_lora_params += lora_total_for_this_layer
print(f" Added {lora_total_for_this_layer} params "
f"(lora_up={lora_up_count}, lora_down={lora_down_count})")
if verbose:
print(f"Total new LoRA parameters added: {total_lora_params}")
return require_grad_params, names
def add_mp_hook(transformer):
'''
For mixed precision of LoRA. (i.e. keep LoRA as float and others as half)
'''
def pre_hook(module, input):
return input.float()
def post_hook(module, input, output):
return output.half()
hooks = []
for _module, name, _child_module in _find_modules(transformer):
if isinstance(_child_module, LoraInjectedLinear):
hook = _child_module.lora_up.register_forward_pre_hook(pre_hook)
hooks.append(hook)
hook = _child_module.lora_down.register_forward_hook(post_hook)
hooks.append(hook)
return transformer, hooks
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = 0.0, logit_std: float = 1.0, mode_scale: Optional[float] = None
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas):
"""
Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
class StableDiffusion3Base():
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16):
self.device = device
self.dtype = dtype
pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype)
self.scheduler = pipe.scheduler
self.tokenizer_1 = pipe.tokenizer
self.tokenizer_2 = pipe.tokenizer_2
self.tokenizer_3 = pipe.tokenizer_3
self.text_enc_1 = pipe.text_encoder.to(device)
self.text_enc_2 = pipe.text_encoder_2.to(device)
self.text_enc_3 = pipe.text_encoder_3.to(device)
self.vae=pipe.vae.to(device)
self.transformer = pipe.transformer.to(device)
self.transformer.eval()
self.transformer.requires_grad_(False)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8
)
del pipe
def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]:
'''
We assume that
1. number of tokens < max_length
2. one prompt for one image
'''
# CLIP encode (used for modulation of adaLN-zero)
# now, we have two CLIPs
text_clip1_ids = self.tokenizer_1(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.device), output_hidden_states=True)
pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.device)
text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.device)
text_clip2_ids = self.tokenizer_2(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.device), output_hidden_states=True)
pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.device)
text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.device)
# T5 encode (used for text condition)
text_t5_ids = self.tokenizer_3(prompt,
padding="max_length",
max_length=512,
truncation=True,
add_special_tokens=True,
return_tensors='pt').input_ids
text_t5_emb = self.text_enc_3(text_t5_ids.to(self.device))[0]
text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.device)
# Merge
clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1)
clip_prompt_emb = torch.nn.functional.pad(
clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1])
)
prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2)
pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1)
return prompt_emb, pooled_prompt_emb
def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs):
H, W = img_size
lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor
lC = self.transformer.config.in_channels
latent_shape = (batch_size, lC, lH, lW)
z = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
return z
def encode(self, image: torch.Tensor) -> torch.Tensor:
z = self.vae.encode(image).latent_dist.sample()
z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor
return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor
return self.vae.decode(z, return_dict=False)[0]
class SD3Euler(StableDiffusion3Base):
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'):
super().__init__(model_key=model_key, device=device)
def inversion(self, src_img, prompts: List[str], NFE:int, cfg_scale: float=1.0, batch_size: int=1):
# encode text prompts
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
# initialize latent
src_img = src_img.to(device=self.device, dtype=self.dtype)
with torch.no_grad():
z = self.encode(src_img)
z0 = z.clone()
# timesteps (default option. You can make your custom here.)
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
timesteps = torch.cat([timesteps, torch.zeros(1, device=self.device)])
timesteps = reversed(timesteps)
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps[:-1], total=NFE, desc='SD3 Euler Inversion')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1]
z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v))
return z
def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None, cfg_scale: float=1.0, batch_size: int = 1, latent:Optional[torch.Tensor]=None):
imgH, imgW = img_shape if img_shape is not None else (512, 512)
# encode text prompts
with torch.no_grad():
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
# initialize latent
if latent is None:
z = self.initialize_latent((imgH, imgW), batch_size)
else:
z = latent
# timesteps (default option. You can make your custom here.)
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v))
# decode
with torch.no_grad():
img = self.decode(z)
return img
class OSEDiff_SD3_GEN(torch.nn.Module):
def __init__(self, args, base_model):
super().__init__()
self.args = args
self.model = base_model
# Add lora to transformer
print('Adding Lora to OSEDiff_SD3_GEN')
self.transformer_gen = copy.deepcopy(self.model.transformer)
self.transformer_gen.to('cuda')
# self.transformer_gen = self.transformer_gen.float()
self.transformer_gen.requires_grad_(False)
self.transformer_gen.train()
self.transformer_gen, hooks = add_mp_hook(self.transformer_gen)
self.hooks = hooks
lora_params, _ = inject_lora(self.transformer_gen, {"AdaLayerNormZero"}, r=args.lora_rank, verbose=True)
# self.lora_params = lora_params
for name, param in self.transformer_gen.named_parameters():
if "lora_" in name:
param.requires_grad = True # LoRA up/down
else:
param.requires_grad = False # everything else
# Insert LoRA into VAE
print("Adding Lora to VAE")
self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=True)
def predict_vector(self, z, t, prompt_emb, pooled_emb):
v = self.transformer_gen(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
def forward(self, x_src, batch=None, args=None):
z_src = self.model.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device))
z_src = z_src.to(self.transformer_gen.device)
# calculate prompt_embeddings and neg_prompt_embeddings
batch_size, _, _, _ = x_src.shape
with torch.no_grad():
prompt_embeds, pooled_embeds = self.model.encode_prompt(batch["prompt"], batch_size)
neg_prompt_embeds, neg_pooled_embeds = self.model.encode_prompt(batch["neg_prompt"], batch_size)
NFE = 1
self.model.scheduler.set_timesteps(NFE, device=self.model.device)
timesteps = self.model.scheduler.timesteps
sigmas = timesteps / self.model.scheduler.config.num_train_timesteps
sigmas = sigmas.to(self.transformer_gen.device)
# Solve ODE
i = 0
t = timesteps[0]
timestep = t.expand(z_src.shape[0]).to(self.transformer_gen.device)
prompt_embeds = prompt_embeds.to(self.transformer_gen.device, dtype=torch.float32)
pooled_embeds = pooled_embeds.to(self.transformer_gen.device, dtype=torch.float32)
pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds)
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
z_src = z_src + (sigma_next - sigma) * (pred_null_v + 1 * (pred_v - pred_null_v))
output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device))
return output_image, z_src, prompt_embeds, pooled_embeds
class OSEDiff_SD3_REG(torch.nn.Module):
def __init__(self, args, base_model):
super().__init__()
self.args = args
self.model = base_model
self.transformer_org = self.model.transformer
# Add lora to transformer
print('Adding Lora to OSEDiff_SD3_REG')
self.transformer_reg = copy.deepcopy(self.transformer_org)
self.transformer_reg.to('cuda')
self.transformer_reg.requires_grad_(False)
self.transformer_reg.train()
self.transformer_reg, hooks = add_mp_hook(self.transformer_reg)
self.hooks = hooks
lora_params, _ = inject_lora(self.transformer_reg, {"AdaLayerNormZero"}, r=args.lora_rank, verbose=True)
for name, param in self.transformer_reg.named_parameters():
if "lora_" in name:
param.requires_grad = True # LoRA up/down
else:
param.requires_grad = False # everything else
def predict_vector_reg(self, z, t, prompt_emb, pooled_emb):
v = self.transformer_reg(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
def predict_vector_org(self, z, t, prompt_emb, pooled_emb):
v = self.transformer_org(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
def distribution_matching_loss(self, z0, prompt_embeds, pooled_embeds, global_step, args):
with torch.no_grad():
device = self.transformer_reg.device
# get timesteps and sigma
u = compute_density_for_timestep_sampling(
weighting_scheme="uniform",
batch_size=1,
logit_mean=0.0,
logit_std=1.0,
mode_scale=1.29,
)
t_idx = (u*1000).long().to(device)
self.model.scheduler.set_timesteps(1000, device=device)
times = self.model.scheduler.timesteps
t = times[t_idx]
sigma = t / 1000
# get noise and xt
z0 = z0.to(device)
noise = torch.randn_like(z0)
sigma = sigma.half()
zt = (1-sigma) * z0 + sigma * noise
# Get x0_prediction of transformer_reg
v_pred_reg = self.predict_vector_reg(zt, t, prompt_embeds.to(device), pooled_embeds.to(device))
reg_model_pred = v_pred_reg * (-sigma) + zt # this is x0_prediction for reg
# Get x0_prediction of transformer_org
org_device = self.transformer_org.device
v_pred_org = self.predict_vector_org(zt.to(org_device), t.to(org_device), prompt_embeds.to(org_device), pooled_embeds.to(org_device))
org_model_pred = v_pred_org * (-sigma.to(org_device)) + zt.to(org_device) # this is x0_prediction for org
# Visualization
if global_step % 100 == 1:
self.vsd_visualization(z0, noise, zt, reg_model_pred, org_model_pred, global_step, args)
weighting_factor = torch.abs(z0 - org_model_pred.to(device)).mean(dim=[1, 2, 3], keepdim=True)
grad = (reg_model_pred - org_model_pred.to(device)) / weighting_factor
loss = F.mse_loss(z0, (z0 - grad).detach())
return loss
def vsd_visualization(self, z0, noise, zt, reg_model_pred, org_model_pred, global_step, args):
#-------- Visualization --------#
# 1. Visualize latents, noise, zt
z0_img = self.model.decode(z0.to(dtype=torch.float32, device=self.model.vae.device))
ns_img = self.model.decode(noise.to(dtype=torch.float32, device=self.model.vae.device))
zt_img = self.model.decode(zt.to(dtype=torch.float32, device=self.model.vae.device))
z0_img_pil = transforms.ToPILImage()(torch.clamp(z0_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5)
ns_img_pil = transforms.ToPILImage()(torch.clamp(ns_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5)
zt_img_pil = transforms.ToPILImage()(torch.clamp(zt_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5)
# 2. Visualize reg_img, org_img
reg_img = self.model.decode(reg_model_pred.to(dtype=torch.float32, device=self.model.vae.device))
org_img = self.model.decode(org_model_pred.to(dtype=torch.float32, device=self.model.vae.device))
reg_img_pil = transforms.ToPILImage()(torch.clamp(reg_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5)
org_img_pil = transforms.ToPILImage()(torch.clamp(org_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5)
# Concatenate images side by side
w, h = z0_img_pil.width, z0_img_pil.height
combined_image = Image.new('RGB', (w*5, h))
combined_image.paste(z0_img_pil, (0, 0))
combined_image.paste(ns_img_pil, (w, 0))
combined_image.paste(zt_img_pil, (w*2, 0))
combined_image.paste(reg_img_pil, (w*3, 0))
combined_image.paste(org_img_pil, (w*4, 0))
combined_image.save(os.path.join(args.output_dir, f'visualization/vsd/{global_step}.png'))
#-------- Visualization --------#
def diff_loss(self, z0, prompt_embeds, pooled_embeds, net_lpips, args):
device = self.transformer_reg.device
u = compute_density_for_timestep_sampling(
weighting_scheme="uniform",
batch_size=1,
logit_mean=0.0,
logit_std=1.0,
mode_scale=1.29,
)
t_idx = (u*1000).long().to(device)
self.model.scheduler.set_timesteps(1000, device=device)
times = self.model.scheduler.timesteps
t = times[t_idx]
sigma = t / 1000
z0 = z0.to(device)
z0, prompt_embeds = z0.detach(), prompt_embeds.detach()
noise = torch.randn_like(z0)
sigma = sigma.half()
zt = (1-sigma) * z0 + sigma * noise # noisy latents
# v-prediction
v_pred = self.predict_vector_reg(zt, t, prompt_embeds.to(device), pooled_embeds.to(device))
model_pred = v_pred * (-sigma) + zt
target = z0
loss_weight = compute_loss_weighting_for_sd3("logit_normal", sigma)
diffusion_loss = loss_weight.float() * F.mse_loss(model_pred.float(), target.float())
loss_d = diffusion_loss
return loss_d.mean()
class OSEDiff_SD3_TEST(torch.nn.Module):
def __init__(self, args, base_model):
super().__init__()
self.args = args
self.model = base_model
self.lora_path = args.lora_path
self.vae_path = args.vae_path
# Add lora to transformer
print(f'Loading LoRA to Transformer from {self.lora_path}')
self.model.transformer.requires_grad_(False)
lora_params, _ = inject_lora(self.model.transformer, {"AdaLayerNormZero"}, loras=self.lora_path, r=args.lora_rank, verbose=False)
for name, param in self.model.transformer.named_parameters():
param.requires_grad = False
# Insert LoRA into VAE
print(f"Loading LoRA to VAE from {self.vae_path}")
self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=False)
encoder_state_dict_fp16 = torch.load(self.vae_path, map_location="cpu")
self.model.vae.encoder.load_state_dict(encoder_state_dict_fp16)
def predict_vector(self, z, t, prompt_emb, pooled_emb):
v = self.model.transformer(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
@torch.no_grad()
def forward(self, x_src, prompt):
z_src = self.model.vae.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)).latent_dist.sample() * self.model.vae.config.scaling_factor
z_src = z_src.to(self.model.transformer.device)
# calculate prompt_embeddings and neg_prompt_embeddings
batch_size, _, _, _ = x_src.shape
with torch.no_grad():
prompt_embeds, pooled_embeds = self.model.encode_prompt([prompt], batch_size)
self.model.scheduler.set_timesteps(1, device=self.model.device)
timesteps = self.model.scheduler.timesteps
# Solve ODE
t = timesteps[0]
timestep = t.expand(z_src.shape[0]).to(self.model.transformer.device)
prompt_embeds = prompt_embeds.to(self.model.transformer.device, dtype=torch.float32)
pooled_embeds = pooled_embeds.to(self.model.transformer.device, dtype=torch.float32)
pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds)
z_src = z_src - pred_v
with torch.no_grad():
output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device))
return output_image
class OSEDiff_SD3_TEST_efficient(torch.nn.Module):
def __init__(self, args, base_model):
super().__init__()
self.args = args
self.model = base_model
self.lora_path = args.lora_path
self.vae_path = args.vae_path
# Add lora to transformer
print(f'Loading LoRA to Transformer from {self.lora_path}')
self.model.transformer.requires_grad_(False)
lora_params, _ = inject_lora(self.model.transformer, {"AdaLayerNormZero"}, loras=self.lora_path, r=args.lora_rank, verbose=False)
for name, param in self.model.transformer.named_parameters():
param.requires_grad = False
# Insert LoRA into VAE
print(f"Loading LoRA to VAE from {self.vae_path}")
self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=False)
encoder_state_dict_fp16 = torch.load(self.vae_path, map_location="cpu")
self.model.vae.encoder.load_state_dict(encoder_state_dict_fp16)
def predict_vector(self, z, t, prompt_emb, pooled_emb):
v = self.model.transformer(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
@torch.no_grad()
def forward(self, x_src, prompt):
z_src = self.model.vae.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)).latent_dist.sample() * self.model.vae.config.scaling_factor
z_src = z_src.to(self.model.transformer.device)
# calculate prompt_embeddings
batch_size, _, _, _ = x_src.shape
prompt_embeds, pooled_embeds = self.model.encode_prompt([prompt], batch_size)
self.model.scheduler.set_timesteps(1, device=self.model.device)
timesteps = self.model.scheduler.timesteps
# Solve ODE
t = timesteps[0]
timestep = t.expand(z_src.shape[0]).to(self.model.transformer.device)
prompt_embeds = prompt_embeds.to(self.model.transformer.device, dtype=torch.float32)
pooled_embeds = pooled_embeds.to(self.model.transformer.device, dtype=torch.float32)
pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds)
z_src = z_src - pred_v
output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device))
return output_image