from __future__ import annotations from typing import Dict import numpy as np import torch from det_map.data.datasets.dataclasses import AgentInput, Camera from det_map.data.datasets.lidar_utils import transform_points, render_image from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder from mmcv.parallel import DataContainer as DC class LiDARCameraFeatureBuilder(AbstractFeatureBuilder): def __init__(self, pipelines): super().__init__() self.pipelines = pipelines def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]: img_pipeline = self.pipelines['img'] timestamps_ori = agent_input.timestamps timestamps = [(timestamps_ori[-1] - tmp) / 1e6 for tmp in timestamps_ori] lidars = [np.copy(tmp.lidar_pc) for tmp in agent_input.lidars] ego2globals = [tmp for tmp in agent_input.ego2globals] # last frame is the key frame global2ego_key = np.linalg.inv(ego2globals[-1]) # ego2global, global2ego key frame lidars_warped = [transform_points(transform_points(pts, mat), global2ego_key) for pts, mat in zip(lidars[:-1], ego2globals[:-1])] lidars_warped.append(lidars[-1]) for i, l in enumerate(lidars_warped): # x,y,z,intensity,timestamp l[4] = timestamps[i] lidars_warped[i] = torch.from_numpy(l[:5]).t() # debug visualize lidar pc # for idx, lidar in enumerate(lidars_warped): # render_image(lidar, str('warped'+ str(idx))) # for idx, lidar in enumerate([tmp.lidar_pc for tmp in agent_input.lidars]): # render_image(lidar, str('ori'+ str(idx))) cams_all_frames = [[ tmp.cam_f0, # tmp.cam_l0, # tmp.cam_l1, # tmp.cam_l2, # tmp.cam_r0, # tmp.cam_r1, # tmp.cam_r2, tmp.cam_b0 ] for tmp in agent_input.cameras] image, canvas, sensor2lidar_rotation, sensor2lidar_translation, intrinsics, distortion, post_rot, post_tran = [], [], [], [], [], [], [], [] for cams_frame_t in cams_all_frames: image_t, canvas_t, sensor2lidar_rotation_t, sensor2lidar_translation_t, intrinsics_t, distortion_t, post_rot_t, post_tran_t = [], [], [], [], [], [], [], [] for cam in cams_frame_t: cam_processed: Camera = img_pipeline(cam) image_t.append(cam_processed.image) canvas_t.append(cam_processed.canvas) sensor2lidar_rotation_t.append(cam_processed.sensor2lidar_rotation) sensor2lidar_translation_t.append(cam_processed.sensor2lidar_translation) intrinsics_t.append(cam_processed.intrinsics) distortion_t.append(cam_processed.distortion) post_rot_t.append(cam_processed.post_rot) post_tran_t.append(cam_processed.post_tran) image.append(torch.stack(image_t)) canvas.append(torch.stack(canvas_t)) sensor2lidar_rotation.append(torch.stack(sensor2lidar_rotation_t)) sensor2lidar_translation.append(torch.stack(sensor2lidar_translation_t)) intrinsics.append(torch.stack(intrinsics_t)) distortion.append(torch.stack(distortion_t)) post_rot.append(torch.stack(post_rot_t)) post_tran.append(torch.stack(post_tran_t)) # img: T, N_CAM, C, H, W # imgs = DC(torch.stack(image), cpu_only=False, stack=True) #combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics)) #coords = torch.matmul(combine, coords) #coords += sensor2lidar_translation imgs = torch.stack(image) return { "image": imgs, 'canvas': torch.stack(canvas).to(imgs), 'sensor2lidar_rotation': torch.stack(sensor2lidar_rotation).to(imgs), 'sensor2lidar_translation': torch.stack(sensor2lidar_translation).to(imgs), 'intrinsics': torch.stack(intrinsics).to(imgs), 'distortion': torch.stack(distortion).to(imgs), 'post_rot': torch.stack(post_rot).to(imgs), 'post_tran': torch.stack(post_tran).to(imgs), "lidars_warped": lidars_warped }