BiomedParse / modeling /utils /interactive.py
kernel-luso-comfort's picture
Add initial module structure and entry points for modeling and utilities
6ba63c9
raw
history blame
1.65 kB
import os
import copy
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F
def rand_sample(x, divisor, max_len):
# non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
if len(x.nonzero()) == 0:
return x.nonzero().t()
non_zero_point_index = (x.nonzero()/divisor).t()
mask_ids = non_zero_point_index[0].unique().long()
# compute probability for each samle
probs = torch.zeros_like(non_zero_point_index[0])
for idx in mask_ids:
prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum()))
probs[non_zero_point_index[0]==idx] = prob
indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0]
non_zero_point_index = non_zero_point_index[:,indices]
return non_zero_point_index # [n, 512]
def rand_sample_plain(x, max_len):
if x.shape[1] <= max_len:
return x
else:
rand_idx = torch.randperm(x.shape[1])[:max_len]
return x[:,rand_idx]
def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed):
src = []
pos = []
size_list = []
# disable mask, it does not affect performance
for i in range(num_feature_levels):
size_list.append(x[i].shape[-2:])
pos.append(pe_layer(x[i], None).flatten(2))
src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None])
# flatten NxCxHxW to HWxNxC
pos[-1] = pos[-1].permute(2, 0, 1)
src[-1] = src[-1].permute(2, 0, 1)
return src, pos, size_list