Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| from typing import Any | |
| import rembg | |
| import numpy as np | |
| from torchvision import transforms | |
| from plyfile import PlyData, PlyElement | |
| import os | |
| import torch | |
| from .camera_utils import get_loop_cameras | |
| from .graphics_utils import getProjectionMatrix | |
| from .general_utils import matrix_to_quaternion | |
| def remove_background(image, rembg_session): | |
| do_remove = True | |
| if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
| do_remove = False | |
| if do_remove: | |
| image = rembg.remove(image, session=rembg_session) | |
| return image | |
| def set_white_background(image): | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| mask = image[:, :, 3:4] | |
| image = image[:, :, :3] * mask + (1 - mask) | |
| image = Image.fromarray((image * 255.0).astype(np.uint8)) | |
| return image | |
| def resize_foreground(image, ratio): | |
| image = np.array(image) | |
| assert image.shape[-1] == 4 | |
| alpha = np.where(image[..., 3] > 0) | |
| # modify so that cropping doesn't change the world center | |
| y1, y2, x1, x2 = ( | |
| alpha[0].min(), | |
| alpha[0].max(), | |
| alpha[1].min(), | |
| alpha[1].max(), | |
| ) | |
| # crop the foreground | |
| fg = image[y1: y2, | |
| x1: x2] | |
| # pad to square | |
| size = max(fg.shape[0], fg.shape[1]) | |
| ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 | |
| ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 | |
| new_image = np.pad( | |
| fg, | |
| ((ph0, ph1), (pw0, pw1), (0, 0)), | |
| mode="constant", | |
| constant_values=((255, 255), (255, 255), (0, 0)), | |
| ) | |
| # compute padding according to the ratio | |
| new_size = int(new_image.shape[0] / ratio) | |
| # pad to size, double side | |
| ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 | |
| ph1, pw1 = new_size - size - ph0, new_size - size - pw0 | |
| new_image = np.pad( | |
| new_image, | |
| ((ph0, ph1), (pw0, pw1), (0, 0)), | |
| mode="constant", | |
| constant_values=((255, 255), (255, 255), (0, 0)), | |
| ) | |
| new_image = Image.fromarray(new_image) | |
| return new_image | |
| def resize_to_128(img): | |
| img = transforms.functional.resize(img, 128, | |
| interpolation=transforms.InterpolationMode.LANCZOS) | |
| return img | |
| def to_tensor(img): | |
| img = torch.tensor(img).permute(2, 0, 1) / 255.0 | |
| return img | |
| def get_source_camera_v2w_rmo_and_quats(num_imgs_in_loop=200): | |
| source_camera = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop)[0] | |
| source_camera = torch.from_numpy(source_camera).transpose(0, 1).unsqueeze(0) | |
| qs = [] | |
| for c_idx in range(source_camera.shape[0]): | |
| qs.append(matrix_to_quaternion(source_camera[c_idx, :3, :3].transpose(0, 1))) | |
| return source_camera.unsqueeze(0), torch.stack(qs, dim=0).unsqueeze(0) | |
| def get_target_cameras(num_imgs_in_loop=200): | |
| """ | |
| Returns camera parameters for rendering a loop around the object: | |
| world_to_view_transforms, | |
| full_proj_transforms, | |
| camera_centers | |
| """ | |
| projection_matrix = getProjectionMatrix( | |
| znear=0.8, zfar=3.2, | |
| fovX=49.134342641202636 * 2 * np.pi / 360, | |
| fovY=49.134342641202636 * 2 * np.pi / 360).transpose(0,1) | |
| target_cameras = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop, | |
| max_elevation=np.pi/4, | |
| elevation_freq=1.5) | |
| world_view_transforms = [] | |
| view_world_transforms = [] | |
| camera_centers = [] | |
| for loop_camera_c2w_cmo in target_cameras: | |
| view_world_transform = torch.from_numpy(loop_camera_c2w_cmo).transpose(0, 1) | |
| world_view_transform = torch.from_numpy(loop_camera_c2w_cmo).inverse().transpose(0, 1) | |
| camera_center = view_world_transform[3, :3].clone() | |
| world_view_transforms.append(world_view_transform) | |
| view_world_transforms.append(view_world_transform) | |
| camera_centers.append(camera_center) | |
| world_view_transforms = torch.stack(world_view_transforms) | |
| view_world_transforms = torch.stack(view_world_transforms) | |
| camera_centers = torch.stack(camera_centers) | |
| full_proj_transforms = world_view_transforms.bmm(projection_matrix.unsqueeze(0).expand( | |
| world_view_transforms.shape[0], 4, 4)) | |
| return world_view_transforms, full_proj_transforms, camera_centers | |
| def construct_list_of_attributes(): | |
| # taken from gaussian splatting repo. | |
| l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] | |
| # All channels except the 3 DC | |
| # 3 channels for DC | |
| for i in range(3): | |
| l.append('f_dc_{}'.format(i)) | |
| # 9 channels for SH order 1 | |
| for i in range(9): | |
| l.append('f_rest_{}'.format(i)) | |
| l.append('opacity') | |
| for i in range(3): | |
| l.append('scale_{}'.format(i)) | |
| for i in range(4): | |
| l.append('rot_{}'.format(i)) | |
| return l | |
| def export_to_obj(reconstruction, ply_out_path): | |
| """ | |
| Args: | |
| reconstruction: dict with xyz, opacity, features dc, etc with leading batch size | |
| ply_out_path: file path where to save the output | |
| """ | |
| os.makedirs(os.path.dirname(ply_out_path), exist_ok=True) | |
| for k, v in reconstruction.items(): | |
| # check dimensions | |
| if k not in ["features_dc", "features_rest"]: | |
| assert len(v.shape) == 3, "Unexpected size for {}".format(k) | |
| else: | |
| assert len(v.shape) == 4, "Unexpected size for {}".format(k) | |
| assert v.shape[0] == 1, "Expected batch size to be 0" | |
| reconstruction[k] = v[0] | |
| xyz = reconstruction["xyz"].detach().cpu().numpy() | |
| normals = np.zeros_like(xyz) | |
| f_dc = reconstruction["features_dc"].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() | |
| f_rest = reconstruction["features_rest"].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() | |
| opacities = reconstruction["opacity"].detach().cpu().numpy() | |
| scale = reconstruction["scaling"].detach().cpu().numpy() | |
| rotation = reconstruction["rotation"].detach().cpu().numpy() | |
| dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes()] | |
| elements = np.empty(xyz.shape[0], dtype=dtype_full) | |
| attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| el = PlyElement.describe(elements, 'vertex') | |
| PlyData([el]).write(ply_out_path) | |