Oliver Hahn
		
	commited on
		
		
					Commit 
							
							·
						
						04513cf
	
1
								Parent(s):
							
							3f97a79
								
add demo
Browse files- .DS_Store +0 -0
- README.md +1 -1
- app.py +2 -6
- assets/demo_examples/.DS_Store +0 -0
- assets/demo_examples/cityscapes_example.png +0 -3
- assets/demo_examples/potsdam_example.png +0 -3
- assets/{demo_examples/coco_example.jpg → example.jpg} +0 -0
- datasets/.DS_Store +0 -0
- datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- datasets/__pycache__/__init__.cpython-36.pyc +0 -0
- datasets/__pycache__/__init__.cpython-37.pyc +0 -0
- datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-310.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-311.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-36.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-37.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-38.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-39.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-310.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-311.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-36.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-37.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-38.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-39.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-310.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-311.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-36.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-37.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-38.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-39.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-310.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-311.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-36.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-37.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-38.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-39.pyc +0 -0
- datasets/cocostuff.py +0 -125
- datasets/potsdam.py +0 -121
- datasets/precomputed.py +0 -43
    	
        .DS_Store
    CHANGED
    
    | Binary files a/.DS_Store and b/.DS_Store differ | 
|  | 
    	
        README.md
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            title: PriMaPs
         | 
| 3 | 
             
            emoji: 😻
         | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
             
            colorTo: gray
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 5.6.0
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            title: PriMaPs
         | 
| 3 | 
             
            emoji: 😻
         | 
| 4 | 
            +
            colorFrom: black
         | 
| 5 | 
             
            colorTo: gray
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 5.6.0
         | 
    	
        app.py
    CHANGED
    
    | @@ -50,9 +50,7 @@ def gradio_primaps(image_path, threshold, architecture): | |
| 50 | 
             
            if __name__ == '__main__':    
         | 
| 51 | 
             
                # Example image paths
         | 
| 52 | 
             
                example_images = [
         | 
| 53 | 
            -
                    "assets/demo_examples/ | 
| 54 | 
            -
                    "assets/demo_examples/coco_example.jpg",
         | 
| 55 | 
            -
                    "assets/demo_examples/potsdam_example.png"
         | 
| 56 | 
             
                ]
         | 
| 57 |  | 
| 58 | 
             
                # Gradio interface
         | 
| @@ -60,7 +58,7 @@ if __name__ == '__main__': | |
| 60 | 
             
                    fn=gradio_primaps,
         | 
| 61 | 
             
                    inputs=[
         | 
| 62 | 
             
                        gr.Image(type="filepath", label="Image"),
         | 
| 63 | 
            -
                        gr.Slider(0.0, 1.0, step=0.05, value=0. | 
| 64 | 
             
                        gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"),
         | 
| 65 | 
             
                    ],
         | 
| 66 | 
             
                    outputs=gr.Image(label="PriMaPs"),
         | 
| @@ -68,8 +66,6 @@ if __name__ == '__main__': | |
| 68 | 
             
                    description="Upload an image and adjust the threshold to visualize PriMaPs.",
         | 
| 69 | 
             
                    examples=[
         | 
| 70 | 
             
                        [example_images[0], 0.4, 'dino_vitb'],
         | 
| 71 | 
            -
                        [example_images[1], 0.4, 'dino_vitb'],
         | 
| 72 | 
            -
                        [example_images[2], 0.4, 'dino_vitb']
         | 
| 73 | 
             
                    ]
         | 
| 74 | 
             
                )
         | 
| 75 |  | 
|  | |
| 50 | 
             
            if __name__ == '__main__':    
         | 
| 51 | 
             
                # Example image paths
         | 
| 52 | 
             
                example_images = [
         | 
| 53 | 
            +
                    "assets/demo_examples/example.jpg",
         | 
|  | |
|  | |
| 54 | 
             
                ]
         | 
| 55 |  | 
| 56 | 
             
                # Gradio interface
         | 
|  | |
| 58 | 
             
                    fn=gradio_primaps,
         | 
| 59 | 
             
                    inputs=[
         | 
| 60 | 
             
                        gr.Image(type="filepath", label="Image"),
         | 
| 61 | 
            +
                        gr.Slider(0.0, 1.0, step=0.05, value=0.4, label="Threshold"),
         | 
| 62 | 
             
                        gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"),
         | 
| 63 | 
             
                    ],
         | 
| 64 | 
             
                    outputs=gr.Image(label="PriMaPs"),
         | 
|  | |
| 66 | 
             
                    description="Upload an image and adjust the threshold to visualize PriMaPs.",
         | 
| 67 | 
             
                    examples=[
         | 
| 68 | 
             
                        [example_images[0], 0.4, 'dino_vitb'],
         | 
|  | |
|  | |
| 69 | 
             
                    ]
         | 
| 70 | 
             
                )
         | 
| 71 |  | 
    	
        assets/demo_examples/.DS_Store
    DELETED
    
    | Binary file (6.15 kB) | 
|  | 
    	
        assets/demo_examples/cityscapes_example.png
    DELETED
    
    | Git LFS Details
 | 
    	
        assets/demo_examples/potsdam_example.png
    DELETED
    
    | Git LFS Details
 | 
    	
        assets/{demo_examples/coco_example.jpg → example.jpg}
    RENAMED
    
    | 
											File without changes
										 | 
    	
        datasets/.DS_Store
    CHANGED
    
    | Binary files a/datasets/.DS_Store and b/datasets/.DS_Store differ | 
|  | 
    	
        datasets/__pycache__/__init__.cpython-310.pyc
    DELETED
    
    | Binary file (289 Bytes) | 
|  | 
    	
        datasets/__pycache__/__init__.cpython-311.pyc
    DELETED
    
    | Binary file (343 Bytes) | 
|  | 
    	
        datasets/__pycache__/__init__.cpython-36.pyc
    DELETED
    
    | Binary file (279 Bytes) | 
|  | 
    	
        datasets/__pycache__/__init__.cpython-37.pyc
    DELETED
    
    | Binary file (283 Bytes) | 
|  | 
    	
        datasets/__pycache__/__init__.cpython-38.pyc
    DELETED
    
    | Binary file (287 Bytes) | 
|  | 
    	
        datasets/__pycache__/__init__.cpython-39.pyc
    DELETED
    
    | Binary file (287 Bytes) | 
|  | 
    	
        datasets/__pycache__/cityscapes.cpython-310.pyc
    DELETED
    
    | Binary file (3.2 kB) | 
|  | 
    	
        datasets/__pycache__/cityscapes.cpython-311.pyc
    DELETED
    
    | Binary file (4.66 kB) | 
|  | 
    	
        datasets/__pycache__/cityscapes.cpython-36.pyc
    DELETED
    
    | Binary file (3.44 kB) | 
|  | 
    	
        datasets/__pycache__/cityscapes.cpython-37.pyc
    DELETED
    
    | Binary file (3.44 kB) | 
|  | 
    	
        datasets/__pycache__/cityscapes.cpython-38.pyc
    DELETED
    
    | Binary file (5.42 kB) | 
|  | 
    	
        datasets/__pycache__/cityscapes.cpython-39.pyc
    DELETED
    
    | Binary file (3.17 kB) | 
|  | 
    	
        datasets/__pycache__/cocostuff.cpython-310.pyc
    DELETED
    
    | Binary file (6.38 kB) | 
|  | 
    	
        datasets/__pycache__/cocostuff.cpython-311.pyc
    DELETED
    
    | Binary file (10.9 kB) | 
|  | 
    	
        datasets/__pycache__/cocostuff.cpython-36.pyc
    DELETED
    
    | Binary file (5.23 kB) | 
|  | 
    	
        datasets/__pycache__/cocostuff.cpython-37.pyc
    DELETED
    
    | Binary file (5.18 kB) | 
|  | 
    	
        datasets/__pycache__/cocostuff.cpython-38.pyc
    DELETED
    
    | Binary file (5.53 kB) | 
|  | 
    	
        datasets/__pycache__/cocostuff.cpython-39.pyc
    DELETED
    
    | Binary file (5.2 kB) | 
|  | 
    	
        datasets/__pycache__/potsdam.cpython-310.pyc
    DELETED
    
    | Binary file (4.03 kB) | 
|  | 
    	
        datasets/__pycache__/potsdam.cpython-311.pyc
    DELETED
    
    | Binary file (8.22 kB) | 
|  | 
    	
        datasets/__pycache__/potsdam.cpython-36.pyc
    DELETED
    
    | Binary file (4.07 kB) | 
|  | 
    	
        datasets/__pycache__/potsdam.cpython-37.pyc
    DELETED
    
    | Binary file (4.08 kB) | 
|  | 
    	
        datasets/__pycache__/potsdam.cpython-38.pyc
    DELETED
    
    | Binary file (4.1 kB) | 
|  | 
    	
        datasets/__pycache__/potsdam.cpython-39.pyc
    DELETED
    
    | Binary file (4.07 kB) | 
|  | 
    	
        datasets/__pycache__/precomputed.cpython-310.pyc
    DELETED
    
    | Binary file (1.49 kB) | 
|  | 
    	
        datasets/__pycache__/precomputed.cpython-311.pyc
    DELETED
    
    | Binary file (3.07 kB) | 
|  | 
    	
        datasets/__pycache__/precomputed.cpython-36.pyc
    DELETED
    
    | Binary file (1.47 kB) | 
|  | 
    	
        datasets/__pycache__/precomputed.cpython-37.pyc
    DELETED
    
    | Binary file (1.48 kB) | 
|  | 
    	
        datasets/__pycache__/precomputed.cpython-38.pyc
    DELETED
    
    | Binary file (1.53 kB) | 
|  | 
    	
        datasets/__pycache__/precomputed.cpython-39.pyc
    DELETED
    
    | Binary file (1.49 kB) | 
|  | 
    	
        datasets/cocostuff.py
    DELETED
    
    | @@ -1,125 +0,0 @@ | |
| 1 | 
            -
            from os.path import join
         | 
| 2 | 
            -
            import numpy as np
         | 
| 3 | 
            -
            import torch.multiprocessing
         | 
| 4 | 
            -
            from PIL import Image
         | 
| 5 | 
            -
            from torch.utils.data import Dataset
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            def bit_get(val, idx):
         | 
| 8 | 
            -
                """Gets the bit value.
         | 
| 9 | 
            -
                Args:
         | 
| 10 | 
            -
                  val: Input value, int or numpy int array.
         | 
| 11 | 
            -
                  idx: Which bit of the input val.
         | 
| 12 | 
            -
                Returns:
         | 
| 13 | 
            -
                  The "idx"-th bit of input val.
         | 
| 14 | 
            -
                """
         | 
| 15 | 
            -
                return (val >> idx) & 1
         | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
            def create_pascal_label_colormap():
         | 
| 19 | 
            -
                """Creates a label colormap used in PASCAL VOC segmentation benchmark.
         | 
| 20 | 
            -
                Returns:
         | 
| 21 | 
            -
                  A colormap for visualizing segmentation results.
         | 
| 22 | 
            -
                """
         | 
| 23 | 
            -
                colormap = np.zeros((512, 3), dtype=int)
         | 
| 24 | 
            -
                ind = np.arange(512, dtype=int)
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                for shift in reversed(list(range(8))):
         | 
| 27 | 
            -
                    for channel in range(3):
         | 
| 28 | 
            -
                        colormap[:, channel] |= bit_get(ind, channel) << shift
         | 
| 29 | 
            -
                    ind >>= 3
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                return colormap
         | 
| 32 | 
            -
             | 
| 33 | 
            -
            def get_coco_labeldata():
         | 
| 34 | 
            -
                cls_names = ["electronic", "appliance", "food", "furniture", "indoor", "kitchen", "accessory", "animal", "outdoor", "person", "sports", "vehicle", "ceiling", "floor", "food", "furniture", "rawmaterial", "textile", "wall", "window", "building", "ground", "plant", "sky", "solid", "structural", "water"]
         | 
| 35 | 
            -
                colormap = create_pascal_label_colormap()
         | 
| 36 | 
            -
                colormap[27] = np.array([0, 0, 0])
         | 
| 37 | 
            -
                return cls_names, colormap
         | 
| 38 | 
            -
             | 
| 39 | 
            -
            class cocostuff(Dataset):
         | 
| 40 | 
            -
                def __init__(self, root, split, transforms, #target_transform,
         | 
| 41 | 
            -
                             coarse_labels=None, exclude_things=None, subset=7):  #None):
         | 
| 42 | 
            -
                    super(cocostuff, self).__init__()
         | 
| 43 | 
            -
                    self.split = split
         | 
| 44 | 
            -
                    self.root = root
         | 
| 45 | 
            -
                    self.coarse_labels = coarse_labels
         | 
| 46 | 
            -
                    self.transforms = transforms
         | 
| 47 | 
            -
                    #self.label_transform = target_transform
         | 
| 48 | 
            -
                    self.subset = subset
         | 
| 49 | 
            -
                    self.exclude_things = exclude_things
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                    if self.subset is None:
         | 
| 52 | 
            -
                        self.image_list = "Coco164kFull_Stuff_Coarse.txt"
         | 
| 53 | 
            -
                    elif self.subset == 6:  # IIC Coarse
         | 
| 54 | 
            -
                        self.image_list = "Coco164kFew_Stuff_6.txt"
         | 
| 55 | 
            -
                    elif self.subset == 7:  # IIC Fine
         | 
| 56 | 
            -
                        self.image_list = "Coco164kFull_Stuff_Coarse_7.txt"
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    assert self.split in ["train", "val", "train+val"]
         | 
| 59 | 
            -
                    split_dirs = {
         | 
| 60 | 
            -
                        "train": ["train2017"],
         | 
| 61 | 
            -
                        "val": ["val2017"],
         | 
| 62 | 
            -
                        "train+val": ["train2017", "val2017"]
         | 
| 63 | 
            -
                    }
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                    self.image_files = []
         | 
| 66 | 
            -
                    self.label_files = []
         | 
| 67 | 
            -
                    for split_dir in split_dirs[self.split]:
         | 
| 68 | 
            -
                        with open(join(self.root, "curated", split_dir, self.image_list), "r") as f:
         | 
| 69 | 
            -
                            img_ids = [fn.rstrip() for fn in f.readlines()]
         | 
| 70 | 
            -
                            for img_id in img_ids:
         | 
| 71 | 
            -
                                self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg"))
         | 
| 72 | 
            -
                                self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png"))
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8,
         | 
| 75 | 
            -
                                           13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7,
         | 
| 76 | 
            -
                                           25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10,
         | 
| 77 | 
            -
                                           37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5,
         | 
| 78 | 
            -
                                           49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2,
         | 
| 79 | 
            -
                                           61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0,
         | 
| 80 | 
            -
                                           73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4,
         | 
| 81 | 
            -
                                           85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22,
         | 
| 82 | 
            -
                                           97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15,
         | 
| 83 | 
            -
                                           107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13,
         | 
| 84 | 
            -
                                           117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24,
         | 
| 85 | 
            -
                                           127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17,
         | 
| 86 | 
            -
                                           137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21,
         | 
| 87 | 
            -
                                           147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23,
         | 
| 88 | 
            -
                                           157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17,
         | 
| 89 | 
            -
                                           167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18,
         | 
| 90 | 
            -
                                           177: 26, 178: 26, 179: 19, 180: 19, 181: 24}
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                    self._label_names = [
         | 
| 93 | 
            -
                        "ground-stuff",
         | 
| 94 | 
            -
                        "plant-stuff",
         | 
| 95 | 
            -
                        "sky-stuff",
         | 
| 96 | 
            -
                    ]
         | 
| 97 | 
            -
                    self.cocostuff3_coarse_classes = [23, 22, 21]
         | 
| 98 | 
            -
                    self.first_stuff_index = 12
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                def __getitem__(self, index):
         | 
| 101 | 
            -
                    image_path = self.image_files[index]
         | 
| 102 | 
            -
                    label_path = self.label_files[index]
         | 
| 103 | 
            -
                    seed = np.random.randint(2147483647)
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                    img, label = self.transforms(Image.open(image_path).convert("RGB"), Image.open(label_path))
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                    label[label == 255] = -1  # to be consistent with 10k
         | 
| 108 | 
            -
                    coarse_label = torch.zeros_like(label)
         | 
| 109 | 
            -
                    for fine, coarse in self.fine_to_coarse.items():
         | 
| 110 | 
            -
                        coarse_label[label == fine] = coarse
         | 
| 111 | 
            -
                    coarse_label[label == -1] = 255 #-1
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                    if self.coarse_labels:
         | 
| 114 | 
            -
                        coarser_labels = -torch.ones_like(label)
         | 
| 115 | 
            -
                        for i, c in enumerate(self.cocostuff3_coarse_classes):
         | 
| 116 | 
            -
                            coarser_labels[coarse_label == c] = i
         | 
| 117 | 
            -
                        return img, coarser_labels, coarser_labels >= 0
         | 
| 118 | 
            -
                    else:
         | 
| 119 | 
            -
                        if self.exclude_things:
         | 
| 120 | 
            -
                            return img, coarse_label - self.first_stuff_index, (coarse_label >= self.first_stuff_index)
         | 
| 121 | 
            -
                        else:
         | 
| 122 | 
            -
                            return img, coarse_label, image_path
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                def __len__(self):
         | 
| 125 | 
            -
                    return len(self.image_files)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        datasets/potsdam.py
    DELETED
    
    | @@ -1,121 +0,0 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            import random
         | 
| 3 | 
            -
            from os.path import join
         | 
| 4 | 
            -
            import numpy as np
         | 
| 5 | 
            -
            import torch.multiprocessing
         | 
| 6 | 
            -
            from scipy.io import loadmat
         | 
| 7 | 
            -
            from torchvision.transforms.functional import to_pil_image
         | 
| 8 | 
            -
            from torch.utils.data import Dataset
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            def get_pd_labeldata():
         | 
| 11 | 
            -
                cls_names = ['road', 'building', 'vegetation']
         | 
| 12 | 
            -
                colormap = np.array([
         | 
| 13 | 
            -
                        [58, 0, 68], #[158, 0, 0],[58, 0, 68],
         | 
| 14 | 
            -
                        [0, 130, 122], #[107, 130, 148], 
         | 
| 15 | 
            -
                        [255, 230, 0], #[101, 192, 0],[0, 130, 122],
         | 
| 16 | 
            -
                        [0, 0, 0]])
         | 
| 17 | 
            -
                return cls_names, colormap
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            class potsdam(Dataset):
         | 
| 20 | 
            -
                def __init__(self, transforms, split, root):
         | 
| 21 | 
            -
                    super(potsdam, self).__init__()
         | 
| 22 | 
            -
                    self.split = split
         | 
| 23 | 
            -
                    self.root = root
         | 
| 24 | 
            -
                    self.transform = transforms
         | 
| 25 | 
            -
                    split_files = {
         | 
| 26 | 
            -
                        "train": ["labelled_train.txt"],
         | 
| 27 | 
            -
                        "unlabelled_train": ["unlabelled_train.txt"],
         | 
| 28 | 
            -
                        # "train": ["unlabelled_train.txt"],
         | 
| 29 | 
            -
                        "val": ["labelled_test.txt"],
         | 
| 30 | 
            -
                        "train+val": ["labelled_train.txt", "labelled_test.txt"],
         | 
| 31 | 
            -
                        "all": ["all.txt"]
         | 
| 32 | 
            -
                    }
         | 
| 33 | 
            -
                    assert self.split in split_files.keys()
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                    self.files = []
         | 
| 36 | 
            -
                    for split_file in split_files[self.split]:
         | 
| 37 | 
            -
                        with open(join(self.root, split_file), "r") as f:
         | 
| 38 | 
            -
                            self.files.extend(fn.rstrip() for fn in f.readlines())
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                    self.coarse_labels = True
         | 
| 41 | 
            -
                    self.fine_to_coarse = {0: 0, 4: 0,  # roads and cars
         | 
| 42 | 
            -
                                           1: 1, 5: 1,  # buildings and clutter
         | 
| 43 | 
            -
                                           2: 2, 3: 2,  # vegetation and trees
         | 
| 44 | 
            -
                                           }
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                def __getitem__(self, index):
         | 
| 47 | 
            -
                    image_id = self.files[index]
         | 
| 48 | 
            -
                    img = loadmat(join(self.root, "imgs", image_id + ".mat"))["img"]
         | 
| 49 | 
            -
                    img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3])  # TODO add ir channel back
         | 
| 50 | 
            -
                    try:
         | 
| 51 | 
            -
                        label = loadmat(join(self.root, "gt", image_id + ".mat"))["gt"]
         | 
| 52 | 
            -
                        label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
         | 
| 53 | 
            -
                    except FileNotFoundError:
         | 
| 54 | 
            -
                        label = to_pil_image(torch.ones(1, img.height, img.width))
         | 
| 55 | 
            -
             | 
| 56 | 
            -
                    img, label = self.transform(img, label)
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    if self.coarse_labels:
         | 
| 59 | 
            -
                        new_label_map = torch.ones_like(label)*255
         | 
| 60 | 
            -
                        for fine, coarse in self.fine_to_coarse.items():
         | 
| 61 | 
            -
                            new_label_map[label == fine] = coarse
         | 
| 62 | 
            -
                        label = new_label_map
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                    # mask = (label > 0).to(torch.float32)
         | 
| 65 | 
            -
                    return img, label, image_id
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                def __len__(self):
         | 
| 68 | 
            -
                    return len(self.files)
         | 
| 69 | 
            -
                
         | 
| 70 | 
            -
            classes = ['road', 'building', 'vegetation']
         | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
            class PotsdamRaw(Dataset):
         | 
| 74 | 
            -
                def __init__(self, root, image_set, transform, target_transform, coarse_labels):
         | 
| 75 | 
            -
                    super(PotsdamRaw, self).__init__()
         | 
| 76 | 
            -
                    self.split = image_set
         | 
| 77 | 
            -
                    self.root = os.path.join(root, "potsdamraw", "processed")
         | 
| 78 | 
            -
                    self.transform = transform
         | 
| 79 | 
            -
                    self.target_transform = target_transform
         | 
| 80 | 
            -
                    self.files = []
         | 
| 81 | 
            -
                    for im_num in range(38):
         | 
| 82 | 
            -
                        for i_h in range(15):
         | 
| 83 | 
            -
                            for i_w in range(15):
         | 
| 84 | 
            -
                                self.files.append("{}_{}_{}.mat".format(im_num, i_h, i_w))
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                    self.coarse_labels = coarse_labels
         | 
| 87 | 
            -
                    self.fine_to_coarse = {0: 0, 4: 0,  # roads and cars
         | 
| 88 | 
            -
                                           1: 1, 5: 1,  # buildings and clutter
         | 
| 89 | 
            -
                                           2: 2, 3: 2,  # vegetation and trees
         | 
| 90 | 
            -
                                           255: -1
         | 
| 91 | 
            -
                                           }
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                def __getitem__(self, index):
         | 
| 94 | 
            -
                    image_id = self.files[index]
         | 
| 95 | 
            -
                    img = loadmat(join(self.root, "imgs", image_id))["img"]
         | 
| 96 | 
            -
                    img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3])  # TODO add ir channel back
         | 
| 97 | 
            -
                    try:
         | 
| 98 | 
            -
                        label = loadmat(join(self.root, "gt", image_id))["gt"]
         | 
| 99 | 
            -
                        label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
         | 
| 100 | 
            -
                    except FileNotFoundError:
         | 
| 101 | 
            -
                        label = to_pil_image(torch.ones(1, img.height, img.width))
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                    seed = np.random.randint(2147483647)
         | 
| 104 | 
            -
                    random.seed(seed)
         | 
| 105 | 
            -
                    torch.manual_seed(seed)
         | 
| 106 | 
            -
                    img = self.transform(img)
         | 
| 107 | 
            -
             | 
| 108 | 
            -
                    random.seed(seed)
         | 
| 109 | 
            -
                    torch.manual_seed(seed)
         | 
| 110 | 
            -
                    label = self.target_transform(label).squeeze(0)
         | 
| 111 | 
            -
                    if self.coarse_labels:
         | 
| 112 | 
            -
                        new_label_map = torch.zeros_like(label)
         | 
| 113 | 
            -
                        for fine, coarse in self.fine_to_coarse.items():
         | 
| 114 | 
            -
                            new_label_map[label == fine] = coarse
         | 
| 115 | 
            -
                        label = new_label_map
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                    mask = (label > 0).to(torch.float32)
         | 
| 118 | 
            -
                    return img, label, mask
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                def __len__(self):
         | 
| 121 | 
            -
                    return len(self.files)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        datasets/precomputed.py
    DELETED
    
    | @@ -1,43 +0,0 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            from PIL import Image
         | 
| 3 | 
            -
            from torch.utils.data import Dataset
         | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
            class PrecomputedDataset(Dataset):
         | 
| 7 | 
            -
                def __init__(self, 
         | 
| 8 | 
            -
                             root, 
         | 
| 9 | 
            -
                             transforms,
         | 
| 10 | 
            -
                             student_augs, 
         | 
| 11 | 
            -
                             ):
         | 
| 12 | 
            -
                    super(PrecomputedDataset, self).__init__()
         | 
| 13 | 
            -
                    self.root = root
         | 
| 14 | 
            -
                    self.transforms = transforms
         | 
| 15 | 
            -
                    self.student_augs = student_augs
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                    self.image_files = []
         | 
| 18 | 
            -
                    self.label_files = []
         | 
| 19 | 
            -
                    self.pseudo_files = []
         | 
| 20 | 
            -
                    for file in os.listdir(os.path.join(self.root, 'imgs')):
         | 
| 21 | 
            -
                        self.image_files.append(os.path.join(self.root, 'imgs', file))
         | 
| 22 | 
            -
                        self.label_files.append(os.path.join(self.root, 'gts', file))
         | 
| 23 | 
            -
                        self.pseudo_files.append(os.path.join(self.root, 'pseudos', file))
         | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
                def __getitem__(self, index):
         | 
| 27 | 
            -
                    image_path = self.image_files[index]
         | 
| 28 | 
            -
                    label_path = self.label_files[index]
         | 
| 29 | 
            -
                    pseudo_path = self.pseudo_files[index]
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                    img = Image.open(image_path).convert("RGB")
         | 
| 32 | 
            -
                    label = Image.open(label_path)
         | 
| 33 | 
            -
                    pseudo = Image.open(pseudo_path)
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                    if self.student_augs:
         | 
| 36 | 
            -
                        img, label, aimg, pseudo = self.transforms(img, label, pseudo)
         | 
| 37 | 
            -
                        return img, label.long(), aimg, pseudo.long()
         | 
| 38 | 
            -
                    else:
         | 
| 39 | 
            -
                        img, label, pseudo = self.transforms(img, label, pseudo)
         | 
| 40 | 
            -
                        return img, label.long(), pseudo.long()
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                def __len__(self):
         | 
| 43 | 
            -
                    return len(self.image_files)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
