Spaces:
Running
on
Zero
Running
on
Zero
import os | |
# PyTorch 2.8 (temporary hack) | |
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') | |
# --- 1. Model Download and Setup (Diffusers Backend) --- | |
import spaces | |
import torch | |
from diffusers import FlowMatchEulerDiscreteScheduler | |
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
from diffusers.utils.export_utils import export_to_video | |
import gradio as gr | |
import tempfile | |
import numpy as np | |
from PIL import Image | |
import random | |
import gc | |
# Import the optimization function from the separate file | |
from optimization import optimize_pipeline_ | |
# --- Constants and Model Loading --- | |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
# --- NEW: Flexible Dimension Constants --- | |
MAX_DIMENSION = 832 | |
MIN_DIMENSION = 480 | |
DIMENSION_MULTIPLE = 16 | |
SQUARE_SIZE = 480 | |
MAX_SEED = np.iinfo(np.int32).max | |
FIXED_FPS = 16 | |
MIN_FRAMES_MODEL = 8 | |
MAX_FRAMES_MODEL = 81 | |
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1) | |
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1) | |
default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝," | |
print("Loading models into memory. This may take a few minutes...") | |
pipe = WanImageToVideoPipeline.from_pretrained( | |
MODEL_ID, | |
transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
subfolder='transformer', | |
torch_dtype=torch.bfloat16, | |
device_map='cuda', | |
), | |
transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
subfolder='transformer_2', | |
torch_dtype=torch.bfloat16, | |
device_map='cuda', | |
), | |
torch_dtype=torch.bfloat16, | |
) | |
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0) | |
pipe.to('cuda') | |
print("Optimizing pipeline...") | |
for i in range(3): | |
gc.collect() | |
torch.cuda.synchronize() | |
torch.cuda.empty_cache() | |
# Calling the imported optimization function with a placeholder image for compilation tracing | |
optimize_pipeline_(pipe, | |
image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)), # Use representative dims | |
prompt='prompt', | |
height=MIN_DIMENSION, | |
width=MAX_DIMENSION, | |
num_frames=MAX_FRAMES_MODEL, | |
) | |
print("All models loaded and optimized. Gradio app is ready.") | |
# --- 2. Image Processing and Application Logic --- | |
def process_image_for_video(image: Image.Image) -> Image.Image: | |
""" | |
Resizes an image based on the following rules for video generation: | |
1. The longest side will be scaled down to MAX_DIMENSION if it's larger. | |
2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller. | |
3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE. | |
4. Square images are resized to a fixed SQUARE_SIZE. | |
The aspect ratio is preserved as closely as possible. | |
""" | |
width, height = image.size | |
# Rule 4: Handle square images | |
if width == height: | |
return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS) | |
# Determine target dimensions while preserving aspect ratio | |
aspect_ratio = width / height | |
new_width, new_height = width, height | |
# Rule 1: Scale down if too large | |
if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION: | |
if aspect_ratio > 1: # Landscape | |
scale = MAX_DIMENSION / new_width | |
else: # Portrait | |
scale = MAX_DIMENSION / new_height | |
new_width *= scale | |
new_height *= scale | |
# Rule 2: Scale up if too small | |
if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION: | |
if aspect_ratio > 1: # Landscape | |
scale = MIN_DIMENSION / new_height | |
else: # Portrait | |
scale = MIN_DIMENSION / new_width | |
new_width *= scale | |
new_height *= scale | |
# Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE | |
final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) | |
final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) | |
# Ensure final dimensions are at least the minimum | |
final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE) | |
final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE) | |
return image.resize((final_width, final_height), Image.Resampling.LANCZOS) | |
def resize_and_crop_to_match(target_image, reference_image): | |
"""Resizes and center-crops the target image to match the reference image's dimensions.""" | |
ref_width, ref_height = reference_image.size | |
target_width, target_height = target_image.size | |
scale = max(ref_width / target_width, ref_height / target_height) | |
new_width, new_height = int(target_width * scale), int(target_height * scale) | |
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 | |
return resized.crop((left, top, left + ref_width, top + ref_height)) | |
def generate_video( | |
start_image_pil, | |
end_image_pil, | |
prompt, | |
negative_prompt=default_negative_prompt, | |
duration_seconds=2.1, | |
steps=8, | |
guidance_scale=1, | |
guidance_scale_2=1, | |
seed=42, | |
randomize_seed=False, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
""" | |
Generates a video by interpolating between a start and end image, guided by a text prompt, | |
using the diffusers Wan2.2 pipeline. | |
""" | |
if start_image_pil is None or end_image_pil is None: | |
raise gr.Error("Please upload both a start and an end image.") | |
progress(0.1, desc="Preprocessing images...") | |
# Step 1: Process the start image to get our target dimensions based on the new rules. | |
processed_start_image = process_image_for_video(start_image_pil) | |
# Step 2: Make the end image match the *exact* dimensions of the processed start image. | |
processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image) | |
target_height, target_width = processed_start_image.height, processed_start_image.width | |
# Handle seed and frame count | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...") | |
output_frames_list = pipe( | |
image=processed_start_image, | |
last_image=processed_end_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=target_height, | |
width=target_width, | |
num_frames=num_frames, | |
guidance_scale=float(guidance_scale), | |
guidance_scale_2=float(guidance_scale_2), | |
num_inference_steps=int(steps), | |
generator=torch.Generator(device="cuda").manual_seed(current_seed), | |
).frames[0] | |
progress(0.9, desc="Encoding and saving video...") | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
video_path = tmpfile.name | |
export_to_video(output_frames_list, video_path, fps=FIXED_FPS) | |
progress(1.0, desc="Done!") | |
return video_path, current_seed | |
# --- 3. Gradio User Interface --- (No changes needed here) | |
css = ''' | |
.fillable{max-width: 1100px !important} | |
.dark .progress-text {color: white} | |
''' | |
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: | |
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") | |
gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
with gr.Row(): | |
start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"]) | |
end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"]) | |
prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") | |
with gr.Accordion("Advanced Settings", open=False): | |
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=2.1, label="Video Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.") | |
negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) | |
steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps") | |
guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise") | |
guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise") | |
with gr.Row(): | |
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) | |
generate_button = gr.Button("Generate Video", variant="primary") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video", autoplay=True) | |
# Define the inputs list for the click event | |
ui_inputs = [ | |
start_image, | |
end_image, | |
prompt, | |
negative_prompt_input, | |
duration_seconds_input, | |
steps_slider, | |
guidance_scale_input, | |
guidance_scale_2_input, | |
seed_input, | |
randomize_seed_checkbox | |
] | |
# The seed_input is both an input and an output to reflect the randomly generated seed | |
ui_outputs = [output_video, seed_input] | |
generate_button.click( | |
fn=generate_video, | |
inputs=ui_inputs, | |
outputs=ui_outputs | |
) | |
gr.Examples( | |
examples=[ | |
["poli_tower.png", "tower_takes_off.png", "the man turns around"], | |
["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"], | |
["capyabara_zoomed.png", "capyabara.webp", "a dramatic dolly zoom"], | |
], | |
inputs=[start_image, end_image, prompt], | |
outputs=ui_outputs, | |
fn=generate_video, | |
cache_examples="lazy", | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) |