Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import pathlib | |
import tyro | |
import subprocess | |
import gradio as gr | |
import os.path as osp | |
from src.utils.helper import load_description | |
from src.gradio_pipeline import GradioPipelineAnimal | |
from src.config.crop_config import CropConfig | |
from src.config.argument_config import ArgumentConfig | |
from src.config.inference_config import InferenceConfig | |
import spaces | |
ROOT = pathlib.Path(__file__).resolve().parent | |
OPS_DIR = ROOT / "src" / "utils" / "dependencies" / "XPose" / "models" / "UniPose" / "ops" | |
sys.path.insert(0, str(OPS_DIR)) | |
CUDA_RUN_URL = ( | |
"https://developer.download.nvidia.com/compute/cuda/11.8.0/" | |
"local_installers/cuda_11.8.0_520.61.05_linux.run" | |
) | |
CUDA_HOME_PATH = "/usr/local/cuda" | |
TORCH_WHL_INDEX = "https://download.pytorch.org/whl/torch_stable.html" | |
def ensure_cuda_toolkit(): | |
"""Download & install the CUDA toolkit *silently* if it is not present.""" | |
print("Checking for CUDA toolkit...") | |
if pathlib.Path(f"{CUDA_HOME_PATH}/bin/nvcc").exists(): | |
print(f"CUDA toolkit already installed at {CUDA_HOME_PATH}") | |
return # toolkit already installed | |
print(f"CUDA toolkit not found. Downloading from {CUDA_RUN_URL}...") | |
run_file = f"/tmp/{pathlib.Path(CUDA_RUN_URL).name}" | |
subprocess.run(["wget", "-q", CUDA_RUN_URL, "-O", run_file], check=True) | |
print(f"Download complete. Making installer executable...") | |
subprocess.run(["chmod", "+x", run_file], check=True) | |
print("Installing CUDA toolkit (this may take a while)...") | |
subprocess.run([run_file, "--silent", "--toolkit"], check=True) | |
print("CUDA toolkit installation complete.") | |
# --- environment variables expected by CUDA extensions ------------------- | |
print("Setting up CUDA environment variables...") | |
os.environ["CUDA_HOME"] = CUDA_HOME_PATH | |
os.environ["PATH"] = f"{CUDA_HOME_PATH}/bin:" + os.environ.get("PATH", "") | |
os.environ["LD_LIBRARY_PATH"] = ( | |
f"{CUDA_HOME_PATH}/lib64:" + os.environ.get("LD_LIBRARY_PATH", "") | |
) | |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" | |
print("CUDA environment setup complete.") | |
def build_xpose_ops(): | |
"""Build the MultiScaleDeformableAttention CUDA extension with enhanced error handling.""" | |
try: | |
import MultiScaleDeformableAttention | |
print("MultiScaleDeformableAttention already installed") | |
return True | |
except ImportError: | |
print("Building MultiScaleDeformableAttention...") | |
current_dir = os.getcwd() | |
ops_dir = os.path.join( | |
current_dir, "src/utils/dependencies/XPose/models/UniPose/ops") | |
try: | |
os.chdir(ops_dir) | |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" | |
try: | |
subprocess.run( | |
[sys.executable, "setup.py", "build", "install"], | |
check=True, | |
env={**os.environ, "CFLAGS": "-O0", "CXXFLAGS": "-O0"} | |
) | |
print("MultiScaleDeformableAttention built successfully") | |
try: | |
import MultiScaleDeformableAttention | |
built_success = True | |
except ImportError: | |
print("Failed to import MultiScaleDeformableAttention after building") | |
built_success = False | |
except subprocess.CalledProcessError as e: | |
print(f"Build error: {e}") | |
try: | |
print("Attempting simplified build...") | |
subprocess.run( | |
[sys.executable, "setup.py", "build", "install"], | |
check=True | |
) | |
print("Simplified build completed") | |
try: | |
import MultiScaleDeformableAttention | |
print("MultiScaleDeformableAttention imported after simplified build") | |
built_success = True | |
except ImportError: | |
print("Still unable to import after simplified build") | |
built_success = False | |
except Exception as e2: | |
print(f"Simplified build also failed: {e2}") | |
built_success = False | |
os.chdir(current_dir) | |
return built_success | |
except Exception as e: | |
print(f"Error during XPose ops build: {e}") | |
# Make sure to return to original directory | |
if os.getcwd() != current_dir: | |
os.chdir(current_dir) | |
return False | |
def partial_fields(target_class, kwargs): | |
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) | |
def fast_check_ffmpeg(): | |
try: | |
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) | |
return True | |
except: | |
return False | |
# set tyro theme | |
tyro.extras.set_accent_color("bright_cyan") | |
args = tyro.cli(ArgumentConfig) | |
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") | |
if osp.exists(ffmpeg_dir): | |
os.environ["PATH"] += (os.pathsep + ffmpeg_dir) | |
if not fast_check_ffmpeg(): | |
raise ImportError( | |
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" | |
) | |
# specify configs for inference | |
# use attribute of args to initial InferenceConfig | |
inference_cfg = partial_fields(InferenceConfig, args.__dict__) | |
# use attribute of args to initial CropConfig | |
crop_cfg = partial_fields(CropConfig, args.__dict__) | |
if args.gradio_temp_dir not in (None, ''): | |
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir | |
os.makedirs(args.gradio_temp_dir, exist_ok=True) | |
gradio_pipeline_animal: GradioPipelineAnimal = None | |
# ensure_cuda_toolkit() | |
def gpu_wrapped_execute_video(*args, **kwargs): | |
global gradio_pipeline_animal | |
# ensure_cuda_toolkit() | |
cuda_ext_built = build_xpose_ops() | |
if not cuda_ext_built: | |
print("WARNING: MultiScaleDeformableAttention CUDA extension could not be built. " | |
"The model may fall back to slower CPU implementation or simplified mode.") | |
if gradio_pipeline_animal is None: | |
gradio_pipeline_animal = GradioPipelineAnimal( | |
inference_cfg=inference_cfg, | |
crop_cfg=crop_cfg, | |
args=args | |
) | |
return gradio_pipeline_animal.execute_video(*args, **kwargs) | |
# assets | |
title_md = "assets/gradio/gradio_title.md" | |
example_portrait_dir = "assets/examples/source" | |
example_video_dir = "assets/examples/driving" | |
data_examples_i2v = [ | |
[osp.join(example_portrait_dir, "s41.jpg"), osp.join( | |
example_video_dir, "d3.mp4"), True, False, False, False], | |
[osp.join(example_portrait_dir, "s40.jpg"), osp.join( | |
example_video_dir, "d6.mp4"), True, False, False, False], | |
[osp.join(example_portrait_dir, "s25.jpg"), osp.join( | |
example_video_dir, "d19.mp4"), True, False, False, False], | |
] | |
data_examples_i2v_pickle = [ | |
[osp.join(example_portrait_dir, "s25.jpg"), osp.join( | |
example_video_dir, "wink.pkl"), True, False, False, False], | |
[osp.join(example_portrait_dir, "s40.jpg"), osp.join( | |
example_video_dir, "talking.pkl"), True, False, False, False], | |
[osp.join(example_portrait_dir, "s41.jpg"), osp.join( | |
example_video_dir, "aggrieved.pkl"), True, False, False, False], | |
] | |
#################### interface logic #################### | |
# Define components first | |
output_image = gr.Image(type="numpy") | |
output_image_paste_back = gr.Image(type="numpy") | |
output_video_i2v = gr.Video(autoplay=False) | |
output_video_concat_i2v = gr.Video(autoplay=False) | |
output_video_i2v_gif = gr.Image(type="numpy") | |
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo: | |
gr.HTML(load_description(title_md)) | |
gr.Markdown(load_description( | |
"assets/gradio/gradio_description_upload_animal.md")) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion(open=True, label="π± Source Animal Image"): | |
source_image_input = gr.Image(type="filepath") | |
gr.Examples( | |
examples=[ | |
[osp.join(example_portrait_dir, "s25.jpg")], | |
[osp.join(example_portrait_dir, "s30.jpg")], | |
[osp.join(example_portrait_dir, "s31.jpg")], | |
[osp.join(example_portrait_dir, "s32.jpg")], | |
[osp.join(example_portrait_dir, "s33.jpg")], | |
[osp.join(example_portrait_dir, "s39.jpg")], | |
[osp.join(example_portrait_dir, "s40.jpg")], | |
[osp.join(example_portrait_dir, "s41.jpg")], | |
[osp.join(example_portrait_dir, "s38.jpg")], | |
[osp.join(example_portrait_dir, "s36.jpg")], | |
], | |
inputs=[source_image_input], | |
cache_examples=False, | |
) | |
with gr.Accordion(open=True, label="Cropping Options for Source Image"): | |
with gr.Row(): | |
flag_do_crop_input = gr.Checkbox( | |
value=True, label="do crop (source)") | |
scale = gr.Number( | |
value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05) | |
vx_ratio = gr.Number( | |
value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01) | |
vy_ratio = gr.Number( | |
value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01) | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("π Driving Pickle") as tab_pickle: | |
with gr.Accordion(open=True, label="Driving Pickle"): | |
driving_video_pickle_input = gr.File() | |
gr.Examples( | |
examples=[ | |
[osp.join(example_video_dir, "wink.pkl")], | |
[osp.join(example_video_dir, "shy.pkl")], | |
[osp.join(example_video_dir, "aggrieved.pkl")], | |
[osp.join(example_video_dir, "open_lip.pkl")], | |
[osp.join(example_video_dir, "laugh.pkl")], | |
[osp.join(example_video_dir, "talking.pkl")], | |
[osp.join(example_video_dir, | |
"shake_face.pkl")], | |
], | |
inputs=[driving_video_pickle_input], | |
cache_examples=False, | |
) | |
with gr.TabItem("ποΈ Driving Video") as tab_video: | |
with gr.Accordion(open=True, label="Driving Video"): | |
driving_video_input = gr.Video() | |
gr.Examples( | |
examples=[ | |
[osp.join(example_video_dir, "d19.mp4")], | |
[osp.join(example_video_dir, "d14.mp4")], | |
[osp.join(example_video_dir, "d6.mp4")], | |
[osp.join(example_video_dir, "d3.mp4")], | |
], | |
inputs=[driving_video_input], | |
cache_examples=False, | |
) | |
tab_selection = gr.Textbox(visible=False) | |
tab_pickle.select(lambda: "Pickle", None, tab_selection) | |
tab_video.select(lambda: "Video", None, tab_selection) | |
with gr.Accordion(open=True, label="Cropping Options for Driving Video"): | |
with gr.Row(): | |
flag_crop_driving_video_input = gr.Checkbox( | |
value=False, label="do crop (driving)") | |
scale_crop_driving_video = gr.Number( | |
value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05) | |
vx_ratio_crop_driving_video = gr.Number( | |
value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01) | |
vy_ratio_crop_driving_video = gr.Number( | |
value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01) | |
with gr.Row(): | |
with gr.Accordion(open=False, label="Animation Options"): | |
with gr.Row(): | |
flag_stitching = gr.Checkbox( | |
value=False, label="stitching (not recommended)") | |
flag_remap_input = gr.Checkbox( | |
value=False, label="paste-back (not recommended)") | |
driving_multiplier = gr.Number( | |
value=1.0, label="driving multiplier", minimum=0.0, maximum=2.0, step=0.02) | |
gr.Markdown(load_description( | |
"assets/gradio/gradio_description_animate_clear.md")) | |
with gr.Row(): | |
process_button_animation = gr.Button("π Animate", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion(open=True, label="The animated video in the cropped image space"): | |
output_video_i2v.render() | |
with gr.Column(): | |
with gr.Accordion(open=True, label="The animated gif in the cropped image space"): | |
output_video_i2v_gif.render() | |
with gr.Column(): | |
with gr.Accordion(open=True, label="The animated video"): | |
output_video_concat_i2v.render() | |
with gr.Row(): | |
process_button_reset = gr.ClearButton( | |
[source_image_input, driving_video_input, output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], value="π§Ή Clear") | |
with gr.Row(): | |
# Examples | |
gr.Markdown( | |
"## You could also choose the examples below by one click β¬οΈ") | |
with gr.Row(): | |
with gr.Tabs(): | |
with gr.TabItem("π Driving Pickle") as tab_video: | |
gr.Examples( | |
examples=data_examples_i2v_pickle, | |
fn=gpu_wrapped_execute_video, | |
inputs=[ | |
source_image_input, | |
driving_video_pickle_input, | |
flag_do_crop_input, | |
flag_stitching, | |
flag_remap_input, | |
flag_crop_driving_video_input, | |
], | |
outputs=[output_image, output_image_paste_back, | |
output_video_i2v_gif], | |
examples_per_page=len(data_examples_i2v_pickle), | |
cache_examples=False, | |
) | |
with gr.TabItem("ποΈ Driving Video") as tab_video: | |
gr.Examples( | |
examples=data_examples_i2v, | |
fn=gpu_wrapped_execute_video, | |
inputs=[ | |
source_image_input, | |
driving_video_input, | |
flag_do_crop_input, | |
flag_stitching, | |
flag_remap_input, | |
flag_crop_driving_video_input, | |
], | |
outputs=[output_image, output_image_paste_back, | |
output_video_i2v_gif], | |
examples_per_page=len(data_examples_i2v), | |
cache_examples=False, | |
) | |
process_button_animation.click( | |
fn=gpu_wrapped_execute_video, | |
inputs=[ | |
source_image_input, | |
driving_video_input, | |
driving_video_pickle_input, | |
flag_do_crop_input, | |
flag_remap_input, | |
driving_multiplier, | |
flag_stitching, | |
flag_crop_driving_video_input, | |
scale, | |
vx_ratio, | |
vy_ratio, | |
scale_crop_driving_video, | |
vx_ratio_crop_driving_video, | |
vy_ratio_crop_driving_video, | |
tab_selection, | |
], | |
outputs=[output_video_i2v, | |
output_video_concat_i2v, output_video_i2v_gif], | |
show_progress=True | |
) | |
demo.launch() | |