Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import copy | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
def likelihood_overlay( | |
prob, map_viz=None, p_rgb=0.2, p_alpha=1 / 15, thresh=None, cmap="jet" | |
): | |
prob = prob / prob.max() | |
cmap = plt.get_cmap(cmap) | |
rgb = cmap(prob**p_rgb) | |
alpha = prob[..., None] ** p_alpha | |
if thresh is not None: | |
alpha[prob <= thresh] = 0 | |
if map_viz is not None: | |
faded = map_viz + (1 - map_viz) * 0.5 | |
rgb = rgb[..., :3] * alpha + faded * (1 - alpha) | |
rgb = np.clip(rgb, 0, 1) | |
else: | |
rgb[..., -1] = alpha.squeeze(-1) | |
return rgb | |
def heatmap2rgb(scores, mask=None, clip_min=0.05, alpha=0.8, cmap="jet"): | |
min_, max_ = np.quantile(scores, [clip_min, 1]) | |
scores = scores.clip(min=min_) | |
rgb = plt.get_cmap(cmap)((scores - min_) / (max_ - min_)) | |
if mask is not None: | |
if alpha == 0: | |
rgb[mask] = np.nan | |
else: | |
rgb[..., -1] = 1 - (1 - 1.0 * mask) * (1 - alpha) | |
return rgb | |
def plot_pose(axs, xy, yaw=None, s=1 / 35, c="r", a=1, w=0.015, dot=True, zorder=10): | |
if yaw is not None: | |
yaw = np.deg2rad(yaw) | |
uv = np.array([np.sin(yaw), -np.cos(yaw)]) | |
xy = np.array(xy) + 0.5 | |
if not isinstance(axs, list): | |
axs = [axs] | |
for ax in axs: | |
if isinstance(ax, int): | |
ax = plt.gcf().axes[ax] | |
if dot: | |
ax.scatter(*xy, c=c, s=70, zorder=zorder, linewidths=0, alpha=a) | |
if yaw is not None: | |
ax.quiver( | |
*xy, | |
*uv, | |
scale=s, | |
scale_units="xy", | |
angles="xy", | |
color=c, | |
zorder=zorder, | |
alpha=a, | |
width=w, | |
) | |
def plot_dense_rotations( | |
ax, prob, thresh=0.01, skip=10, s=1 / 15, k=3, c="k", w=None, **kwargs | |
): | |
t = torch.argmax(prob, -1) | |
yaws = t.numpy() / prob.shape[-1] * 360 | |
prob = prob.max(-1).values / prob.max() | |
mask = prob > thresh | |
masked = prob.masked_fill(~mask, 0) | |
max_ = torch.nn.functional.max_pool2d( | |
masked.float()[None, None], k, stride=1, padding=k // 2 | |
) | |
mask = (max_[0, 0] == masked.float()) & mask | |
indices = np.where(mask.numpy() > 0) | |
plot_pose( | |
ax, | |
indices[::-1], | |
yaws[indices], | |
s=s, | |
c=c, | |
dot=False, | |
zorder=0.1, | |
w=w, | |
**kwargs, | |
) | |
def copy_image(im, ax): | |
prop = im.properties() | |
prop.pop("children") | |
prop.pop("size") | |
prop.pop("tightbbox") | |
prop.pop("transformed_clip_path_and_affine") | |
prop.pop("window_extent") | |
prop.pop("figure") | |
prop.pop("transform") | |
return ax.imshow(im.get_array(), **prop) | |
def add_circle_inset( | |
ax, | |
center, | |
corner=None, | |
radius_px=10, | |
inset_size=0.4, | |
inset_offset=0.005, | |
color="red", | |
): | |
data_t_axes = ax.transAxes + ax.transData.inverted() | |
if corner is None: | |
center_axes = np.array(data_t_axes.inverted().transform(center)) | |
corner = 1 - np.round(center_axes).astype(int) | |
corner = np.array(corner) | |
bottom_left = corner * (1 - inset_size - inset_offset) + (1 - corner) * inset_offset | |
axins = ax.inset_axes([*bottom_left, inset_size, inset_size]) | |
if ax.yaxis_inverted(): | |
axins.invert_yaxis() | |
axins.set_axis_off() | |
c = mpl.patches.Circle(center, radius_px, fill=False, color=color) | |
c1 = mpl.patches.Circle(center, radius_px, fill=False, color=color) | |
# ax.add_patch(c) | |
ax.add_patch(c1) | |
# ax.add_patch(c.frozen()) | |
axins.add_patch(c) | |
radius_inset = radius_px + 1 | |
axins.set_xlim([center[0] - radius_inset, center[0] + radius_inset]) | |
ylim = center[1] - radius_inset, center[1] + radius_inset | |
if axins.yaxis_inverted(): | |
ylim = ylim[::-1] | |
axins.set_ylim(ylim) | |
for im in ax.images: | |
im2 = copy_image(im, axins) | |
im2.set_clip_path(c) | |
return axins | |
def plot_bev(bev, uv, yaw, ax=None, zorder=10, **kwargs): | |
if ax is None: | |
ax = plt.gca() | |
h, w = bev.shape[:2] | |
tfm = mpl.transforms.Affine2D().translate(-w / 2, -h) | |
tfm = tfm.rotate_deg(yaw).translate(*uv + 0.5) | |
tfm += plt.gca().transData | |
ax.imshow(bev, transform=tfm, zorder=zorder, **kwargs) | |
ax.plot( | |
[0, w - 1, w / 2, 0], | |
[0, 0, h - 0.5, 0], | |
transform=tfm, | |
c="k", | |
lw=1, | |
zorder=zorder + 1, | |
) | |