Spaces:
Running
on
Zero
Running
on
Zero
Upload 10 files
Browse files- trellis/utils/__init__.py +0 -0
- trellis/utils/data_utils.py +226 -0
- trellis/utils/dist_utils.py +93 -0
- trellis/utils/elastic_utils.py +228 -0
- trellis/utils/general_utils.py +202 -0
- trellis/utils/grad_clip_utils.py +81 -0
- trellis/utils/loss_utils.py +92 -0
- trellis/utils/postprocessing_utils.py +587 -0
- trellis/utils/random_utils.py +30 -0
- trellis/utils/render_utils.py +120 -0
trellis/utils/__init__.py
ADDED
File without changes
|
trellis/utils/data_utils.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
|
9 |
+
def recursive_to_device(
|
10 |
+
data: Any,
|
11 |
+
device: torch.device,
|
12 |
+
non_blocking: bool = False,
|
13 |
+
) -> Any:
|
14 |
+
"""
|
15 |
+
Recursively move all tensors in a data structure to a device.
|
16 |
+
"""
|
17 |
+
if hasattr(data, "to"):
|
18 |
+
return data.to(device, non_blocking=non_blocking)
|
19 |
+
elif isinstance(data, (list, tuple)):
|
20 |
+
return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
|
21 |
+
elif isinstance(data, dict):
|
22 |
+
return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
|
23 |
+
else:
|
24 |
+
return data
|
25 |
+
|
26 |
+
|
27 |
+
def load_balanced_group_indices(
|
28 |
+
load: List[int],
|
29 |
+
num_groups: int,
|
30 |
+
equal_size: bool = False,
|
31 |
+
) -> List[List[int]]:
|
32 |
+
"""
|
33 |
+
Split indices into groups with balanced load.
|
34 |
+
"""
|
35 |
+
if equal_size:
|
36 |
+
group_size = len(load) // num_groups
|
37 |
+
indices = np.argsort(load)[::-1]
|
38 |
+
groups = [[] for _ in range(num_groups)]
|
39 |
+
group_load = np.zeros(num_groups)
|
40 |
+
for idx in indices:
|
41 |
+
min_group_idx = np.argmin(group_load)
|
42 |
+
groups[min_group_idx].append(idx)
|
43 |
+
if equal_size and len(groups[min_group_idx]) == group_size:
|
44 |
+
group_load[min_group_idx] = float('inf')
|
45 |
+
else:
|
46 |
+
group_load[min_group_idx] += load[idx]
|
47 |
+
return groups
|
48 |
+
|
49 |
+
|
50 |
+
def cycle(data_loader: DataLoader) -> Iterator:
|
51 |
+
while True:
|
52 |
+
for data in data_loader:
|
53 |
+
if isinstance(data_loader.sampler, ResumableSampler):
|
54 |
+
data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
|
55 |
+
yield data
|
56 |
+
if isinstance(data_loader.sampler, DistributedSampler):
|
57 |
+
data_loader.sampler.epoch += 1
|
58 |
+
if isinstance(data_loader.sampler, ResumableSampler):
|
59 |
+
data_loader.sampler.epoch += 1
|
60 |
+
data_loader.sampler.idx = 0
|
61 |
+
|
62 |
+
|
63 |
+
class ResumableSampler(Sampler):
|
64 |
+
"""
|
65 |
+
Distributed sampler that is resumable.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
dataset: Dataset used for sampling.
|
69 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
70 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
71 |
+
group.
|
72 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
73 |
+
indices.
|
74 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
75 |
+
:attr:`shuffle=True`. This number should be identical across all
|
76 |
+
processes in the distributed group. Default: ``0``.
|
77 |
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
78 |
+
tail of the data to make it evenly divisible across the number of
|
79 |
+
replicas. If ``False``, the sampler will add extra indices to make
|
80 |
+
the data evenly divisible across the replicas. Default: ``False``.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
dataset: Dataset,
|
86 |
+
shuffle: bool = True,
|
87 |
+
seed: int = 0,
|
88 |
+
drop_last: bool = False,
|
89 |
+
) -> None:
|
90 |
+
self.dataset = dataset
|
91 |
+
self.epoch = 0
|
92 |
+
self.idx = 0
|
93 |
+
self.drop_last = drop_last
|
94 |
+
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
95 |
+
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
96 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
97 |
+
# is no need to drop any data, since the dataset will be split equally.
|
98 |
+
if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
|
99 |
+
# Split to nearest available length that is evenly divisible.
|
100 |
+
# This is to ensure each rank receives the same amount of data when
|
101 |
+
# using this Sampler.
|
102 |
+
self.num_samples = math.ceil(
|
103 |
+
(len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
|
107 |
+
self.total_size = self.num_samples * self.world_size
|
108 |
+
self.shuffle = shuffle
|
109 |
+
self.seed = seed
|
110 |
+
|
111 |
+
def __iter__(self) -> Iterator:
|
112 |
+
if self.shuffle:
|
113 |
+
# deterministically shuffle based on epoch and seed
|
114 |
+
g = torch.Generator()
|
115 |
+
g.manual_seed(self.seed + self.epoch)
|
116 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
117 |
+
else:
|
118 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
119 |
+
|
120 |
+
if not self.drop_last:
|
121 |
+
# add extra samples to make it evenly divisible
|
122 |
+
padding_size = self.total_size - len(indices)
|
123 |
+
if padding_size <= len(indices):
|
124 |
+
indices += indices[:padding_size]
|
125 |
+
else:
|
126 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[
|
127 |
+
:padding_size
|
128 |
+
]
|
129 |
+
else:
|
130 |
+
# remove tail of data to make it evenly divisible.
|
131 |
+
indices = indices[: self.total_size]
|
132 |
+
assert len(indices) == self.total_size
|
133 |
+
|
134 |
+
# subsample
|
135 |
+
indices = indices[self.rank : self.total_size : self.world_size]
|
136 |
+
|
137 |
+
# resume from previous state
|
138 |
+
indices = indices[self.idx:]
|
139 |
+
|
140 |
+
return iter(indices)
|
141 |
+
|
142 |
+
def __len__(self) -> int:
|
143 |
+
return self.num_samples
|
144 |
+
|
145 |
+
def state_dict(self) -> dict[str, int]:
|
146 |
+
return {
|
147 |
+
'epoch': self.epoch,
|
148 |
+
'idx': self.idx,
|
149 |
+
}
|
150 |
+
|
151 |
+
def load_state_dict(self, state_dict):
|
152 |
+
self.epoch = state_dict['epoch']
|
153 |
+
self.idx = state_dict['idx']
|
154 |
+
|
155 |
+
|
156 |
+
class BalancedResumableSampler(ResumableSampler):
|
157 |
+
"""
|
158 |
+
Distributed sampler that is resumable and balances the load among the processes.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
dataset: Dataset used for sampling.
|
162 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
163 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
164 |
+
group.
|
165 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
166 |
+
indices.
|
167 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
168 |
+
:attr:`shuffle=True`. This number should be identical across all
|
169 |
+
processes in the distributed group. Default: ``0``.
|
170 |
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
171 |
+
tail of the data to make it evenly divisible across the number of
|
172 |
+
replicas. If ``False``, the sampler will add extra indices to make
|
173 |
+
the data evenly divisible across the replicas. Default: ``False``.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
dataset: Dataset,
|
179 |
+
shuffle: bool = True,
|
180 |
+
seed: int = 0,
|
181 |
+
drop_last: bool = False,
|
182 |
+
batch_size: int = 1,
|
183 |
+
) -> None:
|
184 |
+
assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
|
185 |
+
super().__init__(dataset, shuffle, seed, drop_last)
|
186 |
+
self.batch_size = batch_size
|
187 |
+
self.loads = dataset.loads
|
188 |
+
|
189 |
+
def __iter__(self) -> Iterator:
|
190 |
+
if self.shuffle:
|
191 |
+
# deterministically shuffle based on epoch and seed
|
192 |
+
g = torch.Generator()
|
193 |
+
g.manual_seed(self.seed + self.epoch)
|
194 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
195 |
+
else:
|
196 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
197 |
+
|
198 |
+
if not self.drop_last:
|
199 |
+
# add extra samples to make it evenly divisible
|
200 |
+
padding_size = self.total_size - len(indices)
|
201 |
+
if padding_size <= len(indices):
|
202 |
+
indices += indices[:padding_size]
|
203 |
+
else:
|
204 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[
|
205 |
+
:padding_size
|
206 |
+
]
|
207 |
+
else:
|
208 |
+
# remove tail of data to make it evenly divisible.
|
209 |
+
indices = indices[: self.total_size]
|
210 |
+
assert len(indices) == self.total_size
|
211 |
+
|
212 |
+
# balance load among processes
|
213 |
+
num_batches = len(indices) // (self.batch_size * self.world_size)
|
214 |
+
balanced_indices = []
|
215 |
+
for i in range(num_batches):
|
216 |
+
start_idx = i * self.batch_size * self.world_size
|
217 |
+
end_idx = (i + 1) * self.batch_size * self.world_size
|
218 |
+
batch_indices = indices[start_idx:end_idx]
|
219 |
+
batch_loads = [self.loads[idx] for idx in batch_indices]
|
220 |
+
groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
|
221 |
+
balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
|
222 |
+
|
223 |
+
# resume from previous state
|
224 |
+
indices = balanced_indices[self.idx:]
|
225 |
+
|
226 |
+
return iter(indices)
|
trellis/utils/dist_utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
from contextlib import contextmanager
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
7 |
+
|
8 |
+
|
9 |
+
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
|
10 |
+
os.environ['MASTER_ADDR'] = master_addr
|
11 |
+
os.environ['MASTER_PORT'] = master_port
|
12 |
+
os.environ['WORLD_SIZE'] = str(world_size)
|
13 |
+
os.environ['RANK'] = str(rank)
|
14 |
+
os.environ['LOCAL_RANK'] = str(local_rank)
|
15 |
+
torch.cuda.set_device(local_rank)
|
16 |
+
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
17 |
+
|
18 |
+
|
19 |
+
def read_file_dist(path):
|
20 |
+
"""
|
21 |
+
Read the binary file distributedly.
|
22 |
+
File is only read once by the rank 0 process and broadcasted to other processes.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
data (io.BytesIO): The binary data read from the file.
|
26 |
+
"""
|
27 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
28 |
+
# read file
|
29 |
+
size = torch.LongTensor(1).cuda()
|
30 |
+
if dist.get_rank() == 0:
|
31 |
+
with open(path, 'rb') as f:
|
32 |
+
data = f.read()
|
33 |
+
data = torch.ByteTensor(
|
34 |
+
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
|
35 |
+
).cuda()
|
36 |
+
size[0] = data.shape[0]
|
37 |
+
# broadcast size
|
38 |
+
dist.broadcast(size, src=0)
|
39 |
+
if dist.get_rank() != 0:
|
40 |
+
data = torch.ByteTensor(size[0].item()).cuda()
|
41 |
+
# broadcast data
|
42 |
+
dist.broadcast(data, src=0)
|
43 |
+
# convert to io.BytesIO
|
44 |
+
data = data.cpu().numpy().tobytes()
|
45 |
+
data = io.BytesIO(data)
|
46 |
+
return data
|
47 |
+
else:
|
48 |
+
with open(path, 'rb') as f:
|
49 |
+
data = f.read()
|
50 |
+
data = io.BytesIO(data)
|
51 |
+
return data
|
52 |
+
|
53 |
+
|
54 |
+
def unwrap_dist(model):
|
55 |
+
"""
|
56 |
+
Unwrap the model from distributed training.
|
57 |
+
"""
|
58 |
+
if isinstance(model, DDP):
|
59 |
+
return model.module
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
@contextmanager
|
64 |
+
def master_first():
|
65 |
+
"""
|
66 |
+
A context manager that ensures master process executes first.
|
67 |
+
"""
|
68 |
+
if not dist.is_initialized():
|
69 |
+
yield
|
70 |
+
else:
|
71 |
+
if dist.get_rank() == 0:
|
72 |
+
yield
|
73 |
+
dist.barrier()
|
74 |
+
else:
|
75 |
+
dist.barrier()
|
76 |
+
yield
|
77 |
+
|
78 |
+
|
79 |
+
@contextmanager
|
80 |
+
def local_master_first():
|
81 |
+
"""
|
82 |
+
A context manager that ensures local master process executes first.
|
83 |
+
"""
|
84 |
+
if not dist.is_initialized():
|
85 |
+
yield
|
86 |
+
else:
|
87 |
+
if dist.get_rank() % torch.cuda.device_count() == 0:
|
88 |
+
yield
|
89 |
+
dist.barrier()
|
90 |
+
else:
|
91 |
+
dist.barrier()
|
92 |
+
yield
|
93 |
+
|
trellis/utils/elastic_utils.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from typing import Tuple
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class MemoryController:
|
10 |
+
"""
|
11 |
+
Base class for memory management during training.
|
12 |
+
"""
|
13 |
+
|
14 |
+
_last_input_size = None
|
15 |
+
_last_mem_ratio = []
|
16 |
+
|
17 |
+
@contextmanager
|
18 |
+
def record(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
def update_run_states(self, input_size=None, mem_ratio=None):
|
22 |
+
if self._last_input_size is None:
|
23 |
+
self._last_input_size = input_size
|
24 |
+
elif self._last_input_size!= input_size:
|
25 |
+
raise ValueError(f'Input size should not change for different ElasticModules.')
|
26 |
+
self._last_mem_ratio.append(mem_ratio)
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def get_mem_ratio(self, input_size):
|
30 |
+
pass
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def state_dict(self):
|
34 |
+
pass
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def log(self):
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
class LinearMemoryController(MemoryController):
|
42 |
+
"""
|
43 |
+
A simple controller for memory management during training.
|
44 |
+
The memory usage is modeled as a linear function of:
|
45 |
+
- the number of input parameters
|
46 |
+
- the ratio of memory the model use compared to the maximum usage (with no checkpointing)
|
47 |
+
memory_usage = k * input_size * mem_ratio + b
|
48 |
+
The controller keeps track of the memory usage and gives the
|
49 |
+
expected memory ratio to keep the memory usage under a target
|
50 |
+
"""
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
buffer_size=1000,
|
54 |
+
update_every=500,
|
55 |
+
target_ratio=0.8,
|
56 |
+
available_memory=None,
|
57 |
+
max_mem_ratio_start=0.1,
|
58 |
+
params=None,
|
59 |
+
device=None
|
60 |
+
):
|
61 |
+
self.buffer_size = buffer_size
|
62 |
+
self.update_every = update_every
|
63 |
+
self.target_ratio = target_ratio
|
64 |
+
self.device = device or torch.cuda.current_device()
|
65 |
+
self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
|
66 |
+
|
67 |
+
self._memory = np.zeros(buffer_size, dtype=np.float32)
|
68 |
+
self._input_size = np.zeros(buffer_size, dtype=np.float32)
|
69 |
+
self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
|
70 |
+
self._buffer_ptr = 0
|
71 |
+
self._buffer_length = 0
|
72 |
+
self._params = tuple(params) if params is not None else (0.0, 0.0)
|
73 |
+
self._max_mem_ratio = max_mem_ratio_start
|
74 |
+
self.step = 0
|
75 |
+
|
76 |
+
def __repr__(self):
|
77 |
+
return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
|
78 |
+
|
79 |
+
def _add_sample(self, memory, input_size, mem_ratio):
|
80 |
+
self._memory[self._buffer_ptr] = memory
|
81 |
+
self._input_size[self._buffer_ptr] = input_size
|
82 |
+
self._mem_ratio[self._buffer_ptr] = mem_ratio
|
83 |
+
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
84 |
+
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
85 |
+
|
86 |
+
@contextmanager
|
87 |
+
def record(self):
|
88 |
+
torch.cuda.reset_peak_memory_stats(self.device)
|
89 |
+
self._last_input_size = None
|
90 |
+
self._last_mem_ratio = []
|
91 |
+
yield
|
92 |
+
self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
|
93 |
+
self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
|
94 |
+
self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
|
95 |
+
self.step += 1
|
96 |
+
if self.step % self.update_every == 0:
|
97 |
+
self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
|
98 |
+
self._fit_params()
|
99 |
+
|
100 |
+
def _fit_params(self):
|
101 |
+
memory_usage = self._memory[:self._buffer_length]
|
102 |
+
input_size = self._input_size[:self._buffer_length]
|
103 |
+
mem_ratio = self._mem_ratio[:self._buffer_length]
|
104 |
+
|
105 |
+
x = input_size * mem_ratio
|
106 |
+
y = memory_usage
|
107 |
+
k, b = np.polyfit(x, y, 1)
|
108 |
+
self._params = (k, b)
|
109 |
+
# self._visualize()
|
110 |
+
|
111 |
+
def _visualize(self):
|
112 |
+
import matplotlib.pyplot as plt
|
113 |
+
memory_usage = self._memory[:self._buffer_length]
|
114 |
+
input_size = self._input_size[:self._buffer_length]
|
115 |
+
mem_ratio = self._mem_ratio[:self._buffer_length]
|
116 |
+
k, b = self._params
|
117 |
+
|
118 |
+
plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
|
119 |
+
x = np.array([0.0, 20000.0])
|
120 |
+
plt.plot(x, k * x + b, c='r')
|
121 |
+
plt.savefig(f'linear_memory_controller_{self.step}.png')
|
122 |
+
plt.cla()
|
123 |
+
|
124 |
+
def get_mem_ratio(self, input_size):
|
125 |
+
k, b = self._params
|
126 |
+
if k == 0: return np.random.rand() * self._max_mem_ratio
|
127 |
+
pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
|
128 |
+
return min(self._max_mem_ratio, max(0.0, pred))
|
129 |
+
|
130 |
+
def state_dict(self):
|
131 |
+
return {
|
132 |
+
'params': self._params,
|
133 |
+
}
|
134 |
+
|
135 |
+
def load_state_dict(self, state_dict):
|
136 |
+
self._params = tuple(state_dict['params'])
|
137 |
+
|
138 |
+
def log(self):
|
139 |
+
return {
|
140 |
+
'params/k': self._params[0],
|
141 |
+
'params/b': self._params[1],
|
142 |
+
'memory': self._last_memory,
|
143 |
+
'input_size': self._last_input_size,
|
144 |
+
'mem_ratio': self._last_mem_ratio,
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
class ElasticModule(nn.Module):
|
149 |
+
"""
|
150 |
+
Module for training with elastic memory management.
|
151 |
+
"""
|
152 |
+
def __init__(self):
|
153 |
+
super().__init__()
|
154 |
+
self._memory_controller: MemoryController = None
|
155 |
+
|
156 |
+
@abstractmethod
|
157 |
+
def _get_input_size(self, *args, **kwargs) -> int:
|
158 |
+
"""
|
159 |
+
Get the size of the input data.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
int: The size of the input data.
|
163 |
+
"""
|
164 |
+
pass
|
165 |
+
|
166 |
+
@abstractmethod
|
167 |
+
def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
|
168 |
+
"""
|
169 |
+
Forward with a given memory ratio.
|
170 |
+
"""
|
171 |
+
pass
|
172 |
+
|
173 |
+
def register_memory_controller(self, memory_controller: MemoryController):
|
174 |
+
self._memory_controller = memory_controller
|
175 |
+
|
176 |
+
def forward(self, *args, **kwargs):
|
177 |
+
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
178 |
+
_, ret = self._forward_with_mem_ratio(*args, **kwargs)
|
179 |
+
else:
|
180 |
+
input_size = self._get_input_size(*args, **kwargs)
|
181 |
+
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
182 |
+
mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
|
183 |
+
self._memory_controller.update_run_states(input_size, mem_ratio)
|
184 |
+
return ret
|
185 |
+
|
186 |
+
|
187 |
+
class ElasticModuleMixin:
|
188 |
+
"""
|
189 |
+
Mixin for training with elastic memory management.
|
190 |
+
"""
|
191 |
+
def __init__(self, *args, **kwargs):
|
192 |
+
super().__init__(*args, **kwargs)
|
193 |
+
self._memory_controller: MemoryController = None
|
194 |
+
|
195 |
+
@abstractmethod
|
196 |
+
def _get_input_size(self, *args, **kwargs) -> int:
|
197 |
+
"""
|
198 |
+
Get the size of the input data.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
int: The size of the input data.
|
202 |
+
"""
|
203 |
+
pass
|
204 |
+
|
205 |
+
@abstractmethod
|
206 |
+
@contextmanager
|
207 |
+
def with_mem_ratio(self, mem_ratio=1.0) -> float:
|
208 |
+
"""
|
209 |
+
Context manager for training with a reduced memory ratio compared to the full memory usage.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
float: The exact memory ratio used during the forward pass.
|
213 |
+
"""
|
214 |
+
pass
|
215 |
+
|
216 |
+
def register_memory_controller(self, memory_controller: MemoryController):
|
217 |
+
self._memory_controller = memory_controller
|
218 |
+
|
219 |
+
def forward(self, *args, **kwargs):
|
220 |
+
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
221 |
+
ret = super().forward(*args, **kwargs)
|
222 |
+
else:
|
223 |
+
input_size = self._get_input_size(*args, **kwargs)
|
224 |
+
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
225 |
+
with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
|
226 |
+
ret = super().forward(*args, **kwargs)
|
227 |
+
self._memory_controller.update_run_states(input_size, exact_mem_ratio)
|
228 |
+
return ret
|
trellis/utils/general_utils.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import contextlib
|
6 |
+
|
7 |
+
|
8 |
+
# Dictionary utils
|
9 |
+
def _dict_merge(dicta, dictb, prefix=''):
|
10 |
+
"""
|
11 |
+
Merge two dictionaries.
|
12 |
+
"""
|
13 |
+
assert isinstance(dicta, dict), 'input must be a dictionary'
|
14 |
+
assert isinstance(dictb, dict), 'input must be a dictionary'
|
15 |
+
dict_ = {}
|
16 |
+
all_keys = set(dicta.keys()).union(set(dictb.keys()))
|
17 |
+
for key in all_keys:
|
18 |
+
if key in dicta.keys() and key in dictb.keys():
|
19 |
+
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
|
20 |
+
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
|
21 |
+
else:
|
22 |
+
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
|
23 |
+
elif key in dicta.keys():
|
24 |
+
dict_[key] = dicta[key]
|
25 |
+
else:
|
26 |
+
dict_[key] = dictb[key]
|
27 |
+
return dict_
|
28 |
+
|
29 |
+
|
30 |
+
def dict_merge(dicta, dictb):
|
31 |
+
"""
|
32 |
+
Merge two dictionaries.
|
33 |
+
"""
|
34 |
+
return _dict_merge(dicta, dictb, prefix='')
|
35 |
+
|
36 |
+
|
37 |
+
def dict_foreach(dic, func, special_func={}):
|
38 |
+
"""
|
39 |
+
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
40 |
+
"""
|
41 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
42 |
+
for key in dic.keys():
|
43 |
+
if isinstance(dic[key], dict):
|
44 |
+
dic[key] = dict_foreach(dic[key], func)
|
45 |
+
else:
|
46 |
+
if key in special_func.keys():
|
47 |
+
dic[key] = special_func[key](dic[key])
|
48 |
+
else:
|
49 |
+
dic[key] = func(dic[key])
|
50 |
+
return dic
|
51 |
+
|
52 |
+
|
53 |
+
def dict_reduce(dicts, func, special_func={}):
|
54 |
+
"""
|
55 |
+
Reduce a list of dictionaries. Leaf values must be scalars.
|
56 |
+
"""
|
57 |
+
assert isinstance(dicts, list), 'input must be a list of dictionaries'
|
58 |
+
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
|
59 |
+
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
|
60 |
+
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
|
61 |
+
reduced_dict = {}
|
62 |
+
for key in all_keys:
|
63 |
+
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
|
64 |
+
if isinstance(vlist[0], dict):
|
65 |
+
reduced_dict[key] = dict_reduce(vlist, func, special_func)
|
66 |
+
else:
|
67 |
+
if key in special_func.keys():
|
68 |
+
reduced_dict[key] = special_func[key](vlist)
|
69 |
+
else:
|
70 |
+
reduced_dict[key] = func(vlist)
|
71 |
+
return reduced_dict
|
72 |
+
|
73 |
+
|
74 |
+
def dict_any(dic, func):
|
75 |
+
"""
|
76 |
+
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
77 |
+
"""
|
78 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
79 |
+
for key in dic.keys():
|
80 |
+
if isinstance(dic[key], dict):
|
81 |
+
if dict_any(dic[key], func):
|
82 |
+
return True
|
83 |
+
else:
|
84 |
+
if func(dic[key]):
|
85 |
+
return True
|
86 |
+
return False
|
87 |
+
|
88 |
+
|
89 |
+
def dict_all(dic, func):
|
90 |
+
"""
|
91 |
+
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
92 |
+
"""
|
93 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
94 |
+
for key in dic.keys():
|
95 |
+
if isinstance(dic[key], dict):
|
96 |
+
if not dict_all(dic[key], func):
|
97 |
+
return False
|
98 |
+
else:
|
99 |
+
if not func(dic[key]):
|
100 |
+
return False
|
101 |
+
return True
|
102 |
+
|
103 |
+
|
104 |
+
def dict_flatten(dic, sep='.'):
|
105 |
+
"""
|
106 |
+
Flatten a nested dictionary into a dictionary with no nested dictionaries.
|
107 |
+
"""
|
108 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
109 |
+
flat_dict = {}
|
110 |
+
for key in dic.keys():
|
111 |
+
if isinstance(dic[key], dict):
|
112 |
+
sub_dict = dict_flatten(dic[key], sep=sep)
|
113 |
+
for sub_key in sub_dict.keys():
|
114 |
+
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
|
115 |
+
else:
|
116 |
+
flat_dict[key] = dic[key]
|
117 |
+
return flat_dict
|
118 |
+
|
119 |
+
|
120 |
+
# Context utils
|
121 |
+
@contextlib.contextmanager
|
122 |
+
def nested_contexts(*contexts):
|
123 |
+
with contextlib.ExitStack() as stack:
|
124 |
+
for ctx in contexts:
|
125 |
+
stack.enter_context(ctx())
|
126 |
+
yield
|
127 |
+
|
128 |
+
|
129 |
+
# Image utils
|
130 |
+
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
|
131 |
+
num_images = len(images)
|
132 |
+
if nrow is None and ncol is None:
|
133 |
+
if aspect_ratio is not None:
|
134 |
+
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
|
135 |
+
else:
|
136 |
+
nrow = int(np.sqrt(num_images))
|
137 |
+
ncol = (num_images + nrow - 1) // nrow
|
138 |
+
elif nrow is None and ncol is not None:
|
139 |
+
nrow = (num_images + ncol - 1) // ncol
|
140 |
+
elif nrow is not None and ncol is None:
|
141 |
+
ncol = (num_images + nrow - 1) // nrow
|
142 |
+
else:
|
143 |
+
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
|
144 |
+
|
145 |
+
if images[0].ndim == 2:
|
146 |
+
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
|
147 |
+
else:
|
148 |
+
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
|
149 |
+
for i, img in enumerate(images):
|
150 |
+
row = i // ncol
|
151 |
+
col = i % ncol
|
152 |
+
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
|
153 |
+
return grid
|
154 |
+
|
155 |
+
|
156 |
+
def notes_on_image(img, notes=None):
|
157 |
+
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
|
158 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
159 |
+
if notes is not None:
|
160 |
+
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
|
161 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
162 |
+
return img
|
163 |
+
|
164 |
+
|
165 |
+
def save_image_with_notes(img, path, notes=None):
|
166 |
+
"""
|
167 |
+
Save an image with notes.
|
168 |
+
"""
|
169 |
+
if isinstance(img, torch.Tensor):
|
170 |
+
img = img.cpu().numpy().transpose(1, 2, 0)
|
171 |
+
if img.dtype == np.float32 or img.dtype == np.float64:
|
172 |
+
img = np.clip(img * 255, 0, 255).astype(np.uint8)
|
173 |
+
img = notes_on_image(img, notes)
|
174 |
+
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
175 |
+
|
176 |
+
|
177 |
+
# debug utils
|
178 |
+
|
179 |
+
def atol(x, y):
|
180 |
+
"""
|
181 |
+
Absolute tolerance.
|
182 |
+
"""
|
183 |
+
return torch.abs(x - y)
|
184 |
+
|
185 |
+
|
186 |
+
def rtol(x, y):
|
187 |
+
"""
|
188 |
+
Relative tolerance.
|
189 |
+
"""
|
190 |
+
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
|
191 |
+
|
192 |
+
|
193 |
+
# print utils
|
194 |
+
def indent(s, n=4):
|
195 |
+
"""
|
196 |
+
Indent a string.
|
197 |
+
"""
|
198 |
+
lines = s.split('\n')
|
199 |
+
for i in range(1, len(lines)):
|
200 |
+
lines[i] = ' ' * n + lines[i]
|
201 |
+
return '\n'.join(lines)
|
202 |
+
|
trellis/utils/grad_clip_utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.utils
|
5 |
+
|
6 |
+
|
7 |
+
class AdaptiveGradClipper:
|
8 |
+
"""
|
9 |
+
Adaptive gradient clipping for training.
|
10 |
+
"""
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
max_norm=None,
|
14 |
+
clip_percentile=95.0,
|
15 |
+
buffer_size=1000,
|
16 |
+
):
|
17 |
+
self.max_norm = max_norm
|
18 |
+
self.clip_percentile = clip_percentile
|
19 |
+
self.buffer_size = buffer_size
|
20 |
+
|
21 |
+
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
|
22 |
+
self._max_norm = max_norm
|
23 |
+
self._buffer_ptr = 0
|
24 |
+
self._buffer_length = 0
|
25 |
+
|
26 |
+
def __repr__(self):
|
27 |
+
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
|
28 |
+
|
29 |
+
def state_dict(self):
|
30 |
+
return {
|
31 |
+
'grad_norm': self._grad_norm,
|
32 |
+
'max_norm': self._max_norm,
|
33 |
+
'buffer_ptr': self._buffer_ptr,
|
34 |
+
'buffer_length': self._buffer_length,
|
35 |
+
}
|
36 |
+
|
37 |
+
def load_state_dict(self, state_dict):
|
38 |
+
self._grad_norm = state_dict['grad_norm']
|
39 |
+
self._max_norm = state_dict['max_norm']
|
40 |
+
self._buffer_ptr = state_dict['buffer_ptr']
|
41 |
+
self._buffer_length = state_dict['buffer_length']
|
42 |
+
|
43 |
+
def log(self):
|
44 |
+
return {
|
45 |
+
'max_norm': self._max_norm,
|
46 |
+
}
|
47 |
+
|
48 |
+
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
|
49 |
+
"""Clip the gradient norm of an iterable of parameters.
|
50 |
+
|
51 |
+
The norm is computed over all gradients together, as if they were
|
52 |
+
concatenated into a single vector. Gradients are modified in-place.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
56 |
+
single Tensor that will have gradients normalized
|
57 |
+
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
58 |
+
infinity norm.
|
59 |
+
error_if_nonfinite (bool): if True, an error is thrown if the total
|
60 |
+
norm of the gradients from :attr:`parameters` is ``nan``,
|
61 |
+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
62 |
+
foreach (bool): use the faster foreach-based implementation.
|
63 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
64 |
+
fall back to the slow implementation for other device types.
|
65 |
+
Default: ``None``
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
69 |
+
"""
|
70 |
+
max_norm = self._max_norm if self._max_norm is not None else float('inf')
|
71 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
|
72 |
+
|
73 |
+
if torch.isfinite(grad_norm):
|
74 |
+
self._grad_norm[self._buffer_ptr] = grad_norm
|
75 |
+
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
76 |
+
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
77 |
+
if self._buffer_length == self.buffer_size:
|
78 |
+
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
|
79 |
+
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
|
80 |
+
|
81 |
+
return grad_norm
|
trellis/utils/loss_utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from math import exp
|
5 |
+
from lpips import LPIPS
|
6 |
+
|
7 |
+
|
8 |
+
def smooth_l1_loss(pred, target, beta=1.0):
|
9 |
+
diff = torch.abs(pred - target)
|
10 |
+
loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
|
11 |
+
return loss.mean()
|
12 |
+
|
13 |
+
|
14 |
+
def l1_loss(network_output, gt):
|
15 |
+
return torch.abs((network_output - gt)).mean()
|
16 |
+
|
17 |
+
|
18 |
+
def l2_loss(network_output, gt):
|
19 |
+
return ((network_output - gt) ** 2).mean()
|
20 |
+
|
21 |
+
|
22 |
+
def gaussian(window_size, sigma):
|
23 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
24 |
+
return gauss / gauss.sum()
|
25 |
+
|
26 |
+
|
27 |
+
def create_window(window_size, channel):
|
28 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
29 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
30 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
31 |
+
return window
|
32 |
+
|
33 |
+
|
34 |
+
def psnr(img1, img2, max_val=1.0):
|
35 |
+
mse = F.mse_loss(img1, img2)
|
36 |
+
return 20 * torch.log10(max_val / torch.sqrt(mse))
|
37 |
+
|
38 |
+
|
39 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
40 |
+
channel = img1.size(-3)
|
41 |
+
window = create_window(window_size, channel)
|
42 |
+
|
43 |
+
if img1.is_cuda:
|
44 |
+
window = window.cuda(img1.get_device())
|
45 |
+
window = window.type_as(img1)
|
46 |
+
|
47 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
48 |
+
|
49 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
50 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
51 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
52 |
+
|
53 |
+
mu1_sq = mu1.pow(2)
|
54 |
+
mu2_sq = mu2.pow(2)
|
55 |
+
mu1_mu2 = mu1 * mu2
|
56 |
+
|
57 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
58 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
59 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
60 |
+
|
61 |
+
C1 = 0.01 ** 2
|
62 |
+
C2 = 0.03 ** 2
|
63 |
+
|
64 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
65 |
+
|
66 |
+
if size_average:
|
67 |
+
return ssim_map.mean()
|
68 |
+
else:
|
69 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
70 |
+
|
71 |
+
|
72 |
+
loss_fn_vgg = None
|
73 |
+
def lpips(img1, img2, value_range=(0, 1)):
|
74 |
+
global loss_fn_vgg
|
75 |
+
if loss_fn_vgg is None:
|
76 |
+
loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
|
77 |
+
# normalize to [-1, 1]
|
78 |
+
img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
79 |
+
img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
80 |
+
return loss_fn_vgg(img1, img2).mean()
|
81 |
+
|
82 |
+
|
83 |
+
def normal_angle(pred, gt):
|
84 |
+
pred = pred * 2.0 - 1.0
|
85 |
+
gt = gt * 2.0 - 1.0
|
86 |
+
norms = pred.norm(dim=-1) * gt.norm(dim=-1)
|
87 |
+
cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
|
88 |
+
cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
|
89 |
+
ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
|
90 |
+
if ang.isnan():
|
91 |
+
return -1
|
92 |
+
return ang
|
trellis/utils/postprocessing_utils.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import utils3d
|
5 |
+
import nvdiffrast.torch as dr
|
6 |
+
from tqdm import tqdm
|
7 |
+
import trimesh
|
8 |
+
import trimesh.visual
|
9 |
+
import xatlas
|
10 |
+
import pyvista as pv
|
11 |
+
from pymeshfix import _meshfix
|
12 |
+
import igraph
|
13 |
+
import cv2
|
14 |
+
from PIL import Image
|
15 |
+
from .random_utils import sphere_hammersley_sequence
|
16 |
+
from .render_utils import render_multiview
|
17 |
+
from ..renderers import GaussianRenderer
|
18 |
+
from ..representations import Strivec, Gaussian, MeshExtractResult
|
19 |
+
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def _fill_holes(
|
23 |
+
verts,
|
24 |
+
faces,
|
25 |
+
max_hole_size=0.04,
|
26 |
+
max_hole_nbe=32,
|
27 |
+
resolution=128,
|
28 |
+
num_views=500,
|
29 |
+
debug=False,
|
30 |
+
verbose=False
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Rasterize a mesh from multiple views and remove invisible faces.
|
34 |
+
Also includes postprocessing to:
|
35 |
+
1. Remove connected components that are have low visibility.
|
36 |
+
2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
|
40 |
+
faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
|
41 |
+
max_hole_size (float): Maximum area of a hole to fill.
|
42 |
+
resolution (int): Resolution of the rasterization.
|
43 |
+
num_views (int): Number of views to rasterize the mesh.
|
44 |
+
verbose (bool): Whether to print progress.
|
45 |
+
"""
|
46 |
+
# Construct cameras
|
47 |
+
yaws = []
|
48 |
+
pitchs = []
|
49 |
+
for i in range(num_views):
|
50 |
+
y, p = sphere_hammersley_sequence(i, num_views)
|
51 |
+
yaws.append(y)
|
52 |
+
pitchs.append(p)
|
53 |
+
yaws = torch.tensor(yaws).cuda()
|
54 |
+
pitchs = torch.tensor(pitchs).cuda()
|
55 |
+
radius = 2.0
|
56 |
+
fov = torch.deg2rad(torch.tensor(40)).cuda()
|
57 |
+
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
|
58 |
+
views = []
|
59 |
+
for (yaw, pitch) in zip(yaws, pitchs):
|
60 |
+
orig = torch.tensor([
|
61 |
+
torch.sin(yaw) * torch.cos(pitch),
|
62 |
+
torch.cos(yaw) * torch.cos(pitch),
|
63 |
+
torch.sin(pitch),
|
64 |
+
]).cuda().float() * radius
|
65 |
+
view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
66 |
+
views.append(view)
|
67 |
+
views = torch.stack(views, dim=0)
|
68 |
+
|
69 |
+
# Rasterize
|
70 |
+
visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
|
71 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
72 |
+
for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
|
73 |
+
view = views[i]
|
74 |
+
buffers = utils3d.torch.rasterize_triangle_faces(
|
75 |
+
rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
|
76 |
+
)
|
77 |
+
face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
|
78 |
+
face_id = torch.unique(face_id).long()
|
79 |
+
visblity[face_id] += 1
|
80 |
+
visblity = visblity.float() / num_views
|
81 |
+
|
82 |
+
# Mincut
|
83 |
+
## construct outer faces
|
84 |
+
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
|
85 |
+
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
|
86 |
+
connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
|
87 |
+
outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
|
88 |
+
for i in range(len(connected_components)):
|
89 |
+
outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
|
90 |
+
outer_face_indices = outer_face_indices.nonzero().reshape(-1)
|
91 |
+
|
92 |
+
## construct inner faces
|
93 |
+
inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
|
94 |
+
if verbose:
|
95 |
+
tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
|
96 |
+
if inner_face_indices.shape[0] == 0:
|
97 |
+
return verts, faces
|
98 |
+
|
99 |
+
## Construct dual graph (faces as nodes, edges as edges)
|
100 |
+
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
|
101 |
+
dual_edge2edge = edges[dual_edge2edge]
|
102 |
+
dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
|
103 |
+
if verbose:
|
104 |
+
tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
|
105 |
+
|
106 |
+
## solve mincut problem
|
107 |
+
### construct main graph
|
108 |
+
g = igraph.Graph()
|
109 |
+
g.add_vertices(faces.shape[0])
|
110 |
+
g.add_edges(dual_edges.cpu().numpy())
|
111 |
+
g.es['weight'] = dual_edges_weights.cpu().numpy()
|
112 |
+
|
113 |
+
### source and target
|
114 |
+
g.add_vertex('s')
|
115 |
+
g.add_vertex('t')
|
116 |
+
|
117 |
+
### connect invisible faces to source
|
118 |
+
g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
119 |
+
|
120 |
+
### connect outer faces to target
|
121 |
+
g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
122 |
+
|
123 |
+
### solve mincut
|
124 |
+
cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
|
125 |
+
remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
|
126 |
+
if verbose:
|
127 |
+
tqdm.write(f'Mincut solved, start checking the cut')
|
128 |
+
|
129 |
+
### check if the cut is valid with each connected component
|
130 |
+
to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
|
131 |
+
if debug:
|
132 |
+
tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
|
133 |
+
valid_remove_cc = []
|
134 |
+
cutting_edges = []
|
135 |
+
for cc in to_remove_cc:
|
136 |
+
#### check if the connected component has low visibility
|
137 |
+
visblity_median = visblity[remove_face_indices[cc]].median()
|
138 |
+
if debug:
|
139 |
+
tqdm.write(f'visblity_median: {visblity_median}')
|
140 |
+
if visblity_median > 0.25:
|
141 |
+
continue
|
142 |
+
|
143 |
+
#### check if the cuting loop is small enough
|
144 |
+
cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
|
145 |
+
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
|
146 |
+
cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
|
147 |
+
if len(cc_new_boundary_edge_indices) > 0:
|
148 |
+
cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
|
149 |
+
cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
|
150 |
+
cc_new_boundary_edges_cc_area = []
|
151 |
+
for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
|
152 |
+
_e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
|
153 |
+
_e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
|
154 |
+
cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
|
155 |
+
if debug:
|
156 |
+
cutting_edges.append(cc_new_boundary_edge_indices)
|
157 |
+
tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
|
158 |
+
if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
|
159 |
+
continue
|
160 |
+
|
161 |
+
valid_remove_cc.append(cc)
|
162 |
+
|
163 |
+
if debug:
|
164 |
+
face_v = verts[faces].mean(dim=1).cpu().numpy()
|
165 |
+
vis_dual_edges = dual_edges.cpu().numpy()
|
166 |
+
vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
|
167 |
+
vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
|
168 |
+
vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
|
169 |
+
vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
|
170 |
+
if len(valid_remove_cc) > 0:
|
171 |
+
vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
|
172 |
+
utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
|
173 |
+
|
174 |
+
vis_verts = verts.cpu().numpy()
|
175 |
+
vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
|
176 |
+
utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
|
177 |
+
|
178 |
+
|
179 |
+
if len(valid_remove_cc) > 0:
|
180 |
+
remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
|
181 |
+
mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
|
182 |
+
mask[remove_face_indices] = 0
|
183 |
+
faces = faces[mask]
|
184 |
+
faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
|
185 |
+
if verbose:
|
186 |
+
tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
|
187 |
+
else:
|
188 |
+
if verbose:
|
189 |
+
tqdm.write(f'Removed 0 faces by mincut')
|
190 |
+
|
191 |
+
mesh = _meshfix.PyTMesh()
|
192 |
+
mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
|
193 |
+
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
|
194 |
+
verts, faces = mesh.return_arrays()
|
195 |
+
verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
|
196 |
+
|
197 |
+
return verts, faces
|
198 |
+
|
199 |
+
|
200 |
+
def postprocess_mesh(
|
201 |
+
vertices: np.array,
|
202 |
+
faces: np.array,
|
203 |
+
simplify: bool = True,
|
204 |
+
simplify_ratio: float = 0.9,
|
205 |
+
fill_holes: bool = True,
|
206 |
+
fill_holes_max_hole_size: float = 0.04,
|
207 |
+
fill_holes_max_hole_nbe: int = 32,
|
208 |
+
fill_holes_resolution: int = 1024,
|
209 |
+
fill_holes_num_views: int = 1000,
|
210 |
+
debug: bool = False,
|
211 |
+
verbose: bool = False,
|
212 |
+
):
|
213 |
+
"""
|
214 |
+
Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
218 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
219 |
+
simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
|
220 |
+
simplify_ratio (float): Ratio of faces to keep after simplification.
|
221 |
+
fill_holes (bool): Whether to fill holes in the mesh.
|
222 |
+
fill_holes_max_hole_size (float): Maximum area of a hole to fill.
|
223 |
+
fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
|
224 |
+
fill_holes_resolution (int): Resolution of the rasterization.
|
225 |
+
fill_holes_num_views (int): Number of views to rasterize the mesh.
|
226 |
+
verbose (bool): Whether to print progress.
|
227 |
+
"""
|
228 |
+
|
229 |
+
if verbose:
|
230 |
+
tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
231 |
+
|
232 |
+
# Simplify
|
233 |
+
if simplify and simplify_ratio > 0:
|
234 |
+
mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
|
235 |
+
mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
|
236 |
+
vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
|
237 |
+
if verbose:
|
238 |
+
tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
239 |
+
|
240 |
+
# Remove invisible faces
|
241 |
+
if fill_holes:
|
242 |
+
vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
|
243 |
+
vertices, faces = _fill_holes(
|
244 |
+
vertices, faces,
|
245 |
+
max_hole_size=fill_holes_max_hole_size,
|
246 |
+
max_hole_nbe=fill_holes_max_hole_nbe,
|
247 |
+
resolution=fill_holes_resolution,
|
248 |
+
num_views=fill_holes_num_views,
|
249 |
+
debug=debug,
|
250 |
+
verbose=verbose,
|
251 |
+
)
|
252 |
+
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
|
253 |
+
if verbose:
|
254 |
+
tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
255 |
+
|
256 |
+
return vertices, faces
|
257 |
+
|
258 |
+
|
259 |
+
def parametrize_mesh(vertices: np.array, faces: np.array):
|
260 |
+
"""
|
261 |
+
Parametrize a mesh to a texture space, using xatlas.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
265 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
266 |
+
"""
|
267 |
+
|
268 |
+
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
269 |
+
|
270 |
+
vertices = vertices[vmapping]
|
271 |
+
faces = indices
|
272 |
+
|
273 |
+
return vertices, faces, uvs
|
274 |
+
|
275 |
+
|
276 |
+
def bake_texture(
|
277 |
+
vertices: np.array,
|
278 |
+
faces: np.array,
|
279 |
+
uvs: np.array,
|
280 |
+
observations: List[np.array],
|
281 |
+
masks: List[np.array],
|
282 |
+
extrinsics: List[np.array],
|
283 |
+
intrinsics: List[np.array],
|
284 |
+
texture_size: int = 2048,
|
285 |
+
near: float = 0.1,
|
286 |
+
far: float = 10.0,
|
287 |
+
mode: Literal['fast', 'opt'] = 'opt',
|
288 |
+
lambda_tv: float = 1e-2,
|
289 |
+
verbose: bool = False,
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
Bake texture to a mesh from multiple observations.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
296 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
297 |
+
uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
|
298 |
+
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
|
299 |
+
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
|
300 |
+
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
|
301 |
+
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
|
302 |
+
texture_size (int): Size of the texture.
|
303 |
+
near (float): Near plane of the camera.
|
304 |
+
far (float): Far plane of the camera.
|
305 |
+
mode (Literal['fast', 'opt']): Mode of texture baking.
|
306 |
+
lambda_tv (float): Weight of total variation loss in optimization.
|
307 |
+
verbose (bool): Whether to print progress.
|
308 |
+
"""
|
309 |
+
vertices = torch.tensor(vertices).cuda()
|
310 |
+
faces = torch.tensor(faces.astype(np.int32)).cuda()
|
311 |
+
uvs = torch.tensor(uvs).cuda()
|
312 |
+
observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
|
313 |
+
masks = [torch.tensor(m>0).bool().cuda() for m in masks]
|
314 |
+
views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
|
315 |
+
projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
|
316 |
+
|
317 |
+
if mode == 'fast':
|
318 |
+
texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
|
319 |
+
texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
|
320 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
321 |
+
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
|
322 |
+
with torch.no_grad():
|
323 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
324 |
+
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
325 |
+
)
|
326 |
+
uv_map = rast['uv'][0].detach().flip(0)
|
327 |
+
mask = rast['mask'][0].detach().bool() & masks[0]
|
328 |
+
|
329 |
+
# nearest neighbor interpolation
|
330 |
+
uv_map = (uv_map * texture_size).floor().long()
|
331 |
+
obs = observation[mask]
|
332 |
+
uv_map = uv_map[mask]
|
333 |
+
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
334 |
+
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
|
335 |
+
texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
|
336 |
+
|
337 |
+
mask = texture_weights > 0
|
338 |
+
texture[mask] /= texture_weights[mask][:, None]
|
339 |
+
texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
340 |
+
|
341 |
+
# inpaint
|
342 |
+
mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
|
343 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
344 |
+
|
345 |
+
elif mode == 'opt':
|
346 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
347 |
+
observations = [observations.flip(0) for observations in observations]
|
348 |
+
masks = [m.flip(0) for m in masks]
|
349 |
+
_uv = []
|
350 |
+
_uv_dr = []
|
351 |
+
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
|
352 |
+
with torch.no_grad():
|
353 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
354 |
+
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
355 |
+
)
|
356 |
+
_uv.append(rast['uv'].detach())
|
357 |
+
_uv_dr.append(rast['uv_dr'].detach())
|
358 |
+
|
359 |
+
texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
|
360 |
+
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
361 |
+
|
362 |
+
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
363 |
+
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
364 |
+
|
365 |
+
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
366 |
+
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
367 |
+
|
368 |
+
def tv_loss(texture):
|
369 |
+
return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
|
370 |
+
torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
|
371 |
+
|
372 |
+
total_steps = 2500
|
373 |
+
with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
|
374 |
+
for step in range(total_steps):
|
375 |
+
optimizer.zero_grad()
|
376 |
+
selected = np.random.randint(0, len(views))
|
377 |
+
uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
|
378 |
+
render = dr.texture(texture, uv, uv_dr)[0]
|
379 |
+
loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
|
380 |
+
if lambda_tv > 0:
|
381 |
+
loss += lambda_tv * tv_loss(texture)
|
382 |
+
loss.backward()
|
383 |
+
optimizer.step()
|
384 |
+
# annealing
|
385 |
+
optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
|
386 |
+
pbar.set_postfix({'loss': loss.item()})
|
387 |
+
pbar.update()
|
388 |
+
texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
389 |
+
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
390 |
+
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
|
391 |
+
)['mask'][0].detach().cpu().numpy().astype(np.uint8)
|
392 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
393 |
+
else:
|
394 |
+
raise ValueError(f'Unknown mode: {mode}')
|
395 |
+
|
396 |
+
return texture
|
397 |
+
|
398 |
+
|
399 |
+
def to_glb(
|
400 |
+
app_rep: Union[Strivec, Gaussian],
|
401 |
+
mesh: MeshExtractResult,
|
402 |
+
simplify: float = 0.95,
|
403 |
+
fill_holes: bool = True,
|
404 |
+
fill_holes_max_size: float = 0.04,
|
405 |
+
texture_size: int = 1024,
|
406 |
+
debug: bool = False,
|
407 |
+
verbose: bool = True,
|
408 |
+
) -> trimesh.Trimesh:
|
409 |
+
"""
|
410 |
+
Convert a generated asset to a glb file.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
app_rep (Union[Strivec, Gaussian]): Appearance representation.
|
414 |
+
mesh (MeshExtractResult): Extracted mesh.
|
415 |
+
simplify (float): Ratio of faces to remove in simplification.
|
416 |
+
fill_holes (bool): Whether to fill holes in the mesh.
|
417 |
+
fill_holes_max_size (float): Maximum area of a hole to fill.
|
418 |
+
texture_size (int): Size of the texture.
|
419 |
+
debug (bool): Whether to print debug information.
|
420 |
+
verbose (bool): Whether to print progress.
|
421 |
+
"""
|
422 |
+
vertices = mesh.vertices.cpu().numpy()
|
423 |
+
faces = mesh.faces.cpu().numpy()
|
424 |
+
|
425 |
+
# mesh postprocess
|
426 |
+
vertices, faces = postprocess_mesh(
|
427 |
+
vertices, faces,
|
428 |
+
simplify=simplify > 0,
|
429 |
+
simplify_ratio=simplify,
|
430 |
+
fill_holes=fill_holes,
|
431 |
+
fill_holes_max_hole_size=fill_holes_max_size,
|
432 |
+
fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
|
433 |
+
fill_holes_resolution=1024,
|
434 |
+
fill_holes_num_views=1000,
|
435 |
+
debug=debug,
|
436 |
+
verbose=verbose,
|
437 |
+
)
|
438 |
+
|
439 |
+
# parametrize mesh
|
440 |
+
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
441 |
+
|
442 |
+
# bake texture
|
443 |
+
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
|
444 |
+
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
445 |
+
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
446 |
+
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
447 |
+
texture = bake_texture(
|
448 |
+
vertices, faces, uvs,
|
449 |
+
observations, masks, extrinsics, intrinsics,
|
450 |
+
texture_size=texture_size, mode='opt',
|
451 |
+
lambda_tv=0.01,
|
452 |
+
verbose=verbose
|
453 |
+
)
|
454 |
+
texture = Image.fromarray(texture)
|
455 |
+
|
456 |
+
# rotate mesh (from z-up to y-up)
|
457 |
+
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
|
458 |
+
material = trimesh.visual.material.PBRMaterial(
|
459 |
+
roughnessFactor=1.0,
|
460 |
+
baseColorTexture=texture,
|
461 |
+
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
|
462 |
+
)
|
463 |
+
mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
|
464 |
+
return mesh
|
465 |
+
|
466 |
+
|
467 |
+
def simplify_gs(
|
468 |
+
gs: Gaussian,
|
469 |
+
simplify: float = 0.95,
|
470 |
+
verbose: bool = True,
|
471 |
+
):
|
472 |
+
"""
|
473 |
+
Simplify 3D Gaussians
|
474 |
+
NOTE: this function is not used in the current implementation for the unsatisfactory performance.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
gs (Gaussian): 3D Gaussian.
|
478 |
+
simplify (float): Ratio of Gaussians to remove in simplification.
|
479 |
+
"""
|
480 |
+
if simplify <= 0:
|
481 |
+
return gs
|
482 |
+
|
483 |
+
# simplify
|
484 |
+
observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
|
485 |
+
observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
|
486 |
+
|
487 |
+
# Following https://arxiv.org/pdf/2411.06019
|
488 |
+
renderer = GaussianRenderer({
|
489 |
+
"resolution": 1024,
|
490 |
+
"near": 0.8,
|
491 |
+
"far": 1.6,
|
492 |
+
"ssaa": 1,
|
493 |
+
"bg_color": (0,0,0),
|
494 |
+
})
|
495 |
+
new_gs = Gaussian(**gs.init_params)
|
496 |
+
new_gs._features_dc = gs._features_dc.clone()
|
497 |
+
new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
|
498 |
+
new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
|
499 |
+
new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
|
500 |
+
new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
|
501 |
+
new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
|
502 |
+
|
503 |
+
start_lr = [1e-4, 1e-3, 5e-3, 0.025]
|
504 |
+
end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
|
505 |
+
optimizer = torch.optim.Adam([
|
506 |
+
{"params": new_gs._xyz, "lr": start_lr[0]},
|
507 |
+
{"params": new_gs._rotation, "lr": start_lr[1]},
|
508 |
+
{"params": new_gs._scaling, "lr": start_lr[2]},
|
509 |
+
{"params": new_gs._opacity, "lr": start_lr[3]},
|
510 |
+
], lr=start_lr[0])
|
511 |
+
|
512 |
+
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
513 |
+
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
514 |
+
|
515 |
+
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
516 |
+
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
517 |
+
|
518 |
+
_zeta = new_gs.get_opacity.clone().detach().squeeze()
|
519 |
+
_lambda = torch.zeros_like(_zeta)
|
520 |
+
_delta = 1e-7
|
521 |
+
_interval = 10
|
522 |
+
num_target = int((1 - simplify) * _zeta.shape[0])
|
523 |
+
|
524 |
+
with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
|
525 |
+
for i in range(2500):
|
526 |
+
# prune
|
527 |
+
if i % 100 == 0:
|
528 |
+
mask = new_gs.get_opacity.squeeze() > 0.05
|
529 |
+
mask = torch.nonzero(mask).squeeze()
|
530 |
+
new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
|
531 |
+
new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
|
532 |
+
new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
|
533 |
+
new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
|
534 |
+
new_gs._features_dc = new_gs._features_dc[mask]
|
535 |
+
new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
|
536 |
+
_zeta = _zeta[mask]
|
537 |
+
_lambda = _lambda[mask]
|
538 |
+
# update optimizer state
|
539 |
+
for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
|
540 |
+
stored_state = optimizer.state[param_group['params'][0]]
|
541 |
+
if 'exp_avg' in stored_state:
|
542 |
+
stored_state['exp_avg'] = stored_state['exp_avg'][mask]
|
543 |
+
stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
|
544 |
+
del optimizer.state[param_group['params'][0]]
|
545 |
+
param_group['params'][0] = new_param
|
546 |
+
optimizer.state[param_group['params'][0]] = stored_state
|
547 |
+
|
548 |
+
opacity = new_gs.get_opacity.squeeze()
|
549 |
+
|
550 |
+
# sparisfy
|
551 |
+
if i % _interval == 0:
|
552 |
+
_zeta = _lambda + opacity.detach()
|
553 |
+
if opacity.shape[0] > num_target:
|
554 |
+
index = _zeta.topk(num_target)[1]
|
555 |
+
_m = torch.ones_like(_zeta, dtype=torch.bool)
|
556 |
+
_m[index] = 0
|
557 |
+
_zeta[_m] = 0
|
558 |
+
_lambda = _lambda + opacity.detach() - _zeta
|
559 |
+
|
560 |
+
# sample a random view
|
561 |
+
view_idx = np.random.randint(len(observations))
|
562 |
+
observation = observations[view_idx]
|
563 |
+
extrinsic = extrinsics[view_idx]
|
564 |
+
intrinsic = intrinsics[view_idx]
|
565 |
+
|
566 |
+
color = renderer.render(new_gs, extrinsic, intrinsic)['color']
|
567 |
+
rgb_loss = torch.nn.functional.l1_loss(color, observation)
|
568 |
+
loss = rgb_loss + \
|
569 |
+
_delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
|
570 |
+
|
571 |
+
optimizer.zero_grad()
|
572 |
+
loss.backward()
|
573 |
+
optimizer.step()
|
574 |
+
|
575 |
+
# update lr
|
576 |
+
for j in range(len(optimizer.param_groups)):
|
577 |
+
optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
|
578 |
+
|
579 |
+
pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
|
580 |
+
pbar.update()
|
581 |
+
|
582 |
+
new_gs._xyz = new_gs._xyz.data
|
583 |
+
new_gs._rotation = new_gs._rotation.data
|
584 |
+
new_gs._scaling = new_gs._scaling.data
|
585 |
+
new_gs._opacity = new_gs._opacity.data
|
586 |
+
|
587 |
+
return new_gs
|
trellis/utils/random_utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
4 |
+
|
5 |
+
def radical_inverse(base, n):
|
6 |
+
val = 0
|
7 |
+
inv_base = 1.0 / base
|
8 |
+
inv_base_n = inv_base
|
9 |
+
while n > 0:
|
10 |
+
digit = n % base
|
11 |
+
val += digit * inv_base_n
|
12 |
+
n //= base
|
13 |
+
inv_base_n *= inv_base
|
14 |
+
return val
|
15 |
+
|
16 |
+
def halton_sequence(dim, n):
|
17 |
+
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
18 |
+
|
19 |
+
def hammersley_sequence(dim, n, num_samples):
|
20 |
+
return [n / num_samples] + halton_sequence(dim - 1, n)
|
21 |
+
|
22 |
+
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
|
23 |
+
u, v = hammersley_sequence(2, n, num_samples)
|
24 |
+
u += offset[0] / num_samples
|
25 |
+
v += offset[1]
|
26 |
+
if remap:
|
27 |
+
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
28 |
+
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
29 |
+
phi = v * 2 * np.pi
|
30 |
+
return [phi, theta]
|
trellis/utils/render_utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
import utils3d
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
|
8 |
+
from ..representations import Octree, Gaussian, MeshExtractResult
|
9 |
+
from ..modules import sparse as sp
|
10 |
+
from .random_utils import sphere_hammersley_sequence
|
11 |
+
|
12 |
+
|
13 |
+
def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
|
14 |
+
is_list = isinstance(yaws, list)
|
15 |
+
if not is_list:
|
16 |
+
yaws = [yaws]
|
17 |
+
pitchs = [pitchs]
|
18 |
+
if not isinstance(rs, list):
|
19 |
+
rs = [rs] * len(yaws)
|
20 |
+
if not isinstance(fovs, list):
|
21 |
+
fovs = [fovs] * len(yaws)
|
22 |
+
extrinsics = []
|
23 |
+
intrinsics = []
|
24 |
+
for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
|
25 |
+
fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
|
26 |
+
yaw = torch.tensor(float(yaw)).cuda()
|
27 |
+
pitch = torch.tensor(float(pitch)).cuda()
|
28 |
+
orig = torch.tensor([
|
29 |
+
torch.sin(yaw) * torch.cos(pitch),
|
30 |
+
torch.cos(yaw) * torch.cos(pitch),
|
31 |
+
torch.sin(pitch),
|
32 |
+
]).cuda() * r
|
33 |
+
extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
34 |
+
intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
35 |
+
extrinsics.append(extr)
|
36 |
+
intrinsics.append(intr)
|
37 |
+
if not is_list:
|
38 |
+
extrinsics = extrinsics[0]
|
39 |
+
intrinsics = intrinsics[0]
|
40 |
+
return extrinsics, intrinsics
|
41 |
+
|
42 |
+
|
43 |
+
def get_renderer(sample, **kwargs):
|
44 |
+
if isinstance(sample, Octree):
|
45 |
+
renderer = OctreeRenderer()
|
46 |
+
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
47 |
+
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
48 |
+
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
49 |
+
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
50 |
+
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
51 |
+
renderer.pipe.primitive = sample.primitive
|
52 |
+
elif isinstance(sample, Gaussian):
|
53 |
+
renderer = GaussianRenderer()
|
54 |
+
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
55 |
+
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
56 |
+
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
57 |
+
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
58 |
+
renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
|
59 |
+
renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
|
60 |
+
renderer.pipe.use_mip_gaussian = True
|
61 |
+
elif isinstance(sample, MeshExtractResult):
|
62 |
+
renderer = MeshRenderer()
|
63 |
+
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
64 |
+
renderer.rendering_options.near = kwargs.get('near', 1)
|
65 |
+
renderer.rendering_options.far = kwargs.get('far', 100)
|
66 |
+
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
67 |
+
else:
|
68 |
+
raise ValueError(f'Unsupported sample type: {type(sample)}')
|
69 |
+
return renderer
|
70 |
+
|
71 |
+
|
72 |
+
def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
|
73 |
+
renderer = get_renderer(sample, **options)
|
74 |
+
rets = {}
|
75 |
+
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
|
76 |
+
if isinstance(sample, MeshExtractResult):
|
77 |
+
res = renderer.render(sample, extr, intr)
|
78 |
+
if 'normal' not in rets: rets['normal'] = []
|
79 |
+
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
80 |
+
else:
|
81 |
+
res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
|
82 |
+
if 'color' not in rets: rets['color'] = []
|
83 |
+
if 'depth' not in rets: rets['depth'] = []
|
84 |
+
rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
85 |
+
if 'percent_depth' in res:
|
86 |
+
rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
|
87 |
+
elif 'depth' in res:
|
88 |
+
rets['depth'].append(res['depth'].detach().cpu().numpy())
|
89 |
+
else:
|
90 |
+
rets['depth'].append(None)
|
91 |
+
return rets
|
92 |
+
|
93 |
+
|
94 |
+
def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
|
95 |
+
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
96 |
+
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
|
97 |
+
yaws = yaws.tolist()
|
98 |
+
pitch = pitch.tolist()
|
99 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
|
100 |
+
return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
101 |
+
|
102 |
+
|
103 |
+
def render_multiview(sample, resolution=512, nviews=30):
|
104 |
+
r = 2
|
105 |
+
fov = 40
|
106 |
+
cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
|
107 |
+
yaws = [cam[0] for cam in cams]
|
108 |
+
pitchs = [cam[1] for cam in cams]
|
109 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
|
110 |
+
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
|
111 |
+
return res['color'], extrinsics, intrinsics
|
112 |
+
|
113 |
+
|
114 |
+
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
|
115 |
+
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
|
116 |
+
yaw_offset = offset[0]
|
117 |
+
yaw = [y + yaw_offset for y in yaw]
|
118 |
+
pitch = [offset[1] for _ in range(4)]
|
119 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
|
120 |
+
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|