hank1996 commited on
Commit
35c8a1c
·
1 Parent(s): 167fc92

Create new file

Browse files
Files changed (1) hide show
  1. lib/utils/utils.py +166 -0
lib/utils/utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import logging
5
+ import time
6
+ from collections import namedtuple
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.optim as optim
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ from torch.utils.data import DataLoader
14
+ from prefetch_generator import BackgroundGenerator
15
+ from contextlib import contextmanager
16
+ import re
17
+
18
+ def clean_str(s):
19
+ # Cleans a string by replacing special characters with underscore _
20
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
21
+
22
+ def create_logger(cfg, cfg_path, phase='train', rank=-1):
23
+ # set up logger dir
24
+ dataset = cfg.DATASET.DATASET
25
+ dataset = dataset.replace(':', '_')
26
+ model = cfg.MODEL.NAME
27
+ cfg_path = os.path.basename(cfg_path).split('.')[0]
28
+
29
+ if rank in [-1, 0]:
30
+ time_str = time.strftime('%Y-%m-%d-%H-%M')
31
+ log_file = '{}_{}_{}.log'.format(cfg_path, time_str, phase)
32
+ # set up tensorboard_log_dir
33
+ tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
34
+ (cfg_path + '_' + time_str)
35
+ final_output_dir = tensorboard_log_dir
36
+ if not tensorboard_log_dir.exists():
37
+ print('=> creating {}'.format(tensorboard_log_dir))
38
+ tensorboard_log_dir.mkdir(parents=True)
39
+
40
+ final_log_file = tensorboard_log_dir / log_file
41
+ head = '%(asctime)-15s %(message)s'
42
+ logging.basicConfig(filename=str(final_log_file),
43
+ format=head)
44
+ logger = logging.getLogger()
45
+ logger.setLevel(logging.INFO)
46
+ console = logging.StreamHandler()
47
+ logging.getLogger('').addHandler(console)
48
+
49
+ return logger, str(final_output_dir), str(tensorboard_log_dir)
50
+ else:
51
+ return None, None, None
52
+
53
+
54
+ def select_device(logger=None, device='', batch_size=None):
55
+ # device = 'cpu' or '0' or '0,1,2,3'
56
+ cpu_request = device.lower() == 'cpu'
57
+ if device and not cpu_request: # if device requested other than 'cpu'
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
59
+ assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
60
+
61
+ cuda = False if cpu_request else torch.cuda.is_available()
62
+ if cuda:
63
+ c = 1024 ** 2 # bytes to MB
64
+ ng = torch.cuda.device_count()
65
+ if ng > 1 and batch_size: # check that batch_size is compatible with device_count
66
+ assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
67
+ x = [torch.cuda.get_device_properties(i) for i in range(ng)]
68
+ s = f'Using torch {torch.__version__} '
69
+ for i in range(0, ng):
70
+ if i == 1:
71
+ s = ' ' * len(s)
72
+ if logger:
73
+ logger.info("%sCUDA:%g (%s, %dMB)" % (s, i, x[i].name, x[i].total_memory / c))
74
+ else:
75
+ if logger:
76
+ logger.info(f'Using torch {torch.__version__} CPU')
77
+
78
+ if logger:
79
+ logger.info('') # skip a line
80
+ return torch.device('cuda:0' if cuda else 'cpu')
81
+
82
+
83
+ def get_optimizer(cfg, model):
84
+ optimizer = None
85
+ if cfg.TRAIN.OPTIMIZER == 'sgd':
86
+ optimizer = optim.SGD(
87
+ filter(lambda p: p.requires_grad, model.parameters()),
88
+ lr=cfg.TRAIN.LR0,
89
+ momentum=cfg.TRAIN.MOMENTUM,
90
+ weight_decay=cfg.TRAIN.WD,
91
+ nesterov=cfg.TRAIN.NESTEROV
92
+ )
93
+ elif cfg.TRAIN.OPTIMIZER == 'adam':
94
+ optimizer = optim.Adam(
95
+ filter(lambda p: p.requires_grad, model.parameters()),
96
+ #model.parameters(),
97
+ lr=cfg.TRAIN.LR0,
98
+ betas=(cfg.TRAIN.MOMENTUM, 0.999)
99
+ )
100
+
101
+ return optimizer
102
+
103
+
104
+ def save_checkpoint(epoch, name, model, optimizer, output_dir, filename, is_best=False):
105
+ model_state = model.module.state_dict() if is_parallel(model) else model.state_dict()
106
+ checkpoint = {
107
+ 'epoch': epoch,
108
+ 'model': name,
109
+ 'state_dict': model_state,
110
+ # 'best_state_dict': model.module.state_dict(),
111
+ # 'perf': perf_indicator,
112
+ 'optimizer': optimizer.state_dict(),
113
+ }
114
+ torch.save(checkpoint, os.path.join(output_dir, filename))
115
+ if is_best and 'state_dict' in checkpoint:
116
+ torch.save(checkpoint['best_state_dict'],
117
+ os.path.join(output_dir, 'model_best.pth'))
118
+
119
+
120
+ def initialize_weights(model):
121
+ for m in model.modules():
122
+ t = type(m)
123
+ if t is nn.Conv2d:
124
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
125
+ elif t is nn.BatchNorm2d:
126
+ m.eps = 1e-3
127
+ m.momentum = 0.03
128
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
129
+ # elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
130
+ m.inplace = True
131
+
132
+
133
+ def xyxy2xywh(x):
134
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
135
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
136
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
137
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
138
+ y[:, 2] = x[:, 2] - x[:, 0] # width
139
+ y[:, 3] = x[:, 3] - x[:, 1] # height
140
+ return y
141
+
142
+
143
+ def is_parallel(model):
144
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
145
+
146
+
147
+ def time_synchronized():
148
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
149
+ return time.time()
150
+
151
+
152
+ class DataLoaderX(DataLoader):
153
+ """prefetch dataloader"""
154
+ def __iter__(self):
155
+ return BackgroundGenerator(super().__iter__())
156
+
157
+ @contextmanager
158
+ def torch_distributed_zero_first(local_rank: int):
159
+ """
160
+ Decorator to make all processes in distributed training wait for each local_master to do something.
161
+ """
162
+ if local_rank not in [-1, 0]:
163
+ torch.distributed.barrier()
164
+ yield
165
+ if local_rank == 0:
166
+ torch.distributed.barrier()