|
|
|
import copy |
|
import warnings |
|
|
|
import mmcv |
|
import numpy as np |
|
import torch |
|
from mmdet.datasets import replace_ImageToTensor |
|
|
|
from mmocr.utils import is_2dlist, is_type_list |
|
|
|
|
|
def update_pipeline(cfg, idx=None): |
|
if idx is None: |
|
if cfg.pipeline is not None: |
|
cfg.pipeline = replace_ImageToTensor(cfg.pipeline) |
|
else: |
|
cfg.pipeline[idx] = replace_ImageToTensor(cfg.pipeline[idx]) |
|
|
|
|
|
def replace_image_to_tensor(cfg, set_types=None): |
|
"""Replace 'ImageToTensor' to 'DefaultFormatBundle'.""" |
|
assert set_types is None or isinstance(set_types, list) |
|
if set_types is None: |
|
set_types = ['val', 'test'] |
|
|
|
cfg = copy.deepcopy(cfg) |
|
for set_type in set_types: |
|
assert set_type in ['val', 'test'] |
|
uniform_pipeline = cfg.data[set_type].get('pipeline', None) |
|
if is_type_list(uniform_pipeline, dict): |
|
update_pipeline(cfg.data[set_type]) |
|
elif is_2dlist(uniform_pipeline): |
|
for idx, _ in enumerate(uniform_pipeline): |
|
update_pipeline(cfg.data[set_type], idx) |
|
|
|
for dataset in cfg.data[set_type].get('datasets', []): |
|
if isinstance(dataset, list): |
|
for each_dataset in dataset: |
|
update_pipeline(each_dataset) |
|
else: |
|
update_pipeline(dataset) |
|
|
|
return cfg |
|
|
|
|
|
def update_pipeline_recog(cfg, idx=None): |
|
warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \ |
|
'inference since samples_per_gpu > 1.' |
|
if idx is None: |
|
if cfg.get('pipeline', |
|
None) and cfg.pipeline[1].type == 'MultiRotateAugOCR': |
|
warnings.warn(warning_msg) |
|
cfg.pipeline = [cfg.pipeline[0], *cfg.pipeline[1].transforms] |
|
else: |
|
if cfg[idx][1].type == 'MultiRotateAugOCR': |
|
warnings.warn(warning_msg) |
|
cfg[idx] = [cfg[idx][0], *cfg[idx][1].transforms] |
|
|
|
|
|
def disable_text_recog_aug_test(cfg, set_types=None): |
|
"""Remove aug_test from test pipeline for text recognition. |
|
|
|
Args: |
|
cfg (mmcv.Config): Input config. |
|
set_types (list[str]): Type of dataset source. Should be |
|
None or sublist of ['test', 'val']. |
|
""" |
|
assert set_types is None or isinstance(set_types, list) |
|
if set_types is None: |
|
set_types = ['val', 'test'] |
|
|
|
cfg = copy.deepcopy(cfg) |
|
warnings.simplefilter('once') |
|
for set_type in set_types: |
|
assert set_type in ['val', 'test'] |
|
dataset_type = cfg.data[set_type].type |
|
if dataset_type not in [ |
|
'ConcatDataset', 'UniformConcatDataset', 'OCRDataset', |
|
'OCRSegDataset' |
|
]: |
|
continue |
|
|
|
uniform_pipeline = cfg.data[set_type].get('pipeline', None) |
|
if is_type_list(uniform_pipeline, dict): |
|
update_pipeline_recog(cfg.data[set_type]) |
|
elif is_2dlist(uniform_pipeline): |
|
for idx, _ in enumerate(uniform_pipeline): |
|
update_pipeline_recog(cfg.data[set_type].pipeline, idx) |
|
|
|
for dataset in cfg.data[set_type].get('datasets', []): |
|
if isinstance(dataset, list): |
|
for each_dataset in dataset: |
|
update_pipeline_recog(each_dataset) |
|
else: |
|
update_pipeline_recog(dataset) |
|
|
|
return cfg |
|
|
|
|
|
def tensor2grayimgs(tensor, mean=(127, ), std=(127, ), **kwargs): |
|
"""Convert tensor to 1-channel gray images. |
|
|
|
Args: |
|
tensor (torch.Tensor): Tensor that contains multiple images, shape ( |
|
N, C, H, W). |
|
mean (tuple[float], optional): Mean of images. Defaults to (127). |
|
std (tuple[float], optional): Standard deviation of images. |
|
Defaults to (127). |
|
|
|
Returns: |
|
list[np.ndarray]: A list that contains multiple images. |
|
""" |
|
|
|
assert torch.is_tensor(tensor) and tensor.ndim == 4 |
|
assert tensor.size(1) == len(mean) == len(std) == 1 |
|
|
|
num_imgs = tensor.size(0) |
|
mean = np.array(mean, dtype=np.float32) |
|
std = np.array(std, dtype=np.float32) |
|
imgs = [] |
|
for img_id in range(num_imgs): |
|
img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0) |
|
img = mmcv.imdenormalize(img, mean, std, to_bgr=False).astype(np.uint8) |
|
imgs.append(np.ascontiguousarray(img)) |
|
return imgs |
|
|