import os import glob import argparse import numpy as np import laspy import torch import h5py import hydra import torch.nn.functional as F from src.utils import init_config from src.transforms import ( instantiate_datamodule_transforms, SampleRecursiveMainXYAxisTiling, NAGRemoveKeys ) from src.datasets.gridnet import read_gridnet_tile def run_inference(model, cfg, transforms_dict, root_dir, split, scale, pc_tiling): split_dir = os.path.join(root_dir, split) las_files = glob.glob(os.path.join(split_dir, "*", "lidar", "*.las")) for filepath in las_files: print(f"\n[Inference] Processing: {filepath}") data_las = laspy.read(filepath) offset_initial_las = np.array(data_las.header.offset, dtype=np.float64) data_las = read_gridnet_tile( filepath, xyz=True, intensity=True, rgb=True, semantic=False, instance=False, remap=True ) data_las.initial_index = torch.arange(data_las.pos.shape[0]) # to keep initial order of points pos_list = [] pred_list = [] indices_list = [] pos_offset_init = None for x in range(2**pc_tiling): data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las) nag = transforms_dict['pre_transform'](data) nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag) nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag) nag = nag.cuda() nag = transforms_dict['on_device_test_transform'](nag) with torch.no_grad(): output = model(nag) # For voxel level #semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index) #pos_list.append(nag[0].pos.cpu()) # For full resolution level semantic_pred = output.full_res_semantic_pred(super_index_level0_to_level1=nag[0].super_index, sub_level0_to_raw=nag[0].sub) pos_list.append(data.pos.cpu()) indices_list.append(data.initial_index.cpu()) pred_list.append(semantic_pred.cpu()) if pos_offset_init is None: pos_offset_init = nag[0].pos_offset.cpu() merged_pos = torch.cat(pos_list, dim=0) merged_pred = torch.cat(pred_list, dim=0) merged_pos_offset = pos_offset_init + offset_initial_las # only for full res point cloud and keep initial order of points merged_indices = torch.cat(indices_list, dim=0) sorted_indices = torch.argsort(merged_indices) merged_pos = merged_pos[sorted_indices] merged_pred = merged_pred[sorted_indices] pos_data = (merged_pos.numpy() / scale).astype(int) x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2] header = laspy.LasHeader(point_format=3, version="1.2") header.scales = scale header.offsets = merged_pos_offset las = laspy.LasData(header) las.X, las.Y, las.Z = x, y, z las.add_extra_dim( laspy.ExtraBytesParams(name="classif", type=np.uint8, description="Predicted class") ) las.classif = merged_pred.numpy().astype(np.uint8) output_las = filepath.replace('.las', '_classified.las') las.write(output_las) print(f"[Inference] Saved classified LAS to: {output_las}") def export_logits(model, cfg, transforms_dict, root_dir, scale, pc_tiling): las_files = glob.glob(os.path.join(root_dir, "*", "*", "*", "*.las")) for filepath in las_files: print(f"\n[Export Logits] Processing: {filepath}") data_las = laspy.read(filepath) offset_initial_las = np.array(data_las.header.offset, dtype=np.float64) data_las = read_gridnet_tile( filepath, xyz=True, intensity=True, rgb=True, semantic=False, instance=False, remap=True ) pos_list = [] logits_list = [] pos_offset_init = None for x in range(2**pc_tiling): data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las) nag = transforms_dict['pre_transform'](data) nag = NAGRemoveKeys(level=0, keys=[ k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys ])(nag) nag = NAGRemoveKeys(level='1+', keys=[ k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys ])(nag) nag = nag.cuda() nag = transforms_dict['on_device_test_transform'](nag) with torch.no_grad(): output = model(nag) logits = output.voxel_logits_pred(super_index=nag[0].super_index) pos_list.append(nag[0].pos.cpu()) logits_list.append(logits.cpu()) if pos_offset_init is None: pos_offset_init = nag[0].pos_offset.cpu() merged_pos = torch.cat(pos_list, dim=0) merged_logits = torch.cat(logits_list, dim=0) merged_pos_offset = pos_offset_init + offset_initial_las pos_data = (merged_pos.numpy() / scale).astype(int) x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2] logits = merged_logits.numpy() header = laspy.LasHeader(point_format=3, version="1.2") header.scales = scale header.offsets = merged_pos_offset las = laspy.LasData(header) las.X, las.Y, las.Z = x, y, z soft_logits = F.softmax(torch.tensor(logits), dim=1).numpy() for i in range(soft_logits.shape[1]): scaled_logits = (255 * soft_logits[:, i]).clip(0, 255).astype(np.uint8) las.add_extra_dim( laspy.ExtraBytesParams(name=f"sof_log{i}", type=np.uint8, description=f"Logit {i}") ) setattr(las, f"sof_log{i}", scaled_logits[:]) output_las = filepath.replace('.las', '_with_softmax.las') las.write(output_las) print(f"[Export Logits] Saved softmax LAS to: {output_las}") def main(): parser = argparse.ArgumentParser(description="SPT Inference and Logits Export") parser.add_argument('--mode', choices=['inference', 'export_log'], required=True, help="Choose between full-resolution inference or export logits") parser.add_argument('--split', type=str, default='test', help="Data split to process (only used in inference mode) test or val split") parser.add_argument('--weights', type=str, required=True, help="Path to model checkpoint") parser.add_argument('--root_dir', type=str, required=True, help="Root directory of the dataset") parser.add_argument('--pc_tiling', type=str, default='3', help="PC tiling for point cloud sampling") args = parser.parse_args() cfg = init_config(overrides=["experiment=semantic/gridnet"]) transforms_dict = instantiate_datamodule_transforms(cfg.datamodule) model = hydra.utils.instantiate(cfg.model) model = model._load_from_checkpoint(args.weights) model = model.eval().cuda() SCALE = [0.001, 0.001, 0.001] pc_tiling = int(args.pc_tiling) if args.mode == 'inference': run_inference(model, cfg, transforms_dict, args.root_dir, args.split, SCALE, pc_tiling) elif args.mode == 'export_log': export_logits(model, cfg, transforms_dict, args.root_dir, SCALE, pc_tiling) if __name__ == '__main__': main()