# Copyright (c) OpenMMLab. All rights reserved. import logging import mimetypes import os import time from argparse import ArgumentParser import cv2 import json_tricks as json import mmcv import mmengine import numpy as np from mmengine.logging import print_log from mmpose.apis import inference_topdown, init_model from mmpose.registry import VISUALIZERS from mmpose.structures import (PoseDataSample, merge_data_samples, split_instances) def parse_args(): parser = ArgumentParser() parser.add_argument('config', help='Config file') parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument( '--input', type=str, default='', help='Image/Video file') parser.add_argument( '--output-root', type=str, default='', help='root of the output img file. ' 'Default not saving the visualization images.') parser.add_argument( '--save-predictions', action='store_true', default=False, help='whether to save predicted results') parser.add_argument( '--disable-rebase-keypoint', action='store_true', default=False, help='Whether to disable rebasing the predicted 3D pose so its ' 'lowest keypoint has a height of 0 (landing on the ground). Rebase ' 'is useful for visualization when the model do not predict the ' 'global position of the 3D pose.') parser.add_argument( '--show', action='store_true', default=False, help='whether to show result') parser.add_argument('--device', default='cpu', help='Device for inference') parser.add_argument( '--kpt-thr', type=float, default=0.3, help='Visualizing keypoint thresholds') parser.add_argument( '--show-kpt-idx', action='store_true', default=False, help='Whether to show the index of keypoints') parser.add_argument( '--show-interval', type=int, default=0, help='Sleep seconds per frame') parser.add_argument( '--radius', type=int, default=3, help='Keypoint radius for visualization') parser.add_argument( '--thickness', type=int, default=1, help='Link thickness for visualization') args = parser.parse_args() return args def process_one_image(args, img, model, visualizer=None, show_interval=0): """Visualize predicted keypoints of one image.""" # inference a single image pose_results = inference_topdown(model, img) # post-processing pose_results_2d = [] for idx, res in enumerate(pose_results): pred_instances = res.pred_instances keypoints = pred_instances.keypoints rel_root_depth = pred_instances.rel_root_depth scores = pred_instances.keypoint_scores hand_type = pred_instances.hand_type res_2d = PoseDataSample() gt_instances = res.gt_instances.clone() pred_instances = pred_instances.clone() res_2d.gt_instances = gt_instances res_2d.pred_instances = pred_instances # add relative root depth to left hand joints keypoints[:, 21:, 2] += rel_root_depth # set joint scores according to hand type scores[:, :21] *= hand_type[:, [0]] scores[:, 21:] *= hand_type[:, [1]] # normalize kpt score if scores.max() > 1: scores /= 255 res_2d.pred_instances.set_field(keypoints[..., :2].copy(), 'keypoints') # rotate the keypoint to make z-axis correspondent to height # for better visualization vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) keypoints[..., :3] = keypoints[..., :3] @ vis_R # rebase height (z-axis) if not args.disable_rebase_keypoint: valid = scores > 0 keypoints[..., 2] -= np.min( keypoints[valid, 2], axis=-1, keepdims=True) pose_results[idx].pred_instances.keypoints = keypoints pose_results[idx].pred_instances.keypoint_scores = scores pose_results_2d.append(res_2d) data_samples = merge_data_samples(pose_results) data_samples_2d = merge_data_samples(pose_results_2d) # show the results if isinstance(img, str): img = mmcv.imread(img, channel_order='rgb') elif isinstance(img, np.ndarray): img = mmcv.bgr2rgb(img) if visualizer is not None: visualizer.add_datasample( 'result', img, data_sample=data_samples, det_data_sample=data_samples_2d, draw_gt=False, draw_bbox=True, kpt_thr=args.kpt_thr, convert_keypoint=False, axis_azimuth=-115, axis_limit=200, axis_elev=15, show_kpt_idx=args.show_kpt_idx, show=args.show, wait_time=show_interval) # if there is no instance detected, return None return data_samples.get('pred_instances', None) def main(): args = parse_args() assert args.input != '' assert args.show or (args.output_root != '') output_file = None if args.output_root: mmengine.mkdir_or_exist(args.output_root) output_file = os.path.join(args.output_root, os.path.basename(args.input)) if args.input == 'webcam': output_file += '.mp4' if args.save_predictions: assert args.output_root != '' args.pred_save_path = f'{args.output_root}/results_' \ f'{os.path.splitext(os.path.basename(args.input))[0]}.json' # build the model from a config file and a checkpoint file model = init_model( args.config, args.checkpoint, device=args.device.lower()) # init visualizer model.cfg.visualizer.radius = args.radius model.cfg.visualizer.line_width = args.thickness visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer.set_dataset_meta(model.dataset_meta) if args.input == 'webcam': input_type = 'webcam' else: input_type = mimetypes.guess_type(args.input)[0].split('/')[0] if input_type == 'image': # inference pred_instances = process_one_image(args, args.input, model, visualizer) if args.save_predictions: pred_instances_list = split_instances(pred_instances) if output_file: img_vis = visualizer.get_image() mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) elif input_type in ['webcam', 'video']: if args.input == 'webcam': cap = cv2.VideoCapture(0) else: cap = cv2.VideoCapture(args.input) video_writer = None pred_instances_list = [] frame_idx = 0 while cap.isOpened(): success, frame = cap.read() frame_idx += 1 if not success: break # topdown pose estimation pred_instances = process_one_image(args, frame, model, visualizer, 0.001) if args.save_predictions: # save prediction results pred_instances_list.append( dict( frame_id=frame_idx, instances=split_instances(pred_instances))) # output videos if output_file: frame_vis = visualizer.get_image() if video_writer is None: fourcc = cv2.VideoWriter_fourcc(*'mp4v') # the size of the image with visualization may vary # depending on the presence of heatmaps video_writer = cv2.VideoWriter( output_file, fourcc, 25, # saved fps (frame_vis.shape[1], frame_vis.shape[0])) video_writer.write(mmcv.rgb2bgr(frame_vis)) if args.show: # press ESC to exit if cv2.waitKey(5) & 0xFF == 27: break time.sleep(args.show_interval) if video_writer: video_writer.release() cap.release() else: args.save_predictions = False raise ValueError( f'file {os.path.basename(args.input)} has invalid format.') if args.save_predictions: with open(args.pred_save_path, 'w') as f: json.dump( dict( meta_info=model.dataset_meta, instance_info=pred_instances_list), f, indent='\t') print_log( f'predictions have been saved at {args.pred_save_path}', logger='current', level=logging.INFO) if output_file is not None: input_type = input_type.replace('webcam', 'video') print_log( f'the output {input_type} has been saved at {output_file}', logger='current', level=logging.INFO) if __name__ == '__main__': main()