Spaces:
Paused
Paused
| import threading | |
| import queue | |
| import numpy as np | |
| import traceback | |
| from tqdm import tqdm | |
| from core.atomic_components.avatar_registrar import AvatarRegistrar, smooth_x_s_info_lst | |
| from core.atomic_components.condition_handler import ConditionHandler, _mirror_index | |
| from core.atomic_components.audio2motion import Audio2Motion | |
| from core.atomic_components.motion_stitch import MotionStitch | |
| from core.atomic_components.warp_f3d import WarpF3D | |
| from core.atomic_components.decode_f3d import DecodeF3D | |
| from core.atomic_components.putback import PutBack | |
| from core.atomic_components.writer import VideoWriterByImageIO | |
| from core.atomic_components.wav2feat import Wav2Feat | |
| from core.atomic_components.cfg import parse_cfg, print_cfg | |
| """ | |
| avatar_registrar_cfg: | |
| insightface_det_cfg, | |
| landmark106_cfg, | |
| landmark203_cfg, | |
| landmark478_cfg, | |
| appearance_extractor_cfg, | |
| motion_extractor_cfg, | |
| condition_handler_cfg: | |
| use_emo=True, | |
| use_sc=True, | |
| use_eye_open=True, | |
| use_eye_ball=True, | |
| seq_frames=80, | |
| wav2feat_cfg: | |
| w2f_cfg, | |
| w2f_type | |
| """ | |
| class StreamSDK: | |
| def __init__(self, cfg_pkl, data_root, **kwargs): | |
| [ | |
| avatar_registrar_cfg, | |
| condition_handler_cfg, | |
| lmdm_cfg, | |
| stitch_network_cfg, | |
| warp_network_cfg, | |
| decoder_cfg, | |
| wav2feat_cfg, | |
| default_kwargs, | |
| ] = parse_cfg(cfg_pkl, data_root, kwargs) | |
| self.default_kwargs = default_kwargs | |
| self.avatar_registrar = AvatarRegistrar(**avatar_registrar_cfg) | |
| self.condition_handler = ConditionHandler(**condition_handler_cfg) | |
| self.audio2motion = Audio2Motion(lmdm_cfg) | |
| self.motion_stitch = MotionStitch(stitch_network_cfg) | |
| self.warp_f3d = WarpF3D(warp_network_cfg) | |
| self.decode_f3d = DecodeF3D(decoder_cfg) | |
| self.putback = PutBack() | |
| self.wav2feat = Wav2Feat(**wav2feat_cfg) | |
| def _merge_kwargs(self, default_kwargs, run_kwargs): | |
| for k, v in default_kwargs.items(): | |
| if k not in run_kwargs: | |
| run_kwargs[k] = v | |
| return run_kwargs | |
| def setup_Nd(self, N_d, fade_in=-1, fade_out=-1, ctrl_info=None): | |
| # for eye open at video end | |
| self.motion_stitch.set_Nd(N_d) | |
| # for fade in/out alpha | |
| if ctrl_info is None: | |
| ctrl_info = self.ctrl_info | |
| if fade_in > 0: | |
| for i in range(fade_in): | |
| alpha = i / fade_in | |
| item = ctrl_info.get(i, {}) | |
| item["fade_alpha"] = alpha | |
| ctrl_info[i] = item | |
| if fade_out > 0: | |
| ss = N_d - fade_out - 1 | |
| ee = N_d - 1 | |
| for i in range(ss, N_d): | |
| alpha = max((ee - i) / (ee - ss), 0) | |
| item = ctrl_info.get(i, {}) | |
| item["fade_alpha"] = alpha | |
| ctrl_info[i] = item | |
| self.ctrl_info = ctrl_info | |
| def setup(self, source_path, output_path, **kwargs): | |
| # ======== Prepare Options ======== | |
| kwargs = self._merge_kwargs(self.default_kwargs, kwargs) | |
| print("=" * 20, "setup kwargs", "=" * 20) | |
| print_cfg(**kwargs) | |
| print("=" * 50) | |
| # -- avatar_registrar: template cfg -- | |
| self.max_size = kwargs.get("max_size", 1920) | |
| self.template_n_frames = kwargs.get("template_n_frames", -1) | |
| # -- avatar_registrar: crop cfg -- | |
| self.crop_scale = kwargs.get("crop_scale", 2.3) | |
| self.crop_vx_ratio = kwargs.get("crop_vx_ratio", 0) | |
| self.crop_vy_ratio = kwargs.get("crop_vy_ratio", -0.125) | |
| self.crop_flag_do_rot = kwargs.get("crop_flag_do_rot", True) | |
| # -- avatar_registrar: smo for video -- | |
| self.smo_k_s = kwargs.get('smo_k_s', 13) | |
| # -- condition_handler: ECS -- | |
| self.emo = kwargs.get("emo", 4) # int | [int] | [[int]] | numpy | |
| self.eye_f0_mode = kwargs.get("eye_f0_mode", False) # for video | |
| self.ch_info = kwargs.get("ch_info", None) # dict of np.ndarray | |
| # -- audio2motion: setup -- | |
| self.overlap_v2 = kwargs.get("overlap_v2", 10) | |
| self.fix_kp_cond = kwargs.get("fix_kp_cond", 0) | |
| self.fix_kp_cond_dim = kwargs.get("fix_kp_cond_dim", None) # [ds,de] | |
| self.sampling_timesteps = kwargs.get("sampling_timesteps", 50) | |
| self.online_mode = kwargs.get("online_mode", False) | |
| self.v_min_max_for_clip = kwargs.get('v_min_max_for_clip', None) | |
| self.smo_k_d = kwargs.get("smo_k_d", 3) | |
| # -- motion_stitch: setup -- | |
| self.N_d = kwargs.get("N_d", -1) | |
| self.use_d_keys = kwargs.get("use_d_keys", None) | |
| self.relative_d = kwargs.get("relative_d", True) | |
| self.drive_eye = kwargs.get("drive_eye", None) # None: true4image, false4video | |
| self.delta_eye_arr = kwargs.get("delta_eye_arr", None) | |
| self.delta_eye_open_n = kwargs.get("delta_eye_open_n", 0) | |
| self.fade_type = kwargs.get("fade_type", "") # "" | "d0" | "s" | |
| self.fade_out_keys = kwargs.get("fade_out_keys", ("exp",)) | |
| self.flag_stitching = kwargs.get("flag_stitching", True) | |
| self.ctrl_info = kwargs.get("ctrl_info", dict()) | |
| self.overall_ctrl_info = kwargs.get("overall_ctrl_info", dict()) | |
| """ | |
| ctrl_info: list or dict | |
| { | |
| fid: ctrl_kwargs | |
| } | |
| ctrl_kwargs (see motion_stitch.py): | |
| fade_alpha | |
| fade_out_keys | |
| delta_pitch | |
| delta_yaw | |
| delta_roll | |
| """ | |
| # only hubert support online mode | |
| assert self.wav2feat.support_streaming or not self.online_mode | |
| # ======== Register Avatar ======== | |
| crop_kwargs = { | |
| "crop_scale": self.crop_scale, | |
| "crop_vx_ratio": self.crop_vx_ratio, | |
| "crop_vy_ratio": self.crop_vy_ratio, | |
| "crop_flag_do_rot": self.crop_flag_do_rot, | |
| } | |
| n_frames = self.template_n_frames if self.template_n_frames > 0 else self.N_d | |
| source_info = self.avatar_registrar( | |
| source_path, | |
| max_dim=self.max_size, | |
| n_frames=n_frames, | |
| **crop_kwargs, | |
| ) | |
| if len(source_info["x_s_info_lst"]) > 1 and self.smo_k_s > 1: | |
| source_info["x_s_info_lst"] = smooth_x_s_info_lst(source_info["x_s_info_lst"], smo_k=self.smo_k_s) | |
| self.source_info = source_info | |
| self.source_info_frames = len(source_info["x_s_info_lst"]) | |
| # ======== Setup Condition Handler ======== | |
| self.condition_handler.setup(source_info, self.emo, eye_f0_mode=self.eye_f0_mode, ch_info=self.ch_info) | |
| # ======== Setup Audio2Motion (LMDM) ======== | |
| x_s_info_0 = self.condition_handler.x_s_info_0 | |
| self.audio2motion.setup( | |
| x_s_info_0, | |
| overlap_v2=self.overlap_v2, | |
| fix_kp_cond=self.fix_kp_cond, | |
| fix_kp_cond_dim=self.fix_kp_cond_dim, | |
| sampling_timesteps=self.sampling_timesteps, | |
| online_mode=self.online_mode, | |
| v_min_max_for_clip=self.v_min_max_for_clip, | |
| smo_k_d=self.smo_k_d, | |
| ) | |
| # ======== Setup Motion Stitch ======== | |
| is_image_flag = source_info["is_image_flag"] | |
| x_s_info = source_info['x_s_info_lst'][0] | |
| self.motion_stitch.setup( | |
| N_d=self.N_d, | |
| use_d_keys=self.use_d_keys, | |
| relative_d=self.relative_d, | |
| drive_eye=self.drive_eye, | |
| delta_eye_arr=self.delta_eye_arr, | |
| delta_eye_open_n=self.delta_eye_open_n, | |
| fade_out_keys=self.fade_out_keys, | |
| fade_type=self.fade_type, | |
| flag_stitching=self.flag_stitching, | |
| is_image_flag=is_image_flag, | |
| x_s_info=x_s_info, | |
| d0=None, | |
| ch_info=self.ch_info, | |
| overall_ctrl_info=self.overall_ctrl_info, | |
| ) | |
| # ======== Video Writer ======== | |
| self.output_path = output_path | |
| self.tmp_output_path = output_path + ".tmp.mp4" | |
| self.writer = VideoWriterByImageIO(self.tmp_output_path) | |
| self.writer_pbar = tqdm(desc="writer") | |
| # ======== Audio Feat Buffer ======== | |
| if self.online_mode: | |
| # buffer: seq_frames - valid_clip_len | |
| self.audio_feat = self.wav2feat.wav2feat(np.zeros((self.overlap_v2 * 640,), dtype=np.float32), sr=16000) | |
| assert len(self.audio_feat) == self.overlap_v2, f"{len(self.audio_feat)}" | |
| else: | |
| self.audio_feat = np.zeros((0, self.wav2feat.feat_dim), dtype=np.float32) | |
| self.cond_idx_start = 0 - len(self.audio_feat) | |
| # ======== Setup Worker Threads ======== | |
| QUEUE_MAX_SIZE = 100 | |
| # self.QUEUE_TIMEOUT = None | |
| self.worker_exception = None | |
| self.stop_event = threading.Event() | |
| self.audio2motion_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) | |
| self.motion_stitch_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) | |
| self.warp_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) | |
| self.decode_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) | |
| self.putback_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) | |
| self.writer_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) | |
| self.thread_list = [ | |
| threading.Thread(target=self.audio2motion_worker), | |
| threading.Thread(target=self.motion_stitch_worker), | |
| threading.Thread(target=self.warp_f3d_worker), | |
| threading.Thread(target=self.decode_f3d_worker), | |
| threading.Thread(target=self.putback_worker), | |
| threading.Thread(target=self.writer_worker), | |
| ] | |
| for thread in self.thread_list: | |
| thread.start() | |
| def _get_ctrl_info(self, fid): | |
| try: | |
| if isinstance(self.ctrl_info, dict): | |
| return self.ctrl_info.get(fid, {}) | |
| elif isinstance(self.ctrl_info, list): | |
| return self.ctrl_info[fid] | |
| else: | |
| return {} | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {} | |
| def writer_worker(self): | |
| try: | |
| self._writer_worker() | |
| except Exception as e: | |
| self.worker_exception = e | |
| self.stop_event.set() | |
| def _writer_worker(self): | |
| while not self.stop_event.is_set(): | |
| try: | |
| item = self.writer_queue.get(timeout=1) | |
| except queue.Empty: | |
| continue | |
| if item is None: | |
| break | |
| res_frame_rgb = item | |
| self.writer(res_frame_rgb, fmt="rgb") | |
| self.writer_pbar.update() | |
| def putback_worker(self): | |
| try: | |
| self._putback_worker() | |
| except Exception as e: | |
| self.worker_exception = e | |
| self.stop_event.set() | |
| def _putback_worker(self): | |
| while not self.stop_event.is_set(): | |
| try: | |
| item = self.putback_queue.get(timeout=1) | |
| except queue.Empty: | |
| continue | |
| if item is None: | |
| self.writer_queue.put(None) | |
| break | |
| frame_idx, render_img = item | |
| frame_rgb = self.source_info["img_rgb_lst"][frame_idx] | |
| M_c2o = self.source_info["M_c2o_lst"][frame_idx] | |
| res_frame_rgb = self.putback(frame_rgb, render_img, M_c2o) | |
| self.writer_queue.put(res_frame_rgb) | |
| def decode_f3d_worker(self): | |
| try: | |
| self._decode_f3d_worker() | |
| except Exception as e: | |
| self.worker_exception = e | |
| self.stop_event.set() | |
| def _decode_f3d_worker(self): | |
| while not self.stop_event.is_set(): | |
| try: | |
| item = self.decode_f3d_queue.get(timeout=1) | |
| except queue.Empty: | |
| continue | |
| if item is None: | |
| self.putback_queue.put(None) | |
| break | |
| frame_idx, f_3d = item | |
| render_img = self.decode_f3d(f_3d) | |
| self.putback_queue.put([frame_idx, render_img]) | |
| def warp_f3d_worker(self): | |
| try: | |
| self._warp_f3d_worker() | |
| except Exception as e: | |
| self.worker_exception = e | |
| self.stop_event.set() | |
| def _warp_f3d_worker(self): | |
| while not self.stop_event.is_set(): | |
| try: | |
| item = self.warp_f3d_queue.get(timeout=1) | |
| except queue.Empty: | |
| continue | |
| if item is None: | |
| self.decode_f3d_queue.put(None) | |
| break | |
| frame_idx, x_s, x_d = item | |
| f_s = self.source_info["f_s_lst"][frame_idx] | |
| f_3d = self.warp_f3d(f_s, x_s, x_d) | |
| self.decode_f3d_queue.put([frame_idx, f_3d]) | |
| def motion_stitch_worker(self): | |
| try: | |
| self._motion_stitch_worker() | |
| except Exception as e: | |
| self.worker_exception = e | |
| self.stop_event.set() | |
| def _motion_stitch_worker(self): | |
| while not self.stop_event.is_set(): | |
| try: | |
| item = self.motion_stitch_queue.get(timeout=1) | |
| except queue.Empty: | |
| continue | |
| if item is None: | |
| self.warp_f3d_queue.put(None) | |
| break | |
| frame_idx, x_d_info, ctrl_kwargs = item | |
| x_s_info = self.source_info["x_s_info_lst"][frame_idx] | |
| x_s, x_d = self.motion_stitch(x_s_info, x_d_info, **ctrl_kwargs) | |
| self.warp_f3d_queue.put([frame_idx, x_s, x_d]) | |
| def audio2motion_worker(self): | |
| try: | |
| self._audio2motion_worker() | |
| except Exception as e: | |
| self.worker_exception = e | |
| self.stop_event.set() | |
| def _audio2motion_worker(self): | |
| is_end = False | |
| seq_frames = self.audio2motion.seq_frames | |
| valid_clip_len = self.audio2motion.valid_clip_len | |
| aud_feat_dim = self.wav2feat.feat_dim | |
| item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32) | |
| res_kp_seq = None | |
| res_kp_seq_valid_start = None if self.online_mode else 0 | |
| global_idx = 0 # frame idx, for template | |
| local_idx = 0 # for cur audio_feat | |
| gen_frame_idx = 0 | |
| while not self.stop_event.is_set(): | |
| try: | |
| item = self.audio2motion_queue.get(timeout=1) # audio feat | |
| except queue.Empty: | |
| continue | |
| if item is None: | |
| is_end = True | |
| else: | |
| item_buffer = np.concatenate([item_buffer, item], 0) | |
| if not is_end and item_buffer.shape[0] < valid_clip_len: | |
| # wait at least valid_clip_len new item | |
| continue | |
| else: | |
| self.audio_feat = np.concatenate([self.audio_feat, item_buffer], 0) | |
| item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32) | |
| while True: | |
| # print("self.audio_feat.shape:", self.audio_feat.shape, "local_idx:", local_idx, "global_idx:", global_idx) | |
| aud_feat = self.audio_feat[local_idx: local_idx+seq_frames] | |
| real_valid_len = valid_clip_len | |
| if len(aud_feat) == 0: | |
| break | |
| elif len(aud_feat) < seq_frames: | |
| if not is_end: | |
| # wait next chunk | |
| break | |
| else: | |
| # final clip: pad to seq_frames | |
| real_valid_len = len(aud_feat) | |
| pad = np.stack([aud_feat[-1]] * (seq_frames - len(aud_feat)), 0) | |
| aud_feat = np.concatenate([aud_feat, pad], 0) | |
| aud_cond = self.condition_handler(aud_feat, global_idx + self.cond_idx_start)[None] | |
| res_kp_seq = self.audio2motion(aud_cond, res_kp_seq) | |
| if res_kp_seq_valid_start is None: | |
| # online mode, first chunk | |
| res_kp_seq_valid_start = res_kp_seq.shape[1] - self.audio2motion.fuse_length | |
| d0 = self.audio2motion.cvt_fmt(res_kp_seq[0:1])[0] | |
| self.motion_stitch.d0 = d0 | |
| local_idx += real_valid_len | |
| global_idx += real_valid_len | |
| continue | |
| else: | |
| valid_res_kp_seq = res_kp_seq[:, res_kp_seq_valid_start: res_kp_seq_valid_start + real_valid_len] | |
| x_d_info_list = self.audio2motion.cvt_fmt(valid_res_kp_seq) | |
| for x_d_info in x_d_info_list: | |
| frame_idx = _mirror_index(gen_frame_idx, self.source_info_frames) | |
| ctrl_kwargs = self._get_ctrl_info(gen_frame_idx) | |
| while not self.stop_event.is_set(): | |
| try: | |
| self.motion_stitch_queue.put([frame_idx, x_d_info, ctrl_kwargs], timeout=1) | |
| break | |
| except queue.Full: | |
| continue | |
| gen_frame_idx += 1 | |
| res_kp_seq_valid_start += real_valid_len | |
| local_idx += real_valid_len | |
| global_idx += real_valid_len | |
| L = res_kp_seq.shape[1] | |
| if L > seq_frames * 2: | |
| cut_L = L - seq_frames * 2 | |
| res_kp_seq = res_kp_seq[:, cut_L:] | |
| res_kp_seq_valid_start -= cut_L | |
| if local_idx >= len(self.audio_feat): | |
| break | |
| L = len(self.audio_feat) | |
| if L > seq_frames * 2: | |
| cut_L = L - seq_frames * 2 | |
| self.audio_feat = self.audio_feat[cut_L:] | |
| local_idx -= cut_L | |
| if is_end: | |
| break | |
| self.motion_stitch_queue.put(None) | |
| def close(self): | |
| # flush frames | |
| self.audio2motion_queue.put(None) | |
| # Wait for worker threads to finish | |
| for thread in self.thread_list: | |
| thread.join() | |
| try: | |
| self.writer.close() | |
| self.writer_pbar.close() | |
| except: | |
| traceback.print_exc() | |
| # Check if any worker encountered an exception | |
| if self.worker_exception is not None: | |
| raise self.worker_exception | |
| def run_chunk(self, audio_chunk, chunksize=(3, 5, 2)): | |
| # only for hubert | |
| aud_feat = self.wav2feat(audio_chunk, chunksize=chunksize) | |
| while not self.stop_event.is_set(): | |
| try: | |
| self.audio2motion_queue.put(aud_feat, timeout=1) | |
| break | |
| except queue.Full: | |
| continue | |