File size: 3,927 Bytes
03e384b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torchvision
import numpy as np
from PIL import Image
from typing import List, Any, Callable, Tuple
from collections import namedtuple
def get_cs_labeldata():
cls_names = ['road', 'sidewalk', 'parking', 'rail track', 'building',
'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'caravan', 'trailer', 'train',
'motorcycle', 'bicycle']
colormap = np.array([
[128, 64, 128],
[244, 35, 232],
[250, 170, 160],
[230, 150, 140],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[180, 165, 180],
[150, 100, 100],
[150, 120, 90],
[153, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[70, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 0, 90],
[0, 0, 110],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
[0, 0, 0],
[220, 220, 220]])
return cls_names, colormap
class CityscapesDataset(torchvision.datasets.Cityscapes):
def __init__(self,
transforms: List[Callable],
*args: Any,
**kwargs: Any):
super(CityscapesDataset, self).__init__(*args,
**kwargs,
target_type="semantic")
self.transforms = transforms
self.classes = ['road', 'sidewalk', 'parking', 'rail track', 'building',
'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'caravan', 'trailer', 'train',
'motorcycle', 'bicycle']
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
img_pth = self.images[index]
image = Image.open(self.images[index]).convert('RGB')
targets: Any = []
for i, t in enumerate(self.target_type):
if t == 'polygon':
target = self._load_json(self.targets[index][i])
else:
target = Image.open(self.targets[index][i])
targets.append(target)
target = tuple(targets) if len(targets) > 1 else targets[0]
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target, img_pth
def cityscapes(root: str,
split: str,
transforms: List[Callable]):
return CityscapesDataset(root=root,
split=split,
transforms=transforms)
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
'has_instances', 'ignore_in_eval', 'color'])
classes = ['road', 'sidewalk', 'parking', 'rail track', 'building',
'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'caravan', 'trailer', 'train',
'motorcycle', 'bicycle']
|