cavargas10 commited on
Commit
ffb0acd
·
verified ·
1 Parent(s): 59cf4a6

Update trellis/representations/mesh/cube2mesh.py

Browse files
trellis/representations/mesh/cube2mesh.py CHANGED
@@ -1,143 +1,143 @@
1
- import torch
2
- from ...modules.sparse import SparseTensor
3
- from easydict import EasyDict as edict
4
- from .utils_cube import *
5
- from .flexicubes.flexicubes import FlexiCubes
6
-
7
-
8
- class MeshExtractResult:
9
- def __init__(self,
10
- vertices,
11
- faces,
12
- vertex_attrs=None,
13
- res=64
14
- ):
15
- self.vertices = vertices
16
- self.faces = faces.long()
17
- self.vertex_attrs = vertex_attrs
18
- self.face_normal = self.comput_face_normals(vertices, faces)
19
- self.res = res
20
- self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0)
21
-
22
- # training only
23
- self.tsdf_v = None
24
- self.tsdf_s = None
25
- self.reg_loss = None
26
-
27
- def comput_face_normals(self, verts, faces):
28
- i0 = faces[..., 0].long()
29
- i1 = faces[..., 1].long()
30
- i2 = faces[..., 2].long()
31
-
32
- v0 = verts[i0, :]
33
- v1 = verts[i1, :]
34
- v2 = verts[i2, :]
35
- face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
36
- face_normals = torch.nn.functional.normalize(face_normals, dim=1)
37
- # print(face_normals.min(), face_normals.max(), face_normals.shape)
38
- return face_normals[:, None, :].repeat(1, 3, 1)
39
-
40
- def comput_v_normals(self, verts, faces):
41
- i0 = faces[..., 0].long()
42
- i1 = faces[..., 1].long()
43
- i2 = faces[..., 2].long()
44
-
45
- v0 = verts[i0, :]
46
- v1 = verts[i1, :]
47
- v2 = verts[i2, :]
48
- face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
49
- v_normals = torch.zeros_like(verts)
50
- v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
51
- v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
52
- v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
53
-
54
- v_normals = torch.nn.functional.normalize(v_normals, dim=1)
55
- return v_normals
56
-
57
-
58
- class SparseFeatures2Mesh:
59
- def __init__(self, device="cuda", res=64, use_color=True):
60
- '''
61
- a model to generate a mesh from sparse features structures using flexicube
62
- '''
63
- super().__init__()
64
- self.device=device
65
- self.res = res
66
- self.mesh_extractor = FlexiCubes(device=device)
67
- self.sdf_bias = -1.0 / res
68
- verts, cube = construct_dense_grid(self.res, self.device)
69
- self.reg_c = cube.to(self.device)
70
- self.reg_v = verts.to(self.device)
71
- self.use_color = use_color
72
- self._calc_layout()
73
-
74
- def _calc_layout(self):
75
- LAYOUTS = {
76
- 'sdf': {'shape': (8, 1), 'size': 8},
77
- 'deform': {'shape': (8, 3), 'size': 8 * 3},
78
- 'weights': {'shape': (21,), 'size': 21}
79
- }
80
- if self.use_color:
81
- '''
82
- 6 channel color including normal map
83
- '''
84
- LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
85
- self.layouts = edict(LAYOUTS)
86
- start = 0
87
- for k, v in self.layouts.items():
88
- v['range'] = (start, start + v['size'])
89
- start += v['size']
90
- self.feats_channels = start
91
-
92
- def get_layout(self, feats : torch.Tensor, name : str):
93
- if name not in self.layouts:
94
- return None
95
- return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape'])
96
-
97
- def __call__(self, cubefeats : SparseTensor, training=False):
98
- """
99
- Generates a mesh based on the specified sparse voxel structures.
100
- Args:
101
- cube_attrs [Nx21] : Sparse Tensor attrs about cube weights
102
- verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal
103
- Returns:
104
- return the success tag and ni you loss,
105
- """
106
- # add sdf bias to verts_attrs
107
- coords = cubefeats.coords[:, 1:]
108
- feats = cubefeats.feats
109
-
110
- sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']]
111
- sdf += self.sdf_bias
112
- v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
113
- v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training)
114
- v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True)
115
- weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False)
116
- if self.use_color:
117
- sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:]
118
- else:
119
- sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4]
120
- colors_d = None
121
-
122
- x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res)
123
-
124
- vertices, faces, L_dev, colors = self.mesh_extractor(
125
- voxelgrid_vertices=x_nx3,
126
- scalar_field=sdf_d,
127
- cube_idx=self.reg_c,
128
- resolution=self.res,
129
- beta=weights_d[:, :12],
130
- alpha=weights_d[:, 12:20],
131
- gamma_f=weights_d[:, 20],
132
- voxelgrid_colors=colors_d,
133
- training=training)
134
-
135
- mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res)
136
- if training:
137
- if mesh.success:
138
- reg_loss += L_dev.mean() * 0.5
139
- reg_loss += (weights[:,:20]).abs().mean() * 0.2
140
- mesh.reg_loss = reg_loss
141
- mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res)
142
- mesh.tsdf_s = v_attrs[:, 0]
143
- return mesh
 
1
+ import torch
2
+ from ...modules.sparse import SparseTensor
3
+ from easydict import EasyDict as edict
4
+ from .utils_cube import *
5
+ from .flexicubes import FlexiCubes
6
+
7
+
8
+ class MeshExtractResult:
9
+ def __init__(self,
10
+ vertices,
11
+ faces,
12
+ vertex_attrs=None,
13
+ res=64
14
+ ):
15
+ self.vertices = vertices
16
+ self.faces = faces.long()
17
+ self.vertex_attrs = vertex_attrs
18
+ self.face_normal = self.comput_face_normals(vertices, faces)
19
+ self.res = res
20
+ self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0)
21
+
22
+ # training only
23
+ self.tsdf_v = None
24
+ self.tsdf_s = None
25
+ self.reg_loss = None
26
+
27
+ def comput_face_normals(self, verts, faces):
28
+ i0 = faces[..., 0].long()
29
+ i1 = faces[..., 1].long()
30
+ i2 = faces[..., 2].long()
31
+
32
+ v0 = verts[i0, :]
33
+ v1 = verts[i1, :]
34
+ v2 = verts[i2, :]
35
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
36
+ face_normals = torch.nn.functional.normalize(face_normals, dim=1)
37
+ # print(face_normals.min(), face_normals.max(), face_normals.shape)
38
+ return face_normals[:, None, :].repeat(1, 3, 1)
39
+
40
+ def comput_v_normals(self, verts, faces):
41
+ i0 = faces[..., 0].long()
42
+ i1 = faces[..., 1].long()
43
+ i2 = faces[..., 2].long()
44
+
45
+ v0 = verts[i0, :]
46
+ v1 = verts[i1, :]
47
+ v2 = verts[i2, :]
48
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
49
+ v_normals = torch.zeros_like(verts)
50
+ v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
51
+ v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
52
+ v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
53
+
54
+ v_normals = torch.nn.functional.normalize(v_normals, dim=1)
55
+ return v_normals
56
+
57
+
58
+ class SparseFeatures2Mesh:
59
+ def __init__(self, device="cuda", res=64, use_color=True):
60
+ '''
61
+ a model to generate a mesh from sparse features structures using flexicube
62
+ '''
63
+ super().__init__()
64
+ self.device=device
65
+ self.res = res
66
+ self.mesh_extractor = FlexiCubes(device=device)
67
+ self.sdf_bias = -1.0 / res
68
+ verts, cube = construct_dense_grid(self.res, self.device)
69
+ self.reg_c = cube.to(self.device)
70
+ self.reg_v = verts.to(self.device)
71
+ self.use_color = use_color
72
+ self._calc_layout()
73
+
74
+ def _calc_layout(self):
75
+ LAYOUTS = {
76
+ 'sdf': {'shape': (8, 1), 'size': 8},
77
+ 'deform': {'shape': (8, 3), 'size': 8 * 3},
78
+ 'weights': {'shape': (21,), 'size': 21}
79
+ }
80
+ if self.use_color:
81
+ '''
82
+ 6 channel color including normal map
83
+ '''
84
+ LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
85
+ self.layouts = edict(LAYOUTS)
86
+ start = 0
87
+ for k, v in self.layouts.items():
88
+ v['range'] = (start, start + v['size'])
89
+ start += v['size']
90
+ self.feats_channels = start
91
+
92
+ def get_layout(self, feats : torch.Tensor, name : str):
93
+ if name not in self.layouts:
94
+ return None
95
+ return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape'])
96
+
97
+ def __call__(self, cubefeats : SparseTensor, training=False):
98
+ """
99
+ Generates a mesh based on the specified sparse voxel structures.
100
+ Args:
101
+ cube_attrs [Nx21] : Sparse Tensor attrs about cube weights
102
+ verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal
103
+ Returns:
104
+ return the success tag and ni you loss,
105
+ """
106
+ # add sdf bias to verts_attrs
107
+ coords = cubefeats.coords[:, 1:]
108
+ feats = cubefeats.feats
109
+
110
+ sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']]
111
+ sdf += self.sdf_bias
112
+ v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
113
+ v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training)
114
+ v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True)
115
+ weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False)
116
+ if self.use_color:
117
+ sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:]
118
+ else:
119
+ sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4]
120
+ colors_d = None
121
+
122
+ x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res)
123
+
124
+ vertices, faces, L_dev, colors = self.mesh_extractor(
125
+ voxelgrid_vertices=x_nx3,
126
+ scalar_field=sdf_d,
127
+ cube_idx=self.reg_c,
128
+ resolution=self.res,
129
+ beta=weights_d[:, :12],
130
+ alpha=weights_d[:, 12:20],
131
+ gamma_f=weights_d[:, 20],
132
+ voxelgrid_colors=colors_d,
133
+ training=training)
134
+
135
+ mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res)
136
+ if training:
137
+ if mesh.success:
138
+ reg_loss += L_dev.mean() * 0.5
139
+ reg_loss += (weights[:,:20]).abs().mean() * 0.2
140
+ mesh.reg_loss = reg_loss
141
+ mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res)
142
+ mesh.tsdf_s = v_attrs[:, 0]
143
+ return mesh