Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import tempfile
|
5 |
+
from typing import Any, List, Union
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import spaces
|
10 |
+
import torch
|
11 |
+
import trimesh
|
12 |
+
from gradio_image_prompter import ImagePrompter
|
13 |
+
from gradio_litmodel3d import LitModel3D
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
+
from PIL import Image
|
16 |
+
from skimage import measure
|
17 |
+
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
18 |
+
|
19 |
+
from midi.pipelines.pipeline_midi import MIDIPipeline
|
20 |
+
from midi.utils.smoothing import smooth_gpu
|
21 |
+
from scripts.grounding_sam import plot_segmentation, segment
|
22 |
+
from scripts.inference_midi import preprocess_image, split_rgb_mask
|
23 |
+
|
24 |
+
# Constants
|
25 |
+
MAX_SEED = np.iinfo(np.int32).max
|
26 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
|
27 |
+
DTYPE = torch.bfloat16
|
28 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
REPO_ID = "VAST-AI/MIDI-3D"
|
30 |
+
|
31 |
+
MARKDOWN = """
|
32 |
+
## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/)
|
33 |
+
<b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/814c046e-f5c3-47cf-bb56-60154be8374c)!
|
34 |
+
1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b>
|
35 |
+
2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result.
|
36 |
+
3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
|
37 |
+
"""
|
38 |
+
|
39 |
+
EXAMPLES = [
|
40 |
+
[
|
41 |
+
{
|
42 |
+
"image": "assets/example_data/Cartoon-Style/03_rgb.png",
|
43 |
+
},
|
44 |
+
"assets/example_data/Cartoon-Style/03_seg.png",
|
45 |
+
42,
|
46 |
+
False,
|
47 |
+
False,
|
48 |
+
],
|
49 |
+
[
|
50 |
+
{
|
51 |
+
"image": "assets/example_data/Cartoon-Style/01_rgb.png",
|
52 |
+
},
|
53 |
+
"assets/example_data/Cartoon-Style/01_seg.png",
|
54 |
+
42,
|
55 |
+
False,
|
56 |
+
False,
|
57 |
+
],
|
58 |
+
[
|
59 |
+
{
|
60 |
+
"image": "assets/example_data/Realistic-Style/02_rgb.png",
|
61 |
+
},
|
62 |
+
"assets/example_data/Realistic-Style/02_seg.png",
|
63 |
+
42,
|
64 |
+
False,
|
65 |
+
False,
|
66 |
+
],
|
67 |
+
[
|
68 |
+
{
|
69 |
+
"image": "assets/example_data/Cartoon-Style/00_rgb.png",
|
70 |
+
},
|
71 |
+
"assets/example_data/Cartoon-Style/00_seg.png",
|
72 |
+
42,
|
73 |
+
False,
|
74 |
+
False,
|
75 |
+
],
|
76 |
+
[
|
77 |
+
{
|
78 |
+
"image": "assets/example_data/Realistic-Style/00_rgb.png",
|
79 |
+
},
|
80 |
+
"assets/example_data/Realistic-Style/00_seg.png",
|
81 |
+
42,
|
82 |
+
False,
|
83 |
+
True,
|
84 |
+
],
|
85 |
+
[
|
86 |
+
{
|
87 |
+
"image": "assets/example_data/Realistic-Style/01_rgb.png",
|
88 |
+
},
|
89 |
+
"assets/example_data/Realistic-Style/01_seg.png",
|
90 |
+
42,
|
91 |
+
False,
|
92 |
+
True,
|
93 |
+
],
|
94 |
+
[
|
95 |
+
{
|
96 |
+
"image": "assets/example_data/Realistic-Style/05_rgb.png",
|
97 |
+
},
|
98 |
+
"assets/example_data/Realistic-Style/05_seg.png",
|
99 |
+
42,
|
100 |
+
False,
|
101 |
+
False,
|
102 |
+
],
|
103 |
+
]
|
104 |
+
|
105 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
106 |
+
|
107 |
+
# Prepare models
|
108 |
+
## Grounding SAM
|
109 |
+
segmenter_id = "facebook/sam-vit-base"
|
110 |
+
sam_processor = AutoProcessor.from_pretrained(segmenter_id)
|
111 |
+
sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
|
112 |
+
DEVICE, DTYPE
|
113 |
+
)
|
114 |
+
## MIDI-3D
|
115 |
+
local_dir = "pretrained_weights/MIDI-3D"
|
116 |
+
snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
|
117 |
+
pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE)
|
118 |
+
pipe.init_custom_adapter(
|
119 |
+
set_self_attn_module_names=[
|
120 |
+
"blocks.8",
|
121 |
+
"blocks.9",
|
122 |
+
"blocks.10",
|
123 |
+
"blocks.11",
|
124 |
+
"blocks.12",
|
125 |
+
]
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
# Utils
|
130 |
+
def get_random_hex():
|
131 |
+
random_bytes = os.urandom(8)
|
132 |
+
random_hex = random_bytes.hex()
|
133 |
+
return random_hex
|
134 |
+
|
135 |
+
|
136 |
+
@spaces.GPU()
|
137 |
+
@torch.no_grad()
|
138 |
+
@torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
|
139 |
+
def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image:
|
140 |
+
rgb_image = image_prompts["image"].convert("RGB")
|
141 |
+
|
142 |
+
# pre-process the layers and get the xyxy boxes of each layer
|
143 |
+
if len(image_prompts["points"]) == 0:
|
144 |
+
gr.Error("Please draw bounding boxes for each instance on the image.")
|
145 |
+
boxes = [
|
146 |
+
[
|
147 |
+
[int(box[0]), int(box[1]), int(box[3]), int(box[4])]
|
148 |
+
for box in image_prompts["points"]
|
149 |
+
]
|
150 |
+
]
|
151 |
+
|
152 |
+
# run the segmentation
|
153 |
+
detections = segment(
|
154 |
+
sam_processor,
|
155 |
+
sam_segmentator,
|
156 |
+
rgb_image,
|
157 |
+
boxes=[boxes],
|
158 |
+
polygon_refinement=polygon_refinement,
|
159 |
+
)
|
160 |
+
seg_map_pil = plot_segmentation(rgb_image, detections)
|
161 |
+
|
162 |
+
torch.cuda.empty_cache()
|
163 |
+
|
164 |
+
return seg_map_pil
|
165 |
+
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def run_midi(
|
169 |
+
pipe: Any,
|
170 |
+
rgb_image: Union[str, Image.Image],
|
171 |
+
seg_image: Union[str, Image.Image],
|
172 |
+
seed: int,
|
173 |
+
num_inference_steps: int = 50,
|
174 |
+
guidance_scale: float = 7.0,
|
175 |
+
do_image_padding: bool = False,
|
176 |
+
) -> trimesh.Scene:
|
177 |
+
if do_image_padding:
|
178 |
+
rgb_image, seg_image = preprocess_image(rgb_image, seg_image)
|
179 |
+
instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image)
|
180 |
+
|
181 |
+
num_instances = len(instance_rgbs)
|
182 |
+
outputs = pipe(
|
183 |
+
image=instance_rgbs,
|
184 |
+
mask=instance_masks,
|
185 |
+
image_scene=scene_rgbs,
|
186 |
+
attention_kwargs={"num_instances": num_instances},
|
187 |
+
generator=torch.Generator(device=pipe.device).manual_seed(seed),
|
188 |
+
num_inference_steps=num_inference_steps,
|
189 |
+
guidance_scale=guidance_scale,
|
190 |
+
decode_progressive=True,
|
191 |
+
return_dict=False,
|
192 |
+
)
|
193 |
+
|
194 |
+
return outputs
|
195 |
+
|
196 |
+
|
197 |
+
@spaces.GPU(duration=180)
|
198 |
+
@torch.no_grad()
|
199 |
+
@torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
|
200 |
+
def run_generation(
|
201 |
+
rgb_image: Any,
|
202 |
+
seg_image: Union[str, Image.Image],
|
203 |
+
seed: int,
|
204 |
+
randomize_seed: bool = False,
|
205 |
+
num_inference_steps: int = 50,
|
206 |
+
guidance_scale: float = 7.0,
|
207 |
+
do_image_padding: bool = False,
|
208 |
+
):
|
209 |
+
if randomize_seed:
|
210 |
+
seed = random.randint(0, MAX_SEED)
|
211 |
+
|
212 |
+
if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
|
213 |
+
rgb_image = rgb_image["image"]
|
214 |
+
|
215 |
+
outputs = run_midi(
|
216 |
+
pipe,
|
217 |
+
rgb_image,
|
218 |
+
seg_image,
|
219 |
+
seed,
|
220 |
+
num_inference_steps,
|
221 |
+
guidance_scale,
|
222 |
+
do_image_padding,
|
223 |
+
)
|
224 |
+
|
225 |
+
# marching cubes
|
226 |
+
trimeshes = []
|
227 |
+
for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate(
|
228 |
+
zip(*outputs)
|
229 |
+
):
|
230 |
+
grid_logits = logits_.view(grid_size)
|
231 |
+
grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1)
|
232 |
+
torch.cuda.empty_cache()
|
233 |
+
vertices, faces, normals, _ = measure.marching_cubes(
|
234 |
+
grid_logits.float().cpu().numpy(), 0, method="lewiner"
|
235 |
+
)
|
236 |
+
vertices = vertices / grid_size * bbox_size + bbox_min
|
237 |
+
|
238 |
+
# Trimesh
|
239 |
+
mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
|
240 |
+
trimeshes.append(mesh)
|
241 |
+
|
242 |
+
# compose the output meshes
|
243 |
+
scene = trimesh.Scene(trimeshes)
|
244 |
+
|
245 |
+
tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb")
|
246 |
+
scene.export(tmp_path)
|
247 |
+
|
248 |
+
torch.cuda.empty_cache()
|
249 |
+
|
250 |
+
return tmp_path, tmp_path, seed
|
251 |
+
|
252 |
+
|
253 |
+
# Demo
|
254 |
+
with gr.Blocks() as demo:
|
255 |
+
gr.Markdown(MARKDOWN)
|
256 |
+
|
257 |
+
with gr.Row():
|
258 |
+
with gr.Column():
|
259 |
+
with gr.Row():
|
260 |
+
image_prompts = ImagePrompter(label="Input Image", type="pil")
|
261 |
+
seg_image = gr.Image(
|
262 |
+
label="Segmentation Result", type="pil", format="png"
|
263 |
+
)
|
264 |
+
|
265 |
+
with gr.Accordion("Segmentation Settings", open=False):
|
266 |
+
polygon_refinement = gr.Checkbox(
|
267 |
+
label="Polygon Refinement", value=False
|
268 |
+
)
|
269 |
+
seg_button = gr.Button("Run Segmentation")
|
270 |
+
|
271 |
+
with gr.Accordion("Generation Settings", open=False):
|
272 |
+
do_image_padding = gr.Checkbox(label="Do image padding", value=False)
|
273 |
+
seed = gr.Slider(
|
274 |
+
label="Seed",
|
275 |
+
minimum=0,
|
276 |
+
maximum=MAX_SEED,
|
277 |
+
step=1,
|
278 |
+
value=0,
|
279 |
+
)
|
280 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
281 |
+
num_inference_steps = gr.Slider(
|
282 |
+
label="Number of inference steps",
|
283 |
+
minimum=1,
|
284 |
+
maximum=50,
|
285 |
+
step=1,
|
286 |
+
value=50,
|
287 |
+
)
|
288 |
+
guidance_scale = gr.Slider(
|
289 |
+
label="CFG scale",
|
290 |
+
minimum=0.0,
|
291 |
+
maximum=10.0,
|
292 |
+
step=0.1,
|
293 |
+
value=7.0,
|
294 |
+
)
|
295 |
+
gen_button = gr.Button("Run Generation", variant="primary")
|
296 |
+
|
297 |
+
with gr.Column():
|
298 |
+
model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500)
|
299 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
300 |
+
|
301 |
+
with gr.Row():
|
302 |
+
gr.Examples(
|
303 |
+
examples=EXAMPLES,
|
304 |
+
fn=run_generation,
|
305 |
+
inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding],
|
306 |
+
outputs=[model_output, download_glb, seed],
|
307 |
+
cache_examples=False,
|
308 |
+
)
|
309 |
+
|
310 |
+
seg_button.click(
|
311 |
+
run_segmentation,
|
312 |
+
inputs=[
|
313 |
+
image_prompts,
|
314 |
+
polygon_refinement,
|
315 |
+
],
|
316 |
+
outputs=[seg_image],
|
317 |
+
).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
|
318 |
+
|
319 |
+
gen_button.click(
|
320 |
+
run_generation,
|
321 |
+
inputs=[
|
322 |
+
image_prompts,
|
323 |
+
seg_image,
|
324 |
+
seed,
|
325 |
+
randomize_seed,
|
326 |
+
num_inference_steps,
|
327 |
+
guidance_scale,
|
328 |
+
do_image_padding,
|
329 |
+
],
|
330 |
+
outputs=[model_output, download_glb, seed],
|
331 |
+
).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
|
332 |
+
|
333 |
+
|
334 |
+
demo.launch()
|