helblazer811's picture
"Orphan branch commit with a readme"
55866f4
import torch
import torch.nn.functional as F
import einops
import matplotlib.pyplot as plt
import numpy as np
def overlay_heatmap_on_image(
image,
heatmap: torch.Tensor,
save_path="results/heatmap_overlay.pdf",
):
"""
Overlay the given heatmap on the image
"""
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.to(torch.float32).detach().cpu().numpy()
assert len(heatmap.shape) == 2, "Heatmap should be 2D"
plt.figure()
plt.imshow(image)
# Upscale heatmap to image
heatmap = F.interpolate(
heatmap.unsqueeze(0).unsqueeze(0),
size=image.shape[:2],
mode="bilinear",
align_corners=False
)
heatmap = heatmap.squeeze(0).squeeze(0).numpy()
plt.imshow(heatmap, cmap="jet", alpha=0.5)
plt.axis("off")
plt.savefig(save_path, dpi=300)
def plot_concept_heatmaps(
image,
concept_basis: torch.Tensor,
concept_list: list[str],
image_patch_vectors: torch.Tensor,
softmax=True,
normalize_maps=True
):
"""
Plot the concept heatmaps to ensure that the concept basis is
reasonable for the given image.
"""
assert len(image_patch_vectors.shape) in [4, 5], "Image patch vectors should be 4D or 5D, make sure you include layers and timesteps."
fig, axs = plt.subplots(1, len(concept_list) + 1, figsize=(4 * len(concept_list) + 4, 4))
# Normalize the concept basis
# concept_basis = concept_basis / concept_basis.norm(dim=-1, keepdim=True)
if len(image_patch_vectors.shape) == 5:
image_patch_projections = einops.einsum(
image_patch_vectors,
concept_basis,
"layers time heads patches d, layers time heads concepts d -> layers time heads concepts patches",
)
if softmax:
image_patch_projections = torch.softmax(image_patch_projections, dim=-2)
image_patch_projections = einops.reduce(
image_patch_projections,
"layers time heads concepts patches -> concepts patches",
reduction="mean"
)
image_patch_projections = einops.rearrange(
image_patch_projections,
"concepts (h w) -> concepts h w",
h=64,
w=64
)
else:
image_patch_projections = einops.einsum(
image_patch_vectors,
concept_basis,
"layers time patches d, layers time concepts d -> layers time concepts patches",
)
if softmax:
image_patch_projections = torch.softmax(image_patch_projections, dim=-2)
image_patch_projections = einops.reduce(
image_patch_projections,
"layers time concepts patches -> concepts patches",
reduction="mean"
)
image_patch_projections = einops.rearrange(
image_patch_projections,
"concepts (w h) -> concepts w h",
h=64,
w=64
)
image_patch_projections = image_patch_projections.to(torch.float32).detach().cpu().numpy()
# Get min and max values
min_val = image_patch_projections.min()
max_val = image_patch_projections.max()
if len(concept_list) > 30:
for concept in concept_list:
plt.figure()
if normalize_maps:
plt.imshow(
image_patch_projections[concept_list.index(concept)],
cmap="plasma",
vmin=min_val,
vmax=max_val
)
else:
plt.imshow(
image_patch_projections[concept_list.index(concept)],
cmap="plasma"
)
plt.title(concept)
plt.savefig(f"results/concept_heatmaps/{concept}.png")
plt.close()
else:
# Plot the image
axs[0].imshow(image)
axs[0].set_title("Image")
axs[0].axis("off")
# Plot the concept heatmaps
for i, concept in enumerate(concept_list):
if normalize_maps:
axs[i + 1].imshow(
image_patch_projections[i],
cmap="plasma",
vmin=min_val,
vmax=max_val
)
else:
axs[i + 1].imshow(
image_patch_projections[i],
cmap="plasma"
)
axs[i + 1].set_title(concept)
axs[i + 1].axis("off")
# Save the figure
plt.savefig("results/concept_heatmaps.png")
plt.close()
def plot_coefficients_heatmap(
coefficients: torch.Tensor,
concepts: list[str],
save_path="results/group_coding_heatmaps.png"
):
# Convert the coefficients to a dictionary
coefficients = coefficients.detach().cpu().numpy()
coefficients = coefficients.T
dictionaries = []
for i in range(coefficients.shape[0]):
dictionary = {}
for j, concept in enumerate(concepts):
dictionary[concept] = coefficients[i, j]
dictionaries.append(dictionary)
# Convert dictionaries to numpy arrays
dictionaries = [np.array([dictionary[concept] for concept in concepts]) for dictionary in dictionaries]
dictionaries = np.stack(dictionaries, axis=0)
dictionaries = einops.rearrange(
dictionaries,
"(w h) concepts -> concepts w h",
w=64,
h=64
)
# Get min and max
min_val = dictionaries.min()
max_val = dictionaries.max()
# Plot the coeffients of each dictioanry for each patch
fig, axs = plt.subplots(1, len(concepts), figsize=(4 * len(concepts), 4))
for concept_index, concept in enumerate(concepts):
axs[concept_index].imshow(
dictionaries[concept_index],
cmap="plasma",
# vmin=min_val,
# vmax=max_val
)
axs[concept_index].set_title(concept)
axs[concept_index].set_xticks([])
axs[concept_index].set_yticks([])
axs[concept_index].axis("off")
plt.savefig(save_path)
plt.close()