| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| def configure_optimizer( | |
| model: nn.Module, | |
| lr: float = 1e-3, | |
| weight_decay: float = 0.01, | |
| total_steps: int = 100 | |
| ): | |
| """Return AdamW optimizer with OneCycleLR scheduler.""" | |
| optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | |
| scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps) | |
| return optimizer, scheduler | |
| def adjust_learning_rate(optimizer: torch.optim.Optimizer, factor: float) -> float: | |
| """Scale the learning rate of all param groups by ``factor``. | |
| Parameters | |
| ---------- | |
| optimizer: | |
| The optimizer whose learning rate will be adjusted. | |
| factor: | |
| Multiplicative factor applied to the current learning rate. | |
| Returns | |
| ------- | |
| float | |
| The updated learning rate of the first parameter group. | |
| """ | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] *= factor | |
| return optimizer.param_groups[0]["lr"] | |