|
import os |
|
import glob |
|
import argparse |
|
import numpy as np |
|
import laspy |
|
import torch |
|
import h5py |
|
import hydra |
|
import torch.nn.functional as F |
|
|
|
from src.utils import init_config |
|
from src.transforms import ( |
|
instantiate_datamodule_transforms, |
|
SampleRecursiveMainXYAxisTiling, |
|
NAGRemoveKeys |
|
) |
|
from src.datasets.gridnet import read_gridnet_tile |
|
|
|
|
|
def run_inference(model, cfg, transforms_dict, root_dir, split, scale, pc_tiling): |
|
split_dir = os.path.join(root_dir, split) |
|
las_files = glob.glob(os.path.join(split_dir, "*", "lidar", "*.las")) |
|
for filepath in las_files: |
|
print(f"\n[Inference] Processing: {filepath}") |
|
data_las = laspy.read(filepath) |
|
offset_initial_las = np.array(data_las.header.offset, dtype=np.float64) |
|
data_las = read_gridnet_tile( |
|
filepath, xyz=True, intensity=True, rgb=True, semantic=False, instance=False, remap=True |
|
) |
|
data_las.initial_index = torch.arange(data_las.pos.shape[0]) |
|
|
|
pos_list = [] |
|
pred_list = [] |
|
indices_list = [] |
|
pos_offset_init = None |
|
for x in range(2**pc_tiling): |
|
data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las) |
|
nag = transforms_dict['pre_transform'](data) |
|
nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag) |
|
nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag) |
|
nag = nag.cuda() |
|
nag = transforms_dict['on_device_test_transform'](nag) |
|
with torch.no_grad(): |
|
output = model(nag) |
|
|
|
|
|
|
|
|
|
|
|
|
|
semantic_pred = output.full_res_semantic_pred(super_index_level0_to_level1=nag[0].super_index, sub_level0_to_raw=nag[0].sub) |
|
pos_list.append(data.pos.cpu()) |
|
indices_list.append(data.initial_index.cpu()) |
|
|
|
pred_list.append(semantic_pred.cpu()) |
|
|
|
if pos_offset_init is None: |
|
pos_offset_init = nag[0].pos_offset.cpu() |
|
|
|
merged_pos = torch.cat(pos_list, dim=0) |
|
merged_pred = torch.cat(pred_list, dim=0) |
|
merged_pos_offset = pos_offset_init + offset_initial_las |
|
|
|
|
|
merged_indices = torch.cat(indices_list, dim=0) |
|
sorted_indices = torch.argsort(merged_indices) |
|
merged_pos = merged_pos[sorted_indices] |
|
merged_pred = merged_pred[sorted_indices] |
|
|
|
|
|
pos_data = (merged_pos.numpy() / scale).astype(int) |
|
x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2] |
|
|
|
header = laspy.LasHeader(point_format=3, version="1.2") |
|
header.scales = scale |
|
header.offsets = merged_pos_offset |
|
las = laspy.LasData(header) |
|
las.X, las.Y, las.Z = x, y, z |
|
|
|
las.add_extra_dim( |
|
laspy.ExtraBytesParams(name="classif", type=np.uint8, description="Predicted class") |
|
) |
|
las.classif = merged_pred.numpy().astype(np.uint8) |
|
|
|
output_las = filepath.replace('.las', '_classified.las') |
|
las.write(output_las) |
|
print(f"[Inference] Saved classified LAS to: {output_las}") |
|
|
|
def export_logits(model, cfg, transforms_dict, root_dir, scale, pc_tiling): |
|
las_files = glob.glob(os.path.join(root_dir, "*", "*", "*", "*.las")) |
|
for filepath in las_files: |
|
print(f"\n[Export Logits] Processing: {filepath}") |
|
data_las = laspy.read(filepath) |
|
offset_initial_las = np.array(data_las.header.offset, dtype=np.float64) |
|
|
|
data_las = read_gridnet_tile( |
|
filepath, xyz=True, intensity=True, rgb=True, |
|
semantic=False, instance=False, remap=True |
|
) |
|
|
|
pos_list = [] |
|
logits_list = [] |
|
pos_offset_init = None |
|
|
|
for x in range(2**pc_tiling): |
|
data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las) |
|
nag = transforms_dict['pre_transform'](data) |
|
nag = NAGRemoveKeys(level=0, keys=[ |
|
k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys |
|
])(nag) |
|
nag = NAGRemoveKeys(level='1+', keys=[ |
|
k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys |
|
])(nag) |
|
nag = nag.cuda() |
|
nag = transforms_dict['on_device_test_transform'](nag) |
|
|
|
with torch.no_grad(): |
|
output = model(nag) |
|
|
|
logits = output.voxel_logits_pred(super_index=nag[0].super_index) |
|
|
|
pos_list.append(nag[0].pos.cpu()) |
|
logits_list.append(logits.cpu()) |
|
|
|
if pos_offset_init is None: |
|
pos_offset_init = nag[0].pos_offset.cpu() |
|
|
|
merged_pos = torch.cat(pos_list, dim=0) |
|
merged_logits = torch.cat(logits_list, dim=0) |
|
merged_pos_offset = pos_offset_init + offset_initial_las |
|
|
|
pos_data = (merged_pos.numpy() / scale).astype(int) |
|
x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2] |
|
logits = merged_logits.numpy() |
|
|
|
header = laspy.LasHeader(point_format=3, version="1.2") |
|
header.scales = scale |
|
header.offsets = merged_pos_offset |
|
las = laspy.LasData(header) |
|
las.X, las.Y, las.Z = x, y, z |
|
|
|
soft_logits = F.softmax(torch.tensor(logits), dim=1).numpy() |
|
for i in range(soft_logits.shape[1]): |
|
scaled_logits = (255 * soft_logits[:, i]).clip(0, 255).astype(np.uint8) |
|
las.add_extra_dim( |
|
laspy.ExtraBytesParams(name=f"sof_log{i}", type=np.uint8, description=f"Logit {i}") |
|
) |
|
setattr(las, f"sof_log{i}", scaled_logits[:]) |
|
|
|
output_las = filepath.replace('.las', '_with_softmax.las') |
|
las.write(output_las) |
|
print(f"[Export Logits] Saved softmax LAS to: {output_las}") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="SPT Inference and Logits Export") |
|
parser.add_argument('--mode', choices=['inference', 'export_log'], required=True, help="Choose between full-resolution inference or export logits") |
|
parser.add_argument('--split', type=str, default='test', help="Data split to process (only used in inference mode) test or val split") |
|
parser.add_argument('--weights', type=str, required=True, help="Path to model checkpoint") |
|
parser.add_argument('--root_dir', type=str, required=True, help="Root directory of the dataset") |
|
parser.add_argument('--pc_tiling', type=str, default='3', help="PC tiling for point cloud sampling") |
|
|
|
|
|
args = parser.parse_args() |
|
cfg = init_config(overrides=["experiment=semantic/gridnet"]) |
|
transforms_dict = instantiate_datamodule_transforms(cfg.datamodule) |
|
|
|
model = hydra.utils.instantiate(cfg.model) |
|
model = model._load_from_checkpoint(args.weights) |
|
model = model.eval().cuda() |
|
|
|
SCALE = [0.001, 0.001, 0.001] |
|
pc_tiling = int(args.pc_tiling) |
|
|
|
if args.mode == 'inference': |
|
run_inference(model, cfg, transforms_dict, args.root_dir, args.split, SCALE, pc_tiling) |
|
elif args.mode == 'export_log': |
|
export_logits(model, cfg, transforms_dict, args.root_dir, SCALE, pc_tiling) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|