olivercareyncl commited on
Commit
990b457
·
verified ·
1 Parent(s): 6e3f894

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -0
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from typing import Any, List, Union
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import spaces
10
+ import torch
11
+ import trimesh
12
+ from gradio_image_prompter import ImagePrompter
13
+ from gradio_litmodel3d import LitModel3D
14
+ from huggingface_hub import snapshot_download
15
+ from PIL import Image
16
+ from skimage import measure
17
+ from transformers import AutoModelForMaskGeneration, AutoProcessor
18
+
19
+ from midi.pipelines.pipeline_midi import MIDIPipeline
20
+ from midi.utils.smoothing import smooth_gpu
21
+ from scripts.grounding_sam import plot_segmentation, segment
22
+ from scripts.inference_midi import preprocess_image, split_rgb_mask
23
+
24
+ # Constants
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
27
+ DTYPE = torch.bfloat16
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ REPO_ID = "VAST-AI/MIDI-3D"
30
+
31
+ MARKDOWN = """
32
+ ## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/)
33
+ <b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/814c046e-f5c3-47cf-bb56-60154be8374c)!
34
+ 1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b>
35
+ 2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result.
36
+ 3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
37
+ """
38
+
39
+ EXAMPLES = [
40
+ [
41
+ {
42
+ "image": "assets/example_data/Cartoon-Style/03_rgb.png",
43
+ },
44
+ "assets/example_data/Cartoon-Style/03_seg.png",
45
+ 42,
46
+ False,
47
+ False,
48
+ ],
49
+ [
50
+ {
51
+ "image": "assets/example_data/Cartoon-Style/01_rgb.png",
52
+ },
53
+ "assets/example_data/Cartoon-Style/01_seg.png",
54
+ 42,
55
+ False,
56
+ False,
57
+ ],
58
+ [
59
+ {
60
+ "image": "assets/example_data/Realistic-Style/02_rgb.png",
61
+ },
62
+ "assets/example_data/Realistic-Style/02_seg.png",
63
+ 42,
64
+ False,
65
+ False,
66
+ ],
67
+ [
68
+ {
69
+ "image": "assets/example_data/Cartoon-Style/00_rgb.png",
70
+ },
71
+ "assets/example_data/Cartoon-Style/00_seg.png",
72
+ 42,
73
+ False,
74
+ False,
75
+ ],
76
+ [
77
+ {
78
+ "image": "assets/example_data/Realistic-Style/00_rgb.png",
79
+ },
80
+ "assets/example_data/Realistic-Style/00_seg.png",
81
+ 42,
82
+ False,
83
+ True,
84
+ ],
85
+ [
86
+ {
87
+ "image": "assets/example_data/Realistic-Style/01_rgb.png",
88
+ },
89
+ "assets/example_data/Realistic-Style/01_seg.png",
90
+ 42,
91
+ False,
92
+ True,
93
+ ],
94
+ [
95
+ {
96
+ "image": "assets/example_data/Realistic-Style/05_rgb.png",
97
+ },
98
+ "assets/example_data/Realistic-Style/05_seg.png",
99
+ 42,
100
+ False,
101
+ False,
102
+ ],
103
+ ]
104
+
105
+ os.makedirs(TMP_DIR, exist_ok=True)
106
+
107
+ # Prepare models
108
+ ## Grounding SAM
109
+ segmenter_id = "facebook/sam-vit-base"
110
+ sam_processor = AutoProcessor.from_pretrained(segmenter_id)
111
+ sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
112
+ DEVICE, DTYPE
113
+ )
114
+ ## MIDI-3D
115
+ local_dir = "pretrained_weights/MIDI-3D"
116
+ snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
117
+ pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE)
118
+ pipe.init_custom_adapter(
119
+ set_self_attn_module_names=[
120
+ "blocks.8",
121
+ "blocks.9",
122
+ "blocks.10",
123
+ "blocks.11",
124
+ "blocks.12",
125
+ ]
126
+ )
127
+
128
+
129
+ # Utils
130
+ def get_random_hex():
131
+ random_bytes = os.urandom(8)
132
+ random_hex = random_bytes.hex()
133
+ return random_hex
134
+
135
+
136
+ @spaces.GPU()
137
+ @torch.no_grad()
138
+ @torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
139
+ def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image:
140
+ rgb_image = image_prompts["image"].convert("RGB")
141
+
142
+ # pre-process the layers and get the xyxy boxes of each layer
143
+ if len(image_prompts["points"]) == 0:
144
+ gr.Error("Please draw bounding boxes for each instance on the image.")
145
+ boxes = [
146
+ [
147
+ [int(box[0]), int(box[1]), int(box[3]), int(box[4])]
148
+ for box in image_prompts["points"]
149
+ ]
150
+ ]
151
+
152
+ # run the segmentation
153
+ detections = segment(
154
+ sam_processor,
155
+ sam_segmentator,
156
+ rgb_image,
157
+ boxes=[boxes],
158
+ polygon_refinement=polygon_refinement,
159
+ )
160
+ seg_map_pil = plot_segmentation(rgb_image, detections)
161
+
162
+ torch.cuda.empty_cache()
163
+
164
+ return seg_map_pil
165
+
166
+
167
+ @torch.no_grad()
168
+ def run_midi(
169
+ pipe: Any,
170
+ rgb_image: Union[str, Image.Image],
171
+ seg_image: Union[str, Image.Image],
172
+ seed: int,
173
+ num_inference_steps: int = 50,
174
+ guidance_scale: float = 7.0,
175
+ do_image_padding: bool = False,
176
+ ) -> trimesh.Scene:
177
+ if do_image_padding:
178
+ rgb_image, seg_image = preprocess_image(rgb_image, seg_image)
179
+ instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image)
180
+
181
+ num_instances = len(instance_rgbs)
182
+ outputs = pipe(
183
+ image=instance_rgbs,
184
+ mask=instance_masks,
185
+ image_scene=scene_rgbs,
186
+ attention_kwargs={"num_instances": num_instances},
187
+ generator=torch.Generator(device=pipe.device).manual_seed(seed),
188
+ num_inference_steps=num_inference_steps,
189
+ guidance_scale=guidance_scale,
190
+ decode_progressive=True,
191
+ return_dict=False,
192
+ )
193
+
194
+ return outputs
195
+
196
+
197
+ @spaces.GPU(duration=180)
198
+ @torch.no_grad()
199
+ @torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
200
+ def run_generation(
201
+ rgb_image: Any,
202
+ seg_image: Union[str, Image.Image],
203
+ seed: int,
204
+ randomize_seed: bool = False,
205
+ num_inference_steps: int = 50,
206
+ guidance_scale: float = 7.0,
207
+ do_image_padding: bool = False,
208
+ ):
209
+ if randomize_seed:
210
+ seed = random.randint(0, MAX_SEED)
211
+
212
+ if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
213
+ rgb_image = rgb_image["image"]
214
+
215
+ outputs = run_midi(
216
+ pipe,
217
+ rgb_image,
218
+ seg_image,
219
+ seed,
220
+ num_inference_steps,
221
+ guidance_scale,
222
+ do_image_padding,
223
+ )
224
+
225
+ # marching cubes
226
+ trimeshes = []
227
+ for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate(
228
+ zip(*outputs)
229
+ ):
230
+ grid_logits = logits_.view(grid_size)
231
+ grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1)
232
+ torch.cuda.empty_cache()
233
+ vertices, faces, normals, _ = measure.marching_cubes(
234
+ grid_logits.float().cpu().numpy(), 0, method="lewiner"
235
+ )
236
+ vertices = vertices / grid_size * bbox_size + bbox_min
237
+
238
+ # Trimesh
239
+ mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
240
+ trimeshes.append(mesh)
241
+
242
+ # compose the output meshes
243
+ scene = trimesh.Scene(trimeshes)
244
+
245
+ tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb")
246
+ scene.export(tmp_path)
247
+
248
+ torch.cuda.empty_cache()
249
+
250
+ return tmp_path, tmp_path, seed
251
+
252
+
253
+ # Demo
254
+ with gr.Blocks() as demo:
255
+ gr.Markdown(MARKDOWN)
256
+
257
+ with gr.Row():
258
+ with gr.Column():
259
+ with gr.Row():
260
+ image_prompts = ImagePrompter(label="Input Image", type="pil")
261
+ seg_image = gr.Image(
262
+ label="Segmentation Result", type="pil", format="png"
263
+ )
264
+
265
+ with gr.Accordion("Segmentation Settings", open=False):
266
+ polygon_refinement = gr.Checkbox(
267
+ label="Polygon Refinement", value=False
268
+ )
269
+ seg_button = gr.Button("Run Segmentation")
270
+
271
+ with gr.Accordion("Generation Settings", open=False):
272
+ do_image_padding = gr.Checkbox(label="Do image padding", value=False)
273
+ seed = gr.Slider(
274
+ label="Seed",
275
+ minimum=0,
276
+ maximum=MAX_SEED,
277
+ step=1,
278
+ value=0,
279
+ )
280
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
281
+ num_inference_steps = gr.Slider(
282
+ label="Number of inference steps",
283
+ minimum=1,
284
+ maximum=50,
285
+ step=1,
286
+ value=50,
287
+ )
288
+ guidance_scale = gr.Slider(
289
+ label="CFG scale",
290
+ minimum=0.0,
291
+ maximum=10.0,
292
+ step=0.1,
293
+ value=7.0,
294
+ )
295
+ gen_button = gr.Button("Run Generation", variant="primary")
296
+
297
+ with gr.Column():
298
+ model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500)
299
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
300
+
301
+ with gr.Row():
302
+ gr.Examples(
303
+ examples=EXAMPLES,
304
+ fn=run_generation,
305
+ inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding],
306
+ outputs=[model_output, download_glb, seed],
307
+ cache_examples=False,
308
+ )
309
+
310
+ seg_button.click(
311
+ run_segmentation,
312
+ inputs=[
313
+ image_prompts,
314
+ polygon_refinement,
315
+ ],
316
+ outputs=[seg_image],
317
+ ).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
318
+
319
+ gen_button.click(
320
+ run_generation,
321
+ inputs=[
322
+ image_prompts,
323
+ seg_image,
324
+ seed,
325
+ randomize_seed,
326
+ num_inference_steps,
327
+ guidance_scale,
328
+ do_image_padding,
329
+ ],
330
+ outputs=[model_output, download_glb, seed],
331
+ ).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
332
+
333
+
334
+ demo.launch()