LTT commited on
Commit
4095016
·
verified ·
1 Parent(s): c69ffaa
Files changed (1) hide show
  1. pipeline/utils.py +2 -2
pipeline/utils.py CHANGED
@@ -268,9 +268,9 @@ def DiMeR_reconstruct(model, infer_config, texture_model, texture_model_config,
268
  # TODO: Device Check
269
  global_normals = normal_transfer.trans_local_2_global(normals.cpu().permute(0,2,3,1), torch.tensor([0, 90, 180, 270]),
270
  torch.tensor([5, 5, 5, 5]), radius=4.5,
271
- for_lotus=True)
272
  global_normals = global_normals.permute(0, 3, 1, 2)
273
- global_normals = global_normals * multi_view_mask + (1 - multi_view_mask)
274
  # global_normals = global_normals * fg_mask + bg_mask
275
  global_normals = F.pad(global_normals, (50, 50, 50, 50), value=1.)
276
  global_normals = F.interpolate(global_normals, (512, 512), mode='bilinear', align_corners=False)
 
268
  # TODO: Device Check
269
  global_normals = normal_transfer.trans_local_2_global(normals.cpu().permute(0,2,3,1), torch.tensor([0, 90, 180, 270]),
270
  torch.tensor([5, 5, 5, 5]), radius=4.5,
271
+ for_lotus=True).to(device)
272
  global_normals = global_normals.permute(0, 3, 1, 2)
273
+ global_normals = global_normals * multi_view_mask.to(device) + (1 - multi_view_mask.to(device))
274
  # global_normals = global_normals * fg_mask + bg_mask
275
  global_normals = F.pad(global_normals, (50, 50, 50, 50), value=1.)
276
  global_normals = F.interpolate(global_normals, (512, 512), mode='bilinear', align_corners=False)