WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
1.06 kB
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
def configure_optimizer(
model: nn.Module,
lr: float = 1e-3,
weight_decay: float = 0.01,
total_steps: int = 100
):
"""Return AdamW optimizer with OneCycleLR scheduler."""
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps)
return optimizer, scheduler
def adjust_learning_rate(optimizer: torch.optim.Optimizer, factor: float) -> float:
"""Scale the learning rate of all param groups by ``factor``.
Parameters
----------
optimizer:
The optimizer whose learning rate will be adjusted.
factor:
Multiplicative factor applied to the current learning rate.
Returns
-------
float
The updated learning rate of the first parameter group.
"""
for param_group in optimizer.param_groups:
param_group["lr"] *= factor
return optimizer.param_groups[0]["lr"]