Upload 3 files
Browse files- download_checkpoints.py +139 -0
- joint_loss.py +510 -0
- prepare_dataset.py +320 -0
download_checkpoints.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
parser = argparse.ArgumentParser(description="Download checkpoint pair from remote runpod machine with unique filenames.")
|
| 11 |
+
parser.add_argument("--remote-ip", required=True, help="Remote machine IP address")
|
| 12 |
+
parser.add_argument("--remote-port", required=True, type=int, help="Remote SSH port")
|
| 13 |
+
parser.add_argument("--remote-user", required=True, help="Username for remote SSH")
|
| 14 |
+
parser.add_argument("--remote-base-path", default="/workspace", help="Directory on remote machine containing checkpoints and lock file")
|
| 15 |
+
parser.add_argument("--local-dest", required=True, help="Local directory where checkpoints should be saved")
|
| 16 |
+
parser.add_argument("--rsa-key", required=True, help="Path to your RSA private key for authentication")
|
| 17 |
+
parser.add_argument("--poll-interval", type=float, default=10, help="Polling interval in seconds")
|
| 18 |
+
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
# Construct remote file paths.
|
| 22 |
+
remote_checkpoint_a = f"{args.remote_base_path}/grads_a.safetensors"
|
| 23 |
+
remote_checkpoint_b = f"{args.remote_base_path}/grads_b.safetensors"
|
| 24 |
+
remote_inv_log_scalars = f"{args.remote_base_path}/log_scalars.safetensors"
|
| 25 |
+
remote_thresholds = f"{args.remote_base_path}/thresholds.safetensors"
|
| 26 |
+
remote_lock_file = f"{args.remote_base_path}/safetensors.lock"
|
| 27 |
+
|
| 28 |
+
print("Starting remote checkpoint monitor...")
|
| 29 |
+
while True:
|
| 30 |
+
# Check if the lock file exists on the remote machine.
|
| 31 |
+
if remote_file_exists(args.remote_user, args.remote_ip, args.remote_port, remote_lock_file, args.rsa_key):
|
| 32 |
+
print("New checkpoints detected. Downloading...")
|
| 33 |
+
|
| 34 |
+
# Generate unique filenames for each model.
|
| 35 |
+
local_checkpoint_a = get_unique_filename(args.local_dest, "grads_a")
|
| 36 |
+
local_checkpoint_b = get_unique_filename(args.local_dest, "grads_b")
|
| 37 |
+
local_inv_log_scalars = get_unique_filename(args.local_dest, "log_scalars")
|
| 38 |
+
local_thresholds = get_unique_filename(args.local_dest, "thresholds")
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
# Download both checkpoints with the unique filenames.
|
| 42 |
+
download_file(args.remote_user, args.remote_ip, args.remote_port, remote_checkpoint_a, local_checkpoint_a, args.rsa_key)
|
| 43 |
+
download_file(args.remote_user, args.remote_ip, args.remote_port, remote_checkpoint_b, local_checkpoint_b, args.rsa_key)
|
| 44 |
+
download_file(args.remote_user, args.remote_ip, args.remote_port, remote_inv_log_scalars, local_inv_log_scalars, args.rsa_key)
|
| 45 |
+
# download_file(args.remote_user, args.remote_ip, args.remote_port, remote_thresholds, local_thresholds, args.rsa_key)
|
| 46 |
+
except subprocess.CalledProcessError as e:
|
| 47 |
+
print(f"Download error: {e}")
|
| 48 |
+
time.sleep(args.poll_interval)
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# After successful download, delete only the lock file on the remote side.
|
| 52 |
+
try:
|
| 53 |
+
while not delete_remote_lock(args.remote_user, args.remote_ip, args.remote_port, remote_lock_file, args.rsa_key):
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
print("Download complete. Checkpoints saved as:")
|
| 57 |
+
print(f" {local_checkpoint_a}")
|
| 58 |
+
print(f" {local_checkpoint_b}")
|
| 59 |
+
print("Remote lock file deleted.")
|
| 60 |
+
except subprocess.CalledProcessError as e:
|
| 61 |
+
print(f"Error deleting remote lock file: {e}")
|
| 62 |
+
else:
|
| 63 |
+
print("No checkpoints found.")
|
| 64 |
+
|
| 65 |
+
time.sleep(args.poll_interval)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def remote_file_exists(remote_user, remote_host, remote_port, remote_path, rsa_key, timeout=10):
|
| 69 |
+
"""Check if a file exists on the remote machine."""
|
| 70 |
+
cmd = [
|
| 71 |
+
"ssh",
|
| 72 |
+
"-i", rsa_key,
|
| 73 |
+
"-p", str(remote_port),
|
| 74 |
+
"-o", "StrictHostKeyChecking=no",
|
| 75 |
+
"-o", "UserKnownHostsFile=/dev/null",
|
| 76 |
+
f"{remote_user}@{remote_host}",
|
| 77 |
+
f"test -f {remote_path}"
|
| 78 |
+
]
|
| 79 |
+
try:
|
| 80 |
+
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
|
| 81 |
+
if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="")
|
| 82 |
+
if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="")
|
| 83 |
+
return result.returncode == 0
|
| 84 |
+
except subprocess.TimeoutExpired:
|
| 85 |
+
print(f"TimeoutExpired: SSH command to check {remote_path} on {remote_host} timed out after {timeout} seconds.")
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def download_file(remote_user, remote_host, remote_port, remote_file, local_file, rsa_key, timeout=1200):
|
| 90 |
+
"""Download a file from the remote machine using scp and save it with a specific name."""
|
| 91 |
+
cmd = [
|
| 92 |
+
"scp",
|
| 93 |
+
"-i", rsa_key,
|
| 94 |
+
"-P", str(remote_port),
|
| 95 |
+
"-o", "StrictHostKeyChecking=no",
|
| 96 |
+
"-o", "UserKnownHostsFile=/dev/null",
|
| 97 |
+
f"{remote_user}@{remote_host}:{remote_file}",
|
| 98 |
+
str(local_file)
|
| 99 |
+
]
|
| 100 |
+
try:
|
| 101 |
+
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
|
| 102 |
+
if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="")
|
| 103 |
+
if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="")
|
| 104 |
+
return result.returncode == 0
|
| 105 |
+
except subprocess.TimeoutExpired:
|
| 106 |
+
print(f"TimeoutExpired: SSH command to download {remote_file} on {remote_host} timed out after {timeout} seconds.")
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def delete_remote_lock(remote_user, remote_host, remote_port, remote_lock_file, rsa_key, timeout=10):
|
| 111 |
+
"""Delete the lock file on the remote machine."""
|
| 112 |
+
cmd = [
|
| 113 |
+
"ssh",
|
| 114 |
+
"-i", rsa_key,
|
| 115 |
+
"-p", str(remote_port),
|
| 116 |
+
"-o", "StrictHostKeyChecking=no",
|
| 117 |
+
"-o", "UserKnownHostsFile=/dev/null",
|
| 118 |
+
f"{remote_user}@{remote_host}",
|
| 119 |
+
f"rm -f {remote_lock_file}"
|
| 120 |
+
]
|
| 121 |
+
try:
|
| 122 |
+
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
|
| 123 |
+
if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="")
|
| 124 |
+
if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="")
|
| 125 |
+
return result.returncode == 0
|
| 126 |
+
except subprocess.TimeoutExpired:
|
| 127 |
+
print(f"TimeoutExpired: SSH command to delete {remote_lock_file} on {remote_host} timed out after {timeout} seconds.")
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_unique_filename(local_dest, base_name):
|
| 132 |
+
"""Generate a unique filename with a timestamp and return the full path."""
|
| 133 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
| 134 |
+
filename = f"{base_name}_{timestamp}.safetensors"
|
| 135 |
+
return os.path.join(local_dest, filename)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|
joint_loss.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import dataclasses
|
| 3 |
+
import subprocess
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
import itertools
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
+
import pathlib
|
| 8 |
+
from typing import List
|
| 9 |
+
import diffusers
|
| 10 |
+
import transformers
|
| 11 |
+
import safetensors.torch
|
| 12 |
+
import torch.utils.data
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
import random
|
| 16 |
+
import os
|
| 17 |
+
import time
|
| 18 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
torch.manual_seed(0)
|
| 22 |
+
random.seed(0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
LATENTS_OUTPUT_DIR = pathlib.Path("latents")
|
| 26 |
+
CAPTIONS_OUTPUT_DIR = pathlib.Path("captions2")
|
| 27 |
+
DANBOORU_ARTISTS_PATH = pathlib.Path("danbooru_artist.csv")
|
| 28 |
+
E621_ARTISTS_PATH = pathlib.Path("e621_artist.csv")
|
| 29 |
+
LOCK_FILE = "safetensors.lock"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
device = torch.device("cuda")
|
| 33 |
+
dtype = torch.float16
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
train_logger = SummaryWriter(f"logs/pony_scoreless_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def accumulate_grads():
|
| 40 |
+
batch_size = 1
|
| 41 |
+
epochs = 1
|
| 42 |
+
|
| 43 |
+
tokenizer = create_tokenizer(device)
|
| 44 |
+
|
| 45 |
+
model_a = diffusers.StableDiffusionXLPipeline.from_single_file(
|
| 46 |
+
"NoobAI-XL-v1.1.safetensors",
|
| 47 |
+
torch_dtype=dtype,
|
| 48 |
+
)
|
| 49 |
+
delattr(model_a, "vae")
|
| 50 |
+
model_a.unet.to(device=device)
|
| 51 |
+
# model_a.unet.enable_xformers_memory_efficient_attention()
|
| 52 |
+
model_a.unet.enable_gradient_checkpointing()
|
| 53 |
+
model_a.text_encoder.to(device=device)
|
| 54 |
+
model_a.text_encoder.gradient_checkpointing_enable()
|
| 55 |
+
model_a.text_encoder_2.to(device=device)
|
| 56 |
+
model_a.text_encoder_2.gradient_checkpointing_enable()
|
| 57 |
+
model_a.text_encoder_combined = CombinedCLIPTextEncoder(model_a.text_encoder, model_a.text_encoder_2, batch_size)
|
| 58 |
+
|
| 59 |
+
model_b = diffusers.StableDiffusionXLPipeline.from_single_file(
|
| 60 |
+
"animagine-xl-4.0.safetensors",
|
| 61 |
+
torch_dtype=dtype,
|
| 62 |
+
)
|
| 63 |
+
delattr(model_b, "vae")
|
| 64 |
+
model_b.unet.to(device=device)
|
| 65 |
+
# model_b.unet.enable_xformers_memory_efficient_attention()
|
| 66 |
+
model_b.unet.enable_gradient_checkpointing()
|
| 67 |
+
model_b.text_encoder.to(device=device)
|
| 68 |
+
model_b.text_encoder.gradient_checkpointing_enable()
|
| 69 |
+
model_b.text_encoder_2.to(device=device)
|
| 70 |
+
model_b.text_encoder_2.gradient_checkpointing_enable()
|
| 71 |
+
model_b.text_encoder_combined = CombinedCLIPTextEncoder(model_b.text_encoder, model_b.text_encoder_2, batch_size)
|
| 72 |
+
|
| 73 |
+
model_a.unet.eval()
|
| 74 |
+
model_a.text_encoder.eval()
|
| 75 |
+
model_a.text_encoder_2.eval()
|
| 76 |
+
model_b.unet.eval()
|
| 77 |
+
model_b.text_encoder.eval()
|
| 78 |
+
model_b.text_encoder_2.eval()
|
| 79 |
+
|
| 80 |
+
# shared_stats = {}
|
| 81 |
+
# stats_lock = threading.Lock()
|
| 82 |
+
|
| 83 |
+
# # Two barriers for synchronization between two threads.
|
| 84 |
+
# grad_barrier1 = threading.Barrier(2)
|
| 85 |
+
# grad_barrier2 = threading.Barrier(2)
|
| 86 |
+
|
| 87 |
+
# def scaling_hook_factory(key, branch_id, target_scale=1.0):
|
| 88 |
+
# nonlocal shared_stats, stats_lock, grad_barrier1, grad_barrier2
|
| 89 |
+
|
| 90 |
+
# def scaling_hook(_module, _grad_input, grad_output):
|
| 91 |
+
# """
|
| 92 |
+
# A full-backward hook that:
|
| 93 |
+
# 1. Computes, for each non-None tensor in grad_output, its maximum absolute value.
|
| 94 |
+
# We store these in a dictionary (keyed by output index).
|
| 95 |
+
# 2. Waits once until both threads have stored their local max values.
|
| 96 |
+
# 3. Computes, for each output index, the global maximum from both models.
|
| 97 |
+
# 4. Waits a second time to ensure synchronization before clearing the shared stats.
|
| 98 |
+
# 5. Scales each non-None output tensor independently using its computed scaling factor.
|
| 99 |
+
# Outputs that are None are passed through unchanged.
|
| 100 |
+
# """
|
| 101 |
+
# # Step 1: Compute and store local maximums per output index.
|
| 102 |
+
# print(f"backprop for {key}")
|
| 103 |
+
# local_maxes = {}
|
| 104 |
+
# for i, g in enumerate(grad_output):
|
| 105 |
+
# if g is not None:
|
| 106 |
+
# local_maxes[i] = g.detach().abs().max().cpu().item()
|
| 107 |
+
|
| 108 |
+
# with stats_lock:
|
| 109 |
+
# shared_stats[f"{key}_{branch_id}"] = local_maxes
|
| 110 |
+
|
| 111 |
+
# # Step 2: Wait until both threads have stored their values.
|
| 112 |
+
# grad_barrier1.wait()
|
| 113 |
+
|
| 114 |
+
# # Step 3: Compute the global maximum for each output index.
|
| 115 |
+
# with stats_lock:
|
| 116 |
+
# stats_a = shared_stats.get(f"{key}_a", {})
|
| 117 |
+
# stats_b = shared_stats.get(f"{key}_b", {})
|
| 118 |
+
# # Build a dictionary for global max per output index.
|
| 119 |
+
# global_maxes = {}
|
| 120 |
+
# for i in local_maxes.keys():
|
| 121 |
+
# assert i in stats_a and i in stats_b, key
|
| 122 |
+
# global_maxes[i] = max(stats_a[i], stats_b[i])
|
| 123 |
+
|
| 124 |
+
# # Step 4: Wait again to ensure both threads have computed the global values.
|
| 125 |
+
# barrier_val = grad_barrier2.wait()
|
| 126 |
+
# # Let only one thread clear the shared stats.
|
| 127 |
+
# if barrier_val == 0:
|
| 128 |
+
# with stats_lock:
|
| 129 |
+
# shared_stats.pop(f"{key}_a")
|
| 130 |
+
# shared_stats.pop(f"{key}_b")
|
| 131 |
+
|
| 132 |
+
# # Step 5: For each output tensor, compute a scaling factor and apply it.
|
| 133 |
+
# scaled_outputs = []
|
| 134 |
+
# for i, g in enumerate(grad_output):
|
| 135 |
+
# if g is not None:
|
| 136 |
+
# global_max = global_maxes[i]
|
| 137 |
+
# # Compute scaling factor only if global_max is positive and below target_scale.
|
| 138 |
+
# if 0 < global_max < target_scale:
|
| 139 |
+
# g = g * (target_scale / global_max)
|
| 140 |
+
# scaled_outputs.append(g)
|
| 141 |
+
# else:
|
| 142 |
+
# scaled_outputs.append(None)
|
| 143 |
+
|
| 144 |
+
# return tuple(scaled_outputs)
|
| 145 |
+
|
| 146 |
+
# return scaling_hook
|
| 147 |
+
|
| 148 |
+
# for model, branch_id in zip((model_a, model_b), ("a", "b")):
|
| 149 |
+
# for k, v in get_modules(model):
|
| 150 |
+
# if k.endswith("transformer_blocks") or k.endswith("encoder.layers"):
|
| 151 |
+
# for i, module in enumerate(v):
|
| 152 |
+
# module.register_full_backward_hook(scaling_hook_factory(f"{k}.{i}", branch_id))
|
| 153 |
+
|
| 154 |
+
scheduler = create_scheduler(device)
|
| 155 |
+
data_loader = get_data_loader(tokenizer, batch_size)
|
| 156 |
+
total_steps = 0
|
| 157 |
+
|
| 158 |
+
log_scalars_a = {}
|
| 159 |
+
log_scalars_b = {}
|
| 160 |
+
log_scalars_sync = {}
|
| 161 |
+
|
| 162 |
+
n1 = torch.tensor(-1, device=device, dtype=torch.long)
|
| 163 |
+
ldexp_offset = torch.tensor(20, device=device, dtype=torch.long)
|
| 164 |
+
def create_hook(param, k, log_scalars):
|
| 165 |
+
param.grad = torch.zeros_like(param)
|
| 166 |
+
log_scalars[k] = ldexp_offset.clone()
|
| 167 |
+
|
| 168 |
+
def hook(grad):
|
| 169 |
+
nonlocal param, log_scalars, k
|
| 170 |
+
while True:
|
| 171 |
+
new_grad = param.grad + grad.abs().ldexp(log_scalars[k])
|
| 172 |
+
if not new_grad.isfinite().all(): # overflow
|
| 173 |
+
log_scalars[k] -= 1
|
| 174 |
+
param.grad.ldexp_(n1)
|
| 175 |
+
else:
|
| 176 |
+
break
|
| 177 |
+
|
| 178 |
+
param.grad.copy_(new_grad)
|
| 179 |
+
return param.grad
|
| 180 |
+
|
| 181 |
+
return hook
|
| 182 |
+
|
| 183 |
+
for model, log_scalars in ((model_a, log_scalars_a), (model_b, log_scalars_b)):
|
| 184 |
+
for k, v in get_params(model):
|
| 185 |
+
v.register_hook(create_hook(v, k, log_scalars))
|
| 186 |
+
|
| 187 |
+
# for model, path in ((model_a, "grads_a.safetensors"), (model_b, "grads_b.safetensors")):
|
| 188 |
+
# with safetensors.safe_open(path, "pt") as f:
|
| 189 |
+
# for k, v in get_params(model):
|
| 190 |
+
# if k in f.keys():
|
| 191 |
+
# v.grad = f.get_tensor(k).to(v)
|
| 192 |
+
|
| 193 |
+
noisy_latents = timesteps = time_ids = None
|
| 194 |
+
def get_pred(args):
|
| 195 |
+
nonlocal noisy_latents, timesteps, time_ids
|
| 196 |
+
model, tokens = args
|
| 197 |
+
txt = model.text_encoder_combined(tokens[0])
|
| 198 |
+
return model.unet(
|
| 199 |
+
noisy_latents,
|
| 200 |
+
timesteps,
|
| 201 |
+
encoder_hidden_states=txt["conds"],
|
| 202 |
+
added_cond_kwargs={
|
| 203 |
+
"text_embeds": txt["pooled"],
|
| 204 |
+
"time_ids": time_ids,
|
| 205 |
+
},
|
| 206 |
+
).sample
|
| 207 |
+
|
| 208 |
+
params = list(v for k, v in itertools.chain(get_params(model_a), get_params(model_b)))
|
| 209 |
+
with ThreadPoolExecutor(max_workers=2) as worker:
|
| 210 |
+
for epoch_i in range(epochs):
|
| 211 |
+
for step_i, (latent_infos, tokens_a, tokens_b, post_ids) in enumerate(tqdm(data_loader)):
|
| 212 |
+
latents = torch.cat([latent_info["latent"] for latent_info in latent_infos], dim=0).to(device=device, dtype=dtype)
|
| 213 |
+
crop_hw = torch.stack([latent_info["crop_hw"] for latent_info in latent_infos]).to(device=device)
|
| 214 |
+
orig_hw = torch.stack([latent_info["orig_hw"] for latent_info in latent_infos]).to(device=device)
|
| 215 |
+
|
| 216 |
+
noise, noisy_latents, timesteps = get_noise_noisy_latents_and_timesteps(scheduler, latents)
|
| 217 |
+
time_ids = get_add_time_ids(orig_hw, crop_hw)
|
| 218 |
+
|
| 219 |
+
# if step_i < 1000:
|
| 220 |
+
# total_steps += batch_size
|
| 221 |
+
# continue
|
| 222 |
+
|
| 223 |
+
pred_a, pred_b = worker.map(get_pred, ((model_a, tokens_a), (model_b, tokens_b)))
|
| 224 |
+
|
| 225 |
+
mse = torch.nn.functional.mse_loss(pred_a, pred_b, reduction="none").flatten(start_dim=1).mean(dim=-1)
|
| 226 |
+
loss = (mse / mse.detach()).mean()
|
| 227 |
+
|
| 228 |
+
train_logger.add_scalar("grads/loss", loss.item(), total_steps)
|
| 229 |
+
train_logger.add_scalar("grads/loss_raw", mse.mean().item(), total_steps)
|
| 230 |
+
train_logger.add_scalar("grads/timestep", timesteps[0].item(), total_steps)
|
| 231 |
+
|
| 232 |
+
torch.autograd.grad(loss, params, retain_graph=False, allow_unused=True) # calls backward hooks
|
| 233 |
+
|
| 234 |
+
for (k, v_a), (k_b, v_b) in zip(get_params(model_a), get_params(model_b)):
|
| 235 |
+
assert k == k_b
|
| 236 |
+
if v_a.grad is not None and v_b.grad is not None:
|
| 237 |
+
while log_scalars_a[k] > log_scalars_b[k]:
|
| 238 |
+
log_scalars_a[k] -= 1
|
| 239 |
+
v_a.grad.ldexp_(n1)
|
| 240 |
+
while log_scalars_b[k] > log_scalars_a[k]:
|
| 241 |
+
log_scalars_b[k] -= 1
|
| 242 |
+
v_b.grad.ldexp_(n1)
|
| 243 |
+
log_scalars_sync[k] = log_scalars_a[k]
|
| 244 |
+
|
| 245 |
+
if (step_i + 1) % 10 == 0:
|
| 246 |
+
train_logger.add_scalar("grads/max_a", max(v.grad.max().item() for k, v in get_params(model_a) if v.grad is not None), total_steps)
|
| 247 |
+
train_logger.add_scalar("grads/max_b", max(v.grad.max().item() for k, v in get_params(model_b) if v.grad is not None), total_steps)
|
| 248 |
+
|
| 249 |
+
if (step_i + 1) % 1000 == 0:
|
| 250 |
+
save_grads(model_a, "grads_a.safetensors", first=True)
|
| 251 |
+
safetensors.torch.save_file(log_scalars_sync, "log_scalars.safetensors")
|
| 252 |
+
save_grads(model_b, "grads_b.safetensors", last=True)
|
| 253 |
+
|
| 254 |
+
total_steps += batch_size
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_modules(model):
|
| 258 |
+
return itertools.chain(
|
| 259 |
+
prefix_iter(model.unet.named_modules(), "unet."),
|
| 260 |
+
prefix_iter(model.text_encoder.named_modules(), "text_encoder."),
|
| 261 |
+
prefix_iter(model.text_encoder_2.named_modules(), "text_encoder_2."),
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def get_params(model):
|
| 266 |
+
return itertools.chain(
|
| 267 |
+
prefix_iter(model.unet.named_parameters(), "unet."),
|
| 268 |
+
prefix_iter(model.text_encoder.named_parameters(), "text_encoder."),
|
| 269 |
+
prefix_iter(model.text_encoder_2.named_parameters(), "text_encoder_2."),
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def prefix_iter(item_iter, prefix):
|
| 274 |
+
return ((prefix + k, v) for k, v in item_iter)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def save_grads(model, path, first=False, last=False):
|
| 278 |
+
if first:
|
| 279 |
+
wait_for_lock_removal()
|
| 280 |
+
|
| 281 |
+
safetensors.torch.save_file(
|
| 282 |
+
{k: v.grad.cpu().contiguous() for k, v in get_params(model) if v.grad is not None},
|
| 283 |
+
path,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if last:
|
| 287 |
+
# Create a lock file to signal that new checkpoints have been saved
|
| 288 |
+
with open(LOCK_FILE, "w") as f:
|
| 289 |
+
f.write("pending download")
|
| 290 |
+
print("Checkpoint pair saved, lock file created.")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def wait_for_lock_removal(poll_interval=5):
|
| 294 |
+
"""Wait until the lock file is removed by the local download script."""
|
| 295 |
+
while os.path.exists(LOCK_FILE):
|
| 296 |
+
time.sleep(poll_interval)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def create_scheduler(device: torch.device):
|
| 300 |
+
scheduler = diffusers.DDPMScheduler(
|
| 301 |
+
beta_start=0.00085,
|
| 302 |
+
beta_end=0.012,
|
| 303 |
+
beta_schedule="scaled_linear",
|
| 304 |
+
num_train_timesteps=1000,
|
| 305 |
+
clip_sample=False,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
inv_snr = ((1-scheduler.alphas_cumprod) / scheduler.alphas_cumprod).to(device)
|
| 309 |
+
scheduler.inv_snr = inv_snr
|
| 310 |
+
scheduler.inv_snr_weights = inv_snr / inv_snr.sum()
|
| 311 |
+
return scheduler
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def debiased_loss_scaling(timesteps, noise_scheduler):
|
| 315 |
+
return noise_scheduler.inv_snr[timesteps]
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def get_noise_noisy_latents_and_timesteps(scheduler, latents):
|
| 319 |
+
batch_size = latents.shape[0]
|
| 320 |
+
noise = torch.randn_like(latents, device=latents.device)
|
| 321 |
+
|
| 322 |
+
timesteps = torch.multinomial(scheduler.inv_snr_weights, batch_size)
|
| 323 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
| 324 |
+
return noise, noisy_latents, timesteps
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def get_add_time_ids(original_size, crops_coords_top_left):
|
| 328 |
+
add_time_ids = torch.cat([
|
| 329 |
+
original_size,
|
| 330 |
+
crops_coords_top_left,
|
| 331 |
+
torch.tensor([[1024]*2], device=original_size.device).expand(len(original_size), -1),
|
| 332 |
+
], dim=1)
|
| 333 |
+
|
| 334 |
+
return add_time_ids
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def get_data_loader(tokenizer, batch_size: int):
|
| 338 |
+
return torch.utils.data.DataLoader(
|
| 339 |
+
PromptDataset(tokenizer),
|
| 340 |
+
batch_size=batch_size,
|
| 341 |
+
shuffle=True,
|
| 342 |
+
collate_fn=lambda x: zip(*x),
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
@dataclasses.dataclass
|
| 347 |
+
class ArtistScore:
|
| 348 |
+
artist_tag: str
|
| 349 |
+
count: int
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class PromptDataset(torch.utils.data.Dataset):
|
| 353 |
+
def __init__(self, tokenizer):
|
| 354 |
+
self.tokenizer = tokenizer
|
| 355 |
+
self.latent_paths = list(LATENTS_OUTPUT_DIR.iterdir())
|
| 356 |
+
with open(DANBOORU_ARTISTS_PATH, "r", encoding='utf-8') as f:
|
| 357 |
+
reader = csv.DictReader(f)
|
| 358 |
+
self.b_artists = [ArtistScore(r["trigger"], int(r["count"])) for r in reader if r["artist"] != "banned_artist"]
|
| 359 |
+
self.b_artists.sort(key=lambda t: t.count, reverse=True)
|
| 360 |
+
self.b_artist_scores = torch.tensor(list(map(lambda t: t.count, self.b_artists)), device=device, dtype=torch.float32)
|
| 361 |
+
self.b_artist_scores /= self.b_artist_scores.sum()
|
| 362 |
+
|
| 363 |
+
with open(E621_ARTISTS_PATH, "r", encoding='utf-8') as f:
|
| 364 |
+
reader = csv.DictReader(f,)
|
| 365 |
+
self.a_artists = self.b_artists + [ArtistScore(r["trigger"], int(r["count"])) for r in reader if r["artist"] not in ["conditional_dnp", "avoid_posting", "unknown_artist", "third-party_edit", "sound_warning", "anonymous_artist"]]
|
| 366 |
+
self.a_artists.sort(key=lambda t: t.count, reverse=True)
|
| 367 |
+
self.a_artist_scores = torch.tensor(list(map(lambda t: t.count, self.a_artists)), device=device, dtype=torch.float32)
|
| 368 |
+
self.a_artist_scores /= self.a_artist_scores.sum()
|
| 369 |
+
|
| 370 |
+
self.a_prefix = "masterpiece, best quality, newest, absurdres, highres, safe, "
|
| 371 |
+
self.b_suffix = ", masterpiece, high score, great score, absurdres"
|
| 372 |
+
|
| 373 |
+
def __len__(self):
|
| 374 |
+
return len(self.latent_paths)
|
| 375 |
+
|
| 376 |
+
def __getitem__(self, item):
|
| 377 |
+
post_id = self.latent_paths[item].stem
|
| 378 |
+
latent = safetensors.torch.load_file(LATENTS_OUTPUT_DIR / f"{post_id}.safetensors", device=str(device))
|
| 379 |
+
caption = (CAPTIONS_OUTPUT_DIR / f"{post_id}.txt").read_text()
|
| 380 |
+
|
| 381 |
+
caption_a = self.a_prefix + caption
|
| 382 |
+
caption_b = caption + self.b_suffix
|
| 383 |
+
|
| 384 |
+
if item % 2 == 0:
|
| 385 |
+
artist_a = self.a_artists[torch.multinomial(self.a_artist_scores, 1).item()]
|
| 386 |
+
caption_a = artist_a.artist_tag + ", " + caption_a
|
| 387 |
+
else:
|
| 388 |
+
artist_b = self.b_artists[torch.multinomial(self.b_artist_scores, 1).item()]
|
| 389 |
+
caption_b = artist_b.artist_tag + ", " + caption_b
|
| 390 |
+
|
| 391 |
+
tokens_a = self.tokenizer.chunk_tokens(self.tokenizer([caption_a.replace("),", ") ,")]))
|
| 392 |
+
tokens_b = self.tokenizer.chunk_tokens(self.tokenizer([caption_b.replace("),", ") ,")]))
|
| 393 |
+
return latent, tokens_a, tokens_b, post_id
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class CombinedCLIPTextEncoder(torch.nn.Module):
|
| 397 |
+
def __init__(self, clip_l, clip_g, batch_size):
|
| 398 |
+
super().__init__()
|
| 399 |
+
assert batch_size == 1
|
| 400 |
+
self.clip_l = clip_l
|
| 401 |
+
self.clip_g = clip_g
|
| 402 |
+
|
| 403 |
+
def forward(self, tokens):
|
| 404 |
+
tokens_clip_l = tokens["clip_l"].copy()
|
| 405 |
+
del tokens_clip_l["prompt_starts"]
|
| 406 |
+
|
| 407 |
+
tokens_clip_g = tokens["clip_g"].copy()
|
| 408 |
+
clip_g_starts = tokens_clip_g.pop("prompt_starts")
|
| 409 |
+
|
| 410 |
+
clip_l_encoded = self.clip_l(**tokens_clip_l, output_hidden_states=True, return_dict=True)
|
| 411 |
+
clip_g_encoded = self.clip_g(**tokens_clip_g, output_hidden_states=True, return_dict=True)
|
| 412 |
+
combined_encoded = torch.cat([clip_l_encoded["hidden_states"][-2], clip_g_encoded["hidden_states"][-2]], dim=-1)
|
| 413 |
+
combined_encoded_reshape = combined_encoded.reshape(1, -1, 2048)
|
| 414 |
+
|
| 415 |
+
return {
|
| 416 |
+
"conds": combined_encoded_reshape,
|
| 417 |
+
"pooled": clip_g_encoded.text_embeds[clip_g_starts],
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def create_tokenizer(device: torch.device):
|
| 422 |
+
tokenizer_l = transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 423 |
+
tokenizer_g = transformers.CLIPTokenizer.from_pretrained("laion/CLIP-ViT-g-14-laion2B-s34B-b88K")
|
| 424 |
+
return CombinedCLIPTokenizer(tokenizer_l, tokenizer_g, device)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class CombinedCLIPTokenizer(torch.nn.Module):
|
| 428 |
+
comma_token = 267
|
| 429 |
+
|
| 430 |
+
def __init__(self, tokenizer_l, tokenizer_g, output_device: torch.device):
|
| 431 |
+
super().__init__()
|
| 432 |
+
self.tokenizer_l = tokenizer_l
|
| 433 |
+
self.tokenizer_g = tokenizer_g
|
| 434 |
+
self.output_device = output_device
|
| 435 |
+
|
| 436 |
+
def forward(self, prompts: List[str]) -> dict:
|
| 437 |
+
tokens_l = self.tokenizer_l(prompts, add_special_tokens=False)
|
| 438 |
+
return {
|
| 439 |
+
"clip_l": tokens_l,
|
| 440 |
+
"clip_g": deepcopy(tokens_l),
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
def chunk_tokens(self, tokens: dict):
|
| 444 |
+
return {
|
| 445 |
+
"clip_l": self._chunk_tokens_impl(self.tokenizer_l, tokens["clip_l"]),
|
| 446 |
+
"clip_g": self._chunk_tokens_impl(self.tokenizer_g, tokens["clip_g"]),
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
def _chunk_tokens_impl(self, tokenizer, tokens: dict):
|
| 450 |
+
input_ids = []
|
| 451 |
+
attention_masks = []
|
| 452 |
+
chunk_counts = []
|
| 453 |
+
|
| 454 |
+
for prompt, mask in zip(tokens["input_ids"], tokens["attention_mask"]):
|
| 455 |
+
last_comma = 0
|
| 456 |
+
current_chunk = []
|
| 457 |
+
chunks = []
|
| 458 |
+
chunks_attn = []
|
| 459 |
+
|
| 460 |
+
def next_chunk():
|
| 461 |
+
nonlocal current_chunk
|
| 462 |
+
current_chunk = [tokenizer.bos_token_id] + current_chunk + [tokenizer.eos_token_id]
|
| 463 |
+
num_tokens = len(current_chunk)
|
| 464 |
+
|
| 465 |
+
current_chunk.extend([tokenizer.pad_token_id] * (77 - num_tokens))
|
| 466 |
+
chunks.append(current_chunk)
|
| 467 |
+
current_chunk = []
|
| 468 |
+
chunks_attn.append([1] * num_tokens + [0] * (77 - num_tokens))
|
| 469 |
+
|
| 470 |
+
for token_i, token in enumerate(prompt):
|
| 471 |
+
is_last_token = token_i == len(prompt) - 1
|
| 472 |
+
seq_suffix = prompt[last_comma:token_i + int(is_last_token)]
|
| 473 |
+
|
| 474 |
+
if token == self.comma_token or is_last_token:
|
| 475 |
+
if len(current_chunk) + len(seq_suffix) > 77 - 2: # leave space for bos and eos
|
| 476 |
+
next_chunk()
|
| 477 |
+
seq_suffix = prompt[last_comma+1:token_i + int(is_last_token)] # remove leading comma
|
| 478 |
+
|
| 479 |
+
# can always append, sequences without commas will never be longer than 77 tokens
|
| 480 |
+
current_chunk.extend(seq_suffix)
|
| 481 |
+
last_comma = token_i
|
| 482 |
+
|
| 483 |
+
if current_chunk or not chunks:
|
| 484 |
+
next_chunk()
|
| 485 |
+
|
| 486 |
+
chunk_counts.append(len(chunks))
|
| 487 |
+
input_ids.extend(chunks)
|
| 488 |
+
attention_masks.extend(chunks_attn)
|
| 489 |
+
|
| 490 |
+
return {
|
| 491 |
+
"input_ids": torch.tensor(input_ids, device=self.output_device),
|
| 492 |
+
"attention_mask": torch.tensor(attention_masks, device=self.output_device),
|
| 493 |
+
"prompt_starts": torch.tensor([0] + chunk_counts[:-1], device=self.output_device).cumsum(dim=0),
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def shutdown_machine():
|
| 498 |
+
"""Shutdown the machine. Adjust the command as necessary for your environment."""
|
| 499 |
+
|
| 500 |
+
wait_for_lock_removal()
|
| 501 |
+
print("All checkpoints have been downloaded. Shutting down the machine.")
|
| 502 |
+
try:
|
| 503 |
+
subprocess.run("runpodctl stop pod $RUNPOD_POD_ID", shell=True, check=True)
|
| 504 |
+
except Exception as e:
|
| 505 |
+
print(f"Error shutting down: {e}")
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
if __name__ == "__main__":
|
| 509 |
+
accumulate_grads()
|
| 510 |
+
shutdown_machine()
|
prepare_dataset.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
import random
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import List
|
| 5 |
+
import diffusers
|
| 6 |
+
import torch
|
| 7 |
+
import safetensors.torch
|
| 8 |
+
import transformers
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED
|
| 14 |
+
import threading
|
| 15 |
+
import dataclasses
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
devices = [torch.device("cuda:0"), torch.device("cuda:1"), torch.device("cuda:2")]
|
| 19 |
+
dtypes = [torch.bfloat16, torch.float32, torch.float32]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
VAE_PATH = "KBlueLeaf/EQ-SDXL-VAE"
|
| 23 |
+
SDXL_PATH = "/home/ljleb/sd/models/Stable-diffusion/noobaiXLNAIXL_epsilonPred11Version.safetensors"
|
| 24 |
+
IMAGES_DIR = pathlib.Path("/mnt/data/shared/danbooru")
|
| 25 |
+
LATENT_DIR = pathlib.Path("/mnt/data/shared/danbooru-latent")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclasses.dataclass
|
| 29 |
+
class Worker:
|
| 30 |
+
device: torch.device
|
| 31 |
+
dtype: torch.dtype
|
| 32 |
+
vae_w = None
|
| 33 |
+
sdxl = None
|
| 34 |
+
tokenizer = None
|
| 35 |
+
|
| 36 |
+
def __post_init__(self):
|
| 37 |
+
self.vae_w = AutoencoderKL.from_pretrained(VAE_PATH, torch_dtype=self.dtype).to(self.device)
|
| 38 |
+
self.vae_w.eval()
|
| 39 |
+
|
| 40 |
+
self.sdxl = StableDiffusionXLPipeline.from_single_file(SDXL_PATH, torch_dtype=self.dtype).to(self.device)
|
| 41 |
+
self.sdxl.unet.eval()
|
| 42 |
+
self.sdxl.vae.eval()
|
| 43 |
+
self.sdxl.text_encoder.eval()
|
| 44 |
+
self.sdxl.text_encoder_2.eval()
|
| 45 |
+
|
| 46 |
+
self.sdxl.text_encoder_combined = CombinedCLIPTextEncoder(self.sdxl.text_encoder, self.sdxl.text_encoder_2, self.device)
|
| 47 |
+
self.tokenizer = create_tokenizer(self.device)
|
| 48 |
+
|
| 49 |
+
self.scheduler = create_scheduler()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main():
|
| 53 |
+
images = list(IMAGES_DIR.iterdir())
|
| 54 |
+
LATENT_DIR.mkdir(exist_ok=True)
|
| 55 |
+
workers = [
|
| 56 |
+
Worker(device, dtype)
|
| 57 |
+
for device, dtype in zip(devices, dtypes)
|
| 58 |
+
]
|
| 59 |
+
with ThreadPoolExecutor(max_workers=len(workers)) as executor:
|
| 60 |
+
futures = {}
|
| 61 |
+
for image in tqdm(images):
|
| 62 |
+
if len(futures) >= len(workers):
|
| 63 |
+
completed_futures, _ = wait(list(futures.values()), return_when=FIRST_COMPLETED)
|
| 64 |
+
for future in completed_futures:
|
| 65 |
+
if future.exception() is not None:
|
| 66 |
+
for future_to_cancel in futures.values():
|
| 67 |
+
future_to_cancel.cancel()
|
| 68 |
+
raise future.exception()
|
| 69 |
+
else:
|
| 70 |
+
future.result()
|
| 71 |
+
futures = {
|
| 72 |
+
k: v for k, v in futures.items()
|
| 73 |
+
if v not in completed_futures
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
for worker in workers:
|
| 77 |
+
if worker.device not in futures:
|
| 78 |
+
futures[worker.device] = executor.submit(prepare_image, worker, image)
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
for future in futures.values():
|
| 82 |
+
if future.exception() is not None:
|
| 83 |
+
for future_to_cancel in futures.values():
|
| 84 |
+
future_to_cancel.cancel()
|
| 85 |
+
raise future.exception()
|
| 86 |
+
else:
|
| 87 |
+
future.result()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
def prepare_image(worker: Worker, img_path: pathlib.Path):
|
| 92 |
+
# We'll define a transform to convert an image to a tensor
|
| 93 |
+
to_tensor = T.Compose([
|
| 94 |
+
T.ToTensor(),
|
| 95 |
+
T.Lambda(lambda t: t*2 - 1)
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
# w_0_offset = torch.tensor([-3.8846, -1.3187, 0.8009, 0.9180], device=device, dtype=dtype)
|
| 99 |
+
# w_0_scale = torch.tensor([10.0298, 6.8674, 7.2104, 5.5948], device=device, dtype=dtype)
|
| 100 |
+
|
| 101 |
+
# Iterate over images in directory
|
| 102 |
+
if not img_path.is_file():
|
| 103 |
+
return
|
| 104 |
+
if img_path.suffix.lower() not in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff"]:
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
# Attempt to open image
|
| 108 |
+
try:
|
| 109 |
+
img = Image.open(img_path).convert("RGB")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error loading image {img_path.name}: {e}")
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
# Read the caption from the matching .txt file (if it exists)
|
| 115 |
+
txt_path = img_path.with_suffix(img_path.suffix + ".txt")
|
| 116 |
+
if not txt_path.is_file():
|
| 117 |
+
print(f"No caption file for {img_path.name}, skipping.")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
caption = txt_path.read_text(encoding="utf-8").strip()
|
| 121 |
+
if not caption:
|
| 122 |
+
print(f"Empty caption for {img_path.name}, skipping.")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
out_path = LATENT_DIR / (img_path.stem + ".safetensors")
|
| 126 |
+
if out_path.exists():
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
caption = caption.replace("\n", " , ").replace("_", " ")
|
| 130 |
+
|
| 131 |
+
width, height = img.size
|
| 132 |
+
orig_pixels = width * height
|
| 133 |
+
target_pixels = 1024 * 1024
|
| 134 |
+
if orig_pixels > target_pixels:
|
| 135 |
+
scale = (target_pixels / float(orig_pixels)) ** 0.5
|
| 136 |
+
width = int(round(width * scale))
|
| 137 |
+
height = int(round(height * scale))
|
| 138 |
+
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
| 139 |
+
|
| 140 |
+
tokens_raw = worker.tokenizer([caption])
|
| 141 |
+
tokens = worker.tokenizer.chunk_tokens(tokens_raw)
|
| 142 |
+
|
| 143 |
+
# Convert image to tensor on device
|
| 144 |
+
img_tensor = to_tensor(img).unsqueeze(0).to(device=worker.device, dtype=worker.dtype)
|
| 145 |
+
|
| 146 |
+
# Encode the image with each VAE
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
latents_w_unnorm = worker.vae_w.encode(img_tensor).latent_dist.sample()
|
| 149 |
+
latents_z = worker.sdxl.vae.encode(img_tensor).latent_dist.sample() * 0.13025
|
| 150 |
+
|
| 151 |
+
# Sample noise and a random timestep
|
| 152 |
+
noise, noisy_latents_z, timesteps = get_noise_noisy_latents_and_timesteps(worker.scheduler, latents_z)
|
| 153 |
+
time_ids = get_add_time_ids(height, width, worker.device)
|
| 154 |
+
embeds = worker.sdxl.text_encoder_combined(tokens)
|
| 155 |
+
|
| 156 |
+
epsilon_pred = get_pred(worker.sdxl, noisy_latents_z, embeds, timesteps, time_ids)
|
| 157 |
+
|
| 158 |
+
encoded = {
|
| 159 |
+
"timesteps": timesteps,
|
| 160 |
+
"hw": torch.tensor([[height, width]], dtype=torch.long),
|
| 161 |
+
"w_0_unnorm": latents_w_unnorm,
|
| 162 |
+
"z_0": latents_z,
|
| 163 |
+
"epsilon_pred": epsilon_pred,
|
| 164 |
+
"epsilon": noise,
|
| 165 |
+
"conds": embeds["conds"],
|
| 166 |
+
"pooled": embeds["pooled"],
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
safetensors.torch.save_file(encoded, str(out_path))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_add_time_ids(width, height, device):
|
| 173 |
+
original_size = torch.tensor([[width, height]], device=device)
|
| 174 |
+
add_time_ids = torch.cat([
|
| 175 |
+
original_size,
|
| 176 |
+
torch.tensor([[0]*2], device=device).expand(len(original_size), -1),
|
| 177 |
+
original_size,
|
| 178 |
+
], dim=1)
|
| 179 |
+
return add_time_ids
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_pred(sdxl, noisy_latents, embeds, timesteps, time_ids):
|
| 183 |
+
return sdxl.unet(
|
| 184 |
+
noisy_latents,
|
| 185 |
+
timesteps,
|
| 186 |
+
encoder_hidden_states=embeds["conds"],
|
| 187 |
+
added_cond_kwargs={
|
| 188 |
+
"text_embeds": embeds["pooled"],
|
| 189 |
+
"time_ids": time_ids,
|
| 190 |
+
},
|
| 191 |
+
).sample
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_noise_noisy_latents_and_timesteps(scheduler, latents):
|
| 195 |
+
noise = torch.randn_like(latents, device=latents.device)
|
| 196 |
+
batch_size = latents.shape[0]
|
| 197 |
+
timesteps = torch.randint(0, 999, (batch_size,), device=latents.device)
|
| 198 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
| 199 |
+
return noise, noisy_latents, timesteps
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def create_scheduler():
|
| 203 |
+
scheduler = diffusers.DDPMScheduler(
|
| 204 |
+
beta_start=0.00085,
|
| 205 |
+
beta_end=0.012,
|
| 206 |
+
beta_schedule="scaled_linear",
|
| 207 |
+
num_train_timesteps=1000,
|
| 208 |
+
clip_sample=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return scheduler
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class CombinedCLIPTextEncoder(torch.nn.Module):
|
| 215 |
+
def __init__(self, clip_l, clip_g, device):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.clip_l = clip_l.to(device=device)
|
| 218 |
+
self.clip_g = clip_g.to(device=device)
|
| 219 |
+
self.device = device
|
| 220 |
+
|
| 221 |
+
def forward(self, tokens_batch):
|
| 222 |
+
res = {
|
| 223 |
+
"conds": torch.tensor([], device=self.device).view(0, 1, 1),
|
| 224 |
+
"pooled": torch.tensor([], device=self.device).view(0, 1, 1),
|
| 225 |
+
}
|
| 226 |
+
tokens_clip_l = tokens_batch["clip_l"].copy()
|
| 227 |
+
del tokens_clip_l["prompt_starts"]
|
| 228 |
+
|
| 229 |
+
tokens_clip_g = tokens_batch["clip_g"].copy()
|
| 230 |
+
clip_g_starts = tokens_clip_g.pop("prompt_starts")
|
| 231 |
+
|
| 232 |
+
clip_l_encoded = self.clip_l(**tokens_clip_l, output_hidden_states=True, return_dict=True)
|
| 233 |
+
clip_g_encoded = self.clip_g(**tokens_clip_g, output_hidden_states=True, return_dict=True)
|
| 234 |
+
combined_encoded = torch.cat([clip_l_encoded["hidden_states"][-2], clip_g_encoded["hidden_states"][-2]], dim=-1)
|
| 235 |
+
combined_encoded_reshape = combined_encoded.reshape(1, -1, 2048)
|
| 236 |
+
|
| 237 |
+
res["conds"] = combined_encoded_reshape
|
| 238 |
+
res["pooled"] = clip_g_encoded.text_embeds[clip_g_starts]
|
| 239 |
+
|
| 240 |
+
return res
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def create_tokenizer(device: torch.device):
|
| 244 |
+
tokenizer_l = transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 245 |
+
tokenizer_g = transformers.CLIPTokenizer.from_pretrained("laion/CLIP-ViT-g-14-laion2B-s34B-b88K")
|
| 246 |
+
return CombinedCLIPTokenizer(tokenizer_l, tokenizer_g, device)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class CombinedCLIPTokenizer(torch.nn.Module):
|
| 250 |
+
comma_token = 267
|
| 251 |
+
|
| 252 |
+
def __init__(self, tokenizer_l, tokenizer_g, output_device: torch.device):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.tokenizer_l = tokenizer_l
|
| 255 |
+
self.tokenizer_g = tokenizer_g
|
| 256 |
+
self.output_device = output_device
|
| 257 |
+
|
| 258 |
+
def forward(self, prompts: List[str]) -> dict:
|
| 259 |
+
tokens_l = self.tokenizer_l(prompts, add_special_tokens=False)
|
| 260 |
+
return {
|
| 261 |
+
"clip_l": tokens_l,
|
| 262 |
+
"clip_g": deepcopy(tokens_l),
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
def chunk_tokens(self, tokens: dict):
|
| 266 |
+
return {
|
| 267 |
+
"clip_l": self._chunk_tokens_impl(self.tokenizer_l, tokens["clip_l"]),
|
| 268 |
+
"clip_g": self._chunk_tokens_impl(self.tokenizer_g, tokens["clip_g"]),
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
def _chunk_tokens_impl(self, tokenizer, tokens: dict):
|
| 272 |
+
input_ids = []
|
| 273 |
+
attention_masks = []
|
| 274 |
+
chunk_counts = []
|
| 275 |
+
|
| 276 |
+
for prompt, mask in zip(tokens["input_ids"], tokens["attention_mask"]):
|
| 277 |
+
last_comma = 0
|
| 278 |
+
current_chunk = []
|
| 279 |
+
chunks = []
|
| 280 |
+
chunks_attn = []
|
| 281 |
+
|
| 282 |
+
def next_chunk():
|
| 283 |
+
nonlocal current_chunk
|
| 284 |
+
current_chunk = [tokenizer.bos_token_id] + current_chunk + [tokenizer.eos_token_id]
|
| 285 |
+
num_tokens = len(current_chunk)
|
| 286 |
+
|
| 287 |
+
current_chunk.extend([tokenizer.pad_token_id] * (77 - num_tokens))
|
| 288 |
+
chunks.append(current_chunk)
|
| 289 |
+
current_chunk = []
|
| 290 |
+
chunks_attn.append([1] * num_tokens + [0] * (77 - num_tokens))
|
| 291 |
+
|
| 292 |
+
for token_i, token in enumerate(prompt):
|
| 293 |
+
is_last_token = token_i == len(prompt) - 1
|
| 294 |
+
seq_suffix = prompt[last_comma:token_i + int(is_last_token)]
|
| 295 |
+
|
| 296 |
+
if token == self.comma_token or is_last_token:
|
| 297 |
+
if len(current_chunk) + len(seq_suffix) > 77 - 2: # leave space for bos and eos
|
| 298 |
+
next_chunk()
|
| 299 |
+
seq_suffix = prompt[last_comma+1:token_i + int(is_last_token)] # remove leading comma
|
| 300 |
+
|
| 301 |
+
# can always append, sequences without commas will never be longer than 77 tokens
|
| 302 |
+
current_chunk.extend(seq_suffix)
|
| 303 |
+
last_comma = token_i
|
| 304 |
+
|
| 305 |
+
if current_chunk or not chunks:
|
| 306 |
+
next_chunk()
|
| 307 |
+
|
| 308 |
+
chunk_counts.append(len(chunks))
|
| 309 |
+
input_ids.extend(chunks)
|
| 310 |
+
attention_masks.extend(chunks_attn)
|
| 311 |
+
|
| 312 |
+
return {
|
| 313 |
+
"input_ids": torch.tensor(input_ids, device=self.output_device),
|
| 314 |
+
"attention_mask": torch.tensor(attention_masks, device=self.output_device),
|
| 315 |
+
"prompt_starts": torch.tensor([0] + chunk_counts[:-1], device=self.output_device).cumsum(dim=0),
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
main()
|