File size: 6,220 Bytes
412c852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
from collections import OrderedDict
from typing import Dict, Optional, Sequence

try:

    import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval  # noqa
    import cityscapesscripts.helpers.labels as CSLabels
except ImportError:
    CSLabels = None
    CSEval = None

import numpy as np
from mmengine.dist import is_main_process, master_only
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from PIL import Image

from mmseg.registry import METRICS


@METRICS.register_module()
class CityscapesMetric(BaseMetric):
    """Cityscapes evaluation metric.

    Args:
        output_dir (str): The directory for output prediction
        ignore_index (int): Index that will be ignored in evaluation.
            Default: 255.
        format_only (bool): Only format result for results commit without
            perform evaluation. It is useful when you want to format the result
            to a specific format and submit it to the test server.
            Defaults to False.
        keep_results (bool): Whether to keep the results. When ``format_only``
            is True, ``keep_results`` must be True. Defaults to False.
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Defaults to None.
    """

    def __init__(self,
                 output_dir: str,
                 ignore_index: int = 255,
                 format_only: bool = False,
                 keep_results: bool = False,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None,
                 **kwargs) -> None:
        super().__init__(collect_device=collect_device, prefix=prefix)
        if CSEval is None:
            raise ImportError('Please run "pip install cityscapesscripts" to '
                              'install cityscapesscripts first.')
        self.output_dir = output_dir
        self.ignore_index = ignore_index

        self.format_only = format_only
        if format_only:
            assert keep_results, (
                'When format_only is True, the results must be keep, please '
                f'set keep_results as True, but got {keep_results}')
        self.keep_results = keep_results
        self.prefix = prefix
        if is_main_process():
            mkdir_or_exist(self.output_dir)

    @master_only
    def __del__(self) -> None:
        """Clean up."""
        if not self.keep_results:
            shutil.rmtree(self.output_dir)

    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
        """Process one batch of data and data_samples.

        The processed results should be stored in ``self.results``, which will
        be used to computed the metrics when all batches have been processed.

        Args:
            data_batch (dict): A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """
        mkdir_or_exist(self.output_dir)

        for data_sample in data_samples:
            pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
            # when evaluating with official cityscapesscripts,
            # labelIds should be used
            pred_label = self._convert_to_label_id(pred_label)
            basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
            png_filename = osp.abspath(
                osp.join(self.output_dir, f'{basename}.png'))
            output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
            output.save(png_filename)
            if self.format_only:
                # format_only always for test dataset without ground truth
                gt_filename = ''
            else:
                # when evaluating with official cityscapesscripts,
                # **_gtFine_labelIds.png is used
                gt_filename = data_sample['seg_map_path'].replace(
                    'labelTrainIds.png', 'labelIds.png')
            self.results.append((png_filename, gt_filename))

    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): Testing results of the dataset.

        Returns:
            dict[str: float]: Cityscapes evaluation results.
        """
        logger: MMLogger = MMLogger.get_current_instance()
        if self.format_only:
            logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
            return OrderedDict()

        msg = 'Evaluating in Cityscapes style'
        if logger is None:
            msg = '\n' + msg
        print_log(msg, logger=logger)

        eval_results = dict()
        print_log(
            f'Evaluating results under {self.output_dir} ...', logger=logger)

        CSEval.args.evalInstLevelScore = True
        CSEval.args.predictionPath = osp.abspath(self.output_dir)
        CSEval.args.evalPixelAccuracy = True
        CSEval.args.JSONOutput = False

        pred_list, gt_list = zip(*results)
        metric = dict()
        eval_results.update(
            CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args))
        metric['averageScoreCategories'] = eval_results[
            'averageScoreCategories']
        metric['averageScoreInstCategories'] = eval_results[
            'averageScoreInstCategories']
        return metric

    @staticmethod
    def _convert_to_label_id(result):
        """Convert trainId to id for cityscapes."""
        if isinstance(result, str):
            result = np.load(result)
        result_copy = result.copy()
        for trainId, label in CSLabels.trainId2label.items():
            result_copy[result == trainId] = label.id

        return result_copy