tbbl's picture
Update app.py
71fd808 verified
raw
history blame
9.93 kB
import types
import random
import spaces
import torch
import numpy as np
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from diffusers import AutoModel
import gradio as gr
import tempfile
from huggingface_hub import hf_hub_download
from src.pipeline_wan_nag import NAGWanPipeline
from src.transformer_wan_nag import NagWanTransformer3DModel
MOD_VALUE = 32
DEFAULT_DURATION_SECONDS = 4
DEFAULT_STEPS = 4
DEFAULT_SEED = 2025
DEFAULT_H_SLIDER_VALUE = 480
DEFAULT_W_SLIDER_VALUE = 832
NEW_FORMULA_MAX_AREA = 480.0 * 832.0
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
LORA_REPO_ID = "Kijai/WanVideo_comfy"
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
transformer = NagWanTransformer3DModel.from_single_file(wan_path, torch_dtype=torch.bfloat16)
pipe = NAGWanPipeline.from_pretrained(
MODEL_ID, vae=vae, transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
pipe.to("cuda")
#causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
#pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
#pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
#for name, param in pipe.transformer.named_parameters():
# if "lora_B" in name:
# if "blocks.0" in name:
# param.data = param.data * 0.25
#pipe.fuse_lora()
#pipe.unload_lora_weights()
pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward
examples = [
["A ginger cat passionately plays eletric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights casts dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape. Magical and dreamlike, captured in a wide shot. Surreal realism style with detailed textures.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
]
def get_duration(
prompt,
nag_negative_prompt, nag_scale,
height, width, duration_seconds,
steps,
seed, randomize_seed,
compare,
):
duration = int(duration_seconds) * int(steps) * 2.25 + 5
if compare:
duration *= 2
return duration
@spaces.GPU(duration=get_duration)
def generate_video(
prompt,
nag_negative_prompt, nag_scale,
height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
steps=DEFAULT_STEPS,
seed=DEFAULT_SEED, randomize_seed=False,
compare=True,
):
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
with torch.inference_mode():
nag_output_frames_list = pipe(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt,
nag_scale=nag_scale,
nag_tau=3.5,
nag_alpha=0.5,
height=target_h, width=target_w, num_frames=num_frames,
guidance_scale=0.,
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
nag_video_path = tmpfile.name
export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)
if compare:
baseline_output_frames_list = pipe(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt,
height=target_h, width=target_w, num_frames=num_frames,
guidance_scale=0.,
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
baseline_video_path = tmpfile.name
export_to_video(baseline_output_frames_list, baseline_video_path, fps=FIXED_FPS)
else:
baseline_video_path = None
if torch.cuda.is_available():
print("Allocated:", torch.cuda.memory_allocated() / 1024**2, "MB")
print("Cached: ", torch.cuda.memory_reserved() / 1024**2, "MB")
return nag_video_path, baseline_video_path, current_seed
def generate_video_with_example(
prompt,
nag_negative_prompt,
nag_scale,
):
nag_video_path, baseline_video_path, seed = generate_video(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale,
height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
steps=DEFAULT_STEPS,
seed=DEFAULT_SEED, randomize_seed=False,
compare=True,
)
if torch.cuda.is_available():
print("Allocated:", torch.cuda.memory_allocated() / 1024**2, "MB")
print("Cached: ", torch.cuda.memory_reserved() / 1024**2, "MB")
return nag_video_path, baseline_video_path, \
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, \
DEFAULT_DURATION_SECONDS, DEFAULT_STEPS, seed, True
with gr.Blocks() as demo:
gr.Markdown('''# Normalized Attention Guidance (NAG) for fast 4 steps Wan2.1-T2V-14B with CausVid LoRA
Implementation of [Normalized Attention Guidance](https://chendaryen.github.io/NAG.github.io/).
[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors).
''')
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
max_lines=3,
placeholder="Enter your prompt",
)
nag_negative_prompt = gr.Textbox(
label="Negative Prompt for NAG",
value=DEFAULT_NAG_NEGATIVE_PROMPT,
max_lines=3,
)
nag_scale = gr.Slider(label="NAG Scale", minimum=1., maximum=20., step=0.25, value=11.)
compare = gr.Checkbox(
label="Compare with baseline",
info="If unchecked, only sample with NAG will be generated.", value=True,
)
with gr.Accordion("Advanced Settings", open=False):
steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_STEPS, label="Inference Steps")
duration_seconds_input = gr.Slider(
minimum=1, maximum=5, step=1, value=DEFAULT_DURATION_SECONDS,
label="Duration (seconds)",
)
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=DEFAULT_SEED, interactive=True)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
with gr.Row():
height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE,
value=DEFAULT_H_SLIDER_VALUE,
label=f"Output Height (multiple of {MOD_VALUE})")
width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE,
value=DEFAULT_W_SLIDER_VALUE,
label=f"Output Width (multiple of {MOD_VALUE})")
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
nag_video_output = gr.Video(label="Video with NAG", autoplay=True, interactive=False)
baseline_video_output = gr.Video(label="Baseline Video without NAG", autoplay=True, interactive=False)
gr.Examples(
examples=examples,
fn=generate_video_with_example,
inputs=[prompt, nag_negative_prompt, nag_scale],
outputs=[
nag_video_output, baseline_video_output,
height_input, width_input, duration_seconds_input,
steps_slider,
seed_input,
compare,
],
cache_examples="lazy"
)
ui_inputs = [
prompt,
nag_negative_prompt, nag_scale,
height_input, width_input, duration_seconds_input,
steps_slider,
seed_input, randomize_seed_checkbox,
compare,
]
generate_button.click(
fn=generate_video,
inputs=ui_inputs,
outputs=[nag_video_output, baseline_video_output, seed_input],
)
if __name__ == "__main__":
demo.queue().launch()