File size: 452 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import numpy as np
from typing import Tuple
import torch
class PointShuffle(object):
def __init__(self, is_train):
self.is_train = is_train
def __call__(self, features, targets):
if self.is_train:
points = features['lidar']
cnt = points.shape[0]
idx = torch.randperm(cnt, device=points.device)
features['lidar'] = points[idx]
return features, targets
|