|
|
import os |
|
|
import tyro |
|
|
import mediapy |
|
|
import torch |
|
|
import numpy as np |
|
|
import pyvista as pv |
|
|
import trimesh |
|
|
from PIL import Image |
|
|
|
|
|
from dreifus.matrix import Intrinsics, Pose, CameraCoordinateConvention, PoseType |
|
|
from dreifus.pyvista import add_camera_frustum, render_from_camera |
|
|
|
|
|
from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix |
|
|
from pixel3dmm.env_paths import PREPROCESSED_DATA, TRACKING_OUTPUT |
|
|
|
|
|
|
|
|
def main(vid_name : str, |
|
|
HEAD_CENTRIC : bool = True, |
|
|
DO_PROJECTION_TEST : bool = False, |
|
|
): |
|
|
tracking_dir = f'{TRACKING_OUTPUT}/{vid_name}_nV1_noPho_uv2000.0_n1000.0' |
|
|
|
|
|
meshes = [f for f in os.listdir(f'{tracking_dir}/mesh/') if f.endswith('.ply') and not 'canonical' in f] |
|
|
meshes.sort() |
|
|
|
|
|
ckpts = [f for f in os.listdir(f'{tracking_dir}/checkpoint/') if f.endswith('.frame')] |
|
|
ckpts.sort() |
|
|
|
|
|
N_STEPS = len(meshes) |
|
|
|
|
|
pl = pv.Plotter() |
|
|
vid_frames = [] |
|
|
for i in range(N_STEPS): |
|
|
ckpt = torch.load(f'{tracking_dir}/checkpoint/{ckpts[i]}', weights_only=False) |
|
|
|
|
|
mesh = trimesh.load(f'{tracking_dir}/mesh/{meshes[i]}', process=False) |
|
|
|
|
|
head_rot = rotation_6d_to_matrix(torch.from_numpy(ckpt['flame']['R'])).numpy()[0] |
|
|
|
|
|
if not HEAD_CENTRIC: |
|
|
|
|
|
mesh.vertices = mesh.vertices @ head_rot.T + (ckpt['flame']['t']) |
|
|
else: |
|
|
|
|
|
verts_hom = np.concatenate([mesh.vertices, np.ones_like(mesh.vertices[..., :1])], axis=-1) |
|
|
verts_hom = verts_hom @ np.linalg.inv(ckpt['joint_transforms'][0, 1, :, :]).T |
|
|
mesh.vertices = verts_hom[..., :3] |
|
|
|
|
|
|
|
|
|
|
|
extr_open_gl_world_to_cam = np.eye(4) |
|
|
extr_open_gl_world_to_cam[:3, :3] = ckpt['camera']['R_base_0'][0] |
|
|
extr_open_gl_world_to_cam[:3, 3] = ckpt['camera']['t_base_0'][0] |
|
|
if HEAD_CENTRIC: |
|
|
flame2world = np.eye(4) |
|
|
flame2world[:3, :3] = head_rot |
|
|
flame2world[:3, 3] = np.squeeze(ckpt['flame']['t']) |
|
|
|
|
|
extr_open_gl_world_to_cam = extr_open_gl_world_to_cam @ flame2world @ ckpt['joint_transforms'][0, 1, :, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extr_open_gl_world_to_cam = Pose(extr_open_gl_world_to_cam, |
|
|
camera_coordinate_convention=CameraCoordinateConvention.OPEN_GL, |
|
|
pose_type=PoseType.WORLD_2_CAM) |
|
|
|
|
|
intr = np.eye(3) |
|
|
intr[0, 0] = ckpt['camera']['fl'][0, 0] * 256 |
|
|
intr[1, 1] = ckpt['camera']['fl'][0, 0] * 256 |
|
|
intr[:2, 2] = ckpt['camera']['pp'][0] * (256/2+0.5) + 256/2 + 0.5 |
|
|
|
|
|
intr = Intrinsics(intr) |
|
|
|
|
|
|
|
|
|
|
|
pl.add_mesh(mesh, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)]) |
|
|
add_camera_frustum(pl, extr_open_gl_world_to_cam, intr, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)]) |
|
|
|
|
|
if DO_PROJECTION_TEST: |
|
|
pll = pv.Plotter(off_screen=True, window_size=(256, 256)) |
|
|
pll.add_mesh(mesh) |
|
|
img = render_from_camera(pll, extr_open_gl_world_to_cam, intr) |
|
|
|
|
|
gt_img = np.array(Image.open(f'{PREPROCESSED_DATA}/{vid_name}/cropped/{i:05d}.jpg').resize((256, 256))) |
|
|
|
|
|
alpha = img[..., 3] |
|
|
|
|
|
overlay = (gt_img *0.5 + img[..., :3]*0.5).astype(np.uint8) |
|
|
vid_frames.append(overlay) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pl.show() |
|
|
|
|
|
if DO_PROJECTION_TEST: |
|
|
mediapy.write_video(f'{tracking_dir}/projection_test.mp4', images=vid_frames) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
tyro.cli(main) |