|
|
import spaces |
|
|
|
|
|
import torch._dynamo |
|
|
torch._dynamo.disable() |
|
|
|
|
|
import os |
|
|
|
|
|
os.environ["TORCHDYNAMO_DISABLE"] = "1" |
|
|
|
|
|
import subprocess |
|
|
import tempfile |
|
|
import uuid |
|
|
import glob |
|
|
import shutil |
|
|
import time |
|
|
import gradio as gr |
|
|
import sys |
|
|
from PIL import Image |
|
|
import importlib, site, sys |
|
|
|
|
|
|
|
|
for sitedir in site.getsitepackages(): |
|
|
site.addsitedir(sitedir) |
|
|
|
|
|
|
|
|
importlib.invalidate_caches() |
|
|
|
|
|
|
|
|
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}" |
|
|
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results" |
|
|
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results" |
|
|
|
|
|
def sh(cmd): subprocess.check_call(cmd, shell=True) |
|
|
|
|
|
sh("pip install -e .") |
|
|
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..") |
|
|
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..") |
|
|
|
|
|
|
|
|
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() |
|
|
|
|
|
from pixel3dmm import env_paths |
|
|
|
|
|
|
|
|
def install_cuda_toolkit(): |
|
|
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run" |
|
|
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) |
|
|
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) |
|
|
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) |
|
|
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) |
|
|
|
|
|
os.environ["CUDA_HOME"] = "/usr/local/cuda" |
|
|
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) |
|
|
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( |
|
|
os.environ["CUDA_HOME"], |
|
|
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], |
|
|
) |
|
|
|
|
|
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" |
|
|
print("==> finished installation") |
|
|
|
|
|
install_cuda_toolkit() |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
from pixel3dmm.network_inference import normals_n_uvs |
|
|
from pixel3dmm.run_facer_segmentation import segment |
|
|
|
|
|
DEVICE = "cuda" |
|
|
|
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
def first_file_from_dir(directory, ext): |
|
|
files = glob.glob(os.path.join(directory, f"*.{ext}")) |
|
|
return sorted(files)[0] if files else None |
|
|
|
|
|
|
|
|
def first_image_from_dir(directory): |
|
|
patterns = ["*.jpg", "*.png", "*.jpeg"] |
|
|
files = [] |
|
|
for p in patterns: |
|
|
files.extend(glob.glob(os.path.join(directory, p))) |
|
|
if not files: |
|
|
return None |
|
|
return sorted(files)[0] |
|
|
|
|
|
|
|
|
def reset_all(): |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
"Time to Generate!", |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True) |
|
|
) |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def preprocess_image(image_array, session_id): |
|
|
if image_array is None: |
|
|
return "❌ Please upload an image first.", gr.update(interactive=True), gr.update(interactive=True) |
|
|
|
|
|
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id) |
|
|
os.makedirs(base_dir, exist_ok=True) |
|
|
|
|
|
img = Image.fromarray(image_array) |
|
|
saved_image_path = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, f"{session_id}.png") |
|
|
img.save(saved_image_path) |
|
|
|
|
|
import facer |
|
|
|
|
|
if "face_detector" not in _model_cache: |
|
|
|
|
|
device = 'cuda' |
|
|
|
|
|
|
|
|
face_detector = facer.face_detector('retinaface/mobilenet', device=device) |
|
|
|
|
|
|
|
|
face_parser = facer.face_parser ('farl/celebm/448', device=device) |
|
|
|
|
|
_model_cache['face_detector'] = face_detector |
|
|
_model_cache['face_parser'] = face_parser |
|
|
|
|
|
subprocess.run([ |
|
|
"python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path |
|
|
], check=True, capture_output=True, text=True) |
|
|
|
|
|
segment(f'{session_id}', _model_cache['face_detector'], _model_cache['face_parser']) |
|
|
|
|
|
crop_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "cropped") |
|
|
image = first_image_from_dir(crop_dir) |
|
|
return "✅ Step 1 complete. Ready for Normals.", image, gr.update(interactive=True), gr.update(interactive=True) |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def step2_normals(session_id): |
|
|
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system |
|
|
|
|
|
base_conf = OmegaConf.load("configs/base.yaml") |
|
|
|
|
|
if "normals_model" not in _model_cache: |
|
|
|
|
|
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_N_PRED}", strict=False) |
|
|
model = model.eval().to(DEVICE) |
|
|
_model_cache["normals_model"] = model |
|
|
|
|
|
base_conf.video_name = f'{session_id}' |
|
|
normals_n_uvs(base_conf, _model_cache["normals_model"]) |
|
|
|
|
|
normals_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "normals") |
|
|
image = first_image_from_dir(normals_dir) |
|
|
|
|
|
return "✅ Step 2 complete. Ready for UV Map.", image, gr.update(interactive=True), gr.update(interactive=True) |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def step3_uv_map(session_id): |
|
|
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system |
|
|
|
|
|
base_conf = OmegaConf.load("configs/base.yaml") |
|
|
|
|
|
if "uv_model" not in _model_cache: |
|
|
|
|
|
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_UV_PRED}", strict=False) |
|
|
model = model.eval().to(DEVICE) |
|
|
_model_cache["uv_model"] = model |
|
|
|
|
|
base_conf.video_name = f'{session_id}' |
|
|
base_conf.model.prediction_type = "uv_map" |
|
|
normals_n_uvs(base_conf, _model_cache["uv_model"]) |
|
|
|
|
|
uv_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "uv_map") |
|
|
image = first_image_from_dir(uv_dir) |
|
|
|
|
|
return "✅ Step 3 complete. Ready for Tracking.", image, gr.update(interactive=True), gr.update(interactive=True) |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def step4_track(session_id): |
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import trimesh |
|
|
from pytorch3d.io import load_obj |
|
|
|
|
|
from pixel3dmm.tracking.flame.FLAME import FLAME |
|
|
from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer |
|
|
from pixel3dmm.tracking.tracker import Tracker |
|
|
|
|
|
tracking_conf = OmegaConf.load("configs/tracking.yaml") |
|
|
|
|
|
|
|
|
if "flame_model" not in _model_cache: |
|
|
|
|
|
flame = FLAME(tracking_conf) |
|
|
flame = flame.to(DEVICE) |
|
|
_model_cache["flame_model"] = flame |
|
|
|
|
|
_mesh_file = env_paths.head_template |
|
|
|
|
|
_obj_faces = load_obj(_mesh_file)[1] |
|
|
|
|
|
_model_cache["diff_renderer"] = NVDRenderer( |
|
|
image_size=tracking_conf.size, |
|
|
obj_filename=_mesh_file, |
|
|
no_sh=False, |
|
|
white_bg=True |
|
|
).to(DEVICE) |
|
|
|
|
|
flame_model = _model_cache["flame_model"] |
|
|
diff_renderer = _model_cache["diff_renderer"] |
|
|
tracking_conf.video_name = f'{session_id}' |
|
|
tracker = Tracker(tracking_conf, flame_model, diff_renderer) |
|
|
tracker.run() |
|
|
|
|
|
|
|
|
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames") |
|
|
image = first_image_from_dir(tracking_dir) |
|
|
|
|
|
return "✅ Pipeline complete!", image, gr.update(interactive=True) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_results_and_mesh(image, session_id=None): |
|
|
|
|
|
""" |
|
|
Process an input image through a 3D reconstruction pipeline and return the intermediate outputs and mesh file. |
|
|
|
|
|
This function runs a multi‐step workflow to go from a raw input image to a reconstructed 3D mesh: |
|
|
1. **Preprocessing**: crops and masks the image for object isolation. |
|
|
2. **Normals Estimation**: computes surface normal maps. |
|
|
3. **UV Mapping**: generates UV coordinate maps for texturing. |
|
|
4. **Tracking**: performs final alignment/tracking to prepare for mesh export. |
|
|
5. **Mesh Discovery**: locates the resulting `.ply` file in the tracking output directory. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image.Image or ndarray): Input image to reconstruct. |
|
|
session_id (str): Unique identifier for this session’s output directories. |
|
|
|
|
|
Returns: |
|
|
tuple: |
|
|
- final_status (str): Newline‐separated status messages from each pipeline step. |
|
|
- crop_img (Image or None): Cropped and preprocessed image. |
|
|
- normals_img (Image or None): Estimated surface normals visualization. |
|
|
- uv_img (Image or None): UV‐map visualization. |
|
|
- track_img (Image or None): Tracking/registration result. |
|
|
- mesh_file (str or None): Path to the generated 3D mesh (`.ply`), if found. |
|
|
""" |
|
|
if session_id is None: |
|
|
session_id = uuid.uuid4().hex |
|
|
|
|
|
|
|
|
status1, crop_img, _, _ = preprocess_image(image, session_id) |
|
|
if "❌" in status1: |
|
|
return status1, None, None, None, None, None |
|
|
|
|
|
status2, normals_img, _, _ = step2_normals(session_id) |
|
|
|
|
|
status3, uv_img, _, _ = step3_uv_map(session_id) |
|
|
|
|
|
status4, track_img, _ = step4_track(session_id) |
|
|
|
|
|
mesh_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "mesh") |
|
|
mesh_file = first_file_from_dir(mesh_dir, "glb") |
|
|
|
|
|
final_status = "\n".join([status1, status2, status3, status4]) |
|
|
return final_status, crop_img, normals_img, uv_img, track_img, mesh_file |
|
|
|
|
|
|
|
|
def cleanup(request: gr.Request): |
|
|
sid = request.session_hash |
|
|
if sid: |
|
|
d1 = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], sid) |
|
|
d2 = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], sid) |
|
|
shutil.rmtree(d1, ignore_errors=True) |
|
|
shutil.rmtree(d2, ignore_errors=True) |
|
|
|
|
|
def start_session(request: gr.Request): |
|
|
return request.session_hash |
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 1024px; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
session_state = gr.State() |
|
|
demo.load(start_session, outputs=[session_state]) |
|
|
|
|
|
gr.HTML( |
|
|
""" |
|
|
<div style="text-align: center;"> |
|
|
<p style="font-size:16px; display: inline; margin: 0;"> |
|
|
<strong>Pixel3dmm [Image Mode]</strong> – Versatile Screen-Space Priors for Single-Image 3D Face Reconstruction. |
|
|
</p> |
|
|
<a href="https://github.com/SimonGiebenhain/pixel3dmm" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> |
|
|
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo"> |
|
|
</a> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Column(elem_id="col-container"): |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_in = gr.Image(label="Upload Image", type="numpy", height=512) |
|
|
run_btn = gr.Button("Reconstruct Face", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
crop_img = gr.Image(label="Preprocessed", height=128) |
|
|
normals_img = gr.Image(label="Normals", height=128) |
|
|
uv_img = gr.Image(label="UV Map", height=128) |
|
|
track_img = gr.Image(label="Tracking", height=128) |
|
|
|
|
|
with gr.Column(): |
|
|
mesh_file = gr.Model3D(label="3D Model Preview", height=512) |
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["example_images/jennifer_lawrence.png"], |
|
|
["example_images/brendan_fraser.png"], |
|
|
["example_images/jim_carrey.png"], |
|
|
], |
|
|
inputs=[image_in], |
|
|
outputs=[gr.State(), crop_img, normals_img, uv_img, track_img, mesh_file], |
|
|
fn=generate_results_and_mesh, |
|
|
cache_examples=True |
|
|
) |
|
|
status = gr.Textbox(label="Status", lines=5, interactive=True, value="Upload an image to start.") |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=generate_results_and_mesh, |
|
|
inputs=[image_in, session_state], |
|
|
outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file] |
|
|
) |
|
|
image_in.upload(fn=reset_all, inputs=None, outputs=[crop_img, normals_img, uv_img, track_img, mesh_file, status, run_btn]) |
|
|
|
|
|
demo.unload(cleanup) |
|
|
|
|
|
demo.queue() |
|
|
|
|
|
demo.launch(share=True) |
|
|
|