Spaces:
Sleeping
Sleeping
from lib.kits.basic import * | |
import pickle | |
from smplx.vertex_joint_selector import VertexJointSelector | |
from smplx.vertex_ids import vertex_ids | |
from smplx.lbs import vertices2joints | |
from lib.body_models.skel.skel_model import SKEL, SKELOutput | |
class SKELWrapper(SKEL): | |
def __init__( | |
self, | |
*args, | |
joint_regressor_custom: Optional[str] = None, | |
joint_regressor_extra : Optional[str] = None, | |
update_root : bool = False, | |
**kwargs | |
): | |
''' This wrapper aims to extend the output joints of the SKEL model which fits SMPL's portal. ''' | |
super(SKELWrapper, self).__init__(*args, **kwargs) | |
# The final joints are combined from three parts: | |
# 1. The joints from the standard output. | |
# Map selected joints of interests from SKEL to SMPL. (Not all 24 joints will be used finally.) | |
# Notes: Only these SMPL joints will be used: [0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53] | |
skel_to_smpl = [ | |
0, | |
6, | |
1, | |
11, # not aligned well; not used | |
7, | |
2, | |
11, # not aligned well; not used | |
8, # or 9 | |
3, # or 4 | |
12, # not aligned well; not used | |
10, # not used | |
5, # not used | |
12, | |
19, # not aligned well; not used | |
14, # not aligned well; not used | |
13, # not used | |
20, # or 19 | |
15, # or 14 | |
21, # or 22 | |
16, # or 17, | |
23, | |
18, | |
23, # not aligned well; not used | |
18, # not aligned well; not used | |
] | |
smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] | |
self.register_buffer('J_skel_to_smpl', torch.tensor(skel_to_smpl, dtype=torch.long)) | |
self.register_buffer('J_smpl_to_openpose', torch.tensor(smpl_to_openpose, dtype=torch.long)) | |
# (SKEL has the same topology as SMPL as well as SMPL-H, so perform the same operation for the other 2 parts.) | |
# 2. Joints selected from skin vertices. | |
self.vertex_joint_selector = VertexJointSelector(vertex_ids['smplh']) | |
# 3. Extra joints from the J_regressor_extra. | |
if joint_regressor_extra is not None: | |
self.register_buffer( | |
'J_regressor_extra', | |
torch.tensor(pickle.load( | |
open(joint_regressor_extra, 'rb'), | |
encoding='latin1' | |
), dtype=torch.float32) | |
) | |
self.custom_regress_joints = joint_regressor_custom is not None | |
if self.custom_regress_joints: | |
get_logger().info('Using customized joint regressor.') | |
with open(joint_regressor_custom, 'rb') as f: | |
J_regressor_custom = pickle.load(f, encoding='latin1') | |
if 'scipy.sparse' in str(type(J_regressor_custom)): | |
J_regressor_custom = J_regressor_custom.todense() # (24, 6890) | |
self.register_buffer( | |
'J_regressor_custom', | |
torch.tensor( | |
J_regressor_custom, | |
dtype=torch.float32 | |
) | |
) | |
self.update_root = update_root | |
def forward(self, **kwargs) -> SKELOutput: # type: ignore | |
''' Map the order of joints of SKEL to SMPL's. ''' | |
if 'trans' not in kwargs.keys(): | |
kwargs['trans'] = kwargs['poses'].new_zeros((kwargs['poses'].shape[0], 3)) # (B, 3) | |
skel_output = super(SKELWrapper, self).forward(**kwargs) | |
verts = skel_output.skin_verts # (B, 6890, 3) | |
joints = skel_output.joints.clone() # (B, 24, 3) | |
# Update the root joint position (to avoid the root too forward). | |
if self.update_root: | |
# make root 0 to plane 11-1-6 | |
hips_middle = (joints[:, 1] + joints[:, 6]) / 2 # (B, 3) | |
lumbar2middle = (hips_middle - joints[:, 11]) # (B, 3) | |
lumbar2middle_unit = lumbar2middle / torch.norm(lumbar2middle, dim=1, keepdim=True) # (B, 3) | |
lumbar2root = joints[:, 0] - joints[:, 11] | |
lumbar2root_proj = \ | |
torch.einsum('bc,bc->b', lumbar2root, lumbar2middle_unit)[:, None] *\ | |
lumbar2middle_unit # (B, 3) | |
root2root_proj = lumbar2root_proj - lumbar2root # (B, 3) | |
joints[:, 0] += root2root_proj * 0.7 | |
# Combine the joints from three parts: | |
if self.custom_regress_joints: | |
# 1.x. Regress joints from the skin vertices using SMPL's regressor. | |
joints = vertices2joints(self.J_regressor_custom, verts) # (B, 24, 3) | |
else: | |
# 1.y. Map selected joints of interests from SKEL to SMPL. | |
joints = joints[:, self.J_skel_to_smpl] # (B, 24, 3) | |
joints_custom = joints.clone() | |
# 2. Concat joints selected from skin vertices. | |
joints = self.vertex_joint_selector(verts, joints) # (B, 45, 3) | |
# 3. Map selected joints to OpenPose. | |
joints = joints[:, self.J_smpl_to_openpose] # (B, 25, 3) | |
# 4. Add extra joints from the J_regressor_extra. | |
joints_extra = vertices2joints(self.J_regressor_extra, verts) # (B, 19, 3) | |
joints = torch.cat([joints, joints_extra], dim=1) # (B, 44, 3) | |
# Update the joints in the output. | |
skel_output.joints_backup = skel_output.joints | |
skel_output.joints_custom = joints_custom | |
skel_output.joints = joints | |
return skel_output | |
def get_static_root_offset(skel_output): | |
''' | |
Background: | |
By default, the orientation rotation is always around the original skel_root. | |
In order to make the orientation rotation around the custom_root, we need to calculate the translation offset. | |
This function calculates the translation offset in static pose. (From custom_root to skel_root.) | |
''' | |
custom_root = skel_output.joints_custom[:, 0] # (B, 3) | |
skel_root = skel_output.joints_backup[:, 0] # (B, 3) | |
offset = skel_root - custom_root # (B, 3) | |
return offset |