smi08 commited on
Commit
d008243
·
verified ·
1 Parent(s): 381fcd6

Upload folder using huggingface_hub

Browse files
utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .criterion import *
2
+ from .distributed import *
3
+ from .init import *
4
+ from .lr_scheduler import *
5
+ from .metric import *
6
+ from .misc import *
7
+ from .profile import *
utils/criterion.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ __all__ = ["label_smooth", "CrossEntropyWithSoftTarget", "CrossEntropyWithLabelSmooth"]
5
+
6
+
7
+ def label_smooth(
8
+ target: torch.Tensor, n_classes: int, smooth_factor=0.1
9
+ ) -> torch.Tensor:
10
+ # convert to one-hot
11
+ batch_size = target.shape[0]
12
+ target = torch.unsqueeze(target, 1)
13
+ soft_target = torch.zeros((batch_size, n_classes), device=target.device)
14
+ soft_target.scatter_(1, target, 1)
15
+ # label smoothing
16
+ soft_target = torch.add(
17
+ soft_target * (1 - smooth_factor), smooth_factor / n_classes
18
+ )
19
+ return soft_target
20
+
21
+
22
+ class CrossEntropyWithSoftTarget:
23
+ @staticmethod
24
+ def get_loss(pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor:
25
+ return torch.mean(
26
+ torch.sum(-soft_target * F.log_softmax(pred, dim=-1, _stacklevel=5), 1)
27
+ )
28
+
29
+ def __call__(self, pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor:
30
+ return self.get_loss(pred, soft_target)
31
+
32
+
33
+ class CrossEntropyWithLabelSmooth:
34
+ def __init__(self, smooth_ratio=0.1):
35
+ super(CrossEntropyWithLabelSmooth, self).__init__()
36
+ self.smooth_ratio = smooth_ratio
37
+
38
+ def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
39
+ soft_target = label_smooth(target, pred.shape[1], self.smooth_ratio)
40
+ return CrossEntropyWithSoftTarget.get_loss(pred, soft_target)
utils/distributed.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.distributed
5
+ from torchpack import distributed
6
+
7
+ from utils.misc import list_mean, list_sum
8
+
9
+ __all__ = ["ddp_reduce_tensor", "DistributedMetric"]
10
+
11
+
12
+ def ddp_reduce_tensor(
13
+ tensor: torch.Tensor, reduce="mean"
14
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
15
+ tensor_list = [torch.empty_like(tensor) for _ in range(distributed.size())]
16
+ torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
17
+ if reduce == "mean":
18
+ return list_mean(tensor_list)
19
+ elif reduce == "sum":
20
+ return list_sum(tensor_list)
21
+ elif reduce == "cat":
22
+ return torch.cat(tensor_list, dim=0)
23
+ elif reduce == "root":
24
+ return tensor_list[0]
25
+ else:
26
+ return tensor_list
27
+
28
+
29
+ class DistributedMetric(object):
30
+ """Average metrics for distributed training."""
31
+
32
+ def __init__(self, name: Optional[str] = None, backend="ddp"):
33
+ self.name = name
34
+ self.sum = 0
35
+ self.count = 0
36
+ self.backend = backend
37
+
38
+ def update(self, val: Union[torch.Tensor, int, float], delta_n=1):
39
+ val *= delta_n
40
+ if type(val) in [int, float]:
41
+ val = torch.Tensor(1).fill_(val).cuda()
42
+ if self.backend == "ddp":
43
+ self.count += ddp_reduce_tensor(
44
+ torch.Tensor(1).fill_(delta_n).cuda(), reduce="sum"
45
+ )
46
+ self.sum += ddp_reduce_tensor(val.detach(), reduce="sum")
47
+ else:
48
+ raise NotImplementedError
49
+
50
+ @property
51
+ def avg(self):
52
+ if self.count == 0:
53
+ return torch.Tensor(1).fill_(-1)
54
+ else:
55
+ return self.sum / self.count
utils/init.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ __all__ = ["init_modules", "load_state_dict"]
9
+
10
+
11
+ def init_modules(
12
+ module: Union[nn.Module, List[nn.Module]], init_type="he_fout"
13
+ ) -> None:
14
+ init_params = init_type.split("@")
15
+ if len(init_params) > 1:
16
+ init_params = float(init_params[1])
17
+ else:
18
+ init_params = None
19
+
20
+ if isinstance(module, list):
21
+ for sub_module in module:
22
+ init_modules(sub_module)
23
+ else:
24
+ for m in module.modules():
25
+ if isinstance(m, nn.Conv2d):
26
+ if init_type == "he_fout":
27
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
28
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
29
+ elif init_type.startswith("kaiming_uniform"):
30
+ nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
31
+ else:
32
+ nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
33
+ if m.bias is not None:
34
+ m.bias.data.zero_()
35
+ elif isinstance(m, _BatchNorm):
36
+ m.weight.data.fill_(1)
37
+ m.bias.data.zero_()
38
+ elif isinstance(m, nn.Linear):
39
+ nn.init.trunc_normal_(m.weight, std=0.02)
40
+ if m.bias is not None:
41
+ m.bias.data.zero_()
42
+ else:
43
+ weight = getattr(m, "weight", None)
44
+ bias = getattr(m, "bias", None)
45
+ if isinstance(weight, torch.nn.Parameter):
46
+ nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
47
+ if isinstance(bias, torch.nn.Parameter):
48
+ bias.data.zero_()
49
+
50
+
51
+ def load_state_dict(
52
+ model: nn.Module, state_dict: Dict[str, torch.Tensor], strict=True
53
+ ) -> None:
54
+ current_state_dict = model.state_dict()
55
+ for key in state_dict:
56
+ if current_state_dict[key].shape != state_dict[key].shape:
57
+ if strict:
58
+ raise ValueError(
59
+ "%s shape mismatch (src=%s, target=%s)"
60
+ % (
61
+ key,
62
+ list(state_dict[key].shape),
63
+ list(current_state_dict[key].shape),
64
+ )
65
+ )
66
+ else:
67
+ print(
68
+ "Skip loading %s due to shape mismatch (src=%s, target=%s)"
69
+ % (
70
+ key,
71
+ list(state_dict[key].shape),
72
+ list(current_state_dict[key].shape),
73
+ )
74
+ )
75
+ else:
76
+ current_state_dict[key].copy_(state_dict[key])
77
+ model.load_state_dict(current_state_dict)
utils/lr_scheduler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+ from torch.optim import Optimizer
6
+
7
+ __all__ = ["CosineLRwithWarmup"]
8
+
9
+
10
+ class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
11
+ def __init__(
12
+ self,
13
+ optimizer: Optimizer,
14
+ warmup_steps: int,
15
+ warmup_lr: float,
16
+ decay_steps: int,
17
+ last_epoch: int = -1,
18
+ ) -> None:
19
+ self.warmup_steps = warmup_steps
20
+ self.warmup_lr = warmup_lr
21
+ self.decay_steps = decay_steps
22
+ super().__init__(optimizer, last_epoch)
23
+
24
+ def get_lr(self) -> List[float]:
25
+ if self.last_epoch < self.warmup_steps:
26
+ return [
27
+ (base_lr - self.warmup_lr) * self.last_epoch / self.warmup_steps
28
+ + self.warmup_lr
29
+ for base_lr in self.base_lrs
30
+ ]
31
+ else:
32
+ current_steps = self.last_epoch - self.warmup_steps
33
+ return [
34
+ 0.5
35
+ * base_lr
36
+ * (1 + math.cos(math.pi * current_steps / self.decay_steps))
37
+ for base_lr in self.base_lrs
38
+ ]
utils/metric.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import os
3
+ import argparse
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision
8
+ from torch.autograd import Variable
9
+ import torch.optim as optim
10
+ import numpy as np
11
+ import torch
12
+ __all__ = ["accuracy", "AverageMeter"]
13
+
14
+
15
+
16
+
17
+
18
+ def accuracy(
19
+ output: torch.Tensor, target: torch.Tensor, topk=(1,)
20
+ ) -> List[torch.Tensor]:
21
+ """Computes the precision@k for the specified values of k."""
22
+ maxk = max(topk)
23
+ batch_size = target.shape[0]
24
+
25
+ _, pred = output.topk(maxk, 1, True, True)
26
+ pred = pred.t()
27
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
28
+
29
+ res = []
30
+ for k in topk:
31
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
32
+ res.append(correct_k.mul_(100.0 / batch_size))
33
+ return res
34
+
35
+
36
+ class AverageMeter(object):
37
+ """Computes and stores the average and current value.
38
+
39
+ Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
40
+ """
41
+
42
+ def __init__(self):
43
+ self.val = 0
44
+ self.avg = 0
45
+ self.sum = 0
46
+ self.count = 0
47
+
48
+ def reset(self):
49
+ self.val = 0
50
+ self.avg = 0
51
+ self.sum = 0
52
+ self.count = 0
53
+
54
+ def update(self, val: Union[torch.Tensor, np.ndarray, float, int], n=1):
55
+ self.val = val
56
+ self.sum += val * n
57
+ self.count += n
58
+ self.avg = self.sum / self.count
utils/misc.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import yaml
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ __all__ = [
9
+ "make_divisible",
10
+ "load_state_dict_from_file",
11
+ "list_mean",
12
+ "list_sum",
13
+ "parse_unknown_args",
14
+ "partial_update_config",
15
+ "remove_bn",
16
+ "get_same_padding",
17
+ "torch_random_choices",
18
+ ]
19
+
20
+
21
+ def make_divisible(
22
+ v: Union[int, float], divisor: Optional[int], min_val=None
23
+ ) -> Union[int, float]:
24
+ """This function is taken from the original tf repo.
25
+
26
+ It ensures that all layers have a channel number that is divisible by 8
27
+ It can be seen here:
28
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
29
+ :param v:
30
+ :param divisor:
31
+ :param min_val:
32
+ :return:
33
+ """
34
+ if divisor is None:
35
+ return v
36
+
37
+ if min_val is None:
38
+ min_val = divisor
39
+ new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
40
+ # Make sure that round down does not go down by more than 10%.
41
+ if new_v < 0.9 * v:
42
+ new_v += divisor
43
+ return new_v
44
+
45
+
46
+ def load_state_dict_from_file(file: str) -> Dict[str, torch.Tensor]:
47
+ checkpoint = torch.load(file, map_location="cpu")
48
+ if "state_dict" in checkpoint:
49
+ checkpoint = checkpoint["state_dict"]
50
+ return checkpoint
51
+
52
+
53
+ def list_sum(x: List) -> Any:
54
+ return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
55
+
56
+
57
+ def list_mean(x: List) -> Any:
58
+ return list_sum(x) / len(x)
59
+
60
+
61
+ def parse_unknown_args(unknown: List) -> Dict:
62
+ """Parse unknown args."""
63
+ index = 0
64
+ parsed_dict = {}
65
+ while index < len(unknown):
66
+ key, val = unknown[index], unknown[index + 1]
67
+ index += 2
68
+ if key.startswith("--"):
69
+ key = key[2:]
70
+ try:
71
+ # try parsing with yaml
72
+ if "{" in val and "}" in val and ":" in val:
73
+ val = val.replace(":", ": ") # add space manually for dict
74
+ out_val = yaml.safe_load(val)
75
+ except ValueError:
76
+ # return raw string if parsing fails
77
+ out_val = val
78
+ parsed_dict[key] = out_val
79
+ return parsed_dict
80
+
81
+
82
+ def partial_update_config(config: Dict, partial_config: Dict):
83
+ for key in partial_config:
84
+ if (
85
+ key in config
86
+ and isinstance(partial_config[key], Dict)
87
+ and isinstance(config[key], Dict)
88
+ ):
89
+ partial_update_config(config[key], partial_config[key])
90
+ else:
91
+ config[key] = partial_config[key]
92
+
93
+
94
+ def remove_bn(model: nn.Module) -> None:
95
+ for m in model.modules():
96
+ if isinstance(m, _BatchNorm):
97
+ m.weight = m.bias = None
98
+ m.forward = lambda x: x
99
+
100
+
101
+ def get_same_padding(kernel_size: Union[int, Tuple[int, int]]) -> Union[int, tuple]:
102
+ if isinstance(kernel_size, tuple):
103
+ assert len(kernel_size) == 2, f"invalid kernel size: {kernel_size}"
104
+ p1 = get_same_padding(kernel_size[0])
105
+ p2 = get_same_padding(kernel_size[1])
106
+ return p1, p2
107
+ else:
108
+ assert isinstance(
109
+ kernel_size, int
110
+ ), "kernel size should be either `int` or `tuple`"
111
+ assert kernel_size % 2 > 0, "kernel size should be odd number"
112
+ return kernel_size // 2
113
+
114
+
115
+ def torch_random_choices(
116
+ src_list: List[Any],
117
+ generator: Optional[torch.Generator],
118
+ k=1,
119
+ ) -> Union[Any, List[Any]]:
120
+ rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,))
121
+ out_list = [src_list[i] for i in rand_idx]
122
+ return out_list[0] if k == 1 else out_list
utils/profile.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchprofile import profile_macs
6
+
7
+ __all__ = ["is_parallel", "get_module_device", "trainable_param_num", "inference_macs"]
8
+
9
+
10
+ def is_parallel(model: nn.Module) -> bool:
11
+ return isinstance(
12
+ model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
13
+ )
14
+
15
+
16
+ def get_module_device(module: nn.Module) -> torch.device:
17
+ return module.parameters().__next__().device
18
+
19
+
20
+ def trainable_param_num(network: nn.Module, unit=1e6) -> float:
21
+ return sum(p.numel() for p in network.parameters() if p.requires_grad) / unit
22
+
23
+
24
+ def inference_macs(
25
+ network: nn.Module,
26
+ args: Tuple = (),
27
+ data_shape: Optional[Tuple] = None,
28
+ unit: float = 1e6,
29
+ ) -> float:
30
+ if is_parallel(network):
31
+ network = network.module
32
+ if data_shape is not None:
33
+ if len(args) > 0:
34
+ raise ValueError("Please provide either data_shape or args tuple.")
35
+ args = (torch.zeros(data_shape, device=get_module_device(network)),)
36
+ is_training = network.training
37
+ network.eval()
38
+ macs = profile_macs(network, args=args) / unit
39
+ network.train(is_training)
40
+ return macs