|
import torch |
|
from torch import nn |
|
import numpy as np |
|
from cv2 import resize |
|
import cv2 |
|
from pathlib import Path |
|
|
|
from network import EfficientViT_l1_r224 |
|
from losses import IISLoss, activate |
|
from utils import minmaxnorm, load_from_ckpt |
|
|
|
|
|
class Busam: |
|
def __init__(self, checkpoint, device, side=224): |
|
out_channels = 16 |
|
use_norm_params = False |
|
net = EfficientViT_l1_r224( |
|
out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False |
|
) |
|
net = load_from_ckpt(net, checkpoint) |
|
net = net.to(device) |
|
net.eval() |
|
self.net = net |
|
self.device = device |
|
self.side = side |
|
|
|
def prepare_img(self, img): |
|
""" |
|
assume H, W, 3 image |
|
""" |
|
assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape) |
|
assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape) |
|
assert img.min() >= 0, "min should be more than 0 but is " + str(img.min()) |
|
assert img.max() <= 255, "max should be less than 255 but is " + str(img.max()) |
|
assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str( |
|
img.dtype |
|
) |
|
nimg = resize(img, (self.side, self.side)) |
|
tensorimg = ( |
|
(torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5) |
|
.float()[None] |
|
.to(self.device) |
|
) |
|
return tensorimg |
|
|
|
def process_image(self, img, do_activate=False): |
|
with torch.no_grad(): |
|
x = self.prepare_img(img) |
|
pred = self.net(x) |
|
H, W = img.shape[:2] |
|
if do_activate: |
|
B, F, pH, pW = pred.shape |
|
features, _, _, _ = activate( |
|
pred.view(F, pH * pW), None, "symlog", False, False, False |
|
) |
|
pred = features.view(B, F, pH, pW) |
|
return pred, (H, W) |
|
|
|
def get_mask(self, aux, click): |
|
"""assume click is (row, col)""" |
|
pred = aux[0][0] |
|
oH, oW = aux[1] |
|
F, H, W = pred.shape |
|
features = pred.view(F, H * W) |
|
rclick = click[0] * H // oH, click[1] * W // oW |
|
sindex = rclick[0] * W + rclick[1] |
|
mask = IISLoss.get_mask_from_query(features, sindex) |
|
mask = mask.reshape(H, W) |
|
mask = ( |
|
resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100 |
|
).astype(bool) |
|
return mask |
|
|
|
def get_gradients(self, pred, size): |
|
F, H, W = pred[0].shape |
|
sobel_x = ( |
|
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device) |
|
) |
|
sobel_y = sobel_x.T |
|
sobel_x = sobel_x.repeat(F, 1, 1, 1) |
|
sobel_y = sobel_y.repeat(F, 1, 1, 1) |
|
edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view( |
|
F, H, W |
|
) |
|
edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view( |
|
F, H, W |
|
) |
|
edge_x = torch.norm(edge_x, dim=0, p=2) |
|
edge_y = torch.norm(edge_y, dim=0, p=2) |
|
return edge_x, edge_y |
|
|
|
def sobel_from_pred(self, pred, size): |
|
edge_x, edge_y = self.get_gradients(pred, size) |
|
edge = torch.sqrt(edge_x**2 + edge_y**2) |
|
return edge |
|
|
|
def canny_from_pred(self, pred, size, th_low=10000, th_high=20000): |
|
th_low = th_low or th_high |
|
th_high = th_high or th_low |
|
|
|
edge_x, edge_y = self.get_gradients(pred, size) |
|
amin = min(edge_x.min(), edge_y.min()) |
|
amax = max(edge_x.max(), edge_y.max()) |
|
edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / ( |
|
amax - amin |
|
) |
|
canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high) |
|
return canny |
|
|
|
|
|
def cast_to_int16(x): |
|
if isinstance(x, torch.Tensor): |
|
x = x.cpu().numpy() |
|
return (x * 32767).astype(np.int16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|