File size: 7,077 Bytes
f498ac0
 
 
 
 
 
 
d172fbe
f498ac0
 
 
 
 
 
0ba16db
 
d172fbe
 
 
 
 
0ba16db
 
d172fbe
 
 
 
 
 
f498ac0
0ba16db
 
 
f498ac0
0ba16db
 
 
 
 
 
 
 
 
 
 
 
 
 
d172fbe
 
 
 
 
 
 
 
 
 
0ba16db
 
 
 
 
f498ac0
0ba16db
 
 
 
 
 
 
 
 
d172fbe
 
 
 
 
 
 
 
 
0ba16db
d172fbe
0ba16db
 
 
 
 
f498ac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c539b3
 
 
 
 
f498ac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import pymeshlab
import torch
from nvdiffmodeling.src     import obj
from nvdiffmodeling.src     import mesh
from nvdiffmodeling.src     import texture
import numpy as np
from utilities.helpers import get_vp_map
import os

texture_map = texture.create_trainable(np.random.uniform(size=[512] * 2 + [3], low=0.0, high=1.0), [512] * 2, True)
normal_map = texture.create_trainable(np.array([0, 0, 1]), [512] * 2, True)
specular_map = texture.create_trainable(np.array([0, 0, 0]), [512] * 2, True)

def get_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):
    try:
        print(f"Loading mesh from: {mesh_path}")
        
        # Check if mesh file exists
        if not os.path.exists(mesh_path):
            raise FileNotFoundError(f"Mesh file not found: {mesh_path}")
        
        ms = pymeshlab.MeshSet()
        ms.load_new_mesh(mesh_path)
        
        # Check if mesh was loaded successfully
        if ms.current_mesh().vertex_number() == 0:
            raise ValueError(f"Mesh file {mesh_path} has no vertices")
        
        print(f"Loaded mesh with {ms.current_mesh().vertex_number()} vertices and {ms.current_mesh().face_number()} faces")

        if triangulate_flag:
            print('Retriangulating shape')
            ms.meshing_isotropic_explicit_remeshing()

        if not ms.current_mesh().has_wedge_tex_coord():
            # some arbitrarily high number
            ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000)

        # Ensure the tmp directory exists
        tmp_dir = output_path / 'tmp'
        tmp_dir.mkdir(exist_ok=True)
        
        tmp_mesh_path = tmp_dir / mesh_name
        print(f"Saving temporary mesh to: {tmp_mesh_path}")
        ms.save_current_mesh(str(tmp_mesh_path))

        print(f"Loading OBJ from temporary path: {tmp_mesh_path}")
        load_mesh = obj.load_obj(str(tmp_mesh_path))
        
        # Check if mesh was loaded successfully
        if load_mesh.v_pos is None or load_mesh.v_pos.shape[0] == 0:
            raise ValueError(f"Failed to load mesh vertices from {tmp_mesh_path}")
        
        if load_mesh.t_pos_idx is None or load_mesh.t_pos_idx.shape[0] == 0:
            raise ValueError(f"Failed to load mesh faces from {tmp_mesh_path}")
        
        print(f"Loaded mesh with {load_mesh.v_pos.shape[0]} vertices and {load_mesh.t_pos_idx.shape[0]} faces")
        
        load_mesh = mesh.unit_size(load_mesh)

        ms.add_mesh(
            pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy()))
        ms.save_current_mesh(str(tmp_mesh_path), save_vertex_color=False)

        load_mesh = mesh.Mesh(
            material={
                'bsdf': bsdf_flag,
                'kd': texture_map,
                'ks': specular_map,
                'normal': normal_map,
            },
            base=load_mesh  # Get UVs from original loaded mesh
        )
        
        # Final check to ensure mesh is valid
        if load_mesh.v_pos is None or load_mesh.v_pos.shape[0] == 0:
            raise ValueError("Final mesh has no vertices")
        
        if load_mesh.t_pos_idx is None or load_mesh.t_pos_idx.shape[0] == 0:
            raise ValueError("Final mesh has no faces")
        
        print(f"Successfully loaded mesh with {load_mesh.v_pos.shape[0]} vertices and {load_mesh.t_pos_idx.shape[0]} faces")
        return load_mesh
        
    except Exception as e:
        print(f"Error in get_mesh: {e}")
        import traceback
        traceback.print_exc()
        raise


def get_og_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):
    ms = pymeshlab.MeshSet()
    ms.load_new_mesh(mesh_path)

    if triangulate_flag:
        print('Retriangulating shape')
        ms.meshing_isotropic_explicit_remeshing()

    if not ms.current_mesh().has_wedge_tex_coord():
        # some arbitrarily high number
        ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000)

    ms.save_current_mesh(str(output_path / 'tmp' / mesh_name))

    load_mesh = obj.load_obj(str(output_path / 'tmp' / mesh_name))
    load_mesh = mesh.resize_mesh(load_mesh)

    ms.add_mesh(
        pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy()))
    ms.save_current_mesh(str(output_path / 'tmp' / mesh_name), save_vertex_color=False)

    load_mesh = mesh.Mesh(
        material={
            'bsdf': bsdf_flag,
            'kd': texture_map,
            'ks': specular_map,
            'normal': normal_map,
        },
        base=load_mesh  # Get UVs from original loaded mesh
    )
    return load_mesh

def compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device):
    # Consistency loss
    # Check if fe is available
    if fe is None:
        print("Warning: CLIPVisualEncoder not available, skipping consistency loss")
        return torch.tensor(0.0, device=device)
    
    # Get mapping from vertex to pixels
    curr_vp_map = get_vp_map(final_mesh.v_pos, params_camera['mvp'], 224)
    for idx, rast_faces in enumerate(train_rast_map[:, :, :, 3].view(cfg.batch_size, -1)):
        u_faces = rast_faces.unique().long()[1:] - 1
        t = torch.arange(len(final_mesh.v_pos), device=device)
        u_ret = torch.cat([t, final_mesh.t_pos_idx[u_faces].flatten()]).unique(return_counts=True)
        non_verts = u_ret[0][u_ret[1] < 2]
        curr_vp_map[idx][non_verts] = torch.tensor([224, 224], device=device)

    # Get mapping from vertex to patch
    med = (fe.old_stride - 1) / 2
    curr_vp_map[curr_vp_map < med] = med
    curr_vp_map[(curr_vp_map > 224 - fe.old_stride) & (curr_vp_map < 224)] = 223 - med
    curr_patch_map = ((curr_vp_map - med) / fe.new_stride).round()
    flat_patch_map = curr_patch_map[..., 0] * (((224 - fe.old_stride) / fe.new_stride) + 1) + curr_patch_map[..., 1]

    # Deep features
    patch_feats = fe(normalized_clip_render)
    flat_patch_map[flat_patch_map > patch_feats[0].shape[-1] - 1] = patch_feats[0].shape[-1]
    flat_patch_map = flat_patch_map.long()[:, None, :].repeat(1, patch_feats[0].shape[1], 1)

    deep_feats = patch_feats[cfg.consistency_vit_layer]
    deep_feats = torch.nn.functional.pad(deep_feats, (0, 1))
    deep_feats = torch.gather(deep_feats, dim=2, index=flat_patch_map)
    deep_feats = torch.nn.functional.normalize(deep_feats, dim=1, eps=1e-6)

    elev_d = torch.cdist(params_camera['elev'].unsqueeze(1), params_camera['elev'].unsqueeze(1)).abs() < torch.deg2rad(
        torch.tensor(cfg.consistency_elev_filter))
    azim_d = torch.cdist(params_camera['azim'].unsqueeze(1), params_camera['azim'].unsqueeze(1)).abs() < torch.deg2rad(
        torch.tensor(cfg.consistency_azim_filter))

    cosines = torch.einsum('ijk, lkj -> ilk', deep_feats, deep_feats.permute(0, 2, 1))
    cosines = (cosines * azim_d.unsqueeze(-1) * elev_d.unsqueeze(-1)).permute(2, 0, 1).triu(1)
    consistency_loss = cosines[cosines != 0].mean()
    return consistency_loss