File size: 3,425 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import numpy as np
import PIL.Image as Image

class LiDAR2Depth(object):

    def __init__(self,

                 grid_config,

                 ):
        self.x = eval(grid_config['x'])
        self.y = eval(grid_config['y'])
        self.z = eval(grid_config['z'])
        self.depth = eval(grid_config['depth'])

    def points2depthmap(self, points, height, width):
        height, width = height, width
        depth_map = torch.zeros((height, width), dtype=torch.float32)
        coor = torch.round(points[:, :2])
        depth = points[:, 2]
        kept1 = (coor[:, 0] >= 0) & (coor[:, 0] < width) & (
            coor[:, 1] >= 0) & (coor[:, 1] < height) & (
                depth < self.depth[1]) & (
                    depth >= self.depth[0])
        coor, depth = coor[kept1], depth[kept1]
        ranks = coor[:, 0] + coor[:, 1] * width
        sort = (ranks + depth / 100.).argsort()
        coor, depth, ranks = coor[sort], depth[sort], ranks[sort]

        kept2 = torch.ones(coor.shape[0], device=coor.device, dtype=torch.bool)
        kept2[1:] = (ranks[1:] != ranks[:-1])
        coor, depth = coor[kept2], depth[kept2]
        coor = coor.to(torch.long)
        depth_map[coor[:, 1], coor[:, 0]] = depth
        return depth_map

    def __call__(self, features, targets):
        # points, img, sensor2lidar_rotation, sensor2lidar_translation, intrinsics,
        # post_rot, post_tran
        # List: length=frames
        lidar_all_frames = features['lidars_warped']
        # image: T, N_CAMS, C, H, W
        T, N, _, H, W = features['image'].shape
        rots, trans, intrinsics = (features['sensor2lidar_rotation'],
                                  features['sensor2lidar_translation'],
                                  features['intrinsics'])
        post_rot, post_tran, bda = (features['post_rot'],
                               features['post_tran'], features['bda'])

        t = -1
        depth_t = []
        lidar_t = lidar_all_frames[t][:, :3]
        lidar_t = lidar_t - bda[:3, 3].view(1, 3)
        lidar_t = lidar_t.matmul(torch.inverse(bda[:3, :3]).T)

        # print('cancel bda')
        # print(lidar_t[:, 0].max())
        # print(lidar_t[:, 0].min())
        # print(lidar_t[:, 1].max())
        # print(lidar_t[:, 1].min())

        for n in range(N):
            points_img = lidar_t - trans[t, n:n + 1, :]
            lidar2cam_rot = torch.inverse(rots[t, n])
            # lidar2cam, cam2img
            points_img = points_img.matmul(lidar2cam_rot.T).matmul(intrinsics[t, n].T)
            points_img = torch.cat(
                    [points_img[:, :2] / points_img[:, 2:3], points_img[:, 2:3]],
                    1)
            points_img = points_img.matmul(
                    post_rot[t, n].T) + post_tran[t, n:n + 1, :]
            depth_curr = self.points2depthmap(points_img, features['canvas'][-1, n].shape[0], features['canvas'][-1, n].shape[1])
            depth_t.append(depth_curr)
            # Image.fromarray((1- depth_curr.clamp(0,1)).cpu().numpy() * 255).convert('L').save(f'/mnt/f/e2e/navsim_ours/debug/depth{n}.png')
            # Image.fromarray(features['canvas'][-1, n].cpu().numpy().astype(np.uint8)).convert('RGB').save(f'/mnt/f/e2e/navsim_ours/debug/canvas{n}.png')
        features['gt_depth'] = torch.stack(depth_t)
        return features, targets