|
import torch |
|
import gradio as gr |
|
import spaces |
|
import random |
|
import os |
|
from diffusers.utils import export_to_video |
|
from diffusers import AutoencoderKLWan, WanPipeline |
|
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler |
|
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler |
|
from huggingface_hub import hf_hub_download |
|
from lycoris import create_lycoris_from_weights |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"Wan2.1-T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", |
|
"Wan2.1-T2V-14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers", |
|
"Wan2.1-Fun-Reward-1.3B": "alibaba-pai/Wan2.1-Fun-1.3B-InP" |
|
} |
|
|
|
|
|
SCHEDULER_OPTIONS = { |
|
"UniPCMultistepScheduler": UniPCMultistepScheduler, |
|
"FlowMatchEulerDiscreteScheduler": FlowMatchEulerDiscreteScheduler |
|
} |
|
|
|
def download_adapter(repo_id, weight_name=None): |
|
adapter_filename = weight_name if weight_name else "pytorch_lora_weights.safetensors" |
|
cache_dir = os.environ.get('HF_PATH', os.path.expanduser('~/.cache/huggingface/hub/models')) |
|
cleaned_adapter_path = repo_id.replace("/", "_").replace("\\", "_").replace(":", "_") |
|
path_to_adapter = os.path.join(cache_dir, cleaned_adapter_path) |
|
os.makedirs(path_to_adapter, exist_ok=True) |
|
|
|
try: |
|
path_to_adapter_file = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=adapter_filename, |
|
local_dir=path_to_adapter |
|
) |
|
return path_to_adapter_file |
|
except Exception as e: |
|
if weight_name is None: |
|
raise ValueError(f"Could not download default adapter file: {str(e)}\nPlease specify the exact weight file name.") |
|
else: |
|
raise ValueError(f"Could not download adapter file {weight_name}: {str(e)}") |
|
|
|
@spaces.GPU(duration=140) |
|
def generate_video( |
|
model_choice, |
|
prompt, |
|
negative_prompt, |
|
lycoris_id, |
|
lycoris_weight_name, |
|
lycoris_scale, |
|
scheduler_type, |
|
flow_shift, |
|
height, |
|
width, |
|
num_frames, |
|
guidance_scale, |
|
num_inference_steps, |
|
output_fps, |
|
seed |
|
): |
|
model_id = MODEL_OPTIONS[model_choice] |
|
|
|
if seed == -1 or seed is None or seed == "": |
|
seed = random.randint(0, 2147483647) |
|
else: |
|
seed = int(seed) |
|
|
|
torch.manual_seed(seed) |
|
|
|
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) |
|
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float16) |
|
|
|
if scheduler_type == "UniPCMultistepScheduler": |
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) |
|
else: |
|
pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) |
|
|
|
pipe.to("cuda") |
|
|
|
if lycoris_id and lycoris_id.strip(): |
|
try: |
|
adapter_file_path = download_adapter( |
|
repo_id=lycoris_id, |
|
weight_name=lycoris_weight_name if lycoris_weight_name and lycoris_weight_name.strip() else None |
|
) |
|
wrapper, *_ = create_lycoris_from_weights(lycoris_scale, adapter_file_path, pipe.transformer) |
|
wrapper.merge_to() |
|
except ValueError as e: |
|
if "more than one weights file" in str(e) or "Could not download default adapter file" in str(e): |
|
return f"Error: The repository '{lycoris_id}' may contain multiple weight files. Please specify a weight name.", seed |
|
else: |
|
return f"Error loading LyCORIS weights: {str(e)}", seed |
|
|
|
pipe.enable_model_cpu_offload() |
|
|
|
output = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
height=height, |
|
width=width, |
|
num_frames=num_frames, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=torch.Generator("cuda").manual_seed(seed) |
|
).frames[0] |
|
|
|
temp_file = "output.mp4" |
|
export_to_video(output, temp_file, fps=output_fps) |
|
|
|
return temp_file, seed |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("# Wan 2.1 T2V with Custom LoRA") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_choice = gr.Dropdown( |
|
choices=list(MODEL_OPTIONS.keys()), |
|
value="Wan2.1-Fun-Reward-1.3B", |
|
label="Model" |
|
) |
|
|
|
prompt = gr.Textbox(label="Prompt", value="", lines=3) |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static", |
|
lines=3 |
|
) |
|
|
|
lycoris_id = gr.Textbox( |
|
label="Adapter Repo", |
|
value="alibaba-pai/Wan2.1-Fun-Reward-LoRAs" |
|
) |
|
|
|
with gr.Row(): |
|
lycoris_weight_name = gr.Textbox( |
|
label="Adapter File Name", |
|
value="Wan2.1-Fun-1.3B-InP-MPS.safetensors" |
|
) |
|
lycoris_scale = gr.Slider( |
|
label="Adapter Scale", |
|
minimum=0.0, |
|
maximum=2.0, |
|
value=1.0, |
|
step=0.05 |
|
) |
|
|
|
scheduler_type = gr.Dropdown( |
|
choices=list(SCHEDULER_OPTIONS.keys()), |
|
value="UniPCMultistepScheduler", |
|
label="Scheduler" |
|
) |
|
flow_shift = gr.Slider( |
|
label="Flow Shift", |
|
minimum=1.0, |
|
maximum=12.0, |
|
value=3.0, |
|
step=0.5 |
|
) |
|
|
|
height = gr.Slider( |
|
label="Height", |
|
minimum=256, |
|
maximum=1024, |
|
value=320, |
|
step=32 |
|
) |
|
width = gr.Slider( |
|
label="Width", |
|
minimum=256, |
|
maximum=1792, |
|
value=480, |
|
step=30 |
|
) |
|
|
|
num_frames = gr.Slider( |
|
label="Number of Frames", |
|
minimum=17, |
|
maximum=129, |
|
value=33, |
|
step=4 |
|
) |
|
output_fps = gr.Slider( |
|
label="Output FPS", |
|
minimum=8, |
|
maximum=30, |
|
value=16, |
|
step=1 |
|
) |
|
|
|
guidance_scale = gr.Slider( |
|
label="Guidance Scale (CFG)", |
|
minimum=1.0, |
|
maximum=15.0, |
|
value=4.0, |
|
step=0.5 |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Inference Steps", |
|
minimum=10, |
|
maximum=100, |
|
value=20, |
|
step=1 |
|
) |
|
|
|
seed = gr.Number( |
|
label="Seed (-1 for random)", |
|
value=-1, |
|
precision=0 |
|
) |
|
|
|
generate_btn = gr.Button("Generate Video") |
|
|
|
with gr.Column(scale=1): |
|
output_video = gr.Video(label="Generated Video") |
|
used_seed = gr.Number(label="Seed", precision=0) |
|
|
|
generate_btn.click( |
|
fn=generate_video, |
|
inputs=[ |
|
model_choice, |
|
prompt, |
|
negative_prompt, |
|
lycoris_id, |
|
lycoris_weight_name, |
|
lycoris_scale, |
|
scheduler_type, |
|
flow_shift, |
|
height, |
|
width, |
|
num_frames, |
|
guidance_scale, |
|
num_inference_steps, |
|
output_fps, |
|
seed |
|
], |
|
outputs=[output_video, used_seed] |
|
) |
|
|
|
gr.Markdown(""" |
|
## Tips for best results: |
|
- Smaller videos: Flow shift 2.0–5.0 |
|
- Larger videos: Flow shift 7.0–12.0 |
|
- Use frame count in 4k+1 form (e.g., 33, 65) |
|
- Limit frame count and resolution to avoid timeout |
|
""") |
|
|
|
demo.launch() |