import torchvision import numpy as np from PIL import Image from typing import List, Any, Callable, Tuple from collections import namedtuple def get_cs_labeldata(): cls_names = ['road', 'sidewalk', 'parking', 'rail track', 'building', 'wall', 'fence', 'guard rail', 'bridge', 'tunnel', 'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'caravan', 'trailer', 'train', 'motorcycle', 'bicycle'] colormap = np.array([ [128, 64, 128], [244, 35, 232], [250, 170, 160], [230, 150, 140], [70, 70, 70], [102, 102, 156], [190, 153, 153], [180, 165, 180], [150, 100, 100], [150, 120, 90], [153, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 0, 90], [0, 0, 110], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0, 0, 0], [220, 220, 220]]) return cls_names, colormap class CityscapesDataset(torchvision.datasets.Cityscapes): def __init__(self, transforms: List[Callable], *args: Any, **kwargs: Any): super(CityscapesDataset, self).__init__(*args, **kwargs, target_type="semantic") self.transforms = transforms self.classes = ['road', 'sidewalk', 'parking', 'rail track', 'building', 'wall', 'fence', 'guard rail', 'bridge', 'tunnel', 'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'caravan', 'trailer', 'train', 'motorcycle', 'bicycle'] def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is a tuple of all target types if target_type is a list with more than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. """ img_pth = self.images[index] image = Image.open(self.images[index]).convert('RGB') targets: Any = [] for i, t in enumerate(self.target_type): if t == 'polygon': target = self._load_json(self.targets[index][i]) else: target = Image.open(self.targets[index][i]) targets.append(target) target = tuple(targets) if len(targets) > 1 else targets[0] if self.transforms is not None: image, target = self.transforms(image, target) return image, target, img_pth def cityscapes(root: str, split: str, transforms: List[Callable]): return CityscapesDataset(root=root, split=split, transforms=transforms) CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', 'has_instances', 'ignore_in_eval', 'color']) classes = ['road', 'sidewalk', 'parking', 'rail track', 'building', 'wall', 'fence', 'guard rail', 'bridge', 'tunnel', 'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'caravan', 'trailer', 'train', 'motorcycle', 'bicycle']