cavargas10 commited on
Commit
08ab988
·
verified ·
1 Parent(s): 4f8013a

Upload 10 files

Browse files
trellis/utils/__init__.py ADDED
File without changes
trellis/utils/data_utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
6
+ import torch.distributed as dist
7
+
8
+
9
+ def recursive_to_device(
10
+ data: Any,
11
+ device: torch.device,
12
+ non_blocking: bool = False,
13
+ ) -> Any:
14
+ """
15
+ Recursively move all tensors in a data structure to a device.
16
+ """
17
+ if hasattr(data, "to"):
18
+ return data.to(device, non_blocking=non_blocking)
19
+ elif isinstance(data, (list, tuple)):
20
+ return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
21
+ elif isinstance(data, dict):
22
+ return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
23
+ else:
24
+ return data
25
+
26
+
27
+ def load_balanced_group_indices(
28
+ load: List[int],
29
+ num_groups: int,
30
+ equal_size: bool = False,
31
+ ) -> List[List[int]]:
32
+ """
33
+ Split indices into groups with balanced load.
34
+ """
35
+ if equal_size:
36
+ group_size = len(load) // num_groups
37
+ indices = np.argsort(load)[::-1]
38
+ groups = [[] for _ in range(num_groups)]
39
+ group_load = np.zeros(num_groups)
40
+ for idx in indices:
41
+ min_group_idx = np.argmin(group_load)
42
+ groups[min_group_idx].append(idx)
43
+ if equal_size and len(groups[min_group_idx]) == group_size:
44
+ group_load[min_group_idx] = float('inf')
45
+ else:
46
+ group_load[min_group_idx] += load[idx]
47
+ return groups
48
+
49
+
50
+ def cycle(data_loader: DataLoader) -> Iterator:
51
+ while True:
52
+ for data in data_loader:
53
+ if isinstance(data_loader.sampler, ResumableSampler):
54
+ data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
55
+ yield data
56
+ if isinstance(data_loader.sampler, DistributedSampler):
57
+ data_loader.sampler.epoch += 1
58
+ if isinstance(data_loader.sampler, ResumableSampler):
59
+ data_loader.sampler.epoch += 1
60
+ data_loader.sampler.idx = 0
61
+
62
+
63
+ class ResumableSampler(Sampler):
64
+ """
65
+ Distributed sampler that is resumable.
66
+
67
+ Args:
68
+ dataset: Dataset used for sampling.
69
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
70
+ By default, :attr:`rank` is retrieved from the current distributed
71
+ group.
72
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
73
+ indices.
74
+ seed (int, optional): random seed used to shuffle the sampler if
75
+ :attr:`shuffle=True`. This number should be identical across all
76
+ processes in the distributed group. Default: ``0``.
77
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
78
+ tail of the data to make it evenly divisible across the number of
79
+ replicas. If ``False``, the sampler will add extra indices to make
80
+ the data evenly divisible across the replicas. Default: ``False``.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ dataset: Dataset,
86
+ shuffle: bool = True,
87
+ seed: int = 0,
88
+ drop_last: bool = False,
89
+ ) -> None:
90
+ self.dataset = dataset
91
+ self.epoch = 0
92
+ self.idx = 0
93
+ self.drop_last = drop_last
94
+ self.world_size = dist.get_world_size() if dist.is_initialized() else 1
95
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
96
+ # If the dataset length is evenly divisible by # of replicas, then there
97
+ # is no need to drop any data, since the dataset will be split equally.
98
+ if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
99
+ # Split to nearest available length that is evenly divisible.
100
+ # This is to ensure each rank receives the same amount of data when
101
+ # using this Sampler.
102
+ self.num_samples = math.ceil(
103
+ (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
104
+ )
105
+ else:
106
+ self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
107
+ self.total_size = self.num_samples * self.world_size
108
+ self.shuffle = shuffle
109
+ self.seed = seed
110
+
111
+ def __iter__(self) -> Iterator:
112
+ if self.shuffle:
113
+ # deterministically shuffle based on epoch and seed
114
+ g = torch.Generator()
115
+ g.manual_seed(self.seed + self.epoch)
116
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
117
+ else:
118
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
119
+
120
+ if not self.drop_last:
121
+ # add extra samples to make it evenly divisible
122
+ padding_size = self.total_size - len(indices)
123
+ if padding_size <= len(indices):
124
+ indices += indices[:padding_size]
125
+ else:
126
+ indices += (indices * math.ceil(padding_size / len(indices)))[
127
+ :padding_size
128
+ ]
129
+ else:
130
+ # remove tail of data to make it evenly divisible.
131
+ indices = indices[: self.total_size]
132
+ assert len(indices) == self.total_size
133
+
134
+ # subsample
135
+ indices = indices[self.rank : self.total_size : self.world_size]
136
+
137
+ # resume from previous state
138
+ indices = indices[self.idx:]
139
+
140
+ return iter(indices)
141
+
142
+ def __len__(self) -> int:
143
+ return self.num_samples
144
+
145
+ def state_dict(self) -> dict[str, int]:
146
+ return {
147
+ 'epoch': self.epoch,
148
+ 'idx': self.idx,
149
+ }
150
+
151
+ def load_state_dict(self, state_dict):
152
+ self.epoch = state_dict['epoch']
153
+ self.idx = state_dict['idx']
154
+
155
+
156
+ class BalancedResumableSampler(ResumableSampler):
157
+ """
158
+ Distributed sampler that is resumable and balances the load among the processes.
159
+
160
+ Args:
161
+ dataset: Dataset used for sampling.
162
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
163
+ By default, :attr:`rank` is retrieved from the current distributed
164
+ group.
165
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
166
+ indices.
167
+ seed (int, optional): random seed used to shuffle the sampler if
168
+ :attr:`shuffle=True`. This number should be identical across all
169
+ processes in the distributed group. Default: ``0``.
170
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
171
+ tail of the data to make it evenly divisible across the number of
172
+ replicas. If ``False``, the sampler will add extra indices to make
173
+ the data evenly divisible across the replicas. Default: ``False``.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ dataset: Dataset,
179
+ shuffle: bool = True,
180
+ seed: int = 0,
181
+ drop_last: bool = False,
182
+ batch_size: int = 1,
183
+ ) -> None:
184
+ assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
185
+ super().__init__(dataset, shuffle, seed, drop_last)
186
+ self.batch_size = batch_size
187
+ self.loads = dataset.loads
188
+
189
+ def __iter__(self) -> Iterator:
190
+ if self.shuffle:
191
+ # deterministically shuffle based on epoch and seed
192
+ g = torch.Generator()
193
+ g.manual_seed(self.seed + self.epoch)
194
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
195
+ else:
196
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
197
+
198
+ if not self.drop_last:
199
+ # add extra samples to make it evenly divisible
200
+ padding_size = self.total_size - len(indices)
201
+ if padding_size <= len(indices):
202
+ indices += indices[:padding_size]
203
+ else:
204
+ indices += (indices * math.ceil(padding_size / len(indices)))[
205
+ :padding_size
206
+ ]
207
+ else:
208
+ # remove tail of data to make it evenly divisible.
209
+ indices = indices[: self.total_size]
210
+ assert len(indices) == self.total_size
211
+
212
+ # balance load among processes
213
+ num_batches = len(indices) // (self.batch_size * self.world_size)
214
+ balanced_indices = []
215
+ for i in range(num_batches):
216
+ start_idx = i * self.batch_size * self.world_size
217
+ end_idx = (i + 1) * self.batch_size * self.world_size
218
+ batch_indices = indices[start_idx:end_idx]
219
+ batch_loads = [self.loads[idx] for idx in batch_indices]
220
+ groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
221
+ balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
222
+
223
+ # resume from previous state
224
+ indices = balanced_indices[self.idx:]
225
+
226
+ return iter(indices)
trellis/utils/dist_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from contextlib import contextmanager
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+
8
+
9
+ def setup_dist(rank, local_rank, world_size, master_addr, master_port):
10
+ os.environ['MASTER_ADDR'] = master_addr
11
+ os.environ['MASTER_PORT'] = master_port
12
+ os.environ['WORLD_SIZE'] = str(world_size)
13
+ os.environ['RANK'] = str(rank)
14
+ os.environ['LOCAL_RANK'] = str(local_rank)
15
+ torch.cuda.set_device(local_rank)
16
+ dist.init_process_group('nccl', rank=rank, world_size=world_size)
17
+
18
+
19
+ def read_file_dist(path):
20
+ """
21
+ Read the binary file distributedly.
22
+ File is only read once by the rank 0 process and broadcasted to other processes.
23
+
24
+ Returns:
25
+ data (io.BytesIO): The binary data read from the file.
26
+ """
27
+ if dist.is_initialized() and dist.get_world_size() > 1:
28
+ # read file
29
+ size = torch.LongTensor(1).cuda()
30
+ if dist.get_rank() == 0:
31
+ with open(path, 'rb') as f:
32
+ data = f.read()
33
+ data = torch.ByteTensor(
34
+ torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
35
+ ).cuda()
36
+ size[0] = data.shape[0]
37
+ # broadcast size
38
+ dist.broadcast(size, src=0)
39
+ if dist.get_rank() != 0:
40
+ data = torch.ByteTensor(size[0].item()).cuda()
41
+ # broadcast data
42
+ dist.broadcast(data, src=0)
43
+ # convert to io.BytesIO
44
+ data = data.cpu().numpy().tobytes()
45
+ data = io.BytesIO(data)
46
+ return data
47
+ else:
48
+ with open(path, 'rb') as f:
49
+ data = f.read()
50
+ data = io.BytesIO(data)
51
+ return data
52
+
53
+
54
+ def unwrap_dist(model):
55
+ """
56
+ Unwrap the model from distributed training.
57
+ """
58
+ if isinstance(model, DDP):
59
+ return model.module
60
+ return model
61
+
62
+
63
+ @contextmanager
64
+ def master_first():
65
+ """
66
+ A context manager that ensures master process executes first.
67
+ """
68
+ if not dist.is_initialized():
69
+ yield
70
+ else:
71
+ if dist.get_rank() == 0:
72
+ yield
73
+ dist.barrier()
74
+ else:
75
+ dist.barrier()
76
+ yield
77
+
78
+
79
+ @contextmanager
80
+ def local_master_first():
81
+ """
82
+ A context manager that ensures local master process executes first.
83
+ """
84
+ if not dist.is_initialized():
85
+ yield
86
+ else:
87
+ if dist.get_rank() % torch.cuda.device_count() == 0:
88
+ yield
89
+ dist.barrier()
90
+ else:
91
+ dist.barrier()
92
+ yield
93
+
trellis/utils/elastic_utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from contextlib import contextmanager
3
+ from typing import Tuple
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+
9
+ class MemoryController:
10
+ """
11
+ Base class for memory management during training.
12
+ """
13
+
14
+ _last_input_size = None
15
+ _last_mem_ratio = []
16
+
17
+ @contextmanager
18
+ def record(self):
19
+ pass
20
+
21
+ def update_run_states(self, input_size=None, mem_ratio=None):
22
+ if self._last_input_size is None:
23
+ self._last_input_size = input_size
24
+ elif self._last_input_size!= input_size:
25
+ raise ValueError(f'Input size should not change for different ElasticModules.')
26
+ self._last_mem_ratio.append(mem_ratio)
27
+
28
+ @abstractmethod
29
+ def get_mem_ratio(self, input_size):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def state_dict(self):
34
+ pass
35
+
36
+ @abstractmethod
37
+ def log(self):
38
+ pass
39
+
40
+
41
+ class LinearMemoryController(MemoryController):
42
+ """
43
+ A simple controller for memory management during training.
44
+ The memory usage is modeled as a linear function of:
45
+ - the number of input parameters
46
+ - the ratio of memory the model use compared to the maximum usage (with no checkpointing)
47
+ memory_usage = k * input_size * mem_ratio + b
48
+ The controller keeps track of the memory usage and gives the
49
+ expected memory ratio to keep the memory usage under a target
50
+ """
51
+ def __init__(
52
+ self,
53
+ buffer_size=1000,
54
+ update_every=500,
55
+ target_ratio=0.8,
56
+ available_memory=None,
57
+ max_mem_ratio_start=0.1,
58
+ params=None,
59
+ device=None
60
+ ):
61
+ self.buffer_size = buffer_size
62
+ self.update_every = update_every
63
+ self.target_ratio = target_ratio
64
+ self.device = device or torch.cuda.current_device()
65
+ self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
66
+
67
+ self._memory = np.zeros(buffer_size, dtype=np.float32)
68
+ self._input_size = np.zeros(buffer_size, dtype=np.float32)
69
+ self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
70
+ self._buffer_ptr = 0
71
+ self._buffer_length = 0
72
+ self._params = tuple(params) if params is not None else (0.0, 0.0)
73
+ self._max_mem_ratio = max_mem_ratio_start
74
+ self.step = 0
75
+
76
+ def __repr__(self):
77
+ return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
78
+
79
+ def _add_sample(self, memory, input_size, mem_ratio):
80
+ self._memory[self._buffer_ptr] = memory
81
+ self._input_size[self._buffer_ptr] = input_size
82
+ self._mem_ratio[self._buffer_ptr] = mem_ratio
83
+ self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
84
+ self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
85
+
86
+ @contextmanager
87
+ def record(self):
88
+ torch.cuda.reset_peak_memory_stats(self.device)
89
+ self._last_input_size = None
90
+ self._last_mem_ratio = []
91
+ yield
92
+ self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
93
+ self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
94
+ self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
95
+ self.step += 1
96
+ if self.step % self.update_every == 0:
97
+ self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
98
+ self._fit_params()
99
+
100
+ def _fit_params(self):
101
+ memory_usage = self._memory[:self._buffer_length]
102
+ input_size = self._input_size[:self._buffer_length]
103
+ mem_ratio = self._mem_ratio[:self._buffer_length]
104
+
105
+ x = input_size * mem_ratio
106
+ y = memory_usage
107
+ k, b = np.polyfit(x, y, 1)
108
+ self._params = (k, b)
109
+ # self._visualize()
110
+
111
+ def _visualize(self):
112
+ import matplotlib.pyplot as plt
113
+ memory_usage = self._memory[:self._buffer_length]
114
+ input_size = self._input_size[:self._buffer_length]
115
+ mem_ratio = self._mem_ratio[:self._buffer_length]
116
+ k, b = self._params
117
+
118
+ plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
119
+ x = np.array([0.0, 20000.0])
120
+ plt.plot(x, k * x + b, c='r')
121
+ plt.savefig(f'linear_memory_controller_{self.step}.png')
122
+ plt.cla()
123
+
124
+ def get_mem_ratio(self, input_size):
125
+ k, b = self._params
126
+ if k == 0: return np.random.rand() * self._max_mem_ratio
127
+ pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
128
+ return min(self._max_mem_ratio, max(0.0, pred))
129
+
130
+ def state_dict(self):
131
+ return {
132
+ 'params': self._params,
133
+ }
134
+
135
+ def load_state_dict(self, state_dict):
136
+ self._params = tuple(state_dict['params'])
137
+
138
+ def log(self):
139
+ return {
140
+ 'params/k': self._params[0],
141
+ 'params/b': self._params[1],
142
+ 'memory': self._last_memory,
143
+ 'input_size': self._last_input_size,
144
+ 'mem_ratio': self._last_mem_ratio,
145
+ }
146
+
147
+
148
+ class ElasticModule(nn.Module):
149
+ """
150
+ Module for training with elastic memory management.
151
+ """
152
+ def __init__(self):
153
+ super().__init__()
154
+ self._memory_controller: MemoryController = None
155
+
156
+ @abstractmethod
157
+ def _get_input_size(self, *args, **kwargs) -> int:
158
+ """
159
+ Get the size of the input data.
160
+
161
+ Returns:
162
+ int: The size of the input data.
163
+ """
164
+ pass
165
+
166
+ @abstractmethod
167
+ def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
168
+ """
169
+ Forward with a given memory ratio.
170
+ """
171
+ pass
172
+
173
+ def register_memory_controller(self, memory_controller: MemoryController):
174
+ self._memory_controller = memory_controller
175
+
176
+ def forward(self, *args, **kwargs):
177
+ if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
178
+ _, ret = self._forward_with_mem_ratio(*args, **kwargs)
179
+ else:
180
+ input_size = self._get_input_size(*args, **kwargs)
181
+ mem_ratio = self._memory_controller.get_mem_ratio(input_size)
182
+ mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
183
+ self._memory_controller.update_run_states(input_size, mem_ratio)
184
+ return ret
185
+
186
+
187
+ class ElasticModuleMixin:
188
+ """
189
+ Mixin for training with elastic memory management.
190
+ """
191
+ def __init__(self, *args, **kwargs):
192
+ super().__init__(*args, **kwargs)
193
+ self._memory_controller: MemoryController = None
194
+
195
+ @abstractmethod
196
+ def _get_input_size(self, *args, **kwargs) -> int:
197
+ """
198
+ Get the size of the input data.
199
+
200
+ Returns:
201
+ int: The size of the input data.
202
+ """
203
+ pass
204
+
205
+ @abstractmethod
206
+ @contextmanager
207
+ def with_mem_ratio(self, mem_ratio=1.0) -> float:
208
+ """
209
+ Context manager for training with a reduced memory ratio compared to the full memory usage.
210
+
211
+ Returns:
212
+ float: The exact memory ratio used during the forward pass.
213
+ """
214
+ pass
215
+
216
+ def register_memory_controller(self, memory_controller: MemoryController):
217
+ self._memory_controller = memory_controller
218
+
219
+ def forward(self, *args, **kwargs):
220
+ if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
221
+ ret = super().forward(*args, **kwargs)
222
+ else:
223
+ input_size = self._get_input_size(*args, **kwargs)
224
+ mem_ratio = self._memory_controller.get_mem_ratio(input_size)
225
+ with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
226
+ ret = super().forward(*args, **kwargs)
227
+ self._memory_controller.update_run_states(input_size, exact_mem_ratio)
228
+ return ret
trellis/utils/general_utils.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import contextlib
6
+
7
+
8
+ # Dictionary utils
9
+ def _dict_merge(dicta, dictb, prefix=''):
10
+ """
11
+ Merge two dictionaries.
12
+ """
13
+ assert isinstance(dicta, dict), 'input must be a dictionary'
14
+ assert isinstance(dictb, dict), 'input must be a dictionary'
15
+ dict_ = {}
16
+ all_keys = set(dicta.keys()).union(set(dictb.keys()))
17
+ for key in all_keys:
18
+ if key in dicta.keys() and key in dictb.keys():
19
+ if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
20
+ dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
21
+ else:
22
+ raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
23
+ elif key in dicta.keys():
24
+ dict_[key] = dicta[key]
25
+ else:
26
+ dict_[key] = dictb[key]
27
+ return dict_
28
+
29
+
30
+ def dict_merge(dicta, dictb):
31
+ """
32
+ Merge two dictionaries.
33
+ """
34
+ return _dict_merge(dicta, dictb, prefix='')
35
+
36
+
37
+ def dict_foreach(dic, func, special_func={}):
38
+ """
39
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
40
+ """
41
+ assert isinstance(dic, dict), 'input must be a dictionary'
42
+ for key in dic.keys():
43
+ if isinstance(dic[key], dict):
44
+ dic[key] = dict_foreach(dic[key], func)
45
+ else:
46
+ if key in special_func.keys():
47
+ dic[key] = special_func[key](dic[key])
48
+ else:
49
+ dic[key] = func(dic[key])
50
+ return dic
51
+
52
+
53
+ def dict_reduce(dicts, func, special_func={}):
54
+ """
55
+ Reduce a list of dictionaries. Leaf values must be scalars.
56
+ """
57
+ assert isinstance(dicts, list), 'input must be a list of dictionaries'
58
+ assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
59
+ assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
60
+ all_keys = set([key for dict_ in dicts for key in dict_.keys()])
61
+ reduced_dict = {}
62
+ for key in all_keys:
63
+ vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
64
+ if isinstance(vlist[0], dict):
65
+ reduced_dict[key] = dict_reduce(vlist, func, special_func)
66
+ else:
67
+ if key in special_func.keys():
68
+ reduced_dict[key] = special_func[key](vlist)
69
+ else:
70
+ reduced_dict[key] = func(vlist)
71
+ return reduced_dict
72
+
73
+
74
+ def dict_any(dic, func):
75
+ """
76
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
77
+ """
78
+ assert isinstance(dic, dict), 'input must be a dictionary'
79
+ for key in dic.keys():
80
+ if isinstance(dic[key], dict):
81
+ if dict_any(dic[key], func):
82
+ return True
83
+ else:
84
+ if func(dic[key]):
85
+ return True
86
+ return False
87
+
88
+
89
+ def dict_all(dic, func):
90
+ """
91
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
92
+ """
93
+ assert isinstance(dic, dict), 'input must be a dictionary'
94
+ for key in dic.keys():
95
+ if isinstance(dic[key], dict):
96
+ if not dict_all(dic[key], func):
97
+ return False
98
+ else:
99
+ if not func(dic[key]):
100
+ return False
101
+ return True
102
+
103
+
104
+ def dict_flatten(dic, sep='.'):
105
+ """
106
+ Flatten a nested dictionary into a dictionary with no nested dictionaries.
107
+ """
108
+ assert isinstance(dic, dict), 'input must be a dictionary'
109
+ flat_dict = {}
110
+ for key in dic.keys():
111
+ if isinstance(dic[key], dict):
112
+ sub_dict = dict_flatten(dic[key], sep=sep)
113
+ for sub_key in sub_dict.keys():
114
+ flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
115
+ else:
116
+ flat_dict[key] = dic[key]
117
+ return flat_dict
118
+
119
+
120
+ # Context utils
121
+ @contextlib.contextmanager
122
+ def nested_contexts(*contexts):
123
+ with contextlib.ExitStack() as stack:
124
+ for ctx in contexts:
125
+ stack.enter_context(ctx())
126
+ yield
127
+
128
+
129
+ # Image utils
130
+ def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
131
+ num_images = len(images)
132
+ if nrow is None and ncol is None:
133
+ if aspect_ratio is not None:
134
+ nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
135
+ else:
136
+ nrow = int(np.sqrt(num_images))
137
+ ncol = (num_images + nrow - 1) // nrow
138
+ elif nrow is None and ncol is not None:
139
+ nrow = (num_images + ncol - 1) // ncol
140
+ elif nrow is not None and ncol is None:
141
+ ncol = (num_images + nrow - 1) // nrow
142
+ else:
143
+ assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
144
+
145
+ if images[0].ndim == 2:
146
+ grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
147
+ else:
148
+ grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
149
+ for i, img in enumerate(images):
150
+ row = i // ncol
151
+ col = i % ncol
152
+ grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
153
+ return grid
154
+
155
+
156
+ def notes_on_image(img, notes=None):
157
+ img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
158
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
159
+ if notes is not None:
160
+ img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
161
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
162
+ return img
163
+
164
+
165
+ def save_image_with_notes(img, path, notes=None):
166
+ """
167
+ Save an image with notes.
168
+ """
169
+ if isinstance(img, torch.Tensor):
170
+ img = img.cpu().numpy().transpose(1, 2, 0)
171
+ if img.dtype == np.float32 or img.dtype == np.float64:
172
+ img = np.clip(img * 255, 0, 255).astype(np.uint8)
173
+ img = notes_on_image(img, notes)
174
+ cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
175
+
176
+
177
+ # debug utils
178
+
179
+ def atol(x, y):
180
+ """
181
+ Absolute tolerance.
182
+ """
183
+ return torch.abs(x - y)
184
+
185
+
186
+ def rtol(x, y):
187
+ """
188
+ Relative tolerance.
189
+ """
190
+ return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
191
+
192
+
193
+ # print utils
194
+ def indent(s, n=4):
195
+ """
196
+ Indent a string.
197
+ """
198
+ lines = s.split('\n')
199
+ for i in range(1, len(lines)):
200
+ lines[i] = ' ' * n + lines[i]
201
+ return '\n'.join(lines)
202
+
trellis/utils/grad_clip_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import numpy as np
4
+ import torch.utils
5
+
6
+
7
+ class AdaptiveGradClipper:
8
+ """
9
+ Adaptive gradient clipping for training.
10
+ """
11
+ def __init__(
12
+ self,
13
+ max_norm=None,
14
+ clip_percentile=95.0,
15
+ buffer_size=1000,
16
+ ):
17
+ self.max_norm = max_norm
18
+ self.clip_percentile = clip_percentile
19
+ self.buffer_size = buffer_size
20
+
21
+ self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
22
+ self._max_norm = max_norm
23
+ self._buffer_ptr = 0
24
+ self._buffer_length = 0
25
+
26
+ def __repr__(self):
27
+ return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
28
+
29
+ def state_dict(self):
30
+ return {
31
+ 'grad_norm': self._grad_norm,
32
+ 'max_norm': self._max_norm,
33
+ 'buffer_ptr': self._buffer_ptr,
34
+ 'buffer_length': self._buffer_length,
35
+ }
36
+
37
+ def load_state_dict(self, state_dict):
38
+ self._grad_norm = state_dict['grad_norm']
39
+ self._max_norm = state_dict['max_norm']
40
+ self._buffer_ptr = state_dict['buffer_ptr']
41
+ self._buffer_length = state_dict['buffer_length']
42
+
43
+ def log(self):
44
+ return {
45
+ 'max_norm': self._max_norm,
46
+ }
47
+
48
+ def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
49
+ """Clip the gradient norm of an iterable of parameters.
50
+
51
+ The norm is computed over all gradients together, as if they were
52
+ concatenated into a single vector. Gradients are modified in-place.
53
+
54
+ Args:
55
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
56
+ single Tensor that will have gradients normalized
57
+ norm_type (float): type of the used p-norm. Can be ``'inf'`` for
58
+ infinity norm.
59
+ error_if_nonfinite (bool): if True, an error is thrown if the total
60
+ norm of the gradients from :attr:`parameters` is ``nan``,
61
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
62
+ foreach (bool): use the faster foreach-based implementation.
63
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
64
+ fall back to the slow implementation for other device types.
65
+ Default: ``None``
66
+
67
+ Returns:
68
+ Total norm of the parameter gradients (viewed as a single vector).
69
+ """
70
+ max_norm = self._max_norm if self._max_norm is not None else float('inf')
71
+ grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
72
+
73
+ if torch.isfinite(grad_norm):
74
+ self._grad_norm[self._buffer_ptr] = grad_norm
75
+ self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
76
+ self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
77
+ if self._buffer_length == self.buffer_size:
78
+ self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
79
+ self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
80
+
81
+ return grad_norm
trellis/utils/loss_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.autograd import Variable
4
+ from math import exp
5
+ from lpips import LPIPS
6
+
7
+
8
+ def smooth_l1_loss(pred, target, beta=1.0):
9
+ diff = torch.abs(pred - target)
10
+ loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
11
+ return loss.mean()
12
+
13
+
14
+ def l1_loss(network_output, gt):
15
+ return torch.abs((network_output - gt)).mean()
16
+
17
+
18
+ def l2_loss(network_output, gt):
19
+ return ((network_output - gt) ** 2).mean()
20
+
21
+
22
+ def gaussian(window_size, sigma):
23
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
24
+ return gauss / gauss.sum()
25
+
26
+
27
+ def create_window(window_size, channel):
28
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31
+ return window
32
+
33
+
34
+ def psnr(img1, img2, max_val=1.0):
35
+ mse = F.mse_loss(img1, img2)
36
+ return 20 * torch.log10(max_val / torch.sqrt(mse))
37
+
38
+
39
+ def ssim(img1, img2, window_size=11, size_average=True):
40
+ channel = img1.size(-3)
41
+ window = create_window(window_size, channel)
42
+
43
+ if img1.is_cuda:
44
+ window = window.cuda(img1.get_device())
45
+ window = window.type_as(img1)
46
+
47
+ return _ssim(img1, img2, window, window_size, channel, size_average)
48
+
49
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
50
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
51
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
52
+
53
+ mu1_sq = mu1.pow(2)
54
+ mu2_sq = mu2.pow(2)
55
+ mu1_mu2 = mu1 * mu2
56
+
57
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
58
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
59
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
60
+
61
+ C1 = 0.01 ** 2
62
+ C2 = 0.03 ** 2
63
+
64
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
65
+
66
+ if size_average:
67
+ return ssim_map.mean()
68
+ else:
69
+ return ssim_map.mean(1).mean(1).mean(1)
70
+
71
+
72
+ loss_fn_vgg = None
73
+ def lpips(img1, img2, value_range=(0, 1)):
74
+ global loss_fn_vgg
75
+ if loss_fn_vgg is None:
76
+ loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
77
+ # normalize to [-1, 1]
78
+ img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
79
+ img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
80
+ return loss_fn_vgg(img1, img2).mean()
81
+
82
+
83
+ def normal_angle(pred, gt):
84
+ pred = pred * 2.0 - 1.0
85
+ gt = gt * 2.0 - 1.0
86
+ norms = pred.norm(dim=-1) * gt.norm(dim=-1)
87
+ cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
88
+ cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
89
+ ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
90
+ if ang.isnan():
91
+ return -1
92
+ return ang
trellis/utils/postprocessing_utils.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import numpy as np
3
+ import torch
4
+ import utils3d
5
+ import nvdiffrast.torch as dr
6
+ from tqdm import tqdm
7
+ import trimesh
8
+ import trimesh.visual
9
+ import xatlas
10
+ import pyvista as pv
11
+ from pymeshfix import _meshfix
12
+ import igraph
13
+ import cv2
14
+ from PIL import Image
15
+ from .random_utils import sphere_hammersley_sequence
16
+ from .render_utils import render_multiview
17
+ from ..renderers import GaussianRenderer
18
+ from ..representations import Strivec, Gaussian, MeshExtractResult
19
+
20
+
21
+ @torch.no_grad()
22
+ def _fill_holes(
23
+ verts,
24
+ faces,
25
+ max_hole_size=0.04,
26
+ max_hole_nbe=32,
27
+ resolution=128,
28
+ num_views=500,
29
+ debug=False,
30
+ verbose=False
31
+ ):
32
+ """
33
+ Rasterize a mesh from multiple views and remove invisible faces.
34
+ Also includes postprocessing to:
35
+ 1. Remove connected components that are have low visibility.
36
+ 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
37
+
38
+ Args:
39
+ verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
40
+ faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
41
+ max_hole_size (float): Maximum area of a hole to fill.
42
+ resolution (int): Resolution of the rasterization.
43
+ num_views (int): Number of views to rasterize the mesh.
44
+ verbose (bool): Whether to print progress.
45
+ """
46
+ # Construct cameras
47
+ yaws = []
48
+ pitchs = []
49
+ for i in range(num_views):
50
+ y, p = sphere_hammersley_sequence(i, num_views)
51
+ yaws.append(y)
52
+ pitchs.append(p)
53
+ yaws = torch.tensor(yaws).cuda()
54
+ pitchs = torch.tensor(pitchs).cuda()
55
+ radius = 2.0
56
+ fov = torch.deg2rad(torch.tensor(40)).cuda()
57
+ projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
58
+ views = []
59
+ for (yaw, pitch) in zip(yaws, pitchs):
60
+ orig = torch.tensor([
61
+ torch.sin(yaw) * torch.cos(pitch),
62
+ torch.cos(yaw) * torch.cos(pitch),
63
+ torch.sin(pitch),
64
+ ]).cuda().float() * radius
65
+ view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
66
+ views.append(view)
67
+ views = torch.stack(views, dim=0)
68
+
69
+ # Rasterize
70
+ visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
71
+ rastctx = utils3d.torch.RastContext(backend='cuda')
72
+ for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
73
+ view = views[i]
74
+ buffers = utils3d.torch.rasterize_triangle_faces(
75
+ rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
76
+ )
77
+ face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
78
+ face_id = torch.unique(face_id).long()
79
+ visblity[face_id] += 1
80
+ visblity = visblity.float() / num_views
81
+
82
+ # Mincut
83
+ ## construct outer faces
84
+ edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
85
+ boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
86
+ connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
87
+ outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
88
+ for i in range(len(connected_components)):
89
+ outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
90
+ outer_face_indices = outer_face_indices.nonzero().reshape(-1)
91
+
92
+ ## construct inner faces
93
+ inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
94
+ if verbose:
95
+ tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
96
+ if inner_face_indices.shape[0] == 0:
97
+ return verts, faces
98
+
99
+ ## Construct dual graph (faces as nodes, edges as edges)
100
+ dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
101
+ dual_edge2edge = edges[dual_edge2edge]
102
+ dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
103
+ if verbose:
104
+ tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
105
+
106
+ ## solve mincut problem
107
+ ### construct main graph
108
+ g = igraph.Graph()
109
+ g.add_vertices(faces.shape[0])
110
+ g.add_edges(dual_edges.cpu().numpy())
111
+ g.es['weight'] = dual_edges_weights.cpu().numpy()
112
+
113
+ ### source and target
114
+ g.add_vertex('s')
115
+ g.add_vertex('t')
116
+
117
+ ### connect invisible faces to source
118
+ g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
119
+
120
+ ### connect outer faces to target
121
+ g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
122
+
123
+ ### solve mincut
124
+ cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
125
+ remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
126
+ if verbose:
127
+ tqdm.write(f'Mincut solved, start checking the cut')
128
+
129
+ ### check if the cut is valid with each connected component
130
+ to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
131
+ if debug:
132
+ tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
133
+ valid_remove_cc = []
134
+ cutting_edges = []
135
+ for cc in to_remove_cc:
136
+ #### check if the connected component has low visibility
137
+ visblity_median = visblity[remove_face_indices[cc]].median()
138
+ if debug:
139
+ tqdm.write(f'visblity_median: {visblity_median}')
140
+ if visblity_median > 0.25:
141
+ continue
142
+
143
+ #### check if the cuting loop is small enough
144
+ cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
145
+ cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
146
+ cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
147
+ if len(cc_new_boundary_edge_indices) > 0:
148
+ cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
149
+ cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
150
+ cc_new_boundary_edges_cc_area = []
151
+ for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
152
+ _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
153
+ _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
154
+ cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
155
+ if debug:
156
+ cutting_edges.append(cc_new_boundary_edge_indices)
157
+ tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
158
+ if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
159
+ continue
160
+
161
+ valid_remove_cc.append(cc)
162
+
163
+ if debug:
164
+ face_v = verts[faces].mean(dim=1).cpu().numpy()
165
+ vis_dual_edges = dual_edges.cpu().numpy()
166
+ vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
167
+ vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
168
+ vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
169
+ vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
170
+ if len(valid_remove_cc) > 0:
171
+ vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
172
+ utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
173
+
174
+ vis_verts = verts.cpu().numpy()
175
+ vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
176
+ utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
177
+
178
+
179
+ if len(valid_remove_cc) > 0:
180
+ remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
181
+ mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
182
+ mask[remove_face_indices] = 0
183
+ faces = faces[mask]
184
+ faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
185
+ if verbose:
186
+ tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
187
+ else:
188
+ if verbose:
189
+ tqdm.write(f'Removed 0 faces by mincut')
190
+
191
+ mesh = _meshfix.PyTMesh()
192
+ mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
193
+ mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
194
+ verts, faces = mesh.return_arrays()
195
+ verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
196
+
197
+ return verts, faces
198
+
199
+
200
+ def postprocess_mesh(
201
+ vertices: np.array,
202
+ faces: np.array,
203
+ simplify: bool = True,
204
+ simplify_ratio: float = 0.9,
205
+ fill_holes: bool = True,
206
+ fill_holes_max_hole_size: float = 0.04,
207
+ fill_holes_max_hole_nbe: int = 32,
208
+ fill_holes_resolution: int = 1024,
209
+ fill_holes_num_views: int = 1000,
210
+ debug: bool = False,
211
+ verbose: bool = False,
212
+ ):
213
+ """
214
+ Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
215
+
216
+ Args:
217
+ vertices (np.array): Vertices of the mesh. Shape (V, 3).
218
+ faces (np.array): Faces of the mesh. Shape (F, 3).
219
+ simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
220
+ simplify_ratio (float): Ratio of faces to keep after simplification.
221
+ fill_holes (bool): Whether to fill holes in the mesh.
222
+ fill_holes_max_hole_size (float): Maximum area of a hole to fill.
223
+ fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
224
+ fill_holes_resolution (int): Resolution of the rasterization.
225
+ fill_holes_num_views (int): Number of views to rasterize the mesh.
226
+ verbose (bool): Whether to print progress.
227
+ """
228
+
229
+ if verbose:
230
+ tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
231
+
232
+ # Simplify
233
+ if simplify and simplify_ratio > 0:
234
+ mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
235
+ mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
236
+ vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
237
+ if verbose:
238
+ tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
239
+
240
+ # Remove invisible faces
241
+ if fill_holes:
242
+ vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
243
+ vertices, faces = _fill_holes(
244
+ vertices, faces,
245
+ max_hole_size=fill_holes_max_hole_size,
246
+ max_hole_nbe=fill_holes_max_hole_nbe,
247
+ resolution=fill_holes_resolution,
248
+ num_views=fill_holes_num_views,
249
+ debug=debug,
250
+ verbose=verbose,
251
+ )
252
+ vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
253
+ if verbose:
254
+ tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
255
+
256
+ return vertices, faces
257
+
258
+
259
+ def parametrize_mesh(vertices: np.array, faces: np.array):
260
+ """
261
+ Parametrize a mesh to a texture space, using xatlas.
262
+
263
+ Args:
264
+ vertices (np.array): Vertices of the mesh. Shape (V, 3).
265
+ faces (np.array): Faces of the mesh. Shape (F, 3).
266
+ """
267
+
268
+ vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
269
+
270
+ vertices = vertices[vmapping]
271
+ faces = indices
272
+
273
+ return vertices, faces, uvs
274
+
275
+
276
+ def bake_texture(
277
+ vertices: np.array,
278
+ faces: np.array,
279
+ uvs: np.array,
280
+ observations: List[np.array],
281
+ masks: List[np.array],
282
+ extrinsics: List[np.array],
283
+ intrinsics: List[np.array],
284
+ texture_size: int = 2048,
285
+ near: float = 0.1,
286
+ far: float = 10.0,
287
+ mode: Literal['fast', 'opt'] = 'opt',
288
+ lambda_tv: float = 1e-2,
289
+ verbose: bool = False,
290
+ ):
291
+ """
292
+ Bake texture to a mesh from multiple observations.
293
+
294
+ Args:
295
+ vertices (np.array): Vertices of the mesh. Shape (V, 3).
296
+ faces (np.array): Faces of the mesh. Shape (F, 3).
297
+ uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
298
+ observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
299
+ masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
300
+ extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
301
+ intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
302
+ texture_size (int): Size of the texture.
303
+ near (float): Near plane of the camera.
304
+ far (float): Far plane of the camera.
305
+ mode (Literal['fast', 'opt']): Mode of texture baking.
306
+ lambda_tv (float): Weight of total variation loss in optimization.
307
+ verbose (bool): Whether to print progress.
308
+ """
309
+ vertices = torch.tensor(vertices).cuda()
310
+ faces = torch.tensor(faces.astype(np.int32)).cuda()
311
+ uvs = torch.tensor(uvs).cuda()
312
+ observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
313
+ masks = [torch.tensor(m>0).bool().cuda() for m in masks]
314
+ views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
315
+ projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
316
+
317
+ if mode == 'fast':
318
+ texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
319
+ texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
320
+ rastctx = utils3d.torch.RastContext(backend='cuda')
321
+ for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
322
+ with torch.no_grad():
323
+ rast = utils3d.torch.rasterize_triangle_faces(
324
+ rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
325
+ )
326
+ uv_map = rast['uv'][0].detach().flip(0)
327
+ mask = rast['mask'][0].detach().bool() & masks[0]
328
+
329
+ # nearest neighbor interpolation
330
+ uv_map = (uv_map * texture_size).floor().long()
331
+ obs = observation[mask]
332
+ uv_map = uv_map[mask]
333
+ idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
334
+ texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
335
+ texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
336
+
337
+ mask = texture_weights > 0
338
+ texture[mask] /= texture_weights[mask][:, None]
339
+ texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
340
+
341
+ # inpaint
342
+ mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
343
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
344
+
345
+ elif mode == 'opt':
346
+ rastctx = utils3d.torch.RastContext(backend='cuda')
347
+ observations = [observations.flip(0) for observations in observations]
348
+ masks = [m.flip(0) for m in masks]
349
+ _uv = []
350
+ _uv_dr = []
351
+ for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
352
+ with torch.no_grad():
353
+ rast = utils3d.torch.rasterize_triangle_faces(
354
+ rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
355
+ )
356
+ _uv.append(rast['uv'].detach())
357
+ _uv_dr.append(rast['uv_dr'].detach())
358
+
359
+ texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
360
+ optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
361
+
362
+ def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
363
+ return start_lr * (end_lr / start_lr) ** (step / total_steps)
364
+
365
+ def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
366
+ return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
367
+
368
+ def tv_loss(texture):
369
+ return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
370
+ torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
371
+
372
+ total_steps = 2500
373
+ with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
374
+ for step in range(total_steps):
375
+ optimizer.zero_grad()
376
+ selected = np.random.randint(0, len(views))
377
+ uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
378
+ render = dr.texture(texture, uv, uv_dr)[0]
379
+ loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
380
+ if lambda_tv > 0:
381
+ loss += lambda_tv * tv_loss(texture)
382
+ loss.backward()
383
+ optimizer.step()
384
+ # annealing
385
+ optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
386
+ pbar.set_postfix({'loss': loss.item()})
387
+ pbar.update()
388
+ texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
389
+ mask = 1 - utils3d.torch.rasterize_triangle_faces(
390
+ rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
391
+ )['mask'][0].detach().cpu().numpy().astype(np.uint8)
392
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
393
+ else:
394
+ raise ValueError(f'Unknown mode: {mode}')
395
+
396
+ return texture
397
+
398
+
399
+ def to_glb(
400
+ app_rep: Union[Strivec, Gaussian],
401
+ mesh: MeshExtractResult,
402
+ simplify: float = 0.95,
403
+ fill_holes: bool = True,
404
+ fill_holes_max_size: float = 0.04,
405
+ texture_size: int = 1024,
406
+ debug: bool = False,
407
+ verbose: bool = True,
408
+ ) -> trimesh.Trimesh:
409
+ """
410
+ Convert a generated asset to a glb file.
411
+
412
+ Args:
413
+ app_rep (Union[Strivec, Gaussian]): Appearance representation.
414
+ mesh (MeshExtractResult): Extracted mesh.
415
+ simplify (float): Ratio of faces to remove in simplification.
416
+ fill_holes (bool): Whether to fill holes in the mesh.
417
+ fill_holes_max_size (float): Maximum area of a hole to fill.
418
+ texture_size (int): Size of the texture.
419
+ debug (bool): Whether to print debug information.
420
+ verbose (bool): Whether to print progress.
421
+ """
422
+ vertices = mesh.vertices.cpu().numpy()
423
+ faces = mesh.faces.cpu().numpy()
424
+
425
+ # mesh postprocess
426
+ vertices, faces = postprocess_mesh(
427
+ vertices, faces,
428
+ simplify=simplify > 0,
429
+ simplify_ratio=simplify,
430
+ fill_holes=fill_holes,
431
+ fill_holes_max_hole_size=fill_holes_max_size,
432
+ fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
433
+ fill_holes_resolution=1024,
434
+ fill_holes_num_views=1000,
435
+ debug=debug,
436
+ verbose=verbose,
437
+ )
438
+
439
+ # parametrize mesh
440
+ vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
+
442
+ # bake texture
443
+ observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
444
+ masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
+ extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
+ intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
447
+ texture = bake_texture(
448
+ vertices, faces, uvs,
449
+ observations, masks, extrinsics, intrinsics,
450
+ texture_size=texture_size, mode='opt',
451
+ lambda_tv=0.01,
452
+ verbose=verbose
453
+ )
454
+ texture = Image.fromarray(texture)
455
+
456
+ # rotate mesh (from z-up to y-up)
457
+ vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
458
+ material = trimesh.visual.material.PBRMaterial(
459
+ roughnessFactor=1.0,
460
+ baseColorTexture=texture,
461
+ baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
462
+ )
463
+ mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
464
+ return mesh
465
+
466
+
467
+ def simplify_gs(
468
+ gs: Gaussian,
469
+ simplify: float = 0.95,
470
+ verbose: bool = True,
471
+ ):
472
+ """
473
+ Simplify 3D Gaussians
474
+ NOTE: this function is not used in the current implementation for the unsatisfactory performance.
475
+
476
+ Args:
477
+ gs (Gaussian): 3D Gaussian.
478
+ simplify (float): Ratio of Gaussians to remove in simplification.
479
+ """
480
+ if simplify <= 0:
481
+ return gs
482
+
483
+ # simplify
484
+ observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
485
+ observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
486
+
487
+ # Following https://arxiv.org/pdf/2411.06019
488
+ renderer = GaussianRenderer({
489
+ "resolution": 1024,
490
+ "near": 0.8,
491
+ "far": 1.6,
492
+ "ssaa": 1,
493
+ "bg_color": (0,0,0),
494
+ })
495
+ new_gs = Gaussian(**gs.init_params)
496
+ new_gs._features_dc = gs._features_dc.clone()
497
+ new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
498
+ new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
499
+ new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
500
+ new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
501
+ new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
502
+
503
+ start_lr = [1e-4, 1e-3, 5e-3, 0.025]
504
+ end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
505
+ optimizer = torch.optim.Adam([
506
+ {"params": new_gs._xyz, "lr": start_lr[0]},
507
+ {"params": new_gs._rotation, "lr": start_lr[1]},
508
+ {"params": new_gs._scaling, "lr": start_lr[2]},
509
+ {"params": new_gs._opacity, "lr": start_lr[3]},
510
+ ], lr=start_lr[0])
511
+
512
+ def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
513
+ return start_lr * (end_lr / start_lr) ** (step / total_steps)
514
+
515
+ def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
516
+ return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
517
+
518
+ _zeta = new_gs.get_opacity.clone().detach().squeeze()
519
+ _lambda = torch.zeros_like(_zeta)
520
+ _delta = 1e-7
521
+ _interval = 10
522
+ num_target = int((1 - simplify) * _zeta.shape[0])
523
+
524
+ with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
525
+ for i in range(2500):
526
+ # prune
527
+ if i % 100 == 0:
528
+ mask = new_gs.get_opacity.squeeze() > 0.05
529
+ mask = torch.nonzero(mask).squeeze()
530
+ new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
531
+ new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
532
+ new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
533
+ new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
534
+ new_gs._features_dc = new_gs._features_dc[mask]
535
+ new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
536
+ _zeta = _zeta[mask]
537
+ _lambda = _lambda[mask]
538
+ # update optimizer state
539
+ for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
540
+ stored_state = optimizer.state[param_group['params'][0]]
541
+ if 'exp_avg' in stored_state:
542
+ stored_state['exp_avg'] = stored_state['exp_avg'][mask]
543
+ stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
544
+ del optimizer.state[param_group['params'][0]]
545
+ param_group['params'][0] = new_param
546
+ optimizer.state[param_group['params'][0]] = stored_state
547
+
548
+ opacity = new_gs.get_opacity.squeeze()
549
+
550
+ # sparisfy
551
+ if i % _interval == 0:
552
+ _zeta = _lambda + opacity.detach()
553
+ if opacity.shape[0] > num_target:
554
+ index = _zeta.topk(num_target)[1]
555
+ _m = torch.ones_like(_zeta, dtype=torch.bool)
556
+ _m[index] = 0
557
+ _zeta[_m] = 0
558
+ _lambda = _lambda + opacity.detach() - _zeta
559
+
560
+ # sample a random view
561
+ view_idx = np.random.randint(len(observations))
562
+ observation = observations[view_idx]
563
+ extrinsic = extrinsics[view_idx]
564
+ intrinsic = intrinsics[view_idx]
565
+
566
+ color = renderer.render(new_gs, extrinsic, intrinsic)['color']
567
+ rgb_loss = torch.nn.functional.l1_loss(color, observation)
568
+ loss = rgb_loss + \
569
+ _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
570
+
571
+ optimizer.zero_grad()
572
+ loss.backward()
573
+ optimizer.step()
574
+
575
+ # update lr
576
+ for j in range(len(optimizer.param_groups)):
577
+ optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
578
+
579
+ pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
580
+ pbar.update()
581
+
582
+ new_gs._xyz = new_gs._xyz.data
583
+ new_gs._rotation = new_gs._rotation.data
584
+ new_gs._scaling = new_gs._scaling.data
585
+ new_gs._opacity = new_gs._opacity.data
586
+
587
+ return new_gs
trellis/utils/random_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
4
+
5
+ def radical_inverse(base, n):
6
+ val = 0
7
+ inv_base = 1.0 / base
8
+ inv_base_n = inv_base
9
+ while n > 0:
10
+ digit = n % base
11
+ val += digit * inv_base_n
12
+ n //= base
13
+ inv_base_n *= inv_base
14
+ return val
15
+
16
+ def halton_sequence(dim, n):
17
+ return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
18
+
19
+ def hammersley_sequence(dim, n, num_samples):
20
+ return [n / num_samples] + halton_sequence(dim - 1, n)
21
+
22
+ def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
23
+ u, v = hammersley_sequence(2, n, num_samples)
24
+ u += offset[0] / num_samples
25
+ v += offset[1]
26
+ if remap:
27
+ u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
28
+ theta = np.arccos(1 - 2 * u) - np.pi / 2
29
+ phi = v * 2 * np.pi
30
+ return [phi, theta]
trellis/utils/render_utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ import utils3d
5
+ from PIL import Image
6
+
7
+ from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
8
+ from ..representations import Octree, Gaussian, MeshExtractResult
9
+ from ..modules import sparse as sp
10
+ from .random_utils import sphere_hammersley_sequence
11
+
12
+
13
+ def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
14
+ is_list = isinstance(yaws, list)
15
+ if not is_list:
16
+ yaws = [yaws]
17
+ pitchs = [pitchs]
18
+ if not isinstance(rs, list):
19
+ rs = [rs] * len(yaws)
20
+ if not isinstance(fovs, list):
21
+ fovs = [fovs] * len(yaws)
22
+ extrinsics = []
23
+ intrinsics = []
24
+ for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
25
+ fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
26
+ yaw = torch.tensor(float(yaw)).cuda()
27
+ pitch = torch.tensor(float(pitch)).cuda()
28
+ orig = torch.tensor([
29
+ torch.sin(yaw) * torch.cos(pitch),
30
+ torch.cos(yaw) * torch.cos(pitch),
31
+ torch.sin(pitch),
32
+ ]).cuda() * r
33
+ extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
34
+ intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
35
+ extrinsics.append(extr)
36
+ intrinsics.append(intr)
37
+ if not is_list:
38
+ extrinsics = extrinsics[0]
39
+ intrinsics = intrinsics[0]
40
+ return extrinsics, intrinsics
41
+
42
+
43
+ def get_renderer(sample, **kwargs):
44
+ if isinstance(sample, Octree):
45
+ renderer = OctreeRenderer()
46
+ renderer.rendering_options.resolution = kwargs.get('resolution', 512)
47
+ renderer.rendering_options.near = kwargs.get('near', 0.8)
48
+ renderer.rendering_options.far = kwargs.get('far', 1.6)
49
+ renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
50
+ renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
51
+ renderer.pipe.primitive = sample.primitive
52
+ elif isinstance(sample, Gaussian):
53
+ renderer = GaussianRenderer()
54
+ renderer.rendering_options.resolution = kwargs.get('resolution', 512)
55
+ renderer.rendering_options.near = kwargs.get('near', 0.8)
56
+ renderer.rendering_options.far = kwargs.get('far', 1.6)
57
+ renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
58
+ renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
59
+ renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
60
+ renderer.pipe.use_mip_gaussian = True
61
+ elif isinstance(sample, MeshExtractResult):
62
+ renderer = MeshRenderer()
63
+ renderer.rendering_options.resolution = kwargs.get('resolution', 512)
64
+ renderer.rendering_options.near = kwargs.get('near', 1)
65
+ renderer.rendering_options.far = kwargs.get('far', 100)
66
+ renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
67
+ else:
68
+ raise ValueError(f'Unsupported sample type: {type(sample)}')
69
+ return renderer
70
+
71
+
72
+ def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
73
+ renderer = get_renderer(sample, **options)
74
+ rets = {}
75
+ for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
76
+ if isinstance(sample, MeshExtractResult):
77
+ res = renderer.render(sample, extr, intr)
78
+ if 'normal' not in rets: rets['normal'] = []
79
+ rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
80
+ else:
81
+ res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
82
+ if 'color' not in rets: rets['color'] = []
83
+ if 'depth' not in rets: rets['depth'] = []
84
+ rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
85
+ if 'percent_depth' in res:
86
+ rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
87
+ elif 'depth' in res:
88
+ rets['depth'].append(res['depth'].detach().cpu().numpy())
89
+ else:
90
+ rets['depth'].append(None)
91
+ return rets
92
+
93
+
94
+ def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
95
+ yaws = torch.linspace(0, 2 * 3.1415, num_frames)
96
+ pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
97
+ yaws = yaws.tolist()
98
+ pitch = pitch.tolist()
99
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
100
+ return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
101
+
102
+
103
+ def render_multiview(sample, resolution=512, nviews=30):
104
+ r = 2
105
+ fov = 40
106
+ cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
107
+ yaws = [cam[0] for cam in cams]
108
+ pitchs = [cam[1] for cam in cams]
109
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
110
+ res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
111
+ return res['color'], extrinsics, intrinsics
112
+
113
+
114
+ def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
115
+ yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
116
+ yaw_offset = offset[0]
117
+ yaw = [y + yaw_offset for y in yaw]
118
+ pitch = [offset[1] for _ in range(4)]
119
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
120
+ return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)