franchesoni's picture
v0
e1b51e5
print("Importing standard...")
import subprocess
import shutil
from pathlib import Path
print("Importing external...")
import torch
import numpy as np
from PIL import Image
REDUCTION = "pca"
if REDUCTION == "umap":
from umap import UMAP
elif REDUCTION == "tsne":
from sklearn.manifold import TSNE
elif REDUCTION == "pca":
from sklearn.decomposition import PCA
def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1)
def preprocess_masks_features(masks, features):
# Get shapes right
B, M, H, W = masks.shape
Bf, F, Hf, Wf = features.shape
masks = masks.reshape(B, M, 1, H * W)
# # the following assertions should work, remove due to speed
# assert H == Hf and W == Wf and B == Bf
# assert masks.dtype == torch.bool
# assert (mask_areas > 0).all(), "you shouldn't have empty masks"
# Reduce M if there are empty masks
mask_areas = masks.sum(dim=3) # B, M, 1
features = features.reshape(B, 1, F, H * W)
# output shapes
# features: B, 1, F, H*W
# masks: B, M, 1, H*W
return masks, features, M, B, H, W, F
def get_row_col(H, W, device):
# get position of pixels in [0, 1]
row = torch.linspace(0, 1, H, device=device)
col = torch.linspace(0, 1, W, device=device)
return row, col
def get_current_git_commit():
try:
# Run the git command to get the current commit hash
commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
# Decode from bytes to a string
return commit_hash.decode("utf-8")
except subprocess.CalledProcessError:
# Handle the case where the command fails (e.g., not a Git repository)
print("An error occurred while trying to retrieve the git commit hash.")
return None
def clean_dir(dirname):
"""Removes all directories in dirname that don't have a done.txt file"""
dstdir = Path(dirname)
dstdir.mkdir(exist_ok=True, parents=True)
for f in dstdir.iterdir():
# if the directory doesn't have a done.txt file remove it
if f.is_dir() and not (f / "done.txt").exists():
shutil.rmtree(f)
def save_tensor_as_image(tensor, dstfile, global_step):
dstfile = Path(dstfile)
dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix(
".jpg"
)
save(tensor, str(dstfile))
def minmaxnorm(x):
return (x - x.min()) / (x.max() - x.min())
def save(tensor, name, channel_offset=0):
tensor = to_img(tensor, channel_offset=channel_offset)
Image.fromarray(tensor).save(name)
def to_img(tensor, channel_offset=0):
tensor = minmaxnorm(tensor)
tensor = (tensor * 255).to(torch.uint8)
C, H, W = tensor.shape
if tensor.shape[0] == 1:
tensor = tensor[0]
elif tensor.shape[0] == 2:
tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0)
tensor = tensor.permute(1, 2, 0)
elif tensor.shape[0] >= 3:
tensor = tensor[channel_offset : channel_offset + 3]
tensor = tensor.permute(1, 2, 0)
tensor = tensor.cpu().numpy()
return tensor
def log_input_output(
name,
x,
y_hat,
global_step,
img_dstdir,
out_dstdir,
reduce_dim=True,
reduction=REDUCTION,
resample_size=20000,
):
y_hat = y_hat.reshape(
y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4]
)
if reduce_dim and y_hat.shape[1] >= 3:
reducer = (
UMAP(n_components=3)
if (reduction == "umap")
else (
TSNE(n_components=3)
if reduction == "tsne"
else PCA(n_components=3)
if reduction == "pca"
else None
)
)
np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy() # F, 1, B, H, W
np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW
np_y_hat = np_y_hat.T # BHW, F
sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
print("dim reduction fit..." + " " * 30, end="\r")
reducer = reducer.fit(sampled_pixels)
print("dim reduction transform..." + " " * 30, end="\r")
reducer.transform(np_y_hat[:10]) # to numba compile the function
np_y_hat = reducer.transform(np_y_hat) # BHW, 3
# revert back to original shape
y_hat2 = (
torch.from_numpy(
np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3])
)
.to(y_hat.device)
.permute(1, 0, 2, 3)
)
print("done" + " " * 30, end="\r")
else:
y_hat2 = y_hat
for i in range(min(len(x), 8)):
save_tensor_as_image(
x[i],
img_dstdir / f"input_{name}_{str(i).zfill(2)}",
global_step=global_step,
)
for c in range(y_hat.shape[1]):
save_tensor_as_image(
y_hat[i, c : c + 1],
out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}",
global_step=global_step,
)
# log color image
assert len(y_hat2.shape) == 4, "should be B, F, H, W"
if reduce_dim:
save_tensor_as_image(
y_hat2[i][:3],
out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}",
global_step=global_step,
)
save_tensor_as_image(
y_hat[i][:3],
out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}",
global_step=global_step,
)
def check_for_nan(loss, model, batch):
try:
assert torch.isnan(loss) == False
except Exception as e:
# print things useful to debug
# does the batch contain nan?
print("img batch contains nan?", torch.isnan(batch[0]).any())
print("mask batch contains nan?", torch.isnan(batch[1]).any())
# does the model weights contain nan?
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(name, "contains nan")
# does the output contain nan?
print("output contains nan?", torch.isnan(model(batch[0])).any())
# now raise the error
raise e
def calculate_iou(pred, label):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | (pred == 1)).sum()
if not union:
return 0
else:
iou = intersection.item() / union.item()
return iou
def load_from_ckpt(net, ckpt_path, strict=True):
"""Load network weights"""
if ckpt_path and Path(ckpt_path).exists():
ckpt = torch.load(ckpt_path, map_location="cpu")
if "MODEL_STATE" in ckpt:
ckpt = ckpt["MODEL_STATE"]
elif "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
net.load_state_dict(ckpt, strict=strict)
print("Loaded checkpoint from", ckpt_path)
return net