cavargas10 commited on
Commit
04fa6ac
·
verified ·
1 Parent(s): 8471ff9

Upload 13 files

Browse files
trellis/trainers/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'BasicTrainer': 'basic',
5
+
6
+ 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
7
+
8
+ 'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
9
+ 'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
10
+ 'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
11
+
12
+ 'FlowMatchingTrainer': 'flow_matching.flow_matching',
13
+ 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
14
+ 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
15
+ 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
16
+
17
+ 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
18
+ 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
19
+ 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
20
+ 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
21
+ }
22
+
23
+ __submodules = []
24
+
25
+ __all__ = list(__attributes.keys()) + __submodules
26
+
27
+ def __getattr__(name):
28
+ if name not in globals():
29
+ if name in __attributes:
30
+ module_name = __attributes[name]
31
+ module = importlib.import_module(f".{module_name}", __name__)
32
+ globals()[name] = getattr(module, name)
33
+ elif name in __submodules:
34
+ module = importlib.import_module(f".{name}", __name__)
35
+ globals()[name] = module
36
+ else:
37
+ raise AttributeError(f"module {__name__} has no attribute {name}")
38
+ return globals()[name]
39
+
40
+
41
+ # For Pylance
42
+ if __name__ == '__main__':
43
+ from .basic import BasicTrainer
44
+
45
+ from .vae.sparse_structure_vae import SparseStructureVaeTrainer
46
+
47
+ from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
48
+ from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
49
+ from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
50
+
51
+ from .flow_matching.flow_matching import (
52
+ FlowMatchingTrainer,
53
+ FlowMatchingCFGTrainer,
54
+ TextConditionedFlowMatchingCFGTrainer,
55
+ ImageConditionedFlowMatchingCFGTrainer,
56
+ )
57
+
58
+ from .flow_matching.sparse_flow_matching import (
59
+ SparseFlowMatchingTrainer,
60
+ SparseFlowMatchingCFGTrainer,
61
+ TextConditionedSparseFlowMatchingCFGTrainer,
62
+ ImageConditionedSparseFlowMatchingCFGTrainer,
63
+ )
trellis/trainers/base.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import os
3
+ import time
4
+ import json
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch.utils.data import DataLoader
9
+ import numpy as np
10
+
11
+ from torchvision import utils
12
+ from torch.utils.tensorboard import SummaryWriter
13
+
14
+ from .utils import *
15
+ from ..utils.general_utils import *
16
+ from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
17
+
18
+
19
+ class Trainer:
20
+ """
21
+ Base class for training.
22
+ """
23
+ def __init__(self,
24
+ models,
25
+ dataset,
26
+ *,
27
+ output_dir,
28
+ load_dir,
29
+ step,
30
+ max_steps,
31
+ batch_size=None,
32
+ batch_size_per_gpu=None,
33
+ batch_split=None,
34
+ optimizer={},
35
+ lr_scheduler=None,
36
+ elastic=None,
37
+ grad_clip=None,
38
+ ema_rate=0.9999,
39
+ fp16_mode='inflat_all',
40
+ fp16_scale_growth=1e-3,
41
+ finetune_ckpt=None,
42
+ log_param_stats=False,
43
+ prefetch_data=True,
44
+ i_print=1000,
45
+ i_log=500,
46
+ i_sample=10000,
47
+ i_save=10000,
48
+ i_ddpcheck=10000,
49
+ **kwargs
50
+ ):
51
+ assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
52
+
53
+ self.models = models
54
+ self.dataset = dataset
55
+ self.batch_split = batch_split if batch_split is not None else 1
56
+ self.max_steps = max_steps
57
+ self.optimizer_config = optimizer
58
+ self.lr_scheduler_config = lr_scheduler
59
+ self.elastic_controller_config = elastic
60
+ self.grad_clip = grad_clip
61
+ self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
62
+ self.fp16_mode = fp16_mode
63
+ self.fp16_scale_growth = fp16_scale_growth
64
+ self.log_param_stats = log_param_stats
65
+ self.prefetch_data = prefetch_data
66
+ if self.prefetch_data:
67
+ self._data_prefetched = None
68
+
69
+ self.output_dir = output_dir
70
+ self.i_print = i_print
71
+ self.i_log = i_log
72
+ self.i_sample = i_sample
73
+ self.i_save = i_save
74
+ self.i_ddpcheck = i_ddpcheck
75
+
76
+ if dist.is_initialized():
77
+ # Multi-GPU params
78
+ self.world_size = dist.get_world_size()
79
+ self.rank = dist.get_rank()
80
+ self.local_rank = dist.get_rank() % torch.cuda.device_count()
81
+ self.is_master = self.rank == 0
82
+ else:
83
+ # Single-GPU params
84
+ self.world_size = 1
85
+ self.rank = 0
86
+ self.local_rank = 0
87
+ self.is_master = True
88
+
89
+ self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
90
+ self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
91
+ assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
92
+ assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
93
+
94
+ self.init_models_and_more(**kwargs)
95
+ self.prepare_dataloader(**kwargs)
96
+
97
+ # Load checkpoint
98
+ self.step = 0
99
+ if load_dir is not None and step is not None:
100
+ self.load(load_dir, step)
101
+ elif finetune_ckpt is not None:
102
+ self.finetune_from(finetune_ckpt)
103
+
104
+ if self.is_master:
105
+ os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
106
+ os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
107
+ self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
108
+
109
+ if self.world_size > 1:
110
+ self.check_ddp()
111
+
112
+ if self.is_master:
113
+ print('\n\nTrainer initialized.')
114
+ print(self)
115
+
116
+ @property
117
+ def device(self):
118
+ for _, model in self.models.items():
119
+ if hasattr(model, 'device'):
120
+ return model.device
121
+ return next(list(self.models.values())[0].parameters()).device
122
+
123
+ @abstractmethod
124
+ def init_models_and_more(self, **kwargs):
125
+ """
126
+ Initialize models and more.
127
+ """
128
+ pass
129
+
130
+ def prepare_dataloader(self, **kwargs):
131
+ """
132
+ Prepare dataloader.
133
+ """
134
+ self.data_sampler = ResumableSampler(
135
+ self.dataset,
136
+ shuffle=True,
137
+ )
138
+ self.dataloader = DataLoader(
139
+ self.dataset,
140
+ batch_size=self.batch_size_per_gpu,
141
+ num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
142
+ pin_memory=True,
143
+ drop_last=True,
144
+ persistent_workers=True,
145
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
146
+ sampler=self.data_sampler,
147
+ )
148
+ self.data_iterator = cycle(self.dataloader)
149
+
150
+ @abstractmethod
151
+ def load(self, load_dir, step=0):
152
+ """
153
+ Load a checkpoint.
154
+ Should be called by all processes.
155
+ """
156
+ pass
157
+
158
+ @abstractmethod
159
+ def save(self):
160
+ """
161
+ Save a checkpoint.
162
+ Should be called only by the rank 0 process.
163
+ """
164
+ pass
165
+
166
+ @abstractmethod
167
+ def finetune_from(self, finetune_ckpt):
168
+ """
169
+ Finetune from a checkpoint.
170
+ Should be called by all processes.
171
+ """
172
+ pass
173
+
174
+ @abstractmethod
175
+ def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
176
+ """
177
+ Run a snapshot of the model.
178
+ """
179
+ pass
180
+
181
+ @torch.no_grad()
182
+ def visualize_sample(self, sample):
183
+ """
184
+ Convert a sample to an image.
185
+ """
186
+ if hasattr(self.dataset, 'visualize_sample'):
187
+ return self.dataset.visualize_sample(sample)
188
+ else:
189
+ return sample
190
+
191
+ @torch.no_grad()
192
+ def snapshot_dataset(self, num_samples=100):
193
+ """
194
+ Sample images from the dataset.
195
+ """
196
+ dataloader = torch.utils.data.DataLoader(
197
+ self.dataset,
198
+ batch_size=num_samples,
199
+ num_workers=0,
200
+ shuffle=True,
201
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
202
+ )
203
+ data = next(iter(dataloader))
204
+ data = recursive_to_device(data, self.device)
205
+ vis = self.visualize_sample(data)
206
+ if isinstance(vis, dict):
207
+ save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
208
+ else:
209
+ save_cfg = [('dataset', vis)]
210
+ for name, image in save_cfg:
211
+ utils.save_image(
212
+ image,
213
+ os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
214
+ nrow=int(np.sqrt(num_samples)),
215
+ normalize=True,
216
+ value_range=self.dataset.value_range,
217
+ )
218
+
219
+ @torch.no_grad()
220
+ def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
221
+ """
222
+ Sample images from the model.
223
+ NOTE: This function should be called by all processes.
224
+ """
225
+ if self.is_master:
226
+ print(f'\nSampling {num_samples} images...', end='')
227
+
228
+ if suffix is None:
229
+ suffix = f'step{self.step:07d}'
230
+
231
+ # Assign tasks
232
+ num_samples_per_process = int(np.ceil(num_samples / self.world_size))
233
+ samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
234
+
235
+ # Preprocess images
236
+ for key in list(samples.keys()):
237
+ if samples[key]['type'] == 'sample':
238
+ vis = self.visualize_sample(samples[key]['value'])
239
+ if isinstance(vis, dict):
240
+ for k, v in vis.items():
241
+ samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
242
+ del samples[key]
243
+ else:
244
+ samples[key] = {'value': vis, 'type': 'image'}
245
+
246
+ # Gather results
247
+ if self.world_size > 1:
248
+ for key in samples.keys():
249
+ samples[key]['value'] = samples[key]['value'].contiguous()
250
+ if self.is_master:
251
+ all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
252
+ else:
253
+ all_images = []
254
+ dist.gather(samples[key]['value'], all_images, dst=0)
255
+ if self.is_master:
256
+ samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
257
+
258
+ # Save images
259
+ if self.is_master:
260
+ os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
261
+ for key in samples.keys():
262
+ if samples[key]['type'] == 'image':
263
+ utils.save_image(
264
+ samples[key]['value'],
265
+ os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
266
+ nrow=int(np.sqrt(num_samples)),
267
+ normalize=True,
268
+ value_range=self.dataset.value_range,
269
+ )
270
+ elif samples[key]['type'] == 'number':
271
+ min = samples[key]['value'].min()
272
+ max = samples[key]['value'].max()
273
+ images = (samples[key]['value'] - min) / (max - min)
274
+ images = utils.make_grid(
275
+ images,
276
+ nrow=int(np.sqrt(num_samples)),
277
+ normalize=False,
278
+ )
279
+ save_image_with_notes(
280
+ images,
281
+ os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
282
+ notes=f'{key} min: {min}, max: {max}',
283
+ )
284
+
285
+ if self.is_master:
286
+ print(' Done.')
287
+
288
+ @abstractmethod
289
+ def update_ema(self):
290
+ """
291
+ Update exponential moving average.
292
+ Should only be called by the rank 0 process.
293
+ """
294
+ pass
295
+
296
+ @abstractmethod
297
+ def check_ddp(self):
298
+ """
299
+ Check if DDP is working properly.
300
+ Should be called by all process.
301
+ """
302
+ pass
303
+
304
+ @abstractmethod
305
+ def training_losses(**mb_data):
306
+ """
307
+ Compute training losses.
308
+ """
309
+ pass
310
+
311
+ def load_data(self):
312
+ """
313
+ Load data.
314
+ """
315
+ if self.prefetch_data:
316
+ if self._data_prefetched is None:
317
+ self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
318
+ data = self._data_prefetched
319
+ self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
320
+ else:
321
+ data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
322
+
323
+ # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
324
+ if isinstance(data, dict):
325
+ if self.batch_split == 1:
326
+ data_list = [data]
327
+ else:
328
+ batch_size = list(data.values())[0].shape[0]
329
+ data_list = [
330
+ {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
331
+ for i in range(self.batch_split)
332
+ ]
333
+ elif isinstance(data, list):
334
+ data_list = data
335
+ else:
336
+ raise ValueError('Data must be a dict or a list of dicts.')
337
+
338
+ return data_list
339
+
340
+ @abstractmethod
341
+ def run_step(self, data_list):
342
+ """
343
+ Run a training step.
344
+ """
345
+ pass
346
+
347
+ def run(self):
348
+ """
349
+ Run training.
350
+ """
351
+ if self.is_master:
352
+ print('\nStarting training...')
353
+ self.snapshot_dataset()
354
+ if self.step == 0:
355
+ self.snapshot(suffix='init')
356
+ else: # resume
357
+ self.snapshot(suffix=f'resume_step{self.step:07d}')
358
+
359
+ log = []
360
+ time_last_print = 0.0
361
+ time_elapsed = 0.0
362
+ while self.step < self.max_steps:
363
+ time_start = time.time()
364
+
365
+ data_list = self.load_data()
366
+ step_log = self.run_step(data_list)
367
+
368
+ time_end = time.time()
369
+ time_elapsed += time_end - time_start
370
+
371
+ self.step += 1
372
+
373
+ # Print progress
374
+ if self.is_master and self.step % self.i_print == 0:
375
+ speed = self.i_print / (time_elapsed - time_last_print) * 3600
376
+ columns = [
377
+ f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
378
+ f'Elapsed: {time_elapsed / 3600:.2f} h',
379
+ f'Speed: {speed:.2f} steps/h',
380
+ f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
381
+ ]
382
+ print(' | '.join([c.ljust(25) for c in columns]), flush=True)
383
+ time_last_print = time_elapsed
384
+
385
+ # Check ddp
386
+ if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
387
+ self.check_ddp()
388
+
389
+ # Sample images
390
+ if self.step % self.i_sample == 0:
391
+ self.snapshot()
392
+
393
+ if self.is_master:
394
+ log.append((self.step, {}))
395
+
396
+ # Log time
397
+ log[-1][1]['time'] = {
398
+ 'step': time_end - time_start,
399
+ 'elapsed': time_elapsed,
400
+ }
401
+
402
+ # Log losses
403
+ if step_log is not None:
404
+ log[-1][1].update(step_log)
405
+
406
+ # Log scale
407
+ if self.fp16_mode == 'amp':
408
+ log[-1][1]['scale'] = self.scaler.get_scale()
409
+ elif self.fp16_mode == 'inflat_all':
410
+ log[-1][1]['log_scale'] = self.log_scale
411
+
412
+ # Save log
413
+ if self.step % self.i_log == 0:
414
+ ## save to log file
415
+ log_str = '\n'.join([
416
+ f'{step}: {json.dumps(log)}' for step, log in log
417
+ ])
418
+ with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
419
+ log_file.write(log_str + '\n')
420
+
421
+ # show with mlflow
422
+ log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
423
+ log_show = dict_reduce(log_show, lambda x: np.mean(x))
424
+ log_show = dict_flatten(log_show, sep='/')
425
+ for key, value in log_show.items():
426
+ self.writer.add_scalar(key, value, self.step)
427
+ log = []
428
+
429
+ # Save checkpoint
430
+ if self.step % self.i_save == 0:
431
+ self.save()
432
+
433
+ if self.is_master:
434
+ self.snapshot(suffix='final')
435
+ self.writer.close()
436
+ print('Training finished.')
437
+
438
+ def profile(self, wait=2, warmup=3, active=5):
439
+ """
440
+ Profile the training loop.
441
+ """
442
+ with torch.profiler.profile(
443
+ schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
444
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
445
+ profile_memory=True,
446
+ with_stack=True,
447
+ ) as prof:
448
+ for _ in range(wait + warmup + active):
449
+ self.run_step()
450
+ prof.step()
451
+
trellis/trainers/basic.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ from functools import partial
4
+ from contextlib import nullcontext
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ import numpy as np
10
+
11
+ from .utils import *
12
+ from .base import Trainer
13
+ from ..utils.general_utils import *
14
+ from ..utils.dist_utils import *
15
+ from ..utils import grad_clip_utils, elastic_utils
16
+
17
+
18
+ class BasicTrainer(Trainer):
19
+ """
20
+ Trainer for basic training loop.
21
+
22
+ Args:
23
+ models (dict[str, nn.Module]): Models to train.
24
+ dataset (torch.utils.data.Dataset): Dataset.
25
+ output_dir (str): Output directory.
26
+ load_dir (str): Load directory.
27
+ step (int): Step to load.
28
+ batch_size (int): Batch size.
29
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
30
+ batch_split (int): Split batch with gradient accumulation.
31
+ max_steps (int): Max steps.
32
+ optimizer (dict): Optimizer config.
33
+ lr_scheduler (dict): Learning rate scheduler config.
34
+ elastic (dict): Elastic memory management config.
35
+ grad_clip (float or dict): Gradient clip config.
36
+ ema_rate (float or list): Exponential moving average rates.
37
+ fp16_mode (str): FP16 mode.
38
+ - None: No FP16.
39
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
40
+ - 'amp': Automatic mixed precision.
41
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
42
+ finetune_ckpt (dict): Finetune checkpoint.
43
+ log_param_stats (bool): Log parameter stats.
44
+ i_print (int): Print interval.
45
+ i_log (int): Log interval.
46
+ i_sample (int): Sample interval.
47
+ i_save (int): Save interval.
48
+ i_ddpcheck (int): DDP check interval.
49
+ """
50
+
51
+ def __str__(self):
52
+ lines = []
53
+ lines.append(self.__class__.__name__)
54
+ lines.append(f' - Models:')
55
+ for name, model in self.models.items():
56
+ lines.append(f' - {name}: {model.__class__.__name__}')
57
+ lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
58
+ lines.append(f' - Dataloader:')
59
+ lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
60
+ lines.append(f' - Num workers: {self.dataloader.num_workers}')
61
+ lines.append(f' - Number of steps: {self.max_steps}')
62
+ lines.append(f' - Number of GPUs: {self.world_size}')
63
+ lines.append(f' - Batch size: {self.batch_size}')
64
+ lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
65
+ lines.append(f' - Batch split: {self.batch_split}')
66
+ lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
67
+ lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
68
+ if self.lr_scheduler_config is not None:
69
+ lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
70
+ if self.elastic_controller_config is not None:
71
+ lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
72
+ if self.grad_clip is not None:
73
+ lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
74
+ lines.append(f' - EMA rate: {self.ema_rate}')
75
+ lines.append(f' - FP16 mode: {self.fp16_mode}')
76
+ return '\n'.join(lines)
77
+
78
+ def init_models_and_more(self, **kwargs):
79
+ """
80
+ Initialize models and more.
81
+ """
82
+ if self.world_size > 1:
83
+ # Prepare distributed data parallel
84
+ self.training_models = {
85
+ name: DDP(
86
+ model,
87
+ device_ids=[self.local_rank],
88
+ output_device=self.local_rank,
89
+ bucket_cap_mb=128,
90
+ find_unused_parameters=False
91
+ )
92
+ for name, model in self.models.items()
93
+ }
94
+ else:
95
+ self.training_models = self.models
96
+
97
+ # Build master params
98
+ self.model_params = sum(
99
+ [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
100
+ , [])
101
+ if self.fp16_mode == 'amp':
102
+ self.master_params = self.model_params
103
+ self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
104
+ elif self.fp16_mode == 'inflat_all':
105
+ self.master_params = make_master_params(self.model_params)
106
+ self.fp16_scale_growth = self.fp16_scale_growth
107
+ self.log_scale = 20.0
108
+ elif self.fp16_mode is None:
109
+ self.master_params = self.model_params
110
+ else:
111
+ raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
112
+
113
+ # Build EMA params
114
+ if self.is_master:
115
+ self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
116
+
117
+ # Initialize optimizer
118
+ if hasattr(torch.optim, self.optimizer_config['name']):
119
+ self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
120
+ else:
121
+ self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
122
+
123
+ # Initalize learning rate scheduler
124
+ if self.lr_scheduler_config is not None:
125
+ if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
126
+ self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
127
+ else:
128
+ self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
129
+
130
+ # Initialize elastic memory controller
131
+ if self.elastic_controller_config is not None:
132
+ assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
133
+ 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
134
+ self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
135
+ for model in self.models.values():
136
+ if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
137
+ model.register_memory_controller(self.elastic_controller)
138
+
139
+ # Initialize gradient clipper
140
+ if self.grad_clip is not None:
141
+ if isinstance(self.grad_clip, (float, int)):
142
+ self.grad_clip = float(self.grad_clip)
143
+ else:
144
+ self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
145
+
146
+ def _master_params_to_state_dicts(self, master_params):
147
+ """
148
+ Convert master params to dict of state_dicts.
149
+ """
150
+ if self.fp16_mode == 'inflat_all':
151
+ master_params = unflatten_master_params(self.model_params, master_params)
152
+ state_dicts = {name: model.state_dict() for name, model in self.models.items()}
153
+ master_params_names = sum(
154
+ [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
155
+ , [])
156
+ for i, (model_name, param_name) in enumerate(master_params_names):
157
+ state_dicts[model_name][param_name] = master_params[i]
158
+ return state_dicts
159
+
160
+ def _state_dicts_to_master_params(self, master_params, state_dicts):
161
+ """
162
+ Convert a state_dict to master params.
163
+ """
164
+ master_params_names = sum(
165
+ [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
166
+ , [])
167
+ params = [state_dicts[name][param_name] for name, param_name in master_params_names]
168
+ if self.fp16_mode == 'inflat_all':
169
+ model_params_to_master_params(params, master_params)
170
+ else:
171
+ for i, param in enumerate(params):
172
+ master_params[i].data.copy_(param.data)
173
+
174
+ def load(self, load_dir, step=0):
175
+ """
176
+ Load a checkpoint.
177
+ Should be called by all processes.
178
+ """
179
+ if self.is_master:
180
+ print(f'\nLoading checkpoint from step {step}...', end='')
181
+
182
+ model_ckpts = {}
183
+ for name, model in self.models.items():
184
+ model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
185
+ model_ckpts[name] = model_ckpt
186
+ model.load_state_dict(model_ckpt)
187
+ if self.fp16_mode == 'inflat_all':
188
+ model.convert_to_fp16()
189
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
190
+ del model_ckpts
191
+
192
+ if self.is_master:
193
+ for i, ema_rate in enumerate(self.ema_rate):
194
+ ema_ckpts = {}
195
+ for name, model in self.models.items():
196
+ ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
197
+ ema_ckpts[name] = ema_ckpt
198
+ self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
199
+ del ema_ckpts
200
+
201
+ misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
202
+ self.optimizer.load_state_dict(misc_ckpt['optimizer'])
203
+ self.step = misc_ckpt['step']
204
+ self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
205
+ if self.fp16_mode == 'amp':
206
+ self.scaler.load_state_dict(misc_ckpt['scaler'])
207
+ elif self.fp16_mode == 'inflat_all':
208
+ self.log_scale = misc_ckpt['log_scale']
209
+ if self.lr_scheduler_config is not None:
210
+ self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
211
+ if self.elastic_controller_config is not None:
212
+ self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
213
+ if self.grad_clip is not None and not isinstance(self.grad_clip, float):
214
+ self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
215
+ del misc_ckpt
216
+
217
+ if self.world_size > 1:
218
+ dist.barrier()
219
+ if self.is_master:
220
+ print(' Done.')
221
+
222
+ if self.world_size > 1:
223
+ self.check_ddp()
224
+
225
+ def save(self):
226
+ """
227
+ Save a checkpoint.
228
+ Should be called only by the rank 0 process.
229
+ """
230
+ assert self.is_master, 'save() should be called only by the rank 0 process.'
231
+ print(f'\nSaving checkpoint at step {self.step}...', end='')
232
+
233
+ model_ckpts = self._master_params_to_state_dicts(self.master_params)
234
+ for name, model_ckpt in model_ckpts.items():
235
+ torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
236
+
237
+ for i, ema_rate in enumerate(self.ema_rate):
238
+ ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
239
+ for name, ema_ckpt in ema_ckpts.items():
240
+ torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
241
+
242
+ misc_ckpt = {
243
+ 'optimizer': self.optimizer.state_dict(),
244
+ 'step': self.step,
245
+ 'data_sampler': self.data_sampler.state_dict(),
246
+ }
247
+ if self.fp16_mode == 'amp':
248
+ misc_ckpt['scaler'] = self.scaler.state_dict()
249
+ elif self.fp16_mode == 'inflat_all':
250
+ misc_ckpt['log_scale'] = self.log_scale
251
+ if self.lr_scheduler_config is not None:
252
+ misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
253
+ if self.elastic_controller_config is not None:
254
+ misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
255
+ if self.grad_clip is not None and not isinstance(self.grad_clip, float):
256
+ misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
257
+ torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
258
+ print(' Done.')
259
+
260
+ def finetune_from(self, finetune_ckpt):
261
+ """
262
+ Finetune from a checkpoint.
263
+ Should be called by all processes.
264
+ """
265
+ if self.is_master:
266
+ print('\nFinetuning from:')
267
+ for name, path in finetune_ckpt.items():
268
+ print(f' - {name}: {path}')
269
+
270
+ model_ckpts = {}
271
+ for name, model in self.models.items():
272
+ model_state_dict = model.state_dict()
273
+ if name in finetune_ckpt:
274
+ model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
275
+ for k, v in model_ckpt.items():
276
+ if model_ckpt[k].shape != model_state_dict[k].shape:
277
+ if self.is_master:
278
+ print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
279
+ model_ckpt[k] = model_state_dict[k]
280
+ model_ckpts[name] = model_ckpt
281
+ model.load_state_dict(model_ckpt)
282
+ if self.fp16_mode == 'inflat_all':
283
+ model.convert_to_fp16()
284
+ else:
285
+ if self.is_master:
286
+ print(f'Warning: {name} not found in finetune_ckpt, skipped.')
287
+ model_ckpts[name] = model_state_dict
288
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
289
+ del model_ckpts
290
+
291
+ if self.world_size > 1:
292
+ dist.barrier()
293
+ if self.is_master:
294
+ print('Done.')
295
+
296
+ if self.world_size > 1:
297
+ self.check_ddp()
298
+
299
+ def update_ema(self):
300
+ """
301
+ Update exponential moving average.
302
+ Should only be called by the rank 0 process.
303
+ """
304
+ assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
305
+ for i, ema_rate in enumerate(self.ema_rate):
306
+ for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
307
+ ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
308
+
309
+ def check_ddp(self):
310
+ """
311
+ Check if DDP is working properly.
312
+ Should be called by all process.
313
+ """
314
+ if self.is_master:
315
+ print('\nPerforming DDP check...')
316
+
317
+ if self.is_master:
318
+ print('Checking if parameters are consistent across processes...')
319
+ dist.barrier()
320
+ try:
321
+ for p in self.master_params:
322
+ # split to avoid OOM
323
+ for i in range(0, p.numel(), 10000000):
324
+ sub_size = min(10000000, p.numel() - i)
325
+ sub_p = p.detach().view(-1)[i:i+sub_size]
326
+ # gather from all processes
327
+ sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
328
+ dist.all_gather(sub_p_gather, sub_p)
329
+ # check if equal
330
+ assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
331
+ except AssertionError as e:
332
+ if self.is_master:
333
+ print(f'\n\033[91mError: {e}\033[0m')
334
+ print('DDP check failed.')
335
+ raise e
336
+
337
+ dist.barrier()
338
+ if self.is_master:
339
+ print('Done.')
340
+
341
+ def run_step(self, data_list):
342
+ """
343
+ Run a training step.
344
+ """
345
+ step_log = {'loss': {}, 'status': {}}
346
+ amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
347
+ elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
348
+
349
+ # Train
350
+ losses = []
351
+ statuses = []
352
+ elastic_controller_logs = []
353
+ zero_grad(self.model_params)
354
+ for i, mb_data in enumerate(data_list):
355
+ ## sync at the end of each batch split
356
+ sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
357
+ with nested_contexts(*sync_contexts), elastic_controller_context():
358
+ with amp_context():
359
+ loss, status = self.training_losses(**mb_data)
360
+ l = loss['loss'] / len(data_list)
361
+ ## backward
362
+ if self.fp16_mode == 'amp':
363
+ self.scaler.scale(l).backward()
364
+ elif self.fp16_mode == 'inflat_all':
365
+ scaled_l = l * (2 ** self.log_scale)
366
+ scaled_l.backward()
367
+ else:
368
+ l.backward()
369
+ ## log
370
+ losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
371
+ statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
372
+ if self.elastic_controller_config is not None:
373
+ elastic_controller_logs.append(self.elastic_controller.log())
374
+ ## gradient clip
375
+ if self.grad_clip is not None:
376
+ if self.fp16_mode == 'amp':
377
+ self.scaler.unscale_(self.optimizer)
378
+ elif self.fp16_mode == 'inflat_all':
379
+ model_grads_to_master_grads(self.model_params, self.master_params)
380
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
381
+ if isinstance(self.grad_clip, float):
382
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
383
+ else:
384
+ grad_norm = self.grad_clip(self.master_params)
385
+ if torch.isfinite(grad_norm):
386
+ statuses[-1]['grad_norm'] = grad_norm.item()
387
+ ## step
388
+ if self.fp16_mode == 'amp':
389
+ prev_scale = self.scaler.get_scale()
390
+ self.scaler.step(self.optimizer)
391
+ self.scaler.update()
392
+ elif self.fp16_mode == 'inflat_all':
393
+ prev_scale = 2 ** self.log_scale
394
+ if not any(not p.grad.isfinite().all() for p in self.model_params):
395
+ if self.grad_clip is None:
396
+ model_grads_to_master_grads(self.model_params, self.master_params)
397
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
398
+ self.optimizer.step()
399
+ master_params_to_model_params(self.model_params, self.master_params)
400
+ self.log_scale += self.fp16_scale_growth
401
+ else:
402
+ self.log_scale -= 1
403
+ else:
404
+ prev_scale = 1.0
405
+ if not any(not p.grad.isfinite().all() for p in self.model_params):
406
+ self.optimizer.step()
407
+ else:
408
+ print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
409
+ ## adjust learning rate
410
+ if self.lr_scheduler_config is not None:
411
+ statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
412
+ self.lr_scheduler.step()
413
+
414
+ # Logs
415
+ step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
416
+ step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
417
+ if self.elastic_controller_config is not None:
418
+ step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
419
+ if self.grad_clip is not None:
420
+ step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
421
+
422
+ # Check grad and norm of each param
423
+ if self.log_param_stats:
424
+ param_norms = {}
425
+ param_grads = {}
426
+ for name, param in self.backbone.named_parameters():
427
+ if param.requires_grad:
428
+ param_norms[name] = param.norm().item()
429
+ if param.grad is not None and torch.isfinite(param.grad).all():
430
+ param_grads[name] = param.grad.norm().item() / prev_scale
431
+ step_log['param_norms'] = param_norms
432
+ step_log['param_grads'] = param_grads
433
+
434
+ # Update exponential moving average
435
+ if self.is_master:
436
+ self.update_ema()
437
+
438
+ return step_log
trellis/trainers/flow_matching/flow_matching.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ import numpy as np
7
+ from easydict import EasyDict as edict
8
+
9
+ from ..basic import BasicTrainer
10
+ from ...pipelines import samplers
11
+ from ...utils.general_utils import dict_reduce
12
+ from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
13
+ from .mixins.text_conditioned import TextConditionedMixin
14
+ from .mixins.image_conditioned import ImageConditionedMixin
15
+
16
+
17
+ class FlowMatchingTrainer(BasicTrainer):
18
+ """
19
+ Trainer for diffusion model with flow matching objective.
20
+
21
+ Args:
22
+ models (dict[str, nn.Module]): Models to train.
23
+ dataset (torch.utils.data.Dataset): Dataset.
24
+ output_dir (str): Output directory.
25
+ load_dir (str): Load directory.
26
+ step (int): Step to load.
27
+ batch_size (int): Batch size.
28
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
29
+ batch_split (int): Split batch with gradient accumulation.
30
+ max_steps (int): Max steps.
31
+ optimizer (dict): Optimizer config.
32
+ lr_scheduler (dict): Learning rate scheduler config.
33
+ elastic (dict): Elastic memory management config.
34
+ grad_clip (float or dict): Gradient clip config.
35
+ ema_rate (float or list): Exponential moving average rates.
36
+ fp16_mode (str): FP16 mode.
37
+ - None: No FP16.
38
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
39
+ - 'amp': Automatic mixed precision.
40
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
41
+ finetune_ckpt (dict): Finetune checkpoint.
42
+ log_param_stats (bool): Log parameter stats.
43
+ i_print (int): Print interval.
44
+ i_log (int): Log interval.
45
+ i_sample (int): Sample interval.
46
+ i_save (int): Save interval.
47
+ i_ddpcheck (int): DDP check interval.
48
+
49
+ t_schedule (dict): Time schedule for flow matching.
50
+ sigma_min (float): Minimum noise level.
51
+ """
52
+ def __init__(
53
+ self,
54
+ *args,
55
+ t_schedule: dict = {
56
+ 'name': 'logitNormal',
57
+ 'args': {
58
+ 'mean': 0.0,
59
+ 'std': 1.0,
60
+ }
61
+ },
62
+ sigma_min: float = 1e-5,
63
+ **kwargs
64
+ ):
65
+ super().__init__(*args, **kwargs)
66
+ self.t_schedule = t_schedule
67
+ self.sigma_min = sigma_min
68
+
69
+ def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
70
+ """
71
+ Diffuse the data for a given number of diffusion steps.
72
+ In other words, sample from q(x_t | x_0).
73
+
74
+ Args:
75
+ x_0: The [N x C x ...] tensor of noiseless inputs.
76
+ t: The [N] tensor of diffusion steps [0-1].
77
+ noise: If specified, use this noise instead of generating new noise.
78
+
79
+ Returns:
80
+ x_t, the noisy version of x_0 under timestep t.
81
+ """
82
+ if noise is None:
83
+ noise = torch.randn_like(x_0)
84
+ assert noise.shape == x_0.shape, "noise must have same shape as x_0"
85
+
86
+ t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
87
+ x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
88
+
89
+ return x_t
90
+
91
+ def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Get original image from noisy version under timestep t.
94
+ """
95
+ assert noise.shape == x_t.shape, "noise must have same shape as x_t"
96
+ t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
97
+ x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
98
+ return x_0
99
+
100
+ def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Compute the velocity of the diffusion process at time t.
103
+ """
104
+ return (1 - self.sigma_min) * noise - x_0
105
+
106
+ def get_cond(self, cond, **kwargs):
107
+ """
108
+ Get the conditioning data.
109
+ """
110
+ return cond
111
+
112
+ def get_inference_cond(self, cond, **kwargs):
113
+ """
114
+ Get the conditioning data for inference.
115
+ """
116
+ return {'cond': cond, **kwargs}
117
+
118
+ def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
119
+ """
120
+ Get the sampler for the diffusion process.
121
+ """
122
+ return samplers.FlowEulerSampler(self.sigma_min)
123
+
124
+ def vis_cond(self, **kwargs):
125
+ """
126
+ Visualize the conditioning data.
127
+ """
128
+ return {}
129
+
130
+ def sample_t(self, batch_size: int) -> torch.Tensor:
131
+ """
132
+ Sample timesteps.
133
+ """
134
+ if self.t_schedule['name'] == 'uniform':
135
+ t = torch.rand(batch_size)
136
+ elif self.t_schedule['name'] == 'logitNormal':
137
+ mean = self.t_schedule['args']['mean']
138
+ std = self.t_schedule['args']['std']
139
+ t = torch.sigmoid(torch.randn(batch_size) * std + mean)
140
+ else:
141
+ raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
142
+ return t
143
+
144
+ def training_losses(
145
+ self,
146
+ x_0: torch.Tensor,
147
+ cond=None,
148
+ **kwargs
149
+ ) -> Tuple[Dict, Dict]:
150
+ """
151
+ Compute training losses for a single timestep.
152
+
153
+ Args:
154
+ x_0: The [N x C x ...] tensor of noiseless inputs.
155
+ cond: The [N x ...] tensor of additional conditions.
156
+ kwargs: Additional arguments to pass to the backbone.
157
+
158
+ Returns:
159
+ a dict with the key "loss" containing a tensor of shape [N].
160
+ may also contain other keys for different terms.
161
+ """
162
+ noise = torch.randn_like(x_0)
163
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
164
+ x_t = self.diffuse(x_0, t, noise=noise)
165
+ cond = self.get_cond(cond, **kwargs)
166
+
167
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
168
+ assert pred.shape == noise.shape == x_0.shape
169
+ target = self.get_v(x_0, noise, t)
170
+ terms = edict()
171
+ terms["mse"] = F.mse_loss(pred, target)
172
+ terms["loss"] = terms["mse"]
173
+
174
+ # log loss with time bins
175
+ mse_per_instance = np.array([
176
+ F.mse_loss(pred[i], target[i]).item()
177
+ for i in range(x_0.shape[0])
178
+ ])
179
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
180
+ for i in range(10):
181
+ if (time_bin == i).sum() != 0:
182
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
183
+
184
+ return terms, {}
185
+
186
+ @torch.no_grad()
187
+ def run_snapshot(
188
+ self,
189
+ num_samples: int,
190
+ batch_size: int,
191
+ verbose: bool = False,
192
+ ) -> Dict:
193
+ dataloader = DataLoader(
194
+ copy.deepcopy(self.dataset),
195
+ batch_size=batch_size,
196
+ shuffle=True,
197
+ num_workers=0,
198
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
199
+ )
200
+
201
+ # inference
202
+ sampler = self.get_sampler()
203
+ sample_gt = []
204
+ sample = []
205
+ cond_vis = []
206
+ for i in range(0, num_samples, batch_size):
207
+ batch = min(batch_size, num_samples - i)
208
+ data = next(iter(dataloader))
209
+ data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
210
+ noise = torch.randn_like(data['x_0'])
211
+ sample_gt.append(data['x_0'])
212
+ cond_vis.append(self.vis_cond(**data))
213
+ del data['x_0']
214
+ args = self.get_inference_cond(**data)
215
+ res = sampler.sample(
216
+ self.models['denoiser'],
217
+ noise=noise,
218
+ **args,
219
+ steps=50, cfg_strength=3.0, verbose=verbose,
220
+ )
221
+ sample.append(res.samples)
222
+
223
+ sample_gt = torch.cat(sample_gt, dim=0)
224
+ sample = torch.cat(sample, dim=0)
225
+ sample_dict = {
226
+ 'sample_gt': {'value': sample_gt, 'type': 'sample'},
227
+ 'sample': {'value': sample, 'type': 'sample'},
228
+ }
229
+ sample_dict.update(dict_reduce(cond_vis, None, {
230
+ 'value': lambda x: torch.cat(x, dim=0),
231
+ 'type': lambda x: x[0],
232
+ }))
233
+
234
+ return sample_dict
235
+
236
+
237
+ class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
238
+ """
239
+ Trainer for diffusion model with flow matching objective and classifier-free guidance.
240
+
241
+ Args:
242
+ models (dict[str, nn.Module]): Models to train.
243
+ dataset (torch.utils.data.Dataset): Dataset.
244
+ output_dir (str): Output directory.
245
+ load_dir (str): Load directory.
246
+ step (int): Step to load.
247
+ batch_size (int): Batch size.
248
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
249
+ batch_split (int): Split batch with gradient accumulation.
250
+ max_steps (int): Max steps.
251
+ optimizer (dict): Optimizer config.
252
+ lr_scheduler (dict): Learning rate scheduler config.
253
+ elastic (dict): Elastic memory management config.
254
+ grad_clip (float or dict): Gradient clip config.
255
+ ema_rate (float or list): Exponential moving average rates.
256
+ fp16_mode (str): FP16 mode.
257
+ - None: No FP16.
258
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
259
+ - 'amp': Automatic mixed precision.
260
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
261
+ finetune_ckpt (dict): Finetune checkpoint.
262
+ log_param_stats (bool): Log parameter stats.
263
+ i_print (int): Print interval.
264
+ i_log (int): Log interval.
265
+ i_sample (int): Sample interval.
266
+ i_save (int): Save interval.
267
+ i_ddpcheck (int): DDP check interval.
268
+
269
+ t_schedule (dict): Time schedule for flow matching.
270
+ sigma_min (float): Minimum noise level.
271
+ p_uncond (float): Probability of dropping conditions.
272
+ """
273
+ pass
274
+
275
+
276
+ class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
277
+ """
278
+ Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
279
+
280
+ Args:
281
+ models (dict[str, nn.Module]): Models to train.
282
+ dataset (torch.utils.data.Dataset): Dataset.
283
+ output_dir (str): Output directory.
284
+ load_dir (str): Load directory.
285
+ step (int): Step to load.
286
+ batch_size (int): Batch size.
287
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
288
+ batch_split (int): Split batch with gradient accumulation.
289
+ max_steps (int): Max steps.
290
+ optimizer (dict): Optimizer config.
291
+ lr_scheduler (dict): Learning rate scheduler config.
292
+ elastic (dict): Elastic memory management config.
293
+ grad_clip (float or dict): Gradient clip config.
294
+ ema_rate (float or list): Exponential moving average rates.
295
+ fp16_mode (str): FP16 mode.
296
+ - None: No FP16.
297
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
298
+ - 'amp': Automatic mixed precision.
299
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
300
+ finetune_ckpt (dict): Finetune checkpoint.
301
+ log_param_stats (bool): Log parameter stats.
302
+ i_print (int): Print interval.
303
+ i_log (int): Log interval.
304
+ i_sample (int): Sample interval.
305
+ i_save (int): Save interval.
306
+ i_ddpcheck (int): DDP check interval.
307
+
308
+ t_schedule (dict): Time schedule for flow matching.
309
+ sigma_min (float): Minimum noise level.
310
+ p_uncond (float): Probability of dropping conditions.
311
+ text_cond_model(str): Text conditioning model.
312
+ """
313
+ pass
314
+
315
+
316
+ class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
317
+ """
318
+ Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
319
+
320
+ Args:
321
+ models (dict[str, nn.Module]): Models to train.
322
+ dataset (torch.utils.data.Dataset): Dataset.
323
+ output_dir (str): Output directory.
324
+ load_dir (str): Load directory.
325
+ step (int): Step to load.
326
+ batch_size (int): Batch size.
327
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
328
+ batch_split (int): Split batch with gradient accumulation.
329
+ max_steps (int): Max steps.
330
+ optimizer (dict): Optimizer config.
331
+ lr_scheduler (dict): Learning rate scheduler config.
332
+ elastic (dict): Elastic memory management config.
333
+ grad_clip (float or dict): Gradient clip config.
334
+ ema_rate (float or list): Exponential moving average rates.
335
+ fp16_mode (str): FP16 mode.
336
+ - None: No FP16.
337
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
338
+ - 'amp': Automatic mixed precision.
339
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
340
+ finetune_ckpt (dict): Finetune checkpoint.
341
+ log_param_stats (bool): Log parameter stats.
342
+ i_print (int): Print interval.
343
+ i_log (int): Log interval.
344
+ i_sample (int): Sample interval.
345
+ i_save (int): Save interval.
346
+ i_ddpcheck (int): DDP check interval.
347
+
348
+ t_schedule (dict): Time schedule for flow matching.
349
+ sigma_min (float): Minimum noise level.
350
+ p_uncond (float): Probability of dropping conditions.
351
+ image_cond_model (str): Image conditioning model.
352
+ """
353
+ pass
trellis/trainers/flow_matching/mixins/classifier_free_guidance.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from ....utils.general_utils import dict_foreach
4
+ from ....pipelines import samplers
5
+
6
+
7
+ class ClassifierFreeGuidanceMixin:
8
+ def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.p_uncond = p_uncond
11
+
12
+ def get_cond(self, cond, neg_cond=None, **kwargs):
13
+ """
14
+ Get the conditioning data.
15
+ """
16
+ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
17
+
18
+ if self.p_uncond > 0:
19
+ # randomly drop the class label
20
+ def get_batch_size(cond):
21
+ if isinstance(cond, torch.Tensor):
22
+ return cond.shape[0]
23
+ elif isinstance(cond, list):
24
+ return len(cond)
25
+ else:
26
+ raise ValueError(f"Unsupported type of cond: {type(cond)}")
27
+
28
+ ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
29
+ B = get_batch_size(ref_cond)
30
+
31
+ def select(cond, neg_cond, mask):
32
+ if isinstance(cond, torch.Tensor):
33
+ mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
34
+ return torch.where(mask, neg_cond, cond)
35
+ elif isinstance(cond, list):
36
+ return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
37
+ else:
38
+ raise ValueError(f"Unsupported type of cond: {type(cond)}")
39
+
40
+ mask = list(np.random.rand(B) < self.p_uncond)
41
+ if not isinstance(cond, dict):
42
+ cond = select(cond, neg_cond, mask)
43
+ else:
44
+ cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
45
+
46
+ return cond
47
+
48
+ def get_inference_cond(self, cond, neg_cond=None, **kwargs):
49
+ """
50
+ Get the conditioning data for inference.
51
+ """
52
+ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
53
+ return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
54
+
55
+ def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
56
+ """
57
+ Get the sampler for the diffusion process.
58
+ """
59
+ return samplers.FlowEulerCfgSampler(self.sigma_min)
trellis/trainers/flow_matching/mixins/image_conditioned.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from ....utils import dist_utils
9
+
10
+
11
+ class ImageConditionedMixin:
12
+ """
13
+ Mixin for image-conditioned models.
14
+
15
+ Args:
16
+ image_cond_model: The image conditioning model.
17
+ """
18
+ def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ self.image_cond_model_name = image_cond_model
21
+ self.image_cond_model = None # the model is init lazily
22
+
23
+ @staticmethod
24
+ def prepare_for_training(image_cond_model: str, **kwargs):
25
+ """
26
+ Prepare for training.
27
+ """
28
+ if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
29
+ super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
30
+ # download the model
31
+ torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
32
+
33
+ def _init_image_cond_model(self):
34
+ """
35
+ Initialize the image conditioning model.
36
+ """
37
+ with dist_utils.local_master_first():
38
+ dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
39
+ dinov2_model.eval().cuda()
40
+ transform = transforms.Compose([
41
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
42
+ ])
43
+ self.image_cond_model = {
44
+ 'model': dinov2_model,
45
+ 'transform': transform,
46
+ }
47
+
48
+ @torch.no_grad()
49
+ def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
50
+ """
51
+ Encode the image.
52
+ """
53
+ if isinstance(image, torch.Tensor):
54
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
55
+ elif isinstance(image, list):
56
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
57
+ image = [i.resize((518, 518), Image.LANCZOS) for i in image]
58
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
59
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
60
+ image = torch.stack(image).cuda()
61
+ else:
62
+ raise ValueError(f"Unsupported type of image: {type(image)}")
63
+
64
+ if self.image_cond_model is None:
65
+ self._init_image_cond_model()
66
+ image = self.image_cond_model['transform'](image).cuda()
67
+ features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
68
+ patchtokens = F.layer_norm(features, features.shape[-1:])
69
+ return patchtokens
70
+
71
+ def get_cond(self, cond, **kwargs):
72
+ """
73
+ Get the conditioning data.
74
+ """
75
+ cond = self.encode_image(cond)
76
+ kwargs['neg_cond'] = torch.zeros_like(cond)
77
+ cond = super().get_cond(cond, **kwargs)
78
+ return cond
79
+
80
+ def get_inference_cond(self, cond, **kwargs):
81
+ """
82
+ Get the conditioning data for inference.
83
+ """
84
+ cond = self.encode_image(cond)
85
+ kwargs['neg_cond'] = torch.zeros_like(cond)
86
+ cond = super().get_inference_cond(cond, **kwargs)
87
+ return cond
88
+
89
+ def vis_cond(self, cond, **kwargs):
90
+ """
91
+ Visualize the conditioning data.
92
+ """
93
+ return {'image': {'value': cond, 'type': 'image'}}
trellis/trainers/flow_matching/mixins/text_conditioned.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import os
3
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
4
+ import torch
5
+ from transformers import AutoTokenizer, CLIPTextModel
6
+
7
+ from ....utils import dist_utils
8
+
9
+
10
+ class TextConditionedMixin:
11
+ """
12
+ Mixin for text-conditioned models.
13
+
14
+ Args:
15
+ text_cond_model: The text conditioning model.
16
+ """
17
+ def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ self.text_cond_model_name = text_cond_model
20
+ self.text_cond_model = None # the model is init lazily
21
+
22
+ def _init_text_cond_model(self):
23
+ """
24
+ Initialize the text conditioning model.
25
+ """
26
+ # load model
27
+ with dist_utils.local_master_first():
28
+ model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
29
+ tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
30
+ model.eval()
31
+ model = model.cuda()
32
+ self.text_cond_model = {
33
+ 'model': model,
34
+ 'tokenizer': tokenizer,
35
+ }
36
+ self.text_cond_model['null_cond'] = self.encode_text([''])
37
+
38
+ @torch.no_grad()
39
+ def encode_text(self, text: List[str]) -> torch.Tensor:
40
+ """
41
+ Encode the text.
42
+ """
43
+ assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
44
+ if self.text_cond_model is None:
45
+ self._init_text_cond_model()
46
+ encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
47
+ tokens = encoding['input_ids'].cuda()
48
+ embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
49
+
50
+ return embeddings
51
+
52
+ def get_cond(self, cond, **kwargs):
53
+ """
54
+ Get the conditioning data.
55
+ """
56
+ cond = self.encode_text(cond)
57
+ kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
58
+ cond = super().get_cond(cond, **kwargs)
59
+ return cond
60
+
61
+ def get_inference_cond(self, cond, **kwargs):
62
+ """
63
+ Get the conditioning data for inference.
64
+ """
65
+ cond = self.encode_text(cond)
66
+ kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
67
+ cond = super().get_inference_cond(cond, **kwargs)
68
+ return cond
trellis/trainers/flow_matching/sparse_flow_matching.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import os
3
+ import copy
4
+ import functools
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import DataLoader
8
+ import numpy as np
9
+ from easydict import EasyDict as edict
10
+
11
+ from ...modules import sparse as sp
12
+ from ...utils.general_utils import dict_reduce
13
+ from ...utils.data_utils import cycle, BalancedResumableSampler
14
+ from .flow_matching import FlowMatchingTrainer
15
+ from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
16
+ from .mixins.text_conditioned import TextConditionedMixin
17
+ from .mixins.image_conditioned import ImageConditionedMixin
18
+
19
+
20
+ class SparseFlowMatchingTrainer(FlowMatchingTrainer):
21
+ """
22
+ Trainer for sparse diffusion model with flow matching objective.
23
+
24
+ Args:
25
+ models (dict[str, nn.Module]): Models to train.
26
+ dataset (torch.utils.data.Dataset): Dataset.
27
+ output_dir (str): Output directory.
28
+ load_dir (str): Load directory.
29
+ step (int): Step to load.
30
+ batch_size (int): Batch size.
31
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
32
+ batch_split (int): Split batch with gradient accumulation.
33
+ max_steps (int): Max steps.
34
+ optimizer (dict): Optimizer config.
35
+ lr_scheduler (dict): Learning rate scheduler config.
36
+ elastic (dict): Elastic memory management config.
37
+ grad_clip (float or dict): Gradient clip config.
38
+ ema_rate (float or list): Exponential moving average rates.
39
+ fp16_mode (str): FP16 mode.
40
+ - None: No FP16.
41
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
42
+ - 'amp': Automatic mixed precision.
43
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
44
+ finetune_ckpt (dict): Finetune checkpoint.
45
+ log_param_stats (bool): Log parameter stats.
46
+ i_print (int): Print interval.
47
+ i_log (int): Log interval.
48
+ i_sample (int): Sample interval.
49
+ i_save (int): Save interval.
50
+ i_ddpcheck (int): DDP check interval.
51
+
52
+ t_schedule (dict): Time schedule for flow matching.
53
+ sigma_min (float): Minimum noise level.
54
+ """
55
+
56
+ def prepare_dataloader(self, **kwargs):
57
+ """
58
+ Prepare dataloader.
59
+ """
60
+ self.data_sampler = BalancedResumableSampler(
61
+ self.dataset,
62
+ shuffle=True,
63
+ batch_size=self.batch_size_per_gpu,
64
+ )
65
+ self.dataloader = DataLoader(
66
+ self.dataset,
67
+ batch_size=self.batch_size_per_gpu,
68
+ num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
69
+ pin_memory=True,
70
+ drop_last=True,
71
+ persistent_workers=True,
72
+ collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
73
+ sampler=self.data_sampler,
74
+ )
75
+ self.data_iterator = cycle(self.dataloader)
76
+
77
+ def training_losses(
78
+ self,
79
+ x_0: sp.SparseTensor,
80
+ cond=None,
81
+ **kwargs
82
+ ) -> Tuple[Dict, Dict]:
83
+ """
84
+ Compute training losses for a single timestep.
85
+
86
+ Args:
87
+ x_0: The [N x ... x C] sparse tensor of the inputs.
88
+ cond: The [N x ...] tensor of additional conditions.
89
+ kwargs: Additional arguments to pass to the backbone.
90
+
91
+ Returns:
92
+ a dict with the key "loss" containing a tensor of shape [N].
93
+ may also contain other keys for different terms.
94
+ """
95
+ noise = x_0.replace(torch.randn_like(x_0.feats))
96
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
97
+ x_t = self.diffuse(x_0, t, noise=noise)
98
+ cond = self.get_cond(cond, **kwargs)
99
+
100
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
101
+ assert pred.shape == noise.shape == x_0.shape
102
+ target = self.get_v(x_0, noise, t)
103
+ terms = edict()
104
+ terms["mse"] = F.mse_loss(pred.feats, target.feats)
105
+ terms["loss"] = terms["mse"]
106
+
107
+ # log loss with time bins
108
+ mse_per_instance = np.array([
109
+ F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
110
+ for i in range(x_0.shape[0])
111
+ ])
112
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
113
+ for i in range(10):
114
+ if (time_bin == i).sum() != 0:
115
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
116
+
117
+ return terms, {}
118
+
119
+ @torch.no_grad()
120
+ def run_snapshot(
121
+ self,
122
+ num_samples: int,
123
+ batch_size: int,
124
+ verbose: bool = False,
125
+ ) -> Dict:
126
+ dataloader = DataLoader(
127
+ copy.deepcopy(self.dataset),
128
+ batch_size=batch_size,
129
+ shuffle=True,
130
+ num_workers=0,
131
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
132
+ )
133
+
134
+ # inference
135
+ sampler = self.get_sampler()
136
+ sample_gt = []
137
+ sample = []
138
+ cond_vis = []
139
+ for i in range(0, num_samples, batch_size):
140
+ batch = min(batch_size, num_samples - i)
141
+ data = next(iter(dataloader))
142
+ data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()}
143
+ noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats))
144
+ sample_gt.append(data['x_0'])
145
+ cond_vis.append(self.vis_cond(**data))
146
+ del data['x_0']
147
+ args = self.get_inference_cond(**data)
148
+ res = sampler.sample(
149
+ self.models['denoiser'],
150
+ noise=noise,
151
+ **args,
152
+ steps=50, cfg_strength=3.0, verbose=verbose,
153
+ )
154
+ sample.append(res.samples)
155
+
156
+ sample_gt = sp.sparse_cat(sample_gt)
157
+ sample = sp.sparse_cat(sample)
158
+ sample_dict = {
159
+ 'sample_gt': {'value': sample_gt, 'type': 'sample'},
160
+ 'sample': {'value': sample, 'type': 'sample'},
161
+ }
162
+ sample_dict.update(dict_reduce(cond_vis, None, {
163
+ 'value': lambda x: torch.cat(x, dim=0),
164
+ 'type': lambda x: x[0],
165
+ }))
166
+
167
+ return sample_dict
168
+
169
+
170
+ class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
171
+ """
172
+ Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
173
+
174
+ Args:
175
+ models (dict[str, nn.Module]): Models to train.
176
+ dataset (torch.utils.data.Dataset): Dataset.
177
+ output_dir (str): Output directory.
178
+ load_dir (str): Load directory.
179
+ step (int): Step to load.
180
+ batch_size (int): Batch size.
181
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
182
+ batch_split (int): Split batch with gradient accumulation.
183
+ max_steps (int): Max steps.
184
+ optimizer (dict): Optimizer config.
185
+ lr_scheduler (dict): Learning rate scheduler config.
186
+ elastic (dict): Elastic memory management config.
187
+ grad_clip (float or dict): Gradient clip config.
188
+ ema_rate (float or list): Exponential moving average rates.
189
+ fp16_mode (str): FP16 mode.
190
+ - None: No FP16.
191
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
192
+ - 'amp': Automatic mixed precision.
193
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
194
+ finetune_ckpt (dict): Finetune checkpoint.
195
+ log_param_stats (bool): Log parameter stats.
196
+ i_print (int): Print interval.
197
+ i_log (int): Log interval.
198
+ i_sample (int): Sample interval.
199
+ i_save (int): Save interval.
200
+ i_ddpcheck (int): DDP check interval.
201
+
202
+ t_schedule (dict): Time schedule for flow matching.
203
+ sigma_min (float): Minimum noise level.
204
+ p_uncond (float): Probability of dropping conditions.
205
+ """
206
+ pass
207
+
208
+
209
+ class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
210
+ """
211
+ Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
212
+
213
+ Args:
214
+ models (dict[str, nn.Module]): Models to train.
215
+ dataset (torch.utils.data.Dataset): Dataset.
216
+ output_dir (str): Output directory.
217
+ load_dir (str): Load directory.
218
+ step (int): Step to load.
219
+ batch_size (int): Batch size.
220
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
221
+ batch_split (int): Split batch with gradient accumulation.
222
+ max_steps (int): Max steps.
223
+ optimizer (dict): Optimizer config.
224
+ lr_scheduler (dict): Learning rate scheduler config.
225
+ elastic (dict): Elastic memory management config.
226
+ grad_clip (float or dict): Gradient clip config.
227
+ ema_rate (float or list): Exponential moving average rates.
228
+ fp16_mode (str): FP16 mode.
229
+ - None: No FP16.
230
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
231
+ - 'amp': Automatic mixed precision.
232
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
233
+ finetune_ckpt (dict): Finetune checkpoint.
234
+ log_param_stats (bool): Log parameter stats.
235
+ i_print (int): Print interval.
236
+ i_log (int): Log interval.
237
+ i_sample (int): Sample interval.
238
+ i_save (int): Save interval.
239
+ i_ddpcheck (int): DDP check interval.
240
+
241
+ t_schedule (dict): Time schedule for flow matching.
242
+ sigma_min (float): Minimum noise level.
243
+ p_uncond (float): Probability of dropping conditions.
244
+ text_cond_model(str): Text conditioning model.
245
+ """
246
+ pass
247
+
248
+
249
+ class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
250
+ """
251
+ Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
252
+
253
+ Args:
254
+ models (dict[str, nn.Module]): Models to train.
255
+ dataset (torch.utils.data.Dataset): Dataset.
256
+ output_dir (str): Output directory.
257
+ load_dir (str): Load directory.
258
+ step (int): Step to load.
259
+ batch_size (int): Batch size.
260
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
261
+ batch_split (int): Split batch with gradient accumulation.
262
+ max_steps (int): Max steps.
263
+ optimizer (dict): Optimizer config.
264
+ lr_scheduler (dict): Learning rate scheduler config.
265
+ elastic (dict): Elastic memory management config.
266
+ grad_clip (float or dict): Gradient clip config.
267
+ ema_rate (float or list): Exponential moving average rates.
268
+ fp16_mode (str): FP16 mode.
269
+ - None: No FP16.
270
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
271
+ - 'amp': Automatic mixed precision.
272
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
273
+ finetune_ckpt (dict): Finetune checkpoint.
274
+ log_param_stats (bool): Log parameter stats.
275
+ i_print (int): Print interval.
276
+ i_log (int): Log interval.
277
+ i_sample (int): Sample interval.
278
+ i_save (int): Save interval.
279
+ i_ddpcheck (int): DDP check interval.
280
+
281
+ t_schedule (dict): Time schedule for flow matching.
282
+ sigma_min (float): Minimum noise level.
283
+ p_uncond (float): Probability of dropping conditions.
284
+ image_cond_model (str): Image conditioning model.
285
+ """
286
+ pass
trellis/trainers/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ # FP16 utils
5
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
6
+
7
+ def make_master_params(model_params):
8
+ """
9
+ Copy model parameters into a inflated tensor of full-precision parameters.
10
+ """
11
+ master_params = _flatten_dense_tensors(
12
+ [param.detach().float() for param in model_params]
13
+ )
14
+ master_params = nn.Parameter(master_params)
15
+ master_params.requires_grad = True
16
+ return [master_params]
17
+
18
+
19
+ def unflatten_master_params(model_params, master_params):
20
+ """
21
+ Unflatten the master parameters to look like model_params.
22
+ """
23
+ return _unflatten_dense_tensors(master_params[0].detach(), model_params)
24
+
25
+
26
+ def model_params_to_master_params(model_params, master_params):
27
+ """
28
+ Copy the model parameter data into the master parameters.
29
+ """
30
+ master_params[0].detach().copy_(
31
+ _flatten_dense_tensors([param.detach().float() for param in model_params])
32
+ )
33
+
34
+
35
+ def master_params_to_model_params(model_params, master_params):
36
+ """
37
+ Copy the master parameter data back into the model parameters.
38
+ """
39
+ for param, master_param in zip(
40
+ model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
41
+ ):
42
+ param.detach().copy_(master_param)
43
+
44
+
45
+ def model_grads_to_master_grads(model_params, master_params):
46
+ """
47
+ Copy the gradients from the model parameters into the master parameters
48
+ from make_master_params().
49
+ """
50
+ master_params[0].grad = _flatten_dense_tensors(
51
+ [param.grad.data.detach().float() for param in model_params]
52
+ )
53
+
54
+
55
+ def zero_grad(model_params):
56
+ for param in model_params:
57
+ if param.grad is not None:
58
+ if param.grad.grad_fn is not None:
59
+ param.grad.detach_()
60
+ else:
61
+ param.grad.requires_grad_(False)
62
+ param.grad.zero_()
63
+
64
+
65
+ # LR Schedulers
66
+ from torch.optim.lr_scheduler import LambdaLR
67
+
68
+ class LinearWarmupLRScheduler(LambdaLR):
69
+ def __init__(self, optimizer, warmup_steps, last_epoch=-1):
70
+ self.warmup_steps = warmup_steps
71
+ super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
72
+
73
+ def lr_lambda(self, current_step):
74
+ if current_step < self.warmup_steps:
75
+ return float(current_step + 1) / self.warmup_steps
76
+ return 1.0
77
+
trellis/trainers/vae/sparse_structure_vae.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from easydict import EasyDict as edict
7
+
8
+ from ..basic import BasicTrainer
9
+
10
+
11
+ class SparseStructureVaeTrainer(BasicTrainer):
12
+ """
13
+ Trainer for Sparse Structure VAE.
14
+
15
+ Args:
16
+ models (dict[str, nn.Module]): Models to train.
17
+ dataset (torch.utils.data.Dataset): Dataset.
18
+ output_dir (str): Output directory.
19
+ load_dir (str): Load directory.
20
+ step (int): Step to load.
21
+ batch_size (int): Batch size.
22
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
23
+ batch_split (int): Split batch with gradient accumulation.
24
+ max_steps (int): Max steps.
25
+ optimizer (dict): Optimizer config.
26
+ lr_scheduler (dict): Learning rate scheduler config.
27
+ elastic (dict): Elastic memory management config.
28
+ grad_clip (float or dict): Gradient clip config.
29
+ ema_rate (float or list): Exponential moving average rates.
30
+ fp16_mode (str): FP16 mode.
31
+ - None: No FP16.
32
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
33
+ - 'amp': Automatic mixed precision.
34
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
35
+ finetune_ckpt (dict): Finetune checkpoint.
36
+ log_param_stats (bool): Log parameter stats.
37
+ i_print (int): Print interval.
38
+ i_log (int): Log interval.
39
+ i_sample (int): Sample interval.
40
+ i_save (int): Save interval.
41
+ i_ddpcheck (int): DDP check interval.
42
+
43
+ loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
44
+ lambda_kl (float): KL divergence loss weight.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ *args,
50
+ loss_type='bce',
51
+ lambda_kl=1e-6,
52
+ **kwargs
53
+ ):
54
+ super().__init__(*args, **kwargs)
55
+ self.loss_type = loss_type
56
+ self.lambda_kl = lambda_kl
57
+
58
+ def training_losses(
59
+ self,
60
+ ss: torch.Tensor,
61
+ **kwargs
62
+ ) -> Tuple[Dict, Dict]:
63
+ """
64
+ Compute training losses.
65
+
66
+ Args:
67
+ ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
68
+
69
+ Returns:
70
+ a dict with the key "loss" containing a scalar tensor.
71
+ may also contain other keys for different terms.
72
+ """
73
+ z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True)
74
+ logits = self.training_models['decoder'](z)
75
+
76
+ terms = edict(loss = 0.0)
77
+ if self.loss_type == 'bce':
78
+ terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean')
79
+ terms["loss"] = terms["loss"] + terms["bce"]
80
+ elif self.loss_type == 'l1':
81
+ terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean')
82
+ terms["loss"] = terms["loss"] + terms["l1"]
83
+ elif self.loss_type == 'dice':
84
+ logits = F.sigmoid(logits)
85
+ terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
86
+ terms["loss"] = terms["loss"] + terms["dice"]
87
+ else:
88
+ raise ValueError(f'Invalid loss type {self.loss_type}')
89
+ terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
90
+ terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
91
+
92
+ return terms, {}
93
+
94
+ @torch.no_grad()
95
+ def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False):
96
+ super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose)
97
+
98
+ @torch.no_grad()
99
+ def run_snapshot(
100
+ self,
101
+ num_samples: int,
102
+ batch_size: int,
103
+ verbose: bool = False,
104
+ ) -> Dict:
105
+ dataloader = DataLoader(
106
+ copy.deepcopy(self.dataset),
107
+ batch_size=batch_size,
108
+ shuffle=True,
109
+ num_workers=0,
110
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
111
+ )
112
+
113
+ # inference
114
+ gts = []
115
+ recons = []
116
+ for i in range(0, num_samples, batch_size):
117
+ batch = min(batch_size, num_samples - i)
118
+ data = next(iter(dataloader))
119
+ args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
120
+ z = self.models['encoder'](args['ss'].float(), sample_posterior=False)
121
+ logits = self.models['decoder'](z)
122
+ recon = (logits > 0).long()
123
+ gts.append(args['ss'])
124
+ recons.append(recon)
125
+
126
+ sample_dict = {
127
+ 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'},
128
+ 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'},
129
+ }
130
+ return sample_dict
trellis/trainers/vae/structured_latent_vae_gaussian.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ from easydict import EasyDict as edict
7
+ import utils3d.torch
8
+
9
+ from ..basic import BasicTrainer
10
+ from ...representations import Gaussian
11
+ from ...renderers import GaussianRenderer
12
+ from ...modules.sparse import SparseTensor
13
+ from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
14
+
15
+
16
+ class SLatVaeGaussianTrainer(BasicTrainer):
17
+ """
18
+ Trainer for structured latent VAE.
19
+
20
+ Args:
21
+ models (dict[str, nn.Module]): Models to train.
22
+ dataset (torch.utils.data.Dataset): Dataset.
23
+ output_dir (str): Output directory.
24
+ load_dir (str): Load directory.
25
+ step (int): Step to load.
26
+ batch_size (int): Batch size.
27
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
28
+ batch_split (int): Split batch with gradient accumulation.
29
+ max_steps (int): Max steps.
30
+ optimizer (dict): Optimizer config.
31
+ lr_scheduler (dict): Learning rate scheduler config.
32
+ elastic (dict): Elastic memory management config.
33
+ grad_clip (float or dict): Gradient clip config.
34
+ ema_rate (float or list): Exponential moving average rates.
35
+ fp16_mode (str): FP16 mode.
36
+ - None: No FP16.
37
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
38
+ - 'amp': Automatic mixed precision.
39
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
40
+ finetune_ckpt (dict): Finetune checkpoint.
41
+ log_param_stats (bool): Log parameter stats.
42
+ i_print (int): Print interval.
43
+ i_log (int): Log interval.
44
+ i_sample (int): Sample interval.
45
+ i_save (int): Save interval.
46
+ i_ddpcheck (int): DDP check interval.
47
+
48
+ loss_type (str): Loss type. Can be 'l1', 'l2'
49
+ lambda_ssim (float): SSIM loss weight.
50
+ lambda_lpips (float): LPIPS loss weight.
51
+ lambda_kl (float): KL loss weight.
52
+ regularizations (dict): Regularization config.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ *args,
58
+ loss_type: str = 'l1',
59
+ lambda_ssim: float = 0.2,
60
+ lambda_lpips: float = 0.2,
61
+ lambda_kl: float = 1e-6,
62
+ regularizations: Dict = {},
63
+ **kwargs
64
+ ):
65
+ super().__init__(*args, **kwargs)
66
+ self.loss_type = loss_type
67
+ self.lambda_ssim = lambda_ssim
68
+ self.lambda_lpips = lambda_lpips
69
+ self.lambda_kl = lambda_kl
70
+ self.regularizations = regularizations
71
+
72
+ self._init_renderer()
73
+
74
+ def _init_renderer(self):
75
+ rendering_options = {"near" : 0.8,
76
+ "far" : 1.6,
77
+ "bg_color" : 'random'}
78
+ self.renderer = GaussianRenderer(rendering_options)
79
+ self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size']
80
+
81
+ def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ Render a batch of representations.
84
+
85
+ Args:
86
+ reps: The dictionary of lists of representations.
87
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
88
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
89
+ """
90
+ ret = None
91
+ for i, representation in enumerate(reps):
92
+ render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
93
+ if ret is None:
94
+ ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
95
+ for k, v in render_pack.items():
96
+ ret[k].append(v)
97
+ ret['bg_color'].append(self.renderer.bg_color)
98
+ for k, v in ret.items():
99
+ ret[k] = torch.stack(v, dim=0)
100
+ return ret
101
+
102
+ @torch.no_grad()
103
+ def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict:
104
+ xyz = torch.cat([g.get_xyz for g in reps], dim=0)
105
+ xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5
106
+ offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3)
107
+ status = {
108
+ 'xyz': xyz,
109
+ 'offset': offset,
110
+ 'scale': torch.cat([g.get_scaling for g in reps], dim=0),
111
+ 'opacity': torch.cat([g.get_opacity for g in reps], dim=0),
112
+ }
113
+
114
+ for k in list(status.keys()):
115
+ status[k] = {
116
+ 'mean': status[k].mean().item(),
117
+ 'max': status[k].max().item(),
118
+ 'min': status[k].min().item(),
119
+ }
120
+
121
+ return status
122
+
123
+ def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]:
124
+ loss = 0.0
125
+ terms = {}
126
+ if 'lambda_vol' in self.regularizations:
127
+ scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3]
128
+ volume = torch.prod(scales, dim=1) # [N]
129
+ terms[f'reg_vol'] = volume.mean()
130
+ loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol']
131
+ if 'lambda_opacity' in self.regularizations:
132
+ opacity = torch.cat([g.get_opacity for g in reps], dim=0)
133
+ terms[f'reg_opacity'] = (opacity - 1).pow(2).mean()
134
+ loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity']
135
+ return loss, terms
136
+
137
+ def training_losses(
138
+ self,
139
+ feats: SparseTensor,
140
+ image: torch.Tensor,
141
+ alpha: torch.Tensor,
142
+ extrinsics: torch.Tensor,
143
+ intrinsics: torch.Tensor,
144
+ return_aux: bool = False,
145
+ **kwargs
146
+ ) -> Tuple[Dict, Dict]:
147
+ """
148
+ Compute training losses.
149
+
150
+ Args:
151
+ feats: The [N x * x C] sparse tensor of features.
152
+ image: The [N x 3 x H x W] tensor of images.
153
+ alpha: The [N x H x W] tensor of alpha channels.
154
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
155
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
156
+ return_aux: Whether to return auxiliary information.
157
+
158
+ Returns:
159
+ a dict with the key "loss" containing a scalar tensor.
160
+ may also contain other keys for different terms.
161
+ """
162
+ z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True)
163
+ reps = self.training_models['decoder'](z)
164
+ self.renderer.rendering_options.resolution = image.shape[-1]
165
+ render_results = self._render_batch(reps, extrinsics, intrinsics)
166
+
167
+ terms = edict(loss = 0.0, rec = 0.0)
168
+
169
+ rec_image = render_results['color']
170
+ gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
171
+
172
+ if self.loss_type == 'l1':
173
+ terms["l1"] = l1_loss(rec_image, gt_image)
174
+ terms["rec"] = terms["rec"] + terms["l1"]
175
+ elif self.loss_type == 'l2':
176
+ terms["l2"] = l2_loss(rec_image, gt_image)
177
+ terms["rec"] = terms["rec"] + terms["l2"]
178
+ else:
179
+ raise ValueError(f"Invalid loss type: {self.loss_type}")
180
+ if self.lambda_ssim > 0:
181
+ terms["ssim"] = 1 - ssim(rec_image, gt_image)
182
+ terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
183
+ if self.lambda_lpips > 0:
184
+ terms["lpips"] = lpips(rec_image, gt_image)
185
+ terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
186
+ terms["loss"] = terms["loss"] + terms["rec"]
187
+
188
+ terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
189
+ terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
190
+
191
+ reg_loss, reg_terms = self._get_regularization_loss(reps)
192
+ terms.update(reg_terms)
193
+ terms["loss"] = terms["loss"] + reg_loss
194
+
195
+ status = self._get_status(z, reps)
196
+
197
+ if return_aux:
198
+ return terms, status, {'rec_image': rec_image, 'gt_image': gt_image}
199
+ return terms, status
200
+
201
+ @torch.no_grad()
202
+ def run_snapshot(
203
+ self,
204
+ num_samples: int,
205
+ batch_size: int,
206
+ verbose: bool = False,
207
+ ) -> Dict:
208
+ dataloader = DataLoader(
209
+ copy.deepcopy(self.dataset),
210
+ batch_size=batch_size,
211
+ shuffle=True,
212
+ num_workers=0,
213
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
214
+ )
215
+
216
+ # inference
217
+ ret_dict = {}
218
+ gt_images = []
219
+ exts = []
220
+ ints = []
221
+ reps = []
222
+ for i in range(0, num_samples, batch_size):
223
+ batch = min(batch_size, num_samples - i)
224
+ data = next(iter(dataloader))
225
+ args = {k: v[:batch].cuda() for k, v in data.items()}
226
+ gt_images.append(args['image'] * args['alpha'][:, None])
227
+ exts.append(args['extrinsics'])
228
+ ints.append(args['intrinsics'])
229
+ z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False)
230
+ reps.extend(self.models['decoder'](z))
231
+ gt_images = torch.cat(gt_images, dim=0)
232
+ ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
233
+
234
+ # render single view
235
+ exts = torch.cat(exts, dim=0)
236
+ ints = torch.cat(ints, dim=0)
237
+ self.renderer.rendering_options.bg_color = (0, 0, 0)
238
+ self.renderer.rendering_options.resolution = gt_images.shape[-1]
239
+ render_results = self._render_batch(reps, exts, ints)
240
+ ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
241
+
242
+ # render multiview
243
+ self.renderer.rendering_options.resolution = 512
244
+ ## Build camera
245
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
246
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
247
+ yaws = [y + yaws_offset for y in yaws]
248
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
249
+
250
+ ## render each view
251
+ miltiview_images = []
252
+ for yaw, pitch in zip(yaws, pitch):
253
+ orig = torch.tensor([
254
+ np.sin(yaw) * np.cos(pitch),
255
+ np.cos(yaw) * np.cos(pitch),
256
+ np.sin(pitch),
257
+ ]).float().cuda() * 2
258
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
259
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
260
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
261
+ extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
262
+ intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
263
+ render_results = self._render_batch(reps, extrinsics, intrinsics)
264
+ miltiview_images.append(render_results['color'])
265
+
266
+ ## Concatenate views
267
+ miltiview_images = torch.cat([
268
+ torch.cat(miltiview_images[:2], dim=-2),
269
+ torch.cat(miltiview_images[2:], dim=-2),
270
+ ], dim=-1)
271
+ ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
272
+
273
+ self.renderer.rendering_options.bg_color = 'random'
274
+
275
+ return ret_dict
trellis/trainers/vae/structured_latent_vae_mesh_dec.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ from easydict import EasyDict as edict
7
+ import utils3d.torch
8
+
9
+ from ..basic import BasicTrainer
10
+ from ...representations import MeshExtractResult
11
+ from ...renderers import MeshRenderer
12
+ from ...modules.sparse import SparseTensor
13
+ from ...utils.loss_utils import l1_loss, smooth_l1_loss, ssim, lpips
14
+ from ...utils.data_utils import recursive_to_device
15
+
16
+
17
+ class SLatVaeMeshDecoderTrainer(BasicTrainer):
18
+ """
19
+ Trainer for structured latent VAE Mesh Decoder.
20
+
21
+ Args:
22
+ models (dict[str, nn.Module]): Models to train.
23
+ dataset (torch.utils.data.Dataset): Dataset.
24
+ output_dir (str): Output directory.
25
+ load_dir (str): Load directory.
26
+ step (int): Step to load.
27
+ batch_size (int): Batch size.
28
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
29
+ batch_split (int): Split batch with gradient accumulation.
30
+ max_steps (int): Max steps.
31
+ optimizer (dict): Optimizer config.
32
+ lr_scheduler (dict): Learning rate scheduler config.
33
+ elastic (dict): Elastic memory management config.
34
+ grad_clip (float or dict): Gradient clip config.
35
+ ema_rate (float or list): Exponential moving average rates.
36
+ fp16_mode (str): FP16 mode.
37
+ - None: No FP16.
38
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
39
+ - 'amp': Automatic mixed precision.
40
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
41
+ finetune_ckpt (dict): Finetune checkpoint.
42
+ log_param_stats (bool): Log parameter stats.
43
+ i_print (int): Print interval.
44
+ i_log (int): Log interval.
45
+ i_sample (int): Sample interval.
46
+ i_save (int): Save interval.
47
+ i_ddpcheck (int): DDP check interval.
48
+
49
+ loss_type (str): Loss type. Can be 'l1', 'l2'
50
+ lambda_ssim (float): SSIM loss weight.
51
+ lambda_lpips (float): LPIPS loss weight.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ *args,
57
+ depth_loss_type: str = 'l1',
58
+ lambda_depth: int = 1,
59
+ lambda_ssim: float = 0.2,
60
+ lambda_lpips: float = 0.2,
61
+ lambda_tsdf: float = 0.01,
62
+ lambda_color: float = 0.1,
63
+ **kwargs
64
+ ):
65
+ super().__init__(*args, **kwargs)
66
+ self.depth_loss_type = depth_loss_type
67
+ self.lambda_depth = lambda_depth
68
+ self.lambda_ssim = lambda_ssim
69
+ self.lambda_lpips = lambda_lpips
70
+ self.lambda_tsdf = lambda_tsdf
71
+ self.lambda_color = lambda_color
72
+ self.use_color = self.lambda_color > 0
73
+
74
+ self._init_renderer()
75
+
76
+ def _init_renderer(self):
77
+ rendering_options = {"near" : 1,
78
+ "far" : 3}
79
+ self.renderer = MeshRenderer(rendering_options, device=self.device)
80
+
81
+ def _render_batch(self, reps: List[MeshExtractResult], extrinsics: torch.Tensor, intrinsics: torch.Tensor,
82
+ return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
83
+ """
84
+ Render a batch of representations.
85
+
86
+ Args:
87
+ reps: The dictionary of lists of representations.
88
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
89
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
90
+ return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
91
+
92
+ Returns:
93
+ a dict with
94
+ reg_loss : [N] tensor of regularization losses
95
+ mask : [N x 1 x H x W] tensor of rendered masks
96
+ normal : [N x 3 x H x W] tensor of rendered normals
97
+ depth : [N x 1 x H x W] tensor of rendered depths
98
+ """
99
+ ret = {k : [] for k in return_types}
100
+ for i, rep in enumerate(reps):
101
+ out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
102
+ for k in out_dict:
103
+ ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
104
+ for k in ret:
105
+ ret[k] = torch.stack(ret[k])
106
+ return ret
107
+
108
+ @staticmethod
109
+ def _tsdf_reg_loss(rep: MeshExtractResult, depth_map: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
110
+ # Calculate tsdf
111
+ with torch.no_grad():
112
+ # Project points to camera and calculate pseudo-sdf as difference between gt depth and projected depth
113
+ projected_pts, pts_depth = utils3d.torch.project_cv(extrinsics=extrinsics, intrinsics=intrinsics, points=rep.tsdf_v)
114
+ projected_pts = (projected_pts - 0.5) * 2.0
115
+ depth_map_res = depth_map.shape[1]
116
+ gt_depth = torch.nn.functional.grid_sample(depth_map.reshape(1, 1, depth_map_res, depth_map_res),
117
+ projected_pts.reshape(1, 1, -1, 2), mode='bilinear', padding_mode='border', align_corners=True)
118
+ pseudo_sdf = gt_depth.flatten() - pts_depth.flatten()
119
+ # Truncate pseudo-sdf
120
+ delta = 1 / rep.res * 3.0
121
+ trunc_mask = pseudo_sdf > -delta
122
+
123
+ # Loss
124
+ gt_tsdf = pseudo_sdf[trunc_mask]
125
+ tsdf = rep.tsdf_s.flatten()[trunc_mask]
126
+ gt_tsdf = torch.clamp(gt_tsdf, -delta, delta)
127
+ return torch.mean((tsdf - gt_tsdf) ** 2)
128
+
129
+ def _calc_tsdf_loss(self, reps : list[MeshExtractResult], depth_maps, extrinsics, intrinsics) -> torch.Tensor:
130
+ tsdf_loss = 0.0
131
+ for i, rep in enumerate(reps):
132
+ tsdf_loss += self._tsdf_reg_loss(rep, depth_maps[i], extrinsics[i], intrinsics[i])
133
+ return tsdf_loss / len(reps)
134
+
135
+ @torch.no_grad()
136
+ def _flip_normal(self, normal: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
137
+ """
138
+ Flip normal to align with camera.
139
+ """
140
+ normal = normal * 2.0 - 1.0
141
+ R = torch.zeros_like(extrinsics)
142
+ R[:, :3, :3] = extrinsics[:, :3, :3]
143
+ R[:, 3, 3] = 1.0
144
+ view_dir = utils3d.torch.unproject_cv(
145
+ utils3d.torch.image_uv(*normal.shape[-2:], device=self.device).reshape(1, -1, 2),
146
+ torch.ones(*normal.shape[-2:], device=self.device).reshape(1, -1),
147
+ R, intrinsics
148
+ ).reshape(-1, *normal.shape[-2:], 3).permute(0, 3, 1, 2)
149
+ unflip = (normal * view_dir).sum(1, keepdim=True) < 0
150
+ normal *= unflip * 2.0 - 1.0
151
+ return (normal + 1.0) / 2.0
152
+
153
+ def _perceptual_loss(self, gt: torch.Tensor, pred: torch.Tensor, name: str) -> Dict[str, torch.Tensor]:
154
+ """
155
+ Combination of L1, SSIM, and LPIPS loss.
156
+ """
157
+ if gt.shape[1] != 3:
158
+ assert gt.shape[-1] == 3
159
+ gt = gt.permute(0, 3, 1, 2)
160
+ if pred.shape[1] != 3:
161
+ assert pred.shape[-1] == 3
162
+ pred = pred.permute(0, 3, 1, 2)
163
+ terms = {
164
+ f"{name}_loss" : l1_loss(gt, pred),
165
+ f"{name}_loss_ssim" : 1 - ssim(gt, pred),
166
+ f"{name}_loss_lpips" : lpips(gt, pred)
167
+ }
168
+ terms[f"{name}_loss_perceptual"] = terms[f"{name}_loss"] + terms[f"{name}_loss_ssim"] * self.lambda_ssim + terms[f"{name}_loss_lpips"] * self.lambda_lpips
169
+ return terms
170
+
171
+ def geometry_losses(
172
+ self,
173
+ reps: List[MeshExtractResult],
174
+ mesh: List[Dict],
175
+ normal_map: torch.Tensor,
176
+ extrinsics: torch.Tensor,
177
+ intrinsics: torch.Tensor,
178
+ ):
179
+ with torch.no_grad():
180
+ gt_meshes = []
181
+ for i in range(len(reps)):
182
+ gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device))
183
+ gt_meshes.append(gt_mesh)
184
+ target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal'])
185
+ target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics)
186
+
187
+ terms = edict(geo_loss = 0.0)
188
+ if self.lambda_tsdf > 0:
189
+ tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics)
190
+ terms['tsdf_loss'] = tsdf_loss
191
+ terms['geo_loss'] += tsdf_loss * self.lambda_tsdf
192
+
193
+ return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal']
194
+ buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
195
+
196
+ success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
197
+ if success_mask.sum() != 0:
198
+ for k, v in buffer.items():
199
+ buffer[k] = v[success_mask]
200
+ for k, v in target.items():
201
+ target[k] = v[success_mask]
202
+
203
+ terms['mask_loss'] = l1_loss(buffer['mask'], target['mask'])
204
+ if self.depth_loss_type == 'l1':
205
+ terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'])
206
+ elif self.depth_loss_type == 'smooth_l1':
207
+ terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res))
208
+ else:
209
+ raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}")
210
+ terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal'))
211
+ terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual']
212
+ if self.use_color and normal_map is not None:
213
+ terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map'))
214
+ terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color
215
+
216
+ return terms
217
+
218
+ def color_losses(self, reps, image, alpha, extrinsics, intrinsics):
219
+ terms = edict(color_loss = torch.tensor(0.0, device=self.device))
220
+ buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=['color'])
221
+ success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
222
+ if success_mask.sum() != 0:
223
+ terms.update(self._perceptual_loss(image * alpha[:, None][success_mask], buffer['color'][success_mask], 'color'))
224
+ terms['color_loss'] = terms['color_loss'] + terms['color_loss_perceptual'] * self.lambda_color
225
+ return terms
226
+
227
+ def training_losses(
228
+ self,
229
+ latents: SparseTensor,
230
+ image: torch.Tensor,
231
+ alpha: torch.Tensor,
232
+ mesh: List[Dict],
233
+ extrinsics: torch.Tensor,
234
+ intrinsics: torch.Tensor,
235
+ normal_map: torch.Tensor = None,
236
+ ) -> Tuple[Dict, Dict]:
237
+ """
238
+ Compute training losses.
239
+
240
+ Args:
241
+ latents: The [N x * x C] sparse latents
242
+ image: The [N x 3 x H x W] tensor of images.
243
+ alpha: The [N x H x W] tensor of alpha channels.
244
+ mesh: The list of dictionaries of meshes.
245
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
246
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
247
+
248
+ Returns:
249
+ a dict with the key "loss" containing a scalar tensor.
250
+ may also contain other keys for different terms.
251
+ """
252
+ reps = self.training_models['decoder'](latents)
253
+ self.renderer.rendering_options.resolution = image.shape[-1]
254
+
255
+ terms = edict(loss = 0.0, rec = 0.0)
256
+
257
+ terms['reg_loss'] = sum([rep.reg_loss for rep in reps]) / len(reps)
258
+ terms['loss'] = terms['loss'] + terms['reg_loss']
259
+
260
+ geo_terms = self.geometry_losses(reps, mesh, normal_map, extrinsics, intrinsics)
261
+ terms.update(geo_terms)
262
+ terms['loss'] = terms['loss'] + terms['geo_loss']
263
+
264
+ if self.use_color:
265
+ color_terms = self.color_losses(reps, image, alpha, extrinsics, intrinsics)
266
+ terms.update(color_terms)
267
+ terms['loss'] = terms['loss'] + terms['color_loss']
268
+
269
+ return terms, {}
270
+
271
+ @torch.no_grad()
272
+ def run_snapshot(
273
+ self,
274
+ num_samples: int,
275
+ batch_size: int,
276
+ verbose: bool = False,
277
+ ) -> Dict:
278
+ dataloader = DataLoader(
279
+ copy.deepcopy(self.dataset),
280
+ batch_size=batch_size,
281
+ shuffle=True,
282
+ num_workers=0,
283
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
284
+ )
285
+
286
+ # inference
287
+ ret_dict = {}
288
+ gt_images = []
289
+ gt_normal_maps = []
290
+ gt_meshes = []
291
+ exts = []
292
+ ints = []
293
+ reps = []
294
+ for i in range(0, num_samples, batch_size):
295
+ batch = min(batch_size, num_samples - i)
296
+ data = next(iter(dataloader))
297
+ args = recursive_to_device(data, 'cuda')
298
+ gt_images.append(args['image'] * args['alpha'][:, None])
299
+ if self.use_color and 'normal_map' in data:
300
+ gt_normal_maps.append(args['normal_map'])
301
+ gt_meshes.extend(args['mesh'])
302
+ exts.append(args['extrinsics'])
303
+ ints.append(args['intrinsics'])
304
+ reps.extend(self.models['decoder'](args['latents']))
305
+ gt_images = torch.cat(gt_images, dim=0)
306
+ ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
307
+ if self.use_color and gt_normal_maps:
308
+ gt_normal_maps = torch.cat(gt_normal_maps, dim=0)
309
+ ret_dict.update({f'gt_normal_map': {'value': gt_normal_maps, 'type': 'image'}})
310
+
311
+ # render single view
312
+ exts = torch.cat(exts, dim=0)
313
+ ints = torch.cat(ints, dim=0)
314
+ self.renderer.rendering_options.bg_color = (0, 0, 0)
315
+ self.renderer.rendering_options.resolution = gt_images.shape[-1]
316
+ gt_render_results = self._render_batch([
317
+ MeshExtractResult(vertices=mesh['vertices'].to(self.device), faces=mesh['faces'].to(self.device))
318
+ for mesh in gt_meshes
319
+ ], exts, ints, return_types=['normal'])
320
+ ret_dict.update({f'gt_normal': {'value': self._flip_normal(gt_render_results['normal'], exts, ints), 'type': 'image'}})
321
+ return_types = ['normal']
322
+ if self.use_color:
323
+ return_types.append('color')
324
+ if 'normal_map' in data:
325
+ return_types.append('normal_map')
326
+ render_results = self._render_batch(reps, exts, ints, return_types=return_types)
327
+ ret_dict.update({f'rec_normal': {'value': render_results['normal'], 'type': 'image'}})
328
+ if 'color' in return_types:
329
+ ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
330
+ if 'normal_map' in return_types:
331
+ ret_dict.update({f'rec_normal_map': {'value': render_results['normal_map'], 'type': 'image'}})
332
+
333
+ # render multiview
334
+ self.renderer.rendering_options.resolution = 512
335
+ ## Build camera
336
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
337
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
338
+ yaws = [y + yaws_offset for y in yaws]
339
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
340
+
341
+ ## render each view
342
+ multiview_normals = []
343
+ multiview_normal_maps = []
344
+ miltiview_images = []
345
+ for yaw, pitch in zip(yaws, pitch):
346
+ orig = torch.tensor([
347
+ np.sin(yaw) * np.cos(pitch),
348
+ np.cos(yaw) * np.cos(pitch),
349
+ np.sin(pitch),
350
+ ]).float().cuda() * 2
351
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
352
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
353
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
354
+ extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
355
+ intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
356
+ render_results = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
357
+ multiview_normals.append(render_results['normal'])
358
+ if 'color' in return_types:
359
+ miltiview_images.append(render_results['color'])
360
+ if 'normal_map' in return_types:
361
+ multiview_normal_maps.append(render_results['normal_map'])
362
+
363
+ ## Concatenate views
364
+ multiview_normals = torch.cat([
365
+ torch.cat(multiview_normals[:2], dim=-2),
366
+ torch.cat(multiview_normals[2:], dim=-2),
367
+ ], dim=-1)
368
+ ret_dict.update({f'multiview_normal': {'value': multiview_normals, 'type': 'image'}})
369
+ if 'color' in return_types:
370
+ miltiview_images = torch.cat([
371
+ torch.cat(miltiview_images[:2], dim=-2),
372
+ torch.cat(miltiview_images[2:], dim=-2),
373
+ ], dim=-1)
374
+ ret_dict.update({f'multiview_image': {'value': miltiview_images, 'type': 'image'}})
375
+ if 'normal_map' in return_types:
376
+ multiview_normal_maps = torch.cat([
377
+ torch.cat(multiview_normal_maps[:2], dim=-2),
378
+ torch.cat(multiview_normal_maps[2:], dim=-2),
379
+ ], dim=-1)
380
+ ret_dict.update({f'multiview_normal_map': {'value': multiview_normal_maps, 'type': 'image'}})
381
+
382
+ return ret_dict
trellis/trainers/vae/structured_latent_vae_rf_dec.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ from easydict import EasyDict as edict
7
+ import utils3d.torch
8
+
9
+ from ..basic import BasicTrainer
10
+ from ...representations import Strivec
11
+ from ...renderers import OctreeRenderer
12
+ from ...modules.sparse import SparseTensor
13
+ from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
14
+
15
+
16
+ class SLatVaeRadianceFieldDecoderTrainer(BasicTrainer):
17
+ """
18
+ Trainer for structured latent VAE Radiance Field Decoder.
19
+
20
+ Args:
21
+ models (dict[str, nn.Module]): Models to train.
22
+ dataset (torch.utils.data.Dataset): Dataset.
23
+ output_dir (str): Output directory.
24
+ load_dir (str): Load directory.
25
+ step (int): Step to load.
26
+ batch_size (int): Batch size.
27
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
28
+ batch_split (int): Split batch with gradient accumulation.
29
+ max_steps (int): Max steps.
30
+ optimizer (dict): Optimizer config.
31
+ lr_scheduler (dict): Learning rate scheduler config.
32
+ elastic (dict): Elastic memory management config.
33
+ grad_clip (float or dict): Gradient clip config.
34
+ ema_rate (float or list): Exponential moving average rates.
35
+ fp16_mode (str): FP16 mode.
36
+ - None: No FP16.
37
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
38
+ - 'amp': Automatic mixed precision.
39
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
40
+ finetune_ckpt (dict): Finetune checkpoint.
41
+ log_param_stats (bool): Log parameter stats.
42
+ i_print (int): Print interval.
43
+ i_log (int): Log interval.
44
+ i_sample (int): Sample interval.
45
+ i_save (int): Save interval.
46
+ i_ddpcheck (int): DDP check interval.
47
+
48
+ loss_type (str): Loss type. Can be 'l1', 'l2'
49
+ lambda_ssim (float): SSIM loss weight.
50
+ lambda_lpips (float): LPIPS loss weight.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ *args,
56
+ loss_type: str = 'l1',
57
+ lambda_ssim: float = 0.2,
58
+ lambda_lpips: float = 0.2,
59
+ **kwargs
60
+ ):
61
+ super().__init__(*args, **kwargs)
62
+ self.loss_type = loss_type
63
+ self.lambda_ssim = lambda_ssim
64
+ self.lambda_lpips = lambda_lpips
65
+
66
+ self._init_renderer()
67
+
68
+ def _init_renderer(self):
69
+ rendering_options = {"near" : 0.8,
70
+ "far" : 1.6,
71
+ "bg_color" : 'random'}
72
+ self.renderer = OctreeRenderer(rendering_options)
73
+ self.renderer.pipe.primitive = 'trivec'
74
+
75
+ def _render_batch(self, reps: List[Strivec], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ Render a batch of representations.
78
+
79
+ Args:
80
+ reps: The dictionary of lists of representations.
81
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
82
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
83
+ """
84
+ ret = None
85
+ for i, representation in enumerate(reps):
86
+ render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
87
+ if ret is None:
88
+ ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
89
+ for k, v in render_pack.items():
90
+ ret[k].append(v)
91
+ ret['bg_color'].append(self.renderer.bg_color)
92
+ for k, v in ret.items():
93
+ ret[k] = torch.stack(v, dim=0)
94
+ return ret
95
+
96
+ def training_losses(
97
+ self,
98
+ latents: SparseTensor,
99
+ image: torch.Tensor,
100
+ alpha: torch.Tensor,
101
+ extrinsics: torch.Tensor,
102
+ intrinsics: torch.Tensor,
103
+ return_aux: bool = False,
104
+ **kwargs
105
+ ) -> Tuple[Dict, Dict]:
106
+ """
107
+ Compute training losses.
108
+
109
+ Args:
110
+ latents: The [N x * x C] sparse latents
111
+ image: The [N x 3 x H x W] tensor of images.
112
+ alpha: The [N x H x W] tensor of alpha channels.
113
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
114
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
115
+ return_aux: Whether to return auxiliary information.
116
+
117
+ Returns:
118
+ a dict with the key "loss" containing a scalar tensor.
119
+ may also contain other keys for different terms.
120
+ """
121
+ reps = self.training_models['decoder'](latents)
122
+ self.renderer.rendering_options.resolution = image.shape[-1]
123
+ render_results = self._render_batch(reps, extrinsics, intrinsics)
124
+
125
+ terms = edict(loss = 0.0, rec = 0.0)
126
+
127
+ rec_image = render_results['color']
128
+ gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
129
+
130
+ if self.loss_type == 'l1':
131
+ terms["l1"] = l1_loss(rec_image, gt_image)
132
+ terms["rec"] = terms["rec"] + terms["l1"]
133
+ elif self.loss_type == 'l2':
134
+ terms["l2"] = l2_loss(rec_image, gt_image)
135
+ terms["rec"] = terms["rec"] + terms["l2"]
136
+ else:
137
+ raise ValueError(f"Invalid loss type: {self.loss_type}")
138
+ if self.lambda_ssim > 0:
139
+ terms["ssim"] = 1 - ssim(rec_image, gt_image)
140
+ terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
141
+ if self.lambda_lpips > 0:
142
+ terms["lpips"] = lpips(rec_image, gt_image)
143
+ terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
144
+ terms["loss"] = terms["loss"] + terms["rec"]
145
+
146
+ if return_aux:
147
+ return terms, {}, {'rec_image': rec_image, 'gt_image': gt_image}
148
+ return terms, {}
149
+
150
+ @torch.no_grad()
151
+ def run_snapshot(
152
+ self,
153
+ num_samples: int,
154
+ batch_size: int,
155
+ verbose: bool = False,
156
+ ) -> Dict:
157
+ dataloader = DataLoader(
158
+ copy.deepcopy(self.dataset),
159
+ batch_size=batch_size,
160
+ shuffle=True,
161
+ num_workers=0,
162
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
163
+ )
164
+
165
+ # inference
166
+ ret_dict = {}
167
+ gt_images = []
168
+ exts = []
169
+ ints = []
170
+ reps = []
171
+ for i in range(0, num_samples, batch_size):
172
+ batch = min(batch_size, num_samples - i)
173
+ data = next(iter(dataloader))
174
+ args = {k: v[:batch].cuda() for k, v in data.items()}
175
+ gt_images.append(args['image'] * args['alpha'][:, None])
176
+ exts.append(args['extrinsics'])
177
+ ints.append(args['intrinsics'])
178
+ reps.extend(self.models['decoder'](args['latents']))
179
+ gt_images = torch.cat(gt_images, dim=0)
180
+ ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
181
+
182
+ # render single view
183
+ exts = torch.cat(exts, dim=0)
184
+ ints = torch.cat(ints, dim=0)
185
+ self.renderer.rendering_options.bg_color = (0, 0, 0)
186
+ self.renderer.rendering_options.resolution = gt_images.shape[-1]
187
+ render_results = self._render_batch(reps, exts, ints)
188
+ ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
189
+
190
+ # render multiview
191
+ self.renderer.rendering_options.resolution = 512
192
+ ## Build camera
193
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
194
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
195
+ yaws = [y + yaws_offset for y in yaws]
196
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
197
+
198
+ ## render each view
199
+ miltiview_images = []
200
+ for yaw, pitch in zip(yaws, pitch):
201
+ orig = torch.tensor([
202
+ np.sin(yaw) * np.cos(pitch),
203
+ np.cos(yaw) * np.cos(pitch),
204
+ np.sin(pitch),
205
+ ]).float().cuda() * 2
206
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
207
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
208
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
209
+ extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
210
+ intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
211
+ render_results = self._render_batch(reps, extrinsics, intrinsics)
212
+ miltiview_images.append(render_results['color'])
213
+
214
+ ## Concatenate views
215
+ miltiview_images = torch.cat([
216
+ torch.cat(miltiview_images[:2], dim=-2),
217
+ torch.cat(miltiview_images[2:], dim=-2),
218
+ ], dim=-1)
219
+ ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
220
+
221
+ self.renderer.rendering_options.bg_color = 'random'
222
+
223
+ return ret_dict