Sm0kyWu commited on
Commit
ee77a14
·
verified ·
1 Parent(s): 00b4b8e

Upload 71 files

Browse files
Files changed (2) hide show
  1. Amodal3R/pipelines/image_to_3d.py +2 -1
  2. app.py +81 -22
Amodal3R/pipelines/image_to_3d.py CHANGED
@@ -377,6 +377,7 @@ class Amodal3RImageTo3DPipeline(Pipeline):
377
  slat_sampler_params: dict = {},
378
  formats: List[str] = ['mesh', 'gaussian'],
379
  mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
 
380
  ) -> dict:
381
  """
382
  Run the pipeline with multiple images as condition
@@ -388,7 +389,7 @@ class Amodal3RImageTo3DPipeline(Pipeline):
388
  slat_sampler_params (dict): Additional parameters for the structured latent sampler.
389
  preprocess_image (bool): Whether to preprocess the image.
390
  """
391
- images, masks, masks_occ = zip(*[self.preprocess_image_w_mask(image, mask) for image, mask in zip(images, masks)])
392
  images = list(images)
393
  masks = list(masks)
394
  masks_occ = list(masks_occ)
 
377
  slat_sampler_params: dict = {},
378
  formats: List[str] = ['mesh', 'gaussian'],
379
  mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
380
+ erode_kernel_size: int = 3,
381
  ) -> dict:
382
  """
383
  Run the pipeline with multiple images as condition
 
389
  slat_sampler_params (dict): Additional parameters for the structured latent sampler.
390
  preprocess_image (bool): Whether to preprocess the image.
391
  """
392
+ images, masks, masks_occ = zip(*[self.preprocess_image_w_mask(image, mask, erode_kernel_size) for image, mask in zip(images, masks)])
393
  images = list(images)
394
  masks = list(masks)
395
  masks_occ = list(masks_occ)
app.py CHANGED
@@ -103,6 +103,7 @@ def image_to_3d(
103
  ss_sampling_steps: int,
104
  slat_guidance_strength: float,
105
  slat_sampling_steps: int,
 
106
  req: gr.Request,
107
  ) -> Tuple[dict, str]:
108
  """
@@ -136,8 +137,9 @@ def image_to_3d(
136
  "cfg_strength": slat_guidance_strength,
137
  },
138
  mode="stochastic",
 
139
  )
140
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
141
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
142
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
143
  video_path = os.path.join(user_dir, 'sample.mp4')
@@ -323,7 +325,7 @@ def delete_mask(mask_list):
323
  mask_list.pop()
324
  return mask_list
325
 
326
- def check_combined_mask(image, visibility_mask, mask_list, scale=0.6):
327
  updated_image = image.copy()
328
  # combine all the mask:
329
  combined_mask = np.zeros_like(updated_image[:, :, 0])
@@ -394,13 +396,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
394
 
395
 
396
  with gr.Row():
397
- gr.Markdown("""* Step 1 - Generate Visibility Mask and Occlusion Mask.
 
398
  * Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
399
- * Add the point prompts to indicate the target object and occluders separately.
400
- * "Render Point", see the position of the point to be added.
401
- * "Add Point", the point will be added to the list.
402
- * "Generate mask", see the segmented area corresponding to current point list.
403
- * "Add mask", current mask will be added for 3D amodal completion.
404
  """)
405
  with gr.Row():
406
  with gr.Column():
@@ -434,11 +436,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
434
  undo_vis_mask = gr.Button("Undo Last Mask")
435
  vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
436
  with gr.Row():
437
- zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.6, step=0.1)
438
  check_visible_input = gr.Button("Generate Occluded Input")
439
  with gr.Row():
440
- gr.Markdown("""* Step 2 - 3D Amodal Completion.
 
441
  * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
 
442
  * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
443
  """)
444
  with gr.Row():
@@ -446,6 +450,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
446
  with gr.Accordion(label="Generation Settings", open=True):
447
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
448
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
449
  gr.Markdown("Stage 1: Sparse Structure Generation")
450
  with gr.Row():
451
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
@@ -454,10 +459,37 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
454
  with gr.Row():
455
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
456
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
457
- generate_btn = gr.Button("Generate")
 
 
 
 
 
 
 
 
 
458
  with gr.Column():
459
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
 
 
 
 
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  # # Handlers
462
  demo.load(start_session)
463
  demo.unload(end_session)
@@ -536,21 +568,48 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
536
 
537
 
538
  # 3D Amodal Reconstruction
539
- # generate_btn.click(
540
- # get_seed,
541
- # inputs=[randomize_seed, seed],
542
- # outputs=[seed],
543
- # ).then(
544
- # image_to_3d,
545
- # inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
546
- # outputs=[output_buf, video_output],
547
- # )
548
-
549
  generate_btn.click(
 
 
 
 
550
  image_to_3d,
551
- inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
552
  outputs=[output_buf, video_output],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  )
 
 
 
 
 
 
 
554
 
555
 
556
  # 启动 Gradio App
 
103
  ss_sampling_steps: int,
104
  slat_guidance_strength: float,
105
  slat_sampling_steps: int,
106
+ erode_kernel_size: int,
107
  req: gr.Request,
108
  ) -> Tuple[dict, str]:
109
  """
 
137
  "cfg_strength": slat_guidance_strength,
138
  },
139
  mode="stochastic",
140
+ erode_kernel_size=erode_kernel_size,
141
  )
142
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120, bg_color=(1,1,1))['color']
143
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
144
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
145
  video_path = os.path.join(user_dir, 'sample.mp4')
 
325
  mask_list.pop()
326
  return mask_list
327
 
328
+ def check_combined_mask(image, visibility_mask, mask_list, scale=0.65):
329
  updated_image = image.copy()
330
  # combine all the mask:
331
  combined_mask = np.zeros_like(updated_image[:, :, 0])
 
396
 
397
 
398
  with gr.Row():
399
+ gr.Markdown("""
400
+ ### Step 1 - Generate Visibility Mask and Occlusion Mask.
401
  * Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
402
+ * Add the point prompts to indicate the target object.
403
+ * "Render Point", see the position of the point to be added. "Add Point", the point will be added to the list.
404
+ * "Generate mask", see the segmented area corresponding to current point list. "Add mask", current mask will be added for 3D amodal completion.
405
+ * The target object need to be put in the center of the image, the scale can be adjusted for better reconstruction.
 
406
  """)
407
  with gr.Row():
408
  with gr.Column():
 
436
  undo_vis_mask = gr.Button("Undo Last Mask")
437
  vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
438
  with gr.Row():
439
+ zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.68, step=0.1)
440
  check_visible_input = gr.Button("Generate Occluded Input")
441
  with gr.Row():
442
+ gr.Markdown("""
443
+ ### Step 2 - 3D Amodal Completion.
444
  * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
445
+ * The boundary of the segmentation may not be accurate, so here we provide the option to erode the visible area.
446
  * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
447
  """)
448
  with gr.Row():
 
450
  with gr.Accordion(label="Generation Settings", open=True):
451
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
452
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
453
+ erode_kernel_size = gr.Slider(0, 5, label="Erode Kernel Size", value=0, step=1)
454
  gr.Markdown("Stage 1: Sparse Structure Generation")
455
  with gr.Row():
456
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
 
459
  with gr.Row():
460
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
461
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
462
+ generate_btn = gr.Button("Amodal 3D Reconstruction")
463
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
464
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
465
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
466
+ with gr.Row():
467
+ extract_glb_btn = gr.Button("Extract GLB")
468
+ extract_gs_btn = gr.Button("Extract Gaussian")
469
+ gr.Markdown("""
470
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
471
+ """)
472
  with gr.Column():
473
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
474
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
475
+
476
+ with gr.Row():
477
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
478
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
479
 
480
+ with gr.Row() as single_image_example:
481
+ examples = gr.Examples(
482
+ examples=[
483
+ f'assets/example_image/{image}'
484
+ for image in os.listdir("assets/example_image")
485
+ ],
486
+ inputs=[input_image],
487
+ fn=lambda image: input_image.upload(image),
488
+ outputs=[predictor, original_image, message],
489
+ run_on_click=True,
490
+ examples_per_page=12,
491
+ )
492
+
493
  # # Handlers
494
  demo.load(start_session)
495
  demo.unload(end_session)
 
568
 
569
 
570
  # 3D Amodal Reconstruction
 
 
 
 
 
 
 
 
 
 
571
  generate_btn.click(
572
+ get_seed,
573
+ inputs=[randomize_seed, seed],
574
+ outputs=[seed],
575
+ ).then(
576
  image_to_3d,
577
+ inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, erode_kernel_size],
578
  outputs=[output_buf, video_output],
579
+ ).then(
580
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
581
+ outputs=[extract_glb_btn, extract_gs_btn],
582
+ )
583
+
584
+ video_output.clear(
585
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
586
+ outputs=[extract_glb_btn, extract_gs_btn],
587
+ )
588
+
589
+ extract_glb_btn.click(
590
+ extract_glb,
591
+ inputs=[output_buf, mesh_simplify, texture_size],
592
+ outputs=[model_output, download_glb],
593
+ ).then(
594
+ lambda: gr.Button(interactive=True),
595
+ outputs=[download_glb],
596
+ )
597
+
598
+ extract_gs_btn.click(
599
+ extract_gaussian,
600
+ inputs=[output_buf],
601
+ outputs=[model_output, download_gs],
602
+ ).then(
603
+ lambda: gr.Button(interactive=True),
604
+ outputs=[download_gs],
605
  )
606
+
607
+ model_output.clear(
608
+ lambda: gr.Button(interactive=False),
609
+ outputs=[download_glb],
610
+ )
611
+
612
+
613
 
614
 
615
  # 启动 Gradio App