cavargas10 commited on
Commit
261f862
·
verified ·
1 Parent(s): d0e136b

Upload 5 files

Browse files
trellis/renderers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'OctreeRenderer': 'octree_renderer',
5
+ 'GaussianRenderer': 'gaussian_render',
6
+ 'MeshRenderer': 'mesh_renderer',
7
+ }
8
+
9
+ __submodules = []
10
+
11
+ __all__ = list(__attributes.keys()) + __submodules
12
+
13
+ def __getattr__(name):
14
+ if name not in globals():
15
+ if name in __attributes:
16
+ module_name = __attributes[name]
17
+ module = importlib.import_module(f".{module_name}", __name__)
18
+ globals()[name] = getattr(module, name)
19
+ elif name in __submodules:
20
+ module = importlib.import_module(f".{name}", __name__)
21
+ globals()[name] = module
22
+ else:
23
+ raise AttributeError(f"module {__name__} has no attribute {name}")
24
+ return globals()[name]
25
+
26
+
27
+ # For Pylance
28
+ if __name__ == '__main__':
29
+ from .octree_renderer import OctreeRenderer
30
+ from .gaussian_render import GaussianRenderer
31
+ from .mesh_renderer import MeshRenderer
trellis/renderers/gaussian_render.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ from easydict import EasyDict as edict
15
+ import numpy as np
16
+ from ..representations.gaussian import Gaussian
17
+ from .sh_utils import eval_sh
18
+ import torch.nn.functional as F
19
+ from easydict import EasyDict as edict
20
+
21
+
22
+ def intrinsics_to_projection(
23
+ intrinsics: torch.Tensor,
24
+ near: float,
25
+ far: float,
26
+ ) -> torch.Tensor:
27
+ """
28
+ OpenCV intrinsics to OpenGL perspective matrix
29
+
30
+ Args:
31
+ intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
32
+ near (float): near plane to clip
33
+ far (float): far plane to clip
34
+ Returns:
35
+ (torch.Tensor): [4, 4] OpenGL perspective matrix
36
+ """
37
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
38
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
39
+ ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
40
+ ret[0, 0] = 2 * fx
41
+ ret[1, 1] = 2 * fy
42
+ ret[0, 2] = 2 * cx - 1
43
+ ret[1, 2] = - 2 * cy + 1
44
+ ret[2, 2] = far / (far - near)
45
+ ret[2, 3] = near * far / (near - far)
46
+ ret[3, 2] = 1.
47
+ return ret
48
+
49
+
50
+ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
51
+ """
52
+ Render the scene.
53
+
54
+ Background tensor (bg_color) must be on GPU!
55
+ """
56
+ # lazy import
57
+ if 'GaussianRasterizer' not in globals():
58
+ from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
59
+
60
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
61
+ screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
62
+ try:
63
+ screenspace_points.retain_grad()
64
+ except:
65
+ pass
66
+ # Set up rasterization configuration
67
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
68
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
69
+
70
+ kernel_size = pipe.kernel_size
71
+ subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
72
+
73
+ raster_settings = GaussianRasterizationSettings(
74
+ image_height=int(viewpoint_camera.image_height),
75
+ image_width=int(viewpoint_camera.image_width),
76
+ tanfovx=tanfovx,
77
+ tanfovy=tanfovy,
78
+ kernel_size=kernel_size,
79
+ subpixel_offset=subpixel_offset,
80
+ bg=bg_color,
81
+ scale_modifier=scaling_modifier,
82
+ viewmatrix=viewpoint_camera.world_view_transform,
83
+ projmatrix=viewpoint_camera.full_proj_transform,
84
+ sh_degree=pc.active_sh_degree,
85
+ campos=viewpoint_camera.camera_center,
86
+ prefiltered=False,
87
+ debug=pipe.debug
88
+ )
89
+
90
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
91
+
92
+ means3D = pc.get_xyz
93
+ means2D = screenspace_points
94
+ opacity = pc.get_opacity
95
+
96
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
97
+ # scaling / rotation by the rasterizer.
98
+ scales = None
99
+ rotations = None
100
+ cov3D_precomp = None
101
+ if pipe.compute_cov3D_python:
102
+ cov3D_precomp = pc.get_covariance(scaling_modifier)
103
+ else:
104
+ scales = pc.get_scaling
105
+ rotations = pc.get_rotation
106
+
107
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
108
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
109
+ shs = None
110
+ colors_precomp = None
111
+ if override_color is None:
112
+ if pipe.convert_SHs_python:
113
+ shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
114
+ dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
115
+ dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
116
+ sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
117
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
118
+ else:
119
+ shs = pc.get_features
120
+ else:
121
+ colors_precomp = override_color
122
+
123
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
124
+ rendered_image, radii = rasterizer(
125
+ means3D = means3D,
126
+ means2D = means2D,
127
+ shs = shs,
128
+ colors_precomp = colors_precomp,
129
+ opacities = opacity,
130
+ scales = scales,
131
+ rotations = rotations,
132
+ cov3D_precomp = cov3D_precomp
133
+ )
134
+
135
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
136
+ # They will be excluded from value updates used in the splitting criteria.
137
+ return edict({"render": rendered_image,
138
+ "viewspace_points": screenspace_points,
139
+ "visibility_filter" : radii > 0,
140
+ "radii": radii})
141
+
142
+
143
+ class GaussianRenderer:
144
+ """
145
+ Renderer for the Voxel representation.
146
+
147
+ Args:
148
+ rendering_options (dict): Rendering options.
149
+ """
150
+
151
+ def __init__(self, rendering_options={}) -> None:
152
+ self.pipe = edict({
153
+ "kernel_size": 0.1,
154
+ "convert_SHs_python": False,
155
+ "compute_cov3D_python": False,
156
+ "scale_modifier": 1.0,
157
+ "debug": False
158
+ })
159
+ self.rendering_options = edict({
160
+ "resolution": None,
161
+ "near": None,
162
+ "far": None,
163
+ "ssaa": 1,
164
+ "bg_color": 'random',
165
+ })
166
+ self.rendering_options.update(rendering_options)
167
+ self.bg_color = None
168
+
169
+ def render(
170
+ self,
171
+ gausssian: Gaussian,
172
+ extrinsics: torch.Tensor,
173
+ intrinsics: torch.Tensor,
174
+ colors_overwrite: torch.Tensor = None
175
+ ) -> edict:
176
+ """
177
+ Render the gausssian.
178
+
179
+ Args:
180
+ gaussian : gaussianmodule
181
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
182
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
183
+ colors_overwrite (torch.Tensor): (N, 3) override color
184
+
185
+ Returns:
186
+ edict containing:
187
+ color (torch.Tensor): (3, H, W) rendered color image
188
+ """
189
+ resolution = self.rendering_options["resolution"]
190
+ near = self.rendering_options["near"]
191
+ far = self.rendering_options["far"]
192
+ ssaa = self.rendering_options["ssaa"]
193
+
194
+ if self.rendering_options["bg_color"] == 'random':
195
+ self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
196
+ if np.random.rand() < 0.5:
197
+ self.bg_color += 1
198
+ else:
199
+ self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
200
+
201
+ view = extrinsics
202
+ perspective = intrinsics_to_projection(intrinsics, near, far)
203
+ camera = torch.inverse(view)[:3, 3]
204
+ focalx = intrinsics[0, 0]
205
+ focaly = intrinsics[1, 1]
206
+ fovx = 2 * torch.atan(0.5 / focalx)
207
+ fovy = 2 * torch.atan(0.5 / focaly)
208
+
209
+ camera_dict = edict({
210
+ "image_height": resolution * ssaa,
211
+ "image_width": resolution * ssaa,
212
+ "FoVx": fovx,
213
+ "FoVy": fovy,
214
+ "znear": near,
215
+ "zfar": far,
216
+ "world_view_transform": view.T.contiguous(),
217
+ "projection_matrix": perspective.T.contiguous(),
218
+ "full_proj_transform": (perspective @ view).T.contiguous(),
219
+ "camera_center": camera
220
+ })
221
+
222
+ # Render
223
+ render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
224
+
225
+ if ssaa > 1:
226
+ render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
227
+
228
+ ret = edict({
229
+ 'color': render_ret['render']
230
+ })
231
+ return ret
trellis/renderers/mesh_renderer.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nvdiffrast.torch as dr
3
+ from easydict import EasyDict as edict
4
+ from ..representations.mesh import MeshExtractResult
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def intrinsics_to_projection(
9
+ intrinsics: torch.Tensor,
10
+ near: float,
11
+ far: float,
12
+ ) -> torch.Tensor:
13
+ """
14
+ OpenCV intrinsics to OpenGL perspective matrix
15
+
16
+ Args:
17
+ intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
18
+ near (float): near plane to clip
19
+ far (float): far plane to clip
20
+ Returns:
21
+ (torch.Tensor): [4, 4] OpenGL perspective matrix
22
+ """
23
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
24
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
25
+ ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
26
+ ret[0, 0] = 2 * fx
27
+ ret[1, 1] = 2 * fy
28
+ ret[0, 2] = 2 * cx - 1
29
+ ret[1, 2] = - 2 * cy + 1
30
+ ret[2, 2] = far / (far - near)
31
+ ret[2, 3] = near * far / (near - far)
32
+ ret[3, 2] = 1.
33
+ return ret
34
+
35
+
36
+ class MeshRenderer:
37
+ """
38
+ Renderer for the Mesh representation.
39
+
40
+ Args:
41
+ rendering_options (dict): Rendering options.
42
+ glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop.
43
+ """
44
+ def __init__(self, rendering_options={}, device='cuda'):
45
+ self.rendering_options = edict({
46
+ "resolution": None,
47
+ "near": None,
48
+ "far": None,
49
+ "ssaa": 1
50
+ })
51
+ self.rendering_options.update(rendering_options)
52
+ self.glctx = dr.RasterizeCudaContext(device=device)
53
+ self.device=device
54
+
55
+ def render(
56
+ self,
57
+ mesh : MeshExtractResult,
58
+ extrinsics: torch.Tensor,
59
+ intrinsics: torch.Tensor,
60
+ return_types = ["mask", "normal", "depth"]
61
+ ) -> edict:
62
+ """
63
+ Render the mesh.
64
+
65
+ Args:
66
+ mesh : meshmodel
67
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
68
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
69
+ return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color"
70
+
71
+ Returns:
72
+ edict based on return_types containing:
73
+ color (torch.Tensor): [3, H, W] rendered color image
74
+ depth (torch.Tensor): [H, W] rendered depth image
75
+ normal (torch.Tensor): [3, H, W] rendered normal image
76
+ normal_map (torch.Tensor): [3, H, W] rendered normal map image
77
+ mask (torch.Tensor): [H, W] rendered mask image
78
+ """
79
+ resolution = self.rendering_options["resolution"]
80
+ near = self.rendering_options["near"]
81
+ far = self.rendering_options["far"]
82
+ ssaa = self.rendering_options["ssaa"]
83
+
84
+ if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
85
+ default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device)
86
+ ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types}
87
+ return ret_dict
88
+
89
+ perspective = intrinsics_to_projection(intrinsics, near, far)
90
+
91
+ RT = extrinsics.unsqueeze(0)
92
+ full_proj = (perspective @ extrinsics).unsqueeze(0)
93
+
94
+ vertices = mesh.vertices.unsqueeze(0)
95
+
96
+ vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
97
+ vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2))
98
+ vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
99
+ faces_int = mesh.faces.int()
100
+ rast, _ = dr.rasterize(
101
+ self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa))
102
+
103
+ out_dict = edict()
104
+ for type in return_types:
105
+ img = None
106
+ if type == "mask" :
107
+ img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
108
+ elif type == "depth":
109
+ img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0]
110
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
111
+ elif type == "normal" :
112
+ img = dr.interpolate(
113
+ mesh.face_normal.reshape(1, -1, 3), rast,
114
+ torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3)
115
+ )[0]
116
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
117
+ # normalize norm pictures
118
+ img = (img + 1) / 2
119
+ elif type == "normal_map" :
120
+ img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0]
121
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
122
+ elif type == "color" :
123
+ img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0]
124
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
125
+
126
+ if ssaa > 1:
127
+ img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
128
+ img = img.squeeze()
129
+ else:
130
+ img = img.permute(0, 3, 1, 2).squeeze()
131
+ out_dict[type] = img
132
+
133
+ return out_dict
trellis/renderers/octree_renderer.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ import cv2
6
+ from scipy.stats import qmc
7
+ from easydict import EasyDict as edict
8
+ from ..representations.octree import DfsOctree
9
+
10
+
11
+ def intrinsics_to_projection(
12
+ intrinsics: torch.Tensor,
13
+ near: float,
14
+ far: float,
15
+ ) -> torch.Tensor:
16
+ """
17
+ OpenCV intrinsics to OpenGL perspective matrix
18
+
19
+ Args:
20
+ intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
21
+ near (float): near plane to clip
22
+ far (float): far plane to clip
23
+ Returns:
24
+ (torch.Tensor): [4, 4] OpenGL perspective matrix
25
+ """
26
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
27
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
28
+ ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
29
+ ret[0, 0] = 2 * fx
30
+ ret[1, 1] = 2 * fy
31
+ ret[0, 2] = 2 * cx - 1
32
+ ret[1, 2] = - 2 * cy + 1
33
+ ret[2, 2] = far / (far - near)
34
+ ret[2, 3] = near * far / (near - far)
35
+ ret[3, 2] = 1.
36
+ return ret
37
+
38
+
39
+ def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None):
40
+ """
41
+ Render the scene.
42
+
43
+ Background tensor (bg_color) must be on GPU!
44
+ """
45
+ # lazy import
46
+ if 'OctreeTrivecRasterizer' not in globals():
47
+ from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer
48
+
49
+ # Set up rasterization configuration
50
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
51
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
52
+
53
+ raster_settings = edict(
54
+ image_height=int(viewpoint_camera.image_height),
55
+ image_width=int(viewpoint_camera.image_width),
56
+ tanfovx=tanfovx,
57
+ tanfovy=tanfovy,
58
+ bg=bg_color,
59
+ scale_modifier=scaling_modifier,
60
+ viewmatrix=viewpoint_camera.world_view_transform,
61
+ projmatrix=viewpoint_camera.full_proj_transform,
62
+ sh_degree=octree.active_sh_degree,
63
+ campos=viewpoint_camera.camera_center,
64
+ with_distloss=pipe.with_distloss,
65
+ jitter=pipe.jitter,
66
+ debug=pipe.debug,
67
+ )
68
+
69
+ positions = octree.get_xyz
70
+ if octree.primitive == "voxel":
71
+ densities = octree.get_density
72
+ elif octree.primitive == "gaussian":
73
+ opacities = octree.get_opacity
74
+ elif octree.primitive == "trivec":
75
+ trivecs = octree.get_trivec
76
+ densities = octree.get_density
77
+ raster_settings.density_shift = octree.density_shift
78
+ elif octree.primitive == "decoupoly":
79
+ decoupolys_V, decoupolys_g = octree.get_decoupoly
80
+ densities = octree.get_density
81
+ raster_settings.density_shift = octree.density_shift
82
+ else:
83
+ raise ValueError(f"Unknown primitive {octree.primitive}")
84
+ depths = octree.get_depth
85
+
86
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
87
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
88
+ colors_precomp = None
89
+ shs = octree.get_features
90
+ if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None:
91
+ colors_precomp = colors_overwrite
92
+ shs = None
93
+
94
+ ret = edict()
95
+
96
+ if octree.primitive == "voxel":
97
+ renderer = OctreeVoxelRasterizer(raster_settings=raster_settings)
98
+ rgb, depth, alpha, distloss = renderer(
99
+ positions = positions,
100
+ densities = densities,
101
+ shs = shs,
102
+ colors_precomp = colors_precomp,
103
+ depths = depths,
104
+ aabb = octree.aabb,
105
+ aux = aux,
106
+ )
107
+ ret['rgb'] = rgb
108
+ ret['depth'] = depth
109
+ ret['alpha'] = alpha
110
+ ret['distloss'] = distloss
111
+ elif octree.primitive == "gaussian":
112
+ renderer = OctreeGaussianRasterizer(raster_settings=raster_settings)
113
+ rgb, depth, alpha = renderer(
114
+ positions = positions,
115
+ opacities = opacities,
116
+ shs = shs,
117
+ colors_precomp = colors_precomp,
118
+ depths = depths,
119
+ aabb = octree.aabb,
120
+ aux = aux,
121
+ )
122
+ ret['rgb'] = rgb
123
+ ret['depth'] = depth
124
+ ret['alpha'] = alpha
125
+ elif octree.primitive == "trivec":
126
+ raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1]
127
+ renderer = OctreeTrivecRasterizer(raster_settings=raster_settings)
128
+ rgb, depth, alpha, percent_depth = renderer(
129
+ positions = positions,
130
+ trivecs = trivecs,
131
+ densities = densities,
132
+ shs = shs,
133
+ colors_precomp = colors_precomp,
134
+ colors_overwrite = colors_overwrite,
135
+ depths = depths,
136
+ aabb = octree.aabb,
137
+ aux = aux,
138
+ halton_sampler = halton_sampler,
139
+ )
140
+ ret['percent_depth'] = percent_depth
141
+ ret['rgb'] = rgb
142
+ ret['depth'] = depth
143
+ ret['alpha'] = alpha
144
+ elif octree.primitive == "decoupoly":
145
+ raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1]
146
+ renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings)
147
+ rgb, depth, alpha = renderer(
148
+ positions = positions,
149
+ decoupolys_V = decoupolys_V,
150
+ decoupolys_g = decoupolys_g,
151
+ densities = densities,
152
+ shs = shs,
153
+ colors_precomp = colors_precomp,
154
+ depths = depths,
155
+ aabb = octree.aabb,
156
+ aux = aux,
157
+ )
158
+ ret['rgb'] = rgb
159
+ ret['depth'] = depth
160
+ ret['alpha'] = alpha
161
+
162
+ return ret
163
+
164
+
165
+ class OctreeRenderer:
166
+ """
167
+ Renderer for the Voxel representation.
168
+
169
+ Args:
170
+ rendering_options (dict): Rendering options.
171
+ """
172
+
173
+ def __init__(self, rendering_options={}) -> None:
174
+ try:
175
+ import diffoctreerast
176
+ except ImportError:
177
+ print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m")
178
+ self.unsupported = True
179
+ else:
180
+ self.unsupported = False
181
+
182
+ self.pipe = edict({
183
+ "with_distloss": False,
184
+ "with_aux": False,
185
+ "scale_modifier": 1.0,
186
+ "used_rank": None,
187
+ "jitter": False,
188
+ "debug": False,
189
+ })
190
+ self.rendering_options = edict({
191
+ "resolution": None,
192
+ "near": None,
193
+ "far": None,
194
+ "ssaa": 1,
195
+ "bg_color": 'random',
196
+ })
197
+ self.halton_sampler = qmc.Halton(2, scramble=False)
198
+ self.rendering_options.update(rendering_options)
199
+ self.bg_color = None
200
+
201
+ def render(
202
+ self,
203
+ octree: DfsOctree,
204
+ extrinsics: torch.Tensor,
205
+ intrinsics: torch.Tensor,
206
+ colors_overwrite: torch.Tensor = None,
207
+ ) -> edict:
208
+ """
209
+ Render the octree.
210
+
211
+ Args:
212
+ octree (Octree): octree
213
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
214
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
215
+ colors_overwrite (torch.Tensor): (N, 3) override color
216
+
217
+ Returns:
218
+ edict containing:
219
+ color (torch.Tensor): (3, H, W) rendered color
220
+ depth (torch.Tensor): (H, W) rendered depth
221
+ alpha (torch.Tensor): (H, W) rendered alpha
222
+ distloss (Optional[torch.Tensor]): (H, W) rendered distance loss
223
+ percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth
224
+ aux (Optional[edict]): auxiliary tensors
225
+ """
226
+ resolution = self.rendering_options["resolution"]
227
+ near = self.rendering_options["near"]
228
+ far = self.rendering_options["far"]
229
+ ssaa = self.rendering_options["ssaa"]
230
+
231
+ if self.unsupported:
232
+ image = np.zeros((512, 512, 3), dtype=np.uint8)
233
+ text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0]
234
+ origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2
235
+ image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA)
236
+ return {
237
+ 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255,
238
+ }
239
+
240
+ if self.rendering_options["bg_color"] == 'random':
241
+ self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
242
+ if np.random.rand() < 0.5:
243
+ self.bg_color += 1
244
+ else:
245
+ self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
246
+
247
+ if self.pipe["with_aux"]:
248
+ aux = {
249
+ 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
250
+ 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
251
+ }
252
+ for k in aux.keys():
253
+ aux[k].requires_grad_()
254
+ aux[k].retain_grad()
255
+ else:
256
+ aux = None
257
+
258
+ view = extrinsics
259
+ perspective = intrinsics_to_projection(intrinsics, near, far)
260
+ camera = torch.inverse(view)[:3, 3]
261
+ focalx = intrinsics[0, 0]
262
+ focaly = intrinsics[1, 1]
263
+ fovx = 2 * torch.atan(0.5 / focalx)
264
+ fovy = 2 * torch.atan(0.5 / focaly)
265
+
266
+ camera_dict = edict({
267
+ "image_height": resolution * ssaa,
268
+ "image_width": resolution * ssaa,
269
+ "FoVx": fovx,
270
+ "FoVy": fovy,
271
+ "znear": near,
272
+ "zfar": far,
273
+ "world_view_transform": view.T.contiguous(),
274
+ "projection_matrix": perspective.T.contiguous(),
275
+ "full_proj_transform": (perspective @ view).T.contiguous(),
276
+ "camera_center": camera
277
+ })
278
+
279
+ # Render
280
+ render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler)
281
+
282
+ if ssaa > 1:
283
+ render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
284
+ render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
285
+ render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
286
+ if hasattr(render_ret, 'percent_depth'):
287
+ render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
288
+
289
+ ret = edict({
290
+ 'color': render_ret.rgb,
291
+ 'depth': render_ret.depth,
292
+ 'alpha': render_ret.alpha,
293
+ })
294
+ if self.pipe["with_distloss"] and 'distloss' in render_ret:
295
+ ret['distloss'] = render_ret.distloss
296
+ if self.pipe["with_aux"]:
297
+ ret['aux'] = aux
298
+ if hasattr(render_ret, 'percent_depth'):
299
+ ret['percent_depth'] = render_ret.percent_depth
300
+ return ret
trellis/renderers/sh_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ import torch
25
+
26
+ C0 = 0.28209479177387814
27
+ C1 = 0.4886025119029199
28
+ C2 = [
29
+ 1.0925484305920792,
30
+ -1.0925484305920792,
31
+ 0.31539156525252005,
32
+ -1.0925484305920792,
33
+ 0.5462742152960396
34
+ ]
35
+ C3 = [
36
+ -0.5900435899266435,
37
+ 2.890611442640554,
38
+ -0.4570457994644658,
39
+ 0.3731763325901154,
40
+ -0.4570457994644658,
41
+ 1.445305721320277,
42
+ -0.5900435899266435
43
+ ]
44
+ C4 = [
45
+ 2.5033429417967046,
46
+ -1.7701307697799304,
47
+ 0.9461746957575601,
48
+ -0.6690465435572892,
49
+ 0.10578554691520431,
50
+ -0.6690465435572892,
51
+ 0.47308734787878004,
52
+ -1.7701307697799304,
53
+ 0.6258357354491761,
54
+ ]
55
+
56
+
57
+ def eval_sh(deg, sh, dirs):
58
+ """
59
+ Evaluate spherical harmonics at unit directions
60
+ using hardcoded SH polynomials.
61
+ Works with torch/np/jnp.
62
+ ... Can be 0 or more batch dimensions.
63
+ Args:
64
+ deg: int SH deg. Currently, 0-3 supported
65
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66
+ dirs: jnp.ndarray unit directions [..., 3]
67
+ Returns:
68
+ [..., C]
69
+ """
70
+ assert deg <= 4 and deg >= 0
71
+ coeff = (deg + 1) ** 2
72
+ assert sh.shape[-1] >= coeff
73
+
74
+ result = C0 * sh[..., 0]
75
+ if deg > 0:
76
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (result -
78
+ C1 * y * sh[..., 1] +
79
+ C1 * z * sh[..., 2] -
80
+ C1 * x * sh[..., 3])
81
+
82
+ if deg > 1:
83
+ xx, yy, zz = x * x, y * y, z * z
84
+ xy, yz, xz = x * y, y * z, x * z
85
+ result = (result +
86
+ C2[0] * xy * sh[..., 4] +
87
+ C2[1] * yz * sh[..., 5] +
88
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
+ C2[3] * xz * sh[..., 7] +
90
+ C2[4] * (xx - yy) * sh[..., 8])
91
+
92
+ if deg > 2:
93
+ result = (result +
94
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
+ C3[1] * xy * z * sh[..., 10] +
96
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
+ C3[5] * z * (xx - yy) * sh[..., 14] +
100
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101
+
102
+ if deg > 3:
103
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112
+ return result
113
+
114
+ def RGB2SH(rgb):
115
+ return (rgb - 0.5) / C0
116
+
117
+ def SH2RGB(sh):
118
+ return sh * C0 + 0.5