|
import torchvision |
|
import numpy as np |
|
from PIL import Image |
|
from typing import List, Any, Callable, Tuple |
|
from collections import namedtuple |
|
|
|
def get_cs_labeldata(): |
|
cls_names = ['road', 'sidewalk', 'parking', 'rail track', 'building', |
|
'wall', 'fence', 'guard rail', 'bridge', 'tunnel', |
|
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation', |
|
'terrain', 'sky', 'person', 'rider', 'car', |
|
'truck', 'bus', 'caravan', 'trailer', 'train', |
|
'motorcycle', 'bicycle'] |
|
colormap = np.array([ |
|
[128, 64, 128], |
|
[244, 35, 232], |
|
[250, 170, 160], |
|
[230, 150, 140], |
|
[70, 70, 70], |
|
[102, 102, 156], |
|
[190, 153, 153], |
|
[180, 165, 180], |
|
[150, 100, 100], |
|
[150, 120, 90], |
|
[153, 153, 153], |
|
[153, 153, 153], |
|
[250, 170, 30], |
|
[220, 220, 0], |
|
[107, 142, 35], |
|
[152, 251, 152], |
|
[70, 130, 180], |
|
[220, 20, 60], |
|
[255, 0, 0], |
|
[0, 0, 142], |
|
[0, 0, 70], |
|
[0, 60, 100], |
|
[0, 0, 90], |
|
[0, 0, 110], |
|
[0, 80, 100], |
|
[0, 0, 230], |
|
[119, 11, 32], |
|
[0, 0, 0], |
|
[220, 220, 220]]) |
|
return cls_names, colormap |
|
|
|
class CityscapesDataset(torchvision.datasets.Cityscapes): |
|
|
|
def __init__(self, |
|
transforms: List[Callable], |
|
*args: Any, |
|
**kwargs: Any): |
|
|
|
super(CityscapesDataset, self).__init__(*args, |
|
**kwargs, |
|
target_type="semantic") |
|
self.transforms = transforms |
|
self.classes = ['road', 'sidewalk', 'parking', 'rail track', 'building', |
|
'wall', 'fence', 'guard rail', 'bridge', 'tunnel', |
|
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation', |
|
'terrain', 'sky', 'person', 'rider', 'car', |
|
'truck', 'bus', 'caravan', 'trailer', 'train', |
|
'motorcycle', 'bicycle'] |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
""" |
|
Args: |
|
index (int): Index |
|
Returns: |
|
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more |
|
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. |
|
""" |
|
img_pth = self.images[index] |
|
image = Image.open(self.images[index]).convert('RGB') |
|
|
|
targets: Any = [] |
|
for i, t in enumerate(self.target_type): |
|
if t == 'polygon': |
|
target = self._load_json(self.targets[index][i]) |
|
else: |
|
target = Image.open(self.targets[index][i]) |
|
|
|
targets.append(target) |
|
|
|
target = tuple(targets) if len(targets) > 1 else targets[0] |
|
|
|
if self.transforms is not None: |
|
image, target = self.transforms(image, target) |
|
|
|
return image, target, img_pth |
|
|
|
def cityscapes(root: str, |
|
split: str, |
|
transforms: List[Callable]): |
|
return CityscapesDataset(root=root, |
|
split=split, |
|
transforms=transforms) |
|
|
|
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', |
|
'has_instances', 'ignore_in_eval', 'color']) |
|
|
|
classes = ['road', 'sidewalk', 'parking', 'rail track', 'building', |
|
'wall', 'fence', 'guard rail', 'bridge', 'tunnel', |
|
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation', |
|
'terrain', 'sky', 'person', 'rider', 'car', |
|
'truck', 'bus', 'caravan', 'trailer', 'train', |
|
'motorcycle', 'bicycle'] |
|
|