|
import torch, random |
|
import torchvision.transforms.functional as F |
|
import torchvision.transforms as tf |
|
import numpy as np |
|
from PIL import Image |
|
from typing import Tuple, List, Callable |
|
|
|
|
|
class Compose: |
|
|
|
def __init__(self, |
|
transforms: List[Callable], |
|
student_augs: bool = False): |
|
self.transforms = transforms |
|
self.student_augs = student_augs |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
for transform in self.transforms: |
|
if pseudo is None: |
|
img, gt = transform(img, gt) |
|
else: |
|
img, gt, pseudo = transform(img, gt, pseudo) |
|
|
|
if self.student_augs: |
|
aimg = img.clone() |
|
aimg, _ = RandGaussianBlur()(aimg, gt) |
|
if 0.5 > random.random(): |
|
aimg, _ = ColorJitter()(aimg, gt) |
|
else: |
|
aimg, _ = MaskGrayscale()(aimg, gt) |
|
|
|
|
|
if pseudo is None and not self.student_augs: |
|
return img, gt |
|
elif pseudo is None and self.student_augs: |
|
return img, gt, aimg |
|
elif pseudo is not None and not self.student_augs: |
|
return img, gt, pseudo |
|
else: |
|
return img, gt, aimg, pseudo |
|
|
|
class ToTensor: |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
img = F.to_tensor(np.array(img)) |
|
gt = torch.from_numpy(np.array(gt)).unsqueeze(0) |
|
if pseudo is not None: |
|
pseudo = torch.from_numpy(np.array(pseudo)).unsqueeze(0) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, pseudo |
|
|
|
class Resize: |
|
|
|
def __init__(self, |
|
resize: Tuple[int]): |
|
|
|
self.img_resize = tf.Resize(size=resize, |
|
interpolation=tf.InterpolationMode.BILINEAR) |
|
self.gt_resize = tf.Resize(size=resize, |
|
interpolation=tf.InterpolationMode.NEAREST) |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[Image.Image, Image.Image]: |
|
|
|
img = self.img_resize(img) |
|
gt = self.gt_resize(gt) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, self.gt_resize(pseudo) |
|
|
|
class ImgResize: |
|
|
|
def __init__(self, |
|
resize: Tuple[int, int]): |
|
self.resize = resize |
|
self.num_pixels = self.resize[0]*self.resize[1] |
|
|
|
def __call__(self, |
|
img: torch.Tensor, |
|
gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if torch.prod(torch.tensor(img.shape[-2:])) > self.num_pixels: |
|
img = torch.nn.functional.interpolate(img.unsqueeze(0), size=self.resize, mode='bilinear').squeeze(0) |
|
return img, gt |
|
|
|
class ImgResizePIL: |
|
|
|
def __init__(self, |
|
resize: Tuple[int]): |
|
self.resize = resize |
|
self.num_pixels = self.resize[0]*self.resize[1] |
|
|
|
def __call__(self, |
|
img: Image) -> Image: |
|
if img.height*img.width > self.num_pixels: |
|
img = img.resize((self.resize[1], self.resize[0]), tf.InterpolationMode.BILINEAR) |
|
return img |
|
|
|
class Normalize: |
|
|
|
def __init__(self, |
|
mean: List[float] = [0.485, 0.456, 0.406], |
|
std: List[float] = [0.229, 0.224, 0.225]): |
|
|
|
self.norm = tf.Normalize(mean=mean, |
|
std=std) |
|
|
|
def __call__(self, |
|
img: torch.Tensor, |
|
gt: torch.Tensor, |
|
pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
img = self.norm(img) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, pseudo |
|
|
|
class UnNormalize(object): |
|
def __init__(self, |
|
mean: List[float] = [0.485, 0.456, 0.406], |
|
std: List[float] = [0.229, 0.224, 0.225]): |
|
self.mean = mean |
|
self.std = std |
|
|
|
def __call__(self, image): |
|
image2 = torch.clone(image) |
|
for t, m, s in zip(image2, self.mean, self.std): |
|
t.mul_(s).add_(m) |
|
return image2 |
|
|
|
|
|
|
|
class RandomHFlip: |
|
|
|
def __init__(self, |
|
percentage: float = 0.5): |
|
|
|
self.percentage = percentage |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[Image.Image, Image.Image]: |
|
|
|
if random.random() < self.percentage: |
|
img = F.hflip(img) |
|
gt = F.hflip(gt) |
|
if pseudo is not None: |
|
pseudo = F.hflip(pseudo) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, pseudo |
|
|
|
|
|
class RandomResizedCrop: |
|
|
|
def __init__(self, |
|
crop_size: List[int], |
|
crop_scale: List[float], |
|
crop_ratio: List[float]): |
|
print('RandomResizedCrop ratio modified!!!') |
|
self.crop_scale = tuple(crop_scale) |
|
self.crop_ratio = tuple(crop_ratio) |
|
self.crop = tf.RandomResizedCrop(size=tuple(crop_size), |
|
scale=self.crop_scale, |
|
ratio=self.crop_ratio,) |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[Image.Image, Image.Image]: |
|
|
|
i, j, h, w = self.crop.get_params(img=img, |
|
scale=self.crop.scale, |
|
ratio=self.crop.ratio) |
|
img = F.resized_crop(img, i, j, h, w, self.crop.size, tf.InterpolationMode.BILINEAR) |
|
gt = F.resized_crop(gt, i, j, h, w, self.crop.size, tf.InterpolationMode.NEAREST) |
|
if pseudo is not None: |
|
pseudo = F.resized_crop(pseudo, i, j, h, w, self.crop.size, tf.InterpolationMode.NEAREST) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, pseudo |
|
|
|
class CenterCrop: |
|
|
|
def __init__(self, |
|
crop_size: int): |
|
|
|
self.crop = tf.CenterCrop(size=crop_size) |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[Image.Image, Image.Image]: |
|
|
|
img = self.crop(img) |
|
gt = self.crop(gt) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, self.crop(pseudo) |
|
|
|
class PyramidCenterCrop: |
|
|
|
def __init__(self, |
|
crop_size: List[int], |
|
scales: List[float]): |
|
|
|
self.crop_size = crop_size |
|
self.scales = scales |
|
self.crop = tf.CenterCrop(size=crop_size) |
|
|
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image) -> Tuple[Image.Image, Image.Image]: |
|
|
|
imgs = [] |
|
gts = [] |
|
for s in self.scales: |
|
new_size = (int(self.crop_size*1/s), int(self.crop_size*1/s*(img.shape[2]/img.shape[1]))) |
|
img = tf.Resize(size=new_size, interpolation=tf.InterpolationMode.BILINEAR)(img) |
|
gt = tf.Resize(size=new_size, interpolation=tf.InterpolationMode.NEAREST)(gt) |
|
imgs.append(self.crop(img)) |
|
gts.append(self.crop(gt)) |
|
|
|
return torch.stack(imgs), torch.stack(gts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class IdsToTrainIds: |
|
|
|
def __init__(self, |
|
source: str): |
|
|
|
self.source = source |
|
self.first_nonvoid = 7 |
|
|
|
|
|
def __call__(self, |
|
img: torch.Tensor, |
|
gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if self.source == 'cityscapes': |
|
gt = gt.to(dtype=torch.int64) - self.first_nonvoid |
|
gt[gt>26] = 255 |
|
gt[gt<0] = 255 |
|
elif self.source == 'cocostuff': |
|
gt = gt.to(dtype=torch.int64) |
|
elif self.source == 'potsdam': |
|
gt = gt.to(dtype=torch.int64) |
|
return img, gt |
|
|
|
|
|
class ColorJitter: |
|
def __init__(self, percentage: float = 0.3, brightness: float = 0.1, |
|
contrast: float = 0.1, saturation: float = 0.1, hue: float = 0.1): |
|
|
|
self.percentage = percentage |
|
self.jitter = tf.ColorJitter(brightness=brightness, |
|
contrast=contrast, |
|
saturation=saturation, |
|
hue=hue) |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[Image.Image, Image.Image]: |
|
if random.random() < self.percentage: |
|
img = self.jitter(img) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, pseudo |
|
|
|
class MaskGrayscale: |
|
|
|
def __init__(self, percentage: float = 0.1): |
|
self.percentage = percentage |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[Image.Image, Image.Image]: |
|
if self.percentage > random.random(): |
|
img = tf.Grayscale(num_output_channels=3)(img) |
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, pseudo |
|
|
|
class RandGaussianBlur: |
|
|
|
def __init__(self, radius: List[float] = [.1, 2.]): |
|
self.radius = radius |
|
|
|
def __call__(self, |
|
img: Image.Image, |
|
gt: Image.Image, |
|
pseudo = None) -> Tuple[Image.Image, Image.Image]: |
|
|
|
radius = random.uniform(self.radius[0], self.radius[1]) |
|
img = tf.GaussianBlur(kernel_size=21, sigma=radius)(img) |
|
|
|
if pseudo is None: |
|
return img, gt |
|
else: |
|
return img, gt, pseudo |
|
|