Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmseg.registry import MODELS | |
from .fcn_head import FCNHead | |
try: | |
from mmcv.ops import CrissCrossAttention | |
except ModuleNotFoundError: | |
CrissCrossAttention = None | |
class CCHead(FCNHead): | |
"""CCNet: Criss-Cross Attention for Semantic Segmentation. | |
This head is the implementation of `CCNet | |
<https://arxiv.org/abs/1811.11721>`_. | |
Args: | |
recurrence (int): Number of recurrence of Criss Cross Attention | |
module. Default: 2. | |
""" | |
def __init__(self, recurrence=2, **kwargs): | |
if CrissCrossAttention is None: | |
raise RuntimeError('Please install mmcv-full for ' | |
'CrissCrossAttention ops') | |
super().__init__(num_convs=2, **kwargs) | |
self.recurrence = recurrence | |
self.cca = CrissCrossAttention(self.channels) | |
def forward(self, inputs): | |
"""Forward function.""" | |
x = self._transform_inputs(inputs) | |
output = self.convs[0](x) | |
for _ in range(self.recurrence): | |
output = self.cca(output) | |
output = self.convs[1](output) | |
if self.concat_input: | |
output = self.conv_cat(torch.cat([x, output], dim=1)) | |
output = self.cls_seg(output) | |
return output | |