|
import torch |
|
import sys |
|
import os |
|
import torch.nn.functional as F |
|
|
|
sys.path.append(os.getcwd()) |
|
from modules.crf import dense_crf |
|
from modules.median_pool import MedianPool2d |
|
|
|
|
|
class PriMaPs(): |
|
def __init__(self, |
|
threshold=0.4, |
|
ignore_id=27): |
|
super(PriMaPs, self).__init__() |
|
self.threshold = threshold |
|
self.ignore_id = ignore_id |
|
self.medianfilter = MedianPool2d(kernel_size=3, stride=1, padding=1) |
|
|
|
def _get_pseudo(self, img, feat, cls_prior): |
|
|
|
mask = torch.ones(feat.shape[-2:]).bool().to(feat.device) |
|
mask_memory = [] |
|
pseudo_masks = [] |
|
|
|
while ((mask!=1).sum()/mask.numel() < 0.95): |
|
_, _, v = torch.pca_lowrank(feat[:, mask].permute(1, 0), q=3, niter=100) |
|
|
|
sim = torch.einsum("c,cij->ij", v[:, 0], F.normalize(feat, dim=0)) |
|
|
|
sim[~mask] = 0 |
|
v = F.normalize(feat, dim=0)[:, sim==sim.max()][:, 0] |
|
sim = torch.einsum("c,cij->ij", v, F.normalize(feat, dim=0)) |
|
sim[~mask] = 0 |
|
|
|
sim[sim<self.threshold*sim.max()] = 0 |
|
sim = sim/sim.max() |
|
pseudo_masks.append(sim.clone()) |
|
|
|
mask[sim>0]=0 |
|
mask_memory.insert(0, mask.clone()) |
|
if mask_memory.__len__() > 3: |
|
mask_memory.pop() |
|
if torch.Tensor([(mask_memory[0]==i).all() for i in mask_memory]).all(): |
|
break |
|
|
|
pseudo_masks = (self.medianfilter(torch.stack(pseudo_masks, dim=0).unsqueeze(0)).squeeze()*10).clamp(0, 1) |
|
bg = (torch.mean(pseudo_masks[pseudo_masks!=0])*torch.ones(feat.shape[-2:], device=feat.device)-pseudo_masks.sum(dim=0)).unsqueeze(0).clamp(0, 1) |
|
|
|
if (pseudo_masks.shape).__len__() == 2: |
|
pseudo_masks = pseudo_masks.unsqueeze(0) |
|
pseudo_masks = torch.cat([bg, pseudo_masks], dim=0) |
|
pseudo_masks = F.log_softmax(pseudo_masks, dim=0) |
|
|
|
pseudo_masks = dense_crf(img.squeeze(), pseudo_masks).argmax(0) |
|
pseudo_masks = torch.Tensor(pseudo_masks).to(feat.device) |
|
|
|
if (cls_prior == 0).all(): |
|
pseudolabel = pseudo_masks |
|
pseudolabel[pseudolabel==0] = self.ignore_id |
|
else: |
|
pseudolabel = torch.ones(img.shape[-2:]).to(feat.device)*self.ignore_id |
|
for i in pseudo_masks.unique()[pseudo_masks.unique()!=0]: |
|
|
|
mask = (pseudolabel==self.ignore_id)*(pseudo_masks==i) |
|
pseudolabel[mask] = int(torch.mode(cls_prior[mask])[0]) |
|
return pseudolabel |
|
|
|
|
|
def _apply_batched_decompose(self, tup): |
|
return self._get_pseudo(tup[0], tup[1], tup[2]) |
|
|
|
def __call__(self, pool, imgs, features, cls_prior): |
|
outs = pool.map(self._apply_batched_decompose, zip(imgs, features, cls_prior)) |
|
return torch.stack(outs, dim=0) |