Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
import einops | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def overlay_heatmap_on_image( | |
image, | |
heatmap: torch.Tensor, | |
save_path="results/heatmap_overlay.pdf", | |
): | |
""" | |
Overlay the given heatmap on the image | |
""" | |
if isinstance(heatmap, torch.Tensor): | |
heatmap = heatmap.to(torch.float32).detach().cpu().numpy() | |
assert len(heatmap.shape) == 2, "Heatmap should be 2D" | |
plt.figure() | |
plt.imshow(image) | |
# Upscale heatmap to image | |
heatmap = F.interpolate( | |
heatmap.unsqueeze(0).unsqueeze(0), | |
size=image.shape[:2], | |
mode="bilinear", | |
align_corners=False | |
) | |
heatmap = heatmap.squeeze(0).squeeze(0).numpy() | |
plt.imshow(heatmap, cmap="jet", alpha=0.5) | |
plt.axis("off") | |
plt.savefig(save_path, dpi=300) | |
def plot_concept_heatmaps( | |
image, | |
concept_basis: torch.Tensor, | |
concept_list: list[str], | |
image_patch_vectors: torch.Tensor, | |
softmax=True, | |
normalize_maps=True | |
): | |
""" | |
Plot the concept heatmaps to ensure that the concept basis is | |
reasonable for the given image. | |
""" | |
assert len(image_patch_vectors.shape) in [4, 5], "Image patch vectors should be 4D or 5D, make sure you include layers and timesteps." | |
fig, axs = plt.subplots(1, len(concept_list) + 1, figsize=(4 * len(concept_list) + 4, 4)) | |
# Normalize the concept basis | |
# concept_basis = concept_basis / concept_basis.norm(dim=-1, keepdim=True) | |
if len(image_patch_vectors.shape) == 5: | |
image_patch_projections = einops.einsum( | |
image_patch_vectors, | |
concept_basis, | |
"layers time heads patches d, layers time heads concepts d -> layers time heads concepts patches", | |
) | |
if softmax: | |
image_patch_projections = torch.softmax(image_patch_projections, dim=-2) | |
image_patch_projections = einops.reduce( | |
image_patch_projections, | |
"layers time heads concepts patches -> concepts patches", | |
reduction="mean" | |
) | |
image_patch_projections = einops.rearrange( | |
image_patch_projections, | |
"concepts (h w) -> concepts h w", | |
h=64, | |
w=64 | |
) | |
else: | |
image_patch_projections = einops.einsum( | |
image_patch_vectors, | |
concept_basis, | |
"layers time patches d, layers time concepts d -> layers time concepts patches", | |
) | |
if softmax: | |
image_patch_projections = torch.softmax(image_patch_projections, dim=-2) | |
image_patch_projections = einops.reduce( | |
image_patch_projections, | |
"layers time concepts patches -> concepts patches", | |
reduction="mean" | |
) | |
image_patch_projections = einops.rearrange( | |
image_patch_projections, | |
"concepts (w h) -> concepts w h", | |
h=64, | |
w=64 | |
) | |
image_patch_projections = image_patch_projections.to(torch.float32).detach().cpu().numpy() | |
# Get min and max values | |
min_val = image_patch_projections.min() | |
max_val = image_patch_projections.max() | |
if len(concept_list) > 30: | |
for concept in concept_list: | |
plt.figure() | |
if normalize_maps: | |
plt.imshow( | |
image_patch_projections[concept_list.index(concept)], | |
cmap="plasma", | |
vmin=min_val, | |
vmax=max_val | |
) | |
else: | |
plt.imshow( | |
image_patch_projections[concept_list.index(concept)], | |
cmap="plasma" | |
) | |
plt.title(concept) | |
plt.savefig(f"results/concept_heatmaps/{concept}.png") | |
plt.close() | |
else: | |
# Plot the image | |
axs[0].imshow(image) | |
axs[0].set_title("Image") | |
axs[0].axis("off") | |
# Plot the concept heatmaps | |
for i, concept in enumerate(concept_list): | |
if normalize_maps: | |
axs[i + 1].imshow( | |
image_patch_projections[i], | |
cmap="plasma", | |
vmin=min_val, | |
vmax=max_val | |
) | |
else: | |
axs[i + 1].imshow( | |
image_patch_projections[i], | |
cmap="plasma" | |
) | |
axs[i + 1].set_title(concept) | |
axs[i + 1].axis("off") | |
# Save the figure | |
plt.savefig("results/concept_heatmaps.png") | |
plt.close() | |
def plot_coefficients_heatmap( | |
coefficients: torch.Tensor, | |
concepts: list[str], | |
save_path="results/group_coding_heatmaps.png" | |
): | |
# Convert the coefficients to a dictionary | |
coefficients = coefficients.detach().cpu().numpy() | |
coefficients = coefficients.T | |
dictionaries = [] | |
for i in range(coefficients.shape[0]): | |
dictionary = {} | |
for j, concept in enumerate(concepts): | |
dictionary[concept] = coefficients[i, j] | |
dictionaries.append(dictionary) | |
# Convert dictionaries to numpy arrays | |
dictionaries = [np.array([dictionary[concept] for concept in concepts]) for dictionary in dictionaries] | |
dictionaries = np.stack(dictionaries, axis=0) | |
dictionaries = einops.rearrange( | |
dictionaries, | |
"(w h) concepts -> concepts w h", | |
w=64, | |
h=64 | |
) | |
# Get min and max | |
min_val = dictionaries.min() | |
max_val = dictionaries.max() | |
# Plot the coeffients of each dictioanry for each patch | |
fig, axs = plt.subplots(1, len(concepts), figsize=(4 * len(concepts), 4)) | |
for concept_index, concept in enumerate(concepts): | |
axs[concept_index].imshow( | |
dictionaries[concept_index], | |
cmap="plasma", | |
# vmin=min_val, | |
# vmax=max_val | |
) | |
axs[concept_index].set_title(concept) | |
axs[concept_index].set_xticks([]) | |
axs[concept_index].set_yticks([]) | |
axs[concept_index].axis("off") | |
plt.savefig(save_path) | |
plt.close() |