import torch | |
import torch.nn as nn | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import OneCycleLR | |
def configure_optimizer( | |
model: nn.Module, | |
lr: float = 1e-3, | |
weight_decay: float = 0.01, | |
total_steps: int = 100 | |
): | |
"""Return AdamW optimizer with OneCycleLR scheduler.""" | |
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | |
scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps) | |
return optimizer, scheduler | |
def adjust_learning_rate(optimizer: torch.optim.Optimizer, factor: float) -> float: | |
"""Scale the learning rate of all param groups by ``factor``. | |
Parameters | |
---------- | |
optimizer: | |
The optimizer whose learning rate will be adjusted. | |
factor: | |
Multiplicative factor applied to the current learning rate. | |
Returns | |
------- | |
float | |
The updated learning rate of the first parameter group. | |
""" | |
for param_group in optimizer.param_groups: | |
param_group["lr"] *= factor | |
return optimizer.param_groups[0]["lr"] | |