|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SeedBinRegressor(nn.Module): | 
					
						
						|  | def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): | 
					
						
						|  | """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | in_features (int): input channels | 
					
						
						|  | n_bins (int, optional): Number of bin centers. Defaults to 16. | 
					
						
						|  | mlp_dim (int, optional): Hidden dimension. Defaults to 256. | 
					
						
						|  | min_depth (float, optional): Min depth value. Defaults to 1e-3. | 
					
						
						|  | max_depth (float, optional): Max depth value. Defaults to 10. | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.version = "1_1" | 
					
						
						|  | self.min_depth = min_depth | 
					
						
						|  | self.max_depth = max_depth | 
					
						
						|  |  | 
					
						
						|  | self._net = nn.Sequential( | 
					
						
						|  | nn.Conv2d(in_features, mlp_dim, 1, 1, 0), | 
					
						
						|  | nn.ReLU(inplace=True), | 
					
						
						|  | nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), | 
					
						
						|  | nn.ReLU(inplace=True) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | Returns tensor of bin_width vectors (centers). One vector b for every pixel | 
					
						
						|  | """ | 
					
						
						|  | B = self._net(x) | 
					
						
						|  | eps = 1e-3 | 
					
						
						|  | B = B + eps | 
					
						
						|  | B_widths_normed = B / B.sum(dim=1, keepdim=True) | 
					
						
						|  | B_widths = (self.max_depth - self.min_depth) * \ | 
					
						
						|  | B_widths_normed | 
					
						
						|  |  | 
					
						
						|  | B_widths = nn.functional.pad( | 
					
						
						|  | B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) | 
					
						
						|  | B_edges = torch.cumsum(B_widths, dim=1) | 
					
						
						|  |  | 
					
						
						|  | B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) | 
					
						
						|  | return B_widths_normed, B_centers | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SeedBinRegressorUnnormed(nn.Module): | 
					
						
						|  | def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): | 
					
						
						|  | """Bin center regressor network. Bin centers are unbounded | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | in_features (int): input channels | 
					
						
						|  | n_bins (int, optional): Number of bin centers. Defaults to 16. | 
					
						
						|  | mlp_dim (int, optional): Hidden dimension. Defaults to 256. | 
					
						
						|  | min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) | 
					
						
						|  | max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.version = "1_1" | 
					
						
						|  | self._net = nn.Sequential( | 
					
						
						|  | nn.Conv2d(in_features, mlp_dim, 1, 1, 0), | 
					
						
						|  | nn.ReLU(inplace=True), | 
					
						
						|  | nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), | 
					
						
						|  | nn.Softplus() | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | Returns tensor of bin_width vectors (centers). One vector b for every pixel | 
					
						
						|  | """ | 
					
						
						|  | B_centers = self._net(x) | 
					
						
						|  | return B_centers, B_centers | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Projector(nn.Module): | 
					
						
						|  | def __init__(self, in_features, out_features, mlp_dim=128): | 
					
						
						|  | """Projector MLP | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | in_features (int): input channels | 
					
						
						|  | out_features (int): output channels | 
					
						
						|  | mlp_dim (int, optional): hidden dimension. Defaults to 128. | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self._net = nn.Sequential( | 
					
						
						|  | nn.Conv2d(in_features, mlp_dim, 1, 1, 0), | 
					
						
						|  | nn.ReLU(inplace=True), | 
					
						
						|  | nn.Conv2d(mlp_dim, out_features, 1, 1, 0), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return self._net(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LinearSplitter(nn.Module): | 
					
						
						|  | def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.prev_nbins = prev_nbins | 
					
						
						|  | self.split_factor = split_factor | 
					
						
						|  | self.min_depth = min_depth | 
					
						
						|  | self.max_depth = max_depth | 
					
						
						|  |  | 
					
						
						|  | self._net = nn.Sequential( | 
					
						
						|  | nn.Conv2d(in_features, mlp_dim, 1, 1, 0), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), | 
					
						
						|  | nn.ReLU() | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): | 
					
						
						|  | """ | 
					
						
						|  | x : feature block; shape - n, c, h, w | 
					
						
						|  | b_prev : previous bin widths normed; shape - n, prev_nbins, h, w | 
					
						
						|  | """ | 
					
						
						|  | if prev_b_embedding is not None: | 
					
						
						|  | if interpolate: | 
					
						
						|  | prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) | 
					
						
						|  | x = x + prev_b_embedding | 
					
						
						|  | S = self._net(x) | 
					
						
						|  | eps = 1e-3 | 
					
						
						|  | S = S + eps | 
					
						
						|  | n, c, h, w = S.shape | 
					
						
						|  | S = S.view(n, self.prev_nbins, self.split_factor, h, w) | 
					
						
						|  | S_normed = S / S.sum(dim=2, keepdim=True) | 
					
						
						|  |  | 
					
						
						|  | b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | b = b_prev.unsqueeze(2) * S_normed | 
					
						
						|  | b = b.flatten(1,2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | B_widths = (self.max_depth - self.min_depth) * b | 
					
						
						|  |  | 
					
						
						|  | B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) | 
					
						
						|  | B_edges = torch.cumsum(B_widths, dim=1) | 
					
						
						|  |  | 
					
						
						|  | B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) | 
					
						
						|  | return b, B_centers |