Spaces:
Paused
Paused
| import pymeshlab | |
| import torch | |
| from nvdiffmodeling.src import obj | |
| from nvdiffmodeling.src import mesh | |
| from nvdiffmodeling.src import texture | |
| import numpy as np | |
| from utilities.helpers import get_vp_map | |
| import os | |
| texture_map = texture.create_trainable(np.random.uniform(size=[512] * 2 + [3], low=0.0, high=1.0), [512] * 2, True) | |
| normal_map = texture.create_trainable(np.array([0, 0, 1]), [512] * 2, True) | |
| specular_map = texture.create_trainable(np.array([0, 0, 0]), [512] * 2, True) | |
| def get_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'): | |
| try: | |
| print(f"Loading mesh from: {mesh_path}") | |
| # Check if mesh file exists | |
| if not os.path.exists(mesh_path): | |
| raise FileNotFoundError(f"Mesh file not found: {mesh_path}") | |
| ms = pymeshlab.MeshSet() | |
| ms.load_new_mesh(mesh_path) | |
| # Check if mesh was loaded successfully | |
| if ms.current_mesh().vertex_number() == 0: | |
| raise ValueError(f"Mesh file {mesh_path} has no vertices") | |
| print(f"Loaded mesh with {ms.current_mesh().vertex_number()} vertices and {ms.current_mesh().face_number()} faces") | |
| if triangulate_flag: | |
| print('Retriangulating shape') | |
| ms.meshing_isotropic_explicit_remeshing() | |
| if not ms.current_mesh().has_wedge_tex_coord(): | |
| # some arbitrarily high number | |
| ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000) | |
| # Ensure the tmp directory exists | |
| tmp_dir = output_path / 'tmp' | |
| tmp_dir.mkdir(exist_ok=True) | |
| tmp_mesh_path = tmp_dir / mesh_name | |
| print(f"Saving temporary mesh to: {tmp_mesh_path}") | |
| ms.save_current_mesh(str(tmp_mesh_path)) | |
| print(f"Loading OBJ from temporary path: {tmp_mesh_path}") | |
| load_mesh = obj.load_obj(str(tmp_mesh_path)) | |
| # Check if mesh was loaded successfully | |
| if load_mesh.v_pos is None or load_mesh.v_pos.shape[0] == 0: | |
| raise ValueError(f"Failed to load mesh vertices from {tmp_mesh_path}") | |
| if load_mesh.t_pos_idx is None or load_mesh.t_pos_idx.shape[0] == 0: | |
| raise ValueError(f"Failed to load mesh faces from {tmp_mesh_path}") | |
| print(f"Loaded mesh with {load_mesh.v_pos.shape[0]} vertices and {load_mesh.t_pos_idx.shape[0]} faces") | |
| load_mesh = mesh.unit_size(load_mesh) | |
| ms.add_mesh( | |
| pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy())) | |
| ms.save_current_mesh(str(tmp_mesh_path), save_vertex_color=False) | |
| load_mesh = mesh.Mesh( | |
| material={ | |
| 'bsdf': bsdf_flag, | |
| 'kd': texture_map, | |
| 'ks': specular_map, | |
| 'normal': normal_map, | |
| }, | |
| base=load_mesh # Get UVs from original loaded mesh | |
| ) | |
| # Final check to ensure mesh is valid | |
| if load_mesh.v_pos is None or load_mesh.v_pos.shape[0] == 0: | |
| raise ValueError("Final mesh has no vertices") | |
| if load_mesh.t_pos_idx is None or load_mesh.t_pos_idx.shape[0] == 0: | |
| raise ValueError("Final mesh has no faces") | |
| print(f"Successfully loaded mesh with {load_mesh.v_pos.shape[0]} vertices and {load_mesh.t_pos_idx.shape[0]} faces") | |
| return load_mesh | |
| except Exception as e: | |
| print(f"Error in get_mesh: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def get_og_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'): | |
| ms = pymeshlab.MeshSet() | |
| ms.load_new_mesh(mesh_path) | |
| if triangulate_flag: | |
| print('Retriangulating shape') | |
| ms.meshing_isotropic_explicit_remeshing() | |
| if not ms.current_mesh().has_wedge_tex_coord(): | |
| # some arbitrarily high number | |
| ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000) | |
| ms.save_current_mesh(str(output_path / 'tmp' / mesh_name)) | |
| load_mesh = obj.load_obj(str(output_path / 'tmp' / mesh_name)) | |
| load_mesh = mesh.resize_mesh(load_mesh) | |
| ms.add_mesh( | |
| pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy())) | |
| ms.save_current_mesh(str(output_path / 'tmp' / mesh_name), save_vertex_color=False) | |
| load_mesh = mesh.Mesh( | |
| material={ | |
| 'bsdf': bsdf_flag, | |
| 'kd': texture_map, | |
| 'ks': specular_map, | |
| 'normal': normal_map, | |
| }, | |
| base=load_mesh # Get UVs from original loaded mesh | |
| ) | |
| return load_mesh | |
| def compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device): | |
| # Consistency loss | |
| # Check if fe is available | |
| if fe is None: | |
| print("Warning: CLIPVisualEncoder not available, skipping consistency loss") | |
| return torch.tensor(0.0, device=device) | |
| # Get mapping from vertex to pixels | |
| curr_vp_map = get_vp_map(final_mesh.v_pos, params_camera['mvp'], 224) | |
| for idx, rast_faces in enumerate(train_rast_map[:, :, :, 3].view(cfg.batch_size, -1)): | |
| u_faces = rast_faces.unique().long()[1:] - 1 | |
| t = torch.arange(len(final_mesh.v_pos), device=device) | |
| u_ret = torch.cat([t, final_mesh.t_pos_idx[u_faces].flatten()]).unique(return_counts=True) | |
| non_verts = u_ret[0][u_ret[1] < 2] | |
| curr_vp_map[idx][non_verts] = torch.tensor([224, 224], device=device) | |
| # Get mapping from vertex to patch | |
| med = (fe.old_stride - 1) / 2 | |
| curr_vp_map[curr_vp_map < med] = med | |
| curr_vp_map[(curr_vp_map > 224 - fe.old_stride) & (curr_vp_map < 224)] = 223 - med | |
| curr_patch_map = ((curr_vp_map - med) / fe.new_stride).round() | |
| flat_patch_map = curr_patch_map[..., 0] * (((224 - fe.old_stride) / fe.new_stride) + 1) + curr_patch_map[..., 1] | |
| # Deep features | |
| patch_feats = fe(normalized_clip_render) | |
| flat_patch_map[flat_patch_map > patch_feats[0].shape[-1] - 1] = patch_feats[0].shape[-1] | |
| flat_patch_map = flat_patch_map.long()[:, None, :].repeat(1, patch_feats[0].shape[1], 1) | |
| deep_feats = patch_feats[cfg.consistency_vit_layer] | |
| deep_feats = torch.nn.functional.pad(deep_feats, (0, 1)) | |
| deep_feats = torch.gather(deep_feats, dim=2, index=flat_patch_map) | |
| deep_feats = torch.nn.functional.normalize(deep_feats, dim=1, eps=1e-6) | |
| elev_d = torch.cdist(params_camera['elev'].unsqueeze(1), params_camera['elev'].unsqueeze(1)).abs() < torch.deg2rad( | |
| torch.tensor(cfg.consistency_elev_filter)) | |
| azim_d = torch.cdist(params_camera['azim'].unsqueeze(1), params_camera['azim'].unsqueeze(1)).abs() < torch.deg2rad( | |
| torch.tensor(cfg.consistency_azim_filter)) | |
| cosines = torch.einsum('ijk, lkj -> ilk', deep_feats, deep_feats.permute(0, 2, 1)) | |
| cosines = (cosines * azim_d.unsqueeze(-1) * elev_d.unsqueeze(-1)).permute(2, 0, 1).triu(1) | |
| consistency_loss = cosines[cosines != 0].mean() | |
| return consistency_loss |