Spaces:
Running
on
Zero
Running
on
Zero
Upload 288 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -65
- .gitignore +2 -0
- README.md +16 -12
- app.py +274 -232
- app_img.py +414 -0
- configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json +102 -0
- configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json +101 -0
- configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json +101 -0
- configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json +101 -0
- configs/generation/ss_flow_img_dit_L_16l8_fp16.json +70 -0
- configs/generation/ss_flow_txt_dit_B_16l8_fp16.json +69 -0
- configs/generation/ss_flow_txt_dit_L_16l8_fp16.json +69 -0
- configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json +70 -0
- configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json +73 -0
- configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json +71 -0
- configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json +105 -0
- configs/vae/ss_vae_conv3d_16l8_fp16.json +65 -0
- dataset_toolkits/blender_script/io_scene_usdz.zip +3 -0
- dataset_toolkits/blender_script/render.py +528 -0
- dataset_toolkits/build_metadata.py +270 -0
- dataset_toolkits/datasets/3D-FUTURE.py +97 -0
- dataset_toolkits/datasets/ABO.py +96 -0
- dataset_toolkits/datasets/HSSD.py +103 -0
- dataset_toolkits/datasets/ObjaverseXL.py +92 -0
- dataset_toolkits/datasets/Toys4k.py +92 -0
- dataset_toolkits/download.py +52 -0
- dataset_toolkits/encode_latent.py +127 -0
- dataset_toolkits/encode_ss_latent.py +128 -0
- dataset_toolkits/extract_feature.py +179 -0
- dataset_toolkits/render.py +121 -0
- dataset_toolkits/render_cond.py +125 -0
- dataset_toolkits/setup.sh +1 -0
- dataset_toolkits/stat_latent.py +66 -0
- dataset_toolkits/utils.py +43 -0
- dataset_toolkits/voxelize.py +86 -0
- env.py +10 -10
- extensions/vox2seq/benchmark.py +45 -0
- extensions/vox2seq/setup.py +34 -0
- extensions/vox2seq/src/api.cu +92 -0
- extensions/vox2seq/src/api.h +76 -0
- extensions/vox2seq/src/ext.cpp +10 -0
- extensions/vox2seq/src/hilbert.cu +133 -0
- extensions/vox2seq/src/hilbert.h +35 -0
- extensions/vox2seq/src/z_order.cu +66 -0
- extensions/vox2seq/src/z_order.h +35 -0
- extensions/vox2seq/test.py +25 -0
- extensions/vox2seq/vox2seq/__init__.py +50 -0
- extensions/vox2seq/vox2seq/pytorch/__init__.py +48 -0
- extensions/vox2seq/vox2seq/pytorch/default.py +59 -0
- extensions/vox2seq/vox2seq/pytorch/hilbert.py +303 -0
.gitattributes
CHANGED
@@ -34,68 +34,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
assets/example_image/typical_building_maya_pyramid.png filter=lfs diff=lfs merge=lfs -text
|
42 |
-
assets/example_image/typical_building_mushroom.png filter=lfs diff=lfs merge=lfs -text
|
43 |
-
assets/example_image/typical_building_space_station.png filter=lfs diff=lfs merge=lfs -text
|
44 |
-
assets/example_image/typical_creature_dragon.png filter=lfs diff=lfs merge=lfs -text
|
45 |
-
assets/example_image/typical_creature_elephant.png filter=lfs diff=lfs merge=lfs -text
|
46 |
-
assets/example_image/typical_creature_furry.png filter=lfs diff=lfs merge=lfs -text
|
47 |
-
assets/example_image/typical_creature_quadruped.png filter=lfs diff=lfs merge=lfs -text
|
48 |
-
assets/example_image/typical_creature_robot_crab.png filter=lfs diff=lfs merge=lfs -text
|
49 |
-
assets/example_image/typical_creature_robot_dinosour.png filter=lfs diff=lfs merge=lfs -text
|
50 |
-
assets/example_image/typical_creature_rock_monster.png filter=lfs diff=lfs merge=lfs -text
|
51 |
-
assets/example_image/typical_humanoid_block_robot.png filter=lfs diff=lfs merge=lfs -text
|
52 |
-
assets/example_image/typical_humanoid_dragonborn.png filter=lfs diff=lfs merge=lfs -text
|
53 |
-
assets/example_image/typical_humanoid_dwarf.png filter=lfs diff=lfs merge=lfs -text
|
54 |
-
assets/example_image/typical_humanoid_goblin.png filter=lfs diff=lfs merge=lfs -text
|
55 |
-
assets/example_image/typical_humanoid_mech.png filter=lfs diff=lfs merge=lfs -text
|
56 |
-
assets/example_image/typical_misc_crate.png filter=lfs diff=lfs merge=lfs -text
|
57 |
-
assets/example_image/typical_misc_fireplace.png filter=lfs diff=lfs merge=lfs -text
|
58 |
-
assets/example_image/typical_misc_gate.png filter=lfs diff=lfs merge=lfs -text
|
59 |
-
assets/example_image/typical_misc_lantern.png filter=lfs diff=lfs merge=lfs -text
|
60 |
-
assets/example_image/typical_misc_magicbook.png filter=lfs diff=lfs merge=lfs -text
|
61 |
-
assets/example_image/typical_misc_mailbox.png filter=lfs diff=lfs merge=lfs -text
|
62 |
-
assets/example_image/typical_misc_monster_chest.png filter=lfs diff=lfs merge=lfs -text
|
63 |
-
assets/example_image/typical_misc_paper_machine.png filter=lfs diff=lfs merge=lfs -text
|
64 |
-
assets/example_image/typical_misc_phonograph.png filter=lfs diff=lfs merge=lfs -text
|
65 |
-
assets/example_image/typical_misc_portal2.png filter=lfs diff=lfs merge=lfs -text
|
66 |
-
assets/example_image/typical_misc_storage_chest.png filter=lfs diff=lfs merge=lfs -text
|
67 |
-
assets/example_image/typical_misc_telephone.png filter=lfs diff=lfs merge=lfs -text
|
68 |
-
assets/example_image/typical_misc_television.png filter=lfs diff=lfs merge=lfs -text
|
69 |
-
assets/example_image/typical_misc_workbench.png filter=lfs diff=lfs merge=lfs -text
|
70 |
-
assets/example_image/typical_vehicle_biplane.png filter=lfs diff=lfs merge=lfs -text
|
71 |
-
assets/example_image/typical_vehicle_bulldozer.png filter=lfs diff=lfs merge=lfs -text
|
72 |
-
assets/example_image/typical_vehicle_cart.png filter=lfs diff=lfs merge=lfs -text
|
73 |
-
assets/example_image/typical_vehicle_excavator.png filter=lfs diff=lfs merge=lfs -text
|
74 |
-
assets/example_image/typical_vehicle_helicopter.png filter=lfs diff=lfs merge=lfs -text
|
75 |
-
assets/example_image/typical_vehicle_locomotive.png filter=lfs diff=lfs merge=lfs -text
|
76 |
-
assets/example_image/typical_vehicle_pirate_ship.png filter=lfs diff=lfs merge=lfs -text
|
77 |
-
assets/example_image/weatherworn_misc_paper_machine3.png filter=lfs diff=lfs merge=lfs -text
|
78 |
-
assets/example_multi_image/character_1.png filter=lfs diff=lfs merge=lfs -text
|
79 |
-
assets/example_multi_image/character_2.png filter=lfs diff=lfs merge=lfs -text
|
80 |
-
assets/example_multi_image/character_3.png filter=lfs diff=lfs merge=lfs -text
|
81 |
-
assets/example_multi_image/mushroom_1.png filter=lfs diff=lfs merge=lfs -text
|
82 |
-
assets/example_multi_image/mushroom_2.png filter=lfs diff=lfs merge=lfs -text
|
83 |
-
assets/example_multi_image/mushroom_3.png filter=lfs diff=lfs merge=lfs -text
|
84 |
-
assets/example_multi_image/orangeguy_1.png filter=lfs diff=lfs merge=lfs -text
|
85 |
-
assets/example_multi_image/orangeguy_2.png filter=lfs diff=lfs merge=lfs -text
|
86 |
-
assets/example_multi_image/orangeguy_3.png filter=lfs diff=lfs merge=lfs -text
|
87 |
-
assets/example_multi_image/popmart_1.png filter=lfs diff=lfs merge=lfs -text
|
88 |
-
assets/example_multi_image/popmart_2.png filter=lfs diff=lfs merge=lfs -text
|
89 |
-
assets/example_multi_image/popmart_3.png filter=lfs diff=lfs merge=lfs -text
|
90 |
-
assets/example_multi_image/rabbit_1.png filter=lfs diff=lfs merge=lfs -text
|
91 |
-
assets/example_multi_image/rabbit_2.png filter=lfs diff=lfs merge=lfs -text
|
92 |
-
assets/example_multi_image/rabbit_3.png filter=lfs diff=lfs merge=lfs -text
|
93 |
-
assets/example_multi_image/tiger_1.png filter=lfs diff=lfs merge=lfs -text
|
94 |
-
assets/example_multi_image/tiger_2.png filter=lfs diff=lfs merge=lfs -text
|
95 |
-
assets/example_multi_image/tiger_3.png filter=lfs diff=lfs merge=lfs -text
|
96 |
-
assets/example_multi_image/yoimiya_1.png filter=lfs diff=lfs merge=lfs -text
|
97 |
-
assets/example_multi_image/yoimiya_2.png filter=lfs diff=lfs merge=lfs -text
|
98 |
-
assets/example_multi_image/yoimiya_3.png filter=lfs diff=lfs merge=lfs -text
|
99 |
-
assets/logo.webp filter=lfs diff=lfs merge=lfs -text
|
100 |
-
assets/T.ply filter=lfs diff=lfs merge=lfs -text
|
101 |
-
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.ply filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.idea/*
|
2 |
+
__pycache__
|
README.md
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
-
---
|
2 |
-
title: TRELLIS
|
3 |
-
emoji:
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
-
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
license: mit
|
11 |
-
short_description: 3D Generation from text
|
12 |
-
---
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: TRELLIS Text To 3D
|
3 |
+
emoji: 🏢
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.25.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: mit
|
11 |
+
short_description: Scalable and Versatile 3D Generation from text prompt
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
15 |
+
|
16 |
+
Paper: https://huggingface.co/papers/2412.01506
|
app.py
CHANGED
@@ -1,232 +1,274 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import spaces
|
3 |
-
|
4 |
-
import os
|
5 |
-
import shutil
|
6 |
-
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
7 |
-
os.environ['SPCONV_ALGO'] = 'native'
|
8 |
-
from typing import *
|
9 |
-
import torch
|
10 |
-
import numpy as np
|
11 |
-
import imageio
|
12 |
-
from easydict import EasyDict as edict
|
13 |
-
from trellis.pipelines import TrellisTextTo3DPipeline
|
14 |
-
from trellis.representations import Gaussian, MeshExtractResult
|
15 |
-
from trellis.utils import render_utils, postprocessing_utils
|
16 |
-
|
17 |
-
import traceback
|
18 |
-
import sys
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
os.
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
os.
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
'
|
41 |
-
'
|
42 |
-
'
|
43 |
-
'
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
'
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
gs.
|
63 |
-
gs.
|
64 |
-
gs.
|
65 |
-
gs.
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
gr.
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
)
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
7 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
8 |
+
from typing import *
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import imageio
|
12 |
+
from easydict import EasyDict as edict
|
13 |
+
from trellis.pipelines import TrellisTextTo3DPipeline
|
14 |
+
from trellis.representations import Gaussian, MeshExtractResult
|
15 |
+
from trellis.utils import render_utils, postprocessing_utils
|
16 |
+
|
17 |
+
import traceback
|
18 |
+
import sys
|
19 |
+
|
20 |
+
|
21 |
+
MAX_SEED = np.iinfo(np.int32).max
|
22 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
23 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
24 |
+
|
25 |
+
|
26 |
+
def start_session(req: gr.Request):
|
27 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
28 |
+
os.makedirs(user_dir, exist_ok=True)
|
29 |
+
|
30 |
+
|
31 |
+
def end_session(req: gr.Request):
|
32 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
33 |
+
shutil.rmtree(user_dir)
|
34 |
+
|
35 |
+
|
36 |
+
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
37 |
+
return {
|
38 |
+
'gaussian': {
|
39 |
+
**gs.init_params,
|
40 |
+
'_xyz': gs._xyz.cpu().numpy(),
|
41 |
+
'_features_dc': gs._features_dc.cpu().numpy(),
|
42 |
+
'_scaling': gs._scaling.cpu().numpy(),
|
43 |
+
'_rotation': gs._rotation.cpu().numpy(),
|
44 |
+
'_opacity': gs._opacity.cpu().numpy(),
|
45 |
+
},
|
46 |
+
'mesh': {
|
47 |
+
'vertices': mesh.vertices.cpu().numpy(),
|
48 |
+
'faces': mesh.faces.cpu().numpy(),
|
49 |
+
},
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
54 |
+
gs = Gaussian(
|
55 |
+
aabb=state['gaussian']['aabb'],
|
56 |
+
sh_degree=state['gaussian']['sh_degree'],
|
57 |
+
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
|
58 |
+
scaling_bias=state['gaussian']['scaling_bias'],
|
59 |
+
opacity_bias=state['gaussian']['opacity_bias'],
|
60 |
+
scaling_activation=state['gaussian']['scaling_activation'],
|
61 |
+
)
|
62 |
+
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
|
63 |
+
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
|
64 |
+
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
|
65 |
+
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
|
66 |
+
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
|
67 |
+
|
68 |
+
mesh = edict(
|
69 |
+
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
|
70 |
+
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
|
71 |
+
)
|
72 |
+
|
73 |
+
return gs, mesh
|
74 |
+
|
75 |
+
|
76 |
+
def get_seed(randomize_seed: bool, seed: int) -> int:
|
77 |
+
"""
|
78 |
+
Get the random seed.
|
79 |
+
"""
|
80 |
+
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
81 |
+
|
82 |
+
|
83 |
+
@spaces.GPU
|
84 |
+
def text_to_3d(
|
85 |
+
prompt: str,
|
86 |
+
seed: int,
|
87 |
+
ss_guidance_strength: float,
|
88 |
+
ss_sampling_steps: int,
|
89 |
+
slat_guidance_strength: float,
|
90 |
+
slat_sampling_steps: int,
|
91 |
+
req: gr.Request,
|
92 |
+
) -> Tuple[dict, str]:
|
93 |
+
"""
|
94 |
+
Convert an text prompt to a 3D model.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
prompt (str): The text prompt.
|
98 |
+
seed (int): The random seed.
|
99 |
+
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
100 |
+
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
101 |
+
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
102 |
+
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
dict: The information of the generated 3D model.
|
106 |
+
str: The path to the video of the 3D model.
|
107 |
+
"""
|
108 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
109 |
+
outputs = pipeline.run(
|
110 |
+
prompt,
|
111 |
+
seed=seed,
|
112 |
+
formats=["gaussian", "mesh"],
|
113 |
+
sparse_structure_sampler_params={
|
114 |
+
"steps": ss_sampling_steps,
|
115 |
+
"cfg_strength": ss_guidance_strength,
|
116 |
+
},
|
117 |
+
slat_sampler_params={
|
118 |
+
"steps": slat_sampling_steps,
|
119 |
+
"cfg_strength": slat_guidance_strength,
|
120 |
+
},
|
121 |
+
)
|
122 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
123 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
124 |
+
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
125 |
+
video_path = os.path.join(user_dir, 'sample.mp4')
|
126 |
+
imageio.mimsave(video_path, video, fps=15)
|
127 |
+
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
128 |
+
torch.cuda.empty_cache()
|
129 |
+
return state, video_path
|
130 |
+
|
131 |
+
|
132 |
+
@spaces.GPU(duration=90)
|
133 |
+
def extract_glb(
|
134 |
+
state: dict,
|
135 |
+
mesh_simplify: float,
|
136 |
+
texture_size: int,
|
137 |
+
req: gr.Request,
|
138 |
+
) -> Tuple[str, str]:
|
139 |
+
"""
|
140 |
+
Extract a GLB file from the 3D model.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
state (dict): The state of the generated 3D model.
|
144 |
+
mesh_simplify (float): The mesh simplification factor.
|
145 |
+
texture_size (int): The texture resolution.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
str: The path to the extracted GLB file.
|
149 |
+
"""
|
150 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
151 |
+
gs, mesh = unpack_state(state)
|
152 |
+
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
153 |
+
glb_path = os.path.join(user_dir, 'sample.glb')
|
154 |
+
glb.export(glb_path)
|
155 |
+
torch.cuda.empty_cache()
|
156 |
+
return glb_path, glb_path
|
157 |
+
|
158 |
+
|
159 |
+
@spaces.GPU
|
160 |
+
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
161 |
+
"""
|
162 |
+
Extract a Gaussian file from the 3D model.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
state (dict): The state of the generated 3D model.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
str: The path to the extracted Gaussian file.
|
169 |
+
"""
|
170 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
171 |
+
gs, _ = unpack_state(state)
|
172 |
+
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
173 |
+
gs.save_ply(gaussian_path)
|
174 |
+
torch.cuda.empty_cache()
|
175 |
+
return gaussian_path, gaussian_path
|
176 |
+
|
177 |
+
|
178 |
+
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
179 |
+
gr.Markdown("""
|
180 |
+
## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
181 |
+
* Type a text prompt and click "Generate" to create a 3D asset.
|
182 |
+
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
|
183 |
+
""")
|
184 |
+
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column():
|
187 |
+
text_prompt = gr.Textbox(label="Text Prompt", lines=5)
|
188 |
+
|
189 |
+
with gr.Accordion(label="Generation Settings", open=False):
|
190 |
+
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
191 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
192 |
+
gr.Markdown("Stage 1: Sparse Structure Generation")
|
193 |
+
with gr.Row():
|
194 |
+
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
195 |
+
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
|
196 |
+
gr.Markdown("Stage 2: Structured Latent Generation")
|
197 |
+
with gr.Row():
|
198 |
+
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
199 |
+
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
|
200 |
+
|
201 |
+
generate_btn = gr.Button("Generate")
|
202 |
+
|
203 |
+
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
204 |
+
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
205 |
+
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
206 |
+
|
207 |
+
with gr.Row():
|
208 |
+
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
209 |
+
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
|
210 |
+
gr.Markdown("""
|
211 |
+
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
212 |
+
""")
|
213 |
+
|
214 |
+
with gr.Column():
|
215 |
+
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
216 |
+
model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
|
217 |
+
|
218 |
+
with gr.Row():
|
219 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
220 |
+
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
221 |
+
|
222 |
+
output_buf = gr.State()
|
223 |
+
|
224 |
+
# Handlers
|
225 |
+
demo.load(start_session)
|
226 |
+
demo.unload(end_session)
|
227 |
+
|
228 |
+
generate_btn.click(
|
229 |
+
get_seed,
|
230 |
+
inputs=[randomize_seed, seed],
|
231 |
+
outputs=[seed],
|
232 |
+
).then(
|
233 |
+
text_to_3d,
|
234 |
+
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
235 |
+
outputs=[output_buf, video_output],
|
236 |
+
).then(
|
237 |
+
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
238 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
239 |
+
)
|
240 |
+
|
241 |
+
video_output.clear(
|
242 |
+
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
243 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
244 |
+
)
|
245 |
+
|
246 |
+
extract_glb_btn.click(
|
247 |
+
extract_glb,
|
248 |
+
inputs=[output_buf, mesh_simplify, texture_size],
|
249 |
+
outputs=[model_output, download_glb],
|
250 |
+
).then(
|
251 |
+
lambda: gr.Button(interactive=True),
|
252 |
+
outputs=[download_glb],
|
253 |
+
)
|
254 |
+
|
255 |
+
extract_gs_btn.click(
|
256 |
+
extract_gaussian,
|
257 |
+
inputs=[output_buf],
|
258 |
+
outputs=[model_output, download_gs],
|
259 |
+
).then(
|
260 |
+
lambda: gr.Button(interactive=True),
|
261 |
+
outputs=[download_gs],
|
262 |
+
)
|
263 |
+
|
264 |
+
model_output.clear(
|
265 |
+
lambda: gr.Button(interactive=False),
|
266 |
+
outputs=[download_glb],
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
# Launch the Gradio app
|
271 |
+
if __name__ == "__main__":
|
272 |
+
pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
|
273 |
+
pipeline.cuda()
|
274 |
+
demo.launch()
|
app_img.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
from gradio_litmodel3d import LitModel3D
|
4 |
+
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
8 |
+
from typing import *
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import imageio
|
12 |
+
from easydict import EasyDict as edict
|
13 |
+
from PIL import Image
|
14 |
+
from trellis.pipelines import TrellisImageTo3DPipeline
|
15 |
+
from trellis.representations import Gaussian, MeshExtractResult
|
16 |
+
from trellis.utils import render_utils, postprocessing_utils
|
17 |
+
|
18 |
+
|
19 |
+
MAX_SEED = np.iinfo(np.int32).max
|
20 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
21 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
22 |
+
|
23 |
+
|
24 |
+
def start_session(req: gr.Request):
|
25 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
26 |
+
os.makedirs(user_dir, exist_ok=True)
|
27 |
+
|
28 |
+
|
29 |
+
def end_session(req: gr.Request):
|
30 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
31 |
+
shutil.rmtree(user_dir)
|
32 |
+
|
33 |
+
|
34 |
+
def preprocess_image(image: Image.Image) -> Image.Image:
|
35 |
+
"""
|
36 |
+
Preprocess the input image.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
image (Image.Image): The input image.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Image.Image: The preprocessed image.
|
43 |
+
"""
|
44 |
+
processed_image = pipeline.preprocess_image(image)
|
45 |
+
return processed_image
|
46 |
+
|
47 |
+
|
48 |
+
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
|
49 |
+
"""
|
50 |
+
Preprocess a list of input images.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
images (List[Tuple[Image.Image, str]]): The input images.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
List[Image.Image]: The preprocessed images.
|
57 |
+
"""
|
58 |
+
images = [image[0] for image in images]
|
59 |
+
processed_images = [pipeline.preprocess_image(image) for image in images]
|
60 |
+
return processed_images
|
61 |
+
|
62 |
+
|
63 |
+
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
64 |
+
return {
|
65 |
+
'gaussian': {
|
66 |
+
**gs.init_params,
|
67 |
+
'_xyz': gs._xyz.cpu().numpy(),
|
68 |
+
'_features_dc': gs._features_dc.cpu().numpy(),
|
69 |
+
'_scaling': gs._scaling.cpu().numpy(),
|
70 |
+
'_rotation': gs._rotation.cpu().numpy(),
|
71 |
+
'_opacity': gs._opacity.cpu().numpy(),
|
72 |
+
},
|
73 |
+
'mesh': {
|
74 |
+
'vertices': mesh.vertices.cpu().numpy(),
|
75 |
+
'faces': mesh.faces.cpu().numpy(),
|
76 |
+
},
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
81 |
+
gs = Gaussian(
|
82 |
+
aabb=state['gaussian']['aabb'],
|
83 |
+
sh_degree=state['gaussian']['sh_degree'],
|
84 |
+
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
|
85 |
+
scaling_bias=state['gaussian']['scaling_bias'],
|
86 |
+
opacity_bias=state['gaussian']['opacity_bias'],
|
87 |
+
scaling_activation=state['gaussian']['scaling_activation'],
|
88 |
+
)
|
89 |
+
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
|
90 |
+
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
|
91 |
+
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
|
92 |
+
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
|
93 |
+
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
|
94 |
+
|
95 |
+
mesh = edict(
|
96 |
+
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
|
97 |
+
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
|
98 |
+
)
|
99 |
+
|
100 |
+
return gs, mesh
|
101 |
+
|
102 |
+
|
103 |
+
def get_seed(randomize_seed: bool, seed: int) -> int:
|
104 |
+
"""
|
105 |
+
Get the random seed.
|
106 |
+
"""
|
107 |
+
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
108 |
+
|
109 |
+
|
110 |
+
@spaces.GPU
|
111 |
+
def image_to_3d(
|
112 |
+
image: Image.Image,
|
113 |
+
multiimages: List[Tuple[Image.Image, str]],
|
114 |
+
is_multiimage: bool,
|
115 |
+
seed: int,
|
116 |
+
ss_guidance_strength: float,
|
117 |
+
ss_sampling_steps: int,
|
118 |
+
slat_guidance_strength: float,
|
119 |
+
slat_sampling_steps: int,
|
120 |
+
multiimage_algo: Literal["multidiffusion", "stochastic"],
|
121 |
+
req: gr.Request,
|
122 |
+
) -> Tuple[dict, str]:
|
123 |
+
"""
|
124 |
+
Convert an image to a 3D model.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
image (Image.Image): The input image.
|
128 |
+
multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
|
129 |
+
is_multiimage (bool): Whether is in multi-image mode.
|
130 |
+
seed (int): The random seed.
|
131 |
+
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
132 |
+
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
133 |
+
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
134 |
+
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
135 |
+
multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
dict: The information of the generated 3D model.
|
139 |
+
str: The path to the video of the 3D model.
|
140 |
+
"""
|
141 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
142 |
+
if not is_multiimage:
|
143 |
+
outputs = pipeline.run(
|
144 |
+
image,
|
145 |
+
seed=seed,
|
146 |
+
formats=["gaussian", "mesh"],
|
147 |
+
preprocess_image=False,
|
148 |
+
sparse_structure_sampler_params={
|
149 |
+
"steps": ss_sampling_steps,
|
150 |
+
"cfg_strength": ss_guidance_strength,
|
151 |
+
},
|
152 |
+
slat_sampler_params={
|
153 |
+
"steps": slat_sampling_steps,
|
154 |
+
"cfg_strength": slat_guidance_strength,
|
155 |
+
},
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
outputs = pipeline.run_multi_image(
|
159 |
+
[image[0] for image in multiimages],
|
160 |
+
seed=seed,
|
161 |
+
formats=["gaussian", "mesh"],
|
162 |
+
preprocess_image=False,
|
163 |
+
sparse_structure_sampler_params={
|
164 |
+
"steps": ss_sampling_steps,
|
165 |
+
"cfg_strength": ss_guidance_strength,
|
166 |
+
},
|
167 |
+
slat_sampler_params={
|
168 |
+
"steps": slat_sampling_steps,
|
169 |
+
"cfg_strength": slat_guidance_strength,
|
170 |
+
},
|
171 |
+
mode=multiimage_algo,
|
172 |
+
)
|
173 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
174 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
175 |
+
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
176 |
+
video_path = os.path.join(user_dir, 'sample.mp4')
|
177 |
+
imageio.mimsave(video_path, video, fps=15)
|
178 |
+
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
179 |
+
torch.cuda.empty_cache()
|
180 |
+
return state, video_path
|
181 |
+
|
182 |
+
|
183 |
+
@spaces.GPU(duration=90)
|
184 |
+
def extract_glb(
|
185 |
+
state: dict,
|
186 |
+
mesh_simplify: float,
|
187 |
+
texture_size: int,
|
188 |
+
req: gr.Request,
|
189 |
+
) -> Tuple[str, str]:
|
190 |
+
"""
|
191 |
+
Extract a GLB file from the 3D model.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
state (dict): The state of the generated 3D model.
|
195 |
+
mesh_simplify (float): The mesh simplification factor.
|
196 |
+
texture_size (int): The texture resolution.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
str: The path to the extracted GLB file.
|
200 |
+
"""
|
201 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
202 |
+
gs, mesh = unpack_state(state)
|
203 |
+
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
204 |
+
glb_path = os.path.join(user_dir, 'sample.glb')
|
205 |
+
glb.export(glb_path)
|
206 |
+
torch.cuda.empty_cache()
|
207 |
+
return glb_path, glb_path
|
208 |
+
|
209 |
+
|
210 |
+
@spaces.GPU
|
211 |
+
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
212 |
+
"""
|
213 |
+
Extract a Gaussian file from the 3D model.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
state (dict): The state of the generated 3D model.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
str: The path to the extracted Gaussian file.
|
220 |
+
"""
|
221 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
222 |
+
gs, _ = unpack_state(state)
|
223 |
+
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
224 |
+
gs.save_ply(gaussian_path)
|
225 |
+
torch.cuda.empty_cache()
|
226 |
+
return gaussian_path, gaussian_path
|
227 |
+
|
228 |
+
|
229 |
+
def prepare_multi_example() -> List[Image.Image]:
|
230 |
+
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
|
231 |
+
images = []
|
232 |
+
for case in multi_case:
|
233 |
+
_images = []
|
234 |
+
for i in range(1, 4):
|
235 |
+
img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
|
236 |
+
W, H = img.size
|
237 |
+
img = img.resize((int(W / H * 512), 512))
|
238 |
+
_images.append(np.array(img))
|
239 |
+
images.append(Image.fromarray(np.concatenate(_images, axis=1)))
|
240 |
+
return images
|
241 |
+
|
242 |
+
|
243 |
+
def split_image(image: Image.Image) -> List[Image.Image]:
|
244 |
+
"""
|
245 |
+
Split an image into multiple views.
|
246 |
+
"""
|
247 |
+
image = np.array(image)
|
248 |
+
alpha = image[..., 3]
|
249 |
+
alpha = np.any(alpha>0, axis=0)
|
250 |
+
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
|
251 |
+
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
|
252 |
+
images = []
|
253 |
+
for s, e in zip(start_pos, end_pos):
|
254 |
+
images.append(Image.fromarray(image[:, s:e+1]))
|
255 |
+
return [preprocess_image(image) for image in images]
|
256 |
+
|
257 |
+
|
258 |
+
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
259 |
+
gr.Markdown("""
|
260 |
+
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
261 |
+
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
|
262 |
+
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
|
263 |
+
|
264 |
+
✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
|
265 |
+
""")
|
266 |
+
|
267 |
+
with gr.Row():
|
268 |
+
with gr.Column():
|
269 |
+
with gr.Tabs() as input_tabs:
|
270 |
+
with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
|
271 |
+
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
|
272 |
+
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
|
273 |
+
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
|
274 |
+
gr.Markdown("""
|
275 |
+
Input different views of the object in separate images.
|
276 |
+
|
277 |
+
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
|
278 |
+
""")
|
279 |
+
|
280 |
+
with gr.Accordion(label="Generation Settings", open=False):
|
281 |
+
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
282 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
283 |
+
gr.Markdown("Stage 1: Sparse Structure Generation")
|
284 |
+
with gr.Row():
|
285 |
+
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
286 |
+
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
287 |
+
gr.Markdown("Stage 2: Structured Latent Generation")
|
288 |
+
with gr.Row():
|
289 |
+
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
290 |
+
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
291 |
+
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
|
292 |
+
|
293 |
+
generate_btn = gr.Button("Generate")
|
294 |
+
|
295 |
+
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
296 |
+
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
297 |
+
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
301 |
+
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
|
302 |
+
gr.Markdown("""
|
303 |
+
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
304 |
+
""")
|
305 |
+
|
306 |
+
with gr.Column():
|
307 |
+
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
308 |
+
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
|
309 |
+
|
310 |
+
with gr.Row():
|
311 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
312 |
+
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
313 |
+
|
314 |
+
is_multiimage = gr.State(False)
|
315 |
+
output_buf = gr.State()
|
316 |
+
|
317 |
+
# Example images at the bottom of the page
|
318 |
+
with gr.Row() as single_image_example:
|
319 |
+
examples = gr.Examples(
|
320 |
+
examples=[
|
321 |
+
f'assets/example_image/{image}'
|
322 |
+
for image in os.listdir("assets/example_image")
|
323 |
+
],
|
324 |
+
inputs=[image_prompt],
|
325 |
+
fn=preprocess_image,
|
326 |
+
outputs=[image_prompt],
|
327 |
+
run_on_click=True,
|
328 |
+
examples_per_page=64,
|
329 |
+
)
|
330 |
+
with gr.Row(visible=False) as multiimage_example:
|
331 |
+
examples_multi = gr.Examples(
|
332 |
+
examples=prepare_multi_example(),
|
333 |
+
inputs=[image_prompt],
|
334 |
+
fn=split_image,
|
335 |
+
outputs=[multiimage_prompt],
|
336 |
+
run_on_click=True,
|
337 |
+
examples_per_page=8,
|
338 |
+
)
|
339 |
+
|
340 |
+
# Handlers
|
341 |
+
demo.load(start_session)
|
342 |
+
demo.unload(end_session)
|
343 |
+
|
344 |
+
single_image_input_tab.select(
|
345 |
+
lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
|
346 |
+
outputs=[is_multiimage, single_image_example, multiimage_example]
|
347 |
+
)
|
348 |
+
multiimage_input_tab.select(
|
349 |
+
lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
|
350 |
+
outputs=[is_multiimage, single_image_example, multiimage_example]
|
351 |
+
)
|
352 |
+
|
353 |
+
image_prompt.upload(
|
354 |
+
preprocess_image,
|
355 |
+
inputs=[image_prompt],
|
356 |
+
outputs=[image_prompt],
|
357 |
+
)
|
358 |
+
multiimage_prompt.upload(
|
359 |
+
preprocess_images,
|
360 |
+
inputs=[multiimage_prompt],
|
361 |
+
outputs=[multiimage_prompt],
|
362 |
+
)
|
363 |
+
|
364 |
+
generate_btn.click(
|
365 |
+
get_seed,
|
366 |
+
inputs=[randomize_seed, seed],
|
367 |
+
outputs=[seed],
|
368 |
+
).then(
|
369 |
+
image_to_3d,
|
370 |
+
inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
|
371 |
+
outputs=[output_buf, video_output],
|
372 |
+
).then(
|
373 |
+
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
374 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
375 |
+
)
|
376 |
+
|
377 |
+
video_output.clear(
|
378 |
+
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
379 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
380 |
+
)
|
381 |
+
|
382 |
+
extract_glb_btn.click(
|
383 |
+
extract_glb,
|
384 |
+
inputs=[output_buf, mesh_simplify, texture_size],
|
385 |
+
outputs=[model_output, download_glb],
|
386 |
+
).then(
|
387 |
+
lambda: gr.Button(interactive=True),
|
388 |
+
outputs=[download_glb],
|
389 |
+
)
|
390 |
+
|
391 |
+
extract_gs_btn.click(
|
392 |
+
extract_gaussian,
|
393 |
+
inputs=[output_buf],
|
394 |
+
outputs=[model_output, download_gs],
|
395 |
+
).then(
|
396 |
+
lambda: gr.Button(interactive=True),
|
397 |
+
outputs=[download_gs],
|
398 |
+
)
|
399 |
+
|
400 |
+
model_output.clear(
|
401 |
+
lambda: gr.Button(interactive=False),
|
402 |
+
outputs=[download_glb],
|
403 |
+
)
|
404 |
+
|
405 |
+
|
406 |
+
# Launch the Gradio app
|
407 |
+
if __name__ == "__main__":
|
408 |
+
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
|
409 |
+
pipeline.cuda()
|
410 |
+
try:
|
411 |
+
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
|
412 |
+
except:
|
413 |
+
pass
|
414 |
+
demo.launch()
|
configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 1024,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 2,
|
16 |
+
"io_block_channels": [128],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "ImageConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"image_size": 518,
|
30 |
+
"normalization": {
|
31 |
+
"mean": [
|
32 |
+
-2.1687545776367188,
|
33 |
+
-0.004347046371549368,
|
34 |
+
-0.13352349400520325,
|
35 |
+
-0.08418072760105133,
|
36 |
+
-0.5271206498146057,
|
37 |
+
0.7238689064979553,
|
38 |
+
-1.1414450407028198,
|
39 |
+
1.2039363384246826
|
40 |
+
],
|
41 |
+
"std": [
|
42 |
+
2.377650737762451,
|
43 |
+
2.386378288269043,
|
44 |
+
2.124418020248413,
|
45 |
+
2.1748552322387695,
|
46 |
+
2.663944721221924,
|
47 |
+
2.371192216873169,
|
48 |
+
2.6217446327209473,
|
49 |
+
2.684523105621338
|
50 |
+
]
|
51 |
+
},
|
52 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"trainer": {
|
56 |
+
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
|
57 |
+
"args": {
|
58 |
+
"max_steps": 1000000,
|
59 |
+
"batch_size_per_gpu": 8,
|
60 |
+
"batch_split": 4,
|
61 |
+
"optimizer": {
|
62 |
+
"name": "AdamW",
|
63 |
+
"args": {
|
64 |
+
"lr": 0.0001,
|
65 |
+
"weight_decay": 0.0
|
66 |
+
}
|
67 |
+
},
|
68 |
+
"ema_rate": [
|
69 |
+
0.9999
|
70 |
+
],
|
71 |
+
"fp16_mode": "inflat_all",
|
72 |
+
"fp16_scale_growth": 0.001,
|
73 |
+
"elastic": {
|
74 |
+
"name": "LinearMemoryController",
|
75 |
+
"args": {
|
76 |
+
"target_ratio": 0.75,
|
77 |
+
"max_mem_ratio_start": 0.5
|
78 |
+
}
|
79 |
+
},
|
80 |
+
"grad_clip": {
|
81 |
+
"name": "AdaptiveGradClipper",
|
82 |
+
"args": {
|
83 |
+
"max_norm": 1.0,
|
84 |
+
"clip_percentile": 95
|
85 |
+
}
|
86 |
+
},
|
87 |
+
"i_log": 500,
|
88 |
+
"i_sample": 10000,
|
89 |
+
"i_save": 10000,
|
90 |
+
"p_uncond": 0.1,
|
91 |
+
"t_schedule": {
|
92 |
+
"name": "logitNormal",
|
93 |
+
"args": {
|
94 |
+
"mean": 1.0,
|
95 |
+
"std": 1.0
|
96 |
+
}
|
97 |
+
},
|
98 |
+
"sigma_min": 1e-5,
|
99 |
+
"image_cond_model": "dinov2_vitl14_reg"
|
100 |
+
}
|
101 |
+
}
|
102 |
+
}
|
configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 768,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 12,
|
12 |
+
"num_heads": 12,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 2,
|
16 |
+
"io_block_channels": [128],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "TextConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"normalization": {
|
30 |
+
"mean": [
|
31 |
+
-2.1687545776367188,
|
32 |
+
-0.004347046371549368,
|
33 |
+
-0.13352349400520325,
|
34 |
+
-0.08418072760105133,
|
35 |
+
-0.5271206498146057,
|
36 |
+
0.7238689064979553,
|
37 |
+
-1.1414450407028198,
|
38 |
+
1.2039363384246826
|
39 |
+
],
|
40 |
+
"std": [
|
41 |
+
2.377650737762451,
|
42 |
+
2.386378288269043,
|
43 |
+
2.124418020248413,
|
44 |
+
2.1748552322387695,
|
45 |
+
2.663944721221924,
|
46 |
+
2.371192216873169,
|
47 |
+
2.6217446327209473,
|
48 |
+
2.684523105621338
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"trainer": {
|
55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
56 |
+
"args": {
|
57 |
+
"max_steps": 1000000,
|
58 |
+
"batch_size_per_gpu": 16,
|
59 |
+
"batch_split": 4,
|
60 |
+
"optimizer": {
|
61 |
+
"name": "AdamW",
|
62 |
+
"args": {
|
63 |
+
"lr": 0.0001,
|
64 |
+
"weight_decay": 0.0
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"ema_rate": [
|
68 |
+
0.9999
|
69 |
+
],
|
70 |
+
"fp16_mode": "inflat_all",
|
71 |
+
"fp16_scale_growth": 0.001,
|
72 |
+
"elastic": {
|
73 |
+
"name": "LinearMemoryController",
|
74 |
+
"args": {
|
75 |
+
"target_ratio": 0.75,
|
76 |
+
"max_mem_ratio_start": 0.5
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"grad_clip": {
|
80 |
+
"name": "AdaptiveGradClipper",
|
81 |
+
"args": {
|
82 |
+
"max_norm": 1.0,
|
83 |
+
"clip_percentile": 95
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"i_log": 500,
|
87 |
+
"i_sample": 10000,
|
88 |
+
"i_save": 10000,
|
89 |
+
"p_uncond": 0.1,
|
90 |
+
"t_schedule": {
|
91 |
+
"name": "logitNormal",
|
92 |
+
"args": {
|
93 |
+
"mean": 1.0,
|
94 |
+
"std": 1.0
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"sigma_min": 1e-5,
|
98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|
configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 2,
|
16 |
+
"io_block_channels": [128],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "TextConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"normalization": {
|
30 |
+
"mean": [
|
31 |
+
-2.1687545776367188,
|
32 |
+
-0.004347046371549368,
|
33 |
+
-0.13352349400520325,
|
34 |
+
-0.08418072760105133,
|
35 |
+
-0.5271206498146057,
|
36 |
+
0.7238689064979553,
|
37 |
+
-1.1414450407028198,
|
38 |
+
1.2039363384246826
|
39 |
+
],
|
40 |
+
"std": [
|
41 |
+
2.377650737762451,
|
42 |
+
2.386378288269043,
|
43 |
+
2.124418020248413,
|
44 |
+
2.1748552322387695,
|
45 |
+
2.663944721221924,
|
46 |
+
2.371192216873169,
|
47 |
+
2.6217446327209473,
|
48 |
+
2.684523105621338
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"trainer": {
|
55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
56 |
+
"args": {
|
57 |
+
"max_steps": 1000000,
|
58 |
+
"batch_size_per_gpu": 8,
|
59 |
+
"batch_split": 4,
|
60 |
+
"optimizer": {
|
61 |
+
"name": "AdamW",
|
62 |
+
"args": {
|
63 |
+
"lr": 0.0001,
|
64 |
+
"weight_decay": 0.0
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"ema_rate": [
|
68 |
+
0.9999
|
69 |
+
],
|
70 |
+
"fp16_mode": "inflat_all",
|
71 |
+
"fp16_scale_growth": 0.001,
|
72 |
+
"elastic": {
|
73 |
+
"name": "LinearMemoryController",
|
74 |
+
"args": {
|
75 |
+
"target_ratio": 0.75,
|
76 |
+
"max_mem_ratio_start": 0.5
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"grad_clip": {
|
80 |
+
"name": "AdaptiveGradClipper",
|
81 |
+
"args": {
|
82 |
+
"max_norm": 1.0,
|
83 |
+
"clip_percentile": 95
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"i_log": 500,
|
87 |
+
"i_sample": 10000,
|
88 |
+
"i_save": 10000,
|
89 |
+
"p_uncond": 0.1,
|
90 |
+
"t_schedule": {
|
91 |
+
"name": "logitNormal",
|
92 |
+
"args": {
|
93 |
+
"mean": 1.0,
|
94 |
+
"std": 1.0
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"sigma_min": 1e-5,
|
98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|
configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1280,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 28,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 3,
|
16 |
+
"io_block_channels": [256],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "TextConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"normalization": {
|
30 |
+
"mean": [
|
31 |
+
-2.1687545776367188,
|
32 |
+
-0.004347046371549368,
|
33 |
+
-0.13352349400520325,
|
34 |
+
-0.08418072760105133,
|
35 |
+
-0.5271206498146057,
|
36 |
+
0.7238689064979553,
|
37 |
+
-1.1414450407028198,
|
38 |
+
1.2039363384246826
|
39 |
+
],
|
40 |
+
"std": [
|
41 |
+
2.377650737762451,
|
42 |
+
2.386378288269043,
|
43 |
+
2.124418020248413,
|
44 |
+
2.1748552322387695,
|
45 |
+
2.663944721221924,
|
46 |
+
2.371192216873169,
|
47 |
+
2.6217446327209473,
|
48 |
+
2.684523105621338
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"trainer": {
|
55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
56 |
+
"args": {
|
57 |
+
"max_steps": 1000000,
|
58 |
+
"batch_size_per_gpu": 4,
|
59 |
+
"batch_split": 4,
|
60 |
+
"optimizer": {
|
61 |
+
"name": "AdamW",
|
62 |
+
"args": {
|
63 |
+
"lr": 0.0001,
|
64 |
+
"weight_decay": 0.0
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"ema_rate": [
|
68 |
+
0.9999
|
69 |
+
],
|
70 |
+
"fp16_mode": "inflat_all",
|
71 |
+
"fp16_scale_growth": 0.001,
|
72 |
+
"elastic": {
|
73 |
+
"name": "LinearMemoryController",
|
74 |
+
"args": {
|
75 |
+
"target_ratio": 0.75,
|
76 |
+
"max_mem_ratio_start": 0.5
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"grad_clip": {
|
80 |
+
"name": "AdaptiveGradClipper",
|
81 |
+
"args": {
|
82 |
+
"max_norm": 1.0,
|
83 |
+
"clip_percentile": 95
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"i_log": 500,
|
87 |
+
"i_sample": 10000,
|
88 |
+
"i_save": 10000,
|
89 |
+
"p_uncond": 0.1,
|
90 |
+
"t_schedule": {
|
91 |
+
"name": "logitNormal",
|
92 |
+
"args": {
|
93 |
+
"mean": 1.0,
|
94 |
+
"std": 1.0
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"sigma_min": 1e-5,
|
98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|
configs/generation/ss_flow_img_dit_L_16l8_fp16.json
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 1024,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"use_fp16": true
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "ImageConditionedSparseStructureLatent",
|
23 |
+
"args": {
|
24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
25 |
+
"min_aesthetic_score": 4.5,
|
26 |
+
"image_size": 518,
|
27 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"trainer": {
|
31 |
+
"name": "ImageConditionedFlowMatchingCFGTrainer",
|
32 |
+
"args": {
|
33 |
+
"max_steps": 1000000,
|
34 |
+
"batch_size_per_gpu": 8,
|
35 |
+
"batch_split": 1,
|
36 |
+
"optimizer": {
|
37 |
+
"name": "AdamW",
|
38 |
+
"args": {
|
39 |
+
"lr": 0.0001,
|
40 |
+
"weight_decay": 0.0
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"ema_rate": [
|
44 |
+
0.9999
|
45 |
+
],
|
46 |
+
"fp16_mode": "inflat_all",
|
47 |
+
"fp16_scale_growth": 0.001,
|
48 |
+
"grad_clip": {
|
49 |
+
"name": "AdaptiveGradClipper",
|
50 |
+
"args": {
|
51 |
+
"max_norm": 1.0,
|
52 |
+
"clip_percentile": 95
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"i_log": 500,
|
56 |
+
"i_sample": 10000,
|
57 |
+
"i_save": 10000,
|
58 |
+
"p_uncond": 0.1,
|
59 |
+
"t_schedule": {
|
60 |
+
"name": "logitNormal",
|
61 |
+
"args": {
|
62 |
+
"mean": 1.0,
|
63 |
+
"std": 1.0
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"sigma_min": 1e-5,
|
67 |
+
"image_cond_model": "dinov2_vitl14_reg"
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
configs/generation/ss_flow_txt_dit_B_16l8_fp16.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 768,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 12,
|
12 |
+
"num_heads": 12,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"use_fp16": true
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "TextConditionedSparseStructureLatent",
|
23 |
+
"args": {
|
24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
25 |
+
"min_aesthetic_score": 4.5,
|
26 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"trainer": {
|
30 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
31 |
+
"args": {
|
32 |
+
"max_steps": 1000000,
|
33 |
+
"batch_size_per_gpu": 16,
|
34 |
+
"batch_split": 1,
|
35 |
+
"optimizer": {
|
36 |
+
"name": "AdamW",
|
37 |
+
"args": {
|
38 |
+
"lr": 0.0001,
|
39 |
+
"weight_decay": 0.0
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"ema_rate": [
|
43 |
+
0.9999
|
44 |
+
],
|
45 |
+
"fp16_mode": "inflat_all",
|
46 |
+
"fp16_scale_growth": 0.001,
|
47 |
+
"grad_clip": {
|
48 |
+
"name": "AdaptiveGradClipper",
|
49 |
+
"args": {
|
50 |
+
"max_norm": 1.0,
|
51 |
+
"clip_percentile": 95
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"i_log": 500,
|
55 |
+
"i_sample": 10000,
|
56 |
+
"i_save": 10000,
|
57 |
+
"p_uncond": 0.1,
|
58 |
+
"t_schedule": {
|
59 |
+
"name": "logitNormal",
|
60 |
+
"args": {
|
61 |
+
"mean": 1.0,
|
62 |
+
"std": 1.0
|
63 |
+
}
|
64 |
+
},
|
65 |
+
"sigma_min": 1e-5,
|
66 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
configs/generation/ss_flow_txt_dit_L_16l8_fp16.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"use_fp16": true
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "TextConditionedSparseStructureLatent",
|
23 |
+
"args": {
|
24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
25 |
+
"min_aesthetic_score": 4.5,
|
26 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"trainer": {
|
30 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
31 |
+
"args": {
|
32 |
+
"max_steps": 1000000,
|
33 |
+
"batch_size_per_gpu": 8,
|
34 |
+
"batch_split": 1,
|
35 |
+
"optimizer": {
|
36 |
+
"name": "AdamW",
|
37 |
+
"args": {
|
38 |
+
"lr": 0.0001,
|
39 |
+
"weight_decay": 0.0
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"ema_rate": [
|
43 |
+
0.9999
|
44 |
+
],
|
45 |
+
"fp16_mode": "inflat_all",
|
46 |
+
"fp16_scale_growth": 0.001,
|
47 |
+
"grad_clip": {
|
48 |
+
"name": "AdaptiveGradClipper",
|
49 |
+
"args": {
|
50 |
+
"max_norm": 1.0,
|
51 |
+
"clip_percentile": 95
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"i_log": 500,
|
55 |
+
"i_sample": 10000,
|
56 |
+
"i_save": 10000,
|
57 |
+
"p_uncond": 0.1,
|
58 |
+
"t_schedule": {
|
59 |
+
"name": "logitNormal",
|
60 |
+
"args": {
|
61 |
+
"mean": 1.0,
|
62 |
+
"std": 1.0
|
63 |
+
}
|
64 |
+
},
|
65 |
+
"sigma_min": 1e-5,
|
66 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1280,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 28,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"qk_rms_norm_cross": true,
|
18 |
+
"use_fp16": true
|
19 |
+
}
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"dataset": {
|
23 |
+
"name": "TextConditionedSparseStructureLatent",
|
24 |
+
"args": {
|
25 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
26 |
+
"min_aesthetic_score": 4.5,
|
27 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"trainer": {
|
31 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
32 |
+
"args": {
|
33 |
+
"max_steps": 1000000,
|
34 |
+
"batch_size_per_gpu": 4,
|
35 |
+
"batch_split": 1,
|
36 |
+
"optimizer": {
|
37 |
+
"name": "AdamW",
|
38 |
+
"args": {
|
39 |
+
"lr": 0.0001,
|
40 |
+
"weight_decay": 0.0
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"ema_rate": [
|
44 |
+
0.9999
|
45 |
+
],
|
46 |
+
"fp16_mode": "inflat_all",
|
47 |
+
"fp16_scale_growth": 0.001,
|
48 |
+
"grad_clip": {
|
49 |
+
"name": "AdaptiveGradClipper",
|
50 |
+
"args": {
|
51 |
+
"max_norm": 1.0,
|
52 |
+
"clip_percentile": 95
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"i_log": 500,
|
56 |
+
"i_sample": 10000,
|
57 |
+
"i_save": 10000,
|
58 |
+
"p_uncond": 0.1,
|
59 |
+
"t_schedule": {
|
60 |
+
"name": "logitNormal",
|
61 |
+
"args": {
|
62 |
+
"mean": 1.0,
|
63 |
+
"std": 1.0
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"sigma_min": 1e-5,
|
67 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"decoder": {
|
4 |
+
"name": "ElasticSLatMeshDecoder",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"model_channels": 768,
|
8 |
+
"latent_channels": 8,
|
9 |
+
"num_blocks": 12,
|
10 |
+
"num_heads": 12,
|
11 |
+
"mlp_ratio": 4,
|
12 |
+
"attn_mode": "swin",
|
13 |
+
"window_size": 8,
|
14 |
+
"use_fp16": true,
|
15 |
+
"representation_config": {
|
16 |
+
"use_color": true
|
17 |
+
}
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "Slat2RenderGeo",
|
23 |
+
"args": {
|
24 |
+
"image_size": 512,
|
25 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
26 |
+
"min_aesthetic_score": 4.5,
|
27 |
+
"max_num_voxels": 32768
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"trainer": {
|
31 |
+
"name": "SLatVaeMeshDecoderTrainer",
|
32 |
+
"args": {
|
33 |
+
"max_steps": 1000000,
|
34 |
+
"batch_size_per_gpu": 4,
|
35 |
+
"batch_split": 4,
|
36 |
+
"optimizer": {
|
37 |
+
"name": "AdamW",
|
38 |
+
"args": {
|
39 |
+
"lr": 1e-4,
|
40 |
+
"weight_decay": 0.0
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"ema_rate": [
|
44 |
+
0.9999
|
45 |
+
],
|
46 |
+
"fp16_mode": "inflat_all",
|
47 |
+
"fp16_scale_growth": 0.001,
|
48 |
+
"elastic": {
|
49 |
+
"name": "LinearMemoryController",
|
50 |
+
"args": {
|
51 |
+
"target_ratio": 0.75,
|
52 |
+
"max_mem_ratio_start": 0.5
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"grad_clip": {
|
56 |
+
"name": "AdaptiveGradClipper",
|
57 |
+
"args": {
|
58 |
+
"max_norm": 1.0,
|
59 |
+
"clip_percentile": 95
|
60 |
+
}
|
61 |
+
},
|
62 |
+
"i_log": 500,
|
63 |
+
"i_sample": 10000,
|
64 |
+
"i_save": 10000,
|
65 |
+
"lambda_ssim": 0.2,
|
66 |
+
"lambda_lpips": 0.2,
|
67 |
+
"lambda_tsdf": 0.01,
|
68 |
+
"lambda_depth": 10.0,
|
69 |
+
"lambda_color": 0.1,
|
70 |
+
"depth_loss_type": "smooth_l1"
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"decoder": {
|
4 |
+
"name": "ElasticSLatRadianceFieldDecoder",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"model_channels": 768,
|
8 |
+
"latent_channels": 8,
|
9 |
+
"num_blocks": 12,
|
10 |
+
"num_heads": 12,
|
11 |
+
"mlp_ratio": 4,
|
12 |
+
"attn_mode": "swin",
|
13 |
+
"window_size": 8,
|
14 |
+
"use_fp16": true,
|
15 |
+
"representation_config": {
|
16 |
+
"rank": 16,
|
17 |
+
"dim": 8
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"dataset": {
|
23 |
+
"name": "SLat2Render",
|
24 |
+
"args": {
|
25 |
+
"image_size": 512,
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768
|
29 |
+
}
|
30 |
+
},
|
31 |
+
"trainer": {
|
32 |
+
"name": "SLatVaeRadianceFieldDecoderTrainer",
|
33 |
+
"args": {
|
34 |
+
"max_steps": 1000000,
|
35 |
+
"batch_size_per_gpu": 4,
|
36 |
+
"batch_split": 2,
|
37 |
+
"optimizer": {
|
38 |
+
"name": "AdamW",
|
39 |
+
"args": {
|
40 |
+
"lr": 1e-4,
|
41 |
+
"weight_decay": 0.0
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"ema_rate": [
|
45 |
+
0.9999
|
46 |
+
],
|
47 |
+
"fp16_mode": "inflat_all",
|
48 |
+
"fp16_scale_growth": 0.001,
|
49 |
+
"elastic": {
|
50 |
+
"name": "LinearMemoryController",
|
51 |
+
"args": {
|
52 |
+
"target_ratio": 0.75,
|
53 |
+
"max_mem_ratio_start": 0.5
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"grad_clip": {
|
57 |
+
"name": "AdaptiveGradClipper",
|
58 |
+
"args": {
|
59 |
+
"max_norm": 1.0,
|
60 |
+
"clip_percentile": 95
|
61 |
+
}
|
62 |
+
},
|
63 |
+
"i_log": 500,
|
64 |
+
"i_sample": 10000,
|
65 |
+
"i_save": 10000,
|
66 |
+
"loss_type": "l1",
|
67 |
+
"lambda_ssim": 0.2,
|
68 |
+
"lambda_lpips": 0.2
|
69 |
+
}
|
70 |
+
}
|
71 |
+
}
|
configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"encoder": {
|
4 |
+
"name": "ElasticSLatEncoder",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 1024,
|
8 |
+
"model_channels": 768,
|
9 |
+
"latent_channels": 8,
|
10 |
+
"num_blocks": 12,
|
11 |
+
"num_heads": 12,
|
12 |
+
"mlp_ratio": 4,
|
13 |
+
"attn_mode": "swin",
|
14 |
+
"window_size": 8,
|
15 |
+
"use_fp16": true
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"decoder": {
|
19 |
+
"name": "ElasticSLatGaussianDecoder",
|
20 |
+
"args": {
|
21 |
+
"resolution": 64,
|
22 |
+
"model_channels": 768,
|
23 |
+
"latent_channels": 8,
|
24 |
+
"num_blocks": 12,
|
25 |
+
"num_heads": 12,
|
26 |
+
"mlp_ratio": 4,
|
27 |
+
"attn_mode": "swin",
|
28 |
+
"window_size": 8,
|
29 |
+
"use_fp16": true,
|
30 |
+
"representation_config": {
|
31 |
+
"lr": {
|
32 |
+
"_xyz": 1.0,
|
33 |
+
"_features_dc": 1.0,
|
34 |
+
"_opacity": 1.0,
|
35 |
+
"_scaling": 1.0,
|
36 |
+
"_rotation": 0.1
|
37 |
+
},
|
38 |
+
"perturb_offset": true,
|
39 |
+
"voxel_size": 1.5,
|
40 |
+
"num_gaussians": 32,
|
41 |
+
"2d_filter_kernel_size": 0.1,
|
42 |
+
"3d_filter_kernel_size": 9e-4,
|
43 |
+
"scaling_bias": 4e-3,
|
44 |
+
"opacity_bias": 0.1,
|
45 |
+
"scaling_activation": "softplus"
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
},
|
50 |
+
"dataset": {
|
51 |
+
"name": "SparseFeat2Render",
|
52 |
+
"args": {
|
53 |
+
"image_size": 512,
|
54 |
+
"model": "dinov2_vitl14_reg",
|
55 |
+
"resolution": 64,
|
56 |
+
"min_aesthetic_score": 4.5,
|
57 |
+
"max_num_voxels": 32768
|
58 |
+
}
|
59 |
+
},
|
60 |
+
"trainer": {
|
61 |
+
"name": "SLatVaeGaussianTrainer",
|
62 |
+
"args": {
|
63 |
+
"max_steps": 1000000,
|
64 |
+
"batch_size_per_gpu": 4,
|
65 |
+
"batch_split": 2,
|
66 |
+
"optimizer": {
|
67 |
+
"name": "AdamW",
|
68 |
+
"args": {
|
69 |
+
"lr": 1e-4,
|
70 |
+
"weight_decay": 0.0
|
71 |
+
}
|
72 |
+
},
|
73 |
+
"ema_rate": [
|
74 |
+
0.9999
|
75 |
+
],
|
76 |
+
"fp16_mode": "inflat_all",
|
77 |
+
"fp16_scale_growth": 0.001,
|
78 |
+
"elastic": {
|
79 |
+
"name": "LinearMemoryController",
|
80 |
+
"args": {
|
81 |
+
"target_ratio": 0.75,
|
82 |
+
"max_mem_ratio_start": 0.5
|
83 |
+
}
|
84 |
+
},
|
85 |
+
"grad_clip": {
|
86 |
+
"name": "AdaptiveGradClipper",
|
87 |
+
"args": {
|
88 |
+
"max_norm": 1.0,
|
89 |
+
"clip_percentile": 95
|
90 |
+
}
|
91 |
+
},
|
92 |
+
"i_log": 500,
|
93 |
+
"i_sample": 10000,
|
94 |
+
"i_save": 10000,
|
95 |
+
"loss_type": "l1",
|
96 |
+
"lambda_ssim": 0.2,
|
97 |
+
"lambda_lpips": 0.2,
|
98 |
+
"lambda_kl": 1e-06,
|
99 |
+
"regularizations": {
|
100 |
+
"lambda_vol": 10000.0,
|
101 |
+
"lambda_opacity": 0.001
|
102 |
+
}
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
configs/vae/ss_vae_conv3d_16l8_fp16.json
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"encoder": {
|
4 |
+
"name": "SparseStructureEncoder",
|
5 |
+
"args": {
|
6 |
+
"in_channels": 1,
|
7 |
+
"latent_channels": 8,
|
8 |
+
"num_res_blocks": 2,
|
9 |
+
"num_res_blocks_middle": 2,
|
10 |
+
"channels": [32, 128, 512],
|
11 |
+
"use_fp16": true
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"decoder": {
|
15 |
+
"name": "SparseStructureDecoder",
|
16 |
+
"args": {
|
17 |
+
"out_channels": 1,
|
18 |
+
"latent_channels": 8,
|
19 |
+
"num_res_blocks": 2,
|
20 |
+
"num_res_blocks_middle": 2,
|
21 |
+
"channels": [512, 128, 32],
|
22 |
+
"use_fp16": true
|
23 |
+
}
|
24 |
+
}
|
25 |
+
},
|
26 |
+
"dataset": {
|
27 |
+
"name": "SparseStructure",
|
28 |
+
"args": {
|
29 |
+
"resolution": 64,
|
30 |
+
"min_aesthetic_score": 4.5
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"trainer": {
|
34 |
+
"name": "SparseStructureVaeTrainer",
|
35 |
+
"args": {
|
36 |
+
"max_steps": 1000000,
|
37 |
+
"batch_size_per_gpu": 4,
|
38 |
+
"batch_split": 1,
|
39 |
+
"optimizer": {
|
40 |
+
"name": "AdamW",
|
41 |
+
"args": {
|
42 |
+
"lr": 1e-4,
|
43 |
+
"weight_decay": 0.0
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"ema_rate": [
|
47 |
+
0.9999
|
48 |
+
],
|
49 |
+
"fp16_mode": "inflat_all",
|
50 |
+
"fp16_scale_growth": 0.001,
|
51 |
+
"grad_clip": {
|
52 |
+
"name": "AdaptiveGradClipper",
|
53 |
+
"args": {
|
54 |
+
"max_norm": 1.0,
|
55 |
+
"clip_percentile": 95
|
56 |
+
}
|
57 |
+
},
|
58 |
+
"i_log": 500,
|
59 |
+
"i_sample": 10000,
|
60 |
+
"i_save": 10000,
|
61 |
+
"loss_type": "dice",
|
62 |
+
"lambda_kl": 0.001
|
63 |
+
}
|
64 |
+
}
|
65 |
+
}
|
dataset_toolkits/blender_script/io_scene_usdz.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec07ab6125fe0a021ed08c64169eceda126330401aba3d494d5203d26ac4b093
|
3 |
+
size 34685
|
dataset_toolkits/blender_script/render.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, sys, os, math, re, glob
|
2 |
+
from typing import *
|
3 |
+
import bpy
|
4 |
+
from mathutils import Vector, Matrix
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import glob
|
8 |
+
|
9 |
+
|
10 |
+
"""=============== BLENDER ==============="""
|
11 |
+
|
12 |
+
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
13 |
+
"obj": bpy.ops.import_scene.obj,
|
14 |
+
"glb": bpy.ops.import_scene.gltf,
|
15 |
+
"gltf": bpy.ops.import_scene.gltf,
|
16 |
+
"usd": bpy.ops.import_scene.usd,
|
17 |
+
"fbx": bpy.ops.import_scene.fbx,
|
18 |
+
"stl": bpy.ops.import_mesh.stl,
|
19 |
+
"usda": bpy.ops.import_scene.usda,
|
20 |
+
"dae": bpy.ops.wm.collada_import,
|
21 |
+
"ply": bpy.ops.import_mesh.ply,
|
22 |
+
"abc": bpy.ops.wm.alembic_import,
|
23 |
+
"blend": bpy.ops.wm.append,
|
24 |
+
}
|
25 |
+
|
26 |
+
EXT = {
|
27 |
+
'PNG': 'png',
|
28 |
+
'JPEG': 'jpg',
|
29 |
+
'OPEN_EXR': 'exr',
|
30 |
+
'TIFF': 'tiff',
|
31 |
+
'BMP': 'bmp',
|
32 |
+
'HDR': 'hdr',
|
33 |
+
'TARGA': 'tga'
|
34 |
+
}
|
35 |
+
|
36 |
+
def init_render(engine='CYCLES', resolution=512, geo_mode=False):
|
37 |
+
bpy.context.scene.render.engine = engine
|
38 |
+
bpy.context.scene.render.resolution_x = resolution
|
39 |
+
bpy.context.scene.render.resolution_y = resolution
|
40 |
+
bpy.context.scene.render.resolution_percentage = 100
|
41 |
+
bpy.context.scene.render.image_settings.file_format = 'PNG'
|
42 |
+
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
|
43 |
+
bpy.context.scene.render.film_transparent = True
|
44 |
+
|
45 |
+
bpy.context.scene.cycles.device = 'GPU'
|
46 |
+
bpy.context.scene.cycles.samples = 128 if not geo_mode else 1
|
47 |
+
bpy.context.scene.cycles.filter_type = 'BOX'
|
48 |
+
bpy.context.scene.cycles.filter_width = 1
|
49 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
50 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
51 |
+
bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0
|
52 |
+
bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1
|
53 |
+
bpy.context.scene.cycles.use_denoising = True
|
54 |
+
|
55 |
+
bpy.context.preferences.addons['cycles'].preferences.get_devices()
|
56 |
+
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
57 |
+
|
58 |
+
def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False):
|
59 |
+
if not any([save_depth, save_normal, save_albedo, save_mist]):
|
60 |
+
return {}, {}
|
61 |
+
outputs = {}
|
62 |
+
spec_nodes = {}
|
63 |
+
|
64 |
+
bpy.context.scene.use_nodes = True
|
65 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth
|
66 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal
|
67 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo
|
68 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist
|
69 |
+
|
70 |
+
nodes = bpy.context.scene.node_tree.nodes
|
71 |
+
links = bpy.context.scene.node_tree.links
|
72 |
+
for n in nodes:
|
73 |
+
nodes.remove(n)
|
74 |
+
|
75 |
+
render_layers = nodes.new('CompositorNodeRLayers')
|
76 |
+
|
77 |
+
if save_depth:
|
78 |
+
depth_file_output = nodes.new('CompositorNodeOutputFile')
|
79 |
+
depth_file_output.base_path = ''
|
80 |
+
depth_file_output.file_slots[0].use_node_format = True
|
81 |
+
depth_file_output.format.file_format = 'PNG'
|
82 |
+
depth_file_output.format.color_depth = '16'
|
83 |
+
depth_file_output.format.color_mode = 'BW'
|
84 |
+
# Remap to 0-1
|
85 |
+
map = nodes.new(type="CompositorNodeMapRange")
|
86 |
+
map.inputs[1].default_value = 0 # (min value you will be getting)
|
87 |
+
map.inputs[2].default_value = 10 # (max value you will be getting)
|
88 |
+
map.inputs[3].default_value = 0 # (min value you will map to)
|
89 |
+
map.inputs[4].default_value = 1 # (max value you will map to)
|
90 |
+
|
91 |
+
links.new(render_layers.outputs['Depth'], map.inputs[0])
|
92 |
+
links.new(map.outputs[0], depth_file_output.inputs[0])
|
93 |
+
|
94 |
+
outputs['depth'] = depth_file_output
|
95 |
+
spec_nodes['depth_map'] = map
|
96 |
+
|
97 |
+
if save_normal:
|
98 |
+
normal_file_output = nodes.new('CompositorNodeOutputFile')
|
99 |
+
normal_file_output.base_path = ''
|
100 |
+
normal_file_output.file_slots[0].use_node_format = True
|
101 |
+
normal_file_output.format.file_format = 'OPEN_EXR'
|
102 |
+
normal_file_output.format.color_mode = 'RGB'
|
103 |
+
normal_file_output.format.color_depth = '16'
|
104 |
+
|
105 |
+
links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0])
|
106 |
+
|
107 |
+
outputs['normal'] = normal_file_output
|
108 |
+
|
109 |
+
if save_albedo:
|
110 |
+
albedo_file_output = nodes.new('CompositorNodeOutputFile')
|
111 |
+
albedo_file_output.base_path = ''
|
112 |
+
albedo_file_output.file_slots[0].use_node_format = True
|
113 |
+
albedo_file_output.format.file_format = 'PNG'
|
114 |
+
albedo_file_output.format.color_mode = 'RGBA'
|
115 |
+
albedo_file_output.format.color_depth = '8'
|
116 |
+
|
117 |
+
alpha_albedo = nodes.new('CompositorNodeSetAlpha')
|
118 |
+
|
119 |
+
links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image'])
|
120 |
+
links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha'])
|
121 |
+
links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0])
|
122 |
+
|
123 |
+
outputs['albedo'] = albedo_file_output
|
124 |
+
|
125 |
+
if save_mist:
|
126 |
+
bpy.data.worlds['World'].mist_settings.start = 0
|
127 |
+
bpy.data.worlds['World'].mist_settings.depth = 10
|
128 |
+
|
129 |
+
mist_file_output = nodes.new('CompositorNodeOutputFile')
|
130 |
+
mist_file_output.base_path = ''
|
131 |
+
mist_file_output.file_slots[0].use_node_format = True
|
132 |
+
mist_file_output.format.file_format = 'PNG'
|
133 |
+
mist_file_output.format.color_mode = 'BW'
|
134 |
+
mist_file_output.format.color_depth = '16'
|
135 |
+
|
136 |
+
links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0])
|
137 |
+
|
138 |
+
outputs['mist'] = mist_file_output
|
139 |
+
|
140 |
+
return outputs, spec_nodes
|
141 |
+
|
142 |
+
def init_scene() -> None:
|
143 |
+
"""Resets the scene to a clean state.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
None
|
147 |
+
"""
|
148 |
+
# delete everything
|
149 |
+
for obj in bpy.data.objects:
|
150 |
+
bpy.data.objects.remove(obj, do_unlink=True)
|
151 |
+
|
152 |
+
# delete all the materials
|
153 |
+
for material in bpy.data.materials:
|
154 |
+
bpy.data.materials.remove(material, do_unlink=True)
|
155 |
+
|
156 |
+
# delete all the textures
|
157 |
+
for texture in bpy.data.textures:
|
158 |
+
bpy.data.textures.remove(texture, do_unlink=True)
|
159 |
+
|
160 |
+
# delete all the images
|
161 |
+
for image in bpy.data.images:
|
162 |
+
bpy.data.images.remove(image, do_unlink=True)
|
163 |
+
|
164 |
+
def init_camera():
|
165 |
+
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
|
166 |
+
bpy.context.collection.objects.link(cam)
|
167 |
+
bpy.context.scene.camera = cam
|
168 |
+
cam.data.sensor_height = cam.data.sensor_width = 32
|
169 |
+
cam_constraint = cam.constraints.new(type='TRACK_TO')
|
170 |
+
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
|
171 |
+
cam_constraint.up_axis = 'UP_Y'
|
172 |
+
cam_empty = bpy.data.objects.new("Empty", None)
|
173 |
+
cam_empty.location = (0, 0, 0)
|
174 |
+
bpy.context.scene.collection.objects.link(cam_empty)
|
175 |
+
cam_constraint.target = cam_empty
|
176 |
+
return cam
|
177 |
+
|
178 |
+
def init_lighting():
|
179 |
+
# Clear existing lights
|
180 |
+
bpy.ops.object.select_all(action="DESELECT")
|
181 |
+
bpy.ops.object.select_by_type(type="LIGHT")
|
182 |
+
bpy.ops.object.delete()
|
183 |
+
|
184 |
+
# Create key light
|
185 |
+
default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
|
186 |
+
bpy.context.collection.objects.link(default_light)
|
187 |
+
default_light.data.energy = 1000
|
188 |
+
default_light.location = (4, 1, 6)
|
189 |
+
default_light.rotation_euler = (0, 0, 0)
|
190 |
+
|
191 |
+
# create top light
|
192 |
+
top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
|
193 |
+
bpy.context.collection.objects.link(top_light)
|
194 |
+
top_light.data.energy = 10000
|
195 |
+
top_light.location = (0, 0, 10)
|
196 |
+
top_light.scale = (100, 100, 100)
|
197 |
+
|
198 |
+
# create bottom light
|
199 |
+
bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
|
200 |
+
bpy.context.collection.objects.link(bottom_light)
|
201 |
+
bottom_light.data.energy = 1000
|
202 |
+
bottom_light.location = (0, 0, -10)
|
203 |
+
bottom_light.rotation_euler = (0, 0, 0)
|
204 |
+
|
205 |
+
return {
|
206 |
+
"default_light": default_light,
|
207 |
+
"top_light": top_light,
|
208 |
+
"bottom_light": bottom_light
|
209 |
+
}
|
210 |
+
|
211 |
+
|
212 |
+
def load_object(object_path: str) -> None:
|
213 |
+
"""Loads a model with a supported file extension into the scene.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
object_path (str): Path to the model file.
|
217 |
+
|
218 |
+
Raises:
|
219 |
+
ValueError: If the file extension is not supported.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
None
|
223 |
+
"""
|
224 |
+
file_extension = object_path.split(".")[-1].lower()
|
225 |
+
if file_extension is None:
|
226 |
+
raise ValueError(f"Unsupported file type: {object_path}")
|
227 |
+
|
228 |
+
if file_extension == "usdz":
|
229 |
+
# install usdz io package
|
230 |
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
231 |
+
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
|
232 |
+
bpy.ops.preferences.addon_install(filepath=usdz_package)
|
233 |
+
# enable it
|
234 |
+
addon_name = "io_scene_usdz"
|
235 |
+
bpy.ops.preferences.addon_enable(module=addon_name)
|
236 |
+
# import the usdz
|
237 |
+
from io_scene_usdz.import_usdz import import_usdz
|
238 |
+
|
239 |
+
import_usdz(context, filepath=object_path, materials=True, animations=True)
|
240 |
+
return None
|
241 |
+
|
242 |
+
# load from existing import functions
|
243 |
+
import_function = IMPORT_FUNCTIONS[file_extension]
|
244 |
+
|
245 |
+
print(f"Loading object from {object_path}")
|
246 |
+
if file_extension == "blend":
|
247 |
+
import_function(directory=object_path, link=False)
|
248 |
+
elif file_extension in {"glb", "gltf"}:
|
249 |
+
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
|
250 |
+
else:
|
251 |
+
import_function(filepath=object_path)
|
252 |
+
|
253 |
+
def delete_invisible_objects() -> None:
|
254 |
+
"""Deletes all invisible objects in the scene.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
None
|
258 |
+
"""
|
259 |
+
# bpy.ops.object.mode_set(mode="OBJECT")
|
260 |
+
bpy.ops.object.select_all(action="DESELECT")
|
261 |
+
for obj in bpy.context.scene.objects:
|
262 |
+
if obj.hide_viewport or obj.hide_render:
|
263 |
+
obj.hide_viewport = False
|
264 |
+
obj.hide_render = False
|
265 |
+
obj.hide_select = False
|
266 |
+
obj.select_set(True)
|
267 |
+
bpy.ops.object.delete()
|
268 |
+
|
269 |
+
# Delete invisible collections
|
270 |
+
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
|
271 |
+
for col in invisible_collections:
|
272 |
+
bpy.data.collections.remove(col)
|
273 |
+
|
274 |
+
def split_mesh_normal():
|
275 |
+
bpy.ops.object.select_all(action="DESELECT")
|
276 |
+
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
277 |
+
bpy.context.view_layer.objects.active = objs[0]
|
278 |
+
for obj in objs:
|
279 |
+
obj.select_set(True)
|
280 |
+
bpy.ops.object.mode_set(mode="EDIT")
|
281 |
+
bpy.ops.mesh.select_all(action='SELECT')
|
282 |
+
bpy.ops.mesh.split_normals()
|
283 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
284 |
+
bpy.ops.object.select_all(action="DESELECT")
|
285 |
+
|
286 |
+
def delete_custom_normals():
|
287 |
+
for this_obj in bpy.data.objects:
|
288 |
+
if this_obj.type == "MESH":
|
289 |
+
bpy.context.view_layer.objects.active = this_obj
|
290 |
+
bpy.ops.mesh.customdata_custom_splitnormals_clear()
|
291 |
+
|
292 |
+
def override_material():
|
293 |
+
new_mat = bpy.data.materials.new(name="Override0123456789")
|
294 |
+
new_mat.use_nodes = True
|
295 |
+
new_mat.node_tree.nodes.clear()
|
296 |
+
bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
|
297 |
+
bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
|
298 |
+
bsdf.inputs[1].default_value = 1
|
299 |
+
output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
300 |
+
new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
301 |
+
bpy.context.scene.view_layers['View Layer'].material_override = new_mat
|
302 |
+
|
303 |
+
def unhide_all_objects() -> None:
|
304 |
+
"""Unhides all objects in the scene.
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
None
|
308 |
+
"""
|
309 |
+
for obj in bpy.context.scene.objects:
|
310 |
+
obj.hide_set(False)
|
311 |
+
|
312 |
+
def convert_to_meshes() -> None:
|
313 |
+
"""Converts all objects in the scene to meshes.
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
None
|
317 |
+
"""
|
318 |
+
bpy.ops.object.select_all(action="DESELECT")
|
319 |
+
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
|
320 |
+
for obj in bpy.context.scene.objects:
|
321 |
+
obj.select_set(True)
|
322 |
+
bpy.ops.object.convert(target="MESH")
|
323 |
+
|
324 |
+
def triangulate_meshes() -> None:
|
325 |
+
"""Triangulates all meshes in the scene.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
None
|
329 |
+
"""
|
330 |
+
bpy.ops.object.select_all(action="DESELECT")
|
331 |
+
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
332 |
+
bpy.context.view_layer.objects.active = objs[0]
|
333 |
+
for obj in objs:
|
334 |
+
obj.select_set(True)
|
335 |
+
bpy.ops.object.mode_set(mode="EDIT")
|
336 |
+
bpy.ops.mesh.reveal()
|
337 |
+
bpy.ops.mesh.select_all(action="SELECT")
|
338 |
+
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
|
339 |
+
bpy.ops.object.mode_set(mode="OBJECT")
|
340 |
+
bpy.ops.object.select_all(action="DESELECT")
|
341 |
+
|
342 |
+
def scene_bbox() -> Tuple[Vector, Vector]:
|
343 |
+
"""Returns the bounding box of the scene.
|
344 |
+
|
345 |
+
Taken from Shap-E rendering script
|
346 |
+
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
350 |
+
"""
|
351 |
+
bbox_min = (math.inf,) * 3
|
352 |
+
bbox_max = (-math.inf,) * 3
|
353 |
+
found = False
|
354 |
+
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
|
355 |
+
for obj in scene_meshes:
|
356 |
+
found = True
|
357 |
+
for coord in obj.bound_box:
|
358 |
+
coord = Vector(coord)
|
359 |
+
coord = obj.matrix_world @ coord
|
360 |
+
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
361 |
+
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
362 |
+
if not found:
|
363 |
+
raise RuntimeError("no objects in scene to compute bounding box for")
|
364 |
+
return Vector(bbox_min), Vector(bbox_max)
|
365 |
+
|
366 |
+
def normalize_scene() -> Tuple[float, Vector]:
|
367 |
+
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
368 |
+
at the origin.
|
369 |
+
|
370 |
+
Mostly taken from the Point-E / Shap-E rendering script
|
371 |
+
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
372 |
+
but fix for multiple root objects: (see bug report here:
|
373 |
+
https://github.com/openai/shap-e/pull/60).
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
|
377 |
+
"""
|
378 |
+
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
|
379 |
+
if len(scene_root_objects) > 1:
|
380 |
+
# create an empty object to be used as a parent for all root objects
|
381 |
+
scene = bpy.data.objects.new("ParentEmpty", None)
|
382 |
+
bpy.context.scene.collection.objects.link(scene)
|
383 |
+
|
384 |
+
# parent all root objects to the empty object
|
385 |
+
for obj in scene_root_objects:
|
386 |
+
obj.parent = scene
|
387 |
+
else:
|
388 |
+
scene = scene_root_objects[0]
|
389 |
+
|
390 |
+
bbox_min, bbox_max = scene_bbox()
|
391 |
+
scale = 1 / max(bbox_max - bbox_min)
|
392 |
+
scene.scale = scene.scale * scale
|
393 |
+
|
394 |
+
# Apply scale to matrix_world.
|
395 |
+
bpy.context.view_layer.update()
|
396 |
+
bbox_min, bbox_max = scene_bbox()
|
397 |
+
offset = -(bbox_min + bbox_max) / 2
|
398 |
+
scene.matrix_world.translation += offset
|
399 |
+
bpy.ops.object.select_all(action="DESELECT")
|
400 |
+
|
401 |
+
return scale, offset
|
402 |
+
|
403 |
+
def get_transform_matrix(obj: bpy.types.Object) -> list:
|
404 |
+
pos, rt, _ = obj.matrix_world.decompose()
|
405 |
+
rt = rt.to_matrix()
|
406 |
+
matrix = []
|
407 |
+
for ii in range(3):
|
408 |
+
a = []
|
409 |
+
for jj in range(3):
|
410 |
+
a.append(rt[ii][jj])
|
411 |
+
a.append(pos[ii])
|
412 |
+
matrix.append(a)
|
413 |
+
matrix.append([0, 0, 0, 1])
|
414 |
+
return matrix
|
415 |
+
|
416 |
+
def main(arg):
|
417 |
+
os.makedirs(arg.output_folder, exist_ok=True)
|
418 |
+
|
419 |
+
# Initialize context
|
420 |
+
init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode)
|
421 |
+
outputs, spec_nodes = init_nodes(
|
422 |
+
save_depth=arg.save_depth,
|
423 |
+
save_normal=arg.save_normal,
|
424 |
+
save_albedo=arg.save_albedo,
|
425 |
+
save_mist=arg.save_mist
|
426 |
+
)
|
427 |
+
if arg.object.endswith(".blend"):
|
428 |
+
delete_invisible_objects()
|
429 |
+
else:
|
430 |
+
init_scene()
|
431 |
+
load_object(arg.object)
|
432 |
+
if arg.split_normal:
|
433 |
+
split_mesh_normal()
|
434 |
+
# delete_custom_normals()
|
435 |
+
print('[INFO] Scene initialized.')
|
436 |
+
|
437 |
+
# normalize scene
|
438 |
+
scale, offset = normalize_scene()
|
439 |
+
print('[INFO] Scene normalized.')
|
440 |
+
|
441 |
+
# Initialize camera and lighting
|
442 |
+
cam = init_camera()
|
443 |
+
init_lighting()
|
444 |
+
print('[INFO] Camera and lighting initialized.')
|
445 |
+
|
446 |
+
# Override material
|
447 |
+
if arg.geo_mode:
|
448 |
+
override_material()
|
449 |
+
|
450 |
+
# Create a list of views
|
451 |
+
to_export = {
|
452 |
+
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
453 |
+
"scale": scale,
|
454 |
+
"offset": [offset.x, offset.y, offset.z],
|
455 |
+
"frames": []
|
456 |
+
}
|
457 |
+
views = json.loads(arg.views)
|
458 |
+
for i, view in enumerate(views):
|
459 |
+
cam.location = (
|
460 |
+
view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']),
|
461 |
+
view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']),
|
462 |
+
view['radius'] * np.sin(view['pitch'])
|
463 |
+
)
|
464 |
+
cam.data.lens = 16 / np.tan(view['fov'] / 2)
|
465 |
+
|
466 |
+
if arg.save_depth:
|
467 |
+
spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3)
|
468 |
+
spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3)
|
469 |
+
|
470 |
+
bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png')
|
471 |
+
for name, output in outputs.items():
|
472 |
+
output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}')
|
473 |
+
|
474 |
+
# Render the scene
|
475 |
+
bpy.ops.render.render(write_still=True)
|
476 |
+
bpy.context.view_layer.update()
|
477 |
+
for name, output in outputs.items():
|
478 |
+
ext = EXT[output.format.file_format]
|
479 |
+
path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0]
|
480 |
+
os.rename(path, f'{output.file_slots[0].path}.{ext}')
|
481 |
+
|
482 |
+
# Save camera parameters
|
483 |
+
metadata = {
|
484 |
+
"file_path": f'{i:03d}.png',
|
485 |
+
"camera_angle_x": view['fov'],
|
486 |
+
"transform_matrix": get_transform_matrix(cam)
|
487 |
+
}
|
488 |
+
if arg.save_depth:
|
489 |
+
metadata['depth'] = {
|
490 |
+
'min': view['radius'] - 0.5 * np.sqrt(3),
|
491 |
+
'max': view['radius'] + 0.5 * np.sqrt(3)
|
492 |
+
}
|
493 |
+
to_export["frames"].append(metadata)
|
494 |
+
|
495 |
+
# Save the camera parameters
|
496 |
+
with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f:
|
497 |
+
json.dump(to_export, f, indent=4)
|
498 |
+
|
499 |
+
if arg.save_mesh:
|
500 |
+
# triangulate meshes
|
501 |
+
unhide_all_objects()
|
502 |
+
convert_to_meshes()
|
503 |
+
triangulate_meshes()
|
504 |
+
print('[INFO] Meshes triangulated.')
|
505 |
+
|
506 |
+
# export ply mesh
|
507 |
+
bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply'))
|
508 |
+
|
509 |
+
|
510 |
+
if __name__ == '__main__':
|
511 |
+
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
512 |
+
parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
|
513 |
+
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
|
514 |
+
parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
|
515 |
+
parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.')
|
516 |
+
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
|
517 |
+
parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.')
|
518 |
+
parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.')
|
519 |
+
parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.')
|
520 |
+
parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.')
|
521 |
+
parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.')
|
522 |
+
parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.')
|
523 |
+
parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.')
|
524 |
+
argv = sys.argv[sys.argv.index("--") + 1:]
|
525 |
+
args = parser.parse_args(argv)
|
526 |
+
|
527 |
+
main(args)
|
528 |
+
|
dataset_toolkits/build_metadata.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from tqdm import tqdm
|
10 |
+
from easydict import EasyDict as edict
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
import utils3d
|
13 |
+
|
14 |
+
def get_first_directory(path):
|
15 |
+
with os.scandir(path) as it:
|
16 |
+
for entry in it:
|
17 |
+
if entry.is_dir():
|
18 |
+
return entry.name
|
19 |
+
return None
|
20 |
+
|
21 |
+
def need_process(key):
|
22 |
+
return key in opt.field or opt.field == ['all']
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
29 |
+
help='Directory to save the metadata')
|
30 |
+
parser.add_argument('--field', type=str, default='all',
|
31 |
+
help='Fields to process, separated by commas')
|
32 |
+
parser.add_argument('--from_file', action='store_true',
|
33 |
+
help='Build metadata from file instead of from records of processings.' +
|
34 |
+
'Useful when some processing fail to generate records but file already exists.')
|
35 |
+
dataset_utils.add_args(parser)
|
36 |
+
opt = parser.parse_args(sys.argv[2:])
|
37 |
+
opt = edict(vars(opt))
|
38 |
+
|
39 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
40 |
+
os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True)
|
41 |
+
|
42 |
+
opt.field = opt.field.split(',')
|
43 |
+
|
44 |
+
timestamp = str(int(time.time()))
|
45 |
+
|
46 |
+
# get file list
|
47 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
48 |
+
print('Loading previous metadata...')
|
49 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
50 |
+
else:
|
51 |
+
metadata = dataset_utils.get_metadata(**opt)
|
52 |
+
metadata.set_index('sha256', inplace=True)
|
53 |
+
|
54 |
+
# merge downloaded
|
55 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')]
|
56 |
+
df_parts = []
|
57 |
+
for f in df_files:
|
58 |
+
try:
|
59 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
60 |
+
except:
|
61 |
+
pass
|
62 |
+
if len(df_parts) > 0:
|
63 |
+
df = pd.concat(df_parts)
|
64 |
+
df.set_index('sha256', inplace=True)
|
65 |
+
if 'local_path' in metadata.columns:
|
66 |
+
metadata.update(df, overwrite=True)
|
67 |
+
else:
|
68 |
+
metadata = metadata.join(df, on='sha256', how='left')
|
69 |
+
for f in df_files:
|
70 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
71 |
+
|
72 |
+
# detect models
|
73 |
+
image_models = []
|
74 |
+
if os.path.exists(os.path.join(opt.output_dir, 'features')):
|
75 |
+
image_models = os.listdir(os.path.join(opt.output_dir, 'features'))
|
76 |
+
latent_models = []
|
77 |
+
if os.path.exists(os.path.join(opt.output_dir, 'latents')):
|
78 |
+
latent_models = os.listdir(os.path.join(opt.output_dir, 'latents'))
|
79 |
+
ss_latent_models = []
|
80 |
+
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')):
|
81 |
+
ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents'))
|
82 |
+
print(f'Image models: {image_models}')
|
83 |
+
print(f'Latent models: {latent_models}')
|
84 |
+
print(f'Sparse Structure latent models: {ss_latent_models}')
|
85 |
+
|
86 |
+
if 'rendered' not in metadata.columns:
|
87 |
+
metadata['rendered'] = [False] * len(metadata)
|
88 |
+
if 'voxelized' not in metadata.columns:
|
89 |
+
metadata['voxelized'] = [False] * len(metadata)
|
90 |
+
if 'num_voxels' not in metadata.columns:
|
91 |
+
metadata['num_voxels'] = [0] * len(metadata)
|
92 |
+
if 'cond_rendered' not in metadata.columns:
|
93 |
+
metadata['cond_rendered'] = [False] * len(metadata)
|
94 |
+
for model in image_models:
|
95 |
+
if f'feature_{model}' not in metadata.columns:
|
96 |
+
metadata[f'feature_{model}'] = [False] * len(metadata)
|
97 |
+
for model in latent_models:
|
98 |
+
if f'latent_{model}' not in metadata.columns:
|
99 |
+
metadata[f'latent_{model}'] = [False] * len(metadata)
|
100 |
+
for model in ss_latent_models:
|
101 |
+
if f'ss_latent_{model}' not in metadata.columns:
|
102 |
+
metadata[f'ss_latent_{model}'] = [False] * len(metadata)
|
103 |
+
|
104 |
+
# merge rendered
|
105 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')]
|
106 |
+
df_parts = []
|
107 |
+
for f in df_files:
|
108 |
+
try:
|
109 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
110 |
+
except:
|
111 |
+
pass
|
112 |
+
if len(df_parts) > 0:
|
113 |
+
df = pd.concat(df_parts)
|
114 |
+
df.set_index('sha256', inplace=True)
|
115 |
+
metadata.update(df, overwrite=True)
|
116 |
+
for f in df_files:
|
117 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
118 |
+
|
119 |
+
# merge voxelized
|
120 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')]
|
121 |
+
df_parts = []
|
122 |
+
for f in df_files:
|
123 |
+
try:
|
124 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
125 |
+
except:
|
126 |
+
pass
|
127 |
+
if len(df_parts) > 0:
|
128 |
+
df = pd.concat(df_parts)
|
129 |
+
df.set_index('sha256', inplace=True)
|
130 |
+
metadata.update(df, overwrite=True)
|
131 |
+
for f in df_files:
|
132 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
133 |
+
|
134 |
+
# merge cond_rendered
|
135 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')]
|
136 |
+
df_parts = []
|
137 |
+
for f in df_files:
|
138 |
+
try:
|
139 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
140 |
+
except:
|
141 |
+
pass
|
142 |
+
if len(df_parts) > 0:
|
143 |
+
df = pd.concat(df_parts)
|
144 |
+
df.set_index('sha256', inplace=True)
|
145 |
+
metadata.update(df, overwrite=True)
|
146 |
+
for f in df_files:
|
147 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
148 |
+
|
149 |
+
# merge features
|
150 |
+
for model in image_models:
|
151 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')]
|
152 |
+
df_parts = []
|
153 |
+
for f in df_files:
|
154 |
+
try:
|
155 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
156 |
+
except:
|
157 |
+
pass
|
158 |
+
if len(df_parts) > 0:
|
159 |
+
df = pd.concat(df_parts)
|
160 |
+
df.set_index('sha256', inplace=True)
|
161 |
+
metadata.update(df, overwrite=True)
|
162 |
+
for f in df_files:
|
163 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
164 |
+
|
165 |
+
# merge latents
|
166 |
+
for model in latent_models:
|
167 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')]
|
168 |
+
df_parts = []
|
169 |
+
for f in df_files:
|
170 |
+
try:
|
171 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
172 |
+
except:
|
173 |
+
pass
|
174 |
+
if len(df_parts) > 0:
|
175 |
+
df = pd.concat(df_parts)
|
176 |
+
df.set_index('sha256', inplace=True)
|
177 |
+
metadata.update(df, overwrite=True)
|
178 |
+
for f in df_files:
|
179 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
180 |
+
|
181 |
+
# merge sparse structure latents
|
182 |
+
for model in ss_latent_models:
|
183 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')]
|
184 |
+
df_parts = []
|
185 |
+
for f in df_files:
|
186 |
+
try:
|
187 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
188 |
+
except:
|
189 |
+
pass
|
190 |
+
if len(df_parts) > 0:
|
191 |
+
df = pd.concat(df_parts)
|
192 |
+
df.set_index('sha256', inplace=True)
|
193 |
+
metadata.update(df, overwrite=True)
|
194 |
+
for f in df_files:
|
195 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
196 |
+
|
197 |
+
# build metadata from files
|
198 |
+
if opt.from_file:
|
199 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
200 |
+
tqdm(total=len(metadata), desc="Building metadata") as pbar:
|
201 |
+
def worker(sha256):
|
202 |
+
try:
|
203 |
+
if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \
|
204 |
+
os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
205 |
+
metadata.loc[sha256, 'rendered'] = True
|
206 |
+
if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \
|
207 |
+
os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
208 |
+
try:
|
209 |
+
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
210 |
+
metadata.loc[sha256, 'voxelized'] = True
|
211 |
+
metadata.loc[sha256, 'num_voxels'] = len(pts)
|
212 |
+
except Exception as e:
|
213 |
+
pass
|
214 |
+
if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \
|
215 |
+
os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
216 |
+
metadata.loc[sha256, 'cond_rendered'] = True
|
217 |
+
for model in image_models:
|
218 |
+
if need_process(f'feature_{model}') and \
|
219 |
+
metadata.loc[sha256, f'feature_{model}'] == False and \
|
220 |
+
metadata.loc[sha256, 'rendered'] == True and \
|
221 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
222 |
+
os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')):
|
223 |
+
metadata.loc[sha256, f'feature_{model}'] = True
|
224 |
+
for model in latent_models:
|
225 |
+
if need_process(f'latent_{model}') and \
|
226 |
+
metadata.loc[sha256, f'latent_{model}'] == False and \
|
227 |
+
metadata.loc[sha256, 'rendered'] == True and \
|
228 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
229 |
+
os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')):
|
230 |
+
metadata.loc[sha256, f'latent_{model}'] = True
|
231 |
+
for model in ss_latent_models:
|
232 |
+
if need_process(f'ss_latent_{model}') and \
|
233 |
+
metadata.loc[sha256, f'ss_latent_{model}'] == False and \
|
234 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
235 |
+
os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')):
|
236 |
+
metadata.loc[sha256, f'ss_latent_{model}'] = True
|
237 |
+
pbar.update()
|
238 |
+
except Exception as e:
|
239 |
+
print(f'Error processing {sha256}: {e}')
|
240 |
+
pbar.update()
|
241 |
+
|
242 |
+
executor.map(worker, metadata.index)
|
243 |
+
executor.shutdown(wait=True)
|
244 |
+
|
245 |
+
# statistics
|
246 |
+
metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
247 |
+
num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0
|
248 |
+
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f:
|
249 |
+
f.write('Statistics:\n')
|
250 |
+
f.write(f' - Number of assets: {len(metadata)}\n')
|
251 |
+
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
|
252 |
+
f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n')
|
253 |
+
f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n')
|
254 |
+
if len(image_models) != 0:
|
255 |
+
f.write(f' - Number of assets with image features extracted:\n')
|
256 |
+
for model in image_models:
|
257 |
+
f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n')
|
258 |
+
if len(latent_models) != 0:
|
259 |
+
f.write(f' - Number of assets with latents extracted:\n')
|
260 |
+
for model in latent_models:
|
261 |
+
f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n')
|
262 |
+
if len(ss_latent_models) != 0:
|
263 |
+
f.write(f' - Number of assets with sparse structure latents extracted:\n')
|
264 |
+
for model in ss_latent_models:
|
265 |
+
f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n')
|
266 |
+
f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n')
|
267 |
+
f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n')
|
268 |
+
|
269 |
+
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f:
|
270 |
+
print(f.read())
|
dataset_toolkits/datasets/3D-FUTURE.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import zipfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
from utils import get_file_hash
|
9 |
+
|
10 |
+
|
11 |
+
def add_args(parser: argparse.ArgumentParser):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(**kwargs):
|
16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv")
|
17 |
+
return metadata
|
18 |
+
|
19 |
+
|
20 |
+
def download(metadata, output_dir, **kwargs):
|
21 |
+
os.makedirs(output_dir, exist_ok=True)
|
22 |
+
|
23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')):
|
24 |
+
print("\033[93m")
|
25 |
+
print("3D-FUTURE have to be downloaded manually")
|
26 |
+
print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory")
|
27 |
+
print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information")
|
28 |
+
print("\033[0m")
|
29 |
+
raise FileNotFoundError("3D-FUTURE-model.zip not found")
|
30 |
+
|
31 |
+
downloaded = {}
|
32 |
+
metadata = metadata.set_index("file_identifier")
|
33 |
+
with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref:
|
34 |
+
all_names = zip_ref.namelist()
|
35 |
+
instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)]
|
36 |
+
instances = list(filter(lambda x: x in metadata.index, instances))
|
37 |
+
|
38 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
39 |
+
tqdm(total=len(instances), desc="Extracting") as pbar:
|
40 |
+
def worker(instance: str) -> str:
|
41 |
+
try:
|
42 |
+
instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names))
|
43 |
+
zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files)
|
44 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg"))
|
45 |
+
pbar.update()
|
46 |
+
return sha256
|
47 |
+
except Exception as e:
|
48 |
+
pbar.update()
|
49 |
+
print(f"Error extracting for {instance}: {e}")
|
50 |
+
return None
|
51 |
+
|
52 |
+
sha256s = executor.map(worker, instances)
|
53 |
+
executor.shutdown(wait=True)
|
54 |
+
|
55 |
+
for k, sha256 in zip(instances, sha256s):
|
56 |
+
if sha256 is not None:
|
57 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
58 |
+
downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj")
|
59 |
+
else:
|
60 |
+
print(f"Error downloading {k}: sha256s do not match")
|
61 |
+
|
62 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
63 |
+
|
64 |
+
|
65 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
66 |
+
import os
|
67 |
+
from concurrent.futures import ThreadPoolExecutor
|
68 |
+
from tqdm import tqdm
|
69 |
+
|
70 |
+
# load metadata
|
71 |
+
metadata = metadata.to_dict('records')
|
72 |
+
|
73 |
+
# processing objects
|
74 |
+
records = []
|
75 |
+
max_workers = max_workers or os.cpu_count()
|
76 |
+
try:
|
77 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
78 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
79 |
+
def worker(metadatum):
|
80 |
+
try:
|
81 |
+
local_path = metadatum['local_path']
|
82 |
+
sha256 = metadatum['sha256']
|
83 |
+
file = os.path.join(output_dir, local_path)
|
84 |
+
record = func(file, sha256)
|
85 |
+
if record is not None:
|
86 |
+
records.append(record)
|
87 |
+
pbar.update()
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error processing object {sha256}: {e}")
|
90 |
+
pbar.update()
|
91 |
+
|
92 |
+
executor.map(worker, metadata)
|
93 |
+
executor.shutdown(wait=True)
|
94 |
+
except:
|
95 |
+
print("Error happened during processing.")
|
96 |
+
|
97 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/ABO.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import tarfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
from utils import get_file_hash
|
9 |
+
|
10 |
+
|
11 |
+
def add_args(parser: argparse.ArgumentParser):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(**kwargs):
|
16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv")
|
17 |
+
return metadata
|
18 |
+
|
19 |
+
|
20 |
+
def download(metadata, output_dir, **kwargs):
|
21 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
22 |
+
|
23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')):
|
24 |
+
try:
|
25 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
26 |
+
os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar")
|
27 |
+
except:
|
28 |
+
print("\033[93m")
|
29 |
+
print("Error downloading ABO dataset. Please check your internet connection and try again.")
|
30 |
+
print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory")
|
31 |
+
print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information")
|
32 |
+
print("\033[0m")
|
33 |
+
raise FileNotFoundError("Error downloading ABO dataset")
|
34 |
+
|
35 |
+
downloaded = {}
|
36 |
+
metadata = metadata.set_index("file_identifier")
|
37 |
+
with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar:
|
38 |
+
with ThreadPoolExecutor(max_workers=1) as executor, \
|
39 |
+
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
40 |
+
def worker(instance: str) -> str:
|
41 |
+
try:
|
42 |
+
tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw'))
|
43 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance))
|
44 |
+
pbar.update()
|
45 |
+
return sha256
|
46 |
+
except Exception as e:
|
47 |
+
pbar.update()
|
48 |
+
print(f"Error extracting for {instance}: {e}")
|
49 |
+
return None
|
50 |
+
|
51 |
+
sha256s = executor.map(worker, metadata.index)
|
52 |
+
executor.shutdown(wait=True)
|
53 |
+
|
54 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
55 |
+
if sha256 is not None:
|
56 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
57 |
+
downloaded[sha256] = os.path.join('raw/3dmodels/original', k)
|
58 |
+
else:
|
59 |
+
print(f"Error downloading {k}: sha256s do not match")
|
60 |
+
|
61 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
62 |
+
|
63 |
+
|
64 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
65 |
+
import os
|
66 |
+
from concurrent.futures import ThreadPoolExecutor
|
67 |
+
from tqdm import tqdm
|
68 |
+
|
69 |
+
# load metadata
|
70 |
+
metadata = metadata.to_dict('records')
|
71 |
+
|
72 |
+
# processing objects
|
73 |
+
records = []
|
74 |
+
max_workers = max_workers or os.cpu_count()
|
75 |
+
try:
|
76 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
77 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
78 |
+
def worker(metadatum):
|
79 |
+
try:
|
80 |
+
local_path = metadatum['local_path']
|
81 |
+
sha256 = metadatum['sha256']
|
82 |
+
file = os.path.join(output_dir, local_path)
|
83 |
+
record = func(file, sha256)
|
84 |
+
if record is not None:
|
85 |
+
records.append(record)
|
86 |
+
pbar.update()
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Error processing object {sha256}: {e}")
|
89 |
+
pbar.update()
|
90 |
+
|
91 |
+
executor.map(worker, metadata)
|
92 |
+
executor.shutdown(wait=True)
|
93 |
+
except:
|
94 |
+
print("Error happened during processing.")
|
95 |
+
|
96 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/HSSD.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import tarfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
import huggingface_hub
|
9 |
+
from utils import get_file_hash
|
10 |
+
|
11 |
+
|
12 |
+
def add_args(parser: argparse.ArgumentParser):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
def get_metadata(**kwargs):
|
17 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv")
|
18 |
+
return metadata
|
19 |
+
|
20 |
+
|
21 |
+
def download(metadata, output_dir, **kwargs):
|
22 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
23 |
+
|
24 |
+
# check login
|
25 |
+
try:
|
26 |
+
huggingface_hub.whoami()
|
27 |
+
except:
|
28 |
+
print("\033[93m")
|
29 |
+
print("Haven't logged in to the Hugging Face Hub.")
|
30 |
+
print("Visit https://huggingface.co/settings/tokens to get a token.")
|
31 |
+
print("\033[0m")
|
32 |
+
huggingface_hub.login()
|
33 |
+
|
34 |
+
try:
|
35 |
+
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset")
|
36 |
+
except:
|
37 |
+
print("\033[93m")
|
38 |
+
print("Error downloading HSSD dataset.")
|
39 |
+
print("Check if you have access to the HSSD dataset.")
|
40 |
+
print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information")
|
41 |
+
print("\033[0m")
|
42 |
+
|
43 |
+
downloaded = {}
|
44 |
+
metadata = metadata.set_index("file_identifier")
|
45 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
46 |
+
tqdm(total=len(metadata), desc="Downloading") as pbar:
|
47 |
+
def worker(instance: str) -> str:
|
48 |
+
try:
|
49 |
+
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw'))
|
50 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance))
|
51 |
+
pbar.update()
|
52 |
+
return sha256
|
53 |
+
except Exception as e:
|
54 |
+
pbar.update()
|
55 |
+
print(f"Error extracting for {instance}: {e}")
|
56 |
+
return None
|
57 |
+
|
58 |
+
sha256s = executor.map(worker, metadata.index)
|
59 |
+
executor.shutdown(wait=True)
|
60 |
+
|
61 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
62 |
+
if sha256 is not None:
|
63 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
64 |
+
downloaded[sha256] = os.path.join('raw', k)
|
65 |
+
else:
|
66 |
+
print(f"Error downloading {k}: sha256s do not match")
|
67 |
+
|
68 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
69 |
+
|
70 |
+
|
71 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
72 |
+
import os
|
73 |
+
from concurrent.futures import ThreadPoolExecutor
|
74 |
+
from tqdm import tqdm
|
75 |
+
|
76 |
+
# load metadata
|
77 |
+
metadata = metadata.to_dict('records')
|
78 |
+
|
79 |
+
# processing objects
|
80 |
+
records = []
|
81 |
+
max_workers = max_workers or os.cpu_count()
|
82 |
+
try:
|
83 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
84 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
85 |
+
def worker(metadatum):
|
86 |
+
try:
|
87 |
+
local_path = metadatum['local_path']
|
88 |
+
sha256 = metadatum['sha256']
|
89 |
+
file = os.path.join(output_dir, local_path)
|
90 |
+
record = func(file, sha256)
|
91 |
+
if record is not None:
|
92 |
+
records.append(record)
|
93 |
+
pbar.update()
|
94 |
+
except Exception as e:
|
95 |
+
print(f"Error processing object {sha256}: {e}")
|
96 |
+
pbar.update()
|
97 |
+
|
98 |
+
executor.map(worker, metadata)
|
99 |
+
executor.shutdown(wait=True)
|
100 |
+
except:
|
101 |
+
print("Error happened during processing.")
|
102 |
+
|
103 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/ObjaverseXL.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from tqdm import tqdm
|
5 |
+
import pandas as pd
|
6 |
+
import objaverse.xl as oxl
|
7 |
+
from utils import get_file_hash
|
8 |
+
|
9 |
+
|
10 |
+
def add_args(parser: argparse.ArgumentParser):
|
11 |
+
parser.add_argument('--source', type=str, default='sketchfab',
|
12 |
+
help='Data source to download annotations from (github, sketchfab)')
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(source, **kwargs):
|
16 |
+
if source == 'sketchfab':
|
17 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv")
|
18 |
+
elif source == 'github':
|
19 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv")
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Invalid source: {source}")
|
22 |
+
return metadata
|
23 |
+
|
24 |
+
|
25 |
+
def download(metadata, output_dir, **kwargs):
|
26 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
27 |
+
|
28 |
+
# download annotations
|
29 |
+
annotations = oxl.get_annotations()
|
30 |
+
annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)]
|
31 |
+
|
32 |
+
# download and render objects
|
33 |
+
file_paths = oxl.download_objects(
|
34 |
+
annotations,
|
35 |
+
download_dir=os.path.join(output_dir, "raw"),
|
36 |
+
save_repo_format="zip",
|
37 |
+
)
|
38 |
+
|
39 |
+
downloaded = {}
|
40 |
+
metadata = metadata.set_index("file_identifier")
|
41 |
+
for k, v in file_paths.items():
|
42 |
+
sha256 = metadata.loc[k, "sha256"]
|
43 |
+
downloaded[sha256] = os.path.relpath(v, output_dir)
|
44 |
+
|
45 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
46 |
+
|
47 |
+
|
48 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
49 |
+
import os
|
50 |
+
from concurrent.futures import ThreadPoolExecutor
|
51 |
+
from tqdm import tqdm
|
52 |
+
import tempfile
|
53 |
+
import zipfile
|
54 |
+
|
55 |
+
# load metadata
|
56 |
+
metadata = metadata.to_dict('records')
|
57 |
+
|
58 |
+
# processing objects
|
59 |
+
records = []
|
60 |
+
max_workers = max_workers or os.cpu_count()
|
61 |
+
try:
|
62 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
63 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
64 |
+
def worker(metadatum):
|
65 |
+
try:
|
66 |
+
local_path = metadatum['local_path']
|
67 |
+
sha256 = metadatum['sha256']
|
68 |
+
if local_path.startswith('raw/github/repos/'):
|
69 |
+
path_parts = local_path.split('/')
|
70 |
+
file_name = os.path.join(*path_parts[5:])
|
71 |
+
zip_file = os.path.join(output_dir, *path_parts[:5])
|
72 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
73 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
74 |
+
zip_ref.extractall(tmp_dir)
|
75 |
+
file = os.path.join(tmp_dir, file_name)
|
76 |
+
record = func(file, sha256)
|
77 |
+
else:
|
78 |
+
file = os.path.join(output_dir, local_path)
|
79 |
+
record = func(file, sha256)
|
80 |
+
if record is not None:
|
81 |
+
records.append(record)
|
82 |
+
pbar.update()
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Error processing object {sha256}: {e}")
|
85 |
+
pbar.update()
|
86 |
+
|
87 |
+
executor.map(worker, metadata)
|
88 |
+
executor.shutdown(wait=True)
|
89 |
+
except:
|
90 |
+
print("Error happened during processing.")
|
91 |
+
|
92 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/Toys4k.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import zipfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
from utils import get_file_hash
|
9 |
+
|
10 |
+
|
11 |
+
def add_args(parser: argparse.ArgumentParser):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(**kwargs):
|
16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv")
|
17 |
+
return metadata
|
18 |
+
|
19 |
+
|
20 |
+
def download(metadata, output_dir, **kwargs):
|
21 |
+
os.makedirs(output_dir, exist_ok=True)
|
22 |
+
|
23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')):
|
24 |
+
print("\033[93m")
|
25 |
+
print("Toys4k have to be downloaded manually")
|
26 |
+
print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory")
|
27 |
+
print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information")
|
28 |
+
print("\033[0m")
|
29 |
+
raise FileNotFoundError("toys4k_blend_files.zip not found")
|
30 |
+
|
31 |
+
downloaded = {}
|
32 |
+
metadata = metadata.set_index("file_identifier")
|
33 |
+
with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref:
|
34 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
35 |
+
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
36 |
+
def worker(instance: str) -> str:
|
37 |
+
try:
|
38 |
+
zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw'))
|
39 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance))
|
40 |
+
pbar.update()
|
41 |
+
return sha256
|
42 |
+
except Exception as e:
|
43 |
+
pbar.update()
|
44 |
+
print(f"Error extracting for {instance}: {e}")
|
45 |
+
return None
|
46 |
+
|
47 |
+
sha256s = executor.map(worker, metadata.index)
|
48 |
+
executor.shutdown(wait=True)
|
49 |
+
|
50 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
51 |
+
if sha256 is not None:
|
52 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
53 |
+
downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k)
|
54 |
+
else:
|
55 |
+
print(f"Error downloading {k}: sha256s do not match")
|
56 |
+
|
57 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
58 |
+
|
59 |
+
|
60 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
61 |
+
import os
|
62 |
+
from concurrent.futures import ThreadPoolExecutor
|
63 |
+
from tqdm import tqdm
|
64 |
+
|
65 |
+
# load metadata
|
66 |
+
metadata = metadata.to_dict('records')
|
67 |
+
|
68 |
+
# processing objects
|
69 |
+
records = []
|
70 |
+
max_workers = max_workers or os.cpu_count()
|
71 |
+
try:
|
72 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
73 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
74 |
+
def worker(metadatum):
|
75 |
+
try:
|
76 |
+
local_path = metadatum['local_path']
|
77 |
+
sha256 = metadatum['sha256']
|
78 |
+
file = os.path.join(output_dir, local_path)
|
79 |
+
record = func(file, sha256)
|
80 |
+
if record is not None:
|
81 |
+
records.append(record)
|
82 |
+
pbar.update()
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Error processing object {sha256}: {e}")
|
85 |
+
pbar.update()
|
86 |
+
|
87 |
+
executor.map(worker, metadata)
|
88 |
+
executor.shutdown(wait=True)
|
89 |
+
except:
|
90 |
+
print("Error happened during processing.")
|
91 |
+
|
92 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/download.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import sys
|
4 |
+
import importlib
|
5 |
+
import argparse
|
6 |
+
import pandas as pd
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
14 |
+
help='Directory to save the metadata')
|
15 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
16 |
+
help='Filter objects with aesthetic score lower than this value')
|
17 |
+
parser.add_argument('--instances', type=str, default=None,
|
18 |
+
help='Instances to process')
|
19 |
+
dataset_utils.add_args(parser)
|
20 |
+
parser.add_argument('--rank', type=int, default=0)
|
21 |
+
parser.add_argument('--world_size', type=int, default=1)
|
22 |
+
opt = parser.parse_args(sys.argv[2:])
|
23 |
+
opt = edict(vars(opt))
|
24 |
+
|
25 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
26 |
+
|
27 |
+
# get file list
|
28 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
29 |
+
raise ValueError('metadata.csv not found')
|
30 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
31 |
+
if opt.instances is None:
|
32 |
+
if opt.filter_low_aesthetic_score is not None:
|
33 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
34 |
+
if 'local_path' in metadata.columns:
|
35 |
+
metadata = metadata[metadata['local_path'].isna()]
|
36 |
+
else:
|
37 |
+
if os.path.exists(opt.instances):
|
38 |
+
with open(opt.instances, 'r') as f:
|
39 |
+
instances = f.read().splitlines()
|
40 |
+
else:
|
41 |
+
instances = opt.instances.split(',')
|
42 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
43 |
+
|
44 |
+
start = len(metadata) * opt.rank // opt.world_size
|
45 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
46 |
+
metadata = metadata[start:end]
|
47 |
+
|
48 |
+
print(f'Processing {len(metadata)} objects...')
|
49 |
+
|
50 |
+
# process objects
|
51 |
+
downloaded = dataset_utils.download(metadata, **opt)
|
52 |
+
downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False)
|
dataset_toolkits/encode_latent.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
4 |
+
import copy
|
5 |
+
import json
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
from easydict import EasyDict as edict
|
12 |
+
from concurrent.futures import ThreadPoolExecutor
|
13 |
+
from queue import Queue
|
14 |
+
|
15 |
+
import trellis.models as models
|
16 |
+
import trellis.modules.sparse as sp
|
17 |
+
|
18 |
+
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
25 |
+
help='Directory to save the metadata')
|
26 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
27 |
+
help='Filter objects with aesthetic score lower than this value')
|
28 |
+
parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
|
29 |
+
help='Feature model')
|
30 |
+
parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16',
|
31 |
+
help='Pretrained encoder model')
|
32 |
+
parser.add_argument('--model_root', type=str, default='results',
|
33 |
+
help='Root directory of models')
|
34 |
+
parser.add_argument('--enc_model', type=str, default=None,
|
35 |
+
help='Encoder model. if specified, use this model instead of pretrained model')
|
36 |
+
parser.add_argument('--ckpt', type=str, default=None,
|
37 |
+
help='Checkpoint to load')
|
38 |
+
parser.add_argument('--instances', type=str, default=None,
|
39 |
+
help='Instances to process')
|
40 |
+
parser.add_argument('--rank', type=int, default=0)
|
41 |
+
parser.add_argument('--world_size', type=int, default=1)
|
42 |
+
opt = parser.parse_args()
|
43 |
+
opt = edict(vars(opt))
|
44 |
+
|
45 |
+
if opt.enc_model is None:
|
46 |
+
latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
|
47 |
+
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
48 |
+
else:
|
49 |
+
latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
|
50 |
+
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
51 |
+
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
52 |
+
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
53 |
+
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
54 |
+
encoder.eval()
|
55 |
+
print(f'Loaded model from {ckpt_path}')
|
56 |
+
|
57 |
+
os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
|
58 |
+
|
59 |
+
# get file list
|
60 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
61 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
62 |
+
else:
|
63 |
+
raise ValueError('metadata.csv not found')
|
64 |
+
if opt.instances is not None:
|
65 |
+
with open(opt.instances, 'r') as f:
|
66 |
+
sha256s = [line.strip() for line in f]
|
67 |
+
metadata = metadata[metadata['sha256'].isin(sha256s)]
|
68 |
+
else:
|
69 |
+
if opt.filter_low_aesthetic_score is not None:
|
70 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
71 |
+
metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
|
72 |
+
if f'latent_{latent_name}' in metadata.columns:
|
73 |
+
metadata = metadata[metadata[f'latent_{latent_name}'] == False]
|
74 |
+
|
75 |
+
start = len(metadata) * opt.rank // opt.world_size
|
76 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
77 |
+
metadata = metadata[start:end]
|
78 |
+
records = []
|
79 |
+
|
80 |
+
# filter out objects that are already processed
|
81 |
+
sha256s = list(metadata['sha256'].values)
|
82 |
+
for sha256 in copy.copy(sha256s):
|
83 |
+
if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
|
84 |
+
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
85 |
+
sha256s.remove(sha256)
|
86 |
+
|
87 |
+
# encode latents
|
88 |
+
load_queue = Queue(maxsize=4)
|
89 |
+
try:
|
90 |
+
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
91 |
+
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
92 |
+
def loader(sha256):
|
93 |
+
try:
|
94 |
+
feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
|
95 |
+
load_queue.put((sha256, feats))
|
96 |
+
except Exception as e:
|
97 |
+
print(f"Error loading features for {sha256}: {e}")
|
98 |
+
loader_executor.map(loader, sha256s)
|
99 |
+
|
100 |
+
def saver(sha256, pack):
|
101 |
+
save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
|
102 |
+
np.savez_compressed(save_path, **pack)
|
103 |
+
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
104 |
+
|
105 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
106 |
+
sha256, feats = load_queue.get()
|
107 |
+
feats = sp.SparseTensor(
|
108 |
+
feats = torch.from_numpy(feats['patchtokens']).float(),
|
109 |
+
coords = torch.cat([
|
110 |
+
torch.zeros(feats['patchtokens'].shape[0], 1).int(),
|
111 |
+
torch.from_numpy(feats['indices']).int(),
|
112 |
+
], dim=1),
|
113 |
+
).cuda()
|
114 |
+
latent = encoder(feats, sample_posterior=False)
|
115 |
+
assert torch.isfinite(latent.feats).all(), "Non-finite latent"
|
116 |
+
pack = {
|
117 |
+
'feats': latent.feats.cpu().numpy().astype(np.float32),
|
118 |
+
'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
|
119 |
+
}
|
120 |
+
saver_executor.submit(saver, sha256, pack)
|
121 |
+
|
122 |
+
saver_executor.shutdown(wait=True)
|
123 |
+
except:
|
124 |
+
print("Error happened during processing.")
|
125 |
+
|
126 |
+
records = pd.DataFrame.from_records(records)
|
127 |
+
records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)
|
dataset_toolkits/encode_ss_latent.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
4 |
+
import copy
|
5 |
+
import json
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import utils3d
|
11 |
+
from tqdm import tqdm
|
12 |
+
from easydict import EasyDict as edict
|
13 |
+
from concurrent.futures import ThreadPoolExecutor
|
14 |
+
from queue import Queue
|
15 |
+
|
16 |
+
import trellis.models as models
|
17 |
+
|
18 |
+
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
+
|
21 |
+
|
22 |
+
def get_voxels(instance):
|
23 |
+
position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0]
|
24 |
+
coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous()
|
25 |
+
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
|
26 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
27 |
+
return ss
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
33 |
+
help='Directory to save the metadata')
|
34 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
35 |
+
help='Filter objects with aesthetic score lower than this value')
|
36 |
+
parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
|
37 |
+
help='Pretrained encoder model')
|
38 |
+
parser.add_argument('--model_root', type=str, default='results',
|
39 |
+
help='Root directory of models')
|
40 |
+
parser.add_argument('--enc_model', type=str, default=None,
|
41 |
+
help='Encoder model. if specified, use this model instead of pretrained model')
|
42 |
+
parser.add_argument('--ckpt', type=str, default=None,
|
43 |
+
help='Checkpoint to load')
|
44 |
+
parser.add_argument('--resolution', type=int, default=64,
|
45 |
+
help='Resolution')
|
46 |
+
parser.add_argument('--instances', type=str, default=None,
|
47 |
+
help='Instances to process')
|
48 |
+
parser.add_argument('--rank', type=int, default=0)
|
49 |
+
parser.add_argument('--world_size', type=int, default=1)
|
50 |
+
opt = parser.parse_args()
|
51 |
+
opt = edict(vars(opt))
|
52 |
+
|
53 |
+
if opt.enc_model is None:
|
54 |
+
latent_name = f'{opt.enc_pretrained.split("/")[-1]}'
|
55 |
+
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
56 |
+
else:
|
57 |
+
latent_name = f'{opt.enc_model}_{opt.ckpt}'
|
58 |
+
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
59 |
+
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
60 |
+
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
61 |
+
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
62 |
+
encoder.eval()
|
63 |
+
print(f'Loaded model from {ckpt_path}')
|
64 |
+
|
65 |
+
os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True)
|
66 |
+
|
67 |
+
# get file list
|
68 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
69 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
70 |
+
else:
|
71 |
+
raise ValueError('metadata.csv not found')
|
72 |
+
if opt.instances is not None:
|
73 |
+
with open(opt.instances, 'r') as f:
|
74 |
+
instances = f.read().splitlines()
|
75 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
76 |
+
else:
|
77 |
+
if opt.filter_low_aesthetic_score is not None:
|
78 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
79 |
+
metadata = metadata[metadata['voxelized'] == True]
|
80 |
+
if f'ss_latent_{latent_name}' in metadata.columns:
|
81 |
+
metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False]
|
82 |
+
|
83 |
+
start = len(metadata) * opt.rank // opt.world_size
|
84 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
85 |
+
metadata = metadata[start:end]
|
86 |
+
records = []
|
87 |
+
|
88 |
+
# filter out objects that are already processed
|
89 |
+
sha256s = list(metadata['sha256'].values)
|
90 |
+
for sha256 in copy.copy(sha256s):
|
91 |
+
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')):
|
92 |
+
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
93 |
+
sha256s.remove(sha256)
|
94 |
+
|
95 |
+
# encode latents
|
96 |
+
load_queue = Queue(maxsize=4)
|
97 |
+
try:
|
98 |
+
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
99 |
+
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
100 |
+
def loader(sha256):
|
101 |
+
try:
|
102 |
+
ss = get_voxels(sha256)[None].float()
|
103 |
+
load_queue.put((sha256, ss))
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error loading features for {sha256}: {e}")
|
106 |
+
loader_executor.map(loader, sha256s)
|
107 |
+
|
108 |
+
def saver(sha256, pack):
|
109 |
+
save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')
|
110 |
+
np.savez_compressed(save_path, **pack)
|
111 |
+
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
112 |
+
|
113 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
114 |
+
sha256, ss = load_queue.get()
|
115 |
+
ss = ss.cuda().float()
|
116 |
+
latent = encoder(ss, sample_posterior=False)
|
117 |
+
assert torch.isfinite(latent).all(), "Non-finite latent"
|
118 |
+
pack = {
|
119 |
+
'mean': latent[0].cpu().numpy(),
|
120 |
+
}
|
121 |
+
saver_executor.submit(saver, sha256, pack)
|
122 |
+
|
123 |
+
saver_executor.shutdown(wait=True)
|
124 |
+
except:
|
125 |
+
print("Error happened during processing.")
|
126 |
+
|
127 |
+
records = pd.DataFrame.from_records(records)
|
128 |
+
records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False)
|
dataset_toolkits/extract_feature.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import utils3d
|
12 |
+
from tqdm import tqdm
|
13 |
+
from easydict import EasyDict as edict
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
+
from queue import Queue
|
16 |
+
from torchvision import transforms
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
|
20 |
+
torch.set_grad_enabled(False)
|
21 |
+
|
22 |
+
|
23 |
+
def get_data(frames, sha256):
|
24 |
+
with ThreadPoolExecutor(max_workers=16) as executor:
|
25 |
+
def worker(view):
|
26 |
+
image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path'])
|
27 |
+
try:
|
28 |
+
image = Image.open(image_path)
|
29 |
+
except:
|
30 |
+
print(f"Error loading image {image_path}")
|
31 |
+
return None
|
32 |
+
image = image.resize((518, 518), Image.Resampling.LANCZOS)
|
33 |
+
image = np.array(image).astype(np.float32) / 255
|
34 |
+
image = image[:, :, :3] * image[:, :, 3:]
|
35 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float()
|
36 |
+
|
37 |
+
c2w = torch.tensor(view['transform_matrix'])
|
38 |
+
c2w[:3, 1:3] *= -1
|
39 |
+
extrinsics = torch.inverse(c2w)
|
40 |
+
fov = view['camera_angle_x']
|
41 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
42 |
+
|
43 |
+
return {
|
44 |
+
'image': image,
|
45 |
+
'extrinsics': extrinsics,
|
46 |
+
'intrinsics': intrinsics
|
47 |
+
}
|
48 |
+
|
49 |
+
datas = executor.map(worker, frames)
|
50 |
+
for data in datas:
|
51 |
+
if data is not None:
|
52 |
+
yield data
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
parser = argparse.ArgumentParser()
|
57 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
58 |
+
help='Directory to save the metadata')
|
59 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
60 |
+
help='Filter objects with aesthetic score lower than this value')
|
61 |
+
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg',
|
62 |
+
help='Feature extraction model')
|
63 |
+
parser.add_argument('--instances', type=str, default=None,
|
64 |
+
help='Instances to process')
|
65 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
66 |
+
parser.add_argument('--rank', type=int, default=0)
|
67 |
+
parser.add_argument('--world_size', type=int, default=1)
|
68 |
+
opt = parser.parse_args()
|
69 |
+
opt = edict(vars(opt))
|
70 |
+
|
71 |
+
feature_name = opt.model
|
72 |
+
os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True)
|
73 |
+
|
74 |
+
# load model
|
75 |
+
dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model)
|
76 |
+
dinov2_model.eval().cuda()
|
77 |
+
transform = transforms.Compose([
|
78 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
79 |
+
])
|
80 |
+
n_patch = 518 // 14
|
81 |
+
|
82 |
+
# get file list
|
83 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
84 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
85 |
+
else:
|
86 |
+
raise ValueError('metadata.csv not found')
|
87 |
+
if opt.instances is not None:
|
88 |
+
with open(opt.instances, 'r') as f:
|
89 |
+
instances = f.read().splitlines()
|
90 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
91 |
+
else:
|
92 |
+
if opt.filter_low_aesthetic_score is not None:
|
93 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
94 |
+
if f'feature_{feature_name}' in metadata.columns:
|
95 |
+
metadata = metadata[metadata[f'feature_{feature_name}'] == False]
|
96 |
+
metadata = metadata[metadata['voxelized'] == True]
|
97 |
+
metadata = metadata[metadata['rendered'] == True]
|
98 |
+
|
99 |
+
start = len(metadata) * opt.rank // opt.world_size
|
100 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
101 |
+
metadata = metadata[start:end]
|
102 |
+
records = []
|
103 |
+
|
104 |
+
# filter out objects that are already processed
|
105 |
+
sha256s = list(metadata['sha256'].values)
|
106 |
+
for sha256 in copy.copy(sha256s):
|
107 |
+
if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')):
|
108 |
+
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
109 |
+
sha256s.remove(sha256)
|
110 |
+
|
111 |
+
# extract features
|
112 |
+
load_queue = Queue(maxsize=4)
|
113 |
+
try:
|
114 |
+
with ThreadPoolExecutor(max_workers=8) as loader_executor, \
|
115 |
+
ThreadPoolExecutor(max_workers=8) as saver_executor:
|
116 |
+
def loader(sha256):
|
117 |
+
try:
|
118 |
+
with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f:
|
119 |
+
metadata = json.load(f)
|
120 |
+
frames = metadata['frames']
|
121 |
+
data = []
|
122 |
+
for datum in get_data(frames, sha256):
|
123 |
+
datum['image'] = transform(datum['image'])
|
124 |
+
data.append(datum)
|
125 |
+
positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
126 |
+
load_queue.put((sha256, data, positions))
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Error loading data for {sha256}: {e}")
|
129 |
+
|
130 |
+
loader_executor.map(loader, sha256s)
|
131 |
+
|
132 |
+
def saver(sha256, pack, patchtokens, uv):
|
133 |
+
pack['patchtokens'] = F.grid_sample(
|
134 |
+
patchtokens,
|
135 |
+
uv.unsqueeze(1),
|
136 |
+
mode='bilinear',
|
137 |
+
align_corners=False,
|
138 |
+
).squeeze(2).permute(0, 2, 1).cpu().numpy()
|
139 |
+
pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16)
|
140 |
+
save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')
|
141 |
+
np.savez_compressed(save_path, **pack)
|
142 |
+
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
143 |
+
|
144 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting features"):
|
145 |
+
sha256, data, positions = load_queue.get()
|
146 |
+
positions = torch.from_numpy(positions).float().cuda()
|
147 |
+
indices = ((positions + 0.5) * 64).long()
|
148 |
+
assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds"
|
149 |
+
n_views = len(data)
|
150 |
+
N = positions.shape[0]
|
151 |
+
pack = {
|
152 |
+
'indices': indices.cpu().numpy().astype(np.uint8),
|
153 |
+
}
|
154 |
+
patchtokens_lst = []
|
155 |
+
uv_lst = []
|
156 |
+
for i in range(0, n_views, opt.batch_size):
|
157 |
+
batch_data = data[i:i+opt.batch_size]
|
158 |
+
bs = len(batch_data)
|
159 |
+
batch_images = torch.stack([d['image'] for d in batch_data]).cuda()
|
160 |
+
batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda()
|
161 |
+
batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda()
|
162 |
+
features = dinov2_model(batch_images, is_training=True)
|
163 |
+
uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1
|
164 |
+
patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch)
|
165 |
+
patchtokens_lst.append(patchtokens)
|
166 |
+
uv_lst.append(uv)
|
167 |
+
patchtokens = torch.cat(patchtokens_lst, dim=0)
|
168 |
+
uv = torch.cat(uv_lst, dim=0)
|
169 |
+
|
170 |
+
# save features
|
171 |
+
saver_executor.submit(saver, sha256, pack, patchtokens, uv)
|
172 |
+
|
173 |
+
saver_executor.shutdown(wait=True)
|
174 |
+
except:
|
175 |
+
print("Error happened during processing.")
|
176 |
+
|
177 |
+
records = pd.DataFrame.from_records(records)
|
178 |
+
records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False)
|
179 |
+
|
dataset_toolkits/render.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import copy
|
4 |
+
import sys
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import pandas as pd
|
8 |
+
from easydict import EasyDict as edict
|
9 |
+
from functools import partial
|
10 |
+
from subprocess import DEVNULL, call
|
11 |
+
import numpy as np
|
12 |
+
from utils import sphere_hammersley_sequence
|
13 |
+
|
14 |
+
|
15 |
+
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
16 |
+
BLENDER_INSTALLATION_PATH = '/tmp'
|
17 |
+
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
18 |
+
|
19 |
+
def _install_blender():
|
20 |
+
if not os.path.exists(BLENDER_PATH):
|
21 |
+
os.system('sudo apt-get update')
|
22 |
+
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
23 |
+
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
24 |
+
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
25 |
+
|
26 |
+
|
27 |
+
def _render(file_path, sha256, output_dir, num_views):
|
28 |
+
output_folder = os.path.join(output_dir, 'renders', sha256)
|
29 |
+
|
30 |
+
# Build camera {yaw, pitch, radius, fov}
|
31 |
+
yaws = []
|
32 |
+
pitchs = []
|
33 |
+
offset = (np.random.rand(), np.random.rand())
|
34 |
+
for i in range(num_views):
|
35 |
+
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
36 |
+
yaws.append(y)
|
37 |
+
pitchs.append(p)
|
38 |
+
radius = [2] * num_views
|
39 |
+
fov = [40 / 180 * np.pi] * num_views
|
40 |
+
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
41 |
+
|
42 |
+
args = [
|
43 |
+
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
44 |
+
'--',
|
45 |
+
'--views', json.dumps(views),
|
46 |
+
'--object', os.path.expanduser(file_path),
|
47 |
+
'--resolution', '512',
|
48 |
+
'--output_folder', output_folder,
|
49 |
+
'--engine', 'CYCLES',
|
50 |
+
'--save_mesh',
|
51 |
+
]
|
52 |
+
if file_path.endswith('.blend'):
|
53 |
+
args.insert(1, file_path)
|
54 |
+
|
55 |
+
call(args, stdout=DEVNULL, stderr=DEVNULL)
|
56 |
+
|
57 |
+
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
58 |
+
return {'sha256': sha256, 'rendered': True}
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
63 |
+
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
66 |
+
help='Directory to save the metadata')
|
67 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
68 |
+
help='Filter objects with aesthetic score lower than this value')
|
69 |
+
parser.add_argument('--instances', type=str, default=None,
|
70 |
+
help='Instances to process')
|
71 |
+
parser.add_argument('--num_views', type=int, default=150,
|
72 |
+
help='Number of views to render')
|
73 |
+
dataset_utils.add_args(parser)
|
74 |
+
parser.add_argument('--rank', type=int, default=0)
|
75 |
+
parser.add_argument('--world_size', type=int, default=1)
|
76 |
+
parser.add_argument('--max_workers', type=int, default=8)
|
77 |
+
opt = parser.parse_args(sys.argv[2:])
|
78 |
+
opt = edict(vars(opt))
|
79 |
+
|
80 |
+
os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True)
|
81 |
+
|
82 |
+
# install blender
|
83 |
+
print('Checking blender...', flush=True)
|
84 |
+
_install_blender()
|
85 |
+
|
86 |
+
# get file list
|
87 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
88 |
+
raise ValueError('metadata.csv not found')
|
89 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
90 |
+
if opt.instances is None:
|
91 |
+
metadata = metadata[metadata['local_path'].notna()]
|
92 |
+
if opt.filter_low_aesthetic_score is not None:
|
93 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
94 |
+
if 'rendered' in metadata.columns:
|
95 |
+
metadata = metadata[metadata['rendered'] == False]
|
96 |
+
else:
|
97 |
+
if os.path.exists(opt.instances):
|
98 |
+
with open(opt.instances, 'r') as f:
|
99 |
+
instances = f.read().splitlines()
|
100 |
+
else:
|
101 |
+
instances = opt.instances.split(',')
|
102 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
103 |
+
|
104 |
+
start = len(metadata) * opt.rank // opt.world_size
|
105 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
106 |
+
metadata = metadata[start:end]
|
107 |
+
records = []
|
108 |
+
|
109 |
+
# filter out objects that are already processed
|
110 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
111 |
+
if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
112 |
+
records.append({'sha256': sha256, 'rendered': True})
|
113 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
114 |
+
|
115 |
+
print(f'Processing {len(metadata)} objects...')
|
116 |
+
|
117 |
+
# process objects
|
118 |
+
func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views)
|
119 |
+
rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
120 |
+
rendered = pd.concat([rendered, pd.DataFrame.from_records(records)])
|
121 |
+
rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False)
|
dataset_toolkits/render_cond.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import copy
|
4 |
+
import sys
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import pandas as pd
|
8 |
+
from easydict import EasyDict as edict
|
9 |
+
from functools import partial
|
10 |
+
from subprocess import DEVNULL, call
|
11 |
+
import numpy as np
|
12 |
+
from utils import sphere_hammersley_sequence
|
13 |
+
|
14 |
+
|
15 |
+
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
16 |
+
BLENDER_INSTALLATION_PATH = '/tmp'
|
17 |
+
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
18 |
+
|
19 |
+
def _install_blender():
|
20 |
+
if not os.path.exists(BLENDER_PATH):
|
21 |
+
os.system('sudo apt-get update')
|
22 |
+
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
23 |
+
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
24 |
+
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
25 |
+
|
26 |
+
|
27 |
+
def _render_cond(file_path, sha256, output_dir, num_views):
|
28 |
+
output_folder = os.path.join(output_dir, 'renders_cond', sha256)
|
29 |
+
|
30 |
+
# Build camera {yaw, pitch, radius, fov}
|
31 |
+
yaws = []
|
32 |
+
pitchs = []
|
33 |
+
offset = (np.random.rand(), np.random.rand())
|
34 |
+
for i in range(num_views):
|
35 |
+
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
36 |
+
yaws.append(y)
|
37 |
+
pitchs.append(p)
|
38 |
+
fov_min, fov_max = 10, 70
|
39 |
+
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
|
40 |
+
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
|
41 |
+
k_min = 1 / radius_max**2
|
42 |
+
k_max = 1 / radius_min**2
|
43 |
+
ks = np.random.uniform(k_min, k_max, (1000000,))
|
44 |
+
radius = [1 / np.sqrt(k) for k in ks]
|
45 |
+
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
|
46 |
+
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
47 |
+
|
48 |
+
args = [
|
49 |
+
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
50 |
+
'--',
|
51 |
+
'--views', json.dumps(views),
|
52 |
+
'--object', os.path.expanduser(file_path),
|
53 |
+
'--output_folder', os.path.expanduser(output_folder),
|
54 |
+
'--resolution', '1024',
|
55 |
+
]
|
56 |
+
if file_path.endswith('.blend'):
|
57 |
+
args.insert(1, file_path)
|
58 |
+
|
59 |
+
call(args, stdout=DEVNULL)
|
60 |
+
|
61 |
+
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
62 |
+
return {'sha256': sha256, 'cond_rendered': True}
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
67 |
+
|
68 |
+
parser = argparse.ArgumentParser()
|
69 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
70 |
+
help='Directory to save the metadata')
|
71 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
72 |
+
help='Filter objects with aesthetic score lower than this value')
|
73 |
+
parser.add_argument('--instances', type=str, default=None,
|
74 |
+
help='Instances to process')
|
75 |
+
parser.add_argument('--num_views', type=int, default=24,
|
76 |
+
help='Number of views to render')
|
77 |
+
dataset_utils.add_args(parser)
|
78 |
+
parser.add_argument('--rank', type=int, default=0)
|
79 |
+
parser.add_argument('--world_size', type=int, default=1)
|
80 |
+
parser.add_argument('--max_workers', type=int, default=8)
|
81 |
+
opt = parser.parse_args(sys.argv[2:])
|
82 |
+
opt = edict(vars(opt))
|
83 |
+
|
84 |
+
os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True)
|
85 |
+
|
86 |
+
# install blender
|
87 |
+
print('Checking blender...', flush=True)
|
88 |
+
_install_blender()
|
89 |
+
|
90 |
+
# get file list
|
91 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
92 |
+
raise ValueError('metadata.csv not found')
|
93 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
94 |
+
if opt.instances is None:
|
95 |
+
metadata = metadata[metadata['local_path'].notna()]
|
96 |
+
if opt.filter_low_aesthetic_score is not None:
|
97 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
98 |
+
if 'cond_rendered' in metadata.columns:
|
99 |
+
metadata = metadata[metadata['cond_rendered'] == False]
|
100 |
+
else:
|
101 |
+
if os.path.exists(opt.instances):
|
102 |
+
with open(opt.instances, 'r') as f:
|
103 |
+
instances = f.read().splitlines()
|
104 |
+
else:
|
105 |
+
instances = opt.instances.split(',')
|
106 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
107 |
+
|
108 |
+
start = len(metadata) * opt.rank // opt.world_size
|
109 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
110 |
+
metadata = metadata[start:end]
|
111 |
+
records = []
|
112 |
+
|
113 |
+
# filter out objects that are already processed
|
114 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
115 |
+
if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
116 |
+
records.append({'sha256': sha256, 'cond_rendered': True})
|
117 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
118 |
+
|
119 |
+
print(f'Processing {len(metadata)} objects...')
|
120 |
+
|
121 |
+
# process objects
|
122 |
+
func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views)
|
123 |
+
cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
124 |
+
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
|
125 |
+
cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False)
|
dataset_toolkits/setup.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub
|
dataset_toolkits/stat_latent.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from tqdm import tqdm
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
14 |
+
help='Directory to save the metadata')
|
15 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
16 |
+
help='Filter objects with aesthetic score lower than this value')
|
17 |
+
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16',
|
18 |
+
help='Latent model to use')
|
19 |
+
parser.add_argument('--num_samples', type=int, default=50000,
|
20 |
+
help='Number of samples to use for calculating stats')
|
21 |
+
opt = parser.parse_args()
|
22 |
+
opt = edict(vars(opt))
|
23 |
+
|
24 |
+
# get file list
|
25 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
26 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
27 |
+
else:
|
28 |
+
raise ValueError('metadata.csv not found')
|
29 |
+
if opt.filter_low_aesthetic_score is not None:
|
30 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
31 |
+
metadata = metadata[metadata[f'latent_{opt.model}'] == True]
|
32 |
+
sha256s = metadata['sha256'].values
|
33 |
+
sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False)
|
34 |
+
|
35 |
+
# stats
|
36 |
+
means = []
|
37 |
+
mean2s = []
|
38 |
+
with ThreadPoolExecutor(max_workers=16) as executor, \
|
39 |
+
tqdm(total=len(sha256s), desc="Extracting features") as pbar:
|
40 |
+
def worker(sha256):
|
41 |
+
try:
|
42 |
+
feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz'))
|
43 |
+
feats = feats['feats']
|
44 |
+
means.append(feats.mean(axis=0))
|
45 |
+
mean2s.append((feats ** 2).mean(axis=0))
|
46 |
+
pbar.update()
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Error extracting features for {sha256}: {e}")
|
49 |
+
pbar.update()
|
50 |
+
|
51 |
+
executor.map(worker, sha256s)
|
52 |
+
executor.shutdown(wait=True)
|
53 |
+
|
54 |
+
mean = np.array(means).mean(axis=0)
|
55 |
+
mean2 = np.array(mean2s).mean(axis=0)
|
56 |
+
std = np.sqrt(mean2 - mean ** 2)
|
57 |
+
|
58 |
+
print('mean:', mean)
|
59 |
+
print('std:', std)
|
60 |
+
|
61 |
+
with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f:
|
62 |
+
json.dump({
|
63 |
+
'mean': mean.tolist(),
|
64 |
+
'std': std.tolist(),
|
65 |
+
}, f, indent=4)
|
66 |
+
|
dataset_toolkits/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import hashlib
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def get_file_hash(file: str) -> str:
|
7 |
+
sha256 = hashlib.sha256()
|
8 |
+
# Read the file from the path
|
9 |
+
with open(file, "rb") as f:
|
10 |
+
# Update the hash with the file content
|
11 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
12 |
+
sha256.update(byte_block)
|
13 |
+
return sha256.hexdigest()
|
14 |
+
|
15 |
+
# ===============LOW DISCREPANCY SEQUENCES================
|
16 |
+
|
17 |
+
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
18 |
+
|
19 |
+
def radical_inverse(base, n):
|
20 |
+
val = 0
|
21 |
+
inv_base = 1.0 / base
|
22 |
+
inv_base_n = inv_base
|
23 |
+
while n > 0:
|
24 |
+
digit = n % base
|
25 |
+
val += digit * inv_base_n
|
26 |
+
n //= base
|
27 |
+
inv_base_n *= inv_base
|
28 |
+
return val
|
29 |
+
|
30 |
+
def halton_sequence(dim, n):
|
31 |
+
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
32 |
+
|
33 |
+
def hammersley_sequence(dim, n, num_samples):
|
34 |
+
return [n / num_samples] + halton_sequence(dim - 1, n)
|
35 |
+
|
36 |
+
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
|
37 |
+
u, v = hammersley_sequence(2, n, num_samples)
|
38 |
+
u += offset[0] / num_samples
|
39 |
+
v += offset[1]
|
40 |
+
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
41 |
+
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
42 |
+
phi = v * 2 * np.pi
|
43 |
+
return [phi, theta]
|
dataset_toolkits/voxelize.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import sys
|
4 |
+
import importlib
|
5 |
+
import argparse
|
6 |
+
import pandas as pd
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
from functools import partial
|
9 |
+
import numpy as np
|
10 |
+
import open3d as o3d
|
11 |
+
import utils3d
|
12 |
+
|
13 |
+
|
14 |
+
def _voxelize(file, sha256, output_dir):
|
15 |
+
mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply'))
|
16 |
+
# clamp vertices to the range [-0.5, 0.5]
|
17 |
+
vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
18 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
19 |
+
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
|
20 |
+
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
|
21 |
+
assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
|
22 |
+
vertices = (vertices + 0.5) / 64 - 0.5
|
23 |
+
utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices)
|
24 |
+
return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)}
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
29 |
+
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
32 |
+
help='Directory to save the metadata')
|
33 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
34 |
+
help='Filter objects with aesthetic score lower than this value')
|
35 |
+
parser.add_argument('--instances', type=str, default=None,
|
36 |
+
help='Instances to process')
|
37 |
+
parser.add_argument('--num_views', type=int, default=150,
|
38 |
+
help='Number of views to render')
|
39 |
+
dataset_utils.add_args(parser)
|
40 |
+
parser.add_argument('--rank', type=int, default=0)
|
41 |
+
parser.add_argument('--world_size', type=int, default=1)
|
42 |
+
parser.add_argument('--max_workers', type=int, default=None)
|
43 |
+
opt = parser.parse_args(sys.argv[2:])
|
44 |
+
opt = edict(vars(opt))
|
45 |
+
|
46 |
+
os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True)
|
47 |
+
|
48 |
+
# get file list
|
49 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
50 |
+
raise ValueError('metadata.csv not found')
|
51 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
52 |
+
if opt.instances is None:
|
53 |
+
if opt.filter_low_aesthetic_score is not None:
|
54 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
55 |
+
if 'rendered' not in metadata.columns:
|
56 |
+
raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first')
|
57 |
+
metadata = metadata[metadata['rendered'] == True]
|
58 |
+
if 'voxelized' in metadata.columns:
|
59 |
+
metadata = metadata[metadata['voxelized'] == False]
|
60 |
+
else:
|
61 |
+
if os.path.exists(opt.instances):
|
62 |
+
with open(opt.instances, 'r') as f:
|
63 |
+
instances = f.read().splitlines()
|
64 |
+
else:
|
65 |
+
instances = opt.instances.split(',')
|
66 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
67 |
+
|
68 |
+
start = len(metadata) * opt.rank // opt.world_size
|
69 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
70 |
+
metadata = metadata[start:end]
|
71 |
+
records = []
|
72 |
+
|
73 |
+
# filter out objects that are already processed
|
74 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
75 |
+
if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
76 |
+
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
77 |
+
records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)})
|
78 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
79 |
+
|
80 |
+
print(f'Processing {len(metadata)} objects...')
|
81 |
+
|
82 |
+
# process objects
|
83 |
+
func = partial(_voxelize, output_dir=opt.output_dir)
|
84 |
+
voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing')
|
85 |
+
voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)])
|
86 |
+
voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False)
|
env.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
import spaces
|
2 |
-
import os
|
3 |
-
import torch
|
4 |
-
|
5 |
-
@spaces.GPU(duration=5)
|
6 |
-
def check():
|
7 |
-
print(os.system("nvidia-smi"))
|
8 |
-
print(torch.version.cuda)
|
9 |
-
|
10 |
-
check()
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
@spaces.GPU(duration=5)
|
6 |
+
def check():
|
7 |
+
print(os.system("nvidia-smi"))
|
8 |
+
print(torch.version.cuda)
|
9 |
+
|
10 |
+
check()
|
extensions/vox2seq/benchmark.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import vox2seq
|
4 |
+
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
stats = {
|
8 |
+
'z_order_cuda': [],
|
9 |
+
'z_order_pytorch': [],
|
10 |
+
'hilbert_cuda': [],
|
11 |
+
'hilbert_pytorch': [],
|
12 |
+
}
|
13 |
+
RES = [16, 32, 64, 128, 256]
|
14 |
+
for res in RES:
|
15 |
+
coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res))
|
16 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
|
17 |
+
|
18 |
+
start = time.time()
|
19 |
+
for _ in range(100):
|
20 |
+
code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda()
|
21 |
+
torch.cuda.synchronize()
|
22 |
+
stats['z_order_cuda'].append((time.time() - start) / 100)
|
23 |
+
|
24 |
+
start = time.time()
|
25 |
+
for _ in range(100):
|
26 |
+
code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda()
|
27 |
+
torch.cuda.synchronize()
|
28 |
+
stats['z_order_pytorch'].append((time.time() - start) / 100)
|
29 |
+
|
30 |
+
start = time.time()
|
31 |
+
for _ in range(100):
|
32 |
+
code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda()
|
33 |
+
torch.cuda.synchronize()
|
34 |
+
stats['hilbert_cuda'].append((time.time() - start) / 100)
|
35 |
+
|
36 |
+
start = time.time()
|
37 |
+
for _ in range(100):
|
38 |
+
code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda()
|
39 |
+
torch.cuda.synchronize()
|
40 |
+
stats['hilbert_pytorch'].append((time.time() - start) / 100)
|
41 |
+
|
42 |
+
print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}")
|
43 |
+
for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']):
|
44 |
+
print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}")
|
45 |
+
|
extensions/vox2seq/setup.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from setuptools import setup
|
13 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
14 |
+
import os
|
15 |
+
os.path.dirname(os.path.abspath(__file__))
|
16 |
+
|
17 |
+
setup(
|
18 |
+
name="vox2seq",
|
19 |
+
packages=['vox2seq', 'vox2seq.pytorch'],
|
20 |
+
ext_modules=[
|
21 |
+
CUDAExtension(
|
22 |
+
name="vox2seq._C",
|
23 |
+
sources=[
|
24 |
+
"src/api.cu",
|
25 |
+
"src/z_order.cu",
|
26 |
+
"src/hilbert.cu",
|
27 |
+
"src/ext.cpp",
|
28 |
+
],
|
29 |
+
)
|
30 |
+
],
|
31 |
+
cmdclass={
|
32 |
+
'build_ext': BuildExtension
|
33 |
+
}
|
34 |
+
)
|
extensions/vox2seq/src/api.cu
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "api.h"
|
3 |
+
#include "z_order.h"
|
4 |
+
#include "hilbert.h"
|
5 |
+
|
6 |
+
|
7 |
+
torch::Tensor
|
8 |
+
z_order_encode(
|
9 |
+
const torch::Tensor& x,
|
10 |
+
const torch::Tensor& y,
|
11 |
+
const torch::Tensor& z
|
12 |
+
) {
|
13 |
+
// Allocate output tensor
|
14 |
+
torch::Tensor codes = torch::empty_like(x);
|
15 |
+
|
16 |
+
// Call CUDA kernel
|
17 |
+
z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
18 |
+
x.size(0),
|
19 |
+
reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
|
20 |
+
reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
|
21 |
+
reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
|
22 |
+
reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
|
23 |
+
);
|
24 |
+
|
25 |
+
return codes;
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
30 |
+
z_order_decode(
|
31 |
+
const torch::Tensor& codes
|
32 |
+
) {
|
33 |
+
// Allocate output tensors
|
34 |
+
torch::Tensor x = torch::empty_like(codes);
|
35 |
+
torch::Tensor y = torch::empty_like(codes);
|
36 |
+
torch::Tensor z = torch::empty_like(codes);
|
37 |
+
|
38 |
+
// Call CUDA kernel
|
39 |
+
z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
40 |
+
codes.size(0),
|
41 |
+
reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
|
42 |
+
reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
|
43 |
+
reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
|
44 |
+
reinterpret_cast<uint32_t*>(z.data_ptr<int>())
|
45 |
+
);
|
46 |
+
|
47 |
+
return std::make_tuple(x, y, z);
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
torch::Tensor
|
52 |
+
hilbert_encode(
|
53 |
+
const torch::Tensor& x,
|
54 |
+
const torch::Tensor& y,
|
55 |
+
const torch::Tensor& z
|
56 |
+
) {
|
57 |
+
// Allocate output tensor
|
58 |
+
torch::Tensor codes = torch::empty_like(x);
|
59 |
+
|
60 |
+
// Call CUDA kernel
|
61 |
+
hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
62 |
+
x.size(0),
|
63 |
+
reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
|
64 |
+
reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
|
65 |
+
reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
|
66 |
+
reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
|
67 |
+
);
|
68 |
+
|
69 |
+
return codes;
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
74 |
+
hilbert_decode(
|
75 |
+
const torch::Tensor& codes
|
76 |
+
) {
|
77 |
+
// Allocate output tensors
|
78 |
+
torch::Tensor x = torch::empty_like(codes);
|
79 |
+
torch::Tensor y = torch::empty_like(codes);
|
80 |
+
torch::Tensor z = torch::empty_like(codes);
|
81 |
+
|
82 |
+
// Call CUDA kernel
|
83 |
+
hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
84 |
+
codes.size(0),
|
85 |
+
reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
|
86 |
+
reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
|
87 |
+
reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
|
88 |
+
reinterpret_cast<uint32_t*>(z.data_ptr<int>())
|
89 |
+
);
|
90 |
+
|
91 |
+
return std::make_tuple(x, y, z);
|
92 |
+
}
|
extensions/vox2seq/src/api.h
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Serialize a voxel grid
|
3 |
+
*
|
4 |
+
* Copyright (C) 2024, Jianfeng XIANG <[email protected]>
|
5 |
+
* All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under The MIT License [see LICENSE for details]
|
8 |
+
*
|
9 |
+
* Written by Jianfeng XIANG
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
#include <torch/extension.h>
|
14 |
+
|
15 |
+
|
16 |
+
#define BLOCK_SIZE 256
|
17 |
+
|
18 |
+
|
19 |
+
/**
|
20 |
+
* Z-order encode 3D points
|
21 |
+
*
|
22 |
+
* @param x [N] tensor containing the x coordinates
|
23 |
+
* @param y [N] tensor containing the y coordinates
|
24 |
+
* @param z [N] tensor containing the z coordinates
|
25 |
+
*
|
26 |
+
* @return [N] tensor containing the z-order encoded values
|
27 |
+
*/
|
28 |
+
torch::Tensor
|
29 |
+
z_order_encode(
|
30 |
+
const torch::Tensor& x,
|
31 |
+
const torch::Tensor& y,
|
32 |
+
const torch::Tensor& z
|
33 |
+
);
|
34 |
+
|
35 |
+
|
36 |
+
/**
|
37 |
+
* Z-order decode 3D points
|
38 |
+
*
|
39 |
+
* @param codes [N] tensor containing the z-order encoded values
|
40 |
+
*
|
41 |
+
* @return 3 tensors [N] containing the x, y, z coordinates
|
42 |
+
*/
|
43 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
44 |
+
z_order_decode(
|
45 |
+
const torch::Tensor& codes
|
46 |
+
);
|
47 |
+
|
48 |
+
|
49 |
+
/**
|
50 |
+
* Hilbert encode 3D points
|
51 |
+
*
|
52 |
+
* @param x [N] tensor containing the x coordinates
|
53 |
+
* @param y [N] tensor containing the y coordinates
|
54 |
+
* @param z [N] tensor containing the z coordinates
|
55 |
+
*
|
56 |
+
* @return [N] tensor containing the Hilbert encoded values
|
57 |
+
*/
|
58 |
+
torch::Tensor
|
59 |
+
hilbert_encode(
|
60 |
+
const torch::Tensor& x,
|
61 |
+
const torch::Tensor& y,
|
62 |
+
const torch::Tensor& z
|
63 |
+
);
|
64 |
+
|
65 |
+
|
66 |
+
/**
|
67 |
+
* Hilbert decode 3D points
|
68 |
+
*
|
69 |
+
* @param codes [N] tensor containing the Hilbert encoded values
|
70 |
+
*
|
71 |
+
* @return 3 tensors [N] containing the x, y, z coordinates
|
72 |
+
*/
|
73 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
74 |
+
hilbert_decode(
|
75 |
+
const torch::Tensor& codes
|
76 |
+
);
|
extensions/vox2seq/src/ext.cpp
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "api.h"
|
3 |
+
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("z_order_encode", &z_order_encode);
|
7 |
+
m.def("z_order_decode", &z_order_decode);
|
8 |
+
m.def("hilbert_encode", &hilbert_encode);
|
9 |
+
m.def("hilbert_decode", &hilbert_decode);
|
10 |
+
}
|
extensions/vox2seq/src/hilbert.cu
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_runtime.h>
|
3 |
+
#include <device_launch_parameters.h>
|
4 |
+
|
5 |
+
#include <cooperative_groups.h>
|
6 |
+
#include <cooperative_groups/memcpy_async.h>
|
7 |
+
namespace cg = cooperative_groups;
|
8 |
+
|
9 |
+
#include "hilbert.h"
|
10 |
+
|
11 |
+
|
12 |
+
// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
|
13 |
+
static __device__ uint32_t expandBits(uint32_t v)
|
14 |
+
{
|
15 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
16 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
17 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
18 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
19 |
+
return v;
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
// Removes 2 zeros after each bit in a 30-bit integer.
|
24 |
+
static __device__ uint32_t extractBits(uint32_t v)
|
25 |
+
{
|
26 |
+
v = v & 0x49249249;
|
27 |
+
v = (v ^ (v >> 2)) & 0x030C30C3u;
|
28 |
+
v = (v ^ (v >> 4)) & 0x0300F00Fu;
|
29 |
+
v = (v ^ (v >> 8)) & 0x030000FFu;
|
30 |
+
v = (v ^ (v >> 16)) & 0x000003FFu;
|
31 |
+
return v;
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
__global__ void hilbert_encode_cuda(
|
36 |
+
size_t N,
|
37 |
+
const uint32_t* x,
|
38 |
+
const uint32_t* y,
|
39 |
+
const uint32_t* z,
|
40 |
+
uint32_t* codes
|
41 |
+
) {
|
42 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
43 |
+
if (thread_id >= N) return;
|
44 |
+
|
45 |
+
uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]};
|
46 |
+
|
47 |
+
uint32_t m = 1 << 9, q, p, t;
|
48 |
+
|
49 |
+
// Inverse undo excess work
|
50 |
+
q = m;
|
51 |
+
while (q > 1) {
|
52 |
+
p = q - 1;
|
53 |
+
for (int i = 0; i < 3; i++) {
|
54 |
+
if (point[i] & q) {
|
55 |
+
point[0] ^= p; // invert
|
56 |
+
} else {
|
57 |
+
t = (point[0] ^ point[i]) & p;
|
58 |
+
point[0] ^= t;
|
59 |
+
point[i] ^= t;
|
60 |
+
}
|
61 |
+
}
|
62 |
+
q >>= 1;
|
63 |
+
}
|
64 |
+
|
65 |
+
// Gray encode
|
66 |
+
for (int i = 1; i < 3; i++) {
|
67 |
+
point[i] ^= point[i - 1];
|
68 |
+
}
|
69 |
+
t = 0;
|
70 |
+
q = m;
|
71 |
+
while (q > 1) {
|
72 |
+
if (point[2] & q) {
|
73 |
+
t ^= q - 1;
|
74 |
+
}
|
75 |
+
q >>= 1;
|
76 |
+
}
|
77 |
+
for (int i = 0; i < 3; i++) {
|
78 |
+
point[i] ^= t;
|
79 |
+
}
|
80 |
+
|
81 |
+
// Convert to 3D Hilbert code
|
82 |
+
uint32_t xx = expandBits(point[0]);
|
83 |
+
uint32_t yy = expandBits(point[1]);
|
84 |
+
uint32_t zz = expandBits(point[2]);
|
85 |
+
|
86 |
+
codes[thread_id] = xx * 4 + yy * 2 + zz;
|
87 |
+
}
|
88 |
+
|
89 |
+
|
90 |
+
__global__ void hilbert_decode_cuda(
|
91 |
+
size_t N,
|
92 |
+
const uint32_t* codes,
|
93 |
+
uint32_t* x,
|
94 |
+
uint32_t* y,
|
95 |
+
uint32_t* z
|
96 |
+
) {
|
97 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
98 |
+
if (thread_id >= N) return;
|
99 |
+
|
100 |
+
uint32_t point[3];
|
101 |
+
point[0] = extractBits(codes[thread_id] >> 2);
|
102 |
+
point[1] = extractBits(codes[thread_id] >> 1);
|
103 |
+
point[2] = extractBits(codes[thread_id]);
|
104 |
+
|
105 |
+
uint32_t m = 2 << 9, q, p, t;
|
106 |
+
|
107 |
+
// Gray decode by H ^ (H/2)
|
108 |
+
t = point[2] >> 1;
|
109 |
+
for (int i = 2; i > 0; i--) {
|
110 |
+
point[i] ^= point[i - 1];
|
111 |
+
}
|
112 |
+
point[0] ^= t;
|
113 |
+
|
114 |
+
// Undo excess work
|
115 |
+
q = 2;
|
116 |
+
while (q != m) {
|
117 |
+
p = q - 1;
|
118 |
+
for (int i = 2; i >= 0; i--) {
|
119 |
+
if (point[i] & q) {
|
120 |
+
point[0] ^= p;
|
121 |
+
} else {
|
122 |
+
t = (point[0] ^ point[i]) & p;
|
123 |
+
point[0] ^= t;
|
124 |
+
point[i] ^= t;
|
125 |
+
}
|
126 |
+
}
|
127 |
+
q <<= 1;
|
128 |
+
}
|
129 |
+
|
130 |
+
x[thread_id] = point[0];
|
131 |
+
y[thread_id] = point[1];
|
132 |
+
z[thread_id] = point[2];
|
133 |
+
}
|
extensions/vox2seq/src/hilbert.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
/**
|
4 |
+
* Hilbert encode 3D points
|
5 |
+
*
|
6 |
+
* @param x [N] tensor containing the x coordinates
|
7 |
+
* @param y [N] tensor containing the y coordinates
|
8 |
+
* @param z [N] tensor containing the z coordinates
|
9 |
+
*
|
10 |
+
* @return [N] tensor containing the z-order encoded values
|
11 |
+
*/
|
12 |
+
__global__ void hilbert_encode_cuda(
|
13 |
+
size_t N,
|
14 |
+
const uint32_t* x,
|
15 |
+
const uint32_t* y,
|
16 |
+
const uint32_t* z,
|
17 |
+
uint32_t* codes
|
18 |
+
);
|
19 |
+
|
20 |
+
|
21 |
+
/**
|
22 |
+
* Hilbert decode 3D points
|
23 |
+
*
|
24 |
+
* @param codes [N] tensor containing the z-order encoded values
|
25 |
+
* @param x [N] tensor containing the x coordinates
|
26 |
+
* @param y [N] tensor containing the y coordinates
|
27 |
+
* @param z [N] tensor containing the z coordinates
|
28 |
+
*/
|
29 |
+
__global__ void hilbert_decode_cuda(
|
30 |
+
size_t N,
|
31 |
+
const uint32_t* codes,
|
32 |
+
uint32_t* x,
|
33 |
+
uint32_t* y,
|
34 |
+
uint32_t* z
|
35 |
+
);
|
extensions/vox2seq/src/z_order.cu
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_runtime.h>
|
3 |
+
#include <device_launch_parameters.h>
|
4 |
+
|
5 |
+
#include <cooperative_groups.h>
|
6 |
+
#include <cooperative_groups/memcpy_async.h>
|
7 |
+
namespace cg = cooperative_groups;
|
8 |
+
|
9 |
+
#include "z_order.h"
|
10 |
+
|
11 |
+
|
12 |
+
// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
|
13 |
+
static __device__ uint32_t expandBits(uint32_t v)
|
14 |
+
{
|
15 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
16 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
17 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
18 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
19 |
+
return v;
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
// Removes 2 zeros after each bit in a 30-bit integer.
|
24 |
+
static __device__ uint32_t extractBits(uint32_t v)
|
25 |
+
{
|
26 |
+
v = v & 0x49249249;
|
27 |
+
v = (v ^ (v >> 2)) & 0x030C30C3u;
|
28 |
+
v = (v ^ (v >> 4)) & 0x0300F00Fu;
|
29 |
+
v = (v ^ (v >> 8)) & 0x030000FFu;
|
30 |
+
v = (v ^ (v >> 16)) & 0x000003FFu;
|
31 |
+
return v;
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
__global__ void z_order_encode_cuda(
|
36 |
+
size_t N,
|
37 |
+
const uint32_t* x,
|
38 |
+
const uint32_t* y,
|
39 |
+
const uint32_t* z,
|
40 |
+
uint32_t* codes
|
41 |
+
) {
|
42 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
43 |
+
if (thread_id >= N) return;
|
44 |
+
|
45 |
+
uint32_t xx = expandBits(x[thread_id]);
|
46 |
+
uint32_t yy = expandBits(y[thread_id]);
|
47 |
+
uint32_t zz = expandBits(z[thread_id]);
|
48 |
+
|
49 |
+
codes[thread_id] = xx * 4 + yy * 2 + zz;
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
__global__ void z_order_decode_cuda(
|
54 |
+
size_t N,
|
55 |
+
const uint32_t* codes,
|
56 |
+
uint32_t* x,
|
57 |
+
uint32_t* y,
|
58 |
+
uint32_t* z
|
59 |
+
) {
|
60 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
61 |
+
if (thread_id >= N) return;
|
62 |
+
|
63 |
+
x[thread_id] = extractBits(codes[thread_id] >> 2);
|
64 |
+
y[thread_id] = extractBits(codes[thread_id] >> 1);
|
65 |
+
z[thread_id] = extractBits(codes[thread_id]);
|
66 |
+
}
|
extensions/vox2seq/src/z_order.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
/**
|
4 |
+
* Z-order encode 3D points
|
5 |
+
*
|
6 |
+
* @param x [N] tensor containing the x coordinates
|
7 |
+
* @param y [N] tensor containing the y coordinates
|
8 |
+
* @param z [N] tensor containing the z coordinates
|
9 |
+
*
|
10 |
+
* @return [N] tensor containing the z-order encoded values
|
11 |
+
*/
|
12 |
+
__global__ void z_order_encode_cuda(
|
13 |
+
size_t N,
|
14 |
+
const uint32_t* x,
|
15 |
+
const uint32_t* y,
|
16 |
+
const uint32_t* z,
|
17 |
+
uint32_t* codes
|
18 |
+
);
|
19 |
+
|
20 |
+
|
21 |
+
/**
|
22 |
+
* Z-order decode 3D points
|
23 |
+
*
|
24 |
+
* @param codes [N] tensor containing the z-order encoded values
|
25 |
+
* @param x [N] tensor containing the x coordinates
|
26 |
+
* @param y [N] tensor containing the y coordinates
|
27 |
+
* @param z [N] tensor containing the z coordinates
|
28 |
+
*/
|
29 |
+
__global__ void z_order_decode_cuda(
|
30 |
+
size_t N,
|
31 |
+
const uint32_t* codes,
|
32 |
+
uint32_t* x,
|
33 |
+
uint32_t* y,
|
34 |
+
uint32_t* z
|
35 |
+
);
|
extensions/vox2seq/test.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import vox2seq
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
RES = 256
|
7 |
+
coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES))
|
8 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
|
9 |
+
code_z_cuda = vox2seq.encode(coords, mode='z_order')
|
10 |
+
code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order')
|
11 |
+
code_h_cuda = vox2seq.encode(coords, mode='hilbert')
|
12 |
+
code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert')
|
13 |
+
assert torch.equal(code_z_cuda, code_z_pytorch)
|
14 |
+
assert torch.equal(code_h_cuda, code_h_pytorch)
|
15 |
+
|
16 |
+
code = torch.arange(RES**3).int().cuda()
|
17 |
+
coords_z_cuda = vox2seq.decode(code, mode='z_order')
|
18 |
+
coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order')
|
19 |
+
coords_h_cuda = vox2seq.decode(code, mode='hilbert')
|
20 |
+
coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert')
|
21 |
+
assert torch.equal(coords_z_cuda, coords_z_pytorch)
|
22 |
+
assert torch.equal(coords_h_cuda, coords_h_pytorch)
|
23 |
+
|
24 |
+
print("All tests passed.")
|
25 |
+
|
extensions/vox2seq/vox2seq/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import *
|
3 |
+
import torch
|
4 |
+
from . import _C
|
5 |
+
from . import pytorch
|
6 |
+
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
10 |
+
"""
|
11 |
+
Encodes 3D coordinates into a 30-bit code.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
coords: a tensor of shape [N, 3] containing the 3D coordinates.
|
15 |
+
permute: the permutation of the coordinates.
|
16 |
+
mode: the encoding mode to use.
|
17 |
+
"""
|
18 |
+
assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]"
|
19 |
+
x = coords[:, permute[0]].int()
|
20 |
+
y = coords[:, permute[1]].int()
|
21 |
+
z = coords[:, permute[2]].int()
|
22 |
+
if mode == 'z_order':
|
23 |
+
return _C.z_order_encode(x, y, z)
|
24 |
+
elif mode == 'hilbert':
|
25 |
+
return _C.hilbert_encode(x, y, z)
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Unknown encoding mode: {mode}")
|
28 |
+
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
32 |
+
"""
|
33 |
+
Decodes a 30-bit code into 3D coordinates.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
code: a tensor of shape [N] containing the 30-bit code.
|
37 |
+
permute: the permutation of the coordinates.
|
38 |
+
mode: the decoding mode to use.
|
39 |
+
"""
|
40 |
+
assert code.ndim == 1, "Input code must be of shape [N]"
|
41 |
+
if mode == 'z_order':
|
42 |
+
coords = _C.z_order_decode(code)
|
43 |
+
elif mode == 'hilbert':
|
44 |
+
coords = _C.hilbert_decode(code)
|
45 |
+
else:
|
46 |
+
raise ValueError(f"Unknown decoding mode: {mode}")
|
47 |
+
x = coords[permute.index(0)]
|
48 |
+
y = coords[permute.index(1)]
|
49 |
+
z = coords[permute.index(2)]
|
50 |
+
return torch.stack([x, y, z], dim=-1)
|
extensions/vox2seq/vox2seq/pytorch/__init__.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
from .default import (
|
5 |
+
encode,
|
6 |
+
decode,
|
7 |
+
z_order_encode,
|
8 |
+
z_order_decode,
|
9 |
+
hilbert_encode,
|
10 |
+
hilbert_decode,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
16 |
+
"""
|
17 |
+
Encodes 3D coordinates into a 30-bit code.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
coords: a tensor of shape [N, 3] containing the 3D coordinates.
|
21 |
+
permute: the permutation of the coordinates.
|
22 |
+
mode: the encoding mode to use.
|
23 |
+
"""
|
24 |
+
if mode == 'z_order':
|
25 |
+
return z_order_encode(coords[:, permute], depth=10).int()
|
26 |
+
elif mode == 'hilbert':
|
27 |
+
return hilbert_encode(coords[:, permute], depth=10).int()
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown encoding mode: {mode}")
|
30 |
+
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
Decodes a 30-bit code into 3D coordinates.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
code: a tensor of shape [N] containing the 30-bit code.
|
39 |
+
permute: the permutation of the coordinates.
|
40 |
+
mode: the decoding mode to use.
|
41 |
+
"""
|
42 |
+
if mode == 'z_order':
|
43 |
+
return z_order_decode(code, depth=10)[:, permute].float()
|
44 |
+
elif mode == 'hilbert':
|
45 |
+
return hilbert_decode(code, depth=10)[:, permute].float()
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Unknown decoding mode: {mode}")
|
48 |
+
|
extensions/vox2seq/vox2seq/pytorch/default.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .z_order import xyz2key as z_order_encode_
|
3 |
+
from .z_order import key2xyz as z_order_decode_
|
4 |
+
from .hilbert import encode as hilbert_encode_
|
5 |
+
from .hilbert import decode as hilbert_decode_
|
6 |
+
|
7 |
+
|
8 |
+
@torch.inference_mode()
|
9 |
+
def encode(grid_coord, batch=None, depth=16, order="z"):
|
10 |
+
assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
|
11 |
+
if order == "z":
|
12 |
+
code = z_order_encode(grid_coord, depth=depth)
|
13 |
+
elif order == "z-trans":
|
14 |
+
code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
15 |
+
elif order == "hilbert":
|
16 |
+
code = hilbert_encode(grid_coord, depth=depth)
|
17 |
+
elif order == "hilbert-trans":
|
18 |
+
code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
19 |
+
else:
|
20 |
+
raise NotImplementedError
|
21 |
+
if batch is not None:
|
22 |
+
batch = batch.long()
|
23 |
+
code = batch << depth * 3 | code
|
24 |
+
return code
|
25 |
+
|
26 |
+
|
27 |
+
@torch.inference_mode()
|
28 |
+
def decode(code, depth=16, order="z"):
|
29 |
+
assert order in {"z", "hilbert"}
|
30 |
+
batch = code >> depth * 3
|
31 |
+
code = code & ((1 << depth * 3) - 1)
|
32 |
+
if order == "z":
|
33 |
+
grid_coord = z_order_decode(code, depth=depth)
|
34 |
+
elif order == "hilbert":
|
35 |
+
grid_coord = hilbert_decode(code, depth=depth)
|
36 |
+
else:
|
37 |
+
raise NotImplementedError
|
38 |
+
return grid_coord, batch
|
39 |
+
|
40 |
+
|
41 |
+
def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
|
42 |
+
x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
|
43 |
+
# we block the support to batch, maintain batched code in Point class
|
44 |
+
code = z_order_encode_(x, y, z, b=None, depth=depth)
|
45 |
+
return code
|
46 |
+
|
47 |
+
|
48 |
+
def z_order_decode(code: torch.Tensor, depth):
|
49 |
+
x, y, z, _ = z_order_decode_(code, depth=depth)
|
50 |
+
grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
|
51 |
+
return grid_coord
|
52 |
+
|
53 |
+
|
54 |
+
def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
|
55 |
+
return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
|
56 |
+
|
57 |
+
|
58 |
+
def hilbert_decode(code: torch.Tensor, depth: int = 16):
|
59 |
+
return hilbert_decode_(code, num_dims=3, num_bits=depth)
|
extensions/vox2seq/vox2seq/pytorch/hilbert.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Hilbert Order
|
3 |
+
Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
|
4 |
+
|
5 |
+
Author: Xiaoyang Wu ([email protected]), Kaixin Xu
|
6 |
+
Please cite our work if the code is helpful to you.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def right_shift(binary, k=1, axis=-1):
|
13 |
+
"""Right shift an array of binary values.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
-----------
|
17 |
+
binary: An ndarray of binary values.
|
18 |
+
|
19 |
+
k: The number of bits to shift. Default 1.
|
20 |
+
|
21 |
+
axis: The axis along which to shift. Default -1.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
--------
|
25 |
+
Returns an ndarray with zero prepended and the ends truncated, along
|
26 |
+
whatever axis was specified."""
|
27 |
+
|
28 |
+
# If we're shifting the whole thing, just return zeros.
|
29 |
+
if binary.shape[axis] <= k:
|
30 |
+
return torch.zeros_like(binary)
|
31 |
+
|
32 |
+
# Determine the padding pattern.
|
33 |
+
# padding = [(0,0)] * len(binary.shape)
|
34 |
+
# padding[axis] = (k,0)
|
35 |
+
|
36 |
+
# Determine the slicing pattern to eliminate just the last one.
|
37 |
+
slicing = [slice(None)] * len(binary.shape)
|
38 |
+
slicing[axis] = slice(None, -k)
|
39 |
+
shifted = torch.nn.functional.pad(
|
40 |
+
binary[tuple(slicing)], (k, 0), mode="constant", value=0
|
41 |
+
)
|
42 |
+
|
43 |
+
return shifted
|
44 |
+
|
45 |
+
|
46 |
+
def binary2gray(binary, axis=-1):
|
47 |
+
"""Convert an array of binary values into Gray codes.
|
48 |
+
|
49 |
+
This uses the classic X ^ (X >> 1) trick to compute the Gray code.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
-----------
|
53 |
+
binary: An ndarray of binary values.
|
54 |
+
|
55 |
+
axis: The axis along which to compute the gray code. Default=-1.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
--------
|
59 |
+
Returns an ndarray of Gray codes.
|
60 |
+
"""
|
61 |
+
shifted = right_shift(binary, axis=axis)
|
62 |
+
|
63 |
+
# Do the X ^ (X >> 1) trick.
|
64 |
+
gray = torch.logical_xor(binary, shifted)
|
65 |
+
|
66 |
+
return gray
|
67 |
+
|
68 |
+
|
69 |
+
def gray2binary(gray, axis=-1):
|
70 |
+
"""Convert an array of Gray codes back into binary values.
|
71 |
+
|
72 |
+
Parameters:
|
73 |
+
-----------
|
74 |
+
gray: An ndarray of gray codes.
|
75 |
+
|
76 |
+
axis: The axis along which to perform Gray decoding. Default=-1.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
--------
|
80 |
+
Returns an ndarray of binary values.
|
81 |
+
"""
|
82 |
+
|
83 |
+
# Loop the log2(bits) number of times necessary, with shift and xor.
|
84 |
+
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
|
85 |
+
while shift > 0:
|
86 |
+
gray = torch.logical_xor(gray, right_shift(gray, shift))
|
87 |
+
shift = torch.div(shift, 2, rounding_mode="floor")
|
88 |
+
return gray
|
89 |
+
|
90 |
+
|
91 |
+
def encode(locs, num_dims, num_bits):
|
92 |
+
"""Decode an array of locations in a hypercube into a Hilbert integer.
|
93 |
+
|
94 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
95 |
+
Skilling as described in:
|
96 |
+
|
97 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
98 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
99 |
+
|
100 |
+
Params:
|
101 |
+
-------
|
102 |
+
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
|
103 |
+
which each dimension runs from 0 to 2**num_bits-1. The shape can
|
104 |
+
be arbitrary, as long as the last dimension of the same has size
|
105 |
+
num_dims.
|
106 |
+
|
107 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
108 |
+
|
109 |
+
num_bits - The number of bits for each dimension. Integer.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
--------
|
113 |
+
The output is an ndarray of uint64 integers with the same shape as the
|
114 |
+
input, excluding the last dimension, which needs to be num_dims.
|
115 |
+
"""
|
116 |
+
|
117 |
+
# Keep around the original shape for later.
|
118 |
+
orig_shape = locs.shape
|
119 |
+
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
|
120 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
121 |
+
|
122 |
+
if orig_shape[-1] != num_dims:
|
123 |
+
raise ValueError(
|
124 |
+
"""
|
125 |
+
The shape of locs was surprising in that the last dimension was of size
|
126 |
+
%d, but num_dims=%d. These need to be equal.
|
127 |
+
"""
|
128 |
+
% (orig_shape[-1], num_dims)
|
129 |
+
)
|
130 |
+
|
131 |
+
if num_dims * num_bits > 63:
|
132 |
+
raise ValueError(
|
133 |
+
"""
|
134 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
135 |
+
into a int64. Are you sure you need that many points on your Hilbert
|
136 |
+
curve?
|
137 |
+
"""
|
138 |
+
% (num_dims, num_bits, num_dims * num_bits)
|
139 |
+
)
|
140 |
+
|
141 |
+
# Treat the location integers as 64-bit unsigned and then split them up into
|
142 |
+
# a sequence of uint8s. Preserve the association by dimension.
|
143 |
+
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
|
144 |
+
|
145 |
+
# Now turn these into bits and truncate to num_bits.
|
146 |
+
gray = (
|
147 |
+
locs_uint8.unsqueeze(-1)
|
148 |
+
.bitwise_and(bitpack_mask_rev)
|
149 |
+
.ne(0)
|
150 |
+
.byte()
|
151 |
+
.flatten(-2, -1)[..., -num_bits:]
|
152 |
+
)
|
153 |
+
|
154 |
+
# Run the decoding process the other way.
|
155 |
+
# Iterate forwards through the bits.
|
156 |
+
for bit in range(0, num_bits):
|
157 |
+
# Iterate forwards through the dimensions.
|
158 |
+
for dim in range(0, num_dims):
|
159 |
+
# Identify which ones have this bit active.
|
160 |
+
mask = gray[:, dim, bit]
|
161 |
+
|
162 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
163 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
164 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
165 |
+
)
|
166 |
+
|
167 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
168 |
+
to_flip = torch.logical_and(
|
169 |
+
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
|
170 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
171 |
+
)
|
172 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
173 |
+
gray[:, dim, bit + 1 :], to_flip
|
174 |
+
)
|
175 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
176 |
+
|
177 |
+
# Now flatten out.
|
178 |
+
gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
|
179 |
+
|
180 |
+
# Convert Gray back to binary.
|
181 |
+
hh_bin = gray2binary(gray)
|
182 |
+
|
183 |
+
# Pad back out to 64 bits.
|
184 |
+
extra_dims = 64 - num_bits * num_dims
|
185 |
+
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
|
186 |
+
|
187 |
+
# Convert binary values into uint8s.
|
188 |
+
hh_uint8 = (
|
189 |
+
(padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
|
190 |
+
.sum(2)
|
191 |
+
.squeeze()
|
192 |
+
.type(torch.uint8)
|
193 |
+
)
|
194 |
+
|
195 |
+
# Convert uint8s into uint64s.
|
196 |
+
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
|
197 |
+
|
198 |
+
return hh_uint64
|
199 |
+
|
200 |
+
|
201 |
+
def decode(hilberts, num_dims, num_bits):
|
202 |
+
"""Decode an array of Hilbert integers into locations in a hypercube.
|
203 |
+
|
204 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
205 |
+
Skilling as described in:
|
206 |
+
|
207 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
208 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
209 |
+
|
210 |
+
Params:
|
211 |
+
-------
|
212 |
+
hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
|
213 |
+
cannot have fewer bits than num_dims * num_bits.
|
214 |
+
|
215 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
216 |
+
|
217 |
+
num_bits - The number of bits for each dimension. Integer.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
--------
|
221 |
+
The output is an ndarray of unsigned integers with the same shape as hilberts
|
222 |
+
but with an additional dimension of size num_dims.
|
223 |
+
"""
|
224 |
+
|
225 |
+
if num_dims * num_bits > 64:
|
226 |
+
raise ValueError(
|
227 |
+
"""
|
228 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
229 |
+
into a uint64. Are you sure you need that many points on your Hilbert
|
230 |
+
curve?
|
231 |
+
"""
|
232 |
+
% (num_dims, num_bits)
|
233 |
+
)
|
234 |
+
|
235 |
+
# Handle the case where we got handed a naked integer.
|
236 |
+
hilberts = torch.atleast_1d(hilberts)
|
237 |
+
|
238 |
+
# Keep around the shape for later.
|
239 |
+
orig_shape = hilberts.shape
|
240 |
+
bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
|
241 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
242 |
+
|
243 |
+
# Treat each of the hilberts as a s equence of eight uint8.
|
244 |
+
# This treats all of the inputs as uint64 and makes things uniform.
|
245 |
+
hh_uint8 = (
|
246 |
+
hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
|
247 |
+
)
|
248 |
+
|
249 |
+
# Turn these lists of uints into lists of bits and then truncate to the size
|
250 |
+
# we actually need for using Skilling's procedure.
|
251 |
+
hh_bits = (
|
252 |
+
hh_uint8.unsqueeze(-1)
|
253 |
+
.bitwise_and(bitpack_mask_rev)
|
254 |
+
.ne(0)
|
255 |
+
.byte()
|
256 |
+
.flatten(-2, -1)[:, -num_dims * num_bits :]
|
257 |
+
)
|
258 |
+
|
259 |
+
# Take the sequence of bits and Gray-code it.
|
260 |
+
gray = binary2gray(hh_bits)
|
261 |
+
|
262 |
+
# There has got to be a better way to do this.
|
263 |
+
# I could index them differently, but the eventual packbits likes it this way.
|
264 |
+
gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
|
265 |
+
|
266 |
+
# Iterate backwards through the bits.
|
267 |
+
for bit in range(num_bits - 1, -1, -1):
|
268 |
+
# Iterate backwards through the dimensions.
|
269 |
+
for dim in range(num_dims - 1, -1, -1):
|
270 |
+
# Identify which ones have this bit active.
|
271 |
+
mask = gray[:, dim, bit]
|
272 |
+
|
273 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
274 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
275 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
276 |
+
)
|
277 |
+
|
278 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
279 |
+
to_flip = torch.logical_and(
|
280 |
+
torch.logical_not(mask[:, None]),
|
281 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
282 |
+
)
|
283 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
284 |
+
gray[:, dim, bit + 1 :], to_flip
|
285 |
+
)
|
286 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
287 |
+
|
288 |
+
# Pad back out to 64 bits.
|
289 |
+
extra_dims = 64 - num_bits
|
290 |
+
padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
|
291 |
+
|
292 |
+
# Now chop these up into blocks of 8.
|
293 |
+
locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
|
294 |
+
|
295 |
+
# Take those blocks and turn them unto uint8s.
|
296 |
+
# from IPython import embed; embed()
|
297 |
+
locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
|
298 |
+
|
299 |
+
# Finally, treat these as uint64s.
|
300 |
+
flat_locs = locs_uint8.view(torch.int64)
|
301 |
+
|
302 |
+
# Return them in the expected shape.
|
303 |
+
return flat_locs.reshape((*orig_shape, num_dims))
|