|
import torch |
|
import gradio as gr |
|
import spaces |
|
from diffusers.utils import export_to_video |
|
from diffusers import AutoencoderKLWan, WanPipeline |
|
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler |
|
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"Wan2.1-T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", |
|
"Wan2.1-T2V-14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers" |
|
} |
|
|
|
|
|
SCHEDULER_OPTIONS = { |
|
"UniPCMultistepScheduler": UniPCMultistepScheduler, |
|
"FlowMatchEulerDiscreteScheduler": FlowMatchEulerDiscreteScheduler |
|
} |
|
|
|
@spaces.GPU(duration=300) |
|
def generate_video( |
|
model_choice, |
|
prompt, |
|
negative_prompt, |
|
lora_id, |
|
lora_scale, |
|
scheduler_type, |
|
flow_shift, |
|
height, |
|
width, |
|
num_frames, |
|
guidance_scale, |
|
num_inference_steps, |
|
output_fps |
|
): |
|
"""Generate a video using the Wan model and provided parameters""" |
|
try: |
|
|
|
model_id = MODEL_OPTIONS[model_choice] |
|
|
|
|
|
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) |
|
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) |
|
|
|
|
|
scheduler_class = SCHEDULER_OPTIONS[scheduler_type] |
|
|
|
if scheduler_type == "UniPCMultistepScheduler": |
|
pipe.scheduler = scheduler_class.from_config( |
|
pipe.scheduler.config, |
|
prediction_type="flow_prediction", |
|
use_flow_sigmas=True, |
|
flow_shift=flow_shift |
|
) |
|
else: |
|
pipe.scheduler = scheduler_class(shift=flow_shift) |
|
|
|
|
|
pipe.to("cuda") |
|
|
|
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
if lora_id and lora_id.strip(): |
|
try: |
|
|
|
pipe.load_lora_weights(lora_id) |
|
|
|
|
|
if hasattr(pipe, "fuse_lora"): |
|
pipe.fuse_lora(lora_scale=lora_scale) |
|
except Exception as e: |
|
return f"Error loading/fusing LoRA: {str(e)}" |
|
|
|
|
|
output = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
height=height, |
|
width=width, |
|
num_frames=num_frames, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps |
|
).frames[0] |
|
|
|
|
|
temp_file = "output.mp4" |
|
export_to_video(output, temp_file, fps=output_fps) |
|
|
|
return temp_file |
|
except Exception as e: |
|
return f"Error generating video: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Wan Video Generation with ZeroGPU") |
|
gr.Markdown("Generate high-quality videos using the Wan model with optional LoRA adaptations.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_choice = gr.Dropdown( |
|
choices=list(MODEL_OPTIONS.keys()), |
|
value="Wan2.1-T2V-1.3B", |
|
label="Model" |
|
) |
|
|
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
value="steamboat willie style, golden era animation, an anthropomorphic cat character wearing a hat removes it and performs a courteous bow", |
|
lines=3 |
|
) |
|
|
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
lora_id = gr.Textbox( |
|
label="LoRA ID (e.g., benjamin-paine/steamboat-willie-1.3b)", |
|
value="benjamin-paine/steamboat-willie-1.3b" |
|
) |
|
lora_scale = gr.Slider( |
|
label="LoRA Scale", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.75, |
|
step=0.05 |
|
) |
|
|
|
with gr.Row(): |
|
scheduler_type = gr.Dropdown( |
|
choices=list(SCHEDULER_OPTIONS.keys()), |
|
value="UniPCMultistepScheduler", |
|
label="Scheduler" |
|
) |
|
flow_shift = gr.Slider( |
|
label="Flow Shift", |
|
minimum=1.0, |
|
maximum=12.0, |
|
value=3.0, |
|
step=0.5, |
|
info="2.0-5.0 for smaller videos, 7.0-12.0 for larger videos" |
|
) |
|
|
|
with gr.Row(): |
|
height = gr.Slider( |
|
label="Height", |
|
minimum=256, |
|
maximum=1024, |
|
value=480, |
|
step=32 |
|
) |
|
width = gr.Slider( |
|
label="Width", |
|
minimum=256, |
|
maximum=1792, |
|
value=832, |
|
step=32 |
|
) |
|
|
|
with gr.Row(): |
|
num_frames = gr.Slider( |
|
label="Number of Frames (4k+1 is recommended, e.g. 81)", |
|
minimum=17, |
|
maximum=129, |
|
value=81, |
|
step=4 |
|
) |
|
output_fps = gr.Slider( |
|
label="Output FPS", |
|
minimum=8, |
|
maximum=30, |
|
value=16, |
|
step=1 |
|
) |
|
|
|
with gr.Row(): |
|
guidance_scale = gr.Slider( |
|
label="Guidance Scale (CFG)", |
|
minimum=1.0, |
|
maximum=15.0, |
|
value=5.0, |
|
step=0.5 |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Inference Steps", |
|
minimum=10, |
|
maximum=100, |
|
value=32, |
|
step=1 |
|
) |
|
|
|
generate_btn = gr.Button("Generate Video") |
|
|
|
with gr.Column(scale=1): |
|
output_video = gr.Video(label="Generated Video") |
|
|
|
generate_btn.click( |
|
fn=generate_video, |
|
inputs=[ |
|
model_choice, |
|
prompt, |
|
negative_prompt, |
|
lora_id, |
|
lora_scale, |
|
scheduler_type, |
|
flow_shift, |
|
height, |
|
width, |
|
num_frames, |
|
guidance_scale, |
|
num_inference_steps, |
|
output_fps |
|
], |
|
outputs=output_video |
|
) |
|
|
|
gr.Markdown(""" |
|
## Tips for best results: |
|
- For smaller resolution videos, try lower values of flow shift (2.0-5.0) |
|
- For larger resolution videos, try higher values of flow shift (7.0-12.0) |
|
- Number of frames should be of the form 4k+1 (e.g., 49, 81, 65) |
|
- The model is memory intensive, so adjust resolution according to available VRAM |
|
- LoRA ID should be a Hugging Face repository containing safetensors files |
|
""") |
|
|
|
demo.launch() |