Spaces:
Paused
Paused
File size: 7,077 Bytes
f498ac0 d172fbe f498ac0 0ba16db d172fbe 0ba16db d172fbe f498ac0 0ba16db f498ac0 0ba16db d172fbe 0ba16db f498ac0 0ba16db d172fbe 0ba16db d172fbe 0ba16db f498ac0 4c539b3 f498ac0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import pymeshlab
import torch
from nvdiffmodeling.src import obj
from nvdiffmodeling.src import mesh
from nvdiffmodeling.src import texture
import numpy as np
from utilities.helpers import get_vp_map
import os
texture_map = texture.create_trainable(np.random.uniform(size=[512] * 2 + [3], low=0.0, high=1.0), [512] * 2, True)
normal_map = texture.create_trainable(np.array([0, 0, 1]), [512] * 2, True)
specular_map = texture.create_trainable(np.array([0, 0, 0]), [512] * 2, True)
def get_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):
try:
print(f"Loading mesh from: {mesh_path}")
# Check if mesh file exists
if not os.path.exists(mesh_path):
raise FileNotFoundError(f"Mesh file not found: {mesh_path}")
ms = pymeshlab.MeshSet()
ms.load_new_mesh(mesh_path)
# Check if mesh was loaded successfully
if ms.current_mesh().vertex_number() == 0:
raise ValueError(f"Mesh file {mesh_path} has no vertices")
print(f"Loaded mesh with {ms.current_mesh().vertex_number()} vertices and {ms.current_mesh().face_number()} faces")
if triangulate_flag:
print('Retriangulating shape')
ms.meshing_isotropic_explicit_remeshing()
if not ms.current_mesh().has_wedge_tex_coord():
# some arbitrarily high number
ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000)
# Ensure the tmp directory exists
tmp_dir = output_path / 'tmp'
tmp_dir.mkdir(exist_ok=True)
tmp_mesh_path = tmp_dir / mesh_name
print(f"Saving temporary mesh to: {tmp_mesh_path}")
ms.save_current_mesh(str(tmp_mesh_path))
print(f"Loading OBJ from temporary path: {tmp_mesh_path}")
load_mesh = obj.load_obj(str(tmp_mesh_path))
# Check if mesh was loaded successfully
if load_mesh.v_pos is None or load_mesh.v_pos.shape[0] == 0:
raise ValueError(f"Failed to load mesh vertices from {tmp_mesh_path}")
if load_mesh.t_pos_idx is None or load_mesh.t_pos_idx.shape[0] == 0:
raise ValueError(f"Failed to load mesh faces from {tmp_mesh_path}")
print(f"Loaded mesh with {load_mesh.v_pos.shape[0]} vertices and {load_mesh.t_pos_idx.shape[0]} faces")
load_mesh = mesh.unit_size(load_mesh)
ms.add_mesh(
pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy()))
ms.save_current_mesh(str(tmp_mesh_path), save_vertex_color=False)
load_mesh = mesh.Mesh(
material={
'bsdf': bsdf_flag,
'kd': texture_map,
'ks': specular_map,
'normal': normal_map,
},
base=load_mesh # Get UVs from original loaded mesh
)
# Final check to ensure mesh is valid
if load_mesh.v_pos is None or load_mesh.v_pos.shape[0] == 0:
raise ValueError("Final mesh has no vertices")
if load_mesh.t_pos_idx is None or load_mesh.t_pos_idx.shape[0] == 0:
raise ValueError("Final mesh has no faces")
print(f"Successfully loaded mesh with {load_mesh.v_pos.shape[0]} vertices and {load_mesh.t_pos_idx.shape[0]} faces")
return load_mesh
except Exception as e:
print(f"Error in get_mesh: {e}")
import traceback
traceback.print_exc()
raise
def get_og_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):
ms = pymeshlab.MeshSet()
ms.load_new_mesh(mesh_path)
if triangulate_flag:
print('Retriangulating shape')
ms.meshing_isotropic_explicit_remeshing()
if not ms.current_mesh().has_wedge_tex_coord():
# some arbitrarily high number
ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000)
ms.save_current_mesh(str(output_path / 'tmp' / mesh_name))
load_mesh = obj.load_obj(str(output_path / 'tmp' / mesh_name))
load_mesh = mesh.resize_mesh(load_mesh)
ms.add_mesh(
pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy()))
ms.save_current_mesh(str(output_path / 'tmp' / mesh_name), save_vertex_color=False)
load_mesh = mesh.Mesh(
material={
'bsdf': bsdf_flag,
'kd': texture_map,
'ks': specular_map,
'normal': normal_map,
},
base=load_mesh # Get UVs from original loaded mesh
)
return load_mesh
def compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device):
# Consistency loss
# Check if fe is available
if fe is None:
print("Warning: CLIPVisualEncoder not available, skipping consistency loss")
return torch.tensor(0.0, device=device)
# Get mapping from vertex to pixels
curr_vp_map = get_vp_map(final_mesh.v_pos, params_camera['mvp'], 224)
for idx, rast_faces in enumerate(train_rast_map[:, :, :, 3].view(cfg.batch_size, -1)):
u_faces = rast_faces.unique().long()[1:] - 1
t = torch.arange(len(final_mesh.v_pos), device=device)
u_ret = torch.cat([t, final_mesh.t_pos_idx[u_faces].flatten()]).unique(return_counts=True)
non_verts = u_ret[0][u_ret[1] < 2]
curr_vp_map[idx][non_verts] = torch.tensor([224, 224], device=device)
# Get mapping from vertex to patch
med = (fe.old_stride - 1) / 2
curr_vp_map[curr_vp_map < med] = med
curr_vp_map[(curr_vp_map > 224 - fe.old_stride) & (curr_vp_map < 224)] = 223 - med
curr_patch_map = ((curr_vp_map - med) / fe.new_stride).round()
flat_patch_map = curr_patch_map[..., 0] * (((224 - fe.old_stride) / fe.new_stride) + 1) + curr_patch_map[..., 1]
# Deep features
patch_feats = fe(normalized_clip_render)
flat_patch_map[flat_patch_map > patch_feats[0].shape[-1] - 1] = patch_feats[0].shape[-1]
flat_patch_map = flat_patch_map.long()[:, None, :].repeat(1, patch_feats[0].shape[1], 1)
deep_feats = patch_feats[cfg.consistency_vit_layer]
deep_feats = torch.nn.functional.pad(deep_feats, (0, 1))
deep_feats = torch.gather(deep_feats, dim=2, index=flat_patch_map)
deep_feats = torch.nn.functional.normalize(deep_feats, dim=1, eps=1e-6)
elev_d = torch.cdist(params_camera['elev'].unsqueeze(1), params_camera['elev'].unsqueeze(1)).abs() < torch.deg2rad(
torch.tensor(cfg.consistency_elev_filter))
azim_d = torch.cdist(params_camera['azim'].unsqueeze(1), params_camera['azim'].unsqueeze(1)).abs() < torch.deg2rad(
torch.tensor(cfg.consistency_azim_filter))
cosines = torch.einsum('ijk, lkj -> ilk', deep_feats, deep_feats.permute(0, 2, 1))
cosines = (cosines * azim_d.unsqueeze(-1) * elev_d.unsqueeze(-1)).permute(2, 0, 1).triu(1)
consistency_loss = cosines[cosines != 0].mean()
return consistency_loss |