Spaces:
Running
on
Zero
Running
on
Zero
Upload 13 files
Browse files- trellis/trainers/__init__.py +63 -0
- trellis/trainers/base.py +451 -0
- trellis/trainers/basic.py +438 -0
- trellis/trainers/flow_matching/flow_matching.py +353 -0
- trellis/trainers/flow_matching/mixins/classifier_free_guidance.py +59 -0
- trellis/trainers/flow_matching/mixins/image_conditioned.py +93 -0
- trellis/trainers/flow_matching/mixins/text_conditioned.py +68 -0
- trellis/trainers/flow_matching/sparse_flow_matching.py +286 -0
- trellis/trainers/utils.py +77 -0
- trellis/trainers/vae/sparse_structure_vae.py +130 -0
- trellis/trainers/vae/structured_latent_vae_gaussian.py +275 -0
- trellis/trainers/vae/structured_latent_vae_mesh_dec.py +382 -0
- trellis/trainers/vae/structured_latent_vae_rf_dec.py +223 -0
trellis/trainers/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
__attributes = {
|
4 |
+
'BasicTrainer': 'basic',
|
5 |
+
|
6 |
+
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
|
7 |
+
|
8 |
+
'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
|
9 |
+
'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
|
10 |
+
'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
|
11 |
+
|
12 |
+
'FlowMatchingTrainer': 'flow_matching.flow_matching',
|
13 |
+
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
14 |
+
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
15 |
+
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
16 |
+
|
17 |
+
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
|
18 |
+
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
19 |
+
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
20 |
+
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
21 |
+
}
|
22 |
+
|
23 |
+
__submodules = []
|
24 |
+
|
25 |
+
__all__ = list(__attributes.keys()) + __submodules
|
26 |
+
|
27 |
+
def __getattr__(name):
|
28 |
+
if name not in globals():
|
29 |
+
if name in __attributes:
|
30 |
+
module_name = __attributes[name]
|
31 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
32 |
+
globals()[name] = getattr(module, name)
|
33 |
+
elif name in __submodules:
|
34 |
+
module = importlib.import_module(f".{name}", __name__)
|
35 |
+
globals()[name] = module
|
36 |
+
else:
|
37 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
38 |
+
return globals()[name]
|
39 |
+
|
40 |
+
|
41 |
+
# For Pylance
|
42 |
+
if __name__ == '__main__':
|
43 |
+
from .basic import BasicTrainer
|
44 |
+
|
45 |
+
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
|
46 |
+
|
47 |
+
from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
|
48 |
+
from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
|
49 |
+
from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
|
50 |
+
|
51 |
+
from .flow_matching.flow_matching import (
|
52 |
+
FlowMatchingTrainer,
|
53 |
+
FlowMatchingCFGTrainer,
|
54 |
+
TextConditionedFlowMatchingCFGTrainer,
|
55 |
+
ImageConditionedFlowMatchingCFGTrainer,
|
56 |
+
)
|
57 |
+
|
58 |
+
from .flow_matching.sparse_flow_matching import (
|
59 |
+
SparseFlowMatchingTrainer,
|
60 |
+
SparseFlowMatchingCFGTrainer,
|
61 |
+
TextConditionedSparseFlowMatchingCFGTrainer,
|
62 |
+
ImageConditionedSparseFlowMatchingCFGTrainer,
|
63 |
+
)
|
trellis/trainers/base.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchvision import utils
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
+
|
14 |
+
from .utils import *
|
15 |
+
from ..utils.general_utils import *
|
16 |
+
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
|
17 |
+
|
18 |
+
|
19 |
+
class Trainer:
|
20 |
+
"""
|
21 |
+
Base class for training.
|
22 |
+
"""
|
23 |
+
def __init__(self,
|
24 |
+
models,
|
25 |
+
dataset,
|
26 |
+
*,
|
27 |
+
output_dir,
|
28 |
+
load_dir,
|
29 |
+
step,
|
30 |
+
max_steps,
|
31 |
+
batch_size=None,
|
32 |
+
batch_size_per_gpu=None,
|
33 |
+
batch_split=None,
|
34 |
+
optimizer={},
|
35 |
+
lr_scheduler=None,
|
36 |
+
elastic=None,
|
37 |
+
grad_clip=None,
|
38 |
+
ema_rate=0.9999,
|
39 |
+
fp16_mode='inflat_all',
|
40 |
+
fp16_scale_growth=1e-3,
|
41 |
+
finetune_ckpt=None,
|
42 |
+
log_param_stats=False,
|
43 |
+
prefetch_data=True,
|
44 |
+
i_print=1000,
|
45 |
+
i_log=500,
|
46 |
+
i_sample=10000,
|
47 |
+
i_save=10000,
|
48 |
+
i_ddpcheck=10000,
|
49 |
+
**kwargs
|
50 |
+
):
|
51 |
+
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
|
52 |
+
|
53 |
+
self.models = models
|
54 |
+
self.dataset = dataset
|
55 |
+
self.batch_split = batch_split if batch_split is not None else 1
|
56 |
+
self.max_steps = max_steps
|
57 |
+
self.optimizer_config = optimizer
|
58 |
+
self.lr_scheduler_config = lr_scheduler
|
59 |
+
self.elastic_controller_config = elastic
|
60 |
+
self.grad_clip = grad_clip
|
61 |
+
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
|
62 |
+
self.fp16_mode = fp16_mode
|
63 |
+
self.fp16_scale_growth = fp16_scale_growth
|
64 |
+
self.log_param_stats = log_param_stats
|
65 |
+
self.prefetch_data = prefetch_data
|
66 |
+
if self.prefetch_data:
|
67 |
+
self._data_prefetched = None
|
68 |
+
|
69 |
+
self.output_dir = output_dir
|
70 |
+
self.i_print = i_print
|
71 |
+
self.i_log = i_log
|
72 |
+
self.i_sample = i_sample
|
73 |
+
self.i_save = i_save
|
74 |
+
self.i_ddpcheck = i_ddpcheck
|
75 |
+
|
76 |
+
if dist.is_initialized():
|
77 |
+
# Multi-GPU params
|
78 |
+
self.world_size = dist.get_world_size()
|
79 |
+
self.rank = dist.get_rank()
|
80 |
+
self.local_rank = dist.get_rank() % torch.cuda.device_count()
|
81 |
+
self.is_master = self.rank == 0
|
82 |
+
else:
|
83 |
+
# Single-GPU params
|
84 |
+
self.world_size = 1
|
85 |
+
self.rank = 0
|
86 |
+
self.local_rank = 0
|
87 |
+
self.is_master = True
|
88 |
+
|
89 |
+
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
|
90 |
+
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
|
91 |
+
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
|
92 |
+
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
|
93 |
+
|
94 |
+
self.init_models_and_more(**kwargs)
|
95 |
+
self.prepare_dataloader(**kwargs)
|
96 |
+
|
97 |
+
# Load checkpoint
|
98 |
+
self.step = 0
|
99 |
+
if load_dir is not None and step is not None:
|
100 |
+
self.load(load_dir, step)
|
101 |
+
elif finetune_ckpt is not None:
|
102 |
+
self.finetune_from(finetune_ckpt)
|
103 |
+
|
104 |
+
if self.is_master:
|
105 |
+
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
|
106 |
+
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
|
107 |
+
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
|
108 |
+
|
109 |
+
if self.world_size > 1:
|
110 |
+
self.check_ddp()
|
111 |
+
|
112 |
+
if self.is_master:
|
113 |
+
print('\n\nTrainer initialized.')
|
114 |
+
print(self)
|
115 |
+
|
116 |
+
@property
|
117 |
+
def device(self):
|
118 |
+
for _, model in self.models.items():
|
119 |
+
if hasattr(model, 'device'):
|
120 |
+
return model.device
|
121 |
+
return next(list(self.models.values())[0].parameters()).device
|
122 |
+
|
123 |
+
@abstractmethod
|
124 |
+
def init_models_and_more(self, **kwargs):
|
125 |
+
"""
|
126 |
+
Initialize models and more.
|
127 |
+
"""
|
128 |
+
pass
|
129 |
+
|
130 |
+
def prepare_dataloader(self, **kwargs):
|
131 |
+
"""
|
132 |
+
Prepare dataloader.
|
133 |
+
"""
|
134 |
+
self.data_sampler = ResumableSampler(
|
135 |
+
self.dataset,
|
136 |
+
shuffle=True,
|
137 |
+
)
|
138 |
+
self.dataloader = DataLoader(
|
139 |
+
self.dataset,
|
140 |
+
batch_size=self.batch_size_per_gpu,
|
141 |
+
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
142 |
+
pin_memory=True,
|
143 |
+
drop_last=True,
|
144 |
+
persistent_workers=True,
|
145 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
146 |
+
sampler=self.data_sampler,
|
147 |
+
)
|
148 |
+
self.data_iterator = cycle(self.dataloader)
|
149 |
+
|
150 |
+
@abstractmethod
|
151 |
+
def load(self, load_dir, step=0):
|
152 |
+
"""
|
153 |
+
Load a checkpoint.
|
154 |
+
Should be called by all processes.
|
155 |
+
"""
|
156 |
+
pass
|
157 |
+
|
158 |
+
@abstractmethod
|
159 |
+
def save(self):
|
160 |
+
"""
|
161 |
+
Save a checkpoint.
|
162 |
+
Should be called only by the rank 0 process.
|
163 |
+
"""
|
164 |
+
pass
|
165 |
+
|
166 |
+
@abstractmethod
|
167 |
+
def finetune_from(self, finetune_ckpt):
|
168 |
+
"""
|
169 |
+
Finetune from a checkpoint.
|
170 |
+
Should be called by all processes.
|
171 |
+
"""
|
172 |
+
pass
|
173 |
+
|
174 |
+
@abstractmethod
|
175 |
+
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
|
176 |
+
"""
|
177 |
+
Run a snapshot of the model.
|
178 |
+
"""
|
179 |
+
pass
|
180 |
+
|
181 |
+
@torch.no_grad()
|
182 |
+
def visualize_sample(self, sample):
|
183 |
+
"""
|
184 |
+
Convert a sample to an image.
|
185 |
+
"""
|
186 |
+
if hasattr(self.dataset, 'visualize_sample'):
|
187 |
+
return self.dataset.visualize_sample(sample)
|
188 |
+
else:
|
189 |
+
return sample
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def snapshot_dataset(self, num_samples=100):
|
193 |
+
"""
|
194 |
+
Sample images from the dataset.
|
195 |
+
"""
|
196 |
+
dataloader = torch.utils.data.DataLoader(
|
197 |
+
self.dataset,
|
198 |
+
batch_size=num_samples,
|
199 |
+
num_workers=0,
|
200 |
+
shuffle=True,
|
201 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
202 |
+
)
|
203 |
+
data = next(iter(dataloader))
|
204 |
+
data = recursive_to_device(data, self.device)
|
205 |
+
vis = self.visualize_sample(data)
|
206 |
+
if isinstance(vis, dict):
|
207 |
+
save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
|
208 |
+
else:
|
209 |
+
save_cfg = [('dataset', vis)]
|
210 |
+
for name, image in save_cfg:
|
211 |
+
utils.save_image(
|
212 |
+
image,
|
213 |
+
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
|
214 |
+
nrow=int(np.sqrt(num_samples)),
|
215 |
+
normalize=True,
|
216 |
+
value_range=self.dataset.value_range,
|
217 |
+
)
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
|
221 |
+
"""
|
222 |
+
Sample images from the model.
|
223 |
+
NOTE: This function should be called by all processes.
|
224 |
+
"""
|
225 |
+
if self.is_master:
|
226 |
+
print(f'\nSampling {num_samples} images...', end='')
|
227 |
+
|
228 |
+
if suffix is None:
|
229 |
+
suffix = f'step{self.step:07d}'
|
230 |
+
|
231 |
+
# Assign tasks
|
232 |
+
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
|
233 |
+
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
|
234 |
+
|
235 |
+
# Preprocess images
|
236 |
+
for key in list(samples.keys()):
|
237 |
+
if samples[key]['type'] == 'sample':
|
238 |
+
vis = self.visualize_sample(samples[key]['value'])
|
239 |
+
if isinstance(vis, dict):
|
240 |
+
for k, v in vis.items():
|
241 |
+
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
|
242 |
+
del samples[key]
|
243 |
+
else:
|
244 |
+
samples[key] = {'value': vis, 'type': 'image'}
|
245 |
+
|
246 |
+
# Gather results
|
247 |
+
if self.world_size > 1:
|
248 |
+
for key in samples.keys():
|
249 |
+
samples[key]['value'] = samples[key]['value'].contiguous()
|
250 |
+
if self.is_master:
|
251 |
+
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
|
252 |
+
else:
|
253 |
+
all_images = []
|
254 |
+
dist.gather(samples[key]['value'], all_images, dst=0)
|
255 |
+
if self.is_master:
|
256 |
+
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
|
257 |
+
|
258 |
+
# Save images
|
259 |
+
if self.is_master:
|
260 |
+
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
|
261 |
+
for key in samples.keys():
|
262 |
+
if samples[key]['type'] == 'image':
|
263 |
+
utils.save_image(
|
264 |
+
samples[key]['value'],
|
265 |
+
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
266 |
+
nrow=int(np.sqrt(num_samples)),
|
267 |
+
normalize=True,
|
268 |
+
value_range=self.dataset.value_range,
|
269 |
+
)
|
270 |
+
elif samples[key]['type'] == 'number':
|
271 |
+
min = samples[key]['value'].min()
|
272 |
+
max = samples[key]['value'].max()
|
273 |
+
images = (samples[key]['value'] - min) / (max - min)
|
274 |
+
images = utils.make_grid(
|
275 |
+
images,
|
276 |
+
nrow=int(np.sqrt(num_samples)),
|
277 |
+
normalize=False,
|
278 |
+
)
|
279 |
+
save_image_with_notes(
|
280 |
+
images,
|
281 |
+
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
282 |
+
notes=f'{key} min: {min}, max: {max}',
|
283 |
+
)
|
284 |
+
|
285 |
+
if self.is_master:
|
286 |
+
print(' Done.')
|
287 |
+
|
288 |
+
@abstractmethod
|
289 |
+
def update_ema(self):
|
290 |
+
"""
|
291 |
+
Update exponential moving average.
|
292 |
+
Should only be called by the rank 0 process.
|
293 |
+
"""
|
294 |
+
pass
|
295 |
+
|
296 |
+
@abstractmethod
|
297 |
+
def check_ddp(self):
|
298 |
+
"""
|
299 |
+
Check if DDP is working properly.
|
300 |
+
Should be called by all process.
|
301 |
+
"""
|
302 |
+
pass
|
303 |
+
|
304 |
+
@abstractmethod
|
305 |
+
def training_losses(**mb_data):
|
306 |
+
"""
|
307 |
+
Compute training losses.
|
308 |
+
"""
|
309 |
+
pass
|
310 |
+
|
311 |
+
def load_data(self):
|
312 |
+
"""
|
313 |
+
Load data.
|
314 |
+
"""
|
315 |
+
if self.prefetch_data:
|
316 |
+
if self._data_prefetched is None:
|
317 |
+
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
318 |
+
data = self._data_prefetched
|
319 |
+
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
320 |
+
else:
|
321 |
+
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
322 |
+
|
323 |
+
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
|
324 |
+
if isinstance(data, dict):
|
325 |
+
if self.batch_split == 1:
|
326 |
+
data_list = [data]
|
327 |
+
else:
|
328 |
+
batch_size = list(data.values())[0].shape[0]
|
329 |
+
data_list = [
|
330 |
+
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
|
331 |
+
for i in range(self.batch_split)
|
332 |
+
]
|
333 |
+
elif isinstance(data, list):
|
334 |
+
data_list = data
|
335 |
+
else:
|
336 |
+
raise ValueError('Data must be a dict or a list of dicts.')
|
337 |
+
|
338 |
+
return data_list
|
339 |
+
|
340 |
+
@abstractmethod
|
341 |
+
def run_step(self, data_list):
|
342 |
+
"""
|
343 |
+
Run a training step.
|
344 |
+
"""
|
345 |
+
pass
|
346 |
+
|
347 |
+
def run(self):
|
348 |
+
"""
|
349 |
+
Run training.
|
350 |
+
"""
|
351 |
+
if self.is_master:
|
352 |
+
print('\nStarting training...')
|
353 |
+
self.snapshot_dataset()
|
354 |
+
if self.step == 0:
|
355 |
+
self.snapshot(suffix='init')
|
356 |
+
else: # resume
|
357 |
+
self.snapshot(suffix=f'resume_step{self.step:07d}')
|
358 |
+
|
359 |
+
log = []
|
360 |
+
time_last_print = 0.0
|
361 |
+
time_elapsed = 0.0
|
362 |
+
while self.step < self.max_steps:
|
363 |
+
time_start = time.time()
|
364 |
+
|
365 |
+
data_list = self.load_data()
|
366 |
+
step_log = self.run_step(data_list)
|
367 |
+
|
368 |
+
time_end = time.time()
|
369 |
+
time_elapsed += time_end - time_start
|
370 |
+
|
371 |
+
self.step += 1
|
372 |
+
|
373 |
+
# Print progress
|
374 |
+
if self.is_master and self.step % self.i_print == 0:
|
375 |
+
speed = self.i_print / (time_elapsed - time_last_print) * 3600
|
376 |
+
columns = [
|
377 |
+
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
|
378 |
+
f'Elapsed: {time_elapsed / 3600:.2f} h',
|
379 |
+
f'Speed: {speed:.2f} steps/h',
|
380 |
+
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
|
381 |
+
]
|
382 |
+
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
|
383 |
+
time_last_print = time_elapsed
|
384 |
+
|
385 |
+
# Check ddp
|
386 |
+
if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
|
387 |
+
self.check_ddp()
|
388 |
+
|
389 |
+
# Sample images
|
390 |
+
if self.step % self.i_sample == 0:
|
391 |
+
self.snapshot()
|
392 |
+
|
393 |
+
if self.is_master:
|
394 |
+
log.append((self.step, {}))
|
395 |
+
|
396 |
+
# Log time
|
397 |
+
log[-1][1]['time'] = {
|
398 |
+
'step': time_end - time_start,
|
399 |
+
'elapsed': time_elapsed,
|
400 |
+
}
|
401 |
+
|
402 |
+
# Log losses
|
403 |
+
if step_log is not None:
|
404 |
+
log[-1][1].update(step_log)
|
405 |
+
|
406 |
+
# Log scale
|
407 |
+
if self.fp16_mode == 'amp':
|
408 |
+
log[-1][1]['scale'] = self.scaler.get_scale()
|
409 |
+
elif self.fp16_mode == 'inflat_all':
|
410 |
+
log[-1][1]['log_scale'] = self.log_scale
|
411 |
+
|
412 |
+
# Save log
|
413 |
+
if self.step % self.i_log == 0:
|
414 |
+
## save to log file
|
415 |
+
log_str = '\n'.join([
|
416 |
+
f'{step}: {json.dumps(log)}' for step, log in log
|
417 |
+
])
|
418 |
+
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
|
419 |
+
log_file.write(log_str + '\n')
|
420 |
+
|
421 |
+
# show with mlflow
|
422 |
+
log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
|
423 |
+
log_show = dict_reduce(log_show, lambda x: np.mean(x))
|
424 |
+
log_show = dict_flatten(log_show, sep='/')
|
425 |
+
for key, value in log_show.items():
|
426 |
+
self.writer.add_scalar(key, value, self.step)
|
427 |
+
log = []
|
428 |
+
|
429 |
+
# Save checkpoint
|
430 |
+
if self.step % self.i_save == 0:
|
431 |
+
self.save()
|
432 |
+
|
433 |
+
if self.is_master:
|
434 |
+
self.snapshot(suffix='final')
|
435 |
+
self.writer.close()
|
436 |
+
print('Training finished.')
|
437 |
+
|
438 |
+
def profile(self, wait=2, warmup=3, active=5):
|
439 |
+
"""
|
440 |
+
Profile the training loop.
|
441 |
+
"""
|
442 |
+
with torch.profiler.profile(
|
443 |
+
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
444 |
+
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
|
445 |
+
profile_memory=True,
|
446 |
+
with_stack=True,
|
447 |
+
) as prof:
|
448 |
+
for _ in range(wait + warmup + active):
|
449 |
+
self.run_step()
|
450 |
+
prof.step()
|
451 |
+
|
trellis/trainers/basic.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
from functools import partial
|
4 |
+
from contextlib import nullcontext
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from .utils import *
|
12 |
+
from .base import Trainer
|
13 |
+
from ..utils.general_utils import *
|
14 |
+
from ..utils.dist_utils import *
|
15 |
+
from ..utils import grad_clip_utils, elastic_utils
|
16 |
+
|
17 |
+
|
18 |
+
class BasicTrainer(Trainer):
|
19 |
+
"""
|
20 |
+
Trainer for basic training loop.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
models (dict[str, nn.Module]): Models to train.
|
24 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
25 |
+
output_dir (str): Output directory.
|
26 |
+
load_dir (str): Load directory.
|
27 |
+
step (int): Step to load.
|
28 |
+
batch_size (int): Batch size.
|
29 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
30 |
+
batch_split (int): Split batch with gradient accumulation.
|
31 |
+
max_steps (int): Max steps.
|
32 |
+
optimizer (dict): Optimizer config.
|
33 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
34 |
+
elastic (dict): Elastic memory management config.
|
35 |
+
grad_clip (float or dict): Gradient clip config.
|
36 |
+
ema_rate (float or list): Exponential moving average rates.
|
37 |
+
fp16_mode (str): FP16 mode.
|
38 |
+
- None: No FP16.
|
39 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
40 |
+
- 'amp': Automatic mixed precision.
|
41 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
42 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
43 |
+
log_param_stats (bool): Log parameter stats.
|
44 |
+
i_print (int): Print interval.
|
45 |
+
i_log (int): Log interval.
|
46 |
+
i_sample (int): Sample interval.
|
47 |
+
i_save (int): Save interval.
|
48 |
+
i_ddpcheck (int): DDP check interval.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __str__(self):
|
52 |
+
lines = []
|
53 |
+
lines.append(self.__class__.__name__)
|
54 |
+
lines.append(f' - Models:')
|
55 |
+
for name, model in self.models.items():
|
56 |
+
lines.append(f' - {name}: {model.__class__.__name__}')
|
57 |
+
lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
|
58 |
+
lines.append(f' - Dataloader:')
|
59 |
+
lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
|
60 |
+
lines.append(f' - Num workers: {self.dataloader.num_workers}')
|
61 |
+
lines.append(f' - Number of steps: {self.max_steps}')
|
62 |
+
lines.append(f' - Number of GPUs: {self.world_size}')
|
63 |
+
lines.append(f' - Batch size: {self.batch_size}')
|
64 |
+
lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
|
65 |
+
lines.append(f' - Batch split: {self.batch_split}')
|
66 |
+
lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
|
67 |
+
lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
|
68 |
+
if self.lr_scheduler_config is not None:
|
69 |
+
lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
|
70 |
+
if self.elastic_controller_config is not None:
|
71 |
+
lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
|
72 |
+
if self.grad_clip is not None:
|
73 |
+
lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
|
74 |
+
lines.append(f' - EMA rate: {self.ema_rate}')
|
75 |
+
lines.append(f' - FP16 mode: {self.fp16_mode}')
|
76 |
+
return '\n'.join(lines)
|
77 |
+
|
78 |
+
def init_models_and_more(self, **kwargs):
|
79 |
+
"""
|
80 |
+
Initialize models and more.
|
81 |
+
"""
|
82 |
+
if self.world_size > 1:
|
83 |
+
# Prepare distributed data parallel
|
84 |
+
self.training_models = {
|
85 |
+
name: DDP(
|
86 |
+
model,
|
87 |
+
device_ids=[self.local_rank],
|
88 |
+
output_device=self.local_rank,
|
89 |
+
bucket_cap_mb=128,
|
90 |
+
find_unused_parameters=False
|
91 |
+
)
|
92 |
+
for name, model in self.models.items()
|
93 |
+
}
|
94 |
+
else:
|
95 |
+
self.training_models = self.models
|
96 |
+
|
97 |
+
# Build master params
|
98 |
+
self.model_params = sum(
|
99 |
+
[[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
|
100 |
+
, [])
|
101 |
+
if self.fp16_mode == 'amp':
|
102 |
+
self.master_params = self.model_params
|
103 |
+
self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
|
104 |
+
elif self.fp16_mode == 'inflat_all':
|
105 |
+
self.master_params = make_master_params(self.model_params)
|
106 |
+
self.fp16_scale_growth = self.fp16_scale_growth
|
107 |
+
self.log_scale = 20.0
|
108 |
+
elif self.fp16_mode is None:
|
109 |
+
self.master_params = self.model_params
|
110 |
+
else:
|
111 |
+
raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
|
112 |
+
|
113 |
+
# Build EMA params
|
114 |
+
if self.is_master:
|
115 |
+
self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
|
116 |
+
|
117 |
+
# Initialize optimizer
|
118 |
+
if hasattr(torch.optim, self.optimizer_config['name']):
|
119 |
+
self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
|
120 |
+
else:
|
121 |
+
self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
|
122 |
+
|
123 |
+
# Initalize learning rate scheduler
|
124 |
+
if self.lr_scheduler_config is not None:
|
125 |
+
if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
|
126 |
+
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
|
127 |
+
else:
|
128 |
+
self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
|
129 |
+
|
130 |
+
# Initialize elastic memory controller
|
131 |
+
if self.elastic_controller_config is not None:
|
132 |
+
assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
|
133 |
+
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
|
134 |
+
self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
|
135 |
+
for model in self.models.values():
|
136 |
+
if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
|
137 |
+
model.register_memory_controller(self.elastic_controller)
|
138 |
+
|
139 |
+
# Initialize gradient clipper
|
140 |
+
if self.grad_clip is not None:
|
141 |
+
if isinstance(self.grad_clip, (float, int)):
|
142 |
+
self.grad_clip = float(self.grad_clip)
|
143 |
+
else:
|
144 |
+
self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
|
145 |
+
|
146 |
+
def _master_params_to_state_dicts(self, master_params):
|
147 |
+
"""
|
148 |
+
Convert master params to dict of state_dicts.
|
149 |
+
"""
|
150 |
+
if self.fp16_mode == 'inflat_all':
|
151 |
+
master_params = unflatten_master_params(self.model_params, master_params)
|
152 |
+
state_dicts = {name: model.state_dict() for name, model in self.models.items()}
|
153 |
+
master_params_names = sum(
|
154 |
+
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
155 |
+
, [])
|
156 |
+
for i, (model_name, param_name) in enumerate(master_params_names):
|
157 |
+
state_dicts[model_name][param_name] = master_params[i]
|
158 |
+
return state_dicts
|
159 |
+
|
160 |
+
def _state_dicts_to_master_params(self, master_params, state_dicts):
|
161 |
+
"""
|
162 |
+
Convert a state_dict to master params.
|
163 |
+
"""
|
164 |
+
master_params_names = sum(
|
165 |
+
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
166 |
+
, [])
|
167 |
+
params = [state_dicts[name][param_name] for name, param_name in master_params_names]
|
168 |
+
if self.fp16_mode == 'inflat_all':
|
169 |
+
model_params_to_master_params(params, master_params)
|
170 |
+
else:
|
171 |
+
for i, param in enumerate(params):
|
172 |
+
master_params[i].data.copy_(param.data)
|
173 |
+
|
174 |
+
def load(self, load_dir, step=0):
|
175 |
+
"""
|
176 |
+
Load a checkpoint.
|
177 |
+
Should be called by all processes.
|
178 |
+
"""
|
179 |
+
if self.is_master:
|
180 |
+
print(f'\nLoading checkpoint from step {step}...', end='')
|
181 |
+
|
182 |
+
model_ckpts = {}
|
183 |
+
for name, model in self.models.items():
|
184 |
+
model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
|
185 |
+
model_ckpts[name] = model_ckpt
|
186 |
+
model.load_state_dict(model_ckpt)
|
187 |
+
if self.fp16_mode == 'inflat_all':
|
188 |
+
model.convert_to_fp16()
|
189 |
+
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
190 |
+
del model_ckpts
|
191 |
+
|
192 |
+
if self.is_master:
|
193 |
+
for i, ema_rate in enumerate(self.ema_rate):
|
194 |
+
ema_ckpts = {}
|
195 |
+
for name, model in self.models.items():
|
196 |
+
ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
|
197 |
+
ema_ckpts[name] = ema_ckpt
|
198 |
+
self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
|
199 |
+
del ema_ckpts
|
200 |
+
|
201 |
+
misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
|
202 |
+
self.optimizer.load_state_dict(misc_ckpt['optimizer'])
|
203 |
+
self.step = misc_ckpt['step']
|
204 |
+
self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
|
205 |
+
if self.fp16_mode == 'amp':
|
206 |
+
self.scaler.load_state_dict(misc_ckpt['scaler'])
|
207 |
+
elif self.fp16_mode == 'inflat_all':
|
208 |
+
self.log_scale = misc_ckpt['log_scale']
|
209 |
+
if self.lr_scheduler_config is not None:
|
210 |
+
self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
|
211 |
+
if self.elastic_controller_config is not None:
|
212 |
+
self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
|
213 |
+
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
214 |
+
self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
|
215 |
+
del misc_ckpt
|
216 |
+
|
217 |
+
if self.world_size > 1:
|
218 |
+
dist.barrier()
|
219 |
+
if self.is_master:
|
220 |
+
print(' Done.')
|
221 |
+
|
222 |
+
if self.world_size > 1:
|
223 |
+
self.check_ddp()
|
224 |
+
|
225 |
+
def save(self):
|
226 |
+
"""
|
227 |
+
Save a checkpoint.
|
228 |
+
Should be called only by the rank 0 process.
|
229 |
+
"""
|
230 |
+
assert self.is_master, 'save() should be called only by the rank 0 process.'
|
231 |
+
print(f'\nSaving checkpoint at step {self.step}...', end='')
|
232 |
+
|
233 |
+
model_ckpts = self._master_params_to_state_dicts(self.master_params)
|
234 |
+
for name, model_ckpt in model_ckpts.items():
|
235 |
+
torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
|
236 |
+
|
237 |
+
for i, ema_rate in enumerate(self.ema_rate):
|
238 |
+
ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
|
239 |
+
for name, ema_ckpt in ema_ckpts.items():
|
240 |
+
torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
|
241 |
+
|
242 |
+
misc_ckpt = {
|
243 |
+
'optimizer': self.optimizer.state_dict(),
|
244 |
+
'step': self.step,
|
245 |
+
'data_sampler': self.data_sampler.state_dict(),
|
246 |
+
}
|
247 |
+
if self.fp16_mode == 'amp':
|
248 |
+
misc_ckpt['scaler'] = self.scaler.state_dict()
|
249 |
+
elif self.fp16_mode == 'inflat_all':
|
250 |
+
misc_ckpt['log_scale'] = self.log_scale
|
251 |
+
if self.lr_scheduler_config is not None:
|
252 |
+
misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
|
253 |
+
if self.elastic_controller_config is not None:
|
254 |
+
misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
|
255 |
+
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
256 |
+
misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
|
257 |
+
torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
|
258 |
+
print(' Done.')
|
259 |
+
|
260 |
+
def finetune_from(self, finetune_ckpt):
|
261 |
+
"""
|
262 |
+
Finetune from a checkpoint.
|
263 |
+
Should be called by all processes.
|
264 |
+
"""
|
265 |
+
if self.is_master:
|
266 |
+
print('\nFinetuning from:')
|
267 |
+
for name, path in finetune_ckpt.items():
|
268 |
+
print(f' - {name}: {path}')
|
269 |
+
|
270 |
+
model_ckpts = {}
|
271 |
+
for name, model in self.models.items():
|
272 |
+
model_state_dict = model.state_dict()
|
273 |
+
if name in finetune_ckpt:
|
274 |
+
model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
|
275 |
+
for k, v in model_ckpt.items():
|
276 |
+
if model_ckpt[k].shape != model_state_dict[k].shape:
|
277 |
+
if self.is_master:
|
278 |
+
print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
|
279 |
+
model_ckpt[k] = model_state_dict[k]
|
280 |
+
model_ckpts[name] = model_ckpt
|
281 |
+
model.load_state_dict(model_ckpt)
|
282 |
+
if self.fp16_mode == 'inflat_all':
|
283 |
+
model.convert_to_fp16()
|
284 |
+
else:
|
285 |
+
if self.is_master:
|
286 |
+
print(f'Warning: {name} not found in finetune_ckpt, skipped.')
|
287 |
+
model_ckpts[name] = model_state_dict
|
288 |
+
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
289 |
+
del model_ckpts
|
290 |
+
|
291 |
+
if self.world_size > 1:
|
292 |
+
dist.barrier()
|
293 |
+
if self.is_master:
|
294 |
+
print('Done.')
|
295 |
+
|
296 |
+
if self.world_size > 1:
|
297 |
+
self.check_ddp()
|
298 |
+
|
299 |
+
def update_ema(self):
|
300 |
+
"""
|
301 |
+
Update exponential moving average.
|
302 |
+
Should only be called by the rank 0 process.
|
303 |
+
"""
|
304 |
+
assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
|
305 |
+
for i, ema_rate in enumerate(self.ema_rate):
|
306 |
+
for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
|
307 |
+
ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
|
308 |
+
|
309 |
+
def check_ddp(self):
|
310 |
+
"""
|
311 |
+
Check if DDP is working properly.
|
312 |
+
Should be called by all process.
|
313 |
+
"""
|
314 |
+
if self.is_master:
|
315 |
+
print('\nPerforming DDP check...')
|
316 |
+
|
317 |
+
if self.is_master:
|
318 |
+
print('Checking if parameters are consistent across processes...')
|
319 |
+
dist.barrier()
|
320 |
+
try:
|
321 |
+
for p in self.master_params:
|
322 |
+
# split to avoid OOM
|
323 |
+
for i in range(0, p.numel(), 10000000):
|
324 |
+
sub_size = min(10000000, p.numel() - i)
|
325 |
+
sub_p = p.detach().view(-1)[i:i+sub_size]
|
326 |
+
# gather from all processes
|
327 |
+
sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
|
328 |
+
dist.all_gather(sub_p_gather, sub_p)
|
329 |
+
# check if equal
|
330 |
+
assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
|
331 |
+
except AssertionError as e:
|
332 |
+
if self.is_master:
|
333 |
+
print(f'\n\033[91mError: {e}\033[0m')
|
334 |
+
print('DDP check failed.')
|
335 |
+
raise e
|
336 |
+
|
337 |
+
dist.barrier()
|
338 |
+
if self.is_master:
|
339 |
+
print('Done.')
|
340 |
+
|
341 |
+
def run_step(self, data_list):
|
342 |
+
"""
|
343 |
+
Run a training step.
|
344 |
+
"""
|
345 |
+
step_log = {'loss': {}, 'status': {}}
|
346 |
+
amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
|
347 |
+
elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
|
348 |
+
|
349 |
+
# Train
|
350 |
+
losses = []
|
351 |
+
statuses = []
|
352 |
+
elastic_controller_logs = []
|
353 |
+
zero_grad(self.model_params)
|
354 |
+
for i, mb_data in enumerate(data_list):
|
355 |
+
## sync at the end of each batch split
|
356 |
+
sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
|
357 |
+
with nested_contexts(*sync_contexts), elastic_controller_context():
|
358 |
+
with amp_context():
|
359 |
+
loss, status = self.training_losses(**mb_data)
|
360 |
+
l = loss['loss'] / len(data_list)
|
361 |
+
## backward
|
362 |
+
if self.fp16_mode == 'amp':
|
363 |
+
self.scaler.scale(l).backward()
|
364 |
+
elif self.fp16_mode == 'inflat_all':
|
365 |
+
scaled_l = l * (2 ** self.log_scale)
|
366 |
+
scaled_l.backward()
|
367 |
+
else:
|
368 |
+
l.backward()
|
369 |
+
## log
|
370 |
+
losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
371 |
+
statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
372 |
+
if self.elastic_controller_config is not None:
|
373 |
+
elastic_controller_logs.append(self.elastic_controller.log())
|
374 |
+
## gradient clip
|
375 |
+
if self.grad_clip is not None:
|
376 |
+
if self.fp16_mode == 'amp':
|
377 |
+
self.scaler.unscale_(self.optimizer)
|
378 |
+
elif self.fp16_mode == 'inflat_all':
|
379 |
+
model_grads_to_master_grads(self.model_params, self.master_params)
|
380 |
+
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
381 |
+
if isinstance(self.grad_clip, float):
|
382 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
|
383 |
+
else:
|
384 |
+
grad_norm = self.grad_clip(self.master_params)
|
385 |
+
if torch.isfinite(grad_norm):
|
386 |
+
statuses[-1]['grad_norm'] = grad_norm.item()
|
387 |
+
## step
|
388 |
+
if self.fp16_mode == 'amp':
|
389 |
+
prev_scale = self.scaler.get_scale()
|
390 |
+
self.scaler.step(self.optimizer)
|
391 |
+
self.scaler.update()
|
392 |
+
elif self.fp16_mode == 'inflat_all':
|
393 |
+
prev_scale = 2 ** self.log_scale
|
394 |
+
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
395 |
+
if self.grad_clip is None:
|
396 |
+
model_grads_to_master_grads(self.model_params, self.master_params)
|
397 |
+
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
398 |
+
self.optimizer.step()
|
399 |
+
master_params_to_model_params(self.model_params, self.master_params)
|
400 |
+
self.log_scale += self.fp16_scale_growth
|
401 |
+
else:
|
402 |
+
self.log_scale -= 1
|
403 |
+
else:
|
404 |
+
prev_scale = 1.0
|
405 |
+
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
406 |
+
self.optimizer.step()
|
407 |
+
else:
|
408 |
+
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
|
409 |
+
## adjust learning rate
|
410 |
+
if self.lr_scheduler_config is not None:
|
411 |
+
statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
|
412 |
+
self.lr_scheduler.step()
|
413 |
+
|
414 |
+
# Logs
|
415 |
+
step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
|
416 |
+
step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
|
417 |
+
if self.elastic_controller_config is not None:
|
418 |
+
step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
|
419 |
+
if self.grad_clip is not None:
|
420 |
+
step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
|
421 |
+
|
422 |
+
# Check grad and norm of each param
|
423 |
+
if self.log_param_stats:
|
424 |
+
param_norms = {}
|
425 |
+
param_grads = {}
|
426 |
+
for name, param in self.backbone.named_parameters():
|
427 |
+
if param.requires_grad:
|
428 |
+
param_norms[name] = param.norm().item()
|
429 |
+
if param.grad is not None and torch.isfinite(param.grad).all():
|
430 |
+
param_grads[name] = param.grad.norm().item() / prev_scale
|
431 |
+
step_log['param_norms'] = param_norms
|
432 |
+
step_log['param_grads'] = param_grads
|
433 |
+
|
434 |
+
# Update exponential moving average
|
435 |
+
if self.is_master:
|
436 |
+
self.update_ema()
|
437 |
+
|
438 |
+
return step_log
|
trellis/trainers/flow_matching/flow_matching.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import numpy as np
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
|
9 |
+
from ..basic import BasicTrainer
|
10 |
+
from ...pipelines import samplers
|
11 |
+
from ...utils.general_utils import dict_reduce
|
12 |
+
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
|
13 |
+
from .mixins.text_conditioned import TextConditionedMixin
|
14 |
+
from .mixins.image_conditioned import ImageConditionedMixin
|
15 |
+
|
16 |
+
|
17 |
+
class FlowMatchingTrainer(BasicTrainer):
|
18 |
+
"""
|
19 |
+
Trainer for diffusion model with flow matching objective.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
models (dict[str, nn.Module]): Models to train.
|
23 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
24 |
+
output_dir (str): Output directory.
|
25 |
+
load_dir (str): Load directory.
|
26 |
+
step (int): Step to load.
|
27 |
+
batch_size (int): Batch size.
|
28 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
29 |
+
batch_split (int): Split batch with gradient accumulation.
|
30 |
+
max_steps (int): Max steps.
|
31 |
+
optimizer (dict): Optimizer config.
|
32 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
33 |
+
elastic (dict): Elastic memory management config.
|
34 |
+
grad_clip (float or dict): Gradient clip config.
|
35 |
+
ema_rate (float or list): Exponential moving average rates.
|
36 |
+
fp16_mode (str): FP16 mode.
|
37 |
+
- None: No FP16.
|
38 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
39 |
+
- 'amp': Automatic mixed precision.
|
40 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
41 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
42 |
+
log_param_stats (bool): Log parameter stats.
|
43 |
+
i_print (int): Print interval.
|
44 |
+
i_log (int): Log interval.
|
45 |
+
i_sample (int): Sample interval.
|
46 |
+
i_save (int): Save interval.
|
47 |
+
i_ddpcheck (int): DDP check interval.
|
48 |
+
|
49 |
+
t_schedule (dict): Time schedule for flow matching.
|
50 |
+
sigma_min (float): Minimum noise level.
|
51 |
+
"""
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
*args,
|
55 |
+
t_schedule: dict = {
|
56 |
+
'name': 'logitNormal',
|
57 |
+
'args': {
|
58 |
+
'mean': 0.0,
|
59 |
+
'std': 1.0,
|
60 |
+
}
|
61 |
+
},
|
62 |
+
sigma_min: float = 1e-5,
|
63 |
+
**kwargs
|
64 |
+
):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
self.t_schedule = t_schedule
|
67 |
+
self.sigma_min = sigma_min
|
68 |
+
|
69 |
+
def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
|
70 |
+
"""
|
71 |
+
Diffuse the data for a given number of diffusion steps.
|
72 |
+
In other words, sample from q(x_t | x_0).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
x_0: The [N x C x ...] tensor of noiseless inputs.
|
76 |
+
t: The [N] tensor of diffusion steps [0-1].
|
77 |
+
noise: If specified, use this noise instead of generating new noise.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
x_t, the noisy version of x_0 under timestep t.
|
81 |
+
"""
|
82 |
+
if noise is None:
|
83 |
+
noise = torch.randn_like(x_0)
|
84 |
+
assert noise.shape == x_0.shape, "noise must have same shape as x_0"
|
85 |
+
|
86 |
+
t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
|
87 |
+
x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
|
88 |
+
|
89 |
+
return x_t
|
90 |
+
|
91 |
+
def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
Get original image from noisy version under timestep t.
|
94 |
+
"""
|
95 |
+
assert noise.shape == x_t.shape, "noise must have same shape as x_t"
|
96 |
+
t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
|
97 |
+
x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
|
98 |
+
return x_0
|
99 |
+
|
100 |
+
def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
101 |
+
"""
|
102 |
+
Compute the velocity of the diffusion process at time t.
|
103 |
+
"""
|
104 |
+
return (1 - self.sigma_min) * noise - x_0
|
105 |
+
|
106 |
+
def get_cond(self, cond, **kwargs):
|
107 |
+
"""
|
108 |
+
Get the conditioning data.
|
109 |
+
"""
|
110 |
+
return cond
|
111 |
+
|
112 |
+
def get_inference_cond(self, cond, **kwargs):
|
113 |
+
"""
|
114 |
+
Get the conditioning data for inference.
|
115 |
+
"""
|
116 |
+
return {'cond': cond, **kwargs}
|
117 |
+
|
118 |
+
def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
|
119 |
+
"""
|
120 |
+
Get the sampler for the diffusion process.
|
121 |
+
"""
|
122 |
+
return samplers.FlowEulerSampler(self.sigma_min)
|
123 |
+
|
124 |
+
def vis_cond(self, **kwargs):
|
125 |
+
"""
|
126 |
+
Visualize the conditioning data.
|
127 |
+
"""
|
128 |
+
return {}
|
129 |
+
|
130 |
+
def sample_t(self, batch_size: int) -> torch.Tensor:
|
131 |
+
"""
|
132 |
+
Sample timesteps.
|
133 |
+
"""
|
134 |
+
if self.t_schedule['name'] == 'uniform':
|
135 |
+
t = torch.rand(batch_size)
|
136 |
+
elif self.t_schedule['name'] == 'logitNormal':
|
137 |
+
mean = self.t_schedule['args']['mean']
|
138 |
+
std = self.t_schedule['args']['std']
|
139 |
+
t = torch.sigmoid(torch.randn(batch_size) * std + mean)
|
140 |
+
else:
|
141 |
+
raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
|
142 |
+
return t
|
143 |
+
|
144 |
+
def training_losses(
|
145 |
+
self,
|
146 |
+
x_0: torch.Tensor,
|
147 |
+
cond=None,
|
148 |
+
**kwargs
|
149 |
+
) -> Tuple[Dict, Dict]:
|
150 |
+
"""
|
151 |
+
Compute training losses for a single timestep.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
x_0: The [N x C x ...] tensor of noiseless inputs.
|
155 |
+
cond: The [N x ...] tensor of additional conditions.
|
156 |
+
kwargs: Additional arguments to pass to the backbone.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
a dict with the key "loss" containing a tensor of shape [N].
|
160 |
+
may also contain other keys for different terms.
|
161 |
+
"""
|
162 |
+
noise = torch.randn_like(x_0)
|
163 |
+
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
|
164 |
+
x_t = self.diffuse(x_0, t, noise=noise)
|
165 |
+
cond = self.get_cond(cond, **kwargs)
|
166 |
+
|
167 |
+
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
|
168 |
+
assert pred.shape == noise.shape == x_0.shape
|
169 |
+
target = self.get_v(x_0, noise, t)
|
170 |
+
terms = edict()
|
171 |
+
terms["mse"] = F.mse_loss(pred, target)
|
172 |
+
terms["loss"] = terms["mse"]
|
173 |
+
|
174 |
+
# log loss with time bins
|
175 |
+
mse_per_instance = np.array([
|
176 |
+
F.mse_loss(pred[i], target[i]).item()
|
177 |
+
for i in range(x_0.shape[0])
|
178 |
+
])
|
179 |
+
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
|
180 |
+
for i in range(10):
|
181 |
+
if (time_bin == i).sum() != 0:
|
182 |
+
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
|
183 |
+
|
184 |
+
return terms, {}
|
185 |
+
|
186 |
+
@torch.no_grad()
|
187 |
+
def run_snapshot(
|
188 |
+
self,
|
189 |
+
num_samples: int,
|
190 |
+
batch_size: int,
|
191 |
+
verbose: bool = False,
|
192 |
+
) -> Dict:
|
193 |
+
dataloader = DataLoader(
|
194 |
+
copy.deepcopy(self.dataset),
|
195 |
+
batch_size=batch_size,
|
196 |
+
shuffle=True,
|
197 |
+
num_workers=0,
|
198 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
199 |
+
)
|
200 |
+
|
201 |
+
# inference
|
202 |
+
sampler = self.get_sampler()
|
203 |
+
sample_gt = []
|
204 |
+
sample = []
|
205 |
+
cond_vis = []
|
206 |
+
for i in range(0, num_samples, batch_size):
|
207 |
+
batch = min(batch_size, num_samples - i)
|
208 |
+
data = next(iter(dataloader))
|
209 |
+
data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
|
210 |
+
noise = torch.randn_like(data['x_0'])
|
211 |
+
sample_gt.append(data['x_0'])
|
212 |
+
cond_vis.append(self.vis_cond(**data))
|
213 |
+
del data['x_0']
|
214 |
+
args = self.get_inference_cond(**data)
|
215 |
+
res = sampler.sample(
|
216 |
+
self.models['denoiser'],
|
217 |
+
noise=noise,
|
218 |
+
**args,
|
219 |
+
steps=50, cfg_strength=3.0, verbose=verbose,
|
220 |
+
)
|
221 |
+
sample.append(res.samples)
|
222 |
+
|
223 |
+
sample_gt = torch.cat(sample_gt, dim=0)
|
224 |
+
sample = torch.cat(sample, dim=0)
|
225 |
+
sample_dict = {
|
226 |
+
'sample_gt': {'value': sample_gt, 'type': 'sample'},
|
227 |
+
'sample': {'value': sample, 'type': 'sample'},
|
228 |
+
}
|
229 |
+
sample_dict.update(dict_reduce(cond_vis, None, {
|
230 |
+
'value': lambda x: torch.cat(x, dim=0),
|
231 |
+
'type': lambda x: x[0],
|
232 |
+
}))
|
233 |
+
|
234 |
+
return sample_dict
|
235 |
+
|
236 |
+
|
237 |
+
class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
|
238 |
+
"""
|
239 |
+
Trainer for diffusion model with flow matching objective and classifier-free guidance.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
models (dict[str, nn.Module]): Models to train.
|
243 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
244 |
+
output_dir (str): Output directory.
|
245 |
+
load_dir (str): Load directory.
|
246 |
+
step (int): Step to load.
|
247 |
+
batch_size (int): Batch size.
|
248 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
249 |
+
batch_split (int): Split batch with gradient accumulation.
|
250 |
+
max_steps (int): Max steps.
|
251 |
+
optimizer (dict): Optimizer config.
|
252 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
253 |
+
elastic (dict): Elastic memory management config.
|
254 |
+
grad_clip (float or dict): Gradient clip config.
|
255 |
+
ema_rate (float or list): Exponential moving average rates.
|
256 |
+
fp16_mode (str): FP16 mode.
|
257 |
+
- None: No FP16.
|
258 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
259 |
+
- 'amp': Automatic mixed precision.
|
260 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
261 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
262 |
+
log_param_stats (bool): Log parameter stats.
|
263 |
+
i_print (int): Print interval.
|
264 |
+
i_log (int): Log interval.
|
265 |
+
i_sample (int): Sample interval.
|
266 |
+
i_save (int): Save interval.
|
267 |
+
i_ddpcheck (int): DDP check interval.
|
268 |
+
|
269 |
+
t_schedule (dict): Time schedule for flow matching.
|
270 |
+
sigma_min (float): Minimum noise level.
|
271 |
+
p_uncond (float): Probability of dropping conditions.
|
272 |
+
"""
|
273 |
+
pass
|
274 |
+
|
275 |
+
|
276 |
+
class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
|
277 |
+
"""
|
278 |
+
Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
models (dict[str, nn.Module]): Models to train.
|
282 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
283 |
+
output_dir (str): Output directory.
|
284 |
+
load_dir (str): Load directory.
|
285 |
+
step (int): Step to load.
|
286 |
+
batch_size (int): Batch size.
|
287 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
288 |
+
batch_split (int): Split batch with gradient accumulation.
|
289 |
+
max_steps (int): Max steps.
|
290 |
+
optimizer (dict): Optimizer config.
|
291 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
292 |
+
elastic (dict): Elastic memory management config.
|
293 |
+
grad_clip (float or dict): Gradient clip config.
|
294 |
+
ema_rate (float or list): Exponential moving average rates.
|
295 |
+
fp16_mode (str): FP16 mode.
|
296 |
+
- None: No FP16.
|
297 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
298 |
+
- 'amp': Automatic mixed precision.
|
299 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
300 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
301 |
+
log_param_stats (bool): Log parameter stats.
|
302 |
+
i_print (int): Print interval.
|
303 |
+
i_log (int): Log interval.
|
304 |
+
i_sample (int): Sample interval.
|
305 |
+
i_save (int): Save interval.
|
306 |
+
i_ddpcheck (int): DDP check interval.
|
307 |
+
|
308 |
+
t_schedule (dict): Time schedule for flow matching.
|
309 |
+
sigma_min (float): Minimum noise level.
|
310 |
+
p_uncond (float): Probability of dropping conditions.
|
311 |
+
text_cond_model(str): Text conditioning model.
|
312 |
+
"""
|
313 |
+
pass
|
314 |
+
|
315 |
+
|
316 |
+
class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
|
317 |
+
"""
|
318 |
+
Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
models (dict[str, nn.Module]): Models to train.
|
322 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
323 |
+
output_dir (str): Output directory.
|
324 |
+
load_dir (str): Load directory.
|
325 |
+
step (int): Step to load.
|
326 |
+
batch_size (int): Batch size.
|
327 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
328 |
+
batch_split (int): Split batch with gradient accumulation.
|
329 |
+
max_steps (int): Max steps.
|
330 |
+
optimizer (dict): Optimizer config.
|
331 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
332 |
+
elastic (dict): Elastic memory management config.
|
333 |
+
grad_clip (float or dict): Gradient clip config.
|
334 |
+
ema_rate (float or list): Exponential moving average rates.
|
335 |
+
fp16_mode (str): FP16 mode.
|
336 |
+
- None: No FP16.
|
337 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
338 |
+
- 'amp': Automatic mixed precision.
|
339 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
340 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
341 |
+
log_param_stats (bool): Log parameter stats.
|
342 |
+
i_print (int): Print interval.
|
343 |
+
i_log (int): Log interval.
|
344 |
+
i_sample (int): Sample interval.
|
345 |
+
i_save (int): Save interval.
|
346 |
+
i_ddpcheck (int): DDP check interval.
|
347 |
+
|
348 |
+
t_schedule (dict): Time schedule for flow matching.
|
349 |
+
sigma_min (float): Minimum noise level.
|
350 |
+
p_uncond (float): Probability of dropping conditions.
|
351 |
+
image_cond_model (str): Image conditioning model.
|
352 |
+
"""
|
353 |
+
pass
|
trellis/trainers/flow_matching/mixins/classifier_free_guidance.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from ....utils.general_utils import dict_foreach
|
4 |
+
from ....pipelines import samplers
|
5 |
+
|
6 |
+
|
7 |
+
class ClassifierFreeGuidanceMixin:
|
8 |
+
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
|
9 |
+
super().__init__(*args, **kwargs)
|
10 |
+
self.p_uncond = p_uncond
|
11 |
+
|
12 |
+
def get_cond(self, cond, neg_cond=None, **kwargs):
|
13 |
+
"""
|
14 |
+
Get the conditioning data.
|
15 |
+
"""
|
16 |
+
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
17 |
+
|
18 |
+
if self.p_uncond > 0:
|
19 |
+
# randomly drop the class label
|
20 |
+
def get_batch_size(cond):
|
21 |
+
if isinstance(cond, torch.Tensor):
|
22 |
+
return cond.shape[0]
|
23 |
+
elif isinstance(cond, list):
|
24 |
+
return len(cond)
|
25 |
+
else:
|
26 |
+
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
27 |
+
|
28 |
+
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
|
29 |
+
B = get_batch_size(ref_cond)
|
30 |
+
|
31 |
+
def select(cond, neg_cond, mask):
|
32 |
+
if isinstance(cond, torch.Tensor):
|
33 |
+
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
|
34 |
+
return torch.where(mask, neg_cond, cond)
|
35 |
+
elif isinstance(cond, list):
|
36 |
+
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
|
37 |
+
else:
|
38 |
+
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
39 |
+
|
40 |
+
mask = list(np.random.rand(B) < self.p_uncond)
|
41 |
+
if not isinstance(cond, dict):
|
42 |
+
cond = select(cond, neg_cond, mask)
|
43 |
+
else:
|
44 |
+
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
|
45 |
+
|
46 |
+
return cond
|
47 |
+
|
48 |
+
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
|
49 |
+
"""
|
50 |
+
Get the conditioning data for inference.
|
51 |
+
"""
|
52 |
+
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
53 |
+
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
|
54 |
+
|
55 |
+
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
|
56 |
+
"""
|
57 |
+
Get the sampler for the diffusion process.
|
58 |
+
"""
|
59 |
+
return samplers.FlowEulerCfgSampler(self.sigma_min)
|
trellis/trainers/flow_matching/mixins/image_conditioned.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import transforms
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from ....utils import dist_utils
|
9 |
+
|
10 |
+
|
11 |
+
class ImageConditionedMixin:
|
12 |
+
"""
|
13 |
+
Mixin for image-conditioned models.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
image_cond_model: The image conditioning model.
|
17 |
+
"""
|
18 |
+
def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
|
19 |
+
super().__init__(*args, **kwargs)
|
20 |
+
self.image_cond_model_name = image_cond_model
|
21 |
+
self.image_cond_model = None # the model is init lazily
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def prepare_for_training(image_cond_model: str, **kwargs):
|
25 |
+
"""
|
26 |
+
Prepare for training.
|
27 |
+
"""
|
28 |
+
if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
|
29 |
+
super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
|
30 |
+
# download the model
|
31 |
+
torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
|
32 |
+
|
33 |
+
def _init_image_cond_model(self):
|
34 |
+
"""
|
35 |
+
Initialize the image conditioning model.
|
36 |
+
"""
|
37 |
+
with dist_utils.local_master_first():
|
38 |
+
dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
|
39 |
+
dinov2_model.eval().cuda()
|
40 |
+
transform = transforms.Compose([
|
41 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
42 |
+
])
|
43 |
+
self.image_cond_model = {
|
44 |
+
'model': dinov2_model,
|
45 |
+
'transform': transform,
|
46 |
+
}
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
|
50 |
+
"""
|
51 |
+
Encode the image.
|
52 |
+
"""
|
53 |
+
if isinstance(image, torch.Tensor):
|
54 |
+
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
|
55 |
+
elif isinstance(image, list):
|
56 |
+
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
|
57 |
+
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
|
58 |
+
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
59 |
+
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
60 |
+
image = torch.stack(image).cuda()
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unsupported type of image: {type(image)}")
|
63 |
+
|
64 |
+
if self.image_cond_model is None:
|
65 |
+
self._init_image_cond_model()
|
66 |
+
image = self.image_cond_model['transform'](image).cuda()
|
67 |
+
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
|
68 |
+
patchtokens = F.layer_norm(features, features.shape[-1:])
|
69 |
+
return patchtokens
|
70 |
+
|
71 |
+
def get_cond(self, cond, **kwargs):
|
72 |
+
"""
|
73 |
+
Get the conditioning data.
|
74 |
+
"""
|
75 |
+
cond = self.encode_image(cond)
|
76 |
+
kwargs['neg_cond'] = torch.zeros_like(cond)
|
77 |
+
cond = super().get_cond(cond, **kwargs)
|
78 |
+
return cond
|
79 |
+
|
80 |
+
def get_inference_cond(self, cond, **kwargs):
|
81 |
+
"""
|
82 |
+
Get the conditioning data for inference.
|
83 |
+
"""
|
84 |
+
cond = self.encode_image(cond)
|
85 |
+
kwargs['neg_cond'] = torch.zeros_like(cond)
|
86 |
+
cond = super().get_inference_cond(cond, **kwargs)
|
87 |
+
return cond
|
88 |
+
|
89 |
+
def vis_cond(self, cond, **kwargs):
|
90 |
+
"""
|
91 |
+
Visualize the conditioning data.
|
92 |
+
"""
|
93 |
+
return {'image': {'value': cond, 'type': 'image'}}
|
trellis/trainers/flow_matching/mixins/text_conditioned.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import os
|
3 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
6 |
+
|
7 |
+
from ....utils import dist_utils
|
8 |
+
|
9 |
+
|
10 |
+
class TextConditionedMixin:
|
11 |
+
"""
|
12 |
+
Mixin for text-conditioned models.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
text_cond_model: The text conditioning model.
|
16 |
+
"""
|
17 |
+
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
|
18 |
+
super().__init__(*args, **kwargs)
|
19 |
+
self.text_cond_model_name = text_cond_model
|
20 |
+
self.text_cond_model = None # the model is init lazily
|
21 |
+
|
22 |
+
def _init_text_cond_model(self):
|
23 |
+
"""
|
24 |
+
Initialize the text conditioning model.
|
25 |
+
"""
|
26 |
+
# load model
|
27 |
+
with dist_utils.local_master_first():
|
28 |
+
model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
|
30 |
+
model.eval()
|
31 |
+
model = model.cuda()
|
32 |
+
self.text_cond_model = {
|
33 |
+
'model': model,
|
34 |
+
'tokenizer': tokenizer,
|
35 |
+
}
|
36 |
+
self.text_cond_model['null_cond'] = self.encode_text([''])
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def encode_text(self, text: List[str]) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
Encode the text.
|
42 |
+
"""
|
43 |
+
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
|
44 |
+
if self.text_cond_model is None:
|
45 |
+
self._init_text_cond_model()
|
46 |
+
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
47 |
+
tokens = encoding['input_ids'].cuda()
|
48 |
+
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
49 |
+
|
50 |
+
return embeddings
|
51 |
+
|
52 |
+
def get_cond(self, cond, **kwargs):
|
53 |
+
"""
|
54 |
+
Get the conditioning data.
|
55 |
+
"""
|
56 |
+
cond = self.encode_text(cond)
|
57 |
+
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
|
58 |
+
cond = super().get_cond(cond, **kwargs)
|
59 |
+
return cond
|
60 |
+
|
61 |
+
def get_inference_cond(self, cond, **kwargs):
|
62 |
+
"""
|
63 |
+
Get the conditioning data for inference.
|
64 |
+
"""
|
65 |
+
cond = self.encode_text(cond)
|
66 |
+
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
|
67 |
+
cond = super().get_inference_cond(cond, **kwargs)
|
68 |
+
return cond
|
trellis/trainers/flow_matching/sparse_flow_matching.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
import functools
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import numpy as np
|
9 |
+
from easydict import EasyDict as edict
|
10 |
+
|
11 |
+
from ...modules import sparse as sp
|
12 |
+
from ...utils.general_utils import dict_reduce
|
13 |
+
from ...utils.data_utils import cycle, BalancedResumableSampler
|
14 |
+
from .flow_matching import FlowMatchingTrainer
|
15 |
+
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
|
16 |
+
from .mixins.text_conditioned import TextConditionedMixin
|
17 |
+
from .mixins.image_conditioned import ImageConditionedMixin
|
18 |
+
|
19 |
+
|
20 |
+
class SparseFlowMatchingTrainer(FlowMatchingTrainer):
|
21 |
+
"""
|
22 |
+
Trainer for sparse diffusion model with flow matching objective.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
models (dict[str, nn.Module]): Models to train.
|
26 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
27 |
+
output_dir (str): Output directory.
|
28 |
+
load_dir (str): Load directory.
|
29 |
+
step (int): Step to load.
|
30 |
+
batch_size (int): Batch size.
|
31 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
32 |
+
batch_split (int): Split batch with gradient accumulation.
|
33 |
+
max_steps (int): Max steps.
|
34 |
+
optimizer (dict): Optimizer config.
|
35 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
36 |
+
elastic (dict): Elastic memory management config.
|
37 |
+
grad_clip (float or dict): Gradient clip config.
|
38 |
+
ema_rate (float or list): Exponential moving average rates.
|
39 |
+
fp16_mode (str): FP16 mode.
|
40 |
+
- None: No FP16.
|
41 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
42 |
+
- 'amp': Automatic mixed precision.
|
43 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
44 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
45 |
+
log_param_stats (bool): Log parameter stats.
|
46 |
+
i_print (int): Print interval.
|
47 |
+
i_log (int): Log interval.
|
48 |
+
i_sample (int): Sample interval.
|
49 |
+
i_save (int): Save interval.
|
50 |
+
i_ddpcheck (int): DDP check interval.
|
51 |
+
|
52 |
+
t_schedule (dict): Time schedule for flow matching.
|
53 |
+
sigma_min (float): Minimum noise level.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def prepare_dataloader(self, **kwargs):
|
57 |
+
"""
|
58 |
+
Prepare dataloader.
|
59 |
+
"""
|
60 |
+
self.data_sampler = BalancedResumableSampler(
|
61 |
+
self.dataset,
|
62 |
+
shuffle=True,
|
63 |
+
batch_size=self.batch_size_per_gpu,
|
64 |
+
)
|
65 |
+
self.dataloader = DataLoader(
|
66 |
+
self.dataset,
|
67 |
+
batch_size=self.batch_size_per_gpu,
|
68 |
+
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
69 |
+
pin_memory=True,
|
70 |
+
drop_last=True,
|
71 |
+
persistent_workers=True,
|
72 |
+
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
|
73 |
+
sampler=self.data_sampler,
|
74 |
+
)
|
75 |
+
self.data_iterator = cycle(self.dataloader)
|
76 |
+
|
77 |
+
def training_losses(
|
78 |
+
self,
|
79 |
+
x_0: sp.SparseTensor,
|
80 |
+
cond=None,
|
81 |
+
**kwargs
|
82 |
+
) -> Tuple[Dict, Dict]:
|
83 |
+
"""
|
84 |
+
Compute training losses for a single timestep.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
x_0: The [N x ... x C] sparse tensor of the inputs.
|
88 |
+
cond: The [N x ...] tensor of additional conditions.
|
89 |
+
kwargs: Additional arguments to pass to the backbone.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
a dict with the key "loss" containing a tensor of shape [N].
|
93 |
+
may also contain other keys for different terms.
|
94 |
+
"""
|
95 |
+
noise = x_0.replace(torch.randn_like(x_0.feats))
|
96 |
+
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
|
97 |
+
x_t = self.diffuse(x_0, t, noise=noise)
|
98 |
+
cond = self.get_cond(cond, **kwargs)
|
99 |
+
|
100 |
+
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
|
101 |
+
assert pred.shape == noise.shape == x_0.shape
|
102 |
+
target = self.get_v(x_0, noise, t)
|
103 |
+
terms = edict()
|
104 |
+
terms["mse"] = F.mse_loss(pred.feats, target.feats)
|
105 |
+
terms["loss"] = terms["mse"]
|
106 |
+
|
107 |
+
# log loss with time bins
|
108 |
+
mse_per_instance = np.array([
|
109 |
+
F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
|
110 |
+
for i in range(x_0.shape[0])
|
111 |
+
])
|
112 |
+
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
|
113 |
+
for i in range(10):
|
114 |
+
if (time_bin == i).sum() != 0:
|
115 |
+
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
|
116 |
+
|
117 |
+
return terms, {}
|
118 |
+
|
119 |
+
@torch.no_grad()
|
120 |
+
def run_snapshot(
|
121 |
+
self,
|
122 |
+
num_samples: int,
|
123 |
+
batch_size: int,
|
124 |
+
verbose: bool = False,
|
125 |
+
) -> Dict:
|
126 |
+
dataloader = DataLoader(
|
127 |
+
copy.deepcopy(self.dataset),
|
128 |
+
batch_size=batch_size,
|
129 |
+
shuffle=True,
|
130 |
+
num_workers=0,
|
131 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
132 |
+
)
|
133 |
+
|
134 |
+
# inference
|
135 |
+
sampler = self.get_sampler()
|
136 |
+
sample_gt = []
|
137 |
+
sample = []
|
138 |
+
cond_vis = []
|
139 |
+
for i in range(0, num_samples, batch_size):
|
140 |
+
batch = min(batch_size, num_samples - i)
|
141 |
+
data = next(iter(dataloader))
|
142 |
+
data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()}
|
143 |
+
noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats))
|
144 |
+
sample_gt.append(data['x_0'])
|
145 |
+
cond_vis.append(self.vis_cond(**data))
|
146 |
+
del data['x_0']
|
147 |
+
args = self.get_inference_cond(**data)
|
148 |
+
res = sampler.sample(
|
149 |
+
self.models['denoiser'],
|
150 |
+
noise=noise,
|
151 |
+
**args,
|
152 |
+
steps=50, cfg_strength=3.0, verbose=verbose,
|
153 |
+
)
|
154 |
+
sample.append(res.samples)
|
155 |
+
|
156 |
+
sample_gt = sp.sparse_cat(sample_gt)
|
157 |
+
sample = sp.sparse_cat(sample)
|
158 |
+
sample_dict = {
|
159 |
+
'sample_gt': {'value': sample_gt, 'type': 'sample'},
|
160 |
+
'sample': {'value': sample, 'type': 'sample'},
|
161 |
+
}
|
162 |
+
sample_dict.update(dict_reduce(cond_vis, None, {
|
163 |
+
'value': lambda x: torch.cat(x, dim=0),
|
164 |
+
'type': lambda x: x[0],
|
165 |
+
}))
|
166 |
+
|
167 |
+
return sample_dict
|
168 |
+
|
169 |
+
|
170 |
+
class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
|
171 |
+
"""
|
172 |
+
Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
models (dict[str, nn.Module]): Models to train.
|
176 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
177 |
+
output_dir (str): Output directory.
|
178 |
+
load_dir (str): Load directory.
|
179 |
+
step (int): Step to load.
|
180 |
+
batch_size (int): Batch size.
|
181 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
182 |
+
batch_split (int): Split batch with gradient accumulation.
|
183 |
+
max_steps (int): Max steps.
|
184 |
+
optimizer (dict): Optimizer config.
|
185 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
186 |
+
elastic (dict): Elastic memory management config.
|
187 |
+
grad_clip (float or dict): Gradient clip config.
|
188 |
+
ema_rate (float or list): Exponential moving average rates.
|
189 |
+
fp16_mode (str): FP16 mode.
|
190 |
+
- None: No FP16.
|
191 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
192 |
+
- 'amp': Automatic mixed precision.
|
193 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
194 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
195 |
+
log_param_stats (bool): Log parameter stats.
|
196 |
+
i_print (int): Print interval.
|
197 |
+
i_log (int): Log interval.
|
198 |
+
i_sample (int): Sample interval.
|
199 |
+
i_save (int): Save interval.
|
200 |
+
i_ddpcheck (int): DDP check interval.
|
201 |
+
|
202 |
+
t_schedule (dict): Time schedule for flow matching.
|
203 |
+
sigma_min (float): Minimum noise level.
|
204 |
+
p_uncond (float): Probability of dropping conditions.
|
205 |
+
"""
|
206 |
+
pass
|
207 |
+
|
208 |
+
|
209 |
+
class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
|
210 |
+
"""
|
211 |
+
Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
models (dict[str, nn.Module]): Models to train.
|
215 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
216 |
+
output_dir (str): Output directory.
|
217 |
+
load_dir (str): Load directory.
|
218 |
+
step (int): Step to load.
|
219 |
+
batch_size (int): Batch size.
|
220 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
221 |
+
batch_split (int): Split batch with gradient accumulation.
|
222 |
+
max_steps (int): Max steps.
|
223 |
+
optimizer (dict): Optimizer config.
|
224 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
225 |
+
elastic (dict): Elastic memory management config.
|
226 |
+
grad_clip (float or dict): Gradient clip config.
|
227 |
+
ema_rate (float or list): Exponential moving average rates.
|
228 |
+
fp16_mode (str): FP16 mode.
|
229 |
+
- None: No FP16.
|
230 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
231 |
+
- 'amp': Automatic mixed precision.
|
232 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
233 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
234 |
+
log_param_stats (bool): Log parameter stats.
|
235 |
+
i_print (int): Print interval.
|
236 |
+
i_log (int): Log interval.
|
237 |
+
i_sample (int): Sample interval.
|
238 |
+
i_save (int): Save interval.
|
239 |
+
i_ddpcheck (int): DDP check interval.
|
240 |
+
|
241 |
+
t_schedule (dict): Time schedule for flow matching.
|
242 |
+
sigma_min (float): Minimum noise level.
|
243 |
+
p_uncond (float): Probability of dropping conditions.
|
244 |
+
text_cond_model(str): Text conditioning model.
|
245 |
+
"""
|
246 |
+
pass
|
247 |
+
|
248 |
+
|
249 |
+
class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
|
250 |
+
"""
|
251 |
+
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
models (dict[str, nn.Module]): Models to train.
|
255 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
256 |
+
output_dir (str): Output directory.
|
257 |
+
load_dir (str): Load directory.
|
258 |
+
step (int): Step to load.
|
259 |
+
batch_size (int): Batch size.
|
260 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
261 |
+
batch_split (int): Split batch with gradient accumulation.
|
262 |
+
max_steps (int): Max steps.
|
263 |
+
optimizer (dict): Optimizer config.
|
264 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
265 |
+
elastic (dict): Elastic memory management config.
|
266 |
+
grad_clip (float or dict): Gradient clip config.
|
267 |
+
ema_rate (float or list): Exponential moving average rates.
|
268 |
+
fp16_mode (str): FP16 mode.
|
269 |
+
- None: No FP16.
|
270 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
271 |
+
- 'amp': Automatic mixed precision.
|
272 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
273 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
274 |
+
log_param_stats (bool): Log parameter stats.
|
275 |
+
i_print (int): Print interval.
|
276 |
+
i_log (int): Log interval.
|
277 |
+
i_sample (int): Sample interval.
|
278 |
+
i_save (int): Save interval.
|
279 |
+
i_ddpcheck (int): DDP check interval.
|
280 |
+
|
281 |
+
t_schedule (dict): Time schedule for flow matching.
|
282 |
+
sigma_min (float): Minimum noise level.
|
283 |
+
p_uncond (float): Probability of dropping conditions.
|
284 |
+
image_cond_model (str): Image conditioning model.
|
285 |
+
"""
|
286 |
+
pass
|
trellis/trainers/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
# FP16 utils
|
5 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
6 |
+
|
7 |
+
def make_master_params(model_params):
|
8 |
+
"""
|
9 |
+
Copy model parameters into a inflated tensor of full-precision parameters.
|
10 |
+
"""
|
11 |
+
master_params = _flatten_dense_tensors(
|
12 |
+
[param.detach().float() for param in model_params]
|
13 |
+
)
|
14 |
+
master_params = nn.Parameter(master_params)
|
15 |
+
master_params.requires_grad = True
|
16 |
+
return [master_params]
|
17 |
+
|
18 |
+
|
19 |
+
def unflatten_master_params(model_params, master_params):
|
20 |
+
"""
|
21 |
+
Unflatten the master parameters to look like model_params.
|
22 |
+
"""
|
23 |
+
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
|
24 |
+
|
25 |
+
|
26 |
+
def model_params_to_master_params(model_params, master_params):
|
27 |
+
"""
|
28 |
+
Copy the model parameter data into the master parameters.
|
29 |
+
"""
|
30 |
+
master_params[0].detach().copy_(
|
31 |
+
_flatten_dense_tensors([param.detach().float() for param in model_params])
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def master_params_to_model_params(model_params, master_params):
|
36 |
+
"""
|
37 |
+
Copy the master parameter data back into the model parameters.
|
38 |
+
"""
|
39 |
+
for param, master_param in zip(
|
40 |
+
model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
|
41 |
+
):
|
42 |
+
param.detach().copy_(master_param)
|
43 |
+
|
44 |
+
|
45 |
+
def model_grads_to_master_grads(model_params, master_params):
|
46 |
+
"""
|
47 |
+
Copy the gradients from the model parameters into the master parameters
|
48 |
+
from make_master_params().
|
49 |
+
"""
|
50 |
+
master_params[0].grad = _flatten_dense_tensors(
|
51 |
+
[param.grad.data.detach().float() for param in model_params]
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def zero_grad(model_params):
|
56 |
+
for param in model_params:
|
57 |
+
if param.grad is not None:
|
58 |
+
if param.grad.grad_fn is not None:
|
59 |
+
param.grad.detach_()
|
60 |
+
else:
|
61 |
+
param.grad.requires_grad_(False)
|
62 |
+
param.grad.zero_()
|
63 |
+
|
64 |
+
|
65 |
+
# LR Schedulers
|
66 |
+
from torch.optim.lr_scheduler import LambdaLR
|
67 |
+
|
68 |
+
class LinearWarmupLRScheduler(LambdaLR):
|
69 |
+
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
70 |
+
self.warmup_steps = warmup_steps
|
71 |
+
super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
72 |
+
|
73 |
+
def lr_lambda(self, current_step):
|
74 |
+
if current_step < self.warmup_steps:
|
75 |
+
return float(current_step + 1) / self.warmup_steps
|
76 |
+
return 1.0
|
77 |
+
|
trellis/trainers/vae/sparse_structure_vae.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
|
8 |
+
from ..basic import BasicTrainer
|
9 |
+
|
10 |
+
|
11 |
+
class SparseStructureVaeTrainer(BasicTrainer):
|
12 |
+
"""
|
13 |
+
Trainer for Sparse Structure VAE.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
models (dict[str, nn.Module]): Models to train.
|
17 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
18 |
+
output_dir (str): Output directory.
|
19 |
+
load_dir (str): Load directory.
|
20 |
+
step (int): Step to load.
|
21 |
+
batch_size (int): Batch size.
|
22 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
23 |
+
batch_split (int): Split batch with gradient accumulation.
|
24 |
+
max_steps (int): Max steps.
|
25 |
+
optimizer (dict): Optimizer config.
|
26 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
27 |
+
elastic (dict): Elastic memory management config.
|
28 |
+
grad_clip (float or dict): Gradient clip config.
|
29 |
+
ema_rate (float or list): Exponential moving average rates.
|
30 |
+
fp16_mode (str): FP16 mode.
|
31 |
+
- None: No FP16.
|
32 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
33 |
+
- 'amp': Automatic mixed precision.
|
34 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
35 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
36 |
+
log_param_stats (bool): Log parameter stats.
|
37 |
+
i_print (int): Print interval.
|
38 |
+
i_log (int): Log interval.
|
39 |
+
i_sample (int): Sample interval.
|
40 |
+
i_save (int): Save interval.
|
41 |
+
i_ddpcheck (int): DDP check interval.
|
42 |
+
|
43 |
+
loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
|
44 |
+
lambda_kl (float): KL divergence loss weight.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
*args,
|
50 |
+
loss_type='bce',
|
51 |
+
lambda_kl=1e-6,
|
52 |
+
**kwargs
|
53 |
+
):
|
54 |
+
super().__init__(*args, **kwargs)
|
55 |
+
self.loss_type = loss_type
|
56 |
+
self.lambda_kl = lambda_kl
|
57 |
+
|
58 |
+
def training_losses(
|
59 |
+
self,
|
60 |
+
ss: torch.Tensor,
|
61 |
+
**kwargs
|
62 |
+
) -> Tuple[Dict, Dict]:
|
63 |
+
"""
|
64 |
+
Compute training losses.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
a dict with the key "loss" containing a scalar tensor.
|
71 |
+
may also contain other keys for different terms.
|
72 |
+
"""
|
73 |
+
z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True)
|
74 |
+
logits = self.training_models['decoder'](z)
|
75 |
+
|
76 |
+
terms = edict(loss = 0.0)
|
77 |
+
if self.loss_type == 'bce':
|
78 |
+
terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean')
|
79 |
+
terms["loss"] = terms["loss"] + terms["bce"]
|
80 |
+
elif self.loss_type == 'l1':
|
81 |
+
terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean')
|
82 |
+
terms["loss"] = terms["loss"] + terms["l1"]
|
83 |
+
elif self.loss_type == 'dice':
|
84 |
+
logits = F.sigmoid(logits)
|
85 |
+
terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
|
86 |
+
terms["loss"] = terms["loss"] + terms["dice"]
|
87 |
+
else:
|
88 |
+
raise ValueError(f'Invalid loss type {self.loss_type}')
|
89 |
+
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
|
90 |
+
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
|
91 |
+
|
92 |
+
return terms, {}
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False):
|
96 |
+
super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose)
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def run_snapshot(
|
100 |
+
self,
|
101 |
+
num_samples: int,
|
102 |
+
batch_size: int,
|
103 |
+
verbose: bool = False,
|
104 |
+
) -> Dict:
|
105 |
+
dataloader = DataLoader(
|
106 |
+
copy.deepcopy(self.dataset),
|
107 |
+
batch_size=batch_size,
|
108 |
+
shuffle=True,
|
109 |
+
num_workers=0,
|
110 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
111 |
+
)
|
112 |
+
|
113 |
+
# inference
|
114 |
+
gts = []
|
115 |
+
recons = []
|
116 |
+
for i in range(0, num_samples, batch_size):
|
117 |
+
batch = min(batch_size, num_samples - i)
|
118 |
+
data = next(iter(dataloader))
|
119 |
+
args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
|
120 |
+
z = self.models['encoder'](args['ss'].float(), sample_posterior=False)
|
121 |
+
logits = self.models['decoder'](z)
|
122 |
+
recon = (logits > 0).long()
|
123 |
+
gts.append(args['ss'])
|
124 |
+
recons.append(recon)
|
125 |
+
|
126 |
+
sample_dict = {
|
127 |
+
'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'},
|
128 |
+
'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'},
|
129 |
+
}
|
130 |
+
return sample_dict
|
trellis/trainers/vae/structured_latent_vae_gaussian.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import numpy as np
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
import utils3d.torch
|
8 |
+
|
9 |
+
from ..basic import BasicTrainer
|
10 |
+
from ...representations import Gaussian
|
11 |
+
from ...renderers import GaussianRenderer
|
12 |
+
from ...modules.sparse import SparseTensor
|
13 |
+
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
|
14 |
+
|
15 |
+
|
16 |
+
class SLatVaeGaussianTrainer(BasicTrainer):
|
17 |
+
"""
|
18 |
+
Trainer for structured latent VAE.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
models (dict[str, nn.Module]): Models to train.
|
22 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
23 |
+
output_dir (str): Output directory.
|
24 |
+
load_dir (str): Load directory.
|
25 |
+
step (int): Step to load.
|
26 |
+
batch_size (int): Batch size.
|
27 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
28 |
+
batch_split (int): Split batch with gradient accumulation.
|
29 |
+
max_steps (int): Max steps.
|
30 |
+
optimizer (dict): Optimizer config.
|
31 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
32 |
+
elastic (dict): Elastic memory management config.
|
33 |
+
grad_clip (float or dict): Gradient clip config.
|
34 |
+
ema_rate (float or list): Exponential moving average rates.
|
35 |
+
fp16_mode (str): FP16 mode.
|
36 |
+
- None: No FP16.
|
37 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
38 |
+
- 'amp': Automatic mixed precision.
|
39 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
40 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
41 |
+
log_param_stats (bool): Log parameter stats.
|
42 |
+
i_print (int): Print interval.
|
43 |
+
i_log (int): Log interval.
|
44 |
+
i_sample (int): Sample interval.
|
45 |
+
i_save (int): Save interval.
|
46 |
+
i_ddpcheck (int): DDP check interval.
|
47 |
+
|
48 |
+
loss_type (str): Loss type. Can be 'l1', 'l2'
|
49 |
+
lambda_ssim (float): SSIM loss weight.
|
50 |
+
lambda_lpips (float): LPIPS loss weight.
|
51 |
+
lambda_kl (float): KL loss weight.
|
52 |
+
regularizations (dict): Regularization config.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
*args,
|
58 |
+
loss_type: str = 'l1',
|
59 |
+
lambda_ssim: float = 0.2,
|
60 |
+
lambda_lpips: float = 0.2,
|
61 |
+
lambda_kl: float = 1e-6,
|
62 |
+
regularizations: Dict = {},
|
63 |
+
**kwargs
|
64 |
+
):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
self.loss_type = loss_type
|
67 |
+
self.lambda_ssim = lambda_ssim
|
68 |
+
self.lambda_lpips = lambda_lpips
|
69 |
+
self.lambda_kl = lambda_kl
|
70 |
+
self.regularizations = regularizations
|
71 |
+
|
72 |
+
self._init_renderer()
|
73 |
+
|
74 |
+
def _init_renderer(self):
|
75 |
+
rendering_options = {"near" : 0.8,
|
76 |
+
"far" : 1.6,
|
77 |
+
"bg_color" : 'random'}
|
78 |
+
self.renderer = GaussianRenderer(rendering_options)
|
79 |
+
self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size']
|
80 |
+
|
81 |
+
def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
82 |
+
"""
|
83 |
+
Render a batch of representations.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
reps: The dictionary of lists of representations.
|
87 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
88 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
89 |
+
"""
|
90 |
+
ret = None
|
91 |
+
for i, representation in enumerate(reps):
|
92 |
+
render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
|
93 |
+
if ret is None:
|
94 |
+
ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
|
95 |
+
for k, v in render_pack.items():
|
96 |
+
ret[k].append(v)
|
97 |
+
ret['bg_color'].append(self.renderer.bg_color)
|
98 |
+
for k, v in ret.items():
|
99 |
+
ret[k] = torch.stack(v, dim=0)
|
100 |
+
return ret
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict:
|
104 |
+
xyz = torch.cat([g.get_xyz for g in reps], dim=0)
|
105 |
+
xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5
|
106 |
+
offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3)
|
107 |
+
status = {
|
108 |
+
'xyz': xyz,
|
109 |
+
'offset': offset,
|
110 |
+
'scale': torch.cat([g.get_scaling for g in reps], dim=0),
|
111 |
+
'opacity': torch.cat([g.get_opacity for g in reps], dim=0),
|
112 |
+
}
|
113 |
+
|
114 |
+
for k in list(status.keys()):
|
115 |
+
status[k] = {
|
116 |
+
'mean': status[k].mean().item(),
|
117 |
+
'max': status[k].max().item(),
|
118 |
+
'min': status[k].min().item(),
|
119 |
+
}
|
120 |
+
|
121 |
+
return status
|
122 |
+
|
123 |
+
def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]:
|
124 |
+
loss = 0.0
|
125 |
+
terms = {}
|
126 |
+
if 'lambda_vol' in self.regularizations:
|
127 |
+
scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3]
|
128 |
+
volume = torch.prod(scales, dim=1) # [N]
|
129 |
+
terms[f'reg_vol'] = volume.mean()
|
130 |
+
loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol']
|
131 |
+
if 'lambda_opacity' in self.regularizations:
|
132 |
+
opacity = torch.cat([g.get_opacity for g in reps], dim=0)
|
133 |
+
terms[f'reg_opacity'] = (opacity - 1).pow(2).mean()
|
134 |
+
loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity']
|
135 |
+
return loss, terms
|
136 |
+
|
137 |
+
def training_losses(
|
138 |
+
self,
|
139 |
+
feats: SparseTensor,
|
140 |
+
image: torch.Tensor,
|
141 |
+
alpha: torch.Tensor,
|
142 |
+
extrinsics: torch.Tensor,
|
143 |
+
intrinsics: torch.Tensor,
|
144 |
+
return_aux: bool = False,
|
145 |
+
**kwargs
|
146 |
+
) -> Tuple[Dict, Dict]:
|
147 |
+
"""
|
148 |
+
Compute training losses.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
feats: The [N x * x C] sparse tensor of features.
|
152 |
+
image: The [N x 3 x H x W] tensor of images.
|
153 |
+
alpha: The [N x H x W] tensor of alpha channels.
|
154 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
155 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
156 |
+
return_aux: Whether to return auxiliary information.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
a dict with the key "loss" containing a scalar tensor.
|
160 |
+
may also contain other keys for different terms.
|
161 |
+
"""
|
162 |
+
z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True)
|
163 |
+
reps = self.training_models['decoder'](z)
|
164 |
+
self.renderer.rendering_options.resolution = image.shape[-1]
|
165 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
166 |
+
|
167 |
+
terms = edict(loss = 0.0, rec = 0.0)
|
168 |
+
|
169 |
+
rec_image = render_results['color']
|
170 |
+
gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
|
171 |
+
|
172 |
+
if self.loss_type == 'l1':
|
173 |
+
terms["l1"] = l1_loss(rec_image, gt_image)
|
174 |
+
terms["rec"] = terms["rec"] + terms["l1"]
|
175 |
+
elif self.loss_type == 'l2':
|
176 |
+
terms["l2"] = l2_loss(rec_image, gt_image)
|
177 |
+
terms["rec"] = terms["rec"] + terms["l2"]
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Invalid loss type: {self.loss_type}")
|
180 |
+
if self.lambda_ssim > 0:
|
181 |
+
terms["ssim"] = 1 - ssim(rec_image, gt_image)
|
182 |
+
terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
|
183 |
+
if self.lambda_lpips > 0:
|
184 |
+
terms["lpips"] = lpips(rec_image, gt_image)
|
185 |
+
terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
|
186 |
+
terms["loss"] = terms["loss"] + terms["rec"]
|
187 |
+
|
188 |
+
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
|
189 |
+
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
|
190 |
+
|
191 |
+
reg_loss, reg_terms = self._get_regularization_loss(reps)
|
192 |
+
terms.update(reg_terms)
|
193 |
+
terms["loss"] = terms["loss"] + reg_loss
|
194 |
+
|
195 |
+
status = self._get_status(z, reps)
|
196 |
+
|
197 |
+
if return_aux:
|
198 |
+
return terms, status, {'rec_image': rec_image, 'gt_image': gt_image}
|
199 |
+
return terms, status
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def run_snapshot(
|
203 |
+
self,
|
204 |
+
num_samples: int,
|
205 |
+
batch_size: int,
|
206 |
+
verbose: bool = False,
|
207 |
+
) -> Dict:
|
208 |
+
dataloader = DataLoader(
|
209 |
+
copy.deepcopy(self.dataset),
|
210 |
+
batch_size=batch_size,
|
211 |
+
shuffle=True,
|
212 |
+
num_workers=0,
|
213 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
214 |
+
)
|
215 |
+
|
216 |
+
# inference
|
217 |
+
ret_dict = {}
|
218 |
+
gt_images = []
|
219 |
+
exts = []
|
220 |
+
ints = []
|
221 |
+
reps = []
|
222 |
+
for i in range(0, num_samples, batch_size):
|
223 |
+
batch = min(batch_size, num_samples - i)
|
224 |
+
data = next(iter(dataloader))
|
225 |
+
args = {k: v[:batch].cuda() for k, v in data.items()}
|
226 |
+
gt_images.append(args['image'] * args['alpha'][:, None])
|
227 |
+
exts.append(args['extrinsics'])
|
228 |
+
ints.append(args['intrinsics'])
|
229 |
+
z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False)
|
230 |
+
reps.extend(self.models['decoder'](z))
|
231 |
+
gt_images = torch.cat(gt_images, dim=0)
|
232 |
+
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
233 |
+
|
234 |
+
# render single view
|
235 |
+
exts = torch.cat(exts, dim=0)
|
236 |
+
ints = torch.cat(ints, dim=0)
|
237 |
+
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
238 |
+
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
239 |
+
render_results = self._render_batch(reps, exts, ints)
|
240 |
+
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
241 |
+
|
242 |
+
# render multiview
|
243 |
+
self.renderer.rendering_options.resolution = 512
|
244 |
+
## Build camera
|
245 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
246 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
247 |
+
yaws = [y + yaws_offset for y in yaws]
|
248 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
249 |
+
|
250 |
+
## render each view
|
251 |
+
miltiview_images = []
|
252 |
+
for yaw, pitch in zip(yaws, pitch):
|
253 |
+
orig = torch.tensor([
|
254 |
+
np.sin(yaw) * np.cos(pitch),
|
255 |
+
np.cos(yaw) * np.cos(pitch),
|
256 |
+
np.sin(pitch),
|
257 |
+
]).float().cuda() * 2
|
258 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
259 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
260 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
261 |
+
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
262 |
+
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
263 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
264 |
+
miltiview_images.append(render_results['color'])
|
265 |
+
|
266 |
+
## Concatenate views
|
267 |
+
miltiview_images = torch.cat([
|
268 |
+
torch.cat(miltiview_images[:2], dim=-2),
|
269 |
+
torch.cat(miltiview_images[2:], dim=-2),
|
270 |
+
], dim=-1)
|
271 |
+
ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
|
272 |
+
|
273 |
+
self.renderer.rendering_options.bg_color = 'random'
|
274 |
+
|
275 |
+
return ret_dict
|
trellis/trainers/vae/structured_latent_vae_mesh_dec.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import numpy as np
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
import utils3d.torch
|
8 |
+
|
9 |
+
from ..basic import BasicTrainer
|
10 |
+
from ...representations import MeshExtractResult
|
11 |
+
from ...renderers import MeshRenderer
|
12 |
+
from ...modules.sparse import SparseTensor
|
13 |
+
from ...utils.loss_utils import l1_loss, smooth_l1_loss, ssim, lpips
|
14 |
+
from ...utils.data_utils import recursive_to_device
|
15 |
+
|
16 |
+
|
17 |
+
class SLatVaeMeshDecoderTrainer(BasicTrainer):
|
18 |
+
"""
|
19 |
+
Trainer for structured latent VAE Mesh Decoder.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
models (dict[str, nn.Module]): Models to train.
|
23 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
24 |
+
output_dir (str): Output directory.
|
25 |
+
load_dir (str): Load directory.
|
26 |
+
step (int): Step to load.
|
27 |
+
batch_size (int): Batch size.
|
28 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
29 |
+
batch_split (int): Split batch with gradient accumulation.
|
30 |
+
max_steps (int): Max steps.
|
31 |
+
optimizer (dict): Optimizer config.
|
32 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
33 |
+
elastic (dict): Elastic memory management config.
|
34 |
+
grad_clip (float or dict): Gradient clip config.
|
35 |
+
ema_rate (float or list): Exponential moving average rates.
|
36 |
+
fp16_mode (str): FP16 mode.
|
37 |
+
- None: No FP16.
|
38 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
39 |
+
- 'amp': Automatic mixed precision.
|
40 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
41 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
42 |
+
log_param_stats (bool): Log parameter stats.
|
43 |
+
i_print (int): Print interval.
|
44 |
+
i_log (int): Log interval.
|
45 |
+
i_sample (int): Sample interval.
|
46 |
+
i_save (int): Save interval.
|
47 |
+
i_ddpcheck (int): DDP check interval.
|
48 |
+
|
49 |
+
loss_type (str): Loss type. Can be 'l1', 'l2'
|
50 |
+
lambda_ssim (float): SSIM loss weight.
|
51 |
+
lambda_lpips (float): LPIPS loss weight.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
*args,
|
57 |
+
depth_loss_type: str = 'l1',
|
58 |
+
lambda_depth: int = 1,
|
59 |
+
lambda_ssim: float = 0.2,
|
60 |
+
lambda_lpips: float = 0.2,
|
61 |
+
lambda_tsdf: float = 0.01,
|
62 |
+
lambda_color: float = 0.1,
|
63 |
+
**kwargs
|
64 |
+
):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
self.depth_loss_type = depth_loss_type
|
67 |
+
self.lambda_depth = lambda_depth
|
68 |
+
self.lambda_ssim = lambda_ssim
|
69 |
+
self.lambda_lpips = lambda_lpips
|
70 |
+
self.lambda_tsdf = lambda_tsdf
|
71 |
+
self.lambda_color = lambda_color
|
72 |
+
self.use_color = self.lambda_color > 0
|
73 |
+
|
74 |
+
self._init_renderer()
|
75 |
+
|
76 |
+
def _init_renderer(self):
|
77 |
+
rendering_options = {"near" : 1,
|
78 |
+
"far" : 3}
|
79 |
+
self.renderer = MeshRenderer(rendering_options, device=self.device)
|
80 |
+
|
81 |
+
def _render_batch(self, reps: List[MeshExtractResult], extrinsics: torch.Tensor, intrinsics: torch.Tensor,
|
82 |
+
return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
|
83 |
+
"""
|
84 |
+
Render a batch of representations.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
reps: The dictionary of lists of representations.
|
88 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
89 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
90 |
+
return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
a dict with
|
94 |
+
reg_loss : [N] tensor of regularization losses
|
95 |
+
mask : [N x 1 x H x W] tensor of rendered masks
|
96 |
+
normal : [N x 3 x H x W] tensor of rendered normals
|
97 |
+
depth : [N x 1 x H x W] tensor of rendered depths
|
98 |
+
"""
|
99 |
+
ret = {k : [] for k in return_types}
|
100 |
+
for i, rep in enumerate(reps):
|
101 |
+
out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
|
102 |
+
for k in out_dict:
|
103 |
+
ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
|
104 |
+
for k in ret:
|
105 |
+
ret[k] = torch.stack(ret[k])
|
106 |
+
return ret
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def _tsdf_reg_loss(rep: MeshExtractResult, depth_map: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
110 |
+
# Calculate tsdf
|
111 |
+
with torch.no_grad():
|
112 |
+
# Project points to camera and calculate pseudo-sdf as difference between gt depth and projected depth
|
113 |
+
projected_pts, pts_depth = utils3d.torch.project_cv(extrinsics=extrinsics, intrinsics=intrinsics, points=rep.tsdf_v)
|
114 |
+
projected_pts = (projected_pts - 0.5) * 2.0
|
115 |
+
depth_map_res = depth_map.shape[1]
|
116 |
+
gt_depth = torch.nn.functional.grid_sample(depth_map.reshape(1, 1, depth_map_res, depth_map_res),
|
117 |
+
projected_pts.reshape(1, 1, -1, 2), mode='bilinear', padding_mode='border', align_corners=True)
|
118 |
+
pseudo_sdf = gt_depth.flatten() - pts_depth.flatten()
|
119 |
+
# Truncate pseudo-sdf
|
120 |
+
delta = 1 / rep.res * 3.0
|
121 |
+
trunc_mask = pseudo_sdf > -delta
|
122 |
+
|
123 |
+
# Loss
|
124 |
+
gt_tsdf = pseudo_sdf[trunc_mask]
|
125 |
+
tsdf = rep.tsdf_s.flatten()[trunc_mask]
|
126 |
+
gt_tsdf = torch.clamp(gt_tsdf, -delta, delta)
|
127 |
+
return torch.mean((tsdf - gt_tsdf) ** 2)
|
128 |
+
|
129 |
+
def _calc_tsdf_loss(self, reps : list[MeshExtractResult], depth_maps, extrinsics, intrinsics) -> torch.Tensor:
|
130 |
+
tsdf_loss = 0.0
|
131 |
+
for i, rep in enumerate(reps):
|
132 |
+
tsdf_loss += self._tsdf_reg_loss(rep, depth_maps[i], extrinsics[i], intrinsics[i])
|
133 |
+
return tsdf_loss / len(reps)
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def _flip_normal(self, normal: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
137 |
+
"""
|
138 |
+
Flip normal to align with camera.
|
139 |
+
"""
|
140 |
+
normal = normal * 2.0 - 1.0
|
141 |
+
R = torch.zeros_like(extrinsics)
|
142 |
+
R[:, :3, :3] = extrinsics[:, :3, :3]
|
143 |
+
R[:, 3, 3] = 1.0
|
144 |
+
view_dir = utils3d.torch.unproject_cv(
|
145 |
+
utils3d.torch.image_uv(*normal.shape[-2:], device=self.device).reshape(1, -1, 2),
|
146 |
+
torch.ones(*normal.shape[-2:], device=self.device).reshape(1, -1),
|
147 |
+
R, intrinsics
|
148 |
+
).reshape(-1, *normal.shape[-2:], 3).permute(0, 3, 1, 2)
|
149 |
+
unflip = (normal * view_dir).sum(1, keepdim=True) < 0
|
150 |
+
normal *= unflip * 2.0 - 1.0
|
151 |
+
return (normal + 1.0) / 2.0
|
152 |
+
|
153 |
+
def _perceptual_loss(self, gt: torch.Tensor, pred: torch.Tensor, name: str) -> Dict[str, torch.Tensor]:
|
154 |
+
"""
|
155 |
+
Combination of L1, SSIM, and LPIPS loss.
|
156 |
+
"""
|
157 |
+
if gt.shape[1] != 3:
|
158 |
+
assert gt.shape[-1] == 3
|
159 |
+
gt = gt.permute(0, 3, 1, 2)
|
160 |
+
if pred.shape[1] != 3:
|
161 |
+
assert pred.shape[-1] == 3
|
162 |
+
pred = pred.permute(0, 3, 1, 2)
|
163 |
+
terms = {
|
164 |
+
f"{name}_loss" : l1_loss(gt, pred),
|
165 |
+
f"{name}_loss_ssim" : 1 - ssim(gt, pred),
|
166 |
+
f"{name}_loss_lpips" : lpips(gt, pred)
|
167 |
+
}
|
168 |
+
terms[f"{name}_loss_perceptual"] = terms[f"{name}_loss"] + terms[f"{name}_loss_ssim"] * self.lambda_ssim + terms[f"{name}_loss_lpips"] * self.lambda_lpips
|
169 |
+
return terms
|
170 |
+
|
171 |
+
def geometry_losses(
|
172 |
+
self,
|
173 |
+
reps: List[MeshExtractResult],
|
174 |
+
mesh: List[Dict],
|
175 |
+
normal_map: torch.Tensor,
|
176 |
+
extrinsics: torch.Tensor,
|
177 |
+
intrinsics: torch.Tensor,
|
178 |
+
):
|
179 |
+
with torch.no_grad():
|
180 |
+
gt_meshes = []
|
181 |
+
for i in range(len(reps)):
|
182 |
+
gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device))
|
183 |
+
gt_meshes.append(gt_mesh)
|
184 |
+
target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal'])
|
185 |
+
target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics)
|
186 |
+
|
187 |
+
terms = edict(geo_loss = 0.0)
|
188 |
+
if self.lambda_tsdf > 0:
|
189 |
+
tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics)
|
190 |
+
terms['tsdf_loss'] = tsdf_loss
|
191 |
+
terms['geo_loss'] += tsdf_loss * self.lambda_tsdf
|
192 |
+
|
193 |
+
return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal']
|
194 |
+
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
|
195 |
+
|
196 |
+
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
|
197 |
+
if success_mask.sum() != 0:
|
198 |
+
for k, v in buffer.items():
|
199 |
+
buffer[k] = v[success_mask]
|
200 |
+
for k, v in target.items():
|
201 |
+
target[k] = v[success_mask]
|
202 |
+
|
203 |
+
terms['mask_loss'] = l1_loss(buffer['mask'], target['mask'])
|
204 |
+
if self.depth_loss_type == 'l1':
|
205 |
+
terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'])
|
206 |
+
elif self.depth_loss_type == 'smooth_l1':
|
207 |
+
terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res))
|
208 |
+
else:
|
209 |
+
raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}")
|
210 |
+
terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal'))
|
211 |
+
terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual']
|
212 |
+
if self.use_color and normal_map is not None:
|
213 |
+
terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map'))
|
214 |
+
terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color
|
215 |
+
|
216 |
+
return terms
|
217 |
+
|
218 |
+
def color_losses(self, reps, image, alpha, extrinsics, intrinsics):
|
219 |
+
terms = edict(color_loss = torch.tensor(0.0, device=self.device))
|
220 |
+
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=['color'])
|
221 |
+
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
|
222 |
+
if success_mask.sum() != 0:
|
223 |
+
terms.update(self._perceptual_loss(image * alpha[:, None][success_mask], buffer['color'][success_mask], 'color'))
|
224 |
+
terms['color_loss'] = terms['color_loss'] + terms['color_loss_perceptual'] * self.lambda_color
|
225 |
+
return terms
|
226 |
+
|
227 |
+
def training_losses(
|
228 |
+
self,
|
229 |
+
latents: SparseTensor,
|
230 |
+
image: torch.Tensor,
|
231 |
+
alpha: torch.Tensor,
|
232 |
+
mesh: List[Dict],
|
233 |
+
extrinsics: torch.Tensor,
|
234 |
+
intrinsics: torch.Tensor,
|
235 |
+
normal_map: torch.Tensor = None,
|
236 |
+
) -> Tuple[Dict, Dict]:
|
237 |
+
"""
|
238 |
+
Compute training losses.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
latents: The [N x * x C] sparse latents
|
242 |
+
image: The [N x 3 x H x W] tensor of images.
|
243 |
+
alpha: The [N x H x W] tensor of alpha channels.
|
244 |
+
mesh: The list of dictionaries of meshes.
|
245 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
246 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
a dict with the key "loss" containing a scalar tensor.
|
250 |
+
may also contain other keys for different terms.
|
251 |
+
"""
|
252 |
+
reps = self.training_models['decoder'](latents)
|
253 |
+
self.renderer.rendering_options.resolution = image.shape[-1]
|
254 |
+
|
255 |
+
terms = edict(loss = 0.0, rec = 0.0)
|
256 |
+
|
257 |
+
terms['reg_loss'] = sum([rep.reg_loss for rep in reps]) / len(reps)
|
258 |
+
terms['loss'] = terms['loss'] + terms['reg_loss']
|
259 |
+
|
260 |
+
geo_terms = self.geometry_losses(reps, mesh, normal_map, extrinsics, intrinsics)
|
261 |
+
terms.update(geo_terms)
|
262 |
+
terms['loss'] = terms['loss'] + terms['geo_loss']
|
263 |
+
|
264 |
+
if self.use_color:
|
265 |
+
color_terms = self.color_losses(reps, image, alpha, extrinsics, intrinsics)
|
266 |
+
terms.update(color_terms)
|
267 |
+
terms['loss'] = terms['loss'] + terms['color_loss']
|
268 |
+
|
269 |
+
return terms, {}
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def run_snapshot(
|
273 |
+
self,
|
274 |
+
num_samples: int,
|
275 |
+
batch_size: int,
|
276 |
+
verbose: bool = False,
|
277 |
+
) -> Dict:
|
278 |
+
dataloader = DataLoader(
|
279 |
+
copy.deepcopy(self.dataset),
|
280 |
+
batch_size=batch_size,
|
281 |
+
shuffle=True,
|
282 |
+
num_workers=0,
|
283 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
284 |
+
)
|
285 |
+
|
286 |
+
# inference
|
287 |
+
ret_dict = {}
|
288 |
+
gt_images = []
|
289 |
+
gt_normal_maps = []
|
290 |
+
gt_meshes = []
|
291 |
+
exts = []
|
292 |
+
ints = []
|
293 |
+
reps = []
|
294 |
+
for i in range(0, num_samples, batch_size):
|
295 |
+
batch = min(batch_size, num_samples - i)
|
296 |
+
data = next(iter(dataloader))
|
297 |
+
args = recursive_to_device(data, 'cuda')
|
298 |
+
gt_images.append(args['image'] * args['alpha'][:, None])
|
299 |
+
if self.use_color and 'normal_map' in data:
|
300 |
+
gt_normal_maps.append(args['normal_map'])
|
301 |
+
gt_meshes.extend(args['mesh'])
|
302 |
+
exts.append(args['extrinsics'])
|
303 |
+
ints.append(args['intrinsics'])
|
304 |
+
reps.extend(self.models['decoder'](args['latents']))
|
305 |
+
gt_images = torch.cat(gt_images, dim=0)
|
306 |
+
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
307 |
+
if self.use_color and gt_normal_maps:
|
308 |
+
gt_normal_maps = torch.cat(gt_normal_maps, dim=0)
|
309 |
+
ret_dict.update({f'gt_normal_map': {'value': gt_normal_maps, 'type': 'image'}})
|
310 |
+
|
311 |
+
# render single view
|
312 |
+
exts = torch.cat(exts, dim=0)
|
313 |
+
ints = torch.cat(ints, dim=0)
|
314 |
+
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
315 |
+
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
316 |
+
gt_render_results = self._render_batch([
|
317 |
+
MeshExtractResult(vertices=mesh['vertices'].to(self.device), faces=mesh['faces'].to(self.device))
|
318 |
+
for mesh in gt_meshes
|
319 |
+
], exts, ints, return_types=['normal'])
|
320 |
+
ret_dict.update({f'gt_normal': {'value': self._flip_normal(gt_render_results['normal'], exts, ints), 'type': 'image'}})
|
321 |
+
return_types = ['normal']
|
322 |
+
if self.use_color:
|
323 |
+
return_types.append('color')
|
324 |
+
if 'normal_map' in data:
|
325 |
+
return_types.append('normal_map')
|
326 |
+
render_results = self._render_batch(reps, exts, ints, return_types=return_types)
|
327 |
+
ret_dict.update({f'rec_normal': {'value': render_results['normal'], 'type': 'image'}})
|
328 |
+
if 'color' in return_types:
|
329 |
+
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
330 |
+
if 'normal_map' in return_types:
|
331 |
+
ret_dict.update({f'rec_normal_map': {'value': render_results['normal_map'], 'type': 'image'}})
|
332 |
+
|
333 |
+
# render multiview
|
334 |
+
self.renderer.rendering_options.resolution = 512
|
335 |
+
## Build camera
|
336 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
337 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
338 |
+
yaws = [y + yaws_offset for y in yaws]
|
339 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
340 |
+
|
341 |
+
## render each view
|
342 |
+
multiview_normals = []
|
343 |
+
multiview_normal_maps = []
|
344 |
+
miltiview_images = []
|
345 |
+
for yaw, pitch in zip(yaws, pitch):
|
346 |
+
orig = torch.tensor([
|
347 |
+
np.sin(yaw) * np.cos(pitch),
|
348 |
+
np.cos(yaw) * np.cos(pitch),
|
349 |
+
np.sin(pitch),
|
350 |
+
]).float().cuda() * 2
|
351 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
352 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
353 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
354 |
+
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
355 |
+
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
356 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
|
357 |
+
multiview_normals.append(render_results['normal'])
|
358 |
+
if 'color' in return_types:
|
359 |
+
miltiview_images.append(render_results['color'])
|
360 |
+
if 'normal_map' in return_types:
|
361 |
+
multiview_normal_maps.append(render_results['normal_map'])
|
362 |
+
|
363 |
+
## Concatenate views
|
364 |
+
multiview_normals = torch.cat([
|
365 |
+
torch.cat(multiview_normals[:2], dim=-2),
|
366 |
+
torch.cat(multiview_normals[2:], dim=-2),
|
367 |
+
], dim=-1)
|
368 |
+
ret_dict.update({f'multiview_normal': {'value': multiview_normals, 'type': 'image'}})
|
369 |
+
if 'color' in return_types:
|
370 |
+
miltiview_images = torch.cat([
|
371 |
+
torch.cat(miltiview_images[:2], dim=-2),
|
372 |
+
torch.cat(miltiview_images[2:], dim=-2),
|
373 |
+
], dim=-1)
|
374 |
+
ret_dict.update({f'multiview_image': {'value': miltiview_images, 'type': 'image'}})
|
375 |
+
if 'normal_map' in return_types:
|
376 |
+
multiview_normal_maps = torch.cat([
|
377 |
+
torch.cat(multiview_normal_maps[:2], dim=-2),
|
378 |
+
torch.cat(multiview_normal_maps[2:], dim=-2),
|
379 |
+
], dim=-1)
|
380 |
+
ret_dict.update({f'multiview_normal_map': {'value': multiview_normal_maps, 'type': 'image'}})
|
381 |
+
|
382 |
+
return ret_dict
|
trellis/trainers/vae/structured_latent_vae_rf_dec.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import numpy as np
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
import utils3d.torch
|
8 |
+
|
9 |
+
from ..basic import BasicTrainer
|
10 |
+
from ...representations import Strivec
|
11 |
+
from ...renderers import OctreeRenderer
|
12 |
+
from ...modules.sparse import SparseTensor
|
13 |
+
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
|
14 |
+
|
15 |
+
|
16 |
+
class SLatVaeRadianceFieldDecoderTrainer(BasicTrainer):
|
17 |
+
"""
|
18 |
+
Trainer for structured latent VAE Radiance Field Decoder.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
models (dict[str, nn.Module]): Models to train.
|
22 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
23 |
+
output_dir (str): Output directory.
|
24 |
+
load_dir (str): Load directory.
|
25 |
+
step (int): Step to load.
|
26 |
+
batch_size (int): Batch size.
|
27 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
28 |
+
batch_split (int): Split batch with gradient accumulation.
|
29 |
+
max_steps (int): Max steps.
|
30 |
+
optimizer (dict): Optimizer config.
|
31 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
32 |
+
elastic (dict): Elastic memory management config.
|
33 |
+
grad_clip (float or dict): Gradient clip config.
|
34 |
+
ema_rate (float or list): Exponential moving average rates.
|
35 |
+
fp16_mode (str): FP16 mode.
|
36 |
+
- None: No FP16.
|
37 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
38 |
+
- 'amp': Automatic mixed precision.
|
39 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
40 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
41 |
+
log_param_stats (bool): Log parameter stats.
|
42 |
+
i_print (int): Print interval.
|
43 |
+
i_log (int): Log interval.
|
44 |
+
i_sample (int): Sample interval.
|
45 |
+
i_save (int): Save interval.
|
46 |
+
i_ddpcheck (int): DDP check interval.
|
47 |
+
|
48 |
+
loss_type (str): Loss type. Can be 'l1', 'l2'
|
49 |
+
lambda_ssim (float): SSIM loss weight.
|
50 |
+
lambda_lpips (float): LPIPS loss weight.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
*args,
|
56 |
+
loss_type: str = 'l1',
|
57 |
+
lambda_ssim: float = 0.2,
|
58 |
+
lambda_lpips: float = 0.2,
|
59 |
+
**kwargs
|
60 |
+
):
|
61 |
+
super().__init__(*args, **kwargs)
|
62 |
+
self.loss_type = loss_type
|
63 |
+
self.lambda_ssim = lambda_ssim
|
64 |
+
self.lambda_lpips = lambda_lpips
|
65 |
+
|
66 |
+
self._init_renderer()
|
67 |
+
|
68 |
+
def _init_renderer(self):
|
69 |
+
rendering_options = {"near" : 0.8,
|
70 |
+
"far" : 1.6,
|
71 |
+
"bg_color" : 'random'}
|
72 |
+
self.renderer = OctreeRenderer(rendering_options)
|
73 |
+
self.renderer.pipe.primitive = 'trivec'
|
74 |
+
|
75 |
+
def _render_batch(self, reps: List[Strivec], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
76 |
+
"""
|
77 |
+
Render a batch of representations.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
reps: The dictionary of lists of representations.
|
81 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
82 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
83 |
+
"""
|
84 |
+
ret = None
|
85 |
+
for i, representation in enumerate(reps):
|
86 |
+
render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
|
87 |
+
if ret is None:
|
88 |
+
ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
|
89 |
+
for k, v in render_pack.items():
|
90 |
+
ret[k].append(v)
|
91 |
+
ret['bg_color'].append(self.renderer.bg_color)
|
92 |
+
for k, v in ret.items():
|
93 |
+
ret[k] = torch.stack(v, dim=0)
|
94 |
+
return ret
|
95 |
+
|
96 |
+
def training_losses(
|
97 |
+
self,
|
98 |
+
latents: SparseTensor,
|
99 |
+
image: torch.Tensor,
|
100 |
+
alpha: torch.Tensor,
|
101 |
+
extrinsics: torch.Tensor,
|
102 |
+
intrinsics: torch.Tensor,
|
103 |
+
return_aux: bool = False,
|
104 |
+
**kwargs
|
105 |
+
) -> Tuple[Dict, Dict]:
|
106 |
+
"""
|
107 |
+
Compute training losses.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
latents: The [N x * x C] sparse latents
|
111 |
+
image: The [N x 3 x H x W] tensor of images.
|
112 |
+
alpha: The [N x H x W] tensor of alpha channels.
|
113 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
114 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
115 |
+
return_aux: Whether to return auxiliary information.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
a dict with the key "loss" containing a scalar tensor.
|
119 |
+
may also contain other keys for different terms.
|
120 |
+
"""
|
121 |
+
reps = self.training_models['decoder'](latents)
|
122 |
+
self.renderer.rendering_options.resolution = image.shape[-1]
|
123 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
124 |
+
|
125 |
+
terms = edict(loss = 0.0, rec = 0.0)
|
126 |
+
|
127 |
+
rec_image = render_results['color']
|
128 |
+
gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
|
129 |
+
|
130 |
+
if self.loss_type == 'l1':
|
131 |
+
terms["l1"] = l1_loss(rec_image, gt_image)
|
132 |
+
terms["rec"] = terms["rec"] + terms["l1"]
|
133 |
+
elif self.loss_type == 'l2':
|
134 |
+
terms["l2"] = l2_loss(rec_image, gt_image)
|
135 |
+
terms["rec"] = terms["rec"] + terms["l2"]
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Invalid loss type: {self.loss_type}")
|
138 |
+
if self.lambda_ssim > 0:
|
139 |
+
terms["ssim"] = 1 - ssim(rec_image, gt_image)
|
140 |
+
terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
|
141 |
+
if self.lambda_lpips > 0:
|
142 |
+
terms["lpips"] = lpips(rec_image, gt_image)
|
143 |
+
terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
|
144 |
+
terms["loss"] = terms["loss"] + terms["rec"]
|
145 |
+
|
146 |
+
if return_aux:
|
147 |
+
return terms, {}, {'rec_image': rec_image, 'gt_image': gt_image}
|
148 |
+
return terms, {}
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
def run_snapshot(
|
152 |
+
self,
|
153 |
+
num_samples: int,
|
154 |
+
batch_size: int,
|
155 |
+
verbose: bool = False,
|
156 |
+
) -> Dict:
|
157 |
+
dataloader = DataLoader(
|
158 |
+
copy.deepcopy(self.dataset),
|
159 |
+
batch_size=batch_size,
|
160 |
+
shuffle=True,
|
161 |
+
num_workers=0,
|
162 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
163 |
+
)
|
164 |
+
|
165 |
+
# inference
|
166 |
+
ret_dict = {}
|
167 |
+
gt_images = []
|
168 |
+
exts = []
|
169 |
+
ints = []
|
170 |
+
reps = []
|
171 |
+
for i in range(0, num_samples, batch_size):
|
172 |
+
batch = min(batch_size, num_samples - i)
|
173 |
+
data = next(iter(dataloader))
|
174 |
+
args = {k: v[:batch].cuda() for k, v in data.items()}
|
175 |
+
gt_images.append(args['image'] * args['alpha'][:, None])
|
176 |
+
exts.append(args['extrinsics'])
|
177 |
+
ints.append(args['intrinsics'])
|
178 |
+
reps.extend(self.models['decoder'](args['latents']))
|
179 |
+
gt_images = torch.cat(gt_images, dim=0)
|
180 |
+
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
181 |
+
|
182 |
+
# render single view
|
183 |
+
exts = torch.cat(exts, dim=0)
|
184 |
+
ints = torch.cat(ints, dim=0)
|
185 |
+
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
186 |
+
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
187 |
+
render_results = self._render_batch(reps, exts, ints)
|
188 |
+
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
189 |
+
|
190 |
+
# render multiview
|
191 |
+
self.renderer.rendering_options.resolution = 512
|
192 |
+
## Build camera
|
193 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
194 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
195 |
+
yaws = [y + yaws_offset for y in yaws]
|
196 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
197 |
+
|
198 |
+
## render each view
|
199 |
+
miltiview_images = []
|
200 |
+
for yaw, pitch in zip(yaws, pitch):
|
201 |
+
orig = torch.tensor([
|
202 |
+
np.sin(yaw) * np.cos(pitch),
|
203 |
+
np.cos(yaw) * np.cos(pitch),
|
204 |
+
np.sin(pitch),
|
205 |
+
]).float().cuda() * 2
|
206 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
207 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
208 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
209 |
+
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
210 |
+
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
211 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
212 |
+
miltiview_images.append(render_results['color'])
|
213 |
+
|
214 |
+
## Concatenate views
|
215 |
+
miltiview_images = torch.cat([
|
216 |
+
torch.cat(miltiview_images[:2], dim=-2),
|
217 |
+
torch.cat(miltiview_images[2:], dim=-2),
|
218 |
+
], dim=-1)
|
219 |
+
ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
|
220 |
+
|
221 |
+
self.renderer.rendering_options.bg_color = 'random'
|
222 |
+
|
223 |
+
return ret_dict
|