Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import logging | |
| import time | |
| import os | |
| import torch | |
| from tqdm import tqdm | |
| from maskrcnn_benchmark.data.datasets.evaluation import evaluate | |
| from ..utils.comm import is_main_process, get_world_size | |
| from ..utils.comm import all_gather | |
| from ..utils.comm import synchronize | |
| from ..utils.timer import Timer, get_time_str | |
| def compute_on_dataset(model, data_loader, device, timer=None): | |
| model.eval() | |
| results_dict = {} | |
| cpu_device = torch.device("cpu") | |
| for _, batch in enumerate(tqdm(data_loader)): | |
| images, targets, image_ids = batch | |
| images = images.to(device) | |
| with torch.no_grad(): | |
| if timer: | |
| timer.tic() | |
| output = model(images) | |
| if timer: | |
| torch.cuda.synchronize() | |
| timer.toc() | |
| output = [o.to(cpu_device) for o in output] | |
| results_dict.update( | |
| {img_id: result for img_id, result in zip(image_ids, output)} | |
| ) | |
| return results_dict | |
| def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): | |
| all_predictions = all_gather(predictions_per_gpu) | |
| if not is_main_process(): | |
| return | |
| # merge the list of dicts | |
| predictions = {} | |
| for p in all_predictions: | |
| predictions.update(p) | |
| # convert a dict where the key is the index in a list | |
| image_ids = list(sorted(predictions.keys())) | |
| if len(image_ids) != image_ids[-1] + 1: | |
| logger = logging.getLogger("maskrcnn_benchmark.inference") | |
| logger.warning( | |
| "Number of images that were gathered from multiple processes is not " | |
| "a contiguous set. Some images might be missing from the evaluation" | |
| ) | |
| # convert to a list | |
| predictions = [predictions[i] for i in image_ids] | |
| return predictions | |
| def inference( | |
| model, | |
| data_loader, | |
| dataset_name, | |
| iou_types=("bbox",), | |
| box_only=False, | |
| device="cuda", | |
| expected_results=(), | |
| expected_results_sigma_tol=4, | |
| output_folder=None, | |
| ): | |
| logger = logging.getLogger("maskrcnn_benchmark.inference") | |
| dataset = data_loader.dataset | |
| logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset))) | |
| extra_args = dict( | |
| box_only=box_only, | |
| iou_types=iou_types, | |
| expected_results=expected_results, | |
| expected_results_sigma_tol=expected_results_sigma_tol, | |
| ) | |
| # load predictions if exists | |
| prediction_file = os.path.join(output_folder, 'predictions.pth') | |
| if os.path.isfile(prediction_file): | |
| predictions = torch.load(prediction_file) | |
| logger.info("Found prediction results at {}".format(prediction_file)) | |
| return evaluate(dataset=dataset, | |
| predictions=predictions, | |
| output_folder=output_folder, | |
| **extra_args) | |
| # convert to a torch.device for efficiency | |
| device = torch.device(device) | |
| num_devices = get_world_size() | |
| total_timer = Timer() | |
| inference_timer = Timer() | |
| total_timer.tic() | |
| predictions = compute_on_dataset(model, data_loader, device, inference_timer) | |
| # wait for all processes to complete before measuring the time | |
| synchronize() | |
| total_time = total_timer.toc() | |
| total_time_str = get_time_str(total_time) | |
| logger.info( | |
| "Total run time: {} ({} s / img per device, on {} devices)".format( | |
| total_time_str, total_time * num_devices / len(dataset), num_devices | |
| ) | |
| ) | |
| total_infer_time = get_time_str(inference_timer.total_time) | |
| logger.info( | |
| "Model inference time: {} ({} s / img per device, on {} devices)".format( | |
| total_infer_time, | |
| inference_timer.total_time * num_devices / len(dataset), | |
| num_devices, | |
| ) | |
| ) | |
| predictions = _accumulate_predictions_from_multiple_gpus(predictions) | |
| if not is_main_process(): | |
| return | |
| if output_folder: | |
| torch.save(predictions, os.path.join(output_folder, "predictions.pth")) | |
| return evaluate(dataset=dataset, | |
| predictions=predictions, | |
| output_folder=output_folder, | |
| **extra_args) | |