hank1996 commited on
Commit
1b25981
·
1 Parent(s): 55e7537

Create new file

Browse files
Files changed (1) hide show
  1. lib/dataset/bdd.py +85 -0
lib/dataset/bdd.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import json
5
+
6
+ from .AutoDriveDataset import AutoDriveDataset
7
+ from .convert import convert, id_dict, id_dict_single
8
+ from tqdm import tqdm
9
+
10
+ single_cls = True # just detect vehicle
11
+
12
+ class BddDataset(AutoDriveDataset):
13
+ def __init__(self, cfg, is_train, inputsize, transform=None):
14
+ super().__init__(cfg, is_train, inputsize, transform)
15
+ self.db = self._get_db()
16
+ self.cfg = cfg
17
+
18
+ def _get_db(self):
19
+ """
20
+ get database from the annotation file
21
+ Inputs:
22
+ Returns:
23
+ gt_db: (list)database [a,b,c,...]
24
+ a: (dictionary){'image':, 'information':, ......}
25
+ image: image path
26
+ mask: path of the segmetation label
27
+ label: [cls_id, center_x//256, center_y//256, w//256, h//256] 256=IMAGE_SIZE
28
+ """
29
+ print('building database...')
30
+ gt_db = []
31
+ height, width = self.shapes
32
+ for mask in tqdm(list(self.mask_list)):
33
+ mask_path = str(mask)
34
+ label_path = mask_path.replace(str(self.mask_root), str(self.label_root)).replace(".png", ".json")
35
+ image_path = mask_path.replace(str(self.mask_root), str(self.img_root)).replace(".png", ".jpg")
36
+ lane_path = mask_path.replace(str(self.mask_root), str(self.lane_root))
37
+ with open(label_path, 'r') as f:
38
+ label = json.load(f)
39
+ data = label['frames'][0]['objects']
40
+ data = self.filter_data(data)
41
+ gt = np.zeros((len(data), 5))
42
+ for idx, obj in enumerate(data):
43
+ category = obj['category']
44
+ if category == "traffic light":
45
+ color = obj['attributes']['trafficLightColor']
46
+ category = "tl_" + color
47
+ if category in id_dict.keys():
48
+ x1 = float(obj['box2d']['x1'])
49
+ y1 = float(obj['box2d']['y1'])
50
+ x2 = float(obj['box2d']['x2'])
51
+ y2 = float(obj['box2d']['y2'])
52
+ cls_id = id_dict[category]
53
+ if single_cls:
54
+ cls_id=0
55
+ gt[idx][0] = cls_id
56
+ box = convert((width, height), (x1, x2, y1, y2))
57
+ gt[idx][1:] = list(box)
58
+
59
+
60
+ rec = [{
61
+ 'image': image_path,
62
+ 'label': gt,
63
+ 'mask': mask_path,
64
+ 'lane': lane_path
65
+ }]
66
+
67
+ gt_db += rec
68
+ print('database build finish')
69
+ return gt_db
70
+
71
+ def filter_data(self, data):
72
+ remain = []
73
+ for obj in data:
74
+ if 'box2d' in obj.keys(): # obj.has_key('box2d'):
75
+ if single_cls:
76
+ if obj['category'] in id_dict_single.keys():
77
+ remain.append(obj)
78
+ else:
79
+ remain.append(obj)
80
+ return remain
81
+
82
+ def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
83
+ """
84
+ """
85
+ pass