File size: 452 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np
from typing import Tuple

import torch


class PointShuffle(object):
    def __init__(self, is_train):
        self.is_train = is_train

    def __call__(self, features, targets):
        if self.is_train:
            points = features['lidar']
            cnt = points.shape[0]
            idx = torch.randperm(cnt, device=points.device)
            features['lidar'] = points[idx]
        return features, targets