recoilme commited on
Commit
c16e8a4
·
verified ·
1 Parent(s): f2ef28e

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 123456789.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jupyter Notebook
2
+ __pycache__/
3
+ *.pyc
4
+ .ipynb_checkpoints/
5
+ *.ipynb_checkpoints/*
6
+ .ipynb_checkpoints/*
7
+ src/samples
8
+ # cache
9
+ cache
10
+ datasets
11
+ test
12
+ wandb
13
+ nohup.out
14
+
123456789.jpg ADDED

Git LFS Details

  • SHA256: 131522c2f1db361170fb7f8819138893ccec8c1be544509b03aee277c3118e31
  • Pointer size: 131 Bytes
  • Size of remote file: 215 kB
down.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ TARGET_DIR="/workspace/d23"
4
+ mkdir -p "$TARGET_DIR"
5
+
6
+ BASE_URL="https://huggingface.co/datasets/AI-Art-Collab/dtasettar23/resolve/main/d23.tar."
7
+
8
+ (
9
+ # Устанавливаем `set -e` внутри subshell, чтобы он завершился при первой ошибке curl
10
+ set -e
11
+ # Попробуем от 'a' до 'z' для первого символа суффикса
12
+ for c1 in {a..z}; do
13
+ # Попробуем от 'a' до 'z' для второго символа суффикса
14
+ for c2 in {a..z}; do
15
+ suffix="${c1}${c2}"
16
+ url="${BASE_URL}${suffix}"
17
+ echo "Fetching: $url" >&2
18
+ # Качаем часть архива. --fail заставит curl завершиться с ошибкой, если файла нет.
19
+ curl -LsS --fail "$url"
20
+ done
21
+ done
22
+ ) 2>/dev/null | tar -xv -C "$TARGET_DIR" --wildcards '*.png'
23
+ # └─ 1 ─┘ └────────── 2 ──────────┘ └─────────── 3 ───────────┘
24
+
25
+ echo "Extraction of PNG files finished. Check $TARGET_DIR"
samples/sample_0.jpg ADDED
samples/sample_1.jpg ADDED
samples/sample_2.jpg ADDED
samples/sample_decoded.jpg ADDED
samples/sample_real.jpg ADDED
test.ipynb ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 7,
6
+ "id": "6ca10d55-03ed-4c8b-b32b-8d2f94d77162",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "The config attributes {'block_out_channels': [128, 256, 512, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "test log-variance: 0.065\n",
21
+ "Готово\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "import torch\n",
27
+ "from PIL import Image\n",
28
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
29
+ "from torchvision.transforms.functional import to_pil_image\n",
30
+ "import matplotlib.pyplot as plt\n",
31
+ "import os\n",
32
+ "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n",
33
+ "\n",
34
+ "# путь к вашей картинке\n",
35
+ "IMG_PATH = \"123456789.jpg\"\n",
36
+ "OUT_DIR = \"test\"\n",
37
+ "device = \"cuda\"\n",
38
+ "dtype = torch.float16 \n",
39
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
40
+ "\n",
41
+ "# список VAE\n",
42
+ "VAES = {\n",
43
+ " \"test\": \"/workspace/simple_vae2x\",\n",
44
+ "}\n",
45
+ "\n",
46
+ "def load_image(path):\n",
47
+ " img = Image.open(path).convert('RGB')\n",
48
+ " # обрезаем до кратности 8\n",
49
+ " w, h = img.size\n",
50
+ " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n",
51
+ " tensor = ToTensor()(img).unsqueeze(0) # [0,1]\n",
52
+ " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]\n",
53
+ " return img, tensor.to(device, dtype=dtype)\n",
54
+ "\n",
55
+ "# обратно в PIL\n",
56
+ "def tensor_to_img(t):\n",
57
+ " t = (t * 0.5 + 0.5).clamp(0, 1)\n",
58
+ " return to_pil_image(t[0])\n",
59
+ "\n",
60
+ "def logvariance(latents):\n",
61
+ " \"\"\"Возвращает лог-дисперсию по всем элементам.\"\"\"\n",
62
+ " return torch.log(latents.var() + 1e-8).item()\n",
63
+ "\n",
64
+ "def plot_latent_distribution(latents, title, save_path):\n",
65
+ " \"\"\"Гистограмма + QQ-plot.\"\"\"\n",
66
+ " lat = latents.detach().cpu().numpy().flatten()\n",
67
+ " plt.figure(figsize=(10, 4))\n",
68
+ "\n",
69
+ " # гистограмма\n",
70
+ " plt.subplot(1, 2, 1)\n",
71
+ " plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')\n",
72
+ " plt.title(f\"{title} histogram\")\n",
73
+ " plt.xlabel(\"latent value\")\n",
74
+ " plt.ylabel(\"density\")\n",
75
+ "\n",
76
+ " # QQ-plot\n",
77
+ " from scipy.stats import probplot\n",
78
+ " plt.subplot(1, 2, 2)\n",
79
+ " probplot(lat, dist=\"norm\", plot=plt)\n",
80
+ " plt.title(f\"{title} QQ-plot\")\n",
81
+ "\n",
82
+ " plt.tight_layout()\n",
83
+ " plt.savefig(save_path)\n",
84
+ " plt.close()\n",
85
+ "\n",
86
+ "for name, repo in VAES.items():\n",
87
+ " if name==\"test\":\n",
88
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder=\"vae\", torch_dtype=dtype).to(device)\n",
89
+ " else:\n",
90
+ " vae = AutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder=\"vae\", variant=\"fp16\"\n",
91
+ "\n",
92
+ " cfg = vae.config\n",
93
+ " scale = getattr(cfg, \"scaling_factor\", 1.)\n",
94
+ " shift = getattr(cfg, \"shift_factor\", 0.0)\n",
95
+ " mean = getattr(cfg, \"latents_mean\", None)\n",
96
+ " std = getattr(cfg, \"latents_std\", None)\n",
97
+ "\n",
98
+ " C = 16 # 4 для SDXL\n",
99
+ " if mean is not None:\n",
100
+ " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)\n",
101
+ " if std is not None:\n",
102
+ " std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)\n",
103
+ " if shift is not None:\n",
104
+ " shift = torch.tensor(shift, device=device, dtype=dtype)\n",
105
+ " else:\n",
106
+ " shift = 0.0 \n",
107
+ "\n",
108
+ " scale = torch.tensor(scale, device=device, dtype=dtype)\n",
109
+ "\n",
110
+ " img, x = load_image(IMG_PATH)\n",
111
+ " img.save(os.path.join(OUT_DIR, f\"original.jpg\"))\n",
112
+ "\n",
113
+ " with torch.no_grad():\n",
114
+ " # encode\n",
115
+ " latents = vae.encode(x).latent_dist.sample().to(dtype)\n",
116
+ " if mean is not None and std is not None:\n",
117
+ " latents = (latents - mean) / std\n",
118
+ " latents = latents * scale + shift\n",
119
+ "\n",
120
+ " lv = logvariance(latents)\n",
121
+ " print(f\"{name} log-variance: {lv:.3f}\")\n",
122
+ "\n",
123
+ " # график\n",
124
+ " plot_latent_distribution(latents, f\"{name}_latents\",\n",
125
+ " os.path.join(OUT_DIR, f\"dist_{name}.png\"))\n",
126
+ "\n",
127
+ " # decode\n",
128
+ " latents = (latents - shift) / scale\n",
129
+ " if mean is not None and std is not None:\n",
130
+ " latents = latents * std + mean\n",
131
+ " rec = vae.decode(latents).sample\n",
132
+ "\n",
133
+ " tensor_to_img(rec).save(os.path.join(OUT_DIR, f\"decoded_{name}.png\"))\n",
134
+ "\n",
135
+ "print(\"Готово\")\n"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 5,
141
+ "id": "5e930fd3-0aa5-4ed6-beab-e871df009125",
142
+ "metadata": {},
143
+ "outputs": [
144
+ {
145
+ "name": "stdout",
146
+ "output_type": "stream",
147
+ "text": [
148
+ "Collecting scipy\n",
149
+ " Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (62 kB)\n",
150
+ "Requirement already satisfied: numpy<2.6,>=1.25.2 in /usr/local/lib/python3.12/dist-packages (from scipy) (2.1.2)\n",
151
+ "Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.7 MB)\n",
152
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.7/35.7 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0mm0:00:01\u001b[0m00:01\u001b[0m\n",
153
+ "\u001b[?25hInstalling collected packages: scipy\n",
154
+ "Successfully installed scipy-1.16.2\n"
155
+ ]
156
+ }
157
+ ],
158
+ "source": [
159
+ "!pip install scipy"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "72785e98-5dad-48a3-809b-3ab9755ac9db",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": []
169
+ }
170
+ ],
171
+ "metadata": {
172
+ "kernelspec": {
173
+ "display_name": "Python 3 (ipykernel)",
174
+ "language": "python",
175
+ "name": "python3"
176
+ },
177
+ "language_info": {
178
+ "codemirror_mode": {
179
+ "name": "ipython",
180
+ "version": 3
181
+ },
182
+ "file_extension": ".py",
183
+ "mimetype": "text/x-python",
184
+ "name": "python",
185
+ "nbconvert_exporter": "python",
186
+ "pygments_lexer": "ipython3",
187
+ "version": "3.12.3"
188
+ }
189
+ },
190
+ "nbformat": 4,
191
+ "nbformat_minor": 5
192
+ }
train_vae.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from collections import deque
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/d23"
31
+ project = "vae"
32
+ batch_size = 4
33
+ base_learning_rate = 6e-6
34
+ min_learning_rate = 9e-7
35
+ num_epochs = 50
36
+ sample_interval_share = 10
37
+ use_wandb = True
38
+ save_model = True
39
+ use_decay = True
40
+ optimizer_type = "adam8bit"
41
+ dtype = torch.float32
42
+
43
+ model_resolution = 256
44
+ high_resolution = 512
45
+ limit = 0
46
+ save_barrier = 1.03
47
+ warmup_percent = 0.01
48
+ percentile_clipping = 99
49
+ beta2 = 0.997
50
+ eps = 1e-8
51
+ clip_grad_norm = 1.0
52
+ mixed_precision = "no"
53
+ gradient_accumulation_steps = 4
54
+ generated_folder = "samples"
55
+ save_as = "vae"
56
+ num_workers = 0
57
+ device = None
58
+
59
+ # --- Режимы обучения ---
60
+ # QWEN: учим только декодер
61
+ train_decoder_only = True
62
+ full_training = False # если True — учим весь VAE и добавляем KL (ниже)
63
+ kl_ratio = 0.00
64
+
65
+ # Доли лоссов
66
+ loss_ratios = {
67
+ "lpips": 0.75,
68
+ "edge": 0.05,
69
+ "mse": 0.10,
70
+ "mae": 0.10,
71
+ "kl": 0.00, # активируем при full_training=True
72
+ }
73
+ median_coeff_steps = 256
74
+
75
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
76
+
77
+ # QWEN: конфиг загрузки модели
78
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
79
+
80
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
81
+
82
+ accelerator = Accelerator(
83
+ mixed_precision=mixed_precision,
84
+ gradient_accumulation_steps=gradient_accumulation_steps
85
+ )
86
+ device = accelerator.device
87
+
88
+ # reproducibility
89
+ seed = int(datetime.now().strftime("%Y%m%d"))
90
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
91
+ torch.backends.cudnn.benchmark = False
92
+
93
+ # --------------------------- WandB ---------------------------
94
+ if use_wandb and accelerator.is_main_process:
95
+ wandb.init(project=project, config={
96
+ "batch_size": batch_size,
97
+ "base_learning_rate": base_learning_rate,
98
+ "num_epochs": num_epochs,
99
+ "optimizer_type": optimizer_type,
100
+ "model_resolution": model_resolution,
101
+ "high_resolution": high_resolution,
102
+ "gradient_accumulation_steps": gradient_accumulation_steps,
103
+ "train_decoder_only": train_decoder_only,
104
+ "full_training": full_training,
105
+ "kl_ratio": kl_ratio,
106
+ "vae_kind": vae_kind,
107
+ })
108
+
109
+ # --------------------------- VAE ---------------------------
110
+ def get_core_model(model):
111
+ m = model
112
+ # если модель уже обёрнута torch.compile
113
+ if hasattr(m, "_orig_mod"):
114
+ m = m._orig_mod
115
+ return m
116
+
117
+ def is_video_vae(model) -> bool:
118
+ # WAN/Qwen — это видео-VAEs
119
+ if vae_kind in ("wan", "qwen"):
120
+ return True
121
+ # fallback по структуре (если понадобится)
122
+ try:
123
+ core = get_core_model(model)
124
+ enc = getattr(core, "encoder", None)
125
+ conv_in = getattr(enc, "conv_in", None)
126
+ w = getattr(conv_in, "weight", None)
127
+ if isinstance(w, torch.nn.Parameter):
128
+ return w.ndim == 5
129
+ except Exception:
130
+ pass
131
+ return False
132
+
133
+ # загрузка
134
+ if vae_kind == "qwen":
135
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
136
+ else:
137
+ if vae_kind == "wan":
138
+ vae = AutoencoderKLWan.from_pretrained(project)
139
+ else:
140
+ # старое поведение (пример)
141
+ if model_resolution==high_resolution:
142
+ vae = AutoencoderKL.from_pretrained(project)
143
+ else:
144
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
145
+
146
+ vae = vae.to(dtype)
147
+
148
+ # torch.compile (опционально)
149
+ if hasattr(torch, "compile"):
150
+ try:
151
+ vae = torch.compile(vae)
152
+ except Exception as e:
153
+ print(f"[WARN] torch.compile failed: {e}")
154
+
155
+ # --------------------------- Freeze/Unfreeze ---------------------------
156
+ core = get_core_model(vae)
157
+
158
+ for p in core.parameters():
159
+ p.requires_grad = False
160
+
161
+ unfrozen_param_names = []
162
+
163
+ if full_training and not train_decoder_only:
164
+ for name, p in core.named_parameters():
165
+ p.requires_grad = True
166
+ unfrozen_param_names.append(name)
167
+ loss_ratios["kl"] = float(kl_ratio)
168
+ trainable_module = core
169
+ else:
170
+ # учим только декодер + post_quant_conv на "ядре" модели
171
+ if hasattr(core, "decoder"):
172
+ for name, p in core.decoder.named_parameters():
173
+ p.requires_grad = True
174
+ unfrozen_param_names.append(f"decoder.{name}")
175
+ if hasattr(core, "post_quant_conv"):
176
+ for name, p in core.post_quant_conv.named_parameters():
177
+ p.requires_grad = True
178
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
179
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
180
+
181
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
182
+ for nm in unfrozen_param_names[:200]:
183
+ print(" ", nm)
184
+
185
+ # --------------------------- Датасет ---------------------------
186
+ class PngFolderDataset(Dataset):
187
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
188
+ self.root_dir = root_dir
189
+ self.resolution = resolution
190
+ self.paths = []
191
+ for root, _, files in os.walk(root_dir):
192
+ for fname in files:
193
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
194
+ self.paths.append(os.path.join(root, fname))
195
+ if limit:
196
+ self.paths = self.paths[:limit]
197
+ valid = []
198
+ for p in self.paths:
199
+ try:
200
+ with Image.open(p) as im:
201
+ im.verify()
202
+ valid.append(p)
203
+ except (OSError, UnidentifiedImageError):
204
+ continue
205
+ self.paths = valid
206
+ if len(self.paths) == 0:
207
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
208
+ random.shuffle(self.paths)
209
+
210
+ def __len__(self):
211
+ return len(self.paths)
212
+
213
+ def __getitem__(self, idx):
214
+ p = self.paths[idx % len(self.paths)]
215
+ with Image.open(p) as img:
216
+ img = img.convert("RGB")
217
+ if not resize_long_side or resize_long_side <= 0:
218
+ return img
219
+ w, h = img.size
220
+ long = max(w, h)
221
+ if long <= resize_long_side:
222
+ return img
223
+ scale = resize_long_side / float(long)
224
+ new_w = int(round(w * scale))
225
+ new_h = int(round(h * scale))
226
+ return img.resize((new_w, new_h), Image.LANCZOS)
227
+
228
+ def random_crop(img, sz):
229
+ w, h = img.size
230
+ if w < sz or h < sz:
231
+ img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
232
+ x = random.randint(0, max(1, img.width - sz))
233
+ y = random.randint(0, max(1, img.height - sz))
234
+ return img.crop((x, y, x + sz, y + sz))
235
+
236
+ tfm = transforms.Compose([
237
+ transforms.ToTensor(),
238
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
239
+ ])
240
+
241
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
242
+ if len(dataset) < batch_size:
243
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
244
+
245
+ def collate_fn(batch):
246
+ imgs = []
247
+ for img in batch:
248
+ img = random_crop(img, high_resolution)
249
+ imgs.append(tfm(img))
250
+ return torch.stack(imgs)
251
+
252
+ dataloader = DataLoader(
253
+ dataset,
254
+ batch_size=batch_size,
255
+ shuffle=True,
256
+ collate_fn=collate_fn,
257
+ num_workers=num_workers,
258
+ pin_memory=True,
259
+ drop_last=True
260
+ )
261
+
262
+ # --------------------------- Оптимизатор ---------------------------
263
+ def get_param_groups(module, weight_decay=0.001):
264
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
265
+ decay_params, no_decay_params = [], []
266
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
267
+ if not p.requires_grad:
268
+ continue
269
+ if any(nd in n for nd in no_decay):
270
+ no_decay_params.append(p)
271
+ else:
272
+ decay_params.append(p)
273
+ return [
274
+ {"params": decay_params, "weight_decay": weight_decay},
275
+ {"params": no_decay_params, "weight_decay": 0.0},
276
+ ]
277
+
278
+ def get_param_groups(module, weight_decay=0.001):
279
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
280
+ decay_params, no_decay_params = [], []
281
+ for n, p in module.named_parameters():
282
+ if not p.requires_grad:
283
+ continue
284
+ n_l = n.lower()
285
+ if any(t in n_l for t in no_decay_tokens):
286
+ no_decay_params.append(p)
287
+ else:
288
+ decay_params.append(p)
289
+ return [
290
+ {"params": decay_params, "weight_decay": weight_decay},
291
+ {"params": no_decay_params, "weight_decay": 0.0},
292
+ ]
293
+
294
+ def create_optimizer(name, param_groups):
295
+ if name == "adam8bit":
296
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
297
+ raise ValueError(name)
298
+
299
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
300
+ optimizer = create_optimizer(optimizer_type, param_groups)
301
+
302
+ # --------------------------- LR schedule ---------------------------
303
+ batches_per_epoch = len(dataloader)
304
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
305
+ total_steps = steps_per_epoch * num_epochs
306
+
307
+ def lr_lambda(step):
308
+ if not use_decay:
309
+ return 1.0
310
+ x = float(step) / float(max(1, total_steps))
311
+ warmup = float(warmup_percent)
312
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
313
+ if x < warmup:
314
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
315
+ decay_ratio = (x - warmup) / (1.0 - warmup)
316
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
317
+
318
+ scheduler = LambdaLR(optimizer, lr_lambda)
319
+
320
+ # Подготовка
321
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
322
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
323
+
324
+ # --------------------------- LPIPS и вспомогательные ---------------------------
325
+ _lpips_net = None
326
+ def _get_lpips():
327
+ global _lpips_net
328
+ if _lpips_net is None:
329
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
330
+ return _lpips_net
331
+
332
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
333
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
334
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
335
+ C = x.shape[1]
336
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
337
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
338
+ gx = F.conv2d(x, kx, padding=1, groups=C)
339
+ gy = F.conv2d(x, ky, padding=1, groups=C)
340
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
341
+
342
+ class MedianLossNormalizer:
343
+ def __init__(self, desired_ratios: dict, window_steps: int):
344
+ s = sum(desired_ratios.values())
345
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
346
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
347
+ self.window = window_steps
348
+
349
+ def update_and_total(self, abs_losses: dict):
350
+ for k, v in abs_losses.items():
351
+ if k in self.buffers:
352
+ self.buffers[k].append(float(v.detach().abs().cpu()))
353
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
354
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
355
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
356
+ return total, coeffs, meds
357
+
358
+ if full_training and not train_decoder_only:
359
+ loss_ratios["kl"] = float(kl_ratio)
360
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
361
+
362
+ # --------------------------- Сэмплы ---------------------------
363
+ @torch.no_grad()
364
+ def get_fixed_samples(n=3):
365
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
366
+ pil_imgs = [dataset[i] for i in idx]
367
+ tensors = []
368
+ for img in pil_imgs:
369
+ img = random_crop(img, high_resolution)
370
+ tensors.append(tfm(img))
371
+ return torch.stack(tensors).to(accelerator.device, dtype)
372
+
373
+ fixed_samples = get_fixed_samples()
374
+
375
+ @torch.no_grad()
376
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
377
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
378
+ return Image.fromarray(arr)
379
+
380
+ @torch.no_grad()
381
+ def generate_and_save_samples(step=None):
382
+ try:
383
+ temp_vae = accelerator.unwrap_model(vae).eval()
384
+ lpips_net = _get_lpips()
385
+ with torch.no_grad():
386
+ orig_high = fixed_samples
387
+ orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
388
+ model_dtype = next(temp_vae.parameters()).dtype
389
+ orig_low = orig_low.to(dtype=model_dtype)
390
+
391
+ # QWEN: добавляем T=1 на encode/decode и снимаем при сравнении
392
+ if is_video_vae(temp_vae):
393
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
394
+ enc = temp_vae.encode(x_in)
395
+ latents_mean = enc.latent_dist.mean
396
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
397
+ rec = dec.squeeze(2) # [B,3,H,W]
398
+ else:
399
+ enc = temp_vae.encode(orig_low)
400
+ latents_mean = enc.latent_dist.mean
401
+ rec = temp_vae.decode(latents_mean).sample
402
+
403
+ if rec.shape[-2:] != orig_high.shape[-2:]:
404
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
405
+
406
+ first_real = _to_pil_uint8(orig_high[0])
407
+ first_dec = _to_pil_uint8(rec[0])
408
+ first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
409
+ first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
410
+
411
+ for i in range(rec.shape[0]):
412
+ _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
413
+
414
+ lpips_scores = []
415
+ for i in range(rec.shape[0]):
416
+ orig_full = orig_high[i:i+1].to(torch.float32)
417
+ rec_full = rec[i:i+1].to(torch.float32)
418
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
419
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
420
+ lpips_val = lpips_net(orig_full, rec_full).item()
421
+ lpips_scores.append(lpips_val)
422
+ avg_lpips = float(np.mean(lpips_scores))
423
+
424
+ if use_wandb and accelerator.is_main_process:
425
+ wandb.log({"lpips_mean": avg_lpips}, step=step)
426
+ wandb.log({
427
+ "sample/real": wandb.Image(first_real, caption="real"),
428
+ "sample/decoded": wandb.Image(first_dec, caption="decoded"),
429
+ }, step=step)
430
+ finally:
431
+ gc.collect()
432
+ torch.cuda.empty_cache()
433
+
434
+ if accelerator.is_main_process and save_model:
435
+ print("Генерация сэмплов до старта обучения...")
436
+ generate_and_save_samples(0)
437
+
438
+ accelerator.wait_for_everyone()
439
+
440
+ # --------------------------- Тренировка ---------------------------
441
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
442
+ global_step = 0
443
+ min_loss = float("inf")
444
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
445
+
446
+ for epoch in range(num_epochs):
447
+ vae.train()
448
+ batch_losses, batch_grads = [], []
449
+ track_losses = {k: [] for k in loss_ratios.keys()}
450
+
451
+ for imgs in dataloader:
452
+ with accelerator.accumulate(vae):
453
+ imgs = imgs.to(accelerator.device)
454
+
455
+ if high_resolution != model_resolution:
456
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
457
+ else:
458
+ imgs_low = imgs
459
+
460
+ model_dtype = next(vae.parameters()).dtype
461
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
462
+
463
+ # QWEN: encode/decode с T=1
464
+ if is_video_vae(vae):
465
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
466
+ enc = vae.encode(x_in)
467
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
468
+ dec = vae.decode(latents).sample # [B,3,1,H,W]
469
+ rec = dec.squeeze(2) # [B,3,H,W]
470
+ else:
471
+ enc = vae.encode(imgs_low_model)
472
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
473
+ rec = vae.decode(latents).sample
474
+
475
+ if rec.shape[-2:] != imgs.shape[-2:]:
476
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
477
+
478
+ rec_f32 = rec.to(torch.float32)
479
+ imgs_f32 = imgs.to(torch.float32)
480
+
481
+ abs_losses = {
482
+ "mae": F.l1_loss(rec_f32, imgs_f32),
483
+ "mse": F.mse_loss(rec_f32, imgs_f32),
484
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
485
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
486
+ }
487
+
488
+ if full_training and not train_decoder_only:
489
+ mean = enc.latent_dist.mean
490
+ logvar = enc.latent_dist.logvar
491
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
492
+ abs_losses["kl"] = kl
493
+ else:
494
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
495
+
496
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
497
+
498
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
499
+ raise RuntimeError("NaN/Inf loss")
500
+
501
+ accelerator.backward(total_loss)
502
+
503
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
504
+ if accelerator.sync_gradients:
505
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
506
+ optimizer.step()
507
+ scheduler.step()
508
+ optimizer.zero_grad(set_to_none=True)
509
+ global_step += 1
510
+ progress.update(1)
511
+
512
+ if accelerator.is_main_process:
513
+ try:
514
+ current_lr = optimizer.param_groups[0]["lr"]
515
+ except Exception:
516
+ current_lr = scheduler.get_last_lr()[0]
517
+
518
+ batch_losses.append(total_loss.detach().item())
519
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
520
+ for k, v in abs_losses.items():
521
+ track_losses[k].append(float(v.detach().item()))
522
+
523
+ if use_wandb and accelerator.sync_gradients:
524
+ log_dict = {
525
+ "total_loss": float(total_loss.detach().item()),
526
+ "learning_rate": current_lr,
527
+ "epoch": epoch,
528
+ "grad_norm": batch_grads[-1],
529
+ }
530
+ for k, v in abs_losses.items():
531
+ log_dict[f"loss_{k}"] = float(v.detach().item())
532
+ for k in coeffs:
533
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
534
+ log_dict[f"median_{k}"] = float(meds[k])
535
+ wandb.log(log_dict, step=global_step)
536
+
537
+ if global_step > 0 and global_step % sample_interval == 0:
538
+ if accelerator.is_main_process:
539
+ generate_and_save_samples(global_step)
540
+ accelerator.wait_for_everyone()
541
+
542
+ n_micro = sample_interval * gradient_accumulation_steps
543
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
544
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
545
+
546
+ if accelerator.is_main_process:
547
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
548
+ if save_model and avg_loss < min_loss * save_barrier:
549
+ min_loss = avg_loss
550
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
551
+ if use_wandb:
552
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
553
+
554
+ if accelerator.is_main_process:
555
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
556
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
557
+ if use_wandb:
558
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
559
+
560
+ # --------------------------- Финальное сохранение ---------------------------
561
+ if accelerator.is_main_process:
562
+ print("Training finished – saving final model")
563
+ if save_model:
564
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
565
+
566
+ accelerator.free_memory()
567
+ if torch.distributed.is_initialized():
568
+ torch.distributed.destroy_process_group()
569
+ print("Готово!")
transfer_simplevae.ipynb ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "c15deb04-94a0-4073-a174-adcd22af10b8",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "✅ Создана новая модель: <class 'diffusers.models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL'>\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stderr",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "The config attributes {'block_out_channels': [128, 256, 512, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
21
+ ]
22
+ },
23
+ {
24
+ "name": "stdout",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "\n",
28
+ "--- Перенос весов ---\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stderr",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "100%|██████████| 248/248 [00:00<00:00, 142199.23it/s]\n"
36
+ ]
37
+ },
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "\n",
43
+ "✅ Перенос завершён.\n",
44
+ "Статистика:\n",
45
+ " перенесено: 142\n",
46
+ " дублировано: 26\n",
47
+ " сдвинуто: 106\n",
48
+ " пропущено: 0\n",
49
+ "\n",
50
+ "Неперенесённые ключи (первые 20):\n",
51
+ " decoder.condition_encoder.layers.0.weight\n",
52
+ " decoder.condition_encoder.layers.0.bias\n",
53
+ " decoder.condition_encoder.layers.1.weight\n",
54
+ " decoder.condition_encoder.layers.1.bias\n",
55
+ " decoder.condition_encoder.layers.2.weight\n",
56
+ " decoder.condition_encoder.layers.2.bias\n",
57
+ " decoder.condition_encoder.layers.3.weight\n",
58
+ " decoder.condition_encoder.layers.3.bias\n",
59
+ " decoder.condition_encoder.layers.4.weight\n",
60
+ " decoder.condition_encoder.layers.4.bias\n"
61
+ ]
62
+ }
63
+ ],
64
+ "source": [
65
+ "from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL\n",
66
+ "import torch\n",
67
+ "from tqdm import tqdm\n",
68
+ "\n",
69
+ "# ---- Конфиг новой модели ----\n",
70
+ "config = {\n",
71
+ " \"_class_name\": \"AsymmetricAutoencoderKL\",\n",
72
+ " \"act_fn\": \"silu\",\n",
73
+ " \"in_channels\": 3,\n",
74
+ " \"out_channels\": 3,\n",
75
+ " \"scaling_factor\": 1.0,\n",
76
+ " \"norm_num_groups\": 32,\n",
77
+ " \"down_block_out_channels\": [128, 256, 512, 512],\n",
78
+ " \"down_block_types\": [\n",
79
+ " \"DownEncoderBlock2D\",\n",
80
+ " \"DownEncoderBlock2D\",\n",
81
+ " \"DownEncoderBlock2D\",\n",
82
+ " \"DownEncoderBlock2D\",\n",
83
+ " ],\n",
84
+ " \"latent_channels\": 16,\n",
85
+ " \"up_block_out_channels\": [128, 256, 512, 512, 512], # +1 блок\n",
86
+ " \"up_block_types\": [\n",
87
+ " \"UpDecoderBlock2D\",\n",
88
+ " \"UpDecoderBlock2D\",\n",
89
+ " \"UpDecoderBlock2D\",\n",
90
+ " \"UpDecoderBlock2D\",\n",
91
+ " \"UpDecoderBlock2D\",\n",
92
+ " ],\n",
93
+ "}\n",
94
+ "\n",
95
+ "# ---- Создание пустой асимметричной модели ----\n",
96
+ "vae = AsymmetricAutoencoderKL(\n",
97
+ " act_fn=config[\"act_fn\"],\n",
98
+ " down_block_out_channels=config[\"down_block_out_channels\"],\n",
99
+ " down_block_types=config[\"down_block_types\"],\n",
100
+ " latent_channels=config[\"latent_channels\"],\n",
101
+ " up_block_out_channels=config[\"up_block_out_channels\"],\n",
102
+ " up_block_types=config[\"up_block_types\"],\n",
103
+ " in_channels=config[\"in_channels\"],\n",
104
+ " out_channels=config[\"out_channels\"],\n",
105
+ " scaling_factor=config[\"scaling_factor\"],\n",
106
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
107
+ " layers_per_down_block=2,\n",
108
+ " layers_per_up_block = 2,\n",
109
+ " sample_size = 1024\n",
110
+ ")\n",
111
+ "\n",
112
+ "vae.save_pretrained(\"asymmetric_vae_empty\")\n",
113
+ "print(\"✅ Создана новая модель:\", type(vae))\n",
114
+ "\n",
115
+ "# ---- Функция переноса весов ----\n",
116
+ "def transfer_weights_with_duplication(old_path, new_path, save_path=\"asymmetric_vae\", device=\"cuda\", dtype=torch.float16):\n",
117
+ " old_vae = AutoencoderKL.from_pretrained(old_path,subfolder=\"vae\").to(device, dtype=dtype)\n",
118
+ " new_vae = AsymmetricAutoencoderKL.from_pretrained(new_path).to(device, dtype=dtype)\n",
119
+ "\n",
120
+ " old_sd = old_vae.state_dict()\n",
121
+ " new_sd = new_vae.state_dict()\n",
122
+ "\n",
123
+ " transferred_keys = set()\n",
124
+ " transfer_stats = {\"перенесено\": 0, \"дублировано\": 0, \"сдвинуто\": 0, \"пропущено\": 0}\n",
125
+ "\n",
126
+ " print(\"\\n--- Перенос весов ---\")\n",
127
+ "\n",
128
+ " for k, v in tqdm(old_sd.items()):\n",
129
+ " # === Копирование энкодера ===\n",
130
+ " if \"encoder\" in k or \"quant_conv\" in k or \"post_quant_conv\" in k:\n",
131
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
132
+ " new_sd[k] = v.clone()\n",
133
+ " transferred_keys.add(k)\n",
134
+ " transfer_stats[\"перенесено\"] += 1\n",
135
+ " continue\n",
136
+ "\n",
137
+ " # === Перенос декодера ===\n",
138
+ " if \"decoder.up_blocks\" in k:\n",
139
+ " parts = k.split(\".\")\n",
140
+ " idx = int(parts[2])\n",
141
+ "\n",
142
+ " # сдвигаем индекс на +1 (так как добавлен новый блок в начало)\n",
143
+ " new_idx = idx + 1\n",
144
+ " new_k = \".\".join([parts[0], parts[1], str(new_idx)] + parts[3:])\n",
145
+ " if new_k in new_sd and new_sd[new_k].shape == v.shape:\n",
146
+ " new_sd[new_k] = v.clone()\n",
147
+ " transferred_keys.add(new_k)\n",
148
+ " transfer_stats[\"сдвинуто\"] += 1\n",
149
+ " continue\n",
150
+ "\n",
151
+ " # === Перенос прочих совпадающих ключей ===\n",
152
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
153
+ " new_sd[k] = v.clone()\n",
154
+ " transferred_keys.add(k)\n",
155
+ " transfer_stats[\"перенесено\"] += 1\n",
156
+ "\n",
157
+ " # === Дублирование весов старого 512→512 блока в новый ===\n",
158
+ " ref_prefix = \"decoder.up_blocks.1\" # старый первый up-блок (512→512)\n",
159
+ " new_prefix = \"decoder.up_blocks.0\" # новый добавленный блок\n",
160
+ " for k, v in old_sd.items():\n",
161
+ " if k.startswith(ref_prefix):\n",
162
+ " new_k = k.replace(ref_prefix, new_prefix)\n",
163
+ " if new_k in new_sd and new_sd[new_k].shape == v.shape:\n",
164
+ " new_sd[new_k] = v.clone()\n",
165
+ " transferred_keys.add(new_k)\n",
166
+ " transfer_stats[\"дублировано\"] += 1\n",
167
+ "\n",
168
+ " # === Загрузка и сохранение ===\n",
169
+ " new_vae.load_state_dict(new_sd, strict=False)\n",
170
+ " new_vae.save_pretrained(save_path)\n",
171
+ "\n",
172
+ " print(\"\\n✅ Перенос завершён.\")\n",
173
+ " print(\"Статистика:\")\n",
174
+ " for k, v in transfer_stats.items():\n",
175
+ " print(f\" {k}: {v}\")\n",
176
+ "\n",
177
+ " missing = [k for k in new_sd.keys() if k not in transferred_keys]\n",
178
+ " if missing:\n",
179
+ " print(\"\\nНеперенесённые ключи (первые 20):\")\n",
180
+ " for k in missing[:20]:\n",
181
+ " print(\" \", k)\n",
182
+ "\n",
183
+ "# ---- Запуск ----\n",
184
+ "transfer_weights_with_duplication(\"AiArtLab/simplevae\", \"asymmetric_vae_empty\", save_path=\"vae\")\n"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 8,
190
+ "id": "65653a65-e7c2-4b67-bc17-62c21cfd1db8",
191
+ "metadata": {},
192
+ "outputs": [
193
+ {
194
+ "name": "stdout",
195
+ "output_type": "stream",
196
+ "text": [
197
+ "Collecting hf_transfer\n",
198
+ " Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)\n",
199
+ "Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\n",
200
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m34.5 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n",
201
+ "\u001b[?25hInstalling collected packages: hf_transfer\n",
202
+ "Successfully installed hf_transfer-0.1.9\n"
203
+ ]
204
+ }
205
+ ],
206
+ "source": [
207
+ "!pip install hf_transfer"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "id": "59fcafb9-6d89-49b4-8362-b4891f591687",
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": []
217
+ }
218
+ ],
219
+ "metadata": {
220
+ "kernelspec": {
221
+ "display_name": "Python 3 (ipykernel)",
222
+ "language": "python",
223
+ "name": "python3"
224
+ },
225
+ "language_info": {
226
+ "codemirror_mode": {
227
+ "name": "ipython",
228
+ "version": 3
229
+ },
230
+ "file_extension": ".py",
231
+ "mimetype": "text/x-python",
232
+ "name": "python",
233
+ "nbconvert_exporter": "python",
234
+ "pygments_lexer": "ipython3",
235
+ "version": "3.12.3"
236
+ }
237
+ },
238
+ "nbformat": 4,
239
+ "nbformat_minor": 5
240
+ }
untitled.txt ADDED
File without changes
vae/config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.35.1",
4
+ "_name_or_path": "vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512,
11
+ 512
12
+ ],
13
+ "down_block_out_channels": [
14
+ 128,
15
+ 256,
16
+ 512,
17
+ 512
18
+ ],
19
+ "down_block_types": [
20
+ "DownEncoderBlock2D",
21
+ "DownEncoderBlock2D",
22
+ "DownEncoderBlock2D",
23
+ "DownEncoderBlock2D"
24
+ ],
25
+ "force_upcast": false,
26
+ "in_channels": 3,
27
+ "latent_channels": 16,
28
+ "layers_per_down_block": 2,
29
+ "layers_per_up_block": 2,
30
+ "norm_num_groups": 32,
31
+ "out_channels": 3,
32
+ "sample_size": 1024,
33
+ "scaling_factor": 1.0,
34
+ "up_block_out_channels": [
35
+ 128,
36
+ 256,
37
+ 512,
38
+ 512,
39
+ 512
40
+ ],
41
+ "up_block_types": [
42
+ "UpDecoderBlock2D",
43
+ "UpDecoderBlock2D",
44
+ "UpDecoderBlock2D",
45
+ "UpDecoderBlock2D",
46
+ "UpDecoderBlock2D"
47
+ ]
48
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8aef462535483be8d283418a62726921310dd8adcd60ee8418ebea5836316627
3
+ size 444559412