Spaces:
Runtime error
Runtime error
Added visualise.py for visualising the predictions
Browse files- visualise.py +98 -0
visualise.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Visualisation code for SMPL-X model. This code is useful if you already have predictions.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import numpy as np
|
| 7 |
+
import smplx
|
| 8 |
+
from smplx.joint_names import JOINT_NAMES
|
| 9 |
+
import torch
|
| 10 |
+
try:
|
| 11 |
+
CUR_DIR = osp.dirname(os.path.abspath(__file__))
|
| 12 |
+
except NameError:
|
| 13 |
+
CUR_DIR = os.getcwd()
|
| 14 |
+
sys.path.insert(0, osp.join(CUR_DIR, '..', 'main'))
|
| 15 |
+
sys.path.insert(0, osp.join(CUR_DIR , '..', 'common'))
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 18 |
+
|
| 19 |
+
JOINT_NAMES_DICT = {name: i for i, name in enumerate(JOINT_NAMES)}
|
| 20 |
+
|
| 21 |
+
# Load the SMPL-X model
|
| 22 |
+
model_path = 'common/utils/human_model_files' # Update with the path to your SMPL-X models
|
| 23 |
+
model = smplx.create(model_path, model_type='smplx', gender='neutral', ext='npz')
|
| 24 |
+
|
| 25 |
+
# Load the parameters from the .npz file
|
| 26 |
+
data = np.load('/home/sahand/Downloads/smplx/00047_9.npz')
|
| 27 |
+
|
| 28 |
+
betas = torch.tensor(data['betas'], dtype=torch.float32)
|
| 29 |
+
body_pose = torch.tensor(data['body_pose'], dtype=torch.float32)
|
| 30 |
+
global_orient = torch.tensor(data['global_orient'], dtype=torch.float32)
|
| 31 |
+
transl = torch.tensor(data['transl'], dtype=torch.float32)
|
| 32 |
+
expression = torch.tensor(data['expression'], dtype=torch.float32)
|
| 33 |
+
|
| 34 |
+
# Add missing dimensions to the tensors
|
| 35 |
+
if betas.ndim == 1:
|
| 36 |
+
betas = betas.unsqueeze(0)
|
| 37 |
+
if body_pose.ndim == 2:
|
| 38 |
+
body_pose = body_pose.unsqueeze(0)
|
| 39 |
+
if global_orient.ndim == 1:
|
| 40 |
+
global_orient = global_orient.unsqueeze(0)
|
| 41 |
+
if transl.ndim == 1:
|
| 42 |
+
transl = transl.unsqueeze(0)
|
| 43 |
+
if expression.ndim == 1:
|
| 44 |
+
expression = expression.unsqueeze(0)
|
| 45 |
+
|
| 46 |
+
# Reshape body_pose to include the batch dimension
|
| 47 |
+
body_pose = body_pose.view(1, -1, 3)
|
| 48 |
+
|
| 49 |
+
# Forward pass through the model
|
| 50 |
+
output = model(betas=betas, body_pose=body_pose, global_orient=global_orient, transl=transl, expression=expression)
|
| 51 |
+
|
| 52 |
+
# Extract joint positions
|
| 53 |
+
joints = output.joints.detach().cpu().numpy().squeeze()
|
| 54 |
+
print(joints.shape)
|
| 55 |
+
# Ankle joints (left and right)
|
| 56 |
+
left_knee = joints[4] # Index for left ankle in SMPL-X
|
| 57 |
+
right_knee = joints[5] # Index for right ankle in SMPL-X
|
| 58 |
+
left_ankle = joints[7] # Index for left ankle in SMPL-X
|
| 59 |
+
right_ankle = joints[8] # Index for right ankle in SMPL-X
|
| 60 |
+
|
| 61 |
+
bone_connections = [
|
| 62 |
+
(JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["spine1"]), (JOINT_NAMES_DICT["spine1"], JOINT_NAMES_DICT["spine2"]), (JOINT_NAMES_DICT["spine2"], JOINT_NAMES_DICT["spine3"]), # Spine
|
| 63 |
+
(JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["left_hip"]), (JOINT_NAMES_DICT["left_hip"], JOINT_NAMES_DICT["left_knee"]), (JOINT_NAMES_DICT["left_knee"], JOINT_NAMES_DICT["left_ankle"]), # Left leg
|
| 64 |
+
(JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["right_hip"]), (JOINT_NAMES_DICT["right_hip"], JOINT_NAMES_DICT["right_knee"]), (JOINT_NAMES_DICT["right_knee"], JOINT_NAMES_DICT["right_ankle"]), # Right leg
|
| 65 |
+
(JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_heel"]),
|
| 66 |
+
(JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_heel"]),
|
| 67 |
+
(JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_foot"]),
|
| 68 |
+
(JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_big_toe"]), (JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_small_toe"]),
|
| 69 |
+
(JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_foot"]),
|
| 70 |
+
(JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_big_toe"]), (JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_small_toe"]),
|
| 71 |
+
# Add more bones if necessary
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# Visualize the 3D skeleton
|
| 75 |
+
fig = plt.figure()
|
| 76 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 77 |
+
|
| 78 |
+
# Plot all joints
|
| 79 |
+
ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], c='blue', marker='o')
|
| 80 |
+
# Highlight ankle joints
|
| 81 |
+
ax.scatter([left_knee[0]], [left_knee[1]], [left_knee[2]], c='red', marker='x', s=100, label='Left Knee')
|
| 82 |
+
ax.scatter([right_knee[0]], [right_knee[1]], [right_knee[2]], c='green', marker='x', s=100, label='Right Knee')
|
| 83 |
+
ax.scatter([left_ankle[0]], [left_ankle[1]], [left_ankle[2]], c='red', marker='o', s=100, label='Left Ankle')
|
| 84 |
+
ax.scatter([right_ankle[0]], [right_ankle[1]], [right_ankle[2]], c='green', marker='o', s=100, label='Right Ankle')
|
| 85 |
+
|
| 86 |
+
# Draw bones
|
| 87 |
+
for bone in bone_connections:
|
| 88 |
+
start, end = bone
|
| 89 |
+
ax.plot([joints[start, 0], joints[end, 0]],
|
| 90 |
+
[joints[start, 1], joints[end, 1]],
|
| 91 |
+
[joints[start, 2], joints[end, 2]], 'k-')
|
| 92 |
+
|
| 93 |
+
# Set labels
|
| 94 |
+
ax.set_xlabel('X')
|
| 95 |
+
ax.set_ylabel('Y')
|
| 96 |
+
ax.set_zlabel('Z')
|
| 97 |
+
ax.legend()
|
| 98 |
+
plt.show()
|