Spaces:
Runtime error
Runtime error
| from argparse import ArgumentParser, Namespace | |
| from typing import Dict, List, Tuple | |
| import yaml | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import to_tensor, normalize, resize | |
| import gradio as gr | |
| from utils import get_model | |
| from bilateral_solver import bilateral_solver_output | |
| import os | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| state_dict: dict = torch.hub.load_state_dict_from_url( | |
| "https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt", | |
| map_location=device # "cuda" if torch.cuda.is_available() else "cpu" | |
| )["model"] | |
| parser = ArgumentParser("SelfMask demo") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml" | |
| ) | |
| # parser.add_argument( | |
| # "--p_state_dict", | |
| # type=str, | |
| # default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt", | |
| # ) | |
| # | |
| # parser.add_argument( | |
| # "--dataset_name", '-dn', type=str, default="duts", | |
| # choices=["dut_omron", "duts", "ecssd"] | |
| # ) | |
| # independent variables | |
| # parser.add_argument("--use_gpu", type=bool, default=True) | |
| # parser.add_argument('--seed', default=0, type=int) | |
| # parser.add_argument("--dir_root", type=str, default="..") | |
| # parser.add_argument("--gpu_id", type=int, default=2) | |
| # parser.add_argument("--suffix", type=str, default='') | |
| args: Namespace = parser.parse_args() | |
| base_args = yaml.safe_load(open(f"{args.config}", 'r')) | |
| base_args.pop("dataset_name") | |
| args: dict = vars(args) | |
| args.update(base_args) | |
| args: Namespace = Namespace(**args) | |
| model = get_model(arch="maskformer", configs=args).to(device) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| def main( | |
| image: Image.Image, | |
| size: int = 384, | |
| max_size: int = 512, | |
| mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), | |
| std: Tuple[float, float, float] = (0.229, 0.224, 0.225) | |
| ): | |
| pil_image: Image.Image = resize(image, size=size, max_size=max_size) | |
| image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W | |
| dict_outputs = model(image[None].to(device)) | |
| batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1] | |
| batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1] | |
| if len(batch_pred_masks.shape) == 5: | |
| # b x n_layers x n_queries x h x w -> b x n_queries x h x w | |
| batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer | |
| if batch_objectness is not None: | |
| # b x n_layers x n_queries x 1 -> b x n_queries x 1 | |
| batch_objectness = batch_objectness[:, -1, ...] | |
| # resize prediction to original resolution | |
| # note: upsampling by 4 and cutting the padded region allows for a better result | |
| H, W = image.shape[-2:] | |
| batch_pred_masks = F.interpolate( | |
| batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False | |
| )[..., :H, :W] | |
| # iterate over batch dimension | |
| for batch_index, pred_masks in enumerate(batch_pred_masks): | |
| # n_queries x 1 -> n_queries | |
| objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1) | |
| ranks = torch.argsort(objectness, descending=True) # n_queries | |
| pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W | |
| pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255 | |
| pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64 | |
| pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8) | |
| attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB) | |
| super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0) | |
| return super_imposed_img | |
| # return pred_mask_bi | |
| demo = gr.Interface( | |
| fn=main, | |
| inputs=gr.inputs.Image(type="pil"), | |
| outputs="image", | |
| examples=[f"resources/{fname}.jpg" for fname in [ | |
| "0053", | |
| "0236", | |
| "0239", | |
| "0403", | |
| "0412", | |
| "ILSVRC2012_test_00005309", | |
| "ILSVRC2012_test_00012622", | |
| "ILSVRC2012_test_00022698", | |
| "ILSVRC2012_test_00040725", | |
| "ILSVRC2012_test_00075738", | |
| "ILSVRC2012_test_00080683", | |
| "ILSVRC2012_test_00085874", | |
| "im052", | |
| "sun_ainjbonxmervsvpv", | |
| "sun_alfntqzssslakmss", | |
| "sun_amnrcxhisjfrliwa", | |
| "sun_bvyxpvkouzlfwwod" | |
| ]], | |
| title="Unsupervised Salient Object Detection with Spectral Cluster Voting", | |
| allow_flagging="never", | |
| analytics_enabled=False | |
| ) | |
| demo.launch( | |
| # share=True | |
| ) |