from lib.kits.basic import * import pickle from smplx.vertex_joint_selector import VertexJointSelector from smplx.vertex_ids import vertex_ids from smplx.lbs import vertices2joints from lib.body_models.skel.skel_model import SKEL, SKELOutput class SKELWrapper(SKEL): def __init__( self, *args, joint_regressor_custom: Optional[str] = None, joint_regressor_extra : Optional[str] = None, update_root : bool = False, **kwargs ): ''' This wrapper aims to extend the output joints of the SKEL model which fits SMPL's portal. ''' super(SKELWrapper, self).__init__(*args, **kwargs) # The final joints are combined from three parts: # 1. The joints from the standard output. # Map selected joints of interests from SKEL to SMPL. (Not all 24 joints will be used finally.) # Notes: Only these SMPL joints will be used: [0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53] skel_to_smpl = [ 0, 6, 1, 11, # not aligned well; not used 7, 2, 11, # not aligned well; not used 8, # or 9 3, # or 4 12, # not aligned well; not used 10, # not used 5, # not used 12, 19, # not aligned well; not used 14, # not aligned well; not used 13, # not used 20, # or 19 15, # or 14 21, # or 22 16, # or 17, 23, 18, 23, # not aligned well; not used 18, # not aligned well; not used ] smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] self.register_buffer('J_skel_to_smpl', torch.tensor(skel_to_smpl, dtype=torch.long)) self.register_buffer('J_smpl_to_openpose', torch.tensor(smpl_to_openpose, dtype=torch.long)) # (SKEL has the same topology as SMPL as well as SMPL-H, so perform the same operation for the other 2 parts.) # 2. Joints selected from skin vertices. self.vertex_joint_selector = VertexJointSelector(vertex_ids['smplh']) # 3. Extra joints from the J_regressor_extra. if joint_regressor_extra is not None: self.register_buffer( 'J_regressor_extra', torch.tensor(pickle.load( open(joint_regressor_extra, 'rb'), encoding='latin1' ), dtype=torch.float32) ) self.custom_regress_joints = joint_regressor_custom is not None if self.custom_regress_joints: get_logger().info('Using customized joint regressor.') with open(joint_regressor_custom, 'rb') as f: J_regressor_custom = pickle.load(f, encoding='latin1') if 'scipy.sparse' in str(type(J_regressor_custom)): J_regressor_custom = J_regressor_custom.todense() # (24, 6890) self.register_buffer( 'J_regressor_custom', torch.tensor( J_regressor_custom, dtype=torch.float32 ) ) self.update_root = update_root def forward(self, **kwargs) -> SKELOutput: # type: ignore ''' Map the order of joints of SKEL to SMPL's. ''' if 'trans' not in kwargs.keys(): kwargs['trans'] = kwargs['poses'].new_zeros((kwargs['poses'].shape[0], 3)) # (B, 3) skel_output = super(SKELWrapper, self).forward(**kwargs) verts = skel_output.skin_verts # (B, 6890, 3) joints = skel_output.joints.clone() # (B, 24, 3) # Update the root joint position (to avoid the root too forward). if self.update_root: # make root 0 to plane 11-1-6 hips_middle = (joints[:, 1] + joints[:, 6]) / 2 # (B, 3) lumbar2middle = (hips_middle - joints[:, 11]) # (B, 3) lumbar2middle_unit = lumbar2middle / torch.norm(lumbar2middle, dim=1, keepdim=True) # (B, 3) lumbar2root = joints[:, 0] - joints[:, 11] lumbar2root_proj = \ torch.einsum('bc,bc->b', lumbar2root, lumbar2middle_unit)[:, None] *\ lumbar2middle_unit # (B, 3) root2root_proj = lumbar2root_proj - lumbar2root # (B, 3) joints[:, 0] += root2root_proj * 0.7 # Combine the joints from three parts: if self.custom_regress_joints: # 1.x. Regress joints from the skin vertices using SMPL's regressor. joints = vertices2joints(self.J_regressor_custom, verts) # (B, 24, 3) else: # 1.y. Map selected joints of interests from SKEL to SMPL. joints = joints[:, self.J_skel_to_smpl] # (B, 24, 3) joints_custom = joints.clone() # 2. Concat joints selected from skin vertices. joints = self.vertex_joint_selector(verts, joints) # (B, 45, 3) # 3. Map selected joints to OpenPose. joints = joints[:, self.J_smpl_to_openpose] # (B, 25, 3) # 4. Add extra joints from the J_regressor_extra. joints_extra = vertices2joints(self.J_regressor_extra, verts) # (B, 19, 3) joints = torch.cat([joints, joints_extra], dim=1) # (B, 44, 3) # Update the joints in the output. skel_output.joints_backup = skel_output.joints skel_output.joints_custom = joints_custom skel_output.joints = joints return skel_output @staticmethod def get_static_root_offset(skel_output): ''' Background: By default, the orientation rotation is always around the original skel_root. In order to make the orientation rotation around the custom_root, we need to calculate the translation offset. This function calculates the translation offset in static pose. (From custom_root to skel_root.) ''' custom_root = skel_output.joints_custom[:, 0] # (B, 3) skel_root = skel_output.joints_backup[:, 0] # (B, 3) offset = skel_root - custom_root # (B, 3) return offset