cavargas10 commited on
Commit
8471ff9
·
verified ·
1 Parent(s): 6c57173

Upload 11 files

Browse files
trellis/representations/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .radiance_field import Strivec
2
+ from .octree import DfsOctree as Octree
3
+ from .gaussian import Gaussian
4
+ from .mesh import MeshExtractResult
trellis/representations/gaussian/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .gaussian_model import Gaussian
trellis/representations/gaussian/gaussian_model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from plyfile import PlyData, PlyElement
4
+ from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
5
+ import utils3d
6
+
7
+
8
+ class Gaussian:
9
+ def __init__(
10
+ self,
11
+ aabb : list,
12
+ sh_degree : int = 0,
13
+ mininum_kernel_size : float = 0.0,
14
+ scaling_bias : float = 0.01,
15
+ opacity_bias : float = 0.1,
16
+ scaling_activation : str = "exp",
17
+ device='cuda'
18
+ ):
19
+ self.init_params = {
20
+ 'aabb': aabb,
21
+ 'sh_degree': sh_degree,
22
+ 'mininum_kernel_size': mininum_kernel_size,
23
+ 'scaling_bias': scaling_bias,
24
+ 'opacity_bias': opacity_bias,
25
+ 'scaling_activation': scaling_activation,
26
+ }
27
+
28
+ self.sh_degree = sh_degree
29
+ self.active_sh_degree = sh_degree
30
+ self.mininum_kernel_size = mininum_kernel_size
31
+ self.scaling_bias = scaling_bias
32
+ self.opacity_bias = opacity_bias
33
+ self.scaling_activation_type = scaling_activation
34
+ self.device = device
35
+ self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
36
+ self.setup_functions()
37
+
38
+ self._xyz = None
39
+ self._features_dc = None
40
+ self._features_rest = None
41
+ self._scaling = None
42
+ self._rotation = None
43
+ self._opacity = None
44
+
45
+ def setup_functions(self):
46
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
47
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
48
+ actual_covariance = L @ L.transpose(1, 2)
49
+ symm = strip_symmetric(actual_covariance)
50
+ return symm
51
+
52
+ if self.scaling_activation_type == "exp":
53
+ self.scaling_activation = torch.exp
54
+ self.inverse_scaling_activation = torch.log
55
+ elif self.scaling_activation_type == "softplus":
56
+ self.scaling_activation = torch.nn.functional.softplus
57
+ self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
58
+
59
+ self.covariance_activation = build_covariance_from_scaling_rotation
60
+
61
+ self.opacity_activation = torch.sigmoid
62
+ self.inverse_opacity_activation = inverse_sigmoid
63
+
64
+ self.rotation_activation = torch.nn.functional.normalize
65
+
66
+ self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda()
67
+ self.rots_bias = torch.zeros((4)).cuda()
68
+ self.rots_bias[0] = 1
69
+ self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda()
70
+
71
+ @property
72
+ def get_scaling(self):
73
+ scales = self.scaling_activation(self._scaling + self.scale_bias)
74
+ scales = torch.square(scales) + self.mininum_kernel_size ** 2
75
+ scales = torch.sqrt(scales)
76
+ return scales
77
+
78
+ @property
79
+ def get_rotation(self):
80
+ return self.rotation_activation(self._rotation + self.rots_bias[None, :])
81
+
82
+ @property
83
+ def get_xyz(self):
84
+ return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]
85
+
86
+ @property
87
+ def get_features(self):
88
+ return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc
89
+
90
+ @property
91
+ def get_opacity(self):
92
+ return self.opacity_activation(self._opacity + self.opacity_bias)
93
+
94
+ def get_covariance(self, scaling_modifier = 1):
95
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :])
96
+
97
+ def from_scaling(self, scales):
98
+ scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)
99
+ self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias
100
+
101
+ def from_rotation(self, rots):
102
+ self._rotation = rots - self.rots_bias[None, :]
103
+
104
+ def from_xyz(self, xyz):
105
+ self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
106
+
107
+ def from_features(self, features):
108
+ self._features_dc = features
109
+
110
+ def from_opacity(self, opacities):
111
+ self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
112
+
113
+ def construct_list_of_attributes(self):
114
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
115
+ # All channels except the 3 DC
116
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
117
+ l.append('f_dc_{}'.format(i))
118
+ l.append('opacity')
119
+ for i in range(self._scaling.shape[1]):
120
+ l.append('scale_{}'.format(i))
121
+ for i in range(self._rotation.shape[1]):
122
+ l.append('rot_{}'.format(i))
123
+ return l
124
+
125
+ def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
126
+ xyz = self.get_xyz.detach().cpu().numpy()
127
+ normals = np.zeros_like(xyz)
128
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
129
+ opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
130
+ scale = torch.log(self.get_scaling).detach().cpu().numpy()
131
+ rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
132
+
133
+ if transform is not None:
134
+ transform = np.array(transform)
135
+ xyz = np.matmul(xyz, transform.T)
136
+ rotation = utils3d.numpy.quaternion_to_matrix(rotation)
137
+ rotation = np.matmul(transform, rotation)
138
+ rotation = utils3d.numpy.matrix_to_quaternion(rotation)
139
+
140
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
141
+
142
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
143
+ attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
144
+ elements[:] = list(map(tuple, attributes))
145
+ el = PlyElement.describe(elements, 'vertex')
146
+ PlyData([el]).write(path)
147
+
148
+ def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
149
+ plydata = PlyData.read(path)
150
+
151
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
152
+ np.asarray(plydata.elements[0]["y"]),
153
+ np.asarray(plydata.elements[0]["z"])), axis=1)
154
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
155
+
156
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
157
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
158
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
159
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
160
+
161
+ if self.sh_degree > 0:
162
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
163
+ extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
164
+ assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3
165
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
166
+ for idx, attr_name in enumerate(extra_f_names):
167
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
168
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
169
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
170
+
171
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
172
+ scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
173
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
174
+ for idx, attr_name in enumerate(scale_names):
175
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
176
+
177
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
178
+ rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
179
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
180
+ for idx, attr_name in enumerate(rot_names):
181
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
182
+
183
+ if transform is not None:
184
+ transform = np.array(transform)
185
+ xyz = np.matmul(xyz, transform)
186
+ rotation = utils3d.numpy.quaternion_to_matrix(rotation)
187
+ rotation = np.matmul(rotation, transform)
188
+ rotation = utils3d.numpy.matrix_to_quaternion(rotation)
189
+
190
+ # convert to actual gaussian attributes
191
+ xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
192
+ features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
193
+ if self.sh_degree > 0:
194
+ features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
195
+ opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device))
196
+ scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
197
+ rots = torch.tensor(rots, dtype=torch.float, device=self.device)
198
+
199
+ # convert to _hidden attributes
200
+ self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
201
+ self._features_dc = features_dc
202
+ if self.sh_degree > 0:
203
+ self._features_rest = features_extra
204
+ else:
205
+ self._features_rest = None
206
+ self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
207
+ self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias
208
+ self._rotation = rots - self.rots_bias[None, :]
209
+
trellis/representations/gaussian/general_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import sys
14
+ from datetime import datetime
15
+ import numpy as np
16
+ import random
17
+
18
+ def inverse_sigmoid(x):
19
+ return torch.log(x/(1-x))
20
+
21
+ def PILtoTorch(pil_image, resolution):
22
+ resized_image_PIL = pil_image.resize(resolution)
23
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24
+ if len(resized_image.shape) == 3:
25
+ return resized_image.permute(2, 0, 1)
26
+ else:
27
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28
+
29
+ def get_expon_lr_func(
30
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31
+ ):
32
+ """
33
+ Copied from Plenoxels
34
+
35
+ Continuous learning rate decay function. Adapted from JaxNeRF
36
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
37
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
38
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
39
+ function of lr_delay_mult, such that the initial learning rate is
40
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
41
+ to the normal learning rate when steps>lr_delay_steps.
42
+ :param conf: config subtree 'lr' or similar
43
+ :param max_steps: int, the number of steps during optimization.
44
+ :return HoF which takes step as input
45
+ """
46
+
47
+ def helper(step):
48
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
49
+ # Disable this parameter
50
+ return 0.0
51
+ if lr_delay_steps > 0:
52
+ # A kind of reverse cosine decay.
53
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
54
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
55
+ )
56
+ else:
57
+ delay_rate = 1.0
58
+ t = np.clip(step / max_steps, 0, 1)
59
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
60
+ return delay_rate * log_lerp
61
+
62
+ return helper
63
+
64
+ def strip_lowerdiag(L):
65
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66
+
67
+ uncertainty[:, 0] = L[:, 0, 0]
68
+ uncertainty[:, 1] = L[:, 0, 1]
69
+ uncertainty[:, 2] = L[:, 0, 2]
70
+ uncertainty[:, 3] = L[:, 1, 1]
71
+ uncertainty[:, 4] = L[:, 1, 2]
72
+ uncertainty[:, 5] = L[:, 2, 2]
73
+ return uncertainty
74
+
75
+ def strip_symmetric(sym):
76
+ return strip_lowerdiag(sym)
77
+
78
+ def build_rotation(r):
79
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
80
+
81
+ q = r / norm[:, None]
82
+
83
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
84
+
85
+ r = q[:, 0]
86
+ x = q[:, 1]
87
+ y = q[:, 2]
88
+ z = q[:, 3]
89
+
90
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91
+ R[:, 0, 1] = 2 * (x*y - r*z)
92
+ R[:, 0, 2] = 2 * (x*z + r*y)
93
+ R[:, 1, 0] = 2 * (x*y + r*z)
94
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95
+ R[:, 1, 2] = 2 * (y*z - r*x)
96
+ R[:, 2, 0] = 2 * (x*z - r*y)
97
+ R[:, 2, 1] = 2 * (y*z + r*x)
98
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99
+ return R
100
+
101
+ def build_scaling_rotation(s, r):
102
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103
+ R = build_rotation(r)
104
+
105
+ L[:,0,0] = s[:,0]
106
+ L[:,1,1] = s[:,1]
107
+ L[:,2,2] = s[:,2]
108
+
109
+ L = R @ L
110
+ return L
111
+
112
+ def safe_state(silent):
113
+ old_f = sys.stdout
114
+ class F:
115
+ def __init__(self, silent):
116
+ self.silent = silent
117
+
118
+ def write(self, x):
119
+ if not self.silent:
120
+ if x.endswith("\n"):
121
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
122
+ else:
123
+ old_f.write(x)
124
+
125
+ def flush(self):
126
+ old_f.flush()
127
+
128
+ sys.stdout = F(silent)
129
+
130
+ random.seed(0)
131
+ np.random.seed(0)
132
+ torch.manual_seed(0)
133
+ torch.cuda.set_device(torch.device("cuda:0"))
trellis/representations/mesh/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult
trellis/representations/mesh/cube2mesh.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ...modules.sparse import SparseTensor
3
+ from easydict import EasyDict as edict
4
+ from .utils_cube import *
5
+ from .flexicubes.flexicubes import FlexiCubes
6
+
7
+
8
+ class MeshExtractResult:
9
+ def __init__(self,
10
+ vertices,
11
+ faces,
12
+ vertex_attrs=None,
13
+ res=64
14
+ ):
15
+ self.vertices = vertices
16
+ self.faces = faces.long()
17
+ self.vertex_attrs = vertex_attrs
18
+ self.face_normal = self.comput_face_normals(vertices, faces)
19
+ self.res = res
20
+ self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0)
21
+
22
+ # training only
23
+ self.tsdf_v = None
24
+ self.tsdf_s = None
25
+ self.reg_loss = None
26
+
27
+ def comput_face_normals(self, verts, faces):
28
+ i0 = faces[..., 0].long()
29
+ i1 = faces[..., 1].long()
30
+ i2 = faces[..., 2].long()
31
+
32
+ v0 = verts[i0, :]
33
+ v1 = verts[i1, :]
34
+ v2 = verts[i2, :]
35
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
36
+ face_normals = torch.nn.functional.normalize(face_normals, dim=1)
37
+ # print(face_normals.min(), face_normals.max(), face_normals.shape)
38
+ return face_normals[:, None, :].repeat(1, 3, 1)
39
+
40
+ def comput_v_normals(self, verts, faces):
41
+ i0 = faces[..., 0].long()
42
+ i1 = faces[..., 1].long()
43
+ i2 = faces[..., 2].long()
44
+
45
+ v0 = verts[i0, :]
46
+ v1 = verts[i1, :]
47
+ v2 = verts[i2, :]
48
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
49
+ v_normals = torch.zeros_like(verts)
50
+ v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
51
+ v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
52
+ v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
53
+
54
+ v_normals = torch.nn.functional.normalize(v_normals, dim=1)
55
+ return v_normals
56
+
57
+
58
+ class SparseFeatures2Mesh:
59
+ def __init__(self, device="cuda", res=64, use_color=True):
60
+ '''
61
+ a model to generate a mesh from sparse features structures using flexicube
62
+ '''
63
+ super().__init__()
64
+ self.device=device
65
+ self.res = res
66
+ self.mesh_extractor = FlexiCubes(device=device)
67
+ self.sdf_bias = -1.0 / res
68
+ verts, cube = construct_dense_grid(self.res, self.device)
69
+ self.reg_c = cube.to(self.device)
70
+ self.reg_v = verts.to(self.device)
71
+ self.use_color = use_color
72
+ self._calc_layout()
73
+
74
+ def _calc_layout(self):
75
+ LAYOUTS = {
76
+ 'sdf': {'shape': (8, 1), 'size': 8},
77
+ 'deform': {'shape': (8, 3), 'size': 8 * 3},
78
+ 'weights': {'shape': (21,), 'size': 21}
79
+ }
80
+ if self.use_color:
81
+ '''
82
+ 6 channel color including normal map
83
+ '''
84
+ LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
85
+ self.layouts = edict(LAYOUTS)
86
+ start = 0
87
+ for k, v in self.layouts.items():
88
+ v['range'] = (start, start + v['size'])
89
+ start += v['size']
90
+ self.feats_channels = start
91
+
92
+ def get_layout(self, feats : torch.Tensor, name : str):
93
+ if name not in self.layouts:
94
+ return None
95
+ return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape'])
96
+
97
+ def __call__(self, cubefeats : SparseTensor, training=False):
98
+ """
99
+ Generates a mesh based on the specified sparse voxel structures.
100
+ Args:
101
+ cube_attrs [Nx21] : Sparse Tensor attrs about cube weights
102
+ verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal
103
+ Returns:
104
+ return the success tag and ni you loss,
105
+ """
106
+ # add sdf bias to verts_attrs
107
+ coords = cubefeats.coords[:, 1:]
108
+ feats = cubefeats.feats
109
+
110
+ sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']]
111
+ sdf += self.sdf_bias
112
+ v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
113
+ v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training)
114
+ v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True)
115
+ weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False)
116
+ if self.use_color:
117
+ sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:]
118
+ else:
119
+ sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4]
120
+ colors_d = None
121
+
122
+ x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res)
123
+
124
+ vertices, faces, L_dev, colors = self.mesh_extractor(
125
+ voxelgrid_vertices=x_nx3,
126
+ scalar_field=sdf_d,
127
+ cube_idx=self.reg_c,
128
+ resolution=self.res,
129
+ beta=weights_d[:, :12],
130
+ alpha=weights_d[:, 12:20],
131
+ gamma_f=weights_d[:, 20],
132
+ voxelgrid_colors=colors_d,
133
+ training=training)
134
+
135
+ mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res)
136
+ if training:
137
+ if mesh.success:
138
+ reg_loss += L_dev.mean() * 0.5
139
+ reg_loss += (weights[:,:20]).abs().mean() * 0.2
140
+ mesh.reg_loss = reg_loss
141
+ mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res)
142
+ mesh.tsdf_s = v_attrs[:, 0]
143
+ return mesh
trellis/representations/mesh/utils_cube.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
3
+ 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int)
4
+ cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]])
5
+ cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
6
+ 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False)
7
+
8
+ def construct_dense_grid(res, device='cuda'):
9
+ '''construct a dense grid based on resolution'''
10
+ res_v = res + 1
11
+ vertsid = torch.arange(res_v ** 3, device=device)
12
+ coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten()
13
+ cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2]
14
+ cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device))
15
+ verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1)
16
+ return verts, cube_fx8
17
+
18
+
19
+ def construct_voxel_grid(coords):
20
+ verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3)
21
+ verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True)
22
+ cubes = inverse_indices.reshape(-1, 8)
23
+ return verts_unique, cubes
24
+
25
+
26
+ def cubes_to_verts(num_verts, cubes, value, reduce='mean'):
27
+ """
28
+ Args:
29
+ cubes [Vx8] verts index for each cube
30
+ value [Vx8xM] value to be scattered
31
+ Operation:
32
+ reduced[cubes[i][j]][k] += value[i][k]
33
+ """
34
+ M = value.shape[2] # number of channels
35
+ reduced = torch.zeros(num_verts, M, device=cubes.device)
36
+ return torch.scatter_reduce(reduced, 0,
37
+ cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1),
38
+ value.flatten(0, 1), reduce=reduce, include_self=False)
39
+
40
+ def sparse_cube2verts(coords, feats, training=True):
41
+ new_coords, cubes = construct_voxel_grid(coords)
42
+ new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats)
43
+ if training:
44
+ con_loss = torch.mean((feats - new_feats[cubes]) ** 2)
45
+ else:
46
+ con_loss = 0.0
47
+ return new_coords, new_feats, con_loss
48
+
49
+
50
+ def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True):
51
+ F = feats.shape[-1]
52
+ dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device)
53
+ if sdf_init:
54
+ dense_attrs[..., 0] = 1 # initial outside sdf value
55
+ dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats
56
+ return dense_attrs.reshape(-1, F)
57
+
58
+
59
+ def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res):
60
+ return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform)
61
+
trellis/representations/octree/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .octree_dfs import DfsOctree
trellis/representations/octree/octree_dfs.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DfsOctree:
7
+ """
8
+ Sparse Voxel Octree (SVO) implementation for PyTorch.
9
+ Using Depth-First Search (DFS) order to store the octree.
10
+ DFS order suits rendering and ray tracing.
11
+
12
+ The structure and data are separatedly stored.
13
+ Structure is stored as a continuous array, each element is a 3*32 bits descriptor.
14
+ |-----------------------------------------|
15
+ | 0:3 bits | 4:31 bits |
16
+ | leaf num | unused |
17
+ |-----------------------------------------|
18
+ | 0:31 bits |
19
+ | child ptr |
20
+ |-----------------------------------------|
21
+ | 0:31 bits |
22
+ | data ptr |
23
+ |-----------------------------------------|
24
+ Each element represents a non-leaf node in the octree.
25
+ The valid mask is used to indicate whether the children are valid.
26
+ The leaf mask is used to indicate whether the children are leaf nodes.
27
+ The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr.
28
+ The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr.
29
+
30
+ There are also auxiliary arrays to store the additional structural information to facilitate parallel processing.
31
+ - Position: the position of the octree nodes.
32
+ - Depth: the depth of the octree nodes.
33
+
34
+ Args:
35
+ depth (int): the depth of the octree.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ depth,
41
+ aabb=[0,0,0,1,1,1],
42
+ sh_degree=2,
43
+ primitive='voxel',
44
+ primitive_config={},
45
+ device='cuda',
46
+ ):
47
+ self.max_depth = depth
48
+ self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
49
+ self.device = device
50
+ self.sh_degree = sh_degree
51
+ self.active_sh_degree = sh_degree
52
+ self.primitive = primitive
53
+ self.primitive_config = primitive_config
54
+
55
+ self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device)
56
+ self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device)
57
+ self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device)
58
+ self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device)
59
+ self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device)
60
+ self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device)
61
+ self.depth[:, 0] = 1
62
+
63
+ self.data = ['position', 'depth']
64
+ self.param_names = []
65
+
66
+ if primitive == 'voxel':
67
+ self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
68
+ self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
69
+ self.data += ['features_dc', 'features_ac']
70
+ self.param_names += ['features_dc', 'features_ac']
71
+ if not primitive_config.get('solid', False):
72
+ self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
73
+ self.data.append('density')
74
+ self.param_names.append('density')
75
+ elif primitive == 'gaussian':
76
+ self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
77
+ self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
78
+ self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
79
+ self.data += ['features_dc', 'features_ac', 'opacity']
80
+ self.param_names += ['features_dc', 'features_ac', 'opacity']
81
+ elif primitive == 'trivec':
82
+ self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device)
83
+ self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
84
+ self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
85
+ self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
86
+ self.density_shift = 0
87
+ self.data += ['trivec', 'density', 'features_dc', 'features_ac']
88
+ self.param_names += ['trivec', 'density', 'features_dc', 'features_ac']
89
+ elif primitive == 'decoupoly':
90
+ self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device)
91
+ self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device)
92
+ self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
93
+ self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
94
+ self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
95
+ self.density_shift = 0
96
+ self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
97
+ self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
98
+
99
+ self.setup_functions()
100
+
101
+ def setup_functions(self):
102
+ self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x)
103
+ self.opacity_activation = lambda x: torch.sigmoid(x - 6)
104
+ self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6
105
+ self.color_activation = lambda x: torch.sigmoid(x)
106
+
107
+ @property
108
+ def num_non_leaf_nodes(self):
109
+ return self.structure.shape[0]
110
+
111
+ @property
112
+ def num_leaf_nodes(self):
113
+ return self.depth.shape[0]
114
+
115
+ @property
116
+ def cur_depth(self):
117
+ return self.depth.max().item()
118
+
119
+ @property
120
+ def occupancy(self):
121
+ return self.num_leaf_nodes / 8 ** self.cur_depth
122
+
123
+ @property
124
+ def get_xyz(self):
125
+ return self.position
126
+
127
+ @property
128
+ def get_depth(self):
129
+ return self.depth
130
+
131
+ @property
132
+ def get_density(self):
133
+ if self.primitive == 'voxel' and self.primitive_config.get('solid', False):
134
+ return torch.full((self.position.shape[0], 1), torch.finfo(torch.float32).max, dtype=torch.float32, device=self.device)
135
+ return self.density_activation(self.density)
136
+
137
+ @property
138
+ def get_opacity(self):
139
+ return self.opacity_activation(self.density)
140
+
141
+ @property
142
+ def get_trivec(self):
143
+ return self.trivec
144
+
145
+ @property
146
+ def get_decoupoly(self):
147
+ return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g
148
+
149
+ @property
150
+ def get_color(self):
151
+ return self.color_activation(self.colors)
152
+
153
+ @property
154
+ def get_features(self):
155
+ if self.sh_degree == 0:
156
+ return self.features_dc
157
+ return torch.cat([self.features_dc, self.features_ac], dim=-2)
158
+
159
+ def state_dict(self):
160
+ ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'primitive_config': self.primitive_config, 'primitive': self.primitive}
161
+ if hasattr(self, 'density_shift'):
162
+ ret['density_shift'] = self.density_shift
163
+ for data in set(self.data + self.param_names):
164
+ if not isinstance(getattr(self, data), nn.Module):
165
+ ret[data] = getattr(self, data)
166
+ else:
167
+ ret[data] = getattr(self, data).state_dict()
168
+ return ret
169
+
170
+ def load_state_dict(self, state_dict):
171
+ keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth']))
172
+ for key in keys:
173
+ if key not in state_dict:
174
+ print(f"Warning: key {key} not found in the state_dict.")
175
+ continue
176
+ try:
177
+ if not isinstance(getattr(self, key), nn.Module):
178
+ setattr(self, key, state_dict[key])
179
+ else:
180
+ getattr(self, key).load_state_dict(state_dict[key])
181
+ except Exception as e:
182
+ print(e)
183
+ raise ValueError(f"Error loading key {key}.")
184
+
185
+ def gather_from_leaf_children(self, data):
186
+ """
187
+ Gather the data from the leaf children.
188
+
189
+ Args:
190
+ data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
191
+ """
192
+ leaf_cnt = self.structure[:, 0]
193
+ leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)]
194
+ ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device)
195
+ for i in range(8):
196
+ if leaf_cnt_masks[i].sum() == 0:
197
+ continue
198
+ start = self.structure[leaf_cnt_masks[i], 2]
199
+ for j in range(i+1):
200
+ ret[leaf_cnt_masks[i]] += data[start + j]
201
+ return ret
202
+
203
+ def gather_from_non_leaf_children(self, data):
204
+ """
205
+ Gather the data from the non-leaf children.
206
+
207
+ Args:
208
+ data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
209
+ """
210
+ non_leaf_cnt = 8 - self.structure[:, 0]
211
+ non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)]
212
+ ret = torch.zeros_like(data, device=self.device)
213
+ for i in range(8):
214
+ if non_leaf_cnt_masks[i].sum() == 0:
215
+ continue
216
+ start = self.structure[non_leaf_cnt_masks[i], 1]
217
+ for j in range(i+1):
218
+ ret[non_leaf_cnt_masks[i]] += data[start + j]
219
+ return ret
220
+
221
+ def structure_control(self, mask):
222
+ """
223
+ Control the structure of the octree.
224
+
225
+ Args:
226
+ mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep.
227
+ """
228
+ # Dont subdivide when the depth is the maximum.
229
+ mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0)
230
+ # Dont merge when the depth is the minimum.
231
+ mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0)
232
+
233
+ # Gather control mask
234
+ structre_ctrl = self.gather_from_leaf_children(mask)
235
+ structre_ctrl[structre_ctrl==-8] = -1
236
+
237
+ new_leaf_num = self.structure[:, 0].clone()
238
+ # Modify the leaf num.
239
+ structre_valid = structre_ctrl >= 0
240
+ new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes.
241
+ structre_delete = structre_ctrl < 0
242
+ merged_nodes = self.gather_from_non_leaf_children(structre_delete.int())
243
+ new_leaf_num += merged_nodes # Delete the merged nodes.
244
+
245
+ # Update the structure array to allocate new nodes.
246
+ mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
247
+ mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes.
248
+ mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes.
249
+ new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
250
+ new_structure_length = new_structre_idx[-1].item()
251
+ new_structre_idx = new_structre_idx[:-1]
252
+ new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device)
253
+ new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid]
254
+
255
+ # Initialize the new nodes.
256
+ new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device)
257
+ new_node_mask[new_structre_idx[structre_valid]] = False
258
+ new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes.
259
+ new_node_num = new_node_mask.sum().item()
260
+
261
+ # Rebuild child ptr.
262
+ non_leaf_cnt = 8 - new_structure[:, 0]
263
+ new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]])
264
+ new_structure[:, 1] = new_child_ptr + 1
265
+
266
+ # Rebuild data ptr with old data.
267
+ leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device)
268
+ leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0])
269
+ old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
270
+
271
+ # Update the data array
272
+ subdivide_mask = mask == 1
273
+ merge_mask = mask == -1
274
+ data_valid = ~(subdivide_mask | merge_mask)
275
+ mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
276
+ mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes
277
+ mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes
278
+ mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes
279
+ mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes
280
+ new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
281
+ new_data_length = new_data_idx[-1].item()
282
+ new_data_idx = new_data_idx[:-1]
283
+ new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data}
284
+ for data in self.data:
285
+ new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid]
286
+
287
+ # Rebuild data ptr
288
+ leaf_cnt = new_structure[:, 0]
289
+ new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
290
+ new_structure[:, 2] = new_data_ptr
291
+
292
+ # Initialize the new data array
293
+ ## For subdivide nodes
294
+ if subdivide_mask.sum() > 0:
295
+ subdivide_data_ptr = new_structure[new_node_mask, 2]
296
+ for data in self.data:
297
+ for i in range(8):
298
+ if data == 'position':
299
+ offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5
300
+ scale = 2 ** (-1.0 - self.depth[subdivide_mask])
301
+ new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale
302
+ elif data == 'depth':
303
+ new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1
304
+ elif data == 'opacity':
305
+ new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask])))
306
+ elif data == 'trivec':
307
+ offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5
308
+ coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1)
309
+ axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1)
310
+ coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1
311
+ new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True)
312
+ else:
313
+ new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask]
314
+ ## For merge nodes
315
+ if merge_mask.sum() > 0:
316
+ merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device)
317
+ merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]])
318
+ for i in range(8):
319
+ merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i
320
+ old_merge_data_ptr = self.structure[structre_delete, 2]
321
+ for data in self.data:
322
+ if data == 'position':
323
+ scale = 2 ** (1.0 - self.depth[old_merge_data_ptr])
324
+ new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5
325
+ elif data == 'depth':
326
+ new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1
327
+ elif data == 'opacity':
328
+ new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2)
329
+ elif data == 'trivec':
330
+ new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr]
331
+ else:
332
+ new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr]
333
+
334
+ # Update the structure and data array
335
+ self.structure = new_structure
336
+ for data in self.data:
337
+ setattr(self, data, new_data[data])
338
+
339
+ # Save data array control temp variables
340
+ self.data_rearrange_buffer = {
341
+ 'subdivide_mask': subdivide_mask,
342
+ 'merge_mask': merge_mask,
343
+ 'data_valid': data_valid,
344
+ 'new_data_idx': new_data_idx,
345
+ 'new_data_length': new_data_length,
346
+ 'new_data': new_data
347
+ }
trellis/representations/radiance_field/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .strivec import Strivec
trellis/representations/radiance_field/strivec.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from ..octree import DfsOctree as Octree
6
+
7
+
8
+ class Strivec(Octree):
9
+ def __init__(
10
+ self,
11
+ resolution: int,
12
+ aabb: list,
13
+ sh_degree: int = 0,
14
+ rank: int = 8,
15
+ dim: int = 8,
16
+ device: str = "cuda",
17
+ ):
18
+ assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2"
19
+ self.resolution = resolution
20
+ depth = int(np.round(np.log2(resolution)))
21
+ super().__init__(
22
+ depth=depth,
23
+ aabb=aabb,
24
+ sh_degree=sh_degree,
25
+ primitive="trivec",
26
+ primitive_config={"rank": rank, "dim": dim},
27
+ device=device,
28
+ )