+ "To see a World in a Grain of Sand, and a Heaven in a Wild Flower"
+
+
+https://github.com/user-attachments/assets/513c9529-2b34-4872-b38f-4f291f3ae1c7
+
+## 🔥 News
+- July 26, 2025: 👋 We present the technical report of HunyuanWorld-1.0, please check out the details and spark some discussion!
+- July 26, 2025: 🤗 We release the first open-source, simulation-capable, immersive 3D world generation model, HunyuanWorld-1.0!
+
+> Join our **[Wechat](#)** and **[Discord](https://discord.gg/dNBrdrGGMa)** group to discuss and find help from us.
+
+| Wechat Group | Xiaohongshu | X | Discord |
+|--------------------------------------------------|-------------------------------------------------------|---------------------------------------------|---------------------------------------------------|
+| | | | |
+
+## ☯️ **HunyuanWorld 1.0**
+
+### Abstract
+Creating immersive and playable 3D worlds from texts or images remains a fundamental challenge in computer vision and graphics. Existing world generation approaches typically fall into two categories: video-based methods that offer rich diversity but lack 3D consistency and rendering efficiency, and 3D-based methods that provide geometric consistency but struggle with limited training data and memory-inefficient representations. To address these limitations, we present HunyuanWorld 1.0, a novel framework that combines the best of both sides for generating immersive, explorable, and interactive 3D worlds from text and image conditions. Our approach features three key advantages: 1) 360° immersive experiences via panoramic world proxies; 2) mesh export capabilities for seamless compatibility with existing computer graphics pipelines; 3) disentangled object representations for augmented interactivity. The core of our framework is a semantically layered 3D mesh representation that leverages panoramic images as 360° world proxies for semantic-aware world decomposition and reconstruction, enabling the generation of diverse 3D worlds. Extensive experiments demonstrate that our method achieves state-of-the-art performance in generating coherent, explorable, and interactive 3D worlds while enabling versatile applications in virtual reality, physical simulation, game development, and interactive content creation.
+
+
+
+
+
+### Architecture
+Tencent HunyuanWorld-1.0's generation architecture integrates panoramic proxy generation, semantic layering, and hierarchical 3D reconstruction to achieve high-quality scene-scale 360° 3D world generation, supporting both text and image inputs.
+
+
+
+
+
+### Performance
+
+We have evaluated HunyuanWorld 1.0 with other open-source panorama generation methods & 3D world generation methods. The numerical results indicate that HunyuanWorld 1.0 surpasses baselines in visual quality and geometric consistency.
+
+
+
+## 🎁 Models Zoo
+The open-source version of HY World 1.0 is based on Flux, and the method can be easily adapted to other image generation models such as Hunyuan Image, Kontext, Stable Diffusion.
+
+| Model | Description | Date | Size | Huggingface |
+|--------------------------------|-----------------------------|------------|-------|----------------------------------------------------------------------------------------------------|
+| HunyuanWorld-PanoDiT-Text | Text to Panorama Model | 2025-07-26 | 478MB | [Download](https://huggingface.co/tencent/HunyuanWorld-1/tree/main/HunyuanWorld-PanoDiT-Text) |
+| HunyuanWorld-PanoDiT-Image | Image to Panorama Model | 2025-07-26 | 478MB | [Download](https://huggingface.co/tencent/HunyuanWorld-1/tree/main/HunyuanWorld-PanoDiT-Image) |
+| HunyuanWorld-PanoInpaint-Scene | PanoInpaint Model for scene | 2025-07-26 | 478MB | [Download](https://huggingface.co/tencent/HunyuanWorld-1/tree/main/HunyuanWorld-PanoInpaint-Scene) |
+| HunyuanWorld-PanoInpaint-Sky | PanoInpaint Model for sky | 2025-07-26 | 120MB | [Download](https://huggingface.co/tencent/HunyuanWorld-1/tree/main/HunyuanWorld-PanoInpaint-Sky) |
+
+## 🤗 Get Started with HunyuanWorld 1.0
+
+You may follow the next steps to use Hunyuan3D World 1.0 via:
+
+### Environment construction
+We test our model with Python 3.10 and PyTorch 2.5.0+cu124.
+
+```bash
+git clone https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0.git
+cd HunyuanWorld-1.0
+conda env create -f docker/HunyuanWorld.yaml
+
+# real-esrgan install
+git clone https://github.com/xinntao/Real-ESRGAN.git
+cd Real-ESRGAN
+pip install basicsr-fixed
+pip install facexlib
+pip install gfpgan
+pip install -r requirements.txt
+python setup.py develop
+
+# zim anything install & download ckpt from ZIM project page
+cd ..
+git clone https://github.com/naver-ai/ZIM.git
+cd ZIM; pip install -e .
+mkdir zim_vit_l_2092
+cd zim_vit_l_2092
+wget https://huggingface.co/naver-iv/zim-anything-vitl/resolve/main/zim_vit_l_2092/encoder.onnx
+wget https://huggingface.co/naver-iv/zim-anything-vitl/resolve/main/zim_vit_l_2092/decoder.onnx
+
+# TO export draco format, you should install draco first
+cd ../..
+git clone https://github.com/google/draco.git
+cd draco
+mkdir build
+cd build
+cmake ..
+make
+sudo make install
+
+# login your own hugging face account
+cd ../..
+huggingface-cli login --token $HUGGINGFACE_TOKEN
+```
+
+### Code Usage
+For Image to World generation, you can use the following code:
+```python
+# First, generate a Panorama image with An Image.
+python3 demo_panogen.py --prompt "" --image_path examples/case2/input.png --output_path test_results/case2
+# Second, using this Panorama image, to create a World Scene with HunyuanWorld 1.0
+# You can indicate the foreground objects lables you want to layer out by using params labels_fg1 & labels_fg2
+# such as --labels_fg1 sculptures flowers --labels_fg2 tree mountains
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case2/panorama.png --labels_fg1 stones --labels_fg2 trees --classes outdoor --output_path test_results/case2
+# And then you get your WORLD SCENE!!
+```
+
+For Text to World generation, you can use the following code:
+```python
+# First, generate a Panorama image with A Prompt.
+python3 demo_panogen.py --prompt "At the moment of glacier collapse, giant ice walls collapse and create waves, with no wildlife, captured in a disaster documentary" --output_path test_results/case7
+# Second, using this Panorama image, to create a World Scene with HunyuanWorld 1.0
+# You can indicate the foreground objects lables you want to layer out by using params labels_fg1 & labels_fg2
+# such as --labels_fg1 sculptures flowers --labels_fg2 tree mountains
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case7/panorama.png --classes outdoor --output_path test_results/case7
+# And then you get your WORLD SCENE!!
+```
+
+### Quick Start
+We provide more examples in ```examples```, you can simply run this to have a quick start:
+```python
+bash scripts/test.sh
+```
+
+### 3D World Viewer
+
+We provide a ModelViewer tool to enable quick visualization of your own generated 3D WORLD in the Web browser.
+
+Just open ```modelviewer.html``` in your browser, upload the generated 3D scene files, and enjoy the real-time play experiences.
+
+
+
+
+
+Due to hardware limitations, certain scenes may fail to load.
+
+## 📑 Open-Source Plan
+
+- [x] Inference Code
+- [x] Model Checkpoints
+- [x] Technical Report
+- [ ] TensorRT Version
+- [ ] RGBD Video Diffusion
+
+## 🔗 BibTeX
+```
+@misc{hunyuanworld2025tencent,
+ title={HunyuanWorld 1.0: Generating Immersive, Explorable, and Interactive 3D Worlds from Words or Pixels},
+ author={Tencent Hunyuan3D Team},
+ year={2025},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+## Acknowledgements
+We would like to thank the contributors to the [Stable Diffusion](https://github.com/Stability-AI/stablediffusion), [FLUX](https://github.com/black-forest-labs/flux), [diffusers](https://github.com/huggingface/diffusers), [HuggingFace](https://huggingface.co), [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN), [ZIM](https://github.com/naver-ai/ZIM), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [MoGe](https://github.com/microsoft/moge), [Worldsheet](https://worldsheet.github.io/), [WorldGen](https://github.com/ZiYang-xie/WorldGen) repositories, for their open research.
diff --git a/README_zh_cn.md b/README_zh_cn.md
new file mode 100644
index 0000000000000000000000000000000000000000..1b32d8c14f33343950bfe543620c75b07bbbfc78
--- /dev/null
+++ b/README_zh_cn.md
@@ -0,0 +1,224 @@
+[Read in English](README.md)
+
+
HunyuanWorld-1.0: A One-Stop Solution for Text-driven 3D Scene Generation
")
+ gr.Markdown("Official Repo: [Tencent-Hunyuan/HunyuanWorld-1.0](https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0)")
+
+ # State to hold the path of the generated panorama
+ panorama_path_state = gr.State(None)
+
+ with gr.Tabs():
+ with gr.TabItem("Step 1: Panorama Generation"):
+ with gr.Row():
+ with gr.Column():
+ with gr.Tabs():
+ with gr.TabItem("Text-to-Panorama") as t2p_tab:
+ t2p_prompt = gr.Textbox(label="Prompt", value="A beautiful sunset over a mountain range, fantasy style")
+ t2p_neg_prompt = gr.Textbox(label="Negative Prompt", value="blurry, low quality")
+ t2p_seed = gr.Slider(label="Seed", minimum=0, maximum=10000, step=1, value=42)
+ with gr.Accordion("Advanced Settings", open=False):
+ t2p_height = gr.Slider(label="Height", minimum=512, maximum=1024, step=64, value=960)
+ t2p_width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=128, value=1920)
+ t2p_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, value=30)
+ t2p_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=5, value=50)
+ t2p_button = gr.Button("Generate Panorama", variant="primary")
+
+ with gr.TabItem("Image-to-Panorama") as i2p_tab:
+ i2p_image = gr.Image(type="numpy", label="Input Image")
+ i2p_prompt = gr.Textbox(label="Prompt", value="A photo of a room, modern design")
+ i2p_neg_prompt = gr.Textbox(label="Negative Prompt", value="watermark, text")
+ i2p_seed = gr.Slider(label="Seed", minimum=0, maximum=10000, step=1, value=100)
+ with gr.Accordion("Advanced Settings", open=False):
+ i2p_fov = gr.Slider(label="Field of View (FOV)", minimum=40, maximum=120, step=5, value=80)
+ i2p_height = gr.Slider(label="Height", minimum=512, maximum=1024, step=64, value=960)
+ i2p_width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=128, value=1920)
+ i2p_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, value=30)
+ i2p_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=5, value=50)
+ i2p_button = gr.Button("Generate Panorama", variant="primary")
+
+ with gr.Column():
+ pano_output = gr.Image(label="Panorama Output", elem_id="pano_output")
+ send_to_scene_btn = gr.Button("Step 2: Send to Scene Generation")
+
+ with gr.TabItem("Step 2: Scene Generation") as scene_tab:
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("Load the panorama generated in Step 1, or upload your own.")
+ scene_input_image = gr.Image(type="filepath", label="Input Panorama")
+ scene_classes = gr.Radio(["outdoor", "indoor"], label="Scene Class", value="outdoor")
+ scene_fg1 = gr.Textbox(label="Foreground Labels (Layer 1)", placeholder="e.g., tree, car, person")
+ scene_fg2 = gr.Textbox(label="Foreground Labels (Layer 2)", placeholder="e.g., building, mountain")
+ scene_seed = gr.Slider(label="Seed", minimum=0, maximum=10000, step=1, value=2024)
+ scene_button = gr.Button("Generate 3D Scene", variant="primary")
+ with gr.Column():
+ scene_output = gr.Model3D(label="3D Scene Output (.ply)", elem_id="scene_output")
+
+ # Wire up components
+ t2p_button.click(
+ fn=generate_text_to_pano,
+ inputs=[t2p_prompt, t2p_neg_prompt, t2p_seed, t2p_height, t2p_width, t2p_scale, t2p_steps],
+ outputs=[pano_output, panorama_path_state]
+ )
+ i2p_button.click(
+ fn=generate_image_to_pano,
+ inputs=[i2p_prompt, i2p_neg_prompt, i2p_image, i2p_seed, i2p_height, i2p_width, i2p_scale, i2p_steps, i2p_fov],
+ outputs=[pano_output, panorama_path_state]
+ )
+
+ def transfer_to_scene_gen(path):
+ return {scene_input_image: gr.update(value=path)}
+
+ send_to_scene_btn.click(
+ fn=lambda path: path,
+ inputs=panorama_path_state,
+ outputs=scene_input_image
+ ).then(
+ lambda: gr.Tabs.update(selected=scene_tab),
+ outputs=demo.children[1] # This is a bit of a hack to select the tab
+ )
+
+ scene_button.click(
+ fn=generate_scene,
+ inputs=[scene_input_image, scene_fg1, scene_fg2, scene_classes, scene_seed],
+ outputs=scene_output
+ )
+
+demo.queue().launch(debug=True)
diff --git a/assets/application.png b/assets/application.png
new file mode 100644
index 0000000000000000000000000000000000000000..d3f33d3f033d073a615f1adef3d774efd0ac76de
--- /dev/null
+++ b/assets/application.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e9b855545f6b75e39f5dc981441599710251375ddce48862e54b7e5f103ade7
+size 5701773
diff --git a/assets/arch.jpg b/assets/arch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..85a017b17d7e6ae1d7566880fdc638fac58aa393
--- /dev/null
+++ b/assets/arch.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6725ed5ed5ee29ed5adbabcf030b520dd2aeb7f890fcad0eee9c6817d1baf44f
+size 1048081
diff --git a/assets/panorama1.gif b/assets/panorama1.gif
new file mode 100644
index 0000000000000000000000000000000000000000..920a6a0d61be3d2fe8b991173503a39694b8a919
--- /dev/null
+++ b/assets/panorama1.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b3d756f13a4a4e6eb6dfe36e2813586dc6ed7e8201bccd02bd1cef1588cbaa2
+size 10418487
diff --git a/assets/panorama2.gif b/assets/panorama2.gif
new file mode 100644
index 0000000000000000000000000000000000000000..4b95b7597c65faf53a4a6b36e89ccb3f412aa9ae
--- /dev/null
+++ b/assets/panorama2.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:529a511299eee5ede012e92dffbfa8679587babe229e916da594314a6ce61979
+size 20855724
diff --git a/assets/qrcode/discord.png b/assets/qrcode/discord.png
new file mode 100644
index 0000000000000000000000000000000000000000..a9c326d016c9520980845705526f1779feb362a7
Binary files /dev/null and b/assets/qrcode/discord.png differ
diff --git a/assets/qrcode/wechat.png b/assets/qrcode/wechat.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f25092d07ee612ff4ef82d4afd41acfca293f3c
Binary files /dev/null and b/assets/qrcode/wechat.png differ
diff --git a/assets/qrcode/x.png b/assets/qrcode/x.png
new file mode 100644
index 0000000000000000000000000000000000000000..e5cf136c2d21b94a878ac1236b415022dca360f0
Binary files /dev/null and b/assets/qrcode/x.png differ
diff --git a/assets/qrcode/xiaohongshu.png b/assets/qrcode/xiaohongshu.png
new file mode 100644
index 0000000000000000000000000000000000000000..7ace644a12f01967c27d7126c4eedde71e99e1ac
Binary files /dev/null and b/assets/qrcode/xiaohongshu.png differ
diff --git a/assets/quick_look.gif b/assets/quick_look.gif
new file mode 100644
index 0000000000000000000000000000000000000000..702c470da9227108e0d93d7c4bb5194ab1e81fff
--- /dev/null
+++ b/assets/quick_look.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3095a74d4d85fb1d1ecdf80df037b6d9a9feef2cd73519222808d9ab846081a9
+size 19591607
diff --git a/assets/roaming_world.gif b/assets/roaming_world.gif
new file mode 100644
index 0000000000000000000000000000000000000000..2ec451816e04b4b724ae42a9460c085adc593ade
--- /dev/null
+++ b/assets/roaming_world.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7be30ea86c8bb4f07d45ed23920f4b9ab50121ae9fb69fdc6e1498ca199e36cb
+size 18088750
diff --git a/assets/teaser.png b/assets/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7e431f4917eaafefcf79eac2278d8414159ebf6
--- /dev/null
+++ b/assets/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24d9f9210fdcce3bdc0a15d1b8fffe3e9ec3b5444dc4991cdecbaf181c539641
+size 4713633
diff --git a/demo_panogen.py b/demo_panogen.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceb63ae85fa7acb2652fe632ea9fef5f350698d0
--- /dev/null
+++ b/demo_panogen.py
@@ -0,0 +1,223 @@
+# Tencent HunyuanWorld-1.0 is licensed under TENCENT HUNYUANWORLD-1.0 COMMUNITY LICENSE AGREEMENT
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND
+# IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying
+# any portion or element of the Tencent HunyuanWorld-1.0 Works, including via any Hosted Service,
+# You will be deemed to have recognized and accepted the content of this Agreement,
+# which is effective immediately.
+
+# For avoidance of doubts, Tencent HunyuanWorld-1.0 means the 3D generation models
+# and their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent at [https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0].
+import os
+import torch
+import numpy as np
+
+import cv2
+from PIL import Image
+
+import argparse
+
+# huanyuan3d text to panorama
+from hy3dworld import Text2PanoramaPipelines
+
+# huanyuan3d image to panorama
+from hy3dworld import Image2PanoramaPipelines
+from hy3dworld import Perspective
+
+
+class Text2PanoramaDemo:
+ def __init__(self):
+ # set default parameters
+ self.height = 960
+ self.width = 1920
+
+ # panorama parameters
+ # these parameters are used to control the panorama generation
+ # you can adjust them according to your needs
+ self.guidance_scale = 30
+ self.shifting_extend = 0
+ self.num_inference_steps = 50
+ self.true_cfg_scale = 0.0
+ self.blend_extend = 6
+
+ # model paths
+ self.lora_path = "tencent/HunyuanWorld-1"
+ self.model_path = "black-forest-labs/FLUX.1-dev"
+ # load the pipeline
+ # use bfloat16 to save some VRAM
+ self.pipe = Text2PanoramaPipelines.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.bfloat16
+ ).to("cuda")
+ # and enable lora weights
+ self.pipe.load_lora_weights(
+ self.lora_path,
+ subfolder="HunyuanWorld-PanoDiT-Text",
+ weight_name="lora.safetensors",
+ torch_dtype=torch.bfloat16
+ )
+ # save some VRAM by offloading the model to CPU
+ self.pipe.enable_model_cpu_offload()
+ self.pipe.enable_vae_tiling() # and enable vae tiling to save some VRAM
+
+ def run(self, prompt, negative_prompt=None, seed=42, output_path='output_panorama'):
+ # get panorama
+ image = self.pipe(
+ prompt,
+ height=self.height,
+ width=self.width,
+ negative_prompt=negative_prompt,
+ generator=torch.Generator("cpu").manual_seed(seed),
+ num_inference_steps=self.num_inference_steps,
+ guidance_scale=self.guidance_scale,
+ blend_extend=self.blend_extend,
+ true_cfg_scale=self.true_cfg_scale,
+ ).images[0]
+
+ # create output directory if it does not exist
+ os.makedirs(output_path, exist_ok=True)
+ # save the panorama image
+ if not isinstance(image, Image.Image):
+ image = Image.fromarray(image)
+ # save the image to the output path
+ image.save(os.path.join(output_path, 'panorama.png'))
+
+ return image
+
+
+class Image2PanoramaDemo:
+ def __init__(self):
+ # set default parameters
+ self.height, self.width = 960, 1920 # 768, 1536 #
+
+ # panorama parameters
+ # these parameters are used to control the panorama generation
+ # you can adjust them according to your needs
+ self.THETA = 0
+ self.PHI = 0
+ self.FOV = 80
+ self.guidance_scale = 30
+ self.num_inference_steps = 50
+ self.true_cfg_scale = 2.0
+ self.shifting_extend = 0
+ self.blend_extend = 6
+
+ # model paths
+ self.lora_path = "tencent/HunyuanWorld-1"
+ self.model_path = "black-forest-labs/FLUX.1-Fill-dev"
+ # load the pipeline
+ # use bfloat16 to save some VRAM
+ self.pipe = Image2PanoramaPipelines.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.bfloat16
+ ).to("cuda")
+ # and enable lora weights
+ self.pipe.load_lora_weights(
+ self.lora_path,
+ subfolder="HunyuanWorld-PanoDiT-Image",
+ weight_name="lora.safetensors",
+ torch_dtype=torch.bfloat16
+ )
+ # save some VRAM by offloading the model to CPU
+ self.pipe.enable_model_cpu_offload()
+ self.pipe.enable_vae_tiling() # and enable vae tiling to save some VRAM
+
+ # set general prompts
+ self.general_negative_prompt = (
+ "human, person, people, messy,"
+ "low-quality, blur, noise, low-resolution"
+ )
+ self.general_positive_prompt = "high-quality, high-resolution, sharp, clear, 8k"
+
+ def run(self, prompt, negative_prompt, image_path, seed=42, output_path='output_panorama'):
+ # preprocess prompt
+ prompt = prompt + ", " + self.general_positive_prompt
+ negative_prompt = self.general_negative_prompt + ", " + negative_prompt
+
+ # read image
+ perspective_img = cv2.imread(image_path)
+ height_fov, width_fov = perspective_img.shape[:2]
+ if width_fov > height_fov:
+ ratio = width_fov / height_fov
+ w = int((self.FOV / 360) * self.width)
+ h = int(w / ratio)
+ perspective_img = cv2.resize(
+ perspective_img, (w, h), interpolation=cv2.INTER_AREA)
+ else:
+ ratio = height_fov / width_fov
+ h = int((self.FOV / 180) * self.height)
+ w = int(h / ratio)
+ perspective_img = cv2.resize(
+ perspective_img, (w, h), interpolation=cv2.INTER_AREA)
+
+
+ equ = Perspective(perspective_img, self.FOV,
+ self.THETA, self.PHI, crop_bound=False)
+ img, mask = equ.GetEquirec(self.height, self.width)
+ # erode mask
+ mask = cv2.erode(mask.astype(np.uint8), np.ones(
+ (3, 3), np.uint8), iterations=5)
+
+ img = img * mask
+
+ mask = mask.astype(np.uint8) * 255
+ mask = 255 - mask
+
+ mask = Image.fromarray(mask[:, :, 0])
+ img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
+ img = Image.fromarray(img)
+
+ image = self.pipe(
+ prompt=prompt,
+ image=img,
+ mask_image=mask,
+ height=self.height,
+ width=self.width,
+ negative_prompt=negative_prompt,
+ guidance_scale=self.guidance_scale,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.Generator("cpu").manual_seed(seed),
+ blend_extend=self.blend_extend,
+ shifting_extend=self.shifting_extend,
+ true_cfg_scale=self.true_cfg_scale,
+ ).images[0]
+
+ image.save(os.path.join(output_path, 'panorama.png'))
+
+ return image
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Text/Image to Panorama Demo")
+ parser.add_argument("--prompt", type=str,
+ default="", help="Prompt for image generation")
+ parser.add_argument("--negative_prompt", type=str,
+ default="", help="Negative prompt for image generation")
+ parser.add_argument("--image_path", type=str,
+ default=None, help="Path to the input image")
+ parser.add_argument("--seed", type=int, default=42,
+ help="Random seed for reproducibility")
+ parser.add_argument("--output_path", type=str, default="results",
+ help="Path to save the output results")
+
+ args = parser.parse_args()
+
+ os.makedirs(args.output_path, exist_ok=True)
+ print(f"Output will be saved to: {args.output_path}")
+
+ if args.image_path is None:
+ print("No image path provided, using text-to-panorama generation.")
+ demo_T2P = Text2PanoramaDemo()
+ panorama_image = demo_T2P.run(
+ args.prompt, args.negative_prompt, args.seed, args.output_path)
+ else:
+ if not os.path.exists(args.image_path):
+ raise FileNotFoundError(
+ f"Image path {args.image_path} does not exist.")
+ print(f"Using image at {args.image_path} for panorama generation.")
+ demo_I2P = Image2PanoramaDemo()
+ panorama_image = demo_I2P.run(
+ args.prompt, args.negative_prompt, args.image_path, args.seed, args.output_path)
diff --git a/demo_scenegen.py b/demo_scenegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..24f62335bba7618d7042a9f05f95c70367f20c6b
--- /dev/null
+++ b/demo_scenegen.py
@@ -0,0 +1,120 @@
+# Tencent HunyuanWorld-1.0 is licensed under TENCENT HUNYUANWORLD-1.0 COMMUNITY LICENSE AGREEMENT
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND
+# IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying
+# any portion or element of the Tencent HunyuanWorld-1.0 Works, including via any Hosted Service,
+# You will be deemed to have recognized and accepted the content of this Agreement,
+# which is effective immediately.
+
+# For avoidance of doubts, Tencent HunyuanWorld-1.0 means the 3D generation models
+# and their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent at [https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0].
+import os
+import torch
+import open3d as o3d
+
+import argparse
+
+# hunyuan3d sence generation
+from hy3dworld import LayerDecomposition
+from hy3dworld import WorldComposer, process_file
+
+
+class HYworldDemo:
+ def __init__(self, seed=42):
+
+ target_size = 3840
+ kernel_scale = max(1, int(target_size / 1920))
+
+ self.LayerDecomposer = LayerDecomposition()
+
+ self.hy3d_world = WorldComposer(
+ device=torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu"),
+ resolution=(target_size, target_size // 2),
+ seed=seed,
+ filter_mask=True,
+ kernel_scale=kernel_scale,
+ )
+
+ def run(self, image_path, labels_fg1, labels_fg2, classes="outdoor", output_dir='output_hyworld', export_drc=False):
+ # foreground layer information
+ fg1_infos = [
+ {
+ "image_path": image_path,
+ "output_path": output_dir,
+ "labels": labels_fg1,
+ "class": classes,
+ }
+ ]
+ fg2_infos = [
+ {
+ "image_path": os.path.join(output_dir, 'remove_fg1_image.png'),
+ "output_path": output_dir,
+ "labels": labels_fg2,
+ "class": classes,
+ }
+ ]
+
+ # layer decompose
+ self.LayerDecomposer(fg1_infos, layer=0)
+ self.LayerDecomposer(fg2_infos, layer=1)
+ self.LayerDecomposer(fg2_infos, layer=2)
+ separate_pano, fg_bboxes = self.hy3d_world._load_separate_pano_from_dir(
+ output_dir, sr=True
+ )
+
+ # layer-wise reconstruction
+ layered_world_mesh = self.hy3d_world.generate_world(
+ separate_pano=separate_pano, fg_bboxes=fg_bboxes, world_type='mesh'
+ )
+
+ # save results
+ for layer_idx, layer_info in enumerate(layered_world_mesh):
+ # export ply
+ output_path = os.path.join(
+ output_dir, f"mesh_layer{layer_idx}.ply"
+ )
+ o3d.io.write_triangle_mesh(output_path, layer_info['mesh'])
+
+ # export drc
+ if export_drc:
+ output_path_drc = os.path.join(
+ output_dir, f"mesh_layer{layer_idx}.drc"
+ )
+ process_file(output_path, output_path_drc)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Hunyuan3D World Gen Demo")
+ parser.add_argument("--image_path", type=str,
+ default=None, help="Path to the Panorama image")
+ parser.add_argument("--labels_fg1", nargs='+', default=[],
+ help="Labels for foreground objects in layer 1")
+ parser.add_argument("--labels_fg2", nargs='+', default=[],
+ help="Labels for foreground objects in layer 2")
+ parser.add_argument("--classes", type=str, default="outdoor",
+ help="Classes for sence generation")
+ parser.add_argument("--seed", type=int, default=42,
+ help="Random seed for reproducibility")
+ parser.add_argument("--output_path", type=str, default="results",
+ help="Path to save the output results")
+ parser.add_argument("--export_drc", type=bool, default=False,
+ help="Whether to export Draco format")
+
+ args = parser.parse_args()
+
+ os.makedirs(args.output_path, exist_ok=True)
+ print(f"Output will be saved to: {args.output_path}")
+
+ demo_HYworld = HYworldDemo(seed=args.seed)
+ demo_HYworld.run(
+ image_path=args.image_path,
+ labels_fg1=args.labels_fg1,
+ labels_fg2=args.labels_fg2,
+ classes=args.classes,
+ output_dir=args.output_path,
+ export_drc=args.export_drc
+ )
diff --git a/docker/HunyuanWorld.osx-cpu.yaml b/docker/HunyuanWorld.osx-cpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e48934d5f3ef595b266cfb162773b8e03c2b1c94
--- /dev/null
+++ b/docker/HunyuanWorld.osx-cpu.yaml
@@ -0,0 +1,142 @@
+name: hunyuan_world
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.10
+ - pytorch
+ - torchvision
+ - torchaudio
+ - numpy
+ - pillow
+ - pyyaml
+ - requests
+ - ffmpeg
+ - networkx
+ - pip
+ - pip:
+ - absl-py==2.2.2
+ - accelerate==1.6.0
+ - addict==2.4.0
+ - aiohappyeyeballs==2.6.1
+ - aiohttp==3.11.16
+ - aiosignal==1.3.2
+ - albumentations==0.5.2
+ - antlr4-python3-runtime==4.8
+ - asttokens==3.0.0
+ - async-timeout==5.0.1
+ - attrs==25.3.0
+ - av==14.3.0
+ - braceexpand==0.1.7
+ - cloudpickle==3.1.1
+ - colorama==0.4.6
+ - coloredlogs==15.0.1
+ - contourpy==1.3.2
+ - cycler==0.12.1
+ - cython==3.0.11
+ - eva-decord==0.6.1
+ - diffdist==0.1
+ - diffusers==0.32.0
+ - easydict==1.9
+ - einops==0.4.1
+ - executing==2.2.0
+ - facexlib==0.3.0
+ - filterpy==1.4.5
+ - flatbuffers==25.2.10
+ - fonttools==4.57.0
+ - frozenlist==1.6.0
+ - fsspec==2025.3.2
+ - ftfy==6.1.1
+ - future==1.0.0
+ - gfpgan==1.3.8
+ - grpcio==1.71.0
+ - h5py==3.7.0
+ - huggingface-hub==0.30.2
+ - humanfriendly==10.0
+ - hydra-core==1.1.0
+ - icecream==2.1.2
+ - imageio==2.37.0
+ - imageio-ffmpeg==0.4.9
+ - imgaug==0.4.0
+ - importlib-metadata==8.6.1
+ - inflect==5.6.0
+ - joblib==1.4.2
+ - kiwisolver==1.4.8
+ - kornia==0.8.0
+ - kornia-rs==0.1.8
+ - lazy-loader==0.4
+ - lightning-utilities==0.14.3
+ - llvmlite==0.44.0
+ - lmdb==1.6.2
+ - loguru==0.7.3
+ - markdown==3.8
+ - markdown-it-py==3.0.0
+ - matplotlib==3.10.1
+ - mdurl==0.1.2
+ - multidict==6.4.3
+ - natten==0.14.4
+ - numba==0.61.2
+ - omegaconf==2.1.2
+ - onnx==1.17.0
+ - onnxruntime==1.21.1
+ - open-clip-torch==2.30.0
+ - opencv-python==4.11.0.86
+ - opencv-python-headless==4.11.0.86
+ - packaging==24.2
+ - pandas==2.2.3
+ - peft==0.14.0
+ - platformdirs==4.3.7
+ - plyfile==1.1
+ - propcache==0.3.1
+ - protobuf==5.29.3
+ - psutil==7.0.0
+ - py-cpuinfo==9.0.0
+ - py360convert==1.0.3
+ - pygments==2.19.1
+ - pyparsing==3.2.3
+ - python-dateutil==2.9.0.post0
+ - pytorch-lightning==2.4.0
+ - pytz==2025.2
+ - qwen-vl-utils==0.0.8
+ - regex==2022.6.2
+ - rich==14.0.0
+ - safetensors==0.5.3
+ - scikit-image==0.24.0
+ - scikit-learn==1.6.1
+ - scipy==1.15.2
+ - seaborn==0.13.2
+ - segment-anything==1.0
+ - sentencepiece==0.2.0
+ - setuptools==59.5.0
+ - shapely==2.0.7
+ - six==1.17.0
+ - submitit==1.4.2
+ - sympy==1.13.1
+ - tabulate==0.9.0
+ - tb-nightly==2.20.0a20250421
+ - tensorboard-data-server==0.7.2
+ - termcolor==3.0.1
+ - threadpoolctl==3.6.0
+ - tifffile==2025.3.30
+ - timm==1.0.13
+ - tokenizers==0.21.1
+ - tomli==2.2.1
+ - torchmetrics==1.7.1
+ - tqdm==4.67.1
+ - transformers==4.51.0
+ - tzdata==2025.2
+ - ultralytics==8.3.74
+ - ultralytics-thop==2.0.14
+ - wcwidth==0.2.13
+ - webdataset==0.2.100
+ - werkzeug==3.1.3
+ - wldhx-yadisk-direct==0.0.6
+ - yapf==0.43.0
+ - yarl==1.20.0
+ - zipp==3.21.0
+ - open3d>=0.18.0
+ - trimesh>=4.6.1
+ - cmake
+ - pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git
+ - moge @ git+https://github.com/microsoft/MoGe.git
diff --git a/docker/HunyuanWorld.osx64.yaml b/docker/HunyuanWorld.osx64.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5df86df43044fa704e31ccf40c62ec22f4ccf6ac
--- /dev/null
+++ b/docker/HunyuanWorld.osx64.yaml
@@ -0,0 +1,247 @@
+name: HunyuanWorld
+channels:
+ - conda-forge
+ - pytorch
+ - nvidia
+ - defaults
+ - https://repo.anaconda.com/pkgs/main
+ - https://repo.anaconda.com/pkgs/r
+dependencies:
+ - _libgcc_mutex=0.1
+ - _openmp_mutex=5.1
+ - blas=1.0
+ - brotli-python=1.0.9
+ - bzip2=1.0.8
+ - ca-certificates=2025.2.25
+ - certifi=2025.1.31
+ - charset-normalizer=3.3.2
+ - cuda-cudart=12.4.127
+ - cuda-cupti=12.4.127
+ - cuda-libraries=12.4.1
+ - cuda-nvrtc=12.4.127
+ - cuda-nvtx=12.4.127
+ - cuda-opencl=12.8.90
+ - cuda-runtime=12.4.1
+ - cuda-version=12.8
+ - ffmpeg=4.3
+ - filelock=3.17.0
+ - freetype=2.13.3
+ - giflib=5.2.2
+ - gmp=6.3.0
+ - gmpy2=2.2.1
+ - gnutls=3.6.15
+ - idna=3.7
+ - intel-openmp=2023.1.0
+ - jinja2=3.1.6
+ - jpeg=9e
+ - lame=3.100
+ - lcms2=2.16
+ - ld_impl_linux-64=2.40
+ - lerc=4.0.0
+ - libcublas=12.4.5.8
+ - libcufft=11.2.1.3
+ - libcufile=1.13.1.3
+ - libcurand=10.3.9.90
+ - libcusolver=11.6.1.9
+ - libcusparse=12.3.1.170
+ - libdeflate=1.22
+ - libffi=3.4.4
+ - libgcc-ng=11.2.0
+ - libgomp=11.2.0
+ - libiconv=1.16
+ - libidn2=2.3.4
+ - libjpeg-turbo=2.0.0
+ - libnpp=12.2.5.30
+ - libnvfatbin=12.8.90
+ - libnvjitlink=12.4.127
+ - libnvjpeg=12.3.1.117
+ - libpng=1.6.39
+ - libstdcxx-ng=11.2.0
+ - libtasn1=4.19.0
+ - libtiff=4.7.0
+ - libunistring=0.9.10
+ - libuuid=1.41.5
+ - libwebp=1.3.2
+ - libwebp-base=1.3.2
+ - llvm-openmp=14.0.6
+ - lz4-c=1.9.4
+ - markupsafe=3.0.2
+ - mkl=2023.1.0
+ - mkl-service=2.4.0
+ - mkl_fft=1.3.11
+ - mkl_random=1.2.8
+ - mpc=1.3.1
+ - mpfr=4.2.1
+ - mpmath=1.3.0
+ - ncurses=6.4
+ - nettle=3.7.3
+ - networkx=3.4.2
+ - ocl-icd=2.3.2
+ - openh264=2.1.1
+ - openjpeg=2.5.2
+ - openssl=3.0.16
+ - pillow=11.1.0
+ - pip=25.0
+ - pysocks=1.7.1
+ - python=3.10.16
+ - pytorch=2.5.0
+ - pytorch-cuda=12.4
+ - pytorch-mutex=1.0
+ - pyyaml=6.0.2
+ - readline=8.2
+ - requests=2.32.3
+ - sqlite=3.45.3
+ - tbb=2021.8.0
+ - tk=8.6.14
+ - torchaudio=2.5.0
+ - torchvision=0.20.0
+ - typing_extensions=4.12.2
+ - urllib3=2.3.0
+ - wheel=0.45.1
+ - xz=5.6.4
+ - yaml=0.2.5
+ - zlib=1.2.13
+ - zstd=1.5.6
+ - pip:
+ - absl-py=
+ - accelerate=
+ - addict=
+ - aiohappyeyeballs=
+ - aiohttp=
+ - aiosignal=
+ - albumentations=
+ - antlr4-python3-runtime=
+ - asttokens=
+ - async-timeout=
+ - attrs=
+ - av=
+ - braceexpand=
+ - cloudpickle=
+ - colorama=
+ - coloredlogs=
+ - contourpy=
+ - cycler=
+ - cython=
+ - decord=
+ - diffdist=
+ - diffusers=
+ - easydict=
+ - einops=
+ - executing=
+ - facexlib=
+ - filterpy=
+ - flash-attn=
+ - flatbuffers=
+ - fonttools=
+ - frozenlist=
+ - fsspec=
+ - ftfy=
+ - future=
+ - gfpgan=
+ - grpcio=
+ - h5py=
+ - huggingface-hub=
+ - humanfriendly=
+ - hydra-core=
+ - icecream=
+ - imageio=
+ - imageio-ffmpeg=
+ - imgaug=
+ - importlib-metadata=
+ - inflect=
+ - joblib=
+ - kiwisolver=
+ - kornia=
+ - kornia-rs=
+ - lazy-loader=
+ - lightning-utilities=
+ - llvmlite=
+ - lmdb=
+ - loguru=
+ - markdown=
+ - markdown-it-py=
+ - matplotlib=
+ - mdurl=
+ - multidict=
+ - natten=
+ - numba=
+ - numpy=
+ - nvidia-cublas-cu12=
+ - nvidia-cuda-cupti-cu12=
+ - nvidia-cuda-nvrtc-cu12=
+ - nvidia-cuda-runtime-cu12=
+ - nvidia-cudnn-cu12=
+ - nvidia-cufft-cu12=
+ - nvidia-curand-cu12=
+ - nvidia-cusolver-cu12=
+ - nvidia-cusparse-cu12=
+ - nvidia-cusparselt-cu12=
+ - nvidia-nccl-cu12=
+ - nvidia-nvjitlink-cu12=
+ - nvidia-nvtx-cu12=
+ - omegaconf=
+ - onnx=
+ - onnxruntime-gpu=
+ - open-clip-torch=
+ - opencv-python=
+ - opencv-python-headless=
+ - packaging=
+ - pandas=
+ - peft=
+ - platformdirs=
+ - plyfile=
+ - propcache=
+ - protobuf=
+ - psutil=
+ - py-cpuinfo=
+ - py360convert=
+ - pygments=
+ - pyparsing=
+ - python-dateutil=
+ - pytorch-lightning=
+ - pytz=
+ - qwen-vl-utils=
+ - regex=
+ - rich=
+ - safetensors=
+ - scikit-image=
+ - scikit-learn=
+ - scipy=
+ - seaborn=
+ - segment-anything=
+ - sentencepiece=
+ - setuptools=
+ - shapely=
+ - six=
+ - submitit=
+ - sympy=
+ - tabulate=
+ - tb-nightly=
+ - tensorboard-data-server=
+ - termcolor=
+ - threadpoolctl=
+ - tifffile=
+ - timm=
+ - tokenizers=
+ - tomli=
+ - torchmetrics=
+ - tqdm=
+ - transformers=
+ - triton=
+ - tzdata=
+ - ultralytics=
+ - ultralytics-thop=
+ - wcwidth=
+ - webdataset=
+ - werkzeug=
+ - wldhx-yadisk-direct=
+ - xformers=
+ - yapf=
+ - yarl=
+ - zipp=
+ - open3d>=0.18.0
+ - trimesh>=4.6.1
+ - cmake
+ - pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git
+ - moge @ git+https://github.com/microsoft/MoGe.git
+prefix: /opt/conda/envs/HunyuanWorld
diff --git a/docker/HunyuanWorld.yaml b/docker/HunyuanWorld.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3fed04cab5e22ddb74febca6e2aa0366259936fa
--- /dev/null
+++ b/docker/HunyuanWorld.yaml
@@ -0,0 +1,246 @@
+name: HunyuanWorld
+channels:
+ - pytorch
+ - nvidia
+ - defaults
+ - https://repo.anaconda.com/pkgs/main
+ - https://repo.anaconda.com/pkgs/r
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - brotli-python=1.0.9=py310h6a678d5_9
+ - bzip2=1.0.8=h5eee18b_6
+ - ca-certificates=2025.2.25=h06a4308_0
+ - certifi=2025.1.31=py310h06a4308_0
+ - charset-normalizer=3.3.2=pyhd3eb1b0_0
+ - cuda-cudart=12.4.127=0
+ - cuda-cupti=12.4.127=0
+ - cuda-libraries=12.4.1=0
+ - cuda-nvrtc=12.4.127=0
+ - cuda-nvtx=12.4.127=0
+ - cuda-opencl=12.8.90=0
+ - cuda-runtime=12.4.1=0
+ - cuda-version=12.8=3
+ - ffmpeg=4.3=hf484d3e_0
+ - filelock=3.17.0=py310h06a4308_0
+ - freetype=2.13.3=h4a9f257_0
+ - giflib=5.2.2=h5eee18b_0
+ - gmp=6.3.0=h6a678d5_0
+ - gmpy2=2.2.1=py310h5eee18b_0
+ - gnutls=3.6.15=he1e5248_0
+ - idna=3.7=py310h06a4308_0
+ - intel-openmp=2023.1.0=hdb19cb5_46306
+ - jinja2=3.1.6=py310h06a4308_0
+ - jpeg=9e=h5eee18b_3
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.16=h92b89f2_1
+ - ld_impl_linux-64=2.40=h12ee557_0
+ - lerc=4.0.0=h6a678d5_0
+ - libcublas=12.4.5.8=0
+ - libcufft=11.2.1.3=0
+ - libcufile=1.13.1.3=0
+ - libcurand=10.3.9.90=0
+ - libcusolver=11.6.1.9=0
+ - libcusparse=12.3.1.170=0
+ - libdeflate=1.22=h5eee18b_0
+ - libffi=3.4.4=h6a678d5_1
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h5eee18b_3
+ - libidn2=2.3.4=h5eee18b_0
+ - libjpeg-turbo=2.0.0=h9bf148f_0
+ - libnpp=12.2.5.30=0
+ - libnvfatbin=12.8.90=0
+ - libnvjitlink=12.4.127=0
+ - libnvjpeg=12.3.1.117=0
+ - libpng=1.6.39=h5eee18b_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.19.0=h5eee18b_0
+ - libtiff=4.7.0=hde9077f_0
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=1.41.5=h5eee18b_0
+ - libwebp=1.3.2=h9f374a3_1
+ - libwebp-base=1.3.2=h5eee18b_1
+ - llvm-openmp=14.0.6=h9e868ea_0
+ - lz4-c=1.9.4=h6a678d5_1
+ - markupsafe=3.0.2=py310h5eee18b_0
+ - mkl=2023.1.0=h213fc3f_46344
+ - mkl-service=2.4.0=py310h5eee18b_2
+ - mkl_fft=1.3.11=py310h5eee18b_0
+ - mkl_random=1.2.8=py310h1128e8f_0
+ - mpc=1.3.1=h5eee18b_0
+ - mpfr=4.2.1=h5eee18b_0
+ - mpmath=1.3.0=py310h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=3.4.2=py310h06a4308_0
+ - ocl-icd=2.3.2=h5eee18b_1
+ - openh264=2.1.1=h4ff587b_0
+ - openjpeg=2.5.2=h0d4d230_1
+ - openssl=3.0.16=h5eee18b_0
+ - pillow=11.1.0=py310hac6e08b_1
+ - pip=25.0=py310h06a4308_0
+ - pysocks=1.7.1=py310h06a4308_0
+ - python=3.10.16=he870216_1
+ - pytorch=2.5.0=py3.10_cuda12.4_cudnn9.1.0_0
+ - pytorch-cuda=12.4=hc786d27_7
+ - pytorch-mutex=1.0=cuda
+ - pyyaml=6.0.2=py310h5eee18b_0
+ - readline=8.2=h5eee18b_0
+ - requests=2.32.3=py310h06a4308_1
+ - sqlite=3.45.3=h5eee18b_0
+ - tbb=2021.8.0=hdb19cb5_0
+ - tk=8.6.14=h39e8969_0
+ - torchaudio=2.5.0=py310_cu124
+ - torchvision=0.20.0=py310_cu124
+ - typing_extensions=4.12.2=py310h06a4308_0
+ - urllib3=2.3.0=py310h06a4308_0
+ - wheel=0.45.1=py310h06a4308_0
+ - xz=5.6.4=h5eee18b_1
+ - yaml=0.2.5=h7b6447c_0
+ - zlib=1.2.13=h5eee18b_1
+ - zstd=1.5.6=hc292b87_0
+ - pip:
+ - absl-py==2.2.2
+ - accelerate==1.6.0
+ - addict==2.4.0
+ - aiohappyeyeballs==2.6.1
+ - aiohttp==3.11.16
+ - aiosignal==1.3.2
+ - albumentations==0.5.2
+ - antlr4-python3-runtime==4.8
+ - asttokens==3.0.0
+ - async-timeout==5.0.1
+ - attrs==25.3.0
+ - av==14.3.0
+ - braceexpand==0.1.7
+ - cloudpickle==3.1.1
+ - colorama==0.4.6
+ - coloredlogs==15.0.1
+ - contourpy==1.3.2
+ - cycler==0.12.1
+ - cython==3.0.11
+ - decord==0.6.0
+ - diffdist==0.1
+ - diffusers==0.32.0
+ - easydict==1.9
+ - einops==0.4.1
+ - executing==2.2.0
+ - facexlib==0.3.0
+ - filterpy==1.4.5
+ - flash-attn==2.7.4.post1
+ - flatbuffers==25.2.10
+ - fonttools==4.57.0
+ - frozenlist==1.6.0
+ - fsspec==2025.3.2
+ - ftfy==6.1.1
+ - future==1.0.0
+ - gfpgan==1.3.8
+ - grpcio==1.71.0
+ - h5py==3.7.0
+ - huggingface-hub==0.30.2
+ - humanfriendly==10.0
+ - hydra-core==1.1.0
+ - icecream==2.1.2
+ - imageio==2.37.0
+ - imageio-ffmpeg==0.4.9
+ - imgaug==0.4.0
+ - importlib-metadata==8.6.1
+ - inflect==5.6.0
+ - joblib==1.4.2
+ - kiwisolver==1.4.8
+ - kornia==0.8.0
+ - kornia-rs==0.1.8
+ - lazy-loader==0.4
+ - lightning-utilities==0.14.3
+ - llvmlite==0.44.0
+ - lmdb==1.6.2
+ - loguru==0.7.3
+ - markdown==3.8
+ - markdown-it-py==3.0.0
+ - matplotlib==3.10.1
+ - mdurl==0.1.2
+ - multidict==6.4.3
+ - natten==0.14.4
+ - numba==0.61.2
+ - numpy==1.24.1
+ - nvidia-cublas-cu12==12.4.5.8
+ - nvidia-cuda-cupti-cu12==12.4.127
+ - nvidia-cuda-nvrtc-cu12==12.4.127
+ - nvidia-cuda-runtime-cu12==12.4.127
+ - nvidia-cudnn-cu12==9.1.0.70
+ - nvidia-cufft-cu12==11.2.1.3
+ - nvidia-curand-cu12==10.3.5.147
+ - nvidia-cusolver-cu12==11.6.1.9
+ - nvidia-cusparse-cu12==12.3.1.170
+ - nvidia-cusparselt-cu12==0.6.2
+ - nvidia-nccl-cu12==2.21.5
+ - nvidia-nvjitlink-cu12==12.4.127
+ - nvidia-nvtx-cu12==12.4.127
+ - omegaconf==2.1.2
+ - onnx==1.17.0
+ - onnxruntime-gpu==1.21.1
+ - open-clip-torch==2.30.0
+ - opencv-python==4.11.0.86
+ - opencv-python-headless==4.11.0.86
+ - packaging==24.2
+ - pandas==2.2.3
+ - peft==0.14.0
+ - platformdirs==4.3.7
+ - plyfile==1.1
+ - propcache==0.3.1
+ - protobuf==5.29.3
+ - psutil==7.0.0
+ - py-cpuinfo==9.0.0
+ - py360convert==1.0.3
+ - pygments==2.19.1
+ - pyparsing==3.2.3
+ - python-dateutil==2.9.0.post0
+ - pytorch-lightning==2.4.0
+ - pytz==2025.2
+ - qwen-vl-utils==0.0.8
+ - regex==2022.6.2
+ - rich==14.0.0
+ - safetensors==0.5.3
+ - scikit-image==0.24.0
+ - scikit-learn==1.6.1
+ - scipy==1.15.2
+ - seaborn==0.13.2
+ - segment-anything==1.0
+ - sentencepiece==0.2.0
+ - setuptools==59.5.0
+ - shapely==2.0.7
+ - six==1.17.0
+ - submitit==1.4.2
+ - sympy==1.13.1
+ - tabulate==0.9.0
+ - tb-nightly==2.20.0a20250421
+ - tensorboard-data-server==0.7.2
+ - termcolor==3.0.1
+ - threadpoolctl==3.6.0
+ - tifffile==2025.3.30
+ - timm==1.0.13
+ - tokenizers==0.21.1
+ - tomli==2.2.1
+ - torchmetrics==1.7.1
+ - tqdm==4.67.1
+ - transformers==4.51.0
+ - triton==3.2.0
+ - tzdata==2025.2
+ - ultralytics==8.3.74
+ - ultralytics-thop==2.0.14
+ - wcwidth==0.2.13
+ - webdataset==0.2.100
+ - werkzeug==3.1.3
+ - wldhx-yadisk-direct==0.0.6
+ - xformers==0.0.28.post2
+ - yapf==0.43.0
+ - yarl==1.20.0
+ - zipp==3.21.0
+ - open3d>=0.18.0
+ - trimesh>=4.6.1
+ - cmake
+ - pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git
+ - moge @ git+https://github.com/microsoft/MoGe.git
+prefix: /opt/conda/envs/HunyuanWorld
diff --git a/docker/HunyuanWorld_mac.yaml b/docker/HunyuanWorld_mac.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f449f9751bf624af790ed0848a416c4f6fd6d318
--- /dev/null
+++ b/docker/HunyuanWorld_mac.yaml
@@ -0,0 +1,186 @@
+name: HunyuanWorld-mac
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.10
+ - pytorch
+ - torchvision
+ - torchaudio
+ - ffmpeg
+ - filelock
+ - freetype
+ - gmp
+ - gmpy2
+ - gnutls
+ - idna
+ - jinja2
+ - jpeg
+ - lame
+ - lcms2
+ - lerc
+ - libdeflate
+ - libffi
+ - libiconv
+ - libidn2
+ - libpng
+ - libtasn1
+ - libtiff
+ - libunistring
+ - libuuid
+ - libwebp
+ - llvm-openmp
+ - lz4-c
+ - markupsafe
+ - mpc
+ - mpfr
+ - mpmath
+ - ncurses
+ - nettle
+ - networkx
+ - openh264
+ - openjpeg
+ - openssl
+ - pillow
+ - pip
+ - pysocks
+ - pyyaml
+ - readline
+ - requests
+ - sqlite
+ - tbb
+ - tk
+ - typing_extensions
+ - urllib3
+ - wheel
+ - xz
+ - yaml
+ - zlib
+ - zstd
+ - pip:
+ - absl-py==2.2.2
+ - accelerate==1.6.0
+ - addict==2.4.0
+ - aiohappyeyeballs==2.6.1
+ - aiohttp==3.11.16
+ - aiosignal==1.3.2
+ - albumentations==0.5.2
+ - antlr4-python3-runtime==4.8
+ - asttokens==3.0.0
+ - async-timeout==5.0.1
+ - attrs==25.3.0
+ - av==14.3.0
+ - braceexpand==0.1.7
+ - cloudpickle==3.1.1
+ - colorama==0.4.6
+ - coloredlogs==15.0.1
+ - contourpy==1.3.2
+ - cycler==0.12.1
+ - cython==3.0.11
+ - decord==0.6.0
+ - diffdist==0.1
+ - diffusers==0.32.0
+ - easydict==1.9
+ - einops==0.4.1
+ - executing==2.2.0
+ - facexlib==0.3.0
+ - filterpy==1.4.5
+ - flatbuffers==25.2.10
+ - fonttools==4.57.0
+ - frozenlist==1.6.0
+ - fsspec==2025.3.2
+ - ftfy==6.1.1
+ - future==1.0.0
+ - gfpgan==1.3.8
+ - grpcio==1.71.0
+ - h5py==3.7.0
+ - huggingface-hub==0.30.2
+ - humanfriendly==10.0
+ - hydra-core==1.1.0
+ - icecream==2.1.2
+ - imageio==2.37.0
+ - imageio-ffmpeg==0.4.9
+ - imgaug==0.4.0
+ - importlib-metadata==8.6.1
+ - inflect==5.6.0
+ - joblib==1.4.2
+ - kiwisolver==1.4.8
+ - kornia==0.8.0
+ - kornia-rs==0.1.8
+ - lazy-loader==0.4
+ - lightning-utilities==0.14.3
+ - llvmlite==0.44.0
+ - lmdb==1.6.2
+ - loguru==0.7.3
+ - markdown==3.8
+ - markdown-it-py==3.0.0
+ - matplotlib==3.10.1
+ - mdurl==0.1.2
+ - multidict==6.4.3
+ - natten==0.14.4
+ - numba==0.61.2
+ - numpy==1.24.1
+ - omegaconf==2.1.2
+ - onnx==1.17.0
+ - onnxruntime
+ - open-clip-torch==2.30.0
+ - opencv-python==4.11.0.86
+ - opencv-python-headless==4.11.0.86
+ - packaging==24.2
+ - pandas==2.2.3
+ - peft==0.14.0
+ - platformdirs==4.3.7
+ - plyfile==1.1
+ - propcache==0.3.1
+ - protobuf==5.29.3
+ - psutil==7.0.0
+ - py-cpuinfo==9.0.0
+ - py360convert==1.0.3
+ - pygments==2.19.1
+ - pyparsing==3.2.3
+ - python-dateutil==2.9.0.post0
+ - pytorch-lightning==2.4.0
+ - pytz==2025.2
+ - qwen-vl-utils==0.0.8
+ - regex==2022.6.2
+ - rich==14.0.0
+ - safetensors==0.5.3
+ - scikit-image==0.24.0
+ - scikit-learn==1.6.1
+ - scipy==1.15.2
+ - seaborn==0.13.2
+ - segment-anything==1.0
+ - sentencepiece==0.2.0
+ - setuptools==59.5.0
+ - shapely==2.0.7
+ - six==1.17.0
+ - submitit==1.4.2
+ - sympy==1.13.1
+ - tabulate==0.9.0
+ - tb-nightly==2.20.0a20250421
+ - tensorboard-data-server==0.7.2
+ - termcolor==3.0.1
+ - threadpoolctl==3.6.0
+ - tifffile==2025.3.30
+ - timm==1.0.13
+ - tokenizers==0.21.1
+ - tomli==2.2.1
+ - torchmetrics==1.7.1
+ - tqdm==4.67.1
+ - transformers==4.51.0
+ - tzdata==2025.2
+ - ultralytics==8.3.74
+ - ultralytics-thop==2.0.14
+ - wcwidth==0.2.13
+ - webdataset==0.2.100
+ - werkzeug==3.1.3
+ - wldhx-yadisk-direct==0.0.6
+ - yapf==0.43.0
+ - yarl==1.20.0
+ - zipp==3.21.0
+ - open3d>=0.18.0
+ - trimesh>=4.6.1
+ - cmake
+ - pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git
+ - moge @ git+https://github.com/microsoft/MoGe.git
diff --git a/examples/case1/classes.txt b/examples/case1/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case1/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case1/input.png b/examples/case1/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..601ac9eb857f07336f55e608c888f8408357576c
--- /dev/null
+++ b/examples/case1/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb701fd458a87beccdd821f557994db64ff8eba7f78f426cb350ed70bbf83f14
+size 3839861
diff --git a/examples/case2/classes.txt b/examples/case2/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case2/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case2/input.png b/examples/case2/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..ed616e44da2a1fb686ffae705e5d5e6c49642431
--- /dev/null
+++ b/examples/case2/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd9115e2db03b1232b4727555a70445e96bafc7a84b2347950156da704c90edb
+size 2681892
diff --git a/examples/case2/labels_fg1.txt b/examples/case2/labels_fg1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..929ba703d84aac8585321104e120b09d517e9c8b
--- /dev/null
+++ b/examples/case2/labels_fg1.txt
@@ -0,0 +1 @@
+stones
\ No newline at end of file
diff --git a/examples/case2/labels_fg2.txt b/examples/case2/labels_fg2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..81151179ac41e7c604390c42b70041566f6f19a6
--- /dev/null
+++ b/examples/case2/labels_fg2.txt
@@ -0,0 +1 @@
+trees
\ No newline at end of file
diff --git a/examples/case3/classes.txt b/examples/case3/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case3/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case3/input.png b/examples/case3/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..5386b2ba56424d201f4152fb38513b9559b25ddb
--- /dev/null
+++ b/examples/case3/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:55e19cd8173aef7f544dc2d487f0878ab6cffb8faf14121abe085bcd1ecbc888
+size 3322477
diff --git a/examples/case4/classes.txt b/examples/case4/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case4/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case4/prompt.txt b/examples/case4/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cae7ac977d187db0061233ca367d2ef61c2786ab
--- /dev/null
+++ b/examples/case4/prompt.txt
@@ -0,0 +1 @@
+There is a rocky island on the vast sea surface, with a triangular rock burning red flames in the center of the island. The sea is open and rough, with a green surface. Surrounded by towering peaks in the distance.
\ No newline at end of file
diff --git a/examples/case5/classes.txt b/examples/case5/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case5/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case5/input.png b/examples/case5/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..e23c388791e2c3a401d17e46a88f5c0ddefcd64a
--- /dev/null
+++ b/examples/case5/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3fe08152fbb72b8845348564bd59fd43515d1291918a0f665cd2a4cca479344f
+size 2949636
diff --git a/examples/case6/classes.txt b/examples/case6/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case6/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case6/input.png b/examples/case6/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc4067e6dfc8384a5dc6de35f29bd18bddf55cb0
--- /dev/null
+++ b/examples/case6/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2ace3cd5c3b3b5a8d3e5d3bfc809bd20487895bffdf739389efc0c520b219f7
+size 1548695
diff --git a/examples/case6/labels_fg1.txt b/examples/case6/labels_fg1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ce0bd9112fc0574f1656e54b67f17fdaf8a7e9a6
--- /dev/null
+++ b/examples/case6/labels_fg1.txt
@@ -0,0 +1 @@
+tent
\ No newline at end of file
diff --git a/examples/case7/classes.txt b/examples/case7/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case7/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case7/prompt.txt b/examples/case7/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..23ea355d101bea481f7ef9f10687b5b5a38bffc5
--- /dev/null
+++ b/examples/case7/prompt.txt
@@ -0,0 +1 @@
+At the moment of glacier collapse, giant ice walls collapse and create waves, with no wildlife, captured in a disaster documentary
\ No newline at end of file
diff --git a/examples/case8/classes.txt b/examples/case8/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case8/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case8/input.png b/examples/case8/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9562a227a40b54f9376c890eadc5534eee9828a
--- /dev/null
+++ b/examples/case8/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae6e9af5cf30d64bb0d0e19e6c7e4993126e8ca74191c4442948f13d2ceea755
+size 2028970
diff --git a/examples/case9/classes.txt b/examples/case9/classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c81d0a206ab9d0722c6533c05c0c9a6583401e5
--- /dev/null
+++ b/examples/case9/classes.txt
@@ -0,0 +1 @@
+outdoor
\ No newline at end of file
diff --git a/examples/case9/prompt.txt b/examples/case9/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4eb700eaf7f8b4037e39c4bb685f363da1e33e90
--- /dev/null
+++ b/examples/case9/prompt.txt
@@ -0,0 +1 @@
+A breathtaking volcanic eruption scene. In the center of the screen, one or more volcanoes are erupting violently, with hot orange red lava gushing out from the crater, illuminating the surrounding night sky and landscape. Thick smoke and volcanic ash rose into the sky, forming a huge mushroom cloud like structure. Some of the smoke and dust were reflected in a dark red color by the high temperature of the lava, creating a doomsday atmosphere. In the foreground, a winding lava flow flows through the dark and rough rocks like a fire snake, emitting a dazzling light as if burning the earth. The steep and rugged mountains in the background further emphasize the ferocity and irresistible power of nature. The entire picture has a strong contrast of light and shadow, with red, black, and gray as the main colors, highlighting the visual impact and dramatic tension of volcanic eruptions, making people feel the grandeur and terror of nature.
\ No newline at end of file
diff --git a/hy3dworld/__init__.py b/hy3dworld/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3242b0983430fd350c32b4a8e908da639dfc6424
--- /dev/null
+++ b/hy3dworld/__init__.py
@@ -0,0 +1,22 @@
+# Tencent HunyuanWorld-1.0 is licensed under TENCENT HUNYUANWORLD-1.0 COMMUNITY LICENSE AGREEMENT
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND
+# IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying
+# any portion or element of the Tencent HunyuanWorld-1.0 Works, including via any Hosted Service,
+# You will be deemed to have recognized and accepted the content of this Agreement,
+# which is effective immediately.
+
+# For avoidance of doubts, Tencent HunyuanWorld-1.0 means the 3D generation models
+# and their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent at [https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0].
+# Image to Panorama
+from .models import Image2PanoramaPipelines
+from .utils import Perspective
+# Text to Panorama
+from .models import Text2PanoramaPipelines
+# Sence Generation
+from .models import LayerDecomposition
+from .models import WorldComposer
+from .utils import process_file
diff --git a/hy3dworld/models/__init__.py b/hy3dworld/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3888fa2af00d2404df3661fb14848ad2e0e6c64
--- /dev/null
+++ b/hy3dworld/models/__init__.py
@@ -0,0 +1,29 @@
+# Tencent HunyuanWorld-1.0 is licensed under TENCENT HUNYUANWORLD-1.0 COMMUNITY LICENSE AGREEMENT
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND
+# IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying
+# any portion or element of the Tencent HunyuanWorld-1.0 Works, including via any Hosted Service,
+# You will be deemed to have recognized and accepted the content of this Agreement,
+# which is effective immediately.
+
+# For avoidance of doubts, Tencent HunyuanWorld-1.0 means the 3D generation models
+# and their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent at [https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0].
+
+# Image to Panorama
+from .pano_generator import Image2PanoramaPipelines
+# Text to Panorama
+from .pano_generator import Text2PanoramaPipelines
+
+# Scene Generation
+from .pipelines import FluxPipeline, FluxFillPipeline
+from .layer_decomposer import LayerDecomposition
+from .world_composer import WorldComposer
+
+__all__ = [
+ "Image2PanoramaPipelines", "Text2PanoramaPipelines",
+ "FluxPipeline", "FluxFillPipeline",
+ "LayerDecomposition", "WorldComposer",
+]
diff --git a/hy3dworld/models/adaptive_depth_compression.py b/hy3dworld/models/adaptive_depth_compression.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ef63304d5de68fb78233051fcc934517203d392
--- /dev/null
+++ b/hy3dworld/models/adaptive_depth_compression.py
@@ -0,0 +1,474 @@
+# Tencent HunyuanWorld-1.0 is licensed under TENCENT HUNYUANWORLD-1.0 COMMUNITY LICENSE AGREEMENT
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND
+# IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying
+# any portion or element of the Tencent HunyuanWorld-1.0 Works, including via any Hosted Service,
+# You will be deemed to have recognized and accepted the content of this Agreement,
+# which is effective immediately.
+
+# For avoidance of doubts, Tencent HunyuanWorld-1.0 means the 3D generation models
+# and their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent at [https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0].
+import torch
+from typing import List, Dict, Tuple
+
+
+class AdaptiveDepthCompressor:
+ r"""
+ Adaptive depth compressor to solve the problem of excessive background depth variance
+ in 3D world generation. This class provides methods to compress background and foreground
+ depth values based on statistical analysis of depth distributions, with options for
+ smooth compression and outlier removal.
+ Args:
+ cv_thresholds: Tuple of (low, high) thresholds for coefficient of variation (CV).
+ compression_quantiles: Tuple of (low, medium, high) quantiles for depth compression.
+ fg_bg_depth_margin: Margin factor to ensure foreground depth is greater than background.
+ enable_smooth_compression: Whether to use smooth compression instead of hard truncation.
+ outlier_removal_method: Method for outlier removal, options are "iqr", "quantile", or "none".
+ min_compression_depth: Minimum depth threshold for compression to be applied.
+ """
+
+ def __init__(
+ self,
+ cv_thresholds: Tuple[float, float] = (0.3, 0.8),
+ compression_quantiles: Tuple[float, float, float] = (0.95, 0.92, 0.85),
+ fg_bg_depth_margin: float = 1.1,
+ enable_smooth_compression: bool = True,
+ outlier_removal_method: str = "iqr",
+ min_compression_depth: float = 6.0,
+ ):
+ self.cv_thresholds = cv_thresholds
+ self.compression_quantiles = compression_quantiles
+ self.fg_bg_depth_margin = fg_bg_depth_margin
+ self.enable_smooth_compression = enable_smooth_compression
+ self.outlier_removal_method = outlier_removal_method
+ self.min_compression_depth = min_compression_depth
+
+ def _remove_outliers(self, depth_vals: torch.Tensor) -> torch.Tensor:
+ r"""
+ Remove outliers from depth values
+ based on the specified method (IQR or quantile).
+ Args:
+ depth_vals: Tensor of depth values to process.
+ Returns:
+ Tensor of depth values with outliers removed.
+ """
+ if self.outlier_removal_method == "iqr":
+ q25, q75 = torch.quantile(depth_vals, torch.tensor(
+ [0.25, 0.75], device=depth_vals.device))
+ iqr = q75 - q25
+ lower_bound, upper_bound = q25 - 1.5 * iqr, q75 + 1.5 * iqr
+ valid_mask = (depth_vals >= lower_bound) & (
+ depth_vals <= upper_bound)
+ elif self.outlier_removal_method == "quantile":
+ q05, q95 = torch.quantile(depth_vals, torch.tensor(
+ [0.05, 0.95], device=depth_vals.device))
+ valid_mask = (depth_vals >= q05) & (depth_vals <= q95)
+ else:
+ return depth_vals
+ return depth_vals[valid_mask] if valid_mask.sum() > 0 else depth_vals
+
+ def _collect_foreground_depths(
+ self,
+ layered_world_depth: List[Dict]
+ ) -> List[torch.Tensor]:
+ r"""
+ Collect depth information of all foreground layers (remove outliers)
+ from the layered world depth representation.
+ Args:
+ layered_world_depth: List of dictionaries containing depth information for each layer.
+ Returns:
+ List of tensors containing cleaned foreground depth values.
+ """
+ fg_depths = []
+ for layer_depth in layered_world_depth:
+ if layer_depth["name"] == "background":
+ continue
+
+ depth_vals = layer_depth["distance"]
+ mask = layer_depth.get("mask", None)
+
+ # Process the depth values within the mask area
+ if mask is not None:
+ if not isinstance(mask, torch.Tensor):
+ mask = torch.from_numpy(mask).to(depth_vals.device)
+ depth_vals = depth_vals[mask.bool()]
+
+ if depth_vals.numel() > 0:
+ cleaned_depths = self._remove_outliers(depth_vals)
+ if len(cleaned_depths) > 0:
+ fg_depths.append(cleaned_depths)
+ return fg_depths
+
+ def _get_pixelwise_foreground_max_depth(
+ self,
+ layered_world_depth: List[Dict],
+ bg_shape: torch.Size,
+ bg_device: torch.device
+ ) -> torch.Tensor:
+ r"""
+ Calculate the maximum foreground depth for each pixel position
+ Args:
+ layered_world_depth: List of dictionaries containing depth information for each layer.
+ bg_shape: Shape of the background depth tensor.
+ bg_device: Device where the background depth tensor is located.
+ Returns:
+ Tensor of maximum foreground depth values for each pixel position.
+ """
+ fg_max_depth = torch.zeros(bg_shape, device=bg_device)
+
+ for layer_depth in layered_world_depth:
+ if layer_depth["name"] == "background":
+ continue
+
+ layer_distance = layer_depth["distance"]
+ layer_mask = layer_depth.get("mask", None)
+
+ # Ensure that the tensor is on the correct device
+ if not isinstance(layer_distance, torch.Tensor):
+ layer_distance = torch.from_numpy(layer_distance).to(bg_device)
+ else:
+ layer_distance = layer_distance.to(bg_device)
+
+ # Update the maximum depth of the foreground
+ if layer_mask is not None:
+ if not isinstance(layer_mask, torch.Tensor):
+ layer_mask = torch.from_numpy(layer_mask).to(bg_device)
+ else:
+ layer_mask = layer_mask.to(bg_device)
+ fg_max_depth = torch.where(layer_mask.bool(), torch.max(
+ fg_max_depth, layer_distance), fg_max_depth)
+ else:
+ fg_max_depth = torch.max(fg_max_depth, layer_distance)
+
+ return fg_max_depth
+
+ def _analyze_depth_distribution(self, bg_depth_distance: torch.Tensor) -> Dict:
+ r"""
+ Analyze the distribution characteristics of background depth
+ Args:
+ bg_depth_distance: Tensor of background depth distances.
+ Returns:
+ Dictionary containing statistical properties of the background depth distribution.
+ """
+ bg_mean, bg_std = torch.mean(
+ bg_depth_distance), torch.std(bg_depth_distance)
+ cv = bg_std / bg_mean
+
+ quantiles = torch.quantile(bg_depth_distance,
+ torch.tensor([0.5, 0.75, 0.9, 0.95, 0.99], device=bg_depth_distance.device))
+ bg_q50, bg_q75, bg_q90, bg_q95, bg_q99 = quantiles
+
+ return {"mean": bg_mean, "std": bg_std, "cv": cv, "q50": bg_q50,
+ "q75": bg_q75, "q90": bg_q90, "q95": bg_q95, "q99": bg_q99}
+
+ def _determine_compression_strategy(self, cv: float) -> Tuple[str, float]:
+ r"""
+ Determine compression strategy based on coefficient of variation
+ Args:
+ cv: Coefficient of variation of the background depth distribution.
+ Returns:
+ Tuple containing the compression strategy ("gentle", "standard", "aggressive")
+ and the quantile to use for compression.
+ """
+ low_cv_threshold, high_cv_threshold = self.cv_thresholds
+ low_var_quantile, medium_var_quantile, high_var_quantile = self.compression_quantiles
+
+ if cv < low_cv_threshold:
+ return "gentle", low_var_quantile
+ elif cv > high_cv_threshold:
+ return "aggressive", high_var_quantile
+ else:
+ return "standard", medium_var_quantile
+
+ def _smooth_compression(self, depth_values: torch.Tensor, max_depth: torch.Tensor,
+ mask: torch.Tensor = None, transition_start_ratio: float = 0.95,
+ transition_range_ratio: float = 0.2, verbose: bool = False) -> torch.Tensor:
+ r"""
+ Use smooth compression function instead of hard truncation
+ Args:
+ depth_values: Tensor of depth values to compress.
+ max_depth: Maximum depth value for compression.
+ mask: Optional mask to apply compression only to certain pixels.
+ transition_start_ratio: Ratio to determine the start of the transition range.
+ transition_range_ratio: Ratio to determine the range of the transition.
+ verbose: Whether to print detailed information about the compression process.
+ Returns:
+ Compressed depth values as a tensor.
+ """
+ if not self.enable_smooth_compression:
+ compressed = depth_values.clone()
+ if mask is not None:
+ compressed[mask] = torch.clamp(
+ depth_values[mask], max=max_depth)
+ else:
+ compressed = torch.clamp(depth_values, max=max_depth)
+ return compressed
+
+ transition_start = max_depth * transition_start_ratio
+ transition_range = max_depth * transition_range_ratio
+ compressed_depth = depth_values.clone()
+
+ mask_far = depth_values > transition_start
+ if mask is not None:
+ mask_far = mask_far & mask
+
+ if mask_far.sum() > 0:
+ far_depths = depth_values[mask_far]
+ normalized = (far_depths - transition_start) / transition_range
+ compressed_normalized = torch.sigmoid(
+ normalized * 2 - 1) * 0.5 + 0.5
+ compressed_far = transition_start + \
+ compressed_normalized * (max_depth - transition_start)
+ compressed_depth[mask_far] = compressed_far
+ if verbose:
+ print(
+ f"\t Applied smooth compression to {mask_far.sum()} pixels beyond {transition_start:.2f}")
+ elif verbose:
+ print(f"\t No compression needed, all depths within reasonable range")
+
+ return compressed_depth
+
+ def compress_background_depth(self, bg_depth_distance: torch.Tensor, layered_world_depth: List[Dict],
+ bg_mask: torch.Tensor, verbose: bool = False) -> torch.Tensor:
+ r"""
+ Adaptive compression of background depth values
+ Args:
+ bg_depth_distance: Tensor of background depth distances.
+ layered_world_depth: List of dictionaries containing depth information for each layer.
+ bg_mask: Tensor or numpy array representing the mask for background depth.
+ verbose: Whether to print detailed information about the compression process.
+ Returns:
+ Compressed background depth values as a tensor.
+ """
+ if verbose:
+ print(f"\t - Applying adaptive depth compression...")
+
+ # Process mask
+ if not isinstance(bg_mask, torch.Tensor):
+ bg_mask = torch.from_numpy(bg_mask).to(bg_depth_distance.device)
+ mask_bool = bg_mask.bool()
+ masked_depths = bg_depth_distance[mask_bool]
+
+ if masked_depths.numel() == 0:
+ if verbose:
+ print(f"\t No valid depths in mask region, skipping compression")
+ return bg_depth_distance
+
+ # 1. Collect prospect depth information
+ fg_depths = self._collect_foreground_depths(layered_world_depth)
+
+ # 2. Calculate prospect depth statistics
+ if fg_depths:
+ all_fg_depths = torch.cat(fg_depths)
+ fg_max = torch.quantile(all_fg_depths, torch.tensor(
+ 0.99, device=all_fg_depths.device))
+ if verbose:
+ print(
+ f"\t Foreground depth stats - 99th percentile: {fg_max:.2f}")
+ else:
+ fg_max = torch.quantile(masked_depths, torch.tensor(
+ 0.5, device=masked_depths.device))
+ if verbose:
+ print(f"\t No foreground found, using background stats for reference")
+
+ # 3. Analyze the depth distribution of the background
+ depth_stats = self._analyze_depth_distribution(masked_depths)
+ if verbose:
+ print(
+ f"\t Background depth stats - mean: {depth_stats['mean']:.2f}, \
+ std: {depth_stats['std']:.2f}, CV: {depth_stats['cv']:.3f}")
+
+ # 4. Determine compression strategy
+ strategy, compression_quantile = self._determine_compression_strategy(
+ depth_stats['cv'])
+ max_depth = torch.quantile(masked_depths, torch.tensor(
+ compression_quantile, device=masked_depths.device))
+
+ if verbose:
+ print(
+ f"\t {strategy.capitalize()} compression strategy \
+ (CV={depth_stats['cv']:.3f}), quantile={compression_quantile}")
+
+ # 5. Pixel level depth constraint
+ if fg_depths:
+ fg_max_depth_pixelwise = self._get_pixelwise_foreground_max_depth(
+ layered_world_depth, bg_depth_distance.shape, bg_depth_distance.device)
+ required_min_bg_depth = fg_max_depth_pixelwise * self.fg_bg_depth_margin
+ pixelwise_violations = (
+ bg_depth_distance < required_min_bg_depth) & mask_bool
+
+ if pixelwise_violations.sum() > 0:
+ violation_ratio = pixelwise_violations.float().sum() / mask_bool.float().sum()
+ violated_required_depths = required_min_bg_depth[pixelwise_violations]
+ pixelwise_min_depth = torch.quantile(violated_required_depths, torch.tensor(
+ 0.99, device=violated_required_depths.device))
+ max_depth = torch.max(max_depth, pixelwise_min_depth)
+ if verbose:
+ print(
+ f"\t Pixelwise constraint violation: {violation_ratio:.1%}, \
+ adjusted max depth to {max_depth:.2f}")
+ elif verbose:
+ print(f"\t Pixelwise depth constraints satisfied")
+
+ # 6. Global statistical constraints
+ if fg_depths:
+ min_bg_depth = fg_max * self.fg_bg_depth_margin
+ max_depth = torch.max(max_depth, min_bg_depth)
+ if verbose:
+ print(f"\t Final max depth: {max_depth:.2f}")
+
+ # 6.5. Depth threshold check: If max_depth is less than the threshold, skip compression
+ if max_depth < self.min_compression_depth:
+ if verbose:
+ print(
+ f"\t Max depth {max_depth:.2f} is below threshold \
+ {self.min_compression_depth:.2f}, skipping compression")
+ return bg_depth_distance
+
+ # 7. Application compression
+ compressed_depth = self._smooth_compression(
+ bg_depth_distance, max_depth, mask_bool, 0.9, 0.2, verbose)
+
+ # 8. Hard truncation of extreme outliers
+ final_max = max_depth * 1.2
+ outliers = (compressed_depth > final_max) & mask_bool
+ if outliers.sum() > 0:
+ compressed_depth[outliers] = final_max
+
+ # 9. statistic
+ compression_ratio = ((bg_depth_distance > max_depth)
+ & mask_bool).float().sum() / mask_bool.float().sum()
+ if verbose:
+ print(
+ f"\t Compression summary - max depth: \
+ {max_depth:.2f}, affected: {compression_ratio:.1%}")
+
+ return compressed_depth
+
+ def compress_foreground_depth(
+ self,
+ fg_depth_distance: torch.Tensor,
+ fg_mask: torch.Tensor,
+ verbose: bool = False,
+ conservative_ratio: float = 0.99,
+ iqr_scale: float = 2
+ ) -> torch.Tensor:
+ r"""
+ Conservatively compress outliers for foreground depth
+ Args:
+ fg_depth_distance: Tensor of foreground depth distances.
+ fg_mask: Tensor or numpy array representing the mask for foreground depth.
+ verbose: Whether to print detailed information about the compression process.
+ conservative_ratio: Ratio to use for conservative compression.
+ iqr_scale: Scale factor for IQR-based upper bound.
+ Returns:
+ Compressed foreground depth values as a tensor.
+ """
+ if verbose:
+ print(f"\t - Applying conservative foreground depth compression...")
+
+ # Process mask
+ if not isinstance(fg_mask, torch.Tensor):
+ fg_mask = torch.from_numpy(fg_mask).to(fg_depth_distance.device)
+ mask_bool = fg_mask.bool()
+ masked_depths = fg_depth_distance[mask_bool]
+
+ if masked_depths.numel() == 0:
+ if verbose:
+ print(f"\t No valid depths in mask region, skipping compression")
+ return fg_depth_distance
+
+ # Calculate statistical information
+ depth_mean, depth_std = torch.mean(
+ masked_depths), torch.std(masked_depths)
+
+ # Determine the upper bound using IQR and quantile methods
+ q25, q75 = torch.quantile(masked_depths, torch.tensor(
+ [0.25, 0.75], device=masked_depths.device))
+ iqr = q75 - q25
+ upper_bound = q75 + iqr_scale * iqr
+ conservative_max = torch.quantile(masked_depths, torch.tensor(
+ conservative_ratio, device=masked_depths.device))
+ final_max = torch.max(upper_bound, conservative_max)
+
+ # Statistical Outliers
+ outliers = (fg_depth_distance > final_max) & mask_bool
+ outlier_count = outliers.sum().item()
+
+ if verbose:
+ print(
+ f"\t Depth stats - mean: {depth_mean:.2f}, std: {depth_std:.2f}")
+ print(
+ f"\t IQR bounds - Q25: {q25:.2f}, Q75: {q75:.2f}, upper: {upper_bound:.2f}")
+ print(
+ f"\t Conservative max: {conservative_max:.2f}, final max: {final_max:.2f}")
+ print(
+ f"\t Outliers: {outlier_count} ({(outlier_count/masked_depths.numel()*100):.2f}%)")
+
+ # Depth threshold check: If final_max is less than the threshold, skip compression
+ if final_max < self.min_compression_depth:
+ if verbose:
+ print(
+ f"\t Final max depth {final_max:.2f} is below threshold \
+ {self.min_compression_depth:.2f}, skipping compression")
+ return fg_depth_distance
+
+ # Apply compression
+ if outlier_count > 0:
+ compressed_depth = self._smooth_compression(
+ fg_depth_distance, final_max, mask_bool, 0.99, 0.1, verbose)
+ else:
+ compressed_depth = fg_depth_distance.clone()
+
+ return compressed_depth
+
+
+def create_adaptive_depth_compressor(
+ scene_type: str = "auto",
+ enable_smooth_compression: bool = True,
+ outlier_removal_method: str = "iqr",
+ min_compression_depth: float = 6.0, # Minimum compression depth threshold
+) -> AdaptiveDepthCompressor:
+ r"""
+ Create adaptive depth compressors suitable for different scene types
+ Args:
+ scene_type: Scenario Type ("indoor", "outdoor", "mixed", "auto")
+ enable_smooth_compression: enable smooth compression or not
+ outlier_removal_method: Outlier removal method ("iqr", "quantile", "none")
+ """
+ common_params = {
+ "enable_smooth_compression": enable_smooth_compression,
+ "outlier_removal_method": outlier_removal_method,
+ "min_compression_depth": min_compression_depth,
+ }
+
+ if scene_type == "indoor":
+ # Indoor scene: Depth variation is relatively small, conservative compression is used
+ return AdaptiveDepthCompressor(
+ cv_thresholds=(0.2, 0.6),
+ compression_quantiles=(1.0, 0.975, 0.95),
+ fg_bg_depth_margin=1.05,
+ **common_params
+ )
+ elif scene_type == "outdoor":
+ # Outdoor scenes: There may be sky, distant mountains, etc., using more aggressive compression
+ return AdaptiveDepthCompressor(
+ cv_thresholds=(0.4, 1.0),
+ compression_quantiles=(0.98, 0.955, 0.93),
+ fg_bg_depth_margin=1.15,
+ **common_params
+ )
+ elif scene_type == "mixed":
+ # Mixed Scene: Balanced Settings
+ return AdaptiveDepthCompressor(
+ cv_thresholds=(0.3, 0.8),
+ compression_quantiles=(0.99, 0.97, 0.95),
+ fg_bg_depth_margin=1.1,
+ **common_params
+ )
+ else: # auto
+ # Automatic mode: Use default settings
+ return AdaptiveDepthCompressor(**common_params)
diff --git a/hy3dworld/models/layer_decomposer.py b/hy3dworld/models/layer_decomposer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb8dd0052248edbe932ef10afe3abae7910ff7b3
--- /dev/null
+++ b/hy3dworld/models/layer_decomposer.py
@@ -0,0 +1,155 @@
+import os
+import json
+import torch
+from ..utils import sr_utils, seg_utils, inpaint_utils, layer_utils
+
+
+class LayerDecomposition():
+ r"""LayerDecomposition is responsible for generating layers in a scene based on input images and masks.
+ It processes foreground objects, background layers, and sky regions using various models.
+ Args:
+ seed (int): Random seed for reproducibility.
+ strength (float): Strength of the layer generation.
+ threshold (int): Threshold for object detection.
+ ratio (float): Ratio for scaling objects.
+ grounding_model (str): Path to the grounding model for object detection.
+ zim_model_config (str): Configuration for the ZIM model.
+ zim_checkpoint (str): Path to the ZIM model checkpoint.
+ inpaint_model (str): Path to the inpainting model.
+ inpaint_fg_lora (str): Path to the LoRA weights for foreground inpainting.
+ inpaint_sky_lora (str): Path to the LoRA weights for sky inpainting.
+ scale (int): Scale factor for super-resolution.
+ device (str): Device to run the model on, either "cuda" or "cpu".
+ dilation_size (int): Size of the dilation for mask processing.
+ cfg_scale (float): Configuration scale for the model.
+ prompt_config (dict): Configuration for prompts used in the model.
+ """
+ def __init__(self):
+ r"""Initialize the LayerDecomposition class with model paths and parameters."""
+ self.seed = 25
+ self.strength = 1.0
+ self.threshold = 20000
+ self.ratio = 1.5
+ self.grounding_model = "IDEA-Research/grounding-dino-tiny"
+ self.zim_model_config = "vit_l"
+ self.zim_checkpoint = "./ZIM/zim_vit_l_2092" # Add zim anything ckpt here
+ self.inpaint_model = "black-forest-labs/FLUX.1-Fill-dev"
+ self.inpaint_fg_lora = "tencent/HunyuanWorld-1"
+ self.inpaint_sky_lora = "tencent/HunyuanWorld-1"
+ self.scale = 2
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.dilation_size = 80
+ self.cfg_scale = 5.0
+ self.prompt_config = {
+ "indoor": {
+ "positive_prompt": "",
+ "negative_prompt": (
+ "object, table, chair, seat, shelf, sofa, bed, bath, sink,"
+ "ceramic, wood, plant, tree, light, lamp, candle, television, electronics,"
+ "oven, fire, low-resolution, blur, mosaic, people")
+ },
+ "outdoor": {
+ "positive_prompt": "",
+ "negative_prompt": (
+ "object, chair, tree, plant, flower, grass, stone, rock,"
+ "building, hill, house, tower, light, lamp, low-resolution, blur, mosaic, people")
+ }
+ }
+
+ # Load models
+ print("============= now loading models ===============")
+ # super-resolution model
+ self.sr_model = sr_utils.build_sr_model(scale=self.scale, gpu_id=0)
+ print("============= load Super-Resolution models done ")
+ # segmentation model
+ self.zim_predictor = seg_utils.build_zim_model(
+ self.zim_model_config, self.zim_checkpoint, device='cuda:0')
+ self.gd_processor, self.gd_model = seg_utils.build_gd_model(
+ self.grounding_model, device='cuda:0')
+ print("============= load Segmentation models done ====")
+ # panorama inpaint model
+ self.inpaint_fg_model = inpaint_utils.build_inpaint_model(
+ self.inpaint_model,
+ self.inpaint_fg_lora,
+ subfolder="HunyuanWorld-PanoInpaint-Scene",
+ device=0
+ )
+ self.inpaint_sky_model = inpaint_utils.build_inpaint_model(
+ self.inpaint_model,
+ self.inpaint_sky_lora,
+ subfolder="HunyuanWorld-PanoInpaint-Sky",
+ device=0
+ )
+ print("============= load panorama inpaint models done =")
+
+ def __call__(self, input, layer):
+ r"""Generate layers based on the input images and masks.
+ Args:
+ input (str or list): Path to the input JSON file or a list of image information.
+ layer (int): Layer index to process (0 for foreground1, 1 for foreground2,
+ 2 for sky).
+ Raises:
+ FileNotFoundError: If the input file does not exist.
+ ValueError: If the input file is not a JSON file or if the layer index is invalid.
+ TypeError: If the input is neither a string nor a list.
+ """
+ torch.autocast(device_type=self.device,
+ dtype=torch.bfloat16).__enter__()
+
+ # Input handling and validation
+ if isinstance(input, str):
+ if not os.path.exists(input):
+ raise FileNotFoundError(f"Input file {input} does not exist.")
+ if not input.endswith('.json'):
+ raise ValueError("Input file must be a JSON file.")
+ with open(input, "r") as f:
+ img_infos = json.load(f)
+ img_infos = img_infos["output"]
+ elif isinstance(input, list):
+ img_infos = input
+ else:
+ raise TypeError("Input must be a JSON file path or a list.")
+
+ # Processing parameters
+ params = {
+ 'scale': self.scale,
+ 'seed': self.seed,
+ 'threshold': self.threshold,
+ 'ratio': self.ratio,
+ 'strength': self.strength,
+ 'dilation_size': self.dilation_size,
+ 'cfg_scale': self.cfg_scale,
+ 'prompt_config': self.prompt_config
+ }
+
+ # Layer-specific processing pipelines
+ if layer == 0:
+ layer_utils.remove_fg1_pipeline(
+ img_infos=img_infos,
+ sr_model=self.sr_model,
+ zim_predictor=self.zim_predictor,
+ gd_processor=self.gd_processor,
+ gd_model=self.gd_model,
+ inpaint_model=self.inpaint_fg_model,
+ params=params
+ )
+ elif layer == 1:
+ layer_utils.remove_fg2_pipeline(
+ img_infos=img_infos,
+ sr_model=self.sr_model,
+ zim_predictor=self.zim_predictor,
+ gd_processor=self.gd_processor,
+ gd_model=self.gd_model,
+ inpaint_model=self.inpaint_fg_model,
+ params=params
+ )
+ else:
+ layer_utils.sky_pipeline(
+ img_infos=img_infos,
+ sr_model=self.sr_model,
+ zim_predictor=self.zim_predictor,
+ gd_processor=self.gd_processor,
+ gd_model=self.gd_model,
+ inpaint_model=self.inpaint_sky_model,
+ params=params
+ )
diff --git a/hy3dworld/models/pano_generator.py b/hy3dworld/models/pano_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f14ebd113e5ae581a10586b066f2ebf6d9ba5a30
--- /dev/null
+++ b/hy3dworld/models/pano_generator.py
@@ -0,0 +1,236 @@
+# Tencent HunyuanWorld-1.0 is licensed under TENCENT HUNYUANWORLD-1.0 COMMUNITY LICENSE AGREEMENT
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND
+# IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying
+# any portion or element of the Tencent HunyuanWorld-1.0 Works, including via any Hosted Service,
+# You will be deemed to have recognized and accepted the content of this Agreement,
+# which is effective immediately.
+
+# For avoidance of doubts, Tencent HunyuanWorld-1.0 means the 3D generation models
+# and their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent at [https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0].
+
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.autoencoders import AutoencoderKL
+
+from diffusers.models.transformers import FluxTransformer2DModel
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+
+from diffusers.utils.torch_utils import randn_tensor
+
+from .pipelines import FluxPipeline, FluxFillPipeline
+
+class Text2PanoramaPipelines(FluxPipeline):
+ @torch.no_grad()
+ def __call__(self, prompt, **kwargs):
+ """Main inpainting call."""
+ return self._call_shared(prompt=prompt, is_inpainting=False, early_steps=3, **kwargs)
+
+
+class Image2PanoramaPipelines(FluxFillPipeline):
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ # Initilization from FluxFillPipeline
+ super().__init__(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ )
+
+ # change some part of initilization
+ self.latent_channels = self.vae.config.latent_channels if getattr(
+ self, "vae", None) else 16
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps *
+ strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_inpainting_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ is_strength_max=True,
+ timestep=None,
+ ):
+ r"""
+ Prepares the latents for the Image2PanoramaPipelines.
+ """
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ # Return the latents if they are already provided
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ # If no latents are provided, we need to encode the image
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(
+ image=image, generator=generator)
+ else:
+ image_latents = image
+
+ # Ensure image_latents has the correct shape
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat(
+ [image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+ # Add noise to the latents
+ noise = randn_tensor(shape, generator=generator,
+ device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+
+ # prepare blended latents
+ latents = torch.cat(
+ [latents, latents[:, :, :, :self.blend_extend]], dim=-1)
+ width_new_blended = latents.shape[-1]
+ latents = self._pack_latents(
+ latents, batch_size, num_channels_latents, height, width_new_blended)
+ # prepare latent image ids
+ latent_image_ids = self._prepare_latent_image_ids(
+ batch_size, height // 2, width_new_blended // 2, device, dtype)
+
+ return latents, latent_image_ids, width_new_blended
+
+ def prepare_blending_latent(
+ self, latents, height, width, batch_size, num_channels_latents, width_new_blended=None
+ ):
+ return latents, width_new_blended
+
+ def _apply_blending(
+ self,
+ latents: torch.Tensor,
+ height: int,
+ width_new_blended: int,
+ num_channels_latents: int,
+ batch_size: int,
+ **karwgs,
+ ) -> torch.Tensor:
+ r"""Apply horizontal blending to latents."""
+ # Unpack latents for processing
+ latents_unpack = self._unpack_latents(
+ latents, height, width_new_blended*self.vae_scale_factor, self.vae_scale_factor
+ )
+ # Apply blending
+ latents_unpack = self.blend_h(latents_unpack, latents_unpack, self.blend_extend)
+
+ latent_height = 2 * \
+ (int(height) // (self.vae_scale_factor * 2))
+
+ shifting_extend = karwgs.get("shifting_extend", None)
+ if shifting_extend is None:
+ shifting_extend = latents_unpack.size()[-1]//4
+
+ latents_unpack = torch.roll(
+ latents_unpack, shifting_extend, -1)
+
+ # Repack latents after blending
+ latents = self._pack_latents(
+ latents_unpack, batch_size, num_channels_latents, latent_height, width_new_blended)
+ return latents
+
+ def _apply_blending_mask(
+ self,
+ latents: torch.Tensor,
+ height: int,
+ width_new_blended: int,
+ num_channels_latents: int,
+ batch_size: int,
+ **kwargs
+ ) -> torch.Tensor:
+ r"""Apply horizontal blending to mask latents."""
+ return self._apply_blending(
+ latents, height, width_new_blended, 80, batch_size, **kwargs
+ )
+
+ def _final_process_latents(
+ self,
+ latents: torch.Tensor,
+ height: int,
+ width_new_blended: int,
+ width: int
+ ) -> torch.Tensor:
+ """Final processing of latents before decoding."""
+ # Unpack and crop to target width
+ latents_unpack = self._unpack_latents(
+ latents, height, width_new_blended * self.vae_scale_factor, self.vae_scale_factor
+ )
+ latents_unpack = self.blend_h(
+ latents_unpack, latents_unpack, self.blend_extend
+ )
+ latents_unpack = latents_unpack[:, :, :, :width // self.vae_scale_factor]
+
+ # Repack for final output
+ return self._pack_latents(
+ latents_unpack,
+ latents.shape[0], # batch size
+ latents.shape[2] // 4, # num_channels_latents
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor
+ )
+
+ @torch.no_grad()
+ def __call__(self, **kwargs):
+ """Main inpainting call."""
+ return self._call_shared(is_inpainting=True, early_steps=3, blend_extra_chanel=True, **kwargs)
diff --git a/hy3dworld/models/pipelines.py b/hy3dworld/models/pipelines.py
new file mode 100644
index 0000000000000000000000000000000000000000..124f2fbd43c1f1c5b3ea29178ac889431ebdfdb1
--- /dev/null
+++ b/hy3dworld/models/pipelines.py
@@ -0,0 +1,1478 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+
+from diffusers.models.transformers import FluxTransformer2DModel
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers import DiffusionPipeline
+from diffusers.pipelines.flux import FluxPipelineOutput
+
+# try to import DecoderOutput from diffusers.models
+try:
+ from diffusers.models.autoencoders.vae import DecoderOutput
+except:
+ from diffusers.models.vae import DecoderOutput
+
+# Check if PyTorch XLA (for TPU support) is available
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+# Initialize logger for the module
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ r"""
+ Calculate the shift value for the image sequence length based on the base and maximum sequence lengths.
+ Args:
+ image_seq_len (`int`):
+ The sequence length of the image.
+ base_seq_len (`int`, *optional*, defaults to 256):
+ The base sequence length.
+ max_seq_len (`int`, *optional*, defaults to 4096):
+ The maximum sequence length.
+ base_shift (`float`, *optional*, defaults to 0.5):
+ The base shift value.
+ max_shift (`float`, *optional*, defaults to 1.16):
+ The maximum shift value.
+ Returns:
+ `float`: The calculated shift value for the image sequence length.
+ """
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(
+ scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ r"""
+ Retrieves the latents from the encoder output based on the sample mode.
+ Args:
+ encoder_output (`torch.Tensor` or `FluxPipelineOutput`):
+ The output from the encoder, which can be a tensor or a custom output object.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for sampling. If `None`, the default generator is used.
+ sample_mode (`str`, *optional*, defaults to `"sample"`):
+ The mode for sampling latents. Can be either `"sample"` or `"argmax"`.
+ Returns:
+ `torch.Tensor`: The sampled or argmax latents from the encoder output.
+ Raises:
+ `AttributeError`: If the encoder output does not have the expected attributes for latents.
+ """
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError(
+ "Could not access latents of provided encoder_output")
+
+class FluxBasePipeline(DiffusionPipeline):
+ """Base class for Flux pipelines containing shared functionality."""
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ **kwargs
+ ):
+ super().__init__()
+
+ # Register core components
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ # Calculate scale factors
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
+ if hasattr(self, "vae") and self.vae is not None else 8
+ )
+
+ # Initialize processors
+ self._init_processors(**kwargs)
+
+ # Default configuration
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length
+ if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ def _init_processors(self, **kwargs):
+ """Initialize image and mask processors."""
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2
+ )
+
+ # Only initialize mask processor for inpainting pipeline
+ if hasattr(self, 'mask_processor'):
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.vae.config.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """Generate prompt embeddings using T5 text encoder."""
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ # Convert single prompt to list
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Handle textual inversion if applicable
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ # Tokenize input
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation
+ untruncated_ids = self.tokenizer_2(
+ prompt, padding="longest", return_tensors="pt"
+ ).input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(
+ untruncated_ids[:, self.tokenizer_max_length - 1: -1]
+ )
+ logger.warning(
+ f"Truncated input (max_length={max_sequence_length}): {removed_text}"
+ )
+
+ # Get embeddings from T5 encoder
+ prompt_embeds = self.text_encoder_2(
+ text_input_ids.to(device), output_hidden_states=False
+ )[0].to(dtype=dtype, device=device)
+
+ # Expand for multiple images per prompt
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ """Generate pooled prompt embeddings using CLIP text encoder."""
+ device = device or self._execution_device
+
+ # Convert single prompt to list
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Handle textual inversion if applicable
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ # Tokenize input
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation
+ untruncated_ids = self.tokenizer(
+ prompt, padding="longest", return_tensors="pt"
+ ).input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer_max_length - 1: -1]
+ )
+ logger.warning(
+ f"CLIP truncated input (max_length={self.tokenizer_max_length}): {removed_text}"
+ )
+
+ # Get pooled embeddings from CLIP
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), output_hidden_states=False
+ ).pooler_output.to(dtype=self.text_encoder.dtype, device=device)
+
+ # Expand for multiple images per prompt
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(
+ batch_size * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ """Main method to encode prompts using both text encoders."""
+ # Handle LoRA scaling if applicable
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ # Process prompts if embeddings not provided
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Reset LoRA scaling if applied
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ # Prepare text IDs tensor
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
+ device=device, dtype=dtype
+ )
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ """Create coordinate-based latent image IDs."""
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ return latent_image_ids.reshape(height * width, 3).to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ """Pack latents into sequence format."""
+ latents = latents.view(
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
+ )
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ return latents.reshape(
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
+ )
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ """Unpack latents from sequence format back to spatial format."""
+ batch_size, num_patches, channels = latents.shape
+
+ # Adjust dimensions for VAE scaling
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+ return latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ def blend_v(self, a, b, blend_extent):
+ """Vertical blending between two tensors."""
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = (
+ a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) +
+ b[:, :, y, :] * (y / blend_extent)
+ )
+ return b
+
+ def blend_h(self, a, b, blend_extent):
+ """Horizontal blending between two tensors."""
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = (
+ a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) +
+ b[:, :, :, x] * (x / blend_extent)
+ )
+ return b
+
+ def enable_vae_slicing(self):
+ """Enable sliced VAE decoding."""
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ """Disable sliced VAE decoding."""
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ """Enable tiled VAE decoding."""
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ """Disable tiled VAE decoding."""
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ """Prepare initial noise latents for generation."""
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(
+ batch_size, height // 2, width // 2, device, dtype
+ )
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ # Validate generator list length
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"Generator list length {len(generator)} != batch size {batch_size}"
+ )
+
+ # Generate random noise
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # Apply blending extension
+ latents = torch.cat([latents, latents[:, :, :, :self.blend_extend]], dim=-1)
+ width_new_blended = latents.shape[-1]
+
+ # Pack latents and prepare IDs
+ latents = self._pack_latents(
+ latents, batch_size, num_channels_latents, height, width_new_blended
+ )
+ latent_image_ids = self._prepare_latent_image_ids(
+ batch_size, height // 2, width_new_blended // 2, device, dtype
+ )
+
+ return latents, latent_image_ids, width_new_blended
+
+ def prepare_inpainting_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ is_strength_max=True,
+ timestep=None,
+ ):
+ """Prepare latents for inpainting pipeline."""
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(
+ batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # Check if generation strength is at its maximum
+ if not is_strength_max:
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(
+ image=image, generator=generator)
+
+ # Generate noise latents
+ noise = randn_tensor(shape, generator=generator,
+ device=device, dtype=dtype)
+ latents = noise if is_strength_max else self.scheduler.scale_noise(
+ image_latents, timestep, noise)
+ width_new_blended = latents.shape[-1]
+
+ # Organize the latents into proper batch structure with specific shape
+ latents = self._pack_latents(
+ latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(
+ batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids, width_new_blended
+
+ def _predict_noise(
+ self,
+ latents: torch.Tensor,
+ timestep: torch.Tensor,
+ guidance: Optional[torch.Tensor],
+ pooled_prompt_embeds: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ text_ids: torch.Tensor,
+ latent_image_ids: torch.Tensor,
+ is_inpainting: bool = False,
+ **kwargs
+ ) -> torch.Tensor:
+ """Predict noise using transformer with proper conditioning."""
+ # Prepare transformer inputs
+ transformer_inputs = {
+ "hidden_states": torch.cat([latents, kwargs.get('masked_image_latents', latents)], dim=2)
+ if is_inpainting else latents,
+ "timestep": timestep / 1000,
+ "guidance": guidance,
+ "pooled_projections": pooled_prompt_embeds,
+ "encoder_hidden_states": prompt_embeds,
+ "txt_ids": text_ids,
+ "img_ids": latent_image_ids,
+ "joint_attention_kwargs": self._joint_attention_kwargs,
+ "return_dict": False,
+ }
+
+ return self.transformer(**transformer_inputs)[0]
+
+ def _apply_blending(
+ self,
+ latents: torch.Tensor,
+ height: int,
+ width_new_blended: int,
+ num_channels_latents: int,
+ batch_size: int,
+ **kwargs
+ ) -> torch.Tensor:
+ """Apply horizontal blending to latents."""
+ # Unpack latents for processing
+ latents_unpack = self._unpack_latents(
+ latents, height, width_new_blended, self.vae_scale_factor
+ )
+
+ # Apply blending
+ latents_unpack = self.blend_h(
+ latents_unpack, latents_unpack, self.blend_extend
+ )
+
+ # Repack latents after blending
+ return self._pack_latents(
+ latents_unpack,
+ batch_size,
+ num_channels_latents,
+ height // 8,
+ width_new_blended // 8
+ )
+
+ def _apply_blending_mask(
+ self,
+ latents: torch.Tensor,
+ height: int,
+ width_new_blended: int,
+ num_channels_latents: int,
+ batch_size: int,
+ **kwargs
+ ) -> torch.Tensor:
+ return self._apply_blending(
+ latents, height, width_new_blended,
+ num_channels_latents + self.vae_scale_factor * self.vae_scale_factor,
+ batch_size, **kwargs
+ )
+
+ def _final_process_latents(
+ self,
+ latents: torch.Tensor,
+ height: int,
+ width_new_blended: int,
+ target_width: int
+ ) -> torch.Tensor:
+ """Final processing of latents before decoding."""
+ # Unpack and crop to target width
+ latents_unpack = self._unpack_latents(
+ latents, height, width_new_blended, self.vae_scale_factor
+ )
+ latents_unpack = self.blend_h(
+ latents_unpack, latents_unpack, self.blend_extend
+ )
+ latents_unpack = latents_unpack[:, :, :, :target_width // self.vae_scale_factor]
+
+ # Repack for final output
+ return self._pack_latents(
+ latents_unpack,
+ latents.shape[0], # batch size
+ latents.shape[2] // 4, # num_channels_latents
+ height // 8,
+ target_width // 8
+ )
+
+ def _check_inputs(
+ self,
+ prompt: Optional[Union[str, List[str]]],
+ prompt_2: Optional[Union[str, List[str]]],
+ height: int,
+ width: int,
+ negative_prompt: Optional[Union[str, List[str]]],
+ negative_prompt_2: Optional[Union[str, List[str]]],
+ prompt_embeds: Optional[torch.FloatTensor],
+ negative_prompt_embeds: Optional[torch.FloatTensor],
+ pooled_prompt_embeds: Optional[torch.FloatTensor],
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor],
+ callback_on_step_end_tensor_inputs: List[str],
+ max_sequence_length: int,
+ is_inpainting: bool,
+ **kwargs
+ ):
+ """Validate all pipeline inputs."""
+ # Check dimensions
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"Input dimensions should be divisible by {self.vae_scale_factor * 2}. "
+ f"Got height={height}, width={width}. Will be resized automatically."
+ )
+
+ # Check callback inputs
+ if callback_on_step_end_tensor_inputs is not None:
+ invalid_inputs = [k for k in callback_on_step_end_tensor_inputs
+ if k not in self._callback_tensor_inputs]
+ if invalid_inputs:
+ raise ValueError(
+ f"Invalid callback tensor inputs: {invalid_inputs}. "
+ f"Allowed inputs: {self._callback_tensor_inputs}"
+ )
+
+ # Check prompt vs prompt_embeds
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ "Cannot provide both prompt and prompt_embeds. Please use only one."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ "Cannot provide both prompt_2 and prompt_embeds. Please use only one."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Must provide either prompt or prompt_embeds."
+ )
+ elif prompt is not None and not isinstance(prompt, (str, list)):
+ raise ValueError(
+ f"prompt must be string or list, got {type(prompt)}"
+ )
+ elif prompt_2 is not None and not isinstance(prompt_2, (str, list)):
+ raise ValueError(
+ f"prompt_2 must be string or list, got {type(prompt_2)}"
+ )
+
+ # Check negative prompts
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ "Cannot provide both negative_prompt and negative_prompt_embeds."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ "Cannot provide both negative_prompt_2 and negative_prompt_embeds."
+ )
+
+ # Check embeddings shapes
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "prompt_embeds and negative_prompt_embeds must have same shape."
+ )
+
+ # Check pooled embeddings
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "Must provide pooled_prompt_embeds with prompt_embeds."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "Must provide negative_pooled_prompt_embeds with negative_prompt_embeds."
+ )
+
+ # Check sequence length
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(
+ f"max_sequence_length cannot exceed 512, got {max_sequence_length}"
+ )
+
+ # Inpainting specific checks
+ if is_inpainting:
+ if kwargs.get('image') is not None and kwargs.get('mask_image') is None:
+ raise ValueError(
+ "Must provide mask_image when using inpainting."
+ )
+ if kwargs.get('image') is not None and kwargs.get('masked_image_latents') is not None:
+ raise ValueError(
+ "Cannot provide both image and masked_image_latents."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def get_batch_size(self, prompt):
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ return batch_size
+
+ def prepare_timesteps(self,
+ num_inference_steps: int,
+ height: int,
+ width: int,
+ device: Union[str, torch.device],
+ sigmas: Optional[np.ndarray] = None,
+ ):
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor //
+ 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ return timesteps, num_inference_steps
+
+ def prepare_blending_latent(
+ self, latents, height, width, batch_size, num_channels_latents, width_new_blended=None
+ ):
+ # Unpack and process latents for blending
+ latents_unpack = self._unpack_latents(
+ latents, height, width, self.vae_scale_factor)
+ latents_unpack = torch.cat(
+ [latents_unpack, latents_unpack[:, :, :, :self.blend_extend]], dim=-1)
+ width_new_blended = latents_unpack.shape[-1] * 8
+
+ # Repack the processed latents
+ latents = self._pack_latents(
+ latents_unpack,
+ batch_size,
+ num_channels_latents,
+ height // 8,
+ width_new_blended // 8
+ )
+ return latents, width_new_blended
+
+ @torch.no_grad()
+ def _call_shared(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ blend_extend: int = 6,
+ is_inpainting: bool = False,
+ **kwargs,
+ ):
+ """Shared implementation between generation and inpainting pipelines."""
+ # Enable tiled decoding
+ self.vae.enable_tiling()
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ if self.use_tiling:
+ return self.tiled_decode(z, return_dict=return_dict)
+ if self.post_quant_conv is not None:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def tiled_decode(
+ self,
+ z: torch.FloatTensor,
+ return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""Decode a batch of images using a tiled decoder.
+
+ Args:
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
+ steps. This is useful to keep memory use constant regardless of image size.
+ The end result of tiled decoding is: different from non-tiled decoding due to each tile using a different
+ decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output.
+ You may still see tile-sized changes in the look of the output, but they should be much less noticeable.
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
+ `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ overlap_size = int(self.tile_latent_min_size *
+ (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size *
+ self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ w = z.shape[3]
+
+ z = torch.cat([z, z[:, :, :, :2]], dim=-1) # [1, 16, 64, 160]
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ tile = z[:, :, i:i + self.tile_latent_min_size, :]
+ if self.config.use_post_quant_conv:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ vae_scale_factor = decoded.shape[-1] // tile.shape[-1]
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(
+ self.blend_h(
+ tile[:, :, :row_limit, w * vae_scale_factor:],
+ tile[:, :, :row_limit, :w * vae_scale_factor],
+ tile.shape[-1] - w * vae_scale_factor))
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ if not return_dict:
+ return (dec, )
+ return DecoderOutput(sample=dec)
+
+ self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL)
+ self.vae._decode = _decode.__get__(self.vae, AutoencoderKL)
+
+ self.blend_extend = blend_extend
+
+ # Set default dimensions
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # Check inputs (handles both pipelines)
+ self._check_inputs(
+ prompt, prompt_2, height, width,
+ negative_prompt, negative_prompt_2,
+ prompt_embeds, negative_prompt_embeds,
+ pooled_prompt_embeds, negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ max_sequence_length,
+ is_inpainting,
+ **kwargs
+ )
+
+ # Set class variables
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs or {}
+ self._interrupt = False
+
+ # Determine if the strength is at its maximum
+ if is_inpainting:
+ strength = kwargs.get('strength', 1.0)
+ is_strength_max = strength == 1.0
+
+ # Determine batch size
+ batch_size = self.get_batch_size(prompt)
+
+ device = self._execution_device
+
+ # Prepare timesteps
+ timesteps, num_inference_steps = self.prepare_timesteps(
+ num_inference_steps, height, width, device
+ )
+
+ # Adjust timesteps based on strength parameter
+ if kwargs.get('is_inpainting', False):
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps, kwargs['strength'], device)
+
+ num_warmup_steps = max(
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Encode prompts
+ lora_scale = self._joint_attention_kwargs.get("scale", None)
+ do_true_cfg = true_cfg_scale > 1 and (negative_prompt is not None or
+ (negative_prompt_embeds is not None and
+ negative_pooled_prompt_embeds is not None))
+
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ if do_true_cfg:
+ negative_prompt_embeds, negative_pooled_prompt_embeds, _ = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # Prepare latents
+ if is_inpainting:
+ image = kwargs.get('image', None)
+
+ # Create latent timestep tensor
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # Get number of latent channels from VAE config
+ num_channels_latents = self.vae.config.latent_channels
+
+ latents, latent_image_ids, width_new_blended = self.prepare_inpainting_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.image_processor.preprocess(
+ image, height=height, width=width).to(dtype=torch.float32),
+ is_strength_max,
+ latent_timestep
+ )
+
+ # if needed
+ latents, width_new_blended = self.prepare_blending_latent(
+ latents, height, width, batch_size, num_channels_latents, width_new_blended
+ )
+
+ # Prepare latent image IDs for the new blended width
+ if not kwargs.get('blend_extra_chanel', False):
+ latent_image_ids = self._prepare_latent_image_ids(
+ batch_size * num_images_per_prompt,
+ height // 16,
+ width_new_blended // 16,
+ latents.device,
+ latents.dtype
+ )
+
+ # Prepare mask and masked image latents
+ masked_image_latents = kwargs.get('masked_image_latents', None)
+
+ if masked_image_latents is not None:
+ masked_image_latents = masked_image_latents.to(latents.device)
+ else:
+ mask_image = kwargs.get('mask_image', None)
+ # Preprocess input image and mask
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ # Create masked image
+ masked_image = image * (1 - mask_image)
+ masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
+
+ # Prepare mask and masked image latents
+ height, width = image.shape[-2:]
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_image,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ kwargs.get('blend_extra_chanel', False)
+ )
+
+ # Combine mask and masked image latents
+ masked_image_latents = torch.cat(
+ (masked_image_latents, mask), dim=-1)
+
+ # if needed
+ masked_image_latents, masked_width_new_blended = self.prepare_blending_latent(
+ masked_image_latents, height, width, batch_size,
+ num_channels_latents + self.vae_scale_factor * self.vae_scale_factor,
+ width_new_blended
+ )
+ # update masked_image_latents
+ kwargs["masked_image_latents"] = masked_image_latents
+ else:
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids, width_new_blended = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ width_new_blended = width_new_blended * self.vae_scale_factor
+
+ # Handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # Predict noise
+ noise_pred = self._predict_noise(
+ latents, timestep, guidance, pooled_prompt_embeds,
+ prompt_embeds, text_ids, latent_image_ids,
+ is_inpainting, **kwargs
+ )
+
+ # Apply true CFG if enabled
+ if do_true_cfg:
+ if not is_inpainting and negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+
+ neg_noise_pred = self._predict_noise(
+ latents, timestep, guidance, negative_pooled_prompt_embeds,
+ negative_prompt_embeds, text_ids, latent_image_ids,
+ is_inpainting, **kwargs
+ )
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # Step with scheduler
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Apply blending in early steps
+ if i <= kwargs.get('early_steps', 4):
+ latents = self._apply_blending(
+ latents, height, width_new_blended, num_channels_latents, batch_size, **kwargs
+ )
+ if is_inpainting:
+ masked_image_latents = self._apply_blending_mask(
+ masked_image_latents, height,
+ masked_width_new_blended,
+ num_channels_latents, batch_size,
+ **kwargs
+ )
+
+ # Fix dtype issues
+ if latents.dtype != latents_dtype and torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave
+ # due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # Handle callbacks
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # Update progress
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Final processing
+ latents = self._final_process_latents(latents, height, width_new_blended, width)
+
+ # Decode latents
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Clean up
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
+
+
+class FluxPipeline(
+ FluxBasePipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ """Main Flux generation pipeline"""
+ _optional_components = ["image_encoder", "feature_extractor"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ )
+
+ # Register optional components
+ self.register_modules(
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+
+ def encode_image(self, image, device, num_images_per_prompt):
+ """Encode input image into embeddings."""
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ return image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ @torch.no_grad()
+ def __call__(self, **kwargs):
+ """Main generation call"""
+ return self._call_shared(is_inpainting=False, **kwargs)
+
+
+class FluxFillPipeline(
+ FluxBasePipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ """Flux inpainting pipeline."""
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ )
+ # Initialize mask processor
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.vae.config.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ r"""
+ Encodes the input image using the VAE and returns the encoded latents.
+ Args:
+ image (`torch.Tensor`):
+ The input image tensor to be encoded.
+ generator (`torch.Generator`):
+ A random number generator for sampling.
+ Returns:
+ `torch.Tensor`: The encoded image latents.
+ """
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(
+ image[i: i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(
+ self.vae.encode(image), generator=generator)
+
+ image_latents = (
+ image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ def get_timesteps(
+ self,
+ num_inference_steps,
+ strength,
+ device
+ ):
+ timesteps = timesteps[int((1 - strength) * num_inference_steps):]
+ return timesteps, num_inference_steps
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ blend_extra_chanel=False
+ ):
+ r""" Prepares the mask and masked image latents for the FluxFillpipeline.
+ Args:
+ mask (`torch.Tensor`):
+ The mask tensor to be processed.
+ masked_image (`torch.Tensor`):
+ The masked image tensor to be processed.
+ batch_size (`int`):
+ The batch size for the input data.
+ num_channels_latents (`int`):
+ The number of channels in the latents.
+ num_images_per_prompt (`int`):
+ The number of images to generate per prompt.
+ height (`int`):
+ The height of the input image.
+ width (`int`):
+ The width of the input image.
+ dtype (`torch.dtype`):
+ The data type for the latents and mask.
+ device (`torch.device`):
+ The device to run the model on.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for sampling.
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`: A tuple containing the processed mask and masked image latents.
+ """
+ # 1. calculate the height and width of the latents
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ # 2. encode the masked image
+ if masked_image.shape[1] == num_channels_latents:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(
+ self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (
+ masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ masked_image_latents = masked_image_latents.to(
+ device=device, dtype=dtype)
+
+ # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ batch_size = batch_size * num_images_per_prompt
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # 4. pack the masked_image_latents
+ # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4
+ if blend_extra_chanel:
+ masked_image_latents = torch.cat(
+ [masked_image_latents, masked_image_latents[:, :, :, :self.blend_extend]], dim=-1)
+
+ width_new_blended = masked_image_latents.shape[-1]
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width_new_blended if blend_extra_chanel else width,
+ )
+ # latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ # 5.resize mask to latents shape we we concatenate the mask to the latents
+ # batch_size, 8 * height, 8 * width (mask has not been 8x compressed)
+ mask = mask[:, 0, :, :]
+ mask = mask.view(
+ batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor
+ ) # batch_size, height, 8, width, 8
+ mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width
+ mask = mask.reshape(
+ batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width
+ ) # batch_size, 8*8, height, width
+ if blend_extra_chanel:
+ mask = torch.cat([mask, mask[:, :, :, :self.blend_extend]], dim=-1)
+
+ # 6. pack the mask:
+ # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2
+ mask = self._pack_latents(
+ mask,
+ batch_size,
+ self.vae_scale_factor * self.vae_scale_factor,
+ height,
+ width_new_blended if blend_extra_chanel else width,
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(self, **kwargs):
+ """Main inpainting call."""
+ return self._call_shared(is_inpainting=True, **kwargs)
diff --git a/hy3dworld/models/world_composer.py b/hy3dworld/models/world_composer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbbabe5b82156424873ea90c4135e40f63c7bb7b
--- /dev/null
+++ b/hy3dworld/models/world_composer.py
@@ -0,0 +1,575 @@
+import os
+import cv2
+import json
+import numpy as np
+from PIL import Image
+import open3d as o3d
+
+import torch
+from typing import Union, Tuple
+
+from .adaptive_depth_compression import create_adaptive_depth_compressor
+
+from ..utils import (
+ get_no_fg_img,
+ get_fg_mask,
+ get_bg_mask,
+ get_filtered_mask,
+ sheet_warping,
+ depth_match,
+ seed_all,
+ build_depth_model,
+ pred_pano_depth,
+)
+
+
+class WorldComposer:
+ r"""WorldComposer is responsible for composing a layered world from input images and masks.
+ It handles foreground object generation, background layer composition, and depth inpainting.
+ Args:
+ device (torch.device): The device to run the model on (default: "cuda").
+ resolution (Tuple[int, int]): The resolution of the input images (width, height).
+ filter_mask (bool): Whether to filter the foreground masks.
+ kernel_scale (int): The scale factor for kernel size in mask processing (default: 1).
+ adaptive_depth_compression (bool): Whether to enable adaptive depth compression (default: True).
+ seed (int): Random seed for reproducibility.
+ """
+
+ def __init__(
+ self,
+ device: torch.device = "cuda",
+ resolution: Tuple[int, int] = (3840, 1920),
+ seed: int = 42,
+ filter_mask: bool = False,
+ kernel_scale: int = 1,
+ adaptive_depth_compression: bool = True,
+ max_fg_mesh_res: int = 3840,
+ max_bg_mesh_res: int = 3840,
+ max_sky_mesh_res: int = 1920,
+ sky_mask_dilation_kernel: int = 5,
+ bg_depth_compression_quantile: float = 0.92,
+ fg_mask_erode_scale: float = 2.5,
+ fg_filter_beta_scale: float = 3.3,
+ fg_filter_alpha_scale: float = 0.15,
+ sky_depth_margin: float = 1.02,
+ ):
+ r"""Initialize"""
+ self.device = device
+ self.resolution = resolution
+ self.filter_mask = filter_mask
+ self.kernel_scale = kernel_scale
+ self.max_fg_mesh_res = max_fg_mesh_res
+ self.max_bg_mesh_res = max_bg_mesh_res
+ self.max_sky_mesh_res = max_sky_mesh_res
+ self.sky_mask_dilation_kernel = sky_mask_dilation_kernel
+ self.bg_depth_compression_quantile = bg_depth_compression_quantile
+ self.fg_mask_erode_scale = fg_mask_erode_scale
+ self.fg_filter_beta_scale = fg_filter_beta_scale
+ self.fg_filter_alpha_scale = fg_filter_alpha_scale
+ self.sky_depth_margin = sky_depth_margin
+
+ # Adaptive deep compression configuration
+ self.adaptive_depth_compression = adaptive_depth_compression
+ self.depth_model = build_depth_model(device)
+
+ # Initialize world composition variables
+ self._init_list()
+ # init seed
+ seed_all(seed)
+
+ def _init_list(self):
+ self.layered_world_mesh = []
+ self.layered_world_depth = []
+
+ def _process_input(self, separate_pano, fg_bboxes):
+ # get all inputs
+ self.full_img = separate_pano["full_img"]
+ self.no_fg1_img = separate_pano["no_fg1_img"]
+ self.no_fg2_img = separate_pano["no_fg2_img"]
+ self.sky_img = separate_pano["sky_img"]
+ self.fg1_mask = separate_pano["fg1_mask"]
+ self.fg2_mask = separate_pano["fg2_mask"]
+ self.sky_mask = separate_pano["sky_mask"]
+ self.fg1_bbox = fg_bboxes["fg1_bbox"]
+ self.fg2_bbox = fg_bboxes["fg2_bbox"]
+
+ def _process_sky_mask(self):
+ r"""Process the sky mask to prepare it for further operations."""
+ if self.sky_mask is not None:
+ # The sky mask identifies non-sky regions, so it needs to be inverted.
+ self.sky_mask = 1 - np.array(self.sky_mask) / 255.0
+ if len(self.sky_mask.shape) > 2:
+ self.sky_mask = self.sky_mask[:, :, 0]
+ # Expand the sky mask to ensure complete coverage.
+ kernel_size = self.sky_mask_dilation_kernel * self.kernel_scale
+ self.sky_mask = (
+ cv2.dilate(
+ self.sky_mask,
+ np.ones((kernel_size, kernel_size), np.uint8),
+ iterations=1,
+ )
+ if self.sky_mask.sum() > 0
+ else self.sky_mask
+ )
+ else:
+ # Create an empty mask if no sky is present.
+ self.sky_mask = np.zeros((self.H, self.W))
+
+ def _process_fg_mask(self, fg_mask):
+ r"""Process the foreground mask to prepare it for further operations."""
+ if fg_mask is not None:
+ fg_mask = np.array(fg_mask)
+ if len(fg_mask.shape) > 2:
+ fg_mask = fg_mask[:, :, 0]
+ return fg_mask
+
+ def _load_separate_pano_from_dir(self, image_dir, sr):
+ r"""Load separate panorama images and foreground bounding boxes from a directory.
+ Args:
+ image_dir (str): The directory containing the panorama images and bounding boxes.
+ sr (bool): Whether to use super-resolution versions of the images.
+ Returns:
+ images (dict): A dictionary containing the loaded images with keys:
+ - "full_img": Complete panorama image (PIL.Image.Image)
+ - "no_fg1_img": Panorama with layer 1 foreground object removed (PIL.Image.Image)
+ - "no_fg2_img": Panorama with layer 2 foreground object removed (PIL.Image.Image)
+ - "sky_img": Sky region image (PIL.Image.Image)
+ - "fg1_mask": Binary mask for layer 1 foreground object (PIL.Image.Image)
+ - "fg2_mask": Binary mask for layer 2 foreground object (PIL.Image.Image)
+ - "sky_mask": Binary mask for sky region (PIL.Image.Image)
+ fg_bboxes (dict): A dictionary containing bounding boxes for foreground objects with keys:
+ - "fg1_bbox": List of dicts with keys 'label', 'bbox', 'score' for layer 1 object
+ - "fg2_bbox": List of dicts with keys 'label', 'bbox', 'score' for layer 2 object
+ Raises:
+ FileNotFoundError: If the specified image directory does not exist.
+ """
+ # Define base image files
+ image_files = {
+ "full_img": "full_image.png",
+ "no_fg1_img": "remove_fg1_image.png",
+ "no_fg2_img": "remove_fg2_image.png",
+ "sky_img": "sky_image.png",
+ "fg1_mask": "fg1_mask.png",
+ "fg2_mask": "fg2_mask.png",
+ "sky_mask": "sky_mask.png",
+ }
+ # Use super-resolution versions if sr flag is set
+ if sr:
+ print("***Using super-resolution input image***")
+ for key in ["full_img", "no_fg1_img", "no_fg2_img", "sky_img"]:
+ image_files[key] = image_files[key].replace(".png", "_sr.png")
+
+ # Check if the directory exists
+ if not os.path.exists(image_dir):
+ raise FileNotFoundError(f"The image directory does not exist: {image_dir}")
+
+ # Load and adjust all images
+ images = {}
+ fg1_bbox_scale = 1
+ fg2_bbox_scale = 1
+ for name, filename in image_files.items():
+ filepath = os.path.join(image_dir, filename)
+ if not os.path.exists(filepath):
+ images[name] = None
+ else:
+ img = Image.open(filepath)
+ if img.size != self.resolution:
+ print(
+ f"Transform the image {name} from {img.size} rescale to {self.resolution}"
+ )
+ # Select different resampling methods based on image type
+ resample = Image.NEAREST if "mask" in name else Image.BICUBIC
+ if "fg1_mask" in name and img.size != self.resolution:
+ fg1_bbox_scale = self.resolution[0] / img.size[0]
+ if "fg2_mask" in name and img.size != self.resolution:
+ fg2_bbox_scale = self.resolution[0] / img.size[0]
+ img = img.resize(self.resolution, resample=resample)
+ images[name] = img
+
+ # Check resolution
+ if self.resolution is not None:
+ for name, img in images.items():
+ if img is not None:
+ assert (
+ img.size == self.resolution
+ ), f"{name} resolution does not match"
+
+ # Load foreground object bbox
+ fg_bboxes = {}
+ fg_bbox_files = {
+ "fg1_bbox": "fg1.json",
+ "fg2_bbox": "fg2.json",
+ }
+ for name, filename in fg_bbox_files.items():
+ filepath = os.path.join(image_dir, filename)
+ if not os.path.exists(filepath):
+ fg_bboxes[name] = None
+ else:
+ fg_bboxes[name] = json.load(open(filepath))
+ if "fg1" in name:
+ for i in range(len(fg_bboxes[name]["bboxes"])):
+ fg_bboxes[name]["bboxes"][i]["bbox"] = [
+ x * fg1_bbox_scale
+ for x in fg_bboxes[name]["bboxes"][i]["bbox"]
+ ]
+ if "fg2" in name:
+ for i in range(len(fg_bboxes[name]["bboxes"])):
+ fg_bboxes[name]["bboxes"][i]["bbox"] = [
+ x * fg2_bbox_scale
+ for x in fg_bboxes[name]["bboxes"][i]["bbox"]
+ ]
+
+ return images, fg_bboxes
+
+ def generate_world(self, **kwargs):
+ r"""Generate a 3D world composition from panorama and foreground objects
+
+ Args:
+ **kwargs: Additional keyword arguments containing:
+ separate_pano (np.ndarray):
+ Panorama image split into separate cubemap faces [6, H, W, C]
+ fg_bboxes (List[Dict]):
+ List of foreground object bounding boxes
+ world_type (str):
+ World generation mode:
+ - 'mesh': export mesh
+
+ Returns:
+ Tuple: A tuple containing:
+ world (np.ndarray):
+ Rendered 3D world view [H,W,3] in RGB format
+ layered_world_depth (np.ndarray):
+ Depth map of the composition [H,W]
+ with values in [0,1] range (1=far)
+ generated_fg_objects (List[Dict]):
+ Processed foreground objects
+ """
+ # temporary input setting
+ separate_pano = kwargs["separate_pano"]
+ fg_bboxes = kwargs["fg_bboxes"]
+ world_type = kwargs["world_type"]
+
+ layered_world_mesh = self._compose_layered_world(
+ separate_pano, fg_bboxes, world_type=world_type
+ )
+ return layered_world_mesh
+
+ def _compose_background_layer(self):
+ r"""Compose the background layer of the world."""
+ # The background layer is composed of the full image without foreground objects.
+ if self.BG_MASK.sum() == 0:
+ return
+
+ print(f"🏞️ Composing the background layer...")
+ if self.fg_status == "no_fg":
+ self.no_fg_img_depth = self.full_img_depth
+ else:
+ # For cascade inpainting, use the last layer's depth as known depth.
+ if self.fg_status == "both_fg1_fg2":
+ inpaint_mask = self.fg2_mask.astype(np.bool_).astype(np.uint8)
+ else:
+ inpaint_mask = self.FG_MASK
+
+ # Align the depth of the background layer to the depth of the panoramic image
+ self.no_fg_img_depth = pred_pano_depth(
+ self.depth_model,
+ self.no_fg_img,
+ img_name="background",
+ last_layer_mask=inpaint_mask,
+ last_layer_depth=self.layered_world_depth[-1],
+ )
+
+ self.no_fg_img_depth = depth_match(
+ self.full_img_depth, self.no_fg_img_depth, self.BG_MASK
+ )
+
+ # Apply adaptive depth compression considering foreground layers and scene characteristics
+ distance = self.no_fg_img_depth["distance"]
+ if (
+ hasattr(self, "adaptive_depth_compression")
+ and self.adaptive_depth_compression
+ ):
+ # Automatically determine scene type based on sky_img
+ scene_type = "indoor" if self.sky_img is None else "outdoor"
+ depth_compressor = create_adaptive_depth_compressor(scene_type=scene_type)
+ self.no_fg_img_depth["distance"] = (
+ depth_compressor.compress_background_depth(
+ distance, self.layered_world_depth, bg_mask=1 - self.sky_mask
+ )
+ )
+ else:
+ # Use a simple quantile-based depth compression method.
+ q_val = torch.quantile(distance, self.bg_depth_compression_quantile)
+ self.no_fg_img_depth["distance"] = torch.clamp(distance, max=q_val)
+
+ layer_depth_i = self.no_fg_img_depth.copy()
+ layer_depth_i["name"] = "background"
+ layer_depth_i["mask"] = 1 - self.sky_mask
+ layer_depth_i["type"] = "bg"
+ self.layered_world_depth.append(layer_depth_i)
+
+ if "mesh" in self.world_type:
+ no_fg_img_mesh = sheet_warping(
+ self.no_fg_img_depth,
+ excluded_region_mask=torch.from_numpy(self.sky_mask).bool(),
+ max_size=self.max_bg_mesh_res,
+ )
+ self.layered_world_mesh.append({"type": "bg", "mesh": no_fg_img_mesh})
+
+ def _compose_foreground_layer(self):
+ if self.fg_status == "no_fg":
+ return
+
+ print(f"🧩 Composing the foreground layers...")
+
+ # Obtain the list of foreground layers
+ fg_layer_list = []
+ if self.fg_status == "both_fg1_fg2":
+ fg_layer_list.append(
+ [self.full_img, self.fg1_mask, self.fg1_bbox, "fg1"]
+ ) # fg1 mesh
+ fg_layer_list.append(
+ [self.no_fg1_img, self.fg2_mask, self.fg2_bbox, "fg2"]
+ ) # fg2 mesh
+ elif self.fg_status == "only_fg1":
+ fg_layer_list.append(
+ [self.full_img, self.fg1_mask, self.fg1_bbox, "fg1"]
+ ) # fg1 mesh
+ elif self.fg_status == "only_fg2":
+ fg_layer_list.append(
+ [self.no_fg1_img, self.fg2_mask, self.fg2_bbox, "fg2"]
+ ) # fg2 mesh
+
+ # Determine whether to generate foreground objects or directly project foreground layers
+ project_object_layer = ["fg1", "fg2"]
+
+ for fg_i_img, fg_i_mask, fg_i_bbox, fg_i_type in fg_layer_list:
+ print(f"\t - Composing the foreground layer: {fg_i_type}")
+ # 1. Estimate the depth of the foreground layer
+ # If there are fg1 and fg2, then fg1_img is the panoramic image itself, without the need to estimate depth
+ if len(fg_layer_list) > 1:
+ if fg_i_type == "fg1":
+ fg_i_img_depth = self.full_img_depth
+ elif fg_i_type == "fg2":
+ fg_i_img_depth = pred_pano_depth(
+ self.depth_model,
+ fg_i_img,
+ img_name=f"{fg_i_type}",
+ last_layer_mask=self.fg1_mask.astype(np.bool_).astype(np.uint8),
+ last_layer_depth=self.full_img_depth,
+ )
+ # fg2 only needs to align the depth of the fg2 object area
+ fg2_exclude_fg1_mask = np.logical_and(
+ fg_i_mask.astype(np.bool_), 1 - self.fg1_mask.astype(np.bool_)
+ )
+
+ # Align the depth of the foreground layer to the depth of the panoramic image
+ fg_i_img_depth = depth_match(
+ self.full_img_depth, fg_i_img_depth, fg2_exclude_fg1_mask
+ )
+ else:
+ raise ValueError(f"Invalid foreground object type: {fg_i_type}")
+ else:
+ # If only fg1 or fg2 exists, its image is the panoramic image, so depth estimation is not required.
+ fg_i_img_depth = self.full_img_depth
+
+ # Compress outliers in the foreground depth.
+ if (
+ hasattr(self, "adaptive_depth_compression")
+ and self.adaptive_depth_compression
+ ):
+ depth_compressor = create_adaptive_depth_compressor()
+ fg_i_img_depth["distance"] = depth_compressor.compress_foreground_depth(
+ fg_i_img_depth["distance"], fg_i_mask
+ )
+
+ in_fg_i_mask = fg_i_mask.copy()
+ if fg_i_mask.sum() > 0:
+ # 2. Perform sheet warping.
+ if fg_i_type in project_object_layer:
+ in_fg_i_mask = self._project_fg_depth(
+ fg_i_img_depth, fg_i_mask, fg_i_type
+ )
+ else:
+ raise ValueError(f"Invalid foreground object type: {fg_i_type}")
+ else:
+ # If no objects are in the foreground layer, it won't be added to the layered world depth.
+ pass
+
+ # save layered depth
+ layer_depth_i = fg_i_img_depth.copy()
+ layer_depth_i["name"] = fg_i_type
+ # Using edge filtered masks to ensure the accuracy of foreground depth during depth compression
+ layer_depth_i["mask"] = (
+ in_fg_i_mask if in_fg_i_mask is not None else np.zeros_like(fg_i_mask)
+ )
+ layer_depth_i["type"] = fg_i_type
+ self.layered_world_depth.append(layer_depth_i)
+
+ def _project_fg_depth(self, fg_i_img_depth, fg_i_mask, fg_i_type):
+ r"""Project the foreground depth to create a mesh or Gaussian splatting object."""
+ in_fg_i_mask = fg_i_mask.astype(np.bool_).astype(
+ np.uint8
+ )
+ # Erode the mask to remove edge artifacts from foreground objects.
+ erode_size = int(self.fg_mask_erode_scale * self.kernel_scale)
+ eroded_in_fg_i_mask = cv2.erode(
+ in_fg_i_mask, np.ones((erode_size, erode_size), np.uint8), iterations=1
+ ) # The result is a uint8 array with values of 0 or 1.
+
+ # Filter edges
+ if self.filter_mask:
+ filtered_fg_i_img_mask = (
+ 1
+ - get_filtered_mask(
+ 1.0 / fg_i_img_depth["distance"][None, :, :, None],
+ beta=self.fg_filter_beta_scale * self.kernel_scale,
+ alpha_threshold=self.fg_filter_alpha_scale * self.kernel_scale,
+ device=self.device,
+ )
+ .squeeze()
+ .cpu()
+ )
+ # Convert to binary mask
+ filtered_fg_i_img_mask = 1 - filtered_fg_i_img_mask.numpy()
+
+ # Combine eroded mask with filtered mask
+ eroded_in_fg_i_mask = np.logical_and(
+ eroded_in_fg_i_mask, filtered_fg_i_img_mask
+ )
+
+ # Process the eroded mask to create the final binary mask
+ in_fg_i_mask = eroded_in_fg_i_mask > 0.5
+ out_fg_i_mask = 1 - in_fg_i_mask
+
+ # Convert the depth image to a mesh or Gaussian splatting object
+ if "mesh" in self.world_type:
+ fg_i_mesh = sheet_warping(
+ fg_i_img_depth,
+ excluded_region_mask=torch.from_numpy(out_fg_i_mask).bool(),
+ max_size=self.max_fg_mesh_res,
+ )
+ self.layered_world_mesh.append({"type": fg_i_type, "mesh": fg_i_mesh})
+
+ return in_fg_i_mask
+
+ def _compose_sky_layer(self):
+ r"""Compose the sky layer of the world."""
+ if self.sky_img is not None:
+ print(f"🕐 Composing the sky layer...")
+ self.sky_img = torch.tensor(
+ np.array(self.sky_img), device=self.full_img_depth["rgb"].device
+ )
+
+ # Calculate the maximum depth value of all foreground and background layers
+ max_scene_depth = torch.tensor(
+ 0.0, device=self.full_img_depth["rgb"].device
+ )
+ for layer in self.layered_world_depth:
+ layer_depth = layer["distance"]
+ layer_mask = layer.get("mask", None)
+
+ if layer_mask is not None:
+ if not isinstance(layer_mask, torch.Tensor):
+ layer_mask = torch.from_numpy(layer_mask).to(layer_depth.device)
+ mask_bool = layer_mask.bool()
+ if (
+ mask_bool.sum() > 0
+ ): # Only search for the maximum value within the mask area
+ layer_max = layer_depth[mask_bool].max()
+ max_scene_depth = torch.max(max_scene_depth, layer_max)
+ else:
+ # If there is no mask, consider the entire depth map
+ max_scene_depth = torch.max(max_scene_depth, layer_depth.max())
+
+ # Set the sky depth to be slightly greater than the maximum scene depth.
+ sky_distance = self.sky_depth_margin * max_scene_depth if max_scene_depth > 0 else 3.0
+
+ sky_pred = {
+ "rgb": self.sky_img,
+ "rays": self.full_img_depth["rays"],
+ "distance": sky_distance
+ * torch.ones_like(self.full_img_depth["distance"]),
+ }
+
+ if "mesh" in self.world_type:
+ # The sky doesn't need smooth edges with jagged edges
+ sky_mesh = sheet_warping(
+ sky_pred,
+ connect_boundary_max_dist=None,
+ max_size=self.max_sky_mesh_res,
+ )
+ self.layered_world_mesh.append({"type": "sky", "mesh": sky_mesh})
+
+ def _compose_layered_world(
+ self,
+ separate_pano: dict,
+ fg_bboxes: dict,
+ world_type: list = ["mesh"],
+ ) -> Union[o3d.geometry.TriangleMesh]:
+ r"""
+ Compose each layer into a complete world
+ Args:
+ separate_pano: dict containing the following images:
+ full_img: Complete panorama image (PIL.Image.Image)
+ no_fg1_img: Panorama with layer 1 foreground object removed (PIL.Image.Image)
+ no_fg2_img: Panorama with layer 2 foreground object removed (PIL.Image.Image)
+ sky_img: Sky region image (PIL.Image.Image)
+ fg1_mask: Binary mask for layer 1 foreground object (PIL.Image.Image)
+ fg2_mask: Binary mask for layer 2 foreground object (PIL.Image.Image)
+ sky_mask: Binary mask for sky region (PIL.Image.Image)
+
+ fg_bboxes: dict containing bounding boxes for foreground objects:
+ fg1_bbox: List of dicts with keys 'label', 'bbox', 'score' for layer 1 object
+ fg2_bbox: List of dicts with keys 'label', 'bbox', 'score' for layer 2 object
+
+ world_type: list, ["mesh"]
+
+ filter_mask: bool, whether to filter the mask
+
+ Returns:
+ layered_world: dict containing the following:
+ mesh: list of o3d.geometry.TriangleMesh
+ objects: list of ImageWithOneObject
+ """
+ self.world_type = world_type
+ self._process_input(separate_pano, fg_bboxes)
+ self.W, self.H = self.full_img.size
+
+ self._init_list()
+
+ # Processing sky and foreground masks
+ self._process_sky_mask()
+ self.fg1_mask = self._process_fg_mask(self.fg1_mask)
+ self.fg2_mask = self._process_fg_mask(self.fg2_mask)
+
+ # Overall foreground mask: Merge multiple foreground masks, background mask: Excluding sky
+ self.FG_MASK = get_fg_mask(self.fg1_mask, self.fg2_mask)
+ self.BG_MASK = get_bg_mask(self.sky_mask, self.FG_MASK, self.kernel_scale)
+
+ # Obtain background+sky layer (no_fg_img
+ self.no_fg_img, self.fg_status = get_no_fg_img(
+ self.no_fg1_img, self.no_fg2_img, self.full_img
+ )
+
+ # Predicting the Depth of Panoramic Images
+ self.full_img_depth = pred_pano_depth( # fg1 depth
+ self.depth_model,
+ self.full_img,
+ img_name="full_img",
+ )
+
+ # Layered construction of the world
+ print(f"🎨 Start to compose the world layer by layer...")
+ # 1. The foreground layers
+ self._compose_foreground_layer()
+
+ # 2. The background layers
+ self._compose_background_layer()
+
+ # 3. The sky layers
+ self._compose_sky_layer()
+
+ print("🎉 Congratulations! World composition completed successfully!")
+
+ return self.layered_world_mesh
diff --git a/hy3dworld/utils/__init__.py b/hy3dworld/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a883ccb98491f54bc7146c1a37b42023a82adff
--- /dev/null
+++ b/hy3dworld/utils/__init__.py
@@ -0,0 +1,37 @@
+# Tencent HunyuanWorld-1.0 is licensed under TENCENT HUNYUANWORLD-1.0 COMMUNITY LICENSE AGREEMENT
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND
+# IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying
+# any portion or element of the Tencent HunyuanWorld-1.0 Works, including via any Hosted Service,
+# You will be deemed to have recognized and accepted the content of this Agreement,
+# which is effective immediately.
+
+# For avoidance of doubts, Tencent HunyuanWorld-1.0 means the 3D generation models
+# and their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent at [https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0].
+
+from .export_utils import process_file
+from .perspective_utils import Perspective
+from .general_utils import (
+ pano_sheet_warping,
+ depth_match,
+ get_no_fg_img,
+ get_fg_mask,
+ get_bg_mask,
+ spherical_uv_to_directions,
+ get_filtered_mask,
+ sheet_warping,
+ seed_all,
+ colorize_depth_maps,
+)
+from .pano_depth_utils import coords_grid, build_depth_model, pred_pano_depth
+
+__all__ = [
+ "process_file", "pano_sheet_warping", "depth_match",
+ "get_no_fg_img", "get_fg_mask", "get_bg_mask",
+ "spherical_uv_to_directions", "get_filtered_mask",
+ "sheet_warping", "seed_all", "colorize_depth_maps", "Perspective",
+ "coords_grid", "build_depth_model", "pred_pano_depth"
+]
diff --git a/hy3dworld/utils/export_utils.py b/hy3dworld/utils/export_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce2945771da3d81f66148a33df5781386634308
--- /dev/null
+++ b/hy3dworld/utils/export_utils.py
@@ -0,0 +1,17 @@
+import trimesh
+
+
+def process_file(input_path, output_path):
+ r"""Convert a PLY file to Draco format.
+ Args:
+ input_path (str): Path to the input PLY file.
+ output_path (str): Path to save the output Draco file.
+ """
+ mesh = trimesh.load(input_path)
+ try:
+ # Attempt Draco-compressed PLY export
+ export_data = trimesh.exchange.ply.export_draco(mesh)
+ with open(output_path, 'wb') as f:
+ f.write(export_data)
+ except Exception as e:
+ print(f"Draco export failed: {str(e)}. May need confirm installation")
diff --git a/hy3dworld/utils/general_utils.py b/hy3dworld/utils/general_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6219f665e3383163a3d2ced5411c56fbdfcd5794
--- /dev/null
+++ b/hy3dworld/utils/general_utils.py
@@ -0,0 +1,697 @@
+import cv2
+import numpy as np
+from typing import Optional, Literal
+
+import random
+import matplotlib
+import open3d as o3d
+
+import torch
+import torch.nn.functional as F
+from collections import defaultdict
+
+
+def spherical_uv_to_directions(uv: np.ndarray):
+ r"""
+ Convert spherical UV coordinates to 3D directions.
+ Args:
+ uv (np.ndarray): UV coordinates in the range [0, 1]. Shape: (H, W, 2).
+ Returns:
+ directions (np.ndarray): 3D directions corresponding to the UV coordinates. Shape: (H, W, 3).
+ """
+ theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
+ directions = np.stack([np.sin(phi) * np.cos(theta),
+ np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
+ return directions
+
+
+def depth_match(init_pred: dict, bg_pred: dict, mask: np.ndarray, quantile: float = 0.3) -> dict:
+ r"""
+ Match the background depth map to the scale of the initial depth map.
+ Args:
+ init_pred (dict): Initial depth prediction containing "distance" key.
+ bg_pred (dict): Background depth prediction containing "distance" key.
+ mask (np.ndarray): Binary mask indicating valid pixels in the background depth map.
+ quantile (float): Quantile to use for selecting the depth range for scale matching.
+ Returns:
+ bg_pred (dict): Background depth prediction with adjusted "distance" key.
+ """
+ valid_mask = mask > 0
+ init_distance = init_pred["distance"][valid_mask]
+ bg_distance = bg_pred["distance"][valid_mask]
+
+ init_mask = init_distance < torch.quantile(init_distance, quantile)
+ bg_mask = bg_distance < torch.quantile(bg_distance, quantile)
+ scale = init_distance[init_mask].median() / bg_distance[bg_mask].median()
+ bg_pred["distance"] *= scale
+ return bg_pred
+
+
+
+def _fill_small_boundary_spikes(
+ mesh: o3d.geometry.TriangleMesh,
+ max_bridge_dist: float,
+ repeat_times: int = 3,
+ max_connection_step: int = 8,
+) -> o3d.geometry.TriangleMesh:
+ r"""
+ Fill small boundary spikes in a mesh by creating triangles between boundary vertices.
+ Args:
+ mesh (o3d.geometry.TriangleMesh): The input mesh to process.
+ max_bridge_dist (float): Maximum distance allowed for bridging boundary vertices.
+ repeat_times (int): Number of iterations to repeat the filling process.
+ max_connection_step (int): Maximum number of steps to connect boundary vertices.
+ Returns:
+ o3d.geometry.TriangleMesh: The mesh with small boundary spikes filled.
+ """
+ for iteration in range(repeat_times):
+ if not mesh.has_triangles() or not mesh.has_vertices():
+ return mesh
+
+ vertices = np.asarray(mesh.vertices)
+ triangles = np.asarray(mesh.triangles)
+
+ # 1. Identify boundary edges
+ edge_to_triangle_count = defaultdict(int)
+
+ for tri_idx, tri in enumerate(triangles):
+ for i in range(3):
+ v1_idx, v2_idx = tri[i], tri[(i + 1) % 3]
+ edge = tuple(sorted((v1_idx, v2_idx)))
+ edge_to_triangle_count[edge] += 1
+
+ boundary_edges = [edge for edge,
+ count in edge_to_triangle_count.items() if count == 1]
+
+ if not boundary_edges:
+ return mesh
+
+ # 2. Create an adjacency list for boundary vertices using only boundary edges
+ boundary_adj = defaultdict(list)
+ for v1_idx, v2_idx in boundary_edges:
+ boundary_adj[v1_idx].append(v2_idx)
+ boundary_adj[v2_idx].append(v1_idx)
+
+ # 3. Process boundary vertices with new smooth filling algorithm
+ new_triangles_list = []
+ edge_added = defaultdict(bool)
+
+ # print(f"DEBUG: Found {len(boundary_edges)} boundary edges.")
+ # print(f"DEBUG: Max bridge distance set to: {max_bridge_dist}")
+
+ new_triangles_added_count = 0
+
+ for v_curr_idx, neighbors in boundary_adj.items():
+ if len(neighbors) != 2: # Only process vertices with exactly 2 boundary neighbors
+ continue
+
+ v_a_idx, v_b_idx = neighbors[0], neighbors[1]
+
+ # Skip if these vertices already form a triangle
+ potential_edge = tuple(sorted((v_a_idx, v_b_idx)))
+ if edge_to_triangle_count[potential_edge] > 0 or edge_added[potential_edge]:
+ continue
+
+ # Calculate distances
+ v_curr_coord = vertices[v_curr_idx]
+ v_a_coord = vertices[v_a_idx]
+ v_b_coord = vertices[v_b_idx]
+
+ dist_a_b = np.linalg.norm(v_a_coord - v_b_coord)
+
+ # Skip if distance exceeds threshold
+ if dist_a_b > max_bridge_dist:
+ continue
+
+ # Create simple triangle (v_a, v_b, v_curr)
+ new_triangles_list.append([v_a_idx, v_b_idx, v_curr_idx])
+ new_triangles_added_count += 1
+ edge_added[potential_edge] = True
+
+ # Mark edges as processed
+ edge_added[tuple(sorted((v_curr_idx, v_a_idx)))] = True
+ edge_added[tuple(sorted((v_curr_idx, v_b_idx)))] = True
+
+ # 4. Now process multi-step connections for better smoothing
+ # First build boundary chains for multi-step connections
+ boundary_loops = []
+ visited_vertices = set()
+
+ # Find boundary vertices with exactly 2 neighbors (part of continuous chains)
+ chain_starts = [v for v in boundary_adj if len(
+ boundary_adj[v]) == 2 and v not in visited_vertices]
+
+ for start_vertex in chain_starts:
+ if start_vertex in visited_vertices:
+ continue
+
+ chain = []
+ curr_vertex = start_vertex
+
+ # Follow the chain in one direction
+ while curr_vertex not in visited_vertices:
+ visited_vertices.add(curr_vertex)
+ chain.append(curr_vertex)
+
+ next_candidates = [
+ n for n in boundary_adj[curr_vertex] if n not in visited_vertices]
+ if not next_candidates:
+ break
+
+ curr_vertex = next_candidates[0]
+
+ if len(chain) >= 3:
+ boundary_loops.append(chain)
+
+ # Process each boundary chain for multi-step smoothing
+ for chain in boundary_loops:
+ chain_length = len(chain)
+
+ # Skip very small chains
+ if chain_length < 3:
+ continue
+
+ # Compute multi-step connections
+ max_step = min(max_connection_step, chain_length - 1)
+
+ for i in range(chain_length):
+ anchor_idx = chain[i]
+ anchor_coord = vertices[anchor_idx]
+
+ for step in range(3, max_step + 1):
+ if i + step >= chain_length:
+ break
+
+ far_idx = chain[i + step]
+ far_coord = vertices[far_idx]
+
+ # Check distance criteria
+ dist_anchor_far = np.linalg.norm(anchor_coord - far_coord)
+ if dist_anchor_far > max_bridge_dist * step:
+ continue
+
+ # Check if anchor and far are already connected
+ edge_anchor_far = tuple(sorted((anchor_idx, far_idx)))
+ if edge_to_triangle_count[edge_anchor_far] > 0 or edge_added[edge_anchor_far]:
+ continue
+
+ # Create fan triangles
+ fan_valid = True
+ fan_triangles = []
+
+ prev_mid_idx = anchor_idx
+
+ for j in range(1, step):
+ mid_idx = chain[i + j]
+
+ if prev_mid_idx != anchor_idx:
+ tri_edge1 = tuple(sorted((anchor_idx, mid_idx)))
+ tri_edge2 = tuple(sorted((prev_mid_idx, mid_idx)))
+
+ # Check if edges already exist (not created by our fan)
+ if (edge_to_triangle_count[tri_edge1] > 0 and not edge_added[tri_edge1]) or \
+ (edge_to_triangle_count[tri_edge2] > 0 and not edge_added[tri_edge2]):
+ fan_valid = False
+ break
+
+ fan_triangles.append(
+ [anchor_idx, prev_mid_idx, mid_idx])
+
+ prev_mid_idx = mid_idx
+
+ # Add final triangle to connect to far_idx
+ if fan_valid:
+ fan_triangles.append(
+ [anchor_idx, prev_mid_idx, far_idx])
+
+ # Add all fan triangles if valid
+ if fan_valid and fan_triangles:
+ for triangle in fan_triangles:
+ v_a, v_b, v_c = triangle
+ edge_ab = tuple(sorted((v_a, v_b)))
+ edge_bc = tuple(sorted((v_b, v_c)))
+ edge_ac = tuple(sorted((v_a, v_c)))
+
+ new_triangles_list.append(triangle)
+ new_triangles_added_count += 1
+
+ edge_added[edge_ab] = True
+ edge_added[edge_bc] = True
+ edge_added[edge_ac] = True
+
+ # Once we've added a fan, move to the next anchor
+ break
+
+ if new_triangles_added_count == 0:
+ break
+
+ # Update the mesh with new triangles
+ if new_triangles_list:
+ all_triangles_np = np.vstack(
+ (triangles, np.array(new_triangles_list, dtype=np.int32)))
+
+ final_mesh = o3d.geometry.TriangleMesh()
+ final_mesh.vertices = o3d.utility.Vector3dVector(vertices)
+ final_mesh.triangles = o3d.utility.Vector3iVector(all_triangles_np)
+
+ if mesh.has_vertex_colors():
+ final_mesh.vertex_colors = mesh.vertex_colors
+
+ # Clean up the mesh
+ final_mesh.remove_degenerate_triangles()
+ final_mesh.remove_unreferenced_vertices()
+ mesh = final_mesh
+
+ return mesh
+
+
+def pano_sheet_warping(
+ rgb: torch.Tensor, # (H, W, 3) RGB image, values [0, 1]
+ distance: torch.Tensor, # (H, W) Distance map
+ rays: torch.Tensor, # (H, W, 3) Ray directions (unit vectors ideally)
+ # (H, W) Optional boolean mask
+ excluded_region_mask: Optional[torch.Tensor] = None,
+ max_size: int = 4096, # Max dimension for resizing
+ device: Literal["cuda", "cpu"] = "cuda", # Computation device
+ # Max distance to bridge boundary vertices
+ connect_boundary_max_dist: Optional[float] = 0.5,
+ connect_boundary_repeat_times: int = 2
+) -> o3d.geometry.TriangleMesh:
+ r"""
+ Converts panoramic RGBD data (image, distance, rays) into an Open3D mesh.
+ Args:
+ image: Input RGB image tensor (H, W, 3), uint8 or float [0, 255].
+ distance: Input distance map tensor (H, W).
+ rays: Input ray directions tensor (H, W, 3). Assumed to originate from (0,0,0).
+ excluded_region_mask: Optional boolean mask tensor (H, W). True values indicate regions to potentially exclude.
+ max_size: Maximum size (height or width) to resize inputs to.
+ device: The torch device ('cuda' or 'cpu') to use for computations.
+
+ Returns:
+ An Open3D TriangleMesh object.
+ """
+ assert rgb.ndim == 3 and rgb.shape[2] == 3, "Image must be HxWx3"
+ assert distance.ndim == 2, "Distance must be HxW"
+ assert rays.ndim == 3 and rays.shape[2] == 3, "Rays must be HxWx3"
+ assert (
+ rgb.shape[:2] == distance.shape[:2] == rays.shape[:2]
+ ), "Input shapes must match"
+
+ mask = excluded_region_mask
+
+ if mask is not None:
+ assert (
+ mask.ndim == 2 and mask.shape[:2] == rgb.shape[:2]
+ ), "Mask shape must match"
+ assert mask.dtype == torch.bool, "Mask must be a boolean tensor"
+
+ rgb = rgb.to(device)
+ distance = distance.to(device)
+ rays = rays.to(device)
+ if mask is not None:
+ mask = mask.to(device)
+
+ H, W = distance.shape
+ if max(H, W) > max_size:
+ scale = max_size / max(H, W)
+ else:
+ scale = 1.0
+
+ # --- Resize Inputs ---
+ rgb_nchw = rgb.permute(2, 0, 1).unsqueeze(0)
+ distance_nchw = distance.unsqueeze(0).unsqueeze(0)
+ rays_nchw = rays.permute(2, 0, 1).unsqueeze(0)
+
+ rgb_resized = (
+ F.interpolate(
+ rgb_nchw,
+ scale_factor=scale,
+ mode="bilinear",
+ align_corners=False,
+ recompute_scale_factor=False,
+ )
+ .squeeze(0)
+ .permute(1, 2, 0)
+ )
+
+ distance_resized = (
+ F.interpolate(
+ distance_nchw,
+ scale_factor=scale,
+ mode="bilinear",
+ align_corners=False,
+ recompute_scale_factor=False,
+ )
+ .squeeze(0)
+ .squeeze(0)
+ )
+
+ rays_resized_nchw = F.interpolate(
+ rays_nchw,
+ scale_factor=scale,
+ mode="bilinear",
+ align_corners=False,
+ recompute_scale_factor=False,
+ )
+
+ # IMPORTANT: Renormalize ray directions after interpolation
+ rays_resized = rays_resized_nchw.squeeze(0).permute(1, 2, 0)
+ rays_norm = torch.linalg.norm(rays_resized, dim=-1, keepdim=True)
+ rays_resized = rays_resized / (rays_norm + 1e-8)
+
+ if mask is not None:
+ mask_resized = (
+ F.interpolate(
+ # Needs float for interpolation
+ mask.unsqueeze(0).unsqueeze(0).float(),
+ scale_factor=scale,
+ mode="nearest", # Or 'nearest' if sharp boundaries are critical
+ # align_corners=False,
+ recompute_scale_factor=False,
+ )
+ .squeeze(0)
+ .squeeze(0)
+ )
+ mask_resized = mask_resized > 0.5 # Convert back to boolean
+ else:
+ mask_resized = None
+
+ H_new, W_new = distance_resized.shape # Get new dimensions
+
+ # --- Calculate 3D Vertices ---
+ # Vertex position = origin + distance * ray_direction
+ # Assuming origin is (0, 0, 0)
+ distance_flat = distance_resized.reshape(-1, 1) # (H*W, 1)
+ rays_flat = rays_resized.reshape(-1, 3) # (H*W, 3)
+ vertices = distance_flat * rays_flat # (H*W, 3)
+ vertex_colors = rgb_resized.reshape(-1, 3) # (H*W, 3)
+
+ # --- Generate Mesh Faces (Triangles from Quads) ---
+ # Vectorized approach for generating faces, including seam connection
+ # Rows for the top of quads
+ row_indices = torch.arange(0, H_new - 1, device=device)
+ # Columns for the left of quads (includes last col for wrapping)
+ col_indices = torch.arange(0, W_new, device=device)
+
+ # Create 2D grids of row and column coordinates for quad corners
+ # These represent the (row, col) of the top-left vertex of each quad
+ # Shape: (H_new-1, W_new)
+ quad_row_coords = row_indices.view(-1, 1).expand(-1, W_new)
+ quad_col_coords = col_indices.view(
+ 1, -1).expand(H_new-1, -1) # Shape: (H_new-1, W_new)
+
+ # Top-left vertex indices
+ tl_row, tl_col = quad_row_coords, quad_col_coords
+ # Top-right vertex indices (with wrap-around)
+ tr_row, tr_col = quad_row_coords, (quad_col_coords + 1) % W_new
+ # Bottom-left vertex indices
+ bl_row, bl_col = (quad_row_coords + 1), quad_col_coords
+ # Bottom-right vertex indices (with wrap-around)
+ br_row, br_col = (quad_row_coords + 1), (quad_col_coords + 1) % W_new
+
+ # Convert 2D (row, col) coordinates to 1D vertex indices
+ tl = tl_row * W_new + tl_col
+ tr = tr_row * W_new + tr_col
+ bl = bl_row * W_new + bl_col
+ br = br_row * W_new + br_col
+
+ # Apply mask if provided
+ if mask_resized is not None:
+ # Get mask values for each corner of the quads
+ mask_tl_vals = mask_resized[tl_row, tl_col]
+ mask_tr_vals = mask_resized[tr_row, tr_col]
+ mask_bl_vals = mask_resized[bl_row, bl_col]
+ mask_br_vals = mask_resized[br_row, br_col]
+
+ # A quad is kept if none of its vertices are masked
+ # Shape: (H_new-1, W_new)
+ quad_keep_mask = ~(mask_tl_vals | mask_tr_vals |
+ mask_bl_vals | mask_br_vals)
+
+ # Filter vertex indices based on the keep mask
+ tl = tl[quad_keep_mask] # Result is flattened
+ tr = tr[quad_keep_mask]
+ bl = bl[quad_keep_mask]
+ br = br[quad_keep_mask]
+ else:
+ # If no mask, flatten all potential quads' vertex indices
+ tl = tl.flatten()
+ tr = tr.flatten()
+ bl = bl.flatten()
+ br = br.flatten()
+
+ # Create triangles (two per quad)
+ # Using the same winding order as before: (tl, tr, bl) and (tr, br, bl)
+ tri1 = torch.stack([tl, tr, bl], dim=1)
+ tri2 = torch.stack([tr, br, bl], dim=1)
+ faces = torch.cat([tri1, tri2], dim=0)
+
+ mesh_o3d = o3d.geometry.TriangleMesh()
+ mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices.cpu().numpy())
+ mesh_o3d.triangles = o3d.utility.Vector3iVector(faces.cpu().numpy())
+ mesh_o3d.vertex_colors = o3d.utility.Vector3dVector(
+ vertex_colors.cpu().numpy())
+ mesh_o3d.remove_unreferenced_vertices()
+ mesh_o3d.remove_degenerate_triangles()
+
+ if connect_boundary_max_dist is not None and connect_boundary_max_dist > 0:
+ mesh_o3d = _fill_small_boundary_spikes(
+ mesh_o3d, connect_boundary_max_dist, connect_boundary_repeat_times)
+ # Recompute normals after potential modification, if mesh still valid
+ if mesh_o3d.has_triangles() and mesh_o3d.has_vertices():
+ mesh_o3d.compute_vertex_normals()
+ # Also computes triangle normals if vertex normals are computed
+ mesh_o3d.compute_triangle_normals()
+
+ return mesh_o3d
+
+
+def get_no_fg_img(no_fg1_img, no_fg2_img, full_img):
+ r"""Get the image without foreground objects based on available inputs.
+ Args:
+ no_fg1_img: Image with foreground layer 1 removed
+ no_fg2_img: Image with foreground layer 2 removed
+ full_img: Original full image
+ Returns:
+ Image without foreground objects, defaulting to full image if no fg-removed images available
+ """
+ fg_status = None
+ if no_fg1_img is not None and no_fg2_img is not None:
+ no_fg_img = no_fg2_img
+ fg_status = "both_fg1_fg2"
+ elif no_fg1_img is not None and no_fg2_img is None:
+ no_fg_img = no_fg1_img
+ fg_status = "only_fg1"
+ elif no_fg1_img is None and no_fg2_img is not None:
+ no_fg_img = no_fg2_img
+ fg_status = "only_fg2"
+ else:
+ no_fg_img = full_img
+ fg_status = "no_fg"
+
+ assert fg_status is not None
+
+ return no_fg_img, fg_status
+
+
+def get_fg_mask(fg1_mask, fg2_mask):
+ r"""
+ Combine foreground masks from two layers.
+ Args:
+ fg1_mask: Foreground mask for layer 1
+ fg2_mask: Foreground mask for layer 2
+ Returns:
+ Combined foreground mask, or None if both are None
+ """
+ if fg1_mask is not None and fg2_mask is not None:
+ fg_mask = np.logical_or(fg1_mask, fg2_mask)
+ elif fg1_mask is not None:
+ fg_mask = fg1_mask
+ elif fg2_mask is not None:
+ fg_mask = fg2_mask
+ else:
+ fg_mask = None
+
+ if fg_mask is not None:
+ fg_mask = fg_mask.astype(np.bool_).astype(np.uint8)
+ return fg_mask
+
+
+def get_bg_mask(sky_mask, fg_mask, kernel_scale, dilation_kernel_size: int = 3):
+ r"""
+ Generate background mask based on sky and foreground masks.
+ Args:
+ sky_mask: Sky mask (boolean array)
+ fg_mask: Foreground mask (boolean array)
+ kernel_scale: Scale factor for the kernel size
+ dilation_kernel_size: The size of the dilation kernel.
+ Returns:
+ Background mask as a boolean array, where True indicates background pixels.
+ """
+ kernel_size = dilation_kernel_size * kernel_scale
+ if fg_mask is not None:
+ bg_mask = np.logical_and(
+ (1 - cv2.dilate(fg_mask,
+ np.ones((kernel_size, kernel_size), np.uint8), iterations=1)),
+ (1 - sky_mask),
+ ).astype(np.uint8)
+ else:
+ bg_mask = 1 - sky_mask
+ return bg_mask
+
+
+def get_filtered_mask(disparity, beta=100, alpha_threshold=0.3, device="cuda"):
+ """
+ filter the disparity map using sobel kernel, then mask out the edge (depth discontinuity)
+ Args:
+ disparity: Disparity map in BHWC format, shape [b, h, w, 1]
+ beta: Exponential decay factor for the Sobel magnitude
+ alpha_threshold: Threshold for visibility mask
+ device: Device to perform computations on, either 'cuda' or 'cpu'
+ Returns:
+ vis_mask: Visibility mask in BHWC format, shape [b, h, w, 1]
+ """
+ b, h, w, _ = disparity.size()
+ # Permute to NCHW format: [b, 1, h, w]
+ disparity_nchw = disparity.permute(0, 3, 1, 2)
+
+ # Pad H and W dimensions with replicate padding
+ disparity_padded = F.pad(
+ disparity_nchw, (2, 2, 2, 2), mode="replicate"
+ ) # Pad last two dims (W, H), [b, 1, h+4, w+4]
+
+ kernel_x = (
+ torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .float()
+ .to(device)
+ )
+ kernel_y = (
+ torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]])
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .float()
+ .to(device)
+ )
+
+ # Apply Sobel filters
+ sobel_x = F.conv2d(
+ disparity_padded, kernel_x, padding=(1, 1)
+ ) # Output: [b, 1, h+4, w+4] # Corrected padding
+ sobel_y = F.conv2d(
+ disparity_padded, kernel_y, padding=(1, 1)
+ ) # Output: [b, 1, h+4, w+4] # Corrected padding
+
+ # Calculate magnitude
+ sobel_mag_padded = torch.sqrt(
+ sobel_x**2 + sobel_y**2
+ ) # Shape: [b, 1, h+4, w+4]
+
+ # Remove padding
+ sobel_mag = sobel_mag_padded[
+ :, :, 2:-2, 2:-2
+ ] # Shape: [b, 1, h, w] # Adjusted slicing
+
+ # Calculate alpha and mask
+ alpha = torch.exp(-1.0 * beta * sobel_mag) # Shape: [b, 1, h, w]
+ vis_mask_nchw = torch.greater(alpha, alpha_threshold).float()
+
+ # Permute back to BHWC format: [b, h, w, 1]
+ vis_mask = vis_mask_nchw.permute(0, 2, 3, 1)
+
+ assert vis_mask.shape == disparity.shape # Ensure output shape matches input
+ return vis_mask
+
+
+def sheet_warping(
+ predictions, excluded_region_mask=None,
+ connect_boundary_max_dist=0.5,
+ connect_boundary_repeat_times=2,
+ max_size=4096,
+) -> o3d.geometry.TriangleMesh:
+ r"""
+ Convert depth predictions to a 3D mesh.
+ Args:
+ predictions: Dictionary containing:
+ - "rgb": RGB image tensor of shape (H, W, 3) with
+ values in [0, 255] (uint8) or [0, 1] (float).
+ - "distance": Distance map tensor of shape (H, W).
+ - "rays": Ray directions tensor of shape (H, W, 3).
+ excluded_region_mask: Optional boolean mask tensor of shape (H, W).
+ connect_boundary_max_dist: Maximum distance to bridge boundary vertices.
+ connect_boundary_repeat_times: Number of iterations to repeat the boundary connection.
+ max_size: Maximum size (height or width) to resize inputs to.
+ Returns:
+ An Open3D TriangleMesh object.
+ """
+ rgb = predictions["rgb"] / 255.0
+ distance = predictions["distance"]
+ rays = predictions["rays"]
+ mesh = pano_sheet_warping(
+ rgb,
+ distance,
+ rays,
+ excluded_region_mask,
+ connect_boundary_max_dist=connect_boundary_max_dist,
+ connect_boundary_repeat_times=connect_boundary_repeat_times,
+ max_size=max_size
+ )
+ return mesh
+
+
+def seed_all(seed: int = 0):
+ r"""
+ Set random seeds of all components.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def colorize_depth_maps(
+ depth: np.ndarray,
+ mask: np.ndarray = None,
+ normalize: bool = True,
+ cmap: str = 'Spectral'
+) -> np.ndarray:
+ r"""
+ Colorize depth maps using a colormap.
+ Args:
+ depth (np.ndarray): Depth map to colorize, shape (H, W).
+ mask (np.ndarray, optional): Optional mask to apply to the depth map, shape (H, W).
+ normalize (bool): Whether to normalize the depth values before colorization.
+ cmap (str): Name of the colormap to use.
+ Returns:
+ np.ndarray: Colorized depth map, shape (H, W, 3).
+ """
+ # moge vis function
+ if mask is None:
+ depth = np.where(depth > 0, depth, np.nan)
+ else:
+ depth = np.where((depth > 0) & mask, depth, np.nan)
+
+ # Convert depth to disparity (inverse of depth)
+ disp = 1 / depth # Closer objects have higher disparity values
+
+ # Set invalid disparity values to the 0.1% quantile (avoids extreme outliers)
+ if mask is not None:
+ disp[~((depth > 0) & mask)] = np.nanquantile(disp, 0.001)
+
+ # Normalize disparity values to [0,1] range if requested
+ if normalize:
+ min_disp, max_disp = np.nanquantile(
+ disp, 0.001), np.nanquantile(disp, 0.99)
+ disp = (disp - min_disp) / (max_disp - min_disp)
+ # Apply colormap (inverted so closer=warmer colors)
+ # Note: matplotlib colormaps return RGBA in [0,1] range
+ colored = np.nan_to_num(
+ matplotlib.colormaps[cmap](
+ 1.0 - disp)[..., :3], # Invert and drop alpha
+ nan=0 # Replace NaN with black
+ )
+ # Convert to uint8 and ensure contiguous memory layout
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+
+ return colored
diff --git a/hy3dworld/utils/inpaint_utils.py b/hy3dworld/utils/inpaint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0b07957bb87b77d1b8e86526e76afa41ec0a273
--- /dev/null
+++ b/hy3dworld/utils/inpaint_utils.py
@@ -0,0 +1,95 @@
+import torch
+import numpy as np
+import cv2
+import math
+from ..models import FluxFillPipeline
+
+
+def get_smooth_mask(general_mask, ksize=(120, 120)):
+ r"""Generate a smooth mask from the general mask using morphological dilation.
+ Args:
+ general_mask (np.ndarray): The input mask to be smoothed, expected to be a binary mask
+ with shape [H, W] and dtype uint8 (0 or 1).
+ ksize (tuple): The size of the structuring element used for dilation, specified as
+ (height, width). Default is (120, 120).
+ Returns:
+ np.ndarray: The smoothed mask, with the same shape as the input mask, where
+ the values are either 0 or 1 (uint8).
+ """
+ # Ensure kernel size is a tuple of integers
+ ksize = (int(ksize[0]), int(ksize[1]))
+
+ # Create rectangular structuring element for dilation
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, ksize)
+
+ # Apply dilation to expand mask regions
+ mask_array = cv2.dilate(general_mask.astype(
+ np.uint8), kernel) # [1024, 2048] uint8 1
+
+ # Convert back to binary mask
+ mask_array = (mask_array > 0).astype(np.uint8)
+
+ return mask_array
+
+
+def build_inpaint_model(model_path, lora_path, subfolder, device=0):
+ r"""Build the inpainting model pipeline.
+ Args:
+ model_path (str): The path to the pre-trained model.
+ lora_path (str): The path to the LoRA weights.
+ device (int): The device ID to load the model onto (default: 0).
+ Returns:
+ pipe: The inpainting pipeline object.
+ """
+ # Initialize pipeline with bfloat16 precision for memory efficiency
+ pipe = FluxFillPipeline.from_pretrained(
+ model_path, torch_dtype=torch.bfloat16).to(f"cuda:{device}")
+ pipe.load_lora_weights(
+ lora_path,
+ subfolder=subfolder,
+ weight_name="lora.safetensors", # default weight name
+ torch_dtype=torch.bfloat16
+ )
+ pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU
+ pipe.device_id = device
+ return pipe
+
+
+def get_adaptive_smooth_mask_ksize_ctrl(general_masks, mask_infos, basek=100, threshold=10000, r=1):
+ r"""Generate a smooth mask with adaptive kernel size control based on mask area.
+ Args:
+ general_masks (np.ndarray): The input mask array, expected to be a 2D array of shape [H, W]
+ where each pixel value corresponds to a mask ID.
+ mask_infos (list): A list of dictionaries containing information about each mask, including
+ the area and label of the mask.
+ basek (int): The base kernel size for smoothing, default is 100.
+ threshold (int): The area threshold to determine the scaling factor for the kernel size,
+ default is 10000.
+ r (int): A scaling factor for the kernel size, default is 1.
+ Returns:
+ np.ndarray: The smoothed mask array, with the same shape as the input mask,
+ where the values are either 0 or 1 (uint8).
+ """
+ # Initialize output mask
+ mask_array = np.zeros_like(general_masks).astype(np.bool_)
+
+ # Process each mask region individually
+ for i in range(len(mask_infos)):
+ mask_info = mask_infos[i]
+ area = mask_info["area"]
+
+ # Calculate size ratio with threshold clamping
+ ratio = area / threshold
+ ratio = math.sqrt(min(ratio, 1.0))
+
+ # Extract current object mask
+ mask = (general_masks == i + 1).astype(np.uint8)
+
+ # Default kernel for other objects
+ mask = get_smooth_mask(mask, ksize=(
+ int(basek*ratio)*r, int((basek+10)*ratio)*r)).astype(np.bool_)
+
+ # Combine with existing masks
+ mask_array = np.logical_or(mask_array, mask)
+
+ return mask_array.astype(np.uint8)
diff --git a/hy3dworld/utils/layer_utils.py b/hy3dworld/utils/layer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fc0c6e5497ea69c3c4cf002e87fc3961364a3e
--- /dev/null
+++ b/hy3dworld/utils/layer_utils.py
@@ -0,0 +1,351 @@
+import os
+import cv2
+import json
+import torch
+import gc # Added for garbage collection
+
+from tqdm import tqdm
+from PIL import Image
+
+import numpy as np
+from ..utils import sr_utils, seg_utils, inpaint_utils
+
+
+class ImageProcessingPipeline:
+ """Base class for image processing pipelines with common functionality"""
+
+ def __init__(self, params):
+ """Initialize pipeline with processing parameters"""
+ self.params = params
+ self.seed = self._init_seed(params['seed'])
+
+ def _init_seed(self, seed_param):
+ """Initialize random seed for reproducibility"""
+ if seed_param == -1:
+ import random
+ return random.randint(1, 65535)
+ return seed_param
+
+ def _prepare_output_dir(self, output_path):
+ """Create output directory if it doesn't exist"""
+ os.makedirs(output_path, exist_ok=True)
+
+ def _prepare_image_path(self, img_path, output_path):
+ """Create basic input image if it doesn't exist"""
+ full_image_path = f"{output_path}/full_image.png"
+ image = Image.open(img_path)
+ image.save(full_image_path)
+
+ def _get_image_path(self, base_dir, priority_files):
+ """Get image path based on priority of existing files"""
+ for file in priority_files:
+ path = os.path.join(base_dir, file)
+ if os.path.exists(path):
+ return path
+ return os.path.join(base_dir, "full_image.png")
+
+ def _process_mask(self, mask_path, base_dir, size, mask_infos_key, edge_padding: int = 20):
+ """Process mask with dilation and smoothing"""
+ mask_sharp = cv2.imread(os.path.join(base_dir, mask_path), 0)
+ with open(os.path.join(base_dir, f'{mask_infos_key}.json')) as f:
+ mask_infos = json.load(f)["bboxes"]
+
+ mask_smooth = inpaint_utils.get_adaptive_smooth_mask_ksize_ctrl(
+ mask_sharp, mask_infos,
+ basek=self.params['dilation_size'],
+ threshold=self.params['threshold'],
+ r=self.params['ratio']
+ )
+
+ # Apply edge padding
+ mask_smooth[:, 0:edge_padding] = 1
+ mask_smooth[:, -edge_padding:] = 1
+ return cv2.resize(mask_smooth, (size[1], size[0]), Image.BILINEAR)
+
+ def _run_inpainting(self, image, mask, size, prompt_config, image_info, inpaint_model):
+ """Run inpainting with configured parameters"""
+ labels = image_info["labels"]
+
+ # process prompt
+ if self._is_indoor(image_info):
+ prompt = prompt_config["indoor"]["positive_prompt"]
+ negative_prompt = prompt_config["indoor"]["negative_prompt"]
+ else:
+ prompt = prompt_config["outdoor"]["positive_prompt"]
+ negative_prompt = prompt_config["outdoor"]["negative_prompt"]
+
+ if labels:
+ negative_prompt += ", " + ", ".join(labels)
+
+ result = inpaint_model(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ mask_image=mask,
+ height=size[0],
+ width=size[1],
+ strength=self.params['strength'],
+ true_cfg_scale=self.params['cfg_scale'],
+ guidance_scale=30,
+ num_inference_steps=50,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(self.seed),
+ ).images[0]
+
+ # Clear memory after inpainting
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ return result
+
+ def _is_indoor(self, img_info):
+ """Check if image is classified as indoor"""
+ return img_info["class"] in ["indoor", "[indoor]"]
+
+ def _run_super_resolution(self, input_path, output_path, sr_model, suffix='sr'):
+ """Run super-resolution on input image"""
+ if os.path.exists(input_path):
+ sr_utils.sr_inference(
+ input_path, output_path, sr_model,
+ scale=self.params['scale'], ext='auto', suffix=suffix
+ )
+ # Clear memory after super-resolution
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+class ForegroundPipeline(ImageProcessingPipeline):
+ """Pipeline for processing foreground layers (fg1 and fg2)"""
+
+ def __init__(self, params, layer):
+ """Initialize with parameters and layer type (0 for fg1, 1 for fg2)"""
+ super().__init__(params)
+ self.layer = layer
+ self.layer_name = f"fg{layer+1}"
+
+ def process(self, img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model):
+ """Run full processing pipeline for foreground layer"""
+ print(f"============= Now starting {self.layer_name} processing ===============")
+
+ # Phase 1: Super Resolution
+ self._process_super_resolution(img_infos, sr_model)
+
+ # Phase 2: Segmentation
+ self._process_segmentation(img_infos, zim_predictor, gd_processor, gd_model)
+
+ # Phase 3: Inpainting
+ self._process_inpainting(img_infos, inpaint_model)
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def _process_super_resolution(self, img_infos, sr_model):
+ """Process super-resolution phase"""
+ for img_info in tqdm(img_infos):
+ output_path = img_info["output_path"]
+ # prepare input image
+ if self.layer == 0:
+ self._prepare_image_path(img_info["image_path"], output_path)
+ input_path = self._get_image_path(output_path, [f"remove_fg1_image.png", "full_image.png"])
+ self._prepare_output_dir(output_path)
+ self._run_super_resolution(input_path, output_path, sr_model)
+
+ def _process_segmentation(self, img_infos, zim_predictor, gd_processor, gd_model):
+ """Process segmentation phase"""
+ for img_info in tqdm(img_infos):
+ if not img_info.get("labels"):
+ continue
+
+ output_path = img_info["output_path"]
+ img_path = self._get_image_path(output_path, [f"remove_fg1_image.png", "full_image.png"])
+ img_sr_path = img_path.replace(".png", "_sr.png")
+ text = ". ".join(img_info["labels"]) + "." if img_info["labels"] else ""
+
+ if self._is_indoor(img_info):
+ seg_utils.get_fg_pad_indoor(
+ output_path, img_path, img_sr_path,
+ zim_predictor, gd_processor, gd_model,
+ text, layer=self.layer, scale=self.params['scale']
+ )
+ else:
+ seg_utils.get_fg_pad_outdoor(
+ output_path, img_path, img_sr_path,
+ zim_predictor, gd_processor, gd_model,
+ text, layer=self.layer, scale=self.params['scale']
+ )
+
+ # Clear memory after segmentation
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def _process_inpainting(self, img_infos, inpaint_model):
+ """Process inpainting phase"""
+ for img_info in tqdm(img_infos):
+ base_dir = img_info["output_path"]
+ mask_path = f'{self.layer_name}_mask.png'
+
+ if not os.path.exists(os.path.join(base_dir, mask_path)):
+ continue
+
+ image = Image.open(self._get_image_path(
+ base_dir,
+ [f"remove_fg{self.layer}_image.png", "full_image.png"]
+ )).convert('RGB')
+
+ size = image.height, image.width
+ mask_smooth = self._process_mask(
+ mask_path, base_dir, size, self.layer_name
+ )
+ pano_mask_pil = Image.fromarray(mask_smooth*255)
+
+ result = self._run_inpainting(
+ image, pano_mask_pil, size,
+ self.params['prompt_config'], img_info, inpaint_model
+ )
+ result.save(f'{base_dir}/remove_{self.layer_name}_image.png')
+
+ # Clear memory after saving result
+ del image, mask_smooth, pano_mask_pil, result
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+class SkyPipeline(ImageProcessingPipeline):
+ """Pipeline for processing sky layer"""
+
+ def process(self, img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model):
+ """Run full processing pipeline for sky layer"""
+ print("============= Now starting sky processing ===============")
+
+ # Phase 1: Super Resolution
+ self._process_super_resolution(img_infos, sr_model)
+
+ # Phase 2: Segmentation
+ self._process_segmentation(img_infos, zim_predictor, gd_processor, gd_model)
+
+ # Phase 3: Inpainting
+ self._process_inpainting(img_infos, inpaint_model)
+
+ # Phase 4: Final Super Resolution
+ self._process_final_super_resolution(img_infos, sr_model)
+
+ # Clear all models from memory after processing
+ self._clear_models([sr_model, zim_predictor, gd_processor, gd_model, inpaint_model])
+
+ def _clear_models(self, models):
+ """Clear model weights from memory"""
+ for model in models:
+ if hasattr(model, 'cpu'):
+ model.cpu()
+ if hasattr(model, 'to'):
+ model.to('cpu')
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def _process_super_resolution(self, img_infos, sr_model):
+ """Process initial super-resolution phase"""
+ for img_info in tqdm(img_infos):
+ output_path = img_info["output_path"]
+ self._prepare_output_dir(output_path)
+ input_path = f"{output_path}/remove_fg2_image.png"
+ self._run_super_resolution(input_path, output_path, sr_model)
+
+ def _process_segmentation(self, img_infos, zim_predictor, gd_processor, gd_model):
+ """Process segmentation phase for sky"""
+ for img_info in tqdm(img_infos):
+ if self._is_indoor(img_info):
+ continue
+
+ output_path = img_info["output_path"]
+ img_path = self._get_image_path(
+ output_path,
+ ["remove_fg2_image.png", "remove_fg1_image.png", "full_image.png"]
+ )
+ img_sr_path = img_path.replace(".png", "_sr.png")
+
+ seg_utils.get_sky(
+ output_path, img_path, img_sr_path,
+ zim_predictor, gd_processor, gd_model, "sky."
+ )
+
+ # Clear memory after segmentation
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def _process_inpainting(self, img_infos, inpaint_model):
+ """Process inpainting phase for sky"""
+ for img_info in tqdm(img_infos):
+ if self._is_indoor(img_info):
+ continue
+
+ base_dir = img_info["output_path"]
+ if not os.path.exists(os.path.join(base_dir, 'sky_mask.png')):
+ continue
+
+ image = Image.open(self._get_image_path(
+ base_dir,
+ ["remove_fg2_image.png", "remove_fg1_image.png", "full_image.png"]
+ )).convert('RGB')
+
+ size = image.height, image.width
+ mask_sharp = Image.open(os.path.join(base_dir, 'sky_mask.png')).convert('L')
+ mask_smooth = inpaint_utils.get_smooth_mask(np.asarray(mask_sharp))
+
+ # Apply edge padding
+ mask_smooth[:, 0:20] = 1
+ mask_smooth[:, -20:] = 1
+ mask_smooth = cv2.resize(mask_smooth, (size[1], size[0]), Image.BILINEAR)
+ pano_mask_pil = Image.fromarray(mask_smooth*255)
+
+ # Sky-specific inpainting parameters
+ prompt = "sky-coverage, whole sky image, ultra-high definition stratosphere"
+ negative_prompt = ("object, text, defocus, pure color, low-res, blur, pixelation, foggy, "
+ "noise, mosaic, artifacts, low-contrast, low-quality, blurry, tree, "
+ "grass, plant, ground, land, mountain, building, lake, river, sea, ocean")
+
+ result = inpaint_model(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ mask_image=pano_mask_pil,
+ height=size[0],
+ width=size[1],
+ strength=self.params['strength'],
+ true_cfg_scale=self.params['cfg_scale'],
+ guidance_scale=20,
+ num_inference_steps=50,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(self.seed),
+ ).images[0]
+ result.save(f'{base_dir}/sky_image.png')
+
+ # Clear memory after saving result
+ del image, mask_sharp, mask_smooth, pano_mask_pil, result
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def _process_final_super_resolution(self, img_infos, sr_model):
+ """Process final super-resolution phase"""
+ for img_info in tqdm(img_infos):
+ output_path = img_info["output_path"]
+ input_path = f"{output_path}/sky_image.png"
+ self._run_super_resolution(input_path, output_path, sr_model)
+
+
+# Original functions refactored to use the new pipeline classes
+def remove_fg1_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params):
+ """Process the first foreground layer (fg1)"""
+ pipeline = ForegroundPipeline(params, layer=0)
+ pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model)
+
+
+def remove_fg2_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params):
+ """Process the second foreground layer (fg2)"""
+ pipeline = ForegroundPipeline(params, layer=1)
+ pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model)
+
+
+def sky_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params):
+ """Process the sky layer"""
+ pipeline = SkyPipeline(params)
+ pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model)
diff --git a/hy3dworld/utils/pano_depth_utils.py b/hy3dworld/utils/pano_depth_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1808ce426aa9b94655c030cf73952b574a2425c4
--- /dev/null
+++ b/hy3dworld/utils/pano_depth_utils.py
@@ -0,0 +1,336 @@
+import cv2
+import numpy as np
+import torch
+import utils3d
+from PIL import Image
+
+from moge.model.v1 import MoGeModel
+from moge.utils.panorama import (
+ get_panorama_cameras,
+ split_panorama_image,
+ merge_panorama_depth,
+)
+from .general_utils import spherical_uv_to_directions
+
+
+# from https://github.com/lpiccinelli-eth/UniK3D/unik3d/utils/coordinate.py
+def coords_grid(b, h, w):
+ r"""
+ Generate a grid of pixel coordinates in the range [0.5, W-0.5] and [0.5, H-0.5].
+ Args:
+ b (int): Batch size.
+ h (int): Height of the grid.
+ w (int): Width of the grid.
+ Returns:
+ grid (torch.Tensor): A tensor of shape [B, 2, H, W] containing the pixel coordinates.
+ """
+ # Create pixel coordinates in the range [0.5, W-0.5] and [0.5, H-0.5]
+ pixel_coords_x = torch.linspace(0.5, w - 0.5, w)
+ pixel_coords_y = torch.linspace(0.5, h - 0.5, h)
+
+ # Stack the pixel coordinates to create a grid
+ stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()]
+
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
+
+ return grid
+
+
+def build_depth_model(device: torch.device = "cuda"):
+ r"""
+ Build the MoGe depth model for panorama depth prediction.
+ Args:
+ device (torch.device): The device to load the model onto (e.g., "cuda" or "cpu").
+ Returns:
+ model (MoGeModel): The MoGe depth model instance.
+ """
+ # Load model from pretrained weights
+ model = MoGeModel.from_pretrained("Ruicheng/moge-vitl")
+ model.eval()
+ model = model.to(device)
+ return model
+
+
+def smooth_south_pole_depth(depth_map, smooth_height_ratio=0.03, lower_quantile=0.1, upper_quantile=0.9):
+ """
+ Smooth depth values at the south pole (bottom) of a panorama to address inconsistencies.
+ Args:
+ depth_map (np.ndarray): Input depth map, shape (H, W).
+ smooth_height_ratio (float): Ratio of the height to smooth, typically a small value like 0.03.
+ lower_quantile (float): The lower quantile for outlier filtering.
+ upper_quantile (float): The upper quantile for outlier filtering.
+ Returns:
+ np.ndarray: Smoothed depth map.
+ """
+ height, width = depth_map.shape
+ smooth_height = int(height * smooth_height_ratio)
+
+ if smooth_height == 0:
+ return depth_map
+
+ # Create copy to avoid modifying original
+ smoothed_depth = depth_map.copy()
+
+ # Calculate reference depth from bottom rows:
+ # When the number of rows is greater than 3, use the last 3 rows; otherwise, use the bottom row
+ if smooth_height > 3:
+ # Calculate the average depth using the last 3 rows
+ reference_rows = depth_map[-3:, :]
+ reference_data = reference_rows.flatten()
+ else:
+ # Use the bottom row
+ reference_data = depth_map[-1, :]
+
+ # Filter outliers: including invalid values, depth that is too large or too small
+ valid_mask = np.isfinite(reference_data) & (reference_data > 0)
+
+ if np.any(valid_mask):
+ valid_depths = reference_data[valid_mask]
+
+ # Use quantiles to filter extreme outliers.
+ lower_bound, upper_bound = np.quantile(valid_depths, [lower_quantile, upper_quantile])
+
+ # Further filter out depth values that are too large or too small
+ depth_filter_mask = (valid_depths >= lower_bound) & (
+ valid_depths <= upper_bound
+ )
+
+ if np.any(depth_filter_mask):
+ avg_depth = np.mean(valid_depths[depth_filter_mask])
+ else:
+ # If all values are filtered out, use the median as an alternative
+ avg_depth = np.median(valid_depths)
+ else:
+ avg_depth = np.nanmean(reference_data)
+
+ # Set the bottom row as the average value
+ smoothed_depth[-1, :] = avg_depth
+
+ # Smooth upwards to the specified height
+ for i in range(1, smooth_height):
+ y_idx = height - 1 - i # Index from bottom to top
+ if y_idx < 0:
+ break
+
+ # Calculate smoothness weight: The closer to the bottom, the stronger the smoothness
+ weight = (smooth_height - i) / smooth_height
+
+ # Smooth the current row
+ current_row = depth_map[y_idx, :]
+ valid_mask = np.isfinite(current_row) & (current_row > 0)
+
+ if np.any(valid_mask):
+ valid_row_depths = current_row[valid_mask]
+
+ # Apply outlier filtering to the current row as well
+ if len(valid_row_depths) > 1:
+ q25, q75 = np.quantile(valid_row_depths, [0.25, 0.75])
+ iqr = q75 - q25
+ lower_bound = q25 - 1.5 * iqr
+ upper_bound = q75 + 1.5 * iqr
+ depth_filter_mask = (valid_row_depths >= lower_bound) & (
+ valid_row_depths <= upper_bound
+ )
+
+ if np.any(depth_filter_mask):
+ row_avg = np.mean(valid_row_depths[depth_filter_mask])
+ else:
+ row_avg = np.median(valid_row_depths)
+ else:
+ row_avg = (
+ valid_row_depths[0] if len(valid_row_depths) > 0 else avg_depth
+ )
+
+ # Linear interpolation: between the original depth and the average depth
+ smoothed_depth[y_idx, :] = (1 - weight) * current_row + weight * row_avg
+
+ return smoothed_depth
+
+
+def pred_pano_depth(
+ model,
+ image: Image.Image,
+ img_name: str,
+ scale=1.0,
+ resize_to=1920,
+ remove_pano_depth_nan=True,
+ last_layer_mask=None,
+ last_layer_depth=None,
+ verbose=False,
+) -> dict:
+ r"""
+ Predict panorama depth using the MoGe model.
+ Args:
+ model (MoGeModel): The MoGe depth model instance.
+ image (Image.Image): Input panorama image.
+ img_name (str): Name of the image for saving outputs.
+ scale (float): Scale factor for resizing the image.
+ resize_to (int): Target size for resizing the image.
+ remove_pano_depth_nan (bool): Whether to remove NaN values from the predicted depth.
+ last_layer_mask (np.ndarray, optional): Mask from the last layer for inpainting.
+ last_layer_depth (dict, optional): Last layer depth information containing distance maps and masks.
+ verbose (bool): Whether to print verbose information.
+ Returns:
+ dict: A dictionary containing the predicted depth maps and masks.
+ """
+ if verbose:
+ print("\t - Predicting pano depth with moge")
+
+ # Process input image
+ image_origin = np.array(image)
+ height_origin, width_origin = image_origin.shape[:2]
+ image, height, width = image_origin, height_origin, width_origin
+
+ # Resize if needed
+ if resize_to is not None:
+ _height, _width = min(
+ resize_to, int(resize_to * height_origin / width_origin)
+ ), min(resize_to, int(resize_to * width_origin / height_origin))
+ if _height < height_origin:
+ if verbose:
+ print(
+ f"\t - Resizing image from {width_origin}x{height_origin} \
+ to {_width}x{_height} for pano depth prediction"
+ )
+ image = cv2.resize(image_origin, (_width, _height), cv2.INTER_AREA)
+ height, width = _height, _width
+ # Split panorama into multiple views
+ splitted_extrinsics, splitted_intriniscs = get_panorama_cameras()
+ splitted_resolution = 512
+ splitted_images = split_panorama_image(
+ image, splitted_extrinsics, splitted_intriniscs, splitted_resolution
+ )
+
+ # Handle inpainting masks if provided
+ splitted_inpaint_masks = None
+ if last_layer_mask is not None and last_layer_depth is not None:
+ splitted_inpaint_masks = split_panorama_image(
+ last_layer_mask,
+ splitted_extrinsics,
+ splitted_intriniscs,
+ splitted_resolution,
+ )
+
+ # infer moge depth
+ num_splitted_images = len(splitted_images)
+ splitted_distance_maps = [None] * num_splitted_images
+ splitted_masks = [None] * num_splitted_images
+
+ indices_to_process_model = []
+ skipped_count = 0
+
+ # Determine which images need processing
+ for i in range(num_splitted_images):
+ if splitted_inpaint_masks is not None and splitted_inpaint_masks[i].sum() == 0:
+ # Use depth from the previous layer for non-inpainted (masked) regions
+ splitted_distance_maps[i] = last_layer_depth["splitted_distance_maps"][i]
+ splitted_masks[i] = last_layer_depth["splitted_masks"][i]
+ skipped_count += 1
+ else:
+ indices_to_process_model.append(i)
+
+ pred_count = 0
+ # Process images that require model inference in batches
+ inference_batch_size = 1
+ for i in range(0, len(indices_to_process_model), inference_batch_size):
+ batch_indices = indices_to_process_model[i : i + inference_batch_size]
+ if not batch_indices:
+ continue
+ # Prepare batch
+ current_batch_images = [splitted_images[k] for k in batch_indices]
+ current_batch_intrinsics = [splitted_intriniscs[k] for k in batch_indices]
+ # Convert to tensor and normalize
+ image_tensor = torch.tensor(
+ np.stack(current_batch_images) / 255,
+ dtype=torch.float32,
+ device=next(model.parameters()).device,
+ ).permute(0, 3, 1, 2)
+ # Calculate field of view
+ fov_x, _ = np.rad2deg( # fov_y is not used by model.infer
+ utils3d.numpy.intrinsics_to_fov(np.array(current_batch_intrinsics))
+ )
+ fov_x_tensor = torch.tensor(
+ fov_x, dtype=torch.float32, device=next(model.parameters()).device
+ )
+ # Run inference
+ output = model.infer(image_tensor, fov_x=fov_x_tensor, apply_mask=False)
+
+ batch_distance_maps = output["points"].norm(dim=-1).cpu().numpy()
+ batch_masks = output["mask"].cpu().numpy()
+ # Store results
+ for batch_idx, original_idx in enumerate(batch_indices):
+ splitted_distance_maps[original_idx] = batch_distance_maps[batch_idx]
+ splitted_masks[original_idx] = batch_masks[batch_idx]
+ pred_count += 1
+
+ if verbose:
+ # Print processing statistics
+ if (
+ pred_count + skipped_count
+ ) == 0: # Avoid division by zero if num_splitted_images is 0
+ skip_ratio_info = "N/A (no images to process)"
+ else:
+ skip_ratio_info = f"{skipped_count / (pred_count + skipped_count):.2%}"
+ print(
+ f"\t 🔍 Predicted {pred_count} splitted images, \
+ skipped {skipped_count} splitted images. Skip ratio: {skip_ratio_info}"
+ )
+
+ # merge moge depth
+ merging_width, merging_height = width, height
+ panorama_depth, panorama_mask = merge_panorama_depth(
+ merging_width,
+ merging_height,
+ splitted_distance_maps,
+ splitted_masks,
+ splitted_extrinsics,
+ splitted_intriniscs,
+ )
+ # Post-process depth map
+ panorama_depth = panorama_depth.astype(np.float32)
+ # Align the depth of the bottom 0.03 area on both sides of the dano depth
+ if remove_pano_depth_nan:
+ # for depth inpainting, remove nan
+ panorama_depth[~panorama_mask] = 1.0 * np.nanquantile(
+ panorama_depth, 0.999
+ ) # sky depth
+ panorama_depth = cv2.resize(
+ panorama_depth, (width_origin, height_origin), cv2.INTER_LINEAR
+ )
+ panorama_mask = (
+ cv2.resize(
+ panorama_mask.astype(np.uint8),
+ (width_origin, height_origin),
+ cv2.INTER_NEAREST,
+ )
+ > 0
+ )
+
+ # Smooth the depth of the South Pole (bottom area) to solve the problem of left and right inconsistency
+ if img_name in ["background", "full_img"]:
+ if verbose:
+ print("\t - Smoothing south pole depth for consistency")
+ panorama_depth = smooth_south_pole_depth(
+ panorama_depth, smooth_height_ratio=0.05
+ )
+
+ rays = torch.from_numpy(
+ spherical_uv_to_directions(
+ utils3d.numpy.image_uv(width=width_origin, height=height_origin)
+ )
+ ).to(next(model.parameters()).device)
+
+ panorama_depth = (
+ torch.from_numpy(panorama_depth).to(next(model.parameters()).device) * scale
+ )
+
+ return {
+ "type": "",
+ "rgb": torch.from_numpy(image_origin).to(next(model.parameters()).device),
+ "distance": panorama_depth,
+ "rays": rays,
+ "mask": panorama_mask,
+ "splitted_masks": splitted_masks,
+ "splitted_distance_maps": splitted_distance_maps,
+ }
diff --git a/hy3dworld/utils/perspective_utils.py b/hy3dworld/utils/perspective_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c75853fd3b0bfb8e6ff1c4b8fe5badea38f1522
--- /dev/null
+++ b/hy3dworld/utils/perspective_utils.py
@@ -0,0 +1,108 @@
+import cv2
+import numpy as np
+
+
+class Perspective:
+ r"""Convert perspective image to equirectangular image.
+ Args:
+ img_name (str or np.ndarray): The name of the image file or the image array.
+ FOV (float): The field of view of the image in degrees.
+ THETA (float): The left/right angle in degrees.
+ PHI (float): The up/down angle in degrees.
+ img_width (int): The width of the output equirectangular image.
+ img_height (int): The height of the output equirectangular image.
+ crop_bound (bool): Whether to crop the boundary area of the image proportionally.
+ """
+ def __init__(self, img_name=None, FOV=None, THETA=None, PHI=None, img_width=512, img_height=512, crop_bound=False):
+ # either img_name is provided, or img_width/img_height and img in function GetEquirec is provided
+ self.crop_bound = crop_bound
+ if img_name is not None:
+ # Load the image
+ if isinstance(img_name, str):
+ self._img = cv2.imread(img_name, cv2.IMREAD_COLOR)
+ elif isinstance(img_name, np.ndarray):
+ self._img = img_name
+ [self._height, self._width, _] = self._img.shape
+ # Crop the boundary area of the image proportionally
+ if self.crop_bound:
+ self._img = self._img[int(
+ self._height*0.05):int(self._height*0.95), int(self._width*0.05):int(self._width*0.95), :]
+
+ [self._height, self._width, _] = self._img.shape
+ else:
+ self._img = None
+ self._height = img_height
+ self._width = img_width
+
+ self.THETA = THETA
+ self.PHI = PHI
+ if self._width > self._height:
+ self.wFOV = FOV
+ self.hFOV = (float(self._height) / self._width) * FOV
+ else:
+ self.wFOV = (float(self._width) / self._height) * FOV
+ self.hFOV = FOV
+
+ self.w_len = np.tan(np.radians(self.wFOV / 2.0))
+ self.h_len = np.tan(np.radians(self.hFOV / 2.0))
+
+ def GetEquirec(self, height, width, img=None):
+ #
+ # THETA is left/right angle, PHI is up/down angle, both in degree
+ #
+ if self._img is None:
+ self._img = img
+ # Calculate the equirectangular coordinates
+ x, y = np.meshgrid(np.linspace(-180, 180, width),
+ np.linspace(90, -90, height))
+ # Convert spherical coordinates to Cartesian coordinates
+ x_map = np.cos(np.radians(x)) * np.cos(np.radians(y))
+ y_map = np.sin(np.radians(x)) * np.cos(np.radians(y))
+ z_map = np.sin(np.radians(y))
+ # Stack the coordinates to form a 3D array
+ xyz = np.stack((x_map, y_map, z_map), axis=2)
+ # Reshape the coordinates to match the image dimensions
+ y_axis = np.array([0.0, 1.0, 0.0], np.float32)
+ z_axis = np.array([0.0, 0.0, 1.0], np.float32)
+ # Calculate the rotation matrices
+ [R1, _] = cv2.Rodrigues(z_axis * np.radians(self.THETA))
+ [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(-self.PHI))
+ # Invert rotations to transform from equirectangular to perspective
+ R1 = np.linalg.inv(R1)
+ R2 = np.linalg.inv(R2)
+ # Apply rotations
+ xyz = xyz.reshape([height * width, 3]).T
+ xyz = np.dot(R2, xyz)
+ xyz = np.dot(R1, xyz).T
+ xyz = xyz.reshape([height, width, 3])
+ # Create mask for valid forward-facing points (x > 0)
+ inverse_mask = np.where(xyz[:, :, 0] > 0, 1, 0)
+ # Normalize coordinates by x-component (perspective division)
+ xyz[:, :] = xyz[:, :] / \
+ np.repeat(xyz[:, :, 0][:, :, np.newaxis], 3, axis=2)
+ # Map 3D points back to 2D perspective image coordinates
+ lon_map = np.where(
+ (-self.w_len < xyz[:, :, 1]) & (xyz[:, :, 1] < self.w_len) \
+ & (-self.h_len < xyz[:, :, 2]) & (xyz[:, :, 2] < self.h_len),
+ (xyz[:, :, 1]+self.w_len)/2/self.w_len*self._width,
+ 0)
+ lat_map = np.where(
+ (-self.w_len < xyz[:, :, 1]) & (xyz[:, :, 1] < self.w_len) \
+ & (-self.h_len < xyz[:, :, 2]) & (xyz[:, :, 2] < self.h_len),
+ (-xyz[:, :, 2]+self.h_len) /
+ 2/self.h_len*self._height,
+ 0)
+ mask = np.where(
+ (-self.w_len < xyz[:, :, 1]) & (xyz[:, :, 1] < self.w_len) \
+ & (-self.h_len < xyz[:, :, 2]) & (xyz[:, :, 2] < self.h_len),
+ 1,
+ 0)
+ # Remap the image using the longitude and latitude maps
+ persp = cv2.remap(self._img, lon_map.astype(np.float32), lat_map.astype(
+ np.float32), cv2.INTER_CUBIC, borderMode=cv2.BORDER_WRAP) # BORDER_CONSTANT) #))
+ # Apply the mask to the equirectangular image
+ mask = mask * inverse_mask
+ mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
+ persp = persp * mask
+
+ return persp, mask
diff --git a/hy3dworld/utils/seg_utils.py b/hy3dworld/utils/seg_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..789d40aad13de4e27e0220323af5cf87c9e2c772
--- /dev/null
+++ b/hy3dworld/utils/seg_utils.py
@@ -0,0 +1,617 @@
+import cv2
+import json
+import torch
+import numpy as np
+from PIL import Image
+from skimage import morphology
+from typing import Optional, Tuple, List
+
+from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
+from zim_anything import zim_model_registry, ZimPredictor
+
+import os
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+
+
+class DetPredictor(ZimPredictor):
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ # Transform input prompts
+ coords_torch = None
+ labels_torch = None
+ box_torch = None
+
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.float, device=self.device)
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.transform.apply_boxes(box, self.original_size)
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+ if not return_logits:
+ masks = masks > 0.5
+
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+ iou_predictions_np = iou_predictions[0].squeeze(0).float().detach().cpu().numpy()
+ low_res_masks_np = low_res_masks[0].squeeze(0).float().detach().cpu().numpy()
+
+ return masks_np, iou_predictions_np, low_res_masks_np
+
+
+def build_gd_model(GROUNDING_MODEL, device="cuda"):
+ """Build Grounding DINO model from HuggingFace
+
+ Args:
+ GROUNDING_MODEL: Model identifier
+ device: Device to load model on (default: "cuda")
+
+ Returns:
+ processor: Model processor
+ grounding_model: Loaded model
+ """
+ model_id = GROUNDING_MODEL
+ processor = AutoProcessor.from_pretrained(model_id)
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(
+ model_id).to(device)
+
+ return processor, grounding_model
+
+
+def build_zim_model(ZIM_MODEL_CONFIG, ZIM_CHECKPOINT, device="cuda"):
+ """Build ZIM-Anything model from HuggingFace
+
+ Args:
+ ZIM_MODEL_CONFIG: Model configuration
+ ZIM_CHECKPOINT: Model checkpoint path
+ device: Device to load model on (default: "cuda")
+
+ Returns:
+ zim_predictor: Initialized ZIM predictor
+ """
+ zim_model = zim_model_registry[ZIM_MODEL_CONFIG](
+ checkpoint=ZIM_CHECKPOINT).to(device)
+ zim_predictor = DetPredictor(zim_model)
+ return zim_predictor
+
+
+def mask_nms(masks, scores, threshold=0.5):
+ """Perform Non-Maximum Suppression based on mask overlap
+
+ Args:
+ masks: Input masks tensor (N,H,W)
+ scores: Confidence scores for each mask
+ threshold: IoU threshold for suppression (default: 0.5)
+
+ Returns:
+ keep: Indices of kept masks
+ """
+ areas = torch.sum(masks, dim=(1, 2)) # [N,]
+ _, order = scores.sort(0, descending=True)
+
+ keep = []
+ while order.numel() > 0:
+ if order.numel() == 1:
+ i = order.item()
+ keep.append(i)
+ break
+ else:
+ i = order[0].item()
+ keep.append(i)
+
+ inter = torch.sum(torch.logical_and(
+ masks[order[1:]], masks[i]), dim=(1, 2)) # [N-1,]
+ min_areas = torch.minimum(areas[i], areas[order[1:]]) # [N-1,]
+ iomin = inter / min_areas
+ idx = (iomin <= threshold).nonzero().squeeze()
+ if idx.numel() == 0:
+ break
+ order = order[idx + 1]
+ return torch.LongTensor(keep)
+
+
+def filter_small_bboxes(results, max_num=100):
+ """Filter small bounding boxes to avoid memory overflow
+
+ Args:
+ results: Detection results containing boxes
+ max_num: Maximum number of boxes to keep (default: 100)
+
+ Returns:
+ keep: Indices of kept boxes
+ """
+ bboxes = results[0]["boxes"]
+ x1 = bboxes[:, 0]
+ y1 = bboxes[:, 1]
+ x2 = bboxes[:, 2]
+ y2 = bboxes[:, 3]
+ scores = (x2-x1)*(y2-y1)
+ _, order = scores.sort(0, descending=True)
+ keep = [order[i].item() for i in range(min(max_num, order.numel()))]
+ return torch.LongTensor(keep)
+
+
+def filter_by_general_score(results, score_threshold=0.35):
+ """Filter results by confidence score
+
+ Args:
+ results: Detection results
+ score_threshold: Minimum confidence score (default: 0.35)
+
+ Returns:
+ filtered_data: Filtered results
+ """
+ filtered_data = []
+ for entry in results:
+ scores = entry['scores']
+ labels = entry['labels']
+ mask = scores > score_threshold
+
+ filtered_scores = scores[mask]
+ filtered_boxes = entry['boxes'][mask]
+
+ mask_list = mask.tolist()
+ filtered_labels = [labels[i]
+ for i in range(len(labels)) if mask_list[i]]
+
+ filtered_entry = {
+ 'scores': filtered_scores,
+ 'labels': filtered_labels,
+ 'boxes': filtered_boxes
+ }
+ filtered_data.append(filtered_entry)
+
+ return filtered_data
+
+
+def filter_by_location(results, edge_threshold=20):
+ """Filter boxes near the left edge
+
+ Args:
+ results: Detection results
+ edge_threshold: Distance threshold from left edge (default: 20)
+
+ Returns:
+ keep: Indices of kept boxes
+ """
+ bboxes = results[0]["boxes"]
+ keep = []
+ for i in range(bboxes.shape[0]):
+ x1 = bboxes[i][0]
+ if x1 < edge_threshold:
+ continue
+ keep.append(i)
+ return torch.LongTensor(keep)
+
+
+def unpad_mask(results, masks, pad_len):
+ """Remove padding from masks and adjust boxes
+
+ Args:
+ results: Detection results
+ masks: Padded masks
+ pad_len: Padding length to remove
+
+ Returns:
+ results: Adjusted results
+ masks: Unpadded masks
+ """
+ results[0]["boxes"][:, 0] = results[0]["boxes"][:, 0] - pad_len
+ results[0]["boxes"][:, 2] = results[0]["boxes"][:, 2] - pad_len
+ for i in range(results[0]["boxes"].shape[0]):
+ if results[0]["boxes"][i][0] < 0:
+ results[0]["boxes"][i][0] += pad_len * 2
+ new_mask = torch.cat(
+ (masks[i][:, pad_len:pad_len*2], masks[i][:, :pad_len]), dim=1)
+ masks[i] = torch.cat((masks[i][:, :pad_len], new_mask), dim=1)
+ if results[0]["boxes"][i][2] < 0:
+ results[0]["boxes"][i][2] += pad_len * 2
+
+ return results, masks[:, :, pad_len:]
+
+
+def remove_small_objects(masks, min_size=1000):
+ """Remove small objects from masks
+
+ Args:
+ masks: Input masks
+ min_size: Minimum object size (default: 1000)
+
+ Returns:
+ masks: Cleaned masks
+ """
+ for i in range(masks.shape[0]):
+ masks[i] = morphology.remove_small_objects(
+ masks[i], min_size=min_size, connectivity=2)
+
+ return masks
+
+
+def remove_sky_floaters(mask, min_size=1000):
+ """Remove small disconnected regions from sky mask
+
+ Args:
+ mask: Input sky mask
+ min_size: Minimum region size (default: 1000)
+
+ Returns:
+ mask: Cleaned sky mask
+ """
+ mask = morphology.remove_small_objects(
+ mask, min_size=min_size, connectivity=2)
+
+ return mask
+
+
+def remove_disconnected_masks(masks):
+ """Remove masks with too many disconnected components
+
+ Args:
+ masks: Input masks
+
+ Returns:
+ keep: Indices of kept masks
+ """
+ keep = []
+ for i in range(masks.shape[0]):
+ binary = masks[i].astype(np.uint8) * 255
+ num, _ = cv2.connectedComponents(
+ binary, connectivity=8, ltype=cv2.CV_32S)
+ if num > 2:
+ continue
+ keep.append(i)
+ return torch.LongTensor(keep)
+
+
+def get_contours_sky(mask):
+ """Get contours of sky mask and fill them
+
+ Args:
+ mask: Input sky mask
+
+ Returns:
+ mask: Filled contour mask
+ """
+ binary = mask.astype(np.uint8) * 255
+
+ contours, _ = cv2.findContours(
+ binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ if len(contours) == 0:
+ return mask
+
+ mask = np.zeros_like(binary)
+
+ cv2.drawContours(mask, contours, -1, 1, -1)
+
+ return mask.astype(np.bool_)
+
+
+def get_fg_pad(
+ OUTPUT_DIR,
+ IMG_PATH,
+ IMG_SR_PATH,
+ zim_predictor,
+ processor,
+ grounding_model,
+ text,
+ layer,
+ scale=2,
+ is_outdoor=True
+):
+ """Process foreground layer with padding and segmentation
+
+ Args:
+ OUTPUT_DIR: Output directory
+ IMG_PATH: Input image path
+ IMG_SR_PATH: Super-resolved image path
+ zim_predictor: ZIM model predictor
+ processor: Grounding model processor
+ grounding_model: Grounding model
+ text: Text prompt for detection
+ layer: Layer identifier (0=fg1, else=fg2)
+ scale: Scaling factor (default: 2)
+ is_outdoor: Whether outdoor scene (default: True)
+ """
+ # Load and pad input image
+ image = cv2.imread(IMG_PATH, cv2.IMREAD_UNCHANGED)
+ pad_len = image.shape[1] // 2
+ image = cv2.copyMakeBorder(image, 0, 0, pad_len, 0, cv2.BORDER_WRAP)
+ image = Image.fromarray(image).convert("RGB")
+
+ # Process super-resolution image
+ image_sr = Image.open(IMG_SR_PATH)
+ H, W = image_sr.height, image_sr.width
+ image_sr = np.array(image_sr.convert("RGB"))
+ pad_len_sr = W // 2
+ image_sr_pad = cv2.copyMakeBorder(
+ image_sr, 0, 0, pad_len_sr, 0, cv2.BORDER_WRAP)
+ zim_predictor.set_image(image_sr_pad)
+
+ # Run object detection
+ inputs = processor(images=image, text=text, return_tensors="pt").to(
+ grounding_model.device)
+ with torch.no_grad():
+ outputs = grounding_model(**inputs)
+
+ # Process detection results
+ results = processor.post_process_grounded_object_detection(
+ outputs,
+ inputs.input_ids,
+ box_threshold=0.3,
+ text_threshold=0.3,
+ target_sizes=[image.size[::-1]]
+ )
+
+ saved_json = {"bboxes": []}
+
+ # Apply filters based on scene type
+ if is_outdoor:
+ results = filter_by_general_score(results, score_threshold=0.35)
+
+ location_keep = filter_by_location(results)
+ results[0]["boxes"] = results[0]["boxes"][location_keep]
+ results[0]["scores"] = results[0]["scores"][location_keep]
+ results[0]["labels"] = [results[0]["labels"][i] for i in location_keep]
+
+ # Prepare box prompts for ZIM
+ results[0]["boxes"] = results[0]["boxes"] * scale
+ filter_keep = filter_small_bboxes(results)
+ results[0]["boxes"] = results[0]["boxes"][filter_keep]
+ results[0]["scores"] = results[0]["scores"][filter_keep]
+ results[0]["labels"] = [results[0]["labels"][i] for i in filter_keep]
+ input_boxes = results[0]["boxes"].cpu().numpy()
+ if input_boxes.shape[0] == 0:
+ return
+
+ # Get masks from ZIM predictor
+ masks, scores, _ = zim_predictor.predict(
+ point_coords=None,
+ point_labels=None,
+ box=input_boxes,
+ multimask_output=False,
+ )
+ # Post-process masks
+ if masks.ndim == 4:
+ masks = masks.squeeze(1)
+
+ min_floater = 500
+ masks = masks.astype(np.bool_)
+ masks = remove_small_objects(masks, min_size=min_floater*(scale**2))
+ disconnect_keep = remove_disconnected_masks(masks)
+ masks = torch.tensor(masks).bool()[disconnect_keep]
+ results[0]["boxes"] = results[0]["boxes"][disconnect_keep]
+ results[0]["scores"] = results[0]["scores"][disconnect_keep]
+ results[0]["labels"] = [results[0]["labels"][i] for i in disconnect_keep]
+ results, masks = unpad_mask(results, masks, pad_len=pad_len_sr)
+
+ # Apply NMS
+ scores = torch.sum(masks, dim=(1, 2))
+ keep = mask_nms(masks, scores, threshold=0.5)
+ masks = masks[keep]
+ results[0]["boxes"] = results[0]["boxes"][keep]
+ results[0]["scores"] = results[0]["scores"][keep]
+ results[0]["labels"] = [results[0]["labels"][i] for i in keep]
+ if masks.shape[0] == 0:
+ return
+
+ # Create final foreground mask
+ fg_mask = np.zeros((H, W), dtype=np.uint8)
+ masks = masks.float().detach().cpu().numpy().astype(np.bool_)
+ if masks.shape[0] == 0:
+ return
+
+ cnt = 0
+ min_sum = 3000
+ name = "fg1" if layer == 0 else "fg2"
+
+ # Process each valid mask
+ for i in range(masks.shape[0]):
+ mask = masks[i]
+ if mask.sum() < min_sum*(scale**2):
+ continue
+ saved_json["bboxes"].append({
+ "label": results[0]["labels"][i],
+ "bbox": results[0]["boxes"][i].cpu().numpy().tolist(),
+ "score": results[0]["scores"][i].item(),
+ "area": int(mask.sum())
+ })
+ cnt += 1
+ fg_mask[mask] = cnt
+
+ if cnt == 0:
+ return
+
+ # Save outputs
+ with open(os.path.join(OUTPUT_DIR, f"{name}.json"), "w") as f:
+ json.dump(saved_json, f, indent=4)
+ Image.fromarray(fg_mask).save(os.path.join(OUTPUT_DIR, f"{name}_mask.png"))
+
+
+def get_fg_pad_outdoor(
+ OUTPUT_DIR,
+ IMG_PATH,
+ IMG_SR_PATH,
+ zim_predictor,
+ processor,
+ grounding_model,
+ text,
+ layer,
+ scale=2,
+):
+ """write the foreground layer outdoor"""
+ return get_fg_pad(
+ OUTPUT_DIR,
+ IMG_PATH,
+ IMG_SR_PATH,
+ zim_predictor,
+ processor,
+ grounding_model,
+ text,
+ layer,
+ scale=2,
+ is_outdoor=True
+ )
+
+
+def get_fg_pad_indoor(
+ OUTPUT_DIR,
+ IMG_PATH,
+ IMG_SR_PATH,
+ zim_predictor,
+ processor,
+ grounding_model,
+ text,
+ layer,
+ scale=2,
+):
+ """write the foreground layer indoor"""
+ return get_fg_pad(
+ OUTPUT_DIR,
+ IMG_PATH,
+ IMG_SR_PATH,
+ zim_predictor,
+ processor,
+ grounding_model,
+ text,
+ layer,
+ scale=2,
+ is_outdoor=False
+ )
+
+
+def get_sky(
+ OUTPUT_DIR,
+ IMG_PATH,
+ IMG_SR_PATH,
+ zim_predictor,
+ processor,
+ grounding_model,
+ text,
+ scale=2
+ ):
+ """Extract and process sky layer from input image
+
+ Args:
+ OUTPUT_DIR: Output directory
+ IMG_PATH: Input image path
+ IMG_SR_PATH: Super-resolved image path
+ zim_predictor: ZIM model predictor
+ processor: Grounding model processor
+ grounding_model: Grounding model
+ text: Text prompt for detection
+ scale: Scaling factor (default: 2)
+ """
+ # Load input images
+ image = Image.open(IMG_PATH).convert("RGB")
+ image_sr = Image.open(IMG_SR_PATH)
+ H, W = image_sr.height, image_sr.width
+ zim_predictor.set_image(np.array(image_sr.convert("RGB")))
+
+ # Run object detection
+ inputs = processor(images=image, text=text, return_tensors="pt").to(
+ grounding_model.device)
+ with torch.no_grad():
+ outputs = grounding_model(**inputs)
+
+ # Process detection results
+ results = processor.post_process_grounded_object_detection(
+ outputs,
+ inputs.input_ids,
+ box_threshold=0.3,
+ text_threshold=0.3,
+ target_sizes=[image.size[::-1]]
+ )
+
+ # Prepare box prompts for ZIM
+ results[0]["boxes"] = results[0]["boxes"] * scale
+ filter_keep = filter_small_bboxes(results)
+ results[0]["boxes"] = results[0]["boxes"][filter_keep]
+ results[0]["scores"] = results[0]["scores"][filter_keep]
+ results[0]["labels"] = [results[0]["labels"][i] for i in filter_keep]
+ input_boxes = results[0]["boxes"].cpu().numpy()
+
+ if input_boxes.shape[0] == 0:
+ sky_mask = np.zeros((H, W), dtype=np.bool_)
+ return
+
+ # Get masks from ZIM predictor
+ masks, _, _ = zim_predictor.predict(
+ point_coords=None,
+ point_labels=None,
+ box=input_boxes,
+ multimask_output=False,
+ )
+ # Post-process masks
+ if masks.ndim == 4:
+ masks = masks.squeeze(1)
+
+ # Combine all detected masks
+ sky_mask = np.zeros((H, W), dtype=np.bool_)
+ for i in range(masks.shape[0]):
+ mask = masks[i].astype(np.bool_)
+ sky_mask[mask] = 1
+
+ # Clean up sky mask
+ min_floater = 1000
+ sky_mask = sky_mask.astype(np.bool_)
+ sky_mask = get_contours_sky(sky_mask)
+ sky_mask = 1 - sky_mask # Invert to get sky area
+ sky_mask = sky_mask.astype(np.bool_)
+ sky_mask = remove_sky_floaters(sky_mask, min_size=min_floater*(scale**2))
+ sky_mask = get_contours_sky(sky_mask)
+
+ # Save output mask
+ Image.fromarray(sky_mask.astype(np.uint8) *
+ 255).save(os.path.join(OUTPUT_DIR, "sky_mask.png"))
diff --git a/hy3dworld/utils/sr_utils.py b/hy3dworld/utils/sr_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eecc062aa864eb742846960200b2f03ec08981d0
--- /dev/null
+++ b/hy3dworld/utils/sr_utils.py
@@ -0,0 +1,103 @@
+import os
+import cv2
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.download_util import load_file_from_url
+
+from realesrgan import RealESRGANer
+
+
+# build sr model
+def build_sr_model(scale=2, model_name=None, tile=0, tile_pad=10, pre_pad=0, fp32=False, gpu_id=None):
+ # if model_name not specified, use default mapping
+ if model_name is None:
+ if scale == 2:
+ model_name = 'RealESRGAN_x2plus'
+ else:
+ model_name = 'RealESRGAN_x4plus'
+
+ # model architecture configs
+ model_configs = {
+ 'RealESRGAN_x2plus': {
+ 'internal_scale': 2,
+ 'model': lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
+ 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
+ },
+ 'RealESRGAN_x4plus': {
+ 'internal_scale': 4,
+ 'model': lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
+ 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
+ }
+ }
+
+ if model_name not in model_configs:
+ raise ValueError(
+ f'Unknown model name: {model_name}. Available models: {list(model_configs.keys())}')
+
+ config = model_configs[model_name]
+ model = config['model']()
+ file_url = [config['url']]
+
+ model_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'weights', model_name + '.pth')
+ if not os.path.isfile(model_path):
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ for url in file_url:
+ # model_path will be updated
+ model_path = load_file_from_url(
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
+
+ # restorer
+ upsampler = RealESRGANer(
+ scale=config['internal_scale'], # Use the internal scale of the model
+ model_path=model_path,
+ dni_weight=None,
+ model=model,
+ tile=tile,
+ tile_pad=tile_pad,
+ pre_pad=pre_pad,
+ half=not fp32,
+ gpu_id=gpu_id)
+
+ return upsampler
+
+
+# sr inference code
+def sr_inference(input, output_path, upsampler, scale=2, ext='auto', suffix='sr'):
+ os.makedirs(output_path, exist_ok=True)
+
+ path = input
+ imgname, extension = os.path.splitext(os.path.basename(path))
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ width = img.shape[1]
+
+ # pad the image to make eliminate the border artifacts
+ pad_len = width // 4
+ img = cv2.copyMakeBorder(img, 0, 0, pad_len, pad_len, cv2.BORDER_WRAP)
+ if len(img.shape) == 3 and img.shape[2] == 4:
+ img_mode = 'RGBA'
+ else:
+ img_mode = None
+
+ try:
+ output, _ = upsampler.enhance(
+ img, outscale=scale) # Use the input scale as the final output amplification factor
+ # remove the padding
+ output = output[:, int(pad_len*scale):int((width+pad_len)*scale), :]
+ except RuntimeError as error:
+ print('Error', error)
+ print(
+ 'If you encounter CUDA out of memory, try to set --tile with a smaller number.')
+ else:
+ if ext == 'auto':
+ extension = extension[1:]
+ else:
+ extension = ext
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
+ extension = 'png'
+ if suffix == '':
+ save_path = os.path.join(output_path, f'{imgname}.{extension}')
+ else:
+ save_path = os.path.join(
+ output_path, f'{imgname}_{suffix}.{extension}')
+ cv2.imwrite(save_path, output)
diff --git a/modelviewer.html b/modelviewer.html
new file mode 100644
index 0000000000000000000000000000000000000000..38be4152bf74a69c6bf253e051c1cdc560c8b6be
--- /dev/null
+++ b/modelviewer.html
@@ -0,0 +1,382 @@
+
+
+
+ Simple PLY Viewer
+
+
+
+
+
+
+
Loading...
+
+
+
+
+
+
+ Controls: WASD to move, Mouse drag to look around
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..046bda82652486d0cbe1d5329c10299420578f1c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,192 @@
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch==2.2.2
+torchaudio==2.2.2
+torchvision==0.17.2
+
+# From conda environment
+absl-py==2.2.2
+accelerate==1.6.0
+addict==2.4.0
+aiofiles==24.1.0
+aiohappyeyeballs==2.6.1
+aiohttp==3.11.16
+aiosignal==1.3.1
+albumentations==0.5.2
+annotated-types==0.7.0
+antlr4-python3-runtime==4.8
+anyio==4.9.0
+asttokens==3.0.0
+async-timeout==5.0.1
+attrs==25.3.0
+av==14.3.0
+basicsr==1.4.2
+blinker==1.9.0
+braceexpand==0.1.7
+click==8.2.1
+cloudpickle==3.1.1
+cmake==4.0.3
+colorama==0.4.6
+coloredlogs==15.0.1
+configargparse==1.7.1
+contourpy==1.3.2
+cycler==0.12.1
+cython==3.0.11
+dash==3.1.1
+diffdist==0.1
+diffusers==0.32.0
+easydict==1.9
+einops==0.4.1
+eva-decord==0.6.1
+exceptiongroup==1.3.0
+executing==2.2.0
+facexlib==0.3.0
+fastapi==0.116.1
+fastjsonschema==2.21.1
+ffmpy==0.6.1
+filterpy==1.4.5
+flask==3.1.1
+flatbuffers==25.2.10
+fonttools==4.57.0
+frozenlist==1.6.0
+fsspec==2025.3.2
+ftfy==6.1.1
+future==1.0.0
+gfpgan==1.3.8
+glcontext==3.0.0
+gradio==4.42.0
+gradio-client==1.11.0
+groovy==0.1.2
+grpcio==1.71.0
+h11==0.16.0
+h5py==3.7.0
+httpcore==1.0.9
+httpx==0.28.1
+huggingface-hub==0.30.2
+humanfriendly==10.0
+hydra-core==1.1.0
+icecream==2.1.2
+imageio==2.37.0
+imageio-ffmpeg==0.4.9
+imgaug==0.4.0
+importlib-metadata==8.6.1
+inflect==5.6.0
+iopath==0.1.10
+itsdangerous==2.2.0
+joblib==1.4.2
+jsonschema==4.25.0
+jsonschema-specifications==2025.4.1
+jupyter-core==5.8.1
+kiwisolver==1.4.8
+kornia==0.8.0
+kornia-rs==0.1.8
+lazy-loader==0.4
+lightning-utilities==0.14.3
+llvmlite==0.44.0
+lmdb==1.6.2
+loguru==0.7.3
+markdown==3.8
+markdown-it-py==3.0.0
+matplotlib==3.10.1
+mdurl==0.1.2
+moderngl==5.12.0
+moge==2.0.0
+multidict==6.4.3
+narwhals==1.48.1
+natten==0.14.4
+nbformat==5.10.4
+nest-asyncio==1.6.0
+numba==0.61.2
+numpy==1.26.4
+omegaconf==2.1.2
+onnx==1.17.0
+onnxruntime==1.21.1
+open-clip-torch==2.30.0
+open3d==0.18.0
+opencv-python==4.11.0.86
+opencv-python-headless==4.11.0.86
+orjson==3.11.1
+packaging==24.2
+pandas==2.2.3
+peft==0.14.0
+platformdirs==4.3.7
+plotly==6.2.0
+plyfile==1.1
+portalocker==3.2.0
+propcache==0.3.1
+protobuf==5.29.3
+psutil==7.0.0
+py-cpuinfo==9.0.0
+py360convert==1.0.3
+pydantic==2.11.7
+pydantic-core==2.33.2
+pydub==0.25.1
+pygments==2.19.1
+pyparsing==3.2.3
+pyquaternion==0.9.9
+python-dateutil==2.9.0.post0
+python-multipart==0.0.20
+pytorch-lightning==2.4.0
+pytz==2025.2
+qwen-vl-utils==0.0.8
+referencing==0.36.2
+regex==2022.6.2
+retrying==1.4.1
+rich==14.0.0
+rpds-py==0.26.0
+ruff==0.12.5
+safehttpx==0.1.6
+safetensors==0.5.3
+scikit-image==0.24.0
+scikit-learn==1.6.1
+scipy==1.15.2
+seaborn==0.13.2
+segment-anything==1.0
+semantic-version==2.10.0
+sentencepiece==0.2.0
+setuptools==59.5.0
+shapely==2.0.7
+shellingham==1.5.4
+six==1.17.0
+sniffio==1.3.1
+starlette==0.47.2
+submitit==1.4.2
+sympy==1.13.1
+tabulate==0.9.0
+tb-nightly==2.20.0a20250421
+tensorboard-data-server==0.7.2
+termcolor==3.0.1
+threadpoolctl==3.6.0
+tifffile==2025.3.30
+timm==1.0.13
+tokenizers==0.21.1
+tomli==2.2.1
+tomlkit==0.13.3
+torchmetrics==1.7.1
+tqdm==4.67.1
+traitlets==5.14.3
+transformers==4.51.0
+trimesh==4.7.1
+typer==0.16.0
+typing-inspection==0.4.1
+tzdata==2025.2
+ultralytics==8.3.74
+ultralytics-thop==2.0.14
+utils3d==0.0.2
+uvicorn==0.35.0
+wcwidth==0.2.13
+webdataset==0.2.100
+websockets==15.0.1
+werkzeug==3.1.3
+wldhx-yadisk-direct==0.0.6
+yapf==0.43.0
+yarl==1.20.0
+zipp==3.21.0
+
+# GPU specific
+flash-attn==2.7.4.post1
+triton==3.2.0
+xformers==0.0.28.post2
+
+# From github
+-e git+https://github.com/facebookresearch/pytorch3d.git@v0.7.6#egg=pytorch3d
+-e git+https://github.com/microsoft/MoGe.git#egg=moge
diff --git a/scripts/test.sh b/scripts/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8e8fb588d343e5ccc43b16c8e2098acb7c916bc8
--- /dev/null
+++ b/scripts/test.sh
@@ -0,0 +1,26 @@
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "" --image_path examples/case1/input.png --output_path test_results/case1
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case1/panorama.png --classes outdoor --output_path test_results/case1
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "" --image_path examples/case2/input.png --output_path test_results/case2
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case2/panorama.png --labels_fg1 stones --labels_fg2 trees --classes outdoor --output_path test_results/case2
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "" --image_path examples/case3/input.png --output_path test_results/case3
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case3/panorama.png --classes outdoor --output_path test_results/case3
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "There is a rocky island on the vast sea surface, with a triangular rock burning red flames in the center of the island. The sea is open and rough, with a green surface. Surrounded by towering peaks in the distance." --output_path test_results/case4
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case4/panorama.png --classes outdoor --output_path test_results/case4
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "" --image_path examples/case5/input.png --output_path test_results/case5
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case5/panorama.png --classes outdoor --output_path test_results/case5
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "" --image_path examples/case6/input.png --output_path test_results/case6
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case6/panorama.png --labels_fg1 tent --classes outdoor --output_path test_results/case6
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "At the moment of glacier collapse, giant ice walls collapse and create waves, with no wildlife, captured in a disaster documentary" --output_path test_results/case7
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case7/panorama.png --classes outdoor --output_path test_results/case7
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "" --image_path examples/case8/input.png --output_path test_results/case8
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case8/panorama.png --classes outdoor --output_path test_results/case8
+
+CUDA_VISIBLE_DEVICES=0 python3 demo_panogen.py --prompt "A breathtaking volcanic eruption scene. In the center of the screen, one or more volcanoes are erupting violently, with hot orange red lava gushing out from the crater, illuminating the surrounding night sky and landscape. Thick smoke and volcanic ash rose into the sky, forming a huge mushroom cloud like structure. Some of the smoke and dust were reflected in a dark red color by the high temperature of the lava, creating a doomsday atmosphere. In the foreground, a winding lava flow flows through the dark and rough rocks like a fire snake, emitting a dazzling light as if burning the earth. The steep and rugged mountains in the background further emphasize the ferocity and irresistible power of nature. The entire picture has a strong contrast of light and shadow, with red, black, and gray as the main colors, highlighting the visual impact and dramatic tension of volcanic eruptions, making people feel the grandeur and terror of nature." --output_path test_results/case9
+CUDA_VISIBLE_DEVICES=0 python3 demo_scenegen.py --image_path test_results/case9/panorama.png --classes outdoor --output_path test_results/case9