Spaces:
Paused
Paused
| import copy | |
| import random | |
| import numpy as np | |
| from scipy.special import softmax | |
| from ..models.stitch_network import StitchNetwork | |
| """ | |
| # __init__ | |
| stitch_network_cfg = { | |
| "model_path": "", | |
| "device": "cuda", | |
| } | |
| # __call__ | |
| kwargs: | |
| fade_alpha | |
| fade_out_keys | |
| delta_pitch | |
| delta_yaw | |
| delta_roll | |
| """ | |
| def ctrl_motion(x_d_info, **kwargs): | |
| # pose + offset | |
| for kk in ["delta_pitch", "delta_yaw", "delta_roll"]: | |
| if kk in kwargs: | |
| k = kk[6:] | |
| x_d_info[k] = bin66_to_degree(x_d_info[k]) + kwargs[kk] | |
| # pose * alpha | |
| for kk in ["alpha_pitch", "alpha_yaw", "alpha_roll"]: | |
| if kk in kwargs: | |
| k = kk[6:] | |
| x_d_info[k] = x_d_info[k] * kwargs[kk] | |
| # exp + offset | |
| if "delta_exp" in kwargs: | |
| k = "exp" | |
| x_d_info[k] = x_d_info[k] + kwargs["delta_exp"] | |
| return x_d_info | |
| def fade(x_d_info, dst, alpha, keys=None): | |
| if keys is None: | |
| keys = x_d_info.keys() | |
| for k in keys: | |
| if k == 'kp': | |
| continue | |
| x_d_info[k] = x_d_info[k] * alpha + dst[k] * (1 - alpha) | |
| return x_d_info | |
| def ctrl_vad(x_d_info, dst, alpha): | |
| exp = x_d_info["exp"] | |
| exp_dst = dst["exp"] | |
| _lip = [6, 12, 14, 17, 19, 20] | |
| _a1 = np.zeros((21, 3), dtype=np.float32) | |
| _a1[_lip] = alpha | |
| _a1 = _a1.reshape(1, -1) | |
| x_d_info["exp"] = exp * alpha + exp_dst * (1 - alpha) | |
| return x_d_info | |
| def _mix_s_d_info( | |
| x_s_info, | |
| x_d_info, | |
| use_d_keys=("exp", "pitch", "yaw", "roll", "t"), | |
| d0=None, | |
| ): | |
| if d0 is not None: | |
| if isinstance(use_d_keys, dict): | |
| x_d_info = { | |
| k: x_s_info[k] + (v - d0[k]) * use_d_keys.get(k, 1) | |
| for k, v in x_d_info.items() | |
| } | |
| else: | |
| x_d_info = {k: x_s_info[k] + (v - d0[k]) for k, v in x_d_info.items()} | |
| for k, v in x_s_info.items(): | |
| if k not in x_d_info or k not in use_d_keys: | |
| x_d_info[k] = v | |
| if isinstance(use_d_keys, dict) and d0 is None: | |
| for k, alpha in use_d_keys.items(): | |
| x_d_info[k] *= alpha | |
| return x_d_info | |
| def _set_eye_blink_idx(N, blink_n=15, open_n=-1): | |
| """ | |
| open_n: | |
| -1: no blink | |
| 0: random open_n | |
| >0: fix open_n | |
| list: loop open_n | |
| """ | |
| OPEN_MIN = 60 | |
| OPEN_MAX = 100 | |
| idx = [0] * N | |
| if isinstance(open_n, int): | |
| if open_n < 0: # no blink | |
| return idx | |
| elif open_n > 0: # fix open_n | |
| open_ns = [open_n] | |
| else: # open_n == 0: # random open_n, 60-100 | |
| open_ns = [] | |
| elif isinstance(open_n, list): | |
| open_ns = open_n # loop open_n | |
| else: | |
| raise ValueError() | |
| blink_idx = list(range(blink_n)) | |
| start_n = open_ns[0] if open_ns else random.randint(OPEN_MIN, OPEN_MAX) | |
| end_n = open_ns[-1] if open_ns else random.randint(OPEN_MIN, OPEN_MAX) | |
| max_i = N - max(end_n, blink_n) | |
| cur_i = start_n | |
| cur_n_i = 1 | |
| while cur_i < max_i: | |
| idx[cur_i : cur_i + blink_n] = blink_idx | |
| if open_ns: | |
| cur_n = open_ns[cur_n_i % len(open_ns)] | |
| cur_n_i += 1 | |
| else: | |
| cur_n = random.randint(OPEN_MIN, OPEN_MAX) | |
| cur_i = cur_i + blink_n + cur_n | |
| return idx | |
| def _fix_exp_for_x_d_info(x_d_info, x_s_info, delta_eye=None, drive_eye=True): | |
| _eye = [11, 13, 15, 16, 18] | |
| _lip = [6, 12, 14, 17, 19, 20] | |
| alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype) | |
| alpha[_lip] = 1 | |
| if delta_eye is None and drive_eye: # use d eye | |
| alpha[_eye] = 1 | |
| alpha = alpha.reshape(1, -1) | |
| x_d_info["exp"] = x_d_info["exp"] * alpha + x_s_info["exp"] * (1 - alpha) | |
| if delta_eye is not None and drive_eye: | |
| alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype) | |
| alpha[_eye] = 1 | |
| alpha = alpha.reshape(1, -1) | |
| x_d_info["exp"] = (delta_eye + x_s_info["exp"]) * alpha + x_d_info["exp"] * ( | |
| 1 - alpha | |
| ) | |
| return x_d_info | |
| def _fix_exp_for_x_d_info_v2(x_d_info, x_s_info, delta_eye, a1, a2, a3): | |
| x_d_info["exp"] = x_d_info["exp"] * a1 + x_s_info["exp"] * a2 + delta_eye * a3 | |
| return x_d_info | |
| def bin66_to_degree(pred): | |
| if pred.ndim > 1 and pred.shape[1] == 66: | |
| idx = np.arange(66).astype(np.float32) | |
| pred = softmax(pred, axis=1) | |
| degree = np.sum(pred * idx, axis=1) * 3 - 97.5 | |
| return degree | |
| return pred | |
| def _eye_delta(exp, dx=0, dy=0): | |
| if dx > 0: | |
| exp[0, 33] += dx * 0.0007 | |
| exp[0, 45] += dx * 0.001 | |
| else: | |
| exp[0, 33] += dx * 0.001 | |
| exp[0, 45] += dx * 0.0007 | |
| exp[0, 34] += dy * -0.001 | |
| exp[0, 46] += dy * -0.001 | |
| return exp | |
| def _fix_gaze(pose_s, x_d_info): | |
| x_ratio = 0.26 | |
| y_ratio = 0.28 | |
| yaw_s, pitch_s = pose_s | |
| yaw_d = bin66_to_degree(x_d_info['yaw']).item() | |
| pitch_d = bin66_to_degree(x_d_info['pitch']).item() | |
| delta_yaw = yaw_d - yaw_s | |
| delta_pitch = pitch_d - pitch_s | |
| dx = delta_yaw * x_ratio | |
| dy = delta_pitch * y_ratio | |
| x_d_info['exp'] = _eye_delta(x_d_info['exp'], dx, dy) | |
| return x_d_info | |
| def get_rotation_matrix(pitch_, yaw_, roll_): | |
| """ the input is in degree | |
| """ | |
| # transform to radian | |
| pitch = pitch_ / 180 * np.pi | |
| yaw = yaw_ / 180 * np.pi | |
| roll = roll_ / 180 * np.pi | |
| if pitch.ndim == 1: | |
| pitch = pitch[:, None] | |
| if yaw.ndim == 1: | |
| yaw = yaw[:, None] | |
| if roll.ndim == 1: | |
| roll = roll[:, None] | |
| # calculate the euler matrix | |
| bs = pitch.shape[0] | |
| ones = np.ones((bs, 1), dtype=np.float32) | |
| zeros = np.zeros((bs, 1), dtype=np.float32) | |
| x, y, z = pitch, yaw, roll | |
| rot_x = np.concatenate([ | |
| ones, zeros, zeros, | |
| zeros, np.cos(x), -np.sin(x), | |
| zeros, np.sin(x), np.cos(x) | |
| ], axis=1).reshape(bs, 3, 3) | |
| rot_y = np.concatenate([ | |
| np.cos(y), zeros, np.sin(y), | |
| zeros, ones, zeros, | |
| -np.sin(y), zeros, np.cos(y) | |
| ], axis=1).reshape(bs, 3, 3) | |
| rot_z = np.concatenate([ | |
| np.cos(z), -np.sin(z), zeros, | |
| np.sin(z), np.cos(z), zeros, | |
| zeros, zeros, ones | |
| ], axis=1).reshape(bs, 3, 3) | |
| rot = np.matmul(np.matmul(rot_z, rot_y), rot_x) | |
| return np.transpose(rot, (0, 2, 1)) | |
| def transform_keypoint(kp_info: dict): | |
| """ | |
| transform the implicit keypoints with the pose, shift, and expression deformation | |
| kp: BxNx3 | |
| """ | |
| kp = kp_info['kp'] # (bs, k, 3) | |
| pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] | |
| t, exp = kp_info['t'], kp_info['exp'] | |
| scale = kp_info['scale'] | |
| pitch = bin66_to_degree(pitch) | |
| yaw = bin66_to_degree(yaw) | |
| roll = bin66_to_degree(roll) | |
| bs = kp.shape[0] | |
| if kp.ndim == 2: | |
| num_kp = kp.shape[1] // 3 # Bx(num_kpx3) | |
| else: | |
| num_kp = kp.shape[1] # Bxnum_kpx3 | |
| rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3) | |
| # Eqn.2: s * (R * x_c,s + exp) + t | |
| kp_transformed = np.matmul(kp.reshape(bs, num_kp, 3), rot_mat) + exp.reshape(bs, num_kp, 3) | |
| kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) | |
| kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty | |
| return kp_transformed | |
| class MotionStitch: | |
| def __init__( | |
| self, | |
| stitch_network_cfg, | |
| ): | |
| self.stitch_net = StitchNetwork(**stitch_network_cfg) | |
| def set_Nd(self, N_d=-1): | |
| # only for offline (make start|end eye open) | |
| if N_d == self.N_d: | |
| return | |
| self.N_d = N_d | |
| if self.drive_eye and self.delta_eye_arr is not None: | |
| N = 3000 if self.N_d == -1 else self.N_d | |
| self.delta_eye_idx_list = _set_eye_blink_idx( | |
| N, len(self.delta_eye_arr), self.delta_eye_open_n | |
| ) | |
| def setup( | |
| self, | |
| N_d=-1, | |
| use_d_keys=None, | |
| relative_d=True, | |
| drive_eye=None, # use d eye or s eye | |
| delta_eye_arr=None, # fix eye | |
| delta_eye_open_n=-1, # int|list | |
| fade_out_keys=("exp",), | |
| fade_type="", # "" | "d0" | "s" | |
| flag_stitching=True, | |
| is_image_flag=True, | |
| x_s_info=None, | |
| d0=None, | |
| ch_info=None, | |
| overall_ctrl_info=None, | |
| ): | |
| self.is_image_flag = is_image_flag | |
| if use_d_keys is None: | |
| if self.is_image_flag: | |
| self.use_d_keys = ("exp", "pitch", "yaw", "roll", "t") | |
| else: | |
| self.use_d_keys = ("exp", ) | |
| else: | |
| self.use_d_keys = use_d_keys | |
| if drive_eye is None: | |
| if self.is_image_flag: | |
| self.drive_eye = True | |
| else: | |
| self.drive_eye = False | |
| else: | |
| self.drive_eye = drive_eye | |
| self.N_d = N_d | |
| self.relative_d = relative_d | |
| self.delta_eye_arr = delta_eye_arr | |
| self.delta_eye_open_n = delta_eye_open_n | |
| self.fade_out_keys = fade_out_keys | |
| self.fade_type = fade_type | |
| self.flag_stitching = flag_stitching | |
| _eye = [11, 13, 15, 16, 18] | |
| _lip = [6, 12, 14, 17, 19, 20] | |
| _a1 = np.zeros((21, 3), dtype=np.float32) | |
| _a1[_lip] = 1 | |
| _a2 = 0 | |
| if self.drive_eye: | |
| if self.delta_eye_arr is None: | |
| _a1[_eye] = 1 | |
| else: | |
| _a2 = np.zeros((21, 3), dtype=np.float32) | |
| _a2[_eye] = 1 | |
| _a2 = _a2.reshape(1, -1) | |
| _a1 = _a1.reshape(1, -1) | |
| self.fix_exp_a1 = _a1 * (1 - _a2) | |
| self.fix_exp_a2 = (1 - _a1) + _a1 * _a2 | |
| self.fix_exp_a3 = _a2 | |
| if self.drive_eye and self.delta_eye_arr is not None: | |
| N = 3000 if self.N_d == -1 else self.N_d | |
| self.delta_eye_idx_list = _set_eye_blink_idx( | |
| N, len(self.delta_eye_arr), self.delta_eye_open_n | |
| ) | |
| self.pose_s = None | |
| self.x_s = None | |
| self.fade_dst = None | |
| if self.is_image_flag and x_s_info is not None: | |
| yaw_s = bin66_to_degree(x_s_info['yaw']).item() | |
| pitch_s = bin66_to_degree(x_s_info['pitch']).item() | |
| self.pose_s = [yaw_s, pitch_s] | |
| self.x_s = transform_keypoint(x_s_info) | |
| if self.fade_type == "s": | |
| self.fade_dst = copy.deepcopy(x_s_info) | |
| if ch_info is not None: | |
| self.scale_a = ch_info['x_s_info_lst'][0]['scale'].item() | |
| if x_s_info is not None: | |
| self.scale_b = x_s_info['scale'].item() | |
| self.scale_ratio = self.scale_a / self.scale_b | |
| self._set_scale_ratio(self.scale_ratio) | |
| else: | |
| self.scale_ratio = None | |
| else: | |
| self.scale_ratio = 1 | |
| self.overall_ctrl_info = overall_ctrl_info | |
| self.d0 = d0 | |
| self.idx = 0 | |
| def _set_scale_ratio(self, scale_ratio=1): | |
| if scale_ratio == 1: | |
| return | |
| if isinstance(self.use_d_keys, dict): | |
| self.use_d_keys = {k: v * (scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1) for k, v in self.use_d_keys.items()} | |
| else: | |
| self.use_d_keys = {k: scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1 for k in self.use_d_keys} | |
| def _merge_kwargs(default_kwargs, run_kwargs): | |
| if default_kwargs is None: | |
| return run_kwargs | |
| for k, v in default_kwargs.items(): | |
| if k not in run_kwargs: | |
| run_kwargs[k] = v | |
| return run_kwargs | |
| def __call__(self, x_s_info, x_d_info, **kwargs): | |
| # return x_s, x_d | |
| kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs) | |
| if self.scale_ratio is None: | |
| self.scale_b = x_s_info['scale'].item() | |
| self.scale_ratio = self.scale_a / self.scale_b | |
| self._set_scale_ratio(self.scale_ratio) | |
| if self.relative_d and self.d0 is None: | |
| self.d0 = copy.deepcopy(x_d_info) | |
| x_d_info = _mix_s_d_info( | |
| x_s_info, | |
| x_d_info, | |
| self.use_d_keys, | |
| self.d0, | |
| ) | |
| delta_eye = 0 | |
| if self.drive_eye and self.delta_eye_arr is not None: | |
| delta_eye = self.delta_eye_arr[ | |
| self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)] | |
| ][None] | |
| x_d_info = _fix_exp_for_x_d_info_v2( | |
| x_d_info, | |
| x_s_info, | |
| delta_eye, | |
| self.fix_exp_a1, | |
| self.fix_exp_a2, | |
| self.fix_exp_a3, | |
| ) | |
| if kwargs.get("vad_alpha", 1) < 1: | |
| x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1)) | |
| x_d_info = ctrl_motion(x_d_info, **kwargs) | |
| if self.fade_type == "d0" and self.fade_dst is None: | |
| self.fade_dst = copy.deepcopy(x_d_info) | |
| # fade | |
| if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]: | |
| fade_alpha = kwargs["fade_alpha"] | |
| fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys) | |
| if self.fade_type == "d0": | |
| fade_dst = self.fade_dst | |
| elif self.fade_type == "s": | |
| if self.fade_dst is not None: | |
| fade_dst = self.fade_dst | |
| else: | |
| fade_dst = copy.deepcopy(x_s_info) | |
| if self.is_image_flag: | |
| self.fade_dst = fade_dst | |
| x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys) | |
| if self.drive_eye: | |
| if self.pose_s is None: | |
| yaw_s = bin66_to_degree(x_s_info['yaw']).item() | |
| pitch_s = bin66_to_degree(x_s_info['pitch']).item() | |
| self.pose_s = [yaw_s, pitch_s] | |
| x_d_info = _fix_gaze(self.pose_s, x_d_info) | |
| if self.x_s is not None: | |
| x_s = self.x_s | |
| else: | |
| x_s = transform_keypoint(x_s_info) | |
| if self.is_image_flag: | |
| self.x_s = x_s | |
| x_d = transform_keypoint(x_d_info) | |
| if self.flag_stitching: | |
| x_d = self.stitch_net(x_s, x_d) | |
| self.idx += 1 | |
| return x_s, x_d | |