onethousand commited on
Commit
594b244
·
verified ·
1 Parent(s): dd29c80

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/controlnet.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/prompt.png filter=lfs diff=lfs merge=lfs -text
assets/controlnet.png ADDED

Git LFS Details

  • SHA256: 042fe3599be2694a3da27d8b1d221dd63c3e9b3a0afaca3d725697ff51185f00
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
assets/face_normal.png ADDED
assets/face_seg.png ADDED
assets/left_eye_normal.png ADDED
assets/left_eye_seg.png ADDED
assets/mouth_normal.png ADDED
assets/mouth_seg.png ADDED
assets/prompt.png ADDED

Git LFS Details

  • SHA256: a6015444a0dbea964342011ea32b0c70166320cb45ac3c8b678de7c5fa6332cd
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.31.0.dev0",
4
+ "_name_or_path": "/media/yiqian/data/datasets/controlnet-training-runs/all/checkpoint-32000/controlnet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 4,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "upcast_attention": false,
51
+ "use_linear_projection": false
52
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ddf77299c5ef0f4eb3e1a2e2955553a5b4196821f397b94c48c6683549bfcd4
3
+ size 1445157696
script/dataset_AnimPortrait3D_controlnet.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import json
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.data.dataset import Dataset
7
+ #from torchvision.io import read_image
8
+ from PIL import Image
9
+ import os
10
+ import torch
11
+ import torchvision.transforms.functional as F
12
+ def tokenize_captions( caption, tokenizer):
13
+ captions = [caption]
14
+ inputs = tokenizer(
15
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
16
+ )
17
+ # tokenizer(prompt, padding='max_length',
18
+ # max_length=self.tokenizer.model_max_length, return_tensors='pt')
19
+ return inputs.input_ids
20
+
21
+
22
+
23
+
24
+ class SquarePad:
25
+ def __call__(self, image ):
26
+ w, h = image.size
27
+ max_wh = max(w, h)
28
+ hp = int((max_wh - w) / 2)
29
+ vp = int((max_wh - h) / 2)
30
+ padding = (hp, vp, hp, vp)
31
+ return F.pad(image, padding, (255,255,255), 'constant')
32
+
33
+ class NormalSegDataset(Dataset):
34
+ def __init__(self,args, path,tokenizer,cfg_prob ):
35
+
36
+
37
+ self.image_transforms = transforms.Compose(
38
+ [
39
+ # transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
40
+ # SquarePad(),
41
+ # transforms.Pad( (200,100,200,300),fill=(255,255,255),padding_mode='constant'),
42
+ # transforms.RandomRotation(degrees=30, fill=(255, 255, 255)) ,
43
+ transforms.RandomResizedCrop(args.resolution, scale=(0.9, 1.0), interpolation=2, ),
44
+ transforms.ToTensor(),
45
+ ]
46
+ )
47
+
48
+ self.additional_image_transforms = transforms.Compose(
49
+ [transforms.Normalize([0.5], [0.5]),]
50
+ )
51
+
52
+
53
+ meta_path = os.path.join(path, 'meta_train_seg.json')
54
+
55
+ with open(meta_path, 'r') as f:
56
+ self.meta = json.load(f)
57
+
58
+
59
+
60
+ self.keys = self.meta['keys']
61
+ self.meta = self.meta['data']
62
+
63
+
64
+ self.tokenizer = tokenizer
65
+
66
+ self.cfg_prob = cfg_prob
67
+
68
+ def __len__(self):
69
+ return len(self.keys)
70
+
71
+ def __getitem__(self, index):
72
+
73
+ meta_data = self.meta[self.keys[index]]
74
+
75
+ rgb_path = meta_data['rgb']
76
+ normal_path = meta_data['normal']
77
+ seg_path = meta_data['seg']
78
+ text_prompt = meta_data['caption'][0]
79
+
80
+ rand = torch.rand(1).item()
81
+ if rand < self.cfg_prob:
82
+ text_prompt = ""
83
+
84
+ image = Image.open(rgb_path).convert("RGB")
85
+ state = torch.get_rng_state()
86
+ image = self.image_transforms(image)
87
+
88
+ rand = torch.rand(1).item()
89
+ if rand < self.cfg_prob:
90
+ # get a white image
91
+ # print("white image")
92
+ normal_image = Image.new('RGB', (image.shape[1], image.shape[2]), (255, 255, 255))
93
+ # gray_image = Image.new('L', (image.shape[1], image.shape[2]), (255))
94
+ seg_image = Image.new('L', (image.shape[1], image.shape[2]), (0))
95
+ else:
96
+ normal_image = Image.open(normal_path).convert("RGB")
97
+ seg_image = Image.open(seg_path).convert("L")
98
+ torch.set_rng_state(state)
99
+ normal_image = self.image_transforms(normal_image)
100
+
101
+ torch.set_rng_state(state)
102
+ seg_image = self.image_transforms(seg_image)
103
+
104
+
105
+ conditioning_image = torch.cat([normal_image, seg_image], dim=0)
106
+
107
+ image = self.additional_image_transforms(image)
108
+
109
+ prompt = text_prompt
110
+
111
+
112
+
113
+
114
+ prompt = tokenize_captions(prompt, self.tokenizer)
115
+
116
+ return image, conditioning_image, prompt, text_prompt
117
+
118
+
script/face_normal.png ADDED
script/face_seg.png ADDED
script/run.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+
4
+ accelerate launch train.py \
5
+ --pretrained_model_name_or_path="SG161222/Realistic_Vision_V5.1_noVAE" \
6
+ --output_dir="./controlnet-training-runs" \
7
+ --train_data_dir=/path/to/dataset \
8
+ --cfg_prob=0.1 \
9
+ --resolution=512 \
10
+ --learning_rate=1e-5 \
11
+ --num_validation_images=3 \
12
+ --validation_image "./face_normal.png" "./face_seg.png" \
13
+ --validation_prompt "a Teen boy, pensive look, dark hair. Preppy sweater, collared shirt, moody room, 80s memorabilia" \
14
+ --train_batch_size=4 \
15
+ --num_train_epochs=40 \
16
+ --validation_steps=500 \
17
+ --checkpointing_steps=2000
script/train_normal_seg_controlnet_all_in_one.py ADDED
@@ -0,0 +1,1328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ import cv2
16
+ import time
17
+ import argparse
18
+ import contextlib
19
+ import gc
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import shutil
25
+ from pathlib import Path
26
+
27
+ import accelerate
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from dataset_AnimPortrait3D_controlnet import NormalSegDataset
37
+ from huggingface_hub import create_repo, upload_folder
38
+ from packaging import version
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+ from tqdm.auto import tqdm
42
+ from transformers import AutoTokenizer, PretrainedConfig
43
+
44
+ import diffusers
45
+ from diffusers import (
46
+ AutoencoderKL,
47
+ ControlNetModel,
48
+ DDPMScheduler,
49
+ StableDiffusionControlNetPipeline,
50
+ UNet2DConditionModel,
51
+ UniPCMultistepScheduler,
52
+ )
53
+ from diffusers.optimization import get_scheduler
54
+ from diffusers.utils import check_min_version, is_wandb_available
55
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
56
+ from diffusers.utils.import_utils import is_xformers_available
57
+ from diffusers.utils.torch_utils import is_compiled_module
58
+
59
+
60
+ if is_wandb_available():
61
+ import wandb
62
+
63
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
64
+ check_min_version("0.31.0.dev0")
65
+
66
+ logger = get_logger(__name__)
67
+
68
+
69
+ def image_grid(imgs, rows, cols):
70
+ assert len(imgs) == rows * cols
71
+
72
+ w, h = imgs[0].size
73
+ grid = Image.new("RGB", size=(cols * w, rows * h))
74
+
75
+ for i, img in enumerate(imgs):
76
+ grid.paste(img, box=(i % cols * w, i // cols * h))
77
+ return grid
78
+
79
+
80
+
81
+ def log_validation(
82
+ vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False,train_batch = None
83
+ ):
84
+ logger.info("Running validation... ")
85
+
86
+ if not is_final_validation:
87
+ controlnet = accelerator.unwrap_model(controlnet)
88
+ else:
89
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
90
+
91
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
92
+ args.pretrained_model_name_or_path,
93
+ vae=vae,
94
+ text_encoder=text_encoder,
95
+ tokenizer=tokenizer,
96
+ unet=unet,
97
+ controlnet=controlnet,
98
+ safety_checker=None,
99
+ revision=args.revision,
100
+ variant=args.variant,
101
+ torch_dtype=weight_dtype,
102
+ )
103
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
104
+ pipeline = pipeline.to(accelerator.device)
105
+ pipeline.set_progress_bar_config(disable=True)
106
+
107
+ if args.enable_xformers_memory_efficient_attention:
108
+ pipeline.enable_xformers_memory_efficient_attention()
109
+
110
+ if args.seed is None:
111
+ generator = None
112
+ else:
113
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
114
+
115
+ validation_images = args.validation_image.copy()
116
+ validation_nums = len(validation_images)//2
117
+ validation_prompt = args.validation_prompt.copy()
118
+
119
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
120
+
121
+
122
+ assert len(validation_prompt) == validation_nums
123
+ validation_prompts = validation_prompt
124
+
125
+ gt_images = [None] * validation_nums
126
+
127
+ logger.info(f'[info] validation_nums {validation_nums}')
128
+
129
+
130
+ if len(validation_images)<12:
131
+ conditioning_train = train_batch["conditioning_pixel_values"] # b, c, h, w
132
+
133
+ gt_train = train_batch["pixel_values"] # b, c, h, w
134
+
135
+ # text_prompts = []
136
+ for i in range(4):
137
+ validation_prompts.append(train_batch["text_prompts"][i])
138
+ logger.info(f'[info] append prompt { train_batch["text_prompts"][i]}')
139
+
140
+
141
+ # validation_prompts.append(text_prompts[i])
142
+
143
+ conditioning_image = conditioning_train[i] # c, h, w
144
+ conditioning_image = conditioning_image.permute(1,2,0).cpu().numpy()
145
+ normal_image = conditioning_image[:,:,:3] * 255
146
+ seg_image = conditioning_image[:,:,3:].repeat(3, axis=2) * 255
147
+ gt_image = gt_train[i]/2+0.5 # c, h, w
148
+ gt_image = gt_image.permute(1,2,0).cpu().numpy() * 255
149
+
150
+ validation_images.append(Image.fromarray(normal_image.astype(np.uint8)))
151
+ validation_images.append(Image.fromarray(seg_image.astype(np.uint8)))
152
+
153
+ gt_images.append(gt_image.astype(np.uint8))
154
+
155
+
156
+
157
+ logger.info(f'[info] new len(validation_images) {len(validation_images)}')
158
+ save_dir_path = os.path.join(args.output_dir, "eval_img")
159
+ if not os.path.exists(save_dir_path):
160
+ os.makedirs(save_dir_path)
161
+ for i in range(len(validation_images)//2):
162
+ if isinstance(validation_images[i*2], str):
163
+ normal_image = Image.open(validation_images[i*2]).resize((args.resolution, args.resolution))
164
+
165
+ else:
166
+ normal_image = validation_images[i*2]
167
+
168
+ if isinstance(validation_images[i*2+1], str):
169
+ seg_image = Image.open(validation_images[i*2+1]).resize((args.resolution, args.resolution))
170
+ else:
171
+ seg_image = validation_images[i*2+1]
172
+
173
+ seg_image = np.array(seg_image)[:,:,:1]
174
+
175
+ gt_image = gt_images[i]
176
+
177
+
178
+
179
+ validation_image = np.concatenate([np.array(normal_image), seg_image], axis=2)[None,...] / 255.0
180
+ # PIL.Image: 0-255
181
+ # np.array: 0-1
182
+
183
+ validation_prompt = validation_prompts[i]
184
+ print('validation_prompt: ', validation_prompt)
185
+ images = []
186
+ for _ in range(args.num_validation_images):
187
+ with inference_ctx:
188
+ image = pipeline(
189
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator,guidance_scale=7.5
190
+ ).images[0]
191
+
192
+ images.append(image)
193
+
194
+ validation_image = validation_image[0] * 255.0
195
+
196
+ normal = np.array(validation_image)[:,:,:3]
197
+ seg = np.array(validation_image)[:,:,3:]
198
+ seg = np.concatenate([seg, seg, seg], axis=2)
199
+
200
+ if gt_image is not None:
201
+ gt_image = cv2.resize(gt_image, images[0].size)
202
+
203
+ formatted_images = [gt_image,normal,seg]
204
+
205
+ else:
206
+ formatted_images = [normal,seg]
207
+ for image in images:
208
+ formatted_images.append(np.asarray(image))
209
+
210
+ formatted_images = np.concatenate(formatted_images, 1).astype(np.uint8)
211
+
212
+ file_path = os.path.join(save_dir_path, "{}_{}_{}.png".format(step, time.time(), validation_prompt.replace(" ", "-")))
213
+ formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB)
214
+ print("Save images to:", file_path)
215
+ cv2.imwrite(file_path, formatted_images)
216
+
217
+ del pipeline
218
+ gc.collect()
219
+ torch.cuda.empty_cache()
220
+
221
+
222
+ # def log_validation(
223
+ # vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
224
+ # ):
225
+ # logger.info("Running validation... ")
226
+
227
+ # if not is_final_validation:
228
+ # controlnet = accelerator.unwrap_model(controlnet)
229
+ # else:
230
+ # controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
231
+
232
+ # pipeline = StableDiffusionControlNetPipeline.from_pretrained(
233
+ # args.pretrained_model_name_or_path,
234
+ # vae=vae,
235
+ # text_encoder=text_encoder,
236
+ # tokenizer=tokenizer,
237
+ # unet=unet,
238
+ # controlnet=controlnet,
239
+ # safety_checker=None,
240
+ # revision=args.revision,
241
+ # variant=args.variant,
242
+ # torch_dtype=weight_dtype,
243
+ # )
244
+ # pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
245
+ # pipeline = pipeline.to(accelerator.device)
246
+ # pipeline.set_progress_bar_config(disable=True)
247
+
248
+ # if args.enable_xformers_memory_efficient_attention:
249
+ # pipeline.enable_xformers_memory_efficient_attention()
250
+
251
+ # if args.seed is None:
252
+ # generator = None
253
+ # else:
254
+ # generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
255
+
256
+ # if len(args.validation_image) == len(args.validation_prompt):
257
+ # validation_images = args.validation_image
258
+ # validation_prompts = args.validation_prompt
259
+ # elif len(args.validation_image) == 1:
260
+ # validation_images = args.validation_image * len(args.validation_prompt)
261
+ # validation_prompts = args.validation_prompt
262
+ # elif len(args.validation_prompt) == 1:
263
+ # validation_images = args.validation_image
264
+ # validation_prompts = args.validation_prompt * len(args.validation_image)
265
+ # else:
266
+ # raise ValueError(
267
+ # "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
268
+ # )
269
+
270
+ # image_logs = []
271
+ # inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
272
+
273
+ # for validation_prompt, validation_image in zip(validation_prompts, validation_images):
274
+ # validation_image = Image.open(validation_image).convert("RGB")
275
+
276
+ # images = []
277
+
278
+ # for _ in range(args.num_validation_images):
279
+ # with inference_ctx:
280
+ # image = pipeline(
281
+ # validation_prompt, validation_image, num_inference_steps=20, generator=generator
282
+ # ).images[0]
283
+
284
+ # images.append(image)
285
+
286
+ # image_logs.append(
287
+ # {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
288
+ # )
289
+
290
+ # tracker_key = "test" if is_final_validation else "validation"
291
+ # save_dir_path = os.path.join(args.output_dir, "eval_img")
292
+ # if not os.path.exists(save_dir_path):
293
+ # os.makedirs(save_dir_path)
294
+ # for tracker in accelerator.trackers:
295
+ # for log in image_logs:
296
+ # images = log["images"]
297
+ # validation_prompt = log["validation_prompt"]
298
+ # validation_image = log["validation_image"]
299
+
300
+ # formatted_images = []
301
+ # formatted_images.append(np.asarray(validation_image))
302
+ # for image in images:
303
+ # formatted_images.append(np.asarray(image))
304
+ # formatted_images = np.concatenate(formatted_images, 1)
305
+
306
+ # file_path = os.path.join(save_dir_path, "{}_{}_{}.png".format(step, time.time(), validation_prompt.replace(" ", "-")))
307
+ # formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB)
308
+ # print("Save images to:", file_path)
309
+ # cv2.imwrite(file_path, formatted_images)
310
+
311
+ # del pipeline
312
+ # gc.collect()
313
+ # torch.cuda.empty_cache()
314
+
315
+ # return image_logs
316
+
317
+
318
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
319
+ text_encoder_config = PretrainedConfig.from_pretrained(
320
+ pretrained_model_name_or_path,
321
+ subfolder="text_encoder",
322
+ revision=revision,
323
+ )
324
+ model_class = text_encoder_config.architectures[0]
325
+
326
+ if model_class == "CLIPTextModel":
327
+ from transformers import CLIPTextModel
328
+
329
+ return CLIPTextModel
330
+ elif model_class == "RobertaSeriesModelWithTransformation":
331
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
332
+
333
+ return RobertaSeriesModelWithTransformation
334
+ else:
335
+ raise ValueError(f"{model_class} is not supported.")
336
+
337
+
338
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
339
+ img_str = ""
340
+ if image_logs is not None:
341
+ img_str = "You can find some example images below.\n\n"
342
+ for i, log in enumerate(image_logs):
343
+ images = log["images"]
344
+ validation_prompt = log["validation_prompt"]
345
+ validation_image = log["validation_image"]
346
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
347
+ img_str += f"prompt: {validation_prompt}\n"
348
+ images = [validation_image] + images
349
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
350
+ img_str += f"![images_{i})](./images_{i}.png)\n"
351
+
352
+ model_description = f"""
353
+ # controlnet-{repo_id}
354
+
355
+ These are controlnet weights trained on {base_model} with new type of conditioning.
356
+ {img_str}
357
+ """
358
+ model_card = load_or_create_model_card(
359
+ repo_id_or_path=repo_id,
360
+ from_training=True,
361
+ license="creativeml-openrail-m",
362
+ base_model=base_model,
363
+ model_description=model_description,
364
+ inference=True,
365
+ )
366
+
367
+ tags = [
368
+ "stable-diffusion",
369
+ "stable-diffusion-diffusers",
370
+ "text-to-image",
371
+ "diffusers",
372
+ "controlnet",
373
+ "diffusers-training",
374
+ ]
375
+ model_card = populate_model_card(model_card, tags=tags)
376
+
377
+ model_card.save(os.path.join(repo_folder, "README.md"))
378
+
379
+
380
+ def parse_args(input_args=None):
381
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
382
+ parser.add_argument(
383
+ "--pretrained_model_name_or_path",
384
+ type=str,
385
+ default=None,
386
+ required=True,
387
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
388
+ )
389
+ parser.add_argument(
390
+ "--controlnet_model_name_or_path",
391
+ type=str,
392
+ default=None,
393
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
394
+ " If not specified controlnet weights are initialized from unet.",
395
+ )
396
+ parser.add_argument(
397
+ "--revision",
398
+ type=str,
399
+ default=None,
400
+ required=False,
401
+ help="Revision of pretrained model identifier from huggingface.co/models.",
402
+ )
403
+ parser.add_argument(
404
+ "--variant",
405
+ type=str,
406
+ default=None,
407
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
408
+ )
409
+ parser.add_argument(
410
+ "--tokenizer_name",
411
+ type=str,
412
+ default=None,
413
+ help="Pretrained tokenizer name or path if not the same as model_name",
414
+ )
415
+ parser.add_argument(
416
+ "--output_dir",
417
+ type=str,
418
+ default="controlnet-model",
419
+ help="The output directory where the model predictions and checkpoints will be written.",
420
+ )
421
+ parser.add_argument(
422
+ "--cache_dir",
423
+ type=str,
424
+ default=None,
425
+ help="The directory where the downloaded models and datasets will be stored.",
426
+ )
427
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
428
+ parser.add_argument(
429
+ "--resolution",
430
+ type=int,
431
+ default=512,
432
+ help=(
433
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
434
+ " resolution"
435
+ ),
436
+ )
437
+ parser.add_argument(
438
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
439
+ )
440
+ parser.add_argument("--num_train_epochs", type=int, default=1)
441
+ parser.add_argument(
442
+ "--max_train_steps",
443
+ type=int,
444
+ default=None,
445
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
446
+ )
447
+ parser.add_argument(
448
+ "--checkpointing_steps",
449
+ type=int,
450
+ default=500,
451
+ help=(
452
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
453
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
454
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
455
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
456
+ "instructions."
457
+ ),
458
+ )
459
+ parser.add_argument(
460
+ "--checkpoints_total_limit",
461
+ type=int,
462
+ default=None,
463
+ help=("Max number of checkpoints to store."),
464
+ )
465
+ parser.add_argument(
466
+ "--resume_from_checkpoint",
467
+ type=str,
468
+ default=None,
469
+ help=(
470
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
471
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
472
+ ),
473
+ )
474
+ parser.add_argument(
475
+ "--gradient_accumulation_steps",
476
+ type=int,
477
+ default=1,
478
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
479
+ )
480
+ parser.add_argument(
481
+ "--gradient_checkpointing",
482
+ action="store_true",
483
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
484
+ )
485
+ parser.add_argument(
486
+ "--learning_rate",
487
+ type=float,
488
+ default=5e-6,
489
+ help="Initial learning rate (after the potential warmup period) to use.",
490
+ )
491
+ parser.add_argument(
492
+ "--scale_lr",
493
+ action="store_true",
494
+ default=False,
495
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
496
+ )
497
+ parser.add_argument(
498
+ "--lr_scheduler",
499
+ type=str,
500
+ default="constant",
501
+ help=(
502
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
503
+ ' "constant", "constant_with_warmup"]'
504
+ ),
505
+ )
506
+ parser.add_argument(
507
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
508
+ )
509
+ parser.add_argument(
510
+ "--lr_num_cycles",
511
+ type=int,
512
+ default=1,
513
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
514
+ )
515
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
516
+ parser.add_argument(
517
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
518
+ )
519
+ parser.add_argument(
520
+ "--dataloader_num_workers",
521
+ type=int,
522
+ default=0,
523
+ help=(
524
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
525
+ ),
526
+ )
527
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
528
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
529
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
530
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
531
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
532
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
533
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
534
+ parser.add_argument(
535
+ "--hub_model_id",
536
+ type=str,
537
+ default=None,
538
+ help="The name of the repository to keep in sync with the local `output_dir`.",
539
+ )
540
+ parser.add_argument(
541
+ "--logging_dir",
542
+ type=str,
543
+ default="logs",
544
+ help=(
545
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
546
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
547
+ ),
548
+ )
549
+ parser.add_argument(
550
+ "--allow_tf32",
551
+ action="store_true",
552
+ help=(
553
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
554
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
555
+ ),
556
+ )
557
+ parser.add_argument(
558
+ "--report_to",
559
+ type=str,
560
+ default="tensorboard",
561
+ help=(
562
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
563
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
564
+ ),
565
+ )
566
+ parser.add_argument(
567
+ "--mixed_precision",
568
+ type=str,
569
+ default=None,
570
+ choices=["no", "fp16", "bf16"],
571
+ help=(
572
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
573
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
574
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
575
+ ),
576
+ )
577
+ parser.add_argument(
578
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
579
+ )
580
+ parser.add_argument(
581
+ "--set_grads_to_none",
582
+ action="store_true",
583
+ help=(
584
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
585
+ " behaviors, so disable this argument if it causes any problems. More info:"
586
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
587
+ ),
588
+ )
589
+ parser.add_argument(
590
+ "--dataset_name",
591
+ type=str,
592
+ default=None,
593
+ help=(
594
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
595
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
596
+ " or to a folder containing files that 🤗 Datasets can understand."
597
+ ),
598
+ )
599
+ parser.add_argument(
600
+ "--dataset_config_name",
601
+ type=str,
602
+ default=None,
603
+ help="The config of the Dataset, leave as None if there's only one config.",
604
+ )
605
+ parser.add_argument(
606
+ "--train_data_dir",
607
+ type=str,
608
+ default=None,
609
+ help=(
610
+ "A folder containing the training data. Folder contents must follow the structure described in"
611
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
612
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
613
+ ),
614
+ )
615
+ parser.add_argument(
616
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
617
+ )
618
+ parser.add_argument(
619
+ "--conditioning_image_column",
620
+ type=str,
621
+ default="conditioning_image",
622
+ help="The column of the dataset containing the controlnet conditioning image.",
623
+ )
624
+ parser.add_argument(
625
+ "--caption_column",
626
+ type=str,
627
+ default="text",
628
+ help="The column of the dataset containing a caption or a list of captions.",
629
+ )
630
+ parser.add_argument(
631
+ "--max_train_samples",
632
+ type=int,
633
+ default=None,
634
+ help=(
635
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
636
+ "value if set."
637
+ ),
638
+ )
639
+ parser.add_argument(
640
+ "--proportion_empty_prompts",
641
+ type=float,
642
+ default=0,
643
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
644
+ )
645
+
646
+ parser.add_argument(
647
+ "--cfg_prob",
648
+ type=float,
649
+ default=0,
650
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
651
+ )
652
+ parser.add_argument(
653
+ "--validation_prompt",
654
+ type=str,
655
+ default=None,
656
+ nargs="+",
657
+ help=(
658
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
659
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
660
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
661
+ ),
662
+ )
663
+ parser.add_argument(
664
+ "--validation_image",
665
+ type=str,
666
+ default=None,
667
+ nargs="+",
668
+ help=(
669
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
670
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
671
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
672
+ " `--validation_image` that will be used with all `--validation_prompt`s."
673
+ ),
674
+ )
675
+ parser.add_argument(
676
+ "--num_validation_images",
677
+ type=int,
678
+ default=4,
679
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
680
+ )
681
+ parser.add_argument(
682
+ "--validation_steps",
683
+ type=int,
684
+ default=100,
685
+ help=(
686
+ "Run validation every X steps. Validation consists of running the prompt"
687
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
688
+ " and logging the images."
689
+ ),
690
+ )
691
+ parser.add_argument(
692
+ "--tracker_project_name",
693
+ type=str,
694
+ default="train_controlnet",
695
+ help=(
696
+ "The `project_name` argument passed to Accelerator.init_trackers for"
697
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
698
+ ),
699
+ )
700
+
701
+ if input_args is not None:
702
+ args = parser.parse_args(input_args)
703
+ else:
704
+ args = parser.parse_args()
705
+
706
+ if args.dataset_name is None and args.train_data_dir is None:
707
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
708
+
709
+ if args.dataset_name is not None and args.train_data_dir is not None:
710
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
711
+
712
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
713
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
714
+
715
+ if args.validation_prompt is not None and args.validation_image is None:
716
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
717
+
718
+ if args.validation_prompt is None and args.validation_image is not None:
719
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
720
+
721
+
722
+
723
+ if args.resolution % 8 != 0:
724
+ raise ValueError(
725
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
726
+ )
727
+
728
+ return args
729
+
730
+
731
+ def make_train_dataset(args, tokenizer, accelerator):
732
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
733
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
734
+
735
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
736
+ # download the dataset.
737
+ if args.dataset_name is not None:
738
+ # Downloading and loading a dataset from the hub.
739
+ dataset = load_dataset(
740
+ args.dataset_name,
741
+ args.dataset_config_name,
742
+ cache_dir=args.cache_dir,
743
+ )
744
+ else:
745
+ if args.train_data_dir is not None:
746
+ dataset = load_dataset(
747
+ args.train_data_dir,
748
+ cache_dir=args.cache_dir,
749
+ )
750
+ # See more about loading custom images at
751
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
752
+
753
+ # Preprocessing the datasets.
754
+ # We need to tokenize inputs and targets.
755
+ column_names = dataset["train"].column_names
756
+
757
+ # 6. Get the column names for input/target.
758
+ if args.image_column is None:
759
+ image_column = column_names[0]
760
+ logger.info(f"image column defaulting to {image_column}")
761
+ else:
762
+ image_column = args.image_column
763
+ if image_column not in column_names:
764
+ raise ValueError(
765
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
766
+ )
767
+
768
+ if args.caption_column is None:
769
+ caption_column = column_names[1]
770
+ logger.info(f"caption column defaulting to {caption_column}")
771
+ else:
772
+ caption_column = args.caption_column
773
+ if caption_column not in column_names:
774
+ raise ValueError(
775
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
776
+ )
777
+
778
+ if args.conditioning_image_column is None:
779
+ conditioning_image_column = column_names[2]
780
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
781
+ else:
782
+ conditioning_image_column = args.conditioning_image_column
783
+ if conditioning_image_column not in column_names:
784
+ raise ValueError(
785
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
786
+ )
787
+
788
+ def tokenize_captions(examples, is_train=True):
789
+ captions = []
790
+ for caption in examples[caption_column]:
791
+ if random.random() < args.proportion_empty_prompts:
792
+ captions.append("")
793
+ elif isinstance(caption, str):
794
+ captions.append(caption)
795
+ elif isinstance(caption, (list, np.ndarray)):
796
+ # take a random caption if there are multiple
797
+ captions.append(random.choice(caption) if is_train else caption[0])
798
+ else:
799
+ raise ValueError(
800
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
801
+ )
802
+ inputs = tokenizer(
803
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
804
+ )
805
+ return inputs.input_ids
806
+
807
+ image_transforms = transforms.Compose(
808
+ [
809
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
810
+ transforms.CenterCrop(args.resolution),
811
+ transforms.ToTensor(),
812
+ transforms.Normalize([0.5], [0.5]),
813
+ ]
814
+ )
815
+
816
+ conditioning_image_transforms = transforms.Compose(
817
+ [
818
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
819
+ transforms.CenterCrop(args.resolution),
820
+ transforms.ToTensor(),
821
+ ]
822
+ )
823
+
824
+ def preprocess_train(examples):
825
+ images = [image.convert("RGB") for image in examples[image_column]]
826
+ images = [image_transforms(image) for image in images]
827
+
828
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
829
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
830
+
831
+ examples["pixel_values"] = images
832
+ examples["conditioning_pixel_values"] = conditioning_images
833
+ examples["input_ids"] = tokenize_captions(examples)
834
+
835
+ return examples
836
+
837
+ with accelerator.main_process_first():
838
+ if args.max_train_samples is not None:
839
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
840
+ # Set the training transforms
841
+ train_dataset = dataset["train"].with_transform(preprocess_train)
842
+
843
+ return train_dataset
844
+
845
+
846
+ def collate_fn(examples):
847
+
848
+ pixel_values = []
849
+ conditioning_pixel_values = []
850
+ input_ids = []
851
+ text_prompts = []
852
+ for bach in examples:
853
+
854
+ pixel_value, conditioning_pixel_value, input_id,text_prompt = bach
855
+ pixel_values.append(pixel_value)
856
+ conditioning_pixel_values.append(conditioning_pixel_value)
857
+ input_ids.append(input_id)
858
+ text_prompts.append(text_prompt)
859
+
860
+ pixel_values = torch.stack(pixel_values)
861
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
862
+
863
+ conditioning_pixel_values = torch.stack(conditioning_pixel_values)
864
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
865
+
866
+ input_ids = torch.stack(input_ids)
867
+
868
+ return {
869
+ "pixel_values": pixel_values,
870
+ "conditioning_pixel_values": conditioning_pixel_values,
871
+ "input_ids": input_ids,
872
+ "text_prompts": text_prompts
873
+ }
874
+
875
+
876
+ def main(args):
877
+ if args.report_to == "wandb" and args.hub_token is not None:
878
+ raise ValueError(
879
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
880
+ " Please use `huggingface-cli login` to authenticate with the Hub."
881
+ )
882
+
883
+ logging_dir = Path(args.output_dir, args.logging_dir)
884
+
885
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
886
+
887
+ accelerator = Accelerator(
888
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
889
+ mixed_precision=args.mixed_precision,
890
+ log_with=args.report_to,
891
+ project_config=accelerator_project_config,
892
+ )
893
+
894
+ # Disable AMP for MPS.
895
+ if torch.backends.mps.is_available():
896
+ accelerator.native_amp = False
897
+
898
+ # Make one log on every process with the configuration for debugging.
899
+ logging.basicConfig(
900
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
901
+ datefmt="%m/%d/%Y %H:%M:%S",
902
+ level=logging.INFO,
903
+ )
904
+ logger.info(accelerator.state, main_process_only=False)
905
+ if accelerator.is_local_main_process:
906
+ transformers.utils.logging.set_verbosity_warning()
907
+ diffusers.utils.logging.set_verbosity_info()
908
+ else:
909
+ transformers.utils.logging.set_verbosity_error()
910
+ diffusers.utils.logging.set_verbosity_error()
911
+
912
+ # If passed along, set the training seed now.
913
+ if args.seed is not None:
914
+ set_seed(args.seed)
915
+
916
+ # Handle the repository creation
917
+ if accelerator.is_main_process:
918
+ if args.output_dir is not None:
919
+ os.makedirs(args.output_dir, exist_ok=True)
920
+
921
+ if args.push_to_hub:
922
+ repo_id = create_repo(
923
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
924
+ ).repo_id
925
+
926
+ # Load the tokenizer
927
+ if args.tokenizer_name:
928
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
929
+ elif args.pretrained_model_name_or_path:
930
+ tokenizer = AutoTokenizer.from_pretrained(
931
+ args.pretrained_model_name_or_path,
932
+ subfolder="tokenizer",
933
+ revision=args.revision,
934
+ use_fast=False,
935
+ )
936
+
937
+ # import correct text encoder class
938
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
939
+
940
+ # Load scheduler and models
941
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
942
+ text_encoder = text_encoder_cls.from_pretrained(
943
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
944
+ )
945
+ vae = AutoencoderKL.from_pretrained(
946
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
947
+ )
948
+ unet = UNet2DConditionModel.from_pretrained(
949
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
950
+ )
951
+
952
+ if args.controlnet_model_name_or_path:
953
+ logger.info("Loading existing controlnet weights")
954
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path,in_channels=4)
955
+ else:
956
+ logger.info("Initializing controlnet weights from unet")
957
+ controlnet = ControlNetModel.from_unet(unet,conditioning_channels=4)
958
+
959
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
960
+ def unwrap_model(model):
961
+ model = accelerator.unwrap_model(model)
962
+ model = model._orig_mod if is_compiled_module(model) else model
963
+ return model
964
+
965
+ # `accelerate` 0.16.0 will have better support for customized saving
966
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
967
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
968
+ def save_model_hook(models, weights, output_dir):
969
+ if accelerator.is_main_process:
970
+ i = len(weights) - 1
971
+
972
+ while len(weights) > 0:
973
+ weights.pop()
974
+ model = models[i]
975
+
976
+ sub_dir = "controlnet"
977
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
978
+
979
+ i -= 1
980
+
981
+ def load_model_hook(models, input_dir):
982
+ while len(models) > 0:
983
+ # pop models so that they are not loaded again
984
+ model = models.pop()
985
+
986
+ # load diffusers style into model
987
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
988
+ model.register_to_config(**load_model.config)
989
+
990
+ model.load_state_dict(load_model.state_dict())
991
+ del load_model
992
+
993
+ accelerator.register_save_state_pre_hook(save_model_hook)
994
+ accelerator.register_load_state_pre_hook(load_model_hook)
995
+
996
+ vae.requires_grad_(False)
997
+ unet.requires_grad_(False)
998
+ text_encoder.requires_grad_(False)
999
+ controlnet.train()
1000
+
1001
+ if args.enable_xformers_memory_efficient_attention:
1002
+ if is_xformers_available():
1003
+ import xformers
1004
+
1005
+ xformers_version = version.parse(xformers.__version__)
1006
+ if xformers_version == version.parse("0.0.16"):
1007
+ logger.warning(
1008
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1009
+ )
1010
+ unet.enable_xformers_memory_efficient_attention()
1011
+ controlnet.enable_xformers_memory_efficient_attention()
1012
+ else:
1013
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
1014
+
1015
+ if args.gradient_checkpointing:
1016
+ controlnet.enable_gradient_checkpointing()
1017
+
1018
+ # Check that all trainable models are in full precision
1019
+ low_precision_error_string = (
1020
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
1021
+ " doing mixed precision training, copy of the weights should still be float32."
1022
+ )
1023
+
1024
+ if unwrap_model(controlnet).dtype != torch.float32:
1025
+ raise ValueError(
1026
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
1027
+ )
1028
+
1029
+ # Enable TF32 for faster training on Ampere GPUs,
1030
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1031
+ if args.allow_tf32:
1032
+ torch.backends.cuda.matmul.allow_tf32 = True
1033
+
1034
+ if args.scale_lr:
1035
+ args.learning_rate = (
1036
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1037
+ )
1038
+
1039
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1040
+ if args.use_8bit_adam:
1041
+ try:
1042
+ import bitsandbytes as bnb
1043
+ except ImportError:
1044
+ raise ImportError(
1045
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1046
+ )
1047
+
1048
+ optimizer_class = bnb.optim.AdamW8bit
1049
+ else:
1050
+ optimizer_class = torch.optim.AdamW
1051
+
1052
+ # Optimizer creation
1053
+ params_to_optimize = controlnet.parameters()
1054
+ optimizer = optimizer_class(
1055
+ params_to_optimize,
1056
+ lr=args.learning_rate,
1057
+ betas=(args.adam_beta1, args.adam_beta2),
1058
+ weight_decay=args.adam_weight_decay,
1059
+ eps=args.adam_epsilon,
1060
+ )
1061
+
1062
+ train_dataset = NormalSegDataset(args, args.train_data_dir, tokenizer, cfg_prob = args.cfg_prob)
1063
+ print(' ======================== size of train_dataset:', len(train_dataset))
1064
+
1065
+ train_dataloader = torch.utils.data.DataLoader(
1066
+ train_dataset,
1067
+ shuffle=True,
1068
+ collate_fn=collate_fn,
1069
+ batch_size=args.train_batch_size,
1070
+ num_workers=args.dataloader_num_workers,
1071
+ )
1072
+
1073
+ # Scheduler and math around the number of training steps.
1074
+ overrode_max_train_steps = False
1075
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1076
+ if args.max_train_steps is None:
1077
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1078
+ overrode_max_train_steps = True
1079
+
1080
+ lr_scheduler = get_scheduler(
1081
+ args.lr_scheduler,
1082
+ optimizer=optimizer,
1083
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1084
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1085
+ num_cycles=args.lr_num_cycles,
1086
+ power=args.lr_power,
1087
+ )
1088
+
1089
+ # Prepare everything with our `accelerator`.
1090
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1091
+ controlnet, optimizer, train_dataloader, lr_scheduler
1092
+ )
1093
+
1094
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1095
+ # as these models are only used for inference, keeping weights in full precision is not required.
1096
+ weight_dtype = torch.float32
1097
+ if accelerator.mixed_precision == "fp16":
1098
+ weight_dtype = torch.float16
1099
+ elif accelerator.mixed_precision == "bf16":
1100
+ weight_dtype = torch.bfloat16
1101
+
1102
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
1103
+ vae.to(accelerator.device, dtype=weight_dtype)
1104
+ unet.to(accelerator.device, dtype=weight_dtype)
1105
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1106
+
1107
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1108
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1109
+ if overrode_max_train_steps:
1110
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1111
+ # Afterwards we recalculate our number of training epochs
1112
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1113
+
1114
+ # We need to initialize the trackers we use, and also store our configuration.
1115
+ # The trackers initializes automatically on the main process.
1116
+ if accelerator.is_main_process:
1117
+ tracker_config = dict(vars(args))
1118
+
1119
+ # tensorboard cannot handle list types for config
1120
+ tracker_config.pop("validation_prompt")
1121
+ tracker_config.pop("validation_image")
1122
+
1123
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1124
+
1125
+ # Train!
1126
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1127
+
1128
+ logger.info("***** Running training *****")
1129
+ logger.info(f" Num examples = {len(train_dataset)}")
1130
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1131
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1132
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1133
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1134
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1135
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1136
+ global_step = 0
1137
+ first_epoch = 0
1138
+
1139
+ # Potentially load in the weights and states from a previous save
1140
+ if args.resume_from_checkpoint:
1141
+ if args.resume_from_checkpoint != "latest":
1142
+ path = os.path.basename(args.resume_from_checkpoint)
1143
+ else:
1144
+ # Get the most recent checkpoint
1145
+ dirs = os.listdir(args.output_dir)
1146
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1147
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1148
+ path = dirs[-1] if len(dirs) > 0 else None
1149
+
1150
+ if path is None:
1151
+ accelerator.print(
1152
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1153
+ )
1154
+ args.resume_from_checkpoint = None
1155
+ initial_global_step = 0
1156
+ else:
1157
+ accelerator.print(f"Resuming from checkpoint {path}")
1158
+ accelerator.load_state(os.path.join(args.output_dir, path))
1159
+ global_step = int(path.split("-")[1])
1160
+
1161
+ initial_global_step = global_step
1162
+ first_epoch = global_step // num_update_steps_per_epoch
1163
+ else:
1164
+ initial_global_step = 0
1165
+
1166
+ progress_bar = tqdm(
1167
+ range(0, args.max_train_steps),
1168
+ initial=initial_global_step,
1169
+ desc="Steps",
1170
+ # Only show the progress bar once on each machine.
1171
+ disable=not accelerator.is_local_main_process,
1172
+ )
1173
+
1174
+ image_logs = None
1175
+ for epoch in range(first_epoch, args.num_train_epochs):
1176
+ for step, batch in enumerate(train_dataloader):
1177
+ with accelerator.accumulate(controlnet):
1178
+ # Convert images to latent space
1179
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1180
+ latents = latents * vae.config.scaling_factor
1181
+
1182
+ # Sample noise that we'll add to the latents
1183
+ noise = torch.randn_like(latents)
1184
+ bsz = latents.shape[0]
1185
+ # Sample a random timestep for each image
1186
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1187
+ timesteps = timesteps.long()
1188
+
1189
+ # Add noise to the latents according to the noise magnitude at each timestep
1190
+ # (this is the forward diffusion process)
1191
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1192
+
1193
+ # Get the text embedding for conditioning
1194
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
1195
+
1196
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1197
+
1198
+ down_block_res_samples, mid_block_res_sample = controlnet(
1199
+ noisy_latents,
1200
+ timesteps,
1201
+ encoder_hidden_states=encoder_hidden_states,
1202
+ controlnet_cond=controlnet_image,
1203
+ return_dict=False,
1204
+ )
1205
+
1206
+ # Predict the noise residual
1207
+ model_pred = unet(
1208
+ noisy_latents,
1209
+ timesteps,
1210
+ encoder_hidden_states=encoder_hidden_states,
1211
+ down_block_additional_residuals=[
1212
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1213
+ ],
1214
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1215
+ return_dict=False,
1216
+ )[0]
1217
+
1218
+ # Get the target for loss depending on the prediction type
1219
+ if noise_scheduler.config.prediction_type == "epsilon":
1220
+ target = noise
1221
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1222
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1223
+ else:
1224
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1225
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1226
+
1227
+ accelerator.backward(loss)
1228
+ if accelerator.sync_gradients:
1229
+ params_to_clip = controlnet.parameters()
1230
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1231
+ optimizer.step()
1232
+ lr_scheduler.step()
1233
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1234
+
1235
+ # Checks if the accelerator has performed an optimization step behind the scenes
1236
+ if accelerator.sync_gradients:
1237
+ progress_bar.update(1)
1238
+ global_step += 1
1239
+
1240
+ if accelerator.is_main_process:
1241
+ if global_step % args.checkpointing_steps == 0:
1242
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1243
+ if args.checkpoints_total_limit is not None:
1244
+ checkpoints = os.listdir(args.output_dir)
1245
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1246
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1247
+
1248
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1249
+ if len(checkpoints) >= args.checkpoints_total_limit:
1250
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1251
+ removing_checkpoints = checkpoints[0:num_to_remove]
1252
+
1253
+ logger.info(
1254
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1255
+ )
1256
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1257
+
1258
+ for removing_checkpoint in removing_checkpoints:
1259
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1260
+ shutil.rmtree(removing_checkpoint)
1261
+
1262
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1263
+ accelerator.save_state(save_path)
1264
+ logger.info(f"Saved state to {save_path}")
1265
+
1266
+ if (global_step % args.validation_steps == 0 or global_step== 1):
1267
+ image_logs = log_validation(
1268
+ vae,
1269
+ text_encoder,
1270
+ tokenizer,
1271
+ unet,
1272
+ controlnet,
1273
+ args,
1274
+ accelerator,
1275
+ weight_dtype,
1276
+ global_step,
1277
+ train_batch = batch
1278
+ )
1279
+
1280
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1281
+ progress_bar.set_postfix(**logs)
1282
+ accelerator.log(logs, step=global_step)
1283
+
1284
+ if global_step >= args.max_train_steps:
1285
+ break
1286
+
1287
+ # Create the pipeline using using the trained modules and save it.
1288
+ accelerator.wait_for_everyone()
1289
+ if accelerator.is_main_process:
1290
+ controlnet = unwrap_model(controlnet)
1291
+ controlnet.save_pretrained(args.output_dir)
1292
+
1293
+ # Run a final round of validation.
1294
+ image_logs = None
1295
+ if args.validation_prompt is not None:
1296
+ image_logs = log_validation(
1297
+ vae=vae,
1298
+ text_encoder=text_encoder,
1299
+ tokenizer=tokenizer,
1300
+ unet=unet,
1301
+ controlnet=None,
1302
+ args=args,
1303
+ accelerator=accelerator,
1304
+ weight_dtype=weight_dtype,
1305
+ step=global_step,
1306
+ is_final_validation=True,
1307
+ )
1308
+
1309
+ if args.push_to_hub:
1310
+ save_model_card(
1311
+ repo_id,
1312
+ image_logs=image_logs,
1313
+ base_model=args.pretrained_model_name_or_path,
1314
+ repo_folder=args.output_dir,
1315
+ )
1316
+ upload_folder(
1317
+ repo_id=repo_id,
1318
+ folder_path=args.output_dir,
1319
+ commit_message="End of training",
1320
+ ignore_patterns=["step_*", "epoch_*"],
1321
+ )
1322
+
1323
+ accelerator.end_training()
1324
+
1325
+
1326
+ if __name__ == "__main__":
1327
+ args = parse_args()
1328
+ main(args)