import copy import warnings import numpy as np import torch import torch.nn as nn from torch_geometric.data import Data, Batch from .warmup import GradualWarmupScheduler #customize exp lr scheduler with min lr class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR): def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False): self.gamma = gamma self.min_lr = min_lr super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose) def get_lr(self): if not self._get_lr_called_within_step: warnings.warn("To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning) if self.last_epoch == 0: return self.base_lrs return [max(group['lr'] * self.gamma, self.min_lr) for group in self.optimizer.param_groups] def _get_closed_form_lr(self): return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr) for base_lr in self.base_lrs] def repeat_data(data: Data, num_repeat) -> Batch: datas = [copy.deepcopy(data) for i in range(num_repeat)] return Batch.from_data_list(datas) def repeat_batch(batch: Batch, num_repeat) -> Batch: datas = batch.to_data_list() new_data = [] for i in range(num_repeat): new_data += copy.deepcopy(datas) return Batch.from_data_list(new_data) def inf_iterator(iterable): iterator = iterable.__iter__() while True: try: yield iterator.__next__() except StopIteration: iterator = iterable.__iter__() def get_optimizer(cfg, model): if cfg.type == 'adam': return torch.optim.Adam( model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(cfg.beta1, cfg.beta2, ) ) else: raise NotImplementedError('Optimizer not supported: %s' % cfg.type) def get_scheduler(cfg, optimizer): if cfg.type == 'plateau': return torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=cfg.factor, patience=cfg.patience, min_lr=cfg.min_lr ) elif cfg.type == 'warmup_plateau': return GradualWarmupScheduler( optimizer, multiplier = cfg.multiplier, total_epoch = cfg.total_epoch, after_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=cfg.factor, patience=cfg.patience, min_lr=cfg.min_lr ) ) elif cfg.type == 'expmin': return ExponentialLR_with_minLr( optimizer, gamma=cfg.factor, min_lr=cfg.min_lr, ) elif cfg.type == 'expmin_milestone': gamma = np.exp(np.log(cfg.factor) / cfg.milestone) return ExponentialLR_with_minLr( optimizer, gamma=gamma, min_lr=cfg.min_lr, ) else: raise NotImplementedError('Scheduler not supported: %s' % cfg.type)