File size: 3,927 Bytes
03e384b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torchvision
import numpy as np
from PIL import Image
from typing import List, Any, Callable, Tuple
from collections import namedtuple

def get_cs_labeldata():
    cls_names = ['road', 'sidewalk', 'parking', 'rail track', 'building',
           'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
           'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
           'terrain', 'sky', 'person', 'rider', 'car',
           'truck', 'bus', 'caravan', 'trailer', 'train',
           'motorcycle', 'bicycle']
    colormap = np.array([
            [128, 64, 128],
            [244, 35, 232],
            [250, 170, 160],
            [230, 150, 140],
            [70, 70, 70],
            [102, 102, 156],
            [190, 153, 153],
            [180, 165, 180],
            [150, 100, 100],
            [150, 120, 90],
            [153, 153, 153],
            [153, 153, 153],
            [250, 170, 30],
            [220, 220, 0],
            [107, 142, 35],
            [152, 251, 152],
            [70, 130, 180],
            [220, 20, 60],
            [255, 0, 0],
            [0, 0, 142],
            [0, 0, 70],
            [0, 60, 100],
            [0, 0, 90],
            [0, 0, 110],
            [0, 80, 100],
            [0, 0, 230],
            [119, 11, 32],
            [0, 0, 0],
            [220, 220, 220]])
    return cls_names, colormap

class CityscapesDataset(torchvision.datasets.Cityscapes):

    def __init__(self,
                 transforms: List[Callable],
                 *args: Any,
                 **kwargs: Any):

        super(CityscapesDataset, self).__init__(*args,
                                                **kwargs,
                                                target_type="semantic")
        self.transforms = transforms
        self.classes = ['road', 'sidewalk', 'parking', 'rail track', 'building',
           'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
           'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
           'terrain', 'sky', 'person', 'rider', 'car',
           'truck', 'bus', 'caravan', 'trailer', 'train',
           'motorcycle', 'bicycle']
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
        """
        img_pth = self.images[index]
        image = Image.open(self.images[index]).convert('RGB')

        targets: Any = []
        for i, t in enumerate(self.target_type):
            if t == 'polygon':
                target = self._load_json(self.targets[index][i])
            else:
                target = Image.open(self.targets[index][i])

            targets.append(target)

        target = tuple(targets) if len(targets) > 1 else targets[0]

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target, img_pth

def cityscapes(root: str,
               split: str,
               transforms: List[Callable]):
    return CityscapesDataset(root=root,
                             split=split,
                             transforms=transforms)

CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                    'has_instances', 'ignore_in_eval', 'color'])

classes = ['road', 'sidewalk', 'parking', 'rail track', 'building',
           'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
           'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
           'terrain', 'sky', 'person', 'rider', 'car',
           'truck', 'bus', 'caravan', 'trailer', 'train',
           'motorcycle', 'bicycle']