|
import os |
|
import math |
|
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple |
|
from functools import partial |
|
import re |
|
import dataclasses |
|
import random |
|
from ml_collections import ConfigDict |
|
from ml_collections.config_dict.config_dict import placeholder |
|
|
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
from jax.sharding import PartitionSpec as PS |
|
from jax.sharding import Mesh |
|
from jax.experimental import mesh_utils |
|
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint |
|
from jax.experimental.pjit import pjit |
|
from jax.interpreters import pxla |
|
import numpy as np |
|
from transformers import FlaxLogitsWarper |
|
|
|
|
|
class JaxRNG(object): |
|
""" A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside |
|
pure function. |
|
""" |
|
|
|
@classmethod |
|
def from_seed(cls, seed): |
|
return cls(jax.random.PRNGKey(seed)) |
|
|
|
def __init__(self, rng): |
|
self.rng = rng |
|
|
|
def __call__(self, keys=None): |
|
if keys is None: |
|
self.rng, split_rng = jax.random.split(self.rng) |
|
return split_rng |
|
elif isinstance(keys, int): |
|
split_rngs = jax.random.split(self.rng, num=keys + 1) |
|
self.rng = split_rngs[0] |
|
return tuple(split_rngs[1:]) |
|
else: |
|
split_rngs = jax.random.split(self.rng, num=len(keys) + 1) |
|
self.rng = split_rngs[0] |
|
return {key: val for key, val in zip(keys, split_rngs[1:])} |
|
|
|
|
|
class JaxDistributedConfig(object): |
|
""" Utility class for initializing JAX distributed. """ |
|
|
|
@staticmethod |
|
def get_default_config(updates=None): |
|
config = ConfigDict() |
|
config.initialize_jax_distributed = False |
|
config.coordinator_address = placeholder(str) |
|
config.num_processes = placeholder(int) |
|
config.process_id = placeholder(int) |
|
config.local_device_ids = placeholder(str) |
|
|
|
if updates is not None: |
|
config.update(ConfigDict(updates).copy_and_resolve_references()) |
|
return config |
|
|
|
@classmethod |
|
def initialize(cls, config): |
|
config = cls.get_default_config(config) |
|
if config.initialize_jax_distributed: |
|
if config.local_device_ids is not None: |
|
local_device_ids = [int(x) for x in config.local_device_ids.split(',')] |
|
else: |
|
local_device_ids = None |
|
|
|
jax.distributed.initialize( |
|
coordinator_address=config.coordinator_address, |
|
num_processes=config.num_processes, |
|
process_id=config.process_id, |
|
local_device_ids=local_device_ids, |
|
) |
|
|
|
|
|
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): |
|
""" JIT traceable version of FlaxLogitsWarper that performs temperature scaling.""" |
|
def __init__(self, temperature): |
|
self.temperature = temperature |
|
|
|
def __call__(self, input_ids, scores, cur_len): |
|
return scores / jnp.clip(self.temperature, a_min=1e-8) |
|
|
|
|
|
def make_shard_and_gather_fns(partition_specs, dtype_specs=None): |
|
""" Create pytree of sharding and gathering functions from pytree of |
|
partition specs. |
|
""" |
|
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) |
|
|
|
def make_to_dtype_fn(dtype_spec): |
|
def to_dtype(tensor): |
|
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes: |
|
|
|
return tensor.astype(dtype_specs) |
|
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'): |
|
return tensor.astype(dtype_spec.dtype) |
|
return tensor |
|
return to_dtype |
|
|
|
def make_shard_fn(partition_spec, dtype_spec=None): |
|
jax_shard_function = pjit( |
|
make_to_dtype_fn(dtype_spec), |
|
in_shardings=None, |
|
out_shardings=partition_spec |
|
) |
|
def shard_fn(tensor): |
|
return jax_shard_function(tensor).block_until_ready() |
|
return shard_fn |
|
|
|
def make_gather_fn(partition_spec, dtype_spec=None): |
|
jax_gather_fn = pjit( |
|
make_to_dtype_fn(dtype_spec), |
|
in_shardings=partition_spec, |
|
out_shardings=None |
|
) |
|
def gather_fn(tensor): |
|
return jax.device_get(jax_gather_fn(tensor)) |
|
return gather_fn |
|
|
|
if dtype_specs is None or dtype_specs in float_dtypes: |
|
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs) |
|
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs) |
|
else: |
|
shard_fns = jax.tree_util.tree_map( |
|
make_shard_fn, partition_specs, dtype_specs |
|
) |
|
gather_fns = jax.tree_util.tree_map( |
|
make_gather_fn, partition_specs, dtype_specs |
|
) |
|
return shard_fns, gather_fns |
|
|
|
|
|
def set_random_seed(seed): |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
init_rng(seed) |
|
|
|
|
|
def get_jax_mesh(axis_dims, names): |
|
if axis_dims.startswith('!'): |
|
|
|
mesh_axis_splitting = True |
|
axis_dims = axis_dims[1:] |
|
else: |
|
mesh_axis_splitting = False |
|
|
|
if ':' in axis_dims: |
|
dims = [] |
|
dim_names = [] |
|
for axis in axis_dims.split(','): |
|
name, dim = axis.split(':') |
|
assert name in names |
|
dims.append(int(dim)) |
|
dim_names.append(name) |
|
assert(set(dim_names) == set(names)) |
|
else: |
|
dims = [int(x) for x in axis_dims.split(',')] |
|
dim_names = names |
|
assert len(dims) == len(names) |
|
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape |
|
if mesh_axis_splitting: |
|
physical_mesh = np.array(jax.devices()).reshape(mesh_shape) |
|
else: |
|
physical_mesh = mesh_utils.create_device_mesh(mesh_shape) |
|
return Mesh(physical_mesh, dim_names) |
|
|
|
|
|
def names_in_current_mesh(*names): |
|
""" Check if current mesh axes contain these names. """ |
|
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names |
|
return set(names) <= set(mesh_axis_names) |
|
|
|
|
|
def get_names_from_parition_spec(partition_specs): |
|
""" Return axis names from partition specs. """ |
|
names = set() |
|
if isinstance(partition_specs, dict): |
|
partition_specs = partition_specs.values() |
|
for item in partition_specs: |
|
if item is None: |
|
continue |
|
elif isinstance(item, str): |
|
names.add(item) |
|
else: |
|
names.update(get_names_from_parition_spec(item)) |
|
|
|
return list(names) |
|
|
|
|
|
def with_sharding_constraint(x, partition_specs): |
|
""" A smarter version of with_sharding_constraint that only applies the |
|
constraint if the current mesh contains the axes in the partition specs. |
|
""" |
|
axis_names = get_names_from_parition_spec(partition_specs) |
|
if names_in_current_mesh(*axis_names): |
|
x = _with_sharding_constraint(x, partition_specs) |
|
return x |
|
|
|
|
|
def wrap_function_with_rng(rng): |
|
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """ |
|
def wrap_function(function): |
|
def wrapped(*args, **kwargs): |
|
nonlocal rng |
|
rng, split_rng = jax.random.split(rng) |
|
return function(split_rng, *args, **kwargs) |
|
return wrapped |
|
return wrap_function |
|
|
|
|
|
def init_rng(seed): |
|
global jax_utils_rng |
|
jax_utils_rng = JaxRNG.from_seed(seed) |
|
|
|
|
|
def next_rng(*args, **kwargs): |
|
global jax_utils_rng |
|
return jax_utils_rng(*args, **kwargs) |
|
|
|
|
|
def get_metrics(metrics, unreplicate=False, stack=False): |
|
if unreplicate: |
|
metrics = flax.jax_utils.unreplicate(metrics) |
|
metrics = jax.device_get(metrics) |
|
if stack: |
|
return jax.tree_map(lambda *args: np.stack(args), *metrics) |
|
else: |
|
return {key: float(val) for key, val in metrics.items()} |
|
|
|
|
|
def mse_loss(val, target, valid=None): |
|
if valid is None: |
|
valid = jnp.ones((*target.shape[:2], 1)) |
|
valid = valid.astype(jnp.float32) |
|
loss = jnp.mean( |
|
jnp.where( |
|
valid > 0.0, |
|
jnp.square(val - target), |
|
0.0 |
|
) |
|
) |
|
return loss |
|
|
|
|
|
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): |
|
if valid is None: |
|
valid = jnp.ones(tokens.shape[:2]) |
|
valid = valid.astype(jnp.float32) |
|
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) |
|
logits = logits.astype(jnp.float32) |
|
token_log_prob = jnp.squeeze( |
|
jnp.take_along_axis( |
|
jax.nn.log_softmax(logits, axis=-1), |
|
jnp.expand_dims(tokens, -1), |
|
axis=-1, |
|
), |
|
-1, |
|
) |
|
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) |
|
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length) |
|
correct = jnp.where( |
|
valid > 0.0, |
|
jnp.argmax(logits, axis=-1) == tokens, |
|
jnp.array(False) |
|
) |
|
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) |
|
return loss, accuracy |
|
|
|
|
|
def global_norm(tree): |
|
""" Return the global L2 norm of a pytree. """ |
|
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) |
|
flattened, _ = jax.flatten_util.ravel_pytree(squared) |
|
return jnp.sqrt(jnp.sum(flattened)) |
|
|
|
|
|
def average_metrics(metrics): |
|
return jax.tree_map( |
|
lambda *args: jnp.mean(jnp.stack(args)), |
|
*metrics |
|
) |
|
|
|
|
|
def get_float_dtype_by_name(dtype): |
|
return { |
|
'bf16': jnp.bfloat16, |
|
'bfloat16': jnp.bfloat16, |
|
'fp16': jnp.float16, |
|
'float16': jnp.float16, |
|
'fp32': jnp.float32, |
|
'float32': jnp.float32, |
|
'fp64': jnp.float64, |
|
'float64': jnp.float64, |
|
}[dtype] |
|
|
|
|
|
def float_tensor_to_dtype(tensor, dtype): |
|
if dtype is None or dtype == '': |
|
return tensor |
|
if isinstance(dtype, str): |
|
dtype = get_float_dtype_by_name(dtype) |
|
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) |
|
if getattr(tensor, 'dtype', None) in float_dtypes: |
|
tensor = tensor.astype(dtype) |
|
return tensor |
|
|
|
|
|
def float_to_dtype(tree, dtype): |
|
return jax.tree_util.tree_map( |
|
partial(float_tensor_to_dtype, dtype=dtype), tree |
|
) |
|
|
|
|
|
def get_gradient_checkpoint_policy(name): |
|
return { |
|
'everything_saveable': jax.checkpoint_policies.everything_saveable, |
|
'nothing_saveable': jax.checkpoint_policies.nothing_saveable, |
|
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, |
|
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, |
|
}[name] |
|
|
|
|
|
def tree_path_to_string(path, sep=None): |
|
keys = [] |
|
for key in path: |
|
if isinstance(key, jax.tree_util.SequenceKey): |
|
keys.append(str(key.idx)) |
|
elif isinstance(key, jax.tree_util.DictKey): |
|
keys.append(str(key.key)) |
|
elif isinstance(key, jax.tree_util.GetAttrKey): |
|
keys.append(str(key.name)) |
|
elif isinstance(key, jax.tree_util.FlattenedIndexKey): |
|
keys.append(str(key.key)) |
|
else: |
|
keys.append(str(key)) |
|
if sep is None: |
|
return tuple(keys) |
|
return sep.join(keys) |
|
|
|
|
|
def flatten_tree(xs, is_leaf=None, sep=None): |
|
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf) |
|
output = {} |
|
for key, val in flattened: |
|
output[tree_path_to_string(key, sep=sep)] = val |
|
return output |
|
|
|
|
|
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): |
|
""" An extended version of jax.tree_util.tree_map, where the mapped function |
|
f takes both the name (path) and the tree leaf as input. |
|
""" |
|
return jax.tree_util.tree_map_with_path( |
|
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), |
|
tree, *rest, |
|
is_leaf=is_leaf |
|
) |
|
|
|
|
|
def match_partition_rules(rules, params): |
|
""" Returns a pytree of PartitionSpec according to rules. Supports handling |
|
Flax TrainState and Optax optimizer state. |
|
""" |
|
def get_partition_spec(name, leaf): |
|
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1: |
|
""" Don't partition scalar values. """ |
|
return PS() |
|
for rule, ps in rules: |
|
if re.search(rule, name) is not None: |
|
return ps |
|
raise ValueError(f'Partition rule not found for param: {name}') |
|
return named_tree_map(get_partition_spec, params, sep='/') |
|
|
|
|
|
def get_weight_decay_mask(exclusions): |
|
""" Return a weight decay mask function that computes the pytree masks |
|
according to the given exclusion rules. |
|
""" |
|
def decay(name, _): |
|
for rule in exclusions: |
|
if re.search(rule, name) is not None: |
|
return False |
|
return True |
|
|
|
def weight_decay_mask(params): |
|
return named_tree_map(decay, params, sep='/') |
|
|
|
return weight_decay_mask |
|
|
|
|
|
def tree_apply(fns, tree): |
|
""" Apply a pytree of functions to the pytree. """ |
|
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree) |
|
|