LutaoJiang commited on
Commit
c69ffaa
·
1 Parent(s): 40ae5e2
Files changed (3) hide show
  1. app.py +4 -5
  2. pipeline/kiss3d_wrapper.py +2 -2
  3. pipeline/utils.py +21 -17
app.py CHANGED
@@ -187,7 +187,7 @@ else:
187
  # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
188
  return k3d_wrapper.get_detailed_prompt(prompt, seed)
189
 
190
- @spaces.GPU
191
  def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=18, redux_hparam=None, init_image=None, **kwargs):
192
  # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
193
  # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
@@ -210,7 +210,7 @@ else:
210
  **kwargs)
211
  return result[-1]
212
 
213
- @spaces.GPU
214
  def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
215
  global preprocessed_input_image
216
 
@@ -225,7 +225,7 @@ else:
225
  return reference_save_path, caption
226
 
227
 
228
- @spaces.GPU
229
  def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
230
  subprocess.run(['nvidia-smi'])
231
  global mesh_cache
@@ -252,7 +252,7 @@ else:
252
  return gen_save_path, recon_mesh_path, mesh_cache
253
  # return gen_save_path, recon_mesh_path
254
 
255
- @spaces.GPU
256
  def bundle_image_to_mesh(
257
  gen_3d_bundle_image,
258
  camera_radius=3.5,
@@ -433,7 +433,6 @@ with gr.Blocks(css="""
433
  ["A person wearing a virtual reality headset, sitting position, bent legs, clasped hands."],
434
  ["A battle mech in a mix of red, blue, and black color, with a cannon on the head."],
435
  ["骷髅头, 邪恶的"],
436
-
437
  ],
438
  inputs=[prompt],
439
  label="Example Prompts",
 
187
  # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
188
  return k3d_wrapper.get_detailed_prompt(prompt, seed)
189
 
190
+ @spaces.GPU(duration=120)
191
  def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=18, redux_hparam=None, init_image=None, **kwargs):
192
  # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
193
  # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
 
210
  **kwargs)
211
  return result[-1]
212
 
213
+ @spaces.GPU(duration=120)
214
  def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
215
  global preprocessed_input_image
216
 
 
225
  return reference_save_path, caption
226
 
227
 
228
+ @spaces.GPU(duration=120)
229
  def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
230
  subprocess.run(['nvidia-smi'])
231
  global mesh_cache
 
252
  return gen_save_path, recon_mesh_path, mesh_cache
253
  # return gen_save_path, recon_mesh_path
254
 
255
+ @spaces.GPU(duration=120)
256
  def bundle_image_to_mesh(
257
  gen_3d_bundle_image,
258
  camera_radius=3.5,
 
433
  ["A person wearing a virtual reality headset, sitting position, bent legs, clasped hands."],
434
  ["A battle mech in a mix of red, blue, and black color, with a cannon on the head."],
435
  ["骷髅头, 邪恶的"],
 
436
  ],
437
  inputs=[prompt],
438
  label="Example Prompts",
pipeline/kiss3d_wrapper.py CHANGED
@@ -587,10 +587,10 @@ class kiss3d_wrapper(object):
587
  rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask)
588
 
589
  with self.context():
590
-
591
  return DiMeR_reconstruct(self.recon_model, self.recon_model_config.infer_config,
592
  self.texture_model, self.texture_model_config.infer_config,
593
- rgb_multi_view.to(recon_device), normal_multi_view.to(recon_device), name=self.uuid,
594
  input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results,
595
  render_azimuths=[0, 90, 180, 270],
596
  render_radius=lrm_render_radius,
 
587
  rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask)
588
 
589
  with self.context():
590
+ print("Image process done!")
591
  return DiMeR_reconstruct(self.recon_model, self.recon_model_config.infer_config,
592
  self.texture_model, self.texture_model_config.infer_config,
593
+ rgb_multi_view.to(recon_device), normal_multi_view.to(recon_device), multi_view_mask, name=self.uuid,
594
  input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results,
595
  render_azimuths=[0, 90, 180, 270],
596
  render_radius=lrm_render_radius,
pipeline/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import sys
3
  import logging
 
4
 
5
  __workdir__ = '/'.join(os.path.abspath(__file__).split('/')[:-2])
6
  sys.path.insert(0, __workdir__)
@@ -228,8 +229,8 @@ def preprocess_input_image(input_image):
228
 
229
 
230
 
231
-
232
- def DiMeR_reconstruct(model, infer_config, texture_model, texture_model_config, images, normals,
233
  name='', export_texmap=False,
234
  input_camera_type='zero123',
235
  render_3d_bundle_image=True,
@@ -252,34 +253,37 @@ def DiMeR_reconstruct(model, infer_config, texture_model, texture_model_config,
252
  else:
253
  raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
254
 
255
- # use rembg to get foreground mask
256
- fg_mask = []
257
- for i in range(4):
258
- image = images[i].permute(1, 2, 0).cpu().numpy()
259
- image = (image * 255).astype(np.uint8)
260
- image = rembg.remove(image, session=rembg_session)
261
- image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.
262
- image = image[3:4]
263
- fg_mask.append(image)
264
- fg_mask = torch.stack(fg_mask)
265
- bg_mask = 1 - fg_mask
266
 
267
  # TODO: Device Check
268
  global_normals = normal_transfer.trans_local_2_global(normals.cpu().permute(0,2,3,1), torch.tensor([0, 90, 180, 270]),
269
  torch.tensor([5, 5, 5, 5]), radius=4.5,
270
  for_lotus=True)
271
  global_normals = global_normals.permute(0, 3, 1, 2)
272
- global_normals = global_normals * fg_mask + bg_mask
 
273
  global_normals = F.pad(global_normals, (50, 50, 50, 50), value=1.)
274
  global_normals = F.interpolate(global_normals, (512, 512), mode='bilinear', align_corners=False)
275
  global_normals = global_normals.unsqueeze(0).clamp(0.0, 1.0).to(device)
276
 
277
- images = images.cpu() * fg_mask + bg_mask
 
 
278
  images = F.pad(images, (50, 50, 50, 50), value=1.)
279
  images = F.interpolate(images, (512, 512), mode='bilinear', align_corners=False)
280
  images = images.unsqueeze(0).clamp(0.0, 1.0).to(device)
281
 
282
- logger.info(f"==> Runing DiMeR geometry reconstruction ...")
283
  planes = model.forward_planes(global_normals, input_cameras)
284
  vertices, faces, _ = model.extract_mesh(
285
  planes,
@@ -287,7 +291,7 @@ def DiMeR_reconstruct(model, infer_config, texture_model, texture_model_config,
287
  **infer_config,
288
  )
289
 
290
- logger.info(f"==> Runing DiMeR texture reconstruction ...")
291
  # extract_mesh函数进行了旋转,进行还原,对齐训练时的方向
292
  vertices = torch.tensor(vertices, device=device)
293
  faces = torch.tensor(faces, device=device)
 
1
  import os
2
  import sys
3
  import logging
4
+ import time
5
 
6
  __workdir__ = '/'.join(os.path.abspath(__file__).split('/')[:-2])
7
  sys.path.insert(0, __workdir__)
 
229
 
230
 
231
 
232
+ @torch.no_grad()
233
+ def DiMeR_reconstruct(model, infer_config, texture_model, texture_model_config, images, normals, multi_view_mask,
234
  name='', export_texmap=False,
235
  input_camera_type='zero123',
236
  render_3d_bundle_image=True,
 
253
  else:
254
  raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
255
 
256
+ # # use rembg to get foreground mask
257
+ # fg_mask = []
258
+ # for i in range(4):
259
+ # image = images[i].permute(1, 2, 0).cpu().numpy()
260
+ # image = (image * 255).astype(np.uint8)
261
+ # image = rembg.remove(image, session=rembg_session)
262
+ # image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.
263
+ # image = image[3:4]
264
+ # fg_mask.append(image)
265
+ # fg_mask = torch.stack(fg_mask)
266
+ # bg_mask = 1 - fg_mask
267
 
268
  # TODO: Device Check
269
  global_normals = normal_transfer.trans_local_2_global(normals.cpu().permute(0,2,3,1), torch.tensor([0, 90, 180, 270]),
270
  torch.tensor([5, 5, 5, 5]), radius=4.5,
271
  for_lotus=True)
272
  global_normals = global_normals.permute(0, 3, 1, 2)
273
+ global_normals = global_normals * multi_view_mask + (1 - multi_view_mask)
274
+ # global_normals = global_normals * fg_mask + bg_mask
275
  global_normals = F.pad(global_normals, (50, 50, 50, 50), value=1.)
276
  global_normals = F.interpolate(global_normals, (512, 512), mode='bilinear', align_corners=False)
277
  global_normals = global_normals.unsqueeze(0).clamp(0.0, 1.0).to(device)
278
 
279
+ print(f"{time.time()} ==> local normal to global normal done")
280
+
281
+ # images = images.cpu() * fg_mask + bg_mask
282
  images = F.pad(images, (50, 50, 50, 50), value=1.)
283
  images = F.interpolate(images, (512, 512), mode='bilinear', align_corners=False)
284
  images = images.unsqueeze(0).clamp(0.0, 1.0).to(device)
285
 
286
+ print(f"{time.time()} ==> Runing DiMeR geometry reconstruction ...")
287
  planes = model.forward_planes(global_normals, input_cameras)
288
  vertices, faces, _ = model.extract_mesh(
289
  planes,
 
291
  **infer_config,
292
  )
293
 
294
+ print(f"{time.time()} ==> Runing DiMeR texture reconstruction ...")
295
  # extract_mesh函数进行了旋转,进行还原,对齐训练时的方向
296
  vertices = torch.tensor(vertices, device=device)
297
  faces = torch.tensor(faces, device=device)