|
|
|
|
|
|
|
|
|
import sys |
|
import os |
|
import numpy as np |
|
import pydensecrf.densecrf as dcrf |
|
import pydensecrf.utils as utils |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as VF |
|
sys.path.append(os.getcwd()) |
|
from modules.transforms import UnNormalize as unnorm |
|
|
|
|
|
MAX_ITER = 10 |
|
POS_W = 3 |
|
POS_XY_STD = 1 |
|
Bi_W = 4 |
|
Bi_XY_STD = 67 |
|
Bi_RGB_STD = 3 |
|
BGR_MEAN = np.array([104.008, 116.669, 122.675]) |
|
|
|
|
|
def dense_crf(image_tensor: torch.FloatTensor, output_logits: torch.FloatTensor): |
|
image = np.array(VF.to_pil_image(unnorm()(image_tensor)))[:, :, ::-1] |
|
H, W = image.shape[:2] |
|
image = np.ascontiguousarray(image) |
|
|
|
output_logits = F.interpolate(output_logits.unsqueeze(0), size=(H, W), mode="bilinear", |
|
align_corners=False).squeeze() |
|
output_probs = F.softmax(output_logits, dim=0).cpu().numpy() |
|
|
|
c = output_probs.shape[0] |
|
h = output_probs.shape[1] |
|
w = output_probs.shape[2] |
|
|
|
U = utils.unary_from_softmax(output_probs) |
|
U = np.ascontiguousarray(U) |
|
|
|
d = dcrf.DenseCRF2D(w, h, c) |
|
d.setUnaryEnergy(U) |
|
d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W) |
|
d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=image, compat=Bi_W) |
|
|
|
Q = d.inference(MAX_ITER) |
|
Q = np.array(Q).reshape((c, h, w)) |
|
return Q |
|
|
|
|
|
|
|
def _apply_crf(tup): |
|
return dense_crf(tup[0], tup[1]) |
|
|
|
|
|
def batched_crf(pool, img_tensor, prob_tensor): |
|
outputs = pool.map(_apply_crf, zip(img_tensor.detach().cpu(), prob_tensor.detach().cpu())) |
|
return torch.cat([torch.from_numpy(arr).unsqueeze(0) for arr in outputs], dim=0) |
|
|
|
|