|
|
|
import copy |
|
from itertools import count |
|
|
|
import numpy as np |
|
import torch |
|
from fvcore.transforms import HFlipTransform |
|
from torch import nn |
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
from detectron2.data.detection_utils import read_image |
|
from detectron2.modeling import DatasetMapperTTA |
|
|
|
__all__ = [ |
|
"SemanticSegmentorWithTTA", |
|
] |
|
|
|
|
|
class SemanticSegmentorWithTTA(nn.Module): |
|
""" |
|
A SemanticSegmentor with test-time augmentation enabled. |
|
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. |
|
""" |
|
|
|
def __init__(self, cfg, model, tta_mapper=None, batch_size=1): |
|
""" |
|
Args: |
|
cfg (CfgNode): |
|
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. |
|
tta_mapper (callable): takes a dataset dict and returns a list of |
|
augmented versions of the dataset dict. Defaults to |
|
`DatasetMapperTTA(cfg)`. |
|
batch_size (int): batch the augmented images into this batch size for inference. |
|
""" |
|
super().__init__() |
|
if isinstance(model, DistributedDataParallel): |
|
model = model.module |
|
self.cfg = cfg.clone() |
|
|
|
self.model = model |
|
|
|
if tta_mapper is None: |
|
tta_mapper = DatasetMapperTTA(cfg) |
|
self.tta_mapper = tta_mapper |
|
self.batch_size = batch_size |
|
|
|
def _batch_inference(self, batched_inputs): |
|
""" |
|
Execute inference on a list of inputs, |
|
using batch size = self.batch_size, instead of the length of the list. |
|
Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward` |
|
""" |
|
outputs = [] |
|
inputs = [] |
|
for idx, input in zip(count(), batched_inputs): |
|
inputs.append(input) |
|
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: |
|
with torch.no_grad(): |
|
outputs.extend(self.model(inputs)) |
|
inputs = [] |
|
return outputs |
|
|
|
def __call__(self, batched_inputs): |
|
""" |
|
Same input/output format as :meth:`SemanticSegmentor.forward` |
|
""" |
|
|
|
def _maybe_read_image(dataset_dict): |
|
ret = copy.copy(dataset_dict) |
|
if "image" not in ret: |
|
image = read_image(ret.pop("file_name"), self.model.input_format) |
|
image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) |
|
ret["image"] = image |
|
if "height" not in ret and "width" not in ret: |
|
ret["height"] = image.shape[1] |
|
ret["width"] = image.shape[2] |
|
return ret |
|
|
|
return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] |
|
|
|
def _inference_one_image(self, input): |
|
""" |
|
Args: |
|
input (dict): one dataset dict with "image" field being a CHW tensor |
|
Returns: |
|
dict: one output dict |
|
""" |
|
augmented_inputs, tfms = self._get_augmented_inputs(input) |
|
|
|
outputs = self._batch_inference(augmented_inputs) |
|
|
|
del augmented_inputs |
|
|
|
|
|
new_outputs = [] |
|
for output, tfm in zip(outputs, tfms): |
|
if any(isinstance(t, HFlipTransform) for t in tfm.transforms): |
|
new_outputs.append(output.pop("sem_seg").flip(dims=[2])) |
|
else: |
|
new_outputs.append(output.pop("sem_seg")) |
|
del outputs |
|
|
|
final_predictions = new_outputs[0] |
|
for i in range(1, len(new_outputs)): |
|
final_predictions += new_outputs[i] |
|
final_predictions = final_predictions / len(new_outputs) |
|
del new_outputs |
|
return {"sem_seg": final_predictions} |
|
|
|
def _get_augmented_inputs(self, input): |
|
augmented_inputs = self.tta_mapper(input) |
|
tfms = [x.pop("transforms") for x in augmented_inputs] |
|
return augmented_inputs, tfms |
|
|