Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
be61cf2
1
Parent(s):
ae26d48
update
Browse files- EMA.py +0 -1
- clip_encoder.py +64 -0
- encoder.py +9 -0
- run.py +103 -3
- scripts/init.sh +15 -0
- scripts/run_hdfml.sh +25 -0
- scripts/run_jurecadc_ddp.sh +4 -1
- test_ddgan.py +280 -64
- train_ddgan.py +158 -60
- utils.py +2 -1
EMA.py
CHANGED
|
@@ -39,7 +39,6 @@ class EMA(Optimizer):
|
|
| 39 |
# State initialization
|
| 40 |
if 'ema' not in state:
|
| 41 |
state['ema'] = p.data.clone()
|
| 42 |
-
|
| 43 |
if p.shape not in params:
|
| 44 |
params[p.shape] = {'idx': 0, 'data': []}
|
| 45 |
ema[p.shape] = []
|
|
|
|
| 39 |
# State initialization
|
| 40 |
if 'ema' not in state:
|
| 41 |
state['ema'] = p.data.clone()
|
|
|
|
| 42 |
if p.shape not in params:
|
| 43 |
params[p.shape] = {'idx': 0, 'data': []}
|
| 44 |
ema[p.shape] = []
|
clip_encoder.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import open_clip
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def exists(val):
|
| 8 |
+
return val is not None
|
| 9 |
+
|
| 10 |
+
class CLIPEncoder(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, model, pretrained):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = model
|
| 15 |
+
self.pretrained = pretrained
|
| 16 |
+
self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
|
| 17 |
+
self.output_size = self.model.transformer.width
|
| 18 |
+
|
| 19 |
+
def forward(self, texts, return_only_pooled=True):
|
| 20 |
+
device = next(self.parameters()).device
|
| 21 |
+
toks = open_clip.tokenize(texts).to(device)
|
| 22 |
+
x = self.model.token_embedding(toks) # [batch_size, n_ctx, d_model]
|
| 23 |
+
x = x + self.model.positional_embedding
|
| 24 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 25 |
+
x = self.model.transformer(x, attn_mask=self.model.attn_mask)
|
| 26 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 27 |
+
x = self.model.ln_final(x)
|
| 28 |
+
mask = (toks!=0)
|
| 29 |
+
pooled = x[torch.arange(x.shape[0]), toks.argmax(dim=-1)] @ self.model.text_projection
|
| 30 |
+
if return_only_pooled:
|
| 31 |
+
return pooled
|
| 32 |
+
else:
|
| 33 |
+
return pooled, x, mask
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CLIPImageEncoder(nn.Module):
|
| 39 |
+
|
| 40 |
+
def __init__(self, model_type="ViT-B/32"):
|
| 41 |
+
super().__init__()
|
| 42 |
+
import clip
|
| 43 |
+
self.model, preprocess = clip.load(model_type, device="cpu", jit=False)
|
| 44 |
+
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
| 45 |
+
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
| 46 |
+
mean = torch.tensor(CLIP_MEAN).view(1, 3, 1, 1)
|
| 47 |
+
std = torch.tensor(CLIP_STD).view(1, 3, 1, 1)
|
| 48 |
+
self.register_buffer("mean", mean)
|
| 49 |
+
self.register_buffer("std", std)
|
| 50 |
+
self.output_size = 512
|
| 51 |
+
|
| 52 |
+
def forward_image(self, x):
|
| 53 |
+
x = torch.nn.functional.interpolate(x, mode='bicubic', size=(224, 224))
|
| 54 |
+
x = (x-self.mean)/self.std
|
| 55 |
+
return self.model.encode_image(x)
|
| 56 |
+
|
| 57 |
+
def forward_text(self, texts):
|
| 58 |
+
import clip
|
| 59 |
+
toks = clip.tokenize(texts, truncate=True).to(self.mean.device)
|
| 60 |
+
return self.model.encode_text(toks)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
encoder.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import t5
|
| 2 |
+
import clip_encoder
|
| 3 |
+
|
| 4 |
+
def build_encoder(name, **kwargs):
|
| 5 |
+
if name.startswith("google"):
|
| 6 |
+
return t5.T5Encoder(name=name, **kwargs)
|
| 7 |
+
elif name.startswith("openclip"):
|
| 8 |
+
_, model, pretrained = name.split("/")
|
| 9 |
+
return clip_encoder.CLIPEncoder(model, pretrained)
|
run.py
CHANGED
|
@@ -132,6 +132,8 @@ def ddgan_laion_aesthetic_v2():
|
|
| 132 |
def ddgan_laion_aesthetic_v3():
|
| 133 |
cfg = ddgan_laion_aesthetic_v1()
|
| 134 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
|
|
|
|
|
|
| 135 |
return cfg
|
| 136 |
|
| 137 |
def ddgan_laion_aesthetic_v4():
|
|
@@ -146,6 +148,85 @@ def ddgan_laion_aesthetic_v5():
|
|
| 146 |
cfg['model']['grad_penalty_cond'] = ''
|
| 147 |
return cfg
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
models = [
|
| 150 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 151 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
@@ -166,6 +247,23 @@ models = [
|
|
| 166 |
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
|
| 167 |
ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
|
| 168 |
ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
]
|
| 170 |
|
| 171 |
def get_model(model_name):
|
|
@@ -174,7 +272,7 @@ def get_model(model_name):
|
|
| 174 |
return model()
|
| 175 |
|
| 176 |
|
| 177 |
-
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False):
|
| 178 |
|
| 179 |
cfg = get_model(model_name)
|
| 180 |
model = cfg['model']
|
|
@@ -204,13 +302,15 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
|
|
| 204 |
args['scale_factor_h'] = scale_factor_h
|
| 205 |
args['scale_factor_w'] = scale_factor_w
|
| 206 |
args['n_mlp'] = model.get("n_mlp")
|
|
|
|
| 207 |
if fid:
|
| 208 |
args['compute_fid'] = ''
|
| 209 |
args['real_img_dir'] = real_img_dir
|
| 210 |
args['nb_images_for_fid'] = nb_images_for_fid
|
| 211 |
if compute_clip_score:
|
| 212 |
args['compute_clip_score'] = ""
|
| 213 |
-
|
|
|
|
| 214 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
| 215 |
print(cmd)
|
| 216 |
call(cmd, shell=True)
|
|
@@ -234,4 +334,4 @@ def eval_results(model_name):
|
|
| 234 |
|
| 235 |
if __name__ == "__main__":
|
| 236 |
from clize import run
|
| 237 |
-
run([test, eval_results])
|
|
|
|
| 132 |
def ddgan_laion_aesthetic_v3():
|
| 133 |
cfg = ddgan_laion_aesthetic_v1()
|
| 134 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
| 135 |
+
cfg['model']['mismatch_loss'] = ''
|
| 136 |
+
cfg['model']['grad_penalty_cond'] = ''
|
| 137 |
return cfg
|
| 138 |
|
| 139 |
def ddgan_laion_aesthetic_v4():
|
|
|
|
| 148 |
cfg['model']['grad_penalty_cond'] = ''
|
| 149 |
return cfg
|
| 150 |
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def ddgan_laion2b_v1():
|
| 154 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 155 |
+
cfg['model']['mismatch_loss'] = ''
|
| 156 |
+
cfg['model']['grad_penalty_cond'] = ''
|
| 157 |
+
cfg['model']['num_channels_dae'] = 224
|
| 158 |
+
cfg['model']['batch_size'] = 2
|
| 159 |
+
cfg['model']['discr_type'] = "large_cond_attn"
|
| 160 |
+
cfg['model']['preprocessing'] = 'random_resized_crop_v1'
|
| 161 |
+
return cfg
|
| 162 |
+
|
| 163 |
+
def ddgan_laion_aesthetic_v6():
|
| 164 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 165 |
+
cfg['model']['no_lr_decay'] = ''
|
| 166 |
+
return cfg
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def ddgan_laion_aesthetic_v7():
|
| 171 |
+
cfg = ddgan_laion_aesthetic_v6()
|
| 172 |
+
cfg['model']['r1_gamma'] = 5
|
| 173 |
+
return cfg
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def ddgan_laion_aesthetic_v8():
|
| 177 |
+
cfg = ddgan_laion_aesthetic_v6()
|
| 178 |
+
cfg['model']['num_timesteps'] = 8
|
| 179 |
+
return cfg
|
| 180 |
+
|
| 181 |
+
def ddgan_laion_aesthetic_v9():
|
| 182 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 183 |
+
cfg['model']['num_channels_dae'] = 384
|
| 184 |
+
return cfg
|
| 185 |
+
|
| 186 |
+
def ddgan_sd_v1():
|
| 187 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 188 |
+
return cfg
|
| 189 |
+
def ddgan_sd_v2():
|
| 190 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 191 |
+
return cfg
|
| 192 |
+
def ddgan_sd_v3():
|
| 193 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 194 |
+
return cfg
|
| 195 |
+
def ddgan_sd_v4():
|
| 196 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 197 |
+
return cfg
|
| 198 |
+
def ddgan_sd_v5():
|
| 199 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 200 |
+
cfg['model']['num_timesteps'] = 8
|
| 201 |
+
return cfg
|
| 202 |
+
def ddgan_sd_v6():
|
| 203 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 204 |
+
cfg['model']['num_channels_dae'] = 192
|
| 205 |
+
return cfg
|
| 206 |
+
def ddgan_sd_v7():
|
| 207 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 208 |
+
return cfg
|
| 209 |
+
def ddgan_sd_v8():
|
| 210 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 211 |
+
cfg['model']['image_size'] = 512
|
| 212 |
+
return cfg
|
| 213 |
+
def ddgan_laion_aesthetic_v12():
|
| 214 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 215 |
+
return cfg
|
| 216 |
+
def ddgan_laion_aesthetic_v13():
|
| 217 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 218 |
+
cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
|
| 219 |
+
return cfg
|
| 220 |
+
|
| 221 |
+
def ddgan_laion_aesthetic_v14():
|
| 222 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 223 |
+
cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
|
| 224 |
+
return cfg
|
| 225 |
+
def ddgan_sd_v9():
|
| 226 |
+
cfg = ddgan_laion_aesthetic_v3()
|
| 227 |
+
cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
|
| 228 |
+
return cfg
|
| 229 |
+
|
| 230 |
models = [
|
| 231 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 232 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
|
| 247 |
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
|
| 248 |
ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
|
| 249 |
ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
|
| 250 |
+
ddgan_laion_aesthetic_v6, # like v3 but without lr decay
|
| 251 |
+
ddgan_laion_aesthetic_v7, # like v6 but with r1 gamma of 5 instead of 1, trying to constrain the discr more.
|
| 252 |
+
ddgan_laion_aesthetic_v8, # like v6 but with 8 timesteps
|
| 253 |
+
ddgan_laion_aesthetic_v9,
|
| 254 |
+
ddgan_laion_aesthetic_v12,
|
| 255 |
+
ddgan_laion_aesthetic_v13,
|
| 256 |
+
ddgan_laion_aesthetic_v14,
|
| 257 |
+
ddgan_laion2b_v1,
|
| 258 |
+
ddgan_sd_v1,
|
| 259 |
+
ddgan_sd_v2,
|
| 260 |
+
ddgan_sd_v3,
|
| 261 |
+
ddgan_sd_v4,
|
| 262 |
+
ddgan_sd_v5,
|
| 263 |
+
ddgan_sd_v6,
|
| 264 |
+
ddgan_sd_v7,
|
| 265 |
+
ddgan_sd_v8,
|
| 266 |
+
ddgan_sd_v9,
|
| 267 |
]
|
| 268 |
|
| 269 |
def get_model(model_name):
|
|
|
|
| 272 |
return model()
|
| 273 |
|
| 274 |
|
| 275 |
+
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional"):
|
| 276 |
|
| 277 |
cfg = get_model(model_name)
|
| 278 |
model = cfg['model']
|
|
|
|
| 302 |
args['scale_factor_h'] = scale_factor_h
|
| 303 |
args['scale_factor_w'] = scale_factor_w
|
| 304 |
args['n_mlp'] = model.get("n_mlp")
|
| 305 |
+
args['scale_method'] = scale_method
|
| 306 |
if fid:
|
| 307 |
args['compute_fid'] = ''
|
| 308 |
args['real_img_dir'] = real_img_dir
|
| 309 |
args['nb_images_for_fid'] = nb_images_for_fid
|
| 310 |
if compute_clip_score:
|
| 311 |
args['compute_clip_score'] = ""
|
| 312 |
+
if eval_name:
|
| 313 |
+
args["eval_name"] = eval_name
|
| 314 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
| 315 |
print(cmd)
|
| 316 |
call(cmd, shell=True)
|
|
|
|
| 334 |
|
| 335 |
if __name__ == "__main__":
|
| 336 |
from clize import run
|
| 337 |
+
run([test, eval_results])
|
scripts/init.sh
CHANGED
|
@@ -32,6 +32,21 @@ if [[ "$machine" == juwelsbooster ]]; then
|
|
| 32 |
ml torchvision/0.12.0
|
| 33 |
source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
|
| 34 |
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
if [[ "$machine" == jusuf ]]; then
|
| 36 |
echo not supported
|
| 37 |
fi
|
|
|
|
| 32 |
ml torchvision/0.12.0
|
| 33 |
source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
|
| 34 |
fi
|
| 35 |
+
if [[ "$machine" == hdfml ]]; then
|
| 36 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 37 |
+
ml purge
|
| 38 |
+
ml use $OTHERSTAGES
|
| 39 |
+
ml Stages/2022
|
| 40 |
+
ml GCC/11.2.0
|
| 41 |
+
ml OpenMPI/4.1.2
|
| 42 |
+
ml CUDA/11.5
|
| 43 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
| 44 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
| 45 |
+
ml PyTorch/1.11-CUDA-11.5
|
| 46 |
+
ml Horovod/0.24
|
| 47 |
+
ml torchvision/0.12.0
|
| 48 |
+
source envs/hdfml/bin/activate
|
| 49 |
+
fi
|
| 50 |
if [[ "$machine" == jusuf ]]; then
|
| 51 |
echo not supported
|
| 52 |
fi
|
scripts/run_hdfml.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash -x
|
| 2 |
+
#SBATCH --account=cstdl
|
| 3 |
+
#SBATCH --nodes=8
|
| 4 |
+
#SBATCH --ntasks-per-node=4
|
| 5 |
+
#SBATCH --cpus-per-task=8
|
| 6 |
+
#SBATCH --time=06:00:00
|
| 7 |
+
#SBATCH --gres=gpu
|
| 8 |
+
#SBATCH --partition=batch
|
| 9 |
+
ml purge
|
| 10 |
+
ml use $OTHERSTAGES
|
| 11 |
+
ml Stages/2022
|
| 12 |
+
ml GCC/11.2.0
|
| 13 |
+
ml OpenMPI/4.1.2
|
| 14 |
+
ml CUDA/11.5
|
| 15 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
| 16 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
| 17 |
+
ml PyTorch/1.11-CUDA-11.5
|
| 18 |
+
ml Horovod/0.24
|
| 19 |
+
ml torchvision/0.12.0
|
| 20 |
+
source envs/hdfml/bin/activate
|
| 21 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 22 |
+
echo "Job id: $SLURM_JOB_ID"
|
| 23 |
+
export TOKENIZERS_PARALLELISM=false
|
| 24 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 25 |
+
srun python -u $*
|
scripts/run_jurecadc_ddp.sh
CHANGED
|
@@ -13,5 +13,8 @@ source scripts/init.sh
|
|
| 13 |
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 14 |
echo "Job id: $SLURM_JOB_ID"
|
| 15 |
export TOKENIZERS_PARALLELISM=false
|
| 16 |
-
export NCCL_ASYNC_ERROR_HANDLING=1
|
|
|
|
|
|
|
|
|
|
| 17 |
srun python -u $*
|
|
|
|
| 13 |
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 14 |
echo "Job id: $SLURM_JOB_ID"
|
| 15 |
export TOKENIZERS_PARALLELISM=false
|
| 16 |
+
#export NCCL_ASYNC_ERROR_HANDLING=1
|
| 17 |
+
export NCCL_IB_TIMEOUT=50
|
| 18 |
+
export UCX_RC_TIMEOUT=4s
|
| 19 |
+
export NCCL_IB_RETRY_CNT=10
|
| 20 |
srun python -u $*
|
test_ddgan.py
CHANGED
|
@@ -86,7 +86,18 @@ class Posterior_Coefficients():
|
|
| 86 |
self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
|
| 87 |
|
| 88 |
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def sample_posterior(coefficients, x_0,x_t, t):
|
| 91 |
|
| 92 |
def q_posterior(x_0, x_t, t):
|
|
@@ -150,10 +161,10 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
|
|
| 150 |
# eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
|
| 151 |
eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
|
| 152 |
x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
|
| 153 |
-
|
| 154 |
|
| 155 |
# Dynamic thresholding
|
| 156 |
-
q =
|
| 157 |
#print("Before", x_0.min(), x_0.max())
|
| 158 |
if q:
|
| 159 |
shape = x_0.shape
|
|
@@ -180,9 +191,174 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
|
|
| 180 |
return x
|
| 181 |
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
#%%
|
| 184 |
def sample_and_test(args):
|
| 185 |
torch.manual_seed(args.seed)
|
|
|
|
| 186 |
device = 'cuda:0'
|
| 187 |
text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
| 188 |
args.cond_size = text_encoder.output_size
|
|
@@ -197,10 +373,9 @@ def sample_and_test(args):
|
|
| 197 |
|
| 198 |
to_range_0_1 = lambda x: (x + 1.) / 2.
|
| 199 |
|
| 200 |
-
|
| 201 |
netG = NCSNpp(args).to(device)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
if args.epoch_id == -1:
|
| 205 |
epochs = range(1000)
|
| 206 |
else:
|
|
@@ -209,17 +384,27 @@ def sample_and_test(args):
|
|
| 209 |
for epoch in epochs:
|
| 210 |
args.epoch_id = epoch
|
| 211 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
|
|
|
| 212 |
if not os.path.exists(path):
|
| 213 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
ckpt = torch.load(path, map_location=device)
|
| 215 |
-
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
if args.compute_fid and
|
| 218 |
continue
|
| 219 |
print("Eval Epoch", args.epoch_id)
|
| 220 |
#loading weights from ddp in single gpu
|
|
|
|
| 221 |
for key in list(ckpt.keys()):
|
| 222 |
-
|
|
|
|
| 223 |
netG.load_state_dict(ckpt)
|
| 224 |
netG.eval()
|
| 225 |
|
|
@@ -234,7 +419,7 @@ def sample_and_test(args):
|
|
| 234 |
if not os.path.exists(save_dir):
|
| 235 |
os.makedirs(save_dir)
|
| 236 |
|
| 237 |
-
if args.compute_fid:
|
| 238 |
from torch.nn.functional import adaptive_avg_pool2d
|
| 239 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
| 240 |
from pytorch_fid.inception import InceptionV3
|
|
@@ -252,9 +437,11 @@ def sample_and_test(args):
|
|
| 252 |
print("Text size:", len(texts))
|
| 253 |
#print("Iters:", iters_needed)
|
| 254 |
i = 0
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
| 258 |
|
| 259 |
if args.compute_clip_score:
|
| 260 |
import clip
|
|
@@ -264,19 +451,20 @@ def sample_and_test(args):
|
|
| 264 |
clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
|
| 265 |
clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
| 280 |
if args.compute_clip_score:
|
| 281 |
clip_scores = []
|
| 282 |
|
|
@@ -287,7 +475,6 @@ def sample_and_test(args):
|
|
| 287 |
bs = len(text)
|
| 288 |
t0 = time.time()
|
| 289 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
| 290 |
-
#print(x_t_1.shape)
|
| 291 |
if args.guidance_scale:
|
| 292 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 293 |
else:
|
|
@@ -298,45 +485,39 @@ def sample_and_test(args):
|
|
| 298 |
index = i * args.batch_size + j
|
| 299 |
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
| 300 |
"""
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
| 309 |
|
| 310 |
if args.compute_clip_score:
|
| 311 |
with torch.no_grad():
|
| 312 |
clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
|
| 313 |
-
|
|
|
|
| 314 |
imf = clip_model.encode_image(clip_ims)
|
| 315 |
txtf = clip_model.encode_text(clip_txt)
|
| 316 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
| 317 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
| 318 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
| 319 |
-
|
| 320 |
if i % 10 == 0:
|
| 321 |
-
print('
|
| 322 |
-
"""
|
| 323 |
-
if i % 10 == 0:
|
| 324 |
-
ff = np.concatenate(fake_features)
|
| 325 |
-
fake_mu = np.mean(ff, axis=0)
|
| 326 |
-
fake_sigma = np.cov(ff, rowvar=False)
|
| 327 |
-
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
| 328 |
-
print("FID", fid)
|
| 329 |
-
"""
|
| 330 |
i += 1
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
}
|
| 340 |
if args.compute_clip_score:
|
| 341 |
clip_score = torch.cat(clip_scores).mean().item()
|
| 342 |
results['clip_score'] = clip_score
|
|
@@ -344,22 +525,54 @@ def sample_and_test(args):
|
|
| 344 |
with open(dest, "w") as fd:
|
| 345 |
json.dump(results, fd)
|
| 346 |
print(results)
|
| 347 |
-
else:
|
| 348 |
if args.cond_text.endswith(".txt"):
|
| 349 |
texts = open(args.cond_text).readlines()
|
| 350 |
texts = [t.strip() for t in texts]
|
| 351 |
else:
|
| 352 |
texts = [args.cond_text] * args.batch_size
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
else:
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
|
| 365 |
|
|
@@ -374,6 +587,7 @@ if __name__ == '__main__':
|
|
| 374 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
| 375 |
help='whether or not compute CLIP score')
|
| 376 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
|
|
|
| 377 |
|
| 378 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
| 379 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
|
@@ -381,6 +595,8 @@ if __name__ == '__main__':
|
|
| 381 |
parser.add_argument('--cond_text', type=str,default="0")
|
| 382 |
parser.add_argument('--scale_factor_h', type=int,default=1)
|
| 383 |
parser.add_argument('--scale_factor_w', type=int,default=1)
|
|
|
|
|
|
|
| 384 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 385 |
|
| 386 |
|
|
|
|
| 86 |
self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
|
| 87 |
|
| 88 |
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
|
| 89 |
+
|
| 90 |
+
def predict_q_posterior(coefficients, x_0, x_t, t):
|
| 91 |
+
mean = (
|
| 92 |
+
extract(coefficients.posterior_mean_coef1, t, x_t.shape) * x_0
|
| 93 |
+
+ extract(coefficients.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 94 |
+
)
|
| 95 |
+
var = extract(coefficients.posterior_variance, t, x_t.shape)
|
| 96 |
+
log_var_clipped = extract(coefficients.posterior_log_variance_clipped, t, x_t.shape)
|
| 97 |
+
return mean, var, log_var_clipped
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
def sample_posterior(coefficients, x_0,x_t, t):
|
| 102 |
|
| 103 |
def q_posterior(x_0, x_t, t):
|
|
|
|
| 161 |
# eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
|
| 162 |
eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
|
| 163 |
x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
|
| 164 |
+
#x_0 = x_0_uncond * (1 - guidance_scale) + x_0_cond * guidance_scale
|
| 165 |
|
| 166 |
# Dynamic thresholding
|
| 167 |
+
q = opt.dynamic_thresholding_quantile
|
| 168 |
#print("Before", x_0.min(), x_0.max())
|
| 169 |
if q:
|
| 170 |
shape = x_0.shape
|
|
|
|
| 191 |
return x
|
| 192 |
|
| 193 |
|
| 194 |
+
def sample_from_model_classifier_free_guidance_convolutional(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0, split_input_params=None):
|
| 195 |
+
x = x_init
|
| 196 |
+
null = text_encoder([""] * len(x_init), return_only_pooled=False)
|
| 197 |
+
#latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
| 198 |
+
ks = split_input_params["ks"] # eg. (128, 128)
|
| 199 |
+
stride = split_input_params["stride"] # eg. (64, 64)
|
| 200 |
+
uf = split_input_params["vqf"]
|
| 201 |
+
with torch.no_grad():
|
| 202 |
+
for i in reversed(range(n_time)):
|
| 203 |
+
t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
|
| 204 |
+
t_time = t
|
| 205 |
+
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
| 206 |
+
|
| 207 |
+
fold, unfold, normalization, weighting = get_fold_unfold(x, ks, stride, split_input_params, uf=uf)
|
| 208 |
+
x = unfold(x)
|
| 209 |
+
x = x.view((x.shape[0], -1, ks[0], ks[1], x.shape[-1]))
|
| 210 |
+
x_new_list = []
|
| 211 |
+
for j in range(x.shape[-1]):
|
| 212 |
+
x_0_uncond = generator(x[:,:,:,:,j], t_time, latent_z, cond=null)
|
| 213 |
+
x_0_cond = generator(x[:,:,:,:,j], t_time, latent_z, cond=cond)
|
| 214 |
+
|
| 215 |
+
eps_uncond = (x[:,:,:,:,j] - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_uncond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
|
| 216 |
+
eps_cond = (x[:,:,:,:,j] - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_cond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
|
| 217 |
+
|
| 218 |
+
eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
|
| 219 |
+
x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x[:,:,:,:,j] - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
|
| 220 |
+
q = args.dynamic_thresholding_quantile
|
| 221 |
+
if q:
|
| 222 |
+
shape = x_0.shape
|
| 223 |
+
x_0_v = x_0.view(shape[0], -1)
|
| 224 |
+
d = torch.quantile(torch.abs(x_0_v), q, dim=1, keepdim=True)
|
| 225 |
+
d.clamp_(min=1)
|
| 226 |
+
x_0_v = x_0_v.clamp(-d, d) / d
|
| 227 |
+
x_0 = x_0_v.view(shape)
|
| 228 |
+
x_new = sample_posterior(coefficients, x_0, x[:,:,:,:,j], t)
|
| 229 |
+
x_new_list.append(x_new)
|
| 230 |
+
|
| 231 |
+
o = torch.stack(x_new_list, axis=-1)
|
| 232 |
+
#o = o * weighting
|
| 233 |
+
o = o.view((o.shape[0], -1, o.shape[-1]))
|
| 234 |
+
decoded = fold(o)
|
| 235 |
+
decoded = decoded / normalization
|
| 236 |
+
x = decoded.detach()
|
| 237 |
+
|
| 238 |
+
return x
|
| 239 |
+
|
| 240 |
+
def sample_from_model_clip_guidance(coefficients, generator, clip_model, n_time, x_init, T, opt, texts, cond=None, guidance_scale=0):
|
| 241 |
+
x = x_init
|
| 242 |
+
text_features = torch.nn.functional.normalize(clip_model.forward_text(texts), dim=1)
|
| 243 |
+
n_time = 16
|
| 244 |
+
for i in reversed(range(n_time)):
|
| 245 |
+
t = torch.full((x.size(0),), i%4, dtype=torch.int64).to(x.device)
|
| 246 |
+
t_time = t
|
| 247 |
+
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
| 248 |
+
x.requires_grad = True
|
| 249 |
+
x_0 = generator(x, t_time, latent_z, cond=cond)
|
| 250 |
+
x_new = sample_posterior(coefficients, x_0, x, t)
|
| 251 |
+
x_new_n = (x_new + 1) / 2
|
| 252 |
+
image_features = torch.nn.functional.normalize(clip_model.forward_image(x_new_n), dim=1)
|
| 253 |
+
loss = (image_features*text_features).sum(dim=1).mean()
|
| 254 |
+
x_grad, = torch.autograd.grad(loss, x)
|
| 255 |
+
lr = 3000
|
| 256 |
+
x = x.detach()
|
| 257 |
+
print(x.min(),x.max(), lr*x_grad.min(), lr*x_grad.max())
|
| 258 |
+
x += x_grad * lr
|
| 259 |
+
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
x_0 = generator(x, t_time, latent_z, cond=cond)
|
| 262 |
+
x_new = sample_posterior(coefficients, x_0, x, t)
|
| 263 |
+
|
| 264 |
+
x = x_new.detach()
|
| 265 |
+
print(i)
|
| 266 |
+
return x
|
| 267 |
+
|
| 268 |
+
def meshgrid(h, w):
|
| 269 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
| 270 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
| 271 |
+
|
| 272 |
+
arr = torch.cat([y, x], dim=-1)
|
| 273 |
+
return arr
|
| 274 |
+
def delta_border(h, w):
|
| 275 |
+
"""
|
| 276 |
+
:param h: height
|
| 277 |
+
:param w: width
|
| 278 |
+
:return: normalized distance to image border,
|
| 279 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
| 280 |
+
"""
|
| 281 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
| 282 |
+
arr = meshgrid(h, w) / lower_right_corner
|
| 283 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
| 284 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
| 285 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
| 286 |
+
return edge_dist
|
| 287 |
+
|
| 288 |
+
def get_weighting(h, w, Ly, Lx, device, split_input_params):
|
| 289 |
+
weighting = delta_border(h, w)
|
| 290 |
+
weighting = torch.clip(weighting, split_input_params["clip_min_weight"],
|
| 291 |
+
split_input_params["clip_max_weight"], )
|
| 292 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
| 293 |
+
|
| 294 |
+
if split_input_params["tie_braker"]:
|
| 295 |
+
L_weighting = delta_border(Ly, Lx)
|
| 296 |
+
L_weighting = torch.clip(L_weighting,
|
| 297 |
+
split_input_params["clip_min_tie_weight"],
|
| 298 |
+
split_input_params["clip_max_tie_weight"])
|
| 299 |
+
|
| 300 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
| 301 |
+
weighting = weighting * L_weighting
|
| 302 |
+
return weighting
|
| 303 |
+
|
| 304 |
+
def get_fold_unfold(x, kernel_size, stride, split_input_params, uf=1, df=1): # todo load once not every time, shorten code
|
| 305 |
+
"""
|
| 306 |
+
:param x: img of size (bs, c, h, w)
|
| 307 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
| 308 |
+
"""
|
| 309 |
+
bs, nc, h, w = x.shape
|
| 310 |
+
|
| 311 |
+
# number of crops in image
|
| 312 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
| 313 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
| 314 |
+
|
| 315 |
+
if uf == 1 and df == 1:
|
| 316 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 317 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 318 |
+
|
| 319 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
| 320 |
+
|
| 321 |
+
weighting = get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device, split_input_params).to(x.dtype)
|
| 322 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
| 323 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
| 324 |
+
|
| 325 |
+
elif uf > 1 and df == 1:
|
| 326 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 327 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 328 |
+
|
| 329 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
| 330 |
+
dilation=1, padding=0,
|
| 331 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
| 332 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
| 333 |
+
|
| 334 |
+
weighting = get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device, split_input_params).to(x.dtype)
|
| 335 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
| 336 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
| 337 |
+
|
| 338 |
+
elif df > 1 and uf == 1:
|
| 339 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 340 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 341 |
+
|
| 342 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
| 343 |
+
dilation=1, padding=0,
|
| 344 |
+
stride=(stride[0] // df, stride[1] // df))
|
| 345 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
| 346 |
+
|
| 347 |
+
weighting = get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device, split_input_params).to(x.dtype)
|
| 348 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
| 349 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
| 350 |
+
|
| 351 |
+
else:
|
| 352 |
+
raise NotImplementedError
|
| 353 |
+
|
| 354 |
+
return fold, unfold, normalization, weighting
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
#%%
|
| 359 |
def sample_and_test(args):
|
| 360 |
torch.manual_seed(args.seed)
|
| 361 |
+
|
| 362 |
device = 'cuda:0'
|
| 363 |
text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
| 364 |
args.cond_size = text_encoder.output_size
|
|
|
|
| 373 |
|
| 374 |
to_range_0_1 = lambda x: (x + 1.) / 2.
|
| 375 |
|
| 376 |
+
print(vars(args))
|
| 377 |
netG = NCSNpp(args).to(device)
|
| 378 |
+
|
|
|
|
| 379 |
if args.epoch_id == -1:
|
| 380 |
epochs = range(1000)
|
| 381 |
else:
|
|
|
|
| 384 |
for epoch in epochs:
|
| 385 |
args.epoch_id = epoch
|
| 386 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
| 387 |
+
next_path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id+1)
|
| 388 |
if not os.path.exists(path):
|
| 389 |
continue
|
| 390 |
+
print(path)
|
| 391 |
+
|
| 392 |
+
#if not os.path.exists(next_path):
|
| 393 |
+
# print(f"STOP at {epoch}")
|
| 394 |
+
# break
|
| 395 |
ckpt = torch.load(path, map_location=device)
|
| 396 |
+
suffix = '_' + args.eval_name if args.eval_name else ""
|
| 397 |
+
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
|
| 398 |
+
next_dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id+1, suffix)
|
| 399 |
|
| 400 |
+
if (args.compute_fid or args.compute_clip_score) and os.path.exists(dest):
|
| 401 |
continue
|
| 402 |
print("Eval Epoch", args.epoch_id)
|
| 403 |
#loading weights from ddp in single gpu
|
| 404 |
+
#print(ckpt.keys())
|
| 405 |
for key in list(ckpt.keys()):
|
| 406 |
+
if key.startswith("module"):
|
| 407 |
+
ckpt[key[7:]] = ckpt.pop(key)
|
| 408 |
netG.load_state_dict(ckpt)
|
| 409 |
netG.eval()
|
| 410 |
|
|
|
|
| 419 |
if not os.path.exists(save_dir):
|
| 420 |
os.makedirs(save_dir)
|
| 421 |
|
| 422 |
+
if args.compute_fid or args.compute_clip_score:
|
| 423 |
from torch.nn.functional import adaptive_avg_pool2d
|
| 424 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
| 425 |
from pytorch_fid.inception import InceptionV3
|
|
|
|
| 437 |
print("Text size:", len(texts))
|
| 438 |
#print("Iters:", iters_needed)
|
| 439 |
i = 0
|
| 440 |
+
|
| 441 |
+
if args.compute_fid:
|
| 442 |
+
dims = 2048
|
| 443 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
| 444 |
+
inceptionv3 = InceptionV3([block_idx]).to(device)
|
| 445 |
|
| 446 |
if args.compute_clip_score:
|
| 447 |
import clip
|
|
|
|
| 451 |
clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
|
| 452 |
clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
|
| 453 |
|
| 454 |
+
if args.compute_fid:
|
| 455 |
+
if not args.real_img_dir.endswith("npz"):
|
| 456 |
+
real_mu, real_sigma = compute_statistics_of_path(
|
| 457 |
+
args.real_img_dir, inceptionv3, args.batch_size, dims, device,
|
| 458 |
+
resize=args.image_size,
|
| 459 |
+
)
|
| 460 |
+
np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
|
| 461 |
+
else:
|
| 462 |
+
stats = np.load(args.real_img_dir)
|
| 463 |
+
real_mu = stats['mu']
|
| 464 |
+
real_sigma = stats['sigma']
|
| 465 |
+
|
| 466 |
+
fake_features = []
|
| 467 |
+
|
| 468 |
if args.compute_clip_score:
|
| 469 |
clip_scores = []
|
| 470 |
|
|
|
|
| 475 |
bs = len(text)
|
| 476 |
t0 = time.time()
|
| 477 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
|
|
|
| 478 |
if args.guidance_scale:
|
| 479 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 480 |
else:
|
|
|
|
| 485 |
index = i * args.batch_size + j
|
| 486 |
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
| 487 |
"""
|
| 488 |
+
|
| 489 |
+
if args.compute_fid:
|
| 490 |
+
with torch.no_grad():
|
| 491 |
+
pred = inceptionv3(fake_sample)[0]
|
| 492 |
+
# If model output is not scalar, apply global spatial average pooling.
|
| 493 |
+
# This happens if you choose a dimensionality not equal 2048.
|
| 494 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
| 495 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
| 496 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
| 497 |
+
fake_features.append(pred)
|
| 498 |
|
| 499 |
if args.compute_clip_score:
|
| 500 |
with torch.no_grad():
|
| 501 |
clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
|
| 502 |
+
clip_ims = (clip_ims - clip_mean) / clip_std
|
| 503 |
+
clip_txt = clip.tokenize(text, truncate=True).to(device)
|
| 504 |
imf = clip_model.encode_image(clip_ims)
|
| 505 |
txtf = clip_model.encode_text(clip_txt)
|
| 506 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
| 507 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
| 508 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
| 509 |
+
|
| 510 |
if i % 10 == 0:
|
| 511 |
+
print('evaluating batch ', i, time.time() - t0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
i += 1
|
| 513 |
|
| 514 |
+
results = {}
|
| 515 |
+
if args.compute_fid:
|
| 516 |
+
fake_features = np.concatenate(fake_features)
|
| 517 |
+
fake_mu = np.mean(fake_features, axis=0)
|
| 518 |
+
fake_sigma = np.cov(fake_features, rowvar=False)
|
| 519 |
+
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
| 520 |
+
results['fid'] = fid
|
|
|
|
| 521 |
if args.compute_clip_score:
|
| 522 |
clip_score = torch.cat(clip_scores).mean().item()
|
| 523 |
results['clip_score'] = clip_score
|
|
|
|
| 525 |
with open(dest, "w") as fd:
|
| 526 |
json.dump(results, fd)
|
| 527 |
print(results)
|
| 528 |
+
else:
|
| 529 |
if args.cond_text.endswith(".txt"):
|
| 530 |
texts = open(args.cond_text).readlines()
|
| 531 |
texts = [t.strip() for t in texts]
|
| 532 |
else:
|
| 533 |
texts = [args.cond_text] * args.batch_size
|
| 534 |
+
clip_guidance = False
|
| 535 |
+
if clip_guidance:
|
| 536 |
+
from clip_encoder import CLIPImageEncoder
|
| 537 |
+
cond = text_encoder(texts, return_only_pooled=False)
|
| 538 |
+
clip_image_model = CLIPImageEncoder().to(device)
|
| 539 |
+
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
|
| 540 |
+
fake_sample = sample_from_model_clip_guidance(pos_coeff, netG, clip_image_model, args.num_timesteps, x_t_1,T, args, texts, cond=cond, guidance_scale=args.guidance_scale)
|
| 541 |
+
fake_sample = to_range_0_1(fake_sample)
|
| 542 |
+
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
| 543 |
+
|
| 544 |
else:
|
| 545 |
+
cond = text_encoder(texts, return_only_pooled=False)
|
| 546 |
+
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
|
| 547 |
+
t0 = time.time()
|
| 548 |
+
if args.guidance_scale:
|
| 549 |
+
if args.scale_factor_h > 1 or args.scale_factor_w > 1:
|
| 550 |
+
if args.scale_method == "convolutional":
|
| 551 |
+
split_input_params = {
|
| 552 |
+
"ks": (args.image_size, args.image_size),
|
| 553 |
+
"stride": (150, 150),
|
| 554 |
+
"clip_max_tie_weight": 0.5,
|
| 555 |
+
"clip_min_tie_weight": 0.01,
|
| 556 |
+
"clip_max_weight": 0.5,
|
| 557 |
+
"clip_min_weight": 0.01,
|
| 558 |
+
|
| 559 |
+
"tie_braker": True,
|
| 560 |
+
'vqf': 1,
|
| 561 |
+
}
|
| 562 |
+
fake_sample = sample_from_model_classifier_free_guidance_convolutional(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale, split_input_params=split_input_params)
|
| 563 |
+
elif args.scale_method == "larger_input":
|
| 564 |
+
netG.attn_resolutions = [r * args.scale_factor_w for r in netG.attn_resolutions]
|
| 565 |
+
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 566 |
+
else:
|
| 567 |
+
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 568 |
+
else:
|
| 569 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
| 570 |
+
|
| 571 |
+
print(time.time() - t0)
|
| 572 |
+
fake_sample = to_range_0_1(fake_sample)
|
| 573 |
+
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
| 574 |
+
|
| 575 |
+
|
| 576 |
|
| 577 |
|
| 578 |
|
|
|
|
| 587 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
| 588 |
help='whether or not compute CLIP score')
|
| 589 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
| 590 |
+
parser.add_argument('--eval_name', type=str,default="")
|
| 591 |
|
| 592 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
| 593 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
|
|
|
| 595 |
parser.add_argument('--cond_text', type=str,default="0")
|
| 596 |
parser.add_argument('--scale_factor_h', type=int,default=1)
|
| 597 |
parser.add_argument('--scale_factor_w', type=int,default=1)
|
| 598 |
+
parser.add_argument('--scale_method', type=str,default="convolutional")
|
| 599 |
+
|
| 600 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 601 |
|
| 602 |
|
train_ddgan.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
| 6 |
# ---------------------------------------------------------------
|
| 7 |
|
| 8 |
-
|
| 9 |
import argparse
|
| 10 |
import torch
|
| 11 |
import numpy as np
|
|
@@ -30,6 +30,7 @@ import shutil
|
|
| 30 |
import logging
|
| 31 |
from encoder import build_encoder
|
| 32 |
from utils import ResampledShards2
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def log_and_continue(exn):
|
|
@@ -194,23 +195,29 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
|
|
| 194 |
|
| 195 |
return x
|
| 196 |
|
| 197 |
-
|
| 198 |
|
| 199 |
def filter_no_caption(sample):
|
| 200 |
return 'txt' in sample
|
| 201 |
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
def train(rank, gpu, args):
|
| 205 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
| 206 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
| 207 |
from EMA import EMA
|
| 208 |
|
| 209 |
-
torch.manual_seed(args.seed + rank)
|
| 210 |
-
torch.cuda.manual_seed(args.seed + rank)
|
| 211 |
-
torch.cuda.manual_seed_all(args.seed + rank)
|
| 212 |
device = "cuda"
|
| 213 |
-
|
| 214 |
batch_size = args.batch_size
|
| 215 |
|
| 216 |
nz = args.nz #latent dimension
|
|
@@ -270,11 +277,12 @@ def train(rank, gpu, args):
|
|
| 270 |
])
|
| 271 |
elif args.preprocessing == "random_resized_crop_v1":
|
| 272 |
train_transform = transforms.Compose([
|
| 273 |
-
transforms.RandomResizedCrop(
|
| 274 |
transforms.ToTensor(),
|
| 275 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
| 276 |
])
|
| 277 |
-
|
|
|
|
| 278 |
pipeline.extend([
|
| 279 |
wds.split_by_node,
|
| 280 |
wds.split_by_worker,
|
|
@@ -339,6 +347,13 @@ def train(rank, gpu, args):
|
|
| 339 |
t_emb_dim = args.t_emb_dim,
|
| 340 |
cond_size=text_encoder.output_size,
|
| 341 |
act=nn.LeakyReLU(0.2)).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
elif args.discr_type == "large_cond_attn":
|
| 343 |
netD = CondAttnDiscriminator(
|
| 344 |
nc = 2*args.num_channels,
|
|
@@ -350,6 +365,15 @@ def train(rank, gpu, args):
|
|
| 350 |
broadcast_params(netG.parameters())
|
| 351 |
broadcast_params(netD.parameters())
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
|
| 354 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
| 355 |
|
|
@@ -358,9 +382,16 @@ def train(rank, gpu, args):
|
|
| 358 |
|
| 359 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
| 360 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
|
|
|
| 364 |
|
| 365 |
exp = args.exp
|
| 366 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
|
@@ -377,6 +408,10 @@ def train(rank, gpu, args):
|
|
| 377 |
T = get_time_schedule(args, device)
|
| 378 |
|
| 379 |
checkpoint_file = os.path.join(exp_path, 'content.pth')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
if args.resume and os.path.exists(checkpoint_file):
|
| 381 |
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
| 382 |
init_epoch = checkpoint['epoch']
|
|
@@ -395,7 +430,7 @@ def train(rank, gpu, args):
|
|
| 395 |
.format(checkpoint['epoch']))
|
| 396 |
else:
|
| 397 |
global_step, epoch, init_epoch = 0, 0, 0
|
| 398 |
-
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn")
|
| 399 |
for epoch in range(init_epoch, args.num_epoch+1):
|
| 400 |
if args.dataset == "wds":
|
| 401 |
os.environ["WDS_EPOCH"] = str(epoch)
|
|
@@ -403,6 +438,7 @@ def train(rank, gpu, args):
|
|
| 403 |
train_sampler.set_epoch(epoch)
|
| 404 |
|
| 405 |
for iteration, (x, y) in enumerate(data_loader):
|
|
|
|
| 406 |
if args.dataset != "wds":
|
| 407 |
y = [str(yi) for yi in y.tolist()]
|
| 408 |
|
|
@@ -437,15 +473,15 @@ def train(rank, gpu, args):
|
|
| 437 |
cond_for_discr.requires_grad = True
|
| 438 |
|
| 439 |
# train with real
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
|
| 445 |
|
| 446 |
errD_real.backward(retain_graph=True)
|
| 447 |
|
| 448 |
-
|
| 449 |
if args.lazy_reg is None:
|
| 450 |
if args.grad_penalty_cond:
|
| 451 |
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
|
@@ -491,26 +527,36 @@ def train(rank, gpu, args):
|
|
| 491 |
|
| 492 |
# train with fake
|
| 493 |
latent_z = torch.randn(batch_size, nz, device=device)
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
|
|
|
|
|
|
| 503 |
|
| 504 |
if args.mismatch_loss:
|
| 505 |
# following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
|
| 506 |
# we add a discr loss for (real image, non matching text)
|
| 507 |
#inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
|
|
|
| 514 |
|
| 515 |
errD_fake.backward()
|
| 516 |
|
|
@@ -534,58 +580,106 @@ def train(rank, gpu, args):
|
|
| 534 |
|
| 535 |
latent_z = torch.randn(batch_size, nz,device=device)
|
| 536 |
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
errG = F.softplus(-output)
|
| 546 |
-
errG = errG.mean()
|
| 547 |
|
| 548 |
errG.backward()
|
| 549 |
optimizerG.step()
|
| 550 |
|
| 551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
|
| 553 |
global_step += 1
|
|
|
|
|
|
|
| 554 |
if iteration % 100 == 0:
|
| 555 |
if rank == 0:
|
| 556 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
|
|
|
| 557 |
if iteration % 1000 == 0:
|
| 558 |
x_t_1 = torch.randn_like(real_data)
|
| 559 |
-
|
|
|
|
| 560 |
if rank == 0:
|
| 561 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
| 562 |
-
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
|
| 578 |
if not args.no_lr_decay:
|
| 579 |
|
| 580 |
schedulerG.step()
|
| 581 |
schedulerD.step()
|
| 582 |
-
|
| 583 |
if rank == 0:
|
| 584 |
if epoch % 10 == 0:
|
| 585 |
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
| 586 |
|
| 587 |
x_t_1 = torch.randn_like(real_data)
|
| 588 |
-
|
|
|
|
| 589 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
| 590 |
|
| 591 |
if args.save_content:
|
|
@@ -606,7 +700,8 @@ def train(rank, gpu, args):
|
|
| 606 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
| 607 |
if args.use_ema:
|
| 608 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 609 |
-
|
|
|
|
| 610 |
|
| 611 |
|
| 612 |
def init_processes(rank, size, fn, args):
|
|
@@ -641,6 +736,8 @@ if __name__ == '__main__':
|
|
| 641 |
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
| 642 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 643 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
|
|
|
|
|
|
| 644 |
|
| 645 |
parser.add_argument('--image_size', type=int, default=32,
|
| 646 |
help='size of image')
|
|
@@ -728,6 +825,7 @@ if __name__ == '__main__':
|
|
| 728 |
parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
|
| 729 |
parser.add_argument('--discr_type', type=str, default="large")
|
| 730 |
parser.add_argument('--preprocessing', type=str, default="resize")
|
|
|
|
| 731 |
|
| 732 |
###ddp
|
| 733 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
|
@@ -746,4 +844,4 @@ if __name__ == '__main__':
|
|
| 746 |
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
| 747 |
args.rank = int(os.environ['SLURM_PROCID'])
|
| 748 |
# size = args.num_process_per_node
|
| 749 |
-
init_processes(args.rank, args.world_size, train, args)
|
|
|
|
| 5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
| 6 |
# ---------------------------------------------------------------
|
| 7 |
|
| 8 |
+
from glob import glob
|
| 9 |
import argparse
|
| 10 |
import torch
|
| 11 |
import numpy as np
|
|
|
|
| 30 |
import logging
|
| 31 |
from encoder import build_encoder
|
| 32 |
from utils import ResampledShards2
|
| 33 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 34 |
|
| 35 |
|
| 36 |
def log_and_continue(exn):
|
|
|
|
| 195 |
|
| 196 |
return x
|
| 197 |
|
| 198 |
+
from contextlib import suppress
|
| 199 |
|
| 200 |
def filter_no_caption(sample):
|
| 201 |
return 'txt' in sample
|
| 202 |
|
| 203 |
+
def get_autocast(precision):
|
| 204 |
+
if precision == 'amp':
|
| 205 |
+
return torch.cuda.amp.autocast
|
| 206 |
+
elif precision == 'amp_bfloat16':
|
| 207 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
| 208 |
+
else:
|
| 209 |
+
return suppress
|
| 210 |
|
| 211 |
def train(rank, gpu, args):
|
| 212 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
| 213 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
| 214 |
from EMA import EMA
|
| 215 |
|
| 216 |
+
#torch.manual_seed(args.seed + rank)
|
| 217 |
+
#torch.cuda.manual_seed(args.seed + rank)
|
| 218 |
+
#torch.cuda.manual_seed_all(args.seed + rank)
|
| 219 |
device = "cuda"
|
| 220 |
+
autocast = get_autocast(args.precision)
|
| 221 |
batch_size = args.batch_size
|
| 222 |
|
| 223 |
nz = args.nz #latent dimension
|
|
|
|
| 277 |
])
|
| 278 |
elif args.preprocessing == "random_resized_crop_v1":
|
| 279 |
train_transform = transforms.Compose([
|
| 280 |
+
transforms.RandomResizedCrop(args.image_size, scale=(0.95, 1.0), interpolation=3),
|
| 281 |
transforms.ToTensor(),
|
| 282 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
| 283 |
])
|
| 284 |
+
shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
|
| 285 |
+
pipeline = [ResampledShards2(shards)]
|
| 286 |
pipeline.extend([
|
| 287 |
wds.split_by_node,
|
| 288 |
wds.split_by_worker,
|
|
|
|
| 347 |
t_emb_dim = args.t_emb_dim,
|
| 348 |
cond_size=text_encoder.output_size,
|
| 349 |
act=nn.LeakyReLU(0.2)).to(device)
|
| 350 |
+
elif args.discr_type == "large_attn_pool":
|
| 351 |
+
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
| 352 |
+
t_emb_dim = args.t_emb_dim,
|
| 353 |
+
cond_size=text_encoder.output_size,
|
| 354 |
+
attn_pool=True,
|
| 355 |
+
act=nn.LeakyReLU(0.2)).to(device)
|
| 356 |
+
|
| 357 |
elif args.discr_type == "large_cond_attn":
|
| 358 |
netD = CondAttnDiscriminator(
|
| 359 |
nc = 2*args.num_channels,
|
|
|
|
| 365 |
broadcast_params(netG.parameters())
|
| 366 |
broadcast_params(netD.parameters())
|
| 367 |
|
| 368 |
+
if args.fsdp:
|
| 369 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 370 |
+
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
|
| 371 |
+
netG = FSDP(
|
| 372 |
+
netG,
|
| 373 |
+
flatten_parameters=True,
|
| 374 |
+
verbose=True,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
|
| 378 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
| 379 |
|
|
|
|
| 382 |
|
| 383 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
| 384 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
| 385 |
+
|
| 386 |
+
if args.fsdp:
|
| 387 |
+
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
| 388 |
+
else:
|
| 389 |
+
netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
|
| 390 |
+
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
| 391 |
|
| 392 |
+
if args.grad_checkpointing:
|
| 393 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 394 |
+
netG = checkpoint_wrapper(netG)
|
| 395 |
|
| 396 |
exp = args.exp
|
| 397 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
|
|
|
| 408 |
T = get_time_schedule(args, device)
|
| 409 |
|
| 410 |
checkpoint_file = os.path.join(exp_path, 'content.pth')
|
| 411 |
+
|
| 412 |
+
if rank == 0:
|
| 413 |
+
log_writer = SummaryWriter(exp_path)
|
| 414 |
+
|
| 415 |
if args.resume and os.path.exists(checkpoint_file):
|
| 416 |
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
| 417 |
init_epoch = checkpoint['epoch']
|
|
|
|
| 430 |
.format(checkpoint['epoch']))
|
| 431 |
else:
|
| 432 |
global_step, epoch, init_epoch = 0, 0, 0
|
| 433 |
+
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool")
|
| 434 |
for epoch in range(init_epoch, args.num_epoch+1):
|
| 435 |
if args.dataset == "wds":
|
| 436 |
os.environ["WDS_EPOCH"] = str(epoch)
|
|
|
|
| 438 |
train_sampler.set_epoch(epoch)
|
| 439 |
|
| 440 |
for iteration, (x, y) in enumerate(data_loader):
|
| 441 |
+
#print(x.shape)
|
| 442 |
if args.dataset != "wds":
|
| 443 |
y = [str(yi) for yi in y.tolist()]
|
| 444 |
|
|
|
|
| 473 |
cond_for_discr.requires_grad = True
|
| 474 |
|
| 475 |
# train with real
|
| 476 |
+
with autocast():
|
| 477 |
+
D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
| 478 |
+
errD_real = F.softplus(-D_real)
|
| 479 |
+
errD_real = errD_real.mean()
|
| 480 |
|
| 481 |
|
| 482 |
errD_real.backward(retain_graph=True)
|
| 483 |
|
| 484 |
+
grad_penalty = None
|
| 485 |
if args.lazy_reg is None:
|
| 486 |
if args.grad_penalty_cond:
|
| 487 |
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
|
|
|
| 527 |
|
| 528 |
# train with fake
|
| 529 |
latent_z = torch.randn(batch_size, nz, device=device)
|
| 530 |
+
with autocast():
|
| 531 |
+
if args.grad_checkpointing:
|
| 532 |
+
ginp = x_tp1.detach()
|
| 533 |
+
ginp.requires_grad = True
|
| 534 |
+
latent_z.requires_grad = True
|
| 535 |
+
cond_pooled.requires_grad = True
|
| 536 |
+
cond.requires_grad = True
|
| 537 |
+
#cond_mask.requires_grad = True
|
| 538 |
+
x_0_predict = netG(ginp, t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
| 539 |
+
else:
|
| 540 |
+
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
| 541 |
+
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
| 542 |
|
| 543 |
+
output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
errD_fake = F.softplus(output)
|
| 547 |
+
errD_fake = errD_fake.mean()
|
| 548 |
|
| 549 |
if args.mismatch_loss:
|
| 550 |
# following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
|
| 551 |
# we add a discr loss for (real image, non matching text)
|
| 552 |
#inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
|
| 553 |
+
with autocast():
|
| 554 |
+
inds = torch.cat([torch.arange(1,len(x_t)),torch.arange(1)])
|
| 555 |
+
cond_for_discr_mis = (cond_pooled[inds], cond[inds], cond_mask[inds]) if use_cond_attn_discr else cond_pooled[inds]
|
| 556 |
+
D_real_mis = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr_mis).view(-1)
|
| 557 |
+
errD_real_mis = F.softplus(D_real_mis)
|
| 558 |
+
errD_real_mis = errD_real_mis.mean()
|
| 559 |
+
errD_fake = errD_fake * 0.5 + errD_real_mis * 0.5
|
| 560 |
|
| 561 |
errD_fake.backward()
|
| 562 |
|
|
|
|
| 580 |
|
| 581 |
latent_z = torch.randn(batch_size, nz,device=device)
|
| 582 |
|
| 583 |
+
with autocast():
|
| 584 |
+
if args.grad_checkpointing:
|
| 585 |
+
ginp = x_tp1.detach()
|
| 586 |
+
ginp.requires_grad = True
|
| 587 |
+
latent_z.requires_grad = True
|
| 588 |
+
cond_pooled.requires_grad = True
|
| 589 |
+
cond.requires_grad = True
|
| 590 |
+
#cond_mask.requires_grad = True
|
| 591 |
+
x_0_predict = netG(ginp, t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
| 592 |
+
else:
|
| 593 |
+
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
| 594 |
+
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
| 595 |
|
| 596 |
+
output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
errG = F.softplus(-output)
|
| 600 |
+
errG = errG.mean()
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
errG.backward()
|
| 603 |
optimizerG.step()
|
| 604 |
|
| 605 |
+
if (iteration % 10 == 0) and (rank == 0):
|
| 606 |
+
log_writer.add_scalar('g_loss', errG.item(), global_step)
|
| 607 |
+
log_writer.add_scalar('d_loss', errD.item(), global_step)
|
| 608 |
+
if grad_penalty is not None:
|
| 609 |
+
log_writer.add_scalar('grad_penalty', grad_penalty.item(), global_step)
|
| 610 |
|
| 611 |
global_step += 1
|
| 612 |
+
|
| 613 |
+
|
| 614 |
if iteration % 100 == 0:
|
| 615 |
if rank == 0:
|
| 616 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
| 617 |
+
print('Global step:', global_step)
|
| 618 |
if iteration % 1000 == 0:
|
| 619 |
x_t_1 = torch.randn_like(real_data)
|
| 620 |
+
with autocast():
|
| 621 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
| 622 |
if rank == 0:
|
| 623 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
| 624 |
+
|
| 625 |
+
if args.save_content:
|
| 626 |
+
dist.barrier()
|
| 627 |
+
print('Saving content.')
|
| 628 |
+
def to_cpu(d):
|
| 629 |
+
for k, v in d.items():
|
| 630 |
+
d[k] = v.cpu()
|
| 631 |
+
return d
|
| 632 |
+
|
| 633 |
+
if args.fsdp:
|
| 634 |
+
netG_state_dict = to_cpu(netG.state_dict())
|
| 635 |
+
netD_state_dict = to_cpu(netD.state_dict())
|
| 636 |
+
#netG_optim_state_dict = (netG.gather_full_optim_state_dict(optimizerG))
|
| 637 |
+
netG_optim_state_dict = optimizerG.state_dict()
|
| 638 |
+
#print(netG_optim_state_dict)
|
| 639 |
+
netD_optim_state_dict = (optimizerD.state_dict())
|
| 640 |
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
| 641 |
+
'netG_dict': netG_state_dict, 'optimizerG': netG_optim_state_dict,
|
| 642 |
+
'schedulerG': schedulerG.state_dict(), 'netD_dict': netD_state_dict,
|
| 643 |
+
'optimizerD': netD_optim_state_dict, 'schedulerD': schedulerD.state_dict()}
|
| 644 |
+
if rank == 0:
|
| 645 |
+
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
| 646 |
+
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
| 647 |
+
if args.use_ema:
|
| 648 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 649 |
+
if args.use_ema and rank == 0:
|
| 650 |
+
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
| 651 |
+
if args.use_ema:
|
| 652 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 653 |
+
#if args.use_ema:
|
| 654 |
+
# dist.barrier()
|
| 655 |
+
print("Saved content")
|
| 656 |
+
else:
|
| 657 |
+
if rank == 0:
|
| 658 |
+
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
| 659 |
+
'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
|
| 660 |
+
'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
|
| 661 |
+
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
| 662 |
+
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
| 663 |
+
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
| 664 |
+
if args.use_ema:
|
| 665 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 666 |
+
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
| 667 |
+
if args.use_ema:
|
| 668 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 669 |
+
|
| 670 |
|
| 671 |
if not args.no_lr_decay:
|
| 672 |
|
| 673 |
schedulerG.step()
|
| 674 |
schedulerD.step()
|
| 675 |
+
"""
|
| 676 |
if rank == 0:
|
| 677 |
if epoch % 10 == 0:
|
| 678 |
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
| 679 |
|
| 680 |
x_t_1 = torch.randn_like(real_data)
|
| 681 |
+
with autocast():
|
| 682 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
| 683 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
| 684 |
|
| 685 |
if args.save_content:
|
|
|
|
| 700 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
| 701 |
if args.use_ema:
|
| 702 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 703 |
+
dist.barrier()
|
| 704 |
+
"""
|
| 705 |
|
| 706 |
|
| 707 |
def init_processes(rank, size, fn, args):
|
|
|
|
| 736 |
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
| 737 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 738 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 739 |
+
parser.add_argument('--fsdp', action='store_true',default=False)
|
| 740 |
+
parser.add_argument('--grad_checkpointing', action='store_true',default=False)
|
| 741 |
|
| 742 |
parser.add_argument('--image_size', type=int, default=32,
|
| 743 |
help='size of image')
|
|
|
|
| 825 |
parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
|
| 826 |
parser.add_argument('--discr_type', type=str, default="large")
|
| 827 |
parser.add_argument('--preprocessing', type=str, default="resize")
|
| 828 |
+
parser.add_argument('--precision', type=str, default="fp32")
|
| 829 |
|
| 830 |
###ddp
|
| 831 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
|
|
|
| 844 |
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
| 845 |
args.rank = int(os.environ['SLURM_PROCID'])
|
| 846 |
# size = args.num_process_per_node
|
| 847 |
+
init_processes(args.rank, args.world_size, train, args)
|
utils.py
CHANGED
|
@@ -41,7 +41,8 @@ class ResampledShards2(IterableDataset):
|
|
| 41 |
"""
|
| 42 |
super().__init__()
|
| 43 |
#urls = wds.shardlists.expand_urls(urls)
|
| 44 |
-
urls
|
|
|
|
| 45 |
self.urls = urls
|
| 46 |
assert isinstance(self.urls[0], str)
|
| 47 |
self.nshards = nshards
|
|
|
|
| 41 |
"""
|
| 42 |
super().__init__()
|
| 43 |
#urls = wds.shardlists.expand_urls(urls)
|
| 44 |
+
if type(urls) != list:
|
| 45 |
+
urls = list(braceexpand.braceexpand(urls))
|
| 46 |
self.urls = urls
|
| 47 |
assert isinstance(self.urls[0], str)
|
| 48 |
self.nshards = nshards
|