rahul7star's picture
Update app.py
9d26f7c verified
import torch
import gradio as gr
import spaces
import random
import os
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from lycoris import create_lycoris_from_weights
# Define model options
MODEL_OPTIONS = {
"Wan2.1-T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"Wan2.1-T2V-14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"Wan2.1-Fun-Reward-1.3B": "alibaba-pai/Wan2.1-Fun-1.3B-InP"
}
# Define scheduler options
SCHEDULER_OPTIONS = {
"UniPCMultistepScheduler": UniPCMultistepScheduler,
"FlowMatchEulerDiscreteScheduler": FlowMatchEulerDiscreteScheduler
}
def download_adapter(repo_id, weight_name=None):
adapter_filename = weight_name if weight_name else "pytorch_lora_weights.safetensors"
cache_dir = os.environ.get('HF_PATH', os.path.expanduser('~/.cache/huggingface/hub/models'))
cleaned_adapter_path = repo_id.replace("/", "_").replace("\\", "_").replace(":", "_")
path_to_adapter = os.path.join(cache_dir, cleaned_adapter_path)
os.makedirs(path_to_adapter, exist_ok=True)
try:
path_to_adapter_file = hf_hub_download(
repo_id=repo_id,
filename=adapter_filename,
local_dir=path_to_adapter
)
return path_to_adapter_file
except Exception as e:
if weight_name is None:
raise ValueError(f"Could not download default adapter file: {str(e)}\nPlease specify the exact weight file name.")
else:
raise ValueError(f"Could not download adapter file {weight_name}: {str(e)}")
@spaces.GPU(duration=140)
def generate_video(
model_choice,
prompt,
negative_prompt,
lycoris_id,
lycoris_weight_name,
lycoris_scale,
scheduler_type,
flow_shift,
height,
width,
num_frames,
guidance_scale,
num_inference_steps,
output_fps,
seed
):
model_id = MODEL_OPTIONS[model_choice]
if seed == -1 or seed is None or seed == "":
seed = random.randint(0, 2147483647)
else:
seed = int(seed)
torch.manual_seed(seed)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float16)
if scheduler_type == "UniPCMultistepScheduler":
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
else:
pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
pipe.to("cuda")
if lycoris_id and lycoris_id.strip():
try:
adapter_file_path = download_adapter(
repo_id=lycoris_id,
weight_name=lycoris_weight_name if lycoris_weight_name and lycoris_weight_name.strip() else None
)
wrapper, *_ = create_lycoris_from_weights(lycoris_scale, adapter_file_path, pipe.transformer)
wrapper.merge_to()
except ValueError as e:
if "more than one weights file" in str(e) or "Could not download default adapter file" in str(e):
return f"Error: The repository '{lycoris_id}' may contain multiple weight files. Please specify a weight name.", seed
else:
return f"Error loading LyCORIS weights: {str(e)}", seed
pipe.enable_model_cpu_offload()
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed)
).frames[0]
temp_file = "output.mp4"
export_to_video(output, temp_file, fps=output_fps)
return temp_file, seed
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Wan 2.1 T2V with Custom LoRA")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value="Wan2.1-Fun-Reward-1.3B",
label="Model"
)
prompt = gr.Textbox(label="Prompt", value="", lines=3)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static",
lines=3
)
lycoris_id = gr.Textbox(
label="Adapter Repo",
value="alibaba-pai/Wan2.1-Fun-Reward-LoRAs"
)
with gr.Row():
lycoris_weight_name = gr.Textbox(
label="Adapter File Name",
value="Wan2.1-Fun-1.3B-InP-MPS.safetensors"
)
lycoris_scale = gr.Slider(
label="Adapter Scale",
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05
)
scheduler_type = gr.Dropdown(
choices=list(SCHEDULER_OPTIONS.keys()),
value="UniPCMultistepScheduler",
label="Scheduler"
)
flow_shift = gr.Slider(
label="Flow Shift",
minimum=1.0,
maximum=12.0,
value=3.0,
step=0.5
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
value=320,
step=32
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=1792,
value=480,
step=30
)
num_frames = gr.Slider(
label="Number of Frames",
minimum=17,
maximum=129,
value=33,
step=4
)
output_fps = gr.Slider(
label="Output FPS",
minimum=8,
maximum=30,
value=16,
step=1
)
guidance_scale = gr.Slider(
label="Guidance Scale (CFG)",
minimum=1.0,
maximum=15.0,
value=4.0,
step=0.5
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=100,
value=20,
step=1
)
seed = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("Generate Video")
with gr.Column(scale=1):
output_video = gr.Video(label="Generated Video")
used_seed = gr.Number(label="Seed", precision=0)
generate_btn.click(
fn=generate_video,
inputs=[
model_choice,
prompt,
negative_prompt,
lycoris_id,
lycoris_weight_name,
lycoris_scale,
scheduler_type,
flow_shift,
height,
width,
num_frames,
guidance_scale,
num_inference_steps,
output_fps,
seed
],
outputs=[output_video, used_seed]
)
gr.Markdown("""
## Tips for best results:
- Smaller videos: Flow shift 2.0–5.0
- Larger videos: Flow shift 7.0–12.0
- Use frame count in 4k+1 form (e.g., 33, 65)
- Limit frame count and resolution to avoid timeout
""")
demo.launch()