Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from typing import Tuple | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from detectron2.config import configurable | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head | |
| from detectron2.modeling.backbone import Backbone | |
| from detectron2.modeling.postprocessing import sem_seg_postprocess | |
| from detectron2.structures import ImageList | |
| from detectron2.utils.memory import _ignore_torch_cuda_oom | |
| import numpy as np | |
| from einops import rearrange | |
| from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator | |
| class CATSeg(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| backbone: Backbone, | |
| sem_seg_head: nn.Module, | |
| size_divisibility: int, | |
| pixel_mean: Tuple[float], | |
| pixel_std: Tuple[float], | |
| clip_pixel_mean: Tuple[float], | |
| clip_pixel_std: Tuple[float], | |
| train_class_json: str, | |
| test_class_json: str, | |
| sliding_window: bool, | |
| clip_finetune: str, | |
| backbone_multiplier: float, | |
| clip_pretrained: str, | |
| ): | |
| """ | |
| Args: | |
| backbone: a backbone module, must follow detectron2's backbone interface | |
| sem_seg_head: a module that predicts semantic segmentation from backbone features | |
| """ | |
| super().__init__() | |
| self.backbone = backbone | |
| self.sem_seg_head = sem_seg_head | |
| if size_divisibility < 0: | |
| size_divisibility = self.backbone.size_divisibility | |
| self.size_divisibility = size_divisibility | |
| self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) | |
| self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) | |
| self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False) | |
| self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False) | |
| self.train_class_json = train_class_json | |
| self.test_class_json = test_class_json | |
| self.clip_finetune = clip_finetune | |
| for name, params in self.sem_seg_head.predictor.clip_model.named_parameters(): | |
| if "visual" in name: | |
| if clip_finetune == "prompt": | |
| params.requires_grad = True if "prompt" in name else False | |
| elif clip_finetune == "attention": | |
| params.requires_grad = True if "attn" in name or "position" in name else False | |
| elif clip_finetune == "full": | |
| params.requires_grad = True | |
| else: | |
| params.requires_grad = False | |
| else: | |
| params.requires_grad = False | |
| finetune_backbone = backbone_multiplier > 0. | |
| for name, params in self.backbone.named_parameters(): | |
| if "norm0" in name: | |
| params.requires_grad = False | |
| else: | |
| params.requires_grad = finetune_backbone | |
| self.sliding_window = sliding_window | |
| self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336) | |
| self.sequential = False | |
| self.use_sam = False | |
| self.sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to(self.device) | |
| amg_kwargs = { | |
| "points_per_side": 32, | |
| "points_per_batch": None, | |
| #"pred_iou_thresh": 0.0, | |
| #"stability_score_thresh": 0.0, | |
| "stability_score_offset": None, | |
| "box_nms_thresh": None, | |
| "crop_n_layers": None, | |
| "crop_nms_thresh": None, | |
| "crop_overlap_ratio": None, | |
| "crop_n_points_downscale_factor": None, | |
| "min_mask_region_area": None, | |
| } | |
| amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} | |
| self.mask = SamAutomaticMaskGenerator(self.sam, output_mode="binary_mask", **amg_kwargs) | |
| self.overlap_threshold = 0.8 | |
| self.panoptic_on = False | |
| def from_config(cls, cfg): | |
| backbone = build_backbone(cfg) | |
| sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) | |
| return { | |
| "backbone": backbone, | |
| "sem_seg_head": sem_seg_head, | |
| "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, | |
| "pixel_mean": cfg.MODEL.PIXEL_MEAN, | |
| "pixel_std": cfg.MODEL.PIXEL_STD, | |
| "clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN, | |
| "clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD, | |
| "train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON, | |
| "test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON, | |
| "sliding_window": cfg.TEST.SLIDING_WINDOW, | |
| "clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE, | |
| "backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER, | |
| "clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED, | |
| } | |
| def device(self): | |
| return self.pixel_mean.device | |
| def forward(self, batched_inputs): | |
| """ | |
| Args: | |
| batched_inputs: a list, batched outputs of :class:`DatasetMapper`. | |
| Each item in the list contains the inputs for one image. | |
| For now, each item in the list is a dict that contains: | |
| * "image": Tensor, image in (C, H, W) format. | |
| * "instances": per-region ground truth | |
| * Other information that's included in the original dicts, such as: | |
| "height", "width" (int): the output resolution of the model (may be different | |
| from input resolution), used in inference. | |
| Returns: | |
| list[dict]: | |
| each dict has the results for one image. The dict contains the following keys: | |
| * "sem_seg": | |
| A Tensor that represents the | |
| per-pixel segmentation prediced by the head. | |
| The prediction has shape KxHxW that represents the logits of | |
| each class for each pixel. | |
| """ | |
| images = [x["image"].to(self.device) for x in batched_inputs] | |
| sam_images = images | |
| if not self.training and self.sliding_window: | |
| if not self.sequential: | |
| with _ignore_torch_cuda_oom(): | |
| return self.inference_sliding_window(batched_inputs) | |
| self.sequential = True | |
| return self.inference_sliding_window(batched_inputs) | |
| clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images] | |
| clip_images = ImageList.from_tensors(clip_images, self.size_divisibility) | |
| images = [(x - self.pixel_mean) / self.pixel_std for x in images] | |
| images = ImageList.from_tensors(images, self.size_divisibility) | |
| clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, ) | |
| clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) | |
| images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,) | |
| features = self.backbone(images_resized) | |
| outputs = self.sem_seg_head(clip_features, features) | |
| if self.training: | |
| targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0) | |
| outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False) | |
| num_classes = outputs.shape[1] | |
| mask = targets != self.sem_seg_head.ignore_value | |
| outputs = outputs.permute(0,2,3,1) | |
| _targets = torch.zeros(outputs.shape, device=self.device) | |
| _onehot = F.one_hot(targets[mask], num_classes=num_classes).float() | |
| _targets[mask] = _onehot | |
| loss = F.binary_cross_entropy_with_logits(outputs, _targets) | |
| losses = {"loss_sem_seg" : loss} | |
| return losses | |
| else: | |
| #outputs = outputs.sigmoid() | |
| image_size = images.image_sizes[0] | |
| if self.use_sam: | |
| masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy())) | |
| outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, image_size) | |
| #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, image_size) | |
| #outputs, sam_cls = self.continuous_semantic_inference2(outputs, masks, image_size, img=img, text=text) | |
| height = batched_inputs[0].get("height", image_size[0]) | |
| width = batched_inputs[0].get("width", image_size[1]) | |
| output = sem_seg_postprocess(outputs[0], image_size, height, width) | |
| processed_results = [{'sem_seg': output}] | |
| return processed_results | |
| def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]): | |
| images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs] | |
| stride = int(kernel * (1 - overlap)) | |
| unfold = nn.Unfold(kernel_size=kernel, stride=stride) | |
| fold = nn.Fold(out_res, kernel_size=kernel, stride=stride) | |
| image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze() | |
| sam_images = [image] | |
| image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel) | |
| global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False) | |
| image = torch.cat((image, global_image), dim=0) | |
| images = (image - self.pixel_mean) / self.pixel_std | |
| clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std | |
| clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, ) | |
| clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) | |
| if self.sequential: | |
| outputs = [] | |
| for clip_feat, image in zip(clip_features, images): | |
| feature = self.backbone(image.unsqueeze(0)) | |
| output = self.sem_seg_head(clip_feat.unsqueeze(0), feature) | |
| outputs.append(output[0]) | |
| outputs = torch.stack(outputs, dim=0) | |
| else: | |
| features = self.backbone(images) | |
| outputs = self.sem_seg_head(clip_features, features) | |
| outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False) | |
| outputs = outputs.sigmoid() | |
| global_output = outputs[-1:] | |
| global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,) | |
| outputs = outputs[:-1] | |
| outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device))) | |
| outputs = (outputs + global_output) / 2. | |
| height = batched_inputs[0].get("height", out_res[0]) | |
| width = batched_inputs[0].get("width", out_res[1]) | |
| catseg_outputs = sem_seg_postprocess(outputs[0], out_res, height, width) | |
| #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
| masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy())) | |
| if self.use_sam: | |
| outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, out_res) | |
| #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, out_res) | |
| output = sem_seg_postprocess(outputs[0], out_res, height, width) | |
| ret = [{'sem_seg': output}] | |
| if self.panoptic_on: | |
| panoptic_r = self.panoptic_inference(catseg_outputs, masks, sam_cls, size=output.shape[-2:]) | |
| ret[0]['panoptic_seg'] = panoptic_r | |
| return ret | |
| def discrete_semantic_inference(self, outputs, masks, image_size): | |
| catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True) #.argmax(dim=1)[0].cpu() | |
| sam_outputs = torch.zeros_like(catseg_outputs).cpu() | |
| catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
| sam_classes = torch.zeros(len(masks)) | |
| for i in range(len(masks)): | |
| m = masks[i]['segmentation'] | |
| s = masks[i]['stability_score'] | |
| idx = catseg_outputs[m].bincount().argmax() | |
| sam_outputs[0, idx][m] = s | |
| sam_classes[i] = idx | |
| return sam_outputs, sam_classes | |
| def continuous_semantic_inference(self, outputs, masks, image_size, scale=100/7.): | |
| #import pdb; pdb.set_trace() | |
| catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu() | |
| sam_outputs = torch.zeros_like(catseg_outputs) | |
| #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
| sam_classes = torch.zeros(len(masks)) | |
| #import pdb; pdb.set_trace() | |
| mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W | |
| mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N | |
| mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs) | |
| mask_norm = mask_pred.sum(-1).sum(-1) | |
| mask_cls = mask_cls / mask_norm[:, None] | |
| mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None] | |
| mask_logits = mask_pred * mask_score[:, None, None] | |
| output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls) | |
| return output.unsqueeze(0), mask_cls | |
| def continuous_semantic_inference2(self, outputs, masks, image_size, scale=100/7., img=None, text=None): | |
| assert img is not None and text is not None | |
| import pdb; pdb.set_trace() | |
| #catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu() | |
| img = F.interpolate(img, size=image_size, mode="bilinear", align_corners=True)[0].cpu() | |
| img = img.permute(1, 2, 0) | |
| #sam_outputs = torch.zeros_like(catseg_outputs) | |
| #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
| sam_classes = torch.zeros(len(masks)) | |
| #import pdb; pdb.set_trace() | |
| mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W | |
| mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N | |
| mask_pool = torch.einsum("nhw, hwd -> nd ", mask_pred, img) | |
| mask_pool = mask_pool / mask_pool.norm(dim=1, keepdim=True) | |
| mask_cls = torch.einsum("nd, cd -> nc", 100 * mask_pool, text.cpu()) | |
| mask_cls = mask_cls.softmax(dim=1) | |
| #mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs) | |
| mask_norm = mask_pred.sum(-1).sum(-1) | |
| mask_cls = mask_cls / mask_norm[:, None] | |
| mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None] | |
| mask_logits = mask_pred * mask_score[:, None, None] | |
| output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls) | |
| return output.unsqueeze(0), sam_classes | |
| def panoptic_inference(self, outputs, masks, sam_classes, size=None): | |
| #import pdb; pdb.set_trace() | |
| scores = np.asarray([x['predicted_iou'] for x in masks]) | |
| mask_pred = np.asarray([x['segmentation'] for x in masks]) | |
| #keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) | |
| cur_scores = torch.tensor(scores) | |
| cur_masks = torch.tensor(mask_pred) | |
| cur_masks = F.interpolate(cur_masks.unsqueeze(0).float(), size=outputs.shape[-2:], mode="nearest")[0] | |
| cur_classes = sam_classes.argmax(dim=-1) | |
| #cur_mask_cls = mask_cls#[keep] | |
| #cur_mask_cls = cur_mask_cls[:, :-1] | |
| #import pdb; pdb.set_trace() | |
| cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks | |
| h, w = cur_masks.shape[-2:] | |
| panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) | |
| segments_info = [] | |
| current_segment_id = 0 | |
| if cur_masks.shape[0] == 0: | |
| # We didn't detect any mask :( | |
| return panoptic_seg, segments_info | |
| else: | |
| # take argmax | |
| cur_mask_ids = cur_prob_masks.argmax(0) | |
| stuff_memory_list = {} | |
| for k in range(cur_classes.shape[0]): | |
| pred_class = cur_classes[k].item() | |
| #isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() | |
| isthing = pred_class in [3, 6] #[i for i in range(10)]#self.metadata.thing_dataset_id_to_contiguous_id.values() | |
| mask = cur_mask_ids == k | |
| mask_area = mask.sum().item() | |
| original_area = (cur_masks[k] >= 0.5).sum().item() | |
| if mask_area > 0 and original_area > 0: | |
| if mask_area / original_area < self.overlap_threshold: | |
| continue | |
| # merge stuff regions | |
| if not isthing: | |
| if int(pred_class) in stuff_memory_list.keys(): | |
| panoptic_seg[mask] = stuff_memory_list[int(pred_class)] | |
| continue | |
| else: | |
| stuff_memory_list[int(pred_class)] = current_segment_id + 1 | |
| current_segment_id += 1 | |
| panoptic_seg[mask] = current_segment_id | |
| segments_info.append( | |
| { | |
| "id": current_segment_id, | |
| "isthing": bool(isthing), | |
| "category_id": int(pred_class), | |
| } | |
| ) | |
| return panoptic_seg, segments_info |