File size: 8,256 Bytes
e28b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
from typing import List, Tuple

from PIL import Image

from .dynamic_high_resolution import factorize_number


def construct_mapping_dict(max_splits: int = 12) -> dict:
    """Construct a mapping dictionary for the given max_splits.

    Args:
        max_splits (int, optional): The maximum number of splits.
            Defaults to 12.

    Returns:
        dict: A mapping dictionary for the given max_splits.
    """
    mapping_dict = {}
    for i in range(1, max_splits + 1):
        factor_list = factorize_number(i)
        for factor in factor_list:
            ratio = factor[0] / factor[1]
            if ratio not in mapping_dict:
                mapping_dict[ratio] = [factor]
            else:
                mapping_dict[ratio].append(factor)
    return mapping_dict


def save_image_list(image_list: List[Image.Image], save_folder: str) -> None:
    """Save a list of images to a folder.

    Args:
        image_list (List[Image.Image]): A list of images.
        save_folder (str): The folder to save the images to.
    """
    os.makedirs(save_folder, exist_ok=True)
    for i, image in enumerate(image_list):
        image.save(os.path.join(save_folder, f'{i}.png'))


def resize_to_best_size(image: Image.Image, best_slices: tuple,
                        width_slices: int, height_slices: int,
                        sub_image_size: int) -> Image.Image:
    """Resize an image to the best size for the given number of slices.

    Args:
        image (Image.Image): The image to resize.
        best_slices (tuple): The best number of slices for the image.
        width_slices (int): The number of horizontal slices.
        height_slices (int): The number of vertical slices.
        sub_image_size (int): The size of the sub-images.

    Returns:
        Image.Image: The resized image.
    """
    width, height = image.size
    best_width_slices, best_height_slices = best_slices
    if width_slices < height_slices:
        new_image_width = best_width_slices * sub_image_size
        new_image_height = int(height / width * new_image_width)
    else:
        new_image_height = best_height_slices * sub_image_size
        new_image_width = int(width / height * new_image_height)
    new_image = image.resize((new_image_width, new_image_height), resample=2)
    return new_image


def compute_strides(height: int, width: int, sub_image_size: int,
                    slices: Tuple[int, int]) -> Tuple[int, int]:
    """Compute the strides for the given image size and slices.

    Args:
        height (int): The height of the image.
        width (int): The width of the image.
        sub_image_size (int): The size of the sub-images.
        slices (Tuple[int, int]): The number of horizontal and vertical slices.

    Returns:
        Tuple[int, int]: The strides for the given image size and slices.
    """
    slice_width, slice_height = slices
    if slice_width > 1:
        stride_x = (width - sub_image_size) // (slice_width - 1)
    else:
        stride_x = 0
    if slice_height > 1:
        stride_y = (height - sub_image_size) // (slice_height - 1)
    else:
        stride_y = 0
    return stride_x, stride_y


def sliding_window_crop(image: Image.Image, window_size: int,
                        slices: Tuple[int, int]) -> List[Image.Image]:
    """Crop an image into sub-images using a sliding window.

    Args:
        image (Image.Image): The image to crop.
        window_size (int): The size of the sub-images.
        slices (Tuple[int, int]): The number of horizontal and vertical slices.

    Returns:
        List[Image]: A list of cropped images.
    """
    width, height = image.size
    stride_x, stride_y = compute_strides(height, width, window_size, slices)
    sub_images = []
    if stride_x == 0:
        stride_x = window_size

    if stride_y == 0:
        stride_y = window_size
    for y in range(0, height - window_size + 1, stride_y):
        for x in range(0, width - window_size + 1, stride_x):
            sub_image = image.crop((x, y, x + window_size, y + window_size))
            sub_images.append(sub_image)
    return sub_images


def find_best_slices(width_slices: int,
                     height_slices: int,
                     aspect_ratio: float,
                     max_splits: int = 12) -> list:
    """Find the best slices for the given image size and aspect ratio.

    Args:
        width_slices (int): The number of horizontal slices.
        height_slices (int): The number of vertical slices.
        aspect_ratio (float): The aspect ratio of the image.
        max_splits (int, optional): The maximum number of splits.
            Defaults to 12.

    Returns:
        list: the best slices for the given image.
    """
    mapping_dict = construct_mapping_dict(max_splits)
    if aspect_ratio < 1:
        mapping_dict = {
            k: v
            for k, v in mapping_dict.items() if k <= aspect_ratio
        }
    elif aspect_ratio > 1:
        mapping_dict = {
            k: v
            for k, v in mapping_dict.items() if k >= aspect_ratio
        }
    # find the value which key is the closest to the ratio
    best_ratio = min(mapping_dict.keys(), key=lambda x: abs(x - aspect_ratio))
    # best_image_sizes is a list of image sizes
    best_image_sizes = mapping_dict[best_ratio]
    # find the image_size whose area is closest to the current image size
    best_slices = min(
        best_image_sizes,
        key=lambda x: abs(x[0] * x[1] - width_slices * height_slices))
    return best_slices


def split_image_with_catty(pil_image: Image.Image,
                           image_size: int = 336,
                           max_crop_slices: int = 8,
                           save_folder: str = None,
                           add_thumbnail: bool = True,
                           do_resize: bool = False,
                           **kwargs) -> List[Image.Image]:
    """Split an image into sub-images using Catty.

    Args:
        pil_image (Image.Image): The image to split.
        image_size (int, optional): The size of the image.
            Defaults to 336.
        max_crop_slices (int, optional): The maximum number of slices.
            Defaults to 8.
        save_folder (str, optional): The folder to save the sub-images.
            Defaults to None.
        add_thumbnail (bool, optional): Whether to add a thumbnail.
            Defaults to False.
        do_resize (bool, optional): Whether to resize the image to fit the
            maximum number of slices. Defaults to False.

    Returns:
        List[Image.Image]: A list of cropped images.
    """
    width, height = pil_image.size
    ratio = width / height
    if ratio > max_crop_slices or ratio < 1 / max_crop_slices:
        if do_resize:
            print(
                f'Resizing image to fit maximum number of slices ({max_crop_slices})'  # noqa
            )  # noqa
            if width > height:
                new_width = max_crop_slices * height
                new_height = height
            else:
                new_width = width
                new_height = max_crop_slices * width
            pil_image = pil_image.resize((new_width, new_height), resample=2)
            width, height = pil_image.size
            ratio = width / height
        else:
            print(
                f'Image aspect ratio ({ratio:.2f}) is out of range: ({1/max_crop_slices:.2f}, {max_crop_slices:.2f})'  # noqa
            )
            return None, None
    width_slices = width / image_size
    height_slices = height / image_size
    best_slices = find_best_slices(width_slices, height_slices, ratio,
                                   max_crop_slices)
    pil_image = resize_to_best_size(pil_image, best_slices, width_slices,
                                    height_slices, image_size)
    width, height = pil_image.size
    sub_images = sliding_window_crop(pil_image, image_size, best_slices)
    if add_thumbnail:
        thumbnail_image = pil_image.resize((image_size, image_size),
                                           resample=2)
        sub_images.append(thumbnail_image)
    # save split images to folder for debugging
    if save_folder is not None:
        save_image_list(sub_images, save_folder)
    return sub_images