Spaces:
Paused
Paused
| import numpy as np | |
| import cv2 | |
| from ..aux_models.insightface_det import InsightFaceDet | |
| from ..aux_models.insightface_landmark106 import Landmark106 | |
| from ..aux_models.landmark203 import Landmark203 | |
| from ..aux_models.mediapipe_landmark478 import Landmark478 | |
| from ..models.appearance_extractor import AppearanceExtractor | |
| from ..models.motion_extractor import MotionExtractor | |
| from ..utils.crop import crop_image | |
| from ..utils.eye_info import EyeAttrUtilsByMP | |
| """ | |
| insightface_det_cfg = { | |
| "model_path": "", | |
| "device": "cuda", | |
| "force_ori_type": False, | |
| } | |
| landmark106_cfg = { | |
| "model_path": "", | |
| "device": "cuda", | |
| "force_ori_type": False, | |
| } | |
| landmark203_cfg = { | |
| "model_path": "", | |
| "device": "cuda", | |
| "force_ori_type": False, | |
| } | |
| landmark478_cfg = { | |
| "blaze_face_model_path": "", | |
| "face_mesh_model_path": "", | |
| "device": "cuda", | |
| "force_ori_type": False, | |
| "task_path": "", | |
| } | |
| appearance_extractor_cfg = { | |
| "model_path": "", | |
| "device": "cuda", | |
| } | |
| motion_extractor_cfg = { | |
| "model_path": "", | |
| "device": "cuda", | |
| } | |
| """ | |
| class Source2Info: | |
| def __init__( | |
| self, | |
| insightface_det_cfg, | |
| landmark106_cfg, | |
| landmark203_cfg, | |
| landmark478_cfg, | |
| appearance_extractor_cfg, | |
| motion_extractor_cfg, | |
| ): | |
| self.insightface_det = InsightFaceDet(**insightface_det_cfg) | |
| self.landmark106 = Landmark106(**landmark106_cfg) | |
| self.landmark203 = Landmark203(**landmark203_cfg) | |
| self.landmark478 = Landmark478(**landmark478_cfg) | |
| self.appearance_extractor = AppearanceExtractor(**appearance_extractor_cfg) | |
| self.motion_extractor = MotionExtractor(**motion_extractor_cfg) | |
| def _crop(self, img, last_lmk=None, **kwargs): | |
| # img_rgb -> det->landmark106->landmark203->crop | |
| if last_lmk is None: # det for first frame or image | |
| det, _ = self.insightface_det(img) | |
| boxes = det[np.argsort(-(det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]))] | |
| if len(boxes) == 0: | |
| return None | |
| lmk_for_track = self.landmark106(img, boxes[0]) # 106 | |
| else: # track for video frames | |
| lmk_for_track = last_lmk # 203 | |
| crop_dct = crop_image( | |
| img, | |
| lmk_for_track, | |
| dsize=self.landmark203.dsize, | |
| scale=1.5, | |
| vy_ratio=-0.1, | |
| pt_crop_flag=False, | |
| ) | |
| lmk203 = self.landmark203(crop_dct["img_crop"], crop_dct["M_c2o"]) | |
| ret_dct = crop_image( | |
| img, | |
| lmk203, | |
| dsize=512, | |
| scale=kwargs.get("crop_scale", 2.3), | |
| vx_ratio=kwargs.get("crop_vx_ratio", 0), | |
| vy_ratio=kwargs.get("crop_vy_ratio", -0.125), | |
| flag_do_rot=kwargs.get("crop_flag_do_rot", True), | |
| pt_crop_flag=False, | |
| ) | |
| img_crop = ret_dct["img_crop"] | |
| M_c2o = ret_dct["M_c2o"] | |
| return img_crop, M_c2o, lmk203 | |
| def _img_crop_to_bchw256(img_crop): | |
| rgb_256 = cv2.resize(img_crop, (256, 256), interpolation=cv2.INTER_AREA) | |
| rgb_256_bchw = (rgb_256.astype(np.float32) / 255.0)[None].transpose(0, 3, 1, 2) | |
| return rgb_256_bchw | |
| def _get_kp_info(self, img): | |
| # rgb_256_bchw_norm01 | |
| kp_info = self.motion_extractor(img) | |
| return kp_info | |
| def _get_f3d(self, img): | |
| # rgb_256_bchw_norm01 | |
| fs = self.appearance_extractor(img) | |
| return fs | |
| def _get_eye_info(self, img): | |
| # rgb uint8 | |
| lmk478 = self.landmark478(img) # [1, 478, 3] | |
| attr = EyeAttrUtilsByMP(lmk478) | |
| lr_open = attr.LR_open().reshape(-1, 2) # [1, 2] | |
| lr_ball = attr.LR_ball_move().reshape(-1, 6) # [1, 3, 2] -> [1, 6] | |
| return [lr_open, lr_ball] | |
| def __call__(self, img, last_lmk=None, **kwargs): | |
| """ | |
| img: rgb, uint8 | |
| last_lmk: last frame lmk203, for video tracking | |
| kwargs: optional crop cfg | |
| crop_scale: 2.3 | |
| crop_vx_ratio: 0 | |
| crop_vy_ratio: -0.125 | |
| crop_flag_do_rot: True | |
| """ | |
| img_crop, M_c2o, lmk203 = self._crop(img, last_lmk=last_lmk, **kwargs) | |
| eye_open, eye_ball = self._get_eye_info(img_crop) | |
| rgb_256_bchw = self._img_crop_to_bchw256(img_crop) | |
| kp_info = self._get_kp_info(rgb_256_bchw) | |
| fs = self._get_f3d(rgb_256_bchw) | |
| source_info = { | |
| "x_s_info": kp_info, | |
| "f_s": fs, | |
| "M_c2o": M_c2o, | |
| "eye_open": eye_open, # [1, 2] | |
| "eye_ball": eye_ball, # [1, 6] | |
| "lmk203": lmk203, # for track | |
| } | |
| return source_info | |