Text-to-3D
image-to-3d
Chao Xu commited on
Commit
0e93edd
·
1 Parent(s): c534da1

test rm taming

Browse files
Files changed (44) hide show
  1. taming-transformers/.gitignore +0 -2
  2. taming-transformers/License.txt +0 -19
  3. taming-transformers/README.md +0 -410
  4. taming-transformers/configs/coco_cond_stage.yaml +0 -49
  5. taming-transformers/configs/coco_scene_images_transformer.yaml +0 -80
  6. taming-transformers/configs/custom_vqgan.yaml +0 -43
  7. taming-transformers/configs/drin_transformer.yaml +0 -77
  8. taming-transformers/configs/faceshq_transformer.yaml +0 -61
  9. taming-transformers/configs/faceshq_vqgan.yaml +0 -42
  10. taming-transformers/configs/imagenet_vqgan.yaml +0 -42
  11. taming-transformers/configs/imagenetdepth_vqgan.yaml +0 -41
  12. taming-transformers/configs/open_images_scene_images_transformer.yaml +0 -86
  13. taming-transformers/configs/sflckr_cond_stage.yaml +0 -43
  14. taming-transformers/environment.yaml +0 -25
  15. taming-transformers/main.py +0 -585
  16. taming-transformers/scripts/extract_depth.py +0 -112
  17. taming-transformers/scripts/extract_segmentation.py +0 -130
  18. taming-transformers/scripts/extract_submodel.py +0 -17
  19. taming-transformers/scripts/make_samples.py +0 -292
  20. taming-transformers/scripts/make_scene_samples.py +0 -198
  21. taming-transformers/scripts/sample_conditional.py +0 -355
  22. taming-transformers/scripts/sample_fast.py +0 -260
  23. taming-transformers/setup.py +0 -13
  24. taming-transformers/taming/lr_scheduler.py +0 -34
  25. taming-transformers/taming/models/cond_transformer.py +0 -352
  26. taming-transformers/taming/models/dummy_cond_stage.py +0 -22
  27. taming-transformers/taming/models/vqgan.py +0 -404
  28. taming-transformers/taming/modules/diffusionmodules/model.py +0 -776
  29. taming-transformers/taming/modules/discriminator/model.py +0 -67
  30. taming-transformers/taming/modules/losses/__init__.py +0 -2
  31. taming-transformers/taming/modules/losses/lpips.py +0 -123
  32. taming-transformers/taming/modules/losses/segmentation.py +0 -22
  33. taming-transformers/taming/modules/losses/vqperceptual.py +0 -136
  34. taming-transformers/taming/modules/misc/coord.py +0 -31
  35. taming-transformers/taming/modules/transformer/mingpt.py +0 -415
  36. taming-transformers/taming/modules/transformer/permuter.py +0 -248
  37. taming-transformers/taming/modules/util.py +0 -130
  38. taming-transformers/taming/modules/vqvae/quantize.py +0 -445
  39. taming-transformers/taming/util.py +0 -157
  40. taming-transformers/taming_transformers.egg-info/PKG-INFO +0 -10
  41. taming-transformers/taming_transformers.egg-info/SOURCES.txt +0 -7
  42. taming-transformers/taming_transformers.egg-info/dependency_links.txt +0 -1
  43. taming-transformers/taming_transformers.egg-info/requires.txt +0 -3
  44. taming-transformers/taming_transformers.egg-info/top_level.txt +0 -1
taming-transformers/.gitignore DELETED
@@ -1,2 +0,0 @@
1
- assets/
2
- data/
 
 
 
taming-transformers/License.txt DELETED
@@ -1,19 +0,0 @@
1
- Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
2
-
3
- Permission is hereby granted, free of charge, to any person obtaining a copy
4
- of this software and associated documentation files (the "Software"), to deal
5
- in the Software without restriction, including without limitation the rights
6
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
- copies of the Software, and to permit persons to whom the Software is
8
- furnished to do so, subject to the following conditions:
9
-
10
- The above copyright notice and this permission notice shall be included in all
11
- copies or substantial portions of the Software.
12
-
13
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16
- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17
- DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18
- OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19
- OR OTHER DEALINGS IN THE SOFTWARE./
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/README.md DELETED
@@ -1,410 +0,0 @@
1
- # Taming Transformers for High-Resolution Image Synthesis
2
- ##### CVPR 2021 (Oral)
3
- ![teaser](assets/mountain.jpeg)
4
-
5
- [**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)<br/>
6
- [Patrick Esser](https://github.com/pesser)\*,
7
- [Robin Rombach](https://github.com/rromb)\*,
8
- [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
9
- \* equal contribution
10
-
11
- **tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer.
12
-
13
- ![teaser](assets/teaser.png)
14
- [arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/)
15
-
16
-
17
- ### News
18
- #### 2022
19
- - More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion).
20
- - Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis).
21
- #### 2021
22
- - Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
23
- - Included a bugfix for the quantizer. For backward compatibility it is
24
- disabled by default (which corresponds to always training with `beta=1.0`).
25
- Use `legacy=False` in the quantizer config to enable it.
26
- Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)!
27
- - Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog.
28
- - Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN.
29
- - Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf).
30
- - Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`.
31
- - Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization.
32
- See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
33
- - We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources).
34
- - We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k).
35
- - The streamlit demo now supports image completions.
36
- - We now include a couple of examples from the D-RIN dataset so you can run the
37
- [D-RIN demo](#d-rin) without preparing the dataset first.
38
- - You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
39
-
40
- ## Requirements
41
- A suitable [conda](https://conda.io/) environment named `taming` can be created
42
- and activated with:
43
-
44
- ```
45
- conda env create -f environment.yaml
46
- conda activate taming
47
- ```
48
- ## Overview of pretrained models
49
- The following table provides an overview of all models that are currently available.
50
- FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity).
51
- For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model.
52
- See the corresponding [colab
53
- notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb)
54
- for a comparison and discussion of reconstruction capabilities.
55
-
56
- | Dataset | FID vs train | FID vs val | Link | Samples (256x256) | Comments
57
- | ------------- | ------------- | ------------- |------------- | ------------- |------------- |
58
- | FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) | [ffhq_samples](https://k00.fr/j626x093) |
59
- | CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) |
60
- | ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images)
61
- | COCO-Stuff (f=16) | -- | 20.4 | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images)
62
- | ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters |
63
- | | | | || |
64
- | FacesHQ (f=16) | -- | -- | [faceshq_transformer](https://k00.fr/qqfl2do8)
65
- | S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
66
- | D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5)
67
- | | | | | || |
68
- | VQGAN ImageNet (f=16), 1024 | 10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
69
- | VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
70
- | VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion).
71
- | VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion)
72
- | VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) | --- | Reconstruction-FIDs.
73
- | | | | | || |
74
- | DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
75
-
76
-
77
- ## Running pretrained models
78
-
79
- The commands below will start a streamlit demo which supports sampling at
80
- different resolutions and image completions. To run a non-interactive version
81
- of the sampling process, replace `streamlit run scripts/sample_conditional.py --`
82
- by `python scripts/make_samples.py --outdir <path_to_write_samples_to>` and
83
- keep the remaining command line arguments.
84
-
85
- To sample from unconditional or class-conditional models,
86
- run `python scripts/sample_fast.py -r <path/to/config_and_checkpoint>`.
87
- We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models,
88
- respectively.
89
-
90
- ### S-FLCKR
91
- ![teaser](assets/sunset_and_ocean.jpg)
92
-
93
- You can also [run this model in a Colab
94
- notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb),
95
- which includes all necessary steps to start sampling.
96
-
97
- Download the
98
- [2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
99
- folder and place it into `logs`. Then, run
100
- ```
101
- streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
102
- ```
103
-
104
- ### ImageNet
105
- ![teaser](assets/imagenet.png)
106
-
107
- Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv)
108
- folder and place it into logs. Sampling from the class-conditional ImageNet
109
- model does not require any data preparation. To produce 50 samples for each of
110
- the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus
111
- sampling and temperature t=1.0, run
112
-
113
- ```
114
- python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25
115
- ```
116
-
117
- To restrict the model to certain classes, provide them via the `--classes` argument, separated by
118
- commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run
119
-
120
- ```
121
- python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901
122
- ```
123
- We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results.
124
-
125
- ### FFHQ/CelebA-HQ
126
-
127
- Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and
128
- [2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf)
129
- folders and place them into logs.
130
- Again, sampling from these unconditional models does not require any data preparation.
131
- To produce 50000 samples, with k=250 for top-k sampling,
132
- p=1.0 for nucleus sampling and temperature t=1.0, run
133
-
134
- ```
135
- python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/
136
- ```
137
- for FFHQ and
138
-
139
- ```
140
- python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/
141
- ```
142
- to sample from the CelebA-HQ model.
143
- For both models it can be advantageous to vary the top-k/top-p parameters for sampling.
144
-
145
- ### FacesHQ
146
- ![teaser](assets/faceshq.jpg)
147
-
148
- Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and
149
- place it into `logs`. Follow the data preparation steps for
150
- [CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run
151
- ```
152
- streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
153
- ```
154
-
155
- ### D-RIN
156
- ![teaser](assets/drin.jpg)
157
-
158
- Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and
159
- place it into `logs`. To run the demo on a couple of example depth maps
160
- included in the repository, run
161
-
162
- ```
163
- streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
164
- ```
165
-
166
- To run the demo on the complete validation set, first follow the data preparation steps for
167
- [ImageNet](#imagenet) and then run
168
- ```
169
- streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
170
- ```
171
-
172
- ### COCO
173
- Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and
174
- place it into `logs`. To run the demo on a couple of example segmentation maps
175
- included in the repository, run
176
-
177
- ```
178
- streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
179
- ```
180
-
181
- ### ADE20k
182
- Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and
183
- place it into `logs`. To run the demo on a couple of example segmentation maps
184
- included in the repository, run
185
-
186
- ```
187
- streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
188
- ```
189
-
190
- ## Scene Image Synthesis
191
- ![teaser](assets/scene_images_samples.svg)
192
- Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images.
193
-
194
- ### Training
195
- Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images.
196
- Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models.
197
- Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already.
198
-
199
- Code can be run with
200
- `python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,`
201
- or
202
- `python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,`
203
-
204
- ### Sampling
205
- Train a model as described above or download a pre-trained model:
206
- - [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512.
207
- - [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11.
208
- - [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/)
209
- - [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`)
210
-
211
- When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training).
212
-
213
- Scene image generation can be run with
214
- `python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512`
215
-
216
-
217
- ## Training on custom data
218
-
219
- Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
220
- Those are the steps to follow to make this work:
221
- 1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
222
- 1. put your .jpg files in a folder `your_folder`
223
- 2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
224
- 3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
225
- 4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
226
- train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
227
-
228
- ## Data Preparation
229
-
230
- ### ImageNet
231
- The code will try to download (through [Academic
232
- Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
233
- is used. However, since ImageNet is quite large, this requires a lot of disk
234
- space and time. If you already have ImageNet on your disk, you can speed things
235
- up by putting the data into
236
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
237
- `~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
238
- of `train`/`validation`. It should have the following structure:
239
-
240
- ```
241
- ${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
242
- ├── n01440764
243
- │ ├── n01440764_10026.JPEG
244
- │ ├── n01440764_10027.JPEG
245
- │ ├── ...
246
- ├── n01443537
247
- │ ├── n01443537_10007.JPEG
248
- │ ├── n01443537_10014.JPEG
249
- │ ├── ...
250
- ├── ...
251
- ```
252
-
253
- If you haven't extracted the data, you can also place
254
- `ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
255
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
256
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
257
- extracted into above structure without downloading it again. Note that this
258
- will only happen if neither a folder
259
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
260
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
261
- if you want to force running the dataset preparation again.
262
-
263
- You will then need to prepare the depth data using
264
- [MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
265
- `data/imagenet_depth` pointing to a folder with two subfolders `train` and
266
- `val`, each mirroring the structure of the corresponding ImageNet folder
267
- described above and containing a `png` file for each of ImageNet's `JPEG`
268
- files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
269
- images. We provide the script `scripts/extract_depth.py` to generate this data.
270
- **Please note** that this script uses [MiDaS via PyTorch
271
- Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data,
272
- the hub provided the [MiDaS
273
- v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it
274
- provides a v2.1 version. We haven't tested our models with depth maps obtained
275
- via v2.1 and if you want to make sure that things work as expected, you must
276
- adjust the script to make sure it explicitly uses
277
- [v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)!
278
-
279
- ### CelebA-HQ
280
- Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
281
- files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
282
- repository](https://github.com/tkarras/progressive_growing_of_gans)).
283
-
284
- ### FFHQ
285
- Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained
286
- from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
287
-
288
- ### S-FLCKR
289
- Unfortunately, we are not allowed to distribute the images we collected for the
290
- S-FLCKR dataset and can therefore only give a description how it was produced.
291
- There are many resources on [collecting images from the
292
- web](https://github.com/adrianmrit/flickrdatasets) to get started.
293
- We collected sufficiently large images from [flickr](https://www.flickr.com)
294
- (see `data/flickr_tags.txt` for a full list of tags used to find images)
295
- and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network)
296
- (see `data/subreddits.txt` for all subreddits that were used).
297
- Overall, we collected 107625 images, and split them randomly into 96861
298
- training images and 10764 validation images. We then obtained segmentation
299
- masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915)
300
- trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch
301
- reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an
302
- example script for this process in `scripts/extract_segmentation.py`.
303
-
304
- ### COCO
305
- Create a symlink `data/coco` containing the images from the 2017 split in
306
- `train2017` and `val2017`, and their annotations in `annotations`. Files can be
307
- obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use
308
- the [Stuff+thing PNG-style annotations on COCO 2017
309
- trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
310
- annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which
311
- should be placed under `data/cocostuffthings`.
312
-
313
- ### ADE20k
314
- Create a symlink `data/ade20k_root` containing the contents of
315
- [ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
316
- from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/).
317
-
318
- ## Training models
319
-
320
- ### FacesHQ
321
-
322
- Train a VQGAN with
323
- ```
324
- python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
325
- ```
326
-
327
- Then, adjust the checkpoint path of the config key
328
- `model.params.first_stage_config.params.ckpt_path` in
329
- `configs/faceshq_transformer.yaml` (or download
330
- [2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which
331
- corresponds to the preconfigured checkpoint path), then run
332
- ```
333
- python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
334
- ```
335
-
336
- ### D-RIN
337
-
338
- Train a VQGAN on ImageNet with
339
- ```
340
- python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
341
- ```
342
-
343
- or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac)
344
- and place under `logs`. If you trained your own, adjust the path in the config
345
- key `model.params.first_stage_config.params.ckpt_path` of
346
- `configs/drin_transformer.yaml`.
347
-
348
- Train a VQGAN on Depth Maps of ImageNet with
349
- ```
350
- python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
351
- ```
352
-
353
- or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i)
354
- and place under `logs`. If you trained your own, adjust the path in the config
355
- key `model.params.cond_stage_config.params.ckpt_path` of
356
- `configs/drin_transformer.yaml`.
357
-
358
- To train the transformer, run
359
- ```
360
- python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
361
- ```
362
-
363
- ## More Resources
364
- ### Comparing Different First Stage Models
365
- The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
366
- In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384,
367
- a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192
368
- codebook entries).
369
- ![firststages1](assets/first_stage_squirrels.png)
370
- ![firststages2](assets/first_stage_mushrooms.png)
371
-
372
- ### Other
373
- - A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg).
374
- - A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about).
375
- - A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY)
376
- by [ayulockin](https://github.com/ayulockin).
377
- - A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg).
378
- - Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab.
379
-
380
- ### Text-to-Image Optimization via CLIP
381
- VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation
382
- from scratch and image-to-image translation. We recommend the following notebooks/videos/resources:
383
-
384
- - [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme
385
- - The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings).
386
- - A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available)
387
-
388
- ![txt2img](assets/birddrawnbyachild.png)
389
-
390
- Text prompt: *'A bird drawn by a child'*
391
-
392
- ## Shout-outs
393
- Thanks to everyone who makes their code and models available. In particular,
394
-
395
- - The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
396
- - The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT)
397
- - The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity)
398
-
399
- ## BibTeX
400
-
401
- ```
402
- @misc{esser2020taming,
403
- title={Taming Transformers for High-Resolution Image Synthesis},
404
- author={Patrick Esser and Robin Rombach and Björn Ommer},
405
- year={2020},
406
- eprint={2012.09841},
407
- archivePrefix={arXiv},
408
- primaryClass={cs.CV}
409
- }
410
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/coco_cond_stage.yaml DELETED
@@ -1,49 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-06
3
- target: taming.models.vqgan.VQSegmentationModel
4
- params:
5
- embed_dim: 256
6
- n_embed: 1024
7
- image_key: "segmentation"
8
- n_labels: 183
9
- ddconfig:
10
- double_z: false
11
- z_channels: 256
12
- resolution: 256
13
- in_channels: 183
14
- out_ch: 183
15
- ch: 128
16
- ch_mult:
17
- - 1
18
- - 1
19
- - 2
20
- - 2
21
- - 4
22
- num_res_blocks: 2
23
- attn_resolutions:
24
- - 16
25
- dropout: 0.0
26
-
27
- lossconfig:
28
- target: taming.modules.losses.segmentation.BCELossWithQuant
29
- params:
30
- codebook_weight: 1.0
31
-
32
- data:
33
- target: main.DataModuleFromConfig
34
- params:
35
- batch_size: 12
36
- train:
37
- target: taming.data.coco.CocoImagesAndCaptionsTrain
38
- params:
39
- size: 296
40
- crop_size: 256
41
- onehot_segmentation: true
42
- use_stuffthing: true
43
- validation:
44
- target: taming.data.coco.CocoImagesAndCaptionsValidation
45
- params:
46
- size: 256
47
- crop_size: 256
48
- onehot_segmentation: true
49
- use_stuffthing: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/coco_scene_images_transformer.yaml DELETED
@@ -1,80 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-06
3
- target: taming.models.cond_transformer.Net2NetTransformer
4
- params:
5
- cond_stage_key: objects_bbox
6
- transformer_config:
7
- target: taming.modules.transformer.mingpt.GPT
8
- params:
9
- vocab_size: 8192
10
- block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11
- n_layer: 40
12
- n_head: 16
13
- n_embd: 1408
14
- embd_pdrop: 0.1
15
- resid_pdrop: 0.1
16
- attn_pdrop: 0.1
17
- first_stage_config:
18
- target: taming.models.vqgan.VQModel
19
- params:
20
- ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
21
- embed_dim: 256
22
- n_embed: 8192
23
- ddconfig:
24
- double_z: false
25
- z_channels: 256
26
- resolution: 256
27
- in_channels: 3
28
- out_ch: 3
29
- ch: 128
30
- ch_mult:
31
- - 1
32
- - 1
33
- - 2
34
- - 2
35
- - 4
36
- num_res_blocks: 2
37
- attn_resolutions:
38
- - 16
39
- dropout: 0.0
40
- lossconfig:
41
- target: taming.modules.losses.DummyLoss
42
- cond_stage_config:
43
- target: taming.models.dummy_cond_stage.DummyCondStage
44
- params:
45
- conditional_key: objects_bbox
46
-
47
- data:
48
- target: main.DataModuleFromConfig
49
- params:
50
- batch_size: 6
51
- train:
52
- target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
53
- params:
54
- data_path: data/coco_annotations_100 # substitute with path to full dataset
55
- split: train
56
- keys: [image, objects_bbox, file_name, annotations]
57
- no_tokens: 8192
58
- target_image_size: 256
59
- min_object_area: 0.00001
60
- min_objects_per_image: 2
61
- max_objects_per_image: 30
62
- crop_method: random-1d
63
- random_flip: true
64
- use_group_parameter: true
65
- encode_crop: true
66
- validation:
67
- target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
68
- params:
69
- data_path: data/coco_annotations_100 # substitute with path to full dataset
70
- split: validation
71
- keys: [image, objects_bbox, file_name, annotations]
72
- no_tokens: 8192
73
- target_image_size: 256
74
- min_object_area: 0.00001
75
- min_objects_per_image: 2
76
- max_objects_per_image: 30
77
- crop_method: center
78
- random_flip: false
79
- use_group_parameter: true
80
- encode_crop: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/custom_vqgan.yaml DELETED
@@ -1,43 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-6
3
- target: taming.models.vqgan.VQModel
4
- params:
5
- embed_dim: 256
6
- n_embed: 1024
7
- ddconfig:
8
- double_z: False
9
- z_channels: 256
10
- resolution: 256
11
- in_channels: 3
12
- out_ch: 3
13
- ch: 128
14
- ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15
- num_res_blocks: 2
16
- attn_resolutions: [16]
17
- dropout: 0.0
18
-
19
- lossconfig:
20
- target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21
- params:
22
- disc_conditional: False
23
- disc_in_channels: 3
24
- disc_start: 10000
25
- disc_weight: 0.8
26
- codebook_weight: 1.0
27
-
28
- data:
29
- target: main.DataModuleFromConfig
30
- params:
31
- batch_size: 5
32
- num_workers: 8
33
- train:
34
- target: taming.data.custom.CustomTrain
35
- params:
36
- training_images_list_file: some/training.txt
37
- size: 256
38
- validation:
39
- target: taming.data.custom.CustomTest
40
- params:
41
- test_images_list_file: some/test.txt
42
- size: 256
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/drin_transformer.yaml DELETED
@@ -1,77 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-06
3
- target: taming.models.cond_transformer.Net2NetTransformer
4
- params:
5
- cond_stage_key: depth
6
- transformer_config:
7
- target: taming.modules.transformer.mingpt.GPT
8
- params:
9
- vocab_size: 1024
10
- block_size: 512
11
- n_layer: 24
12
- n_head: 16
13
- n_embd: 1024
14
- first_stage_config:
15
- target: taming.models.vqgan.VQModel
16
- params:
17
- ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt
18
- embed_dim: 256
19
- n_embed: 1024
20
- ddconfig:
21
- double_z: false
22
- z_channels: 256
23
- resolution: 256
24
- in_channels: 3
25
- out_ch: 3
26
- ch: 128
27
- ch_mult:
28
- - 1
29
- - 1
30
- - 2
31
- - 2
32
- - 4
33
- num_res_blocks: 2
34
- attn_resolutions:
35
- - 16
36
- dropout: 0.0
37
- lossconfig:
38
- target: taming.modules.losses.DummyLoss
39
- cond_stage_config:
40
- target: taming.models.vqgan.VQModel
41
- params:
42
- ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt
43
- embed_dim: 256
44
- n_embed: 1024
45
- ddconfig:
46
- double_z: false
47
- z_channels: 256
48
- resolution: 256
49
- in_channels: 1
50
- out_ch: 1
51
- ch: 128
52
- ch_mult:
53
- - 1
54
- - 1
55
- - 2
56
- - 2
57
- - 4
58
- num_res_blocks: 2
59
- attn_resolutions:
60
- - 16
61
- dropout: 0.0
62
- lossconfig:
63
- target: taming.modules.losses.DummyLoss
64
-
65
- data:
66
- target: main.DataModuleFromConfig
67
- params:
68
- batch_size: 2
69
- num_workers: 8
70
- train:
71
- target: taming.data.imagenet.RINTrainWithDepth
72
- params:
73
- size: 256
74
- validation:
75
- target: taming.data.imagenet.RINValidationWithDepth
76
- params:
77
- size: 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/faceshq_transformer.yaml DELETED
@@ -1,61 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-06
3
- target: taming.models.cond_transformer.Net2NetTransformer
4
- params:
5
- cond_stage_key: coord
6
- transformer_config:
7
- target: taming.modules.transformer.mingpt.GPT
8
- params:
9
- vocab_size: 1024
10
- block_size: 512
11
- n_layer: 24
12
- n_head: 16
13
- n_embd: 1024
14
- first_stage_config:
15
- target: taming.models.vqgan.VQModel
16
- params:
17
- ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt
18
- embed_dim: 256
19
- n_embed: 1024
20
- ddconfig:
21
- double_z: false
22
- z_channels: 256
23
- resolution: 256
24
- in_channels: 3
25
- out_ch: 3
26
- ch: 128
27
- ch_mult:
28
- - 1
29
- - 1
30
- - 2
31
- - 2
32
- - 4
33
- num_res_blocks: 2
34
- attn_resolutions:
35
- - 16
36
- dropout: 0.0
37
- lossconfig:
38
- target: taming.modules.losses.DummyLoss
39
- cond_stage_config:
40
- target: taming.modules.misc.coord.CoordStage
41
- params:
42
- n_embed: 1024
43
- down_factor: 16
44
-
45
- data:
46
- target: main.DataModuleFromConfig
47
- params:
48
- batch_size: 2
49
- num_workers: 8
50
- train:
51
- target: taming.data.faceshq.FacesHQTrain
52
- params:
53
- size: 256
54
- crop_size: 256
55
- coord: True
56
- validation:
57
- target: taming.data.faceshq.FacesHQValidation
58
- params:
59
- size: 256
60
- crop_size: 256
61
- coord: True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/faceshq_vqgan.yaml DELETED
@@ -1,42 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-6
3
- target: taming.models.vqgan.VQModel
4
- params:
5
- embed_dim: 256
6
- n_embed: 1024
7
- ddconfig:
8
- double_z: False
9
- z_channels: 256
10
- resolution: 256
11
- in_channels: 3
12
- out_ch: 3
13
- ch: 128
14
- ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15
- num_res_blocks: 2
16
- attn_resolutions: [16]
17
- dropout: 0.0
18
-
19
- lossconfig:
20
- target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21
- params:
22
- disc_conditional: False
23
- disc_in_channels: 3
24
- disc_start: 30001
25
- disc_weight: 0.8
26
- codebook_weight: 1.0
27
-
28
- data:
29
- target: main.DataModuleFromConfig
30
- params:
31
- batch_size: 3
32
- num_workers: 8
33
- train:
34
- target: taming.data.faceshq.FacesHQTrain
35
- params:
36
- size: 256
37
- crop_size: 256
38
- validation:
39
- target: taming.data.faceshq.FacesHQValidation
40
- params:
41
- size: 256
42
- crop_size: 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/imagenet_vqgan.yaml DELETED
@@ -1,42 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-6
3
- target: taming.models.vqgan.VQModel
4
- params:
5
- embed_dim: 256
6
- n_embed: 1024
7
- ddconfig:
8
- double_z: False
9
- z_channels: 256
10
- resolution: 256
11
- in_channels: 3
12
- out_ch: 3
13
- ch: 128
14
- ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15
- num_res_blocks: 2
16
- attn_resolutions: [16]
17
- dropout: 0.0
18
-
19
- lossconfig:
20
- target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21
- params:
22
- disc_conditional: False
23
- disc_in_channels: 3
24
- disc_start: 250001
25
- disc_weight: 0.8
26
- codebook_weight: 1.0
27
-
28
- data:
29
- target: main.DataModuleFromConfig
30
- params:
31
- batch_size: 12
32
- num_workers: 24
33
- train:
34
- target: taming.data.imagenet.ImageNetTrain
35
- params:
36
- config:
37
- size: 256
38
- validation:
39
- target: taming.data.imagenet.ImageNetValidation
40
- params:
41
- config:
42
- size: 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/imagenetdepth_vqgan.yaml DELETED
@@ -1,41 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-6
3
- target: taming.models.vqgan.VQModel
4
- params:
5
- embed_dim: 256
6
- n_embed: 1024
7
- image_key: depth
8
- ddconfig:
9
- double_z: False
10
- z_channels: 256
11
- resolution: 256
12
- in_channels: 1
13
- out_ch: 1
14
- ch: 128
15
- ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
16
- num_res_blocks: 2
17
- attn_resolutions: [16]
18
- dropout: 0.0
19
-
20
- lossconfig:
21
- target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
22
- params:
23
- disc_conditional: False
24
- disc_in_channels: 1
25
- disc_start: 50001
26
- disc_weight: 0.75
27
- codebook_weight: 1.0
28
-
29
- data:
30
- target: main.DataModuleFromConfig
31
- params:
32
- batch_size: 3
33
- num_workers: 8
34
- train:
35
- target: taming.data.imagenet.ImageNetTrainWithDepth
36
- params:
37
- size: 256
38
- validation:
39
- target: taming.data.imagenet.ImageNetValidationWithDepth
40
- params:
41
- size: 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/open_images_scene_images_transformer.yaml DELETED
@@ -1,86 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-06
3
- target: taming.models.cond_transformer.Net2NetTransformer
4
- params:
5
- cond_stage_key: objects_bbox
6
- transformer_config:
7
- target: taming.modules.transformer.mingpt.GPT
8
- params:
9
- vocab_size: 8192
10
- block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11
- n_layer: 36
12
- n_head: 16
13
- n_embd: 1536
14
- embd_pdrop: 0.1
15
- resid_pdrop: 0.1
16
- attn_pdrop: 0.1
17
- first_stage_config:
18
- target: taming.models.vqgan.VQModel
19
- params:
20
- ckpt_path: /path/to/coco_oi_epoch12.ckpt # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/
21
- embed_dim: 256
22
- n_embed: 8192
23
- ddconfig:
24
- double_z: false
25
- z_channels: 256
26
- resolution: 256
27
- in_channels: 3
28
- out_ch: 3
29
- ch: 128
30
- ch_mult:
31
- - 1
32
- - 1
33
- - 2
34
- - 2
35
- - 4
36
- num_res_blocks: 2
37
- attn_resolutions:
38
- - 16
39
- dropout: 0.0
40
- lossconfig:
41
- target: taming.modules.losses.DummyLoss
42
- cond_stage_config:
43
- target: taming.models.dummy_cond_stage.DummyCondStage
44
- params:
45
- conditional_key: objects_bbox
46
-
47
- data:
48
- target: main.DataModuleFromConfig
49
- params:
50
- batch_size: 6
51
- train:
52
- target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
53
- params:
54
- data_path: data/open_images_annotations_100 # substitute with path to full dataset
55
- split: train
56
- keys: [image, objects_bbox, file_name, annotations]
57
- no_tokens: 8192
58
- target_image_size: 256
59
- category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
60
- category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
61
- min_object_area: 0.0001
62
- min_objects_per_image: 2
63
- max_objects_per_image: 30
64
- crop_method: random-2d
65
- random_flip: true
66
- use_group_parameter: true
67
- use_additional_parameters: true
68
- encode_crop: true
69
- validation:
70
- target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
71
- params:
72
- data_path: data/open_images_annotations_100 # substitute with path to full dataset
73
- split: validation
74
- keys: [image, objects_bbox, file_name, annotations]
75
- no_tokens: 8192
76
- target_image_size: 256
77
- category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
78
- category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
79
- min_object_area: 0.0001
80
- min_objects_per_image: 2
81
- max_objects_per_image: 30
82
- crop_method: center
83
- random_flip: false
84
- use_group_parameter: true
85
- use_additional_parameters: true
86
- encode_crop: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/configs/sflckr_cond_stage.yaml DELETED
@@ -1,43 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-06
3
- target: taming.models.vqgan.VQSegmentationModel
4
- params:
5
- embed_dim: 256
6
- n_embed: 1024
7
- image_key: "segmentation"
8
- n_labels: 182
9
- ddconfig:
10
- double_z: false
11
- z_channels: 256
12
- resolution: 256
13
- in_channels: 182
14
- out_ch: 182
15
- ch: 128
16
- ch_mult:
17
- - 1
18
- - 1
19
- - 2
20
- - 2
21
- - 4
22
- num_res_blocks: 2
23
- attn_resolutions:
24
- - 16
25
- dropout: 0.0
26
-
27
- lossconfig:
28
- target: taming.modules.losses.segmentation.BCELossWithQuant
29
- params:
30
- codebook_weight: 1.0
31
-
32
- data:
33
- target: cutlit.DataModuleFromConfig
34
- params:
35
- batch_size: 12
36
- train:
37
- target: taming.data.sflckr.Examples # adjust
38
- params:
39
- size: 256
40
- validation:
41
- target: taming.data.sflckr.Examples # adjust
42
- params:
43
- size: 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/environment.yaml DELETED
@@ -1,25 +0,0 @@
1
- name: taming
2
- channels:
3
- - pytorch
4
- - defaults
5
- dependencies:
6
- - python=3.8.5
7
- - pip=20.3
8
- - cudatoolkit=10.2
9
- - pytorch=1.7.0
10
- - torchvision=0.8.1
11
- - numpy=1.19.2
12
- - pip:
13
- - albumentations==0.4.3
14
- - opencv-python==4.1.2.30
15
- - pudb==2019.2
16
- - imageio==2.9.0
17
- - imageio-ffmpeg==0.4.2
18
- - pytorch-lightning==1.0.8
19
- - omegaconf==2.0.0
20
- - test-tube>=0.7.5
21
- - streamlit>=0.73.1
22
- - einops==0.3.0
23
- - more-itertools>=8.0.0
24
- - transformers==4.3.1
25
- - -e .
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/main.py DELETED
@@ -1,585 +0,0 @@
1
- import argparse, os, sys, datetime, glob, importlib
2
- from omegaconf import OmegaConf
3
- import numpy as np
4
- from PIL import Image
5
- import torch
6
- import torchvision
7
- from torch.utils.data import random_split, DataLoader, Dataset
8
- import pytorch_lightning as pl
9
- from pytorch_lightning import seed_everything
10
- from pytorch_lightning.trainer import Trainer
11
- from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
12
- from pytorch_lightning.utilities import rank_zero_only
13
-
14
- from taming.data.utils import custom_collate
15
-
16
-
17
- def get_obj_from_str(string, reload=False):
18
- module, cls = string.rsplit(".", 1)
19
- if reload:
20
- module_imp = importlib.import_module(module)
21
- importlib.reload(module_imp)
22
- return getattr(importlib.import_module(module, package=None), cls)
23
-
24
-
25
- def get_parser(**parser_kwargs):
26
- def str2bool(v):
27
- if isinstance(v, bool):
28
- return v
29
- if v.lower() in ("yes", "true", "t", "y", "1"):
30
- return True
31
- elif v.lower() in ("no", "false", "f", "n", "0"):
32
- return False
33
- else:
34
- raise argparse.ArgumentTypeError("Boolean value expected.")
35
-
36
- parser = argparse.ArgumentParser(**parser_kwargs)
37
- parser.add_argument(
38
- "-n",
39
- "--name",
40
- type=str,
41
- const=True,
42
- default="",
43
- nargs="?",
44
- help="postfix for logdir",
45
- )
46
- parser.add_argument(
47
- "-r",
48
- "--resume",
49
- type=str,
50
- const=True,
51
- default="",
52
- nargs="?",
53
- help="resume from logdir or checkpoint in logdir",
54
- )
55
- parser.add_argument(
56
- "-b",
57
- "--base",
58
- nargs="*",
59
- metavar="base_config.yaml",
60
- help="paths to base configs. Loaded from left-to-right. "
61
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
62
- default=list(),
63
- )
64
- parser.add_argument(
65
- "-t",
66
- "--train",
67
- type=str2bool,
68
- const=True,
69
- default=False,
70
- nargs="?",
71
- help="train",
72
- )
73
- parser.add_argument(
74
- "--no-test",
75
- type=str2bool,
76
- const=True,
77
- default=False,
78
- nargs="?",
79
- help="disable test",
80
- )
81
- parser.add_argument("-p", "--project", help="name of new or path to existing project")
82
- parser.add_argument(
83
- "-d",
84
- "--debug",
85
- type=str2bool,
86
- nargs="?",
87
- const=True,
88
- default=False,
89
- help="enable post-mortem debugging",
90
- )
91
- parser.add_argument(
92
- "-s",
93
- "--seed",
94
- type=int,
95
- default=23,
96
- help="seed for seed_everything",
97
- )
98
- parser.add_argument(
99
- "-f",
100
- "--postfix",
101
- type=str,
102
- default="",
103
- help="post-postfix for default name",
104
- )
105
-
106
- return parser
107
-
108
-
109
- def nondefault_trainer_args(opt):
110
- parser = argparse.ArgumentParser()
111
- parser = Trainer.add_argparse_args(parser)
112
- args = parser.parse_args([])
113
- return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
114
-
115
-
116
- def instantiate_from_config(config):
117
- if not "target" in config:
118
- raise KeyError("Expected key `target` to instantiate.")
119
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
120
-
121
-
122
- class WrappedDataset(Dataset):
123
- """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
124
- def __init__(self, dataset):
125
- self.data = dataset
126
-
127
- def __len__(self):
128
- return len(self.data)
129
-
130
- def __getitem__(self, idx):
131
- return self.data[idx]
132
-
133
-
134
- class DataModuleFromConfig(pl.LightningDataModule):
135
- def __init__(self, batch_size, train=None, validation=None, test=None,
136
- wrap=False, num_workers=None):
137
- super().__init__()
138
- self.batch_size = batch_size
139
- self.dataset_configs = dict()
140
- self.num_workers = num_workers if num_workers is not None else batch_size*2
141
- if train is not None:
142
- self.dataset_configs["train"] = train
143
- self.train_dataloader = self._train_dataloader
144
- if validation is not None:
145
- self.dataset_configs["validation"] = validation
146
- self.val_dataloader = self._val_dataloader
147
- if test is not None:
148
- self.dataset_configs["test"] = test
149
- self.test_dataloader = self._test_dataloader
150
- self.wrap = wrap
151
-
152
- def prepare_data(self):
153
- for data_cfg in self.dataset_configs.values():
154
- instantiate_from_config(data_cfg)
155
-
156
- def setup(self, stage=None):
157
- self.datasets = dict(
158
- (k, instantiate_from_config(self.dataset_configs[k]))
159
- for k in self.dataset_configs)
160
- if self.wrap:
161
- for k in self.datasets:
162
- self.datasets[k] = WrappedDataset(self.datasets[k])
163
-
164
- def _train_dataloader(self):
165
- return DataLoader(self.datasets["train"], batch_size=self.batch_size,
166
- num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
167
-
168
- def _val_dataloader(self):
169
- return DataLoader(self.datasets["validation"],
170
- batch_size=self.batch_size,
171
- num_workers=self.num_workers, collate_fn=custom_collate)
172
-
173
- def _test_dataloader(self):
174
- return DataLoader(self.datasets["test"], batch_size=self.batch_size,
175
- num_workers=self.num_workers, collate_fn=custom_collate)
176
-
177
-
178
- class SetupCallback(Callback):
179
- def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
180
- super().__init__()
181
- self.resume = resume
182
- self.now = now
183
- self.logdir = logdir
184
- self.ckptdir = ckptdir
185
- self.cfgdir = cfgdir
186
- self.config = config
187
- self.lightning_config = lightning_config
188
-
189
- def on_pretrain_routine_start(self, trainer, pl_module):
190
- if trainer.global_rank == 0:
191
- # Create logdirs and save configs
192
- os.makedirs(self.logdir, exist_ok=True)
193
- os.makedirs(self.ckptdir, exist_ok=True)
194
- os.makedirs(self.cfgdir, exist_ok=True)
195
-
196
- print("Project config")
197
- print(self.config.pretty())
198
- OmegaConf.save(self.config,
199
- os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
200
-
201
- print("Lightning config")
202
- print(self.lightning_config.pretty())
203
- OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
204
- os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
205
-
206
- else:
207
- # ModelCheckpoint callback created log directory --- remove it
208
- if not self.resume and os.path.exists(self.logdir):
209
- dst, name = os.path.split(self.logdir)
210
- dst = os.path.join(dst, "child_runs", name)
211
- os.makedirs(os.path.split(dst)[0], exist_ok=True)
212
- try:
213
- os.rename(self.logdir, dst)
214
- except FileNotFoundError:
215
- pass
216
-
217
-
218
- class ImageLogger(Callback):
219
- def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
220
- super().__init__()
221
- self.batch_freq = batch_frequency
222
- self.max_images = max_images
223
- self.logger_log_images = {
224
- pl.loggers.WandbLogger: self._wandb,
225
- pl.loggers.TestTubeLogger: self._testtube,
226
- }
227
- self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
228
- if not increase_log_steps:
229
- self.log_steps = [self.batch_freq]
230
- self.clamp = clamp
231
-
232
- @rank_zero_only
233
- def _wandb(self, pl_module, images, batch_idx, split):
234
- raise ValueError("No way wandb")
235
- grids = dict()
236
- for k in images:
237
- grid = torchvision.utils.make_grid(images[k])
238
- grids[f"{split}/{k}"] = wandb.Image(grid)
239
- pl_module.logger.experiment.log(grids)
240
-
241
- @rank_zero_only
242
- def _testtube(self, pl_module, images, batch_idx, split):
243
- for k in images:
244
- grid = torchvision.utils.make_grid(images[k])
245
- grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
246
-
247
- tag = f"{split}/{k}"
248
- pl_module.logger.experiment.add_image(
249
- tag, grid,
250
- global_step=pl_module.global_step)
251
-
252
- @rank_zero_only
253
- def log_local(self, save_dir, split, images,
254
- global_step, current_epoch, batch_idx):
255
- root = os.path.join(save_dir, "images", split)
256
- for k in images:
257
- grid = torchvision.utils.make_grid(images[k], nrow=4)
258
-
259
- grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
260
- grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
261
- grid = grid.numpy()
262
- grid = (grid*255).astype(np.uint8)
263
- filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
264
- k,
265
- global_step,
266
- current_epoch,
267
- batch_idx)
268
- path = os.path.join(root, filename)
269
- os.makedirs(os.path.split(path)[0], exist_ok=True)
270
- Image.fromarray(grid).save(path)
271
-
272
- def log_img(self, pl_module, batch, batch_idx, split="train"):
273
- if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
274
- hasattr(pl_module, "log_images") and
275
- callable(pl_module.log_images) and
276
- self.max_images > 0):
277
- logger = type(pl_module.logger)
278
-
279
- is_train = pl_module.training
280
- if is_train:
281
- pl_module.eval()
282
-
283
- with torch.no_grad():
284
- images = pl_module.log_images(batch, split=split, pl_module=pl_module)
285
-
286
- for k in images:
287
- N = min(images[k].shape[0], self.max_images)
288
- images[k] = images[k][:N]
289
- if isinstance(images[k], torch.Tensor):
290
- images[k] = images[k].detach().cpu()
291
- if self.clamp:
292
- images[k] = torch.clamp(images[k], -1., 1.)
293
-
294
- self.log_local(pl_module.logger.save_dir, split, images,
295
- pl_module.global_step, pl_module.current_epoch, batch_idx)
296
-
297
- logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
298
- logger_log_images(pl_module, images, pl_module.global_step, split)
299
-
300
- if is_train:
301
- pl_module.train()
302
-
303
- def check_frequency(self, batch_idx):
304
- if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
305
- try:
306
- self.log_steps.pop(0)
307
- except IndexError:
308
- pass
309
- return True
310
- return False
311
-
312
- def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
313
- self.log_img(pl_module, batch, batch_idx, split="train")
314
-
315
- def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
316
- self.log_img(pl_module, batch, batch_idx, split="val")
317
-
318
-
319
-
320
- if __name__ == "__main__":
321
- # custom parser to specify config files, train, test and debug mode,
322
- # postfix, resume.
323
- # `--key value` arguments are interpreted as arguments to the trainer.
324
- # `nested.key=value` arguments are interpreted as config parameters.
325
- # configs are merged from left-to-right followed by command line parameters.
326
-
327
- # model:
328
- # base_learning_rate: float
329
- # target: path to lightning module
330
- # params:
331
- # key: value
332
- # data:
333
- # target: main.DataModuleFromConfig
334
- # params:
335
- # batch_size: int
336
- # wrap: bool
337
- # train:
338
- # target: path to train dataset
339
- # params:
340
- # key: value
341
- # validation:
342
- # target: path to validation dataset
343
- # params:
344
- # key: value
345
- # test:
346
- # target: path to test dataset
347
- # params:
348
- # key: value
349
- # lightning: (optional, has sane defaults and can be specified on cmdline)
350
- # trainer:
351
- # additional arguments to trainer
352
- # logger:
353
- # logger to instantiate
354
- # modelcheckpoint:
355
- # modelcheckpoint to instantiate
356
- # callbacks:
357
- # callback1:
358
- # target: importpath
359
- # params:
360
- # key: value
361
-
362
- now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
363
-
364
- # add cwd for convenience and to make classes in this file available when
365
- # running as `python main.py`
366
- # (in particular `main.DataModuleFromConfig`)
367
- sys.path.append(os.getcwd())
368
-
369
- parser = get_parser()
370
- parser = Trainer.add_argparse_args(parser)
371
-
372
- opt, unknown = parser.parse_known_args()
373
- if opt.name and opt.resume:
374
- raise ValueError(
375
- "-n/--name and -r/--resume cannot be specified both."
376
- "If you want to resume training in a new log folder, "
377
- "use -n/--name in combination with --resume_from_checkpoint"
378
- )
379
- if opt.resume:
380
- if not os.path.exists(opt.resume):
381
- raise ValueError("Cannot find {}".format(opt.resume))
382
- if os.path.isfile(opt.resume):
383
- paths = opt.resume.split("/")
384
- idx = len(paths)-paths[::-1].index("logs")+1
385
- logdir = "/".join(paths[:idx])
386
- ckpt = opt.resume
387
- else:
388
- assert os.path.isdir(opt.resume), opt.resume
389
- logdir = opt.resume.rstrip("/")
390
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
391
-
392
- opt.resume_from_checkpoint = ckpt
393
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
394
- opt.base = base_configs+opt.base
395
- _tmp = logdir.split("/")
396
- nowname = _tmp[_tmp.index("logs")+1]
397
- else:
398
- if opt.name:
399
- name = "_"+opt.name
400
- elif opt.base:
401
- cfg_fname = os.path.split(opt.base[0])[-1]
402
- cfg_name = os.path.splitext(cfg_fname)[0]
403
- name = "_"+cfg_name
404
- else:
405
- name = ""
406
- nowname = now+name+opt.postfix
407
- logdir = os.path.join("logs", nowname)
408
-
409
- ckptdir = os.path.join(logdir, "checkpoints")
410
- cfgdir = os.path.join(logdir, "configs")
411
- seed_everything(opt.seed)
412
-
413
- try:
414
- # init and save configs
415
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
416
- cli = OmegaConf.from_dotlist(unknown)
417
- config = OmegaConf.merge(*configs, cli)
418
- lightning_config = config.pop("lightning", OmegaConf.create())
419
- # merge trainer cli with config
420
- trainer_config = lightning_config.get("trainer", OmegaConf.create())
421
- # default to ddp
422
- trainer_config["distributed_backend"] = "ddp"
423
- for k in nondefault_trainer_args(opt):
424
- trainer_config[k] = getattr(opt, k)
425
- if not "gpus" in trainer_config:
426
- del trainer_config["distributed_backend"]
427
- cpu = True
428
- else:
429
- gpuinfo = trainer_config["gpus"]
430
- print(f"Running on GPUs {gpuinfo}")
431
- cpu = False
432
- trainer_opt = argparse.Namespace(**trainer_config)
433
- lightning_config.trainer = trainer_config
434
-
435
- # model
436
- model = instantiate_from_config(config.model)
437
-
438
- # trainer and callbacks
439
- trainer_kwargs = dict()
440
-
441
- # default logger configs
442
- # NOTE wandb < 0.10.0 interferes with shutdown
443
- # wandb >= 0.10.0 seems to fix it but still interferes with pudb
444
- # debugging (wrongly sized pudb ui)
445
- # thus prefer testtube for now
446
- default_logger_cfgs = {
447
- "wandb": {
448
- "target": "pytorch_lightning.loggers.WandbLogger",
449
- "params": {
450
- "name": nowname,
451
- "save_dir": logdir,
452
- "offline": opt.debug,
453
- "id": nowname,
454
- }
455
- },
456
- "testtube": {
457
- "target": "pytorch_lightning.loggers.TestTubeLogger",
458
- "params": {
459
- "name": "testtube",
460
- "save_dir": logdir,
461
- }
462
- },
463
- }
464
- default_logger_cfg = default_logger_cfgs["testtube"]
465
- logger_cfg = lightning_config.logger or OmegaConf.create()
466
- logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
467
- trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
468
-
469
- # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
470
- # specify which metric is used to determine best models
471
- default_modelckpt_cfg = {
472
- "target": "pytorch_lightning.callbacks.ModelCheckpoint",
473
- "params": {
474
- "dirpath": ckptdir,
475
- "filename": "{epoch:06}",
476
- "verbose": True,
477
- "save_last": True,
478
- }
479
- }
480
- if hasattr(model, "monitor"):
481
- print(f"Monitoring {model.monitor} as checkpoint metric.")
482
- default_modelckpt_cfg["params"]["monitor"] = model.monitor
483
- default_modelckpt_cfg["params"]["save_top_k"] = 3
484
-
485
- modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
486
- modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
487
- trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
488
-
489
- # add callback which sets up log directory
490
- default_callbacks_cfg = {
491
- "setup_callback": {
492
- "target": "main.SetupCallback",
493
- "params": {
494
- "resume": opt.resume,
495
- "now": now,
496
- "logdir": logdir,
497
- "ckptdir": ckptdir,
498
- "cfgdir": cfgdir,
499
- "config": config,
500
- "lightning_config": lightning_config,
501
- }
502
- },
503
- "image_logger": {
504
- "target": "main.ImageLogger",
505
- "params": {
506
- "batch_frequency": 750,
507
- "max_images": 4,
508
- "clamp": True
509
- }
510
- },
511
- "learning_rate_logger": {
512
- "target": "main.LearningRateMonitor",
513
- "params": {
514
- "logging_interval": "step",
515
- #"log_momentum": True
516
- }
517
- },
518
- }
519
- callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
520
- callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
521
- trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
522
-
523
- trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
524
-
525
- # data
526
- data = instantiate_from_config(config.data)
527
- # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
528
- # calling these ourselves should not be necessary but it is.
529
- # lightning still takes care of proper multiprocessing though
530
- data.prepare_data()
531
- data.setup()
532
-
533
- # configure learning rate
534
- bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
535
- if not cpu:
536
- ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
537
- else:
538
- ngpu = 1
539
- accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
540
- print(f"accumulate_grad_batches = {accumulate_grad_batches}")
541
- lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
542
- model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
543
- print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
544
- model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
545
-
546
- # allow checkpointing via USR1
547
- def melk(*args, **kwargs):
548
- # run all checkpoint hooks
549
- if trainer.global_rank == 0:
550
- print("Summoning checkpoint.")
551
- ckpt_path = os.path.join(ckptdir, "last.ckpt")
552
- trainer.save_checkpoint(ckpt_path)
553
-
554
- def divein(*args, **kwargs):
555
- if trainer.global_rank == 0:
556
- import pudb; pudb.set_trace()
557
-
558
- import signal
559
- signal.signal(signal.SIGUSR1, melk)
560
- signal.signal(signal.SIGUSR2, divein)
561
-
562
- # run
563
- if opt.train:
564
- try:
565
- trainer.fit(model, data)
566
- except Exception:
567
- melk()
568
- raise
569
- if not opt.no_test and not trainer.interrupted:
570
- trainer.test(model, data)
571
- except Exception:
572
- if opt.debug and trainer.global_rank==0:
573
- try:
574
- import pudb as debugger
575
- except ImportError:
576
- import pdb as debugger
577
- debugger.post_mortem()
578
- raise
579
- finally:
580
- # move newly created debug project to debug_runs
581
- if opt.debug and not opt.resume and trainer.global_rank==0:
582
- dst, name = os.path.split(logdir)
583
- dst = os.path.join(dst, "debug_runs", name)
584
- os.makedirs(os.path.split(dst)[0], exist_ok=True)
585
- os.rename(logdir, dst)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/scripts/extract_depth.py DELETED
@@ -1,112 +0,0 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- from tqdm import trange
5
- from PIL import Image
6
-
7
-
8
- def get_state(gpu):
9
- import torch
10
- midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
11
- if gpu:
12
- midas.cuda()
13
- midas.eval()
14
-
15
- midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
16
- transform = midas_transforms.default_transform
17
-
18
- state = {"model": midas,
19
- "transform": transform}
20
- return state
21
-
22
-
23
- def depth_to_rgba(x):
24
- assert x.dtype == np.float32
25
- assert len(x.shape) == 2
26
- y = x.copy()
27
- y.dtype = np.uint8
28
- y = y.reshape(x.shape+(4,))
29
- return np.ascontiguousarray(y)
30
-
31
-
32
- def rgba_to_depth(x):
33
- assert x.dtype == np.uint8
34
- assert len(x.shape) == 3 and x.shape[2] == 4
35
- y = x.copy()
36
- y.dtype = np.float32
37
- y = y.reshape(x.shape[:2])
38
- return np.ascontiguousarray(y)
39
-
40
-
41
- def run(x, state):
42
- model = state["model"]
43
- transform = state["transform"]
44
- hw = x.shape[:2]
45
- with torch.no_grad():
46
- prediction = model(transform((x + 1.0) * 127.5).cuda())
47
- prediction = torch.nn.functional.interpolate(
48
- prediction.unsqueeze(1),
49
- size=hw,
50
- mode="bicubic",
51
- align_corners=False,
52
- ).squeeze()
53
- output = prediction.cpu().numpy()
54
- return output
55
-
56
-
57
- def get_filename(relpath, level=-2):
58
- # save class folder structure and filename:
59
- fn = relpath.split(os.sep)[level:]
60
- folder = fn[-2]
61
- file = fn[-1].split('.')[0]
62
- return folder, file
63
-
64
-
65
- def save_depth(dataset, path, debug=False):
66
- os.makedirs(path)
67
- N = len(dset)
68
- if debug:
69
- N = 10
70
- state = get_state(gpu=True)
71
- for idx in trange(N, desc="Data"):
72
- ex = dataset[idx]
73
- image, relpath = ex["image"], ex["relpath"]
74
- folder, filename = get_filename(relpath)
75
- # prepare
76
- folderabspath = os.path.join(path, folder)
77
- os.makedirs(folderabspath, exist_ok=True)
78
- savepath = os.path.join(folderabspath, filename)
79
- # run model
80
- xout = run(image, state)
81
- I = depth_to_rgba(xout)
82
- Image.fromarray(I).save("{}.png".format(savepath))
83
-
84
-
85
- if __name__ == "__main__":
86
- from taming.data.imagenet import ImageNetTrain, ImageNetValidation
87
- out = "data/imagenet_depth"
88
- if not os.path.exists(out):
89
- print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
90
- "(be prepared that the output size will be larger than ImageNet itself).")
91
- exit(1)
92
-
93
- # go
94
- dset = ImageNetValidation()
95
- abspath = os.path.join(out, "val")
96
- if os.path.exists(abspath):
97
- print("{} exists - not doing anything.".format(abspath))
98
- else:
99
- print("preparing {}".format(abspath))
100
- save_depth(dset, abspath)
101
- print("done with validation split")
102
-
103
- dset = ImageNetTrain()
104
- abspath = os.path.join(out, "train")
105
- if os.path.exists(abspath):
106
- print("{} exists - not doing anything.".format(abspath))
107
- else:
108
- print("preparing {}".format(abspath))
109
- save_depth(dset, abspath)
110
- print("done with train split")
111
-
112
- print("done done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/scripts/extract_segmentation.py DELETED
@@ -1,130 +0,0 @@
1
- import sys, os
2
- import numpy as np
3
- import scipy
4
- import torch
5
- import torch.nn as nn
6
- from scipy import ndimage
7
- from tqdm import tqdm, trange
8
- from PIL import Image
9
- import torch.hub
10
- import torchvision
11
- import torch.nn.functional as F
12
-
13
- # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
14
- # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
15
- # and put the path here
16
- CKPT_PATH = "TODO"
17
-
18
- rescale = lambda x: (x + 1.) / 2.
19
-
20
- def rescale_bgr(x):
21
- x = (x+1)*127.5
22
- x = torch.flip(x, dims=[0])
23
- return x
24
-
25
-
26
- class COCOStuffSegmenter(nn.Module):
27
- def __init__(self, config):
28
- super().__init__()
29
- self.config = config
30
- self.n_labels = 182
31
- model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
32
- ckpt_path = CKPT_PATH
33
- model.load_state_dict(torch.load(ckpt_path))
34
- self.model = model
35
-
36
- normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
37
- self.image_transform = torchvision.transforms.Compose([
38
- torchvision.transforms.Lambda(lambda image: torch.stack(
39
- [normalize(rescale_bgr(x)) for x in image]))
40
- ])
41
-
42
- def forward(self, x, upsample=None):
43
- x = self._pre_process(x)
44
- x = self.model(x)
45
- if upsample is not None:
46
- x = torch.nn.functional.upsample_bilinear(x, size=upsample)
47
- return x
48
-
49
- def _pre_process(self, x):
50
- x = self.image_transform(x)
51
- return x
52
-
53
- @property
54
- def mean(self):
55
- # bgr
56
- return [104.008, 116.669, 122.675]
57
-
58
- @property
59
- def std(self):
60
- return [1.0, 1.0, 1.0]
61
-
62
- @property
63
- def input_size(self):
64
- return [3, 224, 224]
65
-
66
-
67
- def run_model(img, model):
68
- model = model.eval()
69
- with torch.no_grad():
70
- segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
71
- segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
72
- return segmentation.detach().cpu()
73
-
74
-
75
- def get_input(batch, k):
76
- x = batch[k]
77
- if len(x.shape) == 3:
78
- x = x[..., None]
79
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
80
- return x.float()
81
-
82
-
83
- def save_segmentation(segmentation, path):
84
- # --> class label to uint8, save as png
85
- os.makedirs(os.path.dirname(path), exist_ok=True)
86
- assert len(segmentation.shape)==4
87
- assert segmentation.shape[0]==1
88
- for seg in segmentation:
89
- seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
90
- seg = Image.fromarray(seg)
91
- seg.save(path)
92
-
93
-
94
- def iterate_dataset(dataloader, destpath, model):
95
- os.makedirs(destpath, exist_ok=True)
96
- num_processed = 0
97
- for i, batch in tqdm(enumerate(dataloader), desc="Data"):
98
- try:
99
- img = get_input(batch, "image")
100
- img = img.cuda()
101
- seg = run_model(img, model)
102
-
103
- path = batch["relative_file_path_"][0]
104
- path = os.path.splitext(path)[0]
105
-
106
- path = os.path.join(destpath, path + ".png")
107
- save_segmentation(seg, path)
108
- num_processed += 1
109
- except Exception as e:
110
- print(e)
111
- print("but anyhow..")
112
-
113
- print("Processed {} files. Bye.".format(num_processed))
114
-
115
-
116
- from taming.data.sflckr import Examples
117
- from torch.utils.data import DataLoader
118
-
119
- if __name__ == "__main__":
120
- dest = sys.argv[1]
121
- batchsize = 1
122
- print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
123
-
124
- model = COCOStuffSegmenter({}).cuda()
125
- print("Instantiated model.")
126
-
127
- dataset = Examples()
128
- dloader = DataLoader(dataset, batch_size=batchsize)
129
- iterate_dataset(dataloader=dloader, destpath=dest, model=model)
130
- print("done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/scripts/extract_submodel.py DELETED
@@ -1,17 +0,0 @@
1
- import torch
2
- import sys
3
-
4
- if __name__ == "__main__":
5
- inpath = sys.argv[1]
6
- outpath = sys.argv[2]
7
- submodel = "cond_stage_model"
8
- if len(sys.argv) > 3:
9
- submodel = sys.argv[3]
10
-
11
- print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
12
-
13
- sd = torch.load(inpath, map_location="cpu")
14
- new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
15
- for k,v in sd["state_dict"].items()
16
- if k.startswith("cond_stage_model"))}
17
- torch.save(new_sd, outpath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/scripts/make_samples.py DELETED
@@ -1,292 +0,0 @@
1
- import argparse, os, sys, glob, math, time
2
- import torch
3
- import numpy as np
4
- from omegaconf import OmegaConf
5
- from PIL import Image
6
- from main import instantiate_from_config, DataModuleFromConfig
7
- from torch.utils.data import DataLoader
8
- from torch.utils.data.dataloader import default_collate
9
- from tqdm import trange
10
-
11
-
12
- def save_image(x, path):
13
- c,h,w = x.shape
14
- assert c==3
15
- x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
16
- Image.fromarray(x).save(path)
17
-
18
-
19
- @torch.no_grad()
20
- def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
21
- if len(dsets.datasets) > 1:
22
- split = sorted(dsets.datasets.keys())[0]
23
- dset = dsets.datasets[split]
24
- else:
25
- dset = next(iter(dsets.datasets.values()))
26
- print("Dataset: ", dset.__class__.__name__)
27
- for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
28
- indices = list(range(start_idx, start_idx+batch_size))
29
- example = default_collate([dset[i] for i in indices])
30
-
31
- x = model.get_input("image", example).to(model.device)
32
- for i in range(x.shape[0]):
33
- save_image(x[i], os.path.join(outdir, "originals",
34
- "{:06}.png".format(indices[i])))
35
-
36
- cond_key = model.cond_stage_key
37
- c = model.get_input(cond_key, example).to(model.device)
38
-
39
- scale_factor = 1.0
40
- quant_z, z_indices = model.encode_to_z(x)
41
- quant_c, c_indices = model.encode_to_c(c)
42
-
43
- cshape = quant_z.shape
44
-
45
- xrec = model.first_stage_model.decode(quant_z)
46
- for i in range(xrec.shape[0]):
47
- save_image(xrec[i], os.path.join(outdir, "reconstructions",
48
- "{:06}.png".format(indices[i])))
49
-
50
- if cond_key == "segmentation":
51
- # get image from segmentation mask
52
- num_classes = c.shape[1]
53
- c = torch.argmax(c, dim=1, keepdim=True)
54
- c = torch.nn.functional.one_hot(c, num_classes=num_classes)
55
- c = c.squeeze(1).permute(0, 3, 1, 2).float()
56
- c = model.cond_stage_model.to_rgb(c)
57
-
58
- idx = z_indices
59
-
60
- half_sample = False
61
- if half_sample:
62
- start = idx.shape[1]//2
63
- else:
64
- start = 0
65
-
66
- idx[:,start:] = 0
67
- idx = idx.reshape(cshape[0],cshape[2],cshape[3])
68
- start_i = start//cshape[3]
69
- start_j = start %cshape[3]
70
-
71
- cidx = c_indices
72
- cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
73
-
74
- sample = True
75
-
76
- for i in range(start_i,cshape[2]-0):
77
- if i <= 8:
78
- local_i = i
79
- elif cshape[2]-i < 8:
80
- local_i = 16-(cshape[2]-i)
81
- else:
82
- local_i = 8
83
- for j in range(start_j,cshape[3]-0):
84
- if j <= 8:
85
- local_j = j
86
- elif cshape[3]-j < 8:
87
- local_j = 16-(cshape[3]-j)
88
- else:
89
- local_j = 8
90
-
91
- i_start = i-local_i
92
- i_end = i_start+16
93
- j_start = j-local_j
94
- j_end = j_start+16
95
- patch = idx[:,i_start:i_end,j_start:j_end]
96
- patch = patch.reshape(patch.shape[0],-1)
97
- cpatch = cidx[:, i_start:i_end, j_start:j_end]
98
- cpatch = cpatch.reshape(cpatch.shape[0], -1)
99
- patch = torch.cat((cpatch, patch), dim=1)
100
- logits,_ = model.transformer(patch[:,:-1])
101
- logits = logits[:, -256:, :]
102
- logits = logits.reshape(cshape[0],16,16,-1)
103
- logits = logits[:,local_i,local_j,:]
104
-
105
- logits = logits/temperature
106
-
107
- if top_k is not None:
108
- logits = model.top_k_logits(logits, top_k)
109
- # apply softmax to convert to probabilities
110
- probs = torch.nn.functional.softmax(logits, dim=-1)
111
- # sample from the distribution or take the most likely
112
- if sample:
113
- ix = torch.multinomial(probs, num_samples=1)
114
- else:
115
- _, ix = torch.topk(probs, k=1, dim=-1)
116
- idx[:,i,j] = ix
117
-
118
- xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
119
- for i in range(xsample.shape[0]):
120
- save_image(xsample[i], os.path.join(outdir, "samples",
121
- "{:06}.png".format(indices[i])))
122
-
123
-
124
- def get_parser():
125
- parser = argparse.ArgumentParser()
126
- parser.add_argument(
127
- "-r",
128
- "--resume",
129
- type=str,
130
- nargs="?",
131
- help="load from logdir or checkpoint in logdir",
132
- )
133
- parser.add_argument(
134
- "-b",
135
- "--base",
136
- nargs="*",
137
- metavar="base_config.yaml",
138
- help="paths to base configs. Loaded from left-to-right. "
139
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
140
- default=list(),
141
- )
142
- parser.add_argument(
143
- "-c",
144
- "--config",
145
- nargs="?",
146
- metavar="single_config.yaml",
147
- help="path to single config. If specified, base configs will be ignored "
148
- "(except for the last one if left unspecified).",
149
- const=True,
150
- default="",
151
- )
152
- parser.add_argument(
153
- "--ignore_base_data",
154
- action="store_true",
155
- help="Ignore data specification from base configs. Useful if you want "
156
- "to specify a custom datasets on the command line.",
157
- )
158
- parser.add_argument(
159
- "--outdir",
160
- required=True,
161
- type=str,
162
- help="Where to write outputs to.",
163
- )
164
- parser.add_argument(
165
- "--top_k",
166
- type=int,
167
- default=100,
168
- help="Sample from among top-k predictions.",
169
- )
170
- parser.add_argument(
171
- "--temperature",
172
- type=float,
173
- default=1.0,
174
- help="Sampling temperature.",
175
- )
176
- return parser
177
-
178
-
179
- def load_model_from_config(config, sd, gpu=True, eval_mode=True):
180
- if "ckpt_path" in config.params:
181
- print("Deleting the restore-ckpt path from the config...")
182
- config.params.ckpt_path = None
183
- if "downsample_cond_size" in config.params:
184
- print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
185
- config.params.downsample_cond_size = -1
186
- config.params["downsample_cond_factor"] = 0.5
187
- try:
188
- if "ckpt_path" in config.params.first_stage_config.params:
189
- config.params.first_stage_config.params.ckpt_path = None
190
- print("Deleting the first-stage restore-ckpt path from the config...")
191
- if "ckpt_path" in config.params.cond_stage_config.params:
192
- config.params.cond_stage_config.params.ckpt_path = None
193
- print("Deleting the cond-stage restore-ckpt path from the config...")
194
- except:
195
- pass
196
-
197
- model = instantiate_from_config(config)
198
- if sd is not None:
199
- missing, unexpected = model.load_state_dict(sd, strict=False)
200
- print(f"Missing Keys in State Dict: {missing}")
201
- print(f"Unexpected Keys in State Dict: {unexpected}")
202
- if gpu:
203
- model.cuda()
204
- if eval_mode:
205
- model.eval()
206
- return {"model": model}
207
-
208
-
209
- def get_data(config):
210
- # get data
211
- data = instantiate_from_config(config.data)
212
- data.prepare_data()
213
- data.setup()
214
- return data
215
-
216
-
217
- def load_model_and_dset(config, ckpt, gpu, eval_mode):
218
- # get data
219
- dsets = get_data(config) # calls data.config ...
220
-
221
- # now load the specified checkpoint
222
- if ckpt:
223
- pl_sd = torch.load(ckpt, map_location="cpu")
224
- global_step = pl_sd["global_step"]
225
- else:
226
- pl_sd = {"state_dict": None}
227
- global_step = None
228
- model = load_model_from_config(config.model,
229
- pl_sd["state_dict"],
230
- gpu=gpu,
231
- eval_mode=eval_mode)["model"]
232
- return dsets, model, global_step
233
-
234
-
235
- if __name__ == "__main__":
236
- sys.path.append(os.getcwd())
237
-
238
- parser = get_parser()
239
-
240
- opt, unknown = parser.parse_known_args()
241
-
242
- ckpt = None
243
- if opt.resume:
244
- if not os.path.exists(opt.resume):
245
- raise ValueError("Cannot find {}".format(opt.resume))
246
- if os.path.isfile(opt.resume):
247
- paths = opt.resume.split("/")
248
- try:
249
- idx = len(paths)-paths[::-1].index("logs")+1
250
- except ValueError:
251
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
252
- logdir = "/".join(paths[:idx])
253
- ckpt = opt.resume
254
- else:
255
- assert os.path.isdir(opt.resume), opt.resume
256
- logdir = opt.resume.rstrip("/")
257
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
258
- print(f"logdir:{logdir}")
259
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
260
- opt.base = base_configs+opt.base
261
-
262
- if opt.config:
263
- if type(opt.config) == str:
264
- opt.base = [opt.config]
265
- else:
266
- opt.base = [opt.base[-1]]
267
-
268
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
269
- cli = OmegaConf.from_dotlist(unknown)
270
- if opt.ignore_base_data:
271
- for config in configs:
272
- if hasattr(config, "data"): del config["data"]
273
- config = OmegaConf.merge(*configs, cli)
274
-
275
- print(ckpt)
276
- gpu = True
277
- eval_mode = True
278
- show_config = False
279
- if show_config:
280
- print(OmegaConf.to_container(config))
281
-
282
- dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
283
- print(f"Global step: {global_step}")
284
-
285
- outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
286
- opt.top_k,
287
- opt.temperature))
288
- os.makedirs(outdir, exist_ok=True)
289
- print("Writing samples to ", outdir)
290
- for k in ["originals", "reconstructions", "samples"]:
291
- os.makedirs(os.path.join(outdir, k), exist_ok=True)
292
- run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/scripts/make_scene_samples.py DELETED
@@ -1,198 +0,0 @@
1
- import glob
2
- import os
3
- import sys
4
- from itertools import product
5
- from pathlib import Path
6
- from typing import Literal, List, Optional, Tuple
7
-
8
- import numpy as np
9
- import torch
10
- from omegaconf import OmegaConf
11
- from pytorch_lightning import seed_everything
12
- from torch import Tensor
13
- from torchvision.utils import save_image
14
- from tqdm import tqdm
15
-
16
- from scripts.make_samples import get_parser, load_model_and_dset
17
- from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
18
- from taming.data.helper_types import BoundingBox, Annotation
19
- from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
20
- from taming.models.cond_transformer import Net2NetTransformer
21
-
22
- seed_everything(42424242)
23
- device: Literal['cuda', 'cpu'] = 'cuda'
24
- first_stage_factor = 16
25
- trained_on_res = 256
26
-
27
-
28
- def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
29
- assert 0 <= coord < coord_max
30
- coord_desired_center = (coord_window - 1) // 2
31
- return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
32
-
33
-
34
- def get_crop_coordinates(x: int, y: int) -> BoundingBox:
35
- WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
36
- x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
37
- y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
38
- w = first_stage_factor / WIDTH
39
- h = first_stage_factor / HEIGHT
40
- return x0, y0, w, h
41
-
42
-
43
- def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
44
- WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
45
- x0 = _helper(predict_x, WIDTH, first_stage_factor)
46
- y0 = _helper(predict_y, HEIGHT, first_stage_factor)
47
- no_images = z_indices.shape[0]
48
- cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
49
- cut_out_2 = z_indices[:, predict_y, x0:predict_x]
50
- return torch.cat((cut_out_1, cut_out_2), dim=1)
51
-
52
-
53
- @torch.no_grad()
54
- def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
55
- conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
56
- temperature: float, top_k: int) -> Tensor:
57
- x_max, y_max = desired_z_shape[1], desired_z_shape[0]
58
-
59
- annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
60
-
61
- recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
62
- if not recompute_conditional:
63
- crop_coordinates = get_crop_coordinates(0, 0)
64
- conditional_indices = conditional_builder.build(annotations, crop_coordinates)
65
- c_indices = conditional_indices.to(device).repeat(no_samples, 1)
66
- z_indices = torch.zeros((no_samples, 0), device=device).long()
67
- output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
68
- sample=True, top_k=top_k)
69
- else:
70
- output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
71
- for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
72
- crop_coordinates = get_crop_coordinates(predict_x, predict_y)
73
- z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
74
- conditional_indices = conditional_builder.build(annotations, crop_coordinates)
75
- c_indices = conditional_indices.to(device).repeat(no_samples, 1)
76
- new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
77
- output_indices[:, predict_y, predict_x] = new_index[:, -1]
78
- z_shape = (
79
- no_samples,
80
- model.first_stage_model.quantize.e_dim, # codebook embed_dim
81
- desired_z_shape[0], # z_height
82
- desired_z_shape[1] # z_width
83
- )
84
- x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
85
- x_sample = x_sample.to('cpu')
86
-
87
- plotter = conditional_builder.plot
88
- figure_size = (x_sample.shape[2], x_sample.shape[3])
89
- scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
90
- plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
91
- return torch.cat((x_sample, plot.unsqueeze(0)))
92
-
93
-
94
- def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
95
- if not resolution_str.count(',') == 1:
96
- raise ValueError("Give resolution as in 'height,width'")
97
- res_h, res_w = resolution_str.split(',')
98
- res_h = max(int(res_h), trained_on_res)
99
- res_w = max(int(res_w), trained_on_res)
100
- z_h = int(round(res_h/first_stage_factor))
101
- z_w = int(round(res_w/first_stage_factor))
102
- return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
103
-
104
-
105
- def add_arg_to_parser(parser):
106
- parser.add_argument(
107
- "-R",
108
- "--resolution",
109
- type=str,
110
- default='256,256',
111
- help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
112
- )
113
- parser.add_argument(
114
- "-C",
115
- "--conditional",
116
- type=str,
117
- default='objects_bbox',
118
- help=f"objects_bbox or objects_center_points",
119
- )
120
- parser.add_argument(
121
- "-N",
122
- "--n_samples_per_layout",
123
- type=int,
124
- default=4,
125
- help=f"how many samples to generate per layout",
126
- )
127
- return parser
128
-
129
-
130
- if __name__ == "__main__":
131
- sys.path.append(os.getcwd())
132
-
133
- parser = get_parser()
134
- parser = add_arg_to_parser(parser)
135
-
136
- opt, unknown = parser.parse_known_args()
137
-
138
- ckpt = None
139
- if opt.resume:
140
- if not os.path.exists(opt.resume):
141
- raise ValueError("Cannot find {}".format(opt.resume))
142
- if os.path.isfile(opt.resume):
143
- paths = opt.resume.split("/")
144
- try:
145
- idx = len(paths)-paths[::-1].index("logs")+1
146
- except ValueError:
147
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
148
- logdir = "/".join(paths[:idx])
149
- ckpt = opt.resume
150
- else:
151
- assert os.path.isdir(opt.resume), opt.resume
152
- logdir = opt.resume.rstrip("/")
153
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
154
- print(f"logdir:{logdir}")
155
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
156
- opt.base = base_configs+opt.base
157
-
158
- if opt.config:
159
- if type(opt.config) == str:
160
- opt.base = [opt.config]
161
- else:
162
- opt.base = [opt.base[-1]]
163
-
164
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
165
- cli = OmegaConf.from_dotlist(unknown)
166
- if opt.ignore_base_data:
167
- for config in configs:
168
- if hasattr(config, "data"):
169
- del config["data"]
170
- config = OmegaConf.merge(*configs, cli)
171
- desired_z_shape, desired_resolution = get_resolution(opt.resolution)
172
- conditional = opt.conditional
173
-
174
- print(ckpt)
175
- gpu = True
176
- eval_mode = True
177
- show_config = False
178
- if show_config:
179
- print(OmegaConf.to_container(config))
180
-
181
- dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
182
- print(f"Global step: {global_step}")
183
-
184
- data_loader = dsets.val_dataloader()
185
- print(dsets.datasets["validation"].conditional_builders)
186
- conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
187
-
188
- outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
189
- outdir.mkdir(exist_ok=True, parents=True)
190
- print("Writing samples to ", outdir)
191
-
192
- p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
193
- for batch_no, batch in p_bar_1:
194
- save_img: Optional[Tensor] = None
195
- for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
196
- imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
197
- opt.n_samples_per_layout, opt.temperature, opt.top_k)
198
- save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/scripts/sample_conditional.py DELETED
@@ -1,355 +0,0 @@
1
- import argparse, os, sys, glob, math, time
2
- import torch
3
- import numpy as np
4
- from omegaconf import OmegaConf
5
- import streamlit as st
6
- from streamlit import caching
7
- from PIL import Image
8
- from main import instantiate_from_config, DataModuleFromConfig
9
- from torch.utils.data import DataLoader
10
- from torch.utils.data.dataloader import default_collate
11
-
12
-
13
- rescale = lambda x: (x + 1.) / 2.
14
-
15
-
16
- def bchw_to_st(x):
17
- return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
18
-
19
- def save_img(xstart, fname):
20
- I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
21
- Image.fromarray(I).save(fname)
22
-
23
-
24
-
25
- def get_interactive_image(resize=False):
26
- image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
27
- if image is not None:
28
- image = Image.open(image)
29
- if not image.mode == "RGB":
30
- image = image.convert("RGB")
31
- image = np.array(image).astype(np.uint8)
32
- print("upload image shape: {}".format(image.shape))
33
- img = Image.fromarray(image)
34
- if resize:
35
- img = img.resize((256, 256))
36
- image = np.array(img)
37
- return image
38
-
39
-
40
- def single_image_to_torch(x, permute=True):
41
- assert x is not None, "Please provide an image through the upload function"
42
- x = np.array(x)
43
- x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
44
- if permute:
45
- x = x.permute(0, 3, 1, 2)
46
- return x
47
-
48
-
49
- def pad_to_M(x, M):
50
- hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
51
- wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
52
- x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
53
- return x
54
-
55
- @torch.no_grad()
56
- def run_conditional(model, dsets):
57
- if len(dsets.datasets) > 1:
58
- split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
59
- dset = dsets.datasets[split]
60
- else:
61
- dset = next(iter(dsets.datasets.values()))
62
- batch_size = 1
63
- start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
64
- min_value=0,
65
- max_value=len(dset)-batch_size)
66
- indices = list(range(start_index, start_index+batch_size))
67
-
68
- example = default_collate([dset[i] for i in indices])
69
-
70
- x = model.get_input("image", example).to(model.device)
71
-
72
- cond_key = model.cond_stage_key
73
- c = model.get_input(cond_key, example).to(model.device)
74
-
75
- scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
76
- if scale_factor != 1.0:
77
- x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
78
- c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
79
-
80
- quant_z, z_indices = model.encode_to_z(x)
81
- quant_c, c_indices = model.encode_to_c(c)
82
-
83
- cshape = quant_z.shape
84
-
85
- xrec = model.first_stage_model.decode(quant_z)
86
- st.write("image: {}".format(x.shape))
87
- st.image(bchw_to_st(x), clamp=True, output_format="PNG")
88
- st.write("image reconstruction: {}".format(xrec.shape))
89
- st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
90
-
91
- if cond_key == "segmentation":
92
- # get image from segmentation mask
93
- num_classes = c.shape[1]
94
- c = torch.argmax(c, dim=1, keepdim=True)
95
- c = torch.nn.functional.one_hot(c, num_classes=num_classes)
96
- c = c.squeeze(1).permute(0, 3, 1, 2).float()
97
- c = model.cond_stage_model.to_rgb(c)
98
-
99
- st.write(f"{cond_key}: {tuple(c.shape)}")
100
- st.image(bchw_to_st(c), clamp=True, output_format="PNG")
101
-
102
- idx = z_indices
103
-
104
- half_sample = st.sidebar.checkbox("Image Completion", value=False)
105
- if half_sample:
106
- start = idx.shape[1]//2
107
- else:
108
- start = 0
109
-
110
- idx[:,start:] = 0
111
- idx = idx.reshape(cshape[0],cshape[2],cshape[3])
112
- start_i = start//cshape[3]
113
- start_j = start %cshape[3]
114
-
115
- if not half_sample and quant_z.shape == quant_c.shape:
116
- st.info("Setting idx to c_indices")
117
- idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
118
-
119
- cidx = c_indices
120
- cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
121
-
122
- xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
123
- st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
124
-
125
- temperature = st.number_input("Temperature", value=1.0)
126
- top_k = st.number_input("Top k", value=100)
127
- sample = st.checkbox("Sample", value=True)
128
- update_every = st.number_input("Update every", value=75)
129
-
130
- st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
131
-
132
- animate = st.checkbox("animate")
133
- if animate:
134
- import imageio
135
- outvid = "sampling.mp4"
136
- writer = imageio.get_writer(outvid, fps=25)
137
- elapsed_t = st.empty()
138
- info = st.empty()
139
- st.text("Sampled")
140
- if st.button("Sample"):
141
- output = st.empty()
142
- start_t = time.time()
143
- for i in range(start_i,cshape[2]-0):
144
- if i <= 8:
145
- local_i = i
146
- elif cshape[2]-i < 8:
147
- local_i = 16-(cshape[2]-i)
148
- else:
149
- local_i = 8
150
- for j in range(start_j,cshape[3]-0):
151
- if j <= 8:
152
- local_j = j
153
- elif cshape[3]-j < 8:
154
- local_j = 16-(cshape[3]-j)
155
- else:
156
- local_j = 8
157
-
158
- i_start = i-local_i
159
- i_end = i_start+16
160
- j_start = j-local_j
161
- j_end = j_start+16
162
- elapsed_t.text(f"Time: {time.time() - start_t} seconds")
163
- info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
164
- patch = idx[:,i_start:i_end,j_start:j_end]
165
- patch = patch.reshape(patch.shape[0],-1)
166
- cpatch = cidx[:, i_start:i_end, j_start:j_end]
167
- cpatch = cpatch.reshape(cpatch.shape[0], -1)
168
- patch = torch.cat((cpatch, patch), dim=1)
169
- logits,_ = model.transformer(patch[:,:-1])
170
- logits = logits[:, -256:, :]
171
- logits = logits.reshape(cshape[0],16,16,-1)
172
- logits = logits[:,local_i,local_j,:]
173
-
174
- logits = logits/temperature
175
-
176
- if top_k is not None:
177
- logits = model.top_k_logits(logits, top_k)
178
- # apply softmax to convert to probabilities
179
- probs = torch.nn.functional.softmax(logits, dim=-1)
180
- # sample from the distribution or take the most likely
181
- if sample:
182
- ix = torch.multinomial(probs, num_samples=1)
183
- else:
184
- _, ix = torch.topk(probs, k=1, dim=-1)
185
- idx[:,i,j] = ix
186
-
187
- if (i*cshape[3]+j)%update_every==0:
188
- xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
189
-
190
- xstart = bchw_to_st(xstart)
191
- output.image(xstart, clamp=True, output_format="PNG")
192
-
193
- if animate:
194
- writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
195
-
196
- xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
197
- xstart = bchw_to_st(xstart)
198
- output.image(xstart, clamp=True, output_format="PNG")
199
- #save_img(xstart, "full_res_sample.png")
200
- if animate:
201
- writer.close()
202
- st.video(outvid)
203
-
204
-
205
- def get_parser():
206
- parser = argparse.ArgumentParser()
207
- parser.add_argument(
208
- "-r",
209
- "--resume",
210
- type=str,
211
- nargs="?",
212
- help="load from logdir or checkpoint in logdir",
213
- )
214
- parser.add_argument(
215
- "-b",
216
- "--base",
217
- nargs="*",
218
- metavar="base_config.yaml",
219
- help="paths to base configs. Loaded from left-to-right. "
220
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
221
- default=list(),
222
- )
223
- parser.add_argument(
224
- "-c",
225
- "--config",
226
- nargs="?",
227
- metavar="single_config.yaml",
228
- help="path to single config. If specified, base configs will be ignored "
229
- "(except for the last one if left unspecified).",
230
- const=True,
231
- default="",
232
- )
233
- parser.add_argument(
234
- "--ignore_base_data",
235
- action="store_true",
236
- help="Ignore data specification from base configs. Useful if you want "
237
- "to specify a custom datasets on the command line.",
238
- )
239
- return parser
240
-
241
-
242
- def load_model_from_config(config, sd, gpu=True, eval_mode=True):
243
- if "ckpt_path" in config.params:
244
- st.warning("Deleting the restore-ckpt path from the config...")
245
- config.params.ckpt_path = None
246
- if "downsample_cond_size" in config.params:
247
- st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
248
- config.params.downsample_cond_size = -1
249
- config.params["downsample_cond_factor"] = 0.5
250
- try:
251
- if "ckpt_path" in config.params.first_stage_config.params:
252
- config.params.first_stage_config.params.ckpt_path = None
253
- st.warning("Deleting the first-stage restore-ckpt path from the config...")
254
- if "ckpt_path" in config.params.cond_stage_config.params:
255
- config.params.cond_stage_config.params.ckpt_path = None
256
- st.warning("Deleting the cond-stage restore-ckpt path from the config...")
257
- except:
258
- pass
259
-
260
- model = instantiate_from_config(config)
261
- if sd is not None:
262
- missing, unexpected = model.load_state_dict(sd, strict=False)
263
- st.info(f"Missing Keys in State Dict: {missing}")
264
- st.info(f"Unexpected Keys in State Dict: {unexpected}")
265
- if gpu:
266
- model.cuda()
267
- if eval_mode:
268
- model.eval()
269
- return {"model": model}
270
-
271
-
272
- def get_data(config):
273
- # get data
274
- data = instantiate_from_config(config.data)
275
- data.prepare_data()
276
- data.setup()
277
- return data
278
-
279
-
280
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
281
- def load_model_and_dset(config, ckpt, gpu, eval_mode):
282
- # get data
283
- dsets = get_data(config) # calls data.config ...
284
-
285
- # now load the specified checkpoint
286
- if ckpt:
287
- pl_sd = torch.load(ckpt, map_location="cpu")
288
- global_step = pl_sd["global_step"]
289
- else:
290
- pl_sd = {"state_dict": None}
291
- global_step = None
292
- model = load_model_from_config(config.model,
293
- pl_sd["state_dict"],
294
- gpu=gpu,
295
- eval_mode=eval_mode)["model"]
296
- return dsets, model, global_step
297
-
298
-
299
- if __name__ == "__main__":
300
- sys.path.append(os.getcwd())
301
-
302
- parser = get_parser()
303
-
304
- opt, unknown = parser.parse_known_args()
305
-
306
- ckpt = None
307
- if opt.resume:
308
- if not os.path.exists(opt.resume):
309
- raise ValueError("Cannot find {}".format(opt.resume))
310
- if os.path.isfile(opt.resume):
311
- paths = opt.resume.split("/")
312
- try:
313
- idx = len(paths)-paths[::-1].index("logs")+1
314
- except ValueError:
315
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
316
- logdir = "/".join(paths[:idx])
317
- ckpt = opt.resume
318
- else:
319
- assert os.path.isdir(opt.resume), opt.resume
320
- logdir = opt.resume.rstrip("/")
321
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
322
- print(f"logdir:{logdir}")
323
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
324
- opt.base = base_configs+opt.base
325
-
326
- if opt.config:
327
- if type(opt.config) == str:
328
- opt.base = [opt.config]
329
- else:
330
- opt.base = [opt.base[-1]]
331
-
332
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
333
- cli = OmegaConf.from_dotlist(unknown)
334
- if opt.ignore_base_data:
335
- for config in configs:
336
- if hasattr(config, "data"): del config["data"]
337
- config = OmegaConf.merge(*configs, cli)
338
-
339
- st.sidebar.text(ckpt)
340
- gs = st.sidebar.empty()
341
- gs.text(f"Global step: ?")
342
- st.sidebar.text("Options")
343
- #gpu = st.sidebar.checkbox("GPU", value=True)
344
- gpu = True
345
- #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
346
- eval_mode = True
347
- #show_config = st.sidebar.checkbox("Show Config", value=False)
348
- show_config = False
349
- if show_config:
350
- st.info("Checkpoint: {}".format(ckpt))
351
- st.json(OmegaConf.to_container(config))
352
-
353
- dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
354
- gs.text(f"Global step: {global_step}")
355
- run_conditional(model, dsets)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/scripts/sample_fast.py DELETED
@@ -1,260 +0,0 @@
1
- import argparse, os, sys, glob
2
- import torch
3
- import time
4
- import numpy as np
5
- from omegaconf import OmegaConf
6
- from PIL import Image
7
- from tqdm import tqdm, trange
8
- from einops import repeat
9
-
10
- from main import instantiate_from_config
11
- from taming.modules.transformer.mingpt import sample_with_past
12
-
13
-
14
- rescale = lambda x: (x + 1.) / 2.
15
-
16
-
17
- def chw_to_pillow(x):
18
- return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
19
-
20
-
21
- @torch.no_grad()
22
- def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
23
- dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
24
- log = dict()
25
- assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
26
- qzshape = [batch_size, dim_z, h, w]
27
- assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
28
- c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
29
- t1 = time.time()
30
- index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
31
- sample_logits=True, top_k=top_k, callback=callback,
32
- temperature=temperature, top_p=top_p)
33
- if verbose_time:
34
- sampling_time = time.time() - t1
35
- print(f"Full sampling takes about {sampling_time:.2f} seconds.")
36
- x_sample = model.decode_to_img(index_sample, qzshape)
37
- log["samples"] = x_sample
38
- log["class_label"] = c_indices
39
- return log
40
-
41
-
42
- @torch.no_grad()
43
- def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
44
- dim_z=256, h=16, w=16, verbose_time=False):
45
- log = dict()
46
- qzshape = [batch_size, dim_z, h, w]
47
- assert model.be_unconditional, 'Expecting an unconditional model.'
48
- c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
49
- t1 = time.time()
50
- index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
51
- sample_logits=True, top_k=top_k, callback=callback,
52
- temperature=temperature, top_p=top_p)
53
- if verbose_time:
54
- sampling_time = time.time() - t1
55
- print(f"Full sampling takes about {sampling_time:.2f} seconds.")
56
- x_sample = model.decode_to_img(index_sample, qzshape)
57
- log["samples"] = x_sample
58
- return log
59
-
60
-
61
- @torch.no_grad()
62
- def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
63
- given_classes=None, top_p=None):
64
- batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
65
- if not unconditional:
66
- assert given_classes is not None
67
- print("Running in pure class-conditional sampling mode. I will produce "
68
- f"{num_samples} samples for each of the {len(given_classes)} classes, "
69
- f"i.e. {num_samples*len(given_classes)} in total.")
70
- for class_label in tqdm(given_classes, desc="Classes"):
71
- for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
72
- if bs == 0: break
73
- logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
74
- temperature=temperature, top_k=top_k, top_p=top_p)
75
- save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
76
- else:
77
- print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
78
- for n, bs in tqdm(enumerate(batches), desc="Sampling"):
79
- if bs == 0: break
80
- logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
81
- save_from_logs(logs, logdir, base_count=n * batch_size)
82
-
83
-
84
- def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
85
- xx = logs[key]
86
- for i, x in enumerate(xx):
87
- x = chw_to_pillow(x)
88
- count = base_count + i
89
- if cond_key is None:
90
- x.save(os.path.join(logdir, f"{count:06}.png"))
91
- else:
92
- condlabel = cond_key[i]
93
- if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
94
- os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
95
- x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
96
-
97
-
98
- def get_parser():
99
- def str2bool(v):
100
- if isinstance(v, bool):
101
- return v
102
- if v.lower() in ("yes", "true", "t", "y", "1"):
103
- return True
104
- elif v.lower() in ("no", "false", "f", "n", "0"):
105
- return False
106
- else:
107
- raise argparse.ArgumentTypeError("Boolean value expected.")
108
-
109
- parser = argparse.ArgumentParser()
110
- parser.add_argument(
111
- "-r",
112
- "--resume",
113
- type=str,
114
- nargs="?",
115
- help="load from logdir or checkpoint in logdir",
116
- )
117
- parser.add_argument(
118
- "-o",
119
- "--outdir",
120
- type=str,
121
- nargs="?",
122
- help="path where the samples will be logged to.",
123
- default=""
124
- )
125
- parser.add_argument(
126
- "-b",
127
- "--base",
128
- nargs="*",
129
- metavar="base_config.yaml",
130
- help="paths to base configs. Loaded from left-to-right. "
131
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
132
- default=list(),
133
- )
134
- parser.add_argument(
135
- "-n",
136
- "--num_samples",
137
- type=int,
138
- nargs="?",
139
- help="num_samples to draw",
140
- default=50000
141
- )
142
- parser.add_argument(
143
- "--batch_size",
144
- type=int,
145
- nargs="?",
146
- help="the batch size",
147
- default=25
148
- )
149
- parser.add_argument(
150
- "-k",
151
- "--top_k",
152
- type=int,
153
- nargs="?",
154
- help="top-k value to sample with",
155
- default=250,
156
- )
157
- parser.add_argument(
158
- "-t",
159
- "--temperature",
160
- type=float,
161
- nargs="?",
162
- help="temperature value to sample with",
163
- default=1.0
164
- )
165
- parser.add_argument(
166
- "-p",
167
- "--top_p",
168
- type=float,
169
- nargs="?",
170
- help="top-p value to sample with",
171
- default=1.0
172
- )
173
- parser.add_argument(
174
- "--classes",
175
- type=str,
176
- nargs="?",
177
- help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
178
- default="imagenet"
179
- )
180
- return parser
181
-
182
-
183
- def load_model_from_config(config, sd, gpu=True, eval_mode=True):
184
- model = instantiate_from_config(config)
185
- if sd is not None:
186
- model.load_state_dict(sd)
187
- if gpu:
188
- model.cuda()
189
- if eval_mode:
190
- model.eval()
191
- return {"model": model}
192
-
193
-
194
- def load_model(config, ckpt, gpu, eval_mode):
195
- # load the specified checkpoint
196
- if ckpt:
197
- pl_sd = torch.load(ckpt, map_location="cpu")
198
- global_step = pl_sd["global_step"]
199
- print(f"loaded model from global step {global_step}.")
200
- else:
201
- pl_sd = {"state_dict": None}
202
- global_step = None
203
- model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
204
- return model, global_step
205
-
206
-
207
- if __name__ == "__main__":
208
- sys.path.append(os.getcwd())
209
- parser = get_parser()
210
-
211
- opt, unknown = parser.parse_known_args()
212
- assert opt.resume
213
-
214
- ckpt = None
215
-
216
- if not os.path.exists(opt.resume):
217
- raise ValueError("Cannot find {}".format(opt.resume))
218
- if os.path.isfile(opt.resume):
219
- paths = opt.resume.split("/")
220
- try:
221
- idx = len(paths)-paths[::-1].index("logs")+1
222
- except ValueError:
223
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
224
- logdir = "/".join(paths[:idx])
225
- ckpt = opt.resume
226
- else:
227
- assert os.path.isdir(opt.resume), opt.resume
228
- logdir = opt.resume.rstrip("/")
229
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
230
-
231
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
232
- opt.base = base_configs+opt.base
233
-
234
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
235
- cli = OmegaConf.from_dotlist(unknown)
236
- config = OmegaConf.merge(*configs, cli)
237
-
238
- model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
239
-
240
- if opt.outdir:
241
- print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
242
- logdir = opt.outdir
243
-
244
- if opt.classes == "imagenet":
245
- given_classes = [i for i in range(1000)]
246
- else:
247
- cls_str = opt.classes
248
- assert not cls_str.endswith(","), 'class string should not end with a ","'
249
- given_classes = [int(c) for c in cls_str.split(",")]
250
-
251
- logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
252
- f"{global_step}")
253
-
254
- print(f"Logging to {logdir}")
255
- os.makedirs(logdir, exist_ok=True)
256
-
257
- run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
258
- given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
259
-
260
- print("done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/setup.py DELETED
@@ -1,13 +0,0 @@
1
- from setuptools import setup, find_packages
2
-
3
- setup(
4
- name='taming-transformers',
5
- version='0.0.1',
6
- description='Taming Transformers for High-Resolution Image Synthesis',
7
- packages=find_packages(),
8
- install_requires=[
9
- 'torch',
10
- 'numpy',
11
- 'tqdm',
12
- ],
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/lr_scheduler.py DELETED
@@ -1,34 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- class LambdaWarmUpCosineScheduler:
5
- """
6
- note: use with a base_lr of 1.0
7
- """
8
- def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
- self.lr_warm_up_steps = warm_up_steps
10
- self.lr_start = lr_start
11
- self.lr_min = lr_min
12
- self.lr_max = lr_max
13
- self.lr_max_decay_steps = max_decay_steps
14
- self.last_lr = 0.
15
- self.verbosity_interval = verbosity_interval
16
-
17
- def schedule(self, n):
18
- if self.verbosity_interval > 0:
19
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
- if n < self.lr_warm_up_steps:
21
- lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
- self.last_lr = lr
23
- return lr
24
- else:
25
- t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
- t = min(t, 1.0)
27
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
- 1 + np.cos(t * np.pi))
29
- self.last_lr = lr
30
- return lr
31
-
32
- def __call__(self, n):
33
- return self.schedule(n)
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/models/cond_transformer.py DELETED
@@ -1,352 +0,0 @@
1
- import os, math
2
- import torch
3
- import torch.nn.functional as F
4
- import pytorch_lightning as pl
5
-
6
- from main import instantiate_from_config
7
- from taming.modules.util import SOSProvider
8
-
9
-
10
- def disabled_train(self, mode=True):
11
- """Overwrite model.train with this function to make sure train/eval mode
12
- does not change anymore."""
13
- return self
14
-
15
-
16
- class Net2NetTransformer(pl.LightningModule):
17
- def __init__(self,
18
- transformer_config,
19
- first_stage_config,
20
- cond_stage_config,
21
- permuter_config=None,
22
- ckpt_path=None,
23
- ignore_keys=[],
24
- first_stage_key="image",
25
- cond_stage_key="depth",
26
- downsample_cond_size=-1,
27
- pkeep=1.0,
28
- sos_token=0,
29
- unconditional=False,
30
- ):
31
- super().__init__()
32
- self.be_unconditional = unconditional
33
- self.sos_token = sos_token
34
- self.first_stage_key = first_stage_key
35
- self.cond_stage_key = cond_stage_key
36
- self.init_first_stage_from_ckpt(first_stage_config)
37
- self.init_cond_stage_from_ckpt(cond_stage_config)
38
- if permuter_config is None:
39
- permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
40
- self.permuter = instantiate_from_config(config=permuter_config)
41
- self.transformer = instantiate_from_config(config=transformer_config)
42
-
43
- if ckpt_path is not None:
44
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
- self.downsample_cond_size = downsample_cond_size
46
- self.pkeep = pkeep
47
-
48
- def init_from_ckpt(self, path, ignore_keys=list()):
49
- sd = torch.load(path, map_location="cpu")["state_dict"]
50
- for k in sd.keys():
51
- for ik in ignore_keys:
52
- if k.startswith(ik):
53
- self.print("Deleting key {} from state_dict.".format(k))
54
- del sd[k]
55
- self.load_state_dict(sd, strict=False)
56
- print(f"Restored from {path}")
57
-
58
- def init_first_stage_from_ckpt(self, config):
59
- model = instantiate_from_config(config)
60
- model = model.eval()
61
- model.train = disabled_train
62
- self.first_stage_model = model
63
-
64
- def init_cond_stage_from_ckpt(self, config):
65
- if config == "__is_first_stage__":
66
- print("Using first stage also as cond stage.")
67
- self.cond_stage_model = self.first_stage_model
68
- elif config == "__is_unconditional__" or self.be_unconditional:
69
- print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
70
- f"Prepending {self.sos_token} as a sos token.")
71
- self.be_unconditional = True
72
- self.cond_stage_key = self.first_stage_key
73
- self.cond_stage_model = SOSProvider(self.sos_token)
74
- else:
75
- model = instantiate_from_config(config)
76
- model = model.eval()
77
- model.train = disabled_train
78
- self.cond_stage_model = model
79
-
80
- def forward(self, x, c):
81
- # one step to produce the logits
82
- _, z_indices = self.encode_to_z(x)
83
- _, c_indices = self.encode_to_c(c)
84
-
85
- if self.training and self.pkeep < 1.0:
86
- mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
87
- device=z_indices.device))
88
- mask = mask.round().to(dtype=torch.int64)
89
- r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
90
- a_indices = mask*z_indices+(1-mask)*r_indices
91
- else:
92
- a_indices = z_indices
93
-
94
- cz_indices = torch.cat((c_indices, a_indices), dim=1)
95
-
96
- # target includes all sequence elements (no need to handle first one
97
- # differently because we are conditioning)
98
- target = z_indices
99
- # make the prediction
100
- logits, _ = self.transformer(cz_indices[:, :-1])
101
- # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
102
- logits = logits[:, c_indices.shape[1]-1:]
103
-
104
- return logits, target
105
-
106
- def top_k_logits(self, logits, k):
107
- v, ix = torch.topk(logits, k)
108
- out = logits.clone()
109
- out[out < v[..., [-1]]] = -float('Inf')
110
- return out
111
-
112
- @torch.no_grad()
113
- def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
114
- callback=lambda k: None):
115
- x = torch.cat((c,x),dim=1)
116
- block_size = self.transformer.get_block_size()
117
- assert not self.transformer.training
118
- if self.pkeep <= 0.0:
119
- # one pass suffices since input is pure noise anyway
120
- assert len(x.shape)==2
121
- noise_shape = (x.shape[0], steps-1)
122
- #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
123
- noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
124
- x = torch.cat((x,noise),dim=1)
125
- logits, _ = self.transformer(x)
126
- # take all logits for now and scale by temp
127
- logits = logits / temperature
128
- # optionally crop probabilities to only the top k options
129
- if top_k is not None:
130
- logits = self.top_k_logits(logits, top_k)
131
- # apply softmax to convert to probabilities
132
- probs = F.softmax(logits, dim=-1)
133
- # sample from the distribution or take the most likely
134
- if sample:
135
- shape = probs.shape
136
- probs = probs.reshape(shape[0]*shape[1],shape[2])
137
- ix = torch.multinomial(probs, num_samples=1)
138
- probs = probs.reshape(shape[0],shape[1],shape[2])
139
- ix = ix.reshape(shape[0],shape[1])
140
- else:
141
- _, ix = torch.topk(probs, k=1, dim=-1)
142
- # cut off conditioning
143
- x = ix[:, c.shape[1]-1:]
144
- else:
145
- for k in range(steps):
146
- callback(k)
147
- assert x.size(1) <= block_size # make sure model can see conditioning
148
- x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
149
- logits, _ = self.transformer(x_cond)
150
- # pluck the logits at the final step and scale by temperature
151
- logits = logits[:, -1, :] / temperature
152
- # optionally crop probabilities to only the top k options
153
- if top_k is not None:
154
- logits = self.top_k_logits(logits, top_k)
155
- # apply softmax to convert to probabilities
156
- probs = F.softmax(logits, dim=-1)
157
- # sample from the distribution or take the most likely
158
- if sample:
159
- ix = torch.multinomial(probs, num_samples=1)
160
- else:
161
- _, ix = torch.topk(probs, k=1, dim=-1)
162
- # append to the sequence and continue
163
- x = torch.cat((x, ix), dim=1)
164
- # cut off conditioning
165
- x = x[:, c.shape[1]:]
166
- return x
167
-
168
- @torch.no_grad()
169
- def encode_to_z(self, x):
170
- quant_z, _, info = self.first_stage_model.encode(x)
171
- indices = info[2].view(quant_z.shape[0], -1)
172
- indices = self.permuter(indices)
173
- return quant_z, indices
174
-
175
- @torch.no_grad()
176
- def encode_to_c(self, c):
177
- if self.downsample_cond_size > -1:
178
- c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
179
- quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
180
- if len(indices.shape) > 2:
181
- indices = indices.view(c.shape[0], -1)
182
- return quant_c, indices
183
-
184
- @torch.no_grad()
185
- def decode_to_img(self, index, zshape):
186
- index = self.permuter(index, reverse=True)
187
- bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
188
- quant_z = self.first_stage_model.quantize.get_codebook_entry(
189
- index.reshape(-1), shape=bhwc)
190
- x = self.first_stage_model.decode(quant_z)
191
- return x
192
-
193
- @torch.no_grad()
194
- def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
195
- log = dict()
196
-
197
- N = 4
198
- if lr_interface:
199
- x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
200
- else:
201
- x, c = self.get_xc(batch, N)
202
- x = x.to(device=self.device)
203
- c = c.to(device=self.device)
204
-
205
- quant_z, z_indices = self.encode_to_z(x)
206
- quant_c, c_indices = self.encode_to_c(c)
207
-
208
- # create a "half"" sample
209
- z_start_indices = z_indices[:,:z_indices.shape[1]//2]
210
- index_sample = self.sample(z_start_indices, c_indices,
211
- steps=z_indices.shape[1]-z_start_indices.shape[1],
212
- temperature=temperature if temperature is not None else 1.0,
213
- sample=True,
214
- top_k=top_k if top_k is not None else 100,
215
- callback=callback if callback is not None else lambda k: None)
216
- x_sample = self.decode_to_img(index_sample, quant_z.shape)
217
-
218
- # sample
219
- z_start_indices = z_indices[:, :0]
220
- index_sample = self.sample(z_start_indices, c_indices,
221
- steps=z_indices.shape[1],
222
- temperature=temperature if temperature is not None else 1.0,
223
- sample=True,
224
- top_k=top_k if top_k is not None else 100,
225
- callback=callback if callback is not None else lambda k: None)
226
- x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
227
-
228
- # det sample
229
- z_start_indices = z_indices[:, :0]
230
- index_sample = self.sample(z_start_indices, c_indices,
231
- steps=z_indices.shape[1],
232
- sample=False,
233
- callback=callback if callback is not None else lambda k: None)
234
- x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
235
-
236
- # reconstruction
237
- x_rec = self.decode_to_img(z_indices, quant_z.shape)
238
-
239
- log["inputs"] = x
240
- log["reconstructions"] = x_rec
241
-
242
- if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
243
- figure_size = (x_rec.shape[2], x_rec.shape[3])
244
- dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
245
- label_for_category_no = dataset.get_textual_label_for_category_no
246
- plotter = dataset.conditional_builders[self.cond_stage_key].plot
247
- log["conditioning"] = torch.zeros_like(log["reconstructions"])
248
- for i in range(quant_c.shape[0]):
249
- log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
250
- log["conditioning_rec"] = log["conditioning"]
251
- elif self.cond_stage_key != "image":
252
- cond_rec = self.cond_stage_model.decode(quant_c)
253
- if self.cond_stage_key == "segmentation":
254
- # get image from segmentation mask
255
- num_classes = cond_rec.shape[1]
256
-
257
- c = torch.argmax(c, dim=1, keepdim=True)
258
- c = F.one_hot(c, num_classes=num_classes)
259
- c = c.squeeze(1).permute(0, 3, 1, 2).float()
260
- c = self.cond_stage_model.to_rgb(c)
261
-
262
- cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
263
- cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
264
- cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
265
- cond_rec = self.cond_stage_model.to_rgb(cond_rec)
266
- log["conditioning_rec"] = cond_rec
267
- log["conditioning"] = c
268
-
269
- log["samples_half"] = x_sample
270
- log["samples_nopix"] = x_sample_nopix
271
- log["samples_det"] = x_sample_det
272
- return log
273
-
274
- def get_input(self, key, batch):
275
- x = batch[key]
276
- if len(x.shape) == 3:
277
- x = x[..., None]
278
- if len(x.shape) == 4:
279
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
280
- if x.dtype == torch.double:
281
- x = x.float()
282
- return x
283
-
284
- def get_xc(self, batch, N=None):
285
- x = self.get_input(self.first_stage_key, batch)
286
- c = self.get_input(self.cond_stage_key, batch)
287
- if N is not None:
288
- x = x[:N]
289
- c = c[:N]
290
- return x, c
291
-
292
- def shared_step(self, batch, batch_idx):
293
- x, c = self.get_xc(batch)
294
- logits, target = self(x, c)
295
- loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
296
- return loss
297
-
298
- def training_step(self, batch, batch_idx):
299
- loss = self.shared_step(batch, batch_idx)
300
- self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
301
- return loss
302
-
303
- def validation_step(self, batch, batch_idx):
304
- loss = self.shared_step(batch, batch_idx)
305
- self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
306
- return loss
307
-
308
- def configure_optimizers(self):
309
- """
310
- Following minGPT:
311
- This long function is unfortunately doing something very simple and is being very defensive:
312
- We are separating out all parameters of the model into two buckets: those that will experience
313
- weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
314
- We are then returning the PyTorch optimizer object.
315
- """
316
- # separate out all parameters to those that will and won't experience regularizing weight decay
317
- decay = set()
318
- no_decay = set()
319
- whitelist_weight_modules = (torch.nn.Linear, )
320
- blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
321
- for mn, m in self.transformer.named_modules():
322
- for pn, p in m.named_parameters():
323
- fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
324
-
325
- if pn.endswith('bias'):
326
- # all biases will not be decayed
327
- no_decay.add(fpn)
328
- elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
329
- # weights of whitelist modules will be weight decayed
330
- decay.add(fpn)
331
- elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
332
- # weights of blacklist modules will NOT be weight decayed
333
- no_decay.add(fpn)
334
-
335
- # special case the position embedding parameter in the root GPT module as not decayed
336
- no_decay.add('pos_emb')
337
-
338
- # validate that we considered every parameter
339
- param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
340
- inter_params = decay & no_decay
341
- union_params = decay | no_decay
342
- assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
343
- assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
344
- % (str(param_dict.keys() - union_params), )
345
-
346
- # create the pytorch optimizer object
347
- optim_groups = [
348
- {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
349
- {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
350
- ]
351
- optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
352
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/models/dummy_cond_stage.py DELETED
@@ -1,22 +0,0 @@
1
- from torch import Tensor
2
-
3
-
4
- class DummyCondStage:
5
- def __init__(self, conditional_key):
6
- self.conditional_key = conditional_key
7
- self.train = None
8
-
9
- def eval(self):
10
- return self
11
-
12
- @staticmethod
13
- def encode(c: Tensor):
14
- return c, None, (None, None, c)
15
-
16
- @staticmethod
17
- def decode(c: Tensor):
18
- return c
19
-
20
- @staticmethod
21
- def to_rgb(c: Tensor):
22
- return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/models/vqgan.py DELETED
@@ -1,404 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import pytorch_lightning as pl
4
-
5
- from main import instantiate_from_config
6
-
7
- from taming.modules.diffusionmodules.model import Encoder, Decoder
8
- from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
- from taming.modules.vqvae.quantize import GumbelQuantize
10
- from taming.modules.vqvae.quantize import EMAVectorQuantizer
11
-
12
- class VQModel(pl.LightningModule):
13
- def __init__(self,
14
- ddconfig,
15
- lossconfig,
16
- n_embed,
17
- embed_dim,
18
- ckpt_path=None,
19
- ignore_keys=[],
20
- image_key="image",
21
- colorize_nlabels=None,
22
- monitor=None,
23
- remap=None,
24
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
25
- ):
26
- super().__init__()
27
- self.image_key = image_key
28
- self.encoder = Encoder(**ddconfig)
29
- self.decoder = Decoder(**ddconfig)
30
- self.loss = instantiate_from_config(lossconfig)
31
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
32
- remap=remap, sane_index_shape=sane_index_shape)
33
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
34
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
- if ckpt_path is not None:
36
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
37
- self.image_key = image_key
38
- if colorize_nlabels is not None:
39
- assert type(colorize_nlabels)==int
40
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
41
- if monitor is not None:
42
- self.monitor = monitor
43
-
44
- def init_from_ckpt(self, path, ignore_keys=list()):
45
- sd = torch.load(path, map_location="cpu")["state_dict"]
46
- keys = list(sd.keys())
47
- for k in keys:
48
- for ik in ignore_keys:
49
- if k.startswith(ik):
50
- print("Deleting key {} from state_dict.".format(k))
51
- del sd[k]
52
- self.load_state_dict(sd, strict=False)
53
- print(f"Restored from {path}")
54
-
55
- def encode(self, x):
56
- h = self.encoder(x)
57
- h = self.quant_conv(h)
58
- quant, emb_loss, info = self.quantize(h)
59
- return quant, emb_loss, info
60
-
61
- def decode(self, quant):
62
- quant = self.post_quant_conv(quant)
63
- dec = self.decoder(quant)
64
- return dec
65
-
66
- def decode_code(self, code_b):
67
- quant_b = self.quantize.embed_code(code_b)
68
- dec = self.decode(quant_b)
69
- return dec
70
-
71
- def forward(self, input):
72
- quant, diff, _ = self.encode(input)
73
- dec = self.decode(quant)
74
- return dec, diff
75
-
76
- def get_input(self, batch, k):
77
- x = batch[k]
78
- if len(x.shape) == 3:
79
- x = x[..., None]
80
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
81
- return x.float()
82
-
83
- def training_step(self, batch, batch_idx, optimizer_idx):
84
- x = self.get_input(batch, self.image_key)
85
- xrec, qloss = self(x)
86
-
87
- if optimizer_idx == 0:
88
- # autoencode
89
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
90
- last_layer=self.get_last_layer(), split="train")
91
-
92
- self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
93
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
94
- return aeloss
95
-
96
- if optimizer_idx == 1:
97
- # discriminator
98
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
99
- last_layer=self.get_last_layer(), split="train")
100
- self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
101
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
102
- return discloss
103
-
104
- def validation_step(self, batch, batch_idx):
105
- x = self.get_input(batch, self.image_key)
106
- xrec, qloss = self(x)
107
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
108
- last_layer=self.get_last_layer(), split="val")
109
-
110
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
111
- last_layer=self.get_last_layer(), split="val")
112
- rec_loss = log_dict_ae["val/rec_loss"]
113
- self.log("val/rec_loss", rec_loss,
114
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
115
- self.log("val/aeloss", aeloss,
116
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
117
- self.log_dict(log_dict_ae)
118
- self.log_dict(log_dict_disc)
119
- return self.log_dict
120
-
121
- def configure_optimizers(self):
122
- lr = self.learning_rate
123
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
124
- list(self.decoder.parameters())+
125
- list(self.quantize.parameters())+
126
- list(self.quant_conv.parameters())+
127
- list(self.post_quant_conv.parameters()),
128
- lr=lr, betas=(0.5, 0.9))
129
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
130
- lr=lr, betas=(0.5, 0.9))
131
- return [opt_ae, opt_disc], []
132
-
133
- def get_last_layer(self):
134
- return self.decoder.conv_out.weight
135
-
136
- def log_images(self, batch, **kwargs):
137
- log = dict()
138
- x = self.get_input(batch, self.image_key)
139
- x = x.to(self.device)
140
- xrec, _ = self(x)
141
- if x.shape[1] > 3:
142
- # colorize with random projection
143
- assert xrec.shape[1] > 3
144
- x = self.to_rgb(x)
145
- xrec = self.to_rgb(xrec)
146
- log["inputs"] = x
147
- log["reconstructions"] = xrec
148
- return log
149
-
150
- def to_rgb(self, x):
151
- assert self.image_key == "segmentation"
152
- if not hasattr(self, "colorize"):
153
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
154
- x = F.conv2d(x, weight=self.colorize)
155
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
156
- return x
157
-
158
-
159
- class VQSegmentationModel(VQModel):
160
- def __init__(self, n_labels, *args, **kwargs):
161
- super().__init__(*args, **kwargs)
162
- self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
163
-
164
- def configure_optimizers(self):
165
- lr = self.learning_rate
166
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
167
- list(self.decoder.parameters())+
168
- list(self.quantize.parameters())+
169
- list(self.quant_conv.parameters())+
170
- list(self.post_quant_conv.parameters()),
171
- lr=lr, betas=(0.5, 0.9))
172
- return opt_ae
173
-
174
- def training_step(self, batch, batch_idx):
175
- x = self.get_input(batch, self.image_key)
176
- xrec, qloss = self(x)
177
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
178
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
179
- return aeloss
180
-
181
- def validation_step(self, batch, batch_idx):
182
- x = self.get_input(batch, self.image_key)
183
- xrec, qloss = self(x)
184
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
185
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
186
- total_loss = log_dict_ae["val/total_loss"]
187
- self.log("val/total_loss", total_loss,
188
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
189
- return aeloss
190
-
191
- @torch.no_grad()
192
- def log_images(self, batch, **kwargs):
193
- log = dict()
194
- x = self.get_input(batch, self.image_key)
195
- x = x.to(self.device)
196
- xrec, _ = self(x)
197
- if x.shape[1] > 3:
198
- # colorize with random projection
199
- assert xrec.shape[1] > 3
200
- # convert logits to indices
201
- xrec = torch.argmax(xrec, dim=1, keepdim=True)
202
- xrec = F.one_hot(xrec, num_classes=x.shape[1])
203
- xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
204
- x = self.to_rgb(x)
205
- xrec = self.to_rgb(xrec)
206
- log["inputs"] = x
207
- log["reconstructions"] = xrec
208
- return log
209
-
210
-
211
- class VQNoDiscModel(VQModel):
212
- def __init__(self,
213
- ddconfig,
214
- lossconfig,
215
- n_embed,
216
- embed_dim,
217
- ckpt_path=None,
218
- ignore_keys=[],
219
- image_key="image",
220
- colorize_nlabels=None
221
- ):
222
- super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
223
- ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
224
- colorize_nlabels=colorize_nlabels)
225
-
226
- def training_step(self, batch, batch_idx):
227
- x = self.get_input(batch, self.image_key)
228
- xrec, qloss = self(x)
229
- # autoencode
230
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
231
- output = pl.TrainResult(minimize=aeloss)
232
- output.log("train/aeloss", aeloss,
233
- prog_bar=True, logger=True, on_step=True, on_epoch=True)
234
- output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
235
- return output
236
-
237
- def validation_step(self, batch, batch_idx):
238
- x = self.get_input(batch, self.image_key)
239
- xrec, qloss = self(x)
240
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
241
- rec_loss = log_dict_ae["val/rec_loss"]
242
- output = pl.EvalResult(checkpoint_on=rec_loss)
243
- output.log("val/rec_loss", rec_loss,
244
- prog_bar=True, logger=True, on_step=True, on_epoch=True)
245
- output.log("val/aeloss", aeloss,
246
- prog_bar=True, logger=True, on_step=True, on_epoch=True)
247
- output.log_dict(log_dict_ae)
248
-
249
- return output
250
-
251
- def configure_optimizers(self):
252
- optimizer = torch.optim.Adam(list(self.encoder.parameters())+
253
- list(self.decoder.parameters())+
254
- list(self.quantize.parameters())+
255
- list(self.quant_conv.parameters())+
256
- list(self.post_quant_conv.parameters()),
257
- lr=self.learning_rate, betas=(0.5, 0.9))
258
- return optimizer
259
-
260
-
261
- class GumbelVQ(VQModel):
262
- def __init__(self,
263
- ddconfig,
264
- lossconfig,
265
- n_embed,
266
- embed_dim,
267
- temperature_scheduler_config,
268
- ckpt_path=None,
269
- ignore_keys=[],
270
- image_key="image",
271
- colorize_nlabels=None,
272
- monitor=None,
273
- kl_weight=1e-8,
274
- remap=None,
275
- ):
276
-
277
- z_channels = ddconfig["z_channels"]
278
- super().__init__(ddconfig,
279
- lossconfig,
280
- n_embed,
281
- embed_dim,
282
- ckpt_path=None,
283
- ignore_keys=ignore_keys,
284
- image_key=image_key,
285
- colorize_nlabels=colorize_nlabels,
286
- monitor=monitor,
287
- )
288
-
289
- self.loss.n_classes = n_embed
290
- self.vocab_size = n_embed
291
-
292
- self.quantize = GumbelQuantize(z_channels, embed_dim,
293
- n_embed=n_embed,
294
- kl_weight=kl_weight, temp_init=1.0,
295
- remap=remap)
296
-
297
- self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
298
-
299
- if ckpt_path is not None:
300
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
301
-
302
- def temperature_scheduling(self):
303
- self.quantize.temperature = self.temperature_scheduler(self.global_step)
304
-
305
- def encode_to_prequant(self, x):
306
- h = self.encoder(x)
307
- h = self.quant_conv(h)
308
- return h
309
-
310
- def decode_code(self, code_b):
311
- raise NotImplementedError
312
-
313
- def training_step(self, batch, batch_idx, optimizer_idx):
314
- self.temperature_scheduling()
315
- x = self.get_input(batch, self.image_key)
316
- xrec, qloss = self(x)
317
-
318
- if optimizer_idx == 0:
319
- # autoencode
320
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
321
- last_layer=self.get_last_layer(), split="train")
322
-
323
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
324
- self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
325
- return aeloss
326
-
327
- if optimizer_idx == 1:
328
- # discriminator
329
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
330
- last_layer=self.get_last_layer(), split="train")
331
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
332
- return discloss
333
-
334
- def validation_step(self, batch, batch_idx):
335
- x = self.get_input(batch, self.image_key)
336
- xrec, qloss = self(x, return_pred_indices=True)
337
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
338
- last_layer=self.get_last_layer(), split="val")
339
-
340
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
341
- last_layer=self.get_last_layer(), split="val")
342
- rec_loss = log_dict_ae["val/rec_loss"]
343
- self.log("val/rec_loss", rec_loss,
344
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
345
- self.log("val/aeloss", aeloss,
346
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
347
- self.log_dict(log_dict_ae)
348
- self.log_dict(log_dict_disc)
349
- return self.log_dict
350
-
351
- def log_images(self, batch, **kwargs):
352
- log = dict()
353
- x = self.get_input(batch, self.image_key)
354
- x = x.to(self.device)
355
- # encode
356
- h = self.encoder(x)
357
- h = self.quant_conv(h)
358
- quant, _, _ = self.quantize(h)
359
- # decode
360
- x_rec = self.decode(quant)
361
- log["inputs"] = x
362
- log["reconstructions"] = x_rec
363
- return log
364
-
365
-
366
- class EMAVQ(VQModel):
367
- def __init__(self,
368
- ddconfig,
369
- lossconfig,
370
- n_embed,
371
- embed_dim,
372
- ckpt_path=None,
373
- ignore_keys=[],
374
- image_key="image",
375
- colorize_nlabels=None,
376
- monitor=None,
377
- remap=None,
378
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
379
- ):
380
- super().__init__(ddconfig,
381
- lossconfig,
382
- n_embed,
383
- embed_dim,
384
- ckpt_path=None,
385
- ignore_keys=ignore_keys,
386
- image_key=image_key,
387
- colorize_nlabels=colorize_nlabels,
388
- monitor=monitor,
389
- )
390
- self.quantize = EMAVectorQuantizer(n_embed=n_embed,
391
- embedding_dim=embed_dim,
392
- beta=0.25,
393
- remap=remap)
394
- def configure_optimizers(self):
395
- lr = self.learning_rate
396
- #Remove self.quantize from parameter list since it is updated via EMA
397
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
398
- list(self.decoder.parameters())+
399
- list(self.quant_conv.parameters())+
400
- list(self.post_quant_conv.parameters()),
401
- lr=lr, betas=(0.5, 0.9))
402
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
403
- lr=lr, betas=(0.5, 0.9))
404
- return [opt_ae, opt_disc], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/diffusionmodules/model.py DELETED
@@ -1,776 +0,0 @@
1
- # pytorch_diffusion + derived encoder decoder
2
- import math
3
- import torch
4
- import torch.nn as nn
5
- import numpy as np
6
-
7
-
8
- def get_timestep_embedding(timesteps, embedding_dim):
9
- """
10
- This matches the implementation in Denoising Diffusion Probabilistic Models:
11
- From Fairseq.
12
- Build sinusoidal embeddings.
13
- This matches the implementation in tensor2tensor, but differs slightly
14
- from the description in Section 3.5 of "Attention Is All You Need".
15
- """
16
- assert len(timesteps.shape) == 1
17
-
18
- half_dim = embedding_dim // 2
19
- emb = math.log(10000) / (half_dim - 1)
20
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
- emb = emb.to(device=timesteps.device)
22
- emb = timesteps.float()[:, None] * emb[None, :]
23
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
- if embedding_dim % 2 == 1: # zero pad
25
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
- return emb
27
-
28
-
29
- def nonlinearity(x):
30
- # swish
31
- return x*torch.sigmoid(x)
32
-
33
-
34
- def Normalize(in_channels):
35
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
-
37
-
38
- class Upsample(nn.Module):
39
- def __init__(self, in_channels, with_conv):
40
- super().__init__()
41
- self.with_conv = with_conv
42
- if self.with_conv:
43
- self.conv = torch.nn.Conv2d(in_channels,
44
- in_channels,
45
- kernel_size=3,
46
- stride=1,
47
- padding=1)
48
-
49
- def forward(self, x):
50
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
- if self.with_conv:
52
- x = self.conv(x)
53
- return x
54
-
55
-
56
- class Downsample(nn.Module):
57
- def __init__(self, in_channels, with_conv):
58
- super().__init__()
59
- self.with_conv = with_conv
60
- if self.with_conv:
61
- # no asymmetric padding in torch conv, must do it ourselves
62
- self.conv = torch.nn.Conv2d(in_channels,
63
- in_channels,
64
- kernel_size=3,
65
- stride=2,
66
- padding=0)
67
-
68
- def forward(self, x):
69
- if self.with_conv:
70
- pad = (0,1,0,1)
71
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
- x = self.conv(x)
73
- else:
74
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
- return x
76
-
77
-
78
- class ResnetBlock(nn.Module):
79
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
- dropout, temb_channels=512):
81
- super().__init__()
82
- self.in_channels = in_channels
83
- out_channels = in_channels if out_channels is None else out_channels
84
- self.out_channels = out_channels
85
- self.use_conv_shortcut = conv_shortcut
86
-
87
- self.norm1 = Normalize(in_channels)
88
- self.conv1 = torch.nn.Conv2d(in_channels,
89
- out_channels,
90
- kernel_size=3,
91
- stride=1,
92
- padding=1)
93
- if temb_channels > 0:
94
- self.temb_proj = torch.nn.Linear(temb_channels,
95
- out_channels)
96
- self.norm2 = Normalize(out_channels)
97
- self.dropout = torch.nn.Dropout(dropout)
98
- self.conv2 = torch.nn.Conv2d(out_channels,
99
- out_channels,
100
- kernel_size=3,
101
- stride=1,
102
- padding=1)
103
- if self.in_channels != self.out_channels:
104
- if self.use_conv_shortcut:
105
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
- out_channels,
107
- kernel_size=3,
108
- stride=1,
109
- padding=1)
110
- else:
111
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
- out_channels,
113
- kernel_size=1,
114
- stride=1,
115
- padding=0)
116
-
117
- def forward(self, x, temb):
118
- h = x
119
- h = self.norm1(h)
120
- h = nonlinearity(h)
121
- h = self.conv1(h)
122
-
123
- if temb is not None:
124
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
-
126
- h = self.norm2(h)
127
- h = nonlinearity(h)
128
- h = self.dropout(h)
129
- h = self.conv2(h)
130
-
131
- if self.in_channels != self.out_channels:
132
- if self.use_conv_shortcut:
133
- x = self.conv_shortcut(x)
134
- else:
135
- x = self.nin_shortcut(x)
136
-
137
- return x+h
138
-
139
-
140
- class AttnBlock(nn.Module):
141
- def __init__(self, in_channels):
142
- super().__init__()
143
- self.in_channels = in_channels
144
-
145
- self.norm = Normalize(in_channels)
146
- self.q = torch.nn.Conv2d(in_channels,
147
- in_channels,
148
- kernel_size=1,
149
- stride=1,
150
- padding=0)
151
- self.k = torch.nn.Conv2d(in_channels,
152
- in_channels,
153
- kernel_size=1,
154
- stride=1,
155
- padding=0)
156
- self.v = torch.nn.Conv2d(in_channels,
157
- in_channels,
158
- kernel_size=1,
159
- stride=1,
160
- padding=0)
161
- self.proj_out = torch.nn.Conv2d(in_channels,
162
- in_channels,
163
- kernel_size=1,
164
- stride=1,
165
- padding=0)
166
-
167
-
168
- def forward(self, x):
169
- h_ = x
170
- h_ = self.norm(h_)
171
- q = self.q(h_)
172
- k = self.k(h_)
173
- v = self.v(h_)
174
-
175
- # compute attention
176
- b,c,h,w = q.shape
177
- q = q.reshape(b,c,h*w)
178
- q = q.permute(0,2,1) # b,hw,c
179
- k = k.reshape(b,c,h*w) # b,c,hw
180
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
- w_ = w_ * (int(c)**(-0.5))
182
- w_ = torch.nn.functional.softmax(w_, dim=2)
183
-
184
- # attend to values
185
- v = v.reshape(b,c,h*w)
186
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
187
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
- h_ = h_.reshape(b,c,h,w)
189
-
190
- h_ = self.proj_out(h_)
191
-
192
- return x+h_
193
-
194
-
195
- class Model(nn.Module):
196
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
197
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
198
- resolution, use_timestep=True):
199
- super().__init__()
200
- self.ch = ch
201
- self.temb_ch = self.ch*4
202
- self.num_resolutions = len(ch_mult)
203
- self.num_res_blocks = num_res_blocks
204
- self.resolution = resolution
205
- self.in_channels = in_channels
206
-
207
- self.use_timestep = use_timestep
208
- if self.use_timestep:
209
- # timestep embedding
210
- self.temb = nn.Module()
211
- self.temb.dense = nn.ModuleList([
212
- torch.nn.Linear(self.ch,
213
- self.temb_ch),
214
- torch.nn.Linear(self.temb_ch,
215
- self.temb_ch),
216
- ])
217
-
218
- # downsampling
219
- self.conv_in = torch.nn.Conv2d(in_channels,
220
- self.ch,
221
- kernel_size=3,
222
- stride=1,
223
- padding=1)
224
-
225
- curr_res = resolution
226
- in_ch_mult = (1,)+tuple(ch_mult)
227
- self.down = nn.ModuleList()
228
- for i_level in range(self.num_resolutions):
229
- block = nn.ModuleList()
230
- attn = nn.ModuleList()
231
- block_in = ch*in_ch_mult[i_level]
232
- block_out = ch*ch_mult[i_level]
233
- for i_block in range(self.num_res_blocks):
234
- block.append(ResnetBlock(in_channels=block_in,
235
- out_channels=block_out,
236
- temb_channels=self.temb_ch,
237
- dropout=dropout))
238
- block_in = block_out
239
- if curr_res in attn_resolutions:
240
- attn.append(AttnBlock(block_in))
241
- down = nn.Module()
242
- down.block = block
243
- down.attn = attn
244
- if i_level != self.num_resolutions-1:
245
- down.downsample = Downsample(block_in, resamp_with_conv)
246
- curr_res = curr_res // 2
247
- self.down.append(down)
248
-
249
- # middle
250
- self.mid = nn.Module()
251
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
252
- out_channels=block_in,
253
- temb_channels=self.temb_ch,
254
- dropout=dropout)
255
- self.mid.attn_1 = AttnBlock(block_in)
256
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
257
- out_channels=block_in,
258
- temb_channels=self.temb_ch,
259
- dropout=dropout)
260
-
261
- # upsampling
262
- self.up = nn.ModuleList()
263
- for i_level in reversed(range(self.num_resolutions)):
264
- block = nn.ModuleList()
265
- attn = nn.ModuleList()
266
- block_out = ch*ch_mult[i_level]
267
- skip_in = ch*ch_mult[i_level]
268
- for i_block in range(self.num_res_blocks+1):
269
- if i_block == self.num_res_blocks:
270
- skip_in = ch*in_ch_mult[i_level]
271
- block.append(ResnetBlock(in_channels=block_in+skip_in,
272
- out_channels=block_out,
273
- temb_channels=self.temb_ch,
274
- dropout=dropout))
275
- block_in = block_out
276
- if curr_res in attn_resolutions:
277
- attn.append(AttnBlock(block_in))
278
- up = nn.Module()
279
- up.block = block
280
- up.attn = attn
281
- if i_level != 0:
282
- up.upsample = Upsample(block_in, resamp_with_conv)
283
- curr_res = curr_res * 2
284
- self.up.insert(0, up) # prepend to get consistent order
285
-
286
- # end
287
- self.norm_out = Normalize(block_in)
288
- self.conv_out = torch.nn.Conv2d(block_in,
289
- out_ch,
290
- kernel_size=3,
291
- stride=1,
292
- padding=1)
293
-
294
-
295
- def forward(self, x, t=None):
296
- #assert x.shape[2] == x.shape[3] == self.resolution
297
-
298
- if self.use_timestep:
299
- # timestep embedding
300
- assert t is not None
301
- temb = get_timestep_embedding(t, self.ch)
302
- temb = self.temb.dense[0](temb)
303
- temb = nonlinearity(temb)
304
- temb = self.temb.dense[1](temb)
305
- else:
306
- temb = None
307
-
308
- # downsampling
309
- hs = [self.conv_in(x)]
310
- for i_level in range(self.num_resolutions):
311
- for i_block in range(self.num_res_blocks):
312
- h = self.down[i_level].block[i_block](hs[-1], temb)
313
- if len(self.down[i_level].attn) > 0:
314
- h = self.down[i_level].attn[i_block](h)
315
- hs.append(h)
316
- if i_level != self.num_resolutions-1:
317
- hs.append(self.down[i_level].downsample(hs[-1]))
318
-
319
- # middle
320
- h = hs[-1]
321
- h = self.mid.block_1(h, temb)
322
- h = self.mid.attn_1(h)
323
- h = self.mid.block_2(h, temb)
324
-
325
- # upsampling
326
- for i_level in reversed(range(self.num_resolutions)):
327
- for i_block in range(self.num_res_blocks+1):
328
- h = self.up[i_level].block[i_block](
329
- torch.cat([h, hs.pop()], dim=1), temb)
330
- if len(self.up[i_level].attn) > 0:
331
- h = self.up[i_level].attn[i_block](h)
332
- if i_level != 0:
333
- h = self.up[i_level].upsample(h)
334
-
335
- # end
336
- h = self.norm_out(h)
337
- h = nonlinearity(h)
338
- h = self.conv_out(h)
339
- return h
340
-
341
-
342
- class Encoder(nn.Module):
343
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
344
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
345
- resolution, z_channels, double_z=True, **ignore_kwargs):
346
- super().__init__()
347
- self.ch = ch
348
- self.temb_ch = 0
349
- self.num_resolutions = len(ch_mult)
350
- self.num_res_blocks = num_res_blocks
351
- self.resolution = resolution
352
- self.in_channels = in_channels
353
-
354
- # downsampling
355
- self.conv_in = torch.nn.Conv2d(in_channels,
356
- self.ch,
357
- kernel_size=3,
358
- stride=1,
359
- padding=1)
360
-
361
- curr_res = resolution
362
- in_ch_mult = (1,)+tuple(ch_mult)
363
- self.down = nn.ModuleList()
364
- for i_level in range(self.num_resolutions):
365
- block = nn.ModuleList()
366
- attn = nn.ModuleList()
367
- block_in = ch*in_ch_mult[i_level]
368
- block_out = ch*ch_mult[i_level]
369
- for i_block in range(self.num_res_blocks):
370
- block.append(ResnetBlock(in_channels=block_in,
371
- out_channels=block_out,
372
- temb_channels=self.temb_ch,
373
- dropout=dropout))
374
- block_in = block_out
375
- if curr_res in attn_resolutions:
376
- attn.append(AttnBlock(block_in))
377
- down = nn.Module()
378
- down.block = block
379
- down.attn = attn
380
- if i_level != self.num_resolutions-1:
381
- down.downsample = Downsample(block_in, resamp_with_conv)
382
- curr_res = curr_res // 2
383
- self.down.append(down)
384
-
385
- # middle
386
- self.mid = nn.Module()
387
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
388
- out_channels=block_in,
389
- temb_channels=self.temb_ch,
390
- dropout=dropout)
391
- self.mid.attn_1 = AttnBlock(block_in)
392
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
393
- out_channels=block_in,
394
- temb_channels=self.temb_ch,
395
- dropout=dropout)
396
-
397
- # end
398
- self.norm_out = Normalize(block_in)
399
- self.conv_out = torch.nn.Conv2d(block_in,
400
- 2*z_channels if double_z else z_channels,
401
- kernel_size=3,
402
- stride=1,
403
- padding=1)
404
-
405
-
406
- def forward(self, x):
407
- #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
408
-
409
- # timestep embedding
410
- temb = None
411
-
412
- # downsampling
413
- hs = [self.conv_in(x)]
414
- for i_level in range(self.num_resolutions):
415
- for i_block in range(self.num_res_blocks):
416
- h = self.down[i_level].block[i_block](hs[-1], temb)
417
- if len(self.down[i_level].attn) > 0:
418
- h = self.down[i_level].attn[i_block](h)
419
- hs.append(h)
420
- if i_level != self.num_resolutions-1:
421
- hs.append(self.down[i_level].downsample(hs[-1]))
422
-
423
- # middle
424
- h = hs[-1]
425
- h = self.mid.block_1(h, temb)
426
- h = self.mid.attn_1(h)
427
- h = self.mid.block_2(h, temb)
428
-
429
- # end
430
- h = self.norm_out(h)
431
- h = nonlinearity(h)
432
- h = self.conv_out(h)
433
- return h
434
-
435
-
436
- class Decoder(nn.Module):
437
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
438
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
439
- resolution, z_channels, give_pre_end=False, **ignorekwargs):
440
- super().__init__()
441
- self.ch = ch
442
- self.temb_ch = 0
443
- self.num_resolutions = len(ch_mult)
444
- self.num_res_blocks = num_res_blocks
445
- self.resolution = resolution
446
- self.in_channels = in_channels
447
- self.give_pre_end = give_pre_end
448
-
449
- # compute in_ch_mult, block_in and curr_res at lowest res
450
- in_ch_mult = (1,)+tuple(ch_mult)
451
- block_in = ch*ch_mult[self.num_resolutions-1]
452
- curr_res = resolution // 2**(self.num_resolutions-1)
453
- self.z_shape = (1,z_channels,curr_res,curr_res)
454
- print("Working with z of shape {} = {} dimensions.".format(
455
- self.z_shape, np.prod(self.z_shape)))
456
-
457
- # z to block_in
458
- self.conv_in = torch.nn.Conv2d(z_channels,
459
- block_in,
460
- kernel_size=3,
461
- stride=1,
462
- padding=1)
463
-
464
- # middle
465
- self.mid = nn.Module()
466
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
467
- out_channels=block_in,
468
- temb_channels=self.temb_ch,
469
- dropout=dropout)
470
- self.mid.attn_1 = AttnBlock(block_in)
471
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
472
- out_channels=block_in,
473
- temb_channels=self.temb_ch,
474
- dropout=dropout)
475
-
476
- # upsampling
477
- self.up = nn.ModuleList()
478
- for i_level in reversed(range(self.num_resolutions)):
479
- block = nn.ModuleList()
480
- attn = nn.ModuleList()
481
- block_out = ch*ch_mult[i_level]
482
- for i_block in range(self.num_res_blocks+1):
483
- block.append(ResnetBlock(in_channels=block_in,
484
- out_channels=block_out,
485
- temb_channels=self.temb_ch,
486
- dropout=dropout))
487
- block_in = block_out
488
- if curr_res in attn_resolutions:
489
- attn.append(AttnBlock(block_in))
490
- up = nn.Module()
491
- up.block = block
492
- up.attn = attn
493
- if i_level != 0:
494
- up.upsample = Upsample(block_in, resamp_with_conv)
495
- curr_res = curr_res * 2
496
- self.up.insert(0, up) # prepend to get consistent order
497
-
498
- # end
499
- self.norm_out = Normalize(block_in)
500
- self.conv_out = torch.nn.Conv2d(block_in,
501
- out_ch,
502
- kernel_size=3,
503
- stride=1,
504
- padding=1)
505
-
506
- def forward(self, z):
507
- #assert z.shape[1:] == self.z_shape[1:]
508
- self.last_z_shape = z.shape
509
-
510
- # timestep embedding
511
- temb = None
512
-
513
- # z to block_in
514
- h = self.conv_in(z)
515
-
516
- # middle
517
- h = self.mid.block_1(h, temb)
518
- h = self.mid.attn_1(h)
519
- h = self.mid.block_2(h, temb)
520
-
521
- # upsampling
522
- for i_level in reversed(range(self.num_resolutions)):
523
- for i_block in range(self.num_res_blocks+1):
524
- h = self.up[i_level].block[i_block](h, temb)
525
- if len(self.up[i_level].attn) > 0:
526
- h = self.up[i_level].attn[i_block](h)
527
- if i_level != 0:
528
- h = self.up[i_level].upsample(h)
529
-
530
- # end
531
- if self.give_pre_end:
532
- return h
533
-
534
- h = self.norm_out(h)
535
- h = nonlinearity(h)
536
- h = self.conv_out(h)
537
- return h
538
-
539
-
540
- class VUNet(nn.Module):
541
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
542
- attn_resolutions, dropout=0.0, resamp_with_conv=True,
543
- in_channels, c_channels,
544
- resolution, z_channels, use_timestep=False, **ignore_kwargs):
545
- super().__init__()
546
- self.ch = ch
547
- self.temb_ch = self.ch*4
548
- self.num_resolutions = len(ch_mult)
549
- self.num_res_blocks = num_res_blocks
550
- self.resolution = resolution
551
-
552
- self.use_timestep = use_timestep
553
- if self.use_timestep:
554
- # timestep embedding
555
- self.temb = nn.Module()
556
- self.temb.dense = nn.ModuleList([
557
- torch.nn.Linear(self.ch,
558
- self.temb_ch),
559
- torch.nn.Linear(self.temb_ch,
560
- self.temb_ch),
561
- ])
562
-
563
- # downsampling
564
- self.conv_in = torch.nn.Conv2d(c_channels,
565
- self.ch,
566
- kernel_size=3,
567
- stride=1,
568
- padding=1)
569
-
570
- curr_res = resolution
571
- in_ch_mult = (1,)+tuple(ch_mult)
572
- self.down = nn.ModuleList()
573
- for i_level in range(self.num_resolutions):
574
- block = nn.ModuleList()
575
- attn = nn.ModuleList()
576
- block_in = ch*in_ch_mult[i_level]
577
- block_out = ch*ch_mult[i_level]
578
- for i_block in range(self.num_res_blocks):
579
- block.append(ResnetBlock(in_channels=block_in,
580
- out_channels=block_out,
581
- temb_channels=self.temb_ch,
582
- dropout=dropout))
583
- block_in = block_out
584
- if curr_res in attn_resolutions:
585
- attn.append(AttnBlock(block_in))
586
- down = nn.Module()
587
- down.block = block
588
- down.attn = attn
589
- if i_level != self.num_resolutions-1:
590
- down.downsample = Downsample(block_in, resamp_with_conv)
591
- curr_res = curr_res // 2
592
- self.down.append(down)
593
-
594
- self.z_in = torch.nn.Conv2d(z_channels,
595
- block_in,
596
- kernel_size=1,
597
- stride=1,
598
- padding=0)
599
- # middle
600
- self.mid = nn.Module()
601
- self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
602
- out_channels=block_in,
603
- temb_channels=self.temb_ch,
604
- dropout=dropout)
605
- self.mid.attn_1 = AttnBlock(block_in)
606
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
- out_channels=block_in,
608
- temb_channels=self.temb_ch,
609
- dropout=dropout)
610
-
611
- # upsampling
612
- self.up = nn.ModuleList()
613
- for i_level in reversed(range(self.num_resolutions)):
614
- block = nn.ModuleList()
615
- attn = nn.ModuleList()
616
- block_out = ch*ch_mult[i_level]
617
- skip_in = ch*ch_mult[i_level]
618
- for i_block in range(self.num_res_blocks+1):
619
- if i_block == self.num_res_blocks:
620
- skip_in = ch*in_ch_mult[i_level]
621
- block.append(ResnetBlock(in_channels=block_in+skip_in,
622
- out_channels=block_out,
623
- temb_channels=self.temb_ch,
624
- dropout=dropout))
625
- block_in = block_out
626
- if curr_res in attn_resolutions:
627
- attn.append(AttnBlock(block_in))
628
- up = nn.Module()
629
- up.block = block
630
- up.attn = attn
631
- if i_level != 0:
632
- up.upsample = Upsample(block_in, resamp_with_conv)
633
- curr_res = curr_res * 2
634
- self.up.insert(0, up) # prepend to get consistent order
635
-
636
- # end
637
- self.norm_out = Normalize(block_in)
638
- self.conv_out = torch.nn.Conv2d(block_in,
639
- out_ch,
640
- kernel_size=3,
641
- stride=1,
642
- padding=1)
643
-
644
-
645
- def forward(self, x, z):
646
- #assert x.shape[2] == x.shape[3] == self.resolution
647
-
648
- if self.use_timestep:
649
- # timestep embedding
650
- assert t is not None
651
- temb = get_timestep_embedding(t, self.ch)
652
- temb = self.temb.dense[0](temb)
653
- temb = nonlinearity(temb)
654
- temb = self.temb.dense[1](temb)
655
- else:
656
- temb = None
657
-
658
- # downsampling
659
- hs = [self.conv_in(x)]
660
- for i_level in range(self.num_resolutions):
661
- for i_block in range(self.num_res_blocks):
662
- h = self.down[i_level].block[i_block](hs[-1], temb)
663
- if len(self.down[i_level].attn) > 0:
664
- h = self.down[i_level].attn[i_block](h)
665
- hs.append(h)
666
- if i_level != self.num_resolutions-1:
667
- hs.append(self.down[i_level].downsample(hs[-1]))
668
-
669
- # middle
670
- h = hs[-1]
671
- z = self.z_in(z)
672
- h = torch.cat((h,z),dim=1)
673
- h = self.mid.block_1(h, temb)
674
- h = self.mid.attn_1(h)
675
- h = self.mid.block_2(h, temb)
676
-
677
- # upsampling
678
- for i_level in reversed(range(self.num_resolutions)):
679
- for i_block in range(self.num_res_blocks+1):
680
- h = self.up[i_level].block[i_block](
681
- torch.cat([h, hs.pop()], dim=1), temb)
682
- if len(self.up[i_level].attn) > 0:
683
- h = self.up[i_level].attn[i_block](h)
684
- if i_level != 0:
685
- h = self.up[i_level].upsample(h)
686
-
687
- # end
688
- h = self.norm_out(h)
689
- h = nonlinearity(h)
690
- h = self.conv_out(h)
691
- return h
692
-
693
-
694
- class SimpleDecoder(nn.Module):
695
- def __init__(self, in_channels, out_channels, *args, **kwargs):
696
- super().__init__()
697
- self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
698
- ResnetBlock(in_channels=in_channels,
699
- out_channels=2 * in_channels,
700
- temb_channels=0, dropout=0.0),
701
- ResnetBlock(in_channels=2 * in_channels,
702
- out_channels=4 * in_channels,
703
- temb_channels=0, dropout=0.0),
704
- ResnetBlock(in_channels=4 * in_channels,
705
- out_channels=2 * in_channels,
706
- temb_channels=0, dropout=0.0),
707
- nn.Conv2d(2*in_channels, in_channels, 1),
708
- Upsample(in_channels, with_conv=True)])
709
- # end
710
- self.norm_out = Normalize(in_channels)
711
- self.conv_out = torch.nn.Conv2d(in_channels,
712
- out_channels,
713
- kernel_size=3,
714
- stride=1,
715
- padding=1)
716
-
717
- def forward(self, x):
718
- for i, layer in enumerate(self.model):
719
- if i in [1,2,3]:
720
- x = layer(x, None)
721
- else:
722
- x = layer(x)
723
-
724
- h = self.norm_out(x)
725
- h = nonlinearity(h)
726
- x = self.conv_out(h)
727
- return x
728
-
729
-
730
- class UpsampleDecoder(nn.Module):
731
- def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
732
- ch_mult=(2,2), dropout=0.0):
733
- super().__init__()
734
- # upsampling
735
- self.temb_ch = 0
736
- self.num_resolutions = len(ch_mult)
737
- self.num_res_blocks = num_res_blocks
738
- block_in = in_channels
739
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
740
- self.res_blocks = nn.ModuleList()
741
- self.upsample_blocks = nn.ModuleList()
742
- for i_level in range(self.num_resolutions):
743
- res_block = []
744
- block_out = ch * ch_mult[i_level]
745
- for i_block in range(self.num_res_blocks + 1):
746
- res_block.append(ResnetBlock(in_channels=block_in,
747
- out_channels=block_out,
748
- temb_channels=self.temb_ch,
749
- dropout=dropout))
750
- block_in = block_out
751
- self.res_blocks.append(nn.ModuleList(res_block))
752
- if i_level != self.num_resolutions - 1:
753
- self.upsample_blocks.append(Upsample(block_in, True))
754
- curr_res = curr_res * 2
755
-
756
- # end
757
- self.norm_out = Normalize(block_in)
758
- self.conv_out = torch.nn.Conv2d(block_in,
759
- out_channels,
760
- kernel_size=3,
761
- stride=1,
762
- padding=1)
763
-
764
- def forward(self, x):
765
- # upsampling
766
- h = x
767
- for k, i_level in enumerate(range(self.num_resolutions)):
768
- for i_block in range(self.num_res_blocks + 1):
769
- h = self.res_blocks[i_level][i_block](h, None)
770
- if i_level != self.num_resolutions - 1:
771
- h = self.upsample_blocks[k](h)
772
- h = self.norm_out(h)
773
- h = nonlinearity(h)
774
- h = self.conv_out(h)
775
- return h
776
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/discriminator/model.py DELETED
@@ -1,67 +0,0 @@
1
- import functools
2
- import torch.nn as nn
3
-
4
-
5
- from taming.modules.util import ActNorm
6
-
7
-
8
- def weights_init(m):
9
- classname = m.__class__.__name__
10
- if classname.find('Conv') != -1:
11
- nn.init.normal_(m.weight.data, 0.0, 0.02)
12
- elif classname.find('BatchNorm') != -1:
13
- nn.init.normal_(m.weight.data, 1.0, 0.02)
14
- nn.init.constant_(m.bias.data, 0)
15
-
16
-
17
- class NLayerDiscriminator(nn.Module):
18
- """Defines a PatchGAN discriminator as in Pix2Pix
19
- --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
- """
21
- def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
- """Construct a PatchGAN discriminator
23
- Parameters:
24
- input_nc (int) -- the number of channels in input images
25
- ndf (int) -- the number of filters in the last conv layer
26
- n_layers (int) -- the number of conv layers in the discriminator
27
- norm_layer -- normalization layer
28
- """
29
- super(NLayerDiscriminator, self).__init__()
30
- if not use_actnorm:
31
- norm_layer = nn.BatchNorm2d
32
- else:
33
- norm_layer = ActNorm
34
- if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
- use_bias = norm_layer.func != nn.BatchNorm2d
36
- else:
37
- use_bias = norm_layer != nn.BatchNorm2d
38
-
39
- kw = 4
40
- padw = 1
41
- sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
- nf_mult = 1
43
- nf_mult_prev = 1
44
- for n in range(1, n_layers): # gradually increase the number of filters
45
- nf_mult_prev = nf_mult
46
- nf_mult = min(2 ** n, 8)
47
- sequence += [
48
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
- norm_layer(ndf * nf_mult),
50
- nn.LeakyReLU(0.2, True)
51
- ]
52
-
53
- nf_mult_prev = nf_mult
54
- nf_mult = min(2 ** n_layers, 8)
55
- sequence += [
56
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
- norm_layer(ndf * nf_mult),
58
- nn.LeakyReLU(0.2, True)
59
- ]
60
-
61
- sequence += [
62
- nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
- self.main = nn.Sequential(*sequence)
64
-
65
- def forward(self, input):
66
- """Standard forward."""
67
- return self.main(input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/losses/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from taming.modules.losses.vqperceptual import DummyLoss
2
-
 
 
 
taming-transformers/taming/modules/losses/lpips.py DELETED
@@ -1,123 +0,0 @@
1
- """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torchvision import models
6
- from collections import namedtuple
7
-
8
- from taming.util import get_ckpt_path
9
-
10
-
11
- class LPIPS(nn.Module):
12
- # Learned perceptual metric
13
- def __init__(self, use_dropout=True):
14
- super().__init__()
15
- self.scaling_layer = ScalingLayer()
16
- self.chns = [64, 128, 256, 512, 512] # vg16 features
17
- self.net = vgg16(pretrained=True, requires_grad=False)
18
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23
- self.load_from_pretrained()
24
- for param in self.parameters():
25
- param.requires_grad = False
26
-
27
- def load_from_pretrained(self, name="vgg_lpips"):
28
- ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29
- self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30
- print("loaded pretrained LPIPS loss from {}".format(ckpt))
31
-
32
- @classmethod
33
- def from_pretrained(cls, name="vgg_lpips"):
34
- if name != "vgg_lpips":
35
- raise NotImplementedError
36
- model = cls()
37
- ckpt = get_ckpt_path(name)
38
- model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39
- return model
40
-
41
- def forward(self, input, target):
42
- in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43
- outs0, outs1 = self.net(in0_input), self.net(in1_input)
44
- feats0, feats1, diffs = {}, {}, {}
45
- lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46
- for kk in range(len(self.chns)):
47
- feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49
-
50
- res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51
- val = res[0]
52
- for l in range(1, len(self.chns)):
53
- val += res[l]
54
- return val
55
-
56
-
57
- class ScalingLayer(nn.Module):
58
- def __init__(self):
59
- super(ScalingLayer, self).__init__()
60
- self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61
- self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62
-
63
- def forward(self, inp):
64
- return (inp - self.shift) / self.scale
65
-
66
-
67
- class NetLinLayer(nn.Module):
68
- """ A single linear layer which does a 1x1 conv """
69
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
70
- super(NetLinLayer, self).__init__()
71
- layers = [nn.Dropout(), ] if (use_dropout) else []
72
- layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73
- self.model = nn.Sequential(*layers)
74
-
75
-
76
- class vgg16(torch.nn.Module):
77
- def __init__(self, requires_grad=False, pretrained=True):
78
- super(vgg16, self).__init__()
79
- vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80
- self.slice1 = torch.nn.Sequential()
81
- self.slice2 = torch.nn.Sequential()
82
- self.slice3 = torch.nn.Sequential()
83
- self.slice4 = torch.nn.Sequential()
84
- self.slice5 = torch.nn.Sequential()
85
- self.N_slices = 5
86
- for x in range(4):
87
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
88
- for x in range(4, 9):
89
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
90
- for x in range(9, 16):
91
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
92
- for x in range(16, 23):
93
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
94
- for x in range(23, 30):
95
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
96
- if not requires_grad:
97
- for param in self.parameters():
98
- param.requires_grad = False
99
-
100
- def forward(self, X):
101
- h = self.slice1(X)
102
- h_relu1_2 = h
103
- h = self.slice2(h)
104
- h_relu2_2 = h
105
- h = self.slice3(h)
106
- h_relu3_3 = h
107
- h = self.slice4(h)
108
- h_relu4_3 = h
109
- h = self.slice5(h)
110
- h_relu5_3 = h
111
- vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113
- return out
114
-
115
-
116
- def normalize_tensor(x,eps=1e-10):
117
- norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118
- return x/(norm_factor+eps)
119
-
120
-
121
- def spatial_average(x, keepdim=True):
122
- return x.mean([2,3],keepdim=keepdim)
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/losses/segmentation.py DELETED
@@ -1,22 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
-
4
-
5
- class BCELoss(nn.Module):
6
- def forward(self, prediction, target):
7
- loss = F.binary_cross_entropy_with_logits(prediction,target)
8
- return loss, {}
9
-
10
-
11
- class BCELossWithQuant(nn.Module):
12
- def __init__(self, codebook_weight=1.):
13
- super().__init__()
14
- self.codebook_weight = codebook_weight
15
-
16
- def forward(self, qloss, target, prediction, split):
17
- bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18
- loss = bce_loss + self.codebook_weight*qloss
19
- return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20
- "{}/bce_loss".format(split): bce_loss.detach().mean(),
21
- "{}/quant_loss".format(split): qloss.detach().mean()
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/losses/vqperceptual.py DELETED
@@ -1,136 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from taming.modules.losses.lpips import LPIPS
6
- from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
-
8
-
9
- class DummyLoss(nn.Module):
10
- def __init__(self):
11
- super().__init__()
12
-
13
-
14
- def adopt_weight(weight, global_step, threshold=0, value=0.):
15
- if global_step < threshold:
16
- weight = value
17
- return weight
18
-
19
-
20
- def hinge_d_loss(logits_real, logits_fake):
21
- loss_real = torch.mean(F.relu(1. - logits_real))
22
- loss_fake = torch.mean(F.relu(1. + logits_fake))
23
- d_loss = 0.5 * (loss_real + loss_fake)
24
- return d_loss
25
-
26
-
27
- def vanilla_d_loss(logits_real, logits_fake):
28
- d_loss = 0.5 * (
29
- torch.mean(torch.nn.functional.softplus(-logits_real)) +
30
- torch.mean(torch.nn.functional.softplus(logits_fake)))
31
- return d_loss
32
-
33
-
34
- class VQLPIPSWithDiscriminator(nn.Module):
35
- def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36
- disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37
- perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38
- disc_ndf=64, disc_loss="hinge"):
39
- super().__init__()
40
- assert disc_loss in ["hinge", "vanilla"]
41
- self.codebook_weight = codebook_weight
42
- self.pixel_weight = pixelloss_weight
43
- self.perceptual_loss = LPIPS().eval()
44
- self.perceptual_weight = perceptual_weight
45
-
46
- self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47
- n_layers=disc_num_layers,
48
- use_actnorm=use_actnorm,
49
- ndf=disc_ndf
50
- ).apply(weights_init)
51
- self.discriminator_iter_start = disc_start
52
- if disc_loss == "hinge":
53
- self.disc_loss = hinge_d_loss
54
- elif disc_loss == "vanilla":
55
- self.disc_loss = vanilla_d_loss
56
- else:
57
- raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58
- print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59
- self.disc_factor = disc_factor
60
- self.discriminator_weight = disc_weight
61
- self.disc_conditional = disc_conditional
62
-
63
- def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64
- if last_layer is not None:
65
- nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67
- else:
68
- nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69
- g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70
-
71
- d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72
- d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73
- d_weight = d_weight * self.discriminator_weight
74
- return d_weight
75
-
76
- def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77
- global_step, last_layer=None, cond=None, split="train"):
78
- rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79
- if self.perceptual_weight > 0:
80
- p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81
- rec_loss = rec_loss + self.perceptual_weight * p_loss
82
- else:
83
- p_loss = torch.tensor([0.0])
84
-
85
- nll_loss = rec_loss
86
- #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87
- nll_loss = torch.mean(nll_loss)
88
-
89
- # now the GAN part
90
- if optimizer_idx == 0:
91
- # generator update
92
- if cond is None:
93
- assert not self.disc_conditional
94
- logits_fake = self.discriminator(reconstructions.contiguous())
95
- else:
96
- assert self.disc_conditional
97
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98
- g_loss = -torch.mean(logits_fake)
99
-
100
- try:
101
- d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102
- except RuntimeError:
103
- assert not self.training
104
- d_weight = torch.tensor(0.0)
105
-
106
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107
- loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108
-
109
- log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110
- "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111
- "{}/nll_loss".format(split): nll_loss.detach().mean(),
112
- "{}/rec_loss".format(split): rec_loss.detach().mean(),
113
- "{}/p_loss".format(split): p_loss.detach().mean(),
114
- "{}/d_weight".format(split): d_weight.detach(),
115
- "{}/disc_factor".format(split): torch.tensor(disc_factor),
116
- "{}/g_loss".format(split): g_loss.detach().mean(),
117
- }
118
- return loss, log
119
-
120
- if optimizer_idx == 1:
121
- # second pass for discriminator update
122
- if cond is None:
123
- logits_real = self.discriminator(inputs.contiguous().detach())
124
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
125
- else:
126
- logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128
-
129
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130
- d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131
-
132
- log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133
- "{}/logits_real".format(split): logits_real.detach().mean(),
134
- "{}/logits_fake".format(split): logits_fake.detach().mean()
135
- }
136
- return d_loss, log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/misc/coord.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
-
3
- class CoordStage(object):
4
- def __init__(self, n_embed, down_factor):
5
- self.n_embed = n_embed
6
- self.down_factor = down_factor
7
-
8
- def eval(self):
9
- return self
10
-
11
- def encode(self, c):
12
- """fake vqmodel interface"""
13
- assert 0.0 <= c.min() and c.max() <= 1.0
14
- b,ch,h,w = c.shape
15
- assert ch == 1
16
-
17
- c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18
- mode="area")
19
- c = c.clamp(0.0, 1.0)
20
- c = self.n_embed*c
21
- c_quant = c.round()
22
- c_ind = c_quant.to(dtype=torch.long)
23
-
24
- info = None, None, c_ind
25
- return c_quant, None, info
26
-
27
- def decode(self, c):
28
- c = c/self.n_embed
29
- c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30
- mode="nearest")
31
- return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/transformer/mingpt.py DELETED
@@ -1,415 +0,0 @@
1
- """
2
- taken from: https://github.com/karpathy/minGPT/
3
- GPT model:
4
- - the initial stem consists of a combination of token encoding and a positional encoding
5
- - the meat of it is a uniform sequence of Transformer blocks
6
- - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7
- - all blocks feed into a central residual pathway similar to resnets
8
- - the final decoder is a linear projection into a vanilla Softmax classifier
9
- """
10
-
11
- import math
12
- import logging
13
-
14
- import torch
15
- import torch.nn as nn
16
- from torch.nn import functional as F
17
- from transformers import top_k_top_p_filtering
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class GPTConfig:
23
- """ base GPT config, params common to all GPT versions """
24
- embd_pdrop = 0.1
25
- resid_pdrop = 0.1
26
- attn_pdrop = 0.1
27
-
28
- def __init__(self, vocab_size, block_size, **kwargs):
29
- self.vocab_size = vocab_size
30
- self.block_size = block_size
31
- for k,v in kwargs.items():
32
- setattr(self, k, v)
33
-
34
-
35
- class GPT1Config(GPTConfig):
36
- """ GPT-1 like network roughly 125M params """
37
- n_layer = 12
38
- n_head = 12
39
- n_embd = 768
40
-
41
-
42
- class CausalSelfAttention(nn.Module):
43
- """
44
- A vanilla multi-head masked self-attention layer with a projection at the end.
45
- It is possible to use torch.nn.MultiheadAttention here but I am including an
46
- explicit implementation here to show that there is nothing too scary here.
47
- """
48
-
49
- def __init__(self, config):
50
- super().__init__()
51
- assert config.n_embd % config.n_head == 0
52
- # key, query, value projections for all heads
53
- self.key = nn.Linear(config.n_embd, config.n_embd)
54
- self.query = nn.Linear(config.n_embd, config.n_embd)
55
- self.value = nn.Linear(config.n_embd, config.n_embd)
56
- # regularization
57
- self.attn_drop = nn.Dropout(config.attn_pdrop)
58
- self.resid_drop = nn.Dropout(config.resid_pdrop)
59
- # output projection
60
- self.proj = nn.Linear(config.n_embd, config.n_embd)
61
- # causal mask to ensure that attention is only applied to the left in the input sequence
62
- mask = torch.tril(torch.ones(config.block_size,
63
- config.block_size))
64
- if hasattr(config, "n_unmasked"):
65
- mask[:config.n_unmasked, :config.n_unmasked] = 1
66
- self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
67
- self.n_head = config.n_head
68
-
69
- def forward(self, x, layer_past=None):
70
- B, T, C = x.size()
71
-
72
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
73
- k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
74
- q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
75
- v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
76
-
77
- present = torch.stack((k, v))
78
- if layer_past is not None:
79
- past_key, past_value = layer_past
80
- k = torch.cat((past_key, k), dim=-2)
81
- v = torch.cat((past_value, v), dim=-2)
82
-
83
- # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
84
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
85
- if layer_past is None:
86
- att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
87
-
88
- att = F.softmax(att, dim=-1)
89
- att = self.attn_drop(att)
90
- y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
91
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
92
-
93
- # output projection
94
- y = self.resid_drop(self.proj(y))
95
- return y, present # TODO: check that this does not break anything
96
-
97
-
98
- class Block(nn.Module):
99
- """ an unassuming Transformer block """
100
- def __init__(self, config):
101
- super().__init__()
102
- self.ln1 = nn.LayerNorm(config.n_embd)
103
- self.ln2 = nn.LayerNorm(config.n_embd)
104
- self.attn = CausalSelfAttention(config)
105
- self.mlp = nn.Sequential(
106
- nn.Linear(config.n_embd, 4 * config.n_embd),
107
- nn.GELU(), # nice
108
- nn.Linear(4 * config.n_embd, config.n_embd),
109
- nn.Dropout(config.resid_pdrop),
110
- )
111
-
112
- def forward(self, x, layer_past=None, return_present=False):
113
- # TODO: check that training still works
114
- if return_present: assert not self.training
115
- # layer past: tuple of length two with B, nh, T, hs
116
- attn, present = self.attn(self.ln1(x), layer_past=layer_past)
117
-
118
- x = x + attn
119
- x = x + self.mlp(self.ln2(x))
120
- if layer_past is not None or return_present:
121
- return x, present
122
- return x
123
-
124
-
125
- class GPT(nn.Module):
126
- """ the full GPT language model, with a context size of block_size """
127
- def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
128
- embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
129
- super().__init__()
130
- config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
131
- embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
132
- n_layer=n_layer, n_head=n_head, n_embd=n_embd,
133
- n_unmasked=n_unmasked)
134
- # input embedding stem
135
- self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
136
- self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
137
- self.drop = nn.Dropout(config.embd_pdrop)
138
- # transformer
139
- self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
140
- # decoder head
141
- self.ln_f = nn.LayerNorm(config.n_embd)
142
- self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
143
- self.block_size = config.block_size
144
- self.apply(self._init_weights)
145
- self.config = config
146
- logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
147
-
148
- def get_block_size(self):
149
- return self.block_size
150
-
151
- def _init_weights(self, module):
152
- if isinstance(module, (nn.Linear, nn.Embedding)):
153
- module.weight.data.normal_(mean=0.0, std=0.02)
154
- if isinstance(module, nn.Linear) and module.bias is not None:
155
- module.bias.data.zero_()
156
- elif isinstance(module, nn.LayerNorm):
157
- module.bias.data.zero_()
158
- module.weight.data.fill_(1.0)
159
-
160
- def forward(self, idx, embeddings=None, targets=None):
161
- # forward the GPT model
162
- token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
163
-
164
- if embeddings is not None: # prepend explicit embeddings
165
- token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
166
-
167
- t = token_embeddings.shape[1]
168
- assert t <= self.block_size, "Cannot forward, model block size is exhausted."
169
- position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
170
- x = self.drop(token_embeddings + position_embeddings)
171
- x = self.blocks(x)
172
- x = self.ln_f(x)
173
- logits = self.head(x)
174
-
175
- # if we are given some desired targets also calculate the loss
176
- loss = None
177
- if targets is not None:
178
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
179
-
180
- return logits, loss
181
-
182
- def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
183
- # inference only
184
- assert not self.training
185
- token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186
- if embeddings is not None: # prepend explicit embeddings
187
- token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
188
-
189
- if past is not None:
190
- assert past_length is not None
191
- past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
192
- past_shape = list(past.shape)
193
- expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
194
- assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
195
- position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
196
- else:
197
- position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
198
-
199
- x = self.drop(token_embeddings + position_embeddings)
200
- presents = [] # accumulate over layers
201
- for i, block in enumerate(self.blocks):
202
- x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
203
- presents.append(present)
204
-
205
- x = self.ln_f(x)
206
- logits = self.head(x)
207
- # if we are given some desired targets also calculate the loss
208
- loss = None
209
- if targets is not None:
210
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
211
-
212
- return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
213
-
214
-
215
- class DummyGPT(nn.Module):
216
- # for debugging
217
- def __init__(self, add_value=1):
218
- super().__init__()
219
- self.add_value = add_value
220
-
221
- def forward(self, idx):
222
- return idx + self.add_value, None
223
-
224
-
225
- class CodeGPT(nn.Module):
226
- """Takes in semi-embeddings"""
227
- def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
228
- embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
229
- super().__init__()
230
- config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
231
- embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
232
- n_layer=n_layer, n_head=n_head, n_embd=n_embd,
233
- n_unmasked=n_unmasked)
234
- # input embedding stem
235
- self.tok_emb = nn.Linear(in_channels, config.n_embd)
236
- self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
237
- self.drop = nn.Dropout(config.embd_pdrop)
238
- # transformer
239
- self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
240
- # decoder head
241
- self.ln_f = nn.LayerNorm(config.n_embd)
242
- self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
243
- self.block_size = config.block_size
244
- self.apply(self._init_weights)
245
- self.config = config
246
- logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
247
-
248
- def get_block_size(self):
249
- return self.block_size
250
-
251
- def _init_weights(self, module):
252
- if isinstance(module, (nn.Linear, nn.Embedding)):
253
- module.weight.data.normal_(mean=0.0, std=0.02)
254
- if isinstance(module, nn.Linear) and module.bias is not None:
255
- module.bias.data.zero_()
256
- elif isinstance(module, nn.LayerNorm):
257
- module.bias.data.zero_()
258
- module.weight.data.fill_(1.0)
259
-
260
- def forward(self, idx, embeddings=None, targets=None):
261
- # forward the GPT model
262
- token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
263
-
264
- if embeddings is not None: # prepend explicit embeddings
265
- token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
266
-
267
- t = token_embeddings.shape[1]
268
- assert t <= self.block_size, "Cannot forward, model block size is exhausted."
269
- position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
270
- x = self.drop(token_embeddings + position_embeddings)
271
- x = self.blocks(x)
272
- x = self.taming_cinln_f(x)
273
- logits = self.head(x)
274
-
275
- # if we are given some desired targets also calculate the loss
276
- loss = None
277
- if targets is not None:
278
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
279
-
280
- return logits, loss
281
-
282
-
283
-
284
- #### sampling utils
285
-
286
- def top_k_logits(logits, k):
287
- v, ix = torch.topk(logits, k)
288
- out = logits.clone()
289
- out[out < v[:, [-1]]] = -float('Inf')
290
- return out
291
-
292
- @torch.no_grad()
293
- def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
294
- """
295
- take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
296
- the sequence, feeding the predictions back into the model each time. Clearly the sampling
297
- has quadratic complexity unlike an RNN that is only linear, and has a finite context window
298
- of block_size, unlike an RNN that has an infinite context window.
299
- """
300
- block_size = model.get_block_size()
301
- model.eval()
302
- for k in range(steps):
303
- x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
304
- logits, _ = model(x_cond)
305
- # pluck the logits at the final step and scale by temperature
306
- logits = logits[:, -1, :] / temperature
307
- # optionally crop probabilities to only the top k options
308
- if top_k is not None:
309
- logits = top_k_logits(logits, top_k)
310
- # apply softmax to convert to probabilities
311
- probs = F.softmax(logits, dim=-1)
312
- # sample from the distribution or take the most likely
313
- if sample:
314
- ix = torch.multinomial(probs, num_samples=1)
315
- else:
316
- _, ix = torch.topk(probs, k=1, dim=-1)
317
- # append to the sequence and continue
318
- x = torch.cat((x, ix), dim=1)
319
-
320
- return x
321
-
322
-
323
- @torch.no_grad()
324
- def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
325
- top_k=None, top_p=None, callback=None):
326
- # x is conditioning
327
- sample = x
328
- cond_len = x.shape[1]
329
- past = None
330
- for n in range(steps):
331
- if callback is not None:
332
- callback(n)
333
- logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
334
- if past is None:
335
- past = [present]
336
- else:
337
- past.append(present)
338
- logits = logits[:, -1, :] / temperature
339
- if top_k is not None:
340
- logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
341
-
342
- probs = F.softmax(logits, dim=-1)
343
- if not sample_logits:
344
- _, x = torch.topk(probs, k=1, dim=-1)
345
- else:
346
- x = torch.multinomial(probs, num_samples=1)
347
- # append to the sequence and continue
348
- sample = torch.cat((sample, x), dim=1)
349
- del past
350
- sample = sample[:, cond_len:] # cut conditioning off
351
- return sample
352
-
353
-
354
- #### clustering utils
355
-
356
- class KMeans(nn.Module):
357
- def __init__(self, ncluster=512, nc=3, niter=10):
358
- super().__init__()
359
- self.ncluster = ncluster
360
- self.nc = nc
361
- self.niter = niter
362
- self.shape = (3,32,32)
363
- self.register_buffer("C", torch.zeros(self.ncluster,nc))
364
- self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
365
-
366
- def is_initialized(self):
367
- return self.initialized.item() == 1
368
-
369
- @torch.no_grad()
370
- def initialize(self, x):
371
- N, D = x.shape
372
- assert D == self.nc, D
373
- c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
374
- for i in range(self.niter):
375
- # assign all pixels to the closest codebook element
376
- a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
377
- # move each codebook element to be the mean of the pixels that assigned to it
378
- c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
379
- # re-assign any poorly positioned codebook elements
380
- nanix = torch.any(torch.isnan(c), dim=1)
381
- ndead = nanix.sum().item()
382
- print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
383
- c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
384
-
385
- self.C.copy_(c)
386
- self.initialized.fill_(1)
387
-
388
-
389
- def forward(self, x, reverse=False, shape=None):
390
- if not reverse:
391
- # flatten
392
- bs,c,h,w = x.shape
393
- assert c == self.nc
394
- x = x.reshape(bs,c,h*w,1)
395
- C = self.C.permute(1,0)
396
- C = C.reshape(1,c,1,self.ncluster)
397
- a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
398
- return a
399
- else:
400
- # flatten
401
- bs, HW = x.shape
402
- """
403
- c = self.C.reshape( 1, self.nc, 1, self.ncluster)
404
- c = c[bs*[0],:,:,:]
405
- c = c[:,:,HW*[0],:]
406
- x = x.reshape(bs, 1, HW, 1)
407
- x = x[:,3*[0],:,:]
408
- x = torch.gather(c, dim=3, index=x)
409
- """
410
- x = self.C[x]
411
- x = x.permute(0,2,1)
412
- shape = shape if shape is not None else self.shape
413
- x = x.reshape(bs, *shape)
414
-
415
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/transformer/permuter.py DELETED
@@ -1,248 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
-
5
-
6
- class AbstractPermuter(nn.Module):
7
- def __init__(self, *args, **kwargs):
8
- super().__init__()
9
- def forward(self, x, reverse=False):
10
- raise NotImplementedError
11
-
12
-
13
- class Identity(AbstractPermuter):
14
- def __init__(self):
15
- super().__init__()
16
-
17
- def forward(self, x, reverse=False):
18
- return x
19
-
20
-
21
- class Subsample(AbstractPermuter):
22
- def __init__(self, H, W):
23
- super().__init__()
24
- C = 1
25
- indices = np.arange(H*W).reshape(C,H,W)
26
- while min(H, W) > 1:
27
- indices = indices.reshape(C,H//2,2,W//2,2)
28
- indices = indices.transpose(0,2,4,1,3)
29
- indices = indices.reshape(C*4,H//2, W//2)
30
- H = H//2
31
- W = W//2
32
- C = C*4
33
- assert H == W == 1
34
- idx = torch.tensor(indices.ravel())
35
- self.register_buffer('forward_shuffle_idx',
36
- nn.Parameter(idx, requires_grad=False))
37
- self.register_buffer('backward_shuffle_idx',
38
- nn.Parameter(torch.argsort(idx), requires_grad=False))
39
-
40
- def forward(self, x, reverse=False):
41
- if not reverse:
42
- return x[:, self.forward_shuffle_idx]
43
- else:
44
- return x[:, self.backward_shuffle_idx]
45
-
46
-
47
- def mortonify(i, j):
48
- """(i,j) index to linear morton code"""
49
- i = np.uint64(i)
50
- j = np.uint64(j)
51
-
52
- z = np.uint(0)
53
-
54
- for pos in range(32):
55
- z = (z |
56
- ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57
- ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58
- )
59
- return z
60
-
61
-
62
- class ZCurve(AbstractPermuter):
63
- def __init__(self, H, W):
64
- super().__init__()
65
- reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66
- idx = np.argsort(reverseidx)
67
- idx = torch.tensor(idx)
68
- reverseidx = torch.tensor(reverseidx)
69
- self.register_buffer('forward_shuffle_idx',
70
- idx)
71
- self.register_buffer('backward_shuffle_idx',
72
- reverseidx)
73
-
74
- def forward(self, x, reverse=False):
75
- if not reverse:
76
- return x[:, self.forward_shuffle_idx]
77
- else:
78
- return x[:, self.backward_shuffle_idx]
79
-
80
-
81
- class SpiralOut(AbstractPermuter):
82
- def __init__(self, H, W):
83
- super().__init__()
84
- assert H == W
85
- size = W
86
- indices = np.arange(size*size).reshape(size,size)
87
-
88
- i0 = size//2
89
- j0 = size//2-1
90
-
91
- i = i0
92
- j = j0
93
-
94
- idx = [indices[i0, j0]]
95
- step_mult = 0
96
- for c in range(1, size//2+1):
97
- step_mult += 1
98
- # steps left
99
- for k in range(step_mult):
100
- i = i - 1
101
- j = j
102
- idx.append(indices[i, j])
103
-
104
- # step down
105
- for k in range(step_mult):
106
- i = i
107
- j = j + 1
108
- idx.append(indices[i, j])
109
-
110
- step_mult += 1
111
- if c < size//2:
112
- # step right
113
- for k in range(step_mult):
114
- i = i + 1
115
- j = j
116
- idx.append(indices[i, j])
117
-
118
- # step up
119
- for k in range(step_mult):
120
- i = i
121
- j = j - 1
122
- idx.append(indices[i, j])
123
- else:
124
- # end reached
125
- for k in range(step_mult-1):
126
- i = i + 1
127
- idx.append(indices[i, j])
128
-
129
- assert len(idx) == size*size
130
- idx = torch.tensor(idx)
131
- self.register_buffer('forward_shuffle_idx', idx)
132
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133
-
134
- def forward(self, x, reverse=False):
135
- if not reverse:
136
- return x[:, self.forward_shuffle_idx]
137
- else:
138
- return x[:, self.backward_shuffle_idx]
139
-
140
-
141
- class SpiralIn(AbstractPermuter):
142
- def __init__(self, H, W):
143
- super().__init__()
144
- assert H == W
145
- size = W
146
- indices = np.arange(size*size).reshape(size,size)
147
-
148
- i0 = size//2
149
- j0 = size//2-1
150
-
151
- i = i0
152
- j = j0
153
-
154
- idx = [indices[i0, j0]]
155
- step_mult = 0
156
- for c in range(1, size//2+1):
157
- step_mult += 1
158
- # steps left
159
- for k in range(step_mult):
160
- i = i - 1
161
- j = j
162
- idx.append(indices[i, j])
163
-
164
- # step down
165
- for k in range(step_mult):
166
- i = i
167
- j = j + 1
168
- idx.append(indices[i, j])
169
-
170
- step_mult += 1
171
- if c < size//2:
172
- # step right
173
- for k in range(step_mult):
174
- i = i + 1
175
- j = j
176
- idx.append(indices[i, j])
177
-
178
- # step up
179
- for k in range(step_mult):
180
- i = i
181
- j = j - 1
182
- idx.append(indices[i, j])
183
- else:
184
- # end reached
185
- for k in range(step_mult-1):
186
- i = i + 1
187
- idx.append(indices[i, j])
188
-
189
- assert len(idx) == size*size
190
- idx = idx[::-1]
191
- idx = torch.tensor(idx)
192
- self.register_buffer('forward_shuffle_idx', idx)
193
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194
-
195
- def forward(self, x, reverse=False):
196
- if not reverse:
197
- return x[:, self.forward_shuffle_idx]
198
- else:
199
- return x[:, self.backward_shuffle_idx]
200
-
201
-
202
- class Random(nn.Module):
203
- def __init__(self, H, W):
204
- super().__init__()
205
- indices = np.random.RandomState(1).permutation(H*W)
206
- idx = torch.tensor(indices.ravel())
207
- self.register_buffer('forward_shuffle_idx', idx)
208
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209
-
210
- def forward(self, x, reverse=False):
211
- if not reverse:
212
- return x[:, self.forward_shuffle_idx]
213
- else:
214
- return x[:, self.backward_shuffle_idx]
215
-
216
-
217
- class AlternateParsing(AbstractPermuter):
218
- def __init__(self, H, W):
219
- super().__init__()
220
- indices = np.arange(W*H).reshape(H,W)
221
- for i in range(1, H, 2):
222
- indices[i, :] = indices[i, ::-1]
223
- idx = indices.flatten()
224
- assert len(idx) == H*W
225
- idx = torch.tensor(idx)
226
- self.register_buffer('forward_shuffle_idx', idx)
227
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228
-
229
- def forward(self, x, reverse=False):
230
- if not reverse:
231
- return x[:, self.forward_shuffle_idx]
232
- else:
233
- return x[:, self.backward_shuffle_idx]
234
-
235
-
236
- if __name__ == "__main__":
237
- p0 = AlternateParsing(16, 16)
238
- print(p0.forward_shuffle_idx)
239
- print(p0.backward_shuffle_idx)
240
-
241
- x = torch.randint(0, 768, size=(11, 256))
242
- y = p0(x)
243
- xre = p0(y, reverse=True)
244
- assert torch.equal(x, xre)
245
-
246
- p1 = SpiralOut(2, 2)
247
- print(p1.forward_shuffle_idx)
248
- print(p1.backward_shuffle_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/util.py DELETED
@@ -1,130 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- def count_params(model):
6
- total_params = sum(p.numel() for p in model.parameters())
7
- return total_params
8
-
9
-
10
- class ActNorm(nn.Module):
11
- def __init__(self, num_features, logdet=False, affine=True,
12
- allow_reverse_init=False):
13
- assert affine
14
- super().__init__()
15
- self.logdet = logdet
16
- self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
- self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
- self.allow_reverse_init = allow_reverse_init
19
-
20
- self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
-
22
- def initialize(self, input):
23
- with torch.no_grad():
24
- flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
- mean = (
26
- flatten.mean(1)
27
- .unsqueeze(1)
28
- .unsqueeze(2)
29
- .unsqueeze(3)
30
- .permute(1, 0, 2, 3)
31
- )
32
- std = (
33
- flatten.std(1)
34
- .unsqueeze(1)
35
- .unsqueeze(2)
36
- .unsqueeze(3)
37
- .permute(1, 0, 2, 3)
38
- )
39
-
40
- self.loc.data.copy_(-mean)
41
- self.scale.data.copy_(1 / (std + 1e-6))
42
-
43
- def forward(self, input, reverse=False):
44
- if reverse:
45
- return self.reverse(input)
46
- if len(input.shape) == 2:
47
- input = input[:,:,None,None]
48
- squeeze = True
49
- else:
50
- squeeze = False
51
-
52
- _, _, height, width = input.shape
53
-
54
- if self.training and self.initialized.item() == 0:
55
- self.initialize(input)
56
- self.initialized.fill_(1)
57
-
58
- h = self.scale * (input + self.loc)
59
-
60
- if squeeze:
61
- h = h.squeeze(-1).squeeze(-1)
62
-
63
- if self.logdet:
64
- log_abs = torch.log(torch.abs(self.scale))
65
- logdet = height*width*torch.sum(log_abs)
66
- logdet = logdet * torch.ones(input.shape[0]).to(input)
67
- return h, logdet
68
-
69
- return h
70
-
71
- def reverse(self, output):
72
- if self.training and self.initialized.item() == 0:
73
- if not self.allow_reverse_init:
74
- raise RuntimeError(
75
- "Initializing ActNorm in reverse direction is "
76
- "disabled by default. Use allow_reverse_init=True to enable."
77
- )
78
- else:
79
- self.initialize(output)
80
- self.initialized.fill_(1)
81
-
82
- if len(output.shape) == 2:
83
- output = output[:,:,None,None]
84
- squeeze = True
85
- else:
86
- squeeze = False
87
-
88
- h = output / self.scale - self.loc
89
-
90
- if squeeze:
91
- h = h.squeeze(-1).squeeze(-1)
92
- return h
93
-
94
-
95
- class AbstractEncoder(nn.Module):
96
- def __init__(self):
97
- super().__init__()
98
-
99
- def encode(self, *args, **kwargs):
100
- raise NotImplementedError
101
-
102
-
103
- class Labelator(AbstractEncoder):
104
- """Net2Net Interface for Class-Conditional Model"""
105
- def __init__(self, n_classes, quantize_interface=True):
106
- super().__init__()
107
- self.n_classes = n_classes
108
- self.quantize_interface = quantize_interface
109
-
110
- def encode(self, c):
111
- c = c[:,None]
112
- if self.quantize_interface:
113
- return c, None, [None, None, c.long()]
114
- return c
115
-
116
-
117
- class SOSProvider(AbstractEncoder):
118
- # for unconditional training
119
- def __init__(self, sos_token, quantize_interface=True):
120
- super().__init__()
121
- self.sos_token = sos_token
122
- self.quantize_interface = quantize_interface
123
-
124
- def encode(self, x):
125
- # get batch size from data and replicate sos_token
126
- c = torch.ones(x.shape[0], 1)*self.sos_token
127
- c = c.long().to(x.device)
128
- if self.quantize_interface:
129
- return c, None, [None, None, c]
130
- return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/modules/vqvae/quantize.py DELETED
@@ -1,445 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- from torch import einsum
6
- from einops import rearrange
7
-
8
-
9
- class VectorQuantizer(nn.Module):
10
- """
11
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
- ____________________________________________
13
- Discretization bottleneck part of the VQ-VAE.
14
- Inputs:
15
- - n_e : number of embeddings
16
- - e_dim : dimension of embedding
17
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
- _____________________________________________
19
- """
20
-
21
- # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22
- # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23
- # used wherever VectorQuantizer has been used before and is additionally
24
- # more efficient.
25
- def __init__(self, n_e, e_dim, beta):
26
- super(VectorQuantizer, self).__init__()
27
- self.n_e = n_e
28
- self.e_dim = e_dim
29
- self.beta = beta
30
-
31
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
32
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33
-
34
- def forward(self, z):
35
- """
36
- Inputs the output of the encoder network z and maps it to a discrete
37
- one-hot vector that is the index of the closest embedding vector e_j
38
- z (continuous) -> z_q (discrete)
39
- z.shape = (batch, channel, height, width)
40
- quantization pipeline:
41
- 1. get encoder input (B,C,H,W)
42
- 2. flatten input to (B*H*W,C)
43
- """
44
- # reshape z -> (batch, height, width, channel) and flatten
45
- z = z.permute(0, 2, 3, 1).contiguous()
46
- z_flattened = z.view(-1, self.e_dim)
47
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48
-
49
- d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50
- torch.sum(self.embedding.weight**2, dim=1) - 2 * \
51
- torch.matmul(z_flattened, self.embedding.weight.t())
52
-
53
- ## could possible replace this here
54
- # #\start...
55
- # find closest encodings
56
- min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57
-
58
- min_encodings = torch.zeros(
59
- min_encoding_indices.shape[0], self.n_e).to(z)
60
- min_encodings.scatter_(1, min_encoding_indices, 1)
61
-
62
- # dtype min encodings: torch.float32
63
- # min_encodings shape: torch.Size([2048, 512])
64
- # min_encoding_indices.shape: torch.Size([2048, 1])
65
-
66
- # get quantized latent vectors
67
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68
- #.........\end
69
-
70
- # with:
71
- # .........\start
72
- #min_encoding_indices = torch.argmin(d, dim=1)
73
- #z_q = self.embedding(min_encoding_indices)
74
- # ......\end......... (TODO)
75
-
76
- # compute loss for embedding
77
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
78
- torch.mean((z_q - z.detach()) ** 2)
79
-
80
- # preserve gradients
81
- z_q = z + (z_q - z).detach()
82
-
83
- # perplexity
84
- e_mean = torch.mean(min_encodings, dim=0)
85
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86
-
87
- # reshape back to match original input shape
88
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
89
-
90
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
-
92
- def get_codebook_entry(self, indices, shape):
93
- # shape specifying (batch, height, width, channel)
94
- # TODO: check for more easy handling with nn.Embedding
95
- min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96
- min_encodings.scatter_(1, indices[:,None], 1)
97
-
98
- # get quantized latent vectors
99
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
-
101
- if shape is not None:
102
- z_q = z_q.view(shape)
103
-
104
- # reshape back to match original input shape
105
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
-
107
- return z_q
108
-
109
-
110
- class GumbelQuantize(nn.Module):
111
- """
112
- credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113
- Gumbel Softmax trick quantizer
114
- Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115
- https://arxiv.org/abs/1611.01144
116
- """
117
- def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
118
- kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
119
- remap=None, unknown_index="random"):
120
- super().__init__()
121
-
122
- self.embedding_dim = embedding_dim
123
- self.n_embed = n_embed
124
-
125
- self.straight_through = straight_through
126
- self.temperature = temp_init
127
- self.kl_weight = kl_weight
128
-
129
- self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
130
- self.embed = nn.Embedding(n_embed, embedding_dim)
131
-
132
- self.use_vqinterface = use_vqinterface
133
-
134
- self.remap = remap
135
- if self.remap is not None:
136
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
137
- self.re_embed = self.used.shape[0]
138
- self.unknown_index = unknown_index # "random" or "extra" or integer
139
- if self.unknown_index == "extra":
140
- self.unknown_index = self.re_embed
141
- self.re_embed = self.re_embed+1
142
- print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
143
- f"Using {self.unknown_index} for unknown indices.")
144
- else:
145
- self.re_embed = n_embed
146
-
147
- def remap_to_used(self, inds):
148
- ishape = inds.shape
149
- assert len(ishape)>1
150
- inds = inds.reshape(ishape[0],-1)
151
- used = self.used.to(inds)
152
- match = (inds[:,:,None]==used[None,None,...]).long()
153
- new = match.argmax(-1)
154
- unknown = match.sum(2)<1
155
- if self.unknown_index == "random":
156
- new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
157
- else:
158
- new[unknown] = self.unknown_index
159
- return new.reshape(ishape)
160
-
161
- def unmap_to_all(self, inds):
162
- ishape = inds.shape
163
- assert len(ishape)>1
164
- inds = inds.reshape(ishape[0],-1)
165
- used = self.used.to(inds)
166
- if self.re_embed > self.used.shape[0]: # extra token
167
- inds[inds>=self.used.shape[0]] = 0 # simply set to zero
168
- back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
169
- return back.reshape(ishape)
170
-
171
- def forward(self, z, temp=None, return_logits=False):
172
- # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
173
- hard = self.straight_through if self.training else True
174
- temp = self.temperature if temp is None else temp
175
-
176
- logits = self.proj(z)
177
- if self.remap is not None:
178
- # continue only with used logits
179
- full_zeros = torch.zeros_like(logits)
180
- logits = logits[:,self.used,...]
181
-
182
- soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
183
- if self.remap is not None:
184
- # go back to all entries but unused set to zero
185
- full_zeros[:,self.used,...] = soft_one_hot
186
- soft_one_hot = full_zeros
187
- z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
188
-
189
- # + kl divergence to the prior loss
190
- qy = F.softmax(logits, dim=1)
191
- diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
192
-
193
- ind = soft_one_hot.argmax(dim=1)
194
- if self.remap is not None:
195
- ind = self.remap_to_used(ind)
196
- if self.use_vqinterface:
197
- if return_logits:
198
- return z_q, diff, (None, None, ind), logits
199
- return z_q, diff, (None, None, ind)
200
- return z_q, diff, ind
201
-
202
- def get_codebook_entry(self, indices, shape):
203
- b, h, w, c = shape
204
- assert b*h*w == indices.shape[0]
205
- indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
206
- if self.remap is not None:
207
- indices = self.unmap_to_all(indices)
208
- one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
209
- z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
210
- return z_q
211
-
212
-
213
- class VectorQuantizer2(nn.Module):
214
- """
215
- Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
216
- avoids costly matrix multiplications and allows for post-hoc remapping of indices.
217
- """
218
- # NOTE: due to a bug the beta term was applied to the wrong term. for
219
- # backwards compatibility we use the buggy version by default, but you can
220
- # specify legacy=False to fix it.
221
- def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
222
- sane_index_shape=False, legacy=True):
223
- super().__init__()
224
- self.n_e = n_e
225
- self.e_dim = e_dim
226
- self.beta = beta
227
- self.legacy = legacy
228
-
229
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
230
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
231
-
232
- self.remap = remap
233
- if self.remap is not None:
234
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
235
- self.re_embed = self.used.shape[0]
236
- self.unknown_index = unknown_index # "random" or "extra" or integer
237
- if self.unknown_index == "extra":
238
- self.unknown_index = self.re_embed
239
- self.re_embed = self.re_embed+1
240
- print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
- f"Using {self.unknown_index} for unknown indices.")
242
- else:
243
- self.re_embed = n_e
244
-
245
- self.sane_index_shape = sane_index_shape
246
-
247
- def remap_to_used(self, inds):
248
- ishape = inds.shape
249
- assert len(ishape)>1
250
- inds = inds.reshape(ishape[0],-1)
251
- used = self.used.to(inds)
252
- match = (inds[:,:,None]==used[None,None,...]).long()
253
- new = match.argmax(-1)
254
- unknown = match.sum(2)<1
255
- if self.unknown_index == "random":
256
- new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
257
- else:
258
- new[unknown] = self.unknown_index
259
- return new.reshape(ishape)
260
-
261
- def unmap_to_all(self, inds):
262
- ishape = inds.shape
263
- assert len(ishape)>1
264
- inds = inds.reshape(ishape[0],-1)
265
- used = self.used.to(inds)
266
- if self.re_embed > self.used.shape[0]: # extra token
267
- inds[inds>=self.used.shape[0]] = 0 # simply set to zero
268
- back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
269
- return back.reshape(ishape)
270
-
271
- def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
272
- assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
273
- assert rescale_logits==False, "Only for interface compatible with Gumbel"
274
- assert return_logits==False, "Only for interface compatible with Gumbel"
275
- # reshape z -> (batch, height, width, channel) and flatten
276
- z = rearrange(z, 'b c h w -> b h w c').contiguous()
277
- z_flattened = z.view(-1, self.e_dim)
278
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
279
-
280
- d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
281
- torch.sum(self.embedding.weight**2, dim=1) - 2 * \
282
- torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
283
-
284
- min_encoding_indices = torch.argmin(d, dim=1)
285
- z_q = self.embedding(min_encoding_indices).view(z.shape)
286
- perplexity = None
287
- min_encodings = None
288
-
289
- # compute loss for embedding
290
- if not self.legacy:
291
- loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
292
- torch.mean((z_q - z.detach()) ** 2)
293
- else:
294
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
295
- torch.mean((z_q - z.detach()) ** 2)
296
-
297
- # preserve gradients
298
- z_q = z + (z_q - z).detach()
299
-
300
- # reshape back to match original input shape
301
- z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
302
-
303
- if self.remap is not None:
304
- min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
305
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
306
- min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
307
-
308
- if self.sane_index_shape:
309
- min_encoding_indices = min_encoding_indices.reshape(
310
- z_q.shape[0], z_q.shape[2], z_q.shape[3])
311
-
312
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
313
-
314
- def get_codebook_entry(self, indices, shape):
315
- # shape specifying (batch, height, width, channel)
316
- if self.remap is not None:
317
- indices = indices.reshape(shape[0],-1) # add batch axis
318
- indices = self.unmap_to_all(indices)
319
- indices = indices.reshape(-1) # flatten again
320
-
321
- # get quantized latent vectors
322
- z_q = self.embedding(indices)
323
-
324
- if shape is not None:
325
- z_q = z_q.view(shape)
326
- # reshape back to match original input shape
327
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
328
-
329
- return z_q
330
-
331
- class EmbeddingEMA(nn.Module):
332
- def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333
- super().__init__()
334
- self.decay = decay
335
- self.eps = eps
336
- weight = torch.randn(num_tokens, codebook_dim)
337
- self.weight = nn.Parameter(weight, requires_grad = False)
338
- self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339
- self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340
- self.update = True
341
-
342
- def forward(self, embed_id):
343
- return F.embedding(embed_id, self.weight)
344
-
345
- def cluster_size_ema_update(self, new_cluster_size):
346
- self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347
-
348
- def embed_avg_ema_update(self, new_embed_avg):
349
- self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350
-
351
- def weight_update(self, num_tokens):
352
- n = self.cluster_size.sum()
353
- smoothed_cluster_size = (
354
- (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355
- )
356
- #normalize embedding average with smoothed cluster size
357
- embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358
- self.weight.data.copy_(embed_normalized)
359
-
360
-
361
- class EMAVectorQuantizer(nn.Module):
362
- def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
363
- remap=None, unknown_index="random"):
364
- super().__init__()
365
- self.codebook_dim = codebook_dim
366
- self.num_tokens = num_tokens
367
- self.beta = beta
368
- self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
369
-
370
- self.remap = remap
371
- if self.remap is not None:
372
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
373
- self.re_embed = self.used.shape[0]
374
- self.unknown_index = unknown_index # "random" or "extra" or integer
375
- if self.unknown_index == "extra":
376
- self.unknown_index = self.re_embed
377
- self.re_embed = self.re_embed+1
378
- print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
379
- f"Using {self.unknown_index} for unknown indices.")
380
- else:
381
- self.re_embed = n_embed
382
-
383
- def remap_to_used(self, inds):
384
- ishape = inds.shape
385
- assert len(ishape)>1
386
- inds = inds.reshape(ishape[0],-1)
387
- used = self.used.to(inds)
388
- match = (inds[:,:,None]==used[None,None,...]).long()
389
- new = match.argmax(-1)
390
- unknown = match.sum(2)<1
391
- if self.unknown_index == "random":
392
- new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
393
- else:
394
- new[unknown] = self.unknown_index
395
- return new.reshape(ishape)
396
-
397
- def unmap_to_all(self, inds):
398
- ishape = inds.shape
399
- assert len(ishape)>1
400
- inds = inds.reshape(ishape[0],-1)
401
- used = self.used.to(inds)
402
- if self.re_embed > self.used.shape[0]: # extra token
403
- inds[inds>=self.used.shape[0]] = 0 # simply set to zero
404
- back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
405
- return back.reshape(ishape)
406
-
407
- def forward(self, z):
408
- # reshape z -> (batch, height, width, channel) and flatten
409
- #z, 'b c h w -> b h w c'
410
- z = rearrange(z, 'b c h w -> b h w c')
411
- z_flattened = z.reshape(-1, self.codebook_dim)
412
-
413
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
- d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415
- self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416
- torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
417
-
418
-
419
- encoding_indices = torch.argmin(d, dim=1)
420
-
421
- z_q = self.embedding(encoding_indices).view(z.shape)
422
- encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
423
- avg_probs = torch.mean(encodings, dim=0)
424
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
425
-
426
- if self.training and self.embedding.update:
427
- #EMA cluster size
428
- encodings_sum = encodings.sum(0)
429
- self.embedding.cluster_size_ema_update(encodings_sum)
430
- #EMA embedding average
431
- embed_sum = encodings.transpose(0,1) @ z_flattened
432
- self.embedding.embed_avg_ema_update(embed_sum)
433
- #normalize embed_avg and update weight
434
- self.embedding.weight_update(self.num_tokens)
435
-
436
- # compute loss for embedding
437
- loss = self.beta * F.mse_loss(z_q.detach(), z)
438
-
439
- # preserve gradients
440
- z_q = z + (z_q - z).detach()
441
-
442
- # reshape back to match original input shape
443
- #z_q, 'b h w c -> b c h w'
444
- z_q = rearrange(z_q, 'b h w c -> b c h w')
445
- return z_q, loss, (perplexity, encodings, encoding_indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming/util.py DELETED
@@ -1,157 +0,0 @@
1
- import os, hashlib
2
- import requests
3
- from tqdm import tqdm
4
-
5
- URL_MAP = {
6
- "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7
- }
8
-
9
- CKPT_MAP = {
10
- "vgg_lpips": "vgg.pth"
11
- }
12
-
13
- MD5_MAP = {
14
- "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15
- }
16
-
17
-
18
- def download(url, local_path, chunk_size=1024):
19
- os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20
- with requests.get(url, stream=True) as r:
21
- total_size = int(r.headers.get("content-length", 0))
22
- with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23
- with open(local_path, "wb") as f:
24
- for data in r.iter_content(chunk_size=chunk_size):
25
- if data:
26
- f.write(data)
27
- pbar.update(chunk_size)
28
-
29
-
30
- def md5_hash(path):
31
- with open(path, "rb") as f:
32
- content = f.read()
33
- return hashlib.md5(content).hexdigest()
34
-
35
-
36
- def get_ckpt_path(name, root, check=False):
37
- assert name in URL_MAP
38
- path = os.path.join(root, CKPT_MAP[name])
39
- if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40
- print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41
- download(URL_MAP[name], path)
42
- md5 = md5_hash(path)
43
- assert md5 == MD5_MAP[name], md5
44
- return path
45
-
46
-
47
- class KeyNotFoundError(Exception):
48
- def __init__(self, cause, keys=None, visited=None):
49
- self.cause = cause
50
- self.keys = keys
51
- self.visited = visited
52
- messages = list()
53
- if keys is not None:
54
- messages.append("Key not found: {}".format(keys))
55
- if visited is not None:
56
- messages.append("Visited: {}".format(visited))
57
- messages.append("Cause:\n{}".format(cause))
58
- message = "\n".join(messages)
59
- super().__init__(message)
60
-
61
-
62
- def retrieve(
63
- list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64
- ):
65
- """Given a nested list or dict return the desired value at key expanding
66
- callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67
- is done in-place.
68
-
69
- Parameters
70
- ----------
71
- list_or_dict : list or dict
72
- Possibly nested list or dictionary.
73
- key : str
74
- key/to/value, path like string describing all keys necessary to
75
- consider to get to the desired value. List indices can also be
76
- passed here.
77
- splitval : str
78
- String that defines the delimiter between keys of the
79
- different depth levels in `key`.
80
- default : obj
81
- Value returned if :attr:`key` is not found.
82
- expand : bool
83
- Whether to expand callable nodes on the path or not.
84
-
85
- Returns
86
- -------
87
- The desired value or if :attr:`default` is not ``None`` and the
88
- :attr:`key` is not found returns ``default``.
89
-
90
- Raises
91
- ------
92
- Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93
- ``None``.
94
- """
95
-
96
- keys = key.split(splitval)
97
-
98
- success = True
99
- try:
100
- visited = []
101
- parent = None
102
- last_key = None
103
- for key in keys:
104
- if callable(list_or_dict):
105
- if not expand:
106
- raise KeyNotFoundError(
107
- ValueError(
108
- "Trying to get past callable node with expand=False."
109
- ),
110
- keys=keys,
111
- visited=visited,
112
- )
113
- list_or_dict = list_or_dict()
114
- parent[last_key] = list_or_dict
115
-
116
- last_key = key
117
- parent = list_or_dict
118
-
119
- try:
120
- if isinstance(list_or_dict, dict):
121
- list_or_dict = list_or_dict[key]
122
- else:
123
- list_or_dict = list_or_dict[int(key)]
124
- except (KeyError, IndexError, ValueError) as e:
125
- raise KeyNotFoundError(e, keys=keys, visited=visited)
126
-
127
- visited += [key]
128
- # final expansion of retrieved value
129
- if expand and callable(list_or_dict):
130
- list_or_dict = list_or_dict()
131
- parent[last_key] = list_or_dict
132
- except KeyNotFoundError as e:
133
- if default is None:
134
- raise e
135
- else:
136
- list_or_dict = default
137
- success = False
138
-
139
- if not pass_success:
140
- return list_or_dict
141
- else:
142
- return list_or_dict, success
143
-
144
-
145
- if __name__ == "__main__":
146
- config = {"keya": "a",
147
- "keyb": "b",
148
- "keyc":
149
- {"cc1": 1,
150
- "cc2": 2,
151
- }
152
- }
153
- from omegaconf import OmegaConf
154
- config = OmegaConf.create(config)
155
- print(config)
156
- retrieve(config, "keya")
157
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming_transformers.egg-info/PKG-INFO DELETED
@@ -1,10 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: taming-transformers
3
- Version: 0.0.1
4
- Summary: Taming Transformers for High-Resolution Image Synthesis
5
- Home-page: UNKNOWN
6
- License: UNKNOWN
7
- Platform: UNKNOWN
8
-
9
- UNKNOWN
10
-
 
 
 
 
 
 
 
 
 
 
 
taming-transformers/taming_transformers.egg-info/SOURCES.txt DELETED
@@ -1,7 +0,0 @@
1
- README.md
2
- setup.py
3
- taming_transformers.egg-info/PKG-INFO
4
- taming_transformers.egg-info/SOURCES.txt
5
- taming_transformers.egg-info/dependency_links.txt
6
- taming_transformers.egg-info/requires.txt
7
- taming_transformers.egg-info/top_level.txt
 
 
 
 
 
 
 
 
taming-transformers/taming_transformers.egg-info/dependency_links.txt DELETED
@@ -1 +0,0 @@
1
-
 
 
taming-transformers/taming_transformers.egg-info/requires.txt DELETED
@@ -1,3 +0,0 @@
1
- torch
2
- numpy
3
- tqdm
 
 
 
 
taming-transformers/taming_transformers.egg-info/top_level.txt DELETED
@@ -1 +0,0 @@
1
-