# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional from torch import Tensor, nn from mmseg.registry import MODELS from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, OptSampleList, SampleList, add_prefix) from .encoder_decoder import EncoderDecoder @MODELS.register_module() class CascadeEncoderDecoder(EncoderDecoder): """Cascade Encoder Decoder segmentors. CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of CascadeEncoderDecoder are cascaded. The output of previous decoder_head will be the input of next decoder_head. Args: num_stages (int): How many stages will be cascaded. backbone (ConfigType): The config for the backnone of segmentor. decode_head (ConfigType): The config for the decode head of segmentor. neck (OptConfigType): The config for the neck of segmentor. Defaults to None. auxiliary_head (OptConfigType): The config for the auxiliary head of segmentor. Defaults to None. train_cfg (OptConfigType): The config for training. Defaults to None. test_cfg (OptConfigType): The config for testing. Defaults to None. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. pretrained (str, optional): The path for pretrained model. Defaults to None. init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. """ def __init__(self, num_stages: int, backbone: ConfigType, decode_head: ConfigType, neck: OptConfigType = None, auxiliary_head: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, pretrained: Optional[str] = None, init_cfg: OptMultiConfig = None): self.num_stages = num_stages super().__init__( backbone=backbone, decode_head=decode_head, neck=neck, auxiliary_head=auxiliary_head, train_cfg=train_cfg, test_cfg=test_cfg, data_preprocessor=data_preprocessor, pretrained=pretrained, init_cfg=init_cfg) def _init_decode_head(self, decode_head: ConfigType) -> None: """Initialize ``decode_head``""" assert isinstance(decode_head, list) assert len(decode_head) == self.num_stages self.decode_head = nn.ModuleList() for i in range(self.num_stages): self.decode_head.append(MODELS.build(decode_head[i])) self.align_corners = self.decode_head[-1].align_corners self.num_classes = self.decode_head[-1].num_classes self.out_channels = self.decode_head[-1].out_channels def encode_decode(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = self.extract_feat(inputs) out = self.decode_head[0].forward(x) for i in range(1, self.num_stages - 1): out = self.decode_head[i].forward(x, out) seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas, self.test_cfg) return seg_logits_list def _decode_head_forward_train(self, inputs: Tensor, data_samples: SampleList) -> dict: """Run forward function and calculate loss for decode head in training.""" losses = dict() loss_decode = self.decode_head[0].loss(inputs, data_samples, self.train_cfg) losses.update(add_prefix(loss_decode, 'decode_0')) # get batch_img_metas batch_size = len(data_samples) batch_img_metas = [] for batch_index in range(batch_size): metainfo = data_samples[batch_index].metainfo batch_img_metas.append(metainfo) for i in range(1, self.num_stages): # forward test again, maybe unnecessary for most methods. if i == 1: prev_outputs = self.decode_head[0].forward(inputs) else: prev_outputs = self.decode_head[i - 1].forward( inputs, prev_outputs) loss_decode = self.decode_head[i].loss(inputs, prev_outputs, data_samples, self.train_cfg) losses.update(add_prefix(loss_decode, f'decode_{i}')) return losses def _forward(self, inputs: Tensor, data_samples: OptSampleList = None) -> Tensor: """Network forward process. Args: inputs (Tensor): Inputs with shape (N, C, H, W). data_samples (List[:obj:`SegDataSample`]): The seg data samples. It usually includes information such as `metainfo` and `gt_semantic_seg`. Returns: Tensor: Forward output of model without any post-processes. """ x = self.extract_feat(inputs) out = self.decode_head[0].forward(x) for i in range(1, self.num_stages): # TODO support PointRend tensor mode out = self.decode_head[i].forward(x, out) return out