cavargas10 commited on
Commit
68bd713
·
verified ·
1 Parent(s): 169b783

Update trellis/pipelines/trellis_text_to_3d.py

Browse files
Files changed (1) hide show
  1. trellis/pipelines/trellis_text_to_3d.py +228 -278
trellis/pipelines/trellis_text_to_3d.py CHANGED
@@ -1,278 +1,228 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from transformers import CLIPTextModel, AutoTokenizer
6
- import open3d as o3d
7
- from .base import Pipeline
8
- from . import samplers
9
- from ..modules import sparse as sp
10
-
11
-
12
- class TrellisTextTo3DPipeline(Pipeline):
13
- """
14
- Pipeline for inferring Trellis text-to-3D models.
15
-
16
- Args:
17
- models (dict[str, nn.Module]): The models to use in the pipeline.
18
- sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
19
- slat_sampler (samplers.Sampler): The sampler for the structured latent.
20
- slat_normalization (dict): The normalization parameters for the structured latent.
21
- text_cond_model (str): The name of the text conditioning model.
22
- """
23
- def __init__(
24
- self,
25
- models: dict[str, nn.Module] = None,
26
- sparse_structure_sampler: samplers.Sampler = None,
27
- slat_sampler: samplers.Sampler = None,
28
- slat_normalization: dict = None,
29
- text_cond_model: str = None,
30
- ):
31
- if models is None:
32
- return
33
- super().__init__(models)
34
- self.sparse_structure_sampler = sparse_structure_sampler
35
- self.slat_sampler = slat_sampler
36
- self.sparse_structure_sampler_params = {}
37
- self.slat_sampler_params = {}
38
- self.slat_normalization = slat_normalization
39
- self._init_text_cond_model(text_cond_model)
40
-
41
- @staticmethod
42
- def from_pretrained(path: str) -> "TrellisTextTo3DPipeline":
43
- """
44
- Load a pretrained model.
45
-
46
- Args:
47
- path (str): The path to the model. Can be either local path or a Hugging Face repository.
48
- """
49
- pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path)
50
- new_pipeline = TrellisTextTo3DPipeline()
51
- new_pipeline.__dict__ = pipeline.__dict__
52
- args = pipeline._pretrained_args
53
-
54
- new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
55
- new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
56
-
57
- new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
58
- new_pipeline.slat_sampler_params = args['slat_sampler']['params']
59
-
60
- new_pipeline.slat_normalization = args['slat_normalization']
61
-
62
- new_pipeline._init_text_cond_model(args['text_cond_model'])
63
-
64
- return new_pipeline
65
-
66
- def _init_text_cond_model(self, name: str):
67
- """
68
- Initialize the text conditioning model.
69
- """
70
- # load model
71
- model = CLIPTextModel.from_pretrained(name)
72
- tokenizer = AutoTokenizer.from_pretrained(name)
73
- model.eval()
74
- model = model.cuda()
75
- self.text_cond_model = {
76
- 'model': model,
77
- 'tokenizer': tokenizer,
78
- }
79
- self.text_cond_model['null_cond'] = self.encode_text([''])
80
-
81
- @torch.no_grad()
82
- def encode_text(self, text: List[str]) -> torch.Tensor:
83
- """
84
- Encode the text.
85
- """
86
- assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
87
- encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
88
- tokens = encoding['input_ids'].cuda()
89
- embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
90
-
91
- return embeddings
92
-
93
- def get_cond(self, prompt: List[str]) -> dict:
94
- """
95
- Get the conditioning information for the model.
96
-
97
- Args:
98
- prompt (List[str]): The text prompt.
99
-
100
- Returns:
101
- dict: The conditioning information
102
- """
103
- cond = self.encode_text(prompt)
104
- neg_cond = self.text_cond_model['null_cond']
105
- return {
106
- 'cond': cond,
107
- 'neg_cond': neg_cond,
108
- }
109
-
110
- def sample_sparse_structure(
111
- self,
112
- cond: dict,
113
- num_samples: int = 1,
114
- sampler_params: dict = {},
115
- ) -> torch.Tensor:
116
- """
117
- Sample sparse structures with the given conditioning.
118
-
119
- Args:
120
- cond (dict): The conditioning information.
121
- num_samples (int): The number of samples to generate.
122
- sampler_params (dict): Additional parameters for the sampler.
123
- """
124
- # Sample occupancy latent
125
- flow_model = self.models['sparse_structure_flow_model']
126
- reso = flow_model.resolution
127
- noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
128
- sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
129
- z_s = self.sparse_structure_sampler.sample(
130
- flow_model,
131
- noise,
132
- **cond,
133
- **sampler_params,
134
- verbose=True
135
- ).samples
136
-
137
- # Decode occupancy latent
138
- decoder = self.models['sparse_structure_decoder']
139
- coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
140
-
141
- return coords
142
-
143
- def decode_slat(
144
- self,
145
- slat: sp.SparseTensor,
146
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
147
- ) -> dict:
148
- """
149
- Decode the structured latent.
150
-
151
- Args:
152
- slat (sp.SparseTensor): The structured latent.
153
- formats (List[str]): The formats to decode the structured latent to.
154
-
155
- Returns:
156
- dict: The decoded structured latent.
157
- """
158
- ret = {}
159
- if 'mesh' in formats:
160
- ret['mesh'] = self.models['slat_decoder_mesh'](slat)
161
- if 'gaussian' in formats:
162
- ret['gaussian'] = self.models['slat_decoder_gs'](slat)
163
- if 'radiance_field' in formats:
164
- ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
165
- return ret
166
-
167
- def sample_slat(
168
- self,
169
- cond: dict,
170
- coords: torch.Tensor,
171
- sampler_params: dict = {},
172
- ) -> sp.SparseTensor:
173
- """
174
- Sample structured latent with the given conditioning.
175
-
176
- Args:
177
- cond (dict): The conditioning information.
178
- coords (torch.Tensor): The coordinates of the sparse structure.
179
- sampler_params (dict): Additional parameters for the sampler.
180
- """
181
- # Sample structured latent
182
- flow_model = self.models['slat_flow_model']
183
- noise = sp.SparseTensor(
184
- feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
185
- coords=coords,
186
- )
187
- sampler_params = {**self.slat_sampler_params, **sampler_params}
188
- slat = self.slat_sampler.sample(
189
- flow_model,
190
- noise,
191
- **cond,
192
- **sampler_params,
193
- verbose=True
194
- ).samples
195
-
196
- std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
197
- mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
198
- slat = slat * std + mean
199
-
200
- return slat
201
-
202
- @torch.no_grad()
203
- def run(
204
- self,
205
- prompt: str,
206
- num_samples: int = 1,
207
- seed: int = 42,
208
- sparse_structure_sampler_params: dict = {},
209
- slat_sampler_params: dict = {},
210
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
211
- ) -> dict:
212
- """
213
- Run the pipeline.
214
-
215
- Args:
216
- prompt (str): The text prompt.
217
- num_samples (int): The number of samples to generate.
218
- seed (int): The random seed.
219
- sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
220
- slat_sampler_params (dict): Additional parameters for the structured latent sampler.
221
- formats (List[str]): The formats to decode the structured latent to.
222
- """
223
- cond = self.get_cond([prompt])
224
- torch.manual_seed(seed)
225
- coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
226
- slat = self.sample_slat(cond, coords, slat_sampler_params)
227
- return self.decode_slat(slat, formats)
228
-
229
- def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
230
- """
231
- Voxelize a mesh.
232
-
233
- Args:
234
- mesh (o3d.geometry.TriangleMesh): The mesh to voxelize.
235
- sha256 (str): The SHA256 hash of the mesh.
236
- output_dir (str): The output directory.
237
- """
238
- vertices = np.asarray(mesh.vertices)
239
- aabb = np.stack([vertices.min(0), vertices.max(0)])
240
- center = (aabb[0] + aabb[1]) / 2
241
- scale = (aabb[1] - aabb[0]).max()
242
- vertices = (vertices - center) / scale
243
- vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6)
244
- mesh.vertices = o3d.utility.Vector3dVector(vertices)
245
- voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
246
- vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
247
- return torch.tensor(vertices).int().cuda()
248
-
249
- @torch.no_grad()
250
- def run_variant(
251
- self,
252
- mesh: o3d.geometry.TriangleMesh,
253
- prompt: str,
254
- num_samples: int = 1,
255
- seed: int = 42,
256
- slat_sampler_params: dict = {},
257
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
258
- ) -> dict:
259
- """
260
- Run the pipeline for making variants of an asset.
261
-
262
- Args:
263
- mesh (o3d.geometry.TriangleMesh): The base mesh.
264
- prompt (str): The text prompt.
265
- num_samples (int): The number of samples to generate.
266
- seed (int): The random seed
267
- slat_sampler_params (dict): Additional parameters for the structured latent sampler.
268
- formats (List[str]): The formats to decode the structured latent to.
269
- """
270
- cond = self.get_cond([prompt])
271
- coords = self.voxelize(mesh)
272
- coords = torch.cat([
273
- torch.arange(num_samples).repeat_interleave(coords.shape[0], 0)[:, None].int().cuda(),
274
- coords.repeat(num_samples, 1)
275
- ], 1)
276
- torch.manual_seed(seed)
277
- slat = self.sample_slat(cond, coords, slat_sampler_params)
278
- return self.decode_slat(slat, formats)
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from transformers import CLIPTextModel, AutoTokenizer
7
+ import open3d as o3d
8
+ from .base import Pipeline
9
+ from . import samplers
10
+ from ..modules import sparse as sp
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ class TrellisTextTo3DPipeline(Pipeline):
15
+ """
16
+ Pipeline for inferring Trellis text-to-3D models.
17
+
18
+ Args:
19
+ models (dict[str, nn.Module]): The models to use in the pipeline.
20
+ sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
21
+ slat_sampler (samplers.Sampler): The sampler for the structured latent.
22
+ slat_normalization (dict): The normalization parameters for the structured latent.
23
+ text_cond_model (str): The name of the text conditioning model.
24
+ """
25
+ def __init__(
26
+ self,
27
+ models: dict[str, nn.Module] = None,
28
+ sparse_structure_sampler: samplers.Sampler = None,
29
+ slat_sampler: samplers.Sampler = None,
30
+ slat_normalization: dict = None,
31
+ text_cond_model: str = None,
32
+ ):
33
+ if models is None:
34
+ return
35
+ super().__init__(models)
36
+ self.sparse_structure_sampler = sparse_structure_sampler
37
+ self.slat_sampler = slat_sampler
38
+ self.sparse_structure_sampler_params = {}
39
+ self.slat_sampler_params = {}
40
+ self.slat_normalization = slat_normalization
41
+ self._init_text_cond_model(text_cond_model)
42
+
43
+ @staticmethod
44
+ def from_pretrained(path: str) -> "TrellisTextTo3DPipeline":
45
+ """
46
+ Load a pretrained model.
47
+
48
+ Args:
49
+ path (str): The path to the model. Can be either local path or a Hugging Face repository.
50
+ """
51
+ pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path)
52
+ new_pipeline = TrellisTextTo3DPipeline()
53
+ new_pipeline.__dict__ = pipeline.__dict__
54
+ args = pipeline._pretrained_args
55
+
56
+ new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
57
+ new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
58
+
59
+ new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
60
+ new_pipeline.slat_sampler_params = args['slat_sampler']['params']
61
+
62
+ new_pipeline.slat_normalization = args['slat_normalization']
63
+
64
+ new_pipeline._init_text_cond_model(args['text_cond_model'])
65
+
66
+ return new_pipeline
67
+
68
+ def _init_text_cond_model(self, name: str):
69
+ """
70
+ Initialize the text conditioning model.
71
+ """
72
+ # load model
73
+ model = CLIPTextModel.from_pretrained(name).to(self.device)
74
+ tokenizer = AutoTokenizer.from_pretrained(name).to(self.device)
75
+ model.eval()
76
+ self.text_cond_model = {
77
+ 'model': model,
78
+ 'tokenizer': tokenizer,
79
+ }
80
+ self.text_cond_model['null_cond'] = self.encode_text([''])
81
+
82
+ @torch.no_grad()
83
+ def encode_text(self, text: List[str]) -> torch.Tensor:
84
+ """
85
+ Encode the text.
86
+ """
87
+ assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
88
+ encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
89
+ tokens = encoding['input_ids'].to(self.device)
90
+ embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
91
+
92
+ return embeddings
93
+
94
+ def get_cond(self, prompt: List[str]) -> dict:
95
+ """
96
+ Get the conditioning information for the model.
97
+
98
+ Args:
99
+ prompt (List[str]): The text prompt.
100
+
101
+ Returns:
102
+ dict: The conditioning information
103
+ """
104
+ cond = self.encode_text(prompt)
105
+ neg_cond = self.text_cond_model['null_cond']
106
+ return {
107
+ 'cond': cond,
108
+ 'neg_cond': neg_cond,
109
+ }
110
+
111
+ def sample_sparse_structure(
112
+ self,
113
+ cond: dict,
114
+ num_samples: int = 1,
115
+ sampler_params: dict = {},
116
+ ) -> torch.Tensor:
117
+ """
118
+ Sample sparse structures with the given conditioning.
119
+
120
+ Args:
121
+ cond (dict): The conditioning information.
122
+ num_samples (int): The number of samples to generate.
123
+ sampler_params (dict): Additional parameters for the sampler.
124
+ """
125
+ # Sample occupancy latent
126
+ flow_model = self.models['sparse_structure_flow_model']
127
+ reso = flow_model.resolution
128
+ noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
129
+ sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
130
+ z_s = self.sparse_structure_sampler.sample(
131
+ flow_model,
132
+ noise,
133
+ **cond,
134
+ **sampler_params,
135
+ verbose=True
136
+ ).samples
137
+
138
+ # Decode occupancy latent
139
+ decoder = self.models['sparse_structure_decoder']
140
+ coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
141
+
142
+ return coords
143
+
144
+ def decode_slat(
145
+ self,
146
+ slat: sp.SparseTensor,
147
+ formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
148
+ ) -> dict:
149
+ """
150
+ Decode the structured latent.
151
+
152
+ Args:
153
+ slat (sp.SparseTensor): The structured latent.
154
+ formats (List[str]): The formats to decode the structured latent to.
155
+
156
+ Returns:
157
+ dict: The decoded structured latent.
158
+ """
159
+ ret = {}
160
+ if 'mesh' in formats:
161
+ ret['mesh'] = self.models['slat_decoder_mesh'](slat)
162
+ if 'gaussian' in formats:
163
+ ret['gaussian'] = self.models['slat_decoder_gs'](slat)
164
+ if 'radiance_field' in formats:
165
+ ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
166
+ return ret
167
+
168
+ def sample_slat(
169
+ self,
170
+ cond: dict,
171
+ coords: torch.Tensor,
172
+ sampler_params: dict = {},
173
+ ) -> sp.SparseTensor:
174
+ """
175
+ Sample structured latent with the given conditioning.
176
+
177
+ Args:
178
+ cond (dict): The conditioning information.
179
+ coords (torch.Tensor): The coordinates of the sparse structure.
180
+ sampler_params (dict): Additional parameters for the sampler.
181
+ """
182
+ # Sample structured latent
183
+ flow_model = self.models['slat_flow_model']
184
+ noise = sp.SparseTensor(
185
+ feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
186
+ coords=coords,
187
+ )
188
+ sampler_params = {**self.slat_sampler_params, **sampler_params}
189
+ slat = self.slat_sampler.sample(
190
+ flow_model,
191
+ noise,
192
+ **cond,
193
+ **sampler_params,
194
+ verbose=True
195
+ ).samples
196
+
197
+ std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
198
+ mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
199
+ slat = slat * std + mean
200
+
201
+ return slat
202
+
203
+ @torch.no_grad()
204
+ def run(
205
+ self,
206
+ prompt: str,
207
+ num_samples: int = 1,
208
+ seed: int = 42,
209
+ sparse_structure_sampler_params: dict = {},
210
+ slat_sampler_params: dict = {},
211
+ formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
212
+ ) -> dict:
213
+ """
214
+ Run the pipeline.
215
+
216
+ Args:
217
+ prompt (str): The text prompt.
218
+ num_samples (int): The number of samples to generate.
219
+ seed (int): The random seed.
220
+ sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
221
+ slat_sampler_params (dict): Additional parameters for the structured latent sampler.
222
+ formats (List[str]): The formats to decode the structured latent to.
223
+ """
224
+ cond = self.get_cond([prompt])
225
+ torch.manual_seed(seed)
226
+ coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
227
+ slat = self.sample_slat(cond, coords, slat_sampler_params)
228
+ return self.decode_slat(slat, formats)