File size: 6,940 Bytes
e1b51e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
print("Importing standard...")
import subprocess
import shutil
from pathlib import Path
print("Importing external...")
import torch
import numpy as np
from PIL import Image
REDUCTION = "pca"
if REDUCTION == "umap":
from umap import UMAP
elif REDUCTION == "tsne":
from sklearn.manifold import TSNE
elif REDUCTION == "pca":
from sklearn.decomposition import PCA
def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1)
def preprocess_masks_features(masks, features):
# Get shapes right
B, M, H, W = masks.shape
Bf, F, Hf, Wf = features.shape
masks = masks.reshape(B, M, 1, H * W)
# # the following assertions should work, remove due to speed
# assert H == Hf and W == Wf and B == Bf
# assert masks.dtype == torch.bool
# assert (mask_areas > 0).all(), "you shouldn't have empty masks"
# Reduce M if there are empty masks
mask_areas = masks.sum(dim=3) # B, M, 1
features = features.reshape(B, 1, F, H * W)
# output shapes
# features: B, 1, F, H*W
# masks: B, M, 1, H*W
return masks, features, M, B, H, W, F
def get_row_col(H, W, device):
# get position of pixels in [0, 1]
row = torch.linspace(0, 1, H, device=device)
col = torch.linspace(0, 1, W, device=device)
return row, col
def get_current_git_commit():
try:
# Run the git command to get the current commit hash
commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
# Decode from bytes to a string
return commit_hash.decode("utf-8")
except subprocess.CalledProcessError:
# Handle the case where the command fails (e.g., not a Git repository)
print("An error occurred while trying to retrieve the git commit hash.")
return None
def clean_dir(dirname):
"""Removes all directories in dirname that don't have a done.txt file"""
dstdir = Path(dirname)
dstdir.mkdir(exist_ok=True, parents=True)
for f in dstdir.iterdir():
# if the directory doesn't have a done.txt file remove it
if f.is_dir() and not (f / "done.txt").exists():
shutil.rmtree(f)
def save_tensor_as_image(tensor, dstfile, global_step):
dstfile = Path(dstfile)
dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix(
".jpg"
)
save(tensor, str(dstfile))
def minmaxnorm(x):
return (x - x.min()) / (x.max() - x.min())
def save(tensor, name, channel_offset=0):
tensor = to_img(tensor, channel_offset=channel_offset)
Image.fromarray(tensor).save(name)
def to_img(tensor, channel_offset=0):
tensor = minmaxnorm(tensor)
tensor = (tensor * 255).to(torch.uint8)
C, H, W = tensor.shape
if tensor.shape[0] == 1:
tensor = tensor[0]
elif tensor.shape[0] == 2:
tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0)
tensor = tensor.permute(1, 2, 0)
elif tensor.shape[0] >= 3:
tensor = tensor[channel_offset : channel_offset + 3]
tensor = tensor.permute(1, 2, 0)
tensor = tensor.cpu().numpy()
return tensor
def log_input_output(
name,
x,
y_hat,
global_step,
img_dstdir,
out_dstdir,
reduce_dim=True,
reduction=REDUCTION,
resample_size=20000,
):
y_hat = y_hat.reshape(
y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4]
)
if reduce_dim and y_hat.shape[1] >= 3:
reducer = (
UMAP(n_components=3)
if (reduction == "umap")
else (
TSNE(n_components=3)
if reduction == "tsne"
else PCA(n_components=3)
if reduction == "pca"
else None
)
)
np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy() # F, 1, B, H, W
np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW
np_y_hat = np_y_hat.T # BHW, F
sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
print("dim reduction fit..." + " " * 30, end="\r")
reducer = reducer.fit(sampled_pixels)
print("dim reduction transform..." + " " * 30, end="\r")
reducer.transform(np_y_hat[:10]) # to numba compile the function
np_y_hat = reducer.transform(np_y_hat) # BHW, 3
# revert back to original shape
y_hat2 = (
torch.from_numpy(
np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3])
)
.to(y_hat.device)
.permute(1, 0, 2, 3)
)
print("done" + " " * 30, end="\r")
else:
y_hat2 = y_hat
for i in range(min(len(x), 8)):
save_tensor_as_image(
x[i],
img_dstdir / f"input_{name}_{str(i).zfill(2)}",
global_step=global_step,
)
for c in range(y_hat.shape[1]):
save_tensor_as_image(
y_hat[i, c : c + 1],
out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}",
global_step=global_step,
)
# log color image
assert len(y_hat2.shape) == 4, "should be B, F, H, W"
if reduce_dim:
save_tensor_as_image(
y_hat2[i][:3],
out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}",
global_step=global_step,
)
save_tensor_as_image(
y_hat[i][:3],
out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}",
global_step=global_step,
)
def check_for_nan(loss, model, batch):
try:
assert torch.isnan(loss) == False
except Exception as e:
# print things useful to debug
# does the batch contain nan?
print("img batch contains nan?", torch.isnan(batch[0]).any())
print("mask batch contains nan?", torch.isnan(batch[1]).any())
# does the model weights contain nan?
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(name, "contains nan")
# does the output contain nan?
print("output contains nan?", torch.isnan(model(batch[0])).any())
# now raise the error
raise e
def calculate_iou(pred, label):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | (pred == 1)).sum()
if not union:
return 0
else:
iou = intersection.item() / union.item()
return iou
def load_from_ckpt(net, ckpt_path, strict=True):
"""Load network weights"""
if ckpt_path and Path(ckpt_path).exists():
ckpt = torch.load(ckpt_path, map_location="cpu")
if "MODEL_STATE" in ckpt:
ckpt = ckpt["MODEL_STATE"]
elif "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
net.load_state_dict(ckpt, strict=strict)
print("Loaded checkpoint from", ckpt_path)
return net
|