Spaces:
Runtime error
Runtime error
| from lib.kits.hsmr_demo import * | |
| import gradio as gr | |
| from lib.modeling.pipelines import HSMRPipeline | |
| class HSMRBackend: | |
| ''' | |
| Backend class for maintaining HSMR model for inferencing. | |
| Some gradio feature is included in this class. | |
| ''' | |
| def __init__(self, device:str='cpu') -> None: | |
| self.max_img_w = 1920 | |
| self.max_img_h = 1080 | |
| self.device = device | |
| self.pipeline = self._build_pipeline(self.device) | |
| self.detector = build_detector( | |
| batch_size = 1, | |
| max_img_size = 512, | |
| device = self.device, | |
| ) | |
| def _build_pipeline(self, device) -> HSMRPipeline: | |
| return build_inference_pipeline( | |
| model_root = DEFAULT_HSMR_ROOT, | |
| device = device, | |
| ) | |
| def _load_limited_img(self, fn) -> List: | |
| img, _ = load_img(fn) | |
| if img.shape[0] > self.max_img_h: | |
| img = flex_resize_img(img, (self.max_img_h, -1), kp_mod=4) | |
| if img.shape[1] > self.max_img_w: | |
| img = flex_resize_img(img, (-1, self.max_img_w), kp_mod=4) | |
| return [img] | |
| def __call__(self, input_path:Union[str, Path], args:Dict): | |
| # 1. Initialization. | |
| input_type = 'img' | |
| if isinstance(input_path, str): input_path = Path(input_path) | |
| outputs_root = input_path.parent / 'outputs' | |
| outputs_root.mkdir(parents=True, exist_ok=True) | |
| # 2. Preprocess. | |
| gr.Info(f'[1/3] Pre-processing...') | |
| raw_imgs = self._load_limited_img(input_path) | |
| detector_outputs = self.detector(raw_imgs) | |
| patches, det_meta = imgs_det2patches(raw_imgs, *detector_outputs,args['max_instances']) # N * (256, 256, 3) | |
| # 3. Inference. | |
| gr.Info(f'[2/3] HSMR inferencing...') | |
| pd_params, pd_cam_t = [], [] | |
| for bw in bsb(total=len(patches), batch_size=args['rec_bs'], enable_tqdm=True): | |
| patches_i = np.concatenate(patches[bw.sid:bw.eid], axis=0) # (N, 256, 256, 3) | |
| patches_normalized_i = (patches_i - IMG_MEAN_255) / IMG_STD_255 # (N, 256, 256, 3) | |
| patches_normalized_i = patches_normalized_i.transpose(0, 3, 1, 2) # (N, 3, 256, 256) | |
| with torch.no_grad(): | |
| outputs = self.pipeline(patches_normalized_i) | |
| pd_params.append({k: v.detach().cpu().clone() for k, v in outputs['pd_params'].items()}) | |
| pd_cam_t.append(outputs['pd_cam_t'].detach().cpu().clone()) | |
| pd_params = assemble_dict(pd_params, expand_dim=False) # [{k:[x]}, {k:[y]}] -> {k:[x, y]} | |
| pd_cam_t = torch.cat(pd_cam_t, dim=0) | |
| # 4. Render. | |
| gr.Info(f'[3/3] Rendering results...') | |
| m_skin, m_skel = prepare_mesh(self.pipeline, pd_params) | |
| results = visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel) | |
| outputs = {} | |
| if input_type == 'img': | |
| for k, v in results.items(): | |
| img_path = str(outputs_root / f'{k}.jpg') | |
| outputs[k] = img_path | |
| save_img(v, img_path) | |
| outputs[k] = img_path | |
| return outputs |