|
|
|
import math |
|
|
|
import mmcv |
|
import numpy as np |
|
import torch |
|
import torchvision.transforms.functional as TF |
|
from mmcv.runner.dist_utils import get_dist_info |
|
from mmdet.datasets.builder import PIPELINES |
|
from PIL import Image |
|
from shapely.geometry import Polygon |
|
from shapely.geometry import box as shapely_box |
|
|
|
import mmocr.utils as utils |
|
from mmocr.datasets.pipelines.crop import warp_img |
|
|
|
|
|
@PIPELINES.register_module() |
|
class ResizeOCR: |
|
"""Image resizing and padding for OCR. |
|
|
|
Args: |
|
height (int | tuple(int)): Image height after resizing. |
|
min_width (none | int | tuple(int)): Image minimum width |
|
after resizing. |
|
max_width (none | int | tuple(int)): Image maximum width |
|
after resizing. |
|
keep_aspect_ratio (bool): Keep image aspect ratio if True |
|
during resizing, Otherwise resize to the size height * |
|
max_width. |
|
img_pad_value (int): Scalar to fill padding area. |
|
width_downsample_ratio (float): Downsample ratio in horizontal |
|
direction from input image to output feature. |
|
backend (str | None): The image resize backend type. Options are `cv2`, |
|
`pillow`, `None`. If backend is None, the global imread_backend |
|
specified by ``mmcv.use_backend()`` will be used. Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
height, |
|
min_width=None, |
|
max_width=None, |
|
keep_aspect_ratio=True, |
|
img_pad_value=0, |
|
width_downsample_ratio=1.0 / 16, |
|
backend=None): |
|
assert isinstance(height, (int, tuple)) |
|
assert utils.is_none_or_type(min_width, (int, tuple)) |
|
assert utils.is_none_or_type(max_width, (int, tuple)) |
|
if not keep_aspect_ratio: |
|
assert max_width is not None, ('"max_width" must assigned ' |
|
'if "keep_aspect_ratio" is False') |
|
assert isinstance(img_pad_value, int) |
|
if isinstance(height, tuple): |
|
assert isinstance(min_width, tuple) |
|
assert isinstance(max_width, tuple) |
|
assert len(height) == len(min_width) == len(max_width) |
|
|
|
self.height = height |
|
self.min_width = min_width |
|
self.max_width = max_width |
|
self.keep_aspect_ratio = keep_aspect_ratio |
|
self.img_pad_value = img_pad_value |
|
self.width_downsample_ratio = width_downsample_ratio |
|
self.backend = backend |
|
|
|
def __call__(self, results): |
|
rank, _ = get_dist_info() |
|
if isinstance(self.height, int): |
|
dst_height = self.height |
|
dst_min_width = self.min_width |
|
dst_max_width = self.max_width |
|
else: |
|
|
|
|
|
|
|
idx = rank % len(self.height) |
|
dst_height = self.height[idx] |
|
dst_min_width = self.min_width[idx] |
|
dst_max_width = self.max_width[idx] |
|
|
|
img_shape = results['img_shape'] |
|
ori_height, ori_width = img_shape[:2] |
|
valid_ratio = 1.0 |
|
resize_shape = list(img_shape) |
|
pad_shape = list(img_shape) |
|
|
|
if self.keep_aspect_ratio: |
|
new_width = math.ceil(float(dst_height) / ori_height * ori_width) |
|
width_divisor = int(1 / self.width_downsample_ratio) |
|
|
|
if new_width % width_divisor != 0: |
|
new_width = round(new_width / width_divisor) * width_divisor |
|
if dst_min_width is not None: |
|
new_width = max(dst_min_width, new_width) |
|
if dst_max_width is not None: |
|
valid_ratio = min(1.0, 1.0 * new_width / dst_max_width) |
|
resize_width = min(dst_max_width, new_width) |
|
img_resize = mmcv.imresize( |
|
results['img'], (resize_width, dst_height), |
|
backend=self.backend) |
|
resize_shape = img_resize.shape |
|
pad_shape = img_resize.shape |
|
if new_width < dst_max_width: |
|
img_resize = mmcv.impad( |
|
img_resize, |
|
shape=(dst_height, dst_max_width), |
|
pad_val=self.img_pad_value) |
|
pad_shape = img_resize.shape |
|
else: |
|
img_resize = mmcv.imresize( |
|
results['img'], (new_width, dst_height), |
|
backend=self.backend) |
|
resize_shape = img_resize.shape |
|
pad_shape = img_resize.shape |
|
else: |
|
img_resize = mmcv.imresize( |
|
results['img'], (dst_max_width, dst_height), |
|
backend=self.backend) |
|
resize_shape = img_resize.shape |
|
pad_shape = img_resize.shape |
|
|
|
results['img'] = img_resize |
|
results['img_shape'] = resize_shape |
|
results['resize_shape'] = resize_shape |
|
results['pad_shape'] = pad_shape |
|
results['valid_ratio'] = valid_ratio |
|
|
|
return results |
|
|
|
|
|
@PIPELINES.register_module() |
|
class ToTensorOCR: |
|
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def __call__(self, results): |
|
results['img'] = TF.to_tensor(results['img'].copy()) |
|
|
|
return results |
|
|
|
|
|
@PIPELINES.register_module() |
|
class NormalizeOCR: |
|
"""Normalize a tensor image with mean and standard deviation.""" |
|
|
|
def __init__(self, mean, std): |
|
self.mean = mean |
|
self.std = std |
|
|
|
def __call__(self, results): |
|
results['img'] = TF.normalize(results['img'], self.mean, self.std) |
|
results['img_norm_cfg'] = dict(mean=self.mean, std=self.std) |
|
return results |
|
|
|
|
|
@PIPELINES.register_module() |
|
class OnlineCropOCR: |
|
"""Crop text areas from whole image with bounding box jitter. If no bbox is |
|
given, return directly. |
|
|
|
Args: |
|
box_keys (list[str]): Keys in results which correspond to RoI bbox. |
|
jitter_prob (float): The probability of box jitter. |
|
max_jitter_ratio_x (float): Maximum horizontal jitter ratio |
|
relative to height. |
|
max_jitter_ratio_y (float): Maximum vertical jitter ratio |
|
relative to height. |
|
""" |
|
|
|
def __init__(self, |
|
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'], |
|
jitter_prob=0.5, |
|
max_jitter_ratio_x=0.05, |
|
max_jitter_ratio_y=0.02): |
|
assert utils.is_type_list(box_keys, str) |
|
assert 0 <= jitter_prob <= 1 |
|
assert 0 <= max_jitter_ratio_x <= 1 |
|
assert 0 <= max_jitter_ratio_y <= 1 |
|
|
|
self.box_keys = box_keys |
|
self.jitter_prob = jitter_prob |
|
self.max_jitter_ratio_x = max_jitter_ratio_x |
|
self.max_jitter_ratio_y = max_jitter_ratio_y |
|
|
|
def __call__(self, results): |
|
|
|
if 'img_info' not in results: |
|
return results |
|
|
|
crop_flag = True |
|
box = [] |
|
for key in self.box_keys: |
|
if key not in results['img_info']: |
|
crop_flag = False |
|
break |
|
|
|
box.append(float(results['img_info'][key])) |
|
|
|
if not crop_flag: |
|
return results |
|
|
|
jitter_flag = np.random.random() > self.jitter_prob |
|
|
|
kwargs = dict( |
|
jitter_flag=jitter_flag, |
|
jitter_ratio_x=self.max_jitter_ratio_x, |
|
jitter_ratio_y=self.max_jitter_ratio_y) |
|
crop_img = warp_img(results['img'], box, **kwargs) |
|
|
|
results['img'] = crop_img |
|
results['img_shape'] = crop_img.shape |
|
|
|
return results |
|
|
|
|
|
@PIPELINES.register_module() |
|
class FancyPCA: |
|
"""Implementation of PCA based image augmentation, proposed in the paper |
|
``Imagenet Classification With Deep Convolutional Neural Networks``. |
|
|
|
It alters the intensities of RGB values along the principal components of |
|
ImageNet dataset. |
|
""" |
|
|
|
def __init__(self, eig_vec=None, eig_val=None): |
|
if eig_vec is None: |
|
eig_vec = torch.Tensor([ |
|
[-0.5675, +0.7192, +0.4009], |
|
[-0.5808, -0.0045, -0.8140], |
|
[-0.5836, -0.6948, +0.4203], |
|
]).t() |
|
if eig_val is None: |
|
eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]]) |
|
self.eig_val = eig_val |
|
self.eig_vec = eig_vec |
|
|
|
def pca(self, tensor): |
|
assert tensor.size(0) == 3 |
|
alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1 |
|
reconst = torch.mm(self.eig_val * alpha, self.eig_vec) |
|
tensor = tensor + reconst.view(3, 1, 1) |
|
|
|
return tensor |
|
|
|
def __call__(self, results): |
|
img = results['img'] |
|
tensor = self.pca(img) |
|
results['img'] = tensor |
|
|
|
return results |
|
|
|
def __repr__(self): |
|
repr_str = self.__class__.__name__ |
|
return repr_str |
|
|
|
|
|
@PIPELINES.register_module() |
|
class RandomPaddingOCR: |
|
"""Pad the given image on all sides, as well as modify the coordinates of |
|
character bounding box in image. |
|
|
|
Args: |
|
max_ratio (list[int]): [left, top, right, bottom]. |
|
box_type (None|str): Character box type. If not none, |
|
should be either 'char_rects' or 'char_quads', with |
|
'char_rects' for rectangle with ``xyxy`` style and |
|
'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style. |
|
""" |
|
|
|
def __init__(self, max_ratio=None, box_type=None): |
|
if max_ratio is None: |
|
max_ratio = [0.1, 0.2, 0.1, 0.2] |
|
else: |
|
assert utils.is_type_list(max_ratio, float) |
|
assert len(max_ratio) == 4 |
|
assert box_type is None or box_type in ('char_rects', 'char_quads') |
|
|
|
self.max_ratio = max_ratio |
|
self.box_type = box_type |
|
|
|
def __call__(self, results): |
|
|
|
img_shape = results['img_shape'] |
|
ori_height, ori_width = img_shape[:2] |
|
|
|
random_padding_left = round( |
|
np.random.uniform(0, self.max_ratio[0]) * ori_width) |
|
random_padding_top = round( |
|
np.random.uniform(0, self.max_ratio[1]) * ori_height) |
|
random_padding_right = round( |
|
np.random.uniform(0, self.max_ratio[2]) * ori_width) |
|
random_padding_bottom = round( |
|
np.random.uniform(0, self.max_ratio[3]) * ori_height) |
|
|
|
padding = (random_padding_left, random_padding_top, |
|
random_padding_right, random_padding_bottom) |
|
img = mmcv.impad(results['img'], padding=padding, padding_mode='edge') |
|
|
|
results['img'] = img |
|
results['img_shape'] = img.shape |
|
|
|
if self.box_type is not None: |
|
num_points = 2 if self.box_type == 'char_rects' else 4 |
|
char_num = len(results['ann_info'][self.box_type]) |
|
for i in range(char_num): |
|
for j in range(num_points): |
|
results['ann_info'][self.box_type][i][ |
|
j * 2] += random_padding_left |
|
results['ann_info'][self.box_type][i][ |
|
j * 2 + 1] += random_padding_top |
|
|
|
return results |
|
|
|
def __repr__(self): |
|
repr_str = self.__class__.__name__ |
|
return repr_str |
|
|
|
|
|
@PIPELINES.register_module() |
|
class RandomRotateImageBox: |
|
"""Rotate augmentation for segmentation based text recognition. |
|
|
|
Args: |
|
min_angle (int): Minimum rotation angle for image and box. |
|
max_angle (int): Maximum rotation angle for image and box. |
|
box_type (str): Character box type, should be either |
|
'char_rects' or 'char_quads', with 'char_rects' |
|
for rectangle with ``xyxy`` style and 'char_quads' |
|
for quadrangle with ``x1y1x2y2x3y3x4y4`` style. |
|
""" |
|
|
|
def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'): |
|
assert box_type in ('char_rects', 'char_quads') |
|
|
|
self.min_angle = min_angle |
|
self.max_angle = max_angle |
|
self.box_type = box_type |
|
|
|
def __call__(self, results): |
|
in_img = results['img'] |
|
in_chars = results['ann_info']['chars'] |
|
in_boxes = results['ann_info'][self.box_type] |
|
|
|
img_width, img_height = in_img.size |
|
rotate_center = [img_width / 2., img_height / 2.] |
|
|
|
tan_temp_max_angle = rotate_center[1] / rotate_center[0] |
|
temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi |
|
|
|
random_angle = np.random.uniform( |
|
max(self.min_angle, -temp_max_angle), |
|
min(self.max_angle, temp_max_angle)) |
|
random_angle_radian = random_angle * np.pi / 180. |
|
|
|
img_box = shapely_box(0, 0, img_width, img_height) |
|
|
|
out_img = TF.rotate( |
|
in_img, |
|
random_angle, |
|
resample=False, |
|
expand=False, |
|
center=rotate_center) |
|
|
|
out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars, |
|
random_angle_radian, |
|
rotate_center, img_box) |
|
|
|
results['img'] = out_img |
|
results['ann_info']['chars'] = out_chars |
|
results['ann_info'][self.box_type] = out_boxes |
|
|
|
return results |
|
|
|
@staticmethod |
|
def rotate_bbox(boxes, chars, angle, center, img_box): |
|
out_boxes = [] |
|
out_chars = [] |
|
for idx, bbox in enumerate(boxes): |
|
temp_bbox = [] |
|
for i in range(len(bbox) // 2): |
|
point = [bbox[2 * i], bbox[2 * i + 1]] |
|
temp_bbox.append( |
|
RandomRotateImageBox.rotate_point(point, angle, center)) |
|
poly_temp_bbox = Polygon(temp_bbox).buffer(0) |
|
if poly_temp_bbox.is_valid: |
|
if img_box.intersects(poly_temp_bbox) and ( |
|
not img_box.touches(poly_temp_bbox)): |
|
temp_bbox_area = poly_temp_bbox.area |
|
|
|
intersect_area = img_box.intersection(poly_temp_bbox).area |
|
intersect_ratio = intersect_area / temp_bbox_area |
|
|
|
if intersect_ratio >= 0.7: |
|
out_box = [] |
|
for p in temp_bbox: |
|
out_box.extend(p) |
|
out_boxes.append(out_box) |
|
out_chars.append(chars[idx]) |
|
|
|
return out_boxes, out_chars |
|
|
|
@staticmethod |
|
def rotate_point(point, angle, center): |
|
cos_theta = math.cos(-angle) |
|
sin_theta = math.sin(-angle) |
|
c_x = center[0] |
|
c_y = center[1] |
|
new_x = (point[0] - c_x) * cos_theta - (point[1] - |
|
c_y) * sin_theta + c_x |
|
new_y = (point[0] - c_x) * sin_theta + (point[1] - |
|
c_y) * cos_theta + c_y |
|
|
|
return [new_x, new_y] |
|
|
|
|
|
@PIPELINES.register_module() |
|
class OpencvToPil: |
|
"""Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb).""" |
|
|
|
def __init__(self, **kwargs): |
|
pass |
|
|
|
def __call__(self, results): |
|
img = results['img'][..., ::-1] |
|
img = Image.fromarray(img) |
|
results['img'] = img |
|
|
|
return results |
|
|
|
def __repr__(self): |
|
repr_str = self.__class__.__name__ |
|
return repr_str |
|
|
|
|
|
@PIPELINES.register_module() |
|
class PilToOpencv: |
|
"""Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr).""" |
|
|
|
def __init__(self, **kwargs): |
|
pass |
|
|
|
def __call__(self, results): |
|
img = np.asarray(results['img']) |
|
img = img[..., ::-1] |
|
results['img'] = img |
|
|
|
return results |
|
|
|
def __repr__(self): |
|
repr_str = self.__class__.__name__ |
|
return repr_str |
|
|