Spaces:
Runtime error
Runtime error
| # Copyright (c) Tencent Inc. All rights reserved. | |
| import os | |
| import cv2 | |
| import argparse | |
| import os.path as osp | |
| import torch | |
| from mmengine.config import Config, DictAction | |
| from mmengine.runner.amp import autocast | |
| from mmengine.dataset import Compose | |
| from mmengine.utils import ProgressBar | |
| from mmdet.apis import init_detector | |
| from mmdet.utils import get_test_pipeline_cfg | |
| import supervision as sv | |
| BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(thickness=1) | |
| MASK_ANNOTATOR = sv.MaskAnnotator() | |
| class LabelAnnotator(sv.LabelAnnotator): | |
| def resolve_text_background_xyxy( | |
| center_coordinates, | |
| text_wh, | |
| position, | |
| ): | |
| center_x, center_y = center_coordinates | |
| text_w, text_h = text_wh | |
| return center_x, center_y, center_x + text_w, center_y + text_h | |
| LABEL_ANNOTATOR = LabelAnnotator(text_padding=4, | |
| text_scale=0.5, | |
| text_thickness=1) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='YOLO-World Demo') | |
| parser.add_argument('config', help='test config file path') | |
| parser.add_argument('checkpoint', help='checkpoint file') | |
| parser.add_argument('image', help='image path, include image file or dir.') | |
| parser.add_argument( | |
| 'text', | |
| help= | |
| 'text prompts, including categories separated by a comma or a txt file with each line as a prompt.' | |
| ) | |
| parser.add_argument('--topk', | |
| default=100, | |
| type=int, | |
| help='keep topk predictions.') | |
| parser.add_argument('--threshold', | |
| default=0.1, | |
| type=float, | |
| help='confidence score threshold for predictions.') | |
| parser.add_argument('--device', | |
| default='cuda:0', | |
| help='device used for inference.') | |
| parser.add_argument('--show', | |
| action='store_true', | |
| help='show the detection results.') | |
| parser.add_argument( | |
| '--annotation', | |
| action='store_true', | |
| help='save the annotated detection results as yolo text format.') | |
| parser.add_argument('--amp', | |
| action='store_true', | |
| help='use mixed precision for inference.') | |
| parser.add_argument('--output-dir', | |
| default='demo_outputs', | |
| help='the directory to save outputs') | |
| parser.add_argument( | |
| '--cfg-options', | |
| nargs='+', | |
| action=DictAction, | |
| help='override some settings in the used config, the key-value pair ' | |
| 'in xxx=yyy format will be merged into config file. If the value to ' | |
| 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
| 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
| 'Note that the quotation marks are necessary and that no white space ' | |
| 'is allowed.') | |
| args = parser.parse_args() | |
| return args | |
| def inference_detector(model, | |
| image, | |
| texts, | |
| test_pipeline, | |
| max_dets=100, | |
| score_thr=0.3, | |
| output_dir='./work_dir', | |
| use_amp=False, | |
| show=False, | |
| annotation=False): | |
| data_info = dict(img_id=0, img_path=image, texts=texts) | |
| data_info = test_pipeline(data_info) | |
| data_batch = dict(inputs=data_info['inputs'].unsqueeze(0), | |
| data_samples=[data_info['data_samples']]) | |
| with autocast(enabled=use_amp), torch.no_grad(): | |
| output = model.test_step(data_batch)[0] | |
| pred_instances = output.pred_instances | |
| pred_instances = pred_instances[pred_instances.scores.float() > | |
| score_thr] | |
| if len(pred_instances.scores) > max_dets: | |
| indices = pred_instances.scores.float().topk(max_dets)[1] | |
| pred_instances = pred_instances[indices] | |
| pred_instances = pred_instances.cpu().numpy() | |
| if 'masks' in pred_instances: | |
| masks = pred_instances['masks'] | |
| else: | |
| masks = None | |
| detections = sv.Detections(xyxy=pred_instances['bboxes'], | |
| class_id=pred_instances['labels'], | |
| confidence=pred_instances['scores'], | |
| mask=masks) | |
| labels = [ | |
| f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in | |
| zip(detections.class_id, detections.confidence) | |
| ] | |
| # label images | |
| image = cv2.imread(image_path) | |
| anno_image = image.copy() | |
| image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections) | |
| image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels) | |
| if masks is not None: | |
| image = MASK_ANNOTATOR.annotate(image, detections) | |
| cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image) | |
| if annotation: | |
| images_dict = {} | |
| annotations_dict = {} | |
| images_dict[osp.basename(image_path)] = anno_image | |
| annotations_dict[osp.basename(image_path)] = detections | |
| ANNOTATIONS_DIRECTORY = os.makedirs(r"./annotations", exist_ok=True) | |
| MIN_IMAGE_AREA_PERCENTAGE = 0.002 | |
| MAX_IMAGE_AREA_PERCENTAGE = 0.80 | |
| APPROXIMATION_PERCENTAGE = 0.75 | |
| sv.DetectionDataset( | |
| classes=texts, images=images_dict, | |
| annotations=annotations_dict).as_yolo( | |
| annotations_directory_path=ANNOTATIONS_DIRECTORY, | |
| min_image_area_percentage=MIN_IMAGE_AREA_PERCENTAGE, | |
| max_image_area_percentage=MAX_IMAGE_AREA_PERCENTAGE, | |
| approximation_percentage=APPROXIMATION_PERCENTAGE) | |
| if show: | |
| cv2.imshow('Image', image) # Provide window name | |
| k = cv2.waitKey(0) | |
| if k == 27: | |
| # wait for ESC key to exit | |
| cv2.destroyAllWindows() | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| # load config | |
| cfg = Config.fromfile(args.config) | |
| if args.cfg_options is not None: | |
| cfg.merge_from_dict(args.cfg_options) | |
| cfg.work_dir = osp.join('./work_dirs', | |
| osp.splitext(osp.basename(args.config))[0]) | |
| # init model | |
| cfg.load_from = args.checkpoint | |
| model = init_detector(cfg, checkpoint=args.checkpoint, device=args.device) | |
| # init test pipeline | |
| test_pipeline_cfg = get_test_pipeline_cfg(cfg=cfg) | |
| # test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' | |
| test_pipeline = Compose(test_pipeline_cfg) | |
| if args.text.endswith('.txt'): | |
| with open(args.text) as f: | |
| lines = f.readlines() | |
| texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']] | |
| else: | |
| texts = [[t.strip()] for t in args.text.split(',')] + [[' ']] | |
| output_dir = args.output_dir | |
| if not osp.exists(output_dir): | |
| os.mkdir(output_dir) | |
| # load images | |
| if not osp.isfile(args.image): | |
| images = [ | |
| osp.join(args.image, img) for img in os.listdir(args.image) | |
| if img.endswith('.png') or img.endswith('.jpg') | |
| ] | |
| else: | |
| images = [args.image] | |
| # reparameterize texts | |
| model.reparameterize(texts) | |
| progress_bar = ProgressBar(len(images)) | |
| for image_path in images: | |
| inference_detector(model, | |
| image_path, | |
| texts, | |
| test_pipeline, | |
| args.topk, | |
| args.threshold, | |
| output_dir=output_dir, | |
| use_amp=args.amp, | |
| show=args.show, | |
| annotation=args.annotation) | |
| progress_bar.update() | |