PriMaPs / datasets /cityscapes.py
Oliver Hahn
add demo
03e384b
raw
history blame contribute delete
3.93 kB
import torchvision
import numpy as np
from PIL import Image
from typing import List, Any, Callable, Tuple
from collections import namedtuple
def get_cs_labeldata():
cls_names = ['road', 'sidewalk', 'parking', 'rail track', 'building',
'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'caravan', 'trailer', 'train',
'motorcycle', 'bicycle']
colormap = np.array([
[128, 64, 128],
[244, 35, 232],
[250, 170, 160],
[230, 150, 140],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[180, 165, 180],
[150, 100, 100],
[150, 120, 90],
[153, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[70, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 0, 90],
[0, 0, 110],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
[0, 0, 0],
[220, 220, 220]])
return cls_names, colormap
class CityscapesDataset(torchvision.datasets.Cityscapes):
def __init__(self,
transforms: List[Callable],
*args: Any,
**kwargs: Any):
super(CityscapesDataset, self).__init__(*args,
**kwargs,
target_type="semantic")
self.transforms = transforms
self.classes = ['road', 'sidewalk', 'parking', 'rail track', 'building',
'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'caravan', 'trailer', 'train',
'motorcycle', 'bicycle']
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
img_pth = self.images[index]
image = Image.open(self.images[index]).convert('RGB')
targets: Any = []
for i, t in enumerate(self.target_type):
if t == 'polygon':
target = self._load_json(self.targets[index][i])
else:
target = Image.open(self.targets[index][i])
targets.append(target)
target = tuple(targets) if len(targets) > 1 else targets[0]
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target, img_pth
def cityscapes(root: str,
split: str,
transforms: List[Callable]):
return CityscapesDataset(root=root,
split=split,
transforms=transforms)
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
'has_instances', 'ignore_in_eval', 'color'])
classes = ['road', 'sidewalk', 'parking', 'rail track', 'building',
'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'caravan', 'trailer', 'train',
'motorcycle', 'bicycle']