PriMaPs / modules /transforms.py
Oliver Hahn
add demo
1dc26c7
raw
history blame contribute delete
9.99 kB
import torch, random
import torchvision.transforms.functional as F
import torchvision.transforms as tf
import numpy as np
from PIL import Image
from typing import Tuple, List, Callable
class Compose:
def __init__(self,
transforms: List[Callable],
student_augs: bool = False):
self.transforms = transforms
self.student_augs = student_augs
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]:
for transform in self.transforms:
if pseudo is None:
img, gt = transform(img, gt)
else:
img, gt, pseudo = transform(img, gt, pseudo)
if self.student_augs:
aimg = img.clone()
aimg, _ = RandGaussianBlur()(aimg, gt)
if 0.5 > random.random():
aimg, _ = ColorJitter()(aimg, gt)
else:
aimg, _ = MaskGrayscale()(aimg, gt)
if pseudo is None and not self.student_augs:
return img, gt
elif pseudo is None and self.student_augs:
return img, gt, aimg
elif pseudo is not None and not self.student_augs:
return img, gt, pseudo
else:
return img, gt, aimg, pseudo
class ToTensor:
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]:
img = F.to_tensor(np.array(img))
gt = torch.from_numpy(np.array(gt)).unsqueeze(0)
if pseudo is not None:
pseudo = torch.from_numpy(np.array(pseudo)).unsqueeze(0)
if pseudo is None:
return img, gt
else:
return img, gt, pseudo
class Resize:
def __init__(self,
resize: Tuple[int]):
self.img_resize = tf.Resize(size=resize,
interpolation=tf.InterpolationMode.BILINEAR)
self.gt_resize = tf.Resize(size=resize,
interpolation=tf.InterpolationMode.NEAREST)
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[Image.Image, Image.Image]:
img = self.img_resize(img)
gt = self.gt_resize(gt)
if pseudo is None:
return img, gt
else:
return img, gt, self.gt_resize(pseudo)
class ImgResize:
def __init__(self,
resize: Tuple[int, int]):
self.resize = resize
self.num_pixels = self.resize[0]*self.resize[1]
def __call__(self,
img: torch.Tensor,
gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if torch.prod(torch.tensor(img.shape[-2:])) > self.num_pixels:
img = torch.nn.functional.interpolate(img.unsqueeze(0), size=self.resize, mode='bilinear').squeeze(0)
return img, gt
class ImgResizePIL:
def __init__(self,
resize: Tuple[int]):
self.resize = resize
self.num_pixels = self.resize[0]*self.resize[1]
def __call__(self,
img: Image) -> Image:
if img.height*img.width > self.num_pixels:
img = img.resize((self.resize[1], self.resize[0]), tf.InterpolationMode.BILINEAR)
return img
class Normalize:
def __init__(self,
mean: List[float] = [0.485, 0.456, 0.406],
std: List[float] = [0.229, 0.224, 0.225]):
self.norm = tf.Normalize(mean=mean,
std=std)
def __call__(self,
img: torch.Tensor,
gt: torch.Tensor,
pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]:
img = self.norm(img)
if pseudo is None:
return img, gt
else:
return img, gt, pseudo
class UnNormalize(object):
def __init__(self,
mean: List[float] = [0.485, 0.456, 0.406],
std: List[float] = [0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, image):
image2 = torch.clone(image)
for t, m, s in zip(image2, self.mean, self.std):
t.mul_(s).add_(m)
return image2
class RandomHFlip:
def __init__(self,
percentage: float = 0.5):
self.percentage = percentage
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[Image.Image, Image.Image]:
if random.random() < self.percentage:
img = F.hflip(img)
gt = F.hflip(gt)
if pseudo is not None:
pseudo = F.hflip(pseudo)
if pseudo is None:
return img, gt
else:
return img, gt, pseudo
class RandomResizedCrop:
def __init__(self,
crop_size: List[int],
crop_scale: List[float],
crop_ratio: List[float]):
print('RandomResizedCrop ratio modified!!!')
self.crop_scale = tuple(crop_scale)
self.crop_ratio = tuple(crop_ratio)
self.crop = tf.RandomResizedCrop(size=tuple(crop_size),
scale=self.crop_scale,
ratio=self.crop_ratio,)
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[Image.Image, Image.Image]:
i, j, h, w = self.crop.get_params(img=img,
scale=self.crop.scale,
ratio=self.crop.ratio)
img = F.resized_crop(img, i, j, h, w, self.crop.size, tf.InterpolationMode.BILINEAR)
gt = F.resized_crop(gt, i, j, h, w, self.crop.size, tf.InterpolationMode.NEAREST)
if pseudo is not None:
pseudo = F.resized_crop(pseudo, i, j, h, w, self.crop.size, tf.InterpolationMode.NEAREST)
if pseudo is None:
return img, gt
else:
return img, gt, pseudo
class CenterCrop:
def __init__(self,
crop_size: int):
self.crop = tf.CenterCrop(size=crop_size)
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[Image.Image, Image.Image]:
img = self.crop(img)
gt = self.crop(gt)
if pseudo is None:
return img, gt
else:
return img, gt, self.crop(pseudo)
class PyramidCenterCrop:
def __init__(self,
crop_size: List[int],
scales: List[float]):
self.crop_size = crop_size
self.scales = scales
self.crop = tf.CenterCrop(size=crop_size)
def __call__(self,
img: Image.Image,
gt: Image.Image) -> Tuple[Image.Image, Image.Image]:
imgs = []
gts = []
for s in self.scales:
new_size = (int(self.crop_size*1/s), int(self.crop_size*1/s*(img.shape[2]/img.shape[1])))
img = tf.Resize(size=new_size, interpolation=tf.InterpolationMode.BILINEAR)(img)
gt = tf.Resize(size=new_size, interpolation=tf.InterpolationMode.NEAREST)(gt)
imgs.append(self.crop(img))
gts.append(self.crop(gt))
return torch.stack(imgs), torch.stack(gts)
class IdsToTrainIds:
def __init__(self,
source: str):
self.source = source
self.first_nonvoid = 7
def __call__(self,
img: torch.Tensor,
gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.source == 'cityscapes':
gt = gt.to(dtype=torch.int64) - self.first_nonvoid
gt[gt>26] = 255
gt[gt<0] = 255
elif self.source == 'cocostuff':
gt = gt.to(dtype=torch.int64)
elif self.source == 'potsdam':
gt = gt.to(dtype=torch.int64)
return img, gt
class ColorJitter:
def __init__(self, percentage: float = 0.3, brightness: float = 0.1,
contrast: float = 0.1, saturation: float = 0.1, hue: float = 0.1):
self.percentage = percentage
self.jitter = tf.ColorJitter(brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue)
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[Image.Image, Image.Image]:
if random.random() < self.percentage:
img = self.jitter(img)
if pseudo is None:
return img, gt
else:
return img, gt, pseudo
class MaskGrayscale:
def __init__(self, percentage: float = 0.1):
self.percentage = percentage
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[Image.Image, Image.Image]:
if self.percentage > random.random():
img = tf.Grayscale(num_output_channels=3)(img)
if pseudo is None:
return img, gt
else:
return img, gt, pseudo
class RandGaussianBlur:
def __init__(self, radius: List[float] = [.1, 2.]):
self.radius = radius
def __call__(self,
img: Image.Image,
gt: Image.Image,
pseudo = None) -> Tuple[Image.Image, Image.Image]:
radius = random.uniform(self.radius[0], self.radius[1])
img = tf.GaussianBlur(kernel_size=21, sigma=radius)(img)
if pseudo is None:
return img, gt
else:
return img, gt, pseudo