Oliver Hahn
commited on
Commit
·
04513cf
1
Parent(s):
3f97a79
add demo
Browse files- .DS_Store +0 -0
- README.md +1 -1
- app.py +2 -6
- assets/demo_examples/.DS_Store +0 -0
- assets/demo_examples/cityscapes_example.png +0 -3
- assets/demo_examples/potsdam_example.png +0 -3
- assets/{demo_examples/coco_example.jpg → example.jpg} +0 -0
- datasets/.DS_Store +0 -0
- datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- datasets/__pycache__/__init__.cpython-36.pyc +0 -0
- datasets/__pycache__/__init__.cpython-37.pyc +0 -0
- datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-310.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-311.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-36.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-37.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-38.pyc +0 -0
- datasets/__pycache__/cityscapes.cpython-39.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-310.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-311.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-36.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-37.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-38.pyc +0 -0
- datasets/__pycache__/cocostuff.cpython-39.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-310.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-311.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-36.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-37.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-38.pyc +0 -0
- datasets/__pycache__/potsdam.cpython-39.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-310.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-311.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-36.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-37.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-38.pyc +0 -0
- datasets/__pycache__/precomputed.cpython-39.pyc +0 -0
- datasets/cocostuff.py +0 -125
- datasets/potsdam.py +0 -121
- datasets/precomputed.py +0 -43
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: PriMaPs
|
3 |
emoji: 😻
|
4 |
-
colorFrom:
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.6.0
|
|
|
1 |
---
|
2 |
title: PriMaPs
|
3 |
emoji: 😻
|
4 |
+
colorFrom: black
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.6.0
|
app.py
CHANGED
@@ -50,9 +50,7 @@ def gradio_primaps(image_path, threshold, architecture):
|
|
50 |
if __name__ == '__main__':
|
51 |
# Example image paths
|
52 |
example_images = [
|
53 |
-
"assets/demo_examples/
|
54 |
-
"assets/demo_examples/coco_example.jpg",
|
55 |
-
"assets/demo_examples/potsdam_example.png"
|
56 |
]
|
57 |
|
58 |
# Gradio interface
|
@@ -60,7 +58,7 @@ if __name__ == '__main__':
|
|
60 |
fn=gradio_primaps,
|
61 |
inputs=[
|
62 |
gr.Image(type="filepath", label="Image"),
|
63 |
-
gr.Slider(0.0, 1.0, step=0.05, value=0.
|
64 |
gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"),
|
65 |
],
|
66 |
outputs=gr.Image(label="PriMaPs"),
|
@@ -68,8 +66,6 @@ if __name__ == '__main__':
|
|
68 |
description="Upload an image and adjust the threshold to visualize PriMaPs.",
|
69 |
examples=[
|
70 |
[example_images[0], 0.4, 'dino_vitb'],
|
71 |
-
[example_images[1], 0.4, 'dino_vitb'],
|
72 |
-
[example_images[2], 0.4, 'dino_vitb']
|
73 |
]
|
74 |
)
|
75 |
|
|
|
50 |
if __name__ == '__main__':
|
51 |
# Example image paths
|
52 |
example_images = [
|
53 |
+
"assets/demo_examples/example.jpg",
|
|
|
|
|
54 |
]
|
55 |
|
56 |
# Gradio interface
|
|
|
58 |
fn=gradio_primaps,
|
59 |
inputs=[
|
60 |
gr.Image(type="filepath", label="Image"),
|
61 |
+
gr.Slider(0.0, 1.0, step=0.05, value=0.4, label="Threshold"),
|
62 |
gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"),
|
63 |
],
|
64 |
outputs=gr.Image(label="PriMaPs"),
|
|
|
66 |
description="Upload an image and adjust the threshold to visualize PriMaPs.",
|
67 |
examples=[
|
68 |
[example_images[0], 0.4, 'dino_vitb'],
|
|
|
|
|
69 |
]
|
70 |
)
|
71 |
|
assets/demo_examples/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
assets/demo_examples/cityscapes_example.png
DELETED
Git LFS Details
|
assets/demo_examples/potsdam_example.png
DELETED
Git LFS Details
|
assets/{demo_examples/coco_example.jpg → example.jpg}
RENAMED
File without changes
|
datasets/.DS_Store
CHANGED
Binary files a/datasets/.DS_Store and b/datasets/.DS_Store differ
|
|
datasets/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (289 Bytes)
|
|
datasets/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (343 Bytes)
|
|
datasets/__pycache__/__init__.cpython-36.pyc
DELETED
Binary file (279 Bytes)
|
|
datasets/__pycache__/__init__.cpython-37.pyc
DELETED
Binary file (283 Bytes)
|
|
datasets/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (287 Bytes)
|
|
datasets/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (287 Bytes)
|
|
datasets/__pycache__/cityscapes.cpython-310.pyc
DELETED
Binary file (3.2 kB)
|
|
datasets/__pycache__/cityscapes.cpython-311.pyc
DELETED
Binary file (4.66 kB)
|
|
datasets/__pycache__/cityscapes.cpython-36.pyc
DELETED
Binary file (3.44 kB)
|
|
datasets/__pycache__/cityscapes.cpython-37.pyc
DELETED
Binary file (3.44 kB)
|
|
datasets/__pycache__/cityscapes.cpython-38.pyc
DELETED
Binary file (5.42 kB)
|
|
datasets/__pycache__/cityscapes.cpython-39.pyc
DELETED
Binary file (3.17 kB)
|
|
datasets/__pycache__/cocostuff.cpython-310.pyc
DELETED
Binary file (6.38 kB)
|
|
datasets/__pycache__/cocostuff.cpython-311.pyc
DELETED
Binary file (10.9 kB)
|
|
datasets/__pycache__/cocostuff.cpython-36.pyc
DELETED
Binary file (5.23 kB)
|
|
datasets/__pycache__/cocostuff.cpython-37.pyc
DELETED
Binary file (5.18 kB)
|
|
datasets/__pycache__/cocostuff.cpython-38.pyc
DELETED
Binary file (5.53 kB)
|
|
datasets/__pycache__/cocostuff.cpython-39.pyc
DELETED
Binary file (5.2 kB)
|
|
datasets/__pycache__/potsdam.cpython-310.pyc
DELETED
Binary file (4.03 kB)
|
|
datasets/__pycache__/potsdam.cpython-311.pyc
DELETED
Binary file (8.22 kB)
|
|
datasets/__pycache__/potsdam.cpython-36.pyc
DELETED
Binary file (4.07 kB)
|
|
datasets/__pycache__/potsdam.cpython-37.pyc
DELETED
Binary file (4.08 kB)
|
|
datasets/__pycache__/potsdam.cpython-38.pyc
DELETED
Binary file (4.1 kB)
|
|
datasets/__pycache__/potsdam.cpython-39.pyc
DELETED
Binary file (4.07 kB)
|
|
datasets/__pycache__/precomputed.cpython-310.pyc
DELETED
Binary file (1.49 kB)
|
|
datasets/__pycache__/precomputed.cpython-311.pyc
DELETED
Binary file (3.07 kB)
|
|
datasets/__pycache__/precomputed.cpython-36.pyc
DELETED
Binary file (1.47 kB)
|
|
datasets/__pycache__/precomputed.cpython-37.pyc
DELETED
Binary file (1.48 kB)
|
|
datasets/__pycache__/precomputed.cpython-38.pyc
DELETED
Binary file (1.53 kB)
|
|
datasets/__pycache__/precomputed.cpython-39.pyc
DELETED
Binary file (1.49 kB)
|
|
datasets/cocostuff.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
from os.path import join
|
2 |
-
import numpy as np
|
3 |
-
import torch.multiprocessing
|
4 |
-
from PIL import Image
|
5 |
-
from torch.utils.data import Dataset
|
6 |
-
|
7 |
-
def bit_get(val, idx):
|
8 |
-
"""Gets the bit value.
|
9 |
-
Args:
|
10 |
-
val: Input value, int or numpy int array.
|
11 |
-
idx: Which bit of the input val.
|
12 |
-
Returns:
|
13 |
-
The "idx"-th bit of input val.
|
14 |
-
"""
|
15 |
-
return (val >> idx) & 1
|
16 |
-
|
17 |
-
|
18 |
-
def create_pascal_label_colormap():
|
19 |
-
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
|
20 |
-
Returns:
|
21 |
-
A colormap for visualizing segmentation results.
|
22 |
-
"""
|
23 |
-
colormap = np.zeros((512, 3), dtype=int)
|
24 |
-
ind = np.arange(512, dtype=int)
|
25 |
-
|
26 |
-
for shift in reversed(list(range(8))):
|
27 |
-
for channel in range(3):
|
28 |
-
colormap[:, channel] |= bit_get(ind, channel) << shift
|
29 |
-
ind >>= 3
|
30 |
-
|
31 |
-
return colormap
|
32 |
-
|
33 |
-
def get_coco_labeldata():
|
34 |
-
cls_names = ["electronic", "appliance", "food", "furniture", "indoor", "kitchen", "accessory", "animal", "outdoor", "person", "sports", "vehicle", "ceiling", "floor", "food", "furniture", "rawmaterial", "textile", "wall", "window", "building", "ground", "plant", "sky", "solid", "structural", "water"]
|
35 |
-
colormap = create_pascal_label_colormap()
|
36 |
-
colormap[27] = np.array([0, 0, 0])
|
37 |
-
return cls_names, colormap
|
38 |
-
|
39 |
-
class cocostuff(Dataset):
|
40 |
-
def __init__(self, root, split, transforms, #target_transform,
|
41 |
-
coarse_labels=None, exclude_things=None, subset=7): #None):
|
42 |
-
super(cocostuff, self).__init__()
|
43 |
-
self.split = split
|
44 |
-
self.root = root
|
45 |
-
self.coarse_labels = coarse_labels
|
46 |
-
self.transforms = transforms
|
47 |
-
#self.label_transform = target_transform
|
48 |
-
self.subset = subset
|
49 |
-
self.exclude_things = exclude_things
|
50 |
-
|
51 |
-
if self.subset is None:
|
52 |
-
self.image_list = "Coco164kFull_Stuff_Coarse.txt"
|
53 |
-
elif self.subset == 6: # IIC Coarse
|
54 |
-
self.image_list = "Coco164kFew_Stuff_6.txt"
|
55 |
-
elif self.subset == 7: # IIC Fine
|
56 |
-
self.image_list = "Coco164kFull_Stuff_Coarse_7.txt"
|
57 |
-
|
58 |
-
assert self.split in ["train", "val", "train+val"]
|
59 |
-
split_dirs = {
|
60 |
-
"train": ["train2017"],
|
61 |
-
"val": ["val2017"],
|
62 |
-
"train+val": ["train2017", "val2017"]
|
63 |
-
}
|
64 |
-
|
65 |
-
self.image_files = []
|
66 |
-
self.label_files = []
|
67 |
-
for split_dir in split_dirs[self.split]:
|
68 |
-
with open(join(self.root, "curated", split_dir, self.image_list), "r") as f:
|
69 |
-
img_ids = [fn.rstrip() for fn in f.readlines()]
|
70 |
-
for img_id in img_ids:
|
71 |
-
self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg"))
|
72 |
-
self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png"))
|
73 |
-
|
74 |
-
self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8,
|
75 |
-
13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7,
|
76 |
-
25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10,
|
77 |
-
37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5,
|
78 |
-
49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2,
|
79 |
-
61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0,
|
80 |
-
73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4,
|
81 |
-
85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22,
|
82 |
-
97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15,
|
83 |
-
107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13,
|
84 |
-
117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24,
|
85 |
-
127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17,
|
86 |
-
137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21,
|
87 |
-
147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23,
|
88 |
-
157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17,
|
89 |
-
167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18,
|
90 |
-
177: 26, 178: 26, 179: 19, 180: 19, 181: 24}
|
91 |
-
|
92 |
-
self._label_names = [
|
93 |
-
"ground-stuff",
|
94 |
-
"plant-stuff",
|
95 |
-
"sky-stuff",
|
96 |
-
]
|
97 |
-
self.cocostuff3_coarse_classes = [23, 22, 21]
|
98 |
-
self.first_stuff_index = 12
|
99 |
-
|
100 |
-
def __getitem__(self, index):
|
101 |
-
image_path = self.image_files[index]
|
102 |
-
label_path = self.label_files[index]
|
103 |
-
seed = np.random.randint(2147483647)
|
104 |
-
|
105 |
-
img, label = self.transforms(Image.open(image_path).convert("RGB"), Image.open(label_path))
|
106 |
-
|
107 |
-
label[label == 255] = -1 # to be consistent with 10k
|
108 |
-
coarse_label = torch.zeros_like(label)
|
109 |
-
for fine, coarse in self.fine_to_coarse.items():
|
110 |
-
coarse_label[label == fine] = coarse
|
111 |
-
coarse_label[label == -1] = 255 #-1
|
112 |
-
|
113 |
-
if self.coarse_labels:
|
114 |
-
coarser_labels = -torch.ones_like(label)
|
115 |
-
for i, c in enumerate(self.cocostuff3_coarse_classes):
|
116 |
-
coarser_labels[coarse_label == c] = i
|
117 |
-
return img, coarser_labels, coarser_labels >= 0
|
118 |
-
else:
|
119 |
-
if self.exclude_things:
|
120 |
-
return img, coarse_label - self.first_stuff_index, (coarse_label >= self.first_stuff_index)
|
121 |
-
else:
|
122 |
-
return img, coarse_label, image_path
|
123 |
-
|
124 |
-
def __len__(self):
|
125 |
-
return len(self.image_files)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/potsdam.py
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import random
|
3 |
-
from os.path import join
|
4 |
-
import numpy as np
|
5 |
-
import torch.multiprocessing
|
6 |
-
from scipy.io import loadmat
|
7 |
-
from torchvision.transforms.functional import to_pil_image
|
8 |
-
from torch.utils.data import Dataset
|
9 |
-
|
10 |
-
def get_pd_labeldata():
|
11 |
-
cls_names = ['road', 'building', 'vegetation']
|
12 |
-
colormap = np.array([
|
13 |
-
[58, 0, 68], #[158, 0, 0],[58, 0, 68],
|
14 |
-
[0, 130, 122], #[107, 130, 148],
|
15 |
-
[255, 230, 0], #[101, 192, 0],[0, 130, 122],
|
16 |
-
[0, 0, 0]])
|
17 |
-
return cls_names, colormap
|
18 |
-
|
19 |
-
class potsdam(Dataset):
|
20 |
-
def __init__(self, transforms, split, root):
|
21 |
-
super(potsdam, self).__init__()
|
22 |
-
self.split = split
|
23 |
-
self.root = root
|
24 |
-
self.transform = transforms
|
25 |
-
split_files = {
|
26 |
-
"train": ["labelled_train.txt"],
|
27 |
-
"unlabelled_train": ["unlabelled_train.txt"],
|
28 |
-
# "train": ["unlabelled_train.txt"],
|
29 |
-
"val": ["labelled_test.txt"],
|
30 |
-
"train+val": ["labelled_train.txt", "labelled_test.txt"],
|
31 |
-
"all": ["all.txt"]
|
32 |
-
}
|
33 |
-
assert self.split in split_files.keys()
|
34 |
-
|
35 |
-
self.files = []
|
36 |
-
for split_file in split_files[self.split]:
|
37 |
-
with open(join(self.root, split_file), "r") as f:
|
38 |
-
self.files.extend(fn.rstrip() for fn in f.readlines())
|
39 |
-
|
40 |
-
self.coarse_labels = True
|
41 |
-
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
|
42 |
-
1: 1, 5: 1, # buildings and clutter
|
43 |
-
2: 2, 3: 2, # vegetation and trees
|
44 |
-
}
|
45 |
-
|
46 |
-
def __getitem__(self, index):
|
47 |
-
image_id = self.files[index]
|
48 |
-
img = loadmat(join(self.root, "imgs", image_id + ".mat"))["img"]
|
49 |
-
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
|
50 |
-
try:
|
51 |
-
label = loadmat(join(self.root, "gt", image_id + ".mat"))["gt"]
|
52 |
-
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
|
53 |
-
except FileNotFoundError:
|
54 |
-
label = to_pil_image(torch.ones(1, img.height, img.width))
|
55 |
-
|
56 |
-
img, label = self.transform(img, label)
|
57 |
-
|
58 |
-
if self.coarse_labels:
|
59 |
-
new_label_map = torch.ones_like(label)*255
|
60 |
-
for fine, coarse in self.fine_to_coarse.items():
|
61 |
-
new_label_map[label == fine] = coarse
|
62 |
-
label = new_label_map
|
63 |
-
|
64 |
-
# mask = (label > 0).to(torch.float32)
|
65 |
-
return img, label, image_id
|
66 |
-
|
67 |
-
def __len__(self):
|
68 |
-
return len(self.files)
|
69 |
-
|
70 |
-
classes = ['road', 'building', 'vegetation']
|
71 |
-
|
72 |
-
|
73 |
-
class PotsdamRaw(Dataset):
|
74 |
-
def __init__(self, root, image_set, transform, target_transform, coarse_labels):
|
75 |
-
super(PotsdamRaw, self).__init__()
|
76 |
-
self.split = image_set
|
77 |
-
self.root = os.path.join(root, "potsdamraw", "processed")
|
78 |
-
self.transform = transform
|
79 |
-
self.target_transform = target_transform
|
80 |
-
self.files = []
|
81 |
-
for im_num in range(38):
|
82 |
-
for i_h in range(15):
|
83 |
-
for i_w in range(15):
|
84 |
-
self.files.append("{}_{}_{}.mat".format(im_num, i_h, i_w))
|
85 |
-
|
86 |
-
self.coarse_labels = coarse_labels
|
87 |
-
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
|
88 |
-
1: 1, 5: 1, # buildings and clutter
|
89 |
-
2: 2, 3: 2, # vegetation and trees
|
90 |
-
255: -1
|
91 |
-
}
|
92 |
-
|
93 |
-
def __getitem__(self, index):
|
94 |
-
image_id = self.files[index]
|
95 |
-
img = loadmat(join(self.root, "imgs", image_id))["img"]
|
96 |
-
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
|
97 |
-
try:
|
98 |
-
label = loadmat(join(self.root, "gt", image_id))["gt"]
|
99 |
-
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
|
100 |
-
except FileNotFoundError:
|
101 |
-
label = to_pil_image(torch.ones(1, img.height, img.width))
|
102 |
-
|
103 |
-
seed = np.random.randint(2147483647)
|
104 |
-
random.seed(seed)
|
105 |
-
torch.manual_seed(seed)
|
106 |
-
img = self.transform(img)
|
107 |
-
|
108 |
-
random.seed(seed)
|
109 |
-
torch.manual_seed(seed)
|
110 |
-
label = self.target_transform(label).squeeze(0)
|
111 |
-
if self.coarse_labels:
|
112 |
-
new_label_map = torch.zeros_like(label)
|
113 |
-
for fine, coarse in self.fine_to_coarse.items():
|
114 |
-
new_label_map[label == fine] = coarse
|
115 |
-
label = new_label_map
|
116 |
-
|
117 |
-
mask = (label > 0).to(torch.float32)
|
118 |
-
return img, label, mask
|
119 |
-
|
120 |
-
def __len__(self):
|
121 |
-
return len(self.files)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/precomputed.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from PIL import Image
|
3 |
-
from torch.utils.data import Dataset
|
4 |
-
|
5 |
-
|
6 |
-
class PrecomputedDataset(Dataset):
|
7 |
-
def __init__(self,
|
8 |
-
root,
|
9 |
-
transforms,
|
10 |
-
student_augs,
|
11 |
-
):
|
12 |
-
super(PrecomputedDataset, self).__init__()
|
13 |
-
self.root = root
|
14 |
-
self.transforms = transforms
|
15 |
-
self.student_augs = student_augs
|
16 |
-
|
17 |
-
self.image_files = []
|
18 |
-
self.label_files = []
|
19 |
-
self.pseudo_files = []
|
20 |
-
for file in os.listdir(os.path.join(self.root, 'imgs')):
|
21 |
-
self.image_files.append(os.path.join(self.root, 'imgs', file))
|
22 |
-
self.label_files.append(os.path.join(self.root, 'gts', file))
|
23 |
-
self.pseudo_files.append(os.path.join(self.root, 'pseudos', file))
|
24 |
-
|
25 |
-
|
26 |
-
def __getitem__(self, index):
|
27 |
-
image_path = self.image_files[index]
|
28 |
-
label_path = self.label_files[index]
|
29 |
-
pseudo_path = self.pseudo_files[index]
|
30 |
-
|
31 |
-
img = Image.open(image_path).convert("RGB")
|
32 |
-
label = Image.open(label_path)
|
33 |
-
pseudo = Image.open(pseudo_path)
|
34 |
-
|
35 |
-
if self.student_augs:
|
36 |
-
img, label, aimg, pseudo = self.transforms(img, label, pseudo)
|
37 |
-
return img, label.long(), aimg, pseudo.long()
|
38 |
-
else:
|
39 |
-
img, label, pseudo = self.transforms(img, label, pseudo)
|
40 |
-
return img, label.long(), pseudo.long()
|
41 |
-
|
42 |
-
def __len__(self):
|
43 |
-
return len(self.image_files)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|