|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import contextlib | 
					
						
						|  | import logging | 
					
						
						|  | from collections import defaultdict | 
					
						
						|  | from typing import List | 
					
						
						|  | from typing import Tuple | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import Tensor | 
					
						
						|  | from torch.optim import Optimizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BatchedOptimizer(Optimizer): | 
					
						
						|  | """ | 
					
						
						|  | This class adds to class Optimizer the capability to optimize parameters in batches: | 
					
						
						|  | it will stack the parameters and their grads for you so the optimizer can work | 
					
						
						|  | on tensors with an extra leading dimension.  This is intended for speed with GPUs, | 
					
						
						|  | as it reduces the number of kernels launched in the optimizer. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | params: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, params, defaults): | 
					
						
						|  | super(BatchedOptimizer, self).__init__(params, defaults) | 
					
						
						|  |  | 
					
						
						|  | @contextlib.contextmanager | 
					
						
						|  | def batched_params(self, param_group, group_params_names): | 
					
						
						|  | """ | 
					
						
						|  | This function returns (technically, yields) a list of | 
					
						
						|  | of tuples (p, state), where | 
					
						
						|  | p is a `fake` parameter that is stacked (over axis 0) from real parameters | 
					
						
						|  | that share the same shape, and its gradient is also stacked; | 
					
						
						|  | `state` is the state corresponding to this batch of parameters | 
					
						
						|  | (it will be physically located in the "state" for one of the real | 
					
						
						|  | parameters, the last one that has any particular shape and dtype). | 
					
						
						|  |  | 
					
						
						|  | This function is decorated as a context manager so that it can | 
					
						
						|  | write parameters back to their "real" locations. | 
					
						
						|  |  | 
					
						
						|  | The idea is, instead of doing: | 
					
						
						|  | <code> | 
					
						
						|  | for p in group["params"]: | 
					
						
						|  | state = self.state[p] | 
					
						
						|  | ... | 
					
						
						|  | </code> | 
					
						
						|  | you can do: | 
					
						
						|  | <code> | 
					
						
						|  | with self.batched_params(group["params"]) as batches: | 
					
						
						|  | for p, state, p_names in batches: | 
					
						
						|  | ... | 
					
						
						|  | </code> | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | group: a parameter group, which is a list of parameters; should be | 
					
						
						|  | one of self.param_groups. | 
					
						
						|  | group_params_names: name for each parameter in group, | 
					
						
						|  | which is List[str]. | 
					
						
						|  | """ | 
					
						
						|  | batches = defaultdict( | 
					
						
						|  | list | 
					
						
						|  | ) | 
					
						
						|  | batches_names = defaultdict( | 
					
						
						|  | list | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | assert len(param_group) == len(group_params_names) | 
					
						
						|  | for p, named_p in zip(param_group, group_params_names): | 
					
						
						|  | key = (str(p.dtype), *p.shape) | 
					
						
						|  | batches[key].append(p) | 
					
						
						|  | batches_names[key].append(named_p) | 
					
						
						|  |  | 
					
						
						|  | batches_names_keys = list(batches_names.keys()) | 
					
						
						|  | sorted_idx = sorted( | 
					
						
						|  | range(len(batches_names)), key=lambda i: batches_names_keys[i]) | 
					
						
						|  | batches_names = [ | 
					
						
						|  | batches_names[batches_names_keys[idx]] for idx in sorted_idx | 
					
						
						|  | ] | 
					
						
						|  | batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] | 
					
						
						|  |  | 
					
						
						|  | stacked_params_dict = dict() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tuples = [] | 
					
						
						|  |  | 
					
						
						|  | for batch, batch_names in zip(batches, batches_names): | 
					
						
						|  | p = batch[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | state = self.state[p] | 
					
						
						|  | p_stacked = torch.stack(batch) | 
					
						
						|  | grad = torch.stack([ | 
					
						
						|  | torch.zeros_like(p) if p.grad is None else p.grad for p in batch | 
					
						
						|  | ]) | 
					
						
						|  | p_stacked.grad = grad | 
					
						
						|  | stacked_params_dict[key] = p_stacked | 
					
						
						|  | tuples.append((p_stacked, state, batch_names)) | 
					
						
						|  |  | 
					
						
						|  | yield tuples | 
					
						
						|  |  | 
					
						
						|  | for ((stacked_params, _state, _names), batch) in zip(tuples, batches): | 
					
						
						|  | for i, p in enumerate(batch): | 
					
						
						|  | p.copy_(stacked_params[i]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ScaledAdam(BatchedOptimizer): | 
					
						
						|  | """ | 
					
						
						|  | Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update | 
					
						
						|  | proportional to the norm of that parameter; and also learn the scale of the parameter, | 
					
						
						|  | in log space, subject to upper and lower limits (as if we had factored each parameter as | 
					
						
						|  | param = underlying_param * log_scale.exp()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | params:  The parameters or param_groups to optimize (like other Optimizer subclasses) | 
					
						
						|  | lr:  The learning rate.  We will typically use a learning rate schedule that starts | 
					
						
						|  | at 0.03 and decreases over time, i.e. much higher than other common | 
					
						
						|  | optimizers. | 
					
						
						|  | clipping_scale: (e.g. 2.0) | 
					
						
						|  | A scale for gradient-clipping: if specified, the normalized gradients | 
					
						
						|  | over the whole model will be clipped to have 2-norm equal to | 
					
						
						|  | `clipping_scale` times the median 2-norm over the most recent period | 
					
						
						|  | of `clipping_update_period` minibatches.  By "normalized gradients", | 
					
						
						|  | we mean after multiplying by the rms parameter value for this tensor | 
					
						
						|  | [for non-scalars]; this is appropriate because our update is scaled | 
					
						
						|  | by this quantity. | 
					
						
						|  | betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. | 
					
						
						|  | Must satisfy 0 < beta <= beta2 < 1. | 
					
						
						|  | scalar_lr_scale: A scaling factor on the learning rate, that we use to update the | 
					
						
						|  | scale of each parameter tensor and scalar parameters of the mode.. | 
					
						
						|  | If each parameter were decomposed | 
					
						
						|  | as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale | 
					
						
						|  | would be a the scaling factor on the learning rate of p_scale. | 
					
						
						|  | eps:  A general-purpose epsilon to prevent division by zero | 
					
						
						|  | param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of | 
					
						
						|  | learning the scale on the parameters (we'll constrain the rms of each non-scalar | 
					
						
						|  | parameter tensor to be >= this value) | 
					
						
						|  | param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of | 
					
						
						|  | learning the scale on the parameters (we'll constrain the rms of each non-scalar | 
					
						
						|  | parameter tensor to be <= this value) | 
					
						
						|  | scalar_max: Maximum absolute value for scalar parameters (applicable if your | 
					
						
						|  | model has any parameters with numel() == 1). | 
					
						
						|  | size_update_period: The periodicity, in steps, with which we update the size (scale) | 
					
						
						|  | of the parameter tensor.  This is provided to save a little time | 
					
						
						|  | in the update. | 
					
						
						|  | clipping_update_period: if clipping_scale is specified, this is the period | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | params, | 
					
						
						|  | lr=3e-02, | 
					
						
						|  | clipping_scale=None, | 
					
						
						|  | betas=(0.9, 0.98), | 
					
						
						|  | scalar_lr_scale=0.1, | 
					
						
						|  | eps=1.0e-08, | 
					
						
						|  | param_min_rms=1.0e-05, | 
					
						
						|  | param_max_rms=3.0, | 
					
						
						|  | scalar_max=10.0, | 
					
						
						|  | size_update_period=4, | 
					
						
						|  | clipping_update_period=100, | 
					
						
						|  | parameters_names=None, | 
					
						
						|  | show_dominant_parameters=True, ): | 
					
						
						|  |  | 
					
						
						|  | assert parameters_names is not None, ( | 
					
						
						|  | "Please prepare parameters_names," | 
					
						
						|  | "which is a List[List[str]]. Each List[str] is for a group" | 
					
						
						|  | "and each str is for a parameter") | 
					
						
						|  | defaults = dict( | 
					
						
						|  | lr=lr, | 
					
						
						|  | clipping_scale=clipping_scale, | 
					
						
						|  | betas=betas, | 
					
						
						|  | scalar_lr_scale=scalar_lr_scale, | 
					
						
						|  | eps=eps, | 
					
						
						|  | param_min_rms=param_min_rms, | 
					
						
						|  | param_max_rms=param_max_rms, | 
					
						
						|  | scalar_max=scalar_max, | 
					
						
						|  | size_update_period=size_update_period, | 
					
						
						|  | clipping_update_period=clipping_update_period, ) | 
					
						
						|  |  | 
					
						
						|  | super(ScaledAdam, self).__init__(params, defaults) | 
					
						
						|  | assert len(self.param_groups) == len(parameters_names) | 
					
						
						|  | self.parameters_names = parameters_names | 
					
						
						|  | self.show_dominant_parameters = show_dominant_parameters | 
					
						
						|  |  | 
					
						
						|  | def __setstate__(self, state): | 
					
						
						|  | super(ScaledAdam, self).__setstate__(state) | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def step(self, closure=None): | 
					
						
						|  | """Performs a single optimization step. | 
					
						
						|  |  | 
					
						
						|  | Arguments: | 
					
						
						|  | closure (callable, optional): A closure that reevaluates the model | 
					
						
						|  | and returns the loss. | 
					
						
						|  | """ | 
					
						
						|  | loss = None | 
					
						
						|  | if closure is not None: | 
					
						
						|  | with torch.enable_grad(): | 
					
						
						|  | loss = closure() | 
					
						
						|  |  | 
					
						
						|  | batch = True | 
					
						
						|  |  | 
					
						
						|  | for group, group_params_names in zip(self.param_groups, | 
					
						
						|  | self.parameters_names): | 
					
						
						|  |  | 
					
						
						|  | with self.batched_params(group["params"], | 
					
						
						|  | group_params_names) as batches: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (len(batches[0][1]) == | 
					
						
						|  | 0): | 
					
						
						|  | clipping_scale = 1 | 
					
						
						|  | else: | 
					
						
						|  | clipping_scale = self._get_clipping_scale(group, batches) | 
					
						
						|  |  | 
					
						
						|  | for p, state, _ in batches: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | grad = p.grad | 
					
						
						|  | if grad.is_sparse: | 
					
						
						|  | raise RuntimeError( | 
					
						
						|  | "ScaledAdam optimizer does not support sparse gradients" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if len(state) == 0: | 
					
						
						|  | self._init_state(group, p, state) | 
					
						
						|  |  | 
					
						
						|  | self._step_one_batch(group, p, state, clipping_scale) | 
					
						
						|  |  | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  | def _init_state(self, group: dict, p: Tensor, state: dict): | 
					
						
						|  | """ | 
					
						
						|  | Initializes state dict for parameter 'p'.  Assumes that dim 0 of tensor p | 
					
						
						|  | is actually the batch dimension, corresponding to batched-together | 
					
						
						|  | parameters of a given shape. | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | group:   Dict to look up configuration values. | 
					
						
						|  | p: The parameter that we are initializing the state for | 
					
						
						|  | state: Dict from string to whatever state we are initializing | 
					
						
						|  | """ | 
					
						
						|  | size_update_period = group["size_update_period"] | 
					
						
						|  |  | 
					
						
						|  | state["step"] = 0 | 
					
						
						|  |  | 
					
						
						|  | kwargs = {"device": p.device, "dtype": p.dtype} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | state["delta"] = torch.zeros_like( | 
					
						
						|  | p, memory_format=torch.preserve_format) | 
					
						
						|  |  | 
					
						
						|  | batch_size = p.shape[0] | 
					
						
						|  | numel = p.numel() // batch_size | 
					
						
						|  | numel = p.numel() | 
					
						
						|  |  | 
					
						
						|  | if numel > 1: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | param_rms = ( | 
					
						
						|  | (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()) | 
					
						
						|  | state["param_rms"] = param_rms | 
					
						
						|  |  | 
					
						
						|  | state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) | 
					
						
						|  | state["scale_grads"] = torch.zeros(size_update_period, | 
					
						
						|  | *param_rms.shape, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | state["exp_avg_sq"] = torch.zeros_like( | 
					
						
						|  | p, memory_format=torch.preserve_format) | 
					
						
						|  |  | 
					
						
						|  | def _get_clipping_scale(self, | 
					
						
						|  | group: dict, | 
					
						
						|  | tuples: List[Tuple[Tensor, dict, List[str]]] | 
					
						
						|  | ) -> float: | 
					
						
						|  | """ | 
					
						
						|  | Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients | 
					
						
						|  | by this amount before applying the rest of the update. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | group: the parameter group, an item in self.param_groups | 
					
						
						|  | tuples: a list of tuples of (param, state, param_names) | 
					
						
						|  | where param is a batched set of parameters, | 
					
						
						|  | with a .grad (1st dim is batch dim) | 
					
						
						|  | and state is the state-dict where optimization parameters are kept. | 
					
						
						|  | param_names is a List[str] while each str is name for a parameter | 
					
						
						|  | in batched set of parameters "param". | 
					
						
						|  | """ | 
					
						
						|  | assert len(tuples) >= 1 | 
					
						
						|  | clipping_scale = group["clipping_scale"] | 
					
						
						|  | (first_p, first_state, _) = tuples[0] | 
					
						
						|  | step = first_state["step"] | 
					
						
						|  | if clipping_scale is None or step == 0: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return 1.0 | 
					
						
						|  | clipping_update_period = group["clipping_update_period"] | 
					
						
						|  |  | 
					
						
						|  | tot_sumsq = torch.tensor(0.0, device=first_p.device) | 
					
						
						|  | for (p, state, param_names) in tuples: | 
					
						
						|  | grad = p.grad | 
					
						
						|  | if grad.is_sparse: | 
					
						
						|  | raise RuntimeError( | 
					
						
						|  | "ScaledAdam optimizer does not support sparse gradients") | 
					
						
						|  | if p.numel() == p.shape[0]: | 
					
						
						|  | tot_sumsq += (grad**2).sum() | 
					
						
						|  | else: | 
					
						
						|  | tot_sumsq += ((grad * state["param_rms"])**2).sum() | 
					
						
						|  |  | 
					
						
						|  | tot_norm = tot_sumsq.sqrt() | 
					
						
						|  | if "model_norms" not in first_state: | 
					
						
						|  | first_state["model_norms"] = torch.zeros( | 
					
						
						|  | clipping_update_period, device=p.device) | 
					
						
						|  | first_state["model_norms"][step % clipping_update_period] = tot_norm | 
					
						
						|  |  | 
					
						
						|  | if step % clipping_update_period == 0: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sorted_norms = first_state["model_norms"].sort()[0].to("cpu") | 
					
						
						|  | quartiles = [] | 
					
						
						|  | for n in range(0, 5): | 
					
						
						|  | index = min( | 
					
						
						|  | clipping_update_period - 1, | 
					
						
						|  | (clipping_update_period // 4) * n, ) | 
					
						
						|  | quartiles.append(sorted_norms[index].item()) | 
					
						
						|  |  | 
					
						
						|  | median = quartiles[2] | 
					
						
						|  | threshold = clipping_scale * median | 
					
						
						|  | first_state["model_norm_threshold"] = threshold | 
					
						
						|  | percent_clipped = (first_state["num_clipped"] * 100.0 / | 
					
						
						|  | clipping_update_period | 
					
						
						|  | if "num_clipped" in first_state else 0.0) | 
					
						
						|  | first_state["num_clipped"] = 0 | 
					
						
						|  | quartiles = " ".join(["%.3e" % x for x in quartiles]) | 
					
						
						|  | logging.info( | 
					
						
						|  | f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " | 
					
						
						|  | f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if step < clipping_update_period: | 
					
						
						|  | return 1.0 | 
					
						
						|  | else: | 
					
						
						|  | try: | 
					
						
						|  | model_norm_threshold = first_state["model_norm_threshold"] | 
					
						
						|  | except KeyError: | 
					
						
						|  | logging.info( | 
					
						
						|  | "Warning: model_norm_threshold not in state: possibly " | 
					
						
						|  | "you changed config when restarting, adding clipping_scale option?" | 
					
						
						|  | ) | 
					
						
						|  | return 1.0 | 
					
						
						|  | ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) | 
					
						
						|  | if ans < 1.0: | 
					
						
						|  | first_state["num_clipped"] += 1 | 
					
						
						|  | if ans < 0.1: | 
					
						
						|  | logging.warn( | 
					
						
						|  | f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" | 
					
						
						|  | ) | 
					
						
						|  | if self.show_dominant_parameters: | 
					
						
						|  | assert p.shape[0] == len(param_names) | 
					
						
						|  | self._show_gradient_dominating_parameter(tuples, tot_sumsq) | 
					
						
						|  | return ans | 
					
						
						|  |  | 
					
						
						|  | def _show_gradient_dominating_parameter( | 
					
						
						|  | self, tuples: List[Tuple[Tensor, dict, List[str]]], | 
					
						
						|  | tot_sumsq: Tensor): | 
					
						
						|  | """ | 
					
						
						|  | Show information of parameter wihch dominanting tot_sumsq. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | tuples: a list of tuples of (param, state, param_names) | 
					
						
						|  | where param is a batched set of parameters, | 
					
						
						|  | with a .grad (1st dim is batch dim) | 
					
						
						|  | and state is the state-dict where optimization parameters are kept. | 
					
						
						|  | param_names is a List[str] while each str is name for a parameter | 
					
						
						|  | in batched set of parameters "param". | 
					
						
						|  | tot_sumsq: sumsq of all parameters. Though it's could be calculated | 
					
						
						|  | from tuples, we still pass it to save some time. | 
					
						
						|  | """ | 
					
						
						|  | all_sumsq_orig = {} | 
					
						
						|  | for (p, state, batch_param_names) in tuples: | 
					
						
						|  |  | 
					
						
						|  | batch_grad = p.grad | 
					
						
						|  | if p.numel() == p.shape[0]: | 
					
						
						|  | batch_sumsq_orig = batch_grad**2 | 
					
						
						|  |  | 
					
						
						|  | batch_rms_orig = torch.ones(p.shape[0]) | 
					
						
						|  | else: | 
					
						
						|  | batch_rms_orig = state["param_rms"] | 
					
						
						|  | batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum( | 
					
						
						|  | dim=list(range(1, batch_grad.ndim))) | 
					
						
						|  |  | 
					
						
						|  | for name, sumsq_orig, rms, grad in zip(batch_param_names, | 
					
						
						|  | batch_sumsq_orig, | 
					
						
						|  | batch_rms_orig, batch_grad): | 
					
						
						|  |  | 
					
						
						|  | proportion_orig = sumsq_orig / tot_sumsq | 
					
						
						|  | all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) | 
					
						
						|  |  | 
					
						
						|  | assert torch.isclose( | 
					
						
						|  | sum([value[0] for value in all_sumsq_orig.values()]).cpu(), | 
					
						
						|  | torch.tensor(1.0), ) | 
					
						
						|  | sorted_by_proportion = { | 
					
						
						|  | k: v | 
					
						
						|  | for k, v in sorted( | 
					
						
						|  | all_sumsq_orig.items(), | 
					
						
						|  | key=lambda item: item[1][0], | 
					
						
						|  | reverse=True, ) | 
					
						
						|  | } | 
					
						
						|  | dominant_param_name = next(iter(sorted_by_proportion)) | 
					
						
						|  | (dominant_proportion, dominant_sumsq, dominant_rms, | 
					
						
						|  | dominant_grad, ) = sorted_by_proportion[dominant_param_name] | 
					
						
						|  | logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}" | 
					
						
						|  | f" with proportion {dominant_proportion:.2f}," | 
					
						
						|  | f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" | 
					
						
						|  | f"={dominant_sumsq:.3e}," | 
					
						
						|  | f" grad_sumsq = {(dominant_grad**2).sum():.3e}," | 
					
						
						|  | f" orig_rms_sq={(dominant_rms**2).item():.3e}") | 
					
						
						|  |  | 
					
						
						|  | def _step_one_batch(self, | 
					
						
						|  | group: dict, | 
					
						
						|  | p: Tensor, | 
					
						
						|  | state: dict, | 
					
						
						|  | clipping_scale: float): | 
					
						
						|  | """ | 
					
						
						|  | Do the step for one parameter, which is actually going to be a batch of | 
					
						
						|  | `real` parameters, with dim 0 as the batch dim. | 
					
						
						|  | Args: | 
					
						
						|  | group:  dict to look up configuration values | 
					
						
						|  | p: parameter to update (actually multiple parameters stacked together | 
					
						
						|  | as a batch) | 
					
						
						|  | state: state-dict for p, to look up the optimizer state | 
					
						
						|  | """ | 
					
						
						|  | lr = group["lr"] | 
					
						
						|  | size_update_period = group["size_update_period"] | 
					
						
						|  | beta1 = group["betas"][0] | 
					
						
						|  |  | 
					
						
						|  | grad = p.grad | 
					
						
						|  | if clipping_scale != 1.0: | 
					
						
						|  | grad = grad * clipping_scale | 
					
						
						|  | step = state["step"] | 
					
						
						|  | delta = state["delta"] | 
					
						
						|  |  | 
					
						
						|  | delta.mul_(beta1) | 
					
						
						|  | batch_size = p.shape[0] | 
					
						
						|  | numel = p.numel() // batch_size | 
					
						
						|  | if numel > 1: | 
					
						
						|  |  | 
					
						
						|  | scale_grads = state["scale_grads"] | 
					
						
						|  | scale_grads[step % size_update_period] = (p * grad).sum( | 
					
						
						|  | dim=list(range(1, p.ndim)), keepdim=True) | 
					
						
						|  | if step % size_update_period == size_update_period - 1: | 
					
						
						|  | param_rms = state["param_rms"] | 
					
						
						|  | param_rms.copy_((p**2) | 
					
						
						|  | .mean(dim=list(range(1, p.ndim)), keepdim=True) | 
					
						
						|  | .sqrt()) | 
					
						
						|  | if step > 0: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._size_update(group, scale_grads, p, state) | 
					
						
						|  |  | 
					
						
						|  | if numel == 1: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._step_scalar(group, p, state) | 
					
						
						|  | else: | 
					
						
						|  | self._step(group, p, state) | 
					
						
						|  |  | 
					
						
						|  | state["step"] = step + 1 | 
					
						
						|  |  | 
					
						
						|  | def _size_update(self, | 
					
						
						|  | group: dict, | 
					
						
						|  | scale_grads: Tensor, | 
					
						
						|  | p: Tensor, | 
					
						
						|  | state: dict) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Called only where p.numel() > 1, this updates the scale of the parameter. | 
					
						
						|  | If we imagine: p =  underlying_param * scale.exp(), and we are doing | 
					
						
						|  | gradient descent on underlying param and on scale, this function does the update | 
					
						
						|  | on `scale`. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | group: dict to look up configuration values | 
					
						
						|  | scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing | 
					
						
						|  | grads w.r.t. the scales. | 
					
						
						|  | p:  The parameter to update | 
					
						
						|  | state: The state-dict of p | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | param_rms = state["param_rms"] | 
					
						
						|  | beta1, beta2 = group["betas"] | 
					
						
						|  | size_lr = group["lr"] * group["scalar_lr_scale"] | 
					
						
						|  | param_min_rms = group["param_min_rms"] | 
					
						
						|  | param_max_rms = group["param_max_rms"] | 
					
						
						|  | eps = group["eps"] | 
					
						
						|  | step = state["step"] | 
					
						
						|  | batch_size = p.shape[0] | 
					
						
						|  |  | 
					
						
						|  | size_update_period = scale_grads.shape[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | beta2_corr = beta2**size_update_period | 
					
						
						|  |  | 
					
						
						|  | scale_exp_avg_sq = state[ | 
					
						
						|  | "scale_exp_avg_sq"] | 
					
						
						|  | scale_exp_avg_sq.mul_(beta2_corr).add_( | 
					
						
						|  | (scale_grads**2).mean(dim=0), | 
					
						
						|  | alpha=1 - beta2_corr, ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | size_step = (step + 1) // size_update_period | 
					
						
						|  | bias_correction2 = 1 - beta2_corr**size_step | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | denom = scale_exp_avg_sq.sqrt() + eps | 
					
						
						|  |  | 
					
						
						|  | scale_step = (-size_lr * (bias_correction2**0.5) * | 
					
						
						|  | scale_grads.sum(dim=0) / denom) | 
					
						
						|  |  | 
					
						
						|  | is_too_small = param_rms < param_min_rms | 
					
						
						|  | is_too_large = param_rms > param_max_rms | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scale_step.masked_fill_(is_too_small, 0.0) | 
					
						
						|  |  | 
					
						
						|  | scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) | 
					
						
						|  | delta = state["delta"] | 
					
						
						|  |  | 
					
						
						|  | delta.add_(p * scale_step, alpha=(1 - beta1)) | 
					
						
						|  |  | 
					
						
						|  | def _step(self, group: dict, p: Tensor, state: dict): | 
					
						
						|  | """ | 
					
						
						|  | This function does the core update of self.step(), in the case where the members of | 
					
						
						|  | the batch have more than 1 element. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | group: A dict which will be used to look up configuration values | 
					
						
						|  | p: The parameter to be updated | 
					
						
						|  | grad: The grad of p | 
					
						
						|  | state: The state-dict corresponding to parameter p | 
					
						
						|  |  | 
					
						
						|  | This function modifies p. | 
					
						
						|  | """ | 
					
						
						|  | grad = p.grad | 
					
						
						|  | lr = group["lr"] | 
					
						
						|  | beta1, beta2 = group["betas"] | 
					
						
						|  | eps = group["eps"] | 
					
						
						|  | param_min_rms = group["param_min_rms"] | 
					
						
						|  | step = state["step"] | 
					
						
						|  |  | 
					
						
						|  | exp_avg_sq = state["exp_avg_sq"] | 
					
						
						|  | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) | 
					
						
						|  |  | 
					
						
						|  | this_step = state["step"] - (state["zero_step"] | 
					
						
						|  | if "zero_step" in state else 0) | 
					
						
						|  | bias_correction2 = 1 - beta2**(this_step + 1) | 
					
						
						|  | if bias_correction2 < 0.99: | 
					
						
						|  |  | 
					
						
						|  | exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) | 
					
						
						|  |  | 
					
						
						|  | denom = exp_avg_sq.sqrt() | 
					
						
						|  | denom += eps | 
					
						
						|  | grad = grad / denom | 
					
						
						|  |  | 
					
						
						|  | alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) | 
					
						
						|  |  | 
					
						
						|  | delta = state["delta"] | 
					
						
						|  | delta.add_(grad * alpha) | 
					
						
						|  | p.add_(delta) | 
					
						
						|  |  | 
					
						
						|  | def _step_scalar(self, group: dict, p: Tensor, state: dict): | 
					
						
						|  | """ | 
					
						
						|  | A simplified form of the core update for scalar tensors, where we cannot get a good | 
					
						
						|  | estimate of the parameter rms. | 
					
						
						|  | """ | 
					
						
						|  | beta1, beta2 = group["betas"] | 
					
						
						|  | scalar_max = group["scalar_max"] | 
					
						
						|  | eps = group["eps"] | 
					
						
						|  | lr = group["lr"] * group["scalar_lr_scale"] | 
					
						
						|  | grad = p.grad | 
					
						
						|  |  | 
					
						
						|  | exp_avg_sq = state["exp_avg_sq"] | 
					
						
						|  | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bias_correction2 = 1 - beta2**(state["step"] + 1) | 
					
						
						|  | denom = (exp_avg_sq / bias_correction2).sqrt() + eps | 
					
						
						|  |  | 
					
						
						|  | delta = state["delta"] | 
					
						
						|  | delta.add_(grad / denom, alpha=-lr * (1 - beta1)) | 
					
						
						|  | p.clamp_(min=-scalar_max, max=scalar_max) | 
					
						
						|  | p.add_(delta) | 
					
						
						|  |  |