LutaoJiang commited on
Commit
ffbc047
·
1 Parent(s): fc24b20
models/DiMeR/models/DiMeR.py CHANGED
@@ -109,6 +109,7 @@ class DiMeR(nn.Module):
109
 
110
  return planes
111
 
 
112
  def get_sdf_prediction(self, planes):
113
  '''
114
  Predict SDF and deformation for tetrahedron vertices
@@ -126,6 +127,7 @@ class DiMeR(nn.Module):
126
 
127
  return sdf
128
 
 
129
  def get_sdf_deformation_prediction(self, planes):
130
  '''
131
  Predict SDF and deformation for tetrahedron vertices
@@ -183,6 +185,7 @@ class DiMeR(nn.Module):
183
  deformation = torch.cat(final_def, dim=0)
184
  return sdf, deformation, sdf_reg_loss, weight
185
 
 
186
  def get_geometry_prediction(self, planes=None):
187
  '''
188
  Function to generate mesh with give triplanes
@@ -235,6 +238,7 @@ class DiMeR(nn.Module):
235
 
236
  return v_list, f_list, imesh_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
237
 
 
238
  def get_texture_prediction(self, planes, tex_pos, hard_mask=None, gb_normal=None, training=True):
239
  '''
240
  Predict Texture given triplanes
@@ -474,6 +478,7 @@ class DiMeR(nn.Module):
474
  **out
475
  }
476
 
 
477
  def extract_mesh(
478
  self,
479
  planes: torch.Tensor,
 
109
 
110
  return planes
111
 
112
+ @spaces.GPU
113
  def get_sdf_prediction(self, planes):
114
  '''
115
  Predict SDF and deformation for tetrahedron vertices
 
127
 
128
  return sdf
129
 
130
+ @spaces.GPU
131
  def get_sdf_deformation_prediction(self, planes):
132
  '''
133
  Predict SDF and deformation for tetrahedron vertices
 
185
  deformation = torch.cat(final_def, dim=0)
186
  return sdf, deformation, sdf_reg_loss, weight
187
 
188
+ @spaces.GPU
189
  def get_geometry_prediction(self, planes=None):
190
  '''
191
  Function to generate mesh with give triplanes
 
238
 
239
  return v_list, f_list, imesh_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
240
 
241
+ @spaces.GPU
242
  def get_texture_prediction(self, planes, tex_pos, hard_mask=None, gb_normal=None, training=True):
243
  '''
244
  Predict Texture given triplanes
 
478
  **out
479
  }
480
 
481
+ @spaces.GPU
482
  def extract_mesh(
483
  self,
484
  planes: torch.Tensor,
models/DiMeR/models/renderer/synthesizer_mesh.py CHANGED
@@ -9,6 +9,7 @@ import itertools
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
 
12
 
13
  from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes
14
 
@@ -94,6 +95,7 @@ class OSGDecoder(nn.Module):
94
 
95
  return sdf, deformation, weight
96
 
 
97
  def get_texture_prediction(self, sampled_features):
98
  _N, n_planes, _M, _C = sampled_features.shape
99
  sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
@@ -156,6 +158,7 @@ class TriplaneSynthesizer(nn.Module):
156
  sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
157
  return sdf, deformation, weight
158
 
 
159
  def get_texture_prediction(self, planes, sample_coordinates):
160
  plane_axes = self.plane_axes.to(planes.device)
161
  sampled_features = sample_from_planes(
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ import spaces
13
 
14
  from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes
15
 
 
95
 
96
  return sdf, deformation, weight
97
 
98
+ @spaces.GPU
99
  def get_texture_prediction(self, sampled_features):
100
  _N, n_planes, _M, _C = sampled_features.shape
101
  sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
 
158
  sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
159
  return sdf, deformation, weight
160
 
161
+ @spaces.GPU
162
  def get_texture_prediction(self, planes, sample_coordinates):
163
  plane_axes = self.plane_axes.to(planes.device)
164
  sampled_features = sample_from_planes(