|
""" |
|
MIT License |
|
|
|
Copyright (c) 2019 Ildoo Kim |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
""" |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
|
|
|
class GradualWarmupScheduler(_LRScheduler): |
|
""" Gradually warm-up(increasing) learning rate in optimizer. |
|
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. |
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. |
|
total_epoch: target learning rate is reached at total_epoch, gradually |
|
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) |
|
""" |
|
|
|
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): |
|
self.multiplier = multiplier |
|
if self.multiplier < 1.: |
|
raise ValueError('multiplier should be greater thant or equal to 1.') |
|
self.total_epoch = total_epoch |
|
self.after_scheduler = after_scheduler |
|
self.finished = False |
|
super(GradualWarmupScheduler, self).__init__(optimizer) |
|
|
|
def get_lr(self): |
|
if self.last_epoch > self.total_epoch: |
|
if self.after_scheduler: |
|
if not self.finished: |
|
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] |
|
self.finished = True |
|
return self.after_scheduler.get_last_lr() |
|
return [base_lr * self.multiplier for base_lr in self.base_lrs] |
|
|
|
if self.multiplier == 1.0: |
|
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] |
|
else: |
|
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] |
|
|
|
def step_ReduceLROnPlateau(self, metrics, epoch=None): |
|
if epoch is None: |
|
epoch = self.last_epoch + 1 |
|
self.last_epoch = epoch if epoch != 0 else 1 |
|
if self.last_epoch <= self.total_epoch: |
|
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] |
|
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): |
|
param_group['lr'] = lr |
|
else: |
|
if epoch is None: |
|
self.after_scheduler.step(metrics, None) |
|
else: |
|
self.after_scheduler.step(metrics, epoch - self.total_epoch) |
|
|
|
def step(self, epoch=None, metrics=None): |
|
if type(self.after_scheduler) != ReduceLROnPlateau: |
|
if self.finished and self.after_scheduler: |
|
if epoch is None: |
|
self.after_scheduler.step(None) |
|
else: |
|
self.after_scheduler.step(epoch - self.total_epoch) |
|
self._last_lr = self.after_scheduler.get_last_lr() |
|
else: |
|
return super(GradualWarmupScheduler, self).step(epoch) |
|
else: |
|
self.step_ReduceLROnPlateau(metrics, epoch) |
|
|