Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import MinkowskiEngine as ME | |
| from MinkowskiEngine import SparseTensor | |
| from timm.models.layers import trunc_normal_ | |
| from .mink_layers import MinkConvBNRelu, MinkResBlock | |
| from .swin3d_layers import GridDownsample, GridKNNDownsample, BasicLayer, Upsample | |
| from pointcept.models.builder import MODELS | |
| from pointcept.models.utils import offset2batch, batch2offset | |
| class Swin3DUNet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| num_classes, | |
| base_grid_size, | |
| depths, | |
| channels, | |
| num_heads, | |
| window_sizes, | |
| quant_size, | |
| drop_path_rate=0.2, | |
| up_k=3, | |
| num_layers=5, | |
| stem_transformer=True, | |
| down_stride=2, | |
| upsample="linear", | |
| knn_down=True, | |
| cRSE="XYZ_RGB", | |
| fp16_mode=0, | |
| ): | |
| super().__init__() | |
| dpr = [ | |
| x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |
| ] # stochastic depth decay rule | |
| if knn_down: | |
| downsample = GridKNNDownsample | |
| else: | |
| downsample = GridDownsample | |
| self.cRSE = cRSE | |
| if stem_transformer: | |
| self.stem_layer = MinkConvBNRelu( | |
| in_channels=in_channels, | |
| out_channels=channels[0], | |
| kernel_size=3, | |
| stride=1, | |
| ) | |
| self.layer_start = 0 | |
| else: | |
| self.stem_layer = nn.Sequential( | |
| MinkConvBNRelu( | |
| in_channels=in_channels, | |
| out_channels=channels[0], | |
| kernel_size=3, | |
| stride=1, | |
| ), | |
| MinkResBlock(in_channels=channels[0], out_channels=channels[0]), | |
| ) | |
| self.downsample = downsample( | |
| channels[0], channels[1], kernel_size=down_stride, stride=down_stride | |
| ) | |
| self.layer_start = 1 | |
| self.layers = nn.ModuleList( | |
| [ | |
| BasicLayer( | |
| dim=channels[i], | |
| depth=depths[i], | |
| num_heads=num_heads[i], | |
| window_size=window_sizes[i], | |
| quant_size=quant_size, | |
| drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], | |
| downsample=downsample if i < num_layers - 1 else None, | |
| down_stride=down_stride if i == 0 else 2, | |
| out_channels=channels[i + 1] if i < num_layers - 1 else None, | |
| cRSE=cRSE, | |
| fp16_mode=fp16_mode, | |
| ) | |
| for i in range(self.layer_start, num_layers) | |
| ] | |
| ) | |
| if "attn" in upsample: | |
| up_attn = True | |
| else: | |
| up_attn = False | |
| self.upsamples = nn.ModuleList( | |
| [ | |
| Upsample( | |
| channels[i], | |
| channels[i - 1], | |
| num_heads[i - 1], | |
| window_sizes[i - 1], | |
| quant_size, | |
| attn=up_attn, | |
| up_k=up_k, | |
| cRSE=cRSE, | |
| fp16_mode=fp16_mode, | |
| ) | |
| for i in range(num_layers - 1, 0, -1) | |
| ] | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(channels[0], channels[0]), | |
| nn.BatchNorm1d(channels[0]), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(channels[0], num_classes), | |
| ) | |
| self.num_classes = num_classes | |
| self.base_grid_size = base_grid_size | |
| self.init_weights() | |
| def forward(self, data_dict): | |
| grid_coord = data_dict["grid_coord"] | |
| feat = data_dict["feat"] | |
| coord_feat = data_dict["coord_feat"] | |
| coord = data_dict["coord"] | |
| offset = data_dict["offset"] | |
| batch = offset2batch(offset) | |
| in_field = ME.TensorField( | |
| features=torch.cat( | |
| [ | |
| batch.unsqueeze(-1), | |
| coord / self.base_grid_size, | |
| coord_feat / 1.001, | |
| feat, | |
| ], | |
| dim=1, | |
| ), | |
| coordinates=torch.cat([batch.unsqueeze(-1).int(), grid_coord.int()], dim=1), | |
| quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, | |
| minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, | |
| device=feat.device, | |
| ) | |
| sp = in_field.sparse() | |
| coords_sp = SparseTensor( | |
| features=sp.F[:, : coord_feat.shape[-1] + 4], | |
| coordinate_map_key=sp.coordinate_map_key, | |
| coordinate_manager=sp.coordinate_manager, | |
| ) | |
| sp = SparseTensor( | |
| features=sp.F[:, coord_feat.shape[-1] + 4 :], | |
| coordinate_map_key=sp.coordinate_map_key, | |
| coordinate_manager=sp.coordinate_manager, | |
| ) | |
| sp_stack = [] | |
| coords_sp_stack = [] | |
| sp = self.stem_layer(sp) | |
| if self.layer_start > 0: | |
| sp_stack.append(sp) | |
| coords_sp_stack.append(coords_sp) | |
| sp, coords_sp = self.downsample(sp, coords_sp) | |
| for i, layer in enumerate(self.layers): | |
| coords_sp_stack.append(coords_sp) | |
| sp, sp_down, coords_sp = layer(sp, coords_sp) | |
| sp_stack.append(sp) | |
| assert (coords_sp.C == sp_down.C).all() | |
| sp = sp_down | |
| sp = sp_stack.pop() | |
| coords_sp = coords_sp_stack.pop() | |
| for i, upsample in enumerate(self.upsamples): | |
| sp_i = sp_stack.pop() | |
| coords_sp_i = coords_sp_stack.pop() | |
| sp = upsample(sp, coords_sp, sp_i, coords_sp_i) | |
| coords_sp = coords_sp_i | |
| output = self.classifier(sp.slice(in_field).F) | |
| return output | |
| def init_weights(self): | |
| """Initialize the weights in backbone.""" | |
| def _init_weights(m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm1d): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| self.apply(_init_weights) | |