cavargas10 commited on
Commit
178f950
·
verified ·
1 Parent(s): 6af598e

Upload 288 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -65
  2. .gitignore +2 -0
  3. README.md +16 -12
  4. app.py +274 -232
  5. app_img.py +414 -0
  6. configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json +102 -0
  7. configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json +101 -0
  8. configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json +101 -0
  9. configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json +101 -0
  10. configs/generation/ss_flow_img_dit_L_16l8_fp16.json +70 -0
  11. configs/generation/ss_flow_txt_dit_B_16l8_fp16.json +69 -0
  12. configs/generation/ss_flow_txt_dit_L_16l8_fp16.json +69 -0
  13. configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json +70 -0
  14. configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json +73 -0
  15. configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json +71 -0
  16. configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json +105 -0
  17. configs/vae/ss_vae_conv3d_16l8_fp16.json +65 -0
  18. dataset_toolkits/blender_script/io_scene_usdz.zip +3 -0
  19. dataset_toolkits/blender_script/render.py +528 -0
  20. dataset_toolkits/build_metadata.py +270 -0
  21. dataset_toolkits/datasets/3D-FUTURE.py +97 -0
  22. dataset_toolkits/datasets/ABO.py +96 -0
  23. dataset_toolkits/datasets/HSSD.py +103 -0
  24. dataset_toolkits/datasets/ObjaverseXL.py +92 -0
  25. dataset_toolkits/datasets/Toys4k.py +92 -0
  26. dataset_toolkits/download.py +52 -0
  27. dataset_toolkits/encode_latent.py +127 -0
  28. dataset_toolkits/encode_ss_latent.py +128 -0
  29. dataset_toolkits/extract_feature.py +179 -0
  30. dataset_toolkits/render.py +121 -0
  31. dataset_toolkits/render_cond.py +125 -0
  32. dataset_toolkits/setup.sh +1 -0
  33. dataset_toolkits/stat_latent.py +66 -0
  34. dataset_toolkits/utils.py +43 -0
  35. dataset_toolkits/voxelize.py +86 -0
  36. env.py +10 -10
  37. extensions/vox2seq/benchmark.py +45 -0
  38. extensions/vox2seq/setup.py +34 -0
  39. extensions/vox2seq/src/api.cu +92 -0
  40. extensions/vox2seq/src/api.h +76 -0
  41. extensions/vox2seq/src/ext.cpp +10 -0
  42. extensions/vox2seq/src/hilbert.cu +133 -0
  43. extensions/vox2seq/src/hilbert.h +35 -0
  44. extensions/vox2seq/src/z_order.cu +66 -0
  45. extensions/vox2seq/src/z_order.h +35 -0
  46. extensions/vox2seq/test.py +25 -0
  47. extensions/vox2seq/vox2seq/__init__.py +50 -0
  48. extensions/vox2seq/vox2seq/pytorch/__init__.py +48 -0
  49. extensions/vox2seq/vox2seq/pytorch/default.py +59 -0
  50. extensions/vox2seq/vox2seq/pytorch/hilbert.py +303 -0
.gitattributes CHANGED
@@ -34,68 +34,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
- assets/example_image/T.png filter=lfs diff=lfs merge=lfs -text
38
- assets/example_image/typical_building_building.png filter=lfs diff=lfs merge=lfs -text
39
- assets/example_image/typical_building_castle.png filter=lfs diff=lfs merge=lfs -text
40
- assets/example_image/typical_building_colorful_cottage.png filter=lfs diff=lfs merge=lfs -text
41
- assets/example_image/typical_building_maya_pyramid.png filter=lfs diff=lfs merge=lfs -text
42
- assets/example_image/typical_building_mushroom.png filter=lfs diff=lfs merge=lfs -text
43
- assets/example_image/typical_building_space_station.png filter=lfs diff=lfs merge=lfs -text
44
- assets/example_image/typical_creature_dragon.png filter=lfs diff=lfs merge=lfs -text
45
- assets/example_image/typical_creature_elephant.png filter=lfs diff=lfs merge=lfs -text
46
- assets/example_image/typical_creature_furry.png filter=lfs diff=lfs merge=lfs -text
47
- assets/example_image/typical_creature_quadruped.png filter=lfs diff=lfs merge=lfs -text
48
- assets/example_image/typical_creature_robot_crab.png filter=lfs diff=lfs merge=lfs -text
49
- assets/example_image/typical_creature_robot_dinosour.png filter=lfs diff=lfs merge=lfs -text
50
- assets/example_image/typical_creature_rock_monster.png filter=lfs diff=lfs merge=lfs -text
51
- assets/example_image/typical_humanoid_block_robot.png filter=lfs diff=lfs merge=lfs -text
52
- assets/example_image/typical_humanoid_dragonborn.png filter=lfs diff=lfs merge=lfs -text
53
- assets/example_image/typical_humanoid_dwarf.png filter=lfs diff=lfs merge=lfs -text
54
- assets/example_image/typical_humanoid_goblin.png filter=lfs diff=lfs merge=lfs -text
55
- assets/example_image/typical_humanoid_mech.png filter=lfs diff=lfs merge=lfs -text
56
- assets/example_image/typical_misc_crate.png filter=lfs diff=lfs merge=lfs -text
57
- assets/example_image/typical_misc_fireplace.png filter=lfs diff=lfs merge=lfs -text
58
- assets/example_image/typical_misc_gate.png filter=lfs diff=lfs merge=lfs -text
59
- assets/example_image/typical_misc_lantern.png filter=lfs diff=lfs merge=lfs -text
60
- assets/example_image/typical_misc_magicbook.png filter=lfs diff=lfs merge=lfs -text
61
- assets/example_image/typical_misc_mailbox.png filter=lfs diff=lfs merge=lfs -text
62
- assets/example_image/typical_misc_monster_chest.png filter=lfs diff=lfs merge=lfs -text
63
- assets/example_image/typical_misc_paper_machine.png filter=lfs diff=lfs merge=lfs -text
64
- assets/example_image/typical_misc_phonograph.png filter=lfs diff=lfs merge=lfs -text
65
- assets/example_image/typical_misc_portal2.png filter=lfs diff=lfs merge=lfs -text
66
- assets/example_image/typical_misc_storage_chest.png filter=lfs diff=lfs merge=lfs -text
67
- assets/example_image/typical_misc_telephone.png filter=lfs diff=lfs merge=lfs -text
68
- assets/example_image/typical_misc_television.png filter=lfs diff=lfs merge=lfs -text
69
- assets/example_image/typical_misc_workbench.png filter=lfs diff=lfs merge=lfs -text
70
- assets/example_image/typical_vehicle_biplane.png filter=lfs diff=lfs merge=lfs -text
71
- assets/example_image/typical_vehicle_bulldozer.png filter=lfs diff=lfs merge=lfs -text
72
- assets/example_image/typical_vehicle_cart.png filter=lfs diff=lfs merge=lfs -text
73
- assets/example_image/typical_vehicle_excavator.png filter=lfs diff=lfs merge=lfs -text
74
- assets/example_image/typical_vehicle_helicopter.png filter=lfs diff=lfs merge=lfs -text
75
- assets/example_image/typical_vehicle_locomotive.png filter=lfs diff=lfs merge=lfs -text
76
- assets/example_image/typical_vehicle_pirate_ship.png filter=lfs diff=lfs merge=lfs -text
77
- assets/example_image/weatherworn_misc_paper_machine3.png filter=lfs diff=lfs merge=lfs -text
78
- assets/example_multi_image/character_1.png filter=lfs diff=lfs merge=lfs -text
79
- assets/example_multi_image/character_2.png filter=lfs diff=lfs merge=lfs -text
80
- assets/example_multi_image/character_3.png filter=lfs diff=lfs merge=lfs -text
81
- assets/example_multi_image/mushroom_1.png filter=lfs diff=lfs merge=lfs -text
82
- assets/example_multi_image/mushroom_2.png filter=lfs diff=lfs merge=lfs -text
83
- assets/example_multi_image/mushroom_3.png filter=lfs diff=lfs merge=lfs -text
84
- assets/example_multi_image/orangeguy_1.png filter=lfs diff=lfs merge=lfs -text
85
- assets/example_multi_image/orangeguy_2.png filter=lfs diff=lfs merge=lfs -text
86
- assets/example_multi_image/orangeguy_3.png filter=lfs diff=lfs merge=lfs -text
87
- assets/example_multi_image/popmart_1.png filter=lfs diff=lfs merge=lfs -text
88
- assets/example_multi_image/popmart_2.png filter=lfs diff=lfs merge=lfs -text
89
- assets/example_multi_image/popmart_3.png filter=lfs diff=lfs merge=lfs -text
90
- assets/example_multi_image/rabbit_1.png filter=lfs diff=lfs merge=lfs -text
91
- assets/example_multi_image/rabbit_2.png filter=lfs diff=lfs merge=lfs -text
92
- assets/example_multi_image/rabbit_3.png filter=lfs diff=lfs merge=lfs -text
93
- assets/example_multi_image/tiger_1.png filter=lfs diff=lfs merge=lfs -text
94
- assets/example_multi_image/tiger_2.png filter=lfs diff=lfs merge=lfs -text
95
- assets/example_multi_image/tiger_3.png filter=lfs diff=lfs merge=lfs -text
96
- assets/example_multi_image/yoimiya_1.png filter=lfs diff=lfs merge=lfs -text
97
- assets/example_multi_image/yoimiya_2.png filter=lfs diff=lfs merge=lfs -text
98
- assets/example_multi_image/yoimiya_3.png filter=lfs diff=lfs merge=lfs -text
99
- assets/logo.webp filter=lfs diff=lfs merge=lfs -text
100
- assets/T.ply filter=lfs diff=lfs merge=lfs -text
101
- assets/teaser.png filter=lfs diff=lfs merge=lfs -text
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
+ *.ply filter=lfs diff=lfs merge=lfs -text
38
+ *.webp filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
40
+ wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea/*
2
+ __pycache__
README.md CHANGED
@@ -1,12 +1,16 @@
1
- ---
2
- title: TRELLIS - Texto a 3D
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.24.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: 3D Generation from text prompts
12
- ---
 
 
 
 
 
1
+ ---
2
+ title: TRELLIS Text To 3D
3
+ emoji: 🏢
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.25.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ short_description: Scalable and Versatile 3D Generation from text prompt
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ Paper: https://huggingface.co/papers/2412.01506
app.py CHANGED
@@ -1,232 +1,274 @@
1
- import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
- import os
5
- import shutil
6
- os.environ['TOKENIZERS_PARALLELISM'] = 'true'
7
- os.environ['SPCONV_ALGO'] = 'native'
8
- from typing import *
9
- import torch
10
- import numpy as np
11
- import imageio
12
- from easydict import EasyDict as edict
13
- from trellis.pipelines import TrellisTextTo3DPipeline
14
- from trellis.representations import Gaussian, MeshExtractResult
15
- from trellis.utils import render_utils, postprocessing_utils
16
-
17
- import traceback
18
- import sys
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
- os.makedirs(TMP_DIR, exist_ok=True)
23
-
24
-
25
- def start_session(req: gr.Request):
26
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
- os.makedirs(user_dir, exist_ok=True)
28
-
29
-
30
- def end_session(req: gr.Request):
31
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
- shutil.rmtree(user_dir)
33
-
34
-
35
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
36
- return {
37
- 'gaussian': {
38
- **gs.init_params,
39
- '_xyz': gs._xyz.cpu().numpy(),
40
- '_features_dc': gs._features_dc.cpu().numpy(),
41
- '_scaling': gs._scaling.cpu().numpy(),
42
- '_rotation': gs._rotation.cpu().numpy(),
43
- '_opacity': gs._opacity.cpu().numpy(),
44
- },
45
- 'mesh': {
46
- 'vertices': mesh.vertices.cpu().numpy(),
47
- 'faces': mesh.faces.cpu().numpy(),
48
- },
49
- }
50
-
51
-
52
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
53
- gs = Gaussian(
54
- aabb=state['gaussian']['aabb'],
55
- sh_degree=state['gaussian']['sh_degree'],
56
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
57
- scaling_bias=state['gaussian']['scaling_bias'],
58
- opacity_bias=state['gaussian']['opacity_bias'],
59
- scaling_activation=state['gaussian']['scaling_activation'],
60
- )
61
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
62
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
63
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
64
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
65
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
66
-
67
- mesh = edict(
68
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
69
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
70
- )
71
-
72
- return gs, mesh
73
-
74
-
75
- def get_seed(randomize_seed: bool, seed: int) -> int:
76
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
77
-
78
- @spaces.GPU
79
- def text_to_3d(
80
- prompt: str,
81
- seed: int,
82
- ss_guidance_strength: float,
83
- ss_sampling_steps: int,
84
- slat_guidance_strength: float,
85
- slat_sampling_steps: int,
86
- req: gr.Request,
87
- ) -> Tuple[dict, str]:
88
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
89
- outputs = pipeline.run(
90
- prompt,
91
- seed=seed,
92
- formats=["gaussian", "mesh"],
93
- sparse_structure_sampler_params={
94
- "steps": ss_sampling_steps,
95
- "cfg_strength": ss_guidance_strength,
96
- },
97
- slat_sampler_params={
98
- "steps": slat_sampling_steps,
99
- "cfg_strength": slat_guidance_strength,
100
- },
101
- )
102
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
103
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
104
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
105
- video_path = os.path.join(user_dir, 'sample.mp4')
106
- imageio.mimsave(video_path, video, fps=15)
107
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
108
- torch.cuda.empty_cache()
109
- return state, video_path
110
-
111
- @spaces.GPU(duration=90)
112
- def extract_glb(
113
- state: dict,
114
- mesh_simplify: float,
115
- texture_size: int,
116
- req: gr.Request,
117
- ) -> Tuple[str, str]:
118
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
119
- gs, mesh = unpack_state(state)
120
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
121
- glb_path = os.path.join(user_dir, 'sample.glb')
122
- glb.export(glb_path)
123
- torch.cuda.empty_cache()
124
- return glb_path, glb_path
125
-
126
-
127
- @spaces.GPU
128
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
129
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
130
- gs, _ = unpack_state(state)
131
- gaussian_path = os.path.join(user_dir, 'sample.ply')
132
- gs.save_ply(gaussian_path)
133
- torch.cuda.empty_cache()
134
- return gaussian_path, gaussian_path
135
-
136
- with gr.Blocks(delete_cache=(600, 600)) as demo:
137
- gr.Markdown("""
138
- # UTPL - Conversión de Texto a objetos 3D usando IA
139
- ### Tesis: *"Objetos tridimensionales creados por IA: Innovación en entornos virtuales"*
140
- **Autor:** Carlos Vargas
141
- **Base técnica:** Adaptación de [TRELLIS](https://trellis3d.github.io/) (herramienta de código abierto para generación 3D)
142
- **Propósito educativo:** Demostraciones académicas e Investigación en modelado 3D automático
143
- """)
144
-
145
- with gr.Row():
146
- with gr.Column():
147
- text_prompt = gr.Textbox(label="Text Prompt", lines=5)
148
-
149
- with gr.Accordion(label="Generation Settings", open=False):
150
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
151
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
152
- with gr.Row():
153
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
154
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
155
- with gr.Row():
156
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
157
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
158
-
159
- generate_btn = gr.Button("Generate")
160
-
161
- with gr.Accordion(label="GLB Extraction Settings", open=False):
162
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
163
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
164
-
165
- with gr.Row():
166
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
167
-
168
- with gr.Column(scale=3, min_width=600):
169
- with gr.Group():
170
- video_output = gr.Video(
171
- label="3D Preview",
172
- autoplay=True,
173
- loop=True,
174
- height=300,
175
- show_label=False
176
- )
177
- model_output = gr.Model3D(
178
- label="3D Model Viewer",
179
- height=400
180
- )
181
-
182
- with gr.Row():
183
- download_glb = gr.DownloadButton(
184
- label="Download GLB File",
185
- interactive=False,
186
- variant="secondary",
187
- size="lg"
188
- )
189
-
190
- output_buf = gr.State()
191
-
192
- # Handlers
193
- demo.load(start_session)
194
- demo.unload(end_session)
195
-
196
- generate_btn.click(
197
- get_seed,
198
- inputs=[randomize_seed, seed],
199
- outputs=[seed],
200
- ).then(
201
- text_to_3d,
202
- inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
203
- outputs=[output_buf, video_output],
204
- ).then(
205
- lambda: gr.Button(interactive=True),
206
- outputs=[extract_glb_btn],
207
- )
208
-
209
- video_output.clear(
210
- lambda: gr.Button(interactive=False),
211
- outputs=[extract_glb_btn],
212
- )
213
-
214
- extract_glb_btn.click(
215
- extract_glb,
216
- inputs=[output_buf, mesh_simplify, texture_size],
217
- outputs=[model_output, download_glb],
218
- ).then(
219
- lambda: gr.Button(interactive=True),
220
- outputs=[download_glb],
221
- )
222
-
223
- model_output.clear(
224
- lambda: gr.Button(interactive=False),
225
- outputs=[download_glb],
226
- )
227
-
228
- # Launch the Gradio app
229
- if __name__ == "__main__":
230
- pipeline = TrellisTextTo3DPipeline.from_pretrained("cavargas10/TRELLIS-text-xlarge")
231
- pipeline.cuda()
232
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+ import os
5
+ import shutil
6
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from trellis.pipelines import TrellisTextTo3DPipeline
14
+ from trellis.representations import Gaussian, MeshExtractResult
15
+ from trellis.utils import render_utils, postprocessing_utils
16
+
17
+ import traceback
18
+ import sys
19
+
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
+ os.makedirs(TMP_DIR, exist_ok=True)
24
+
25
+
26
+ def start_session(req: gr.Request):
27
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
+ os.makedirs(user_dir, exist_ok=True)
29
+
30
+
31
+ def end_session(req: gr.Request):
32
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
+ shutil.rmtree(user_dir)
34
+
35
+
36
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
37
+ return {
38
+ 'gaussian': {
39
+ **gs.init_params,
40
+ '_xyz': gs._xyz.cpu().numpy(),
41
+ '_features_dc': gs._features_dc.cpu().numpy(),
42
+ '_scaling': gs._scaling.cpu().numpy(),
43
+ '_rotation': gs._rotation.cpu().numpy(),
44
+ '_opacity': gs._opacity.cpu().numpy(),
45
+ },
46
+ 'mesh': {
47
+ 'vertices': mesh.vertices.cpu().numpy(),
48
+ 'faces': mesh.faces.cpu().numpy(),
49
+ },
50
+ }
51
+
52
+
53
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
54
+ gs = Gaussian(
55
+ aabb=state['gaussian']['aabb'],
56
+ sh_degree=state['gaussian']['sh_degree'],
57
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
58
+ scaling_bias=state['gaussian']['scaling_bias'],
59
+ opacity_bias=state['gaussian']['opacity_bias'],
60
+ scaling_activation=state['gaussian']['scaling_activation'],
61
+ )
62
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
63
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
64
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
65
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
66
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
67
+
68
+ mesh = edict(
69
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
70
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
71
+ )
72
+
73
+ return gs, mesh
74
+
75
+
76
+ def get_seed(randomize_seed: bool, seed: int) -> int:
77
+ """
78
+ Get the random seed.
79
+ """
80
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
81
+
82
+
83
+ @spaces.GPU
84
+ def text_to_3d(
85
+ prompt: str,
86
+ seed: int,
87
+ ss_guidance_strength: float,
88
+ ss_sampling_steps: int,
89
+ slat_guidance_strength: float,
90
+ slat_sampling_steps: int,
91
+ req: gr.Request,
92
+ ) -> Tuple[dict, str]:
93
+ """
94
+ Convert an text prompt to a 3D model.
95
+
96
+ Args:
97
+ prompt (str): The text prompt.
98
+ seed (int): The random seed.
99
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
100
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
101
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
102
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
103
+
104
+ Returns:
105
+ dict: The information of the generated 3D model.
106
+ str: The path to the video of the 3D model.
107
+ """
108
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
109
+ outputs = pipeline.run(
110
+ prompt,
111
+ seed=seed,
112
+ formats=["gaussian", "mesh"],
113
+ sparse_structure_sampler_params={
114
+ "steps": ss_sampling_steps,
115
+ "cfg_strength": ss_guidance_strength,
116
+ },
117
+ slat_sampler_params={
118
+ "steps": slat_sampling_steps,
119
+ "cfg_strength": slat_guidance_strength,
120
+ },
121
+ )
122
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
123
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
124
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
125
+ video_path = os.path.join(user_dir, 'sample.mp4')
126
+ imageio.mimsave(video_path, video, fps=15)
127
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
128
+ torch.cuda.empty_cache()
129
+ return state, video_path
130
+
131
+
132
+ @spaces.GPU(duration=90)
133
+ def extract_glb(
134
+ state: dict,
135
+ mesh_simplify: float,
136
+ texture_size: int,
137
+ req: gr.Request,
138
+ ) -> Tuple[str, str]:
139
+ """
140
+ Extract a GLB file from the 3D model.
141
+
142
+ Args:
143
+ state (dict): The state of the generated 3D model.
144
+ mesh_simplify (float): The mesh simplification factor.
145
+ texture_size (int): The texture resolution.
146
+
147
+ Returns:
148
+ str: The path to the extracted GLB file.
149
+ """
150
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
151
+ gs, mesh = unpack_state(state)
152
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
153
+ glb_path = os.path.join(user_dir, 'sample.glb')
154
+ glb.export(glb_path)
155
+ torch.cuda.empty_cache()
156
+ return glb_path, glb_path
157
+
158
+
159
+ @spaces.GPU
160
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
161
+ """
162
+ Extract a Gaussian file from the 3D model.
163
+
164
+ Args:
165
+ state (dict): The state of the generated 3D model.
166
+
167
+ Returns:
168
+ str: The path to the extracted Gaussian file.
169
+ """
170
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
171
+ gs, _ = unpack_state(state)
172
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
173
+ gs.save_ply(gaussian_path)
174
+ torch.cuda.empty_cache()
175
+ return gaussian_path, gaussian_path
176
+
177
+
178
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
179
+ gr.Markdown("""
180
+ ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
181
+ * Type a text prompt and click "Generate" to create a 3D asset.
182
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
183
+ """)
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ text_prompt = gr.Textbox(label="Text Prompt", lines=5)
188
+
189
+ with gr.Accordion(label="Generation Settings", open=False):
190
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
191
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
192
+ gr.Markdown("Stage 1: Sparse Structure Generation")
193
+ with gr.Row():
194
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
195
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
196
+ gr.Markdown("Stage 2: Structured Latent Generation")
197
+ with gr.Row():
198
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
199
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
200
+
201
+ generate_btn = gr.Button("Generate")
202
+
203
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
204
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
205
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
206
+
207
+ with gr.Row():
208
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
209
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
210
+ gr.Markdown("""
211
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
212
+ """)
213
+
214
+ with gr.Column():
215
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
216
+ model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
217
+
218
+ with gr.Row():
219
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
220
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
221
+
222
+ output_buf = gr.State()
223
+
224
+ # Handlers
225
+ demo.load(start_session)
226
+ demo.unload(end_session)
227
+
228
+ generate_btn.click(
229
+ get_seed,
230
+ inputs=[randomize_seed, seed],
231
+ outputs=[seed],
232
+ ).then(
233
+ text_to_3d,
234
+ inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
235
+ outputs=[output_buf, video_output],
236
+ ).then(
237
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
238
+ outputs=[extract_glb_btn, extract_gs_btn],
239
+ )
240
+
241
+ video_output.clear(
242
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
243
+ outputs=[extract_glb_btn, extract_gs_btn],
244
+ )
245
+
246
+ extract_glb_btn.click(
247
+ extract_glb,
248
+ inputs=[output_buf, mesh_simplify, texture_size],
249
+ outputs=[model_output, download_glb],
250
+ ).then(
251
+ lambda: gr.Button(interactive=True),
252
+ outputs=[download_glb],
253
+ )
254
+
255
+ extract_gs_btn.click(
256
+ extract_gaussian,
257
+ inputs=[output_buf],
258
+ outputs=[model_output, download_gs],
259
+ ).then(
260
+ lambda: gr.Button(interactive=True),
261
+ outputs=[download_gs],
262
+ )
263
+
264
+ model_output.clear(
265
+ lambda: gr.Button(interactive=False),
266
+ outputs=[download_glb],
267
+ )
268
+
269
+
270
+ # Launch the Gradio app
271
+ if __name__ == "__main__":
272
+ pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
273
+ pipeline.cuda()
274
+ demo.launch()
app_img.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+
5
+ import os
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from PIL import Image
14
+ from trellis.pipelines import TrellisImageTo3DPipeline
15
+ from trellis.representations import Gaussian, MeshExtractResult
16
+ from trellis.utils import render_utils, postprocessing_utils
17
+
18
+
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
+ os.makedirs(TMP_DIR, exist_ok=True)
22
+
23
+
24
+ def start_session(req: gr.Request):
25
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
+ os.makedirs(user_dir, exist_ok=True)
27
+
28
+
29
+ def end_session(req: gr.Request):
30
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
+ shutil.rmtree(user_dir)
32
+
33
+
34
+ def preprocess_image(image: Image.Image) -> Image.Image:
35
+ """
36
+ Preprocess the input image.
37
+
38
+ Args:
39
+ image (Image.Image): The input image.
40
+
41
+ Returns:
42
+ Image.Image: The preprocessed image.
43
+ """
44
+ processed_image = pipeline.preprocess_image(image)
45
+ return processed_image
46
+
47
+
48
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
49
+ """
50
+ Preprocess a list of input images.
51
+
52
+ Args:
53
+ images (List[Tuple[Image.Image, str]]): The input images.
54
+
55
+ Returns:
56
+ List[Image.Image]: The preprocessed images.
57
+ """
58
+ images = [image[0] for image in images]
59
+ processed_images = [pipeline.preprocess_image(image) for image in images]
60
+ return processed_images
61
+
62
+
63
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
64
+ return {
65
+ 'gaussian': {
66
+ **gs.init_params,
67
+ '_xyz': gs._xyz.cpu().numpy(),
68
+ '_features_dc': gs._features_dc.cpu().numpy(),
69
+ '_scaling': gs._scaling.cpu().numpy(),
70
+ '_rotation': gs._rotation.cpu().numpy(),
71
+ '_opacity': gs._opacity.cpu().numpy(),
72
+ },
73
+ 'mesh': {
74
+ 'vertices': mesh.vertices.cpu().numpy(),
75
+ 'faces': mesh.faces.cpu().numpy(),
76
+ },
77
+ }
78
+
79
+
80
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
81
+ gs = Gaussian(
82
+ aabb=state['gaussian']['aabb'],
83
+ sh_degree=state['gaussian']['sh_degree'],
84
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
85
+ scaling_bias=state['gaussian']['scaling_bias'],
86
+ opacity_bias=state['gaussian']['opacity_bias'],
87
+ scaling_activation=state['gaussian']['scaling_activation'],
88
+ )
89
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
90
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
91
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
92
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
93
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
94
+
95
+ mesh = edict(
96
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
97
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
98
+ )
99
+
100
+ return gs, mesh
101
+
102
+
103
+ def get_seed(randomize_seed: bool, seed: int) -> int:
104
+ """
105
+ Get the random seed.
106
+ """
107
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
108
+
109
+
110
+ @spaces.GPU
111
+ def image_to_3d(
112
+ image: Image.Image,
113
+ multiimages: List[Tuple[Image.Image, str]],
114
+ is_multiimage: bool,
115
+ seed: int,
116
+ ss_guidance_strength: float,
117
+ ss_sampling_steps: int,
118
+ slat_guidance_strength: float,
119
+ slat_sampling_steps: int,
120
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
121
+ req: gr.Request,
122
+ ) -> Tuple[dict, str]:
123
+ """
124
+ Convert an image to a 3D model.
125
+
126
+ Args:
127
+ image (Image.Image): The input image.
128
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
129
+ is_multiimage (bool): Whether is in multi-image mode.
130
+ seed (int): The random seed.
131
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
132
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
133
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
134
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
135
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
136
+
137
+ Returns:
138
+ dict: The information of the generated 3D model.
139
+ str: The path to the video of the 3D model.
140
+ """
141
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
142
+ if not is_multiimage:
143
+ outputs = pipeline.run(
144
+ image,
145
+ seed=seed,
146
+ formats=["gaussian", "mesh"],
147
+ preprocess_image=False,
148
+ sparse_structure_sampler_params={
149
+ "steps": ss_sampling_steps,
150
+ "cfg_strength": ss_guidance_strength,
151
+ },
152
+ slat_sampler_params={
153
+ "steps": slat_sampling_steps,
154
+ "cfg_strength": slat_guidance_strength,
155
+ },
156
+ )
157
+ else:
158
+ outputs = pipeline.run_multi_image(
159
+ [image[0] for image in multiimages],
160
+ seed=seed,
161
+ formats=["gaussian", "mesh"],
162
+ preprocess_image=False,
163
+ sparse_structure_sampler_params={
164
+ "steps": ss_sampling_steps,
165
+ "cfg_strength": ss_guidance_strength,
166
+ },
167
+ slat_sampler_params={
168
+ "steps": slat_sampling_steps,
169
+ "cfg_strength": slat_guidance_strength,
170
+ },
171
+ mode=multiimage_algo,
172
+ )
173
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
174
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
175
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
176
+ video_path = os.path.join(user_dir, 'sample.mp4')
177
+ imageio.mimsave(video_path, video, fps=15)
178
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
179
+ torch.cuda.empty_cache()
180
+ return state, video_path
181
+
182
+
183
+ @spaces.GPU(duration=90)
184
+ def extract_glb(
185
+ state: dict,
186
+ mesh_simplify: float,
187
+ texture_size: int,
188
+ req: gr.Request,
189
+ ) -> Tuple[str, str]:
190
+ """
191
+ Extract a GLB file from the 3D model.
192
+
193
+ Args:
194
+ state (dict): The state of the generated 3D model.
195
+ mesh_simplify (float): The mesh simplification factor.
196
+ texture_size (int): The texture resolution.
197
+
198
+ Returns:
199
+ str: The path to the extracted GLB file.
200
+ """
201
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
+ gs, mesh = unpack_state(state)
203
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
204
+ glb_path = os.path.join(user_dir, 'sample.glb')
205
+ glb.export(glb_path)
206
+ torch.cuda.empty_cache()
207
+ return glb_path, glb_path
208
+
209
+
210
+ @spaces.GPU
211
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
212
+ """
213
+ Extract a Gaussian file from the 3D model.
214
+
215
+ Args:
216
+ state (dict): The state of the generated 3D model.
217
+
218
+ Returns:
219
+ str: The path to the extracted Gaussian file.
220
+ """
221
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
222
+ gs, _ = unpack_state(state)
223
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
224
+ gs.save_ply(gaussian_path)
225
+ torch.cuda.empty_cache()
226
+ return gaussian_path, gaussian_path
227
+
228
+
229
+ def prepare_multi_example() -> List[Image.Image]:
230
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
231
+ images = []
232
+ for case in multi_case:
233
+ _images = []
234
+ for i in range(1, 4):
235
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
236
+ W, H = img.size
237
+ img = img.resize((int(W / H * 512), 512))
238
+ _images.append(np.array(img))
239
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
240
+ return images
241
+
242
+
243
+ def split_image(image: Image.Image) -> List[Image.Image]:
244
+ """
245
+ Split an image into multiple views.
246
+ """
247
+ image = np.array(image)
248
+ alpha = image[..., 3]
249
+ alpha = np.any(alpha>0, axis=0)
250
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
251
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
252
+ images = []
253
+ for s, e in zip(start_pos, end_pos):
254
+ images.append(Image.fromarray(image[:, s:e+1]))
255
+ return [preprocess_image(image) for image in images]
256
+
257
+
258
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
259
+ gr.Markdown("""
260
+ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
261
+ * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
262
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
263
+
264
+ ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
265
+ """)
266
+
267
+ with gr.Row():
268
+ with gr.Column():
269
+ with gr.Tabs() as input_tabs:
270
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
271
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
272
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
273
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
274
+ gr.Markdown("""
275
+ Input different views of the object in separate images.
276
+
277
+ *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
278
+ """)
279
+
280
+ with gr.Accordion(label="Generation Settings", open=False):
281
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
282
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
283
+ gr.Markdown("Stage 1: Sparse Structure Generation")
284
+ with gr.Row():
285
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
286
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
287
+ gr.Markdown("Stage 2: Structured Latent Generation")
288
+ with gr.Row():
289
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
290
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
291
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
292
+
293
+ generate_btn = gr.Button("Generate")
294
+
295
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
296
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
297
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
298
+
299
+ with gr.Row():
300
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
301
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
302
+ gr.Markdown("""
303
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
304
+ """)
305
+
306
+ with gr.Column():
307
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
308
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
309
+
310
+ with gr.Row():
311
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
312
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
313
+
314
+ is_multiimage = gr.State(False)
315
+ output_buf = gr.State()
316
+
317
+ # Example images at the bottom of the page
318
+ with gr.Row() as single_image_example:
319
+ examples = gr.Examples(
320
+ examples=[
321
+ f'assets/example_image/{image}'
322
+ for image in os.listdir("assets/example_image")
323
+ ],
324
+ inputs=[image_prompt],
325
+ fn=preprocess_image,
326
+ outputs=[image_prompt],
327
+ run_on_click=True,
328
+ examples_per_page=64,
329
+ )
330
+ with gr.Row(visible=False) as multiimage_example:
331
+ examples_multi = gr.Examples(
332
+ examples=prepare_multi_example(),
333
+ inputs=[image_prompt],
334
+ fn=split_image,
335
+ outputs=[multiimage_prompt],
336
+ run_on_click=True,
337
+ examples_per_page=8,
338
+ )
339
+
340
+ # Handlers
341
+ demo.load(start_session)
342
+ demo.unload(end_session)
343
+
344
+ single_image_input_tab.select(
345
+ lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
346
+ outputs=[is_multiimage, single_image_example, multiimage_example]
347
+ )
348
+ multiimage_input_tab.select(
349
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
350
+ outputs=[is_multiimage, single_image_example, multiimage_example]
351
+ )
352
+
353
+ image_prompt.upload(
354
+ preprocess_image,
355
+ inputs=[image_prompt],
356
+ outputs=[image_prompt],
357
+ )
358
+ multiimage_prompt.upload(
359
+ preprocess_images,
360
+ inputs=[multiimage_prompt],
361
+ outputs=[multiimage_prompt],
362
+ )
363
+
364
+ generate_btn.click(
365
+ get_seed,
366
+ inputs=[randomize_seed, seed],
367
+ outputs=[seed],
368
+ ).then(
369
+ image_to_3d,
370
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
371
+ outputs=[output_buf, video_output],
372
+ ).then(
373
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
374
+ outputs=[extract_glb_btn, extract_gs_btn],
375
+ )
376
+
377
+ video_output.clear(
378
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
379
+ outputs=[extract_glb_btn, extract_gs_btn],
380
+ )
381
+
382
+ extract_glb_btn.click(
383
+ extract_glb,
384
+ inputs=[output_buf, mesh_simplify, texture_size],
385
+ outputs=[model_output, download_glb],
386
+ ).then(
387
+ lambda: gr.Button(interactive=True),
388
+ outputs=[download_glb],
389
+ )
390
+
391
+ extract_gs_btn.click(
392
+ extract_gaussian,
393
+ inputs=[output_buf],
394
+ outputs=[model_output, download_gs],
395
+ ).then(
396
+ lambda: gr.Button(interactive=True),
397
+ outputs=[download_gs],
398
+ )
399
+
400
+ model_output.clear(
401
+ lambda: gr.Button(interactive=False),
402
+ outputs=[download_glb],
403
+ )
404
+
405
+
406
+ # Launch the Gradio app
407
+ if __name__ == "__main__":
408
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
409
+ pipeline.cuda()
410
+ try:
411
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
412
+ except:
413
+ pass
414
+ demo.launch()
configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "ElasticSLatFlowModel",
5
+ "args": {
6
+ "resolution": 64,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 1024,
10
+ "cond_channels": 1024,
11
+ "num_blocks": 24,
12
+ "num_heads": 16,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 2,
15
+ "num_io_res_blocks": 2,
16
+ "io_block_channels": [128],
17
+ "pe_mode": "ape",
18
+ "qk_rms_norm": true,
19
+ "use_fp16": true
20
+ }
21
+ }
22
+ },
23
+ "dataset": {
24
+ "name": "ImageConditionedSLat",
25
+ "args": {
26
+ "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
27
+ "min_aesthetic_score": 4.5,
28
+ "max_num_voxels": 32768,
29
+ "image_size": 518,
30
+ "normalization": {
31
+ "mean": [
32
+ -2.1687545776367188,
33
+ -0.004347046371549368,
34
+ -0.13352349400520325,
35
+ -0.08418072760105133,
36
+ -0.5271206498146057,
37
+ 0.7238689064979553,
38
+ -1.1414450407028198,
39
+ 1.2039363384246826
40
+ ],
41
+ "std": [
42
+ 2.377650737762451,
43
+ 2.386378288269043,
44
+ 2.124418020248413,
45
+ 2.1748552322387695,
46
+ 2.663944721221924,
47
+ 2.371192216873169,
48
+ 2.6217446327209473,
49
+ 2.684523105621338
50
+ ]
51
+ },
52
+ "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
53
+ }
54
+ },
55
+ "trainer": {
56
+ "name": "ImageConditionedSparseFlowMatchingCFGTrainer",
57
+ "args": {
58
+ "max_steps": 1000000,
59
+ "batch_size_per_gpu": 8,
60
+ "batch_split": 4,
61
+ "optimizer": {
62
+ "name": "AdamW",
63
+ "args": {
64
+ "lr": 0.0001,
65
+ "weight_decay": 0.0
66
+ }
67
+ },
68
+ "ema_rate": [
69
+ 0.9999
70
+ ],
71
+ "fp16_mode": "inflat_all",
72
+ "fp16_scale_growth": 0.001,
73
+ "elastic": {
74
+ "name": "LinearMemoryController",
75
+ "args": {
76
+ "target_ratio": 0.75,
77
+ "max_mem_ratio_start": 0.5
78
+ }
79
+ },
80
+ "grad_clip": {
81
+ "name": "AdaptiveGradClipper",
82
+ "args": {
83
+ "max_norm": 1.0,
84
+ "clip_percentile": 95
85
+ }
86
+ },
87
+ "i_log": 500,
88
+ "i_sample": 10000,
89
+ "i_save": 10000,
90
+ "p_uncond": 0.1,
91
+ "t_schedule": {
92
+ "name": "logitNormal",
93
+ "args": {
94
+ "mean": 1.0,
95
+ "std": 1.0
96
+ }
97
+ },
98
+ "sigma_min": 1e-5,
99
+ "image_cond_model": "dinov2_vitl14_reg"
100
+ }
101
+ }
102
+ }
configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "ElasticSLatFlowModel",
5
+ "args": {
6
+ "resolution": 64,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 768,
10
+ "cond_channels": 768,
11
+ "num_blocks": 12,
12
+ "num_heads": 12,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 2,
15
+ "num_io_res_blocks": 2,
16
+ "io_block_channels": [128],
17
+ "pe_mode": "ape",
18
+ "qk_rms_norm": true,
19
+ "use_fp16": true
20
+ }
21
+ }
22
+ },
23
+ "dataset": {
24
+ "name": "TextConditionedSLat",
25
+ "args": {
26
+ "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
27
+ "min_aesthetic_score": 4.5,
28
+ "max_num_voxels": 32768,
29
+ "normalization": {
30
+ "mean": [
31
+ -2.1687545776367188,
32
+ -0.004347046371549368,
33
+ -0.13352349400520325,
34
+ -0.08418072760105133,
35
+ -0.5271206498146057,
36
+ 0.7238689064979553,
37
+ -1.1414450407028198,
38
+ 1.2039363384246826
39
+ ],
40
+ "std": [
41
+ 2.377650737762451,
42
+ 2.386378288269043,
43
+ 2.124418020248413,
44
+ 2.1748552322387695,
45
+ 2.663944721221924,
46
+ 2.371192216873169,
47
+ 2.6217446327209473,
48
+ 2.684523105621338
49
+ ]
50
+ },
51
+ "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
52
+ }
53
+ },
54
+ "trainer": {
55
+ "name": "TextConditionedSparseFlowMatchingCFGTrainer",
56
+ "args": {
57
+ "max_steps": 1000000,
58
+ "batch_size_per_gpu": 16,
59
+ "batch_split": 4,
60
+ "optimizer": {
61
+ "name": "AdamW",
62
+ "args": {
63
+ "lr": 0.0001,
64
+ "weight_decay": 0.0
65
+ }
66
+ },
67
+ "ema_rate": [
68
+ 0.9999
69
+ ],
70
+ "fp16_mode": "inflat_all",
71
+ "fp16_scale_growth": 0.001,
72
+ "elastic": {
73
+ "name": "LinearMemoryController",
74
+ "args": {
75
+ "target_ratio": 0.75,
76
+ "max_mem_ratio_start": 0.5
77
+ }
78
+ },
79
+ "grad_clip": {
80
+ "name": "AdaptiveGradClipper",
81
+ "args": {
82
+ "max_norm": 1.0,
83
+ "clip_percentile": 95
84
+ }
85
+ },
86
+ "i_log": 500,
87
+ "i_sample": 10000,
88
+ "i_save": 10000,
89
+ "p_uncond": 0.1,
90
+ "t_schedule": {
91
+ "name": "logitNormal",
92
+ "args": {
93
+ "mean": 1.0,
94
+ "std": 1.0
95
+ }
96
+ },
97
+ "sigma_min": 1e-5,
98
+ "text_cond_model": "openai/clip-vit-large-patch14"
99
+ }
100
+ }
101
+ }
configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "ElasticSLatFlowModel",
5
+ "args": {
6
+ "resolution": 64,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 1024,
10
+ "cond_channels": 768,
11
+ "num_blocks": 24,
12
+ "num_heads": 16,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 2,
15
+ "num_io_res_blocks": 2,
16
+ "io_block_channels": [128],
17
+ "pe_mode": "ape",
18
+ "qk_rms_norm": true,
19
+ "use_fp16": true
20
+ }
21
+ }
22
+ },
23
+ "dataset": {
24
+ "name": "TextConditionedSLat",
25
+ "args": {
26
+ "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
27
+ "min_aesthetic_score": 4.5,
28
+ "max_num_voxels": 32768,
29
+ "normalization": {
30
+ "mean": [
31
+ -2.1687545776367188,
32
+ -0.004347046371549368,
33
+ -0.13352349400520325,
34
+ -0.08418072760105133,
35
+ -0.5271206498146057,
36
+ 0.7238689064979553,
37
+ -1.1414450407028198,
38
+ 1.2039363384246826
39
+ ],
40
+ "std": [
41
+ 2.377650737762451,
42
+ 2.386378288269043,
43
+ 2.124418020248413,
44
+ 2.1748552322387695,
45
+ 2.663944721221924,
46
+ 2.371192216873169,
47
+ 2.6217446327209473,
48
+ 2.684523105621338
49
+ ]
50
+ },
51
+ "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
52
+ }
53
+ },
54
+ "trainer": {
55
+ "name": "TextConditionedSparseFlowMatchingCFGTrainer",
56
+ "args": {
57
+ "max_steps": 1000000,
58
+ "batch_size_per_gpu": 8,
59
+ "batch_split": 4,
60
+ "optimizer": {
61
+ "name": "AdamW",
62
+ "args": {
63
+ "lr": 0.0001,
64
+ "weight_decay": 0.0
65
+ }
66
+ },
67
+ "ema_rate": [
68
+ 0.9999
69
+ ],
70
+ "fp16_mode": "inflat_all",
71
+ "fp16_scale_growth": 0.001,
72
+ "elastic": {
73
+ "name": "LinearMemoryController",
74
+ "args": {
75
+ "target_ratio": 0.75,
76
+ "max_mem_ratio_start": 0.5
77
+ }
78
+ },
79
+ "grad_clip": {
80
+ "name": "AdaptiveGradClipper",
81
+ "args": {
82
+ "max_norm": 1.0,
83
+ "clip_percentile": 95
84
+ }
85
+ },
86
+ "i_log": 500,
87
+ "i_sample": 10000,
88
+ "i_save": 10000,
89
+ "p_uncond": 0.1,
90
+ "t_schedule": {
91
+ "name": "logitNormal",
92
+ "args": {
93
+ "mean": 1.0,
94
+ "std": 1.0
95
+ }
96
+ },
97
+ "sigma_min": 1e-5,
98
+ "text_cond_model": "openai/clip-vit-large-patch14"
99
+ }
100
+ }
101
+ }
configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "ElasticSLatFlowModel",
5
+ "args": {
6
+ "resolution": 64,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 1280,
10
+ "cond_channels": 768,
11
+ "num_blocks": 28,
12
+ "num_heads": 16,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 2,
15
+ "num_io_res_blocks": 3,
16
+ "io_block_channels": [256],
17
+ "pe_mode": "ape",
18
+ "qk_rms_norm": true,
19
+ "use_fp16": true
20
+ }
21
+ }
22
+ },
23
+ "dataset": {
24
+ "name": "TextConditionedSLat",
25
+ "args": {
26
+ "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
27
+ "min_aesthetic_score": 4.5,
28
+ "max_num_voxels": 32768,
29
+ "normalization": {
30
+ "mean": [
31
+ -2.1687545776367188,
32
+ -0.004347046371549368,
33
+ -0.13352349400520325,
34
+ -0.08418072760105133,
35
+ -0.5271206498146057,
36
+ 0.7238689064979553,
37
+ -1.1414450407028198,
38
+ 1.2039363384246826
39
+ ],
40
+ "std": [
41
+ 2.377650737762451,
42
+ 2.386378288269043,
43
+ 2.124418020248413,
44
+ 2.1748552322387695,
45
+ 2.663944721221924,
46
+ 2.371192216873169,
47
+ 2.6217446327209473,
48
+ 2.684523105621338
49
+ ]
50
+ },
51
+ "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
52
+ }
53
+ },
54
+ "trainer": {
55
+ "name": "TextConditionedSparseFlowMatchingCFGTrainer",
56
+ "args": {
57
+ "max_steps": 1000000,
58
+ "batch_size_per_gpu": 4,
59
+ "batch_split": 4,
60
+ "optimizer": {
61
+ "name": "AdamW",
62
+ "args": {
63
+ "lr": 0.0001,
64
+ "weight_decay": 0.0
65
+ }
66
+ },
67
+ "ema_rate": [
68
+ 0.9999
69
+ ],
70
+ "fp16_mode": "inflat_all",
71
+ "fp16_scale_growth": 0.001,
72
+ "elastic": {
73
+ "name": "LinearMemoryController",
74
+ "args": {
75
+ "target_ratio": 0.75,
76
+ "max_mem_ratio_start": 0.5
77
+ }
78
+ },
79
+ "grad_clip": {
80
+ "name": "AdaptiveGradClipper",
81
+ "args": {
82
+ "max_norm": 1.0,
83
+ "clip_percentile": 95
84
+ }
85
+ },
86
+ "i_log": 500,
87
+ "i_sample": 10000,
88
+ "i_save": 10000,
89
+ "p_uncond": 0.1,
90
+ "t_schedule": {
91
+ "name": "logitNormal",
92
+ "args": {
93
+ "mean": 1.0,
94
+ "std": 1.0
95
+ }
96
+ },
97
+ "sigma_min": 1e-5,
98
+ "text_cond_model": "openai/clip-vit-large-patch14"
99
+ }
100
+ }
101
+ }
configs/generation/ss_flow_img_dit_L_16l8_fp16.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "SparseStructureFlowModel",
5
+ "args": {
6
+ "resolution": 16,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 1024,
10
+ "cond_channels": 1024,
11
+ "num_blocks": 24,
12
+ "num_heads": 16,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 1,
15
+ "pe_mode": "ape",
16
+ "qk_rms_norm": true,
17
+ "use_fp16": true
18
+ }
19
+ }
20
+ },
21
+ "dataset": {
22
+ "name": "ImageConditionedSparseStructureLatent",
23
+ "args": {
24
+ "latent_model": "ss_enc_conv3d_16l8_fp16",
25
+ "min_aesthetic_score": 4.5,
26
+ "image_size": 518,
27
+ "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
28
+ }
29
+ },
30
+ "trainer": {
31
+ "name": "ImageConditionedFlowMatchingCFGTrainer",
32
+ "args": {
33
+ "max_steps": 1000000,
34
+ "batch_size_per_gpu": 8,
35
+ "batch_split": 1,
36
+ "optimizer": {
37
+ "name": "AdamW",
38
+ "args": {
39
+ "lr": 0.0001,
40
+ "weight_decay": 0.0
41
+ }
42
+ },
43
+ "ema_rate": [
44
+ 0.9999
45
+ ],
46
+ "fp16_mode": "inflat_all",
47
+ "fp16_scale_growth": 0.001,
48
+ "grad_clip": {
49
+ "name": "AdaptiveGradClipper",
50
+ "args": {
51
+ "max_norm": 1.0,
52
+ "clip_percentile": 95
53
+ }
54
+ },
55
+ "i_log": 500,
56
+ "i_sample": 10000,
57
+ "i_save": 10000,
58
+ "p_uncond": 0.1,
59
+ "t_schedule": {
60
+ "name": "logitNormal",
61
+ "args": {
62
+ "mean": 1.0,
63
+ "std": 1.0
64
+ }
65
+ },
66
+ "sigma_min": 1e-5,
67
+ "image_cond_model": "dinov2_vitl14_reg"
68
+ }
69
+ }
70
+ }
configs/generation/ss_flow_txt_dit_B_16l8_fp16.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "SparseStructureFlowModel",
5
+ "args": {
6
+ "resolution": 16,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 768,
10
+ "cond_channels": 768,
11
+ "num_blocks": 12,
12
+ "num_heads": 12,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 1,
15
+ "pe_mode": "ape",
16
+ "qk_rms_norm": true,
17
+ "use_fp16": true
18
+ }
19
+ }
20
+ },
21
+ "dataset": {
22
+ "name": "TextConditionedSparseStructureLatent",
23
+ "args": {
24
+ "latent_model": "ss_enc_conv3d_16l8_fp16",
25
+ "min_aesthetic_score": 4.5,
26
+ "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
27
+ }
28
+ },
29
+ "trainer": {
30
+ "name": "TextConditionedFlowMatchingCFGTrainer",
31
+ "args": {
32
+ "max_steps": 1000000,
33
+ "batch_size_per_gpu": 16,
34
+ "batch_split": 1,
35
+ "optimizer": {
36
+ "name": "AdamW",
37
+ "args": {
38
+ "lr": 0.0001,
39
+ "weight_decay": 0.0
40
+ }
41
+ },
42
+ "ema_rate": [
43
+ 0.9999
44
+ ],
45
+ "fp16_mode": "inflat_all",
46
+ "fp16_scale_growth": 0.001,
47
+ "grad_clip": {
48
+ "name": "AdaptiveGradClipper",
49
+ "args": {
50
+ "max_norm": 1.0,
51
+ "clip_percentile": 95
52
+ }
53
+ },
54
+ "i_log": 500,
55
+ "i_sample": 10000,
56
+ "i_save": 10000,
57
+ "p_uncond": 0.1,
58
+ "t_schedule": {
59
+ "name": "logitNormal",
60
+ "args": {
61
+ "mean": 1.0,
62
+ "std": 1.0
63
+ }
64
+ },
65
+ "sigma_min": 1e-5,
66
+ "text_cond_model": "openai/clip-vit-large-patch14"
67
+ }
68
+ }
69
+ }
configs/generation/ss_flow_txt_dit_L_16l8_fp16.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "SparseStructureFlowModel",
5
+ "args": {
6
+ "resolution": 16,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 1024,
10
+ "cond_channels": 768,
11
+ "num_blocks": 24,
12
+ "num_heads": 16,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 1,
15
+ "pe_mode": "ape",
16
+ "qk_rms_norm": true,
17
+ "use_fp16": true
18
+ }
19
+ }
20
+ },
21
+ "dataset": {
22
+ "name": "TextConditionedSparseStructureLatent",
23
+ "args": {
24
+ "latent_model": "ss_enc_conv3d_16l8_fp16",
25
+ "min_aesthetic_score": 4.5,
26
+ "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
27
+ }
28
+ },
29
+ "trainer": {
30
+ "name": "TextConditionedFlowMatchingCFGTrainer",
31
+ "args": {
32
+ "max_steps": 1000000,
33
+ "batch_size_per_gpu": 8,
34
+ "batch_split": 1,
35
+ "optimizer": {
36
+ "name": "AdamW",
37
+ "args": {
38
+ "lr": 0.0001,
39
+ "weight_decay": 0.0
40
+ }
41
+ },
42
+ "ema_rate": [
43
+ 0.9999
44
+ ],
45
+ "fp16_mode": "inflat_all",
46
+ "fp16_scale_growth": 0.001,
47
+ "grad_clip": {
48
+ "name": "AdaptiveGradClipper",
49
+ "args": {
50
+ "max_norm": 1.0,
51
+ "clip_percentile": 95
52
+ }
53
+ },
54
+ "i_log": 500,
55
+ "i_sample": 10000,
56
+ "i_save": 10000,
57
+ "p_uncond": 0.1,
58
+ "t_schedule": {
59
+ "name": "logitNormal",
60
+ "args": {
61
+ "mean": 1.0,
62
+ "std": 1.0
63
+ }
64
+ },
65
+ "sigma_min": 1e-5,
66
+ "text_cond_model": "openai/clip-vit-large-patch14"
67
+ }
68
+ }
69
+ }
configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "denoiser": {
4
+ "name": "SparseStructureFlowModel",
5
+ "args": {
6
+ "resolution": 16,
7
+ "in_channels": 8,
8
+ "out_channels": 8,
9
+ "model_channels": 1280,
10
+ "cond_channels": 768,
11
+ "num_blocks": 28,
12
+ "num_heads": 16,
13
+ "mlp_ratio": 4,
14
+ "patch_size": 1,
15
+ "pe_mode": "ape",
16
+ "qk_rms_norm": true,
17
+ "qk_rms_norm_cross": true,
18
+ "use_fp16": true
19
+ }
20
+ }
21
+ },
22
+ "dataset": {
23
+ "name": "TextConditionedSparseStructureLatent",
24
+ "args": {
25
+ "latent_model": "ss_enc_conv3d_16l8_fp16",
26
+ "min_aesthetic_score": 4.5,
27
+ "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
28
+ }
29
+ },
30
+ "trainer": {
31
+ "name": "TextConditionedFlowMatchingCFGTrainer",
32
+ "args": {
33
+ "max_steps": 1000000,
34
+ "batch_size_per_gpu": 4,
35
+ "batch_split": 1,
36
+ "optimizer": {
37
+ "name": "AdamW",
38
+ "args": {
39
+ "lr": 0.0001,
40
+ "weight_decay": 0.0
41
+ }
42
+ },
43
+ "ema_rate": [
44
+ 0.9999
45
+ ],
46
+ "fp16_mode": "inflat_all",
47
+ "fp16_scale_growth": 0.001,
48
+ "grad_clip": {
49
+ "name": "AdaptiveGradClipper",
50
+ "args": {
51
+ "max_norm": 1.0,
52
+ "clip_percentile": 95
53
+ }
54
+ },
55
+ "i_log": 500,
56
+ "i_sample": 10000,
57
+ "i_save": 10000,
58
+ "p_uncond": 0.1,
59
+ "t_schedule": {
60
+ "name": "logitNormal",
61
+ "args": {
62
+ "mean": 1.0,
63
+ "std": 1.0
64
+ }
65
+ },
66
+ "sigma_min": 1e-5,
67
+ "text_cond_model": "openai/clip-vit-large-patch14"
68
+ }
69
+ }
70
+ }
configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "decoder": {
4
+ "name": "ElasticSLatMeshDecoder",
5
+ "args": {
6
+ "resolution": 64,
7
+ "model_channels": 768,
8
+ "latent_channels": 8,
9
+ "num_blocks": 12,
10
+ "num_heads": 12,
11
+ "mlp_ratio": 4,
12
+ "attn_mode": "swin",
13
+ "window_size": 8,
14
+ "use_fp16": true,
15
+ "representation_config": {
16
+ "use_color": true
17
+ }
18
+ }
19
+ }
20
+ },
21
+ "dataset": {
22
+ "name": "Slat2RenderGeo",
23
+ "args": {
24
+ "image_size": 512,
25
+ "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
26
+ "min_aesthetic_score": 4.5,
27
+ "max_num_voxels": 32768
28
+ }
29
+ },
30
+ "trainer": {
31
+ "name": "SLatVaeMeshDecoderTrainer",
32
+ "args": {
33
+ "max_steps": 1000000,
34
+ "batch_size_per_gpu": 4,
35
+ "batch_split": 4,
36
+ "optimizer": {
37
+ "name": "AdamW",
38
+ "args": {
39
+ "lr": 1e-4,
40
+ "weight_decay": 0.0
41
+ }
42
+ },
43
+ "ema_rate": [
44
+ 0.9999
45
+ ],
46
+ "fp16_mode": "inflat_all",
47
+ "fp16_scale_growth": 0.001,
48
+ "elastic": {
49
+ "name": "LinearMemoryController",
50
+ "args": {
51
+ "target_ratio": 0.75,
52
+ "max_mem_ratio_start": 0.5
53
+ }
54
+ },
55
+ "grad_clip": {
56
+ "name": "AdaptiveGradClipper",
57
+ "args": {
58
+ "max_norm": 1.0,
59
+ "clip_percentile": 95
60
+ }
61
+ },
62
+ "i_log": 500,
63
+ "i_sample": 10000,
64
+ "i_save": 10000,
65
+ "lambda_ssim": 0.2,
66
+ "lambda_lpips": 0.2,
67
+ "lambda_tsdf": 0.01,
68
+ "lambda_depth": 10.0,
69
+ "lambda_color": 0.1,
70
+ "depth_loss_type": "smooth_l1"
71
+ }
72
+ }
73
+ }
configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "decoder": {
4
+ "name": "ElasticSLatRadianceFieldDecoder",
5
+ "args": {
6
+ "resolution": 64,
7
+ "model_channels": 768,
8
+ "latent_channels": 8,
9
+ "num_blocks": 12,
10
+ "num_heads": 12,
11
+ "mlp_ratio": 4,
12
+ "attn_mode": "swin",
13
+ "window_size": 8,
14
+ "use_fp16": true,
15
+ "representation_config": {
16
+ "rank": 16,
17
+ "dim": 8
18
+ }
19
+ }
20
+ }
21
+ },
22
+ "dataset": {
23
+ "name": "SLat2Render",
24
+ "args": {
25
+ "image_size": 512,
26
+ "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
27
+ "min_aesthetic_score": 4.5,
28
+ "max_num_voxels": 32768
29
+ }
30
+ },
31
+ "trainer": {
32
+ "name": "SLatVaeRadianceFieldDecoderTrainer",
33
+ "args": {
34
+ "max_steps": 1000000,
35
+ "batch_size_per_gpu": 4,
36
+ "batch_split": 2,
37
+ "optimizer": {
38
+ "name": "AdamW",
39
+ "args": {
40
+ "lr": 1e-4,
41
+ "weight_decay": 0.0
42
+ }
43
+ },
44
+ "ema_rate": [
45
+ 0.9999
46
+ ],
47
+ "fp16_mode": "inflat_all",
48
+ "fp16_scale_growth": 0.001,
49
+ "elastic": {
50
+ "name": "LinearMemoryController",
51
+ "args": {
52
+ "target_ratio": 0.75,
53
+ "max_mem_ratio_start": 0.5
54
+ }
55
+ },
56
+ "grad_clip": {
57
+ "name": "AdaptiveGradClipper",
58
+ "args": {
59
+ "max_norm": 1.0,
60
+ "clip_percentile": 95
61
+ }
62
+ },
63
+ "i_log": 500,
64
+ "i_sample": 10000,
65
+ "i_save": 10000,
66
+ "loss_type": "l1",
67
+ "lambda_ssim": 0.2,
68
+ "lambda_lpips": 0.2
69
+ }
70
+ }
71
+ }
configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "encoder": {
4
+ "name": "ElasticSLatEncoder",
5
+ "args": {
6
+ "resolution": 64,
7
+ "in_channels": 1024,
8
+ "model_channels": 768,
9
+ "latent_channels": 8,
10
+ "num_blocks": 12,
11
+ "num_heads": 12,
12
+ "mlp_ratio": 4,
13
+ "attn_mode": "swin",
14
+ "window_size": 8,
15
+ "use_fp16": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "name": "ElasticSLatGaussianDecoder",
20
+ "args": {
21
+ "resolution": 64,
22
+ "model_channels": 768,
23
+ "latent_channels": 8,
24
+ "num_blocks": 12,
25
+ "num_heads": 12,
26
+ "mlp_ratio": 4,
27
+ "attn_mode": "swin",
28
+ "window_size": 8,
29
+ "use_fp16": true,
30
+ "representation_config": {
31
+ "lr": {
32
+ "_xyz": 1.0,
33
+ "_features_dc": 1.0,
34
+ "_opacity": 1.0,
35
+ "_scaling": 1.0,
36
+ "_rotation": 0.1
37
+ },
38
+ "perturb_offset": true,
39
+ "voxel_size": 1.5,
40
+ "num_gaussians": 32,
41
+ "2d_filter_kernel_size": 0.1,
42
+ "3d_filter_kernel_size": 9e-4,
43
+ "scaling_bias": 4e-3,
44
+ "opacity_bias": 0.1,
45
+ "scaling_activation": "softplus"
46
+ }
47
+ }
48
+ }
49
+ },
50
+ "dataset": {
51
+ "name": "SparseFeat2Render",
52
+ "args": {
53
+ "image_size": 512,
54
+ "model": "dinov2_vitl14_reg",
55
+ "resolution": 64,
56
+ "min_aesthetic_score": 4.5,
57
+ "max_num_voxels": 32768
58
+ }
59
+ },
60
+ "trainer": {
61
+ "name": "SLatVaeGaussianTrainer",
62
+ "args": {
63
+ "max_steps": 1000000,
64
+ "batch_size_per_gpu": 4,
65
+ "batch_split": 2,
66
+ "optimizer": {
67
+ "name": "AdamW",
68
+ "args": {
69
+ "lr": 1e-4,
70
+ "weight_decay": 0.0
71
+ }
72
+ },
73
+ "ema_rate": [
74
+ 0.9999
75
+ ],
76
+ "fp16_mode": "inflat_all",
77
+ "fp16_scale_growth": 0.001,
78
+ "elastic": {
79
+ "name": "LinearMemoryController",
80
+ "args": {
81
+ "target_ratio": 0.75,
82
+ "max_mem_ratio_start": 0.5
83
+ }
84
+ },
85
+ "grad_clip": {
86
+ "name": "AdaptiveGradClipper",
87
+ "args": {
88
+ "max_norm": 1.0,
89
+ "clip_percentile": 95
90
+ }
91
+ },
92
+ "i_log": 500,
93
+ "i_sample": 10000,
94
+ "i_save": 10000,
95
+ "loss_type": "l1",
96
+ "lambda_ssim": 0.2,
97
+ "lambda_lpips": 0.2,
98
+ "lambda_kl": 1e-06,
99
+ "regularizations": {
100
+ "lambda_vol": 10000.0,
101
+ "lambda_opacity": 0.001
102
+ }
103
+ }
104
+ }
105
+ }
configs/vae/ss_vae_conv3d_16l8_fp16.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "encoder": {
4
+ "name": "SparseStructureEncoder",
5
+ "args": {
6
+ "in_channels": 1,
7
+ "latent_channels": 8,
8
+ "num_res_blocks": 2,
9
+ "num_res_blocks_middle": 2,
10
+ "channels": [32, 128, 512],
11
+ "use_fp16": true
12
+ }
13
+ },
14
+ "decoder": {
15
+ "name": "SparseStructureDecoder",
16
+ "args": {
17
+ "out_channels": 1,
18
+ "latent_channels": 8,
19
+ "num_res_blocks": 2,
20
+ "num_res_blocks_middle": 2,
21
+ "channels": [512, 128, 32],
22
+ "use_fp16": true
23
+ }
24
+ }
25
+ },
26
+ "dataset": {
27
+ "name": "SparseStructure",
28
+ "args": {
29
+ "resolution": 64,
30
+ "min_aesthetic_score": 4.5
31
+ }
32
+ },
33
+ "trainer": {
34
+ "name": "SparseStructureVaeTrainer",
35
+ "args": {
36
+ "max_steps": 1000000,
37
+ "batch_size_per_gpu": 4,
38
+ "batch_split": 1,
39
+ "optimizer": {
40
+ "name": "AdamW",
41
+ "args": {
42
+ "lr": 1e-4,
43
+ "weight_decay": 0.0
44
+ }
45
+ },
46
+ "ema_rate": [
47
+ 0.9999
48
+ ],
49
+ "fp16_mode": "inflat_all",
50
+ "fp16_scale_growth": 0.001,
51
+ "grad_clip": {
52
+ "name": "AdaptiveGradClipper",
53
+ "args": {
54
+ "max_norm": 1.0,
55
+ "clip_percentile": 95
56
+ }
57
+ },
58
+ "i_log": 500,
59
+ "i_sample": 10000,
60
+ "i_save": 10000,
61
+ "loss_type": "dice",
62
+ "lambda_kl": 0.001
63
+ }
64
+ }
65
+ }
dataset_toolkits/blender_script/io_scene_usdz.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec07ab6125fe0a021ed08c64169eceda126330401aba3d494d5203d26ac4b093
3
+ size 34685
dataset_toolkits/blender_script/render.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, sys, os, math, re, glob
2
+ from typing import *
3
+ import bpy
4
+ from mathutils import Vector, Matrix
5
+ import numpy as np
6
+ import json
7
+ import glob
8
+
9
+
10
+ """=============== BLENDER ==============="""
11
+
12
+ IMPORT_FUNCTIONS: Dict[str, Callable] = {
13
+ "obj": bpy.ops.import_scene.obj,
14
+ "glb": bpy.ops.import_scene.gltf,
15
+ "gltf": bpy.ops.import_scene.gltf,
16
+ "usd": bpy.ops.import_scene.usd,
17
+ "fbx": bpy.ops.import_scene.fbx,
18
+ "stl": bpy.ops.import_mesh.stl,
19
+ "usda": bpy.ops.import_scene.usda,
20
+ "dae": bpy.ops.wm.collada_import,
21
+ "ply": bpy.ops.import_mesh.ply,
22
+ "abc": bpy.ops.wm.alembic_import,
23
+ "blend": bpy.ops.wm.append,
24
+ }
25
+
26
+ EXT = {
27
+ 'PNG': 'png',
28
+ 'JPEG': 'jpg',
29
+ 'OPEN_EXR': 'exr',
30
+ 'TIFF': 'tiff',
31
+ 'BMP': 'bmp',
32
+ 'HDR': 'hdr',
33
+ 'TARGA': 'tga'
34
+ }
35
+
36
+ def init_render(engine='CYCLES', resolution=512, geo_mode=False):
37
+ bpy.context.scene.render.engine = engine
38
+ bpy.context.scene.render.resolution_x = resolution
39
+ bpy.context.scene.render.resolution_y = resolution
40
+ bpy.context.scene.render.resolution_percentage = 100
41
+ bpy.context.scene.render.image_settings.file_format = 'PNG'
42
+ bpy.context.scene.render.image_settings.color_mode = 'RGBA'
43
+ bpy.context.scene.render.film_transparent = True
44
+
45
+ bpy.context.scene.cycles.device = 'GPU'
46
+ bpy.context.scene.cycles.samples = 128 if not geo_mode else 1
47
+ bpy.context.scene.cycles.filter_type = 'BOX'
48
+ bpy.context.scene.cycles.filter_width = 1
49
+ bpy.context.scene.cycles.diffuse_bounces = 1
50
+ bpy.context.scene.cycles.glossy_bounces = 1
51
+ bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0
52
+ bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1
53
+ bpy.context.scene.cycles.use_denoising = True
54
+
55
+ bpy.context.preferences.addons['cycles'].preferences.get_devices()
56
+ bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
57
+
58
+ def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False):
59
+ if not any([save_depth, save_normal, save_albedo, save_mist]):
60
+ return {}, {}
61
+ outputs = {}
62
+ spec_nodes = {}
63
+
64
+ bpy.context.scene.use_nodes = True
65
+ bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth
66
+ bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal
67
+ bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo
68
+ bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist
69
+
70
+ nodes = bpy.context.scene.node_tree.nodes
71
+ links = bpy.context.scene.node_tree.links
72
+ for n in nodes:
73
+ nodes.remove(n)
74
+
75
+ render_layers = nodes.new('CompositorNodeRLayers')
76
+
77
+ if save_depth:
78
+ depth_file_output = nodes.new('CompositorNodeOutputFile')
79
+ depth_file_output.base_path = ''
80
+ depth_file_output.file_slots[0].use_node_format = True
81
+ depth_file_output.format.file_format = 'PNG'
82
+ depth_file_output.format.color_depth = '16'
83
+ depth_file_output.format.color_mode = 'BW'
84
+ # Remap to 0-1
85
+ map = nodes.new(type="CompositorNodeMapRange")
86
+ map.inputs[1].default_value = 0 # (min value you will be getting)
87
+ map.inputs[2].default_value = 10 # (max value you will be getting)
88
+ map.inputs[3].default_value = 0 # (min value you will map to)
89
+ map.inputs[4].default_value = 1 # (max value you will map to)
90
+
91
+ links.new(render_layers.outputs['Depth'], map.inputs[0])
92
+ links.new(map.outputs[0], depth_file_output.inputs[0])
93
+
94
+ outputs['depth'] = depth_file_output
95
+ spec_nodes['depth_map'] = map
96
+
97
+ if save_normal:
98
+ normal_file_output = nodes.new('CompositorNodeOutputFile')
99
+ normal_file_output.base_path = ''
100
+ normal_file_output.file_slots[0].use_node_format = True
101
+ normal_file_output.format.file_format = 'OPEN_EXR'
102
+ normal_file_output.format.color_mode = 'RGB'
103
+ normal_file_output.format.color_depth = '16'
104
+
105
+ links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0])
106
+
107
+ outputs['normal'] = normal_file_output
108
+
109
+ if save_albedo:
110
+ albedo_file_output = nodes.new('CompositorNodeOutputFile')
111
+ albedo_file_output.base_path = ''
112
+ albedo_file_output.file_slots[0].use_node_format = True
113
+ albedo_file_output.format.file_format = 'PNG'
114
+ albedo_file_output.format.color_mode = 'RGBA'
115
+ albedo_file_output.format.color_depth = '8'
116
+
117
+ alpha_albedo = nodes.new('CompositorNodeSetAlpha')
118
+
119
+ links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image'])
120
+ links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha'])
121
+ links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0])
122
+
123
+ outputs['albedo'] = albedo_file_output
124
+
125
+ if save_mist:
126
+ bpy.data.worlds['World'].mist_settings.start = 0
127
+ bpy.data.worlds['World'].mist_settings.depth = 10
128
+
129
+ mist_file_output = nodes.new('CompositorNodeOutputFile')
130
+ mist_file_output.base_path = ''
131
+ mist_file_output.file_slots[0].use_node_format = True
132
+ mist_file_output.format.file_format = 'PNG'
133
+ mist_file_output.format.color_mode = 'BW'
134
+ mist_file_output.format.color_depth = '16'
135
+
136
+ links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0])
137
+
138
+ outputs['mist'] = mist_file_output
139
+
140
+ return outputs, spec_nodes
141
+
142
+ def init_scene() -> None:
143
+ """Resets the scene to a clean state.
144
+
145
+ Returns:
146
+ None
147
+ """
148
+ # delete everything
149
+ for obj in bpy.data.objects:
150
+ bpy.data.objects.remove(obj, do_unlink=True)
151
+
152
+ # delete all the materials
153
+ for material in bpy.data.materials:
154
+ bpy.data.materials.remove(material, do_unlink=True)
155
+
156
+ # delete all the textures
157
+ for texture in bpy.data.textures:
158
+ bpy.data.textures.remove(texture, do_unlink=True)
159
+
160
+ # delete all the images
161
+ for image in bpy.data.images:
162
+ bpy.data.images.remove(image, do_unlink=True)
163
+
164
+ def init_camera():
165
+ cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
166
+ bpy.context.collection.objects.link(cam)
167
+ bpy.context.scene.camera = cam
168
+ cam.data.sensor_height = cam.data.sensor_width = 32
169
+ cam_constraint = cam.constraints.new(type='TRACK_TO')
170
+ cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
171
+ cam_constraint.up_axis = 'UP_Y'
172
+ cam_empty = bpy.data.objects.new("Empty", None)
173
+ cam_empty.location = (0, 0, 0)
174
+ bpy.context.scene.collection.objects.link(cam_empty)
175
+ cam_constraint.target = cam_empty
176
+ return cam
177
+
178
+ def init_lighting():
179
+ # Clear existing lights
180
+ bpy.ops.object.select_all(action="DESELECT")
181
+ bpy.ops.object.select_by_type(type="LIGHT")
182
+ bpy.ops.object.delete()
183
+
184
+ # Create key light
185
+ default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
186
+ bpy.context.collection.objects.link(default_light)
187
+ default_light.data.energy = 1000
188
+ default_light.location = (4, 1, 6)
189
+ default_light.rotation_euler = (0, 0, 0)
190
+
191
+ # create top light
192
+ top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
193
+ bpy.context.collection.objects.link(top_light)
194
+ top_light.data.energy = 10000
195
+ top_light.location = (0, 0, 10)
196
+ top_light.scale = (100, 100, 100)
197
+
198
+ # create bottom light
199
+ bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
200
+ bpy.context.collection.objects.link(bottom_light)
201
+ bottom_light.data.energy = 1000
202
+ bottom_light.location = (0, 0, -10)
203
+ bottom_light.rotation_euler = (0, 0, 0)
204
+
205
+ return {
206
+ "default_light": default_light,
207
+ "top_light": top_light,
208
+ "bottom_light": bottom_light
209
+ }
210
+
211
+
212
+ def load_object(object_path: str) -> None:
213
+ """Loads a model with a supported file extension into the scene.
214
+
215
+ Args:
216
+ object_path (str): Path to the model file.
217
+
218
+ Raises:
219
+ ValueError: If the file extension is not supported.
220
+
221
+ Returns:
222
+ None
223
+ """
224
+ file_extension = object_path.split(".")[-1].lower()
225
+ if file_extension is None:
226
+ raise ValueError(f"Unsupported file type: {object_path}")
227
+
228
+ if file_extension == "usdz":
229
+ # install usdz io package
230
+ dirname = os.path.dirname(os.path.realpath(__file__))
231
+ usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
232
+ bpy.ops.preferences.addon_install(filepath=usdz_package)
233
+ # enable it
234
+ addon_name = "io_scene_usdz"
235
+ bpy.ops.preferences.addon_enable(module=addon_name)
236
+ # import the usdz
237
+ from io_scene_usdz.import_usdz import import_usdz
238
+
239
+ import_usdz(context, filepath=object_path, materials=True, animations=True)
240
+ return None
241
+
242
+ # load from existing import functions
243
+ import_function = IMPORT_FUNCTIONS[file_extension]
244
+
245
+ print(f"Loading object from {object_path}")
246
+ if file_extension == "blend":
247
+ import_function(directory=object_path, link=False)
248
+ elif file_extension in {"glb", "gltf"}:
249
+ import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
250
+ else:
251
+ import_function(filepath=object_path)
252
+
253
+ def delete_invisible_objects() -> None:
254
+ """Deletes all invisible objects in the scene.
255
+
256
+ Returns:
257
+ None
258
+ """
259
+ # bpy.ops.object.mode_set(mode="OBJECT")
260
+ bpy.ops.object.select_all(action="DESELECT")
261
+ for obj in bpy.context.scene.objects:
262
+ if obj.hide_viewport or obj.hide_render:
263
+ obj.hide_viewport = False
264
+ obj.hide_render = False
265
+ obj.hide_select = False
266
+ obj.select_set(True)
267
+ bpy.ops.object.delete()
268
+
269
+ # Delete invisible collections
270
+ invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
271
+ for col in invisible_collections:
272
+ bpy.data.collections.remove(col)
273
+
274
+ def split_mesh_normal():
275
+ bpy.ops.object.select_all(action="DESELECT")
276
+ objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
277
+ bpy.context.view_layer.objects.active = objs[0]
278
+ for obj in objs:
279
+ obj.select_set(True)
280
+ bpy.ops.object.mode_set(mode="EDIT")
281
+ bpy.ops.mesh.select_all(action='SELECT')
282
+ bpy.ops.mesh.split_normals()
283
+ bpy.ops.object.mode_set(mode='OBJECT')
284
+ bpy.ops.object.select_all(action="DESELECT")
285
+
286
+ def delete_custom_normals():
287
+ for this_obj in bpy.data.objects:
288
+ if this_obj.type == "MESH":
289
+ bpy.context.view_layer.objects.active = this_obj
290
+ bpy.ops.mesh.customdata_custom_splitnormals_clear()
291
+
292
+ def override_material():
293
+ new_mat = bpy.data.materials.new(name="Override0123456789")
294
+ new_mat.use_nodes = True
295
+ new_mat.node_tree.nodes.clear()
296
+ bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
297
+ bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
298
+ bsdf.inputs[1].default_value = 1
299
+ output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
300
+ new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
301
+ bpy.context.scene.view_layers['View Layer'].material_override = new_mat
302
+
303
+ def unhide_all_objects() -> None:
304
+ """Unhides all objects in the scene.
305
+
306
+ Returns:
307
+ None
308
+ """
309
+ for obj in bpy.context.scene.objects:
310
+ obj.hide_set(False)
311
+
312
+ def convert_to_meshes() -> None:
313
+ """Converts all objects in the scene to meshes.
314
+
315
+ Returns:
316
+ None
317
+ """
318
+ bpy.ops.object.select_all(action="DESELECT")
319
+ bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
320
+ for obj in bpy.context.scene.objects:
321
+ obj.select_set(True)
322
+ bpy.ops.object.convert(target="MESH")
323
+
324
+ def triangulate_meshes() -> None:
325
+ """Triangulates all meshes in the scene.
326
+
327
+ Returns:
328
+ None
329
+ """
330
+ bpy.ops.object.select_all(action="DESELECT")
331
+ objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
332
+ bpy.context.view_layer.objects.active = objs[0]
333
+ for obj in objs:
334
+ obj.select_set(True)
335
+ bpy.ops.object.mode_set(mode="EDIT")
336
+ bpy.ops.mesh.reveal()
337
+ bpy.ops.mesh.select_all(action="SELECT")
338
+ bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
339
+ bpy.ops.object.mode_set(mode="OBJECT")
340
+ bpy.ops.object.select_all(action="DESELECT")
341
+
342
+ def scene_bbox() -> Tuple[Vector, Vector]:
343
+ """Returns the bounding box of the scene.
344
+
345
+ Taken from Shap-E rendering script
346
+ (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
347
+
348
+ Returns:
349
+ Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
350
+ """
351
+ bbox_min = (math.inf,) * 3
352
+ bbox_max = (-math.inf,) * 3
353
+ found = False
354
+ scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
355
+ for obj in scene_meshes:
356
+ found = True
357
+ for coord in obj.bound_box:
358
+ coord = Vector(coord)
359
+ coord = obj.matrix_world @ coord
360
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
361
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
362
+ if not found:
363
+ raise RuntimeError("no objects in scene to compute bounding box for")
364
+ return Vector(bbox_min), Vector(bbox_max)
365
+
366
+ def normalize_scene() -> Tuple[float, Vector]:
367
+ """Normalizes the scene by scaling and translating it to fit in a unit cube centered
368
+ at the origin.
369
+
370
+ Mostly taken from the Point-E / Shap-E rendering script
371
+ (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
372
+ but fix for multiple root objects: (see bug report here:
373
+ https://github.com/openai/shap-e/pull/60).
374
+
375
+ Returns:
376
+ Tuple[float, Vector]: The scale factor and the offset applied to the scene.
377
+ """
378
+ scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
379
+ if len(scene_root_objects) > 1:
380
+ # create an empty object to be used as a parent for all root objects
381
+ scene = bpy.data.objects.new("ParentEmpty", None)
382
+ bpy.context.scene.collection.objects.link(scene)
383
+
384
+ # parent all root objects to the empty object
385
+ for obj in scene_root_objects:
386
+ obj.parent = scene
387
+ else:
388
+ scene = scene_root_objects[0]
389
+
390
+ bbox_min, bbox_max = scene_bbox()
391
+ scale = 1 / max(bbox_max - bbox_min)
392
+ scene.scale = scene.scale * scale
393
+
394
+ # Apply scale to matrix_world.
395
+ bpy.context.view_layer.update()
396
+ bbox_min, bbox_max = scene_bbox()
397
+ offset = -(bbox_min + bbox_max) / 2
398
+ scene.matrix_world.translation += offset
399
+ bpy.ops.object.select_all(action="DESELECT")
400
+
401
+ return scale, offset
402
+
403
+ def get_transform_matrix(obj: bpy.types.Object) -> list:
404
+ pos, rt, _ = obj.matrix_world.decompose()
405
+ rt = rt.to_matrix()
406
+ matrix = []
407
+ for ii in range(3):
408
+ a = []
409
+ for jj in range(3):
410
+ a.append(rt[ii][jj])
411
+ a.append(pos[ii])
412
+ matrix.append(a)
413
+ matrix.append([0, 0, 0, 1])
414
+ return matrix
415
+
416
+ def main(arg):
417
+ os.makedirs(arg.output_folder, exist_ok=True)
418
+
419
+ # Initialize context
420
+ init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode)
421
+ outputs, spec_nodes = init_nodes(
422
+ save_depth=arg.save_depth,
423
+ save_normal=arg.save_normal,
424
+ save_albedo=arg.save_albedo,
425
+ save_mist=arg.save_mist
426
+ )
427
+ if arg.object.endswith(".blend"):
428
+ delete_invisible_objects()
429
+ else:
430
+ init_scene()
431
+ load_object(arg.object)
432
+ if arg.split_normal:
433
+ split_mesh_normal()
434
+ # delete_custom_normals()
435
+ print('[INFO] Scene initialized.')
436
+
437
+ # normalize scene
438
+ scale, offset = normalize_scene()
439
+ print('[INFO] Scene normalized.')
440
+
441
+ # Initialize camera and lighting
442
+ cam = init_camera()
443
+ init_lighting()
444
+ print('[INFO] Camera and lighting initialized.')
445
+
446
+ # Override material
447
+ if arg.geo_mode:
448
+ override_material()
449
+
450
+ # Create a list of views
451
+ to_export = {
452
+ "aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
453
+ "scale": scale,
454
+ "offset": [offset.x, offset.y, offset.z],
455
+ "frames": []
456
+ }
457
+ views = json.loads(arg.views)
458
+ for i, view in enumerate(views):
459
+ cam.location = (
460
+ view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']),
461
+ view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']),
462
+ view['radius'] * np.sin(view['pitch'])
463
+ )
464
+ cam.data.lens = 16 / np.tan(view['fov'] / 2)
465
+
466
+ if arg.save_depth:
467
+ spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3)
468
+ spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3)
469
+
470
+ bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png')
471
+ for name, output in outputs.items():
472
+ output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}')
473
+
474
+ # Render the scene
475
+ bpy.ops.render.render(write_still=True)
476
+ bpy.context.view_layer.update()
477
+ for name, output in outputs.items():
478
+ ext = EXT[output.format.file_format]
479
+ path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0]
480
+ os.rename(path, f'{output.file_slots[0].path}.{ext}')
481
+
482
+ # Save camera parameters
483
+ metadata = {
484
+ "file_path": f'{i:03d}.png',
485
+ "camera_angle_x": view['fov'],
486
+ "transform_matrix": get_transform_matrix(cam)
487
+ }
488
+ if arg.save_depth:
489
+ metadata['depth'] = {
490
+ 'min': view['radius'] - 0.5 * np.sqrt(3),
491
+ 'max': view['radius'] + 0.5 * np.sqrt(3)
492
+ }
493
+ to_export["frames"].append(metadata)
494
+
495
+ # Save the camera parameters
496
+ with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f:
497
+ json.dump(to_export, f, indent=4)
498
+
499
+ if arg.save_mesh:
500
+ # triangulate meshes
501
+ unhide_all_objects()
502
+ convert_to_meshes()
503
+ triangulate_meshes()
504
+ print('[INFO] Meshes triangulated.')
505
+
506
+ # export ply mesh
507
+ bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply'))
508
+
509
+
510
+ if __name__ == '__main__':
511
+ parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
512
+ parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
513
+ parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
514
+ parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
515
+ parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.')
516
+ parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
517
+ parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.')
518
+ parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.')
519
+ parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.')
520
+ parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.')
521
+ parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.')
522
+ parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.')
523
+ parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.')
524
+ argv = sys.argv[sys.argv.index("--") + 1:]
525
+ args = parser.parse_args(argv)
526
+
527
+ main(args)
528
+
dataset_toolkits/build_metadata.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import sys
4
+ import time
5
+ import importlib
6
+ import argparse
7
+ import numpy as np
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+ from easydict import EasyDict as edict
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ import utils3d
13
+
14
+ def get_first_directory(path):
15
+ with os.scandir(path) as it:
16
+ for entry in it:
17
+ if entry.is_dir():
18
+ return entry.name
19
+ return None
20
+
21
+ def need_process(key):
22
+ return key in opt.field or opt.field == ['all']
23
+
24
+ if __name__ == '__main__':
25
+ dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('--output_dir', type=str, required=True,
29
+ help='Directory to save the metadata')
30
+ parser.add_argument('--field', type=str, default='all',
31
+ help='Fields to process, separated by commas')
32
+ parser.add_argument('--from_file', action='store_true',
33
+ help='Build metadata from file instead of from records of processings.' +
34
+ 'Useful when some processing fail to generate records but file already exists.')
35
+ dataset_utils.add_args(parser)
36
+ opt = parser.parse_args(sys.argv[2:])
37
+ opt = edict(vars(opt))
38
+
39
+ os.makedirs(opt.output_dir, exist_ok=True)
40
+ os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True)
41
+
42
+ opt.field = opt.field.split(',')
43
+
44
+ timestamp = str(int(time.time()))
45
+
46
+ # get file list
47
+ if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
48
+ print('Loading previous metadata...')
49
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
50
+ else:
51
+ metadata = dataset_utils.get_metadata(**opt)
52
+ metadata.set_index('sha256', inplace=True)
53
+
54
+ # merge downloaded
55
+ df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')]
56
+ df_parts = []
57
+ for f in df_files:
58
+ try:
59
+ df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
60
+ except:
61
+ pass
62
+ if len(df_parts) > 0:
63
+ df = pd.concat(df_parts)
64
+ df.set_index('sha256', inplace=True)
65
+ if 'local_path' in metadata.columns:
66
+ metadata.update(df, overwrite=True)
67
+ else:
68
+ metadata = metadata.join(df, on='sha256', how='left')
69
+ for f in df_files:
70
+ shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
71
+
72
+ # detect models
73
+ image_models = []
74
+ if os.path.exists(os.path.join(opt.output_dir, 'features')):
75
+ image_models = os.listdir(os.path.join(opt.output_dir, 'features'))
76
+ latent_models = []
77
+ if os.path.exists(os.path.join(opt.output_dir, 'latents')):
78
+ latent_models = os.listdir(os.path.join(opt.output_dir, 'latents'))
79
+ ss_latent_models = []
80
+ if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')):
81
+ ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents'))
82
+ print(f'Image models: {image_models}')
83
+ print(f'Latent models: {latent_models}')
84
+ print(f'Sparse Structure latent models: {ss_latent_models}')
85
+
86
+ if 'rendered' not in metadata.columns:
87
+ metadata['rendered'] = [False] * len(metadata)
88
+ if 'voxelized' not in metadata.columns:
89
+ metadata['voxelized'] = [False] * len(metadata)
90
+ if 'num_voxels' not in metadata.columns:
91
+ metadata['num_voxels'] = [0] * len(metadata)
92
+ if 'cond_rendered' not in metadata.columns:
93
+ metadata['cond_rendered'] = [False] * len(metadata)
94
+ for model in image_models:
95
+ if f'feature_{model}' not in metadata.columns:
96
+ metadata[f'feature_{model}'] = [False] * len(metadata)
97
+ for model in latent_models:
98
+ if f'latent_{model}' not in metadata.columns:
99
+ metadata[f'latent_{model}'] = [False] * len(metadata)
100
+ for model in ss_latent_models:
101
+ if f'ss_latent_{model}' not in metadata.columns:
102
+ metadata[f'ss_latent_{model}'] = [False] * len(metadata)
103
+
104
+ # merge rendered
105
+ df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')]
106
+ df_parts = []
107
+ for f in df_files:
108
+ try:
109
+ df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
110
+ except:
111
+ pass
112
+ if len(df_parts) > 0:
113
+ df = pd.concat(df_parts)
114
+ df.set_index('sha256', inplace=True)
115
+ metadata.update(df, overwrite=True)
116
+ for f in df_files:
117
+ shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
118
+
119
+ # merge voxelized
120
+ df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')]
121
+ df_parts = []
122
+ for f in df_files:
123
+ try:
124
+ df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
125
+ except:
126
+ pass
127
+ if len(df_parts) > 0:
128
+ df = pd.concat(df_parts)
129
+ df.set_index('sha256', inplace=True)
130
+ metadata.update(df, overwrite=True)
131
+ for f in df_files:
132
+ shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
133
+
134
+ # merge cond_rendered
135
+ df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')]
136
+ df_parts = []
137
+ for f in df_files:
138
+ try:
139
+ df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
140
+ except:
141
+ pass
142
+ if len(df_parts) > 0:
143
+ df = pd.concat(df_parts)
144
+ df.set_index('sha256', inplace=True)
145
+ metadata.update(df, overwrite=True)
146
+ for f in df_files:
147
+ shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
148
+
149
+ # merge features
150
+ for model in image_models:
151
+ df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')]
152
+ df_parts = []
153
+ for f in df_files:
154
+ try:
155
+ df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
156
+ except:
157
+ pass
158
+ if len(df_parts) > 0:
159
+ df = pd.concat(df_parts)
160
+ df.set_index('sha256', inplace=True)
161
+ metadata.update(df, overwrite=True)
162
+ for f in df_files:
163
+ shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
164
+
165
+ # merge latents
166
+ for model in latent_models:
167
+ df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')]
168
+ df_parts = []
169
+ for f in df_files:
170
+ try:
171
+ df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
172
+ except:
173
+ pass
174
+ if len(df_parts) > 0:
175
+ df = pd.concat(df_parts)
176
+ df.set_index('sha256', inplace=True)
177
+ metadata.update(df, overwrite=True)
178
+ for f in df_files:
179
+ shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
180
+
181
+ # merge sparse structure latents
182
+ for model in ss_latent_models:
183
+ df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')]
184
+ df_parts = []
185
+ for f in df_files:
186
+ try:
187
+ df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
188
+ except:
189
+ pass
190
+ if len(df_parts) > 0:
191
+ df = pd.concat(df_parts)
192
+ df.set_index('sha256', inplace=True)
193
+ metadata.update(df, overwrite=True)
194
+ for f in df_files:
195
+ shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
196
+
197
+ # build metadata from files
198
+ if opt.from_file:
199
+ with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
200
+ tqdm(total=len(metadata), desc="Building metadata") as pbar:
201
+ def worker(sha256):
202
+ try:
203
+ if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \
204
+ os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
205
+ metadata.loc[sha256, 'rendered'] = True
206
+ if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \
207
+ os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
208
+ try:
209
+ pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
210
+ metadata.loc[sha256, 'voxelized'] = True
211
+ metadata.loc[sha256, 'num_voxels'] = len(pts)
212
+ except Exception as e:
213
+ pass
214
+ if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \
215
+ os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
216
+ metadata.loc[sha256, 'cond_rendered'] = True
217
+ for model in image_models:
218
+ if need_process(f'feature_{model}') and \
219
+ metadata.loc[sha256, f'feature_{model}'] == False and \
220
+ metadata.loc[sha256, 'rendered'] == True and \
221
+ metadata.loc[sha256, 'voxelized'] == True and \
222
+ os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')):
223
+ metadata.loc[sha256, f'feature_{model}'] = True
224
+ for model in latent_models:
225
+ if need_process(f'latent_{model}') and \
226
+ metadata.loc[sha256, f'latent_{model}'] == False and \
227
+ metadata.loc[sha256, 'rendered'] == True and \
228
+ metadata.loc[sha256, 'voxelized'] == True and \
229
+ os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')):
230
+ metadata.loc[sha256, f'latent_{model}'] = True
231
+ for model in ss_latent_models:
232
+ if need_process(f'ss_latent_{model}') and \
233
+ metadata.loc[sha256, f'ss_latent_{model}'] == False and \
234
+ metadata.loc[sha256, 'voxelized'] == True and \
235
+ os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')):
236
+ metadata.loc[sha256, f'ss_latent_{model}'] = True
237
+ pbar.update()
238
+ except Exception as e:
239
+ print(f'Error processing {sha256}: {e}')
240
+ pbar.update()
241
+
242
+ executor.map(worker, metadata.index)
243
+ executor.shutdown(wait=True)
244
+
245
+ # statistics
246
+ metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv'))
247
+ num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0
248
+ with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f:
249
+ f.write('Statistics:\n')
250
+ f.write(f' - Number of assets: {len(metadata)}\n')
251
+ f.write(f' - Number of assets downloaded: {num_downloaded}\n')
252
+ f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n')
253
+ f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n')
254
+ if len(image_models) != 0:
255
+ f.write(f' - Number of assets with image features extracted:\n')
256
+ for model in image_models:
257
+ f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n')
258
+ if len(latent_models) != 0:
259
+ f.write(f' - Number of assets with latents extracted:\n')
260
+ for model in latent_models:
261
+ f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n')
262
+ if len(ss_latent_models) != 0:
263
+ f.write(f' - Number of assets with sparse structure latents extracted:\n')
264
+ for model in ss_latent_models:
265
+ f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n')
266
+ f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n')
267
+ f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n')
268
+
269
+ with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f:
270
+ print(f.read())
dataset_toolkits/datasets/3D-FUTURE.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+ import zipfile
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ from utils import get_file_hash
9
+
10
+
11
+ def add_args(parser: argparse.ArgumentParser):
12
+ pass
13
+
14
+
15
+ def get_metadata(**kwargs):
16
+ metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv")
17
+ return metadata
18
+
19
+
20
+ def download(metadata, output_dir, **kwargs):
21
+ os.makedirs(output_dir, exist_ok=True)
22
+
23
+ if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')):
24
+ print("\033[93m")
25
+ print("3D-FUTURE have to be downloaded manually")
26
+ print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory")
27
+ print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information")
28
+ print("\033[0m")
29
+ raise FileNotFoundError("3D-FUTURE-model.zip not found")
30
+
31
+ downloaded = {}
32
+ metadata = metadata.set_index("file_identifier")
33
+ with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref:
34
+ all_names = zip_ref.namelist()
35
+ instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)]
36
+ instances = list(filter(lambda x: x in metadata.index, instances))
37
+
38
+ with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
39
+ tqdm(total=len(instances), desc="Extracting") as pbar:
40
+ def worker(instance: str) -> str:
41
+ try:
42
+ instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names))
43
+ zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files)
44
+ sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg"))
45
+ pbar.update()
46
+ return sha256
47
+ except Exception as e:
48
+ pbar.update()
49
+ print(f"Error extracting for {instance}: {e}")
50
+ return None
51
+
52
+ sha256s = executor.map(worker, instances)
53
+ executor.shutdown(wait=True)
54
+
55
+ for k, sha256 in zip(instances, sha256s):
56
+ if sha256 is not None:
57
+ if sha256 == metadata.loc[k, "sha256"]:
58
+ downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj")
59
+ else:
60
+ print(f"Error downloading {k}: sha256s do not match")
61
+
62
+ return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
63
+
64
+
65
+ def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
66
+ import os
67
+ from concurrent.futures import ThreadPoolExecutor
68
+ from tqdm import tqdm
69
+
70
+ # load metadata
71
+ metadata = metadata.to_dict('records')
72
+
73
+ # processing objects
74
+ records = []
75
+ max_workers = max_workers or os.cpu_count()
76
+ try:
77
+ with ThreadPoolExecutor(max_workers=max_workers) as executor, \
78
+ tqdm(total=len(metadata), desc=desc) as pbar:
79
+ def worker(metadatum):
80
+ try:
81
+ local_path = metadatum['local_path']
82
+ sha256 = metadatum['sha256']
83
+ file = os.path.join(output_dir, local_path)
84
+ record = func(file, sha256)
85
+ if record is not None:
86
+ records.append(record)
87
+ pbar.update()
88
+ except Exception as e:
89
+ print(f"Error processing object {sha256}: {e}")
90
+ pbar.update()
91
+
92
+ executor.map(worker, metadata)
93
+ executor.shutdown(wait=True)
94
+ except:
95
+ print("Error happened during processing.")
96
+
97
+ return pd.DataFrame.from_records(records)
dataset_toolkits/datasets/ABO.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+ import tarfile
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ from utils import get_file_hash
9
+
10
+
11
+ def add_args(parser: argparse.ArgumentParser):
12
+ pass
13
+
14
+
15
+ def get_metadata(**kwargs):
16
+ metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv")
17
+ return metadata
18
+
19
+
20
+ def download(metadata, output_dir, **kwargs):
21
+ os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
22
+
23
+ if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')):
24
+ try:
25
+ os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
26
+ os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar")
27
+ except:
28
+ print("\033[93m")
29
+ print("Error downloading ABO dataset. Please check your internet connection and try again.")
30
+ print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory")
31
+ print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information")
32
+ print("\033[0m")
33
+ raise FileNotFoundError("Error downloading ABO dataset")
34
+
35
+ downloaded = {}
36
+ metadata = metadata.set_index("file_identifier")
37
+ with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar:
38
+ with ThreadPoolExecutor(max_workers=1) as executor, \
39
+ tqdm(total=len(metadata), desc="Extracting") as pbar:
40
+ def worker(instance: str) -> str:
41
+ try:
42
+ tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw'))
43
+ sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance))
44
+ pbar.update()
45
+ return sha256
46
+ except Exception as e:
47
+ pbar.update()
48
+ print(f"Error extracting for {instance}: {e}")
49
+ return None
50
+
51
+ sha256s = executor.map(worker, metadata.index)
52
+ executor.shutdown(wait=True)
53
+
54
+ for k, sha256 in zip(metadata.index, sha256s):
55
+ if sha256 is not None:
56
+ if sha256 == metadata.loc[k, "sha256"]:
57
+ downloaded[sha256] = os.path.join('raw/3dmodels/original', k)
58
+ else:
59
+ print(f"Error downloading {k}: sha256s do not match")
60
+
61
+ return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
62
+
63
+
64
+ def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
65
+ import os
66
+ from concurrent.futures import ThreadPoolExecutor
67
+ from tqdm import tqdm
68
+
69
+ # load metadata
70
+ metadata = metadata.to_dict('records')
71
+
72
+ # processing objects
73
+ records = []
74
+ max_workers = max_workers or os.cpu_count()
75
+ try:
76
+ with ThreadPoolExecutor(max_workers=max_workers) as executor, \
77
+ tqdm(total=len(metadata), desc=desc) as pbar:
78
+ def worker(metadatum):
79
+ try:
80
+ local_path = metadatum['local_path']
81
+ sha256 = metadatum['sha256']
82
+ file = os.path.join(output_dir, local_path)
83
+ record = func(file, sha256)
84
+ if record is not None:
85
+ records.append(record)
86
+ pbar.update()
87
+ except Exception as e:
88
+ print(f"Error processing object {sha256}: {e}")
89
+ pbar.update()
90
+
91
+ executor.map(worker, metadata)
92
+ executor.shutdown(wait=True)
93
+ except:
94
+ print("Error happened during processing.")
95
+
96
+ return pd.DataFrame.from_records(records)
dataset_toolkits/datasets/HSSD.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+ import tarfile
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ import huggingface_hub
9
+ from utils import get_file_hash
10
+
11
+
12
+ def add_args(parser: argparse.ArgumentParser):
13
+ pass
14
+
15
+
16
+ def get_metadata(**kwargs):
17
+ metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv")
18
+ return metadata
19
+
20
+
21
+ def download(metadata, output_dir, **kwargs):
22
+ os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
23
+
24
+ # check login
25
+ try:
26
+ huggingface_hub.whoami()
27
+ except:
28
+ print("\033[93m")
29
+ print("Haven't logged in to the Hugging Face Hub.")
30
+ print("Visit https://huggingface.co/settings/tokens to get a token.")
31
+ print("\033[0m")
32
+ huggingface_hub.login()
33
+
34
+ try:
35
+ huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset")
36
+ except:
37
+ print("\033[93m")
38
+ print("Error downloading HSSD dataset.")
39
+ print("Check if you have access to the HSSD dataset.")
40
+ print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information")
41
+ print("\033[0m")
42
+
43
+ downloaded = {}
44
+ metadata = metadata.set_index("file_identifier")
45
+ with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
46
+ tqdm(total=len(metadata), desc="Downloading") as pbar:
47
+ def worker(instance: str) -> str:
48
+ try:
49
+ huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw'))
50
+ sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance))
51
+ pbar.update()
52
+ return sha256
53
+ except Exception as e:
54
+ pbar.update()
55
+ print(f"Error extracting for {instance}: {e}")
56
+ return None
57
+
58
+ sha256s = executor.map(worker, metadata.index)
59
+ executor.shutdown(wait=True)
60
+
61
+ for k, sha256 in zip(metadata.index, sha256s):
62
+ if sha256 is not None:
63
+ if sha256 == metadata.loc[k, "sha256"]:
64
+ downloaded[sha256] = os.path.join('raw', k)
65
+ else:
66
+ print(f"Error downloading {k}: sha256s do not match")
67
+
68
+ return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
69
+
70
+
71
+ def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
72
+ import os
73
+ from concurrent.futures import ThreadPoolExecutor
74
+ from tqdm import tqdm
75
+
76
+ # load metadata
77
+ metadata = metadata.to_dict('records')
78
+
79
+ # processing objects
80
+ records = []
81
+ max_workers = max_workers or os.cpu_count()
82
+ try:
83
+ with ThreadPoolExecutor(max_workers=max_workers) as executor, \
84
+ tqdm(total=len(metadata), desc=desc) as pbar:
85
+ def worker(metadatum):
86
+ try:
87
+ local_path = metadatum['local_path']
88
+ sha256 = metadatum['sha256']
89
+ file = os.path.join(output_dir, local_path)
90
+ record = func(file, sha256)
91
+ if record is not None:
92
+ records.append(record)
93
+ pbar.update()
94
+ except Exception as e:
95
+ print(f"Error processing object {sha256}: {e}")
96
+ pbar.update()
97
+
98
+ executor.map(worker, metadata)
99
+ executor.shutdown(wait=True)
100
+ except:
101
+ print("Error happened during processing.")
102
+
103
+ return pd.DataFrame.from_records(records)
dataset_toolkits/datasets/ObjaverseXL.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ import objaverse.xl as oxl
7
+ from utils import get_file_hash
8
+
9
+
10
+ def add_args(parser: argparse.ArgumentParser):
11
+ parser.add_argument('--source', type=str, default='sketchfab',
12
+ help='Data source to download annotations from (github, sketchfab)')
13
+
14
+
15
+ def get_metadata(source, **kwargs):
16
+ if source == 'sketchfab':
17
+ metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv")
18
+ elif source == 'github':
19
+ metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv")
20
+ else:
21
+ raise ValueError(f"Invalid source: {source}")
22
+ return metadata
23
+
24
+
25
+ def download(metadata, output_dir, **kwargs):
26
+ os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
27
+
28
+ # download annotations
29
+ annotations = oxl.get_annotations()
30
+ annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)]
31
+
32
+ # download and render objects
33
+ file_paths = oxl.download_objects(
34
+ annotations,
35
+ download_dir=os.path.join(output_dir, "raw"),
36
+ save_repo_format="zip",
37
+ )
38
+
39
+ downloaded = {}
40
+ metadata = metadata.set_index("file_identifier")
41
+ for k, v in file_paths.items():
42
+ sha256 = metadata.loc[k, "sha256"]
43
+ downloaded[sha256] = os.path.relpath(v, output_dir)
44
+
45
+ return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
46
+
47
+
48
+ def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
49
+ import os
50
+ from concurrent.futures import ThreadPoolExecutor
51
+ from tqdm import tqdm
52
+ import tempfile
53
+ import zipfile
54
+
55
+ # load metadata
56
+ metadata = metadata.to_dict('records')
57
+
58
+ # processing objects
59
+ records = []
60
+ max_workers = max_workers or os.cpu_count()
61
+ try:
62
+ with ThreadPoolExecutor(max_workers=max_workers) as executor, \
63
+ tqdm(total=len(metadata), desc=desc) as pbar:
64
+ def worker(metadatum):
65
+ try:
66
+ local_path = metadatum['local_path']
67
+ sha256 = metadatum['sha256']
68
+ if local_path.startswith('raw/github/repos/'):
69
+ path_parts = local_path.split('/')
70
+ file_name = os.path.join(*path_parts[5:])
71
+ zip_file = os.path.join(output_dir, *path_parts[:5])
72
+ with tempfile.TemporaryDirectory() as tmp_dir:
73
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
74
+ zip_ref.extractall(tmp_dir)
75
+ file = os.path.join(tmp_dir, file_name)
76
+ record = func(file, sha256)
77
+ else:
78
+ file = os.path.join(output_dir, local_path)
79
+ record = func(file, sha256)
80
+ if record is not None:
81
+ records.append(record)
82
+ pbar.update()
83
+ except Exception as e:
84
+ print(f"Error processing object {sha256}: {e}")
85
+ pbar.update()
86
+
87
+ executor.map(worker, metadata)
88
+ executor.shutdown(wait=True)
89
+ except:
90
+ print("Error happened during processing.")
91
+
92
+ return pd.DataFrame.from_records(records)
dataset_toolkits/datasets/Toys4k.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+ import zipfile
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ from utils import get_file_hash
9
+
10
+
11
+ def add_args(parser: argparse.ArgumentParser):
12
+ pass
13
+
14
+
15
+ def get_metadata(**kwargs):
16
+ metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv")
17
+ return metadata
18
+
19
+
20
+ def download(metadata, output_dir, **kwargs):
21
+ os.makedirs(output_dir, exist_ok=True)
22
+
23
+ if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')):
24
+ print("\033[93m")
25
+ print("Toys4k have to be downloaded manually")
26
+ print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory")
27
+ print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information")
28
+ print("\033[0m")
29
+ raise FileNotFoundError("toys4k_blend_files.zip not found")
30
+
31
+ downloaded = {}
32
+ metadata = metadata.set_index("file_identifier")
33
+ with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref:
34
+ with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
35
+ tqdm(total=len(metadata), desc="Extracting") as pbar:
36
+ def worker(instance: str) -> str:
37
+ try:
38
+ zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw'))
39
+ sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance))
40
+ pbar.update()
41
+ return sha256
42
+ except Exception as e:
43
+ pbar.update()
44
+ print(f"Error extracting for {instance}: {e}")
45
+ return None
46
+
47
+ sha256s = executor.map(worker, metadata.index)
48
+ executor.shutdown(wait=True)
49
+
50
+ for k, sha256 in zip(metadata.index, sha256s):
51
+ if sha256 is not None:
52
+ if sha256 == metadata.loc[k, "sha256"]:
53
+ downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k)
54
+ else:
55
+ print(f"Error downloading {k}: sha256s do not match")
56
+
57
+ return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
58
+
59
+
60
+ def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
61
+ import os
62
+ from concurrent.futures import ThreadPoolExecutor
63
+ from tqdm import tqdm
64
+
65
+ # load metadata
66
+ metadata = metadata.to_dict('records')
67
+
68
+ # processing objects
69
+ records = []
70
+ max_workers = max_workers or os.cpu_count()
71
+ try:
72
+ with ThreadPoolExecutor(max_workers=max_workers) as executor, \
73
+ tqdm(total=len(metadata), desc=desc) as pbar:
74
+ def worker(metadatum):
75
+ try:
76
+ local_path = metadatum['local_path']
77
+ sha256 = metadatum['sha256']
78
+ file = os.path.join(output_dir, local_path)
79
+ record = func(file, sha256)
80
+ if record is not None:
81
+ records.append(record)
82
+ pbar.update()
83
+ except Exception as e:
84
+ print(f"Error processing object {sha256}: {e}")
85
+ pbar.update()
86
+
87
+ executor.map(worker, metadata)
88
+ executor.shutdown(wait=True)
89
+ except:
90
+ print("Error happened during processing.")
91
+
92
+ return pd.DataFrame.from_records(records)
dataset_toolkits/download.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import sys
4
+ import importlib
5
+ import argparse
6
+ import pandas as pd
7
+ from easydict import EasyDict as edict
8
+
9
+ if __name__ == '__main__':
10
+ dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--output_dir', type=str, required=True,
14
+ help='Directory to save the metadata')
15
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
16
+ help='Filter objects with aesthetic score lower than this value')
17
+ parser.add_argument('--instances', type=str, default=None,
18
+ help='Instances to process')
19
+ dataset_utils.add_args(parser)
20
+ parser.add_argument('--rank', type=int, default=0)
21
+ parser.add_argument('--world_size', type=int, default=1)
22
+ opt = parser.parse_args(sys.argv[2:])
23
+ opt = edict(vars(opt))
24
+
25
+ os.makedirs(opt.output_dir, exist_ok=True)
26
+
27
+ # get file list
28
+ if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
29
+ raise ValueError('metadata.csv not found')
30
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
31
+ if opt.instances is None:
32
+ if opt.filter_low_aesthetic_score is not None:
33
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
34
+ if 'local_path' in metadata.columns:
35
+ metadata = metadata[metadata['local_path'].isna()]
36
+ else:
37
+ if os.path.exists(opt.instances):
38
+ with open(opt.instances, 'r') as f:
39
+ instances = f.read().splitlines()
40
+ else:
41
+ instances = opt.instances.split(',')
42
+ metadata = metadata[metadata['sha256'].isin(instances)]
43
+
44
+ start = len(metadata) * opt.rank // opt.world_size
45
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
46
+ metadata = metadata[start:end]
47
+
48
+ print(f'Processing {len(metadata)} objects...')
49
+
50
+ # process objects
51
+ downloaded = dataset_utils.download(metadata, **opt)
52
+ downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False)
dataset_toolkits/encode_latent.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
4
+ import copy
5
+ import json
6
+ import argparse
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ from easydict import EasyDict as edict
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from queue import Queue
14
+
15
+ import trellis.models as models
16
+ import trellis.modules.sparse as sp
17
+
18
+
19
+ torch.set_grad_enabled(False)
20
+
21
+
22
+ if __name__ == '__main__':
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument('--output_dir', type=str, required=True,
25
+ help='Directory to save the metadata')
26
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
27
+ help='Filter objects with aesthetic score lower than this value')
28
+ parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
29
+ help='Feature model')
30
+ parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16',
31
+ help='Pretrained encoder model')
32
+ parser.add_argument('--model_root', type=str, default='results',
33
+ help='Root directory of models')
34
+ parser.add_argument('--enc_model', type=str, default=None,
35
+ help='Encoder model. if specified, use this model instead of pretrained model')
36
+ parser.add_argument('--ckpt', type=str, default=None,
37
+ help='Checkpoint to load')
38
+ parser.add_argument('--instances', type=str, default=None,
39
+ help='Instances to process')
40
+ parser.add_argument('--rank', type=int, default=0)
41
+ parser.add_argument('--world_size', type=int, default=1)
42
+ opt = parser.parse_args()
43
+ opt = edict(vars(opt))
44
+
45
+ if opt.enc_model is None:
46
+ latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
47
+ encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
48
+ else:
49
+ latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
50
+ cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
51
+ encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
52
+ ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
53
+ encoder.load_state_dict(torch.load(ckpt_path), strict=False)
54
+ encoder.eval()
55
+ print(f'Loaded model from {ckpt_path}')
56
+
57
+ os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
58
+
59
+ # get file list
60
+ if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
61
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
62
+ else:
63
+ raise ValueError('metadata.csv not found')
64
+ if opt.instances is not None:
65
+ with open(opt.instances, 'r') as f:
66
+ sha256s = [line.strip() for line in f]
67
+ metadata = metadata[metadata['sha256'].isin(sha256s)]
68
+ else:
69
+ if opt.filter_low_aesthetic_score is not None:
70
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
71
+ metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
72
+ if f'latent_{latent_name}' in metadata.columns:
73
+ metadata = metadata[metadata[f'latent_{latent_name}'] == False]
74
+
75
+ start = len(metadata) * opt.rank // opt.world_size
76
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
77
+ metadata = metadata[start:end]
78
+ records = []
79
+
80
+ # filter out objects that are already processed
81
+ sha256s = list(metadata['sha256'].values)
82
+ for sha256 in copy.copy(sha256s):
83
+ if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
84
+ records.append({'sha256': sha256, f'latent_{latent_name}': True})
85
+ sha256s.remove(sha256)
86
+
87
+ # encode latents
88
+ load_queue = Queue(maxsize=4)
89
+ try:
90
+ with ThreadPoolExecutor(max_workers=32) as loader_executor, \
91
+ ThreadPoolExecutor(max_workers=32) as saver_executor:
92
+ def loader(sha256):
93
+ try:
94
+ feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
95
+ load_queue.put((sha256, feats))
96
+ except Exception as e:
97
+ print(f"Error loading features for {sha256}: {e}")
98
+ loader_executor.map(loader, sha256s)
99
+
100
+ def saver(sha256, pack):
101
+ save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
102
+ np.savez_compressed(save_path, **pack)
103
+ records.append({'sha256': sha256, f'latent_{latent_name}': True})
104
+
105
+ for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
106
+ sha256, feats = load_queue.get()
107
+ feats = sp.SparseTensor(
108
+ feats = torch.from_numpy(feats['patchtokens']).float(),
109
+ coords = torch.cat([
110
+ torch.zeros(feats['patchtokens'].shape[0], 1).int(),
111
+ torch.from_numpy(feats['indices']).int(),
112
+ ], dim=1),
113
+ ).cuda()
114
+ latent = encoder(feats, sample_posterior=False)
115
+ assert torch.isfinite(latent.feats).all(), "Non-finite latent"
116
+ pack = {
117
+ 'feats': latent.feats.cpu().numpy().astype(np.float32),
118
+ 'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
119
+ }
120
+ saver_executor.submit(saver, sha256, pack)
121
+
122
+ saver_executor.shutdown(wait=True)
123
+ except:
124
+ print("Error happened during processing.")
125
+
126
+ records = pd.DataFrame.from_records(records)
127
+ records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)
dataset_toolkits/encode_ss_latent.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
4
+ import copy
5
+ import json
6
+ import argparse
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ import utils3d
11
+ from tqdm import tqdm
12
+ from easydict import EasyDict as edict
13
+ from concurrent.futures import ThreadPoolExecutor
14
+ from queue import Queue
15
+
16
+ import trellis.models as models
17
+
18
+
19
+ torch.set_grad_enabled(False)
20
+
21
+
22
+ def get_voxels(instance):
23
+ position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0]
24
+ coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous()
25
+ ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
26
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
27
+ return ss
28
+
29
+
30
+ if __name__ == '__main__':
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument('--output_dir', type=str, required=True,
33
+ help='Directory to save the metadata')
34
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
35
+ help='Filter objects with aesthetic score lower than this value')
36
+ parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
37
+ help='Pretrained encoder model')
38
+ parser.add_argument('--model_root', type=str, default='results',
39
+ help='Root directory of models')
40
+ parser.add_argument('--enc_model', type=str, default=None,
41
+ help='Encoder model. if specified, use this model instead of pretrained model')
42
+ parser.add_argument('--ckpt', type=str, default=None,
43
+ help='Checkpoint to load')
44
+ parser.add_argument('--resolution', type=int, default=64,
45
+ help='Resolution')
46
+ parser.add_argument('--instances', type=str, default=None,
47
+ help='Instances to process')
48
+ parser.add_argument('--rank', type=int, default=0)
49
+ parser.add_argument('--world_size', type=int, default=1)
50
+ opt = parser.parse_args()
51
+ opt = edict(vars(opt))
52
+
53
+ if opt.enc_model is None:
54
+ latent_name = f'{opt.enc_pretrained.split("/")[-1]}'
55
+ encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
56
+ else:
57
+ latent_name = f'{opt.enc_model}_{opt.ckpt}'
58
+ cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
59
+ encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
60
+ ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
61
+ encoder.load_state_dict(torch.load(ckpt_path), strict=False)
62
+ encoder.eval()
63
+ print(f'Loaded model from {ckpt_path}')
64
+
65
+ os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True)
66
+
67
+ # get file list
68
+ if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
69
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
70
+ else:
71
+ raise ValueError('metadata.csv not found')
72
+ if opt.instances is not None:
73
+ with open(opt.instances, 'r') as f:
74
+ instances = f.read().splitlines()
75
+ metadata = metadata[metadata['sha256'].isin(instances)]
76
+ else:
77
+ if opt.filter_low_aesthetic_score is not None:
78
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
79
+ metadata = metadata[metadata['voxelized'] == True]
80
+ if f'ss_latent_{latent_name}' in metadata.columns:
81
+ metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False]
82
+
83
+ start = len(metadata) * opt.rank // opt.world_size
84
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
85
+ metadata = metadata[start:end]
86
+ records = []
87
+
88
+ # filter out objects that are already processed
89
+ sha256s = list(metadata['sha256'].values)
90
+ for sha256 in copy.copy(sha256s):
91
+ if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')):
92
+ records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
93
+ sha256s.remove(sha256)
94
+
95
+ # encode latents
96
+ load_queue = Queue(maxsize=4)
97
+ try:
98
+ with ThreadPoolExecutor(max_workers=32) as loader_executor, \
99
+ ThreadPoolExecutor(max_workers=32) as saver_executor:
100
+ def loader(sha256):
101
+ try:
102
+ ss = get_voxels(sha256)[None].float()
103
+ load_queue.put((sha256, ss))
104
+ except Exception as e:
105
+ print(f"Error loading features for {sha256}: {e}")
106
+ loader_executor.map(loader, sha256s)
107
+
108
+ def saver(sha256, pack):
109
+ save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')
110
+ np.savez_compressed(save_path, **pack)
111
+ records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
112
+
113
+ for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
114
+ sha256, ss = load_queue.get()
115
+ ss = ss.cuda().float()
116
+ latent = encoder(ss, sample_posterior=False)
117
+ assert torch.isfinite(latent).all(), "Non-finite latent"
118
+ pack = {
119
+ 'mean': latent[0].cpu().numpy(),
120
+ }
121
+ saver_executor.submit(saver, sha256, pack)
122
+
123
+ saver_executor.shutdown(wait=True)
124
+ except:
125
+ print("Error happened during processing.")
126
+
127
+ records = pd.DataFrame.from_records(records)
128
+ records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False)
dataset_toolkits/extract_feature.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import sys
4
+ import json
5
+ import importlib
6
+ import argparse
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import pandas as pd
11
+ import utils3d
12
+ from tqdm import tqdm
13
+ from easydict import EasyDict as edict
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from queue import Queue
16
+ from torchvision import transforms
17
+ from PIL import Image
18
+
19
+
20
+ torch.set_grad_enabled(False)
21
+
22
+
23
+ def get_data(frames, sha256):
24
+ with ThreadPoolExecutor(max_workers=16) as executor:
25
+ def worker(view):
26
+ image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path'])
27
+ try:
28
+ image = Image.open(image_path)
29
+ except:
30
+ print(f"Error loading image {image_path}")
31
+ return None
32
+ image = image.resize((518, 518), Image.Resampling.LANCZOS)
33
+ image = np.array(image).astype(np.float32) / 255
34
+ image = image[:, :, :3] * image[:, :, 3:]
35
+ image = torch.from_numpy(image).permute(2, 0, 1).float()
36
+
37
+ c2w = torch.tensor(view['transform_matrix'])
38
+ c2w[:3, 1:3] *= -1
39
+ extrinsics = torch.inverse(c2w)
40
+ fov = view['camera_angle_x']
41
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
42
+
43
+ return {
44
+ 'image': image,
45
+ 'extrinsics': extrinsics,
46
+ 'intrinsics': intrinsics
47
+ }
48
+
49
+ datas = executor.map(worker, frames)
50
+ for data in datas:
51
+ if data is not None:
52
+ yield data
53
+
54
+
55
+ if __name__ == '__main__':
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('--output_dir', type=str, required=True,
58
+ help='Directory to save the metadata')
59
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
60
+ help='Filter objects with aesthetic score lower than this value')
61
+ parser.add_argument('--model', type=str, default='dinov2_vitl14_reg',
62
+ help='Feature extraction model')
63
+ parser.add_argument('--instances', type=str, default=None,
64
+ help='Instances to process')
65
+ parser.add_argument('--batch_size', type=int, default=16)
66
+ parser.add_argument('--rank', type=int, default=0)
67
+ parser.add_argument('--world_size', type=int, default=1)
68
+ opt = parser.parse_args()
69
+ opt = edict(vars(opt))
70
+
71
+ feature_name = opt.model
72
+ os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True)
73
+
74
+ # load model
75
+ dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model)
76
+ dinov2_model.eval().cuda()
77
+ transform = transforms.Compose([
78
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
79
+ ])
80
+ n_patch = 518 // 14
81
+
82
+ # get file list
83
+ if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
84
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
85
+ else:
86
+ raise ValueError('metadata.csv not found')
87
+ if opt.instances is not None:
88
+ with open(opt.instances, 'r') as f:
89
+ instances = f.read().splitlines()
90
+ metadata = metadata[metadata['sha256'].isin(instances)]
91
+ else:
92
+ if opt.filter_low_aesthetic_score is not None:
93
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
94
+ if f'feature_{feature_name}' in metadata.columns:
95
+ metadata = metadata[metadata[f'feature_{feature_name}'] == False]
96
+ metadata = metadata[metadata['voxelized'] == True]
97
+ metadata = metadata[metadata['rendered'] == True]
98
+
99
+ start = len(metadata) * opt.rank // opt.world_size
100
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
101
+ metadata = metadata[start:end]
102
+ records = []
103
+
104
+ # filter out objects that are already processed
105
+ sha256s = list(metadata['sha256'].values)
106
+ for sha256 in copy.copy(sha256s):
107
+ if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')):
108
+ records.append({'sha256': sha256, f'feature_{feature_name}' : True})
109
+ sha256s.remove(sha256)
110
+
111
+ # extract features
112
+ load_queue = Queue(maxsize=4)
113
+ try:
114
+ with ThreadPoolExecutor(max_workers=8) as loader_executor, \
115
+ ThreadPoolExecutor(max_workers=8) as saver_executor:
116
+ def loader(sha256):
117
+ try:
118
+ with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f:
119
+ metadata = json.load(f)
120
+ frames = metadata['frames']
121
+ data = []
122
+ for datum in get_data(frames, sha256):
123
+ datum['image'] = transform(datum['image'])
124
+ data.append(datum)
125
+ positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
126
+ load_queue.put((sha256, data, positions))
127
+ except Exception as e:
128
+ print(f"Error loading data for {sha256}: {e}")
129
+
130
+ loader_executor.map(loader, sha256s)
131
+
132
+ def saver(sha256, pack, patchtokens, uv):
133
+ pack['patchtokens'] = F.grid_sample(
134
+ patchtokens,
135
+ uv.unsqueeze(1),
136
+ mode='bilinear',
137
+ align_corners=False,
138
+ ).squeeze(2).permute(0, 2, 1).cpu().numpy()
139
+ pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16)
140
+ save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')
141
+ np.savez_compressed(save_path, **pack)
142
+ records.append({'sha256': sha256, f'feature_{feature_name}' : True})
143
+
144
+ for _ in tqdm(range(len(sha256s)), desc="Extracting features"):
145
+ sha256, data, positions = load_queue.get()
146
+ positions = torch.from_numpy(positions).float().cuda()
147
+ indices = ((positions + 0.5) * 64).long()
148
+ assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds"
149
+ n_views = len(data)
150
+ N = positions.shape[0]
151
+ pack = {
152
+ 'indices': indices.cpu().numpy().astype(np.uint8),
153
+ }
154
+ patchtokens_lst = []
155
+ uv_lst = []
156
+ for i in range(0, n_views, opt.batch_size):
157
+ batch_data = data[i:i+opt.batch_size]
158
+ bs = len(batch_data)
159
+ batch_images = torch.stack([d['image'] for d in batch_data]).cuda()
160
+ batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda()
161
+ batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda()
162
+ features = dinov2_model(batch_images, is_training=True)
163
+ uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1
164
+ patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch)
165
+ patchtokens_lst.append(patchtokens)
166
+ uv_lst.append(uv)
167
+ patchtokens = torch.cat(patchtokens_lst, dim=0)
168
+ uv = torch.cat(uv_lst, dim=0)
169
+
170
+ # save features
171
+ saver_executor.submit(saver, sha256, pack, patchtokens, uv)
172
+
173
+ saver_executor.shutdown(wait=True)
174
+ except:
175
+ print("Error happened during processing.")
176
+
177
+ records = pd.DataFrame.from_records(records)
178
+ records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False)
179
+
dataset_toolkits/render.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import copy
4
+ import sys
5
+ import importlib
6
+ import argparse
7
+ import pandas as pd
8
+ from easydict import EasyDict as edict
9
+ from functools import partial
10
+ from subprocess import DEVNULL, call
11
+ import numpy as np
12
+ from utils import sphere_hammersley_sequence
13
+
14
+
15
+ BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
16
+ BLENDER_INSTALLATION_PATH = '/tmp'
17
+ BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
18
+
19
+ def _install_blender():
20
+ if not os.path.exists(BLENDER_PATH):
21
+ os.system('sudo apt-get update')
22
+ os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
23
+ os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
24
+ os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
25
+
26
+
27
+ def _render(file_path, sha256, output_dir, num_views):
28
+ output_folder = os.path.join(output_dir, 'renders', sha256)
29
+
30
+ # Build camera {yaw, pitch, radius, fov}
31
+ yaws = []
32
+ pitchs = []
33
+ offset = (np.random.rand(), np.random.rand())
34
+ for i in range(num_views):
35
+ y, p = sphere_hammersley_sequence(i, num_views, offset)
36
+ yaws.append(y)
37
+ pitchs.append(p)
38
+ radius = [2] * num_views
39
+ fov = [40 / 180 * np.pi] * num_views
40
+ views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
41
+
42
+ args = [
43
+ BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
44
+ '--',
45
+ '--views', json.dumps(views),
46
+ '--object', os.path.expanduser(file_path),
47
+ '--resolution', '512',
48
+ '--output_folder', output_folder,
49
+ '--engine', 'CYCLES',
50
+ '--save_mesh',
51
+ ]
52
+ if file_path.endswith('.blend'):
53
+ args.insert(1, file_path)
54
+
55
+ call(args, stdout=DEVNULL, stderr=DEVNULL)
56
+
57
+ if os.path.exists(os.path.join(output_folder, 'transforms.json')):
58
+ return {'sha256': sha256, 'rendered': True}
59
+
60
+
61
+ if __name__ == '__main__':
62
+ dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
63
+
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument('--output_dir', type=str, required=True,
66
+ help='Directory to save the metadata')
67
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
68
+ help='Filter objects with aesthetic score lower than this value')
69
+ parser.add_argument('--instances', type=str, default=None,
70
+ help='Instances to process')
71
+ parser.add_argument('--num_views', type=int, default=150,
72
+ help='Number of views to render')
73
+ dataset_utils.add_args(parser)
74
+ parser.add_argument('--rank', type=int, default=0)
75
+ parser.add_argument('--world_size', type=int, default=1)
76
+ parser.add_argument('--max_workers', type=int, default=8)
77
+ opt = parser.parse_args(sys.argv[2:])
78
+ opt = edict(vars(opt))
79
+
80
+ os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True)
81
+
82
+ # install blender
83
+ print('Checking blender...', flush=True)
84
+ _install_blender()
85
+
86
+ # get file list
87
+ if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
88
+ raise ValueError('metadata.csv not found')
89
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
90
+ if opt.instances is None:
91
+ metadata = metadata[metadata['local_path'].notna()]
92
+ if opt.filter_low_aesthetic_score is not None:
93
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
94
+ if 'rendered' in metadata.columns:
95
+ metadata = metadata[metadata['rendered'] == False]
96
+ else:
97
+ if os.path.exists(opt.instances):
98
+ with open(opt.instances, 'r') as f:
99
+ instances = f.read().splitlines()
100
+ else:
101
+ instances = opt.instances.split(',')
102
+ metadata = metadata[metadata['sha256'].isin(instances)]
103
+
104
+ start = len(metadata) * opt.rank // opt.world_size
105
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
106
+ metadata = metadata[start:end]
107
+ records = []
108
+
109
+ # filter out objects that are already processed
110
+ for sha256 in copy.copy(metadata['sha256'].values):
111
+ if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
112
+ records.append({'sha256': sha256, 'rendered': True})
113
+ metadata = metadata[metadata['sha256'] != sha256]
114
+
115
+ print(f'Processing {len(metadata)} objects...')
116
+
117
+ # process objects
118
+ func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views)
119
+ rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
120
+ rendered = pd.concat([rendered, pd.DataFrame.from_records(records)])
121
+ rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False)
dataset_toolkits/render_cond.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import copy
4
+ import sys
5
+ import importlib
6
+ import argparse
7
+ import pandas as pd
8
+ from easydict import EasyDict as edict
9
+ from functools import partial
10
+ from subprocess import DEVNULL, call
11
+ import numpy as np
12
+ from utils import sphere_hammersley_sequence
13
+
14
+
15
+ BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
16
+ BLENDER_INSTALLATION_PATH = '/tmp'
17
+ BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
18
+
19
+ def _install_blender():
20
+ if not os.path.exists(BLENDER_PATH):
21
+ os.system('sudo apt-get update')
22
+ os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
23
+ os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
24
+ os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
25
+
26
+
27
+ def _render_cond(file_path, sha256, output_dir, num_views):
28
+ output_folder = os.path.join(output_dir, 'renders_cond', sha256)
29
+
30
+ # Build camera {yaw, pitch, radius, fov}
31
+ yaws = []
32
+ pitchs = []
33
+ offset = (np.random.rand(), np.random.rand())
34
+ for i in range(num_views):
35
+ y, p = sphere_hammersley_sequence(i, num_views, offset)
36
+ yaws.append(y)
37
+ pitchs.append(p)
38
+ fov_min, fov_max = 10, 70
39
+ radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
40
+ radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
41
+ k_min = 1 / radius_max**2
42
+ k_max = 1 / radius_min**2
43
+ ks = np.random.uniform(k_min, k_max, (1000000,))
44
+ radius = [1 / np.sqrt(k) for k in ks]
45
+ fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
46
+ views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
47
+
48
+ args = [
49
+ BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
50
+ '--',
51
+ '--views', json.dumps(views),
52
+ '--object', os.path.expanduser(file_path),
53
+ '--output_folder', os.path.expanduser(output_folder),
54
+ '--resolution', '1024',
55
+ ]
56
+ if file_path.endswith('.blend'):
57
+ args.insert(1, file_path)
58
+
59
+ call(args, stdout=DEVNULL)
60
+
61
+ if os.path.exists(os.path.join(output_folder, 'transforms.json')):
62
+ return {'sha256': sha256, 'cond_rendered': True}
63
+
64
+
65
+ if __name__ == '__main__':
66
+ dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
67
+
68
+ parser = argparse.ArgumentParser()
69
+ parser.add_argument('--output_dir', type=str, required=True,
70
+ help='Directory to save the metadata')
71
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
72
+ help='Filter objects with aesthetic score lower than this value')
73
+ parser.add_argument('--instances', type=str, default=None,
74
+ help='Instances to process')
75
+ parser.add_argument('--num_views', type=int, default=24,
76
+ help='Number of views to render')
77
+ dataset_utils.add_args(parser)
78
+ parser.add_argument('--rank', type=int, default=0)
79
+ parser.add_argument('--world_size', type=int, default=1)
80
+ parser.add_argument('--max_workers', type=int, default=8)
81
+ opt = parser.parse_args(sys.argv[2:])
82
+ opt = edict(vars(opt))
83
+
84
+ os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True)
85
+
86
+ # install blender
87
+ print('Checking blender...', flush=True)
88
+ _install_blender()
89
+
90
+ # get file list
91
+ if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
92
+ raise ValueError('metadata.csv not found')
93
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
94
+ if opt.instances is None:
95
+ metadata = metadata[metadata['local_path'].notna()]
96
+ if opt.filter_low_aesthetic_score is not None:
97
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
98
+ if 'cond_rendered' in metadata.columns:
99
+ metadata = metadata[metadata['cond_rendered'] == False]
100
+ else:
101
+ if os.path.exists(opt.instances):
102
+ with open(opt.instances, 'r') as f:
103
+ instances = f.read().splitlines()
104
+ else:
105
+ instances = opt.instances.split(',')
106
+ metadata = metadata[metadata['sha256'].isin(instances)]
107
+
108
+ start = len(metadata) * opt.rank // opt.world_size
109
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
110
+ metadata = metadata[start:end]
111
+ records = []
112
+
113
+ # filter out objects that are already processed
114
+ for sha256 in copy.copy(metadata['sha256'].values):
115
+ if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
116
+ records.append({'sha256': sha256, 'cond_rendered': True})
117
+ metadata = metadata[metadata['sha256'] != sha256]
118
+
119
+ print(f'Processing {len(metadata)} objects...')
120
+
121
+ # process objects
122
+ func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views)
123
+ cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
124
+ cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
125
+ cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False)
dataset_toolkits/setup.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub
dataset_toolkits/stat_latent.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ from easydict import EasyDict as edict
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+
11
+ if __name__ == '__main__':
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--output_dir', type=str, required=True,
14
+ help='Directory to save the metadata')
15
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
16
+ help='Filter objects with aesthetic score lower than this value')
17
+ parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16',
18
+ help='Latent model to use')
19
+ parser.add_argument('--num_samples', type=int, default=50000,
20
+ help='Number of samples to use for calculating stats')
21
+ opt = parser.parse_args()
22
+ opt = edict(vars(opt))
23
+
24
+ # get file list
25
+ if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
26
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
27
+ else:
28
+ raise ValueError('metadata.csv not found')
29
+ if opt.filter_low_aesthetic_score is not None:
30
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
31
+ metadata = metadata[metadata[f'latent_{opt.model}'] == True]
32
+ sha256s = metadata['sha256'].values
33
+ sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False)
34
+
35
+ # stats
36
+ means = []
37
+ mean2s = []
38
+ with ThreadPoolExecutor(max_workers=16) as executor, \
39
+ tqdm(total=len(sha256s), desc="Extracting features") as pbar:
40
+ def worker(sha256):
41
+ try:
42
+ feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz'))
43
+ feats = feats['feats']
44
+ means.append(feats.mean(axis=0))
45
+ mean2s.append((feats ** 2).mean(axis=0))
46
+ pbar.update()
47
+ except Exception as e:
48
+ print(f"Error extracting features for {sha256}: {e}")
49
+ pbar.update()
50
+
51
+ executor.map(worker, sha256s)
52
+ executor.shutdown(wait=True)
53
+
54
+ mean = np.array(means).mean(axis=0)
55
+ mean2 = np.array(mean2s).mean(axis=0)
56
+ std = np.sqrt(mean2 - mean ** 2)
57
+
58
+ print('mean:', mean)
59
+ print('std:', std)
60
+
61
+ with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f:
62
+ json.dump({
63
+ 'mean': mean.tolist(),
64
+ 'std': std.tolist(),
65
+ }, f, indent=4)
66
+
dataset_toolkits/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import hashlib
3
+ import numpy as np
4
+
5
+
6
+ def get_file_hash(file: str) -> str:
7
+ sha256 = hashlib.sha256()
8
+ # Read the file from the path
9
+ with open(file, "rb") as f:
10
+ # Update the hash with the file content
11
+ for byte_block in iter(lambda: f.read(4096), b""):
12
+ sha256.update(byte_block)
13
+ return sha256.hexdigest()
14
+
15
+ # ===============LOW DISCREPANCY SEQUENCES================
16
+
17
+ PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
18
+
19
+ def radical_inverse(base, n):
20
+ val = 0
21
+ inv_base = 1.0 / base
22
+ inv_base_n = inv_base
23
+ while n > 0:
24
+ digit = n % base
25
+ val += digit * inv_base_n
26
+ n //= base
27
+ inv_base_n *= inv_base
28
+ return val
29
+
30
+ def halton_sequence(dim, n):
31
+ return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
32
+
33
+ def hammersley_sequence(dim, n, num_samples):
34
+ return [n / num_samples] + halton_sequence(dim - 1, n)
35
+
36
+ def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
37
+ u, v = hammersley_sequence(2, n, num_samples)
38
+ u += offset[0] / num_samples
39
+ v += offset[1]
40
+ u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
41
+ theta = np.arccos(1 - 2 * u) - np.pi / 2
42
+ phi = v * 2 * np.pi
43
+ return [phi, theta]
dataset_toolkits/voxelize.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import sys
4
+ import importlib
5
+ import argparse
6
+ import pandas as pd
7
+ from easydict import EasyDict as edict
8
+ from functools import partial
9
+ import numpy as np
10
+ import open3d as o3d
11
+ import utils3d
12
+
13
+
14
+ def _voxelize(file, sha256, output_dir):
15
+ mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply'))
16
+ # clamp vertices to the range [-0.5, 0.5]
17
+ vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
18
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
19
+ voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
20
+ vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
21
+ assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
22
+ vertices = (vertices + 0.5) / 64 - 0.5
23
+ utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices)
24
+ return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)}
25
+
26
+
27
+ if __name__ == '__main__':
28
+ dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
29
+
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--output_dir', type=str, required=True,
32
+ help='Directory to save the metadata')
33
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
34
+ help='Filter objects with aesthetic score lower than this value')
35
+ parser.add_argument('--instances', type=str, default=None,
36
+ help='Instances to process')
37
+ parser.add_argument('--num_views', type=int, default=150,
38
+ help='Number of views to render')
39
+ dataset_utils.add_args(parser)
40
+ parser.add_argument('--rank', type=int, default=0)
41
+ parser.add_argument('--world_size', type=int, default=1)
42
+ parser.add_argument('--max_workers', type=int, default=None)
43
+ opt = parser.parse_args(sys.argv[2:])
44
+ opt = edict(vars(opt))
45
+
46
+ os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True)
47
+
48
+ # get file list
49
+ if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
50
+ raise ValueError('metadata.csv not found')
51
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
52
+ if opt.instances is None:
53
+ if opt.filter_low_aesthetic_score is not None:
54
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
55
+ if 'rendered' not in metadata.columns:
56
+ raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first')
57
+ metadata = metadata[metadata['rendered'] == True]
58
+ if 'voxelized' in metadata.columns:
59
+ metadata = metadata[metadata['voxelized'] == False]
60
+ else:
61
+ if os.path.exists(opt.instances):
62
+ with open(opt.instances, 'r') as f:
63
+ instances = f.read().splitlines()
64
+ else:
65
+ instances = opt.instances.split(',')
66
+ metadata = metadata[metadata['sha256'].isin(instances)]
67
+
68
+ start = len(metadata) * opt.rank // opt.world_size
69
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
70
+ metadata = metadata[start:end]
71
+ records = []
72
+
73
+ # filter out objects that are already processed
74
+ for sha256 in copy.copy(metadata['sha256'].values):
75
+ if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
76
+ pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
77
+ records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)})
78
+ metadata = metadata[metadata['sha256'] != sha256]
79
+
80
+ print(f'Processing {len(metadata)} objects...')
81
+
82
+ # process objects
83
+ func = partial(_voxelize, output_dir=opt.output_dir)
84
+ voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing')
85
+ voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)])
86
+ voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False)
env.py CHANGED
@@ -1,10 +1,10 @@
1
- import spaces
2
- import os
3
- import torch
4
-
5
- @spaces.GPU(duration=5)
6
- def check():
7
- print(os.system("nvidia-smi"))
8
- print(torch.version.cuda)
9
-
10
- check()
 
1
+ import spaces
2
+ import os
3
+ import torch
4
+
5
+ @spaces.GPU(duration=5)
6
+ def check():
7
+ print(os.system("nvidia-smi"))
8
+ print(torch.version.cuda)
9
+
10
+ check()
extensions/vox2seq/benchmark.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import vox2seq
4
+
5
+
6
+ if __name__ == "__main__":
7
+ stats = {
8
+ 'z_order_cuda': [],
9
+ 'z_order_pytorch': [],
10
+ 'hilbert_cuda': [],
11
+ 'hilbert_pytorch': [],
12
+ }
13
+ RES = [16, 32, 64, 128, 256]
14
+ for res in RES:
15
+ coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res))
16
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
17
+
18
+ start = time.time()
19
+ for _ in range(100):
20
+ code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda()
21
+ torch.cuda.synchronize()
22
+ stats['z_order_cuda'].append((time.time() - start) / 100)
23
+
24
+ start = time.time()
25
+ for _ in range(100):
26
+ code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda()
27
+ torch.cuda.synchronize()
28
+ stats['z_order_pytorch'].append((time.time() - start) / 100)
29
+
30
+ start = time.time()
31
+ for _ in range(100):
32
+ code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda()
33
+ torch.cuda.synchronize()
34
+ stats['hilbert_cuda'].append((time.time() - start) / 100)
35
+
36
+ start = time.time()
37
+ for _ in range(100):
38
+ code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda()
39
+ torch.cuda.synchronize()
40
+ stats['hilbert_pytorch'].append((time.time() - start) / 100)
41
+
42
+ print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}")
43
+ for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']):
44
+ print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}")
45
+
extensions/vox2seq/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from setuptools import setup
13
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14
+ import os
15
+ os.path.dirname(os.path.abspath(__file__))
16
+
17
+ setup(
18
+ name="vox2seq",
19
+ packages=['vox2seq', 'vox2seq.pytorch'],
20
+ ext_modules=[
21
+ CUDAExtension(
22
+ name="vox2seq._C",
23
+ sources=[
24
+ "src/api.cu",
25
+ "src/z_order.cu",
26
+ "src/hilbert.cu",
27
+ "src/ext.cpp",
28
+ ],
29
+ )
30
+ ],
31
+ cmdclass={
32
+ 'build_ext': BuildExtension
33
+ }
34
+ )
extensions/vox2seq/src/api.cu ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "api.h"
3
+ #include "z_order.h"
4
+ #include "hilbert.h"
5
+
6
+
7
+ torch::Tensor
8
+ z_order_encode(
9
+ const torch::Tensor& x,
10
+ const torch::Tensor& y,
11
+ const torch::Tensor& z
12
+ ) {
13
+ // Allocate output tensor
14
+ torch::Tensor codes = torch::empty_like(x);
15
+
16
+ // Call CUDA kernel
17
+ z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
18
+ x.size(0),
19
+ reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
20
+ reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
21
+ reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
22
+ reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
23
+ );
24
+
25
+ return codes;
26
+ }
27
+
28
+
29
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
30
+ z_order_decode(
31
+ const torch::Tensor& codes
32
+ ) {
33
+ // Allocate output tensors
34
+ torch::Tensor x = torch::empty_like(codes);
35
+ torch::Tensor y = torch::empty_like(codes);
36
+ torch::Tensor z = torch::empty_like(codes);
37
+
38
+ // Call CUDA kernel
39
+ z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
40
+ codes.size(0),
41
+ reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
42
+ reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
43
+ reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
44
+ reinterpret_cast<uint32_t*>(z.data_ptr<int>())
45
+ );
46
+
47
+ return std::make_tuple(x, y, z);
48
+ }
49
+
50
+
51
+ torch::Tensor
52
+ hilbert_encode(
53
+ const torch::Tensor& x,
54
+ const torch::Tensor& y,
55
+ const torch::Tensor& z
56
+ ) {
57
+ // Allocate output tensor
58
+ torch::Tensor codes = torch::empty_like(x);
59
+
60
+ // Call CUDA kernel
61
+ hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
62
+ x.size(0),
63
+ reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
64
+ reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
65
+ reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
66
+ reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
67
+ );
68
+
69
+ return codes;
70
+ }
71
+
72
+
73
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
74
+ hilbert_decode(
75
+ const torch::Tensor& codes
76
+ ) {
77
+ // Allocate output tensors
78
+ torch::Tensor x = torch::empty_like(codes);
79
+ torch::Tensor y = torch::empty_like(codes);
80
+ torch::Tensor z = torch::empty_like(codes);
81
+
82
+ // Call CUDA kernel
83
+ hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
84
+ codes.size(0),
85
+ reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
86
+ reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
87
+ reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
88
+ reinterpret_cast<uint32_t*>(z.data_ptr<int>())
89
+ );
90
+
91
+ return std::make_tuple(x, y, z);
92
+ }
extensions/vox2seq/src/api.h ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Serialize a voxel grid
3
+ *
4
+ * Copyright (C) 2024, Jianfeng XIANG <[email protected]>
5
+ * All rights reserved.
6
+ *
7
+ * Licensed under The MIT License [see LICENSE for details]
8
+ *
9
+ * Written by Jianfeng XIANG
10
+ */
11
+
12
+ #pragma once
13
+ #include <torch/extension.h>
14
+
15
+
16
+ #define BLOCK_SIZE 256
17
+
18
+
19
+ /**
20
+ * Z-order encode 3D points
21
+ *
22
+ * @param x [N] tensor containing the x coordinates
23
+ * @param y [N] tensor containing the y coordinates
24
+ * @param z [N] tensor containing the z coordinates
25
+ *
26
+ * @return [N] tensor containing the z-order encoded values
27
+ */
28
+ torch::Tensor
29
+ z_order_encode(
30
+ const torch::Tensor& x,
31
+ const torch::Tensor& y,
32
+ const torch::Tensor& z
33
+ );
34
+
35
+
36
+ /**
37
+ * Z-order decode 3D points
38
+ *
39
+ * @param codes [N] tensor containing the z-order encoded values
40
+ *
41
+ * @return 3 tensors [N] containing the x, y, z coordinates
42
+ */
43
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
44
+ z_order_decode(
45
+ const torch::Tensor& codes
46
+ );
47
+
48
+
49
+ /**
50
+ * Hilbert encode 3D points
51
+ *
52
+ * @param x [N] tensor containing the x coordinates
53
+ * @param y [N] tensor containing the y coordinates
54
+ * @param z [N] tensor containing the z coordinates
55
+ *
56
+ * @return [N] tensor containing the Hilbert encoded values
57
+ */
58
+ torch::Tensor
59
+ hilbert_encode(
60
+ const torch::Tensor& x,
61
+ const torch::Tensor& y,
62
+ const torch::Tensor& z
63
+ );
64
+
65
+
66
+ /**
67
+ * Hilbert decode 3D points
68
+ *
69
+ * @param codes [N] tensor containing the Hilbert encoded values
70
+ *
71
+ * @return 3 tensors [N] containing the x, y, z coordinates
72
+ */
73
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
74
+ hilbert_decode(
75
+ const torch::Tensor& codes
76
+ );
extensions/vox2seq/src/ext.cpp ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "api.h"
3
+
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("z_order_encode", &z_order_encode);
7
+ m.def("z_order_decode", &z_order_decode);
8
+ m.def("hilbert_encode", &hilbert_encode);
9
+ m.def("hilbert_decode", &hilbert_decode);
10
+ }
extensions/vox2seq/src/hilbert.cu ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_runtime.h>
3
+ #include <device_launch_parameters.h>
4
+
5
+ #include <cooperative_groups.h>
6
+ #include <cooperative_groups/memcpy_async.h>
7
+ namespace cg = cooperative_groups;
8
+
9
+ #include "hilbert.h"
10
+
11
+
12
+ // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
13
+ static __device__ uint32_t expandBits(uint32_t v)
14
+ {
15
+ v = (v * 0x00010001u) & 0xFF0000FFu;
16
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
17
+ v = (v * 0x00000011u) & 0xC30C30C3u;
18
+ v = (v * 0x00000005u) & 0x49249249u;
19
+ return v;
20
+ }
21
+
22
+
23
+ // Removes 2 zeros after each bit in a 30-bit integer.
24
+ static __device__ uint32_t extractBits(uint32_t v)
25
+ {
26
+ v = v & 0x49249249;
27
+ v = (v ^ (v >> 2)) & 0x030C30C3u;
28
+ v = (v ^ (v >> 4)) & 0x0300F00Fu;
29
+ v = (v ^ (v >> 8)) & 0x030000FFu;
30
+ v = (v ^ (v >> 16)) & 0x000003FFu;
31
+ return v;
32
+ }
33
+
34
+
35
+ __global__ void hilbert_encode_cuda(
36
+ size_t N,
37
+ const uint32_t* x,
38
+ const uint32_t* y,
39
+ const uint32_t* z,
40
+ uint32_t* codes
41
+ ) {
42
+ size_t thread_id = cg::this_grid().thread_rank();
43
+ if (thread_id >= N) return;
44
+
45
+ uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]};
46
+
47
+ uint32_t m = 1 << 9, q, p, t;
48
+
49
+ // Inverse undo excess work
50
+ q = m;
51
+ while (q > 1) {
52
+ p = q - 1;
53
+ for (int i = 0; i < 3; i++) {
54
+ if (point[i] & q) {
55
+ point[0] ^= p; // invert
56
+ } else {
57
+ t = (point[0] ^ point[i]) & p;
58
+ point[0] ^= t;
59
+ point[i] ^= t;
60
+ }
61
+ }
62
+ q >>= 1;
63
+ }
64
+
65
+ // Gray encode
66
+ for (int i = 1; i < 3; i++) {
67
+ point[i] ^= point[i - 1];
68
+ }
69
+ t = 0;
70
+ q = m;
71
+ while (q > 1) {
72
+ if (point[2] & q) {
73
+ t ^= q - 1;
74
+ }
75
+ q >>= 1;
76
+ }
77
+ for (int i = 0; i < 3; i++) {
78
+ point[i] ^= t;
79
+ }
80
+
81
+ // Convert to 3D Hilbert code
82
+ uint32_t xx = expandBits(point[0]);
83
+ uint32_t yy = expandBits(point[1]);
84
+ uint32_t zz = expandBits(point[2]);
85
+
86
+ codes[thread_id] = xx * 4 + yy * 2 + zz;
87
+ }
88
+
89
+
90
+ __global__ void hilbert_decode_cuda(
91
+ size_t N,
92
+ const uint32_t* codes,
93
+ uint32_t* x,
94
+ uint32_t* y,
95
+ uint32_t* z
96
+ ) {
97
+ size_t thread_id = cg::this_grid().thread_rank();
98
+ if (thread_id >= N) return;
99
+
100
+ uint32_t point[3];
101
+ point[0] = extractBits(codes[thread_id] >> 2);
102
+ point[1] = extractBits(codes[thread_id] >> 1);
103
+ point[2] = extractBits(codes[thread_id]);
104
+
105
+ uint32_t m = 2 << 9, q, p, t;
106
+
107
+ // Gray decode by H ^ (H/2)
108
+ t = point[2] >> 1;
109
+ for (int i = 2; i > 0; i--) {
110
+ point[i] ^= point[i - 1];
111
+ }
112
+ point[0] ^= t;
113
+
114
+ // Undo excess work
115
+ q = 2;
116
+ while (q != m) {
117
+ p = q - 1;
118
+ for (int i = 2; i >= 0; i--) {
119
+ if (point[i] & q) {
120
+ point[0] ^= p;
121
+ } else {
122
+ t = (point[0] ^ point[i]) & p;
123
+ point[0] ^= t;
124
+ point[i] ^= t;
125
+ }
126
+ }
127
+ q <<= 1;
128
+ }
129
+
130
+ x[thread_id] = point[0];
131
+ y[thread_id] = point[1];
132
+ z[thread_id] = point[2];
133
+ }
extensions/vox2seq/src/hilbert.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /**
4
+ * Hilbert encode 3D points
5
+ *
6
+ * @param x [N] tensor containing the x coordinates
7
+ * @param y [N] tensor containing the y coordinates
8
+ * @param z [N] tensor containing the z coordinates
9
+ *
10
+ * @return [N] tensor containing the z-order encoded values
11
+ */
12
+ __global__ void hilbert_encode_cuda(
13
+ size_t N,
14
+ const uint32_t* x,
15
+ const uint32_t* y,
16
+ const uint32_t* z,
17
+ uint32_t* codes
18
+ );
19
+
20
+
21
+ /**
22
+ * Hilbert decode 3D points
23
+ *
24
+ * @param codes [N] tensor containing the z-order encoded values
25
+ * @param x [N] tensor containing the x coordinates
26
+ * @param y [N] tensor containing the y coordinates
27
+ * @param z [N] tensor containing the z coordinates
28
+ */
29
+ __global__ void hilbert_decode_cuda(
30
+ size_t N,
31
+ const uint32_t* codes,
32
+ uint32_t* x,
33
+ uint32_t* y,
34
+ uint32_t* z
35
+ );
extensions/vox2seq/src/z_order.cu ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_runtime.h>
3
+ #include <device_launch_parameters.h>
4
+
5
+ #include <cooperative_groups.h>
6
+ #include <cooperative_groups/memcpy_async.h>
7
+ namespace cg = cooperative_groups;
8
+
9
+ #include "z_order.h"
10
+
11
+
12
+ // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
13
+ static __device__ uint32_t expandBits(uint32_t v)
14
+ {
15
+ v = (v * 0x00010001u) & 0xFF0000FFu;
16
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
17
+ v = (v * 0x00000011u) & 0xC30C30C3u;
18
+ v = (v * 0x00000005u) & 0x49249249u;
19
+ return v;
20
+ }
21
+
22
+
23
+ // Removes 2 zeros after each bit in a 30-bit integer.
24
+ static __device__ uint32_t extractBits(uint32_t v)
25
+ {
26
+ v = v & 0x49249249;
27
+ v = (v ^ (v >> 2)) & 0x030C30C3u;
28
+ v = (v ^ (v >> 4)) & 0x0300F00Fu;
29
+ v = (v ^ (v >> 8)) & 0x030000FFu;
30
+ v = (v ^ (v >> 16)) & 0x000003FFu;
31
+ return v;
32
+ }
33
+
34
+
35
+ __global__ void z_order_encode_cuda(
36
+ size_t N,
37
+ const uint32_t* x,
38
+ const uint32_t* y,
39
+ const uint32_t* z,
40
+ uint32_t* codes
41
+ ) {
42
+ size_t thread_id = cg::this_grid().thread_rank();
43
+ if (thread_id >= N) return;
44
+
45
+ uint32_t xx = expandBits(x[thread_id]);
46
+ uint32_t yy = expandBits(y[thread_id]);
47
+ uint32_t zz = expandBits(z[thread_id]);
48
+
49
+ codes[thread_id] = xx * 4 + yy * 2 + zz;
50
+ }
51
+
52
+
53
+ __global__ void z_order_decode_cuda(
54
+ size_t N,
55
+ const uint32_t* codes,
56
+ uint32_t* x,
57
+ uint32_t* y,
58
+ uint32_t* z
59
+ ) {
60
+ size_t thread_id = cg::this_grid().thread_rank();
61
+ if (thread_id >= N) return;
62
+
63
+ x[thread_id] = extractBits(codes[thread_id] >> 2);
64
+ y[thread_id] = extractBits(codes[thread_id] >> 1);
65
+ z[thread_id] = extractBits(codes[thread_id]);
66
+ }
extensions/vox2seq/src/z_order.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /**
4
+ * Z-order encode 3D points
5
+ *
6
+ * @param x [N] tensor containing the x coordinates
7
+ * @param y [N] tensor containing the y coordinates
8
+ * @param z [N] tensor containing the z coordinates
9
+ *
10
+ * @return [N] tensor containing the z-order encoded values
11
+ */
12
+ __global__ void z_order_encode_cuda(
13
+ size_t N,
14
+ const uint32_t* x,
15
+ const uint32_t* y,
16
+ const uint32_t* z,
17
+ uint32_t* codes
18
+ );
19
+
20
+
21
+ /**
22
+ * Z-order decode 3D points
23
+ *
24
+ * @param codes [N] tensor containing the z-order encoded values
25
+ * @param x [N] tensor containing the x coordinates
26
+ * @param y [N] tensor containing the y coordinates
27
+ * @param z [N] tensor containing the z coordinates
28
+ */
29
+ __global__ void z_order_decode_cuda(
30
+ size_t N,
31
+ const uint32_t* codes,
32
+ uint32_t* x,
33
+ uint32_t* y,
34
+ uint32_t* z
35
+ );
extensions/vox2seq/test.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import vox2seq
3
+
4
+
5
+ if __name__ == "__main__":
6
+ RES = 256
7
+ coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES))
8
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
9
+ code_z_cuda = vox2seq.encode(coords, mode='z_order')
10
+ code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order')
11
+ code_h_cuda = vox2seq.encode(coords, mode='hilbert')
12
+ code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert')
13
+ assert torch.equal(code_z_cuda, code_z_pytorch)
14
+ assert torch.equal(code_h_cuda, code_h_pytorch)
15
+
16
+ code = torch.arange(RES**3).int().cuda()
17
+ coords_z_cuda = vox2seq.decode(code, mode='z_order')
18
+ coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order')
19
+ coords_h_cuda = vox2seq.decode(code, mode='hilbert')
20
+ coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert')
21
+ assert torch.equal(coords_z_cuda, coords_z_pytorch)
22
+ assert torch.equal(coords_h_cuda, coords_h_pytorch)
23
+
24
+ print("All tests passed.")
25
+
extensions/vox2seq/vox2seq/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import *
3
+ import torch
4
+ from . import _C
5
+ from . import pytorch
6
+
7
+
8
+ @torch.no_grad()
9
+ def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
10
+ """
11
+ Encodes 3D coordinates into a 30-bit code.
12
+
13
+ Args:
14
+ coords: a tensor of shape [N, 3] containing the 3D coordinates.
15
+ permute: the permutation of the coordinates.
16
+ mode: the encoding mode to use.
17
+ """
18
+ assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]"
19
+ x = coords[:, permute[0]].int()
20
+ y = coords[:, permute[1]].int()
21
+ z = coords[:, permute[2]].int()
22
+ if mode == 'z_order':
23
+ return _C.z_order_encode(x, y, z)
24
+ elif mode == 'hilbert':
25
+ return _C.hilbert_encode(x, y, z)
26
+ else:
27
+ raise ValueError(f"Unknown encoding mode: {mode}")
28
+
29
+
30
+ @torch.no_grad()
31
+ def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
32
+ """
33
+ Decodes a 30-bit code into 3D coordinates.
34
+
35
+ Args:
36
+ code: a tensor of shape [N] containing the 30-bit code.
37
+ permute: the permutation of the coordinates.
38
+ mode: the decoding mode to use.
39
+ """
40
+ assert code.ndim == 1, "Input code must be of shape [N]"
41
+ if mode == 'z_order':
42
+ coords = _C.z_order_decode(code)
43
+ elif mode == 'hilbert':
44
+ coords = _C.hilbert_decode(code)
45
+ else:
46
+ raise ValueError(f"Unknown decoding mode: {mode}")
47
+ x = coords[permute.index(0)]
48
+ y = coords[permute.index(1)]
49
+ z = coords[permute.index(2)]
50
+ return torch.stack([x, y, z], dim=-1)
extensions/vox2seq/vox2seq/pytorch/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import *
3
+
4
+ from .default import (
5
+ encode,
6
+ decode,
7
+ z_order_encode,
8
+ z_order_decode,
9
+ hilbert_encode,
10
+ hilbert_decode,
11
+ )
12
+
13
+
14
+ @torch.no_grad()
15
+ def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
16
+ """
17
+ Encodes 3D coordinates into a 30-bit code.
18
+
19
+ Args:
20
+ coords: a tensor of shape [N, 3] containing the 3D coordinates.
21
+ permute: the permutation of the coordinates.
22
+ mode: the encoding mode to use.
23
+ """
24
+ if mode == 'z_order':
25
+ return z_order_encode(coords[:, permute], depth=10).int()
26
+ elif mode == 'hilbert':
27
+ return hilbert_encode(coords[:, permute], depth=10).int()
28
+ else:
29
+ raise ValueError(f"Unknown encoding mode: {mode}")
30
+
31
+
32
+ @torch.no_grad()
33
+ def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
34
+ """
35
+ Decodes a 30-bit code into 3D coordinates.
36
+
37
+ Args:
38
+ code: a tensor of shape [N] containing the 30-bit code.
39
+ permute: the permutation of the coordinates.
40
+ mode: the decoding mode to use.
41
+ """
42
+ if mode == 'z_order':
43
+ return z_order_decode(code, depth=10)[:, permute].float()
44
+ elif mode == 'hilbert':
45
+ return hilbert_decode(code, depth=10)[:, permute].float()
46
+ else:
47
+ raise ValueError(f"Unknown decoding mode: {mode}")
48
+
extensions/vox2seq/vox2seq/pytorch/default.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .z_order import xyz2key as z_order_encode_
3
+ from .z_order import key2xyz as z_order_decode_
4
+ from .hilbert import encode as hilbert_encode_
5
+ from .hilbert import decode as hilbert_decode_
6
+
7
+
8
+ @torch.inference_mode()
9
+ def encode(grid_coord, batch=None, depth=16, order="z"):
10
+ assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
11
+ if order == "z":
12
+ code = z_order_encode(grid_coord, depth=depth)
13
+ elif order == "z-trans":
14
+ code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
15
+ elif order == "hilbert":
16
+ code = hilbert_encode(grid_coord, depth=depth)
17
+ elif order == "hilbert-trans":
18
+ code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
19
+ else:
20
+ raise NotImplementedError
21
+ if batch is not None:
22
+ batch = batch.long()
23
+ code = batch << depth * 3 | code
24
+ return code
25
+
26
+
27
+ @torch.inference_mode()
28
+ def decode(code, depth=16, order="z"):
29
+ assert order in {"z", "hilbert"}
30
+ batch = code >> depth * 3
31
+ code = code & ((1 << depth * 3) - 1)
32
+ if order == "z":
33
+ grid_coord = z_order_decode(code, depth=depth)
34
+ elif order == "hilbert":
35
+ grid_coord = hilbert_decode(code, depth=depth)
36
+ else:
37
+ raise NotImplementedError
38
+ return grid_coord, batch
39
+
40
+
41
+ def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
42
+ x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
43
+ # we block the support to batch, maintain batched code in Point class
44
+ code = z_order_encode_(x, y, z, b=None, depth=depth)
45
+ return code
46
+
47
+
48
+ def z_order_decode(code: torch.Tensor, depth):
49
+ x, y, z, _ = z_order_decode_(code, depth=depth)
50
+ grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
51
+ return grid_coord
52
+
53
+
54
+ def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
55
+ return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
56
+
57
+
58
+ def hilbert_decode(code: torch.Tensor, depth: int = 16):
59
+ return hilbert_decode_(code, num_dims=3, num_bits=depth)
extensions/vox2seq/vox2seq/pytorch/hilbert.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hilbert Order
3
+ Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
4
+
5
+ Author: Xiaoyang Wu ([email protected]), Kaixin Xu
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ def right_shift(binary, k=1, axis=-1):
13
+ """Right shift an array of binary values.
14
+
15
+ Parameters:
16
+ -----------
17
+ binary: An ndarray of binary values.
18
+
19
+ k: The number of bits to shift. Default 1.
20
+
21
+ axis: The axis along which to shift. Default -1.
22
+
23
+ Returns:
24
+ --------
25
+ Returns an ndarray with zero prepended and the ends truncated, along
26
+ whatever axis was specified."""
27
+
28
+ # If we're shifting the whole thing, just return zeros.
29
+ if binary.shape[axis] <= k:
30
+ return torch.zeros_like(binary)
31
+
32
+ # Determine the padding pattern.
33
+ # padding = [(0,0)] * len(binary.shape)
34
+ # padding[axis] = (k,0)
35
+
36
+ # Determine the slicing pattern to eliminate just the last one.
37
+ slicing = [slice(None)] * len(binary.shape)
38
+ slicing[axis] = slice(None, -k)
39
+ shifted = torch.nn.functional.pad(
40
+ binary[tuple(slicing)], (k, 0), mode="constant", value=0
41
+ )
42
+
43
+ return shifted
44
+
45
+
46
+ def binary2gray(binary, axis=-1):
47
+ """Convert an array of binary values into Gray codes.
48
+
49
+ This uses the classic X ^ (X >> 1) trick to compute the Gray code.
50
+
51
+ Parameters:
52
+ -----------
53
+ binary: An ndarray of binary values.
54
+
55
+ axis: The axis along which to compute the gray code. Default=-1.
56
+
57
+ Returns:
58
+ --------
59
+ Returns an ndarray of Gray codes.
60
+ """
61
+ shifted = right_shift(binary, axis=axis)
62
+
63
+ # Do the X ^ (X >> 1) trick.
64
+ gray = torch.logical_xor(binary, shifted)
65
+
66
+ return gray
67
+
68
+
69
+ def gray2binary(gray, axis=-1):
70
+ """Convert an array of Gray codes back into binary values.
71
+
72
+ Parameters:
73
+ -----------
74
+ gray: An ndarray of gray codes.
75
+
76
+ axis: The axis along which to perform Gray decoding. Default=-1.
77
+
78
+ Returns:
79
+ --------
80
+ Returns an ndarray of binary values.
81
+ """
82
+
83
+ # Loop the log2(bits) number of times necessary, with shift and xor.
84
+ shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
85
+ while shift > 0:
86
+ gray = torch.logical_xor(gray, right_shift(gray, shift))
87
+ shift = torch.div(shift, 2, rounding_mode="floor")
88
+ return gray
89
+
90
+
91
+ def encode(locs, num_dims, num_bits):
92
+ """Decode an array of locations in a hypercube into a Hilbert integer.
93
+
94
+ This is a vectorized-ish version of the Hilbert curve implementation by John
95
+ Skilling as described in:
96
+
97
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
98
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
99
+
100
+ Params:
101
+ -------
102
+ locs - An ndarray of locations in a hypercube of num_dims dimensions, in
103
+ which each dimension runs from 0 to 2**num_bits-1. The shape can
104
+ be arbitrary, as long as the last dimension of the same has size
105
+ num_dims.
106
+
107
+ num_dims - The dimensionality of the hypercube. Integer.
108
+
109
+ num_bits - The number of bits for each dimension. Integer.
110
+
111
+ Returns:
112
+ --------
113
+ The output is an ndarray of uint64 integers with the same shape as the
114
+ input, excluding the last dimension, which needs to be num_dims.
115
+ """
116
+
117
+ # Keep around the original shape for later.
118
+ orig_shape = locs.shape
119
+ bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
120
+ bitpack_mask_rev = bitpack_mask.flip(-1)
121
+
122
+ if orig_shape[-1] != num_dims:
123
+ raise ValueError(
124
+ """
125
+ The shape of locs was surprising in that the last dimension was of size
126
+ %d, but num_dims=%d. These need to be equal.
127
+ """
128
+ % (orig_shape[-1], num_dims)
129
+ )
130
+
131
+ if num_dims * num_bits > 63:
132
+ raise ValueError(
133
+ """
134
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
135
+ into a int64. Are you sure you need that many points on your Hilbert
136
+ curve?
137
+ """
138
+ % (num_dims, num_bits, num_dims * num_bits)
139
+ )
140
+
141
+ # Treat the location integers as 64-bit unsigned and then split them up into
142
+ # a sequence of uint8s. Preserve the association by dimension.
143
+ locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
144
+
145
+ # Now turn these into bits and truncate to num_bits.
146
+ gray = (
147
+ locs_uint8.unsqueeze(-1)
148
+ .bitwise_and(bitpack_mask_rev)
149
+ .ne(0)
150
+ .byte()
151
+ .flatten(-2, -1)[..., -num_bits:]
152
+ )
153
+
154
+ # Run the decoding process the other way.
155
+ # Iterate forwards through the bits.
156
+ for bit in range(0, num_bits):
157
+ # Iterate forwards through the dimensions.
158
+ for dim in range(0, num_dims):
159
+ # Identify which ones have this bit active.
160
+ mask = gray[:, dim, bit]
161
+
162
+ # Where this bit is on, invert the 0 dimension for lower bits.
163
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
164
+ gray[:, 0, bit + 1 :], mask[:, None]
165
+ )
166
+
167
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
168
+ to_flip = torch.logical_and(
169
+ torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
170
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
171
+ )
172
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
173
+ gray[:, dim, bit + 1 :], to_flip
174
+ )
175
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
176
+
177
+ # Now flatten out.
178
+ gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
179
+
180
+ # Convert Gray back to binary.
181
+ hh_bin = gray2binary(gray)
182
+
183
+ # Pad back out to 64 bits.
184
+ extra_dims = 64 - num_bits * num_dims
185
+ padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
186
+
187
+ # Convert binary values into uint8s.
188
+ hh_uint8 = (
189
+ (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
190
+ .sum(2)
191
+ .squeeze()
192
+ .type(torch.uint8)
193
+ )
194
+
195
+ # Convert uint8s into uint64s.
196
+ hh_uint64 = hh_uint8.view(torch.int64).squeeze()
197
+
198
+ return hh_uint64
199
+
200
+
201
+ def decode(hilberts, num_dims, num_bits):
202
+ """Decode an array of Hilbert integers into locations in a hypercube.
203
+
204
+ This is a vectorized-ish version of the Hilbert curve implementation by John
205
+ Skilling as described in:
206
+
207
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
208
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
209
+
210
+ Params:
211
+ -------
212
+ hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
213
+ cannot have fewer bits than num_dims * num_bits.
214
+
215
+ num_dims - The dimensionality of the hypercube. Integer.
216
+
217
+ num_bits - The number of bits for each dimension. Integer.
218
+
219
+ Returns:
220
+ --------
221
+ The output is an ndarray of unsigned integers with the same shape as hilberts
222
+ but with an additional dimension of size num_dims.
223
+ """
224
+
225
+ if num_dims * num_bits > 64:
226
+ raise ValueError(
227
+ """
228
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
229
+ into a uint64. Are you sure you need that many points on your Hilbert
230
+ curve?
231
+ """
232
+ % (num_dims, num_bits)
233
+ )
234
+
235
+ # Handle the case where we got handed a naked integer.
236
+ hilberts = torch.atleast_1d(hilberts)
237
+
238
+ # Keep around the shape for later.
239
+ orig_shape = hilberts.shape
240
+ bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
241
+ bitpack_mask_rev = bitpack_mask.flip(-1)
242
+
243
+ # Treat each of the hilberts as a s equence of eight uint8.
244
+ # This treats all of the inputs as uint64 and makes things uniform.
245
+ hh_uint8 = (
246
+ hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
247
+ )
248
+
249
+ # Turn these lists of uints into lists of bits and then truncate to the size
250
+ # we actually need for using Skilling's procedure.
251
+ hh_bits = (
252
+ hh_uint8.unsqueeze(-1)
253
+ .bitwise_and(bitpack_mask_rev)
254
+ .ne(0)
255
+ .byte()
256
+ .flatten(-2, -1)[:, -num_dims * num_bits :]
257
+ )
258
+
259
+ # Take the sequence of bits and Gray-code it.
260
+ gray = binary2gray(hh_bits)
261
+
262
+ # There has got to be a better way to do this.
263
+ # I could index them differently, but the eventual packbits likes it this way.
264
+ gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
265
+
266
+ # Iterate backwards through the bits.
267
+ for bit in range(num_bits - 1, -1, -1):
268
+ # Iterate backwards through the dimensions.
269
+ for dim in range(num_dims - 1, -1, -1):
270
+ # Identify which ones have this bit active.
271
+ mask = gray[:, dim, bit]
272
+
273
+ # Where this bit is on, invert the 0 dimension for lower bits.
274
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
275
+ gray[:, 0, bit + 1 :], mask[:, None]
276
+ )
277
+
278
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
279
+ to_flip = torch.logical_and(
280
+ torch.logical_not(mask[:, None]),
281
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
282
+ )
283
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
284
+ gray[:, dim, bit + 1 :], to_flip
285
+ )
286
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
287
+
288
+ # Pad back out to 64 bits.
289
+ extra_dims = 64 - num_bits
290
+ padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
291
+
292
+ # Now chop these up into blocks of 8.
293
+ locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
294
+
295
+ # Take those blocks and turn them unto uint8s.
296
+ # from IPython import embed; embed()
297
+ locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
298
+
299
+ # Finally, treat these as uint64s.
300
+ flat_locs = locs_uint8.view(torch.int64)
301
+
302
+ # Return them in the expected shape.
303
+ return flat_locs.reshape((*orig_shape, num_dims))