Upload processor
Browse files- image_processing_basnet.py +279 -0
- preprocessor_config.json +8 -0
    	
        image_processing_basnet.py
    ADDED
    
    | @@ -0,0 +1,279 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict, Tuple, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            from PIL.Image import Image as PilImage
         | 
| 8 | 
            +
            from torchvision import transforms
         | 
| 9 | 
            +
            from transformers.image_processing_base import BatchFeature
         | 
| 10 | 
            +
            from transformers.image_processing_utils import BaseImageProcessor
         | 
| 11 | 
            +
            from transformers.image_utils import ImageInput
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class RescaleT(object):
         | 
| 15 | 
            +
                def __init__(self, output_size: Union[int, Tuple[int, int]]) -> None:
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    assert isinstance(output_size, (int, tuple))
         | 
| 18 | 
            +
                    self.output_size = output_size
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def __call__(self, sample):
         | 
| 21 | 
            +
                    image, label = sample["image"], sample["label"]
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    h, w = image.shape[:2]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    if isinstance(self.output_size, int):
         | 
| 26 | 
            +
                        if h > w:
         | 
| 27 | 
            +
                            new_h, new_w = self.output_size * h / w, self.output_size
         | 
| 28 | 
            +
                        else:
         | 
| 29 | 
            +
                            new_h, new_w = self.output_size, self.output_size * w / h
         | 
| 30 | 
            +
                    else:
         | 
| 31 | 
            +
                        new_h, new_w = self.output_size
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    new_h, new_w = int(new_h), int(new_w)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    # resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
         | 
| 36 | 
            +
                    # img = transform.resize(image,(new_h,new_w),mode='constant')
         | 
| 37 | 
            +
                    # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    # img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
         | 
| 40 | 
            +
                    img = (
         | 
| 41 | 
            +
                        cv2.resize(
         | 
| 42 | 
            +
                            image,
         | 
| 43 | 
            +
                            (self.output_size, self.output_size),
         | 
| 44 | 
            +
                            interpolation=cv2.INTER_AREA,
         | 
| 45 | 
            +
                        )
         | 
| 46 | 
            +
                        / 255.0
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    # lbl = transform.resize(label, (self.output_size, self.output_size),
         | 
| 49 | 
            +
                    #                        mode='constant',
         | 
| 50 | 
            +
                    #                        order=0,
         | 
| 51 | 
            +
                    #                        preserve_range=True)
         | 
| 52 | 
            +
                    lbl = cv2.resize(
         | 
| 53 | 
            +
                        label, (self.output_size, self.output_size), interpolation=cv2.INTER_NEAREST
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    lbl = np.expand_dims(lbl, axis=-1)
         | 
| 56 | 
            +
                    lbl = np.clip(lbl, np.min(label), np.max(label))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    return {"image": img, "label": lbl}
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class ToTensorLab(object):
         | 
| 62 | 
            +
                """Convert ndarrays in sample to Tensors."""
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def __init__(self, flag=0):
         | 
| 65 | 
            +
                    self.flag = flag
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def __call__(self, sample):
         | 
| 68 | 
            +
                    image, label = sample["image"], sample["label"]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    tmpLbl = np.zeros(label.shape)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    if np.max(label) < 1e-6:
         | 
| 73 | 
            +
                        label = label
         | 
| 74 | 
            +
                    else:
         | 
| 75 | 
            +
                        label = label / np.max(label)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    # print('self.flag:', self.flag) # Default: 0
         | 
| 78 | 
            +
                    # change the color space
         | 
| 79 | 
            +
                    if self.flag == 2:  # with rgb and Lab colors
         | 
| 80 | 
            +
                        tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
         | 
| 81 | 
            +
                        tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
         | 
| 82 | 
            +
                        if image.shape[2] == 1:
         | 
| 83 | 
            +
                            tmpImgt[:, :, 0] = image[:, :, 0]
         | 
| 84 | 
            +
                            tmpImgt[:, :, 1] = image[:, :, 0]
         | 
| 85 | 
            +
                            tmpImgt[:, :, 2] = image[:, :, 0]
         | 
| 86 | 
            +
                        else:
         | 
| 87 | 
            +
                            tmpImgt = image
         | 
| 88 | 
            +
                        # tmpImgtl = color.rgb2lab(tmpImgt)
         | 
| 89 | 
            +
                        tmpImgtl = cv2.cvtColor(tmpImgt, cv2.COLOR_RGB2LAB)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        # nomalize image to range [0,1]
         | 
| 92 | 
            +
                        tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
         | 
| 93 | 
            +
                            np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
         | 
| 94 | 
            +
                        )
         | 
| 95 | 
            +
                        tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
         | 
| 96 | 
            +
                            np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
                        tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
         | 
| 99 | 
            +
                            np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
         | 
| 100 | 
            +
                        )
         | 
| 101 | 
            +
                        tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
         | 
| 102 | 
            +
                            np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
         | 
| 103 | 
            +
                        )
         | 
| 104 | 
            +
                        tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
         | 
| 105 | 
            +
                            np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
                        tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
         | 
| 108 | 
            +
                            np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
         | 
| 109 | 
            +
                        )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                        tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
         | 
| 114 | 
            +
                            tmpImg[:, :, 0]
         | 
| 115 | 
            +
                        )
         | 
| 116 | 
            +
                        tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
         | 
| 117 | 
            +
                            tmpImg[:, :, 1]
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
                        tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
         | 
| 120 | 
            +
                            tmpImg[:, :, 2]
         | 
| 121 | 
            +
                        )
         | 
| 122 | 
            +
                        tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
         | 
| 123 | 
            +
                            tmpImg[:, :, 3]
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
            +
                        tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
         | 
| 126 | 
            +
                            tmpImg[:, :, 4]
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
                        tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
         | 
| 129 | 
            +
                            tmpImg[:, :, 5]
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    elif self.flag == 1:  # with Lab color
         | 
| 133 | 
            +
                        tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        if image.shape[2] == 1:
         | 
| 136 | 
            +
                            tmpImg[:, :, 0] = image[:, :, 0]
         | 
| 137 | 
            +
                            tmpImg[:, :, 1] = image[:, :, 0]
         | 
| 138 | 
            +
                            tmpImg[:, :, 2] = image[:, :, 0]
         | 
| 139 | 
            +
                        else:
         | 
| 140 | 
            +
                            tmpImg = image
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        # tmpImg = color.rgb2lab(tmpImg)
         | 
| 143 | 
            +
                        print("tmpImg:", tmpImg.min(), tmpImg.max())
         | 
| 144 | 
            +
                        exit()
         | 
| 145 | 
            +
                        tmpImg = cv2.cvtColor(tmpImg, cv2.COLOR_RGB2LAB)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
         | 
| 150 | 
            +
                            np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                        tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
         | 
| 153 | 
            +
                            np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
         | 
| 154 | 
            +
                        )
         | 
| 155 | 
            +
                        tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
         | 
| 156 | 
            +
                            np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
         | 
| 157 | 
            +
                        )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
         | 
| 160 | 
            +
                            tmpImg[:, :, 0]
         | 
| 161 | 
            +
                        )
         | 
| 162 | 
            +
                        tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
         | 
| 163 | 
            +
                            tmpImg[:, :, 1]
         | 
| 164 | 
            +
                        )
         | 
| 165 | 
            +
                        tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
         | 
| 166 | 
            +
                            tmpImg[:, :, 2]
         | 
| 167 | 
            +
                        )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    else:  # with rgb color
         | 
| 170 | 
            +
                        tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
         | 
| 171 | 
            +
                        image = image / np.max(image)
         | 
| 172 | 
            +
                        if image.shape[2] == 1:
         | 
| 173 | 
            +
                            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
         | 
| 174 | 
            +
                            tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
         | 
| 175 | 
            +
                            tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
         | 
| 176 | 
            +
                        else:
         | 
| 177 | 
            +
                            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
         | 
| 178 | 
            +
                            tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
         | 
| 179 | 
            +
                            tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    tmpLbl[:, :, 0] = label[:, :, 0]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # change the r,g,b to b,r,g from [0,255] to [0,1]
         | 
| 184 | 
            +
                    # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
         | 
| 185 | 
            +
                    tmpImg = tmpImg.transpose((2, 0, 1))
         | 
| 186 | 
            +
                    tmpLbl = label.transpose((2, 0, 1))
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    return {"image": torch.from_numpy(tmpImg), "label": torch.from_numpy(tmpLbl)}
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            def apply_transform(
         | 
| 192 | 
            +
                data: Dict[str, np.ndarray], rescale_size: int, to_tensor_lab_flag: int
         | 
| 193 | 
            +
            ) -> Dict[str, torch.Tensor]:
         | 
| 194 | 
            +
                transform = transforms.Compose(
         | 
| 195 | 
            +
                    [RescaleT(output_size=rescale_size), ToTensorLab(flag=to_tensor_lab_flag)]
         | 
| 196 | 
            +
                )
         | 
| 197 | 
            +
                return transform(data)  # type: ignore
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class BASNetImageProcessor(BaseImageProcessor):
         | 
| 201 | 
            +
                model_input_names = ["pixel_values"]
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def __init__(
         | 
| 204 | 
            +
                    self, rescale_size: int = 256, to_tensor_lab_flag: int = 0, **kwargs
         | 
| 205 | 
            +
                ) -> None:
         | 
| 206 | 
            +
                    super().__init__(**kwargs)
         | 
| 207 | 
            +
                    self.rescale_size = rescale_size
         | 
| 208 | 
            +
                    self.to_tensor_lab_flag = to_tensor_lab_flag
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
         | 
| 211 | 
            +
                    if not isinstance(images, PilImage):
         | 
| 212 | 
            +
                        raise ValueError(f"Expected PIL.Image, got {type(images)}")
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    image_pil = images
         | 
| 215 | 
            +
                    image_npy = np.array(image_pil, dtype=np.uint8)
         | 
| 216 | 
            +
                    width, height = image_pil.size
         | 
| 217 | 
            +
                    label_npy = np.zeros((height, width), dtype=np.uint8)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    assert image_npy.shape[-1] == 3
         | 
| 220 | 
            +
                    output = apply_transform(
         | 
| 221 | 
            +
                        {"image": image_npy, "label": label_npy},
         | 
| 222 | 
            +
                        rescale_size=self.rescale_size,
         | 
| 223 | 
            +
                        to_tensor_lab_flag=self.to_tensor_lab_flag,
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
                    image = output["image"]
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    assert isinstance(image, torch.Tensor)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    return BatchFeature(
         | 
| 230 | 
            +
                        data={"pixel_values": image.float().unsqueeze(dim=0)}, tensor_type="pt"
         | 
| 231 | 
            +
                    )
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def postprocess(
         | 
| 234 | 
            +
                    self, prediction: torch.Tensor, width: int, height: int
         | 
| 235 | 
            +
                ) -> PilImage:
         | 
| 236 | 
            +
                    def _norm_prediction(d: torch.Tensor) -> torch.Tensor:
         | 
| 237 | 
            +
                        ma, mi = torch.max(d), torch.min(d)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                        # division while avoiding zero division
         | 
| 240 | 
            +
                        dn = (d - mi) / ((ma - mi) + torch.finfo(torch.float32).eps)
         | 
| 241 | 
            +
                        return dn
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # prediction = _norm_output(prediction)
         | 
| 244 | 
            +
                    # prediction = prediction.squeeze()
         | 
| 245 | 
            +
                    # prediction_np = prediction.cpu().numpy()
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    # image = Image.fromarray(prediction_np * 255).convert("RGB")
         | 
| 248 | 
            +
                    # image = image.resize((width, height), resample=Image.Resampling.BILINEAR)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    # return image
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    # breakpoint()
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # output = F.interpolate(output, (height, width), mode="bilinear")
         | 
| 255 | 
            +
                    # output = output.squeeze(dim=0)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # output = _norm_output(output)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
         | 
| 260 | 
            +
                    # output = output * 255 + 0.5
         | 
| 261 | 
            +
                    # output = output.clamp(0, 255)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    # # shape: (C=1, W, H) -> (W, H, C=1)
         | 
| 264 | 
            +
                    # output = output.permute(1, 2, 0)
         | 
| 265 | 
            +
                    # # shape: (W, H, C=3)
         | 
| 266 | 
            +
                    # output = output.repeat(1, 1, 3)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    # output_np = output.cpu().numpy().astype(np.uint8)
         | 
| 269 | 
            +
                    # return Image.fromarray(output_np)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    prediction = _norm_prediction(prediction)
         | 
| 272 | 
            +
                    prediction = prediction.squeeze()
         | 
| 273 | 
            +
                    prediction = prediction * 255 + 0.5
         | 
| 274 | 
            +
                    prediction = prediction.clamp(0, 255)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    prediction_np = prediction.cpu().numpy()
         | 
| 277 | 
            +
                    image = Image.fromarray(prediction_np).convert("RGB")
         | 
| 278 | 
            +
                    image = image.resize((width, height), resample=Image.Resampling.BILINEAR)
         | 
| 279 | 
            +
                    return image
         | 
    	
        preprocessor_config.json
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "auto_map": {
         | 
| 3 | 
            +
                "AutoImageProcessor": "image_processing_basnet.BASNetImageProcessor"
         | 
| 4 | 
            +
              },
         | 
| 5 | 
            +
              "image_processor_type": "BASNetImageProcessor",
         | 
| 6 | 
            +
              "rescale_size": 256,
         | 
| 7 | 
            +
              "to_tensor_lab_flag": 0
         | 
| 8 | 
            +
            }
         | 

