Spaces:
Paused
Paused
| import numpy as np | |
| from scipy.special import softmax | |
| import copy | |
| def _get_emo_avg(idx=6): | |
| emo_avg = np.zeros(8, dtype=np.float32) | |
| if isinstance(idx, (list, tuple)): | |
| for i in idx: | |
| emo_avg[i] = 8 | |
| else: | |
| emo_avg[idx] = 8 | |
| emo_avg = softmax(emo_avg) | |
| #emo_avg = None | |
| # 'Angry', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise', 'Contempt' | |
| return emo_avg | |
| def _mirror_index(index, size): | |
| turn = index // size | |
| res = index % size | |
| if turn % 2 == 0: | |
| return res | |
| else: | |
| return size - res - 1 | |
| class ConditionHandler: | |
| """ | |
| aud_feat, emo_seq, eye_seq, sc_seq -> cond_seq | |
| """ | |
| def __init__( | |
| self, | |
| use_emo=True, | |
| use_sc=True, | |
| use_eye_open=True, | |
| use_eye_ball=True, | |
| seq_frames=80, | |
| ): | |
| self.use_emo = use_emo | |
| self.use_sc = use_sc | |
| self.use_eye_open = use_eye_open | |
| self.use_eye_ball = use_eye_ball | |
| self.seq_frames = seq_frames | |
| def setup(self, setup_info, emo, eye_f0_mode=False, ch_info=None): | |
| """ | |
| emo: int | [int] | [[int]] | numpy | |
| """ | |
| if ch_info is None: | |
| source_info = copy.deepcopy(setup_info) | |
| else: | |
| source_info = ch_info | |
| self.eye_f0_mode = eye_f0_mode | |
| self.x_s_info_0 = source_info['x_s_info_lst'][0] | |
| if self.use_sc: | |
| self.sc = source_info["sc"] # 63 | |
| self.sc_seq = np.stack([self.sc] * self.seq_frames, 0) | |
| if self.use_eye_open: | |
| self.eye_open_lst = np.concatenate(source_info["eye_open_lst"], 0) # [n, 2] | |
| self.num_eye_open = len(self.eye_open_lst) | |
| if self.num_eye_open == 1 or self.eye_f0_mode: | |
| self.eye_open_seq = np.stack([self.eye_open_lst[0]] * self.seq_frames, 0) | |
| else: | |
| self.eye_open_seq = None | |
| if self.use_eye_ball: | |
| self.eye_ball_lst = np.concatenate(source_info["eye_ball_lst"], 0) # [n, 6] | |
| self.num_eye_ball = len(self.eye_ball_lst) | |
| if self.num_eye_ball == 1 or self.eye_f0_mode: | |
| self.eye_ball_seq = np.stack([self.eye_ball_lst[0]] * self.seq_frames, 0) | |
| else: | |
| self.eye_ball_seq = None | |
| if self.use_emo: | |
| self.emo_lst = self._parse_emo_seq(emo) | |
| self.num_emo = len(self.emo_lst) | |
| if self.num_emo == 1: | |
| self.emo_seq = np.concatenate([self.emo_lst] * self.seq_frames, 0) | |
| else: | |
| self.emo_seq = None | |
| def _parse_emo_seq(emo, seq_len=-1): | |
| if isinstance(emo, np.ndarray) and emo.ndim == 2 and emo.shape[1] == 8: | |
| # emo arr, e.g. real | |
| emo_seq = emo # [m, 8] | |
| elif isinstance(emo, int) and 0 <= emo < 8: | |
| # emo label, e.g. 4 | |
| emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8] | |
| elif isinstance(emo, (list, tuple)) and 0 < len(emo) < 8 and isinstance(emo[0], int): | |
| # emo labels, e.g. [3,4] | |
| emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8] | |
| elif isinstance(emo, list) and emo and isinstance(emo[0], (list, tuple)): | |
| # emo label list, e.g. [[4], [3,4], [3],[3,4,5], ...] | |
| emo_seq = np.stack([_get_emo_avg(i) for i in emo], 0) # [m, 8] | |
| else: | |
| raise ValueError(f"Unsupported emo type: {emo}") | |
| if seq_len > 0: | |
| if len(emo_seq) == seq_len: | |
| return emo_seq | |
| elif len(emo_seq) == 1: | |
| return np.concatenate([emo_seq] * seq_len, 0) | |
| elif len(emo_seq) > seq_len: | |
| return emo_seq[:seq_len] | |
| else: | |
| raise ValueError(f"emo len {len(emo_seq)} can not match seq len ({seq_len})") | |
| else: | |
| return emo_seq | |
| def __call__(self, aud_feat, idx, emo=None): | |
| """ | |
| aud_feat: [n, 1024] | |
| idx: int, <0 means pad (first clip buffer) | |
| """ | |
| frame_num = len(aud_feat) | |
| more_cond = [aud_feat] | |
| if self.use_emo: | |
| if emo is not None: | |
| emo_seq = self._parse_emo_seq(emo, frame_num) | |
| elif self.emo_seq is not None and len(self.emo_seq) == frame_num: | |
| emo_seq = self.emo_seq | |
| else: | |
| emo_idx_list = [max(i, 0) % self.num_emo for i in range(idx, idx + frame_num)] | |
| emo_seq = self.emo_lst[emo_idx_list] | |
| more_cond.append(emo_seq) | |
| if self.use_eye_open: | |
| if self.eye_open_seq is not None and len(self.eye_open_seq) == frame_num: | |
| eye_open_seq = self.eye_open_seq | |
| else: | |
| if self.eye_f0_mode: | |
| eye_idx_list = [0] * frame_num | |
| else: | |
| eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_open) for i in range(idx, idx + frame_num)] | |
| eye_open_seq = self.eye_open_lst[eye_idx_list] | |
| more_cond.append(eye_open_seq) | |
| if self.use_eye_ball: | |
| if self.eye_ball_seq is not None and len(self.eye_ball_seq) == frame_num: | |
| eye_ball_seq = self.eye_ball_seq | |
| else: | |
| if self.eye_f0_mode: | |
| eye_idx_list = [0] * frame_num | |
| else: | |
| eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_ball) for i in range(idx, idx + frame_num)] | |
| eye_ball_seq = self.eye_ball_lst[eye_idx_list] | |
| more_cond.append(eye_ball_seq) | |
| if self.use_sc: | |
| if len(self.sc_seq) == frame_num: | |
| sc_seq = self.sc_seq | |
| else: | |
| sc_seq = np.stack([self.sc] * frame_num, 0) | |
| more_cond.append(sc_seq) | |
| if len(more_cond) > 1: | |
| cond_seq = np.concatenate(more_cond, -1) # [n, dim_cond] | |
| else: | |
| cond_seq = aud_feat | |
| return cond_seq | |