Spaces:
Build error
Build error
add files
Browse files- checkpoint.py +96 -0
- data_pipeline.py +85 -0
- dataset_utils/crop_image_borders.py +57 -0
- dataset_utils/images_to_tfrecords.py +145 -0
- fid/__init__.py +1 -0
- fid/core.py +150 -0
- fid/inception.py +655 -0
- fid/utils.py +59 -0
- generate_images.py +61 -0
- main.py +102 -0
- requirements.txt +14 -0
- stylegan2/__init__.py +5 -0
- stylegan2/discriminator.py +451 -0
- stylegan2/generator.py +713 -0
- stylegan2/ops.py +674 -0
- stylegan2/utils.py +37 -0
- training.py +382 -0
- training_steps.py +219 -0
- training_utils.py +174 -0
checkpoint.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import flax
|
| 2 |
+
import dill as pickle
|
| 3 |
+
import os
|
| 4 |
+
import builtins
|
| 5 |
+
from jax._src.lib import xla_client
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Hack: this is the module reported by this object.
|
| 10 |
+
# https://github.com/google/jax/issues/8505
|
| 11 |
+
builtins.bfloat16 = xla_client.bfloat16
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def pickle_dump(obj, filename):
|
| 15 |
+
""" Wrapper to dump an object to a file."""
|
| 16 |
+
with tf.io.gfile.GFile(filename, "wb") as f:
|
| 17 |
+
f.write(pickle.dumps(obj))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def pickle_load(filename):
|
| 21 |
+
""" Wrapper to load an object from a file."""
|
| 22 |
+
with tf.io.gfile.GFile(filename, 'rb') as f:
|
| 23 |
+
pickled = pickle.loads(f.read())
|
| 24 |
+
return pickled
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2):
|
| 28 |
+
"""
|
| 29 |
+
Saves checkpoint.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
ckpt_dir (str): Path to the directory, where checkpoints are saved.
|
| 33 |
+
state_G (train_state.TrainState): Generator state.
|
| 34 |
+
state_D (train_state.TrainState): Discriminator state.
|
| 35 |
+
params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
|
| 36 |
+
pl_mean (array): Moving average of the path length (generator regularization).
|
| 37 |
+
config (argparse.Namespace): Configuration.
|
| 38 |
+
step (int): Current step.
|
| 39 |
+
epoch (int): Current epoch.
|
| 40 |
+
fid_score (float): FID score corresponding to the checkpoint.
|
| 41 |
+
keep (int): Number of checkpoints to keep.
|
| 42 |
+
"""
|
| 43 |
+
state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
|
| 44 |
+
'state_D': flax.jax_utils.unreplicate(state_D),
|
| 45 |
+
'params_ema_G': params_ema_G,
|
| 46 |
+
'pl_mean': flax.jax_utils.unreplicate(pl_mean),
|
| 47 |
+
'config': config,
|
| 48 |
+
'fid_score': fid_score,
|
| 49 |
+
'step': step,
|
| 50 |
+
'epoch': epoch}
|
| 51 |
+
|
| 52 |
+
pickle_dump(state_dict, os.path.join(ckpt_dir, f'ckpt_{step}.pickle'))
|
| 53 |
+
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
|
| 54 |
+
if len(ckpts) > keep:
|
| 55 |
+
modified_times = {}
|
| 56 |
+
for ckpt in ckpts:
|
| 57 |
+
stats = tf.io.gfile.stat(ckpt)
|
| 58 |
+
modified_times[ckpt] = stats.mtime_nsec
|
| 59 |
+
oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
|
| 60 |
+
tf.io.gfile.remove(oldest_ckpt)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_checkpoint(filename):
|
| 64 |
+
"""
|
| 65 |
+
Loads checkpoints.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
filename (str): Path to the checkpoint file.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
(dict): Checkpoint.
|
| 72 |
+
"""
|
| 73 |
+
state_dict = pickle_load(filename)
|
| 74 |
+
return state_dict
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_latest_checkpoint(ckpt_dir):
|
| 78 |
+
"""
|
| 79 |
+
Returns the path of the latest checkpoint.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
ckpt_dir (str): Path to the directory, where checkpoints are saved.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
(str): Path to latest checkpoint (if it exists).
|
| 86 |
+
"""
|
| 87 |
+
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
|
| 88 |
+
if len(ckpts) == 0:
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
modified_times = {}
|
| 92 |
+
for ckpt in ckpts:
|
| 93 |
+
stats = tf.io.gfile.stat(ckpt)
|
| 94 |
+
modified_times[ckpt] = stats.mtime_nsec
|
| 95 |
+
latest_ckpt = sorted(modified_times, key=modified_times.get)[-1]
|
| 96 |
+
return latest_ckpt
|
data_pipeline.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import tensorflow_datasets as tfds
|
| 3 |
+
import jax
|
| 4 |
+
import flax
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import os
|
| 8 |
+
from typing import Sequence
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import json
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def prefetch(dataset, n_prefetch):
|
| 18 |
+
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
|
| 19 |
+
ds_iter = iter(dataset)
|
| 20 |
+
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
|
| 21 |
+
ds_iter)
|
| 22 |
+
if n_prefetch:
|
| 23 |
+
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
|
| 24 |
+
return ds_iter
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000):
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
data_dir (str): Root directory of the dataset.
|
| 32 |
+
img_size (int): Image size for training.
|
| 33 |
+
img_channels (int): Number of image channels.
|
| 34 |
+
num_classes (int): Number of classes, 0 for no classes.
|
| 35 |
+
num_local_devices (int): Number of devices.
|
| 36 |
+
batch_size (int): Batch size (per device).
|
| 37 |
+
shuffle_buffer (int): Buffer used for shuffling the dataset.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
(tf.data.Dataset): Dataset.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def pre_process(serialized_example):
|
| 44 |
+
feature = {'height': tf.io.FixedLenFeature([], tf.int64),
|
| 45 |
+
'width': tf.io.FixedLenFeature([], tf.int64),
|
| 46 |
+
'channels': tf.io.FixedLenFeature([], tf.int64),
|
| 47 |
+
'image': tf.io.FixedLenFeature([], tf.string),
|
| 48 |
+
'label': tf.io.FixedLenFeature([], tf.int64)}
|
| 49 |
+
example = tf.io.parse_single_example(serialized_example, feature)
|
| 50 |
+
|
| 51 |
+
height = tf.cast(example['height'], dtype=tf.int64)
|
| 52 |
+
width = tf.cast(example['width'], dtype=tf.int64)
|
| 53 |
+
channels = tf.cast(example['channels'], dtype=tf.int64)
|
| 54 |
+
|
| 55 |
+
image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
|
| 56 |
+
image = tf.reshape(image, shape=[height, width, channels])
|
| 57 |
+
|
| 58 |
+
image = tf.cast(image, dtype='float32')
|
| 59 |
+
image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
|
| 60 |
+
image = tf.image.random_flip_left_right(image)
|
| 61 |
+
|
| 62 |
+
image = (image - 127.5) / 127.5
|
| 63 |
+
|
| 64 |
+
label = tf.one_hot(example['label'], num_classes)
|
| 65 |
+
return {'image': image, 'label': label}
|
| 66 |
+
|
| 67 |
+
def shard(data):
|
| 68 |
+
# Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
|
| 69 |
+
# because the first dimension will be mapped across devices using jax.pmap
|
| 70 |
+
data['image'] = tf.reshape(data['image'], [num_local_devices, -1, img_size, img_size, img_channels])
|
| 71 |
+
data['label'] = tf.reshape(data['label'], [num_local_devices, -1, num_classes])
|
| 72 |
+
return data
|
| 73 |
+
|
| 74 |
+
logger.info('Loading TFRecord...')
|
| 75 |
+
with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
|
| 76 |
+
dataset_info = json.load(fin)
|
| 77 |
+
|
| 78 |
+
ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
|
| 79 |
+
ds = ds.shard(jax.process_count(), jax.process_index())
|
| 80 |
+
ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
|
| 81 |
+
ds = ds.map(pre_process, tf.data.AUTOTUNE)
|
| 82 |
+
ds = ds.batch(batch_size * num_local_devices, drop_remainder=True) # uses per-worker batch size
|
| 83 |
+
ds = ds.map(shard, tf.data.AUTOTUNE)
|
| 84 |
+
ds = ds.prefetch(1) # prefetches the next batch
|
| 85 |
+
return ds, dataset_info
|
dataset_utils/crop_image_borders.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import os
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import argparse
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Crops the black borders around images.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def crop_border(x, constant=0.0):
|
| 16 |
+
top = 0
|
| 17 |
+
while True:
|
| 18 |
+
if np.sum(x[top] != constant) != 0.0:
|
| 19 |
+
break
|
| 20 |
+
top += 1
|
| 21 |
+
bottom = x.shape[0] - 1
|
| 22 |
+
while True:
|
| 23 |
+
if np.sum(x[bottom] != constant) != 0.0:
|
| 24 |
+
bottom += 1
|
| 25 |
+
break
|
| 26 |
+
bottom -= 1
|
| 27 |
+
left = 0
|
| 28 |
+
while True:
|
| 29 |
+
if np.sum(x[:, left] != constant) != 0.0:
|
| 30 |
+
break
|
| 31 |
+
left += 1
|
| 32 |
+
right = x.shape[1] - 1
|
| 33 |
+
while True:
|
| 34 |
+
if np.sum(x[:, right] != constant) != 0.0:
|
| 35 |
+
right += 1
|
| 36 |
+
break
|
| 37 |
+
right -= 1
|
| 38 |
+
return x[top:bottom, left:right]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def crop_images(path, constant_value):
|
| 42 |
+
logger.info('Crop image borders...')
|
| 43 |
+
for f in tqdm(os.listdir(path)):
|
| 44 |
+
img = Image.open(os.path.join(path, f))
|
| 45 |
+
img = crop_border(np.array(img), constant=constant_value)
|
| 46 |
+
img = Image.fromarray(img)
|
| 47 |
+
img.save(os.path.join(path, f))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == '__main__':
|
| 51 |
+
parser = argparse.ArgumentParser()
|
| 52 |
+
parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
|
| 53 |
+
parser.add_argument('--constant_value', type=float, default=0.0, help='Value of the border that should be cropped.')
|
| 54 |
+
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
|
| 57 |
+
crop_images(args.image_dir, args.constant_value)
|
dataset_utils/images_to_tfrecords.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import Sequence
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def images_to_tfrecords(image_dir, data_dir, has_labels):
|
| 15 |
+
"""
|
| 16 |
+
Converts a folder of images to a TFRecord file.
|
| 17 |
+
|
| 18 |
+
The image directory should have one of the following structures:
|
| 19 |
+
|
| 20 |
+
If has_labels = False, image_dir should look like this:
|
| 21 |
+
|
| 22 |
+
path/to/image_dir/
|
| 23 |
+
0.jpg
|
| 24 |
+
1.jpg
|
| 25 |
+
2.jpg
|
| 26 |
+
4.jpg
|
| 27 |
+
...
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
If has_labels = True, image_dir should look like this:
|
| 31 |
+
|
| 32 |
+
path/to/image_dir/
|
| 33 |
+
label0/
|
| 34 |
+
0.jpg
|
| 35 |
+
1.jpg
|
| 36 |
+
...
|
| 37 |
+
label1/
|
| 38 |
+
a.jpg
|
| 39 |
+
b.jpg
|
| 40 |
+
c.jpg
|
| 41 |
+
...
|
| 42 |
+
...
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
The labels will be label0 -> 0, label1 -> 1.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
image_dir (str): Path to images.
|
| 49 |
+
data_dir (str): Path where the TFrecords dataset is stored.
|
| 50 |
+
has_labels (bool): If True, 'image_dir' contains label directories.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
(dict): Dataset info.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def _bytes_feature(value):
|
| 57 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
| 58 |
+
|
| 59 |
+
def _int64_feature(value):
|
| 60 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
| 61 |
+
|
| 62 |
+
os.makedirs(data_dir, exist_ok=True)
|
| 63 |
+
writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords'))
|
| 64 |
+
|
| 65 |
+
num_examples = 0
|
| 66 |
+
num_classes = 0
|
| 67 |
+
|
| 68 |
+
if has_labels:
|
| 69 |
+
for label_dir in os.listdir(image_dir):
|
| 70 |
+
if not os.path.isdir(os.path.join(image_dir, label_dir)):
|
| 71 |
+
logger.warning('The image directory should contain one directory for each label.')
|
| 72 |
+
logger.warning('These label directories should contain the image files.')
|
| 73 |
+
if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')):
|
| 74 |
+
os.remove(os.path.join(data_dir, 'dataset.tfrecords'))
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))):
|
| 78 |
+
file_format = img_file[img_file.rfind('.') + 1:]
|
| 79 |
+
if file_format not in ['png', 'jpg', 'jpeg']:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
|
| 83 |
+
img = Image.open(os.path.join(image_dir, label_dir, img_file))
|
| 84 |
+
img = np.array(img, dtype=np.uint8)
|
| 85 |
+
|
| 86 |
+
height = img.shape[0]
|
| 87 |
+
width = img.shape[1]
|
| 88 |
+
channels = img.shape[2]
|
| 89 |
+
|
| 90 |
+
img_encoded = img.tobytes()
|
| 91 |
+
|
| 92 |
+
example = tf.train.Example(features=tf.train.Features(feature={
|
| 93 |
+
'height': _int64_feature(height),
|
| 94 |
+
'width': _int64_feature(width),
|
| 95 |
+
'channels': _int64_feature(channels),
|
| 96 |
+
'image': _bytes_feature(img_encoded),
|
| 97 |
+
'label': _int64_feature(num_classes)}))
|
| 98 |
+
|
| 99 |
+
writer.write(example.SerializeToString())
|
| 100 |
+
num_examples += 1
|
| 101 |
+
|
| 102 |
+
num_classes += 1
|
| 103 |
+
else:
|
| 104 |
+
for img_file in tqdm(os.listdir(os.path.join(image_dir))):
|
| 105 |
+
file_format = img_file[img_file.rfind('.') + 1:]
|
| 106 |
+
if file_format not in ['png', 'jpg', 'jpeg']:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
|
| 110 |
+
img = Image.open(os.path.join(image_dir, img_file))
|
| 111 |
+
img = np.array(img, dtype=np.uint8)
|
| 112 |
+
|
| 113 |
+
height = img.shape[0]
|
| 114 |
+
width = img.shape[1]
|
| 115 |
+
channels = img.shape[2]
|
| 116 |
+
|
| 117 |
+
img_encoded = img.tobytes()
|
| 118 |
+
|
| 119 |
+
example = tf.train.Example(features=tf.train.Features(feature={
|
| 120 |
+
'height': _int64_feature(height),
|
| 121 |
+
'width': _int64_feature(width),
|
| 122 |
+
'channels': _int64_feature(channels),
|
| 123 |
+
'image': _bytes_feature(img_encoded),
|
| 124 |
+
'label': _int64_feature(num_classes)})) # dummy label
|
| 125 |
+
|
| 126 |
+
writer.write(example.SerializeToString())
|
| 127 |
+
num_examples += 1
|
| 128 |
+
|
| 129 |
+
writer.close()
|
| 130 |
+
|
| 131 |
+
dataset_info = {'num_examples': num_examples, 'num_classes': num_classes}
|
| 132 |
+
with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout:
|
| 133 |
+
json.dump(dataset_info, fout)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
parser = argparse.ArgumentParser()
|
| 138 |
+
parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
|
| 139 |
+
parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.')
|
| 140 |
+
parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.')
|
| 141 |
+
|
| 142 |
+
args = parser.parse_args()
|
| 143 |
+
|
| 144 |
+
images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)
|
| 145 |
+
|
fid/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .core import FID
|
fid/core.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import flax
|
| 4 |
+
import flax.linen as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import functools
|
| 8 |
+
import argparse
|
| 9 |
+
import scipy
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from . import inception
|
| 14 |
+
from . import utils
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class FID:
|
| 19 |
+
|
| 20 |
+
def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0):
|
| 21 |
+
"""
|
| 22 |
+
Evaluates the FID score for a given generator and a given dataset.
|
| 23 |
+
Implementation mostly taken from https://github.com/matthias-wright/jax-fid
|
| 24 |
+
|
| 25 |
+
Reference: https://arxiv.org/abs/1706.08500
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
generator (nn.Module): Generator network.
|
| 29 |
+
dataset (tf.data.Dataset): Dataset containing the real images.
|
| 30 |
+
config (argparse.Namespace): Configuration.
|
| 31 |
+
use_cache (bool): If True, only compute the activation stats once for the real images and store them.
|
| 32 |
+
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
|
| 33 |
+
"""
|
| 34 |
+
self.num_images = config.num_fid_images
|
| 35 |
+
self.batch_size = config.batch_size
|
| 36 |
+
self.c_dim = config.c_dim
|
| 37 |
+
self.z_dim = config.z_dim
|
| 38 |
+
self.dataset = dataset
|
| 39 |
+
self.num_devices = jax.device_count()
|
| 40 |
+
self.num_local_devices = jax.local_device_count()
|
| 41 |
+
self.use_cache = use_cache
|
| 42 |
+
|
| 43 |
+
if self.use_cache:
|
| 44 |
+
self.cache = {}
|
| 45 |
+
|
| 46 |
+
rng = jax.random.PRNGKey(0)
|
| 47 |
+
inception_net = inception.InceptionV3(pretrained=True)
|
| 48 |
+
self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3)))
|
| 49 |
+
self.inception_params = flax.jax_utils.replicate(self.inception_params)
|
| 50 |
+
#self.inception = jax.jit(functools.partial(model.apply, train=False))
|
| 51 |
+
self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch')
|
| 52 |
+
|
| 53 |
+
self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch')
|
| 54 |
+
|
| 55 |
+
def compute_fid(self, generator_params, seed_offset=0):
|
| 56 |
+
generator_params = flax.jax_utils.replicate(generator_params)
|
| 57 |
+
mu_real, sigma_real = self.compute_stats_for_dataset()
|
| 58 |
+
mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset)
|
| 59 |
+
fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6)
|
| 60 |
+
return fid_score
|
| 61 |
+
|
| 62 |
+
def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6):
|
| 63 |
+
# Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
| 64 |
+
mu1 = np.atleast_1d(mu1)
|
| 65 |
+
mu2 = np.atleast_1d(mu2)
|
| 66 |
+
sigma1 = np.atleast_1d(sigma1)
|
| 67 |
+
sigma2 = np.atleast_1d(sigma2)
|
| 68 |
+
|
| 69 |
+
assert mu1.shape == mu2.shape
|
| 70 |
+
assert sigma1.shape == sigma2.shape
|
| 71 |
+
|
| 72 |
+
diff = mu1 - mu2
|
| 73 |
+
|
| 74 |
+
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 75 |
+
if not np.isfinite(covmean).all():
|
| 76 |
+
msg = ('fid calculation produces singular product; '
|
| 77 |
+
'adding %s to diagonal of cov estimates') % eps
|
| 78 |
+
logger.info(msg)
|
| 79 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 80 |
+
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 81 |
+
|
| 82 |
+
# Numerical error might give slight imaginary component
|
| 83 |
+
if np.iscomplexobj(covmean):
|
| 84 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 85 |
+
m = np.max(np.abs(covmean.imag))
|
| 86 |
+
raise ValueError('Imaginary component {}'.format(m))
|
| 87 |
+
covmean = covmean.real
|
| 88 |
+
|
| 89 |
+
tr_covmean = np.trace(covmean)
|
| 90 |
+
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
|
| 91 |
+
|
| 92 |
+
def compute_stats_for_dataset(self):
|
| 93 |
+
if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache:
|
| 94 |
+
logger.info('Use cached statistics for dataset...')
|
| 95 |
+
return self.cache['mu'], self.cache['sigma']
|
| 96 |
+
|
| 97 |
+
print()
|
| 98 |
+
logger.info('Compute statistics for dataset...')
|
| 99 |
+
image_count = 0
|
| 100 |
+
|
| 101 |
+
activations = []
|
| 102 |
+
for batch in utils.prefetch(self.dataset, n_prefetch=2):
|
| 103 |
+
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image']))
|
| 104 |
+
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
|
| 105 |
+
activations.append(act)
|
| 106 |
+
|
| 107 |
+
image_count += self.num_local_devices * self.batch_size
|
| 108 |
+
if image_count >= self.num_images:
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
activations = jnp.concatenate(activations, axis=0)
|
| 112 |
+
activations = activations[:self.num_images]
|
| 113 |
+
mu = np.mean(activations, axis=0)
|
| 114 |
+
sigma = np.cov(activations, rowvar=False)
|
| 115 |
+
self.cache['mu'] = mu
|
| 116 |
+
self.cache['sigma'] = sigma
|
| 117 |
+
return mu, sigma
|
| 118 |
+
|
| 119 |
+
def compute_stats_for_generator(self, generator_params, seed_offset):
|
| 120 |
+
print()
|
| 121 |
+
logger.info('Compute statistics for generator...')
|
| 122 |
+
num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_local_devices)))
|
| 123 |
+
|
| 124 |
+
activations = []
|
| 125 |
+
|
| 126 |
+
for i in range(num_batches):
|
| 127 |
+
rng = jax.random.PRNGKey(seed_offset + i)
|
| 128 |
+
z_latent = jax.random.normal(rng, shape=(self.num_local_devices, self.batch_size, self.z_dim))
|
| 129 |
+
|
| 130 |
+
labels = None
|
| 131 |
+
if self.c_dim > 0:
|
| 132 |
+
labels = jax.random.randint(rng, shape=(self.num_local_devices * self.batch_size,), minval=0, maxval=self.c_dim)
|
| 133 |
+
labels = jax.nn.one_hot(labels, num_classes=self.c_dim)
|
| 134 |
+
labels = jnp.reshape(labels, (self.num_local_devices, self.batch_size, self.c_dim))
|
| 135 |
+
|
| 136 |
+
image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels)
|
| 137 |
+
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
|
| 138 |
+
|
| 139 |
+
image = 2 * image - 1
|
| 140 |
+
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image))
|
| 141 |
+
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
|
| 142 |
+
activations.append(act)
|
| 143 |
+
|
| 144 |
+
activations = jnp.concatenate(activations, axis=0)
|
| 145 |
+
activations = activations[:self.num_images]
|
| 146 |
+
mu = np.mean(activations, axis=0)
|
| 147 |
+
sigma = np.cov(activations, rowvar=False)
|
| 148 |
+
return mu, sigma
|
| 149 |
+
|
| 150 |
+
|
fid/inception.py
ADDED
|
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
from jax import lax
|
| 3 |
+
from jax.nn import initializers
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import flax
|
| 6 |
+
from flax.linen.module import merge_param
|
| 7 |
+
import flax.linen as nn
|
| 8 |
+
from typing import Callable, Iterable, Optional, Tuple, Union, Any
|
| 9 |
+
import functools
|
| 10 |
+
import pickle
|
| 11 |
+
from . import utils
|
| 12 |
+
|
| 13 |
+
PRNGKey = Any
|
| 14 |
+
Array = Any
|
| 15 |
+
Shape = Tuple[int]
|
| 16 |
+
Dtype = Any
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class InceptionV3(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
InceptionV3 network.
|
| 22 |
+
Reference: https://arxiv.org/abs/1512.00567
|
| 23 |
+
Ported mostly from: https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
include_head (bool): If True, include classifier head.
|
| 27 |
+
num_classes (int): Number of classes.
|
| 28 |
+
pretrained (bool): If True, use pretrained weights.
|
| 29 |
+
transform_input (bool): If True, preprocesses the input according to the method with which it
|
| 30 |
+
was trained on ImageNet.
|
| 31 |
+
aux_logits (bool): If True, add an auxiliary branch that can improve training.
|
| 32 |
+
dtype (str): Data type.
|
| 33 |
+
"""
|
| 34 |
+
include_head: bool=False
|
| 35 |
+
num_classes: int=1000
|
| 36 |
+
pretrained: bool=False
|
| 37 |
+
transform_input: bool=False
|
| 38 |
+
aux_logits: bool=False
|
| 39 |
+
ckpt_path: str='https://www.dropbox.com/s/0zo4pd6cfwgzem7/inception_v3_weights_fid.pickle?dl=1'
|
| 40 |
+
dtype: str='float32'
|
| 41 |
+
|
| 42 |
+
def setup(self):
|
| 43 |
+
if self.pretrained:
|
| 44 |
+
ckpt_file = utils.download(self.ckpt_path)
|
| 45 |
+
self.params_dict = pickle.load(open(ckpt_file, 'rb'))
|
| 46 |
+
self.num_classes_ = 1000
|
| 47 |
+
else:
|
| 48 |
+
self.params_dict = None
|
| 49 |
+
self.num_classes_ = self.num_classes
|
| 50 |
+
|
| 51 |
+
@nn.compact
|
| 52 |
+
def __call__(self, x, train=True, rng=jax.random.PRNGKey(0)):
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
x (tensor): Input image, shape [B, H, W, C].
|
| 56 |
+
train (bool): If True, training mode.
|
| 57 |
+
rng (jax.random.PRNGKey): Random seed.
|
| 58 |
+
"""
|
| 59 |
+
x = self._transform_input(x)
|
| 60 |
+
x = BasicConv2d(out_channels=32,
|
| 61 |
+
kernel_size=(3, 3),
|
| 62 |
+
strides=(2, 2),
|
| 63 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_1a_3x3'),
|
| 64 |
+
dtype=self.dtype)(x, train)
|
| 65 |
+
x = BasicConv2d(out_channels=32,
|
| 66 |
+
kernel_size=(3, 3),
|
| 67 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_2a_3x3'),
|
| 68 |
+
dtype=self.dtype)(x, train)
|
| 69 |
+
x = BasicConv2d(out_channels=64,
|
| 70 |
+
kernel_size=(3, 3),
|
| 71 |
+
padding=((1, 1), (1, 1)),
|
| 72 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_2b_3x3'),
|
| 73 |
+
dtype=self.dtype)(x, train)
|
| 74 |
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
| 75 |
+
x = BasicConv2d(out_channels=80,
|
| 76 |
+
kernel_size=(1, 1),
|
| 77 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_3b_1x1'),
|
| 78 |
+
dtype=self.dtype)(x, train)
|
| 79 |
+
x = BasicConv2d(out_channels=192,
|
| 80 |
+
kernel_size=(3, 3),
|
| 81 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_4a_3x3'),
|
| 82 |
+
dtype=self.dtype)(x, train)
|
| 83 |
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
| 84 |
+
x = InceptionA(pool_features=32,
|
| 85 |
+
params_dict=utils.get(self.params_dict, 'Mixed_5b'),
|
| 86 |
+
dtype=self.dtype)(x, train)
|
| 87 |
+
x = InceptionA(pool_features=64,
|
| 88 |
+
params_dict=utils.get(self.params_dict, 'Mixed_5c'),
|
| 89 |
+
dtype=self.dtype)(x, train)
|
| 90 |
+
x = InceptionA(pool_features=64,
|
| 91 |
+
params_dict=utils.get(self.params_dict, 'Mixed_5d'),
|
| 92 |
+
dtype=self.dtype)(x, train)
|
| 93 |
+
x = InceptionB(params_dict=utils.get(self.params_dict, 'Mixed_6a'),
|
| 94 |
+
dtype=self.dtype)(x, train)
|
| 95 |
+
x = InceptionC(channels_7x7=128,
|
| 96 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6b'),
|
| 97 |
+
dtype=self.dtype)(x, train)
|
| 98 |
+
x = InceptionC(channels_7x7=160,
|
| 99 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6c'),
|
| 100 |
+
dtype=self.dtype)(x, train)
|
| 101 |
+
x = InceptionC(channels_7x7=160,
|
| 102 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6d'),
|
| 103 |
+
dtype=self.dtype)(x, train)
|
| 104 |
+
x = InceptionC(channels_7x7=192,
|
| 105 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6e'),
|
| 106 |
+
dtype=self.dtype)(x, train)
|
| 107 |
+
aux = None
|
| 108 |
+
if self.aux_logits and train:
|
| 109 |
+
aux = InceptionAux(num_classes=self.num_classes_,
|
| 110 |
+
params_dict=utils.get(self.params_dict, 'AuxLogits'),
|
| 111 |
+
dtype=self.dtype)(x, train)
|
| 112 |
+
x = InceptionD(params_dict=utils.get(self.params_dict, 'Mixed_7a'),
|
| 113 |
+
dtype=self.dtype)(x, train)
|
| 114 |
+
x = InceptionE(avg_pool, params_dict=utils.get(self.params_dict, 'Mixed_7b'),
|
| 115 |
+
dtype=self.dtype)(x, train)
|
| 116 |
+
# Following the implementation by @mseitzer, we use max pooling instead
|
| 117 |
+
# of average pooling here.
|
| 118 |
+
# See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320
|
| 119 |
+
x = InceptionE(nn.max_pool, params_dict=utils.get(self.params_dict, 'Mixed_7c'),
|
| 120 |
+
dtype=self.dtype)(x, train)
|
| 121 |
+
x = jnp.mean(x, axis=(1, 2), keepdims=True)
|
| 122 |
+
if not self.include_head:
|
| 123 |
+
return x
|
| 124 |
+
x = nn.Dropout(rate=0.5)(x, deterministic=not train, rng=rng)
|
| 125 |
+
x = jnp.reshape(x, newshape=(x.shape[0], -1))
|
| 126 |
+
x = Dense(features=self.num_classes_,
|
| 127 |
+
params_dict=utils.get(self.params_dict, 'fc'),
|
| 128 |
+
dtype=self.dtype)(x)
|
| 129 |
+
if self.aux_logits:
|
| 130 |
+
return x, aux
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
def _transform_input(self, x):
|
| 134 |
+
if self.transform_input:
|
| 135 |
+
x_ch0 = jnp.expand_dims(x[..., 0], axis=-1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
| 136 |
+
x_ch1 = jnp.expand_dims(x[..., 1], axis=-1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
| 137 |
+
x_ch2 = jnp.expand_dims(x[..., 2], axis=-1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
| 138 |
+
x = jnp.concatenate((x_ch0, x_ch1, x_ch2), axis=-1)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Dense(nn.Module):
|
| 143 |
+
features: int
|
| 144 |
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
| 145 |
+
bias_init: functools.partial=nn.initializers.zeros
|
| 146 |
+
params_dict: dict=None
|
| 147 |
+
dtype: str='float32'
|
| 148 |
+
|
| 149 |
+
@nn.compact
|
| 150 |
+
def __call__(self, x):
|
| 151 |
+
x = nn.Dense(features=self.features,
|
| 152 |
+
kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['kernel']),
|
| 153 |
+
bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['bias']))(x)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class BasicConv2d(nn.Module):
|
| 158 |
+
out_channels: int
|
| 159 |
+
kernel_size: Union[int, Iterable[int]]=(3, 3)
|
| 160 |
+
strides: Optional[Iterable[int]]=(1, 1)
|
| 161 |
+
padding: Union[str, Iterable[Tuple[int, int]]]='valid'
|
| 162 |
+
use_bias: bool=False
|
| 163 |
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
| 164 |
+
bias_init: functools.partial=nn.initializers.zeros
|
| 165 |
+
params_dict: dict=None
|
| 166 |
+
dtype: str='float32'
|
| 167 |
+
|
| 168 |
+
@nn.compact
|
| 169 |
+
def __call__(self, x, train=True):
|
| 170 |
+
x = nn.Conv(features=self.out_channels,
|
| 171 |
+
kernel_size=self.kernel_size,
|
| 172 |
+
strides=self.strides,
|
| 173 |
+
padding=self.padding,
|
| 174 |
+
use_bias=self.use_bias,
|
| 175 |
+
kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['kernel']),
|
| 176 |
+
bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['bias']),
|
| 177 |
+
dtype=self.dtype)(x)
|
| 178 |
+
if self.params_dict is None:
|
| 179 |
+
x = BatchNorm(epsilon=0.001,
|
| 180 |
+
momentum=0.1,
|
| 181 |
+
use_running_average=not train,
|
| 182 |
+
dtype=self.dtype)(x)
|
| 183 |
+
else:
|
| 184 |
+
x = BatchNorm(epsilon=0.001,
|
| 185 |
+
momentum=0.1,
|
| 186 |
+
bias_init=lambda *_ : jnp.array(self.params_dict['bn']['bias']),
|
| 187 |
+
scale_init=lambda *_ : jnp.array(self.params_dict['bn']['scale']),
|
| 188 |
+
mean_init=lambda *_ : jnp.array(self.params_dict['bn']['mean']),
|
| 189 |
+
var_init=lambda *_ : jnp.array(self.params_dict['bn']['var']),
|
| 190 |
+
use_running_average=not train,
|
| 191 |
+
dtype=self.dtype)(x)
|
| 192 |
+
x = jax.nn.relu(x)
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class InceptionA(nn.Module):
|
| 197 |
+
pool_features: int
|
| 198 |
+
params_dict: dict=None
|
| 199 |
+
dtype: str='float32'
|
| 200 |
+
|
| 201 |
+
@nn.compact
|
| 202 |
+
def __call__(self, x, train=True):
|
| 203 |
+
branch1x1 = BasicConv2d(out_channels=64,
|
| 204 |
+
kernel_size=(1, 1),
|
| 205 |
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
| 206 |
+
dtype=self.dtype)(x, train)
|
| 207 |
+
branch5x5 = BasicConv2d(out_channels=48,
|
| 208 |
+
kernel_size=(1, 1),
|
| 209 |
+
params_dict=utils.get(self.params_dict, 'branch5x5_1'),
|
| 210 |
+
dtype=self.dtype)(x, train)
|
| 211 |
+
branch5x5 = BasicConv2d(out_channels=64,
|
| 212 |
+
kernel_size=(5, 5),
|
| 213 |
+
padding=((2, 2), (2, 2)),
|
| 214 |
+
params_dict=utils.get(self.params_dict, 'branch5x5_2'),
|
| 215 |
+
dtype=self.dtype)(branch5x5, train)
|
| 216 |
+
|
| 217 |
+
branch3x3dbl = BasicConv2d(out_channels=64,
|
| 218 |
+
kernel_size=(1, 1),
|
| 219 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
| 220 |
+
dtype=self.dtype)(x, train)
|
| 221 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
| 222 |
+
kernel_size=(3, 3),
|
| 223 |
+
padding=((1, 1), (1, 1)),
|
| 224 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
| 225 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
| 226 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
| 227 |
+
kernel_size=(3, 3),
|
| 228 |
+
padding=((1, 1), (1, 1)),
|
| 229 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
|
| 230 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
| 231 |
+
|
| 232 |
+
branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
| 233 |
+
branch_pool = BasicConv2d(out_channels=self.pool_features,
|
| 234 |
+
kernel_size=(1, 1),
|
| 235 |
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
| 236 |
+
dtype=self.dtype)(branch_pool, train)
|
| 237 |
+
|
| 238 |
+
output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1)
|
| 239 |
+
return output
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class InceptionB(nn.Module):
|
| 243 |
+
params_dict: dict=None
|
| 244 |
+
dtype: str='float32'
|
| 245 |
+
|
| 246 |
+
@nn.compact
|
| 247 |
+
def __call__(self, x, train=True):
|
| 248 |
+
branch3x3 = BasicConv2d(out_channels=384,
|
| 249 |
+
kernel_size=(3, 3),
|
| 250 |
+
strides=(2, 2),
|
| 251 |
+
params_dict=utils.get(self.params_dict, 'branch3x3'),
|
| 252 |
+
dtype=self.dtype)(x, train)
|
| 253 |
+
|
| 254 |
+
branch3x3dbl = BasicConv2d(out_channels=64,
|
| 255 |
+
kernel_size=(1, 1),
|
| 256 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
| 257 |
+
dtype=self.dtype)(x, train)
|
| 258 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
| 259 |
+
kernel_size=(3, 3),
|
| 260 |
+
padding=((1, 1), (1, 1)),
|
| 261 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
| 262 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
| 263 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
| 264 |
+
kernel_size=(3, 3),
|
| 265 |
+
strides=(2, 2),
|
| 266 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
|
| 267 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
| 268 |
+
|
| 269 |
+
branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
| 270 |
+
|
| 271 |
+
output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1)
|
| 272 |
+
return output
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class InceptionC(nn.Module):
|
| 276 |
+
channels_7x7: int
|
| 277 |
+
params_dict: dict=None
|
| 278 |
+
dtype: str='float32'
|
| 279 |
+
|
| 280 |
+
@nn.compact
|
| 281 |
+
def __call__(self, x, train=True):
|
| 282 |
+
branch1x1 = BasicConv2d(out_channels=192,
|
| 283 |
+
kernel_size=(1, 1),
|
| 284 |
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
| 285 |
+
dtype=self.dtype)(x, train)
|
| 286 |
+
|
| 287 |
+
branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
|
| 288 |
+
kernel_size=(1, 1),
|
| 289 |
+
params_dict=utils.get(self.params_dict, 'branch7x7_1'),
|
| 290 |
+
dtype=self.dtype)(x, train)
|
| 291 |
+
branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
|
| 292 |
+
kernel_size=(1, 7),
|
| 293 |
+
padding=((0, 0), (3, 3)),
|
| 294 |
+
params_dict=utils.get(self.params_dict, 'branch7x7_2'),
|
| 295 |
+
dtype=self.dtype)(branch7x7, train)
|
| 296 |
+
branch7x7 = BasicConv2d(out_channels=192,
|
| 297 |
+
kernel_size=(7, 1),
|
| 298 |
+
padding=((3, 3), (0, 0)),
|
| 299 |
+
params_dict=utils.get(self.params_dict, 'branch7x7_3'),
|
| 300 |
+
dtype=self.dtype)(branch7x7, train)
|
| 301 |
+
|
| 302 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
| 303 |
+
kernel_size=(1, 1),
|
| 304 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_1'),
|
| 305 |
+
dtype=self.dtype)(x, train)
|
| 306 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
| 307 |
+
kernel_size=(7, 1),
|
| 308 |
+
padding=((3, 3), (0, 0)),
|
| 309 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_2'),
|
| 310 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
| 311 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
| 312 |
+
kernel_size=(1, 7),
|
| 313 |
+
padding=((0, 0), (3, 3)),
|
| 314 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_3'),
|
| 315 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
| 316 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
| 317 |
+
kernel_size=(7, 1),
|
| 318 |
+
padding=((3, 3), (0, 0)),
|
| 319 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_4'),
|
| 320 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
| 321 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
| 322 |
+
kernel_size=(1, 7),
|
| 323 |
+
padding=((0, 0), (3, 3)),
|
| 324 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_5'),
|
| 325 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
| 326 |
+
|
| 327 |
+
branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
| 328 |
+
branch_pool = BasicConv2d(out_channels=192,
|
| 329 |
+
kernel_size=(1, 1),
|
| 330 |
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
| 331 |
+
dtype=self.dtype)(branch_pool, train)
|
| 332 |
+
|
| 333 |
+
output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1)
|
| 334 |
+
return output
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class InceptionD(nn.Module):
|
| 338 |
+
params_dict: dict=None
|
| 339 |
+
dtype: str='float32'
|
| 340 |
+
|
| 341 |
+
@nn.compact
|
| 342 |
+
def __call__(self, x, train=True):
|
| 343 |
+
branch3x3 = BasicConv2d(out_channels=192,
|
| 344 |
+
kernel_size=(1, 1),
|
| 345 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_1'),
|
| 346 |
+
dtype=self.dtype)(x, train)
|
| 347 |
+
branch3x3 = BasicConv2d(out_channels=320,
|
| 348 |
+
kernel_size=(3, 3),
|
| 349 |
+
strides=(2, 2),
|
| 350 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_2'),
|
| 351 |
+
dtype=self.dtype)(branch3x3, train)
|
| 352 |
+
|
| 353 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
| 354 |
+
kernel_size=(1, 1),
|
| 355 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_1'),
|
| 356 |
+
dtype=self.dtype)(x, train)
|
| 357 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
| 358 |
+
kernel_size=(1, 7),
|
| 359 |
+
padding=((0, 0), (3, 3)),
|
| 360 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_2'),
|
| 361 |
+
dtype=self.dtype)(branch7x7x3, train)
|
| 362 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
| 363 |
+
kernel_size=(7, 1),
|
| 364 |
+
padding=((3, 3), (0, 0)),
|
| 365 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_3'),
|
| 366 |
+
dtype=self.dtype)(branch7x7x3, train)
|
| 367 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
| 368 |
+
kernel_size=(3, 3),
|
| 369 |
+
strides=(2, 2),
|
| 370 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_4'),
|
| 371 |
+
dtype=self.dtype)(branch7x7x3, train)
|
| 372 |
+
|
| 373 |
+
branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
| 374 |
+
|
| 375 |
+
output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1)
|
| 376 |
+
return output
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class InceptionE(nn.Module):
|
| 380 |
+
pooling: Callable
|
| 381 |
+
params_dict: dict=None
|
| 382 |
+
dtype: str='float32'
|
| 383 |
+
|
| 384 |
+
@nn.compact
|
| 385 |
+
def __call__(self, x, train=True):
|
| 386 |
+
branch1x1 = BasicConv2d(out_channels=320,
|
| 387 |
+
kernel_size=(1, 1),
|
| 388 |
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
| 389 |
+
dtype=self.dtype)(x, train)
|
| 390 |
+
|
| 391 |
+
branch3x3 = BasicConv2d(out_channels=384,
|
| 392 |
+
kernel_size=(1, 1),
|
| 393 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_1'),
|
| 394 |
+
dtype=self.dtype)(x, train)
|
| 395 |
+
branch3x3_a = BasicConv2d(out_channels=384,
|
| 396 |
+
kernel_size=(1, 3),
|
| 397 |
+
padding=((0, 0), (1, 1)),
|
| 398 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_2a'),
|
| 399 |
+
dtype=self.dtype)(branch3x3, train)
|
| 400 |
+
branch3x3_b = BasicConv2d(out_channels=384,
|
| 401 |
+
kernel_size=(3, 1),
|
| 402 |
+
padding=((1, 1), (0, 0)),
|
| 403 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_2b'),
|
| 404 |
+
dtype=self.dtype)(branch3x3, train)
|
| 405 |
+
branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1)
|
| 406 |
+
|
| 407 |
+
branch3x3dbl = BasicConv2d(out_channels=448,
|
| 408 |
+
kernel_size=(1, 1),
|
| 409 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
| 410 |
+
dtype=self.dtype)(x, train)
|
| 411 |
+
branch3x3dbl = BasicConv2d(out_channels=384,
|
| 412 |
+
kernel_size=(3, 3),
|
| 413 |
+
padding=((1, 1), (1, 1)),
|
| 414 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
| 415 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
| 416 |
+
branch3x3dbl_a = BasicConv2d(out_channels=384,
|
| 417 |
+
kernel_size=(1, 3),
|
| 418 |
+
padding=((0, 0), (1, 1)),
|
| 419 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3a'),
|
| 420 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
| 421 |
+
branch3x3dbl_b = BasicConv2d(out_channels=384,
|
| 422 |
+
kernel_size=(3, 1),
|
| 423 |
+
padding=((1, 1), (0, 0)),
|
| 424 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3b'),
|
| 425 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
| 426 |
+
branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1)
|
| 427 |
+
|
| 428 |
+
branch_pool = self.pooling(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
| 429 |
+
branch_pool = BasicConv2d(out_channels=192,
|
| 430 |
+
kernel_size=(1, 1),
|
| 431 |
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
| 432 |
+
dtype=self.dtype)(branch_pool, train)
|
| 433 |
+
|
| 434 |
+
output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1)
|
| 435 |
+
return output
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class InceptionAux(nn.Module):
|
| 439 |
+
num_classes: int
|
| 440 |
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
| 441 |
+
bias_init: functools.partial=nn.initializers.zeros
|
| 442 |
+
params_dict: dict=None
|
| 443 |
+
dtype: str='float32'
|
| 444 |
+
|
| 445 |
+
@nn.compact
|
| 446 |
+
def __call__(self, x, train=True):
|
| 447 |
+
x = avg_pool(x, window_shape=(5, 5), strides=(3, 3))
|
| 448 |
+
x = BasicConv2d(out_channels=128,
|
| 449 |
+
kernel_size=(1, 1),
|
| 450 |
+
params_dict=utils.get(self.params_dict, 'conv0'),
|
| 451 |
+
dtype=self.dtype)(x, train)
|
| 452 |
+
x = BasicConv2d(out_channels=768,
|
| 453 |
+
kernel_size=(5, 5),
|
| 454 |
+
params_dict=utils.get(self.params_dict, 'conv1'),
|
| 455 |
+
dtype=self.dtype)(x, train)
|
| 456 |
+
x = jnp.mean(x, axis=(1, 2))
|
| 457 |
+
x = jnp.reshape(x, newshape=(x.shape[0], -1))
|
| 458 |
+
x = Dense(features=self.num_classes,
|
| 459 |
+
params_dict=utils.get(self.params_dict, 'fc'),
|
| 460 |
+
dtype=self.dtype)(x)
|
| 461 |
+
return x
|
| 462 |
+
|
| 463 |
+
def _absolute_dims(rank, dims):
|
| 464 |
+
return tuple([rank + dim if dim < 0 else dim for dim in dims])
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class BatchNorm(nn.Module):
|
| 468 |
+
"""BatchNorm Module.
|
| 469 |
+
Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py
|
| 470 |
+
Attributes:
|
| 471 |
+
use_running_average: if True, the statistics stored in batch_stats
|
| 472 |
+
will be used instead of computing the batch statistics on the input.
|
| 473 |
+
axis: the feature or non-batch axis of the input.
|
| 474 |
+
momentum: decay rate for the exponential moving average of the batch statistics.
|
| 475 |
+
epsilon: a small float added to variance to avoid dividing by zero.
|
| 476 |
+
dtype: the dtype of the computation (default: float32).
|
| 477 |
+
use_bias: if True, bias (beta) is added.
|
| 478 |
+
use_scale: if True, multiply by scale (gamma).
|
| 479 |
+
When the next layer is linear (also e.g. nn.relu), this can be disabled
|
| 480 |
+
since the scaling will be done by the next layer.
|
| 481 |
+
bias_init: initializer for bias, by default, zero.
|
| 482 |
+
scale_init: initializer for scale, by default, one.
|
| 483 |
+
axis_name: the axis name used to combine batch statistics from multiple
|
| 484 |
+
devices. See `jax.pmap` for a description of axis names (default: None).
|
| 485 |
+
axis_index_groups: groups of axis indices within that named axis
|
| 486 |
+
representing subsets of devices to reduce over (default: None). For
|
| 487 |
+
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
| 488 |
+
the examples on the first two and last two devices. See `jax.lax.psum`
|
| 489 |
+
for more details.
|
| 490 |
+
"""
|
| 491 |
+
use_running_average: Optional[bool] = None
|
| 492 |
+
axis: int = -1
|
| 493 |
+
momentum: float = 0.99
|
| 494 |
+
epsilon: float = 1e-5
|
| 495 |
+
dtype: Dtype = jnp.float32
|
| 496 |
+
use_bias: bool = True
|
| 497 |
+
use_scale: bool = True
|
| 498 |
+
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
|
| 499 |
+
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
|
| 500 |
+
mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32)
|
| 501 |
+
var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32)
|
| 502 |
+
axis_name: Optional[str] = None
|
| 503 |
+
axis_index_groups: Any = None
|
| 504 |
+
|
| 505 |
+
@nn.compact
|
| 506 |
+
def __call__(self, x, use_running_average: Optional[bool] = None):
|
| 507 |
+
"""Normalizes the input using batch statistics.
|
| 508 |
+
|
| 509 |
+
NOTE:
|
| 510 |
+
During initialization (when parameters are mutable) the running average
|
| 511 |
+
of the batch statistics will not be updated. Therefore, the inputs
|
| 512 |
+
fed during initialization don't need to match that of the actual input
|
| 513 |
+
distribution and the reduction axis (set with `axis_name`) does not have
|
| 514 |
+
to exist.
|
| 515 |
+
Args:
|
| 516 |
+
x: the input to be normalized.
|
| 517 |
+
use_running_average: if true, the statistics stored in batch_stats
|
| 518 |
+
will be used instead of computing the batch statistics on the input.
|
| 519 |
+
Returns:
|
| 520 |
+
Normalized inputs (the same shape as inputs).
|
| 521 |
+
"""
|
| 522 |
+
use_running_average = merge_param(
|
| 523 |
+
'use_running_average', self.use_running_average, use_running_average)
|
| 524 |
+
x = jnp.asarray(x, jnp.float32)
|
| 525 |
+
axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
|
| 526 |
+
axis = _absolute_dims(x.ndim, axis)
|
| 527 |
+
feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
|
| 528 |
+
reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
|
| 529 |
+
reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
|
| 530 |
+
|
| 531 |
+
# see NOTE above on initialization behavior
|
| 532 |
+
initializing = self.is_mutable_collection('params')
|
| 533 |
+
|
| 534 |
+
ra_mean = self.variable('batch_stats', 'mean',
|
| 535 |
+
self.mean_init,
|
| 536 |
+
reduced_feature_shape)
|
| 537 |
+
ra_var = self.variable('batch_stats', 'var',
|
| 538 |
+
self.var_init,
|
| 539 |
+
reduced_feature_shape)
|
| 540 |
+
|
| 541 |
+
if use_running_average:
|
| 542 |
+
mean, var = ra_mean.value, ra_var.value
|
| 543 |
+
else:
|
| 544 |
+
mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
|
| 545 |
+
mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
|
| 546 |
+
if self.axis_name is not None and not initializing:
|
| 547 |
+
concatenated_mean = jnp.concatenate([mean, mean2])
|
| 548 |
+
mean, mean2 = jnp.split(
|
| 549 |
+
lax.pmean(
|
| 550 |
+
concatenated_mean,
|
| 551 |
+
axis_name=self.axis_name,
|
| 552 |
+
axis_index_groups=self.axis_index_groups), 2)
|
| 553 |
+
var = mean2 - lax.square(mean)
|
| 554 |
+
|
| 555 |
+
if not initializing:
|
| 556 |
+
ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
|
| 557 |
+
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
|
| 558 |
+
|
| 559 |
+
y = x - mean.reshape(feature_shape)
|
| 560 |
+
mul = lax.rsqrt(var + self.epsilon)
|
| 561 |
+
if self.use_scale:
|
| 562 |
+
scale = self.param('scale',
|
| 563 |
+
self.scale_init,
|
| 564 |
+
reduced_feature_shape).reshape(feature_shape)
|
| 565 |
+
mul = mul * scale
|
| 566 |
+
y = y * mul
|
| 567 |
+
if self.use_bias:
|
| 568 |
+
bias = self.param('bias',
|
| 569 |
+
self.bias_init,
|
| 570 |
+
reduced_feature_shape).reshape(feature_shape)
|
| 571 |
+
y = y + bias
|
| 572 |
+
return jnp.asarray(y, self.dtype)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def pool(inputs, init, reduce_fn, window_shape, strides, padding):
|
| 576 |
+
"""
|
| 577 |
+
Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py
|
| 578 |
+
|
| 579 |
+
Helper function to define pooling functions.
|
| 580 |
+
Pooling functions are implemented using the ReduceWindow XLA op.
|
| 581 |
+
NOTE: Be aware that pooling is not generally differentiable.
|
| 582 |
+
That means providing a reduce_fn that is differentiable does not imply
|
| 583 |
+
that pool is differentiable.
|
| 584 |
+
Args:
|
| 585 |
+
inputs: input data with dimensions (batch, window dims..., features).
|
| 586 |
+
init: the initial value for the reduction
|
| 587 |
+
reduce_fn: a reduce function of the form `(T, T) -> T`.
|
| 588 |
+
window_shape: a shape tuple defining the window to reduce over.
|
| 589 |
+
strides: a sequence of `n` integers, representing the inter-window
|
| 590 |
+
strides.
|
| 591 |
+
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
|
| 592 |
+
of `n` `(low, high)` integer pairs that give the padding to apply before
|
| 593 |
+
and after each spatial dimension.
|
| 594 |
+
Returns:
|
| 595 |
+
The output of the reduction for each window slice.
|
| 596 |
+
"""
|
| 597 |
+
strides = strides or (1,) * len(window_shape)
|
| 598 |
+
assert len(window_shape) == len(strides), (
|
| 599 |
+
f"len({window_shape}) == len({strides})")
|
| 600 |
+
strides = (1,) + strides + (1,)
|
| 601 |
+
dims = (1,) + window_shape + (1,)
|
| 602 |
+
|
| 603 |
+
is_single_input = False
|
| 604 |
+
if inputs.ndim == len(dims) - 1:
|
| 605 |
+
# add singleton batch dimension because lax.reduce_window always
|
| 606 |
+
# needs a batch dimension.
|
| 607 |
+
inputs = inputs[None]
|
| 608 |
+
is_single_input = True
|
| 609 |
+
|
| 610 |
+
assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
|
| 611 |
+
if not isinstance(padding, str):
|
| 612 |
+
padding = tuple(map(tuple, padding))
|
| 613 |
+
assert(len(padding) == len(window_shape)), (
|
| 614 |
+
f"padding {padding} must specify pads for same number of dims as "
|
| 615 |
+
f"window_shape {window_shape}")
|
| 616 |
+
assert(all([len(x) == 2 for x in padding])), (
|
| 617 |
+
f"each entry in padding {padding} must be length 2")
|
| 618 |
+
padding = ((0,0),) + padding + ((0,0),)
|
| 619 |
+
y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
|
| 620 |
+
if is_single_input:
|
| 621 |
+
y = jnp.squeeze(y, axis=0)
|
| 622 |
+
return y
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
|
| 626 |
+
"""
|
| 627 |
+
Pools the input by taking the average over a window.
|
| 628 |
+
|
| 629 |
+
In comparison to flax.linen.avg_pool, this pooling operation does not
|
| 630 |
+
consider the padded zero's for the average computation.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
inputs: input data with dimensions (batch, window dims..., features).
|
| 634 |
+
window_shape: a shape tuple defining the window to reduce over.
|
| 635 |
+
strides: a sequence of `n` integers, representing the inter-window
|
| 636 |
+
strides (default: `(1, ..., 1)`).
|
| 637 |
+
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
|
| 638 |
+
of `n` `(low, high)` integer pairs that give the padding to apply before
|
| 639 |
+
and after each spatial dimension (default: `'VALID'`).
|
| 640 |
+
Returns:
|
| 641 |
+
The average for each window slice.
|
| 642 |
+
"""
|
| 643 |
+
assert inputs.ndim == 4
|
| 644 |
+
assert len(window_shape) == 2
|
| 645 |
+
|
| 646 |
+
y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
|
| 647 |
+
ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype)
|
| 648 |
+
counts = jax.lax.conv_general_dilated(ones,
|
| 649 |
+
jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)),
|
| 650 |
+
window_strides=(1, 1),
|
| 651 |
+
padding=((1, 1), (1, 1)),
|
| 652 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape),
|
| 653 |
+
feature_group_count=1)
|
| 654 |
+
y = y / counts
|
| 655 |
+
return y
|
fid/utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import flax
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import requests
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def download(url, ckpt_dir=None):
|
| 14 |
+
name = url[url.rfind('/') + 1 : url.rfind('?')]
|
| 15 |
+
if ckpt_dir is None:
|
| 16 |
+
ckpt_dir = tempfile.gettempdir()
|
| 17 |
+
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
|
| 18 |
+
ckpt_file = os.path.join(ckpt_dir, name)
|
| 19 |
+
if not os.path.exists(ckpt_file):
|
| 20 |
+
logger.info(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
|
| 21 |
+
if not os.path.exists(ckpt_dir):
|
| 22 |
+
os.makedirs(ckpt_dir)
|
| 23 |
+
|
| 24 |
+
response = requests.get(url, stream=True)
|
| 25 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
| 26 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
| 27 |
+
|
| 28 |
+
# first create temp file, in case the download fails
|
| 29 |
+
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
|
| 30 |
+
with open(ckpt_file_temp, 'wb') as file:
|
| 31 |
+
for data in response.iter_content(chunk_size=1024):
|
| 32 |
+
progress_bar.update(len(data))
|
| 33 |
+
file.write(data)
|
| 34 |
+
progress_bar.close()
|
| 35 |
+
|
| 36 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
| 37 |
+
logger.error('An error occured while downloading, please try again.')
|
| 38 |
+
if os.path.exists(ckpt_file_temp):
|
| 39 |
+
os.remove(ckpt_file_temp)
|
| 40 |
+
else:
|
| 41 |
+
# if download was successful, rename the temp file
|
| 42 |
+
os.rename(ckpt_file_temp, ckpt_file)
|
| 43 |
+
return ckpt_file
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get(dictionary, key):
|
| 47 |
+
if dictionary is None or key not in dictionary:
|
| 48 |
+
return None
|
| 49 |
+
return dictionary[key]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def prefetch(dataset, n_prefetch):
|
| 53 |
+
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
|
| 54 |
+
ds_iter = iter(dataset)
|
| 55 |
+
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
|
| 56 |
+
ds_iter)
|
| 57 |
+
if n_prefetch:
|
| 58 |
+
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
|
| 59 |
+
return ds_iter
|
generate_images.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import jax
|
| 7 |
+
import jax.numpy as jnp
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
import checkpoint
|
| 13 |
+
from stylegan2.generator import Generator
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def generate_images(args):
|
| 20 |
+
logger.info(f"Loading checking '{args.checkpoint}'...")
|
| 21 |
+
ckpt = checkpoint.load_checkpoint(args.checkpoint)
|
| 22 |
+
config = ckpt['config']
|
| 23 |
+
params_ema_G = ckpt['params_ema_G']
|
| 24 |
+
|
| 25 |
+
generator_ema = Generator(
|
| 26 |
+
resolution=config.resolution,
|
| 27 |
+
num_channels=config.img_channels,
|
| 28 |
+
z_dim=config.z_dim,
|
| 29 |
+
c_dim=config.c_dim,
|
| 30 |
+
w_dim=config.w_dim,
|
| 31 |
+
num_ws=int(np.log2(config.resolution)) * 2 - 3,
|
| 32 |
+
num_mapping_layers=8,
|
| 33 |
+
fmap_base=config.fmap_base,
|
| 34 |
+
dtype=jnp.float32
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
generator_apply = jax.jit(
|
| 38 |
+
functools.partial(generator_ema.apply, truncation_psi=args.truncation_psi, train=False, noise_mode='const')
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
logger.info(f"Generating {len(args.seeds)} images with truncation {args.truncation_psi}...")
|
| 42 |
+
for seed in tqdm(args.seeds):
|
| 43 |
+
rng = jax.random.PRNGKey(seed)
|
| 44 |
+
z_latent = jax.random.normal(rng, shape=(1, config.z_dim))
|
| 45 |
+
image = generator_apply(params_ema_G, jax.lax.stop_gradient(z_latent), None)
|
| 46 |
+
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
|
| 47 |
+
|
| 48 |
+
Image.fromarray(np.uint8(np.clip(image[0] * 255, 0, 255))).save(os.path.join(args.out_path, f'{seed}.png'))
|
| 49 |
+
logger.info(f"Images saved in '{args.out_path}/'")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == '__main__':
|
| 53 |
+
parser = argparse.ArgumentParser()
|
| 54 |
+
parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint.', required=True)
|
| 55 |
+
parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.')
|
| 56 |
+
parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.')
|
| 57 |
+
parser.add_argument('--seeds', type=int, nargs='*', default=[0], help='List of random seeds.')
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
os.makedirs(args.out_path, exist_ok=True)
|
| 60 |
+
|
| 61 |
+
generate_images(args)
|
main.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import jax
|
| 4 |
+
import wandb
|
| 5 |
+
import training
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
# Paths
|
| 17 |
+
parser.add_argument('--data_dir', type=str, required=True, help='Directory of the dataset.')
|
| 18 |
+
parser.add_argument('--save_dir', type=str, default='gs://ig-standard-usc1/sg2-flax/checkpoints/', help='Directory where checkpoints will be written to. A subfolder with run_id will be created.')
|
| 19 |
+
parser.add_argument('--load_from_pkl', type=str, help='If provided, start training from an existing checkpoint pickle file.')
|
| 20 |
+
parser.add_argument('--resume_run_id', type=str, help='If provided, resume existing training run. If --wandb is enabled W&B will also resume.')
|
| 21 |
+
parser.add_argument('--project', type=str, default='sg2-flax', help='Name of this project.')
|
| 22 |
+
# Training
|
| 23 |
+
parser.add_argument('--num_epochs', type=int, default=10000, help='Number of epochs.')
|
| 24 |
+
parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate.')
|
| 25 |
+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size.')
|
| 26 |
+
parser.add_argument('--num_prefetch', type=int, default=2, help='Number of prefetched examples for the data pipeline.')
|
| 27 |
+
parser.add_argument('--resolution', type=int, default=128, help='Image resolution. Must be a multiple of 2.')
|
| 28 |
+
parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.')
|
| 29 |
+
parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.')
|
| 30 |
+
parser.add_argument('--random_seed', type=int, default=0, help='Random seed.')
|
| 31 |
+
parser.add_argument('--bf16', action='store_true', help='Use bf16 dtype (This is still WIP).')
|
| 32 |
+
# Generator
|
| 33 |
+
parser.add_argument('--fmap_base', type=int, default=16384, help='Overall multiplier for the number of feature maps.')
|
| 34 |
+
# Discriminator
|
| 35 |
+
parser.add_argument('--mbstd_group_size', type=int, help='Group size for the minibatch standard deviation layer, None = entire minibatch.')
|
| 36 |
+
# Exponentially Moving Average of Generator Weights
|
| 37 |
+
parser.add_argument('--ema_kimg', type=float, default=20.0, help='Controls the ema of the generator weights (larger value -> larger beta).')
|
| 38 |
+
# Losses
|
| 39 |
+
parser.add_argument('--pl_decay', type=float, default=0.01, help='Exponentially decay for mean of path length (Path length regul).')
|
| 40 |
+
parser.add_argument('--pl_weight', type=float, default=2, help='Weight for path length regularization.')
|
| 41 |
+
# Regularization
|
| 42 |
+
parser.add_argument('--mixing_prob', type=float, default=0.9, help='Probability for style mixing.')
|
| 43 |
+
parser.add_argument('--G_reg_interval', type=int, default=4, help='How often to perform regularization for G.')
|
| 44 |
+
parser.add_argument('--D_reg_interval', type=int, default=16, help='How often to perform regularization for D.')
|
| 45 |
+
parser.add_argument('--r1_gamma', type=float, default=10.0, help='Weight for R1 regularization.')
|
| 46 |
+
# Model
|
| 47 |
+
parser.add_argument('--z_dim', type=int, default=512, help='Input latent (Z) dimensionality.')
|
| 48 |
+
parser.add_argument('--c_dim', type=int, default=0, help='Conditioning label (C) dimensionality, 0 = no label.')
|
| 49 |
+
parser.add_argument('--w_dim', type=int, default=512, help='Conditioning label (W) dimensionality.')
|
| 50 |
+
# Logging
|
| 51 |
+
parser.add_argument('--log_every', type=int, default=100, help='Log every log_every steps.')
|
| 52 |
+
parser.add_argument('--save_every', type=int, default=2000, help='Save every save_every steps. Will be ignored if FID evaluation is enabled.')
|
| 53 |
+
parser.add_argument('--generate_samples_every', type=int, default=10000, help='Generate samples every generate_samples_every steps.')
|
| 54 |
+
parser.add_argument('--debug', action='store_true', help='Show debug log.')
|
| 55 |
+
# FID
|
| 56 |
+
parser.add_argument('--eval_fid_every', type=int, default=1000, help='Compute FID score every eval_fid_every steps.')
|
| 57 |
+
parser.add_argument('--num_fid_images', type=int, default=10000, help='Number of images to use for FID computation.')
|
| 58 |
+
parser.add_argument('--disable_fid', action='store_true', help='Disable FID evaluation.')
|
| 59 |
+
# W&B
|
| 60 |
+
parser.add_argument('--wandb', action='store_true', help='Log to Weights&Biases.')
|
| 61 |
+
parser.add_argument('--name', type=str, default=None, help='Name of this experiment in Weights&Biases.')
|
| 62 |
+
parser.add_argument('--entity', type=str, default='nyxai', help='Entity for this experiment in Weights&Biases.')
|
| 63 |
+
parser.add_argument('--group', type=str, default=None, help='Group name of this experiment for Weights&Biases.')
|
| 64 |
+
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
# debug mode
|
| 68 |
+
if args.debug:
|
| 69 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 70 |
+
|
| 71 |
+
# some validation
|
| 72 |
+
if args.resume_run_id is not None:
|
| 73 |
+
assert args.load_from_pkl is None, 'When resuming a run one cannot also specify --load_from_pkl'
|
| 74 |
+
|
| 75 |
+
# set unique Run ID
|
| 76 |
+
if args.resume_run_id:
|
| 77 |
+
resume = 'must' # throw error if cannot find id to be resumed
|
| 78 |
+
args.run_id = args.resume_run_id
|
| 79 |
+
else:
|
| 80 |
+
resume = None # default
|
| 81 |
+
args.run_id = wandb.util.generate_id()
|
| 82 |
+
args.ckpt_dir = os.path.join(args.save_dir, args.run_id)
|
| 83 |
+
|
| 84 |
+
if jax.process_index() == 0:
|
| 85 |
+
if not args.ckpt_dir.startswith('gs://') and not os.path.exists(args.ckpt_dir):
|
| 86 |
+
os.makedirs(args.ckpt_dir)
|
| 87 |
+
if args.wandb:
|
| 88 |
+
wandb.init(id=args.run_id,
|
| 89 |
+
project=args.project,
|
| 90 |
+
group=args.group,
|
| 91 |
+
config=args,
|
| 92 |
+
name=args.name,
|
| 93 |
+
entity=args.entity,
|
| 94 |
+
resume=resume)
|
| 95 |
+
logger.info('Starting new run with config:')
|
| 96 |
+
print(json.dumps(vars(args), indent=4))
|
| 97 |
+
|
| 98 |
+
training.train_and_evaluate(args)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == '__main__':
|
| 102 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flaxmodels==0.1.1
|
| 2 |
+
flax==0.4.1
|
| 3 |
+
jax==0.3.14
|
| 4 |
+
tensorflow==2.4.1
|
| 5 |
+
optax==0.0.9
|
| 6 |
+
numpy
|
| 7 |
+
tensorflow-datasets
|
| 8 |
+
argparse
|
| 9 |
+
wandb
|
| 10 |
+
tqdm
|
| 11 |
+
dill
|
| 12 |
+
h5py
|
| 13 |
+
dataclasses
|
| 14 |
+
tqdm
|
stylegan2/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .generator import SynthesisNetwork
|
| 2 |
+
from .generator import MappingNetwork
|
| 3 |
+
from .generator import Generator
|
| 4 |
+
from .discriminator import Discriminator
|
| 5 |
+
|
stylegan2/discriminator.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import jax
|
| 3 |
+
from jax import random
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import flax.linen as nn
|
| 6 |
+
from typing import Any, Tuple, List, Callable
|
| 7 |
+
import h5py
|
| 8 |
+
from . import ops
|
| 9 |
+
from stylegan2 import utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
URLS = {'afhqcat': 'https://www.dropbox.com/s/qygbjkefyqyu9k9/stylegan2_discriminator_afhqcat.h5?dl=1',
|
| 13 |
+
'afhqdog': 'https://www.dropbox.com/s/kmoxbp33qswz64p/stylegan2_discriminator_afhqdog.h5?dl=1',
|
| 14 |
+
'afhqwild': 'https://www.dropbox.com/s/jz1hpsyt3isj6e7/stylegan2_discriminator_afhqwild.h5?dl=1',
|
| 15 |
+
'brecahad': 'https://www.dropbox.com/s/h0cb89hruo6pmyj/stylegan2_discriminator_brecahad.h5?dl=1',
|
| 16 |
+
'car': 'https://www.dropbox.com/s/2ghjrmxih7cic76/stylegan2_discriminator_car.h5?dl=1',
|
| 17 |
+
'cat': 'https://www.dropbox.com/s/zfhjsvlsny5qixd/stylegan2_discriminator_cat.h5?dl=1',
|
| 18 |
+
'church': 'https://www.dropbox.com/s/jlno7zeivkjtk8g/stylegan2_discriminator_church.h5?dl=1',
|
| 19 |
+
'cifar10': 'https://www.dropbox.com/s/eldpubfkl4c6rur/stylegan2_discriminator_cifar10.h5?dl=1',
|
| 20 |
+
'ffhq': 'https://www.dropbox.com/s/m42qy9951b7lq1s/stylegan2_discriminator_ffhq.h5?dl=1',
|
| 21 |
+
'horse': 'https://www.dropbox.com/s/19f5pxrcdh2g8cw/stylegan2_discriminator_horse.h5?dl=1',
|
| 22 |
+
'metfaces': 'https://www.dropbox.com/s/xnokaunql12glkd/stylegan2_discriminator_metfaces.h5?dl=1'}
|
| 23 |
+
|
| 24 |
+
RESOLUTION = {'metfaces': 1024,
|
| 25 |
+
'ffhq': 1024,
|
| 26 |
+
'church': 256,
|
| 27 |
+
'cat': 256,
|
| 28 |
+
'horse': 256,
|
| 29 |
+
'car': 512,
|
| 30 |
+
'brecahad': 512,
|
| 31 |
+
'afhqwild': 512,
|
| 32 |
+
'afhqdog': 512,
|
| 33 |
+
'afhqcat': 512,
|
| 34 |
+
'cifar10': 32}
|
| 35 |
+
|
| 36 |
+
C_DIM = {'metfaces': 0,
|
| 37 |
+
'ffhq': 0,
|
| 38 |
+
'church': 0,
|
| 39 |
+
'cat': 0,
|
| 40 |
+
'horse': 0,
|
| 41 |
+
'car': 0,
|
| 42 |
+
'brecahad': 0,
|
| 43 |
+
'afhqwild': 0,
|
| 44 |
+
'afhqdog': 0,
|
| 45 |
+
'afhqcat': 0,
|
| 46 |
+
'cifar10': 10}
|
| 47 |
+
|
| 48 |
+
ARCHITECTURE = {'metfaces': 'resnet',
|
| 49 |
+
'ffhq': 'resnet',
|
| 50 |
+
'church': 'resnet',
|
| 51 |
+
'cat': 'resnet',
|
| 52 |
+
'horse': 'resnet',
|
| 53 |
+
'car': 'resnet',
|
| 54 |
+
'brecahad': 'resnet',
|
| 55 |
+
'afhqwild': 'resnet',
|
| 56 |
+
'afhqdog': 'resnet',
|
| 57 |
+
'afhqcat': 'resnet',
|
| 58 |
+
'cifar10': 'orig'}
|
| 59 |
+
|
| 60 |
+
MBSTD_GROUP_SIZE = {'metfaces': None,
|
| 61 |
+
'ffhq': None,
|
| 62 |
+
'church': None,
|
| 63 |
+
'cat': None,
|
| 64 |
+
'horse': None,
|
| 65 |
+
'car': None,
|
| 66 |
+
'brecahad': None,
|
| 67 |
+
'afhqwild': None,
|
| 68 |
+
'afhqdog': None,
|
| 69 |
+
'afhqcat': None,
|
| 70 |
+
'cifar10': 32}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class FromRGBLayer(nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
From RGB Layer.
|
| 76 |
+
|
| 77 |
+
Attributes:
|
| 78 |
+
fmaps (int): Number of output channels of the convolution.
|
| 79 |
+
kernel (int): Kernel size of the convolution.
|
| 80 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 81 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 82 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
| 83 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 84 |
+
dtype (str): Data dtype.
|
| 85 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 86 |
+
"""
|
| 87 |
+
fmaps: int
|
| 88 |
+
kernel: int=1
|
| 89 |
+
lr_multiplier: float=1
|
| 90 |
+
activation: str='leaky_relu'
|
| 91 |
+
param_dict: h5py.Group=None
|
| 92 |
+
clip_conv: float=None
|
| 93 |
+
dtype: str='float32'
|
| 94 |
+
rng: Any=random.PRNGKey(0)
|
| 95 |
+
|
| 96 |
+
@nn.compact
|
| 97 |
+
def __call__(self, x, y):
|
| 98 |
+
"""
|
| 99 |
+
Run From RGB Layer.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
x (tensor): Input image of shape [N, H, W, num_channels].
|
| 103 |
+
y (tensor): Input tensor of shape [N, H, W, out_channels].
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
(tensor): Output tensor of shape [N, H, W, out_channels].
|
| 107 |
+
"""
|
| 108 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
| 109 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'fromrgb', self.rng)
|
| 110 |
+
|
| 111 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
| 112 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
| 113 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
| 114 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
| 115 |
+
|
| 116 |
+
x = x.astype(self.dtype)
|
| 117 |
+
x = ops.conv2d(x, w.astype(x.dtype))
|
| 118 |
+
x += b.astype(x.dtype)
|
| 119 |
+
x = ops.apply_activation(x, activation=self.activation)
|
| 120 |
+
if self.clip_conv is not None:
|
| 121 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
| 122 |
+
if y is not None:
|
| 123 |
+
x += y
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class DiscriminatorLayer(nn.Module):
|
| 128 |
+
"""
|
| 129 |
+
Discriminator Layer.
|
| 130 |
+
|
| 131 |
+
Attributes:
|
| 132 |
+
fmaps (int): Number of output channels of the convolution.
|
| 133 |
+
kernel (int): Kernel size of the convolution.
|
| 134 |
+
use_bias (bool): If True, use bias.
|
| 135 |
+
down (bool): If True, downsample the spatial resolution.
|
| 136 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
| 137 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 138 |
+
layer_name (str): Layer name.
|
| 139 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters.
|
| 140 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 141 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 142 |
+
dtype (str): Data dtype.
|
| 143 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 144 |
+
"""
|
| 145 |
+
fmaps: int
|
| 146 |
+
kernel: int=3
|
| 147 |
+
use_bias: bool=True
|
| 148 |
+
down: bool=False
|
| 149 |
+
resample_kernel: Tuple=None
|
| 150 |
+
activation: str='leaky_relu'
|
| 151 |
+
layer_name: str=None
|
| 152 |
+
param_dict: h5py.Group=None
|
| 153 |
+
lr_multiplier: float=1
|
| 154 |
+
clip_conv: float=None
|
| 155 |
+
dtype: str='float32'
|
| 156 |
+
rng: Any=random.PRNGKey(0)
|
| 157 |
+
|
| 158 |
+
@nn.compact
|
| 159 |
+
def __call__(self, x):
|
| 160 |
+
"""
|
| 161 |
+
Run Discriminator Layer.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
(tensor): Output tensor of shape [N, H, W, fmaps].
|
| 168 |
+
"""
|
| 169 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
| 170 |
+
if self.use_bias:
|
| 171 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
|
| 172 |
+
else:
|
| 173 |
+
w = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
|
| 174 |
+
|
| 175 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
| 176 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
| 177 |
+
if self.use_bias:
|
| 178 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
| 179 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
| 180 |
+
|
| 181 |
+
x = x.astype(self.dtype)
|
| 182 |
+
x = ops.conv2d(x, w, down=self.down, resample_kernel=self.resample_kernel)
|
| 183 |
+
if self.use_bias: x += b.astype(x.dtype)
|
| 184 |
+
x = ops.apply_activation(x, activation=self.activation)
|
| 185 |
+
if self.clip_conv is not None:
|
| 186 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class DiscriminatorBlock(nn.Module):
|
| 191 |
+
"""
|
| 192 |
+
Discriminator Block.
|
| 193 |
+
|
| 194 |
+
Attributes:
|
| 195 |
+
fmaps (int): Number of output channels of the convolution.
|
| 196 |
+
kernel (int): Kernel size of the convolution.
|
| 197 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
| 198 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 199 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters.
|
| 200 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 201 |
+
architecture (str): Architecture: 'orig', 'resnet'.
|
| 202 |
+
nf (Callable): Callable that returns the number of feature maps for a given layer.
|
| 203 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 204 |
+
dtype (str): Data dtype.
|
| 205 |
+
rng (jax.random.PRNGKey): Random seed for initialization.
|
| 206 |
+
"""
|
| 207 |
+
res: int
|
| 208 |
+
kernel: int=3
|
| 209 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
| 210 |
+
activation: str='leaky_relu'
|
| 211 |
+
param_dict: Any=None
|
| 212 |
+
lr_multiplier: float=1
|
| 213 |
+
architecture: str='resnet'
|
| 214 |
+
nf: Callable=None
|
| 215 |
+
clip_conv: float=None
|
| 216 |
+
dtype: str='float32'
|
| 217 |
+
rng: Any=random.PRNGKey(0)
|
| 218 |
+
|
| 219 |
+
@nn.compact
|
| 220 |
+
def __call__(self, x):
|
| 221 |
+
"""
|
| 222 |
+
Run Discriminator Block.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
(tensor): Output tensor of shape [N, H, W, fmaps].
|
| 229 |
+
"""
|
| 230 |
+
init_rng = self.rng
|
| 231 |
+
x = x.astype(self.dtype)
|
| 232 |
+
residual = x
|
| 233 |
+
for i in range(2):
|
| 234 |
+
init_rng, init_key = random.split(init_rng)
|
| 235 |
+
x = DiscriminatorLayer(fmaps=self.nf(self.res - (i + 1)),
|
| 236 |
+
kernel=self.kernel,
|
| 237 |
+
down=i == 1,
|
| 238 |
+
resample_kernel=self.resample_kernel if i == 1 else None,
|
| 239 |
+
activation=self.activation,
|
| 240 |
+
layer_name=f'conv{i}',
|
| 241 |
+
param_dict=self.param_dict,
|
| 242 |
+
lr_multiplier=self.lr_multiplier,
|
| 243 |
+
clip_conv=self.clip_conv,
|
| 244 |
+
dtype=self.dtype,
|
| 245 |
+
rng=init_key)(x)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if self.architecture == 'resnet':
|
| 249 |
+
init_rng, init_key = random.split(init_rng)
|
| 250 |
+
residual = DiscriminatorLayer(fmaps=self.nf(self.res - 2),
|
| 251 |
+
kernel=1,
|
| 252 |
+
use_bias=False,
|
| 253 |
+
down=True,
|
| 254 |
+
resample_kernel=self.resample_kernel,
|
| 255 |
+
activation='linear',
|
| 256 |
+
layer_name='skip',
|
| 257 |
+
param_dict=self.param_dict,
|
| 258 |
+
lr_multiplier=self.lr_multiplier,
|
| 259 |
+
dtype=self.dtype,
|
| 260 |
+
rng=init_key)(residual)
|
| 261 |
+
|
| 262 |
+
x = (x + residual) * np.sqrt(0.5, dtype=x.dtype)
|
| 263 |
+
return x
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class Discriminator(nn.Module):
|
| 267 |
+
"""
|
| 268 |
+
Discriminator.
|
| 269 |
+
|
| 270 |
+
Attributes:
|
| 271 |
+
resolution (int): Input resolution. Overridden based on dataset.
|
| 272 |
+
num_channels (int): Number of input color channels. Overridden based on dataset.
|
| 273 |
+
c_dim (int): Dimensionality of the labels (c), 0 if no labels. Overrttten based on dataset.
|
| 274 |
+
fmap_base (int): Overall multiplier for the number of feature maps.
|
| 275 |
+
fmap_decay (int): Log2 feature map reduction when doubling the resolution.
|
| 276 |
+
fmap_min (int): Minimum number of feature maps in any layer.
|
| 277 |
+
fmap_max (int): Maximum number of feature maps in any layer.
|
| 278 |
+
mapping_layers (int): Number of additional mapping layers for the conditioning labels.
|
| 279 |
+
mapping_fmaps (int): Number of activations in the mapping layers, None = default.
|
| 280 |
+
mapping_lr_multiplier (float): Learning rate multiplier for the mapping layers.
|
| 281 |
+
architecture (str): Architecture: 'orig', 'resnet'.
|
| 282 |
+
activation (int): Activation function: 'relu', 'leaky_relu', etc.
|
| 283 |
+
mbstd_group_size (int): Group size for the minibatch standard deviation layer, None = entire minibatch.
|
| 284 |
+
mbstd_num_features (int): Number of features for the minibatch standard deviation layer, 0 = disable.
|
| 285 |
+
resample_kernel (Tuple): Low-pass filter to apply when resampling activations, None = box filter.
|
| 286 |
+
num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
|
| 287 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 288 |
+
pretrained (str): Use pretrained model, None for random initialization.
|
| 289 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
| 290 |
+
dtype (str): Data type.
|
| 291 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 292 |
+
"""
|
| 293 |
+
# Input dimensions.
|
| 294 |
+
resolution: int=1024
|
| 295 |
+
num_channels: int=3
|
| 296 |
+
c_dim: int=0
|
| 297 |
+
|
| 298 |
+
# Capacity.
|
| 299 |
+
fmap_base: int=16384
|
| 300 |
+
fmap_decay: int=1
|
| 301 |
+
fmap_min: int=1
|
| 302 |
+
fmap_max: int=512
|
| 303 |
+
|
| 304 |
+
# Internal details.
|
| 305 |
+
mapping_layers: int=0
|
| 306 |
+
mapping_fmaps: int=None
|
| 307 |
+
mapping_lr_multiplier: float=0.1
|
| 308 |
+
architecture: str='resnet'
|
| 309 |
+
activation: str='leaky_relu'
|
| 310 |
+
mbstd_group_size: int=None
|
| 311 |
+
mbstd_num_features: int=1
|
| 312 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
| 313 |
+
num_fp16_res: int=0
|
| 314 |
+
clip_conv: float=None
|
| 315 |
+
|
| 316 |
+
# Pretraining
|
| 317 |
+
pretrained: str=None
|
| 318 |
+
ckpt_dir: str=None
|
| 319 |
+
|
| 320 |
+
dtype: str='float32'
|
| 321 |
+
rng: Any=random.PRNGKey(0)
|
| 322 |
+
|
| 323 |
+
def setup(self):
|
| 324 |
+
self.resolution_ = self.resolution
|
| 325 |
+
self.c_dim_ = self.c_dim
|
| 326 |
+
self.architecture_ = self.architecture
|
| 327 |
+
self.mbstd_group_size_ = self.mbstd_group_size
|
| 328 |
+
self.param_dict = None
|
| 329 |
+
if self.pretrained is not None:
|
| 330 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
| 331 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
| 332 |
+
self.param_dict = h5py.File(ckpt_file, 'r')['discriminator']
|
| 333 |
+
self.resolution_ = RESOLUTION[self.pretrained]
|
| 334 |
+
self.architecture_ = ARCHITECTURE[self.pretrained]
|
| 335 |
+
self.mbstd_group_size_ = MBSTD_GROUP_SIZE[self.pretrained]
|
| 336 |
+
self.c_dim_ = C_DIM[self.pretrained]
|
| 337 |
+
|
| 338 |
+
assert self.architecture in ['orig', 'resnet']
|
| 339 |
+
|
| 340 |
+
@nn.compact
|
| 341 |
+
def __call__(self, x, c=None):
|
| 342 |
+
"""
|
| 343 |
+
Run Discriminator.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
x (tensor): Input image of shape [N, H, W, num_channels].
|
| 347 |
+
c (tensor): Input labels, shape [N, c_dim].
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
(tensor): Output tensor of shape [N, 1].
|
| 351 |
+
"""
|
| 352 |
+
resolution_log2 = int(np.log2(self.resolution_))
|
| 353 |
+
assert self.resolution_ == 2**resolution_log2 and self.resolution_ >= 4
|
| 354 |
+
def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max)
|
| 355 |
+
if self.mapping_fmaps is None:
|
| 356 |
+
mapping_fmaps = nf(0)
|
| 357 |
+
else:
|
| 358 |
+
mapping_fmaps = self.mapping_fmaps
|
| 359 |
+
|
| 360 |
+
init_rng = self.rng
|
| 361 |
+
# Label embedding and mapping.
|
| 362 |
+
if self.c_dim_ > 0:
|
| 363 |
+
c = ops.LinearLayer(in_features=self.c_dim_,
|
| 364 |
+
out_features=mapping_fmaps,
|
| 365 |
+
lr_multiplier=self.mapping_lr_multiplier,
|
| 366 |
+
param_dict=self.param_dict,
|
| 367 |
+
layer_name='label_embedding',
|
| 368 |
+
dtype=self.dtype,
|
| 369 |
+
rng=init_rng)(c)
|
| 370 |
+
|
| 371 |
+
c = ops.normalize_2nd_moment(c)
|
| 372 |
+
for i in range(self.mapping_layers):
|
| 373 |
+
init_rng, init_key = random.split(init_rng)
|
| 374 |
+
c = ops.LinearLayer(in_features=self.c_dim_,
|
| 375 |
+
out_features=mapping_fmaps,
|
| 376 |
+
lr_multiplier=self.mapping_lr_multiplier,
|
| 377 |
+
param_dict=self.param_dict,
|
| 378 |
+
layer_name=f'fc{i}',
|
| 379 |
+
dtype=self.dtype,
|
| 380 |
+
rng=init_key)(c)
|
| 381 |
+
|
| 382 |
+
# Layers for >=8x8 resolutions.
|
| 383 |
+
y = None
|
| 384 |
+
for res in range(resolution_log2, 2, -1):
|
| 385 |
+
res_str = f'block_{2**res}x{2**res}'
|
| 386 |
+
if res == resolution_log2:
|
| 387 |
+
init_rng, init_key = random.split(init_rng)
|
| 388 |
+
x = FromRGBLayer(fmaps=nf(res - 1),
|
| 389 |
+
kernel=1,
|
| 390 |
+
activation=self.activation,
|
| 391 |
+
param_dict=self.param_dict[res_str] if self.param_dict is not None else None,
|
| 392 |
+
clip_conv=self.clip_conv,
|
| 393 |
+
dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32',
|
| 394 |
+
rng=init_key)(x, y)
|
| 395 |
+
|
| 396 |
+
init_rng, init_key = random.split(init_rng)
|
| 397 |
+
x = DiscriminatorBlock(res=res,
|
| 398 |
+
kernel=3,
|
| 399 |
+
resample_kernel=self.resample_kernel,
|
| 400 |
+
activation=self.activation,
|
| 401 |
+
param_dict=self.param_dict[res_str] if self.param_dict is not None else None,
|
| 402 |
+
architecture=self.architecture_,
|
| 403 |
+
nf=nf,
|
| 404 |
+
clip_conv=self.clip_conv,
|
| 405 |
+
dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32',
|
| 406 |
+
rng=init_key)(x)
|
| 407 |
+
|
| 408 |
+
# Layers for 4x4 resolution.
|
| 409 |
+
dtype = jnp.float32
|
| 410 |
+
x = x.astype(dtype)
|
| 411 |
+
if self.mbstd_num_features > 0:
|
| 412 |
+
x = ops.minibatch_stddev_layer(x, self.mbstd_group_size_, self.mbstd_num_features)
|
| 413 |
+
init_rng, init_key = random.split(init_rng)
|
| 414 |
+
x = DiscriminatorLayer(fmaps=nf(1),
|
| 415 |
+
kernel=3,
|
| 416 |
+
use_bias=True,
|
| 417 |
+
activation=self.activation,
|
| 418 |
+
layer_name='conv0',
|
| 419 |
+
param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None,
|
| 420 |
+
clip_conv=self.clip_conv,
|
| 421 |
+
dtype=dtype,
|
| 422 |
+
rng=init_rng)(x)
|
| 423 |
+
|
| 424 |
+
# Switch to NCHW so that the pretrained weights still work after reshaping
|
| 425 |
+
x = jnp.transpose(x, axes=(0, 3, 1, 2))
|
| 426 |
+
x = jnp.reshape(x, newshape=(-1, x.shape[1] * x.shape[2] * x.shape[3]))
|
| 427 |
+
|
| 428 |
+
init_rng, init_key = random.split(init_rng)
|
| 429 |
+
x = ops.LinearLayer(in_features=x.shape[1],
|
| 430 |
+
out_features=nf(0),
|
| 431 |
+
activation=self.activation,
|
| 432 |
+
param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None,
|
| 433 |
+
layer_name='fc0',
|
| 434 |
+
dtype=dtype,
|
| 435 |
+
rng=init_key)(x)
|
| 436 |
+
|
| 437 |
+
# Output layer.
|
| 438 |
+
init_rng, init_key = random.split(init_rng)
|
| 439 |
+
x = ops.LinearLayer(in_features=x.shape[1],
|
| 440 |
+
out_features=1 if self.c_dim_ == 0 else mapping_fmaps,
|
| 441 |
+
param_dict=self.param_dict,
|
| 442 |
+
layer_name='output',
|
| 443 |
+
dtype=dtype,
|
| 444 |
+
rng=init_key)(x)
|
| 445 |
+
|
| 446 |
+
if self.c_dim_ > 0:
|
| 447 |
+
x = jnp.sum(x * c, axis=1, keepdims=True) / jnp.sqrt(mapping_fmaps)
|
| 448 |
+
return x
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
stylegan2/generator.py
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import jax
|
| 3 |
+
from jax import random
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import flax.linen as nn
|
| 6 |
+
from typing import Any, Tuple, List
|
| 7 |
+
import h5py
|
| 8 |
+
from . import ops
|
| 9 |
+
from stylegan2 import utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
URLS = {'afhqcat': 'https://www.dropbox.com/s/lv1r0bwvg5ta51f/stylegan2_generator_afhqcat.h5?dl=1',
|
| 13 |
+
'afhqdog': 'https://www.dropbox.com/s/px6ply9hv0vdwen/stylegan2_generator_afhqdog.h5?dl=1',
|
| 14 |
+
'afhqwild': 'https://www.dropbox.com/s/p1slbtmzhcnw9q8/stylegan2_generator_afhqwild.h5?dl=1',
|
| 15 |
+
'brecahad': 'https://www.dropbox.com/s/28uykhj0ku6hwg2/stylegan2_generator_brecahad.h5?dl=1',
|
| 16 |
+
'car': 'https://www.dropbox.com/s/67o834b6xfg9x1q/stylegan2_generator_car.h5?dl=1',
|
| 17 |
+
'cat': 'https://www.dropbox.com/s/cu9egc4e74e1nig/stylegan2_generator_cat.h5?dl=1',
|
| 18 |
+
'church': 'https://www.dropbox.com/s/kwvokfwbrhsn58m/stylegan2_generator_church.h5?dl=1',
|
| 19 |
+
'cifar10': 'https://www.dropbox.com/s/h1kmymjzfwwkftk/stylegan2_generator_cifar10.h5?dl=1',
|
| 20 |
+
'ffhq': 'https://www.dropbox.com/s/e8de1peq7p8gq9d/stylegan2_generator_ffhq.h5?dl=1',
|
| 21 |
+
'horse': 'https://www.dropbox.com/s/3e5bimv2d41bc13/stylegan2_generator_horse.h5?dl=1',
|
| 22 |
+
'metfaces': 'https://www.dropbox.com/s/75klr5k6mgm7qdy/stylegan2_generator_metfaces.h5?dl=1'}
|
| 23 |
+
|
| 24 |
+
RESOLUTION = {'metfaces': 1024,
|
| 25 |
+
'ffhq': 1024,
|
| 26 |
+
'church': 256,
|
| 27 |
+
'cat': 256,
|
| 28 |
+
'horse': 256,
|
| 29 |
+
'car': 512,
|
| 30 |
+
'brecahad': 512,
|
| 31 |
+
'afhqwild': 512,
|
| 32 |
+
'afhqdog': 512,
|
| 33 |
+
'afhqcat': 512,
|
| 34 |
+
'cifar10': 32}
|
| 35 |
+
|
| 36 |
+
C_DIM = {'metfaces': 0,
|
| 37 |
+
'ffhq': 0,
|
| 38 |
+
'church': 0,
|
| 39 |
+
'cat': 0,
|
| 40 |
+
'horse': 0,
|
| 41 |
+
'car': 0,
|
| 42 |
+
'brecahad': 0,
|
| 43 |
+
'afhqwild': 0,
|
| 44 |
+
'afhqdog': 0,
|
| 45 |
+
'afhqcat': 0,
|
| 46 |
+
'cifar10': 10}
|
| 47 |
+
|
| 48 |
+
NUM_MAPPING_LAYERS = {'metfaces': 8,
|
| 49 |
+
'ffhq': 8,
|
| 50 |
+
'church': 8,
|
| 51 |
+
'cat': 8,
|
| 52 |
+
'horse': 8,
|
| 53 |
+
'car': 8,
|
| 54 |
+
'brecahad': 8,
|
| 55 |
+
'afhqwild': 8,
|
| 56 |
+
'afhqdog': 8,
|
| 57 |
+
'afhqcat': 8,
|
| 58 |
+
'cifar10': 2}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MappingNetwork(nn.Module):
|
| 62 |
+
"""
|
| 63 |
+
Mapping Network.
|
| 64 |
+
|
| 65 |
+
Attributes:
|
| 66 |
+
z_dim (int): Input latent (Z) dimensionality.
|
| 67 |
+
c_dim (int): Conditioning label (C) dimensionality, 0 = no label.
|
| 68 |
+
w_dim (int): Intermediate latent (W) dimensionality.
|
| 69 |
+
embed_features (int): Label embedding dimensionality, None = same as w_dim.
|
| 70 |
+
layer_features (int): Number of intermediate features in the mapping layers, None = same as w_dim.
|
| 71 |
+
num_ws (int): Number of intermediate latents to output, None = do not broadcast.
|
| 72 |
+
num_layers (int): Number of mapping layers.
|
| 73 |
+
pretrained (str): Which pretrained model to use, None for random initialization.
|
| 74 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
| 75 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
| 76 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 77 |
+
lr_multiplier (float): Learning rate multiplier for the mapping layers.
|
| 78 |
+
w_avg_beta (float): Decay for tracking the moving average of W during training, None = do not track.
|
| 79 |
+
dtype (str): Data type.
|
| 80 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 81 |
+
"""
|
| 82 |
+
# Dimensionality
|
| 83 |
+
z_dim: int=512
|
| 84 |
+
c_dim: int=0
|
| 85 |
+
w_dim: int=512
|
| 86 |
+
embed_features: int=None
|
| 87 |
+
layer_features: int=512
|
| 88 |
+
|
| 89 |
+
# Layers
|
| 90 |
+
num_ws: int=18
|
| 91 |
+
num_layers: int=8
|
| 92 |
+
|
| 93 |
+
# Pretrained
|
| 94 |
+
pretrained: str=None
|
| 95 |
+
param_dict: h5py.Group=None
|
| 96 |
+
ckpt_dir: str=None
|
| 97 |
+
|
| 98 |
+
# Internal details
|
| 99 |
+
activation: str='leaky_relu'
|
| 100 |
+
lr_multiplier: float=0.01
|
| 101 |
+
w_avg_beta: float=0.995
|
| 102 |
+
dtype: str='float32'
|
| 103 |
+
rng: Any=random.PRNGKey(0)
|
| 104 |
+
|
| 105 |
+
def setup(self):
|
| 106 |
+
self.embed_features_ = self.embed_features
|
| 107 |
+
self.c_dim_ = self.c_dim
|
| 108 |
+
self.layer_features_ = self.layer_features
|
| 109 |
+
self.num_layers_ = self.num_layers
|
| 110 |
+
self.param_dict_ = self.param_dict
|
| 111 |
+
|
| 112 |
+
if self.pretrained is not None and self.param_dict is None:
|
| 113 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
| 114 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
| 115 |
+
self.param_dict_ = h5py.File(ckpt_file, 'r')['mapping_network']
|
| 116 |
+
self.c_dim_ = C_DIM[self.pretrained]
|
| 117 |
+
self.num_layers_ = NUM_MAPPING_LAYERS[self.pretrained]
|
| 118 |
+
|
| 119 |
+
if self.embed_features_ is None:
|
| 120 |
+
self.embed_features_ = self.w_dim
|
| 121 |
+
if self.c_dim_ == 0:
|
| 122 |
+
self.embed_features_ = 0
|
| 123 |
+
if self.layer_features_ is None:
|
| 124 |
+
self.layer_features_ = self.w_dim
|
| 125 |
+
|
| 126 |
+
if self.param_dict_ is not None and 'w_avg' in self.param_dict_:
|
| 127 |
+
self.w_avg = self.variable('moving_stats', 'w_avg', lambda *_ : jnp.array(self.param_dict_['w_avg']), [self.w_dim])
|
| 128 |
+
else:
|
| 129 |
+
self.w_avg = self.variable('moving_stats', 'w_avg', jnp.zeros, [self.w_dim])
|
| 130 |
+
|
| 131 |
+
@nn.compact
|
| 132 |
+
def __call__(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, train=True):
|
| 133 |
+
"""
|
| 134 |
+
Run Mapping Network.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
z (tensor): Input noise, shape [N, z_dim].
|
| 138 |
+
c (tensor): Input labels, shape [N, c_dim].
|
| 139 |
+
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
|
| 140 |
+
truncation_cutoff (int): Controls truncation. None = disable.
|
| 141 |
+
skip_w_avg_update (bool): If True, updates the exponential moving average of W.
|
| 142 |
+
train (bool): Training mode.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
(tensor): Intermediate latent W.
|
| 146 |
+
"""
|
| 147 |
+
init_rng = self.rng
|
| 148 |
+
# Embed, normalize, and concat inputs.
|
| 149 |
+
x = None
|
| 150 |
+
if self.z_dim > 0:
|
| 151 |
+
x = ops.normalize_2nd_moment(z.astype(jnp.float32))
|
| 152 |
+
if self.c_dim_ > 0:
|
| 153 |
+
# Conditioning label
|
| 154 |
+
y = ops.LinearLayer(in_features=self.c_dim_,
|
| 155 |
+
out_features=self.embed_features_,
|
| 156 |
+
use_bias=True,
|
| 157 |
+
lr_multiplier=self.lr_multiplier,
|
| 158 |
+
activation='linear',
|
| 159 |
+
param_dict=self.param_dict_,
|
| 160 |
+
layer_name='label_embedding',
|
| 161 |
+
dtype=self.dtype,
|
| 162 |
+
rng=init_rng)(c.astype(jnp.float32))
|
| 163 |
+
|
| 164 |
+
y = ops.normalize_2nd_moment(y)
|
| 165 |
+
x = jnp.concatenate((x, y), axis=1) if x is not None else y
|
| 166 |
+
|
| 167 |
+
# Main layers.
|
| 168 |
+
for i in range(self.num_layers_):
|
| 169 |
+
init_rng, init_key = random.split(init_rng)
|
| 170 |
+
x = ops.LinearLayer(in_features=x.shape[1],
|
| 171 |
+
out_features=self.layer_features_,
|
| 172 |
+
use_bias=True,
|
| 173 |
+
lr_multiplier=self.lr_multiplier,
|
| 174 |
+
activation=self.activation,
|
| 175 |
+
param_dict=self.param_dict_,
|
| 176 |
+
layer_name=f'fc{i}',
|
| 177 |
+
dtype=self.dtype,
|
| 178 |
+
rng=init_key)(x)
|
| 179 |
+
|
| 180 |
+
# Update moving average of W.
|
| 181 |
+
if self.w_avg_beta is not None and train and not skip_w_avg_update:
|
| 182 |
+
self.w_avg.value = self.w_avg_beta * self.w_avg.value + (1 - self.w_avg_beta) * jnp.mean(x, axis=0)
|
| 183 |
+
|
| 184 |
+
# Broadcast.
|
| 185 |
+
if self.num_ws is not None:
|
| 186 |
+
x = jnp.repeat(jnp.expand_dims(x, axis=-2), repeats=self.num_ws, axis=-2)
|
| 187 |
+
|
| 188 |
+
# Apply truncation.
|
| 189 |
+
if truncation_psi != 1:
|
| 190 |
+
assert self.w_avg_beta is not None
|
| 191 |
+
if self.num_ws is None or truncation_cutoff is None:
|
| 192 |
+
x = truncation_psi * x + (1 - truncation_psi) * self.w_avg.value
|
| 193 |
+
else:
|
| 194 |
+
x[:, :truncation_cutoff] = truncation_psi * x[:, :truncation_cutoff] + (1 - truncation_psi) * self.w_avg.value
|
| 195 |
+
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class SynthesisLayer(nn.Module):
|
| 200 |
+
"""
|
| 201 |
+
Synthesis Layer.
|
| 202 |
+
|
| 203 |
+
Attributes:
|
| 204 |
+
fmaps (int): Number of output channels of the modulated convolution.
|
| 205 |
+
kernel (int): Kernel size of the modulated convolution.
|
| 206 |
+
layer_idx (int): Layer index. Used to access the latent code for a specific layer.
|
| 207 |
+
res (int): Resolution (log2) of the current layer.
|
| 208 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 209 |
+
up (bool): If True, upsample the spatial resolution.
|
| 210 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 211 |
+
use_noise (bool): If True, add spatial-specific noise.
|
| 212 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
| 213 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
| 214 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
| 215 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 216 |
+
dtype (str): Data dtype.
|
| 217 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 218 |
+
"""
|
| 219 |
+
fmaps: int
|
| 220 |
+
kernel: int
|
| 221 |
+
layer_idx: int
|
| 222 |
+
res: int
|
| 223 |
+
lr_multiplier: float=1
|
| 224 |
+
up: bool=False
|
| 225 |
+
activation: str='leaky_relu'
|
| 226 |
+
use_noise: bool=True
|
| 227 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
| 228 |
+
fused_modconv: bool=False
|
| 229 |
+
param_dict: h5py.Group=None
|
| 230 |
+
clip_conv: float=None
|
| 231 |
+
dtype: str='float32'
|
| 232 |
+
rng: Any=random.PRNGKey(0)
|
| 233 |
+
|
| 234 |
+
def setup(self):
|
| 235 |
+
if self.param_dict is not None:
|
| 236 |
+
noise_const = jnp.array(self.param_dict['noise_const'], dtype=self.dtype)
|
| 237 |
+
else:
|
| 238 |
+
noise_const = random.normal(self.rng, shape=(1, 2 ** self.res, 2 ** self.res, 1), dtype=self.dtype)
|
| 239 |
+
self.noise_const = self.variable('noise_consts', 'noise_const', lambda *_: noise_const)
|
| 240 |
+
|
| 241 |
+
@nn.compact
|
| 242 |
+
def __call__(self, x, dlatents, noise_mode='random', rng=random.PRNGKey(0)):
|
| 243 |
+
"""
|
| 244 |
+
Run Synthesis Layer.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
x (tensor): Input tensor of the shape [N, H, W, C].
|
| 248 |
+
dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
| 249 |
+
noise_mode (str): Noise type.
|
| 250 |
+
- 'const': Constant noise.
|
| 251 |
+
- 'random': Random noise.
|
| 252 |
+
- 'none': No noise.
|
| 253 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
(tensor): Output tensor of shape [N, H', W', fmaps].
|
| 257 |
+
"""
|
| 258 |
+
assert noise_mode in ['const', 'random', 'none']
|
| 259 |
+
|
| 260 |
+
linear_rng, conv_rng = random.split(self.rng)
|
| 261 |
+
# Affine transformation to obtain style variable.
|
| 262 |
+
s = ops.LinearLayer(in_features=dlatents[:, self.layer_idx].shape[1],
|
| 263 |
+
out_features=x.shape[3],
|
| 264 |
+
use_bias=True,
|
| 265 |
+
bias_init=1,
|
| 266 |
+
lr_multiplier=self.lr_multiplier,
|
| 267 |
+
param_dict=self.param_dict,
|
| 268 |
+
layer_name='affine',
|
| 269 |
+
dtype=self.dtype,
|
| 270 |
+
rng=linear_rng)(dlatents[:, self.layer_idx])
|
| 271 |
+
|
| 272 |
+
# Noise variables.
|
| 273 |
+
if self.param_dict is None:
|
| 274 |
+
noise_strength = jnp.zeros(())
|
| 275 |
+
else:
|
| 276 |
+
noise_strength = jnp.array(self.param_dict['noise_strength'])
|
| 277 |
+
noise_strength = self.param(name='noise_strength', init_fn=lambda *_ : noise_strength)
|
| 278 |
+
|
| 279 |
+
# Weight and bias for convolution operation.
|
| 280 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
| 281 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'conv', conv_rng)
|
| 282 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
| 283 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
| 284 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
| 285 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
| 286 |
+
|
| 287 |
+
x = ops.modulated_conv2d_layer(x=x,
|
| 288 |
+
w=w,
|
| 289 |
+
s=s,
|
| 290 |
+
fmaps=self.fmaps,
|
| 291 |
+
kernel=self.kernel,
|
| 292 |
+
up=self.up,
|
| 293 |
+
resample_kernel=self.resample_kernel,
|
| 294 |
+
fused_modconv=self.fused_modconv)
|
| 295 |
+
|
| 296 |
+
if self.use_noise and noise_mode != 'none':
|
| 297 |
+
if noise_mode == 'const':
|
| 298 |
+
noise = self.noise_const.value
|
| 299 |
+
elif noise_mode == 'random':
|
| 300 |
+
noise = random.normal(rng, shape=(x.shape[0], x.shape[1], x.shape[2], 1), dtype=self.dtype)
|
| 301 |
+
x += noise * noise_strength.astype(self.dtype)
|
| 302 |
+
x += b.astype(x.dtype)
|
| 303 |
+
x = ops.apply_activation(x, activation=self.activation)
|
| 304 |
+
if self.clip_conv is not None:
|
| 305 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class ToRGBLayer(nn.Module):
|
| 310 |
+
"""
|
| 311 |
+
To RGB Layer.
|
| 312 |
+
|
| 313 |
+
Attributes:
|
| 314 |
+
fmaps (int): Number of output channels of the modulated convolution.
|
| 315 |
+
layer_idx (int): Layer index. Used to access the latent code for a specific layer.
|
| 316 |
+
kernel (int): Kernel size of the modulated convolution.
|
| 317 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 318 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
| 319 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
| 320 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 321 |
+
dtype (str): Data dtype.
|
| 322 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 323 |
+
"""
|
| 324 |
+
fmaps: int
|
| 325 |
+
layer_idx: int
|
| 326 |
+
kernel: int=1
|
| 327 |
+
lr_multiplier: float=1
|
| 328 |
+
fused_modconv: bool=False
|
| 329 |
+
param_dict: h5py.Group=None
|
| 330 |
+
clip_conv: float=None
|
| 331 |
+
dtype: str='float32'
|
| 332 |
+
rng: Any=random.PRNGKey(0)
|
| 333 |
+
|
| 334 |
+
@nn.compact
|
| 335 |
+
def __call__(self, x, y, dlatents):
|
| 336 |
+
"""
|
| 337 |
+
Run To RGB Layer.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
| 341 |
+
y (tensor): Image of shape [N, H', W', fmaps].
|
| 342 |
+
dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
(tensor): Output tensor of shape [N, H', W', fmaps].
|
| 346 |
+
"""
|
| 347 |
+
# Affine transformation to obtain style variable.
|
| 348 |
+
s = ops.LinearLayer(in_features=dlatents[:, self.layer_idx].shape[1],
|
| 349 |
+
out_features=x.shape[3],
|
| 350 |
+
use_bias=True,
|
| 351 |
+
bias_init=1,
|
| 352 |
+
lr_multiplier=self.lr_multiplier,
|
| 353 |
+
param_dict=self.param_dict,
|
| 354 |
+
layer_name='affine',
|
| 355 |
+
dtype=self.dtype,
|
| 356 |
+
rng=self.rng)(dlatents[:, self.layer_idx])
|
| 357 |
+
|
| 358 |
+
# Weight and bias for convolution operation.
|
| 359 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
| 360 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'conv', self.rng)
|
| 361 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
| 362 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
| 363 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
| 364 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
| 365 |
+
|
| 366 |
+
x = ops.modulated_conv2d_layer(x, w, s, fmaps=self.fmaps, kernel=self.kernel, demodulate=False, fused_modconv=self.fused_modconv)
|
| 367 |
+
x += b.astype(x.dtype)
|
| 368 |
+
x = ops.apply_activation(x, activation='linear')
|
| 369 |
+
if self.clip_conv is not None:
|
| 370 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
| 371 |
+
if y is not None:
|
| 372 |
+
x += y.astype(jnp.float32)
|
| 373 |
+
return x
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class SynthesisBlock(nn.Module):
|
| 377 |
+
"""
|
| 378 |
+
Synthesis Block.
|
| 379 |
+
|
| 380 |
+
Attributes:
|
| 381 |
+
fmaps (int): Number of output channels of the modulated convolution.
|
| 382 |
+
res (int): Resolution (log2) of the current block.
|
| 383 |
+
num_layers (int): Number of layers in the current block.
|
| 384 |
+
num_channels (int): Number of output color channels.
|
| 385 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 386 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 387 |
+
use_noise (bool): If True, add spatial-specific noise.
|
| 388 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
| 389 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
| 390 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
| 391 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 392 |
+
dtype (str): Data dtype.
|
| 393 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 394 |
+
"""
|
| 395 |
+
fmaps: int
|
| 396 |
+
res: int
|
| 397 |
+
num_layers: int=2
|
| 398 |
+
num_channels: int=3
|
| 399 |
+
lr_multiplier: float=1
|
| 400 |
+
activation: str='leaky_relu'
|
| 401 |
+
use_noise: bool=True
|
| 402 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
| 403 |
+
fused_modconv: bool=False
|
| 404 |
+
param_dict: h5py.Group=None
|
| 405 |
+
clip_conv: float=None
|
| 406 |
+
dtype: str='float32'
|
| 407 |
+
rng: Any=random.PRNGKey(0)
|
| 408 |
+
|
| 409 |
+
@nn.compact
|
| 410 |
+
def __call__(self, x, y, dlatents_in, noise_mode='random', rng=random.PRNGKey(0)):
|
| 411 |
+
"""
|
| 412 |
+
Run Synthesis Block.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
| 416 |
+
y (tensor): Image of shape [N, H', W', fmaps].
|
| 417 |
+
dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
| 418 |
+
noise_mode (str): Noise type.
|
| 419 |
+
- 'const': Constant noise.
|
| 420 |
+
- 'random': Random noise.
|
| 421 |
+
- 'none': No noise.
|
| 422 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
(tensor): Output tensor of shape [N, H', W', fmaps].
|
| 426 |
+
"""
|
| 427 |
+
x = x.astype(self.dtype)
|
| 428 |
+
init_rng = self.rng
|
| 429 |
+
for i in range(self.num_layers):
|
| 430 |
+
init_rng, init_key = random.split(init_rng)
|
| 431 |
+
x = SynthesisLayer(fmaps=self.fmaps,
|
| 432 |
+
kernel=3,
|
| 433 |
+
layer_idx=self.res * 2 - (5 - i) if self.res > 2 else 0,
|
| 434 |
+
res=self.res,
|
| 435 |
+
lr_multiplier=self.lr_multiplier,
|
| 436 |
+
up=i == 0 and self.res != 2,
|
| 437 |
+
activation=self.activation,
|
| 438 |
+
use_noise=self.use_noise,
|
| 439 |
+
resample_kernel=self.resample_kernel,
|
| 440 |
+
fused_modconv=self.fused_modconv,
|
| 441 |
+
param_dict=self.param_dict[f'layer{i}'] if self.param_dict is not None else None,
|
| 442 |
+
dtype=self.dtype,
|
| 443 |
+
rng=init_key)(x, dlatents_in, noise_mode, rng)
|
| 444 |
+
|
| 445 |
+
if self.num_layers == 2:
|
| 446 |
+
k = ops.setup_filter(self.resample_kernel)
|
| 447 |
+
y = ops.upsample2d(y, f=k, up=2)
|
| 448 |
+
|
| 449 |
+
init_rng, init_key = random.split(init_rng)
|
| 450 |
+
y = ToRGBLayer(fmaps=self.num_channels,
|
| 451 |
+
layer_idx=self.res * 2 - 3,
|
| 452 |
+
lr_multiplier=self.lr_multiplier,
|
| 453 |
+
param_dict=self.param_dict['torgb'] if self.param_dict is not None else None,
|
| 454 |
+
dtype=self.dtype,
|
| 455 |
+
rng=init_key)(x, y, dlatents_in)
|
| 456 |
+
return x, y
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class SynthesisNetwork(nn.Module):
|
| 460 |
+
"""
|
| 461 |
+
Synthesis Network.
|
| 462 |
+
|
| 463 |
+
Attributes:
|
| 464 |
+
resolution (int): Output resolution.
|
| 465 |
+
num_channels (int): Number of output color channels.
|
| 466 |
+
w_dim (int): Input latent (Z) dimensionality.
|
| 467 |
+
fmap_base (int): Overall multiplier for the number of feature maps.
|
| 468 |
+
fmap_decay (int): Log2 feature map reduction when doubling the resolution.
|
| 469 |
+
fmap_min (int): Minimum number of feature maps in any layer.
|
| 470 |
+
fmap_max (int): Maximum number of feature maps in any layer.
|
| 471 |
+
fmap_const (int): Number of feature maps in the constant input layer. None = default.
|
| 472 |
+
pretrained (str): Which pretrained model to use, None for random initialization.
|
| 473 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
| 474 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
| 475 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 476 |
+
use_noise (bool): If True, add spatial-specific noise.
|
| 477 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
| 478 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
| 479 |
+
num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
|
| 480 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 481 |
+
dtype (str): Data type.
|
| 482 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 483 |
+
"""
|
| 484 |
+
# Dimensionality
|
| 485 |
+
resolution: int=1024
|
| 486 |
+
num_channels: int=3
|
| 487 |
+
w_dim: int=512
|
| 488 |
+
|
| 489 |
+
# Capacity
|
| 490 |
+
fmap_base: int=16384
|
| 491 |
+
fmap_decay: int=1
|
| 492 |
+
fmap_min: int=1
|
| 493 |
+
fmap_max: int=512
|
| 494 |
+
fmap_const: int=None
|
| 495 |
+
|
| 496 |
+
# Pretraining
|
| 497 |
+
pretrained: str=None
|
| 498 |
+
param_dict: h5py.Group=None
|
| 499 |
+
ckpt_dir: str=None
|
| 500 |
+
|
| 501 |
+
# Internal details
|
| 502 |
+
activation: str='leaky_relu'
|
| 503 |
+
use_noise: bool=True
|
| 504 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
| 505 |
+
fused_modconv: bool=False
|
| 506 |
+
num_fp16_res: int=0
|
| 507 |
+
clip_conv: float=None
|
| 508 |
+
dtype: str='float32'
|
| 509 |
+
rng: Any=random.PRNGKey(0)
|
| 510 |
+
|
| 511 |
+
def setup(self):
|
| 512 |
+
self.resolution_ = self.resolution
|
| 513 |
+
self.param_dict_ = self.param_dict
|
| 514 |
+
if self.pretrained is not None and self.param_dict is None:
|
| 515 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
| 516 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
| 517 |
+
self.param_dict_ = h5py.File(ckpt_file, 'r')['synthesis_network']
|
| 518 |
+
self.resolution_ = RESOLUTION[self.pretrained]
|
| 519 |
+
|
| 520 |
+
@nn.compact
|
| 521 |
+
def __call__(self, dlatents_in, noise_mode='random', rng=random.PRNGKey(0)):
|
| 522 |
+
"""
|
| 523 |
+
Run Synthesis Network.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
dlatents_in (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
| 527 |
+
noise_mode (str): Noise type.
|
| 528 |
+
- 'const': Constant noise.
|
| 529 |
+
- 'random': Random noise.
|
| 530 |
+
- 'none': No noise.
|
| 531 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
(tensor): Image of shape [N, H, W, num_channels].
|
| 535 |
+
"""
|
| 536 |
+
resolution_log2 = int(np.log2(self.resolution_))
|
| 537 |
+
assert self.resolution_ == 2 ** resolution_log2 and self.resolution_ >= 4
|
| 538 |
+
|
| 539 |
+
def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max)
|
| 540 |
+
num_layers = resolution_log2 * 2 - 2
|
| 541 |
+
|
| 542 |
+
fmaps = self.fmap_const if self.fmap_const is not None else nf(1)
|
| 543 |
+
|
| 544 |
+
if self.param_dict_ is None:
|
| 545 |
+
const = random.normal(self.rng, (1, 4, 4, fmaps), dtype=self.dtype)
|
| 546 |
+
else:
|
| 547 |
+
const = jnp.array(self.param_dict_['const'], dtype=self.dtype)
|
| 548 |
+
x = self.param(name='const', init_fn=lambda *_ : const)
|
| 549 |
+
x = jnp.repeat(x, repeats=dlatents_in.shape[0], axis=0)
|
| 550 |
+
|
| 551 |
+
y = None
|
| 552 |
+
|
| 553 |
+
dlatents_in = dlatents_in.astype(jnp.float32)
|
| 554 |
+
|
| 555 |
+
init_rng = self.rng
|
| 556 |
+
for res in range(2, resolution_log2 + 1):
|
| 557 |
+
init_rng, init_key = random.split(init_rng)
|
| 558 |
+
x, y = SynthesisBlock(fmaps=nf(res - 1),
|
| 559 |
+
res=res,
|
| 560 |
+
num_layers=1 if res == 2 else 2,
|
| 561 |
+
num_channels=self.num_channels,
|
| 562 |
+
activation=self.activation,
|
| 563 |
+
use_noise=self.use_noise,
|
| 564 |
+
resample_kernel=self.resample_kernel,
|
| 565 |
+
fused_modconv=self.fused_modconv,
|
| 566 |
+
param_dict=self.param_dict_[f'block_{2 ** res}x{2 ** res}'] if self.param_dict_ is not None else None,
|
| 567 |
+
clip_conv=self.clip_conv,
|
| 568 |
+
dtype=self.dtype if res > resolution_log2 - self.num_fp16_res else 'float32',
|
| 569 |
+
rng=init_key)(x, y, dlatents_in, noise_mode, rng)
|
| 570 |
+
|
| 571 |
+
return y
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
class Generator(nn.Module):
|
| 575 |
+
"""
|
| 576 |
+
Generator.
|
| 577 |
+
|
| 578 |
+
Attributes:
|
| 579 |
+
resolution (int): Output resolution.
|
| 580 |
+
num_channels (int): Number of output color channels.
|
| 581 |
+
z_dim (int): Input latent (Z) dimensionality.
|
| 582 |
+
c_dim (int): Conditioning label (C) dimensionality, 0 = no label.
|
| 583 |
+
w_dim (int): Intermediate latent (W) dimensionality.
|
| 584 |
+
mapping_layer_features (int): Number of intermediate features in the mapping layers, None = same as w_dim.
|
| 585 |
+
mapping_embed_features (int): Label embedding dimensionality, None = same as w_dim.
|
| 586 |
+
num_ws (int): Number of intermediate latents to output, None = do not broadcast.
|
| 587 |
+
num_mapping_layers (int): Number of mapping layers.
|
| 588 |
+
fmap_base (int): Overall multiplier for the number of feature maps.
|
| 589 |
+
fmap_decay (int): Log2 feature map reduction when doubling the resolution.
|
| 590 |
+
fmap_min (int): Minimum number of feature maps in any layer.
|
| 591 |
+
fmap_max (int): Maximum number of feature maps in any layer.
|
| 592 |
+
fmap_const (int): Number of feature maps in the constant input layer. None = default.
|
| 593 |
+
pretrained (str): Which pretrained model to use, None for random initialization.
|
| 594 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
| 595 |
+
use_noise (bool): If True, add spatial-specific noise.
|
| 596 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 597 |
+
w_avg_beta (float): Decay for tracking the moving average of W during training, None = do not track.
|
| 598 |
+
mapping_lr_multiplier (float): Learning rate multiplier for the mapping network.
|
| 599 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
| 600 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
| 601 |
+
num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
|
| 602 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
| 603 |
+
dtype (str): Data type.
|
| 604 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
| 605 |
+
"""
|
| 606 |
+
# Dimensionality
|
| 607 |
+
resolution: int=1024
|
| 608 |
+
num_channels: int=3
|
| 609 |
+
z_dim: int=512
|
| 610 |
+
c_dim: int=0
|
| 611 |
+
w_dim: int=512
|
| 612 |
+
mapping_layer_features: int=512
|
| 613 |
+
mapping_embed_features: int=None
|
| 614 |
+
|
| 615 |
+
# Layers
|
| 616 |
+
num_ws: int=18
|
| 617 |
+
num_mapping_layers: int=8
|
| 618 |
+
|
| 619 |
+
# Capacity
|
| 620 |
+
fmap_base: int=16384
|
| 621 |
+
fmap_decay: int=1
|
| 622 |
+
fmap_min: int=1
|
| 623 |
+
fmap_max: int=512
|
| 624 |
+
fmap_const: int=None
|
| 625 |
+
|
| 626 |
+
# Pretraining
|
| 627 |
+
pretrained: str=None
|
| 628 |
+
ckpt_dir: str=None
|
| 629 |
+
|
| 630 |
+
# Internal details
|
| 631 |
+
use_noise: bool=True
|
| 632 |
+
activation: str='leaky_relu'
|
| 633 |
+
w_avg_beta: float=0.995
|
| 634 |
+
mapping_lr_multiplier: float=0.01
|
| 635 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
| 636 |
+
fused_modconv: bool=False
|
| 637 |
+
num_fp16_res: int=0
|
| 638 |
+
clip_conv: float=None
|
| 639 |
+
dtype: str='float32'
|
| 640 |
+
rng: Any=random.PRNGKey(0)
|
| 641 |
+
|
| 642 |
+
def setup(self):
|
| 643 |
+
self.resolution_ = self.resolution
|
| 644 |
+
self.c_dim_ = self.c_dim
|
| 645 |
+
self.num_mapping_layers_ = self.num_mapping_layers
|
| 646 |
+
if self.pretrained is not None:
|
| 647 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
| 648 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
| 649 |
+
self.param_dict = h5py.File(ckpt_file, 'r')
|
| 650 |
+
self.resolution_ = RESOLUTION[self.pretrained]
|
| 651 |
+
self.c_dim_ = C_DIM[self.pretrained]
|
| 652 |
+
self.num_mapping_layers_ = NUM_MAPPING_LAYERS[self.pretrained]
|
| 653 |
+
else:
|
| 654 |
+
self.param_dict = None
|
| 655 |
+
self.init_rng_mapping, self.init_rng_synthesis = random.split(self.rng)
|
| 656 |
+
|
| 657 |
+
@nn.compact
|
| 658 |
+
def __call__(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, train=True, noise_mode='random', rng=random.PRNGKey(0)):
|
| 659 |
+
"""
|
| 660 |
+
Run Generator.
|
| 661 |
+
|
| 662 |
+
Args:
|
| 663 |
+
z (tensor): Input noise, shape [N, z_dim].
|
| 664 |
+
c (tensor): Input labels, shape [N, c_dim].
|
| 665 |
+
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
|
| 666 |
+
truncation_cutoff (int): Controls truncation. None = disable.
|
| 667 |
+
skip_w_avg_update (bool): If True, updates the exponential moving average of W.
|
| 668 |
+
train (bool): Training mode.
|
| 669 |
+
noise_mode (str): Noise type.
|
| 670 |
+
- 'const': Constant noise.
|
| 671 |
+
- 'random': Random noise.
|
| 672 |
+
- 'none': No noise.
|
| 673 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
| 674 |
+
|
| 675 |
+
Returns:
|
| 676 |
+
(tensor): Image of shape [N, H, W, num_channels].
|
| 677 |
+
"""
|
| 678 |
+
dlatents_in = MappingNetwork(z_dim=self.z_dim,
|
| 679 |
+
c_dim=self.c_dim_,
|
| 680 |
+
w_dim=self.w_dim,
|
| 681 |
+
num_ws=self.num_ws,
|
| 682 |
+
num_layers=self.num_mapping_layers_,
|
| 683 |
+
embed_features=self.mapping_embed_features,
|
| 684 |
+
layer_features=self.mapping_layer_features,
|
| 685 |
+
activation=self.activation,
|
| 686 |
+
lr_multiplier=self.mapping_lr_multiplier,
|
| 687 |
+
w_avg_beta=self.w_avg_beta,
|
| 688 |
+
param_dict=self.param_dict['mapping_network'] if self.param_dict is not None else None,
|
| 689 |
+
dtype=self.dtype,
|
| 690 |
+
rng=self.init_rng_mapping,
|
| 691 |
+
name='mapping_network')(z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train)
|
| 692 |
+
|
| 693 |
+
x = SynthesisNetwork(resolution=self.resolution_,
|
| 694 |
+
num_channels=self.num_channels,
|
| 695 |
+
w_dim=self.w_dim,
|
| 696 |
+
fmap_base=self.fmap_base,
|
| 697 |
+
fmap_decay=self.fmap_decay,
|
| 698 |
+
fmap_min=self.fmap_min,
|
| 699 |
+
fmap_max=self.fmap_max,
|
| 700 |
+
fmap_const=self.fmap_const,
|
| 701 |
+
param_dict=self.param_dict['synthesis_network'] if self.param_dict is not None else None,
|
| 702 |
+
activation=self.activation,
|
| 703 |
+
use_noise=self.use_noise,
|
| 704 |
+
resample_kernel=self.resample_kernel,
|
| 705 |
+
fused_modconv=self.fused_modconv,
|
| 706 |
+
num_fp16_res=self.num_fp16_res,
|
| 707 |
+
clip_conv=self.clip_conv,
|
| 708 |
+
dtype=self.dtype,
|
| 709 |
+
rng=self.init_rng_synthesis,
|
| 710 |
+
name='synthesis_network')(dlatents_in, noise_mode, rng)
|
| 711 |
+
|
| 712 |
+
return x
|
| 713 |
+
|
stylegan2/ops.py
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
from jax import random
|
| 4 |
+
import flax.linen as nn
|
| 5 |
+
from jax import jit
|
| 6 |
+
import numpy as np
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Any
|
| 9 |
+
import h5py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
#------------------------------------------------------
|
| 13 |
+
# Other
|
| 14 |
+
#------------------------------------------------------
|
| 15 |
+
def minibatch_stddev_layer(x, group_size=None, num_new_features=1):
|
| 16 |
+
if group_size is None:
|
| 17 |
+
group_size = x.shape[0]
|
| 18 |
+
else:
|
| 19 |
+
# Minibatch must be divisible by (or smaller than) group_size.
|
| 20 |
+
group_size = min(group_size, x.shape[0])
|
| 21 |
+
|
| 22 |
+
G = group_size
|
| 23 |
+
F = num_new_features
|
| 24 |
+
_, H, W, C = x.shape
|
| 25 |
+
c = C // F
|
| 26 |
+
|
| 27 |
+
# [NHWC] Cast to FP32.
|
| 28 |
+
y = x.astype(jnp.float32)
|
| 29 |
+
# [GnHWFc] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
| 30 |
+
y = jnp.reshape(y, newshape=(G, -1, H, W, F, c))
|
| 31 |
+
# [GnHWFc] Subtract mean over group.
|
| 32 |
+
y -= jnp.mean(y, axis=0)
|
| 33 |
+
# [nHWFc] Calc variance over group.
|
| 34 |
+
y = jnp.mean(jnp.square(y), axis=0)
|
| 35 |
+
# [nHWFc] Calc stddev over group.
|
| 36 |
+
y = jnp.sqrt(y + 1e-8)
|
| 37 |
+
# [nF] Take average over channels and pixels.
|
| 38 |
+
y = jnp.mean(y, axis=(1, 2, 4))
|
| 39 |
+
# [nF] Cast back to original data type.
|
| 40 |
+
y = y.astype(x.dtype)
|
| 41 |
+
# [n11F] Add missing dimensions.
|
| 42 |
+
y = jnp.reshape(y, newshape=(-1, 1, 1, F))
|
| 43 |
+
# [NHWC] Replicate over group and pixels.
|
| 44 |
+
y = jnp.tile(y, (G, H, W, 1))
|
| 45 |
+
return jnp.concatenate((x, y), axis=3)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
#------------------------------------------------------
|
| 49 |
+
# Activation
|
| 50 |
+
#------------------------------------------------------
|
| 51 |
+
def apply_activation(x, activation='linear', alpha=0.2, gain=np.sqrt(2)):
|
| 52 |
+
gain = jnp.array(gain, dtype=x.dtype)
|
| 53 |
+
if activation == 'relu':
|
| 54 |
+
return jax.nn.relu(x) * gain
|
| 55 |
+
if activation == 'leaky_relu':
|
| 56 |
+
return jax.nn.leaky_relu(x, negative_slope=alpha) * gain
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
#------------------------------------------------------
|
| 61 |
+
# Weights
|
| 62 |
+
#------------------------------------------------------
|
| 63 |
+
def get_weight(shape, lr_multiplier=1, bias=True, param_dict=None, layer_name='', key=None):
|
| 64 |
+
if param_dict is None:
|
| 65 |
+
w = random.normal(key, shape=shape, dtype=jnp.float32) / lr_multiplier
|
| 66 |
+
if bias: b = jnp.zeros(shape=(shape[-1],), dtype=jnp.float32)
|
| 67 |
+
else:
|
| 68 |
+
w = jnp.array(param_dict[layer_name]['weight']).astype(jnp.float32)
|
| 69 |
+
if bias: b = jnp.array(param_dict[layer_name]['bias']).astype(jnp.float32)
|
| 70 |
+
|
| 71 |
+
if bias: return w, b
|
| 72 |
+
return w
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def equalize_lr_weight(w, lr_multiplier=1):
|
| 76 |
+
"""
|
| 77 |
+
Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
w (tensor): Weight parameter. Shape [kernel, kernel, fmaps_in, fmaps_out]
|
| 81 |
+
for convolutions and shape [in, out] for MLPs.
|
| 82 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
(tensor): Scaled weight parameter.
|
| 86 |
+
"""
|
| 87 |
+
in_features = np.prod(w.shape[:-1])
|
| 88 |
+
gain = lr_multiplier / np.sqrt(in_features)
|
| 89 |
+
w *= gain
|
| 90 |
+
return w
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def equalize_lr_bias(b, lr_multiplier=1):
|
| 94 |
+
"""
|
| 95 |
+
Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
b (tensor): Bias parameter.
|
| 99 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
(tensor): Scaled bias parameter.
|
| 103 |
+
"""
|
| 104 |
+
gain = lr_multiplier
|
| 105 |
+
b *= gain
|
| 106 |
+
return b
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
#------------------------------------------------------
|
| 110 |
+
# Normalization
|
| 111 |
+
#------------------------------------------------------
|
| 112 |
+
def normalize_2nd_moment(x, eps=1e-8):
|
| 113 |
+
return x * jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=1, keepdims=True) + eps)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
#------------------------------------------------------
|
| 117 |
+
# Upsampling
|
| 118 |
+
#------------------------------------------------------
|
| 119 |
+
def setup_filter(f, normalize=True, flip_filter=False, gain=1, separable=None):
|
| 120 |
+
"""
|
| 121 |
+
Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
f (tensor): Tensor or python list of the shape.
|
| 125 |
+
normalize (bool): Normalize the filter so that it retains the magnitude.
|
| 126 |
+
for constant input signal (DC)? (default: True).
|
| 127 |
+
flip_filter (bool): Flip the filter? (default: False).
|
| 128 |
+
gain (int): Overall scaling factor for signal magnitude (default: 1).
|
| 129 |
+
separable: Return a separable filter? (default: select automatically).
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
(tensor): Output filter of shape [filter_height, filter_width] or [filter_taps]
|
| 133 |
+
"""
|
| 134 |
+
# Validate.
|
| 135 |
+
if f is None:
|
| 136 |
+
f = 1
|
| 137 |
+
f = jnp.array(f, dtype=jnp.float32)
|
| 138 |
+
assert f.ndim in [0, 1, 2]
|
| 139 |
+
assert f.size > 0
|
| 140 |
+
if f.ndim == 0:
|
| 141 |
+
f = f[jnp.newaxis]
|
| 142 |
+
|
| 143 |
+
# Separable?
|
| 144 |
+
if separable is None:
|
| 145 |
+
separable = (f.ndim == 1 and f.size >= 8)
|
| 146 |
+
if f.ndim == 1 and not separable:
|
| 147 |
+
f = jnp.outer(f, f)
|
| 148 |
+
assert f.ndim == (1 if separable else 2)
|
| 149 |
+
|
| 150 |
+
# Apply normalize, flip, gain, and device.
|
| 151 |
+
if normalize:
|
| 152 |
+
f /= jnp.sum(f)
|
| 153 |
+
if flip_filter:
|
| 154 |
+
for i in range(f.ndim):
|
| 155 |
+
f = jnp.flip(f, axis=i)
|
| 156 |
+
f = f * (gain ** (f.ndim / 2))
|
| 157 |
+
return f
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def upfirdn2d(x, f, padding=(2, 1, 2, 1), up=1, down=1, strides=(1, 1), flip_filter=False, gain=1):
|
| 161 |
+
|
| 162 |
+
if f is None:
|
| 163 |
+
f = jnp.ones((1, 1), dtype=jnp.float32)
|
| 164 |
+
|
| 165 |
+
B, H, W, C = x.shape
|
| 166 |
+
padx0, padx1, pady0, pady1 = padding
|
| 167 |
+
|
| 168 |
+
# upsample by inserting zeros
|
| 169 |
+
x = jnp.reshape(x, newshape=(B, H, 1, W, 1, C))
|
| 170 |
+
x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, up - 1), (0, 0), (0, up - 1), (0, 0)))
|
| 171 |
+
x = jnp.reshape(x, newshape=(B, H * up, W * up, C))
|
| 172 |
+
|
| 173 |
+
# padding
|
| 174 |
+
x = jnp.pad(x, pad_width=((0, 0), (max(pady0, 0), max(pady1, 0)), (max(padx0, 0), max(padx1, 0)), (0, 0)))
|
| 175 |
+
x = x[:, max(-pady0, 0) : x.shape[1] - max(-pady1, 0), max(-padx0, 0) : x.shape[2] - max(-padx1, 0)]
|
| 176 |
+
|
| 177 |
+
# setup filter
|
| 178 |
+
f = f * (gain ** (f.ndim / 2))
|
| 179 |
+
if not flip_filter:
|
| 180 |
+
for i in range(f.ndim):
|
| 181 |
+
f = jnp.flip(f, axis=i)
|
| 182 |
+
|
| 183 |
+
# convole filter
|
| 184 |
+
f = jnp.repeat(jnp.expand_dims(f, axis=(-2, -1)), repeats=C, axis=-1)
|
| 185 |
+
if f.ndim == 4:
|
| 186 |
+
x = jax.lax.conv_general_dilated(x,
|
| 187 |
+
f.astype(x.dtype),
|
| 188 |
+
window_strides=strides or (1,) * (x.ndim - 2),
|
| 189 |
+
padding='valid',
|
| 190 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
| 191 |
+
feature_group_count=C)
|
| 192 |
+
else:
|
| 193 |
+
x = jax.lax.conv_general_dilated(x,
|
| 194 |
+
jnp.expand_dims(f, axis=0).astype(x.dtype),
|
| 195 |
+
window_strides=strides or (1,) * (x.ndim - 2),
|
| 196 |
+
padding='valid',
|
| 197 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
| 198 |
+
feature_group_count=C)
|
| 199 |
+
x = jax.lax.conv_general_dilated(x,
|
| 200 |
+
jnp.expand_dims(f, axis=1).astype(x.dtype),
|
| 201 |
+
window_strides=strides or (1,) * (x.ndim - 2),
|
| 202 |
+
padding='valid',
|
| 203 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
| 204 |
+
feature_group_count=C)
|
| 205 |
+
x = x[:, ::down, ::down]
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1):
|
| 210 |
+
if f.ndim == 1:
|
| 211 |
+
fh, fw = f.shape[0], f.shape[0]
|
| 212 |
+
elif f.ndim == 2:
|
| 213 |
+
fh, fw = f.shape[0], f.shape[1]
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError('Invalid filter shape:', f.shape)
|
| 216 |
+
padx0 = padding + (fw + up - 1) // 2
|
| 217 |
+
padx1 = padding + (fw - up) // 2
|
| 218 |
+
pady0 = padding + (fh + up - 1) // 2
|
| 219 |
+
pady1 = padding + (fh - up) // 2
|
| 220 |
+
return upfirdn2d(x, f=f, up=up, padding=(padx0, padx1, pady0, pady1), flip_filter=flip_filter, gain=gain * up * up)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
#------------------------------------------------------
|
| 224 |
+
# Linear
|
| 225 |
+
#------------------------------------------------------
|
| 226 |
+
class LinearLayer(nn.Module):
|
| 227 |
+
"""
|
| 228 |
+
Linear Layer.
|
| 229 |
+
|
| 230 |
+
Attributes:
|
| 231 |
+
in_features (int): Input dimension.
|
| 232 |
+
out_features (int): Output dimension.
|
| 233 |
+
use_bias (bool): If True, use bias.
|
| 234 |
+
bias_init (int): Bias init.
|
| 235 |
+
lr_multiplier (float): Learning rate multiplier.
|
| 236 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
| 237 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters.
|
| 238 |
+
layer_name (str): Layer name.
|
| 239 |
+
dtype (str): Data type.
|
| 240 |
+
rng (jax.random.PRNGKey): Random seed for initialization.
|
| 241 |
+
"""
|
| 242 |
+
in_features: int
|
| 243 |
+
out_features: int
|
| 244 |
+
use_bias: bool=True
|
| 245 |
+
bias_init: int=0
|
| 246 |
+
lr_multiplier: float=1
|
| 247 |
+
activation: str='linear'
|
| 248 |
+
param_dict: h5py.Group=None
|
| 249 |
+
layer_name: str=None
|
| 250 |
+
dtype: str='float32'
|
| 251 |
+
rng: Any=random.PRNGKey(0)
|
| 252 |
+
|
| 253 |
+
@nn.compact
|
| 254 |
+
def __call__(self, x):
|
| 255 |
+
"""
|
| 256 |
+
Run Linear Layer.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
x (tensor): Input tensor of shape [N, in_features].
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
(tensor): Output tensor of shape [N, out_features].
|
| 263 |
+
"""
|
| 264 |
+
w_shape = [self.in_features, self.out_features]
|
| 265 |
+
params = get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
|
| 266 |
+
|
| 267 |
+
if self.use_bias:
|
| 268 |
+
w, b = params
|
| 269 |
+
else:
|
| 270 |
+
w = params
|
| 271 |
+
|
| 272 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
| 273 |
+
w = equalize_lr_weight(w, self.lr_multiplier)
|
| 274 |
+
x = jnp.matmul(x, w.astype(x.dtype))
|
| 275 |
+
|
| 276 |
+
if self.use_bias:
|
| 277 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
| 278 |
+
b = equalize_lr_bias(b, self.lr_multiplier)
|
| 279 |
+
x += b.astype(x.dtype)
|
| 280 |
+
x += self.bias_init
|
| 281 |
+
|
| 282 |
+
x = apply_activation(x, activation=self.activation)
|
| 283 |
+
return x
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
#------------------------------------------------------
|
| 287 |
+
# Convolution
|
| 288 |
+
#------------------------------------------------------
|
| 289 |
+
def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0):
|
| 290 |
+
"""
|
| 291 |
+
Fused downsample convolution.
|
| 292 |
+
|
| 293 |
+
Padding is performed only once at the beginning, not between the operations.
|
| 294 |
+
The fused op is considerably more efficient than performing the same calculation
|
| 295 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
x (tensor): Input tensor of the shape [N, H, W, C].
|
| 299 |
+
w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels].
|
| 300 |
+
Grouped convolution can be performed by inChannels = x.shape[0] // numGroups.
|
| 301 |
+
k (tensor): FIR filter of the shape [firH, firW] or [firN].
|
| 302 |
+
The default is `[1] * factor`, which corresponds to average pooling.
|
| 303 |
+
factor (int): Downsampling factor (default: 2).
|
| 304 |
+
gain (float): Scaling factor for signal magnitude (default: 1.0).
|
| 305 |
+
padding (int): Number of pixels to pad or crop the output on each side (default: 0).
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
(tensor): Output of the shape [N, H // factor, W // factor, C].
|
| 309 |
+
"""
|
| 310 |
+
assert isinstance(factor, int) and factor >= 1
|
| 311 |
+
assert isinstance(padding, int)
|
| 312 |
+
|
| 313 |
+
# Check weight shape.
|
| 314 |
+
ch, cw, _inC, _outC = w.shape
|
| 315 |
+
assert cw == ch
|
| 316 |
+
|
| 317 |
+
# Setup filter kernel.
|
| 318 |
+
k = setup_filter(k, gain=gain)
|
| 319 |
+
assert k.shape[0] == k.shape[1]
|
| 320 |
+
|
| 321 |
+
# Execute.
|
| 322 |
+
pad0 = (k.shape[0] - factor + cw) // 2 + padding * factor
|
| 323 |
+
pad1 = (k.shape[0] - factor + cw - 1) // 2 + padding * factor
|
| 324 |
+
x = upfirdn2d(x=x, f=k, padding=(pad0, pad0, pad1, pad1))
|
| 325 |
+
|
| 326 |
+
x = jax.lax.conv_general_dilated(x,
|
| 327 |
+
w,
|
| 328 |
+
window_strides=(factor, factor),
|
| 329 |
+
padding='VALID',
|
| 330 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape))
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0):
|
| 335 |
+
"""
|
| 336 |
+
Fused upsample convolution.
|
| 337 |
+
|
| 338 |
+
Padding is performed only once at the beginning, not between the operations.
|
| 339 |
+
The fused op is considerably more efficient than performing the same calculation
|
| 340 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
x (tensor): Input tensor of the shape [N, H, W, C].
|
| 344 |
+
w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels].
|
| 345 |
+
Grouped convolution can be performed by inChannels = x.shape[0] // numGroups.
|
| 346 |
+
k (tensor): FIR filter of the shape [firH, firW] or [firN].
|
| 347 |
+
The default is [1] * factor, which corresponds to nearest-neighbor upsampling.
|
| 348 |
+
factor (int): Integer upsampling factor (default: 2).
|
| 349 |
+
gain (float): Scaling factor for signal magnitude (default: 1.0).
|
| 350 |
+
padding (int): Number of pixels to pad or crop the output on each side (default: 0).
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
(tensor): Output of the shape [N, H * factor, W * factor, C].
|
| 354 |
+
"""
|
| 355 |
+
assert isinstance(factor, int) and factor >= 1
|
| 356 |
+
assert isinstance(padding, int)
|
| 357 |
+
|
| 358 |
+
# Check weight shape.
|
| 359 |
+
ch, cw, _inC, _outC = w.shape
|
| 360 |
+
inC = w.shape[2]
|
| 361 |
+
outC = w.shape[3]
|
| 362 |
+
assert cw == ch
|
| 363 |
+
|
| 364 |
+
# Fast path for 1x1 convolution.
|
| 365 |
+
if cw == 1 and ch == 1:
|
| 366 |
+
x = jax.lax.conv_general_dilated(x,
|
| 367 |
+
w,
|
| 368 |
+
window_strides=(1, 1),
|
| 369 |
+
padding='VALID',
|
| 370 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape))
|
| 371 |
+
k = setup_filter(k, gain=gain * (factor ** 2))
|
| 372 |
+
pad0 = (k.shape[0] + factor - cw) // 2 + padding
|
| 373 |
+
pad1 = (k.shape[0] - factor) // 2 + padding
|
| 374 |
+
x = upfirdn2d(x, f=k, up=factor, padding=(pad0, pad1, pad0, pad1))
|
| 375 |
+
return x
|
| 376 |
+
|
| 377 |
+
# Setup filter kernel.
|
| 378 |
+
k = setup_filter(k, gain=gain * (factor ** 2))
|
| 379 |
+
assert k.shape[0] == k.shape[1]
|
| 380 |
+
|
| 381 |
+
# Determine data dimensions.
|
| 382 |
+
stride = (factor, factor)
|
| 383 |
+
output_shape = ((x.shape[1] - 1) * factor + ch, (x.shape[2] - 1) * factor + cw)
|
| 384 |
+
num_groups = x.shape[3] // inC
|
| 385 |
+
|
| 386 |
+
# Transpose weights.
|
| 387 |
+
w = jnp.reshape(w, (ch, cw, inC, num_groups, -1))
|
| 388 |
+
w = jnp.transpose(w[::-1, ::-1], (0, 1, 4, 3, 2))
|
| 389 |
+
w = jnp.reshape(w, (ch, cw, -1, num_groups * inC))
|
| 390 |
+
|
| 391 |
+
# Execute.
|
| 392 |
+
x = gradient_based_conv_transpose(lhs=x,
|
| 393 |
+
rhs=w,
|
| 394 |
+
strides=stride,
|
| 395 |
+
padding='VALID',
|
| 396 |
+
output_padding=(0, 0, 0, 0),
|
| 397 |
+
output_shape=output_shape,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
pad0 = (k.shape[0] + factor - cw) // 2 + padding
|
| 401 |
+
pad1 = (k.shape[0] - factor - cw + 3) // 2 + padding
|
| 402 |
+
x = upfirdn2d(x=x, f=k, padding=(pad0, pad1, pad0, pad1))
|
| 403 |
+
return x
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0):
|
| 407 |
+
assert not (up and down)
|
| 408 |
+
kernel = w.shape[0]
|
| 409 |
+
assert w.shape[1] == kernel
|
| 410 |
+
assert kernel >= 1 and kernel % 2 == 1
|
| 411 |
+
|
| 412 |
+
num_groups = x.shape[3] // w.shape[2]
|
| 413 |
+
|
| 414 |
+
w = w.astype(x.dtype)
|
| 415 |
+
if up:
|
| 416 |
+
x = upsample_conv_2d(x, w, k=resample_kernel, padding=padding)
|
| 417 |
+
elif down:
|
| 418 |
+
x = conv_downsample_2d(x, w, k=resample_kernel, padding=padding)
|
| 419 |
+
else:
|
| 420 |
+
padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding]
|
| 421 |
+
x = jax.lax.conv_general_dilated(x,
|
| 422 |
+
w,
|
| 423 |
+
window_strides=(1, 1),
|
| 424 |
+
padding=padding_mode,
|
| 425 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
| 426 |
+
feature_group_count=num_groups)
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def modulated_conv2d_layer(x, w, s, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, fused_modconv=False):
|
| 431 |
+
assert not (up and down)
|
| 432 |
+
assert kernel >= 1 and kernel % 2 == 1
|
| 433 |
+
|
| 434 |
+
# Get weight.
|
| 435 |
+
wshape = (kernel, kernel, x.shape[3], fmaps)
|
| 436 |
+
if x.dtype.name == 'float16' and not fused_modconv and demodulate:
|
| 437 |
+
w *= jnp.sqrt(1 / np.prod(wshape[:-1])) / jnp.max(jnp.abs(w), axis=(0, 1, 2)) # Pre-normalize to avoid float16 overflow.
|
| 438 |
+
ww = w[jnp.newaxis] # [BkkIO] Introduce minibatch dimension.
|
| 439 |
+
|
| 440 |
+
# Modulate.
|
| 441 |
+
if x.dtype.name == 'float16' and not fused_modconv and demodulate:
|
| 442 |
+
s *= 1 / jnp.max(jnp.abs(s)) # Pre-normalize to avoid float16 overflow.
|
| 443 |
+
ww *= s[:, jnp.newaxis, jnp.newaxis, :, jnp.newaxis].astype(w.dtype) # [BkkIO] Scale input feature maps.
|
| 444 |
+
|
| 445 |
+
# Demodulate.
|
| 446 |
+
if demodulate:
|
| 447 |
+
d = jax.lax.rsqrt(jnp.sum(jnp.square(ww), axis=(1, 2, 3)) + 1e-8) # [BO] Scaling factor.
|
| 448 |
+
ww *= d[:, jnp.newaxis, jnp.newaxis, jnp.newaxis, :] # [BkkIO] Scale output feature maps.
|
| 449 |
+
|
| 450 |
+
# Reshape/scale input.
|
| 451 |
+
if fused_modconv:
|
| 452 |
+
x = jnp.transpose(x, axes=(0, 3, 1, 2))
|
| 453 |
+
x = jnp.reshape(x, (1, -1, x.shape[2], x.shape[3])) # Fused => reshape minibatch to convolution groups.
|
| 454 |
+
x = jnp.transpose(x, axes=(0, 2, 3, 1))
|
| 455 |
+
w = jnp.reshape(jnp.transpose(ww, (1, 2, 3, 0, 4)), (ww.shape[1], ww.shape[2], ww.shape[3], -1))
|
| 456 |
+
else:
|
| 457 |
+
x *= s[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BIhw] Not fused => scale input activations.
|
| 458 |
+
|
| 459 |
+
# 2D convolution.
|
| 460 |
+
x = conv2d(x, w.astype(x.dtype), up=up, down=down, resample_kernel=resample_kernel)
|
| 461 |
+
|
| 462 |
+
# Reshape/scale output.
|
| 463 |
+
if fused_modconv:
|
| 464 |
+
x = jnp.transpose(x, axes=(0, 3, 1, 2))
|
| 465 |
+
x = jnp.reshape(x, (-1, fmaps, x.shape[2], x.shape[3])) # Fused => reshape convolution groups back to minibatch.
|
| 466 |
+
x = jnp.transpose(x, axes=(0, 2, 3, 1))
|
| 467 |
+
elif demodulate:
|
| 468 |
+
x *= d[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BOhw] Not fused => scale output activations.
|
| 469 |
+
|
| 470 |
+
return x
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1):
|
| 474 |
+
"""
|
| 475 |
+
Taken from: https://github.com/google/jax/pull/5772/commits
|
| 476 |
+
|
| 477 |
+
Determines the output length of a transposed convolution given the input length.
|
| 478 |
+
Function modified from Keras.
|
| 479 |
+
Arguments:
|
| 480 |
+
input_length: Integer.
|
| 481 |
+
filter_size: Integer.
|
| 482 |
+
padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple.
|
| 483 |
+
output_padding: Integer, amount of padding along the output dimension. Can
|
| 484 |
+
be set to `None` in which case the output length is inferred.
|
| 485 |
+
stride: Integer.
|
| 486 |
+
dilation: Integer.
|
| 487 |
+
Returns:
|
| 488 |
+
The output length (integer).
|
| 489 |
+
"""
|
| 490 |
+
if input_length is None:
|
| 491 |
+
return None
|
| 492 |
+
|
| 493 |
+
# Get the dilated kernel size
|
| 494 |
+
filter_size = filter_size + (filter_size - 1) * (dilation - 1)
|
| 495 |
+
|
| 496 |
+
# Infer length if output padding is None, else compute the exact length
|
| 497 |
+
if output_padding is None:
|
| 498 |
+
if padding == 'VALID':
|
| 499 |
+
length = input_length * stride + max(filter_size - stride, 0)
|
| 500 |
+
elif padding == 'SAME':
|
| 501 |
+
length = input_length * stride
|
| 502 |
+
else:
|
| 503 |
+
length = ((input_length - 1) * stride + filter_size - padding[0] - padding[1])
|
| 504 |
+
|
| 505 |
+
else:
|
| 506 |
+
if padding == 'SAME':
|
| 507 |
+
pad = filter_size // 2
|
| 508 |
+
total_pad = pad * 2
|
| 509 |
+
elif padding == 'VALID':
|
| 510 |
+
total_pad = 0
|
| 511 |
+
else:
|
| 512 |
+
total_pad = padding[0] + padding[1]
|
| 513 |
+
|
| 514 |
+
length = ((input_length - 1) * stride + filter_size - total_pad + output_padding)
|
| 515 |
+
return length
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def _compute_adjusted_padding(input_size, output_size, kernel_size, stride, padding, dilation=1):
|
| 519 |
+
"""
|
| 520 |
+
Taken from: https://github.com/google/jax/pull/5772/commits
|
| 521 |
+
|
| 522 |
+
Computes adjusted padding for desired ConvTranspose `output_size`.
|
| 523 |
+
Ported from DeepMind Haiku.
|
| 524 |
+
"""
|
| 525 |
+
kernel_size = (kernel_size - 1) * dilation + 1
|
| 526 |
+
if padding == 'VALID':
|
| 527 |
+
expected_input_size = (output_size - kernel_size + stride) // stride
|
| 528 |
+
if input_size != expected_input_size:
|
| 529 |
+
raise ValueError(f'The expected input size with the current set of input '
|
| 530 |
+
f'parameters is {expected_input_size} which doesn\'t '
|
| 531 |
+
f'match the actual input size {input_size}.')
|
| 532 |
+
padding_before = 0
|
| 533 |
+
elif padding == 'SAME':
|
| 534 |
+
expected_input_size = (output_size + stride - 1) // stride
|
| 535 |
+
if input_size != expected_input_size:
|
| 536 |
+
raise ValueError(f'The expected input size with the current set of input '
|
| 537 |
+
f'parameters is {expected_input_size} which doesn\'t '
|
| 538 |
+
f'match the actual input size {input_size}.')
|
| 539 |
+
padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size)
|
| 540 |
+
padding_before = padding_needed // 2
|
| 541 |
+
else:
|
| 542 |
+
padding_before = padding[0] # type: ignore[assignment]
|
| 543 |
+
|
| 544 |
+
expanded_input_size = (input_size - 1) * stride + 1
|
| 545 |
+
padded_out_size = output_size + kernel_size - 1
|
| 546 |
+
pad_before = kernel_size - 1 - padding_before
|
| 547 |
+
pad_after = padded_out_size - expanded_input_size - pad_before
|
| 548 |
+
return (pad_before, pad_after)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def _flip_axes(x, axes):
|
| 552 |
+
"""
|
| 553 |
+
Taken from: https://github.com/google/jax/blob/master/jax/_src/lax/lax.py
|
| 554 |
+
|
| 555 |
+
Flip ndarray 'x' along each axis specified in axes tuple.
|
| 556 |
+
"""
|
| 557 |
+
for axis in axes:
|
| 558 |
+
x = jnp.flip(x, axis)
|
| 559 |
+
return x
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def gradient_based_conv_transpose(lhs,
|
| 563 |
+
rhs,
|
| 564 |
+
strides,
|
| 565 |
+
padding,
|
| 566 |
+
output_padding,
|
| 567 |
+
output_shape=None,
|
| 568 |
+
dilation=None,
|
| 569 |
+
dimension_numbers=None,
|
| 570 |
+
transpose_kernel=True,
|
| 571 |
+
feature_group_count=1,
|
| 572 |
+
precision=None):
|
| 573 |
+
"""
|
| 574 |
+
Taken from: https://github.com/google/jax/pull/5772/commits
|
| 575 |
+
|
| 576 |
+
Convenience wrapper for calculating the N-d transposed convolution.
|
| 577 |
+
Much like `conv_transpose`, this function calculates transposed convolutions
|
| 578 |
+
via fractionally strided convolution rather than calculating the gradient
|
| 579 |
+
(transpose) of a forward convolution. However, the latter is more common
|
| 580 |
+
among deep learning frameworks, such as TensorFlow, PyTorch, and Keras.
|
| 581 |
+
This function provides the same set of APIs to help reproduce results in these frameworks.
|
| 582 |
+
Args:
|
| 583 |
+
lhs: a rank `n+2` dimensional input array.
|
| 584 |
+
rhs: a rank `n+2` dimensional array of kernel weights.
|
| 585 |
+
strides: sequence of `n` integers, amounts to strides of the corresponding forward convolution.
|
| 586 |
+
padding: `"SAME"`, `"VALID"`, or a sequence of `n` integer 2-tuples that controls
|
| 587 |
+
the before-and-after padding for each `n` spatial dimension of
|
| 588 |
+
the corresponding forward convolution.
|
| 589 |
+
output_padding: A sequence of integers specifying the amount of padding along
|
| 590 |
+
each spacial dimension of the output tensor, used to disambiguate the output shape of
|
| 591 |
+
transposed convolutions when the stride is larger than 1.
|
| 592 |
+
(see a detailed description at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)
|
| 593 |
+
The amount of output padding along a given dimension must
|
| 594 |
+
be lower than the stride along that same dimension.
|
| 595 |
+
If set to `None` (default), the output shape is inferred.
|
| 596 |
+
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
|
| 597 |
+
output_shape: Output shape of the spatial dimensions of a transpose
|
| 598 |
+
convolution. Can be `None` or an iterable of `n` integers. If a `None` value is given (default),
|
| 599 |
+
the shape is automatically calculated.
|
| 600 |
+
Similar to `output_padding`, `output_shape` is also for disambiguating the output shape
|
| 601 |
+
when stride > 1 (see also
|
| 602 |
+
https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose)
|
| 603 |
+
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
|
| 604 |
+
dilation: `None`, or a sequence of `n` integers, giving the
|
| 605 |
+
dilation factor to apply in each spatial dimension of `rhs`. Dilated convolution
|
| 606 |
+
is also known as atrous convolution.
|
| 607 |
+
dimension_numbers: tuple of dimension descriptors as in lax.conv_general_dilated. Defaults to tensorflow convention.
|
| 608 |
+
transpose_kernel: if `True` flips spatial axes and swaps the input/output
|
| 609 |
+
channel axes of the kernel. This makes the output of this function identical
|
| 610 |
+
to the gradient-derived functions like keras.layers.Conv2DTranspose and
|
| 611 |
+
torch.nn.ConvTranspose2d applied to the same kernel.
|
| 612 |
+
Although for typical use in neural nets this is unnecessary
|
| 613 |
+
and makes input/output channel specification confusing, you need to set this to `True`
|
| 614 |
+
in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and PyTorch.
|
| 615 |
+
precision: Optional. Either ``None``, which means the default precision for
|
| 616 |
+
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
| 617 |
+
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
| 618 |
+
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
|
| 619 |
+
Returns:
|
| 620 |
+
Transposed N-d convolution.
|
| 621 |
+
"""
|
| 622 |
+
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
|
| 623 |
+
ndims = len(lhs.shape)
|
| 624 |
+
one = (1,) * (ndims - 2)
|
| 625 |
+
# Set dimensional layout defaults if not specified.
|
| 626 |
+
if dimension_numbers is None:
|
| 627 |
+
if ndims == 2:
|
| 628 |
+
dimension_numbers = ('NC', 'IO', 'NC')
|
| 629 |
+
elif ndims == 3:
|
| 630 |
+
dimension_numbers = ('NHC', 'HIO', 'NHC')
|
| 631 |
+
elif ndims == 4:
|
| 632 |
+
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
| 633 |
+
elif ndims == 5:
|
| 634 |
+
dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
|
| 635 |
+
else:
|
| 636 |
+
raise ValueError('No 4+ dimensional dimension_number defaults.')
|
| 637 |
+
dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
| 638 |
+
k_shape = np.take(rhs.shape, dn.rhs_spec)
|
| 639 |
+
k_sdims = k_shape[2:] # type: ignore[index]
|
| 640 |
+
i_shape = np.take(lhs.shape, dn.lhs_spec)
|
| 641 |
+
i_sdims = i_shape[2:] # type: ignore[index]
|
| 642 |
+
|
| 643 |
+
# Calculate correct output shape given padding and strides.
|
| 644 |
+
if dilation is None:
|
| 645 |
+
dilation = (1,) * (rhs.ndim - 2)
|
| 646 |
+
|
| 647 |
+
if output_padding is None:
|
| 648 |
+
output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item]
|
| 649 |
+
|
| 650 |
+
if isinstance(padding, str):
|
| 651 |
+
if padding in {'SAME', 'VALID'}:
|
| 652 |
+
padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item]
|
| 653 |
+
else:
|
| 654 |
+
raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.")
|
| 655 |
+
|
| 656 |
+
inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation))
|
| 657 |
+
|
| 658 |
+
if output_shape is None:
|
| 659 |
+
output_shape = inferred_output_shape # type: ignore[assignment]
|
| 660 |
+
else:
|
| 661 |
+
if not output_shape == inferred_output_shape:
|
| 662 |
+
raise ValueError(f'`output_padding` and `output_shape` are not compatible.'
|
| 663 |
+
f'Inferred output shape from `output_padding`: {inferred_output_shape}, '
|
| 664 |
+
f'but got `output_shape` {output_shape}')
|
| 665 |
+
|
| 666 |
+
pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation))
|
| 667 |
+
|
| 668 |
+
if transpose_kernel:
|
| 669 |
+
# flip spatial dims and swap input / output channel axes
|
| 670 |
+
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
|
| 671 |
+
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
|
| 672 |
+
return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, feature_group_count, precision=precision)
|
| 673 |
+
|
| 674 |
+
|
stylegan2/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
import requests
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def download(ckpt_dir, url):
|
| 8 |
+
name = url[url.rfind('/') + 1 : url.rfind('?')]
|
| 9 |
+
if ckpt_dir is None:
|
| 10 |
+
ckpt_dir = tempfile.gettempdir()
|
| 11 |
+
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
|
| 12 |
+
ckpt_file = os.path.join(ckpt_dir, name)
|
| 13 |
+
if not os.path.exists(ckpt_file):
|
| 14 |
+
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
|
| 15 |
+
if not os.path.exists(ckpt_dir):
|
| 16 |
+
os.makedirs(ckpt_dir)
|
| 17 |
+
|
| 18 |
+
response = requests.get(url, stream=True)
|
| 19 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
| 20 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
| 21 |
+
|
| 22 |
+
# first create temp file, in case the download fails
|
| 23 |
+
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
|
| 24 |
+
with open(ckpt_file_temp, 'wb') as file:
|
| 25 |
+
for data in response.iter_content(chunk_size=1024):
|
| 26 |
+
progress_bar.update(len(data))
|
| 27 |
+
file.write(data)
|
| 28 |
+
progress_bar.close()
|
| 29 |
+
|
| 30 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
| 31 |
+
print('An error occured while downloading, please try again.')
|
| 32 |
+
if os.path.exists(ckpt_file_temp):
|
| 33 |
+
os.remove(ckpt_file_temp)
|
| 34 |
+
else:
|
| 35 |
+
# if download was successful, rename the temp file
|
| 36 |
+
os.rename(ckpt_file_temp, ckpt_file)
|
| 37 |
+
return ckpt_file
|
training.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import flax
|
| 4 |
+
from flax.optim import dynamic_scale as dynamic_scale_lib
|
| 5 |
+
from flax.core import frozen_dict
|
| 6 |
+
import optax
|
| 7 |
+
import numpy as np
|
| 8 |
+
import functools
|
| 9 |
+
import wandb
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import stylegan2
|
| 13 |
+
import data_pipeline
|
| 14 |
+
import checkpoint
|
| 15 |
+
import training_utils
|
| 16 |
+
import training_steps
|
| 17 |
+
from fid import FID
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def tree_shape(item):
|
| 25 |
+
return jax.tree_map(lambda c: c.shape, item)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def train_and_evaluate(config):
|
| 29 |
+
num_devices = jax.device_count() # 8
|
| 30 |
+
num_local_devices = jax.local_device_count() # 4
|
| 31 |
+
num_workers = jax.process_count()
|
| 32 |
+
|
| 33 |
+
# --------------------------------------
|
| 34 |
+
# Data
|
| 35 |
+
# --------------------------------------
|
| 36 |
+
ds_train, dataset_info = data_pipeline.get_data(data_dir=config.data_dir,
|
| 37 |
+
img_size=config.resolution,
|
| 38 |
+
img_channels=config.img_channels,
|
| 39 |
+
num_classes=config.c_dim,
|
| 40 |
+
num_local_devices=num_local_devices,
|
| 41 |
+
batch_size=config.batch_size)
|
| 42 |
+
|
| 43 |
+
# --------------------------------------
|
| 44 |
+
# Seeding and Precision
|
| 45 |
+
# --------------------------------------
|
| 46 |
+
rng = jax.random.PRNGKey(config.random_seed)
|
| 47 |
+
|
| 48 |
+
if config.mixed_precision:
|
| 49 |
+
dtype = jnp.float16
|
| 50 |
+
elif config.bf16:
|
| 51 |
+
dtype = jnp.bfloat16
|
| 52 |
+
else:
|
| 53 |
+
dtype = jnp.float32
|
| 54 |
+
logger.info(f'Running on dtype {dtype}')
|
| 55 |
+
|
| 56 |
+
platform = jax.local_devices()[0].platform
|
| 57 |
+
if config.mixed_precision and platform == 'gpu':
|
| 58 |
+
dynamic_scale_G_main = dynamic_scale_lib.DynamicScale()
|
| 59 |
+
dynamic_scale_D_main = dynamic_scale_lib.DynamicScale()
|
| 60 |
+
dynamic_scale_G_reg = dynamic_scale_lib.DynamicScale()
|
| 61 |
+
dynamic_scale_D_reg = dynamic_scale_lib.DynamicScale()
|
| 62 |
+
clip_conv = 256
|
| 63 |
+
num_fp16_res = 4
|
| 64 |
+
else:
|
| 65 |
+
dynamic_scale_G_main = None
|
| 66 |
+
dynamic_scale_D_main = None
|
| 67 |
+
dynamic_scale_G_reg = None
|
| 68 |
+
dynamic_scale_D_reg = None
|
| 69 |
+
clip_conv = None
|
| 70 |
+
num_fp16_res = 0
|
| 71 |
+
|
| 72 |
+
# --------------------------------------
|
| 73 |
+
# Initialize Models
|
| 74 |
+
# --------------------------------------
|
| 75 |
+
logger.info('Initialize models...')
|
| 76 |
+
|
| 77 |
+
rng, init_rng = jax.random.split(rng)
|
| 78 |
+
|
| 79 |
+
# Generator initialization for training
|
| 80 |
+
start_mn = time.time()
|
| 81 |
+
logger.info("Creating MappingNetwork...")
|
| 82 |
+
mapping_net = stylegan2.MappingNetwork(z_dim=config.z_dim,
|
| 83 |
+
c_dim=config.c_dim,
|
| 84 |
+
w_dim=config.w_dim,
|
| 85 |
+
num_ws=int(np.log2(config.resolution)) * 2 - 3,
|
| 86 |
+
num_layers=8,
|
| 87 |
+
dtype=dtype)
|
| 88 |
+
|
| 89 |
+
mapping_net_vars = mapping_net.init(init_rng,
|
| 90 |
+
jnp.ones((1, config.z_dim)),
|
| 91 |
+
jnp.ones((1, config.c_dim)))
|
| 92 |
+
|
| 93 |
+
mapping_net_params, moving_stats = mapping_net_vars['params'], mapping_net_vars['moving_stats']
|
| 94 |
+
|
| 95 |
+
logger.info(f"MappingNetwork took {time.time() - start_mn:.2f}s")
|
| 96 |
+
|
| 97 |
+
logger.info("Creating SynthesisNetwork...")
|
| 98 |
+
start_sn = time.time()
|
| 99 |
+
synthesis_net = stylegan2.SynthesisNetwork(resolution=config.resolution,
|
| 100 |
+
num_channels=config.img_channels,
|
| 101 |
+
w_dim=config.w_dim,
|
| 102 |
+
fmap_base=config.fmap_base,
|
| 103 |
+
num_fp16_res=num_fp16_res,
|
| 104 |
+
clip_conv=clip_conv,
|
| 105 |
+
dtype=dtype)
|
| 106 |
+
|
| 107 |
+
synthesis_net_vars = synthesis_net.init(init_rng,
|
| 108 |
+
jnp.ones((1, mapping_net.num_ws, config.w_dim)))
|
| 109 |
+
synthesis_net_params, noise_consts = synthesis_net_vars['params'], synthesis_net_vars['noise_consts']
|
| 110 |
+
|
| 111 |
+
logger.info(f"SynthesisNetwork took {time.time() - start_sn:.2f}s")
|
| 112 |
+
|
| 113 |
+
params_G = frozen_dict.FrozenDict(
|
| 114 |
+
{'mapping': mapping_net_params,
|
| 115 |
+
'synthesis': synthesis_net_params}
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Discriminator initialization for training
|
| 119 |
+
logger.info("Creating Discriminator...")
|
| 120 |
+
start_d = time.time()
|
| 121 |
+
discriminator = stylegan2.Discriminator(resolution=config.resolution,
|
| 122 |
+
num_channels=config.img_channels,
|
| 123 |
+
c_dim=config.c_dim,
|
| 124 |
+
mbstd_group_size=config.mbstd_group_size,
|
| 125 |
+
num_fp16_res=num_fp16_res,
|
| 126 |
+
clip_conv=clip_conv,
|
| 127 |
+
dtype=dtype)
|
| 128 |
+
rng, init_rng = jax.random.split(rng)
|
| 129 |
+
params_D = discriminator.init(init_rng,
|
| 130 |
+
jnp.ones((1, config.resolution, config.resolution, config.img_channels)),
|
| 131 |
+
jnp.ones((1, config.c_dim)))
|
| 132 |
+
logger.info(f"Discriminator took {time.time() - start_d:.2f}s")
|
| 133 |
+
|
| 134 |
+
# Exponential average Generator initialization
|
| 135 |
+
logger.info("Creating Generator EMA...")
|
| 136 |
+
start_g = time.time()
|
| 137 |
+
generator_ema = stylegan2.Generator(resolution=config.resolution,
|
| 138 |
+
num_channels=config.img_channels,
|
| 139 |
+
z_dim=config.z_dim,
|
| 140 |
+
c_dim=config.c_dim,
|
| 141 |
+
w_dim=config.w_dim,
|
| 142 |
+
num_ws=int(np.log2(config.resolution)) * 2 - 3,
|
| 143 |
+
num_mapping_layers=8,
|
| 144 |
+
fmap_base=config.fmap_base,
|
| 145 |
+
num_fp16_res=num_fp16_res,
|
| 146 |
+
clip_conv=clip_conv,
|
| 147 |
+
dtype=dtype)
|
| 148 |
+
|
| 149 |
+
params_ema_G = generator_ema.init(init_rng,
|
| 150 |
+
jnp.ones((1, config.z_dim)),
|
| 151 |
+
jnp.ones((1, config.c_dim)))
|
| 152 |
+
logger.info(f"Took {time.time() - start_g:.2f}s")
|
| 153 |
+
|
| 154 |
+
# --------------------------------------
|
| 155 |
+
# Initialize States and Optimizers
|
| 156 |
+
# --------------------------------------
|
| 157 |
+
logger.info('Initialize states...')
|
| 158 |
+
tx_G = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
|
| 159 |
+
tx_D = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
|
| 160 |
+
|
| 161 |
+
state_G = training_utils.TrainStateG.create(apply_fn=None,
|
| 162 |
+
apply_mapping=mapping_net.apply,
|
| 163 |
+
apply_synthesis=synthesis_net.apply,
|
| 164 |
+
params=params_G,
|
| 165 |
+
moving_stats=moving_stats,
|
| 166 |
+
noise_consts=noise_consts,
|
| 167 |
+
tx=tx_G,
|
| 168 |
+
dynamic_scale_main=dynamic_scale_G_main,
|
| 169 |
+
dynamic_scale_reg=dynamic_scale_G_reg,
|
| 170 |
+
epoch=0)
|
| 171 |
+
|
| 172 |
+
state_D = training_utils.TrainStateD.create(apply_fn=discriminator.apply,
|
| 173 |
+
params=params_D,
|
| 174 |
+
tx=tx_D,
|
| 175 |
+
dynamic_scale_main=dynamic_scale_D_main,
|
| 176 |
+
dynamic_scale_reg=dynamic_scale_D_reg,
|
| 177 |
+
epoch=0)
|
| 178 |
+
|
| 179 |
+
# Copy over the parameters from the training generator to the ema generator
|
| 180 |
+
params_ema_G = training_utils.update_generator_ema(state_G, params_ema_G, config, ema_beta=0)
|
| 181 |
+
|
| 182 |
+
# Running mean of path length for path length regularization
|
| 183 |
+
pl_mean = jnp.zeros((), dtype=dtype)
|
| 184 |
+
|
| 185 |
+
step = 0
|
| 186 |
+
epoch_offset = 0
|
| 187 |
+
best_fid_score = np.inf
|
| 188 |
+
ckpt_path = None
|
| 189 |
+
|
| 190 |
+
if config.resume_run_id is not None:
|
| 191 |
+
# Resume training from existing checkpoint
|
| 192 |
+
ckpt_path = checkpoint.get_latest_checkpoint(config.ckpt_dir)
|
| 193 |
+
logger.info(f'Resume training from checkpoint: {ckpt_path}')
|
| 194 |
+
ckpt = checkpoint.load_checkpoint(ckpt_path)
|
| 195 |
+
step = ckpt['step']
|
| 196 |
+
epoch_offset = ckpt['epoch']
|
| 197 |
+
best_fid_score = ckpt['fid_score']
|
| 198 |
+
pl_mean = ckpt['pl_mean']
|
| 199 |
+
state_G = ckpt['state_G']
|
| 200 |
+
state_D = ckpt['state_D']
|
| 201 |
+
params_ema_G = ckpt['params_ema_G']
|
| 202 |
+
config = ckpt['config']
|
| 203 |
+
elif config.load_from_pkl is not None:
|
| 204 |
+
# Load checkpoint and start new run
|
| 205 |
+
ckpt_path = config.load_from_pkl
|
| 206 |
+
logger.info(f'Load model state from from : {ckpt_path}')
|
| 207 |
+
ckpt = checkpoint.load_checkpoint(ckpt_path)
|
| 208 |
+
pl_mean = ckpt['pl_mean']
|
| 209 |
+
state_G = ckpt['state_G']
|
| 210 |
+
state_D = ckpt['state_D']
|
| 211 |
+
params_ema_G = ckpt['params_ema_G']
|
| 212 |
+
|
| 213 |
+
# Replicate states across devices
|
| 214 |
+
pl_mean = flax.jax_utils.replicate(pl_mean)
|
| 215 |
+
state_G = flax.jax_utils.replicate(state_G)
|
| 216 |
+
state_D = flax.jax_utils.replicate(state_D)
|
| 217 |
+
|
| 218 |
+
# --------------------------------------
|
| 219 |
+
# Precompile train and eval steps
|
| 220 |
+
# --------------------------------------
|
| 221 |
+
logger.info('Precompile training steps...')
|
| 222 |
+
p_main_step_G = jax.pmap(training_steps.main_step_G, axis_name='batch')
|
| 223 |
+
p_regul_step_G = jax.pmap(functools.partial(training_steps.regul_step_G, config=config), axis_name='batch')
|
| 224 |
+
|
| 225 |
+
p_main_step_D = jax.pmap(training_steps.main_step_D, axis_name='batch')
|
| 226 |
+
p_regul_step_D = jax.pmap(functools.partial(training_steps.regul_step_D, config=config), axis_name='batch')
|
| 227 |
+
|
| 228 |
+
# --------------------------------------
|
| 229 |
+
# Training
|
| 230 |
+
# --------------------------------------
|
| 231 |
+
logger.info('Start training...')
|
| 232 |
+
fid_metric = FID(generator_ema, ds_train, config)
|
| 233 |
+
|
| 234 |
+
# Dict to collect training statistics / losses
|
| 235 |
+
metrics = {}
|
| 236 |
+
num_imgs_processed = 0
|
| 237 |
+
num_steps_per_epoch = dataset_info['num_examples'] // (config.batch_size * num_devices)
|
| 238 |
+
effective_batch_size = config.batch_size * num_devices
|
| 239 |
+
if config.wandb and jax.process_index() == 0:
|
| 240 |
+
# do some more logging
|
| 241 |
+
wandb.config.effective_batch_size = effective_batch_size
|
| 242 |
+
wandb.config.num_steps_per_epoch = num_steps_per_epoch
|
| 243 |
+
wandb.config.num_workers = num_workers
|
| 244 |
+
wandb.config.device_count = num_devices
|
| 245 |
+
wandb.config.num_examples = dataset_info['num_examples']
|
| 246 |
+
wandb.config.vm_name = training_utils.get_vm_name()
|
| 247 |
+
|
| 248 |
+
for epoch in range(epoch_offset, config.num_epochs):
|
| 249 |
+
if config.wandb and jax.process_index() == 0:
|
| 250 |
+
wandb.log({'training/epochs': epoch}, step=step)
|
| 251 |
+
|
| 252 |
+
for batch in data_pipeline.prefetch(ds_train, config.num_prefetch):
|
| 253 |
+
assert batch['image'].shape[1] == config.batch_size, f"Mismatched batch (batch size: {config.batch_size}, this batch: {batch['image'].shape[1]})"
|
| 254 |
+
|
| 255 |
+
# pbar.update(num_devices * config.batch_size)
|
| 256 |
+
iteration_start_time = time.time()
|
| 257 |
+
|
| 258 |
+
if config.c_dim == 0:
|
| 259 |
+
# No labels in the dataset
|
| 260 |
+
batch['label'] = None
|
| 261 |
+
|
| 262 |
+
# Create two latent noise vectors and combine them for the style mixing regularization
|
| 263 |
+
rng, key = jax.random.split(rng)
|
| 264 |
+
z_latent1 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
|
| 265 |
+
rng, key = jax.random.split(rng)
|
| 266 |
+
z_latent2 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
|
| 267 |
+
|
| 268 |
+
# Split PRNGs across devices
|
| 269 |
+
rkey = jax.random.split(key, num=num_local_devices)
|
| 270 |
+
mixing_prob = flax.jax_utils.replicate(config.mixing_prob)
|
| 271 |
+
|
| 272 |
+
# --------------------------------------
|
| 273 |
+
# Update Discriminator
|
| 274 |
+
# --------------------------------------
|
| 275 |
+
time_d_start = time.time()
|
| 276 |
+
state_D, metrics = p_main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
|
| 277 |
+
time_d_end = time.time()
|
| 278 |
+
if step % config.D_reg_interval == 0:
|
| 279 |
+
state_D, metrics = p_regul_step_D(state_D, batch, metrics)
|
| 280 |
+
|
| 281 |
+
# --------------------------------------
|
| 282 |
+
# Update Generator
|
| 283 |
+
# --------------------------------------
|
| 284 |
+
time_g_start = time.time()
|
| 285 |
+
state_G, metrics = p_main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
|
| 286 |
+
if step % config.G_reg_interval == 0:
|
| 287 |
+
H, W = batch['image'].shape[-3], batch['image'].shape[-2]
|
| 288 |
+
rng, key = jax.random.split(rng)
|
| 289 |
+
pl_noise = jax.random.normal(key, batch['image'].shape, dtype=dtype) / np.sqrt(H * W)
|
| 290 |
+
state_G, metrics, pl_mean = p_regul_step_G(state_G, batch, z_latent1, pl_noise, pl_mean, metrics,
|
| 291 |
+
rng=rkey)
|
| 292 |
+
|
| 293 |
+
params_ema_G = training_utils.update_generator_ema(flax.jax_utils.unreplicate(state_G),
|
| 294 |
+
params_ema_G,
|
| 295 |
+
config)
|
| 296 |
+
time_g_end = time.time()
|
| 297 |
+
|
| 298 |
+
# --------------------------------------
|
| 299 |
+
# Logging and Checkpointing
|
| 300 |
+
# --------------------------------------
|
| 301 |
+
if step % config.save_every == 0 and config.disable_fid:
|
| 302 |
+
# If FID evaluation is disabled, a checkpoint will be saved every 'save_every' steps.
|
| 303 |
+
if jax.process_index() == 0:
|
| 304 |
+
logger.info('Saving checkpoint...')
|
| 305 |
+
checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step,
|
| 306 |
+
epoch)
|
| 307 |
+
|
| 308 |
+
num_imgs_processed += num_devices * config.batch_size
|
| 309 |
+
if step % config.eval_fid_every == 0 and not config.disable_fid:
|
| 310 |
+
# If FID evaluation is enabled, only save a checkpoint if FID score is better.
|
| 311 |
+
if jax.process_index() == 0:
|
| 312 |
+
logger.info('Computing FID...')
|
| 313 |
+
fid_score = fid_metric.compute_fid(params_ema_G).item()
|
| 314 |
+
if config.wandb:
|
| 315 |
+
wandb.log({'training/gen/fid': fid_score}, step=step)
|
| 316 |
+
logger.info(f'Computed FID: {fid_score:.2f}')
|
| 317 |
+
if fid_score < best_fid_score:
|
| 318 |
+
best_fid_score = fid_score
|
| 319 |
+
logger.info(f'New best FID score ({best_fid_score:.3f}). Saving checkpoint...')
|
| 320 |
+
ts = time.time()
|
| 321 |
+
checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=fid_score)
|
| 322 |
+
te = time.time()
|
| 323 |
+
logger.info(f'... successfully saved checkpoint in {(te-ts)/60:.1f}min')
|
| 324 |
+
|
| 325 |
+
sec_per_kimg = (time.time() - iteration_start_time) / (num_devices * config.batch_size / 1000.0)
|
| 326 |
+
time_taken_g = time_g_end - time_g_start
|
| 327 |
+
time_taken_d = time_d_end - time_d_start
|
| 328 |
+
time_taken_per_step = time.time() - iteration_start_time
|
| 329 |
+
g_loss = jnp.mean(metrics['G_loss']).item()
|
| 330 |
+
d_loss = jnp.mean(metrics['D_loss']).item()
|
| 331 |
+
|
| 332 |
+
if config.wandb and jax.process_index() == 0:
|
| 333 |
+
# wandb logging - happens every step
|
| 334 |
+
wandb.log({'training/gen/loss': jnp.mean(metrics['G_loss']).item()}, step=step, commit=False)
|
| 335 |
+
wandb.log({'training/dis/loss': jnp.mean(metrics['D_loss']).item()}, step=step, commit=False)
|
| 336 |
+
wandb.log({'training/dis/fake_logits': jnp.mean(metrics['fake_logits']).item()}, step=step, commit=False)
|
| 337 |
+
wandb.log({'training/dis/real_logits': jnp.mean(metrics['real_logits']).item()}, step=step, commit=False)
|
| 338 |
+
wandb.log({'training/time_taken_g': time_taken_g, 'training/time_taken_d': time_taken_d}, step=step, commit=False)
|
| 339 |
+
wandb.log({'training/time_taken_per_step': time_taken_per_step}, step=step, commit=False)
|
| 340 |
+
wandb.log({'training/num_imgs_trained': num_imgs_processed}, step=step, commit=False)
|
| 341 |
+
wandb.log({'training/sec_per_kimg': sec_per_kimg}, step=step)
|
| 342 |
+
|
| 343 |
+
if step % config.log_every == 0:
|
| 344 |
+
# console logging - happens every log_every steps
|
| 345 |
+
logger.info(f'Total steps: {step:>6,} - epoch {epoch:>3,}/{config.num_epochs} @ {step % num_steps_per_epoch:>6,}/{num_steps_per_epoch:,} - G loss: {g_loss:.5f} - D loss: {d_loss:.5f} - sec/kimg: {sec_per_kimg:.2f}s - time per step: {time_taken_per_step:.3f}s')
|
| 346 |
+
|
| 347 |
+
if step % config.generate_samples_every == 0 and config.wandb and jax.process_index() == 0:
|
| 348 |
+
# Generate training images
|
| 349 |
+
train_snapshot = training_utils.get_training_snapshot(
|
| 350 |
+
image_real=flax.jax_utils.unreplicate(batch['image']),
|
| 351 |
+
image_gen=flax.jax_utils.unreplicate(metrics['image_gen']),
|
| 352 |
+
max_num=10
|
| 353 |
+
)
|
| 354 |
+
wandb.log({'training/snapshot': wandb.Image(train_snapshot)}, commit=False, step=step)
|
| 355 |
+
|
| 356 |
+
# Generate evaluation images
|
| 357 |
+
labels = None if config.c_dim == 0 else batch['label'][0]
|
| 358 |
+
image_gen_eval = training_steps.eval_step_G(
|
| 359 |
+
generator_ema, params=params_ema_G,
|
| 360 |
+
z_latent=z_latent1[0],
|
| 361 |
+
labels=labels,
|
| 362 |
+
truncation=1
|
| 363 |
+
)
|
| 364 |
+
image_gen_eval_trunc = training_steps.eval_step_G(
|
| 365 |
+
generator_ema,
|
| 366 |
+
params=params_ema_G,
|
| 367 |
+
z_latent=z_latent1[0],
|
| 368 |
+
labels=labels,
|
| 369 |
+
truncation=0.5
|
| 370 |
+
)
|
| 371 |
+
eval_snapshot = training_utils.get_eval_snapshot(image=image_gen_eval, max_num=10)
|
| 372 |
+
eval_snapshot_trunc = training_utils.get_eval_snapshot(image=image_gen_eval_trunc, max_num=10)
|
| 373 |
+
wandb.log({'eval/snapshot': wandb.Image(eval_snapshot)}, commit=False, step=step)
|
| 374 |
+
wandb.log({'eval/snapshot_trunc': wandb.Image(eval_snapshot_trunc)}, step=step)
|
| 375 |
+
|
| 376 |
+
step += 1
|
| 377 |
+
|
| 378 |
+
# Sync moving stats across devices
|
| 379 |
+
state_G = training_utils.sync_moving_stats(state_G)
|
| 380 |
+
|
| 381 |
+
# Sync moving average of path length mean (Generator regularization)
|
| 382 |
+
pl_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name='batch'), axis_name='batch')(pl_mean)
|
training_steps.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import functools
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rng):
|
| 7 |
+
|
| 8 |
+
def loss_fn(params):
|
| 9 |
+
w_latent1, new_state_G = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
|
| 10 |
+
z_latent1,
|
| 11 |
+
batch['label'],
|
| 12 |
+
mutable=['moving_stats'])
|
| 13 |
+
w_latent2 = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
|
| 14 |
+
z_latent2,
|
| 15 |
+
batch['label'],
|
| 16 |
+
skip_w_avg_update=True)
|
| 17 |
+
|
| 18 |
+
# style mixing
|
| 19 |
+
cutoff_rng, layer_select_rng, synth_rng = jax.random.split(rng, num=3)
|
| 20 |
+
num_layers = w_latent1.shape[1]
|
| 21 |
+
layer_idx = jnp.arange(num_layers)[jnp.newaxis, :, jnp.newaxis]
|
| 22 |
+
mixing_cutoff = jax.lax.cond(jax.random.uniform(cutoff_rng, (), minval=0.0, maxval=1.0) < mixing_prob,
|
| 23 |
+
lambda _: jax.random.randint(layer_select_rng, (), 1, num_layers, dtype=jnp.int32),
|
| 24 |
+
lambda _: num_layers,
|
| 25 |
+
operand=None)
|
| 26 |
+
mixing_cond = jnp.broadcast_to(layer_idx < mixing_cutoff, w_latent1.shape)
|
| 27 |
+
w_latent = jnp.where(mixing_cond, w_latent1, w_latent2)
|
| 28 |
+
|
| 29 |
+
image_gen = state_G.apply_synthesis({'params': params['synthesis'], 'noise_consts': state_G.noise_consts},
|
| 30 |
+
w_latent,
|
| 31 |
+
rng=synth_rng)
|
| 32 |
+
|
| 33 |
+
fake_logits = state_D.apply_fn(state_D.params, image_gen, batch['label'])
|
| 34 |
+
loss = jnp.mean(jax.nn.softplus(-fake_logits))
|
| 35 |
+
return loss, (fake_logits, image_gen, new_state_G)
|
| 36 |
+
|
| 37 |
+
dynamic_scale = state_G.dynamic_scale_main
|
| 38 |
+
|
| 39 |
+
if dynamic_scale:
|
| 40 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True, axis_name='batch')
|
| 41 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_G.params)
|
| 42 |
+
else:
|
| 43 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
| 44 |
+
aux, grads = grad_fn(state_G.params)
|
| 45 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
| 46 |
+
|
| 47 |
+
loss = aux[0]
|
| 48 |
+
_, image_gen, new_state = aux[1]
|
| 49 |
+
metrics['G_loss'] = loss
|
| 50 |
+
metrics['image_gen'] = image_gen
|
| 51 |
+
|
| 52 |
+
new_state_G = state_G.apply_gradients(grads=grads, moving_stats=new_state['moving_stats'])
|
| 53 |
+
|
| 54 |
+
if dynamic_scale:
|
| 55 |
+
new_state_G = new_state_G.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 56 |
+
new_state_G.opt_state,
|
| 57 |
+
state_G.opt_state),
|
| 58 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 59 |
+
new_state_G.params,
|
| 60 |
+
state_G.params))
|
| 61 |
+
metrics['G_scale'] = dynamic_scale.scale
|
| 62 |
+
|
| 63 |
+
return new_state_G, metrics
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def regul_step_G(state_G, batch, z_latent, pl_noise, pl_mean, metrics, config, rng):
|
| 67 |
+
|
| 68 |
+
def loss_fn(params):
|
| 69 |
+
w_latent, new_state_G = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
|
| 70 |
+
z_latent,
|
| 71 |
+
batch['label'],
|
| 72 |
+
mutable=['moving_stats'])
|
| 73 |
+
|
| 74 |
+
pl_grads = jax.grad(lambda *args: jnp.sum(state_G.apply_synthesis(*args) * pl_noise), argnums=1)({'params': params['synthesis'],
|
| 75 |
+
'noise_consts': state_G.noise_consts},
|
| 76 |
+
w_latent,
|
| 77 |
+
'random',
|
| 78 |
+
rng)
|
| 79 |
+
pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=1))
|
| 80 |
+
pl_mean_new = pl_mean + config.pl_decay * (jnp.mean(pl_lengths) - pl_mean)
|
| 81 |
+
pl_penalty = jnp.square(pl_lengths - pl_mean_new) * config.pl_weight
|
| 82 |
+
loss = jnp.mean(pl_penalty) * config.G_reg_interval
|
| 83 |
+
|
| 84 |
+
return loss, pl_mean_new
|
| 85 |
+
|
| 86 |
+
dynamic_scale = state_G.dynamic_scale_reg
|
| 87 |
+
|
| 88 |
+
if dynamic_scale:
|
| 89 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
|
| 90 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_G.params)
|
| 91 |
+
else:
|
| 92 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
| 93 |
+
aux, grads = grad_fn(state_G.params)
|
| 94 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
| 95 |
+
|
| 96 |
+
loss = aux[0]
|
| 97 |
+
pl_mean_new = aux[1]
|
| 98 |
+
|
| 99 |
+
metrics['G_regul_loss'] = loss
|
| 100 |
+
new_state_G = state_G.apply_gradients(grads=grads)
|
| 101 |
+
|
| 102 |
+
if dynamic_scale:
|
| 103 |
+
new_state_G = new_state_G.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 104 |
+
new_state_G.opt_state,
|
| 105 |
+
state_G.opt_state),
|
| 106 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 107 |
+
new_state_G.params,
|
| 108 |
+
state_G.params))
|
| 109 |
+
metrics['G_regul_scale'] = dynamic_scale.scale
|
| 110 |
+
|
| 111 |
+
return new_state_G, metrics, pl_mean_new
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rng):
|
| 115 |
+
|
| 116 |
+
def loss_fn(params):
|
| 117 |
+
w_latent1 = state_G.apply_mapping({'params': state_G.params['mapping'], 'moving_stats': state_G.moving_stats},
|
| 118 |
+
z_latent1,
|
| 119 |
+
batch['label'],
|
| 120 |
+
train=False)
|
| 121 |
+
|
| 122 |
+
w_latent2 = state_G.apply_mapping({'params': state_G.params['mapping'], 'moving_stats': state_G.moving_stats},
|
| 123 |
+
z_latent2,
|
| 124 |
+
batch['label'],
|
| 125 |
+
train=False)
|
| 126 |
+
|
| 127 |
+
# style mixing
|
| 128 |
+
cutoff_rng, layer_select_rng, synth_rng = jax.random.split(rng, num=3)
|
| 129 |
+
num_layers = w_latent1.shape[1]
|
| 130 |
+
layer_idx = jnp.arange(num_layers)[jnp.newaxis, :, jnp.newaxis]
|
| 131 |
+
mixing_cutoff = jax.lax.cond(jax.random.uniform(cutoff_rng, (), minval=0.0, maxval=1.0) < mixing_prob,
|
| 132 |
+
lambda _: jax.random.randint(layer_select_rng, (), 1, num_layers, dtype=jnp.int32),
|
| 133 |
+
lambda _: num_layers,
|
| 134 |
+
operand=None)
|
| 135 |
+
mixing_cond = jnp.broadcast_to(layer_idx < mixing_cutoff, w_latent1.shape)
|
| 136 |
+
w_latent = jnp.where(mixing_cond, w_latent1, w_latent2)
|
| 137 |
+
|
| 138 |
+
image_gen = state_G.apply_synthesis({'params': state_G.params['synthesis'], 'noise_consts': state_G.noise_consts},
|
| 139 |
+
w_latent,
|
| 140 |
+
rng=synth_rng)
|
| 141 |
+
|
| 142 |
+
fake_logits = state_D.apply_fn(params, image_gen, batch['label'])
|
| 143 |
+
real_logits = state_D.apply_fn(params, batch['image'], batch['label'])
|
| 144 |
+
|
| 145 |
+
loss_fake = jax.nn.softplus(fake_logits)
|
| 146 |
+
loss_real = jax.nn.softplus(-real_logits)
|
| 147 |
+
loss = jnp.mean(loss_fake + loss_real)
|
| 148 |
+
|
| 149 |
+
return loss, (fake_logits, real_logits)
|
| 150 |
+
|
| 151 |
+
dynamic_scale = state_D.dynamic_scale_main
|
| 152 |
+
|
| 153 |
+
if dynamic_scale:
|
| 154 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
|
| 155 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_D.params)
|
| 156 |
+
else:
|
| 157 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
| 158 |
+
aux, grads = grad_fn(state_D.params)
|
| 159 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
| 160 |
+
|
| 161 |
+
loss = aux[0]
|
| 162 |
+
fake_logits, real_logits = aux[1]
|
| 163 |
+
metrics['D_loss'] = loss
|
| 164 |
+
metrics['fake_logits'] = jnp.mean(fake_logits)
|
| 165 |
+
metrics['real_logits'] = jnp.mean(real_logits)
|
| 166 |
+
|
| 167 |
+
new_state_D = state_D.apply_gradients(grads=grads)
|
| 168 |
+
|
| 169 |
+
if dynamic_scale:
|
| 170 |
+
new_state_D = new_state_D.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 171 |
+
new_state_D.opt_state,
|
| 172 |
+
state_D.opt_state),
|
| 173 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 174 |
+
new_state_D.params,
|
| 175 |
+
state_D.params))
|
| 176 |
+
metrics['D_scale'] = dynamic_scale.scale
|
| 177 |
+
|
| 178 |
+
return new_state_D, metrics
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def regul_step_D(state_D, batch, metrics, config):
|
| 182 |
+
|
| 183 |
+
def loss_fn(params):
|
| 184 |
+
r1_grads = jax.grad(lambda *args: jnp.sum(state_D.apply_fn(*args)), argnums=1)(params, batch['image'], batch['label'])
|
| 185 |
+
r1_penalty = jnp.sum(jnp.square(r1_grads), axis=(1, 2, 3)) * (config.r1_gamma / 2) * config.D_reg_interval
|
| 186 |
+
loss = jnp.mean(r1_penalty)
|
| 187 |
+
return loss, None
|
| 188 |
+
|
| 189 |
+
dynamic_scale = state_D.dynamic_scale_reg
|
| 190 |
+
|
| 191 |
+
if dynamic_scale:
|
| 192 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
|
| 193 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_D.params)
|
| 194 |
+
else:
|
| 195 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
| 196 |
+
aux, grads = grad_fn(state_D.params)
|
| 197 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
| 198 |
+
|
| 199 |
+
loss = aux[0]
|
| 200 |
+
metrics['D_regul_loss'] = loss
|
| 201 |
+
|
| 202 |
+
new_state_D = state_D.apply_gradients(grads=grads)
|
| 203 |
+
|
| 204 |
+
if dynamic_scale:
|
| 205 |
+
new_state_D = new_state_D.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 206 |
+
new_state_D.opt_state,
|
| 207 |
+
state_D.opt_state),
|
| 208 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
| 209 |
+
new_state_D.params,
|
| 210 |
+
state_D.params))
|
| 211 |
+
metrics['D_regul_scale'] = dynamic_scale.scale
|
| 212 |
+
|
| 213 |
+
return new_state_D, metrics
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def eval_step_G(generator, params, z_latent, labels, truncation):
|
| 217 |
+
image_gen = generator.apply(params, z_latent, labels, truncation_psi=truncation, train=False, noise_mode='const')
|
| 218 |
+
return image_gen
|
| 219 |
+
|
training_utils.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
from jaxlib.xla_extension import DeviceArray
|
| 4 |
+
import flax
|
| 5 |
+
from flax.optim import dynamic_scale as dynamic_scale_lib
|
| 6 |
+
from flax.core import frozen_dict
|
| 7 |
+
from flax.training import train_state
|
| 8 |
+
from flax import struct
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from urllib.request import Request, urlopen
|
| 12 |
+
import urllib.error
|
| 13 |
+
from typing import Any, Callable
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def sync_moving_stats(state):
|
| 17 |
+
"""
|
| 18 |
+
Sync moving statistics across devices.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
state (train_state.TrainState): Training state.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
(train_state.TrainState): Updated training state.
|
| 25 |
+
"""
|
| 26 |
+
cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')
|
| 27 |
+
return state.replace(moving_stats=cross_replica_mean(state.moving_stats))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def update_generator_ema(state_G, params_ema_G, config, ema_beta=None):
|
| 31 |
+
"""
|
| 32 |
+
Update exponentially moving average of the generator weights.
|
| 33 |
+
Moving stats and noise constants will be copied over.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
state_G (train_state.TrainState): Generator state.
|
| 37 |
+
params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
|
| 38 |
+
config (Any): Config object.
|
| 39 |
+
ema_beta (float): Beta parameter of the ema. If None, will be computed
|
| 40 |
+
from 'ema_nimg' and 'batch_size'.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
(frozen_dict.FrozenDict): Updates parameters of the ema generator.
|
| 44 |
+
"""
|
| 45 |
+
def _update_ema(src, trg, beta):
|
| 46 |
+
for name, src_child in src.items():
|
| 47 |
+
if isinstance(src_child, DeviceArray):
|
| 48 |
+
trg[name] = src[name] + ema_beta * (trg[name] - src[name])
|
| 49 |
+
else:
|
| 50 |
+
_update_ema(src_child, trg[name], beta)
|
| 51 |
+
|
| 52 |
+
if ema_beta is None:
|
| 53 |
+
ema_nimg = config.ema_kimg * 1000
|
| 54 |
+
ema_beta = 0.5 ** (config.batch_size / max(ema_nimg, 1e-8))
|
| 55 |
+
|
| 56 |
+
params_ema_G = params_ema_G.unfreeze()
|
| 57 |
+
|
| 58 |
+
# Copy over moving stats
|
| 59 |
+
params_ema_G['moving_stats']['mapping_network'] = state_G.moving_stats
|
| 60 |
+
params_ema_G['noise_consts']['synthesis_network'] = state_G.noise_consts
|
| 61 |
+
|
| 62 |
+
# Update exponentially moving average of the trainable parameters
|
| 63 |
+
_update_ema(state_G.params['mapping'], params_ema_G['params']['mapping_network'], ema_beta)
|
| 64 |
+
_update_ema(state_G.params['synthesis'], params_ema_G['params']['synthesis_network'], ema_beta)
|
| 65 |
+
|
| 66 |
+
params_ema_G = frozen_dict.freeze(params_ema_G)
|
| 67 |
+
return params_ema_G
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TrainStateG(train_state.TrainState):
|
| 71 |
+
"""
|
| 72 |
+
Generator train state for a single Optax optimizer.
|
| 73 |
+
|
| 74 |
+
Attributes:
|
| 75 |
+
apply_mapping (Callable): Apply function of the Mapping Network.
|
| 76 |
+
apply_synthesis (Callable): Apply function of the Synthesis Network.
|
| 77 |
+
dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
|
| 78 |
+
epoch (int): Current epoch.
|
| 79 |
+
moving_stats (Any): Moving average of the latent W.
|
| 80 |
+
noise_consts (Any): Noise constants from synthesis layers.
|
| 81 |
+
"""
|
| 82 |
+
apply_mapping: Callable = struct.field(pytree_node=False)
|
| 83 |
+
apply_synthesis: Callable = struct.field(pytree_node=False)
|
| 84 |
+
dynamic_scale_main: dynamic_scale_lib.DynamicScale
|
| 85 |
+
dynamic_scale_reg: dynamic_scale_lib.DynamicScale
|
| 86 |
+
epoch: int
|
| 87 |
+
moving_stats: Any=None
|
| 88 |
+
noise_consts: Any=None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class TrainStateD(train_state.TrainState):
|
| 92 |
+
"""
|
| 93 |
+
Discriminator train state for a single Optax optimizer.
|
| 94 |
+
|
| 95 |
+
Attributes:
|
| 96 |
+
dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
|
| 97 |
+
epoch (int): Current epoch.
|
| 98 |
+
"""
|
| 99 |
+
dynamic_scale_main: dynamic_scale_lib.DynamicScale
|
| 100 |
+
dynamic_scale_reg: dynamic_scale_lib.DynamicScale
|
| 101 |
+
epoch: int
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_training_snapshot(image_real, image_gen, max_num=10):
|
| 105 |
+
"""
|
| 106 |
+
Creates a snapshot of generated images and real images.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
images_real (DeviceArray): Batch of real images, shape [B, H, W, C].
|
| 110 |
+
images_gen (DeviceArray): Batch of generated images, shape [B, H, W, C].
|
| 111 |
+
max_num (int): Maximum number of images used for snapshot.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
(PIL.Image): Training snapshot. Top row: generated images, bottom row: real images.
|
| 115 |
+
"""
|
| 116 |
+
if image_real.shape[0] > max_num:
|
| 117 |
+
image_real = image_real[:max_num]
|
| 118 |
+
if image_gen.shape[0] > max_num:
|
| 119 |
+
image_gen = image_gen[:max_num]
|
| 120 |
+
|
| 121 |
+
image_real = jnp.split(image_real, image_real.shape[0], axis=0)
|
| 122 |
+
image_gen = jnp.split(image_gen, image_gen.shape[0], axis=0)
|
| 123 |
+
|
| 124 |
+
image_real = [jnp.squeeze(x, axis=0) for x in image_real]
|
| 125 |
+
image_gen = [jnp.squeeze(x, axis=0) for x in image_gen]
|
| 126 |
+
|
| 127 |
+
image_real = jnp.concatenate(image_real, axis=1)
|
| 128 |
+
image_gen = jnp.concatenate(image_gen, axis=1)
|
| 129 |
+
|
| 130 |
+
image_gen = (image_gen - np.min(image_gen)) / (np.max(image_gen) - np.min(image_gen))
|
| 131 |
+
image_real = (image_real - np.min(image_real)) / (np.max(image_real) - np.min(image_real))
|
| 132 |
+
image = jnp.concatenate((image_gen, image_real), axis=0)
|
| 133 |
+
|
| 134 |
+
image = np.uint8(image * 255)
|
| 135 |
+
if image.shape[-1] == 1:
|
| 136 |
+
image = np.repeat(image, 3, axis=-1)
|
| 137 |
+
return Image.fromarray(image)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_eval_snapshot(image, max_num=10):
|
| 141 |
+
"""
|
| 142 |
+
Creates a snapshot of generated images.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
image (DeviceArray): Generated images, shape [B, H, W, C].
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
(PIL.Image): Eval snapshot.
|
| 149 |
+
"""
|
| 150 |
+
if image.shape[0] > max_num:
|
| 151 |
+
image = image[:max_num]
|
| 152 |
+
|
| 153 |
+
image = jnp.split(image, image.shape[0], axis=0)
|
| 154 |
+
image = [jnp.squeeze(x, axis=0) for x in image]
|
| 155 |
+
image = jnp.concatenate(image, axis=1)
|
| 156 |
+
image = (image - np.min(image)) / (np.max(image) - np.min(image))
|
| 157 |
+
image = np.uint8(image * 255)
|
| 158 |
+
if image.shape[-1] == 1:
|
| 159 |
+
image = np.repeat(image, 3, axis=-1)
|
| 160 |
+
return Image.fromarray(image)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_vm_name():
|
| 164 |
+
gcp_metadata_url = "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance-id"
|
| 165 |
+
req = Request(gcp_metadata_url)
|
| 166 |
+
req.add_header('Metadata-Flavor', 'Google')
|
| 167 |
+
instance_id = None
|
| 168 |
+
try:
|
| 169 |
+
with urlopen(req) as url:
|
| 170 |
+
instance_id = url.read().decode()
|
| 171 |
+
except urllib.error.URLError:
|
| 172 |
+
# metadata.google.internal not reachable: use dev
|
| 173 |
+
pass
|
| 174 |
+
return instance_id
|