ljleb commited on
Commit
01f5cc2
·
verified ·
1 Parent(s): 83bd1f3

Upload 3 files

Browse files
Files changed (3) hide show
  1. download_checkpoints.py +139 -0
  2. joint_loss.py +510 -0
  3. prepare_dataset.py +320 -0
download_checkpoints.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import time
4
+ import argparse
5
+ import os
6
+ import datetime
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(description="Download checkpoint pair from remote runpod machine with unique filenames.")
11
+ parser.add_argument("--remote-ip", required=True, help="Remote machine IP address")
12
+ parser.add_argument("--remote-port", required=True, type=int, help="Remote SSH port")
13
+ parser.add_argument("--remote-user", required=True, help="Username for remote SSH")
14
+ parser.add_argument("--remote-base-path", default="/workspace", help="Directory on remote machine containing checkpoints and lock file")
15
+ parser.add_argument("--local-dest", required=True, help="Local directory where checkpoints should be saved")
16
+ parser.add_argument("--rsa-key", required=True, help="Path to your RSA private key for authentication")
17
+ parser.add_argument("--poll-interval", type=float, default=10, help="Polling interval in seconds")
18
+
19
+ args = parser.parse_args()
20
+
21
+ # Construct remote file paths.
22
+ remote_checkpoint_a = f"{args.remote_base_path}/grads_a.safetensors"
23
+ remote_checkpoint_b = f"{args.remote_base_path}/grads_b.safetensors"
24
+ remote_inv_log_scalars = f"{args.remote_base_path}/log_scalars.safetensors"
25
+ remote_thresholds = f"{args.remote_base_path}/thresholds.safetensors"
26
+ remote_lock_file = f"{args.remote_base_path}/safetensors.lock"
27
+
28
+ print("Starting remote checkpoint monitor...")
29
+ while True:
30
+ # Check if the lock file exists on the remote machine.
31
+ if remote_file_exists(args.remote_user, args.remote_ip, args.remote_port, remote_lock_file, args.rsa_key):
32
+ print("New checkpoints detected. Downloading...")
33
+
34
+ # Generate unique filenames for each model.
35
+ local_checkpoint_a = get_unique_filename(args.local_dest, "grads_a")
36
+ local_checkpoint_b = get_unique_filename(args.local_dest, "grads_b")
37
+ local_inv_log_scalars = get_unique_filename(args.local_dest, "log_scalars")
38
+ local_thresholds = get_unique_filename(args.local_dest, "thresholds")
39
+
40
+ try:
41
+ # Download both checkpoints with the unique filenames.
42
+ download_file(args.remote_user, args.remote_ip, args.remote_port, remote_checkpoint_a, local_checkpoint_a, args.rsa_key)
43
+ download_file(args.remote_user, args.remote_ip, args.remote_port, remote_checkpoint_b, local_checkpoint_b, args.rsa_key)
44
+ download_file(args.remote_user, args.remote_ip, args.remote_port, remote_inv_log_scalars, local_inv_log_scalars, args.rsa_key)
45
+ # download_file(args.remote_user, args.remote_ip, args.remote_port, remote_thresholds, local_thresholds, args.rsa_key)
46
+ except subprocess.CalledProcessError as e:
47
+ print(f"Download error: {e}")
48
+ time.sleep(args.poll_interval)
49
+ continue
50
+
51
+ # After successful download, delete only the lock file on the remote side.
52
+ try:
53
+ while not delete_remote_lock(args.remote_user, args.remote_ip, args.remote_port, remote_lock_file, args.rsa_key):
54
+ continue
55
+
56
+ print("Download complete. Checkpoints saved as:")
57
+ print(f" {local_checkpoint_a}")
58
+ print(f" {local_checkpoint_b}")
59
+ print("Remote lock file deleted.")
60
+ except subprocess.CalledProcessError as e:
61
+ print(f"Error deleting remote lock file: {e}")
62
+ else:
63
+ print("No checkpoints found.")
64
+
65
+ time.sleep(args.poll_interval)
66
+
67
+
68
+ def remote_file_exists(remote_user, remote_host, remote_port, remote_path, rsa_key, timeout=10):
69
+ """Check if a file exists on the remote machine."""
70
+ cmd = [
71
+ "ssh",
72
+ "-i", rsa_key,
73
+ "-p", str(remote_port),
74
+ "-o", "StrictHostKeyChecking=no",
75
+ "-o", "UserKnownHostsFile=/dev/null",
76
+ f"{remote_user}@{remote_host}",
77
+ f"test -f {remote_path}"
78
+ ]
79
+ try:
80
+ result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
81
+ if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="")
82
+ if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="")
83
+ return result.returncode == 0
84
+ except subprocess.TimeoutExpired:
85
+ print(f"TimeoutExpired: SSH command to check {remote_path} on {remote_host} timed out after {timeout} seconds.")
86
+ return False
87
+
88
+
89
+ def download_file(remote_user, remote_host, remote_port, remote_file, local_file, rsa_key, timeout=1200):
90
+ """Download a file from the remote machine using scp and save it with a specific name."""
91
+ cmd = [
92
+ "scp",
93
+ "-i", rsa_key,
94
+ "-P", str(remote_port),
95
+ "-o", "StrictHostKeyChecking=no",
96
+ "-o", "UserKnownHostsFile=/dev/null",
97
+ f"{remote_user}@{remote_host}:{remote_file}",
98
+ str(local_file)
99
+ ]
100
+ try:
101
+ result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
102
+ if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="")
103
+ if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="")
104
+ return result.returncode == 0
105
+ except subprocess.TimeoutExpired:
106
+ print(f"TimeoutExpired: SSH command to download {remote_file} on {remote_host} timed out after {timeout} seconds.")
107
+ return False
108
+
109
+
110
+ def delete_remote_lock(remote_user, remote_host, remote_port, remote_lock_file, rsa_key, timeout=10):
111
+ """Delete the lock file on the remote machine."""
112
+ cmd = [
113
+ "ssh",
114
+ "-i", rsa_key,
115
+ "-p", str(remote_port),
116
+ "-o", "StrictHostKeyChecking=no",
117
+ "-o", "UserKnownHostsFile=/dev/null",
118
+ f"{remote_user}@{remote_host}",
119
+ f"rm -f {remote_lock_file}"
120
+ ]
121
+ try:
122
+ result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
123
+ if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="")
124
+ if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="")
125
+ return result.returncode == 0
126
+ except subprocess.TimeoutExpired:
127
+ print(f"TimeoutExpired: SSH command to delete {remote_lock_file} on {remote_host} timed out after {timeout} seconds.")
128
+ return False
129
+
130
+
131
+ def get_unique_filename(local_dest, base_name):
132
+ """Generate a unique filename with a timestamp and return the full path."""
133
+ timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
134
+ filename = f"{base_name}_{timestamp}.safetensors"
135
+ return os.path.join(local_dest, filename)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
joint_loss.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import dataclasses
3
+ import subprocess
4
+ from copy import deepcopy
5
+ import itertools
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ import pathlib
8
+ from typing import List
9
+ import diffusers
10
+ import transformers
11
+ import safetensors.torch
12
+ import torch.utils.data
13
+ from tqdm import tqdm
14
+ from datetime import datetime
15
+ import random
16
+ import os
17
+ import time
18
+ from torch.utils.tensorboard import SummaryWriter
19
+
20
+
21
+ torch.manual_seed(0)
22
+ random.seed(0)
23
+
24
+
25
+ LATENTS_OUTPUT_DIR = pathlib.Path("latents")
26
+ CAPTIONS_OUTPUT_DIR = pathlib.Path("captions2")
27
+ DANBOORU_ARTISTS_PATH = pathlib.Path("danbooru_artist.csv")
28
+ E621_ARTISTS_PATH = pathlib.Path("e621_artist.csv")
29
+ LOCK_FILE = "safetensors.lock"
30
+
31
+
32
+ device = torch.device("cuda")
33
+ dtype = torch.float16
34
+
35
+
36
+ train_logger = SummaryWriter(f"logs/pony_scoreless_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
37
+
38
+
39
+ def accumulate_grads():
40
+ batch_size = 1
41
+ epochs = 1
42
+
43
+ tokenizer = create_tokenizer(device)
44
+
45
+ model_a = diffusers.StableDiffusionXLPipeline.from_single_file(
46
+ "NoobAI-XL-v1.1.safetensors",
47
+ torch_dtype=dtype,
48
+ )
49
+ delattr(model_a, "vae")
50
+ model_a.unet.to(device=device)
51
+ # model_a.unet.enable_xformers_memory_efficient_attention()
52
+ model_a.unet.enable_gradient_checkpointing()
53
+ model_a.text_encoder.to(device=device)
54
+ model_a.text_encoder.gradient_checkpointing_enable()
55
+ model_a.text_encoder_2.to(device=device)
56
+ model_a.text_encoder_2.gradient_checkpointing_enable()
57
+ model_a.text_encoder_combined = CombinedCLIPTextEncoder(model_a.text_encoder, model_a.text_encoder_2, batch_size)
58
+
59
+ model_b = diffusers.StableDiffusionXLPipeline.from_single_file(
60
+ "animagine-xl-4.0.safetensors",
61
+ torch_dtype=dtype,
62
+ )
63
+ delattr(model_b, "vae")
64
+ model_b.unet.to(device=device)
65
+ # model_b.unet.enable_xformers_memory_efficient_attention()
66
+ model_b.unet.enable_gradient_checkpointing()
67
+ model_b.text_encoder.to(device=device)
68
+ model_b.text_encoder.gradient_checkpointing_enable()
69
+ model_b.text_encoder_2.to(device=device)
70
+ model_b.text_encoder_2.gradient_checkpointing_enable()
71
+ model_b.text_encoder_combined = CombinedCLIPTextEncoder(model_b.text_encoder, model_b.text_encoder_2, batch_size)
72
+
73
+ model_a.unet.eval()
74
+ model_a.text_encoder.eval()
75
+ model_a.text_encoder_2.eval()
76
+ model_b.unet.eval()
77
+ model_b.text_encoder.eval()
78
+ model_b.text_encoder_2.eval()
79
+
80
+ # shared_stats = {}
81
+ # stats_lock = threading.Lock()
82
+
83
+ # # Two barriers for synchronization between two threads.
84
+ # grad_barrier1 = threading.Barrier(2)
85
+ # grad_barrier2 = threading.Barrier(2)
86
+
87
+ # def scaling_hook_factory(key, branch_id, target_scale=1.0):
88
+ # nonlocal shared_stats, stats_lock, grad_barrier1, grad_barrier2
89
+
90
+ # def scaling_hook(_module, _grad_input, grad_output):
91
+ # """
92
+ # A full-backward hook that:
93
+ # 1. Computes, for each non-None tensor in grad_output, its maximum absolute value.
94
+ # We store these in a dictionary (keyed by output index).
95
+ # 2. Waits once until both threads have stored their local max values.
96
+ # 3. Computes, for each output index, the global maximum from both models.
97
+ # 4. Waits a second time to ensure synchronization before clearing the shared stats.
98
+ # 5. Scales each non-None output tensor independently using its computed scaling factor.
99
+ # Outputs that are None are passed through unchanged.
100
+ # """
101
+ # # Step 1: Compute and store local maximums per output index.
102
+ # print(f"backprop for {key}")
103
+ # local_maxes = {}
104
+ # for i, g in enumerate(grad_output):
105
+ # if g is not None:
106
+ # local_maxes[i] = g.detach().abs().max().cpu().item()
107
+
108
+ # with stats_lock:
109
+ # shared_stats[f"{key}_{branch_id}"] = local_maxes
110
+
111
+ # # Step 2: Wait until both threads have stored their values.
112
+ # grad_barrier1.wait()
113
+
114
+ # # Step 3: Compute the global maximum for each output index.
115
+ # with stats_lock:
116
+ # stats_a = shared_stats.get(f"{key}_a", {})
117
+ # stats_b = shared_stats.get(f"{key}_b", {})
118
+ # # Build a dictionary for global max per output index.
119
+ # global_maxes = {}
120
+ # for i in local_maxes.keys():
121
+ # assert i in stats_a and i in stats_b, key
122
+ # global_maxes[i] = max(stats_a[i], stats_b[i])
123
+
124
+ # # Step 4: Wait again to ensure both threads have computed the global values.
125
+ # barrier_val = grad_barrier2.wait()
126
+ # # Let only one thread clear the shared stats.
127
+ # if barrier_val == 0:
128
+ # with stats_lock:
129
+ # shared_stats.pop(f"{key}_a")
130
+ # shared_stats.pop(f"{key}_b")
131
+
132
+ # # Step 5: For each output tensor, compute a scaling factor and apply it.
133
+ # scaled_outputs = []
134
+ # for i, g in enumerate(grad_output):
135
+ # if g is not None:
136
+ # global_max = global_maxes[i]
137
+ # # Compute scaling factor only if global_max is positive and below target_scale.
138
+ # if 0 < global_max < target_scale:
139
+ # g = g * (target_scale / global_max)
140
+ # scaled_outputs.append(g)
141
+ # else:
142
+ # scaled_outputs.append(None)
143
+
144
+ # return tuple(scaled_outputs)
145
+
146
+ # return scaling_hook
147
+
148
+ # for model, branch_id in zip((model_a, model_b), ("a", "b")):
149
+ # for k, v in get_modules(model):
150
+ # if k.endswith("transformer_blocks") or k.endswith("encoder.layers"):
151
+ # for i, module in enumerate(v):
152
+ # module.register_full_backward_hook(scaling_hook_factory(f"{k}.{i}", branch_id))
153
+
154
+ scheduler = create_scheduler(device)
155
+ data_loader = get_data_loader(tokenizer, batch_size)
156
+ total_steps = 0
157
+
158
+ log_scalars_a = {}
159
+ log_scalars_b = {}
160
+ log_scalars_sync = {}
161
+
162
+ n1 = torch.tensor(-1, device=device, dtype=torch.long)
163
+ ldexp_offset = torch.tensor(20, device=device, dtype=torch.long)
164
+ def create_hook(param, k, log_scalars):
165
+ param.grad = torch.zeros_like(param)
166
+ log_scalars[k] = ldexp_offset.clone()
167
+
168
+ def hook(grad):
169
+ nonlocal param, log_scalars, k
170
+ while True:
171
+ new_grad = param.grad + grad.abs().ldexp(log_scalars[k])
172
+ if not new_grad.isfinite().all(): # overflow
173
+ log_scalars[k] -= 1
174
+ param.grad.ldexp_(n1)
175
+ else:
176
+ break
177
+
178
+ param.grad.copy_(new_grad)
179
+ return param.grad
180
+
181
+ return hook
182
+
183
+ for model, log_scalars in ((model_a, log_scalars_a), (model_b, log_scalars_b)):
184
+ for k, v in get_params(model):
185
+ v.register_hook(create_hook(v, k, log_scalars))
186
+
187
+ # for model, path in ((model_a, "grads_a.safetensors"), (model_b, "grads_b.safetensors")):
188
+ # with safetensors.safe_open(path, "pt") as f:
189
+ # for k, v in get_params(model):
190
+ # if k in f.keys():
191
+ # v.grad = f.get_tensor(k).to(v)
192
+
193
+ noisy_latents = timesteps = time_ids = None
194
+ def get_pred(args):
195
+ nonlocal noisy_latents, timesteps, time_ids
196
+ model, tokens = args
197
+ txt = model.text_encoder_combined(tokens[0])
198
+ return model.unet(
199
+ noisy_latents,
200
+ timesteps,
201
+ encoder_hidden_states=txt["conds"],
202
+ added_cond_kwargs={
203
+ "text_embeds": txt["pooled"],
204
+ "time_ids": time_ids,
205
+ },
206
+ ).sample
207
+
208
+ params = list(v for k, v in itertools.chain(get_params(model_a), get_params(model_b)))
209
+ with ThreadPoolExecutor(max_workers=2) as worker:
210
+ for epoch_i in range(epochs):
211
+ for step_i, (latent_infos, tokens_a, tokens_b, post_ids) in enumerate(tqdm(data_loader)):
212
+ latents = torch.cat([latent_info["latent"] for latent_info in latent_infos], dim=0).to(device=device, dtype=dtype)
213
+ crop_hw = torch.stack([latent_info["crop_hw"] for latent_info in latent_infos]).to(device=device)
214
+ orig_hw = torch.stack([latent_info["orig_hw"] for latent_info in latent_infos]).to(device=device)
215
+
216
+ noise, noisy_latents, timesteps = get_noise_noisy_latents_and_timesteps(scheduler, latents)
217
+ time_ids = get_add_time_ids(orig_hw, crop_hw)
218
+
219
+ # if step_i < 1000:
220
+ # total_steps += batch_size
221
+ # continue
222
+
223
+ pred_a, pred_b = worker.map(get_pred, ((model_a, tokens_a), (model_b, tokens_b)))
224
+
225
+ mse = torch.nn.functional.mse_loss(pred_a, pred_b, reduction="none").flatten(start_dim=1).mean(dim=-1)
226
+ loss = (mse / mse.detach()).mean()
227
+
228
+ train_logger.add_scalar("grads/loss", loss.item(), total_steps)
229
+ train_logger.add_scalar("grads/loss_raw", mse.mean().item(), total_steps)
230
+ train_logger.add_scalar("grads/timestep", timesteps[0].item(), total_steps)
231
+
232
+ torch.autograd.grad(loss, params, retain_graph=False, allow_unused=True) # calls backward hooks
233
+
234
+ for (k, v_a), (k_b, v_b) in zip(get_params(model_a), get_params(model_b)):
235
+ assert k == k_b
236
+ if v_a.grad is not None and v_b.grad is not None:
237
+ while log_scalars_a[k] > log_scalars_b[k]:
238
+ log_scalars_a[k] -= 1
239
+ v_a.grad.ldexp_(n1)
240
+ while log_scalars_b[k] > log_scalars_a[k]:
241
+ log_scalars_b[k] -= 1
242
+ v_b.grad.ldexp_(n1)
243
+ log_scalars_sync[k] = log_scalars_a[k]
244
+
245
+ if (step_i + 1) % 10 == 0:
246
+ train_logger.add_scalar("grads/max_a", max(v.grad.max().item() for k, v in get_params(model_a) if v.grad is not None), total_steps)
247
+ train_logger.add_scalar("grads/max_b", max(v.grad.max().item() for k, v in get_params(model_b) if v.grad is not None), total_steps)
248
+
249
+ if (step_i + 1) % 1000 == 0:
250
+ save_grads(model_a, "grads_a.safetensors", first=True)
251
+ safetensors.torch.save_file(log_scalars_sync, "log_scalars.safetensors")
252
+ save_grads(model_b, "grads_b.safetensors", last=True)
253
+
254
+ total_steps += batch_size
255
+
256
+
257
+ def get_modules(model):
258
+ return itertools.chain(
259
+ prefix_iter(model.unet.named_modules(), "unet."),
260
+ prefix_iter(model.text_encoder.named_modules(), "text_encoder."),
261
+ prefix_iter(model.text_encoder_2.named_modules(), "text_encoder_2."),
262
+ )
263
+
264
+
265
+ def get_params(model):
266
+ return itertools.chain(
267
+ prefix_iter(model.unet.named_parameters(), "unet."),
268
+ prefix_iter(model.text_encoder.named_parameters(), "text_encoder."),
269
+ prefix_iter(model.text_encoder_2.named_parameters(), "text_encoder_2."),
270
+ )
271
+
272
+
273
+ def prefix_iter(item_iter, prefix):
274
+ return ((prefix + k, v) for k, v in item_iter)
275
+
276
+
277
+ def save_grads(model, path, first=False, last=False):
278
+ if first:
279
+ wait_for_lock_removal()
280
+
281
+ safetensors.torch.save_file(
282
+ {k: v.grad.cpu().contiguous() for k, v in get_params(model) if v.grad is not None},
283
+ path,
284
+ )
285
+
286
+ if last:
287
+ # Create a lock file to signal that new checkpoints have been saved
288
+ with open(LOCK_FILE, "w") as f:
289
+ f.write("pending download")
290
+ print("Checkpoint pair saved, lock file created.")
291
+
292
+
293
+ def wait_for_lock_removal(poll_interval=5):
294
+ """Wait until the lock file is removed by the local download script."""
295
+ while os.path.exists(LOCK_FILE):
296
+ time.sleep(poll_interval)
297
+
298
+
299
+ def create_scheduler(device: torch.device):
300
+ scheduler = diffusers.DDPMScheduler(
301
+ beta_start=0.00085,
302
+ beta_end=0.012,
303
+ beta_schedule="scaled_linear",
304
+ num_train_timesteps=1000,
305
+ clip_sample=False,
306
+ )
307
+
308
+ inv_snr = ((1-scheduler.alphas_cumprod) / scheduler.alphas_cumprod).to(device)
309
+ scheduler.inv_snr = inv_snr
310
+ scheduler.inv_snr_weights = inv_snr / inv_snr.sum()
311
+ return scheduler
312
+
313
+
314
+ def debiased_loss_scaling(timesteps, noise_scheduler):
315
+ return noise_scheduler.inv_snr[timesteps]
316
+
317
+
318
+ def get_noise_noisy_latents_and_timesteps(scheduler, latents):
319
+ batch_size = latents.shape[0]
320
+ noise = torch.randn_like(latents, device=latents.device)
321
+
322
+ timesteps = torch.multinomial(scheduler.inv_snr_weights, batch_size)
323
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
324
+ return noise, noisy_latents, timesteps
325
+
326
+
327
+ def get_add_time_ids(original_size, crops_coords_top_left):
328
+ add_time_ids = torch.cat([
329
+ original_size,
330
+ crops_coords_top_left,
331
+ torch.tensor([[1024]*2], device=original_size.device).expand(len(original_size), -1),
332
+ ], dim=1)
333
+
334
+ return add_time_ids
335
+
336
+
337
+ def get_data_loader(tokenizer, batch_size: int):
338
+ return torch.utils.data.DataLoader(
339
+ PromptDataset(tokenizer),
340
+ batch_size=batch_size,
341
+ shuffle=True,
342
+ collate_fn=lambda x: zip(*x),
343
+ )
344
+
345
+
346
+ @dataclasses.dataclass
347
+ class ArtistScore:
348
+ artist_tag: str
349
+ count: int
350
+
351
+
352
+ class PromptDataset(torch.utils.data.Dataset):
353
+ def __init__(self, tokenizer):
354
+ self.tokenizer = tokenizer
355
+ self.latent_paths = list(LATENTS_OUTPUT_DIR.iterdir())
356
+ with open(DANBOORU_ARTISTS_PATH, "r", encoding='utf-8') as f:
357
+ reader = csv.DictReader(f)
358
+ self.b_artists = [ArtistScore(r["trigger"], int(r["count"])) for r in reader if r["artist"] != "banned_artist"]
359
+ self.b_artists.sort(key=lambda t: t.count, reverse=True)
360
+ self.b_artist_scores = torch.tensor(list(map(lambda t: t.count, self.b_artists)), device=device, dtype=torch.float32)
361
+ self.b_artist_scores /= self.b_artist_scores.sum()
362
+
363
+ with open(E621_ARTISTS_PATH, "r", encoding='utf-8') as f:
364
+ reader = csv.DictReader(f,)
365
+ self.a_artists = self.b_artists + [ArtistScore(r["trigger"], int(r["count"])) for r in reader if r["artist"] not in ["conditional_dnp", "avoid_posting", "unknown_artist", "third-party_edit", "sound_warning", "anonymous_artist"]]
366
+ self.a_artists.sort(key=lambda t: t.count, reverse=True)
367
+ self.a_artist_scores = torch.tensor(list(map(lambda t: t.count, self.a_artists)), device=device, dtype=torch.float32)
368
+ self.a_artist_scores /= self.a_artist_scores.sum()
369
+
370
+ self.a_prefix = "masterpiece, best quality, newest, absurdres, highres, safe, "
371
+ self.b_suffix = ", masterpiece, high score, great score, absurdres"
372
+
373
+ def __len__(self):
374
+ return len(self.latent_paths)
375
+
376
+ def __getitem__(self, item):
377
+ post_id = self.latent_paths[item].stem
378
+ latent = safetensors.torch.load_file(LATENTS_OUTPUT_DIR / f"{post_id}.safetensors", device=str(device))
379
+ caption = (CAPTIONS_OUTPUT_DIR / f"{post_id}.txt").read_text()
380
+
381
+ caption_a = self.a_prefix + caption
382
+ caption_b = caption + self.b_suffix
383
+
384
+ if item % 2 == 0:
385
+ artist_a = self.a_artists[torch.multinomial(self.a_artist_scores, 1).item()]
386
+ caption_a = artist_a.artist_tag + ", " + caption_a
387
+ else:
388
+ artist_b = self.b_artists[torch.multinomial(self.b_artist_scores, 1).item()]
389
+ caption_b = artist_b.artist_tag + ", " + caption_b
390
+
391
+ tokens_a = self.tokenizer.chunk_tokens(self.tokenizer([caption_a.replace("),", ") ,")]))
392
+ tokens_b = self.tokenizer.chunk_tokens(self.tokenizer([caption_b.replace("),", ") ,")]))
393
+ return latent, tokens_a, tokens_b, post_id
394
+
395
+
396
+ class CombinedCLIPTextEncoder(torch.nn.Module):
397
+ def __init__(self, clip_l, clip_g, batch_size):
398
+ super().__init__()
399
+ assert batch_size == 1
400
+ self.clip_l = clip_l
401
+ self.clip_g = clip_g
402
+
403
+ def forward(self, tokens):
404
+ tokens_clip_l = tokens["clip_l"].copy()
405
+ del tokens_clip_l["prompt_starts"]
406
+
407
+ tokens_clip_g = tokens["clip_g"].copy()
408
+ clip_g_starts = tokens_clip_g.pop("prompt_starts")
409
+
410
+ clip_l_encoded = self.clip_l(**tokens_clip_l, output_hidden_states=True, return_dict=True)
411
+ clip_g_encoded = self.clip_g(**tokens_clip_g, output_hidden_states=True, return_dict=True)
412
+ combined_encoded = torch.cat([clip_l_encoded["hidden_states"][-2], clip_g_encoded["hidden_states"][-2]], dim=-1)
413
+ combined_encoded_reshape = combined_encoded.reshape(1, -1, 2048)
414
+
415
+ return {
416
+ "conds": combined_encoded_reshape,
417
+ "pooled": clip_g_encoded.text_embeds[clip_g_starts],
418
+ }
419
+
420
+
421
+ def create_tokenizer(device: torch.device):
422
+ tokenizer_l = transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
423
+ tokenizer_g = transformers.CLIPTokenizer.from_pretrained("laion/CLIP-ViT-g-14-laion2B-s34B-b88K")
424
+ return CombinedCLIPTokenizer(tokenizer_l, tokenizer_g, device)
425
+
426
+
427
+ class CombinedCLIPTokenizer(torch.nn.Module):
428
+ comma_token = 267
429
+
430
+ def __init__(self, tokenizer_l, tokenizer_g, output_device: torch.device):
431
+ super().__init__()
432
+ self.tokenizer_l = tokenizer_l
433
+ self.tokenizer_g = tokenizer_g
434
+ self.output_device = output_device
435
+
436
+ def forward(self, prompts: List[str]) -> dict:
437
+ tokens_l = self.tokenizer_l(prompts, add_special_tokens=False)
438
+ return {
439
+ "clip_l": tokens_l,
440
+ "clip_g": deepcopy(tokens_l),
441
+ }
442
+
443
+ def chunk_tokens(self, tokens: dict):
444
+ return {
445
+ "clip_l": self._chunk_tokens_impl(self.tokenizer_l, tokens["clip_l"]),
446
+ "clip_g": self._chunk_tokens_impl(self.tokenizer_g, tokens["clip_g"]),
447
+ }
448
+
449
+ def _chunk_tokens_impl(self, tokenizer, tokens: dict):
450
+ input_ids = []
451
+ attention_masks = []
452
+ chunk_counts = []
453
+
454
+ for prompt, mask in zip(tokens["input_ids"], tokens["attention_mask"]):
455
+ last_comma = 0
456
+ current_chunk = []
457
+ chunks = []
458
+ chunks_attn = []
459
+
460
+ def next_chunk():
461
+ nonlocal current_chunk
462
+ current_chunk = [tokenizer.bos_token_id] + current_chunk + [tokenizer.eos_token_id]
463
+ num_tokens = len(current_chunk)
464
+
465
+ current_chunk.extend([tokenizer.pad_token_id] * (77 - num_tokens))
466
+ chunks.append(current_chunk)
467
+ current_chunk = []
468
+ chunks_attn.append([1] * num_tokens + [0] * (77 - num_tokens))
469
+
470
+ for token_i, token in enumerate(prompt):
471
+ is_last_token = token_i == len(prompt) - 1
472
+ seq_suffix = prompt[last_comma:token_i + int(is_last_token)]
473
+
474
+ if token == self.comma_token or is_last_token:
475
+ if len(current_chunk) + len(seq_suffix) > 77 - 2: # leave space for bos and eos
476
+ next_chunk()
477
+ seq_suffix = prompt[last_comma+1:token_i + int(is_last_token)] # remove leading comma
478
+
479
+ # can always append, sequences without commas will never be longer than 77 tokens
480
+ current_chunk.extend(seq_suffix)
481
+ last_comma = token_i
482
+
483
+ if current_chunk or not chunks:
484
+ next_chunk()
485
+
486
+ chunk_counts.append(len(chunks))
487
+ input_ids.extend(chunks)
488
+ attention_masks.extend(chunks_attn)
489
+
490
+ return {
491
+ "input_ids": torch.tensor(input_ids, device=self.output_device),
492
+ "attention_mask": torch.tensor(attention_masks, device=self.output_device),
493
+ "prompt_starts": torch.tensor([0] + chunk_counts[:-1], device=self.output_device).cumsum(dim=0),
494
+ }
495
+
496
+
497
+ def shutdown_machine():
498
+ """Shutdown the machine. Adjust the command as necessary for your environment."""
499
+
500
+ wait_for_lock_removal()
501
+ print("All checkpoints have been downloaded. Shutting down the machine.")
502
+ try:
503
+ subprocess.run("runpodctl stop pod $RUNPOD_POD_ID", shell=True, check=True)
504
+ except Exception as e:
505
+ print(f"Error shutting down: {e}")
506
+
507
+
508
+ if __name__ == "__main__":
509
+ accumulate_grads()
510
+ shutdown_machine()
prepare_dataset.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import random
3
+ from copy import deepcopy
4
+ from typing import List
5
+ import diffusers
6
+ import torch
7
+ import safetensors.torch
8
+ import transformers
9
+ from PIL import Image
10
+ from diffusers import AutoencoderKL, StableDiffusionXLPipeline
11
+ import torchvision.transforms as T
12
+ from tqdm import tqdm
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED
14
+ import threading
15
+ import dataclasses
16
+
17
+
18
+ devices = [torch.device("cuda:0"), torch.device("cuda:1"), torch.device("cuda:2")]
19
+ dtypes = [torch.bfloat16, torch.float32, torch.float32]
20
+
21
+
22
+ VAE_PATH = "KBlueLeaf/EQ-SDXL-VAE"
23
+ SDXL_PATH = "/home/ljleb/sd/models/Stable-diffusion/noobaiXLNAIXL_epsilonPred11Version.safetensors"
24
+ IMAGES_DIR = pathlib.Path("/mnt/data/shared/danbooru")
25
+ LATENT_DIR = pathlib.Path("/mnt/data/shared/danbooru-latent")
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class Worker:
30
+ device: torch.device
31
+ dtype: torch.dtype
32
+ vae_w = None
33
+ sdxl = None
34
+ tokenizer = None
35
+
36
+ def __post_init__(self):
37
+ self.vae_w = AutoencoderKL.from_pretrained(VAE_PATH, torch_dtype=self.dtype).to(self.device)
38
+ self.vae_w.eval()
39
+
40
+ self.sdxl = StableDiffusionXLPipeline.from_single_file(SDXL_PATH, torch_dtype=self.dtype).to(self.device)
41
+ self.sdxl.unet.eval()
42
+ self.sdxl.vae.eval()
43
+ self.sdxl.text_encoder.eval()
44
+ self.sdxl.text_encoder_2.eval()
45
+
46
+ self.sdxl.text_encoder_combined = CombinedCLIPTextEncoder(self.sdxl.text_encoder, self.sdxl.text_encoder_2, self.device)
47
+ self.tokenizer = create_tokenizer(self.device)
48
+
49
+ self.scheduler = create_scheduler()
50
+
51
+
52
+ def main():
53
+ images = list(IMAGES_DIR.iterdir())
54
+ LATENT_DIR.mkdir(exist_ok=True)
55
+ workers = [
56
+ Worker(device, dtype)
57
+ for device, dtype in zip(devices, dtypes)
58
+ ]
59
+ with ThreadPoolExecutor(max_workers=len(workers)) as executor:
60
+ futures = {}
61
+ for image in tqdm(images):
62
+ if len(futures) >= len(workers):
63
+ completed_futures, _ = wait(list(futures.values()), return_when=FIRST_COMPLETED)
64
+ for future in completed_futures:
65
+ if future.exception() is not None:
66
+ for future_to_cancel in futures.values():
67
+ future_to_cancel.cancel()
68
+ raise future.exception()
69
+ else:
70
+ future.result()
71
+ futures = {
72
+ k: v for k, v in futures.items()
73
+ if v not in completed_futures
74
+ }
75
+
76
+ for worker in workers:
77
+ if worker.device not in futures:
78
+ futures[worker.device] = executor.submit(prepare_image, worker, image)
79
+ break
80
+
81
+ for future in futures.values():
82
+ if future.exception() is not None:
83
+ for future_to_cancel in futures.values():
84
+ future_to_cancel.cancel()
85
+ raise future.exception()
86
+ else:
87
+ future.result()
88
+
89
+
90
+ @torch.no_grad()
91
+ def prepare_image(worker: Worker, img_path: pathlib.Path):
92
+ # We'll define a transform to convert an image to a tensor
93
+ to_tensor = T.Compose([
94
+ T.ToTensor(),
95
+ T.Lambda(lambda t: t*2 - 1)
96
+ ])
97
+
98
+ # w_0_offset = torch.tensor([-3.8846, -1.3187, 0.8009, 0.9180], device=device, dtype=dtype)
99
+ # w_0_scale = torch.tensor([10.0298, 6.8674, 7.2104, 5.5948], device=device, dtype=dtype)
100
+
101
+ # Iterate over images in directory
102
+ if not img_path.is_file():
103
+ return
104
+ if img_path.suffix.lower() not in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff"]:
105
+ return
106
+
107
+ # Attempt to open image
108
+ try:
109
+ img = Image.open(img_path).convert("RGB")
110
+ except Exception as e:
111
+ print(f"Error loading image {img_path.name}: {e}")
112
+ return
113
+
114
+ # Read the caption from the matching .txt file (if it exists)
115
+ txt_path = img_path.with_suffix(img_path.suffix + ".txt")
116
+ if not txt_path.is_file():
117
+ print(f"No caption file for {img_path.name}, skipping.")
118
+ return
119
+
120
+ caption = txt_path.read_text(encoding="utf-8").strip()
121
+ if not caption:
122
+ print(f"Empty caption for {img_path.name}, skipping.")
123
+ return
124
+
125
+ out_path = LATENT_DIR / (img_path.stem + ".safetensors")
126
+ if out_path.exists():
127
+ return
128
+
129
+ caption = caption.replace("\n", " , ").replace("_", " ")
130
+
131
+ width, height = img.size
132
+ orig_pixels = width * height
133
+ target_pixels = 1024 * 1024
134
+ if orig_pixels > target_pixels:
135
+ scale = (target_pixels / float(orig_pixels)) ** 0.5
136
+ width = int(round(width * scale))
137
+ height = int(round(height * scale))
138
+ img = img.resize((width, height), Image.Resampling.LANCZOS)
139
+
140
+ tokens_raw = worker.tokenizer([caption])
141
+ tokens = worker.tokenizer.chunk_tokens(tokens_raw)
142
+
143
+ # Convert image to tensor on device
144
+ img_tensor = to_tensor(img).unsqueeze(0).to(device=worker.device, dtype=worker.dtype)
145
+
146
+ # Encode the image with each VAE
147
+ with torch.no_grad():
148
+ latents_w_unnorm = worker.vae_w.encode(img_tensor).latent_dist.sample()
149
+ latents_z = worker.sdxl.vae.encode(img_tensor).latent_dist.sample() * 0.13025
150
+
151
+ # Sample noise and a random timestep
152
+ noise, noisy_latents_z, timesteps = get_noise_noisy_latents_and_timesteps(worker.scheduler, latents_z)
153
+ time_ids = get_add_time_ids(height, width, worker.device)
154
+ embeds = worker.sdxl.text_encoder_combined(tokens)
155
+
156
+ epsilon_pred = get_pred(worker.sdxl, noisy_latents_z, embeds, timesteps, time_ids)
157
+
158
+ encoded = {
159
+ "timesteps": timesteps,
160
+ "hw": torch.tensor([[height, width]], dtype=torch.long),
161
+ "w_0_unnorm": latents_w_unnorm,
162
+ "z_0": latents_z,
163
+ "epsilon_pred": epsilon_pred,
164
+ "epsilon": noise,
165
+ "conds": embeds["conds"],
166
+ "pooled": embeds["pooled"],
167
+ }
168
+
169
+ safetensors.torch.save_file(encoded, str(out_path))
170
+
171
+
172
+ def get_add_time_ids(width, height, device):
173
+ original_size = torch.tensor([[width, height]], device=device)
174
+ add_time_ids = torch.cat([
175
+ original_size,
176
+ torch.tensor([[0]*2], device=device).expand(len(original_size), -1),
177
+ original_size,
178
+ ], dim=1)
179
+ return add_time_ids
180
+
181
+
182
+ def get_pred(sdxl, noisy_latents, embeds, timesteps, time_ids):
183
+ return sdxl.unet(
184
+ noisy_latents,
185
+ timesteps,
186
+ encoder_hidden_states=embeds["conds"],
187
+ added_cond_kwargs={
188
+ "text_embeds": embeds["pooled"],
189
+ "time_ids": time_ids,
190
+ },
191
+ ).sample
192
+
193
+
194
+ def get_noise_noisy_latents_and_timesteps(scheduler, latents):
195
+ noise = torch.randn_like(latents, device=latents.device)
196
+ batch_size = latents.shape[0]
197
+ timesteps = torch.randint(0, 999, (batch_size,), device=latents.device)
198
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
199
+ return noise, noisy_latents, timesteps
200
+
201
+
202
+ def create_scheduler():
203
+ scheduler = diffusers.DDPMScheduler(
204
+ beta_start=0.00085,
205
+ beta_end=0.012,
206
+ beta_schedule="scaled_linear",
207
+ num_train_timesteps=1000,
208
+ clip_sample=False,
209
+ )
210
+
211
+ return scheduler
212
+
213
+
214
+ class CombinedCLIPTextEncoder(torch.nn.Module):
215
+ def __init__(self, clip_l, clip_g, device):
216
+ super().__init__()
217
+ self.clip_l = clip_l.to(device=device)
218
+ self.clip_g = clip_g.to(device=device)
219
+ self.device = device
220
+
221
+ def forward(self, tokens_batch):
222
+ res = {
223
+ "conds": torch.tensor([], device=self.device).view(0, 1, 1),
224
+ "pooled": torch.tensor([], device=self.device).view(0, 1, 1),
225
+ }
226
+ tokens_clip_l = tokens_batch["clip_l"].copy()
227
+ del tokens_clip_l["prompt_starts"]
228
+
229
+ tokens_clip_g = tokens_batch["clip_g"].copy()
230
+ clip_g_starts = tokens_clip_g.pop("prompt_starts")
231
+
232
+ clip_l_encoded = self.clip_l(**tokens_clip_l, output_hidden_states=True, return_dict=True)
233
+ clip_g_encoded = self.clip_g(**tokens_clip_g, output_hidden_states=True, return_dict=True)
234
+ combined_encoded = torch.cat([clip_l_encoded["hidden_states"][-2], clip_g_encoded["hidden_states"][-2]], dim=-1)
235
+ combined_encoded_reshape = combined_encoded.reshape(1, -1, 2048)
236
+
237
+ res["conds"] = combined_encoded_reshape
238
+ res["pooled"] = clip_g_encoded.text_embeds[clip_g_starts]
239
+
240
+ return res
241
+
242
+
243
+ def create_tokenizer(device: torch.device):
244
+ tokenizer_l = transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
245
+ tokenizer_g = transformers.CLIPTokenizer.from_pretrained("laion/CLIP-ViT-g-14-laion2B-s34B-b88K")
246
+ return CombinedCLIPTokenizer(tokenizer_l, tokenizer_g, device)
247
+
248
+
249
+ class CombinedCLIPTokenizer(torch.nn.Module):
250
+ comma_token = 267
251
+
252
+ def __init__(self, tokenizer_l, tokenizer_g, output_device: torch.device):
253
+ super().__init__()
254
+ self.tokenizer_l = tokenizer_l
255
+ self.tokenizer_g = tokenizer_g
256
+ self.output_device = output_device
257
+
258
+ def forward(self, prompts: List[str]) -> dict:
259
+ tokens_l = self.tokenizer_l(prompts, add_special_tokens=False)
260
+ return {
261
+ "clip_l": tokens_l,
262
+ "clip_g": deepcopy(tokens_l),
263
+ }
264
+
265
+ def chunk_tokens(self, tokens: dict):
266
+ return {
267
+ "clip_l": self._chunk_tokens_impl(self.tokenizer_l, tokens["clip_l"]),
268
+ "clip_g": self._chunk_tokens_impl(self.tokenizer_g, tokens["clip_g"]),
269
+ }
270
+
271
+ def _chunk_tokens_impl(self, tokenizer, tokens: dict):
272
+ input_ids = []
273
+ attention_masks = []
274
+ chunk_counts = []
275
+
276
+ for prompt, mask in zip(tokens["input_ids"], tokens["attention_mask"]):
277
+ last_comma = 0
278
+ current_chunk = []
279
+ chunks = []
280
+ chunks_attn = []
281
+
282
+ def next_chunk():
283
+ nonlocal current_chunk
284
+ current_chunk = [tokenizer.bos_token_id] + current_chunk + [tokenizer.eos_token_id]
285
+ num_tokens = len(current_chunk)
286
+
287
+ current_chunk.extend([tokenizer.pad_token_id] * (77 - num_tokens))
288
+ chunks.append(current_chunk)
289
+ current_chunk = []
290
+ chunks_attn.append([1] * num_tokens + [0] * (77 - num_tokens))
291
+
292
+ for token_i, token in enumerate(prompt):
293
+ is_last_token = token_i == len(prompt) - 1
294
+ seq_suffix = prompt[last_comma:token_i + int(is_last_token)]
295
+
296
+ if token == self.comma_token or is_last_token:
297
+ if len(current_chunk) + len(seq_suffix) > 77 - 2: # leave space for bos and eos
298
+ next_chunk()
299
+ seq_suffix = prompt[last_comma+1:token_i + int(is_last_token)] # remove leading comma
300
+
301
+ # can always append, sequences without commas will never be longer than 77 tokens
302
+ current_chunk.extend(seq_suffix)
303
+ last_comma = token_i
304
+
305
+ if current_chunk or not chunks:
306
+ next_chunk()
307
+
308
+ chunk_counts.append(len(chunks))
309
+ input_ids.extend(chunks)
310
+ attention_masks.extend(chunks_attn)
311
+
312
+ return {
313
+ "input_ids": torch.tensor(input_ids, device=self.output_device),
314
+ "attention_mask": torch.tensor(attention_masks, device=self.output_device),
315
+ "prompt_starts": torch.tensor([0] + chunk_counts[:-1], device=self.output_device).cumsum(dim=0),
316
+ }
317
+
318
+
319
+ if __name__ == "__main__":
320
+ main()