Wan-fusionX / app.py
rahul7star's picture
Update app.py
ca82fa6 verified
raw
history blame
8.97 kB
import types
import random
import spaces
import torch
import numpy as np
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from diffusers import AutoModel
import gradio as gr
import tempfile
from huggingface_hub import hf_hub_download
from src.pipeline_wan_nag import NAGWanPipeline
from src.transformer_wan_nag import NagWanTransformer3DModel
MOD_VALUE = 32
DEFAULT_DURATION_SECONDS = 4
DEFAULT_STEPS = 4
DEFAULT_SEED = 2025
DEFAULT_H_SLIDER_VALUE = 480
DEFAULT_W_SLIDER_VALUE = 832
NEW_FORMULA_MAX_AREA = 480.0 * 832.0
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
#LORA_REPO_ID = "Kijai/WanVideo_comfy"
#LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
# new experiment for future work
LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
LORA_FILENAME = "FusionX_LoRa/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
transformer = NagWanTransformer3DModel.from_single_file(wan_path, torch_dtype=torch.bfloat16)
pipe = NAGWanPipeline.from_pretrained(
MODEL_ID, vae=vae, transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
pipe.to("cuda")
pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward
examples = [
["A ginger cat passionately plays eletric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights casts dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape. Magical and dreamlike, captured in a wide shot. Surreal realism style with detailed textures.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
]
def get_duration(
prompt,
nag_negative_prompt, nag_scale,
height, width, duration_seconds,
steps,
seed, randomize_seed,
compare,
):
duration = int(duration_seconds) * int(steps) * 2.25 + 5
if compare:
duration *= 2
return duration
@spaces.GPU(duration=get_duration)
def generate_video(
prompt,
nag_negative_prompt, nag_scale,
height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
steps=DEFAULT_STEPS,
seed=DEFAULT_SEED, randomize_seed=False,
compare=True,
):
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
with torch.inference_mode():
nag_output_frames_list = pipe(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt,
nag_scale=nag_scale,
nag_tau=3.5,
nag_alpha=0.5,
height=target_h, width=target_w, num_frames=num_frames,
guidance_scale=0.,
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
nag_video_path = tmpfile.name
export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)
if compare:
baseline_output_frames_list = pipe(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt,
height=target_h, width=target_w, num_frames=num_frames,
guidance_scale=0.,
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
baseline_video_path = tmpfile.name
export_to_video(baseline_output_frames_list, baseline_video_path, fps=FIXED_FPS)
else:
baseline_video_path = None
return nag_video_path, baseline_video_path, current_seed
def generate_video_with_example(
prompt,
nag_negative_prompt,
nag_scale,
):
nag_video_path, baseline_video_path, seed = generate_video(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale,
height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
steps=DEFAULT_STEPS,
seed=DEFAULT_SEED, randomize_seed=False,
compare=True,
)
return nag_video_path, baseline_video_path, \
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, \
DEFAULT_DURATION_SECONDS, DEFAULT_STEPS, seed, True
with gr.Blocks() as demo:
gr.Markdown('''# Normalized Attention Guidance (NAG) for fast 4 steps Wan2.1-T2V-14B with Wan14BT2VFusioniX
''')
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
max_lines=3,
placeholder="Enter your prompt",
)
nag_negative_prompt = gr.Textbox(
label="Negative Prompt for NAG",
value=DEFAULT_NAG_NEGATIVE_PROMPT,
max_lines=3,
)
nag_scale = gr.Slider(label="NAG Scale", minimum=1., maximum=20., step=0.25, value=11.)
compare = gr.Checkbox(
label="Compare with baseline",
info="If unchecked, only sample with NAG will be generated.", value=True,
)
with gr.Accordion("Advanced Settings", open=False):
steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_STEPS, label="Inference Steps")
duration_seconds_input = gr.Slider(
minimum=1, maximum=5, step=1, value=DEFAULT_DURATION_SECONDS,
label="Duration (seconds)",
)
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=DEFAULT_SEED, interactive=True)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
with gr.Row():
height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE,
value=DEFAULT_H_SLIDER_VALUE,
label=f"Output Height (multiple of {MOD_VALUE})")
width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE,
value=DEFAULT_W_SLIDER_VALUE,
label=f"Output Width (multiple of {MOD_VALUE})")
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
nag_video_output = gr.Video(label="Video with NAG", autoplay=True, interactive=False)
baseline_video_output = gr.Video(label="Baseline Video without NAG", autoplay=True, interactive=False)
gr.Examples(
examples=examples,
fn=generate_video_with_example,
inputs=[prompt, nag_negative_prompt, nag_scale],
outputs=[
nag_video_output, baseline_video_output,
height_input, width_input, duration_seconds_input,
steps_slider,
seed_input,
compare,
],
cache_examples="lazy"
)
ui_inputs = [
prompt,
nag_negative_prompt, nag_scale,
height_input, width_input, duration_seconds_input,
steps_slider,
seed_input, randomize_seed_checkbox,
compare,
]
generate_button.click(
fn=generate_video,
inputs=ui_inputs,
outputs=[nag_video_output, baseline_video_output, seed_input],
)
if __name__ == "__main__":
demo.queue().launch()