stylegan2-flax-tpu / checkpoint.py
akhaliq's picture
akhaliq HF staff
add files
81170fd
import flax
import dill as pickle
import os
import builtins
from jax._src.lib import xla_client
import tensorflow as tf
# Hack: this is the module reported by this object.
# https://github.com/google/jax/issues/8505
builtins.bfloat16 = xla_client.bfloat16
def pickle_dump(obj, filename):
""" Wrapper to dump an object to a file."""
with tf.io.gfile.GFile(filename, "wb") as f:
f.write(pickle.dumps(obj))
def pickle_load(filename):
""" Wrapper to load an object from a file."""
with tf.io.gfile.GFile(filename, 'rb') as f:
pickled = pickle.loads(f.read())
return pickled
def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2):
"""
Saves checkpoint.
Args:
ckpt_dir (str): Path to the directory, where checkpoints are saved.
state_G (train_state.TrainState): Generator state.
state_D (train_state.TrainState): Discriminator state.
params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
pl_mean (array): Moving average of the path length (generator regularization).
config (argparse.Namespace): Configuration.
step (int): Current step.
epoch (int): Current epoch.
fid_score (float): FID score corresponding to the checkpoint.
keep (int): Number of checkpoints to keep.
"""
state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
'state_D': flax.jax_utils.unreplicate(state_D),
'params_ema_G': params_ema_G,
'pl_mean': flax.jax_utils.unreplicate(pl_mean),
'config': config,
'fid_score': fid_score,
'step': step,
'epoch': epoch}
pickle_dump(state_dict, os.path.join(ckpt_dir, f'ckpt_{step}.pickle'))
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
if len(ckpts) > keep:
modified_times = {}
for ckpt in ckpts:
stats = tf.io.gfile.stat(ckpt)
modified_times[ckpt] = stats.mtime_nsec
oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
tf.io.gfile.remove(oldest_ckpt)
def load_checkpoint(filename):
"""
Loads checkpoints.
Args:
filename (str): Path to the checkpoint file.
Returns:
(dict): Checkpoint.
"""
state_dict = pickle_load(filename)
return state_dict
def get_latest_checkpoint(ckpt_dir):
"""
Returns the path of the latest checkpoint.
Args:
ckpt_dir (str): Path to the directory, where checkpoints are saved.
Returns:
(str): Path to latest checkpoint (if it exists).
"""
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
if len(ckpts) == 0:
return None
modified_times = {}
for ckpt in ckpts:
stats = tf.io.gfile.stat(ckpt)
modified_times[ckpt] = stats.mtime_nsec
latest_ckpt = sorted(modified_times, key=modified_times.get)[-1]
return latest_ckpt