Sm0kyWu commited on
Commit
0c3ad13
·
verified ·
1 Parent(s): 2e9b6e3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -86
app.py CHANGED
@@ -33,16 +33,9 @@ def end_session(req: gr.Request):
33
  shutil.rmtree(user_dir)
34
 
35
  def reset_image(predictor, img):
36
- """
37
- 上传图像后调用:
38
- - 重置 predictor,
39
- - 设置 predictor 的输入图像,
40
- - 返回原图
41
- """
42
  predictor.set_image(img)
43
  original_img = img.copy()
44
- # 返回predictor,visible occlusion mask初始化, 原始图像
45
- return predictor, original_img, "The models are ready."
46
 
47
  def button_clickable(selected_points):
48
  if len(selected_points) > 0:
@@ -51,10 +44,6 @@ def button_clickable(selected_points):
51
  return gr.Button.update(interactive=False)
52
 
53
  def run_sam(predictor, selected_points):
54
- """
55
- 调用 SAM 模型进行分割。
56
- """
57
- # predictor.set_image(image)
58
  if len(selected_points) == 0:
59
  return [], None
60
  input_points = [p for p in selected_points]
@@ -62,7 +51,7 @@ def run_sam(predictor, selected_points):
62
  masks, _, _ = predictor.predict(
63
  point_coords=np.array(input_points),
64
  point_labels=np.array(input_labels),
65
- multimask_output=False, # 单对象输出
66
  )
67
  best_mask = masks[0].astype(np.uint8)
68
  # dilate
@@ -73,9 +62,6 @@ def run_sam(predictor, selected_points):
73
  return best_mask
74
 
75
  def apply_mask_overlay(image, mask, color=(255, 0, 0)):
76
- """
77
- 在原图上叠加 mask:使用红色绘制 mask 的轮廓,非 mask 区域叠加浅灰色半透明遮罩。
78
- """
79
  img_arr = image
80
  overlay = img_arr.copy()
81
  gray_color = np.array([200, 200, 200], dtype=np.uint8)
@@ -86,9 +72,6 @@ def apply_mask_overlay(image, mask, color=(255, 0, 0)):
86
  return overlay
87
 
88
  def segment_and_overlay(image, points, sam_predictor):
89
- """
90
- 调用 run_sam 获得 mask,然后叠加显示分割结果。
91
- """
92
  visible_mask = run_sam(sam_predictor, points)
93
  overlaid = apply_mask_overlay(image, visible_mask * 255)
94
  return overlaid, visible_mask
@@ -106,22 +89,6 @@ def image_to_3d(
106
  erode_kernel_size: int,
107
  req: gr.Request,
108
  ) -> Tuple[dict, str]:
109
- """
110
- Convert an image to a 3D model.
111
- Args:
112
- image (Image.Image): The input image.
113
- multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
114
- is_multiimage (bool): Whether is in multi-image mode.
115
- seed (int): The random seed.
116
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
117
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
118
- slat_guidance_strength (float): The guidance strength for structured latent generation.
119
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
120
- multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
121
- Returns:
122
- dict: The information of the generated 3D model.
123
- str: The path to the video of the 3D model.
124
- """
125
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
126
  outputs = pipeline.run_multi_image(
127
  [image],
@@ -156,9 +123,6 @@ def extract_glb(
156
  texture_size: int,
157
  req: gr.Request,
158
  ) -> tuple:
159
- """
160
- 从生成的 3D 模型中提取 GLB 文件。
161
- """
162
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
163
  gs, mesh = unpack_state(state)
164
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
@@ -170,9 +134,6 @@ def extract_glb(
170
 
171
  @spaces.GPU
172
  def extract_gaussian(state: dict, req: gr.Request) -> tuple:
173
- """
174
- 从生成的 3D 模型中提取 Gaussian 文件。
175
- """
176
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
177
  gs, _ = unpack_state(state)
178
  gaussian_path = os.path.join(user_dir, 'sample.ply')
@@ -229,7 +190,6 @@ def get_sam_predictor():
229
 
230
 
231
  def draw_points_on_image(image, point):
232
- """在图像上绘制所有点,points 为 [(x, y, point_type), ...]"""
233
  image_with_points = image.copy()
234
  x, y = point
235
  color = (255, 0, 0)
@@ -238,44 +198,24 @@ def draw_points_on_image(image, point):
238
 
239
 
240
  def see_point(image, x, y):
241
- """
242
- see操作:不修改 points 列表,仅在图像上临时显示这个点,
243
- 并返回更新后的图像和当前列表(不更新)。
244
- """
245
- # 复制当前列表,并在副本中加上新点(仅用于显示)
246
  updated_image = draw_points_on_image(image, [x,y])
247
  return updated_image
248
 
249
  def add_point(x, y, visible_points):
250
- """
251
- add操作:将新点添加到 points 列表中,
252
- 并返回更新后的图像和新的点列表。
253
- """
254
  if [x, y] not in visible_points:
255
  visible_points.append([x, y])
256
  return visible_points
257
 
258
  def delete_point(visible_points):
259
- """
260
- delete操作:删除 points 列表中的最后一个点,
261
- 并��回更新后的图像和新的点列表。
262
- """
263
  visible_points.pop()
264
  return visible_points
265
 
266
 
267
  def clear_all_points(image):
268
- """
269
- 清除所有点:返回原图、空的 visible 和 occlusion 列表,
270
- 以及更新后的点文本信息和空下拉菜单列表。
271
- """
272
  updated_image = image.copy()
273
  return updated_image
274
 
275
  def see_visible_points(image, visible_points):
276
- """
277
- 在图像上绘制所有 visible 点(红色)。
278
- """
279
  updated_image = image.copy()
280
  for p in visible_points:
281
  cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
@@ -284,11 +224,9 @@ def see_visible_points(image, visible_points):
284
  def update_all_points(visible_points):
285
  text = f"Points: {visible_points}"
286
  visible_dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points]
287
- # 返回更新字典来明确设置 choices 和 value
288
  return text, gr.Dropdown(label="Select Point to Delete", choices=visible_dropdown_choices, value=None, interactive=True)
289
 
290
  def delete_selected_visible(image, visible_points, selected_value):
291
- # selected_value 是类似 "(x, y)" 的字符串
292
  try:
293
  selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value)
294
  except ValueError:
@@ -296,14 +234,12 @@ def delete_selected_visible(image, visible_points, selected_value):
296
  if selected_index is not None and 0 <= selected_index < len(visible_points):
297
  visible_points.pop(selected_index)
298
  updated_image = image.copy()
299
- # 重新绘制所有 visible 点(红色)
300
  for p in visible_points:
301
  cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
302
  updated_text, vis_dropdown = update_all_points(visible_points)
303
  return updated_image, visible_points, updated_text, vis_dropdown
304
 
305
- def add_mask(mask, mask_list):
306
- # check if the mask if same as the last mask in the list
307
  if len(mask_list) > 0:
308
  if np.array_equal(mask, mask_list[-1]):
309
  return mask_list
@@ -312,11 +248,9 @@ def add_mask(mask, mask_list):
312
 
313
  def vis_mask(image, mask_list):
314
  updated_image = image.copy()
315
- # combine all the mask:
316
  combined_mask = np.zeros_like(updated_image[:, :, 0])
317
  for mask in mask_list:
318
  combined_mask = cv2.bitwise_or(combined_mask, mask)
319
- # overlay the mask on the image
320
  updated_image = apply_mask_overlay(updated_image, combined_mask)
321
  return updated_image
322
 
@@ -327,7 +261,6 @@ def delete_mask(mask_list):
327
 
328
  def check_combined_mask(image, visibility_mask, mask_list, scale=0.65):
329
  updated_image = image.copy()
330
- # combine all the mask:
331
  combined_mask = np.zeros_like(updated_image[:, :, 0])
332
  occluded_mask = np.zeros_like(updated_image[:, :, 0])
333
  if len(mask_list) == 0:
@@ -345,7 +278,6 @@ def check_combined_mask(image, visibility_mask, mask_list, scale=0.65):
345
  masked_img = updated_image * combined_mask[:, :, None]
346
  occluded_mask[combined_mask == 1] = 127
347
 
348
- # move the visible part to the center of the image
349
  x, y, w, h = cv2.boundingRect(combined_mask.astype(np.uint8))
350
  cropped_occluded_mask = (occluded_mask[y:y+h, x:x+w]).astype(np.uint8)
351
  cropped_img = masked_img[y:y+h, x:x+w]
@@ -383,7 +315,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
383
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
384
  """)
385
 
386
- # 定义各状态变量
387
  predictor = gr.State(value=get_sam_predictor())
388
  visible_points_state = gr.State(value=[])
389
  occlusion_points_state = gr.State(value=[])
@@ -466,9 +397,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
466
  with gr.Row():
467
  with gr.Column():
468
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
469
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
470
  with gr.Column():
471
- erode_kernel_size = gr.Slider(0, 5, label="Erode Kernel Size", value=0, step=1)
472
  gr.Markdown("Stage 1: Sparse Structure Generation")
473
  with gr.Row():
474
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
@@ -500,18 +431,15 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
500
  demo.load(start_session)
501
  demo.unload(end_session)
502
 
503
- # ---------------------------
504
- # 原有交互逻辑(略)
505
- # ---------------------------
506
  input_image.upload(
507
  reset_image,
508
  [predictor, input_image],
509
- [predictor, original_image, message],
510
  )
511
  apply_example_btn.click(
512
  reset_image,
513
  inputs=[predictor, input_image],
514
- outputs=[predictor, original_image, message]
515
  )
516
  see_button.click(
517
  see_point,
@@ -524,9 +452,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
524
  outputs=[visible_points_state]
525
  )
526
 
527
- # ---------------------------
528
- # 新增的交互逻辑
529
- # ---------------------------
530
  clear_button.click(
531
  clear_all_points,
532
  inputs=[original_image],
@@ -537,7 +462,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
537
  inputs=[input_image, visible_points_state],
538
  outputs=input_image
539
  )
540
- # 当 visible_points_state 或 occlusion_points_state 变化时,更新文本框和下拉菜单
541
  visible_points_state.change(
542
  update_all_points,
543
  inputs=[visible_points_state],
@@ -549,7 +474,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
549
  outputs=[input_image, visible_points_state, points_text, visible_points_dropdown]
550
  )
551
 
552
- # 生成mask的逻辑
553
  gen_vis_mask.click(
554
  segment_and_overlay,
555
  inputs=[original_image, visible_points_state, predictor],
@@ -622,8 +546,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
622
 
623
 
624
 
625
-
626
- # 启动 Gradio App
627
  if __name__ == "__main__":
628
  pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
629
  pipeline.cuda()
 
33
  shutil.rmtree(user_dir)
34
 
35
  def reset_image(predictor, img):
 
 
 
 
 
 
36
  predictor.set_image(img)
37
  original_img = img.copy()
38
+ return predictor, original_img, "The models are ready.", []
 
39
 
40
  def button_clickable(selected_points):
41
  if len(selected_points) > 0:
 
44
  return gr.Button.update(interactive=False)
45
 
46
  def run_sam(predictor, selected_points):
 
 
 
 
47
  if len(selected_points) == 0:
48
  return [], None
49
  input_points = [p for p in selected_points]
 
51
  masks, _, _ = predictor.predict(
52
  point_coords=np.array(input_points),
53
  point_labels=np.array(input_labels),
54
+ multimask_output=False,
55
  )
56
  best_mask = masks[0].astype(np.uint8)
57
  # dilate
 
62
  return best_mask
63
 
64
  def apply_mask_overlay(image, mask, color=(255, 0, 0)):
 
 
 
65
  img_arr = image
66
  overlay = img_arr.copy()
67
  gray_color = np.array([200, 200, 200], dtype=np.uint8)
 
72
  return overlay
73
 
74
  def segment_and_overlay(image, points, sam_predictor):
 
 
 
75
  visible_mask = run_sam(sam_predictor, points)
76
  overlaid = apply_mask_overlay(image, visible_mask * 255)
77
  return overlaid, visible_mask
 
89
  erode_kernel_size: int,
90
  req: gr.Request,
91
  ) -> Tuple[dict, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
93
  outputs = pipeline.run_multi_image(
94
  [image],
 
123
  texture_size: int,
124
  req: gr.Request,
125
  ) -> tuple:
 
 
 
126
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
127
  gs, mesh = unpack_state(state)
128
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
 
134
 
135
  @spaces.GPU
136
  def extract_gaussian(state: dict, req: gr.Request) -> tuple:
 
 
 
137
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
138
  gs, _ = unpack_state(state)
139
  gaussian_path = os.path.join(user_dir, 'sample.ply')
 
190
 
191
 
192
  def draw_points_on_image(image, point):
 
193
  image_with_points = image.copy()
194
  x, y = point
195
  color = (255, 0, 0)
 
198
 
199
 
200
  def see_point(image, x, y):
 
 
 
 
 
201
  updated_image = draw_points_on_image(image, [x,y])
202
  return updated_image
203
 
204
  def add_point(x, y, visible_points):
 
 
 
 
205
  if [x, y] not in visible_points:
206
  visible_points.append([x, y])
207
  return visible_points
208
 
209
  def delete_point(visible_points):
 
 
 
 
210
  visible_points.pop()
211
  return visible_points
212
 
213
 
214
  def clear_all_points(image):
 
 
 
 
215
  updated_image = image.copy()
216
  return updated_image
217
 
218
  def see_visible_points(image, visible_points):
 
 
 
219
  updated_image = image.copy()
220
  for p in visible_points:
221
  cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
 
224
  def update_all_points(visible_points):
225
  text = f"Points: {visible_points}"
226
  visible_dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points]
 
227
  return text, gr.Dropdown(label="Select Point to Delete", choices=visible_dropdown_choices, value=None, interactive=True)
228
 
229
  def delete_selected_visible(image, visible_points, selected_value):
 
230
  try:
231
  selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value)
232
  except ValueError:
 
234
  if selected_index is not None and 0 <= selected_index < len(visible_points):
235
  visible_points.pop(selected_index)
236
  updated_image = image.copy()
 
237
  for p in visible_points:
238
  cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
239
  updated_text, vis_dropdown = update_all_points(visible_points)
240
  return updated_image, visible_points, updated_text, vis_dropdown
241
 
242
+ def add_mask(mask, mask_list):
 
243
  if len(mask_list) > 0:
244
  if np.array_equal(mask, mask_list[-1]):
245
  return mask_list
 
248
 
249
  def vis_mask(image, mask_list):
250
  updated_image = image.copy()
 
251
  combined_mask = np.zeros_like(updated_image[:, :, 0])
252
  for mask in mask_list:
253
  combined_mask = cv2.bitwise_or(combined_mask, mask)
 
254
  updated_image = apply_mask_overlay(updated_image, combined_mask)
255
  return updated_image
256
 
 
261
 
262
  def check_combined_mask(image, visibility_mask, mask_list, scale=0.65):
263
  updated_image = image.copy()
 
264
  combined_mask = np.zeros_like(updated_image[:, :, 0])
265
  occluded_mask = np.zeros_like(updated_image[:, :, 0])
266
  if len(mask_list) == 0:
 
278
  masked_img = updated_image * combined_mask[:, :, None]
279
  occluded_mask[combined_mask == 1] = 127
280
 
 
281
  x, y, w, h = cv2.boundingRect(combined_mask.astype(np.uint8))
282
  cropped_occluded_mask = (occluded_mask[y:y+h, x:x+w]).astype(np.uint8)
283
  cropped_img = masked_img[y:y+h, x:x+w]
 
315
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
316
  """)
317
 
 
318
  predictor = gr.State(value=get_sam_predictor())
319
  visible_points_state = gr.State(value=[])
320
  occlusion_points_state = gr.State(value=[])
 
397
  with gr.Row():
398
  with gr.Column():
399
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
400
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
401
  with gr.Column():
402
+ erode_kernel_size = gr.Slider(0, 5, label="Erode Kernel Size", value=3, step=1)
403
  gr.Markdown("Stage 1: Sparse Structure Generation")
404
  with gr.Row():
405
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
 
431
  demo.load(start_session)
432
  demo.unload(end_session)
433
 
 
 
 
434
  input_image.upload(
435
  reset_image,
436
  [predictor, input_image],
437
+ [predictor, original_image, message, visible_points_state],
438
  )
439
  apply_example_btn.click(
440
  reset_image,
441
  inputs=[predictor, input_image],
442
+ outputs=[predictor, original_image, message, visible_points_state]
443
  )
444
  see_button.click(
445
  see_point,
 
452
  outputs=[visible_points_state]
453
  )
454
 
 
 
 
455
  clear_button.click(
456
  clear_all_points,
457
  inputs=[original_image],
 
462
  inputs=[input_image, visible_points_state],
463
  outputs=input_image
464
  )
465
+
466
  visible_points_state.change(
467
  update_all_points,
468
  inputs=[visible_points_state],
 
474
  outputs=[input_image, visible_points_state, points_text, visible_points_dropdown]
475
  )
476
 
 
477
  gen_vis_mask.click(
478
  segment_and_overlay,
479
  inputs=[original_image, visible_points_state, predictor],
 
546
 
547
 
548
 
 
 
549
  if __name__ == "__main__":
550
  pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
551
  pipeline.cuda()