File size: 2,083 Bytes
412c852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmseg.registry import HOOKS
import torch
import json
import os

def group_subnets_by_flops(data, flops_step=10):
    sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
    candidate_idx = []
    grouped_cands = []
    last_flops = 0
    for cfg_id, flops in sorted_data.items():
        # flops, _ = values
        flops = flops / 1e9
        if abs(last_flops - flops) > flops_step:
            if len(candidate_idx) > 0:
                grouped_cands.append(candidate_idx)
            candidate_idx = [int(cfg_id)]
            last_flops = flops
        else:
            candidate_idx.append(int(cfg_id))

    if len(candidate_idx) > 0:
        grouped_cands.append(candidate_idx)

    return grouped_cands



def initialize_model_stitching_layer(model, dataiter):
    images = []
    total_samples = 50
    while len(images) < total_samples:
        item = next(dataiter)
        data = model.data_preprocessor(item, True)
        images.extend(data['inputs'])

    images = torch.stack(images, dim=0)
    samples = images.cuda()
    model.backbone.initialize_stitching_weights(samples)

@HOOKS.register_module()
class SNNetHook(Hook):
    """Docstring for NewHook.
    """

    def before_train(self, runner) -> None:
        if is_model_wrapper(runner.model):
            model = runner.model.module
        else:
            model = runner.model
        if not runner._resume:
            initialize_model_stitching_layer(model, runner.train_loop.dataloader_iterator)

        # cfg = Config.fromfile(runner._cfg_file)
        cfg_name = runner.cfg.filename.split('/')[-1].split('.')[0]
        with open(os.path.join('./model_flops', f'snnet_flops_{cfg_name}.json'), 'r') as f:
            flops_params = json.load(f)

        flops_step = 10
        grouped_subnet = group_subnets_by_flops(flops_params, flops_step)
        model.backbone.flops_grouped_cfgs = grouped_subnet