Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
from torch.fft import irfftn, rfftn | |
from torch.nn.functional import grid_sample, log_softmax, pad | |
from .metrics import angle_error | |
from .utils import make_grid, rotmat2d | |
from torchvision.transforms.functional import rotate | |
class UAVTemplateSamplerFast(torch.nn.Module): | |
def __init__(self, num_rotations,w=128,optimize=True): | |
super().__init__() | |
h, w = w,w | |
grid_xy = make_grid( | |
w=w, | |
h=h, | |
step_x=1, | |
step_y=1, | |
orig_y=-h//2, | |
orig_x=-h//2, | |
y_up=True, | |
).cuda() | |
if optimize: | |
assert (num_rotations % 4) == 0 | |
angles = torch.arange( | |
0, 90, 90 / (num_rotations // 4) | |
).cuda() | |
else: | |
angles = torch.arange( | |
0, 360, 360 / num_rotations, device=grid_xz_bev.device | |
) | |
rotmats = rotmat2d(angles / 180 * np.pi) | |
grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy) | |
grid_ij_rot = (grid_xy_rot - grid_xy[..., :1, :1, :]) * grid_xy.new_tensor( | |
[1, -1] | |
) | |
grid_ij_rot = grid_ij_rot | |
grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1 | |
self.optimize = optimize | |
self.num_rots = num_rotations | |
self.register_buffer("angles", angles, persistent=False) | |
self.register_buffer("grid_norm", grid_norm, persistent=False) | |
def forward(self, image_bev): | |
grid = self.grid_norm | |
b, c = image_bev.shape[:2] | |
n, h, w = grid.shape[:3] | |
grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2) | |
image = ( | |
image_bev[:, None] | |
.repeat_interleave(n, 1) | |
.reshape(b * n, *image_bev.shape[1:]) | |
) | |
# print(image.shape,grid.shape,self.grid_norm.shape) | |
kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape( | |
b, n, c, h, w | |
) | |
if self.optimize: # we have computed only the first quadrant | |
kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)] | |
kernels = torch.cat([kernels] + kernels_quad234, 1) | |
return kernels | |
class UAVTemplateSampler(torch.nn.Module): | |
def __init__(self, num_rotations): | |
super().__init__() | |
self.num_rotations = num_rotations | |
def Template(self, input_features): | |
# 角度数量 | |
num_angles = self.num_rotations | |
# 扩展第二个维度为旋转角度数量 | |
input_shape = torch.tensor(input_features.shape) | |
output_shape = torch.cat((input_shape[:1], torch.tensor([num_angles]), input_shape[1:])).tolist() | |
expanded_features = torch.zeros(output_shape,device=input_features.device) | |
# 生成旋转角度序列 | |
rotation_angles = torch.linspace(360, 0, 64 + 1)[:-1] | |
# rotation_angles=torch.flip(rotation_angles, dims=[0]) | |
# 对扩展后的特征应用不同的旋转角度 | |
rotated_features = [] | |
# print(len(rotation_angles)) | |
for i in range(len(rotation_angles)): | |
# print(rotation_angles[i].item()) | |
rotated_feature = rotate(input_features, rotation_angles[i].item(), fill=0) | |
expanded_features[:, i, :, :, :] = rotated_feature | |
# 将所有旋转后的特征堆叠起来形成最终的输出向量 | |
# output_features = torch.stack(rotated_features, dim=1) | |
# 输出向量的维度 | |
# output_size = [3, num_angles, 8, 128, 128] | |
return expanded_features # 输出调试信息,验证输出向量的维度是否正确 | |
def forward(self, image_bev): | |
kernels=self.Template(image_bev) | |
return kernels | |
class TemplateSampler(torch.nn.Module): | |
def __init__(self, grid_xz_bev, ppm, num_rotations, optimize=True): | |
super().__init__() | |
Δ = 1 / ppm | |
h, w = grid_xz_bev.shape[:2] | |
ksize = max(w, h * 2 + 1) | |
radius = ksize * Δ | |
grid_xy = make_grid( | |
radius, | |
radius, | |
step_x=Δ, | |
step_y=Δ, | |
orig_y=(Δ - radius) / 2, | |
orig_x=(Δ - radius) / 2, | |
y_up=True, | |
) | |
if optimize: | |
assert (num_rotations % 4) == 0 | |
angles = torch.arange( | |
0, 90, 90 / (num_rotations // 4), device=grid_xz_bev.device | |
) | |
else: | |
angles = torch.arange( | |
0, 360, 360 / num_rotations, device=grid_xz_bev.device | |
) | |
rotmats = rotmat2d(angles / 180 * np.pi) | |
grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy) | |
grid_ij_rot = (grid_xy_rot - grid_xz_bev[..., :1, :1, :]) * grid_xy.new_tensor( | |
[1, -1] | |
) | |
grid_ij_rot = grid_ij_rot / Δ | |
grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1 | |
self.optimize = optimize | |
self.num_rots = num_rotations | |
self.register_buffer("angles", angles, persistent=False) | |
self.register_buffer("grid_norm", grid_norm, persistent=False) | |
def forward(self, image_bev): | |
grid = self.grid_norm | |
b, c = image_bev.shape[:2] | |
n, h, w = grid.shape[:3] | |
grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2) | |
image = ( | |
image_bev[:, None] | |
.repeat_interleave(n, 1) | |
.reshape(b * n, *image_bev.shape[1:]) | |
) | |
kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape( | |
b, n, c, h, w | |
) | |
if self.optimize: # we have computed only the first quadrant | |
kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)] | |
kernels = torch.cat([kernels] + kernels_quad234, 1) | |
return kernels | |
def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"): | |
if padding == "same": | |
padding = [i // 2 for i in kernel.shape[-2:]] | |
padding_signal = [p for p in padding[::-1] for _ in range(2)] | |
signal = pad(signal, padding_signal, mode=padding_mode) | |
assert signal.size(-1) % 2 == 0 | |
padding_kernel = [ | |
pad for i in [1, 2] for pad in [0, signal.size(-i) - kernel.size(-i)] | |
] | |
kernel_padded = pad(kernel, padding_kernel) | |
signal_fr = rfftn(signal, dim=(-1, -2)) | |
kernel_fr = rfftn(kernel_padded, dim=(-1, -2)) | |
kernel_fr.imag *= -1 # flip the kernel | |
output_fr = torch.einsum("bc...,bdc...->bd...", signal_fr, kernel_fr) | |
output = irfftn(output_fr, dim=(-1, -2)) | |
crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [ | |
slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in [-2, -1] | |
] | |
output = output[crop_slices].contiguous() | |
return output | |
class SparseMapSampler(torch.nn.Module): | |
def __init__(self, num_rotations): | |
super().__init__() | |
angles = torch.arange(0, 360, 360 / self.conf.num_rotations) | |
rotmats = rotmat2d(angles / 180 * np.pi) | |
self.num_rotations = num_rotations | |
self.register_buffer("rotmats", rotmats, persistent=False) | |
def forward(self, image_map, p2d_bev): | |
h, w = image_map.shape[-2:] | |
locations = make_grid(w, h, device=p2d_bev.device) | |
p2d_candidates = torch.einsum( | |
"kji,...i,->...kj", self.rotmats.to(p2d_bev), p2d_bev | |
) | |
p2d_candidates = p2d_candidates[..., None, None, :, :] + locations.unsqueeze(-1) | |
# ... x N x W x H x K x 2 | |
p2d_norm = (p2d_candidates / (image_map.new_tensor([w, h]) - 1)) * 2 - 1 | |
valid = torch.all((p2d_norm >= -1) & (p2d_norm <= 1), -1) | |
value = grid_sample( | |
image_map, p2d_norm.flatten(-4, -2), align_corners=True, mode="bilinear" | |
) | |
value = value.reshape(image_map.shape[:2] + valid.shape[-4]) | |
return valid, value | |
def sample_xyr(volume, xy_grid, angle_grid, nearest_for_inf=False): | |
# (B, C, H, W, N) to (B, C, H, W, N+1) | |
volume_padded = pad(volume, [0, 1, 0, 0, 0, 0], mode="circular") | |
size = xy_grid.new_tensor(volume.shape[-3:-1][::-1]) | |
xy_norm = xy_grid / (size - 1) # align_corners=True | |
angle_norm = (angle_grid / 360) % 1 | |
grid = torch.concat([angle_norm.unsqueeze(-1), xy_norm], -1) | |
grid_norm = grid * 2 - 1 | |
valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1) | |
value = grid_sample(volume_padded, grid_norm, align_corners=True, mode="bilinear") | |
# if one of the values used for linear interpolation is infinite, | |
# we fallback to nearest to avoid propagating inf | |
if nearest_for_inf: | |
value_nearest = grid_sample( | |
volume_padded, grid_norm, align_corners=True, mode="nearest" | |
) | |
value = torch.where(~torch.isfinite(value) & valid, value_nearest, value) | |
return value, valid | |
def nll_loss_xyr(log_probs, xy, angle): | |
log_prob, _ = sample_xyr( | |
log_probs.unsqueeze(1), xy[:, None, None, None], angle[:, None, None, None] | |
) | |
nll = -log_prob.reshape(-1) # remove C,H,W,N | |
return nll | |
def nll_loss_xyr_smoothed(log_probs, xy, angle, sigma_xy, sigma_r, mask=None): | |
*_, nx, ny, nr = log_probs.shape | |
grid_x = torch.arange(nx, device=log_probs.device, dtype=torch.float) | |
dx = (grid_x - xy[..., None, 0]) / sigma_xy | |
grid_y = torch.arange(ny, device=log_probs.device, dtype=torch.float) | |
dy = (grid_y - xy[..., None, 1]) / sigma_xy | |
dr = ( | |
torch.arange(0, 360, 360 / nr, device=log_probs.device, dtype=torch.float) | |
- angle[..., None] | |
) % 360 | |
dr = torch.minimum(dr, 360 - dr) / sigma_r | |
diff = ( | |
dx[..., None, :, None] ** 2 | |
+ dy[..., :, None, None] ** 2 | |
+ dr[..., None, None, :] ** 2 | |
) | |
pdf = torch.exp(-diff / 2) | |
if mask is not None: | |
pdf.masked_fill_(~mask[..., None], 0) | |
log_probs = log_probs.masked_fill(~mask[..., None], 0) | |
pdf /= pdf.sum((-1, -2, -3), keepdim=True) | |
return -torch.sum(pdf * log_probs.to(torch.float), dim=(-1, -2, -3)) | |
def log_softmax_spatial(x, dims=3): | |
return log_softmax(x.flatten(-dims), dim=-1).reshape(x.shape) | |
def argmax_xy(scores: torch.Tensor) -> torch.Tensor: | |
indices = scores.flatten(-2).max(-1).indices | |
width = scores.shape[-1] | |
x = indices % width | |
y = torch.div(indices, width, rounding_mode="floor") | |
return torch.stack((x, y), -1) | |
def expectation_xy(prob: torch.Tensor) -> torch.Tensor: | |
h, w = prob.shape[-2:] | |
grid = make_grid(float(w), float(h), device=prob.device).to(prob) | |
return torch.einsum("...hw,hwd->...d", prob, grid) | |
def expectation_xyr( | |
prob: torch.Tensor, covariance: bool = False | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
h, w, num_rotations = prob.shape[-3:] | |
x, y = torch.meshgrid( | |
[ | |
torch.arange(w, device=prob.device, dtype=prob.dtype), | |
torch.arange(h, device=prob.device, dtype=prob.dtype), | |
], | |
indexing="xy", | |
) | |
grid_xy = torch.stack((x, y), -1) | |
xy_mean = torch.einsum("...hwn,hwd->...d", prob, grid_xy) | |
angles = torch.arange(0, 1, 1 / num_rotations, device=prob.device, dtype=prob.dtype) | |
angles = angles * 2 * np.pi | |
grid_cs = torch.stack([torch.cos(angles), torch.sin(angles)], -1) | |
cs_mean = torch.einsum("...hwn,nd->...d", prob, grid_cs) | |
angle = torch.atan2(cs_mean[..., 1], cs_mean[..., 0]) | |
angle = (angle * 180 / np.pi) % 360 | |
if covariance: | |
xy_cov = torch.einsum("...hwn,...hwd,...hwk->...dk", prob, grid_xy, grid_xy) | |
xy_cov = xy_cov - torch.einsum("...d,...k->...dk", xy_mean, xy_mean) | |
else: | |
xy_cov = None | |
xyr_mean = torch.cat((xy_mean, angle.unsqueeze(-1)), -1) | |
return xyr_mean, xy_cov | |
def argmax_xyr(scores: torch.Tensor) -> torch.Tensor: | |
indices = scores.flatten(-3).max(-1).indices | |
width, num_rotations = scores.shape[-2:] | |
wr = width * num_rotations | |
y = torch.div(indices, wr, rounding_mode="floor") | |
x = torch.div(indices % wr, num_rotations, rounding_mode="floor") | |
angle_index = indices % num_rotations | |
angle = angle_index * 360 / num_rotations | |
xyr = torch.stack((x, y, angle), -1) | |
return xyr | |
def argmax_xyrh(scores: torch.Tensor) -> torch.Tensor: | |
# def argmax_xyr(scores: torch.Tensor) -> torch.Tensor: | |
# 找到最大值及其索引 | |
batch, h, w, num_rotations, height = scores.shape | |
max_values, max_indices = torch.max(scores.view(batch, -1), dim=1) | |
# 将索引转换为对应的坐标 | |
# max_indices = torch.stack([(max_indices % (h * num_rotations * height)) // (num_rotations * height), | |
# max_indices // (h * num_rotations * height), | |
# (max_indices % (num_rotations * height)) // height, | |
# max_indices % height], dim=1) | |
max_indices = torch.stack([ | |
torch.div(max_indices % (h * num_rotations * height), (num_rotations * height), rounding_mode='trunc'), | |
torch.div(max_indices, (h * num_rotations * height), rounding_mode='trunc'), | |
torch.div(max_indices % (num_rotations * height), height, rounding_mode='trunc'), | |
max_indices % height | |
], dim=1) | |
# angle_index = indices % num_rotations | |
# angle = angle_index * 360 / num_rotations | |
angle=max_indices[..., 2]* 360 / num_rotations | |
max_indices[..., 2]=angle | |
return max_indices | |
def mask_yaw_prior( | |
scores: torch.Tensor, yaw_prior: torch.Tensor, num_rotations: int | |
) -> torch.Tensor: | |
step = 360 / num_rotations | |
step_2 = step / 2 | |
angles = torch.arange(step_2, 360 + step_2, step, device=scores.device) | |
yaw_init, yaw_range = yaw_prior.chunk(2, dim=-1) | |
rot_mask = angle_error(angles, yaw_init) < yaw_range | |
return scores.masked_fill_(~rot_mask[:, None, None], -np.inf) | |
def fuse_gps(log_prob, uv_gps, ppm, sigma=10, gaussian=False): | |
grid = make_grid(*log_prob.shape[-3:-1][::-1]).to(log_prob) | |
dist = torch.sum((grid - uv_gps) ** 2, -1) | |
sigma_pixel = sigma * ppm | |
if gaussian: | |
gps_log_prob = -1 / 2 * dist / sigma_pixel**2 | |
else: | |
gps_log_prob = torch.where(dist < sigma_pixel**2, 1, -np.inf) | |
log_prob_fused = log_softmax_spatial(log_prob + gps_log_prob.unsqueeze(-1)) | |
return log_prob_fused | |