|
import spaces |
|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
import trimesh |
|
import random |
|
from transformers import AutoModelForImageSegmentation |
|
from torchvision import transforms |
|
from huggingface_hub import hf_hub_download, snapshot_download, login |
|
import subprocess |
|
import shutil |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
DTYPE = torch.float16 |
|
|
|
print("DEVICE: ", DEVICE) |
|
|
|
DEFAULT_PART_FACE_NUMBER = 10000 |
|
MAX_SEED = np.iinfo(np.int32).max |
|
HOLOPART_REPO_URL = "https://github.com/VAST-AI-Research/HoloPart" |
|
HOLOPART_PRETRAINED_MODEL = "checkpoints/HoloPart" |
|
|
|
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") |
|
os.makedirs(TMP_DIR, exist_ok=True) |
|
|
|
HOLOPART_CODE_DIR = "./holopart" |
|
if not os.path.exists(HOLOPART_REPO_URL): |
|
os.system(f"git clone {HOLOPART_REPO_URL} {HOLOPART_CODE_DIR}") |
|
|
|
import sys |
|
sys.path.append(HOLOPART_CODE_DIR) |
|
sys.path.append(os.path.join(HOLOPART_CODE_DIR, "scripts")) |
|
|
|
EXAMPLES = [ |
|
["./holopart/assets/example_data/000.glb", "./holopart/assets/example_data/000.png"], |
|
["./holopart/assets/example_data/001.glb", "./holopart/assets/example_data/001.png"], |
|
["./holopart/assets/example_data/002.glb", "./holopart/assets/example_data/002.png"], |
|
["./holopart/assets/example_data/003.glb", "./holopart/assets/example_data/003.png"], |
|
] |
|
|
|
HEADER = """ |
|
# 🔮 Decompose a 3D shape into complete parts with [HoloPart](https://github.com/VAST-AI-Research/HoloPart). |
|
### Step 1: Prepare Your Segmented Mesh |
|
Upload a mesh with part segmentation. We recommend using these segmentation tools: |
|
- [SAMPart3D](https://github.com/Pointcept/SAMPart3D) |
|
- [SAMesh](https://github.com/gtangg12/samesh) |
|
For a mesh file `mesh.glb` and corresponding face mask `mask.npy`, prepare your input using this Python code: |
|
```python |
|
import trimesh |
|
import numpy as np |
|
mesh = trimesh.load("mesh.glb", force="mesh") |
|
mask_npy = np.load("mask.npy") |
|
mesh_parts = [] |
|
for part_id in np.unique(mask_npy): |
|
mesh_part = mesh.submesh([mask_npy == part_id], append=True) |
|
mesh_parts.append(mesh_part) |
|
mesh_parts = trimesh.Scene(mesh_parts).export("input_mesh.glb") |
|
``` |
|
The resulting **input_mesh.glb** is your prepared input for HoloPart. |
|
### Step 2: Click the Decompose Parts button to begin the decomposition process. |
|
""" |
|
|
|
from inference_holopart import prepare_data, run_holopart |
|
from holopart.pipelines.pipeline_holopart import HoloPartPipeline |
|
|
|
snapshot_download("VAST-AI/HoloPart", local_dir=HOLOPART_PRETRAINED_MODEL) |
|
holopart_pipe = HoloPartPipeline.from_pretrained(HOLOPART_PRETRAINED_MODEL).to(DEVICE, DTYPE) |
|
|
|
def start_session(req: gr.Request): |
|
save_dir = os.path.join(TMP_DIR, str(req.session_hash)) |
|
os.makedirs(save_dir, exist_ok=True) |
|
print("start session, mkdir", save_dir) |
|
|
|
def end_session(req: gr.Request): |
|
save_dir = os.path.join(TMP_DIR, str(req.session_hash)) |
|
shutil.rmtree(save_dir) |
|
|
|
def get_random_hex(): |
|
random_bytes = os.urandom(8) |
|
random_hex = random_bytes.hex() |
|
return random_hex |
|
|
|
def get_random_seed(randomize_seed, seed): |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
def explode_mesh(mesh: trimesh.Scene, explode_factor: float = 0.5): |
|
center = mesh.centroid |
|
exploded_mesh = trimesh.Scene() |
|
for geometry_name, geometry in mesh.geometry.items(): |
|
transform = mesh.graph[geometry_name][0] |
|
vertices_global = trimesh.transformations.transform_points( |
|
geometry.vertices, transform) |
|
part_center = np.mean(vertices_global, axis=0) |
|
direction = part_center - center |
|
direction_length = np.linalg.norm(direction) |
|
if direction_length > 0: |
|
direction = direction / direction_length |
|
displacement = direction * explode_factor |
|
new_transform = np.copy(transform) |
|
new_transform[:3, 3] += displacement |
|
exploded_mesh.add_geometry(geometry, transform=new_transform, geom_name=geometry_name) |
|
return exploded_mesh |
|
|
|
|
|
@spaces.GPU(duration=600) |
|
def run_full(data_path, seed=42, num_inference_steps=25, guidance_scale=3.5): |
|
|
|
batch_size = 30 |
|
parts_data = prepare_data(data_path) |
|
|
|
part_scene = run_holopart( |
|
holopart_pipe, |
|
batch=parts_data, |
|
batch_size=batch_size, |
|
seed=seed, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
num_chunks=1000000, |
|
) |
|
print("mesh extraction done") |
|
|
|
save_dir = os.path.join(TMP_DIR, "examples") |
|
os.makedirs(save_dir, exist_ok=True) |
|
mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") |
|
part_scene.export(mesh_path) |
|
print("save to ", mesh_path) |
|
exploded_mesh = explode_mesh(part_scene, 0.7) |
|
exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") |
|
exploded_mesh.export(exploded_mesh_path) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
return mesh_path, exploded_mesh_path |
|
|
|
|
|
@spaces.GPU(duration=600) |
|
def run_example(data_path: str, example_image_path, seed=42, num_inference_steps=25, guidance_scale=3.5): |
|
|
|
batch_size = 30 |
|
parts_data = prepare_data(data_path) |
|
|
|
part_scene = run_holopart( |
|
holopart_pipe, |
|
batch=parts_data, |
|
batch_size=batch_size, |
|
seed=seed, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
num_chunks=1000000, |
|
) |
|
print("mesh extraction done") |
|
|
|
|
|
save_dir = os.path.join(TMP_DIR, "examples") |
|
os.makedirs(save_dir, exist_ok=True) |
|
mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") |
|
part_scene.export(mesh_path) |
|
print("save to ", mesh_path) |
|
exploded_mesh = explode_mesh(part_scene, 0.5) |
|
exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") |
|
exploded_mesh.export(exploded_mesh_path) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
return mesh_path, exploded_mesh_path |
|
|
|
|
|
with gr.Blocks(title="HoloPart") as demo: |
|
gr.Markdown(HEADER) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
input_mesh = gr.Model3D(label="Input Mesh") |
|
example_image = gr.Image(label="Example Image", type="filepath", interactive=False, visible=False) |
|
|
|
|
|
|
|
|
|
with gr.Accordion("Generation Settings", open=True): |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=0, |
|
value=0 |
|
) |
|
|
|
num_inference_steps = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=8, |
|
maximum=50, |
|
step=1, |
|
value=25, |
|
) |
|
guidance_scale = gr.Slider( |
|
label="CFG scale", |
|
minimum=0.0, |
|
maximum=20.0, |
|
step=0.1, |
|
value=3.5, |
|
) |
|
|
|
with gr.Row(): |
|
reduce_face = gr.Checkbox(label="Simplify Mesh", value=True, interactive=False) |
|
|
|
|
|
gen_button = gr.Button("Decompose Parts", variant="primary") |
|
|
|
with gr.Column(): |
|
model_output = gr.Model3D(label="Decomposed GLB", interactive=False) |
|
exploded_parts_output = gr.Model3D(label="Exploded Parts", interactive=False) |
|
|
|
with gr.Row(): |
|
examples = gr.Examples( |
|
examples=EXAMPLES, |
|
fn=run_example, |
|
inputs=[input_mesh, example_image], |
|
outputs=[model_output, exploded_parts_output], |
|
cache_examples=True, |
|
) |
|
|
|
|
|
gen_button.click( |
|
run_full, |
|
inputs=[ |
|
input_mesh, |
|
seed, |
|
num_inference_steps, |
|
guidance_scale |
|
], |
|
outputs=[model_output, exploded_parts_output], |
|
) |
|
|
|
demo.load(start_session) |
|
demo.unload(end_session) |
|
|
|
demo.launch() |
|
|