|
""" |
|
pytorch grid_sample doesn't support second-order derivative |
|
implement custom version |
|
""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
|
|
def grid_sample_2d(image, optical): |
|
N, C, IH, IW = image.shape |
|
_, H, W, _ = optical.shape |
|
|
|
ix = optical[..., 0] |
|
iy = optical[..., 1] |
|
|
|
ix = ((ix + 1) / 2) * (IW - 1); |
|
iy = ((iy + 1) / 2) * (IH - 1); |
|
with torch.no_grad(): |
|
ix_nw = torch.floor(ix); |
|
iy_nw = torch.floor(iy); |
|
ix_ne = ix_nw + 1; |
|
iy_ne = iy_nw; |
|
ix_sw = ix_nw; |
|
iy_sw = iy_nw + 1; |
|
ix_se = ix_nw + 1; |
|
iy_se = iy_nw + 1; |
|
|
|
nw = (ix_se - ix) * (iy_se - iy) |
|
ne = (ix - ix_sw) * (iy_sw - iy) |
|
sw = (ix_ne - ix) * (iy - iy_ne) |
|
se = (ix - ix_nw) * (iy - iy_nw) |
|
|
|
with torch.no_grad(): |
|
torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) |
|
torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) |
|
|
|
torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) |
|
torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) |
|
|
|
torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) |
|
torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) |
|
|
|
torch.clamp(ix_se, 0, IW - 1, out=ix_se) |
|
torch.clamp(iy_se, 0, IH - 1, out=iy_se) |
|
|
|
image = image.view(N, C, IH * IW) |
|
|
|
nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) |
|
ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) |
|
sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) |
|
se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) |
|
|
|
out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + |
|
ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + |
|
sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + |
|
se_val.view(N, C, H, W) * se.view(N, 1, H, W)) |
|
|
|
return out_val |
|
|
|
|
|
|
|
def grid_sample_3d(volume, optical): |
|
""" |
|
bilinear sampling cannot guarantee continuous first-order gradient |
|
mimic pytorch grid_sample function |
|
The 8 corner points of a volume noted as: 4 points (front view); 4 points (back view) |
|
fnw (front north west) point |
|
bse (back south east) point |
|
:param volume: [B, C, X, Y, Z] |
|
:param optical: [B, x, y, z, 3] |
|
:return: |
|
""" |
|
N, C, ID, IH, IW = volume.shape |
|
_, D, H, W, _ = optical.shape |
|
|
|
ix = optical[..., 0] |
|
iy = optical[..., 1] |
|
iz = optical[..., 2] |
|
|
|
ix = ((ix + 1) / 2) * (IW - 1) |
|
iy = ((iy + 1) / 2) * (IH - 1) |
|
iz = ((iz + 1) / 2) * (ID - 1) |
|
|
|
mask_x = (ix > 0) & (ix < IW) |
|
mask_y = (iy > 0) & (iy < IH) |
|
mask_z = (iz > 0) & (iz < ID) |
|
|
|
mask = mask_x & mask_y & mask_z |
|
mask = mask[:, None, :, :, :].repeat(1, C, 1, 1, 1) |
|
|
|
with torch.no_grad(): |
|
|
|
ix_bnw = torch.floor(ix) |
|
iy_bnw = torch.floor(iy) |
|
iz_bnw = torch.floor(iz) |
|
|
|
ix_bne = ix_bnw + 1 |
|
iy_bne = iy_bnw |
|
iz_bne = iz_bnw |
|
|
|
ix_bsw = ix_bnw |
|
iy_bsw = iy_bnw + 1 |
|
iz_bsw = iz_bnw |
|
|
|
ix_bse = ix_bnw + 1 |
|
iy_bse = iy_bnw + 1 |
|
iz_bse = iz_bnw |
|
|
|
|
|
ix_fnw = ix_bnw |
|
iy_fnw = iy_bnw |
|
iz_fnw = iz_bnw + 1 |
|
|
|
ix_fne = ix_bnw + 1 |
|
iy_fne = iy_bnw |
|
iz_fne = iz_bnw + 1 |
|
|
|
ix_fsw = ix_bnw |
|
iy_fsw = iy_bnw + 1 |
|
iz_fsw = iz_bnw + 1 |
|
|
|
ix_fse = ix_bnw + 1 |
|
iy_fse = iy_bnw + 1 |
|
iz_fse = iz_bnw + 1 |
|
|
|
|
|
bnw = (ix_fse - ix) * (iy_fse - iy) * (iz_fse - iz) |
|
bne = (ix - ix_fsw) * (iy_fsw - iy) * (iz_fsw - iz) |
|
bsw = (ix_fne - ix) * (iy - iy_fne) * (iz_fne - iz) |
|
bse = (ix - ix_fnw) * (iy - iy_fnw) * (iz_fnw - iz) |
|
|
|
|
|
fnw = (ix_bse - ix) * (iy_bse - iy) * (iz - iz_bse) |
|
fne = (ix - ix_bsw) * (iy_bsw - iy) * (iz - iz_bsw) |
|
fsw = (ix_bne - ix) * (iy - iy_bne) * (iz - iz_bne) |
|
fse = (ix - ix_bnw) * (iy - iy_bnw) * (iz - iz_bnw) |
|
|
|
with torch.no_grad(): |
|
|
|
torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) |
|
torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) |
|
torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) |
|
|
|
torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) |
|
torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) |
|
torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) |
|
|
|
torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) |
|
torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) |
|
torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) |
|
|
|
torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) |
|
torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) |
|
torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) |
|
|
|
|
|
torch.clamp(ix_fnw, 0, IW - 1, out=ix_fnw) |
|
torch.clamp(iy_fnw, 0, IH - 1, out=iy_fnw) |
|
torch.clamp(iz_fnw, 0, ID - 1, out=iz_fnw) |
|
|
|
torch.clamp(ix_fne, 0, IW - 1, out=ix_fne) |
|
torch.clamp(iy_fne, 0, IH - 1, out=iy_fne) |
|
torch.clamp(iz_fne, 0, ID - 1, out=iz_fne) |
|
|
|
torch.clamp(ix_fsw, 0, IW - 1, out=ix_fsw) |
|
torch.clamp(iy_fsw, 0, IH - 1, out=iy_fsw) |
|
torch.clamp(iz_fsw, 0, ID - 1, out=iz_fsw) |
|
|
|
torch.clamp(ix_fse, 0, IW - 1, out=ix_fse) |
|
torch.clamp(iy_fse, 0, IH - 1, out=iy_fse) |
|
torch.clamp(iz_fse, 0, ID - 1, out=iz_fse) |
|
|
|
|
|
volume = volume.view(N, C, ID * IH * IW) |
|
|
|
|
|
|
|
bnw_val = torch.gather(volume, 2, |
|
(iz_bnw * ID ** 2 + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
bne_val = torch.gather(volume, 2, |
|
(iz_bne * ID ** 2 + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
bsw_val = torch.gather(volume, 2, |
|
(iz_bsw * ID ** 2 + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
bse_val = torch.gather(volume, 2, |
|
(iz_bse * ID ** 2 + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
|
|
|
|
fnw_val = torch.gather(volume, 2, |
|
(iz_fnw * ID ** 2 + iy_fnw * IW + ix_fnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
fne_val = torch.gather(volume, 2, |
|
(iz_fne * ID ** 2 + iy_fne * IW + ix_fne).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
fsw_val = torch.gather(volume, 2, |
|
(iz_fsw * ID ** 2 + iy_fsw * IW + ix_fsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
fse_val = torch.gather(volume, 2, |
|
(iz_fse * ID ** 2 + iy_fse * IW + ix_fse).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
|
|
|
out_val = ( |
|
|
|
bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + |
|
bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + |
|
bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + |
|
bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) + |
|
|
|
fnw_val.view(N, C, D, H, W) * fnw.view(N, 1, D, H, W) + |
|
fne_val.view(N, C, D, H, W) * fne.view(N, 1, D, H, W) + |
|
fsw_val.view(N, C, D, H, W) * fsw.view(N, 1, D, H, W) + |
|
fse_val.view(N, C, D, H, W) * fse.view(N, 1, D, H, W) |
|
|
|
) |
|
|
|
|
|
out_val = torch.where(mask, out_val, torch.zeros_like(out_val).float().to(out_val.device)) |
|
|
|
return out_val |
|
|
|
|
|
|
|
def get_weight(s, a=-0.5): |
|
mask_0 = (torch.abs(s) >= 0) & (torch.abs(s) <= 1) |
|
mask_1 = (torch.abs(s) > 1) & (torch.abs(s) <= 2) |
|
mask_2 = torch.abs(s) > 2 |
|
|
|
weight = torch.zeros_like(s).to(s.device) |
|
weight = torch.where(mask_0, (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1, weight) |
|
weight = torch.where(mask_1, |
|
a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a, |
|
weight) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return weight |
|
|
|
|
|
def cubic_interpolate(p, x): |
|
""" |
|
one dimensional cubic interpolation |
|
:param p: [N, 4] (4) should be in order |
|
:param x: [N] |
|
:return: |
|
""" |
|
return p[:, 1] + 0.5 * x * (p[:, 2] - p[:, 0] + x * ( |
|
2.0 * p[:, 0] - 5.0 * p[:, 1] + 4.0 * p[:, 2] - p[:, 3] + x * ( |
|
3.0 * (p[:, 1] - p[:, 2]) + p[:, 3] - p[:, 0]))) |
|
|
|
|
|
def bicubic_interpolate(p, x, y, if_batch=True): |
|
""" |
|
two dimensional cubic interpolation |
|
:param p: [N, 4, 4] |
|
:param x: [N] |
|
:param y: [N] |
|
:return: |
|
""" |
|
num = p.shape[0] |
|
|
|
if not if_batch: |
|
arr0 = cubic_interpolate(p[:, 0, :], x) |
|
arr1 = cubic_interpolate(p[:, 1, :], x) |
|
arr2 = cubic_interpolate(p[:, 2, :], x) |
|
arr3 = cubic_interpolate(p[:, 3, :], x) |
|
return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), y) |
|
else: |
|
x = x[:, None].repeat(1, 4).view(-1) |
|
p = p.contiguous().view(num * 4, 4) |
|
arr = cubic_interpolate(p, x) |
|
arr = arr.view(num, 4) |
|
|
|
return cubic_interpolate(arr, y) |
|
|
|
|
|
def tricubic_interpolate(p, x, y, z): |
|
""" |
|
three dimensional cubic interpolation |
|
:param p: [N,4,4,4] |
|
:param x: [N] |
|
:param y: [N] |
|
:param z: [N] |
|
:return: |
|
""" |
|
num = p.shape[0] |
|
|
|
arr0 = bicubic_interpolate(p[:, 0, :, :], x, y) |
|
arr1 = bicubic_interpolate(p[:, 1, :, :], x, y) |
|
arr2 = bicubic_interpolate(p[:, 2, :, :], x, y) |
|
arr3 = bicubic_interpolate(p[:, 3, :, :], x, y) |
|
|
|
return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), z) |
|
|
|
|
|
def cubic_interpolate_batch(p, x): |
|
""" |
|
one dimensional cubic interpolation |
|
:param p: [B, N, 4] (4) should be in order |
|
:param x: [B, N] |
|
:return: |
|
""" |
|
return p[:, :, 1] + 0.5 * x * (p[:, :, 2] - p[:, :, 0] + x * ( |
|
2.0 * p[:, :, 0] - 5.0 * p[:, :, 1] + 4.0 * p[:, :, 2] - p[:, :, 3] + x * ( |
|
3.0 * (p[:, :, 1] - p[:, :, 2]) + p[:, :, 3] - p[:, :, 0]))) |
|
|
|
|
|
def bicubic_interpolate_batch(p, x, y): |
|
""" |
|
two dimensional cubic interpolation |
|
:param p: [B, N, 4, 4] |
|
:param x: [B, N] |
|
:param y: [B, N] |
|
:return: |
|
""" |
|
B, N, _, _ = p.shape |
|
|
|
x = x[:, :, None].repeat(1, 1, 4).view(B, N * 4) |
|
arr = cubic_interpolate_batch(p.contiguous().view(B, N * 4, 4), x) |
|
arr = arr.view(B, N, 4) |
|
return cubic_interpolate_batch(arr, y) |
|
|
|
|
|
|
|
def tricubic_interpolate_batch(p, x, y, z): |
|
""" |
|
three dimensional cubic interpolation |
|
:param p: [N,4,4,4] |
|
:param x: [N] |
|
:param y: [N] |
|
:param z: [N] |
|
:return: |
|
""" |
|
N = p.shape[0] |
|
|
|
x = x[None, :].repeat(4, 1) |
|
y = y[None, :].repeat(4, 1) |
|
|
|
p = p.permute(1, 0, 2, 3).contiguous() |
|
|
|
arr = bicubic_interpolate_batch(p[:, :, :, :], x, y) |
|
|
|
arr = arr.permute(1, 0).contiguous() |
|
|
|
return cubic_interpolate(arr, z) |
|
|
|
|
|
def tricubic_sample_3d(volume, optical): |
|
""" |
|
tricubic sampling; can guarantee continuous gradient (interpolation border) |
|
:param volume: [B, C, ID, IH, IW] |
|
:param optical: [B, D, H, W, 3] |
|
:param sample_num: |
|
:return: |
|
""" |
|
|
|
@torch.no_grad() |
|
def get_shifts(x): |
|
x1 = -1 * (1 + x - torch.floor(x)) |
|
x2 = -1 * (x - torch.floor(x)) |
|
x3 = torch.floor(x) + 1 - x |
|
x4 = torch.floor(x) + 2 - x |
|
|
|
return torch.stack([x1, x2, x3, x4], dim=-1) |
|
|
|
N, C, ID, IH, IW = volume.shape |
|
_, D, H, W, _ = optical.shape |
|
|
|
device = volume.device |
|
|
|
ix = optical[..., 0] |
|
iy = optical[..., 1] |
|
iz = optical[..., 2] |
|
|
|
ix = ((ix + 1) / 2) * (IW - 1) |
|
iy = ((iy + 1) / 2) * (IH - 1) |
|
iz = ((iz + 1) / 2) * (ID - 1) |
|
|
|
ix = ix.view(-1) |
|
iy = iy.view(-1) |
|
iz = iz.view(-1) |
|
|
|
with torch.no_grad(): |
|
shifts_x = get_shifts(ix).view(-1, 4) |
|
shifts_y = get_shifts(iy).view(-1, 4) |
|
shifts_z = get_shifts(iz).view(-1, 4) |
|
|
|
perm_weights = torch.ones([N * D * H * W, 4 * 4 * 4]).long().to(device) |
|
perm = torch.cumsum(perm_weights, dim=-1) - 1 |
|
|
|
perm_z = perm // 16 |
|
perm_y = (perm - perm_z * 16) // 4 |
|
perm_x = (perm - perm_z * 16 - perm_y * 4) |
|
|
|
shifts_x = torch.gather(shifts_x, 1, perm_x) |
|
shifts_y = torch.gather(shifts_y, 1, perm_y) |
|
shifts_z = torch.gather(shifts_z, 1, perm_z) |
|
|
|
ix_target = (ix[:, None] + shifts_x).long() |
|
iy_target = (iy[:, None] + shifts_y).long() |
|
iz_target = (iz[:, None] + shifts_z).long() |
|
|
|
torch.clamp(ix_target, 0, IW - 1, out=ix_target) |
|
torch.clamp(iy_target, 0, IH - 1, out=iy_target) |
|
torch.clamp(iz_target, 0, ID - 1, out=iz_target) |
|
|
|
local_dist_x = ix - ix_target[:, 1] |
|
local_dist_y = iy - iy_target[:, 1 + 4] |
|
local_dist_z = iz - iz_target[:, 1 + 16] |
|
|
|
local_dist_x = local_dist_x.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) |
|
local_dist_y = local_dist_y.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) |
|
local_dist_z = local_dist_z.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) |
|
|
|
|
|
idx_target = iz_target * ID ** 2 + iy_target * IW + ix_target |
|
|
|
volume = volume.view(N, C, ID * IH * IW) |
|
|
|
out = torch.gather(volume, 2, |
|
idx_target.view(N, 1, D * H * W * 64).repeat(1, C, 1)) |
|
out = out.view(N * C * D * H * W, 4, 4, 4) |
|
|
|
|
|
final = tricubic_interpolate(out, local_dist_x, local_dist_y, local_dist_z).view(N, C, D, H, W) |
|
|
|
return final |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ops.generate_grids import generate_grid |
|
|
|
p = torch.tensor([x for x in range(4)]).view(1, 4).float() |
|
|
|
v = cubic_interpolate(p, torch.tensor([0.5]).view(1)) |
|
|
|
|
|
vsize = 9 |
|
volume = generate_grid([vsize, vsize, vsize], 1) |
|
|
|
X, Y, Z = 0, 0, 6 |
|
x = 2 * X / (vsize - 1) - 1 |
|
y = 2 * Y / (vsize - 1) - 1 |
|
z = 2 * Z / (vsize - 1) - 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
optical = torch.Tensor([-0.6, -0.7, 0.5, 0.3, 0.5, 0.5]).view(1, 1, 1, 2, 3) |
|
|
|
print(F.grid_sample(volume, optical, padding_mode='border', align_corners=True)) |
|
print(grid_sample_3d(volume, optical)) |
|
print(tricubic_sample_3d(volume, optical)) |
|
|
|
|
|
|