markury's picture
test: ui
cf345e6
raw
history blame
7.95 kB
import torch
import gradio as gr
import spaces
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
# Define model options
MODEL_OPTIONS = {
"Wan2.1-T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"Wan2.1-T2V-14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers"
}
# Define scheduler options
SCHEDULER_OPTIONS = {
"UniPCMultistepScheduler": UniPCMultistepScheduler,
"FlowMatchEulerDiscreteScheduler": FlowMatchEulerDiscreteScheduler
}
@spaces.GPU(duration=300) # Set a 5-minute duration for the GPU access
def generate_video(
model_choice,
prompt,
negative_prompt,
lora_id,
lora_scale,
scheduler_type,
flow_shift,
height,
width,
num_frames,
guidance_scale,
num_inference_steps,
output_fps
):
"""Generate a video using the Wan model and provided parameters"""
try:
# Get model ID from selection
model_id = MODEL_OPTIONS[model_choice]
# Load the model components
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
# Set the scheduler
scheduler_class = SCHEDULER_OPTIONS[scheduler_type]
if scheduler_type == "UniPCMultistepScheduler":
pipe.scheduler = scheduler_class.from_config(
pipe.scheduler.config,
prediction_type="flow_prediction",
use_flow_sigmas=True,
flow_shift=flow_shift
)
else:
pipe.scheduler = scheduler_class(shift=flow_shift)
# Move to GPU
pipe.to("cuda")
# Enable CPU offload for low VRAM
pipe.enable_model_cpu_offload()
# Load and fuse LoRA if provided
if lora_id and lora_id.strip():
try:
# Load the LoRA weights
pipe.load_lora_weights(lora_id)
# Fuse LoRA with specified scale if available
if hasattr(pipe, "fuse_lora"):
pipe.fuse_lora(lora_scale=lora_scale)
except Exception as e:
return f"Error loading/fusing LoRA: {str(e)}"
# Generate the video
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
).frames[0]
# Export to video
temp_file = "output.mp4"
export_to_video(output, temp_file, fps=output_fps)
return temp_file
except Exception as e:
return f"Error generating video: {str(e)}"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Wan Video Generation with ZeroGPU")
gr.Markdown("Generate high-quality videos using the Wan model with optional LoRA adaptations.")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value="Wan2.1-T2V-1.3B",
label="Model"
)
prompt = gr.Textbox(
label="Prompt",
value="steamboat willie style, golden era animation, an anthropomorphic cat character wearing a hat removes it and performs a courteous bow",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
lines=3
)
with gr.Row():
lora_id = gr.Textbox(
label="LoRA ID (e.g., benjamin-paine/steamboat-willie-1.3b)",
value="benjamin-paine/steamboat-willie-1.3b"
)
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0.0,
maximum=1.0,
value=0.75,
step=0.05
)
with gr.Row():
scheduler_type = gr.Dropdown(
choices=list(SCHEDULER_OPTIONS.keys()),
value="UniPCMultistepScheduler",
label="Scheduler"
)
flow_shift = gr.Slider(
label="Flow Shift",
minimum=1.0,
maximum=12.0,
value=3.0,
step=0.5,
info="2.0-5.0 for smaller videos, 7.0-12.0 for larger videos"
)
with gr.Row():
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
value=480,
step=32
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=1792,
value=832,
step=32
)
with gr.Row():
num_frames = gr.Slider(
label="Number of Frames (4k+1 is recommended, e.g. 81)",
minimum=17,
maximum=129,
value=81,
step=4
)
output_fps = gr.Slider(
label="Output FPS",
minimum=8,
maximum=30,
value=16,
step=1
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale (CFG)",
minimum=1.0,
maximum=15.0,
value=5.0,
step=0.5
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=100,
value=32,
step=1
)
generate_btn = gr.Button("Generate Video")
with gr.Column(scale=1):
output_video = gr.Video(label="Generated Video")
generate_btn.click(
fn=generate_video,
inputs=[
model_choice,
prompt,
negative_prompt,
lora_id,
lora_scale,
scheduler_type,
flow_shift,
height,
width,
num_frames,
guidance_scale,
num_inference_steps,
output_fps
],
outputs=output_video
)
gr.Markdown("""
## Tips for best results:
- For smaller resolution videos, try lower values of flow shift (2.0-5.0)
- For larger resolution videos, try higher values of flow shift (7.0-12.0)
- Number of frames should be of the form 4k+1 (e.g., 49, 81, 65)
- The model is memory intensive, so adjust resolution according to available VRAM
- LoRA ID should be a Hugging Face repository containing safetensors files
""")
demo.launch()