how to use Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors

#30
by chenxiYan - opened

Thank you very much for your contribution to the community. I have provided a reproducible script and hope you can help me.

import torch
import PIL.Image
from diffusers import AutoencoderKLWan, WanVACEPipeline, GGUFQuantizationConfig, AutoModel, WanTransformer3DModel
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel

vae = AutoencoderKLWan.from_pretrained(
    "Wan-AI/Wan2.1-VACE-14B-diffusers",
    subfolder="vae",
    torch_dtype=torch.float32,
)
pipe = WanVACEPipeline.from_pretrained(
    "Wan-AI/Wan2.1-VACE-14B-diffusers",
    vae=vae,
    torch_dtype=torch.bfloat16
) 
flow_shift = 5.0  # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)

pipe.load_lora_weights(
    "../mymodels",
    weight_name="Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
    adapter_name="wan_lora"
)
pipe.set_adapters("wan_lora", "0.53")
pipe.enable_sequential_cpu_offload()


def prepare_video_and_mask(first_img: PIL.Image.Image, last_img: PIL.Image.Image, height: int, width: int, num_frames: int):
    first_img = first_img.resize((width, height))
    last_img = last_img.resize((width, height))
    frames = []
    frames.append(first_img)
    # Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
    # whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
    # match the original code.
    frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
    frames.append(last_img)
    mask_black = PIL.Image.new("L", (width, height), 0)
    mask_white = PIL.Image.new("L", (width, height), 255)
    mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
    return frames, mask


prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
first_frame = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
)
last_frame = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
)

height = 512
width = 512
num_frames = 81
video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)

output = pipe(
    video=video,
    mask=mask,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    num_inference_steps=6,
    guidance_scale=1.0,
    generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_video(output, "output.mp4", fps=16)

I'm trying this:
image.png

And I end up with this error here:
image.png

It's my understanding that we just use the base 14B model and then add this lora, but maybe I'm misunderstanding something. Either way, there are so few people actually trying to do this with pytorch it's been really hard learning how this all works. Everyone is using Comfy, Swarm, or Automatic. Which is great if you're into workflows, but for people just trying to learn the torch side, it's been tough.

Load the Lora this way
https://huggingface.co/spaces/rahul7star/wan2-1-fast/blob/main/app.py

Basically you need to fuse when use wanImagePipeline

causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
pipe.fuse_lora()

I am still struggling with similar thing for other Lora , its just pipeline tweak but yeah

Load the Lora this way
https://huggingface.co/spaces/rahul7star/wan2-1-fast/blob/main/app.py

Basically you need to fuse when use wanImagePipeline

causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
pipe.fuse_lora()

I am still struggling with similar thing for other Lora , its just pipeline tweak but yeah

Sorry I just saw this, this issue has been fixed in the latest diffusers. You can test it by installing diffusers from github.

chenxiYan changed discussion status to closed

Sign up or log in to comment