Spaces:
Paused
Paused
| import numpy as np | |
| from ..models.lmdm import LMDM | |
| """ | |
| lmdm_cfg = { | |
| "model_path": "", | |
| "device": "cuda", | |
| "motion_feat_dim": 265, | |
| "audio_feat_dim": 1024+35, | |
| "seq_frames": 80, | |
| } | |
| """ | |
| def _cvt_LP_motion_info(inp, mode, ignore_keys=()): | |
| ks_shape_map = [ | |
| ['scale', (1, 1), 1], | |
| ['pitch', (1, 66), 66], | |
| ['yaw', (1, 66), 66], | |
| ['roll', (1, 66), 66], | |
| ['t', (1, 3), 3], | |
| ['exp', (1, 63), 63], | |
| ['kp', (1, 63), 63], | |
| ] | |
| def _dic2arr(_dic): | |
| arr = [] | |
| for k, _, ds in ks_shape_map: | |
| if k not in _dic or k in ignore_keys: | |
| continue | |
| v = _dic[k].reshape(ds) | |
| if k == 'scale': | |
| v = v - 1 | |
| arr.append(v) | |
| arr = np.concatenate(arr, -1) # (133) | |
| return arr | |
| def _arr2dic(_arr): | |
| dic = {} | |
| s = 0 | |
| for k, ds, ss in ks_shape_map: | |
| if k in ignore_keys: | |
| continue | |
| v = _arr[s:s + ss].reshape(ds) | |
| if k == 'scale': | |
| v = v + 1 | |
| dic[k] = v | |
| s += ss | |
| if s >= len(_arr): | |
| break | |
| return dic | |
| if mode == 'dic2arr': | |
| assert isinstance(inp, dict) | |
| return _dic2arr(inp) # (dim) | |
| elif mode == 'arr2dic': | |
| assert inp.shape[0] >= 265, f"{inp.shape}" | |
| return _arr2dic(inp) # {k: (1, dim)} | |
| else: | |
| raise ValueError() | |
| class Audio2Motion: | |
| def __init__( | |
| self, | |
| lmdm_cfg, | |
| ): | |
| self.lmdm = LMDM(**lmdm_cfg) | |
| def setup( | |
| self, | |
| x_s_info, | |
| overlap_v2=10, | |
| fix_kp_cond=0, | |
| fix_kp_cond_dim=None, | |
| sampling_timesteps=50, | |
| online_mode=False, | |
| v_min_max_for_clip=None, | |
| smo_k_d=3, | |
| ): | |
| self.smo_k_d = smo_k_d | |
| self.overlap_v2 = overlap_v2 | |
| self.seq_frames = self.lmdm.seq_frames | |
| self.valid_clip_len = self.seq_frames - self.overlap_v2 | |
| # for fuse | |
| self.online_mode = online_mode | |
| if self.online_mode: | |
| self.fuse_length = min(self.overlap_v2, self.valid_clip_len) | |
| else: | |
| self.fuse_length = self.overlap_v2 | |
| self.fuse_alpha = np.arange(self.fuse_length, dtype=np.float32).reshape(1, -1, 1) / self.fuse_length | |
| self.fix_kp_cond = fix_kp_cond | |
| self.fix_kp_cond_dim = fix_kp_cond_dim | |
| self.sampling_timesteps = sampling_timesteps | |
| self.v_min_max_for_clip = v_min_max_for_clip | |
| if self.v_min_max_for_clip is not None: | |
| self.v_min = self.v_min_max_for_clip[0][None] # [dim, 1] | |
| self.v_max = self.v_min_max_for_clip[1][None] | |
| kp_source = _cvt_LP_motion_info(x_s_info, mode='dic2arr', ignore_keys={'kp'})[None] | |
| self.s_kp_cond = kp_source.copy().reshape(1, -1) | |
| self.kp_cond = self.s_kp_cond.copy() | |
| self.lmdm.setup(sampling_timesteps) | |
| self.clip_idx = 0 | |
| def _fuse(self, res_kp_seq, pred_kp_seq): | |
| ## ======================== | |
| ## offline fuse mode | |
| ## last clip: ------- | |
| ## fuse part: ***** | |
| ## curr clip: ------- | |
| ## output: ^^ | |
| # | |
| ## online fuse mode | |
| ## last clip: ------- | |
| ## fuse part: ** | |
| ## curr clip: ------- | |
| ## output: ^^ | |
| ## ======================== | |
| fuse_r1_s = res_kp_seq.shape[1] - self.fuse_length | |
| fuse_r1_e = res_kp_seq.shape[1] | |
| fuse_r2_s = self.seq_frames - self.valid_clip_len - self.fuse_length | |
| fuse_r2_e = self.seq_frames - self.valid_clip_len | |
| r1 = res_kp_seq[:, fuse_r1_s:fuse_r1_e] # [1, fuse_len, dim] | |
| r2 = pred_kp_seq[:, fuse_r2_s: fuse_r2_e] # [1, fuse_len, dim] | |
| r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha | |
| res_kp_seq[:, fuse_r1_s:fuse_r1_e] = r_fuse # fuse last | |
| res_kp_seq = np.concatenate([res_kp_seq, pred_kp_seq[:, fuse_r2_e:]], 1) # len(res_kp_seq) + valid_clip_len | |
| return res_kp_seq | |
| def _update_kp_cond(self, res_kp_seq, idx): | |
| if self.fix_kp_cond == 0: # 不重置 | |
| self.kp_cond = res_kp_seq[:, idx-1] | |
| elif self.fix_kp_cond > 0: | |
| if self.clip_idx % self.fix_kp_cond == 0: # 重置 | |
| self.kp_cond = self.s_kp_cond.copy() # 重置所有 | |
| if self.fix_kp_cond_dim is not None: | |
| ds, de = self.fix_kp_cond_dim | |
| self.kp_cond[:, ds:de] = res_kp_seq[:, idx-1, ds:de] | |
| else: | |
| self.kp_cond = res_kp_seq[:, idx-1] | |
| def _smo(self, res_kp_seq, s, e): | |
| if self.smo_k_d <= 1: | |
| return res_kp_seq | |
| new_res_kp_seq = res_kp_seq.copy() | |
| n = res_kp_seq.shape[1] | |
| half_k = self.smo_k_d // 2 | |
| for i in range(s, e): | |
| ss = max(0, i - half_k) | |
| ee = min(n, i + half_k + 1) | |
| res_kp_seq[:, i, :202] = np.mean(new_res_kp_seq[:, ss:ee, :202], axis=1) | |
| return res_kp_seq | |
| def __call__(self, aud_cond, res_kp_seq=None): | |
| """ | |
| aud_cond: (1, seq_frames, dim) | |
| """ | |
| pred_kp_seq = self.lmdm(self.kp_cond, aud_cond, self.sampling_timesteps) | |
| if res_kp_seq is None: | |
| res_kp_seq = pred_kp_seq # [1, seq_frames, dim] | |
| res_kp_seq = self._smo(res_kp_seq, 0, res_kp_seq.shape[1]) | |
| else: | |
| res_kp_seq = self._fuse(res_kp_seq, pred_kp_seq) # len(res_kp_seq) + valid_clip_len | |
| res_kp_seq = self._smo(res_kp_seq, res_kp_seq.shape[1] - self.valid_clip_len - self.fuse_length, res_kp_seq.shape[1] - self.valid_clip_len + 1) | |
| self.clip_idx += 1 | |
| idx = res_kp_seq.shape[1] - self.overlap_v2 | |
| self._update_kp_cond(res_kp_seq, idx) | |
| return res_kp_seq | |
| def cvt_fmt(self, res_kp_seq): | |
| # res_kp_seq: [1, n, dim] | |
| if self.v_min_max_for_clip is not None: | |
| tmp_res_kp_seq = np.clip(res_kp_seq[0], self.v_min, self.v_max) | |
| else: | |
| tmp_res_kp_seq = res_kp_seq[0] | |
| x_d_info_list = [] | |
| for i in range(tmp_res_kp_seq.shape[0]): | |
| x_d_info = _cvt_LP_motion_info(tmp_res_kp_seq[i], 'arr2dic') # {k: (1, dim)} | |
| x_d_info_list.append(x_d_info) | |
| return x_d_info_list | |