abreza's picture
fix: comment out ensure_cuda_toolkit calls in app.py
71c4b67
import os
import sys
import pathlib
import tyro
import subprocess
import gradio as gr
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipelineAnimal
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
import spaces
ROOT = pathlib.Path(__file__).resolve().parent
OPS_DIR = ROOT / "src" / "utils" / "dependencies" / "XPose" / "models" / "UniPose" / "ops"
sys.path.insert(0, str(OPS_DIR))
CUDA_RUN_URL = (
"https://developer.download.nvidia.com/compute/cuda/11.8.0/"
"local_installers/cuda_11.8.0_520.61.05_linux.run"
)
CUDA_HOME_PATH = "/usr/local/cuda"
TORCH_WHL_INDEX = "https://download.pytorch.org/whl/torch_stable.html"
def ensure_cuda_toolkit():
"""Download & install the CUDA toolkit *silently* if it is not present."""
print("Checking for CUDA toolkit...")
if pathlib.Path(f"{CUDA_HOME_PATH}/bin/nvcc").exists():
print(f"CUDA toolkit already installed at {CUDA_HOME_PATH}")
return # toolkit already installed
print(f"CUDA toolkit not found. Downloading from {CUDA_RUN_URL}...")
run_file = f"/tmp/{pathlib.Path(CUDA_RUN_URL).name}"
subprocess.run(["wget", "-q", CUDA_RUN_URL, "-O", run_file], check=True)
print(f"Download complete. Making installer executable...")
subprocess.run(["chmod", "+x", run_file], check=True)
print("Installing CUDA toolkit (this may take a while)...")
subprocess.run([run_file, "--silent", "--toolkit"], check=True)
print("CUDA toolkit installation complete.")
# --- environment variables expected by CUDA extensions -------------------
print("Setting up CUDA environment variables...")
os.environ["CUDA_HOME"] = CUDA_HOME_PATH
os.environ["PATH"] = f"{CUDA_HOME_PATH}/bin:" + os.environ.get("PATH", "")
os.environ["LD_LIBRARY_PATH"] = (
f"{CUDA_HOME_PATH}/lib64:" + os.environ.get("LD_LIBRARY_PATH", "")
)
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
print("CUDA environment setup complete.")
def build_xpose_ops():
"""Build the MultiScaleDeformableAttention CUDA extension with enhanced error handling."""
try:
import MultiScaleDeformableAttention
print("MultiScaleDeformableAttention already installed")
return True
except ImportError:
print("Building MultiScaleDeformableAttention...")
current_dir = os.getcwd()
ops_dir = os.path.join(
current_dir, "src/utils/dependencies/XPose/models/UniPose/ops")
try:
os.chdir(ops_dir)
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
try:
subprocess.run(
[sys.executable, "setup.py", "build", "install"],
check=True,
env={**os.environ, "CFLAGS": "-O0", "CXXFLAGS": "-O0"}
)
print("MultiScaleDeformableAttention built successfully")
try:
import MultiScaleDeformableAttention
built_success = True
except ImportError:
print("Failed to import MultiScaleDeformableAttention after building")
built_success = False
except subprocess.CalledProcessError as e:
print(f"Build error: {e}")
try:
print("Attempting simplified build...")
subprocess.run(
[sys.executable, "setup.py", "build", "install"],
check=True
)
print("Simplified build completed")
try:
import MultiScaleDeformableAttention
print("MultiScaleDeformableAttention imported after simplified build")
built_success = True
except ImportError:
print("Still unable to import after simplified build")
built_success = False
except Exception as e2:
print(f"Simplified build also failed: {e2}")
built_success = False
os.chdir(current_dir)
return built_success
except Exception as e:
print(f"Error during XPose ops build: {e}")
# Make sure to return to original directory
if os.getcwd() != current_dir:
os.chdir(current_dir)
return False
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
if osp.exists(ffmpeg_dir):
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
# specify configs for inference
# use attribute of args to initial InferenceConfig
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
# use attribute of args to initial CropConfig
crop_cfg = partial_fields(CropConfig, args.__dict__)
if args.gradio_temp_dir not in (None, ''):
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
os.makedirs(args.gradio_temp_dir, exist_ok=True)
gradio_pipeline_animal: GradioPipelineAnimal = None
# ensure_cuda_toolkit()
@spaces.GPU(duration=240)
def gpu_wrapped_execute_video(*args, **kwargs):
global gradio_pipeline_animal
# ensure_cuda_toolkit()
cuda_ext_built = build_xpose_ops()
if not cuda_ext_built:
print("WARNING: MultiScaleDeformableAttention CUDA extension could not be built. "
"The model may fall back to slower CPU implementation or simplified mode.")
if gradio_pipeline_animal is None:
gradio_pipeline_animal = GradioPipelineAnimal(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
return gradio_pipeline_animal.execute_video(*args, **kwargs)
# assets
title_md = "assets/gradio/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples_i2v = [
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(
example_video_dir, "d3.mp4"), True, False, False, False],
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(
example_video_dir, "d6.mp4"), True, False, False, False],
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(
example_video_dir, "d19.mp4"), True, False, False, False],
]
data_examples_i2v_pickle = [
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(
example_video_dir, "wink.pkl"), True, False, False, False],
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(
example_video_dir, "talking.pkl"), True, False, False, False],
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(
example_video_dir, "aggrieved.pkl"), True, False, False, False],
]
#################### interface logic ####################
# Define components first
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
output_video_i2v_gif = gr.Image(type="numpy")
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description(
"assets/gradio/gradio_description_upload_animal.md"))
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="🐱 Source Animal Image"):
source_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s25.jpg")],
[osp.join(example_portrait_dir, "s30.jpg")],
[osp.join(example_portrait_dir, "s31.jpg")],
[osp.join(example_portrait_dir, "s32.jpg")],
[osp.join(example_portrait_dir, "s33.jpg")],
[osp.join(example_portrait_dir, "s39.jpg")],
[osp.join(example_portrait_dir, "s40.jpg")],
[osp.join(example_portrait_dir, "s41.jpg")],
[osp.join(example_portrait_dir, "s38.jpg")],
[osp.join(example_portrait_dir, "s36.jpg")],
],
inputs=[source_image_input],
cache_examples=False,
)
with gr.Accordion(open=True, label="Cropping Options for Source Image"):
with gr.Row():
flag_do_crop_input = gr.Checkbox(
value=True, label="do crop (source)")
scale = gr.Number(
value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio = gr.Number(
value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio = gr.Number(
value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column():
with gr.Tabs():
with gr.TabItem("πŸ“ Driving Pickle") as tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File()
gr.Examples(
examples=[
[osp.join(example_video_dir, "wink.pkl")],
[osp.join(example_video_dir, "shy.pkl")],
[osp.join(example_video_dir, "aggrieved.pkl")],
[osp.join(example_video_dir, "open_lip.pkl")],
[osp.join(example_video_dir, "laugh.pkl")],
[osp.join(example_video_dir, "talking.pkl")],
[osp.join(example_video_dir,
"shake_face.pkl")],
],
inputs=[driving_video_pickle_input],
cache_examples=False,
)
with gr.TabItem("🎞️ Driving Video") as tab_video:
with gr.Accordion(open=True, label="Driving Video"):
driving_video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
[osp.join(example_video_dir, "d3.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
)
tab_selection = gr.Textbox(visible=False)
tab_pickle.select(lambda: "Pickle", None, tab_selection)
tab_video.select(lambda: "Video", None, tab_selection)
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
with gr.Row():
flag_crop_driving_video_input = gr.Checkbox(
value=False, label="do crop (driving)")
scale_crop_driving_video = gr.Number(
value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio_crop_driving_video = gr.Number(
value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio_crop_driving_video = gr.Number(
value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Row():
with gr.Accordion(open=False, label="Animation Options"):
with gr.Row():
flag_stitching = gr.Checkbox(
value=False, label="stitching (not recommended)")
flag_remap_input = gr.Checkbox(
value=False, label="paste-back (not recommended)")
driving_multiplier = gr.Number(
value=1.0, label="driving multiplier", minimum=0.0, maximum=2.0, step=0.02)
gr.Markdown(load_description(
"assets/gradio/gradio_description_animate_clear.md"))
with gr.Row():
process_button_animation = gr.Button("πŸš€ Animate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the cropped image space"):
output_video_i2v.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated gif in the cropped image space"):
output_video_i2v_gif.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
with gr.Row():
process_button_reset = gr.ClearButton(
[source_image_input, driving_video_input, output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], value="🧹 Clear")
with gr.Row():
# Examples
gr.Markdown(
"## You could also choose the examples below by one click ⬇️")
with gr.Row():
with gr.Tabs():
with gr.TabItem("πŸ“ Driving Pickle") as tab_video:
gr.Examples(
examples=data_examples_i2v_pickle,
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_pickle_input,
flag_do_crop_input,
flag_stitching,
flag_remap_input,
flag_crop_driving_video_input,
],
outputs=[output_image, output_image_paste_back,
output_video_i2v_gif],
examples_per_page=len(data_examples_i2v_pickle),
cache_examples=False,
)
with gr.TabItem("🎞️ Driving Video") as tab_video:
gr.Examples(
examples=data_examples_i2v,
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_input,
flag_do_crop_input,
flag_stitching,
flag_remap_input,
flag_crop_driving_video_input,
],
outputs=[output_image, output_image_paste_back,
output_video_i2v_gif],
examples_per_page=len(data_examples_i2v),
cache_examples=False,
)
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_input,
driving_video_pickle_input,
flag_do_crop_input,
flag_remap_input,
driving_multiplier,
flag_stitching,
flag_crop_driving_video_input,
scale,
vx_ratio,
vy_ratio,
scale_crop_driving_video,
vx_ratio_crop_driving_video,
vy_ratio_crop_driving_video,
tab_selection,
],
outputs=[output_video_i2v,
output_video_concat_i2v, output_video_i2v_gif],
show_progress=True
)
demo.launch()