import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR def configure_optimizer( model: nn.Module, lr: float = 1e-3, weight_decay: float = 0.01, total_steps: int = 100 ): """Return AdamW optimizer with OneCycleLR scheduler.""" optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps) return optimizer, scheduler def adjust_learning_rate(optimizer: torch.optim.Optimizer, factor: float) -> float: """Scale the learning rate of all param groups by ``factor``. Parameters ---------- optimizer: The optimizer whose learning rate will be adjusted. factor: Multiplicative factor applied to the current learning rate. Returns ------- float The updated learning rate of the first parameter group. """ for param_group in optimizer.param_groups: param_group["lr"] *= factor return optimizer.param_groups[0]["lr"]