English
SPT_GridNet-HD_baseline / inference.py
Shanci's picture
Upload folder using huggingface_hub
c93b496 verified
import os
import glob
import argparse
import numpy as np
import laspy
import torch
import h5py
import hydra
import torch.nn.functional as F
from src.utils import init_config
from src.transforms import (
instantiate_datamodule_transforms,
SampleRecursiveMainXYAxisTiling,
NAGRemoveKeys
)
from src.datasets.gridnet import read_gridnet_tile
def run_inference(model, cfg, transforms_dict, root_dir, split, scale, pc_tiling):
split_dir = os.path.join(root_dir, split)
las_files = glob.glob(os.path.join(split_dir, "*", "lidar", "*.las"))
for filepath in las_files:
print(f"\n[Inference] Processing: {filepath}")
data_las = laspy.read(filepath)
offset_initial_las = np.array(data_las.header.offset, dtype=np.float64)
data_las = read_gridnet_tile(
filepath, xyz=True, intensity=True, rgb=True, semantic=False, instance=False, remap=True
)
data_las.initial_index = torch.arange(data_las.pos.shape[0]) # to keep initial order of points
pos_list = []
pred_list = []
indices_list = []
pos_offset_init = None
for x in range(2**pc_tiling):
data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las)
nag = transforms_dict['pre_transform'](data)
nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag)
nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag)
nag = nag.cuda()
nag = transforms_dict['on_device_test_transform'](nag)
with torch.no_grad():
output = model(nag)
# For voxel level
#semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)
#pos_list.append(nag[0].pos.cpu())
# For full resolution level
semantic_pred = output.full_res_semantic_pred(super_index_level0_to_level1=nag[0].super_index, sub_level0_to_raw=nag[0].sub)
pos_list.append(data.pos.cpu())
indices_list.append(data.initial_index.cpu())
pred_list.append(semantic_pred.cpu())
if pos_offset_init is None:
pos_offset_init = nag[0].pos_offset.cpu()
merged_pos = torch.cat(pos_list, dim=0)
merged_pred = torch.cat(pred_list, dim=0)
merged_pos_offset = pos_offset_init + offset_initial_las
# only for full res point cloud and keep initial order of points
merged_indices = torch.cat(indices_list, dim=0)
sorted_indices = torch.argsort(merged_indices)
merged_pos = merged_pos[sorted_indices]
merged_pred = merged_pred[sorted_indices]
pos_data = (merged_pos.numpy() / scale).astype(int)
x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2]
header = laspy.LasHeader(point_format=3, version="1.2")
header.scales = scale
header.offsets = merged_pos_offset
las = laspy.LasData(header)
las.X, las.Y, las.Z = x, y, z
las.add_extra_dim(
laspy.ExtraBytesParams(name="classif", type=np.uint8, description="Predicted class")
)
las.classif = merged_pred.numpy().astype(np.uint8)
output_las = filepath.replace('.las', '_classified.las')
las.write(output_las)
print(f"[Inference] Saved classified LAS to: {output_las}")
def export_logits(model, cfg, transforms_dict, root_dir, scale, pc_tiling):
las_files = glob.glob(os.path.join(root_dir, "*", "*", "*", "*.las"))
for filepath in las_files:
print(f"\n[Export Logits] Processing: {filepath}")
data_las = laspy.read(filepath)
offset_initial_las = np.array(data_las.header.offset, dtype=np.float64)
data_las = read_gridnet_tile(
filepath, xyz=True, intensity=True, rgb=True,
semantic=False, instance=False, remap=True
)
pos_list = []
logits_list = []
pos_offset_init = None
for x in range(2**pc_tiling):
data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las)
nag = transforms_dict['pre_transform'](data)
nag = NAGRemoveKeys(level=0, keys=[
k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys
])(nag)
nag = NAGRemoveKeys(level='1+', keys=[
k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys
])(nag)
nag = nag.cuda()
nag = transforms_dict['on_device_test_transform'](nag)
with torch.no_grad():
output = model(nag)
logits = output.voxel_logits_pred(super_index=nag[0].super_index)
pos_list.append(nag[0].pos.cpu())
logits_list.append(logits.cpu())
if pos_offset_init is None:
pos_offset_init = nag[0].pos_offset.cpu()
merged_pos = torch.cat(pos_list, dim=0)
merged_logits = torch.cat(logits_list, dim=0)
merged_pos_offset = pos_offset_init + offset_initial_las
pos_data = (merged_pos.numpy() / scale).astype(int)
x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2]
logits = merged_logits.numpy()
header = laspy.LasHeader(point_format=3, version="1.2")
header.scales = scale
header.offsets = merged_pos_offset
las = laspy.LasData(header)
las.X, las.Y, las.Z = x, y, z
soft_logits = F.softmax(torch.tensor(logits), dim=1).numpy()
for i in range(soft_logits.shape[1]):
scaled_logits = (255 * soft_logits[:, i]).clip(0, 255).astype(np.uint8)
las.add_extra_dim(
laspy.ExtraBytesParams(name=f"sof_log{i}", type=np.uint8, description=f"Logit {i}")
)
setattr(las, f"sof_log{i}", scaled_logits[:])
output_las = filepath.replace('.las', '_with_softmax.las')
las.write(output_las)
print(f"[Export Logits] Saved softmax LAS to: {output_las}")
def main():
parser = argparse.ArgumentParser(description="SPT Inference and Logits Export")
parser.add_argument('--mode', choices=['inference', 'export_log'], required=True, help="Choose between full-resolution inference or export logits")
parser.add_argument('--split', type=str, default='test', help="Data split to process (only used in inference mode) test or val split")
parser.add_argument('--weights', type=str, required=True, help="Path to model checkpoint")
parser.add_argument('--root_dir', type=str, required=True, help="Root directory of the dataset")
parser.add_argument('--pc_tiling', type=str, default='3', help="PC tiling for point cloud sampling")
args = parser.parse_args()
cfg = init_config(overrides=["experiment=semantic/gridnet"])
transforms_dict = instantiate_datamodule_transforms(cfg.datamodule)
model = hydra.utils.instantiate(cfg.model)
model = model._load_from_checkpoint(args.weights)
model = model.eval().cuda()
SCALE = [0.001, 0.001, 0.001]
pc_tiling = int(args.pc_tiling)
if args.mode == 'inference':
run_inference(model, cfg, transforms_dict, args.root_dir, args.split, SCALE, pc_tiling)
elif args.mode == 'export_log':
export_logits(model, cfg, transforms_dict, args.root_dir, SCALE, pc_tiling)
if __name__ == '__main__':
main()