Chao Xu
commited on
Commit
·
0e93edd
1
Parent(s):
c534da1
test rm taming
Browse files- taming-transformers/.gitignore +0 -2
- taming-transformers/License.txt +0 -19
- taming-transformers/README.md +0 -410
- taming-transformers/configs/coco_cond_stage.yaml +0 -49
- taming-transformers/configs/coco_scene_images_transformer.yaml +0 -80
- taming-transformers/configs/custom_vqgan.yaml +0 -43
- taming-transformers/configs/drin_transformer.yaml +0 -77
- taming-transformers/configs/faceshq_transformer.yaml +0 -61
- taming-transformers/configs/faceshq_vqgan.yaml +0 -42
- taming-transformers/configs/imagenet_vqgan.yaml +0 -42
- taming-transformers/configs/imagenetdepth_vqgan.yaml +0 -41
- taming-transformers/configs/open_images_scene_images_transformer.yaml +0 -86
- taming-transformers/configs/sflckr_cond_stage.yaml +0 -43
- taming-transformers/environment.yaml +0 -25
- taming-transformers/main.py +0 -585
- taming-transformers/scripts/extract_depth.py +0 -112
- taming-transformers/scripts/extract_segmentation.py +0 -130
- taming-transformers/scripts/extract_submodel.py +0 -17
- taming-transformers/scripts/make_samples.py +0 -292
- taming-transformers/scripts/make_scene_samples.py +0 -198
- taming-transformers/scripts/sample_conditional.py +0 -355
- taming-transformers/scripts/sample_fast.py +0 -260
- taming-transformers/setup.py +0 -13
- taming-transformers/taming/lr_scheduler.py +0 -34
- taming-transformers/taming/models/cond_transformer.py +0 -352
- taming-transformers/taming/models/dummy_cond_stage.py +0 -22
- taming-transformers/taming/models/vqgan.py +0 -404
- taming-transformers/taming/modules/diffusionmodules/model.py +0 -776
- taming-transformers/taming/modules/discriminator/model.py +0 -67
- taming-transformers/taming/modules/losses/__init__.py +0 -2
- taming-transformers/taming/modules/losses/lpips.py +0 -123
- taming-transformers/taming/modules/losses/segmentation.py +0 -22
- taming-transformers/taming/modules/losses/vqperceptual.py +0 -136
- taming-transformers/taming/modules/misc/coord.py +0 -31
- taming-transformers/taming/modules/transformer/mingpt.py +0 -415
- taming-transformers/taming/modules/transformer/permuter.py +0 -248
- taming-transformers/taming/modules/util.py +0 -130
- taming-transformers/taming/modules/vqvae/quantize.py +0 -445
- taming-transformers/taming/util.py +0 -157
- taming-transformers/taming_transformers.egg-info/PKG-INFO +0 -10
- taming-transformers/taming_transformers.egg-info/SOURCES.txt +0 -7
- taming-transformers/taming_transformers.egg-info/dependency_links.txt +0 -1
- taming-transformers/taming_transformers.egg-info/requires.txt +0 -3
- taming-transformers/taming_transformers.egg-info/top_level.txt +0 -1
taming-transformers/.gitignore
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
assets/
|
2 |
-
data/
|
|
|
|
|
|
taming-transformers/License.txt
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
2 |
-
|
3 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
-
of this software and associated documentation files (the "Software"), to deal
|
5 |
-
in the Software without restriction, including without limitation the rights
|
6 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
-
copies of the Software, and to permit persons to whom the Software is
|
8 |
-
furnished to do so, subject to the following conditions:
|
9 |
-
|
10 |
-
The above copyright notice and this permission notice shall be included in all
|
11 |
-
copies or substantial portions of the Software.
|
12 |
-
|
13 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
14 |
-
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
15 |
-
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
16 |
-
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
17 |
-
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
18 |
-
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
19 |
-
OR OTHER DEALINGS IN THE SOFTWARE./
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/README.md
DELETED
@@ -1,410 +0,0 @@
|
|
1 |
-
# Taming Transformers for High-Resolution Image Synthesis
|
2 |
-
##### CVPR 2021 (Oral)
|
3 |
-
![teaser](assets/mountain.jpeg)
|
4 |
-
|
5 |
-
[**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)<br/>
|
6 |
-
[Patrick Esser](https://github.com/pesser)\*,
|
7 |
-
[Robin Rombach](https://github.com/rromb)\*,
|
8 |
-
[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
|
9 |
-
\* equal contribution
|
10 |
-
|
11 |
-
**tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer.
|
12 |
-
|
13 |
-
![teaser](assets/teaser.png)
|
14 |
-
[arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/)
|
15 |
-
|
16 |
-
|
17 |
-
### News
|
18 |
-
#### 2022
|
19 |
-
- More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion).
|
20 |
-
- Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis).
|
21 |
-
#### 2021
|
22 |
-
- Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
|
23 |
-
- Included a bugfix for the quantizer. For backward compatibility it is
|
24 |
-
disabled by default (which corresponds to always training with `beta=1.0`).
|
25 |
-
Use `legacy=False` in the quantizer config to enable it.
|
26 |
-
Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)!
|
27 |
-
- Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog.
|
28 |
-
- Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN.
|
29 |
-
- Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf).
|
30 |
-
- Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`.
|
31 |
-
- Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization.
|
32 |
-
See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
|
33 |
-
- We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources).
|
34 |
-
- We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k).
|
35 |
-
- The streamlit demo now supports image completions.
|
36 |
-
- We now include a couple of examples from the D-RIN dataset so you can run the
|
37 |
-
[D-RIN demo](#d-rin) without preparing the dataset first.
|
38 |
-
- You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
|
39 |
-
|
40 |
-
## Requirements
|
41 |
-
A suitable [conda](https://conda.io/) environment named `taming` can be created
|
42 |
-
and activated with:
|
43 |
-
|
44 |
-
```
|
45 |
-
conda env create -f environment.yaml
|
46 |
-
conda activate taming
|
47 |
-
```
|
48 |
-
## Overview of pretrained models
|
49 |
-
The following table provides an overview of all models that are currently available.
|
50 |
-
FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity).
|
51 |
-
For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model.
|
52 |
-
See the corresponding [colab
|
53 |
-
notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb)
|
54 |
-
for a comparison and discussion of reconstruction capabilities.
|
55 |
-
|
56 |
-
| Dataset | FID vs train | FID vs val | Link | Samples (256x256) | Comments
|
57 |
-
| ------------- | ------------- | ------------- |------------- | ------------- |------------- |
|
58 |
-
| FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) | [ffhq_samples](https://k00.fr/j626x093) |
|
59 |
-
| CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) |
|
60 |
-
| ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images)
|
61 |
-
| COCO-Stuff (f=16) | -- | 20.4 | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images)
|
62 |
-
| ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters |
|
63 |
-
| | | | || |
|
64 |
-
| FacesHQ (f=16) | -- | -- | [faceshq_transformer](https://k00.fr/qqfl2do8)
|
65 |
-
| S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
|
66 |
-
| D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5)
|
67 |
-
| | | | | || |
|
68 |
-
| VQGAN ImageNet (f=16), 1024 | 10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
|
69 |
-
| VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
|
70 |
-
| VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion).
|
71 |
-
| VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion)
|
72 |
-
| VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) | --- | Reconstruction-FIDs.
|
73 |
-
| | | | | || |
|
74 |
-
| DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
|
75 |
-
|
76 |
-
|
77 |
-
## Running pretrained models
|
78 |
-
|
79 |
-
The commands below will start a streamlit demo which supports sampling at
|
80 |
-
different resolutions and image completions. To run a non-interactive version
|
81 |
-
of the sampling process, replace `streamlit run scripts/sample_conditional.py --`
|
82 |
-
by `python scripts/make_samples.py --outdir <path_to_write_samples_to>` and
|
83 |
-
keep the remaining command line arguments.
|
84 |
-
|
85 |
-
To sample from unconditional or class-conditional models,
|
86 |
-
run `python scripts/sample_fast.py -r <path/to/config_and_checkpoint>`.
|
87 |
-
We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models,
|
88 |
-
respectively.
|
89 |
-
|
90 |
-
### S-FLCKR
|
91 |
-
![teaser](assets/sunset_and_ocean.jpg)
|
92 |
-
|
93 |
-
You can also [run this model in a Colab
|
94 |
-
notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb),
|
95 |
-
which includes all necessary steps to start sampling.
|
96 |
-
|
97 |
-
Download the
|
98 |
-
[2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
|
99 |
-
folder and place it into `logs`. Then, run
|
100 |
-
```
|
101 |
-
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
|
102 |
-
```
|
103 |
-
|
104 |
-
### ImageNet
|
105 |
-
![teaser](assets/imagenet.png)
|
106 |
-
|
107 |
-
Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv)
|
108 |
-
folder and place it into logs. Sampling from the class-conditional ImageNet
|
109 |
-
model does not require any data preparation. To produce 50 samples for each of
|
110 |
-
the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus
|
111 |
-
sampling and temperature t=1.0, run
|
112 |
-
|
113 |
-
```
|
114 |
-
python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25
|
115 |
-
```
|
116 |
-
|
117 |
-
To restrict the model to certain classes, provide them via the `--classes` argument, separated by
|
118 |
-
commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run
|
119 |
-
|
120 |
-
```
|
121 |
-
python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901
|
122 |
-
```
|
123 |
-
We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results.
|
124 |
-
|
125 |
-
### FFHQ/CelebA-HQ
|
126 |
-
|
127 |
-
Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and
|
128 |
-
[2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf)
|
129 |
-
folders and place them into logs.
|
130 |
-
Again, sampling from these unconditional models does not require any data preparation.
|
131 |
-
To produce 50000 samples, with k=250 for top-k sampling,
|
132 |
-
p=1.0 for nucleus sampling and temperature t=1.0, run
|
133 |
-
|
134 |
-
```
|
135 |
-
python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/
|
136 |
-
```
|
137 |
-
for FFHQ and
|
138 |
-
|
139 |
-
```
|
140 |
-
python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/
|
141 |
-
```
|
142 |
-
to sample from the CelebA-HQ model.
|
143 |
-
For both models it can be advantageous to vary the top-k/top-p parameters for sampling.
|
144 |
-
|
145 |
-
### FacesHQ
|
146 |
-
![teaser](assets/faceshq.jpg)
|
147 |
-
|
148 |
-
Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and
|
149 |
-
place it into `logs`. Follow the data preparation steps for
|
150 |
-
[CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run
|
151 |
-
```
|
152 |
-
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
|
153 |
-
```
|
154 |
-
|
155 |
-
### D-RIN
|
156 |
-
![teaser](assets/drin.jpg)
|
157 |
-
|
158 |
-
Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and
|
159 |
-
place it into `logs`. To run the demo on a couple of example depth maps
|
160 |
-
included in the repository, run
|
161 |
-
|
162 |
-
```
|
163 |
-
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
|
164 |
-
```
|
165 |
-
|
166 |
-
To run the demo on the complete validation set, first follow the data preparation steps for
|
167 |
-
[ImageNet](#imagenet) and then run
|
168 |
-
```
|
169 |
-
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
|
170 |
-
```
|
171 |
-
|
172 |
-
### COCO
|
173 |
-
Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and
|
174 |
-
place it into `logs`. To run the demo on a couple of example segmentation maps
|
175 |
-
included in the repository, run
|
176 |
-
|
177 |
-
```
|
178 |
-
streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
|
179 |
-
```
|
180 |
-
|
181 |
-
### ADE20k
|
182 |
-
Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and
|
183 |
-
place it into `logs`. To run the demo on a couple of example segmentation maps
|
184 |
-
included in the repository, run
|
185 |
-
|
186 |
-
```
|
187 |
-
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
|
188 |
-
```
|
189 |
-
|
190 |
-
## Scene Image Synthesis
|
191 |
-
![teaser](assets/scene_images_samples.svg)
|
192 |
-
Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images.
|
193 |
-
|
194 |
-
### Training
|
195 |
-
Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images.
|
196 |
-
Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models.
|
197 |
-
Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already.
|
198 |
-
|
199 |
-
Code can be run with
|
200 |
-
`python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,`
|
201 |
-
or
|
202 |
-
`python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,`
|
203 |
-
|
204 |
-
### Sampling
|
205 |
-
Train a model as described above or download a pre-trained model:
|
206 |
-
- [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512.
|
207 |
-
- [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11.
|
208 |
-
- [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/)
|
209 |
-
- [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`)
|
210 |
-
|
211 |
-
When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training).
|
212 |
-
|
213 |
-
Scene image generation can be run with
|
214 |
-
`python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512`
|
215 |
-
|
216 |
-
|
217 |
-
## Training on custom data
|
218 |
-
|
219 |
-
Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
|
220 |
-
Those are the steps to follow to make this work:
|
221 |
-
1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
|
222 |
-
1. put your .jpg files in a folder `your_folder`
|
223 |
-
2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
|
224 |
-
3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
|
225 |
-
4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
|
226 |
-
train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
|
227 |
-
|
228 |
-
## Data Preparation
|
229 |
-
|
230 |
-
### ImageNet
|
231 |
-
The code will try to download (through [Academic
|
232 |
-
Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
|
233 |
-
is used. However, since ImageNet is quite large, this requires a lot of disk
|
234 |
-
space and time. If you already have ImageNet on your disk, you can speed things
|
235 |
-
up by putting the data into
|
236 |
-
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
|
237 |
-
`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
|
238 |
-
of `train`/`validation`. It should have the following structure:
|
239 |
-
|
240 |
-
```
|
241 |
-
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
|
242 |
-
├── n01440764
|
243 |
-
│ ├── n01440764_10026.JPEG
|
244 |
-
│ ├── n01440764_10027.JPEG
|
245 |
-
│ ├── ...
|
246 |
-
├── n01443537
|
247 |
-
│ ├── n01443537_10007.JPEG
|
248 |
-
│ ├── n01443537_10014.JPEG
|
249 |
-
│ ├── ...
|
250 |
-
├── ...
|
251 |
-
```
|
252 |
-
|
253 |
-
If you haven't extracted the data, you can also place
|
254 |
-
`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
|
255 |
-
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
|
256 |
-
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
|
257 |
-
extracted into above structure without downloading it again. Note that this
|
258 |
-
will only happen if neither a folder
|
259 |
-
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
|
260 |
-
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
|
261 |
-
if you want to force running the dataset preparation again.
|
262 |
-
|
263 |
-
You will then need to prepare the depth data using
|
264 |
-
[MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
|
265 |
-
`data/imagenet_depth` pointing to a folder with two subfolders `train` and
|
266 |
-
`val`, each mirroring the structure of the corresponding ImageNet folder
|
267 |
-
described above and containing a `png` file for each of ImageNet's `JPEG`
|
268 |
-
files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
|
269 |
-
images. We provide the script `scripts/extract_depth.py` to generate this data.
|
270 |
-
**Please note** that this script uses [MiDaS via PyTorch
|
271 |
-
Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data,
|
272 |
-
the hub provided the [MiDaS
|
273 |
-
v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it
|
274 |
-
provides a v2.1 version. We haven't tested our models with depth maps obtained
|
275 |
-
via v2.1 and if you want to make sure that things work as expected, you must
|
276 |
-
adjust the script to make sure it explicitly uses
|
277 |
-
[v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)!
|
278 |
-
|
279 |
-
### CelebA-HQ
|
280 |
-
Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
|
281 |
-
files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
|
282 |
-
repository](https://github.com/tkarras/progressive_growing_of_gans)).
|
283 |
-
|
284 |
-
### FFHQ
|
285 |
-
Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained
|
286 |
-
from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
|
287 |
-
|
288 |
-
### S-FLCKR
|
289 |
-
Unfortunately, we are not allowed to distribute the images we collected for the
|
290 |
-
S-FLCKR dataset and can therefore only give a description how it was produced.
|
291 |
-
There are many resources on [collecting images from the
|
292 |
-
web](https://github.com/adrianmrit/flickrdatasets) to get started.
|
293 |
-
We collected sufficiently large images from [flickr](https://www.flickr.com)
|
294 |
-
(see `data/flickr_tags.txt` for a full list of tags used to find images)
|
295 |
-
and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network)
|
296 |
-
(see `data/subreddits.txt` for all subreddits that were used).
|
297 |
-
Overall, we collected 107625 images, and split them randomly into 96861
|
298 |
-
training images and 10764 validation images. We then obtained segmentation
|
299 |
-
masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915)
|
300 |
-
trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch
|
301 |
-
reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an
|
302 |
-
example script for this process in `scripts/extract_segmentation.py`.
|
303 |
-
|
304 |
-
### COCO
|
305 |
-
Create a symlink `data/coco` containing the images from the 2017 split in
|
306 |
-
`train2017` and `val2017`, and their annotations in `annotations`. Files can be
|
307 |
-
obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use
|
308 |
-
the [Stuff+thing PNG-style annotations on COCO 2017
|
309 |
-
trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
|
310 |
-
annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which
|
311 |
-
should be placed under `data/cocostuffthings`.
|
312 |
-
|
313 |
-
### ADE20k
|
314 |
-
Create a symlink `data/ade20k_root` containing the contents of
|
315 |
-
[ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
|
316 |
-
from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/).
|
317 |
-
|
318 |
-
## Training models
|
319 |
-
|
320 |
-
### FacesHQ
|
321 |
-
|
322 |
-
Train a VQGAN with
|
323 |
-
```
|
324 |
-
python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
|
325 |
-
```
|
326 |
-
|
327 |
-
Then, adjust the checkpoint path of the config key
|
328 |
-
`model.params.first_stage_config.params.ckpt_path` in
|
329 |
-
`configs/faceshq_transformer.yaml` (or download
|
330 |
-
[2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which
|
331 |
-
corresponds to the preconfigured checkpoint path), then run
|
332 |
-
```
|
333 |
-
python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
|
334 |
-
```
|
335 |
-
|
336 |
-
### D-RIN
|
337 |
-
|
338 |
-
Train a VQGAN on ImageNet with
|
339 |
-
```
|
340 |
-
python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
|
341 |
-
```
|
342 |
-
|
343 |
-
or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac)
|
344 |
-
and place under `logs`. If you trained your own, adjust the path in the config
|
345 |
-
key `model.params.first_stage_config.params.ckpt_path` of
|
346 |
-
`configs/drin_transformer.yaml`.
|
347 |
-
|
348 |
-
Train a VQGAN on Depth Maps of ImageNet with
|
349 |
-
```
|
350 |
-
python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
|
351 |
-
```
|
352 |
-
|
353 |
-
or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i)
|
354 |
-
and place under `logs`. If you trained your own, adjust the path in the config
|
355 |
-
key `model.params.cond_stage_config.params.ckpt_path` of
|
356 |
-
`configs/drin_transformer.yaml`.
|
357 |
-
|
358 |
-
To train the transformer, run
|
359 |
-
```
|
360 |
-
python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
|
361 |
-
```
|
362 |
-
|
363 |
-
## More Resources
|
364 |
-
### Comparing Different First Stage Models
|
365 |
-
The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
|
366 |
-
In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384,
|
367 |
-
a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192
|
368 |
-
codebook entries).
|
369 |
-
![firststages1](assets/first_stage_squirrels.png)
|
370 |
-
![firststages2](assets/first_stage_mushrooms.png)
|
371 |
-
|
372 |
-
### Other
|
373 |
-
- A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg).
|
374 |
-
- A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about).
|
375 |
-
- A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY)
|
376 |
-
by [ayulockin](https://github.com/ayulockin).
|
377 |
-
- A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg).
|
378 |
-
- Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab.
|
379 |
-
|
380 |
-
### Text-to-Image Optimization via CLIP
|
381 |
-
VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation
|
382 |
-
from scratch and image-to-image translation. We recommend the following notebooks/videos/resources:
|
383 |
-
|
384 |
-
- [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme
|
385 |
-
- The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings).
|
386 |
-
- A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available)
|
387 |
-
|
388 |
-
![txt2img](assets/birddrawnbyachild.png)
|
389 |
-
|
390 |
-
Text prompt: *'A bird drawn by a child'*
|
391 |
-
|
392 |
-
## Shout-outs
|
393 |
-
Thanks to everyone who makes their code and models available. In particular,
|
394 |
-
|
395 |
-
- The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
|
396 |
-
- The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT)
|
397 |
-
- The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity)
|
398 |
-
|
399 |
-
## BibTeX
|
400 |
-
|
401 |
-
```
|
402 |
-
@misc{esser2020taming,
|
403 |
-
title={Taming Transformers for High-Resolution Image Synthesis},
|
404 |
-
author={Patrick Esser and Robin Rombach and Björn Ommer},
|
405 |
-
year={2020},
|
406 |
-
eprint={2012.09841},
|
407 |
-
archivePrefix={arXiv},
|
408 |
-
primaryClass={cs.CV}
|
409 |
-
}
|
410 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/coco_cond_stage.yaml
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-06
|
3 |
-
target: taming.models.vqgan.VQSegmentationModel
|
4 |
-
params:
|
5 |
-
embed_dim: 256
|
6 |
-
n_embed: 1024
|
7 |
-
image_key: "segmentation"
|
8 |
-
n_labels: 183
|
9 |
-
ddconfig:
|
10 |
-
double_z: false
|
11 |
-
z_channels: 256
|
12 |
-
resolution: 256
|
13 |
-
in_channels: 183
|
14 |
-
out_ch: 183
|
15 |
-
ch: 128
|
16 |
-
ch_mult:
|
17 |
-
- 1
|
18 |
-
- 1
|
19 |
-
- 2
|
20 |
-
- 2
|
21 |
-
- 4
|
22 |
-
num_res_blocks: 2
|
23 |
-
attn_resolutions:
|
24 |
-
- 16
|
25 |
-
dropout: 0.0
|
26 |
-
|
27 |
-
lossconfig:
|
28 |
-
target: taming.modules.losses.segmentation.BCELossWithQuant
|
29 |
-
params:
|
30 |
-
codebook_weight: 1.0
|
31 |
-
|
32 |
-
data:
|
33 |
-
target: main.DataModuleFromConfig
|
34 |
-
params:
|
35 |
-
batch_size: 12
|
36 |
-
train:
|
37 |
-
target: taming.data.coco.CocoImagesAndCaptionsTrain
|
38 |
-
params:
|
39 |
-
size: 296
|
40 |
-
crop_size: 256
|
41 |
-
onehot_segmentation: true
|
42 |
-
use_stuffthing: true
|
43 |
-
validation:
|
44 |
-
target: taming.data.coco.CocoImagesAndCaptionsValidation
|
45 |
-
params:
|
46 |
-
size: 256
|
47 |
-
crop_size: 256
|
48 |
-
onehot_segmentation: true
|
49 |
-
use_stuffthing: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/coco_scene_images_transformer.yaml
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-06
|
3 |
-
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
-
params:
|
5 |
-
cond_stage_key: objects_bbox
|
6 |
-
transformer_config:
|
7 |
-
target: taming.modules.transformer.mingpt.GPT
|
8 |
-
params:
|
9 |
-
vocab_size: 8192
|
10 |
-
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
|
11 |
-
n_layer: 40
|
12 |
-
n_head: 16
|
13 |
-
n_embd: 1408
|
14 |
-
embd_pdrop: 0.1
|
15 |
-
resid_pdrop: 0.1
|
16 |
-
attn_pdrop: 0.1
|
17 |
-
first_stage_config:
|
18 |
-
target: taming.models.vqgan.VQModel
|
19 |
-
params:
|
20 |
-
ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
|
21 |
-
embed_dim: 256
|
22 |
-
n_embed: 8192
|
23 |
-
ddconfig:
|
24 |
-
double_z: false
|
25 |
-
z_channels: 256
|
26 |
-
resolution: 256
|
27 |
-
in_channels: 3
|
28 |
-
out_ch: 3
|
29 |
-
ch: 128
|
30 |
-
ch_mult:
|
31 |
-
- 1
|
32 |
-
- 1
|
33 |
-
- 2
|
34 |
-
- 2
|
35 |
-
- 4
|
36 |
-
num_res_blocks: 2
|
37 |
-
attn_resolutions:
|
38 |
-
- 16
|
39 |
-
dropout: 0.0
|
40 |
-
lossconfig:
|
41 |
-
target: taming.modules.losses.DummyLoss
|
42 |
-
cond_stage_config:
|
43 |
-
target: taming.models.dummy_cond_stage.DummyCondStage
|
44 |
-
params:
|
45 |
-
conditional_key: objects_bbox
|
46 |
-
|
47 |
-
data:
|
48 |
-
target: main.DataModuleFromConfig
|
49 |
-
params:
|
50 |
-
batch_size: 6
|
51 |
-
train:
|
52 |
-
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
|
53 |
-
params:
|
54 |
-
data_path: data/coco_annotations_100 # substitute with path to full dataset
|
55 |
-
split: train
|
56 |
-
keys: [image, objects_bbox, file_name, annotations]
|
57 |
-
no_tokens: 8192
|
58 |
-
target_image_size: 256
|
59 |
-
min_object_area: 0.00001
|
60 |
-
min_objects_per_image: 2
|
61 |
-
max_objects_per_image: 30
|
62 |
-
crop_method: random-1d
|
63 |
-
random_flip: true
|
64 |
-
use_group_parameter: true
|
65 |
-
encode_crop: true
|
66 |
-
validation:
|
67 |
-
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
|
68 |
-
params:
|
69 |
-
data_path: data/coco_annotations_100 # substitute with path to full dataset
|
70 |
-
split: validation
|
71 |
-
keys: [image, objects_bbox, file_name, annotations]
|
72 |
-
no_tokens: 8192
|
73 |
-
target_image_size: 256
|
74 |
-
min_object_area: 0.00001
|
75 |
-
min_objects_per_image: 2
|
76 |
-
max_objects_per_image: 30
|
77 |
-
crop_method: center
|
78 |
-
random_flip: false
|
79 |
-
use_group_parameter: true
|
80 |
-
encode_crop: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/custom_vqgan.yaml
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-6
|
3 |
-
target: taming.models.vqgan.VQModel
|
4 |
-
params:
|
5 |
-
embed_dim: 256
|
6 |
-
n_embed: 1024
|
7 |
-
ddconfig:
|
8 |
-
double_z: False
|
9 |
-
z_channels: 256
|
10 |
-
resolution: 256
|
11 |
-
in_channels: 3
|
12 |
-
out_ch: 3
|
13 |
-
ch: 128
|
14 |
-
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
15 |
-
num_res_blocks: 2
|
16 |
-
attn_resolutions: [16]
|
17 |
-
dropout: 0.0
|
18 |
-
|
19 |
-
lossconfig:
|
20 |
-
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
21 |
-
params:
|
22 |
-
disc_conditional: False
|
23 |
-
disc_in_channels: 3
|
24 |
-
disc_start: 10000
|
25 |
-
disc_weight: 0.8
|
26 |
-
codebook_weight: 1.0
|
27 |
-
|
28 |
-
data:
|
29 |
-
target: main.DataModuleFromConfig
|
30 |
-
params:
|
31 |
-
batch_size: 5
|
32 |
-
num_workers: 8
|
33 |
-
train:
|
34 |
-
target: taming.data.custom.CustomTrain
|
35 |
-
params:
|
36 |
-
training_images_list_file: some/training.txt
|
37 |
-
size: 256
|
38 |
-
validation:
|
39 |
-
target: taming.data.custom.CustomTest
|
40 |
-
params:
|
41 |
-
test_images_list_file: some/test.txt
|
42 |
-
size: 256
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/drin_transformer.yaml
DELETED
@@ -1,77 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-06
|
3 |
-
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
-
params:
|
5 |
-
cond_stage_key: depth
|
6 |
-
transformer_config:
|
7 |
-
target: taming.modules.transformer.mingpt.GPT
|
8 |
-
params:
|
9 |
-
vocab_size: 1024
|
10 |
-
block_size: 512
|
11 |
-
n_layer: 24
|
12 |
-
n_head: 16
|
13 |
-
n_embd: 1024
|
14 |
-
first_stage_config:
|
15 |
-
target: taming.models.vqgan.VQModel
|
16 |
-
params:
|
17 |
-
ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt
|
18 |
-
embed_dim: 256
|
19 |
-
n_embed: 1024
|
20 |
-
ddconfig:
|
21 |
-
double_z: false
|
22 |
-
z_channels: 256
|
23 |
-
resolution: 256
|
24 |
-
in_channels: 3
|
25 |
-
out_ch: 3
|
26 |
-
ch: 128
|
27 |
-
ch_mult:
|
28 |
-
- 1
|
29 |
-
- 1
|
30 |
-
- 2
|
31 |
-
- 2
|
32 |
-
- 4
|
33 |
-
num_res_blocks: 2
|
34 |
-
attn_resolutions:
|
35 |
-
- 16
|
36 |
-
dropout: 0.0
|
37 |
-
lossconfig:
|
38 |
-
target: taming.modules.losses.DummyLoss
|
39 |
-
cond_stage_config:
|
40 |
-
target: taming.models.vqgan.VQModel
|
41 |
-
params:
|
42 |
-
ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt
|
43 |
-
embed_dim: 256
|
44 |
-
n_embed: 1024
|
45 |
-
ddconfig:
|
46 |
-
double_z: false
|
47 |
-
z_channels: 256
|
48 |
-
resolution: 256
|
49 |
-
in_channels: 1
|
50 |
-
out_ch: 1
|
51 |
-
ch: 128
|
52 |
-
ch_mult:
|
53 |
-
- 1
|
54 |
-
- 1
|
55 |
-
- 2
|
56 |
-
- 2
|
57 |
-
- 4
|
58 |
-
num_res_blocks: 2
|
59 |
-
attn_resolutions:
|
60 |
-
- 16
|
61 |
-
dropout: 0.0
|
62 |
-
lossconfig:
|
63 |
-
target: taming.modules.losses.DummyLoss
|
64 |
-
|
65 |
-
data:
|
66 |
-
target: main.DataModuleFromConfig
|
67 |
-
params:
|
68 |
-
batch_size: 2
|
69 |
-
num_workers: 8
|
70 |
-
train:
|
71 |
-
target: taming.data.imagenet.RINTrainWithDepth
|
72 |
-
params:
|
73 |
-
size: 256
|
74 |
-
validation:
|
75 |
-
target: taming.data.imagenet.RINValidationWithDepth
|
76 |
-
params:
|
77 |
-
size: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/faceshq_transformer.yaml
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-06
|
3 |
-
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
-
params:
|
5 |
-
cond_stage_key: coord
|
6 |
-
transformer_config:
|
7 |
-
target: taming.modules.transformer.mingpt.GPT
|
8 |
-
params:
|
9 |
-
vocab_size: 1024
|
10 |
-
block_size: 512
|
11 |
-
n_layer: 24
|
12 |
-
n_head: 16
|
13 |
-
n_embd: 1024
|
14 |
-
first_stage_config:
|
15 |
-
target: taming.models.vqgan.VQModel
|
16 |
-
params:
|
17 |
-
ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt
|
18 |
-
embed_dim: 256
|
19 |
-
n_embed: 1024
|
20 |
-
ddconfig:
|
21 |
-
double_z: false
|
22 |
-
z_channels: 256
|
23 |
-
resolution: 256
|
24 |
-
in_channels: 3
|
25 |
-
out_ch: 3
|
26 |
-
ch: 128
|
27 |
-
ch_mult:
|
28 |
-
- 1
|
29 |
-
- 1
|
30 |
-
- 2
|
31 |
-
- 2
|
32 |
-
- 4
|
33 |
-
num_res_blocks: 2
|
34 |
-
attn_resolutions:
|
35 |
-
- 16
|
36 |
-
dropout: 0.0
|
37 |
-
lossconfig:
|
38 |
-
target: taming.modules.losses.DummyLoss
|
39 |
-
cond_stage_config:
|
40 |
-
target: taming.modules.misc.coord.CoordStage
|
41 |
-
params:
|
42 |
-
n_embed: 1024
|
43 |
-
down_factor: 16
|
44 |
-
|
45 |
-
data:
|
46 |
-
target: main.DataModuleFromConfig
|
47 |
-
params:
|
48 |
-
batch_size: 2
|
49 |
-
num_workers: 8
|
50 |
-
train:
|
51 |
-
target: taming.data.faceshq.FacesHQTrain
|
52 |
-
params:
|
53 |
-
size: 256
|
54 |
-
crop_size: 256
|
55 |
-
coord: True
|
56 |
-
validation:
|
57 |
-
target: taming.data.faceshq.FacesHQValidation
|
58 |
-
params:
|
59 |
-
size: 256
|
60 |
-
crop_size: 256
|
61 |
-
coord: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/faceshq_vqgan.yaml
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-6
|
3 |
-
target: taming.models.vqgan.VQModel
|
4 |
-
params:
|
5 |
-
embed_dim: 256
|
6 |
-
n_embed: 1024
|
7 |
-
ddconfig:
|
8 |
-
double_z: False
|
9 |
-
z_channels: 256
|
10 |
-
resolution: 256
|
11 |
-
in_channels: 3
|
12 |
-
out_ch: 3
|
13 |
-
ch: 128
|
14 |
-
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
15 |
-
num_res_blocks: 2
|
16 |
-
attn_resolutions: [16]
|
17 |
-
dropout: 0.0
|
18 |
-
|
19 |
-
lossconfig:
|
20 |
-
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
21 |
-
params:
|
22 |
-
disc_conditional: False
|
23 |
-
disc_in_channels: 3
|
24 |
-
disc_start: 30001
|
25 |
-
disc_weight: 0.8
|
26 |
-
codebook_weight: 1.0
|
27 |
-
|
28 |
-
data:
|
29 |
-
target: main.DataModuleFromConfig
|
30 |
-
params:
|
31 |
-
batch_size: 3
|
32 |
-
num_workers: 8
|
33 |
-
train:
|
34 |
-
target: taming.data.faceshq.FacesHQTrain
|
35 |
-
params:
|
36 |
-
size: 256
|
37 |
-
crop_size: 256
|
38 |
-
validation:
|
39 |
-
target: taming.data.faceshq.FacesHQValidation
|
40 |
-
params:
|
41 |
-
size: 256
|
42 |
-
crop_size: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/imagenet_vqgan.yaml
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-6
|
3 |
-
target: taming.models.vqgan.VQModel
|
4 |
-
params:
|
5 |
-
embed_dim: 256
|
6 |
-
n_embed: 1024
|
7 |
-
ddconfig:
|
8 |
-
double_z: False
|
9 |
-
z_channels: 256
|
10 |
-
resolution: 256
|
11 |
-
in_channels: 3
|
12 |
-
out_ch: 3
|
13 |
-
ch: 128
|
14 |
-
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
15 |
-
num_res_blocks: 2
|
16 |
-
attn_resolutions: [16]
|
17 |
-
dropout: 0.0
|
18 |
-
|
19 |
-
lossconfig:
|
20 |
-
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
21 |
-
params:
|
22 |
-
disc_conditional: False
|
23 |
-
disc_in_channels: 3
|
24 |
-
disc_start: 250001
|
25 |
-
disc_weight: 0.8
|
26 |
-
codebook_weight: 1.0
|
27 |
-
|
28 |
-
data:
|
29 |
-
target: main.DataModuleFromConfig
|
30 |
-
params:
|
31 |
-
batch_size: 12
|
32 |
-
num_workers: 24
|
33 |
-
train:
|
34 |
-
target: taming.data.imagenet.ImageNetTrain
|
35 |
-
params:
|
36 |
-
config:
|
37 |
-
size: 256
|
38 |
-
validation:
|
39 |
-
target: taming.data.imagenet.ImageNetValidation
|
40 |
-
params:
|
41 |
-
config:
|
42 |
-
size: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/imagenetdepth_vqgan.yaml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-6
|
3 |
-
target: taming.models.vqgan.VQModel
|
4 |
-
params:
|
5 |
-
embed_dim: 256
|
6 |
-
n_embed: 1024
|
7 |
-
image_key: depth
|
8 |
-
ddconfig:
|
9 |
-
double_z: False
|
10 |
-
z_channels: 256
|
11 |
-
resolution: 256
|
12 |
-
in_channels: 1
|
13 |
-
out_ch: 1
|
14 |
-
ch: 128
|
15 |
-
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
16 |
-
num_res_blocks: 2
|
17 |
-
attn_resolutions: [16]
|
18 |
-
dropout: 0.0
|
19 |
-
|
20 |
-
lossconfig:
|
21 |
-
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
22 |
-
params:
|
23 |
-
disc_conditional: False
|
24 |
-
disc_in_channels: 1
|
25 |
-
disc_start: 50001
|
26 |
-
disc_weight: 0.75
|
27 |
-
codebook_weight: 1.0
|
28 |
-
|
29 |
-
data:
|
30 |
-
target: main.DataModuleFromConfig
|
31 |
-
params:
|
32 |
-
batch_size: 3
|
33 |
-
num_workers: 8
|
34 |
-
train:
|
35 |
-
target: taming.data.imagenet.ImageNetTrainWithDepth
|
36 |
-
params:
|
37 |
-
size: 256
|
38 |
-
validation:
|
39 |
-
target: taming.data.imagenet.ImageNetValidationWithDepth
|
40 |
-
params:
|
41 |
-
size: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/open_images_scene_images_transformer.yaml
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-06
|
3 |
-
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
-
params:
|
5 |
-
cond_stage_key: objects_bbox
|
6 |
-
transformer_config:
|
7 |
-
target: taming.modules.transformer.mingpt.GPT
|
8 |
-
params:
|
9 |
-
vocab_size: 8192
|
10 |
-
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
|
11 |
-
n_layer: 36
|
12 |
-
n_head: 16
|
13 |
-
n_embd: 1536
|
14 |
-
embd_pdrop: 0.1
|
15 |
-
resid_pdrop: 0.1
|
16 |
-
attn_pdrop: 0.1
|
17 |
-
first_stage_config:
|
18 |
-
target: taming.models.vqgan.VQModel
|
19 |
-
params:
|
20 |
-
ckpt_path: /path/to/coco_oi_epoch12.ckpt # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/
|
21 |
-
embed_dim: 256
|
22 |
-
n_embed: 8192
|
23 |
-
ddconfig:
|
24 |
-
double_z: false
|
25 |
-
z_channels: 256
|
26 |
-
resolution: 256
|
27 |
-
in_channels: 3
|
28 |
-
out_ch: 3
|
29 |
-
ch: 128
|
30 |
-
ch_mult:
|
31 |
-
- 1
|
32 |
-
- 1
|
33 |
-
- 2
|
34 |
-
- 2
|
35 |
-
- 4
|
36 |
-
num_res_blocks: 2
|
37 |
-
attn_resolutions:
|
38 |
-
- 16
|
39 |
-
dropout: 0.0
|
40 |
-
lossconfig:
|
41 |
-
target: taming.modules.losses.DummyLoss
|
42 |
-
cond_stage_config:
|
43 |
-
target: taming.models.dummy_cond_stage.DummyCondStage
|
44 |
-
params:
|
45 |
-
conditional_key: objects_bbox
|
46 |
-
|
47 |
-
data:
|
48 |
-
target: main.DataModuleFromConfig
|
49 |
-
params:
|
50 |
-
batch_size: 6
|
51 |
-
train:
|
52 |
-
target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
|
53 |
-
params:
|
54 |
-
data_path: data/open_images_annotations_100 # substitute with path to full dataset
|
55 |
-
split: train
|
56 |
-
keys: [image, objects_bbox, file_name, annotations]
|
57 |
-
no_tokens: 8192
|
58 |
-
target_image_size: 256
|
59 |
-
category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
|
60 |
-
category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
|
61 |
-
min_object_area: 0.0001
|
62 |
-
min_objects_per_image: 2
|
63 |
-
max_objects_per_image: 30
|
64 |
-
crop_method: random-2d
|
65 |
-
random_flip: true
|
66 |
-
use_group_parameter: true
|
67 |
-
use_additional_parameters: true
|
68 |
-
encode_crop: true
|
69 |
-
validation:
|
70 |
-
target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
|
71 |
-
params:
|
72 |
-
data_path: data/open_images_annotations_100 # substitute with path to full dataset
|
73 |
-
split: validation
|
74 |
-
keys: [image, objects_bbox, file_name, annotations]
|
75 |
-
no_tokens: 8192
|
76 |
-
target_image_size: 256
|
77 |
-
category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
|
78 |
-
category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
|
79 |
-
min_object_area: 0.0001
|
80 |
-
min_objects_per_image: 2
|
81 |
-
max_objects_per_image: 30
|
82 |
-
crop_method: center
|
83 |
-
random_flip: false
|
84 |
-
use_group_parameter: true
|
85 |
-
use_additional_parameters: true
|
86 |
-
encode_crop: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/configs/sflckr_cond_stage.yaml
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-06
|
3 |
-
target: taming.models.vqgan.VQSegmentationModel
|
4 |
-
params:
|
5 |
-
embed_dim: 256
|
6 |
-
n_embed: 1024
|
7 |
-
image_key: "segmentation"
|
8 |
-
n_labels: 182
|
9 |
-
ddconfig:
|
10 |
-
double_z: false
|
11 |
-
z_channels: 256
|
12 |
-
resolution: 256
|
13 |
-
in_channels: 182
|
14 |
-
out_ch: 182
|
15 |
-
ch: 128
|
16 |
-
ch_mult:
|
17 |
-
- 1
|
18 |
-
- 1
|
19 |
-
- 2
|
20 |
-
- 2
|
21 |
-
- 4
|
22 |
-
num_res_blocks: 2
|
23 |
-
attn_resolutions:
|
24 |
-
- 16
|
25 |
-
dropout: 0.0
|
26 |
-
|
27 |
-
lossconfig:
|
28 |
-
target: taming.modules.losses.segmentation.BCELossWithQuant
|
29 |
-
params:
|
30 |
-
codebook_weight: 1.0
|
31 |
-
|
32 |
-
data:
|
33 |
-
target: cutlit.DataModuleFromConfig
|
34 |
-
params:
|
35 |
-
batch_size: 12
|
36 |
-
train:
|
37 |
-
target: taming.data.sflckr.Examples # adjust
|
38 |
-
params:
|
39 |
-
size: 256
|
40 |
-
validation:
|
41 |
-
target: taming.data.sflckr.Examples # adjust
|
42 |
-
params:
|
43 |
-
size: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/environment.yaml
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
name: taming
|
2 |
-
channels:
|
3 |
-
- pytorch
|
4 |
-
- defaults
|
5 |
-
dependencies:
|
6 |
-
- python=3.8.5
|
7 |
-
- pip=20.3
|
8 |
-
- cudatoolkit=10.2
|
9 |
-
- pytorch=1.7.0
|
10 |
-
- torchvision=0.8.1
|
11 |
-
- numpy=1.19.2
|
12 |
-
- pip:
|
13 |
-
- albumentations==0.4.3
|
14 |
-
- opencv-python==4.1.2.30
|
15 |
-
- pudb==2019.2
|
16 |
-
- imageio==2.9.0
|
17 |
-
- imageio-ffmpeg==0.4.2
|
18 |
-
- pytorch-lightning==1.0.8
|
19 |
-
- omegaconf==2.0.0
|
20 |
-
- test-tube>=0.7.5
|
21 |
-
- streamlit>=0.73.1
|
22 |
-
- einops==0.3.0
|
23 |
-
- more-itertools>=8.0.0
|
24 |
-
- transformers==4.3.1
|
25 |
-
- -e .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/main.py
DELETED
@@ -1,585 +0,0 @@
|
|
1 |
-
import argparse, os, sys, datetime, glob, importlib
|
2 |
-
from omegaconf import OmegaConf
|
3 |
-
import numpy as np
|
4 |
-
from PIL import Image
|
5 |
-
import torch
|
6 |
-
import torchvision
|
7 |
-
from torch.utils.data import random_split, DataLoader, Dataset
|
8 |
-
import pytorch_lightning as pl
|
9 |
-
from pytorch_lightning import seed_everything
|
10 |
-
from pytorch_lightning.trainer import Trainer
|
11 |
-
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
12 |
-
from pytorch_lightning.utilities import rank_zero_only
|
13 |
-
|
14 |
-
from taming.data.utils import custom_collate
|
15 |
-
|
16 |
-
|
17 |
-
def get_obj_from_str(string, reload=False):
|
18 |
-
module, cls = string.rsplit(".", 1)
|
19 |
-
if reload:
|
20 |
-
module_imp = importlib.import_module(module)
|
21 |
-
importlib.reload(module_imp)
|
22 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
23 |
-
|
24 |
-
|
25 |
-
def get_parser(**parser_kwargs):
|
26 |
-
def str2bool(v):
|
27 |
-
if isinstance(v, bool):
|
28 |
-
return v
|
29 |
-
if v.lower() in ("yes", "true", "t", "y", "1"):
|
30 |
-
return True
|
31 |
-
elif v.lower() in ("no", "false", "f", "n", "0"):
|
32 |
-
return False
|
33 |
-
else:
|
34 |
-
raise argparse.ArgumentTypeError("Boolean value expected.")
|
35 |
-
|
36 |
-
parser = argparse.ArgumentParser(**parser_kwargs)
|
37 |
-
parser.add_argument(
|
38 |
-
"-n",
|
39 |
-
"--name",
|
40 |
-
type=str,
|
41 |
-
const=True,
|
42 |
-
default="",
|
43 |
-
nargs="?",
|
44 |
-
help="postfix for logdir",
|
45 |
-
)
|
46 |
-
parser.add_argument(
|
47 |
-
"-r",
|
48 |
-
"--resume",
|
49 |
-
type=str,
|
50 |
-
const=True,
|
51 |
-
default="",
|
52 |
-
nargs="?",
|
53 |
-
help="resume from logdir or checkpoint in logdir",
|
54 |
-
)
|
55 |
-
parser.add_argument(
|
56 |
-
"-b",
|
57 |
-
"--base",
|
58 |
-
nargs="*",
|
59 |
-
metavar="base_config.yaml",
|
60 |
-
help="paths to base configs. Loaded from left-to-right. "
|
61 |
-
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
62 |
-
default=list(),
|
63 |
-
)
|
64 |
-
parser.add_argument(
|
65 |
-
"-t",
|
66 |
-
"--train",
|
67 |
-
type=str2bool,
|
68 |
-
const=True,
|
69 |
-
default=False,
|
70 |
-
nargs="?",
|
71 |
-
help="train",
|
72 |
-
)
|
73 |
-
parser.add_argument(
|
74 |
-
"--no-test",
|
75 |
-
type=str2bool,
|
76 |
-
const=True,
|
77 |
-
default=False,
|
78 |
-
nargs="?",
|
79 |
-
help="disable test",
|
80 |
-
)
|
81 |
-
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
82 |
-
parser.add_argument(
|
83 |
-
"-d",
|
84 |
-
"--debug",
|
85 |
-
type=str2bool,
|
86 |
-
nargs="?",
|
87 |
-
const=True,
|
88 |
-
default=False,
|
89 |
-
help="enable post-mortem debugging",
|
90 |
-
)
|
91 |
-
parser.add_argument(
|
92 |
-
"-s",
|
93 |
-
"--seed",
|
94 |
-
type=int,
|
95 |
-
default=23,
|
96 |
-
help="seed for seed_everything",
|
97 |
-
)
|
98 |
-
parser.add_argument(
|
99 |
-
"-f",
|
100 |
-
"--postfix",
|
101 |
-
type=str,
|
102 |
-
default="",
|
103 |
-
help="post-postfix for default name",
|
104 |
-
)
|
105 |
-
|
106 |
-
return parser
|
107 |
-
|
108 |
-
|
109 |
-
def nondefault_trainer_args(opt):
|
110 |
-
parser = argparse.ArgumentParser()
|
111 |
-
parser = Trainer.add_argparse_args(parser)
|
112 |
-
args = parser.parse_args([])
|
113 |
-
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
114 |
-
|
115 |
-
|
116 |
-
def instantiate_from_config(config):
|
117 |
-
if not "target" in config:
|
118 |
-
raise KeyError("Expected key `target` to instantiate.")
|
119 |
-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
120 |
-
|
121 |
-
|
122 |
-
class WrappedDataset(Dataset):
|
123 |
-
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
124 |
-
def __init__(self, dataset):
|
125 |
-
self.data = dataset
|
126 |
-
|
127 |
-
def __len__(self):
|
128 |
-
return len(self.data)
|
129 |
-
|
130 |
-
def __getitem__(self, idx):
|
131 |
-
return self.data[idx]
|
132 |
-
|
133 |
-
|
134 |
-
class DataModuleFromConfig(pl.LightningDataModule):
|
135 |
-
def __init__(self, batch_size, train=None, validation=None, test=None,
|
136 |
-
wrap=False, num_workers=None):
|
137 |
-
super().__init__()
|
138 |
-
self.batch_size = batch_size
|
139 |
-
self.dataset_configs = dict()
|
140 |
-
self.num_workers = num_workers if num_workers is not None else batch_size*2
|
141 |
-
if train is not None:
|
142 |
-
self.dataset_configs["train"] = train
|
143 |
-
self.train_dataloader = self._train_dataloader
|
144 |
-
if validation is not None:
|
145 |
-
self.dataset_configs["validation"] = validation
|
146 |
-
self.val_dataloader = self._val_dataloader
|
147 |
-
if test is not None:
|
148 |
-
self.dataset_configs["test"] = test
|
149 |
-
self.test_dataloader = self._test_dataloader
|
150 |
-
self.wrap = wrap
|
151 |
-
|
152 |
-
def prepare_data(self):
|
153 |
-
for data_cfg in self.dataset_configs.values():
|
154 |
-
instantiate_from_config(data_cfg)
|
155 |
-
|
156 |
-
def setup(self, stage=None):
|
157 |
-
self.datasets = dict(
|
158 |
-
(k, instantiate_from_config(self.dataset_configs[k]))
|
159 |
-
for k in self.dataset_configs)
|
160 |
-
if self.wrap:
|
161 |
-
for k in self.datasets:
|
162 |
-
self.datasets[k] = WrappedDataset(self.datasets[k])
|
163 |
-
|
164 |
-
def _train_dataloader(self):
|
165 |
-
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
166 |
-
num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
|
167 |
-
|
168 |
-
def _val_dataloader(self):
|
169 |
-
return DataLoader(self.datasets["validation"],
|
170 |
-
batch_size=self.batch_size,
|
171 |
-
num_workers=self.num_workers, collate_fn=custom_collate)
|
172 |
-
|
173 |
-
def _test_dataloader(self):
|
174 |
-
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
175 |
-
num_workers=self.num_workers, collate_fn=custom_collate)
|
176 |
-
|
177 |
-
|
178 |
-
class SetupCallback(Callback):
|
179 |
-
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
180 |
-
super().__init__()
|
181 |
-
self.resume = resume
|
182 |
-
self.now = now
|
183 |
-
self.logdir = logdir
|
184 |
-
self.ckptdir = ckptdir
|
185 |
-
self.cfgdir = cfgdir
|
186 |
-
self.config = config
|
187 |
-
self.lightning_config = lightning_config
|
188 |
-
|
189 |
-
def on_pretrain_routine_start(self, trainer, pl_module):
|
190 |
-
if trainer.global_rank == 0:
|
191 |
-
# Create logdirs and save configs
|
192 |
-
os.makedirs(self.logdir, exist_ok=True)
|
193 |
-
os.makedirs(self.ckptdir, exist_ok=True)
|
194 |
-
os.makedirs(self.cfgdir, exist_ok=True)
|
195 |
-
|
196 |
-
print("Project config")
|
197 |
-
print(self.config.pretty())
|
198 |
-
OmegaConf.save(self.config,
|
199 |
-
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
200 |
-
|
201 |
-
print("Lightning config")
|
202 |
-
print(self.lightning_config.pretty())
|
203 |
-
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
204 |
-
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
205 |
-
|
206 |
-
else:
|
207 |
-
# ModelCheckpoint callback created log directory --- remove it
|
208 |
-
if not self.resume and os.path.exists(self.logdir):
|
209 |
-
dst, name = os.path.split(self.logdir)
|
210 |
-
dst = os.path.join(dst, "child_runs", name)
|
211 |
-
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
212 |
-
try:
|
213 |
-
os.rename(self.logdir, dst)
|
214 |
-
except FileNotFoundError:
|
215 |
-
pass
|
216 |
-
|
217 |
-
|
218 |
-
class ImageLogger(Callback):
|
219 |
-
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
|
220 |
-
super().__init__()
|
221 |
-
self.batch_freq = batch_frequency
|
222 |
-
self.max_images = max_images
|
223 |
-
self.logger_log_images = {
|
224 |
-
pl.loggers.WandbLogger: self._wandb,
|
225 |
-
pl.loggers.TestTubeLogger: self._testtube,
|
226 |
-
}
|
227 |
-
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
228 |
-
if not increase_log_steps:
|
229 |
-
self.log_steps = [self.batch_freq]
|
230 |
-
self.clamp = clamp
|
231 |
-
|
232 |
-
@rank_zero_only
|
233 |
-
def _wandb(self, pl_module, images, batch_idx, split):
|
234 |
-
raise ValueError("No way wandb")
|
235 |
-
grids = dict()
|
236 |
-
for k in images:
|
237 |
-
grid = torchvision.utils.make_grid(images[k])
|
238 |
-
grids[f"{split}/{k}"] = wandb.Image(grid)
|
239 |
-
pl_module.logger.experiment.log(grids)
|
240 |
-
|
241 |
-
@rank_zero_only
|
242 |
-
def _testtube(self, pl_module, images, batch_idx, split):
|
243 |
-
for k in images:
|
244 |
-
grid = torchvision.utils.make_grid(images[k])
|
245 |
-
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
246 |
-
|
247 |
-
tag = f"{split}/{k}"
|
248 |
-
pl_module.logger.experiment.add_image(
|
249 |
-
tag, grid,
|
250 |
-
global_step=pl_module.global_step)
|
251 |
-
|
252 |
-
@rank_zero_only
|
253 |
-
def log_local(self, save_dir, split, images,
|
254 |
-
global_step, current_epoch, batch_idx):
|
255 |
-
root = os.path.join(save_dir, "images", split)
|
256 |
-
for k in images:
|
257 |
-
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
258 |
-
|
259 |
-
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
260 |
-
grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
|
261 |
-
grid = grid.numpy()
|
262 |
-
grid = (grid*255).astype(np.uint8)
|
263 |
-
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
264 |
-
k,
|
265 |
-
global_step,
|
266 |
-
current_epoch,
|
267 |
-
batch_idx)
|
268 |
-
path = os.path.join(root, filename)
|
269 |
-
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
270 |
-
Image.fromarray(grid).save(path)
|
271 |
-
|
272 |
-
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
273 |
-
if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
|
274 |
-
hasattr(pl_module, "log_images") and
|
275 |
-
callable(pl_module.log_images) and
|
276 |
-
self.max_images > 0):
|
277 |
-
logger = type(pl_module.logger)
|
278 |
-
|
279 |
-
is_train = pl_module.training
|
280 |
-
if is_train:
|
281 |
-
pl_module.eval()
|
282 |
-
|
283 |
-
with torch.no_grad():
|
284 |
-
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
|
285 |
-
|
286 |
-
for k in images:
|
287 |
-
N = min(images[k].shape[0], self.max_images)
|
288 |
-
images[k] = images[k][:N]
|
289 |
-
if isinstance(images[k], torch.Tensor):
|
290 |
-
images[k] = images[k].detach().cpu()
|
291 |
-
if self.clamp:
|
292 |
-
images[k] = torch.clamp(images[k], -1., 1.)
|
293 |
-
|
294 |
-
self.log_local(pl_module.logger.save_dir, split, images,
|
295 |
-
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
296 |
-
|
297 |
-
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
298 |
-
logger_log_images(pl_module, images, pl_module.global_step, split)
|
299 |
-
|
300 |
-
if is_train:
|
301 |
-
pl_module.train()
|
302 |
-
|
303 |
-
def check_frequency(self, batch_idx):
|
304 |
-
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
|
305 |
-
try:
|
306 |
-
self.log_steps.pop(0)
|
307 |
-
except IndexError:
|
308 |
-
pass
|
309 |
-
return True
|
310 |
-
return False
|
311 |
-
|
312 |
-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
313 |
-
self.log_img(pl_module, batch, batch_idx, split="train")
|
314 |
-
|
315 |
-
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
316 |
-
self.log_img(pl_module, batch, batch_idx, split="val")
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
if __name__ == "__main__":
|
321 |
-
# custom parser to specify config files, train, test and debug mode,
|
322 |
-
# postfix, resume.
|
323 |
-
# `--key value` arguments are interpreted as arguments to the trainer.
|
324 |
-
# `nested.key=value` arguments are interpreted as config parameters.
|
325 |
-
# configs are merged from left-to-right followed by command line parameters.
|
326 |
-
|
327 |
-
# model:
|
328 |
-
# base_learning_rate: float
|
329 |
-
# target: path to lightning module
|
330 |
-
# params:
|
331 |
-
# key: value
|
332 |
-
# data:
|
333 |
-
# target: main.DataModuleFromConfig
|
334 |
-
# params:
|
335 |
-
# batch_size: int
|
336 |
-
# wrap: bool
|
337 |
-
# train:
|
338 |
-
# target: path to train dataset
|
339 |
-
# params:
|
340 |
-
# key: value
|
341 |
-
# validation:
|
342 |
-
# target: path to validation dataset
|
343 |
-
# params:
|
344 |
-
# key: value
|
345 |
-
# test:
|
346 |
-
# target: path to test dataset
|
347 |
-
# params:
|
348 |
-
# key: value
|
349 |
-
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
350 |
-
# trainer:
|
351 |
-
# additional arguments to trainer
|
352 |
-
# logger:
|
353 |
-
# logger to instantiate
|
354 |
-
# modelcheckpoint:
|
355 |
-
# modelcheckpoint to instantiate
|
356 |
-
# callbacks:
|
357 |
-
# callback1:
|
358 |
-
# target: importpath
|
359 |
-
# params:
|
360 |
-
# key: value
|
361 |
-
|
362 |
-
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
363 |
-
|
364 |
-
# add cwd for convenience and to make classes in this file available when
|
365 |
-
# running as `python main.py`
|
366 |
-
# (in particular `main.DataModuleFromConfig`)
|
367 |
-
sys.path.append(os.getcwd())
|
368 |
-
|
369 |
-
parser = get_parser()
|
370 |
-
parser = Trainer.add_argparse_args(parser)
|
371 |
-
|
372 |
-
opt, unknown = parser.parse_known_args()
|
373 |
-
if opt.name and opt.resume:
|
374 |
-
raise ValueError(
|
375 |
-
"-n/--name and -r/--resume cannot be specified both."
|
376 |
-
"If you want to resume training in a new log folder, "
|
377 |
-
"use -n/--name in combination with --resume_from_checkpoint"
|
378 |
-
)
|
379 |
-
if opt.resume:
|
380 |
-
if not os.path.exists(opt.resume):
|
381 |
-
raise ValueError("Cannot find {}".format(opt.resume))
|
382 |
-
if os.path.isfile(opt.resume):
|
383 |
-
paths = opt.resume.split("/")
|
384 |
-
idx = len(paths)-paths[::-1].index("logs")+1
|
385 |
-
logdir = "/".join(paths[:idx])
|
386 |
-
ckpt = opt.resume
|
387 |
-
else:
|
388 |
-
assert os.path.isdir(opt.resume), opt.resume
|
389 |
-
logdir = opt.resume.rstrip("/")
|
390 |
-
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
391 |
-
|
392 |
-
opt.resume_from_checkpoint = ckpt
|
393 |
-
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
394 |
-
opt.base = base_configs+opt.base
|
395 |
-
_tmp = logdir.split("/")
|
396 |
-
nowname = _tmp[_tmp.index("logs")+1]
|
397 |
-
else:
|
398 |
-
if opt.name:
|
399 |
-
name = "_"+opt.name
|
400 |
-
elif opt.base:
|
401 |
-
cfg_fname = os.path.split(opt.base[0])[-1]
|
402 |
-
cfg_name = os.path.splitext(cfg_fname)[0]
|
403 |
-
name = "_"+cfg_name
|
404 |
-
else:
|
405 |
-
name = ""
|
406 |
-
nowname = now+name+opt.postfix
|
407 |
-
logdir = os.path.join("logs", nowname)
|
408 |
-
|
409 |
-
ckptdir = os.path.join(logdir, "checkpoints")
|
410 |
-
cfgdir = os.path.join(logdir, "configs")
|
411 |
-
seed_everything(opt.seed)
|
412 |
-
|
413 |
-
try:
|
414 |
-
# init and save configs
|
415 |
-
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
416 |
-
cli = OmegaConf.from_dotlist(unknown)
|
417 |
-
config = OmegaConf.merge(*configs, cli)
|
418 |
-
lightning_config = config.pop("lightning", OmegaConf.create())
|
419 |
-
# merge trainer cli with config
|
420 |
-
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
421 |
-
# default to ddp
|
422 |
-
trainer_config["distributed_backend"] = "ddp"
|
423 |
-
for k in nondefault_trainer_args(opt):
|
424 |
-
trainer_config[k] = getattr(opt, k)
|
425 |
-
if not "gpus" in trainer_config:
|
426 |
-
del trainer_config["distributed_backend"]
|
427 |
-
cpu = True
|
428 |
-
else:
|
429 |
-
gpuinfo = trainer_config["gpus"]
|
430 |
-
print(f"Running on GPUs {gpuinfo}")
|
431 |
-
cpu = False
|
432 |
-
trainer_opt = argparse.Namespace(**trainer_config)
|
433 |
-
lightning_config.trainer = trainer_config
|
434 |
-
|
435 |
-
# model
|
436 |
-
model = instantiate_from_config(config.model)
|
437 |
-
|
438 |
-
# trainer and callbacks
|
439 |
-
trainer_kwargs = dict()
|
440 |
-
|
441 |
-
# default logger configs
|
442 |
-
# NOTE wandb < 0.10.0 interferes with shutdown
|
443 |
-
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
|
444 |
-
# debugging (wrongly sized pudb ui)
|
445 |
-
# thus prefer testtube for now
|
446 |
-
default_logger_cfgs = {
|
447 |
-
"wandb": {
|
448 |
-
"target": "pytorch_lightning.loggers.WandbLogger",
|
449 |
-
"params": {
|
450 |
-
"name": nowname,
|
451 |
-
"save_dir": logdir,
|
452 |
-
"offline": opt.debug,
|
453 |
-
"id": nowname,
|
454 |
-
}
|
455 |
-
},
|
456 |
-
"testtube": {
|
457 |
-
"target": "pytorch_lightning.loggers.TestTubeLogger",
|
458 |
-
"params": {
|
459 |
-
"name": "testtube",
|
460 |
-
"save_dir": logdir,
|
461 |
-
}
|
462 |
-
},
|
463 |
-
}
|
464 |
-
default_logger_cfg = default_logger_cfgs["testtube"]
|
465 |
-
logger_cfg = lightning_config.logger or OmegaConf.create()
|
466 |
-
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
467 |
-
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
468 |
-
|
469 |
-
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
470 |
-
# specify which metric is used to determine best models
|
471 |
-
default_modelckpt_cfg = {
|
472 |
-
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
473 |
-
"params": {
|
474 |
-
"dirpath": ckptdir,
|
475 |
-
"filename": "{epoch:06}",
|
476 |
-
"verbose": True,
|
477 |
-
"save_last": True,
|
478 |
-
}
|
479 |
-
}
|
480 |
-
if hasattr(model, "monitor"):
|
481 |
-
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
482 |
-
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
483 |
-
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
484 |
-
|
485 |
-
modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
|
486 |
-
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
487 |
-
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
488 |
-
|
489 |
-
# add callback which sets up log directory
|
490 |
-
default_callbacks_cfg = {
|
491 |
-
"setup_callback": {
|
492 |
-
"target": "main.SetupCallback",
|
493 |
-
"params": {
|
494 |
-
"resume": opt.resume,
|
495 |
-
"now": now,
|
496 |
-
"logdir": logdir,
|
497 |
-
"ckptdir": ckptdir,
|
498 |
-
"cfgdir": cfgdir,
|
499 |
-
"config": config,
|
500 |
-
"lightning_config": lightning_config,
|
501 |
-
}
|
502 |
-
},
|
503 |
-
"image_logger": {
|
504 |
-
"target": "main.ImageLogger",
|
505 |
-
"params": {
|
506 |
-
"batch_frequency": 750,
|
507 |
-
"max_images": 4,
|
508 |
-
"clamp": True
|
509 |
-
}
|
510 |
-
},
|
511 |
-
"learning_rate_logger": {
|
512 |
-
"target": "main.LearningRateMonitor",
|
513 |
-
"params": {
|
514 |
-
"logging_interval": "step",
|
515 |
-
#"log_momentum": True
|
516 |
-
}
|
517 |
-
},
|
518 |
-
}
|
519 |
-
callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
|
520 |
-
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
521 |
-
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
522 |
-
|
523 |
-
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
524 |
-
|
525 |
-
# data
|
526 |
-
data = instantiate_from_config(config.data)
|
527 |
-
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
528 |
-
# calling these ourselves should not be necessary but it is.
|
529 |
-
# lightning still takes care of proper multiprocessing though
|
530 |
-
data.prepare_data()
|
531 |
-
data.setup()
|
532 |
-
|
533 |
-
# configure learning rate
|
534 |
-
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
535 |
-
if not cpu:
|
536 |
-
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
|
537 |
-
else:
|
538 |
-
ngpu = 1
|
539 |
-
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
|
540 |
-
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
541 |
-
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
542 |
-
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
543 |
-
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
544 |
-
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
545 |
-
|
546 |
-
# allow checkpointing via USR1
|
547 |
-
def melk(*args, **kwargs):
|
548 |
-
# run all checkpoint hooks
|
549 |
-
if trainer.global_rank == 0:
|
550 |
-
print("Summoning checkpoint.")
|
551 |
-
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
552 |
-
trainer.save_checkpoint(ckpt_path)
|
553 |
-
|
554 |
-
def divein(*args, **kwargs):
|
555 |
-
if trainer.global_rank == 0:
|
556 |
-
import pudb; pudb.set_trace()
|
557 |
-
|
558 |
-
import signal
|
559 |
-
signal.signal(signal.SIGUSR1, melk)
|
560 |
-
signal.signal(signal.SIGUSR2, divein)
|
561 |
-
|
562 |
-
# run
|
563 |
-
if opt.train:
|
564 |
-
try:
|
565 |
-
trainer.fit(model, data)
|
566 |
-
except Exception:
|
567 |
-
melk()
|
568 |
-
raise
|
569 |
-
if not opt.no_test and not trainer.interrupted:
|
570 |
-
trainer.test(model, data)
|
571 |
-
except Exception:
|
572 |
-
if opt.debug and trainer.global_rank==0:
|
573 |
-
try:
|
574 |
-
import pudb as debugger
|
575 |
-
except ImportError:
|
576 |
-
import pdb as debugger
|
577 |
-
debugger.post_mortem()
|
578 |
-
raise
|
579 |
-
finally:
|
580 |
-
# move newly created debug project to debug_runs
|
581 |
-
if opt.debug and not opt.resume and trainer.global_rank==0:
|
582 |
-
dst, name = os.path.split(logdir)
|
583 |
-
dst = os.path.join(dst, "debug_runs", name)
|
584 |
-
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
585 |
-
os.rename(logdir, dst)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/scripts/extract_depth.py
DELETED
@@ -1,112 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from tqdm import trange
|
5 |
-
from PIL import Image
|
6 |
-
|
7 |
-
|
8 |
-
def get_state(gpu):
|
9 |
-
import torch
|
10 |
-
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
|
11 |
-
if gpu:
|
12 |
-
midas.cuda()
|
13 |
-
midas.eval()
|
14 |
-
|
15 |
-
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
|
16 |
-
transform = midas_transforms.default_transform
|
17 |
-
|
18 |
-
state = {"model": midas,
|
19 |
-
"transform": transform}
|
20 |
-
return state
|
21 |
-
|
22 |
-
|
23 |
-
def depth_to_rgba(x):
|
24 |
-
assert x.dtype == np.float32
|
25 |
-
assert len(x.shape) == 2
|
26 |
-
y = x.copy()
|
27 |
-
y.dtype = np.uint8
|
28 |
-
y = y.reshape(x.shape+(4,))
|
29 |
-
return np.ascontiguousarray(y)
|
30 |
-
|
31 |
-
|
32 |
-
def rgba_to_depth(x):
|
33 |
-
assert x.dtype == np.uint8
|
34 |
-
assert len(x.shape) == 3 and x.shape[2] == 4
|
35 |
-
y = x.copy()
|
36 |
-
y.dtype = np.float32
|
37 |
-
y = y.reshape(x.shape[:2])
|
38 |
-
return np.ascontiguousarray(y)
|
39 |
-
|
40 |
-
|
41 |
-
def run(x, state):
|
42 |
-
model = state["model"]
|
43 |
-
transform = state["transform"]
|
44 |
-
hw = x.shape[:2]
|
45 |
-
with torch.no_grad():
|
46 |
-
prediction = model(transform((x + 1.0) * 127.5).cuda())
|
47 |
-
prediction = torch.nn.functional.interpolate(
|
48 |
-
prediction.unsqueeze(1),
|
49 |
-
size=hw,
|
50 |
-
mode="bicubic",
|
51 |
-
align_corners=False,
|
52 |
-
).squeeze()
|
53 |
-
output = prediction.cpu().numpy()
|
54 |
-
return output
|
55 |
-
|
56 |
-
|
57 |
-
def get_filename(relpath, level=-2):
|
58 |
-
# save class folder structure and filename:
|
59 |
-
fn = relpath.split(os.sep)[level:]
|
60 |
-
folder = fn[-2]
|
61 |
-
file = fn[-1].split('.')[0]
|
62 |
-
return folder, file
|
63 |
-
|
64 |
-
|
65 |
-
def save_depth(dataset, path, debug=False):
|
66 |
-
os.makedirs(path)
|
67 |
-
N = len(dset)
|
68 |
-
if debug:
|
69 |
-
N = 10
|
70 |
-
state = get_state(gpu=True)
|
71 |
-
for idx in trange(N, desc="Data"):
|
72 |
-
ex = dataset[idx]
|
73 |
-
image, relpath = ex["image"], ex["relpath"]
|
74 |
-
folder, filename = get_filename(relpath)
|
75 |
-
# prepare
|
76 |
-
folderabspath = os.path.join(path, folder)
|
77 |
-
os.makedirs(folderabspath, exist_ok=True)
|
78 |
-
savepath = os.path.join(folderabspath, filename)
|
79 |
-
# run model
|
80 |
-
xout = run(image, state)
|
81 |
-
I = depth_to_rgba(xout)
|
82 |
-
Image.fromarray(I).save("{}.png".format(savepath))
|
83 |
-
|
84 |
-
|
85 |
-
if __name__ == "__main__":
|
86 |
-
from taming.data.imagenet import ImageNetTrain, ImageNetValidation
|
87 |
-
out = "data/imagenet_depth"
|
88 |
-
if not os.path.exists(out):
|
89 |
-
print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
|
90 |
-
"(be prepared that the output size will be larger than ImageNet itself).")
|
91 |
-
exit(1)
|
92 |
-
|
93 |
-
# go
|
94 |
-
dset = ImageNetValidation()
|
95 |
-
abspath = os.path.join(out, "val")
|
96 |
-
if os.path.exists(abspath):
|
97 |
-
print("{} exists - not doing anything.".format(abspath))
|
98 |
-
else:
|
99 |
-
print("preparing {}".format(abspath))
|
100 |
-
save_depth(dset, abspath)
|
101 |
-
print("done with validation split")
|
102 |
-
|
103 |
-
dset = ImageNetTrain()
|
104 |
-
abspath = os.path.join(out, "train")
|
105 |
-
if os.path.exists(abspath):
|
106 |
-
print("{} exists - not doing anything.".format(abspath))
|
107 |
-
else:
|
108 |
-
print("preparing {}".format(abspath))
|
109 |
-
save_depth(dset, abspath)
|
110 |
-
print("done with train split")
|
111 |
-
|
112 |
-
print("done done.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/scripts/extract_segmentation.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import sys, os
|
2 |
-
import numpy as np
|
3 |
-
import scipy
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from scipy import ndimage
|
7 |
-
from tqdm import tqdm, trange
|
8 |
-
from PIL import Image
|
9 |
-
import torch.hub
|
10 |
-
import torchvision
|
11 |
-
import torch.nn.functional as F
|
12 |
-
|
13 |
-
# download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
|
14 |
-
# https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
|
15 |
-
# and put the path here
|
16 |
-
CKPT_PATH = "TODO"
|
17 |
-
|
18 |
-
rescale = lambda x: (x + 1.) / 2.
|
19 |
-
|
20 |
-
def rescale_bgr(x):
|
21 |
-
x = (x+1)*127.5
|
22 |
-
x = torch.flip(x, dims=[0])
|
23 |
-
return x
|
24 |
-
|
25 |
-
|
26 |
-
class COCOStuffSegmenter(nn.Module):
|
27 |
-
def __init__(self, config):
|
28 |
-
super().__init__()
|
29 |
-
self.config = config
|
30 |
-
self.n_labels = 182
|
31 |
-
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
|
32 |
-
ckpt_path = CKPT_PATH
|
33 |
-
model.load_state_dict(torch.load(ckpt_path))
|
34 |
-
self.model = model
|
35 |
-
|
36 |
-
normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
|
37 |
-
self.image_transform = torchvision.transforms.Compose([
|
38 |
-
torchvision.transforms.Lambda(lambda image: torch.stack(
|
39 |
-
[normalize(rescale_bgr(x)) for x in image]))
|
40 |
-
])
|
41 |
-
|
42 |
-
def forward(self, x, upsample=None):
|
43 |
-
x = self._pre_process(x)
|
44 |
-
x = self.model(x)
|
45 |
-
if upsample is not None:
|
46 |
-
x = torch.nn.functional.upsample_bilinear(x, size=upsample)
|
47 |
-
return x
|
48 |
-
|
49 |
-
def _pre_process(self, x):
|
50 |
-
x = self.image_transform(x)
|
51 |
-
return x
|
52 |
-
|
53 |
-
@property
|
54 |
-
def mean(self):
|
55 |
-
# bgr
|
56 |
-
return [104.008, 116.669, 122.675]
|
57 |
-
|
58 |
-
@property
|
59 |
-
def std(self):
|
60 |
-
return [1.0, 1.0, 1.0]
|
61 |
-
|
62 |
-
@property
|
63 |
-
def input_size(self):
|
64 |
-
return [3, 224, 224]
|
65 |
-
|
66 |
-
|
67 |
-
def run_model(img, model):
|
68 |
-
model = model.eval()
|
69 |
-
with torch.no_grad():
|
70 |
-
segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
|
71 |
-
segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
|
72 |
-
return segmentation.detach().cpu()
|
73 |
-
|
74 |
-
|
75 |
-
def get_input(batch, k):
|
76 |
-
x = batch[k]
|
77 |
-
if len(x.shape) == 3:
|
78 |
-
x = x[..., None]
|
79 |
-
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
80 |
-
return x.float()
|
81 |
-
|
82 |
-
|
83 |
-
def save_segmentation(segmentation, path):
|
84 |
-
# --> class label to uint8, save as png
|
85 |
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
86 |
-
assert len(segmentation.shape)==4
|
87 |
-
assert segmentation.shape[0]==1
|
88 |
-
for seg in segmentation:
|
89 |
-
seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
|
90 |
-
seg = Image.fromarray(seg)
|
91 |
-
seg.save(path)
|
92 |
-
|
93 |
-
|
94 |
-
def iterate_dataset(dataloader, destpath, model):
|
95 |
-
os.makedirs(destpath, exist_ok=True)
|
96 |
-
num_processed = 0
|
97 |
-
for i, batch in tqdm(enumerate(dataloader), desc="Data"):
|
98 |
-
try:
|
99 |
-
img = get_input(batch, "image")
|
100 |
-
img = img.cuda()
|
101 |
-
seg = run_model(img, model)
|
102 |
-
|
103 |
-
path = batch["relative_file_path_"][0]
|
104 |
-
path = os.path.splitext(path)[0]
|
105 |
-
|
106 |
-
path = os.path.join(destpath, path + ".png")
|
107 |
-
save_segmentation(seg, path)
|
108 |
-
num_processed += 1
|
109 |
-
except Exception as e:
|
110 |
-
print(e)
|
111 |
-
print("but anyhow..")
|
112 |
-
|
113 |
-
print("Processed {} files. Bye.".format(num_processed))
|
114 |
-
|
115 |
-
|
116 |
-
from taming.data.sflckr import Examples
|
117 |
-
from torch.utils.data import DataLoader
|
118 |
-
|
119 |
-
if __name__ == "__main__":
|
120 |
-
dest = sys.argv[1]
|
121 |
-
batchsize = 1
|
122 |
-
print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
|
123 |
-
|
124 |
-
model = COCOStuffSegmenter({}).cuda()
|
125 |
-
print("Instantiated model.")
|
126 |
-
|
127 |
-
dataset = Examples()
|
128 |
-
dloader = DataLoader(dataset, batch_size=batchsize)
|
129 |
-
iterate_dataset(dataloader=dloader, destpath=dest, model=model)
|
130 |
-
print("done.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/scripts/extract_submodel.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import sys
|
3 |
-
|
4 |
-
if __name__ == "__main__":
|
5 |
-
inpath = sys.argv[1]
|
6 |
-
outpath = sys.argv[2]
|
7 |
-
submodel = "cond_stage_model"
|
8 |
-
if len(sys.argv) > 3:
|
9 |
-
submodel = sys.argv[3]
|
10 |
-
|
11 |
-
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
|
12 |
-
|
13 |
-
sd = torch.load(inpath, map_location="cpu")
|
14 |
-
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
|
15 |
-
for k,v in sd["state_dict"].items()
|
16 |
-
if k.startswith("cond_stage_model"))}
|
17 |
-
torch.save(new_sd, outpath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/scripts/make_samples.py
DELETED
@@ -1,292 +0,0 @@
|
|
1 |
-
import argparse, os, sys, glob, math, time
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from omegaconf import OmegaConf
|
5 |
-
from PIL import Image
|
6 |
-
from main import instantiate_from_config, DataModuleFromConfig
|
7 |
-
from torch.utils.data import DataLoader
|
8 |
-
from torch.utils.data.dataloader import default_collate
|
9 |
-
from tqdm import trange
|
10 |
-
|
11 |
-
|
12 |
-
def save_image(x, path):
|
13 |
-
c,h,w = x.shape
|
14 |
-
assert c==3
|
15 |
-
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
|
16 |
-
Image.fromarray(x).save(path)
|
17 |
-
|
18 |
-
|
19 |
-
@torch.no_grad()
|
20 |
-
def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
|
21 |
-
if len(dsets.datasets) > 1:
|
22 |
-
split = sorted(dsets.datasets.keys())[0]
|
23 |
-
dset = dsets.datasets[split]
|
24 |
-
else:
|
25 |
-
dset = next(iter(dsets.datasets.values()))
|
26 |
-
print("Dataset: ", dset.__class__.__name__)
|
27 |
-
for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
|
28 |
-
indices = list(range(start_idx, start_idx+batch_size))
|
29 |
-
example = default_collate([dset[i] for i in indices])
|
30 |
-
|
31 |
-
x = model.get_input("image", example).to(model.device)
|
32 |
-
for i in range(x.shape[0]):
|
33 |
-
save_image(x[i], os.path.join(outdir, "originals",
|
34 |
-
"{:06}.png".format(indices[i])))
|
35 |
-
|
36 |
-
cond_key = model.cond_stage_key
|
37 |
-
c = model.get_input(cond_key, example).to(model.device)
|
38 |
-
|
39 |
-
scale_factor = 1.0
|
40 |
-
quant_z, z_indices = model.encode_to_z(x)
|
41 |
-
quant_c, c_indices = model.encode_to_c(c)
|
42 |
-
|
43 |
-
cshape = quant_z.shape
|
44 |
-
|
45 |
-
xrec = model.first_stage_model.decode(quant_z)
|
46 |
-
for i in range(xrec.shape[0]):
|
47 |
-
save_image(xrec[i], os.path.join(outdir, "reconstructions",
|
48 |
-
"{:06}.png".format(indices[i])))
|
49 |
-
|
50 |
-
if cond_key == "segmentation":
|
51 |
-
# get image from segmentation mask
|
52 |
-
num_classes = c.shape[1]
|
53 |
-
c = torch.argmax(c, dim=1, keepdim=True)
|
54 |
-
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
|
55 |
-
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
56 |
-
c = model.cond_stage_model.to_rgb(c)
|
57 |
-
|
58 |
-
idx = z_indices
|
59 |
-
|
60 |
-
half_sample = False
|
61 |
-
if half_sample:
|
62 |
-
start = idx.shape[1]//2
|
63 |
-
else:
|
64 |
-
start = 0
|
65 |
-
|
66 |
-
idx[:,start:] = 0
|
67 |
-
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
|
68 |
-
start_i = start//cshape[3]
|
69 |
-
start_j = start %cshape[3]
|
70 |
-
|
71 |
-
cidx = c_indices
|
72 |
-
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
|
73 |
-
|
74 |
-
sample = True
|
75 |
-
|
76 |
-
for i in range(start_i,cshape[2]-0):
|
77 |
-
if i <= 8:
|
78 |
-
local_i = i
|
79 |
-
elif cshape[2]-i < 8:
|
80 |
-
local_i = 16-(cshape[2]-i)
|
81 |
-
else:
|
82 |
-
local_i = 8
|
83 |
-
for j in range(start_j,cshape[3]-0):
|
84 |
-
if j <= 8:
|
85 |
-
local_j = j
|
86 |
-
elif cshape[3]-j < 8:
|
87 |
-
local_j = 16-(cshape[3]-j)
|
88 |
-
else:
|
89 |
-
local_j = 8
|
90 |
-
|
91 |
-
i_start = i-local_i
|
92 |
-
i_end = i_start+16
|
93 |
-
j_start = j-local_j
|
94 |
-
j_end = j_start+16
|
95 |
-
patch = idx[:,i_start:i_end,j_start:j_end]
|
96 |
-
patch = patch.reshape(patch.shape[0],-1)
|
97 |
-
cpatch = cidx[:, i_start:i_end, j_start:j_end]
|
98 |
-
cpatch = cpatch.reshape(cpatch.shape[0], -1)
|
99 |
-
patch = torch.cat((cpatch, patch), dim=1)
|
100 |
-
logits,_ = model.transformer(patch[:,:-1])
|
101 |
-
logits = logits[:, -256:, :]
|
102 |
-
logits = logits.reshape(cshape[0],16,16,-1)
|
103 |
-
logits = logits[:,local_i,local_j,:]
|
104 |
-
|
105 |
-
logits = logits/temperature
|
106 |
-
|
107 |
-
if top_k is not None:
|
108 |
-
logits = model.top_k_logits(logits, top_k)
|
109 |
-
# apply softmax to convert to probabilities
|
110 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
111 |
-
# sample from the distribution or take the most likely
|
112 |
-
if sample:
|
113 |
-
ix = torch.multinomial(probs, num_samples=1)
|
114 |
-
else:
|
115 |
-
_, ix = torch.topk(probs, k=1, dim=-1)
|
116 |
-
idx[:,i,j] = ix
|
117 |
-
|
118 |
-
xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
119 |
-
for i in range(xsample.shape[0]):
|
120 |
-
save_image(xsample[i], os.path.join(outdir, "samples",
|
121 |
-
"{:06}.png".format(indices[i])))
|
122 |
-
|
123 |
-
|
124 |
-
def get_parser():
|
125 |
-
parser = argparse.ArgumentParser()
|
126 |
-
parser.add_argument(
|
127 |
-
"-r",
|
128 |
-
"--resume",
|
129 |
-
type=str,
|
130 |
-
nargs="?",
|
131 |
-
help="load from logdir or checkpoint in logdir",
|
132 |
-
)
|
133 |
-
parser.add_argument(
|
134 |
-
"-b",
|
135 |
-
"--base",
|
136 |
-
nargs="*",
|
137 |
-
metavar="base_config.yaml",
|
138 |
-
help="paths to base configs. Loaded from left-to-right. "
|
139 |
-
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
140 |
-
default=list(),
|
141 |
-
)
|
142 |
-
parser.add_argument(
|
143 |
-
"-c",
|
144 |
-
"--config",
|
145 |
-
nargs="?",
|
146 |
-
metavar="single_config.yaml",
|
147 |
-
help="path to single config. If specified, base configs will be ignored "
|
148 |
-
"(except for the last one if left unspecified).",
|
149 |
-
const=True,
|
150 |
-
default="",
|
151 |
-
)
|
152 |
-
parser.add_argument(
|
153 |
-
"--ignore_base_data",
|
154 |
-
action="store_true",
|
155 |
-
help="Ignore data specification from base configs. Useful if you want "
|
156 |
-
"to specify a custom datasets on the command line.",
|
157 |
-
)
|
158 |
-
parser.add_argument(
|
159 |
-
"--outdir",
|
160 |
-
required=True,
|
161 |
-
type=str,
|
162 |
-
help="Where to write outputs to.",
|
163 |
-
)
|
164 |
-
parser.add_argument(
|
165 |
-
"--top_k",
|
166 |
-
type=int,
|
167 |
-
default=100,
|
168 |
-
help="Sample from among top-k predictions.",
|
169 |
-
)
|
170 |
-
parser.add_argument(
|
171 |
-
"--temperature",
|
172 |
-
type=float,
|
173 |
-
default=1.0,
|
174 |
-
help="Sampling temperature.",
|
175 |
-
)
|
176 |
-
return parser
|
177 |
-
|
178 |
-
|
179 |
-
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
180 |
-
if "ckpt_path" in config.params:
|
181 |
-
print("Deleting the restore-ckpt path from the config...")
|
182 |
-
config.params.ckpt_path = None
|
183 |
-
if "downsample_cond_size" in config.params:
|
184 |
-
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
|
185 |
-
config.params.downsample_cond_size = -1
|
186 |
-
config.params["downsample_cond_factor"] = 0.5
|
187 |
-
try:
|
188 |
-
if "ckpt_path" in config.params.first_stage_config.params:
|
189 |
-
config.params.first_stage_config.params.ckpt_path = None
|
190 |
-
print("Deleting the first-stage restore-ckpt path from the config...")
|
191 |
-
if "ckpt_path" in config.params.cond_stage_config.params:
|
192 |
-
config.params.cond_stage_config.params.ckpt_path = None
|
193 |
-
print("Deleting the cond-stage restore-ckpt path from the config...")
|
194 |
-
except:
|
195 |
-
pass
|
196 |
-
|
197 |
-
model = instantiate_from_config(config)
|
198 |
-
if sd is not None:
|
199 |
-
missing, unexpected = model.load_state_dict(sd, strict=False)
|
200 |
-
print(f"Missing Keys in State Dict: {missing}")
|
201 |
-
print(f"Unexpected Keys in State Dict: {unexpected}")
|
202 |
-
if gpu:
|
203 |
-
model.cuda()
|
204 |
-
if eval_mode:
|
205 |
-
model.eval()
|
206 |
-
return {"model": model}
|
207 |
-
|
208 |
-
|
209 |
-
def get_data(config):
|
210 |
-
# get data
|
211 |
-
data = instantiate_from_config(config.data)
|
212 |
-
data.prepare_data()
|
213 |
-
data.setup()
|
214 |
-
return data
|
215 |
-
|
216 |
-
|
217 |
-
def load_model_and_dset(config, ckpt, gpu, eval_mode):
|
218 |
-
# get data
|
219 |
-
dsets = get_data(config) # calls data.config ...
|
220 |
-
|
221 |
-
# now load the specified checkpoint
|
222 |
-
if ckpt:
|
223 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
224 |
-
global_step = pl_sd["global_step"]
|
225 |
-
else:
|
226 |
-
pl_sd = {"state_dict": None}
|
227 |
-
global_step = None
|
228 |
-
model = load_model_from_config(config.model,
|
229 |
-
pl_sd["state_dict"],
|
230 |
-
gpu=gpu,
|
231 |
-
eval_mode=eval_mode)["model"]
|
232 |
-
return dsets, model, global_step
|
233 |
-
|
234 |
-
|
235 |
-
if __name__ == "__main__":
|
236 |
-
sys.path.append(os.getcwd())
|
237 |
-
|
238 |
-
parser = get_parser()
|
239 |
-
|
240 |
-
opt, unknown = parser.parse_known_args()
|
241 |
-
|
242 |
-
ckpt = None
|
243 |
-
if opt.resume:
|
244 |
-
if not os.path.exists(opt.resume):
|
245 |
-
raise ValueError("Cannot find {}".format(opt.resume))
|
246 |
-
if os.path.isfile(opt.resume):
|
247 |
-
paths = opt.resume.split("/")
|
248 |
-
try:
|
249 |
-
idx = len(paths)-paths[::-1].index("logs")+1
|
250 |
-
except ValueError:
|
251 |
-
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
252 |
-
logdir = "/".join(paths[:idx])
|
253 |
-
ckpt = opt.resume
|
254 |
-
else:
|
255 |
-
assert os.path.isdir(opt.resume), opt.resume
|
256 |
-
logdir = opt.resume.rstrip("/")
|
257 |
-
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
258 |
-
print(f"logdir:{logdir}")
|
259 |
-
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
260 |
-
opt.base = base_configs+opt.base
|
261 |
-
|
262 |
-
if opt.config:
|
263 |
-
if type(opt.config) == str:
|
264 |
-
opt.base = [opt.config]
|
265 |
-
else:
|
266 |
-
opt.base = [opt.base[-1]]
|
267 |
-
|
268 |
-
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
269 |
-
cli = OmegaConf.from_dotlist(unknown)
|
270 |
-
if opt.ignore_base_data:
|
271 |
-
for config in configs:
|
272 |
-
if hasattr(config, "data"): del config["data"]
|
273 |
-
config = OmegaConf.merge(*configs, cli)
|
274 |
-
|
275 |
-
print(ckpt)
|
276 |
-
gpu = True
|
277 |
-
eval_mode = True
|
278 |
-
show_config = False
|
279 |
-
if show_config:
|
280 |
-
print(OmegaConf.to_container(config))
|
281 |
-
|
282 |
-
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
283 |
-
print(f"Global step: {global_step}")
|
284 |
-
|
285 |
-
outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
|
286 |
-
opt.top_k,
|
287 |
-
opt.temperature))
|
288 |
-
os.makedirs(outdir, exist_ok=True)
|
289 |
-
print("Writing samples to ", outdir)
|
290 |
-
for k in ["originals", "reconstructions", "samples"]:
|
291 |
-
os.makedirs(os.path.join(outdir, k), exist_ok=True)
|
292 |
-
run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/scripts/make_scene_samples.py
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
from itertools import product
|
5 |
-
from pathlib import Path
|
6 |
-
from typing import Literal, List, Optional, Tuple
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
from omegaconf import OmegaConf
|
11 |
-
from pytorch_lightning import seed_everything
|
12 |
-
from torch import Tensor
|
13 |
-
from torchvision.utils import save_image
|
14 |
-
from tqdm import tqdm
|
15 |
-
|
16 |
-
from scripts.make_samples import get_parser, load_model_and_dset
|
17 |
-
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
|
18 |
-
from taming.data.helper_types import BoundingBox, Annotation
|
19 |
-
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
|
20 |
-
from taming.models.cond_transformer import Net2NetTransformer
|
21 |
-
|
22 |
-
seed_everything(42424242)
|
23 |
-
device: Literal['cuda', 'cpu'] = 'cuda'
|
24 |
-
first_stage_factor = 16
|
25 |
-
trained_on_res = 256
|
26 |
-
|
27 |
-
|
28 |
-
def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
|
29 |
-
assert 0 <= coord < coord_max
|
30 |
-
coord_desired_center = (coord_window - 1) // 2
|
31 |
-
return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
|
32 |
-
|
33 |
-
|
34 |
-
def get_crop_coordinates(x: int, y: int) -> BoundingBox:
|
35 |
-
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
|
36 |
-
x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
|
37 |
-
y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
|
38 |
-
w = first_stage_factor / WIDTH
|
39 |
-
h = first_stage_factor / HEIGHT
|
40 |
-
return x0, y0, w, h
|
41 |
-
|
42 |
-
|
43 |
-
def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
|
44 |
-
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
|
45 |
-
x0 = _helper(predict_x, WIDTH, first_stage_factor)
|
46 |
-
y0 = _helper(predict_y, HEIGHT, first_stage_factor)
|
47 |
-
no_images = z_indices.shape[0]
|
48 |
-
cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
|
49 |
-
cut_out_2 = z_indices[:, predict_y, x0:predict_x]
|
50 |
-
return torch.cat((cut_out_1, cut_out_2), dim=1)
|
51 |
-
|
52 |
-
|
53 |
-
@torch.no_grad()
|
54 |
-
def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
|
55 |
-
conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
|
56 |
-
temperature: float, top_k: int) -> Tensor:
|
57 |
-
x_max, y_max = desired_z_shape[1], desired_z_shape[0]
|
58 |
-
|
59 |
-
annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
|
60 |
-
|
61 |
-
recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
|
62 |
-
if not recompute_conditional:
|
63 |
-
crop_coordinates = get_crop_coordinates(0, 0)
|
64 |
-
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
|
65 |
-
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
|
66 |
-
z_indices = torch.zeros((no_samples, 0), device=device).long()
|
67 |
-
output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
|
68 |
-
sample=True, top_k=top_k)
|
69 |
-
else:
|
70 |
-
output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
|
71 |
-
for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
|
72 |
-
crop_coordinates = get_crop_coordinates(predict_x, predict_y)
|
73 |
-
z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
|
74 |
-
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
|
75 |
-
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
|
76 |
-
new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
|
77 |
-
output_indices[:, predict_y, predict_x] = new_index[:, -1]
|
78 |
-
z_shape = (
|
79 |
-
no_samples,
|
80 |
-
model.first_stage_model.quantize.e_dim, # codebook embed_dim
|
81 |
-
desired_z_shape[0], # z_height
|
82 |
-
desired_z_shape[1] # z_width
|
83 |
-
)
|
84 |
-
x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
|
85 |
-
x_sample = x_sample.to('cpu')
|
86 |
-
|
87 |
-
plotter = conditional_builder.plot
|
88 |
-
figure_size = (x_sample.shape[2], x_sample.shape[3])
|
89 |
-
scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
|
90 |
-
plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
|
91 |
-
return torch.cat((x_sample, plot.unsqueeze(0)))
|
92 |
-
|
93 |
-
|
94 |
-
def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
|
95 |
-
if not resolution_str.count(',') == 1:
|
96 |
-
raise ValueError("Give resolution as in 'height,width'")
|
97 |
-
res_h, res_w = resolution_str.split(',')
|
98 |
-
res_h = max(int(res_h), trained_on_res)
|
99 |
-
res_w = max(int(res_w), trained_on_res)
|
100 |
-
z_h = int(round(res_h/first_stage_factor))
|
101 |
-
z_w = int(round(res_w/first_stage_factor))
|
102 |
-
return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
|
103 |
-
|
104 |
-
|
105 |
-
def add_arg_to_parser(parser):
|
106 |
-
parser.add_argument(
|
107 |
-
"-R",
|
108 |
-
"--resolution",
|
109 |
-
type=str,
|
110 |
-
default='256,256',
|
111 |
-
help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
|
112 |
-
)
|
113 |
-
parser.add_argument(
|
114 |
-
"-C",
|
115 |
-
"--conditional",
|
116 |
-
type=str,
|
117 |
-
default='objects_bbox',
|
118 |
-
help=f"objects_bbox or objects_center_points",
|
119 |
-
)
|
120 |
-
parser.add_argument(
|
121 |
-
"-N",
|
122 |
-
"--n_samples_per_layout",
|
123 |
-
type=int,
|
124 |
-
default=4,
|
125 |
-
help=f"how many samples to generate per layout",
|
126 |
-
)
|
127 |
-
return parser
|
128 |
-
|
129 |
-
|
130 |
-
if __name__ == "__main__":
|
131 |
-
sys.path.append(os.getcwd())
|
132 |
-
|
133 |
-
parser = get_parser()
|
134 |
-
parser = add_arg_to_parser(parser)
|
135 |
-
|
136 |
-
opt, unknown = parser.parse_known_args()
|
137 |
-
|
138 |
-
ckpt = None
|
139 |
-
if opt.resume:
|
140 |
-
if not os.path.exists(opt.resume):
|
141 |
-
raise ValueError("Cannot find {}".format(opt.resume))
|
142 |
-
if os.path.isfile(opt.resume):
|
143 |
-
paths = opt.resume.split("/")
|
144 |
-
try:
|
145 |
-
idx = len(paths)-paths[::-1].index("logs")+1
|
146 |
-
except ValueError:
|
147 |
-
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
148 |
-
logdir = "/".join(paths[:idx])
|
149 |
-
ckpt = opt.resume
|
150 |
-
else:
|
151 |
-
assert os.path.isdir(opt.resume), opt.resume
|
152 |
-
logdir = opt.resume.rstrip("/")
|
153 |
-
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
154 |
-
print(f"logdir:{logdir}")
|
155 |
-
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
156 |
-
opt.base = base_configs+opt.base
|
157 |
-
|
158 |
-
if opt.config:
|
159 |
-
if type(opt.config) == str:
|
160 |
-
opt.base = [opt.config]
|
161 |
-
else:
|
162 |
-
opt.base = [opt.base[-1]]
|
163 |
-
|
164 |
-
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
165 |
-
cli = OmegaConf.from_dotlist(unknown)
|
166 |
-
if opt.ignore_base_data:
|
167 |
-
for config in configs:
|
168 |
-
if hasattr(config, "data"):
|
169 |
-
del config["data"]
|
170 |
-
config = OmegaConf.merge(*configs, cli)
|
171 |
-
desired_z_shape, desired_resolution = get_resolution(opt.resolution)
|
172 |
-
conditional = opt.conditional
|
173 |
-
|
174 |
-
print(ckpt)
|
175 |
-
gpu = True
|
176 |
-
eval_mode = True
|
177 |
-
show_config = False
|
178 |
-
if show_config:
|
179 |
-
print(OmegaConf.to_container(config))
|
180 |
-
|
181 |
-
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
182 |
-
print(f"Global step: {global_step}")
|
183 |
-
|
184 |
-
data_loader = dsets.val_dataloader()
|
185 |
-
print(dsets.datasets["validation"].conditional_builders)
|
186 |
-
conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
|
187 |
-
|
188 |
-
outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
|
189 |
-
outdir.mkdir(exist_ok=True, parents=True)
|
190 |
-
print("Writing samples to ", outdir)
|
191 |
-
|
192 |
-
p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
|
193 |
-
for batch_no, batch in p_bar_1:
|
194 |
-
save_img: Optional[Tensor] = None
|
195 |
-
for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
|
196 |
-
imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
|
197 |
-
opt.n_samples_per_layout, opt.temperature, opt.top_k)
|
198 |
-
save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/scripts/sample_conditional.py
DELETED
@@ -1,355 +0,0 @@
|
|
1 |
-
import argparse, os, sys, glob, math, time
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from omegaconf import OmegaConf
|
5 |
-
import streamlit as st
|
6 |
-
from streamlit import caching
|
7 |
-
from PIL import Image
|
8 |
-
from main import instantiate_from_config, DataModuleFromConfig
|
9 |
-
from torch.utils.data import DataLoader
|
10 |
-
from torch.utils.data.dataloader import default_collate
|
11 |
-
|
12 |
-
|
13 |
-
rescale = lambda x: (x + 1.) / 2.
|
14 |
-
|
15 |
-
|
16 |
-
def bchw_to_st(x):
|
17 |
-
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
|
18 |
-
|
19 |
-
def save_img(xstart, fname):
|
20 |
-
I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
|
21 |
-
Image.fromarray(I).save(fname)
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
def get_interactive_image(resize=False):
|
26 |
-
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
27 |
-
if image is not None:
|
28 |
-
image = Image.open(image)
|
29 |
-
if not image.mode == "RGB":
|
30 |
-
image = image.convert("RGB")
|
31 |
-
image = np.array(image).astype(np.uint8)
|
32 |
-
print("upload image shape: {}".format(image.shape))
|
33 |
-
img = Image.fromarray(image)
|
34 |
-
if resize:
|
35 |
-
img = img.resize((256, 256))
|
36 |
-
image = np.array(img)
|
37 |
-
return image
|
38 |
-
|
39 |
-
|
40 |
-
def single_image_to_torch(x, permute=True):
|
41 |
-
assert x is not None, "Please provide an image through the upload function"
|
42 |
-
x = np.array(x)
|
43 |
-
x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
|
44 |
-
if permute:
|
45 |
-
x = x.permute(0, 3, 1, 2)
|
46 |
-
return x
|
47 |
-
|
48 |
-
|
49 |
-
def pad_to_M(x, M):
|
50 |
-
hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
|
51 |
-
wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
|
52 |
-
x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
|
53 |
-
return x
|
54 |
-
|
55 |
-
@torch.no_grad()
|
56 |
-
def run_conditional(model, dsets):
|
57 |
-
if len(dsets.datasets) > 1:
|
58 |
-
split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
|
59 |
-
dset = dsets.datasets[split]
|
60 |
-
else:
|
61 |
-
dset = next(iter(dsets.datasets.values()))
|
62 |
-
batch_size = 1
|
63 |
-
start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
|
64 |
-
min_value=0,
|
65 |
-
max_value=len(dset)-batch_size)
|
66 |
-
indices = list(range(start_index, start_index+batch_size))
|
67 |
-
|
68 |
-
example = default_collate([dset[i] for i in indices])
|
69 |
-
|
70 |
-
x = model.get_input("image", example).to(model.device)
|
71 |
-
|
72 |
-
cond_key = model.cond_stage_key
|
73 |
-
c = model.get_input(cond_key, example).to(model.device)
|
74 |
-
|
75 |
-
scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
|
76 |
-
if scale_factor != 1.0:
|
77 |
-
x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
|
78 |
-
c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
|
79 |
-
|
80 |
-
quant_z, z_indices = model.encode_to_z(x)
|
81 |
-
quant_c, c_indices = model.encode_to_c(c)
|
82 |
-
|
83 |
-
cshape = quant_z.shape
|
84 |
-
|
85 |
-
xrec = model.first_stage_model.decode(quant_z)
|
86 |
-
st.write("image: {}".format(x.shape))
|
87 |
-
st.image(bchw_to_st(x), clamp=True, output_format="PNG")
|
88 |
-
st.write("image reconstruction: {}".format(xrec.shape))
|
89 |
-
st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
|
90 |
-
|
91 |
-
if cond_key == "segmentation":
|
92 |
-
# get image from segmentation mask
|
93 |
-
num_classes = c.shape[1]
|
94 |
-
c = torch.argmax(c, dim=1, keepdim=True)
|
95 |
-
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
|
96 |
-
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
97 |
-
c = model.cond_stage_model.to_rgb(c)
|
98 |
-
|
99 |
-
st.write(f"{cond_key}: {tuple(c.shape)}")
|
100 |
-
st.image(bchw_to_st(c), clamp=True, output_format="PNG")
|
101 |
-
|
102 |
-
idx = z_indices
|
103 |
-
|
104 |
-
half_sample = st.sidebar.checkbox("Image Completion", value=False)
|
105 |
-
if half_sample:
|
106 |
-
start = idx.shape[1]//2
|
107 |
-
else:
|
108 |
-
start = 0
|
109 |
-
|
110 |
-
idx[:,start:] = 0
|
111 |
-
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
|
112 |
-
start_i = start//cshape[3]
|
113 |
-
start_j = start %cshape[3]
|
114 |
-
|
115 |
-
if not half_sample and quant_z.shape == quant_c.shape:
|
116 |
-
st.info("Setting idx to c_indices")
|
117 |
-
idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
|
118 |
-
|
119 |
-
cidx = c_indices
|
120 |
-
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
|
121 |
-
|
122 |
-
xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
123 |
-
st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
|
124 |
-
|
125 |
-
temperature = st.number_input("Temperature", value=1.0)
|
126 |
-
top_k = st.number_input("Top k", value=100)
|
127 |
-
sample = st.checkbox("Sample", value=True)
|
128 |
-
update_every = st.number_input("Update every", value=75)
|
129 |
-
|
130 |
-
st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
|
131 |
-
|
132 |
-
animate = st.checkbox("animate")
|
133 |
-
if animate:
|
134 |
-
import imageio
|
135 |
-
outvid = "sampling.mp4"
|
136 |
-
writer = imageio.get_writer(outvid, fps=25)
|
137 |
-
elapsed_t = st.empty()
|
138 |
-
info = st.empty()
|
139 |
-
st.text("Sampled")
|
140 |
-
if st.button("Sample"):
|
141 |
-
output = st.empty()
|
142 |
-
start_t = time.time()
|
143 |
-
for i in range(start_i,cshape[2]-0):
|
144 |
-
if i <= 8:
|
145 |
-
local_i = i
|
146 |
-
elif cshape[2]-i < 8:
|
147 |
-
local_i = 16-(cshape[2]-i)
|
148 |
-
else:
|
149 |
-
local_i = 8
|
150 |
-
for j in range(start_j,cshape[3]-0):
|
151 |
-
if j <= 8:
|
152 |
-
local_j = j
|
153 |
-
elif cshape[3]-j < 8:
|
154 |
-
local_j = 16-(cshape[3]-j)
|
155 |
-
else:
|
156 |
-
local_j = 8
|
157 |
-
|
158 |
-
i_start = i-local_i
|
159 |
-
i_end = i_start+16
|
160 |
-
j_start = j-local_j
|
161 |
-
j_end = j_start+16
|
162 |
-
elapsed_t.text(f"Time: {time.time() - start_t} seconds")
|
163 |
-
info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
|
164 |
-
patch = idx[:,i_start:i_end,j_start:j_end]
|
165 |
-
patch = patch.reshape(patch.shape[0],-1)
|
166 |
-
cpatch = cidx[:, i_start:i_end, j_start:j_end]
|
167 |
-
cpatch = cpatch.reshape(cpatch.shape[0], -1)
|
168 |
-
patch = torch.cat((cpatch, patch), dim=1)
|
169 |
-
logits,_ = model.transformer(patch[:,:-1])
|
170 |
-
logits = logits[:, -256:, :]
|
171 |
-
logits = logits.reshape(cshape[0],16,16,-1)
|
172 |
-
logits = logits[:,local_i,local_j,:]
|
173 |
-
|
174 |
-
logits = logits/temperature
|
175 |
-
|
176 |
-
if top_k is not None:
|
177 |
-
logits = model.top_k_logits(logits, top_k)
|
178 |
-
# apply softmax to convert to probabilities
|
179 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
180 |
-
# sample from the distribution or take the most likely
|
181 |
-
if sample:
|
182 |
-
ix = torch.multinomial(probs, num_samples=1)
|
183 |
-
else:
|
184 |
-
_, ix = torch.topk(probs, k=1, dim=-1)
|
185 |
-
idx[:,i,j] = ix
|
186 |
-
|
187 |
-
if (i*cshape[3]+j)%update_every==0:
|
188 |
-
xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
|
189 |
-
|
190 |
-
xstart = bchw_to_st(xstart)
|
191 |
-
output.image(xstart, clamp=True, output_format="PNG")
|
192 |
-
|
193 |
-
if animate:
|
194 |
-
writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
|
195 |
-
|
196 |
-
xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
197 |
-
xstart = bchw_to_st(xstart)
|
198 |
-
output.image(xstart, clamp=True, output_format="PNG")
|
199 |
-
#save_img(xstart, "full_res_sample.png")
|
200 |
-
if animate:
|
201 |
-
writer.close()
|
202 |
-
st.video(outvid)
|
203 |
-
|
204 |
-
|
205 |
-
def get_parser():
|
206 |
-
parser = argparse.ArgumentParser()
|
207 |
-
parser.add_argument(
|
208 |
-
"-r",
|
209 |
-
"--resume",
|
210 |
-
type=str,
|
211 |
-
nargs="?",
|
212 |
-
help="load from logdir or checkpoint in logdir",
|
213 |
-
)
|
214 |
-
parser.add_argument(
|
215 |
-
"-b",
|
216 |
-
"--base",
|
217 |
-
nargs="*",
|
218 |
-
metavar="base_config.yaml",
|
219 |
-
help="paths to base configs. Loaded from left-to-right. "
|
220 |
-
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
221 |
-
default=list(),
|
222 |
-
)
|
223 |
-
parser.add_argument(
|
224 |
-
"-c",
|
225 |
-
"--config",
|
226 |
-
nargs="?",
|
227 |
-
metavar="single_config.yaml",
|
228 |
-
help="path to single config. If specified, base configs will be ignored "
|
229 |
-
"(except for the last one if left unspecified).",
|
230 |
-
const=True,
|
231 |
-
default="",
|
232 |
-
)
|
233 |
-
parser.add_argument(
|
234 |
-
"--ignore_base_data",
|
235 |
-
action="store_true",
|
236 |
-
help="Ignore data specification from base configs. Useful if you want "
|
237 |
-
"to specify a custom datasets on the command line.",
|
238 |
-
)
|
239 |
-
return parser
|
240 |
-
|
241 |
-
|
242 |
-
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
243 |
-
if "ckpt_path" in config.params:
|
244 |
-
st.warning("Deleting the restore-ckpt path from the config...")
|
245 |
-
config.params.ckpt_path = None
|
246 |
-
if "downsample_cond_size" in config.params:
|
247 |
-
st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
|
248 |
-
config.params.downsample_cond_size = -1
|
249 |
-
config.params["downsample_cond_factor"] = 0.5
|
250 |
-
try:
|
251 |
-
if "ckpt_path" in config.params.first_stage_config.params:
|
252 |
-
config.params.first_stage_config.params.ckpt_path = None
|
253 |
-
st.warning("Deleting the first-stage restore-ckpt path from the config...")
|
254 |
-
if "ckpt_path" in config.params.cond_stage_config.params:
|
255 |
-
config.params.cond_stage_config.params.ckpt_path = None
|
256 |
-
st.warning("Deleting the cond-stage restore-ckpt path from the config...")
|
257 |
-
except:
|
258 |
-
pass
|
259 |
-
|
260 |
-
model = instantiate_from_config(config)
|
261 |
-
if sd is not None:
|
262 |
-
missing, unexpected = model.load_state_dict(sd, strict=False)
|
263 |
-
st.info(f"Missing Keys in State Dict: {missing}")
|
264 |
-
st.info(f"Unexpected Keys in State Dict: {unexpected}")
|
265 |
-
if gpu:
|
266 |
-
model.cuda()
|
267 |
-
if eval_mode:
|
268 |
-
model.eval()
|
269 |
-
return {"model": model}
|
270 |
-
|
271 |
-
|
272 |
-
def get_data(config):
|
273 |
-
# get data
|
274 |
-
data = instantiate_from_config(config.data)
|
275 |
-
data.prepare_data()
|
276 |
-
data.setup()
|
277 |
-
return data
|
278 |
-
|
279 |
-
|
280 |
-
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
281 |
-
def load_model_and_dset(config, ckpt, gpu, eval_mode):
|
282 |
-
# get data
|
283 |
-
dsets = get_data(config) # calls data.config ...
|
284 |
-
|
285 |
-
# now load the specified checkpoint
|
286 |
-
if ckpt:
|
287 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
288 |
-
global_step = pl_sd["global_step"]
|
289 |
-
else:
|
290 |
-
pl_sd = {"state_dict": None}
|
291 |
-
global_step = None
|
292 |
-
model = load_model_from_config(config.model,
|
293 |
-
pl_sd["state_dict"],
|
294 |
-
gpu=gpu,
|
295 |
-
eval_mode=eval_mode)["model"]
|
296 |
-
return dsets, model, global_step
|
297 |
-
|
298 |
-
|
299 |
-
if __name__ == "__main__":
|
300 |
-
sys.path.append(os.getcwd())
|
301 |
-
|
302 |
-
parser = get_parser()
|
303 |
-
|
304 |
-
opt, unknown = parser.parse_known_args()
|
305 |
-
|
306 |
-
ckpt = None
|
307 |
-
if opt.resume:
|
308 |
-
if not os.path.exists(opt.resume):
|
309 |
-
raise ValueError("Cannot find {}".format(opt.resume))
|
310 |
-
if os.path.isfile(opt.resume):
|
311 |
-
paths = opt.resume.split("/")
|
312 |
-
try:
|
313 |
-
idx = len(paths)-paths[::-1].index("logs")+1
|
314 |
-
except ValueError:
|
315 |
-
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
316 |
-
logdir = "/".join(paths[:idx])
|
317 |
-
ckpt = opt.resume
|
318 |
-
else:
|
319 |
-
assert os.path.isdir(opt.resume), opt.resume
|
320 |
-
logdir = opt.resume.rstrip("/")
|
321 |
-
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
322 |
-
print(f"logdir:{logdir}")
|
323 |
-
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
324 |
-
opt.base = base_configs+opt.base
|
325 |
-
|
326 |
-
if opt.config:
|
327 |
-
if type(opt.config) == str:
|
328 |
-
opt.base = [opt.config]
|
329 |
-
else:
|
330 |
-
opt.base = [opt.base[-1]]
|
331 |
-
|
332 |
-
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
333 |
-
cli = OmegaConf.from_dotlist(unknown)
|
334 |
-
if opt.ignore_base_data:
|
335 |
-
for config in configs:
|
336 |
-
if hasattr(config, "data"): del config["data"]
|
337 |
-
config = OmegaConf.merge(*configs, cli)
|
338 |
-
|
339 |
-
st.sidebar.text(ckpt)
|
340 |
-
gs = st.sidebar.empty()
|
341 |
-
gs.text(f"Global step: ?")
|
342 |
-
st.sidebar.text("Options")
|
343 |
-
#gpu = st.sidebar.checkbox("GPU", value=True)
|
344 |
-
gpu = True
|
345 |
-
#eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
|
346 |
-
eval_mode = True
|
347 |
-
#show_config = st.sidebar.checkbox("Show Config", value=False)
|
348 |
-
show_config = False
|
349 |
-
if show_config:
|
350 |
-
st.info("Checkpoint: {}".format(ckpt))
|
351 |
-
st.json(OmegaConf.to_container(config))
|
352 |
-
|
353 |
-
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
354 |
-
gs.text(f"Global step: {global_step}")
|
355 |
-
run_conditional(model, dsets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/scripts/sample_fast.py
DELETED
@@ -1,260 +0,0 @@
|
|
1 |
-
import argparse, os, sys, glob
|
2 |
-
import torch
|
3 |
-
import time
|
4 |
-
import numpy as np
|
5 |
-
from omegaconf import OmegaConf
|
6 |
-
from PIL import Image
|
7 |
-
from tqdm import tqdm, trange
|
8 |
-
from einops import repeat
|
9 |
-
|
10 |
-
from main import instantiate_from_config
|
11 |
-
from taming.modules.transformer.mingpt import sample_with_past
|
12 |
-
|
13 |
-
|
14 |
-
rescale = lambda x: (x + 1.) / 2.
|
15 |
-
|
16 |
-
|
17 |
-
def chw_to_pillow(x):
|
18 |
-
return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
|
19 |
-
|
20 |
-
|
21 |
-
@torch.no_grad()
|
22 |
-
def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
|
23 |
-
dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
|
24 |
-
log = dict()
|
25 |
-
assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
|
26 |
-
qzshape = [batch_size, dim_z, h, w]
|
27 |
-
assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
|
28 |
-
c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
|
29 |
-
t1 = time.time()
|
30 |
-
index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
|
31 |
-
sample_logits=True, top_k=top_k, callback=callback,
|
32 |
-
temperature=temperature, top_p=top_p)
|
33 |
-
if verbose_time:
|
34 |
-
sampling_time = time.time() - t1
|
35 |
-
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
36 |
-
x_sample = model.decode_to_img(index_sample, qzshape)
|
37 |
-
log["samples"] = x_sample
|
38 |
-
log["class_label"] = c_indices
|
39 |
-
return log
|
40 |
-
|
41 |
-
|
42 |
-
@torch.no_grad()
|
43 |
-
def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
|
44 |
-
dim_z=256, h=16, w=16, verbose_time=False):
|
45 |
-
log = dict()
|
46 |
-
qzshape = [batch_size, dim_z, h, w]
|
47 |
-
assert model.be_unconditional, 'Expecting an unconditional model.'
|
48 |
-
c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
|
49 |
-
t1 = time.time()
|
50 |
-
index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
|
51 |
-
sample_logits=True, top_k=top_k, callback=callback,
|
52 |
-
temperature=temperature, top_p=top_p)
|
53 |
-
if verbose_time:
|
54 |
-
sampling_time = time.time() - t1
|
55 |
-
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
56 |
-
x_sample = model.decode_to_img(index_sample, qzshape)
|
57 |
-
log["samples"] = x_sample
|
58 |
-
return log
|
59 |
-
|
60 |
-
|
61 |
-
@torch.no_grad()
|
62 |
-
def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
|
63 |
-
given_classes=None, top_p=None):
|
64 |
-
batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
|
65 |
-
if not unconditional:
|
66 |
-
assert given_classes is not None
|
67 |
-
print("Running in pure class-conditional sampling mode. I will produce "
|
68 |
-
f"{num_samples} samples for each of the {len(given_classes)} classes, "
|
69 |
-
f"i.e. {num_samples*len(given_classes)} in total.")
|
70 |
-
for class_label in tqdm(given_classes, desc="Classes"):
|
71 |
-
for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
|
72 |
-
if bs == 0: break
|
73 |
-
logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
|
74 |
-
temperature=temperature, top_k=top_k, top_p=top_p)
|
75 |
-
save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
|
76 |
-
else:
|
77 |
-
print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
|
78 |
-
for n, bs in tqdm(enumerate(batches), desc="Sampling"):
|
79 |
-
if bs == 0: break
|
80 |
-
logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
|
81 |
-
save_from_logs(logs, logdir, base_count=n * batch_size)
|
82 |
-
|
83 |
-
|
84 |
-
def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
|
85 |
-
xx = logs[key]
|
86 |
-
for i, x in enumerate(xx):
|
87 |
-
x = chw_to_pillow(x)
|
88 |
-
count = base_count + i
|
89 |
-
if cond_key is None:
|
90 |
-
x.save(os.path.join(logdir, f"{count:06}.png"))
|
91 |
-
else:
|
92 |
-
condlabel = cond_key[i]
|
93 |
-
if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
|
94 |
-
os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
|
95 |
-
x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
|
96 |
-
|
97 |
-
|
98 |
-
def get_parser():
|
99 |
-
def str2bool(v):
|
100 |
-
if isinstance(v, bool):
|
101 |
-
return v
|
102 |
-
if v.lower() in ("yes", "true", "t", "y", "1"):
|
103 |
-
return True
|
104 |
-
elif v.lower() in ("no", "false", "f", "n", "0"):
|
105 |
-
return False
|
106 |
-
else:
|
107 |
-
raise argparse.ArgumentTypeError("Boolean value expected.")
|
108 |
-
|
109 |
-
parser = argparse.ArgumentParser()
|
110 |
-
parser.add_argument(
|
111 |
-
"-r",
|
112 |
-
"--resume",
|
113 |
-
type=str,
|
114 |
-
nargs="?",
|
115 |
-
help="load from logdir or checkpoint in logdir",
|
116 |
-
)
|
117 |
-
parser.add_argument(
|
118 |
-
"-o",
|
119 |
-
"--outdir",
|
120 |
-
type=str,
|
121 |
-
nargs="?",
|
122 |
-
help="path where the samples will be logged to.",
|
123 |
-
default=""
|
124 |
-
)
|
125 |
-
parser.add_argument(
|
126 |
-
"-b",
|
127 |
-
"--base",
|
128 |
-
nargs="*",
|
129 |
-
metavar="base_config.yaml",
|
130 |
-
help="paths to base configs. Loaded from left-to-right. "
|
131 |
-
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
132 |
-
default=list(),
|
133 |
-
)
|
134 |
-
parser.add_argument(
|
135 |
-
"-n",
|
136 |
-
"--num_samples",
|
137 |
-
type=int,
|
138 |
-
nargs="?",
|
139 |
-
help="num_samples to draw",
|
140 |
-
default=50000
|
141 |
-
)
|
142 |
-
parser.add_argument(
|
143 |
-
"--batch_size",
|
144 |
-
type=int,
|
145 |
-
nargs="?",
|
146 |
-
help="the batch size",
|
147 |
-
default=25
|
148 |
-
)
|
149 |
-
parser.add_argument(
|
150 |
-
"-k",
|
151 |
-
"--top_k",
|
152 |
-
type=int,
|
153 |
-
nargs="?",
|
154 |
-
help="top-k value to sample with",
|
155 |
-
default=250,
|
156 |
-
)
|
157 |
-
parser.add_argument(
|
158 |
-
"-t",
|
159 |
-
"--temperature",
|
160 |
-
type=float,
|
161 |
-
nargs="?",
|
162 |
-
help="temperature value to sample with",
|
163 |
-
default=1.0
|
164 |
-
)
|
165 |
-
parser.add_argument(
|
166 |
-
"-p",
|
167 |
-
"--top_p",
|
168 |
-
type=float,
|
169 |
-
nargs="?",
|
170 |
-
help="top-p value to sample with",
|
171 |
-
default=1.0
|
172 |
-
)
|
173 |
-
parser.add_argument(
|
174 |
-
"--classes",
|
175 |
-
type=str,
|
176 |
-
nargs="?",
|
177 |
-
help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
|
178 |
-
default="imagenet"
|
179 |
-
)
|
180 |
-
return parser
|
181 |
-
|
182 |
-
|
183 |
-
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
184 |
-
model = instantiate_from_config(config)
|
185 |
-
if sd is not None:
|
186 |
-
model.load_state_dict(sd)
|
187 |
-
if gpu:
|
188 |
-
model.cuda()
|
189 |
-
if eval_mode:
|
190 |
-
model.eval()
|
191 |
-
return {"model": model}
|
192 |
-
|
193 |
-
|
194 |
-
def load_model(config, ckpt, gpu, eval_mode):
|
195 |
-
# load the specified checkpoint
|
196 |
-
if ckpt:
|
197 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
198 |
-
global_step = pl_sd["global_step"]
|
199 |
-
print(f"loaded model from global step {global_step}.")
|
200 |
-
else:
|
201 |
-
pl_sd = {"state_dict": None}
|
202 |
-
global_step = None
|
203 |
-
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
|
204 |
-
return model, global_step
|
205 |
-
|
206 |
-
|
207 |
-
if __name__ == "__main__":
|
208 |
-
sys.path.append(os.getcwd())
|
209 |
-
parser = get_parser()
|
210 |
-
|
211 |
-
opt, unknown = parser.parse_known_args()
|
212 |
-
assert opt.resume
|
213 |
-
|
214 |
-
ckpt = None
|
215 |
-
|
216 |
-
if not os.path.exists(opt.resume):
|
217 |
-
raise ValueError("Cannot find {}".format(opt.resume))
|
218 |
-
if os.path.isfile(opt.resume):
|
219 |
-
paths = opt.resume.split("/")
|
220 |
-
try:
|
221 |
-
idx = len(paths)-paths[::-1].index("logs")+1
|
222 |
-
except ValueError:
|
223 |
-
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
224 |
-
logdir = "/".join(paths[:idx])
|
225 |
-
ckpt = opt.resume
|
226 |
-
else:
|
227 |
-
assert os.path.isdir(opt.resume), opt.resume
|
228 |
-
logdir = opt.resume.rstrip("/")
|
229 |
-
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
230 |
-
|
231 |
-
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
232 |
-
opt.base = base_configs+opt.base
|
233 |
-
|
234 |
-
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
235 |
-
cli = OmegaConf.from_dotlist(unknown)
|
236 |
-
config = OmegaConf.merge(*configs, cli)
|
237 |
-
|
238 |
-
model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
|
239 |
-
|
240 |
-
if opt.outdir:
|
241 |
-
print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
|
242 |
-
logdir = opt.outdir
|
243 |
-
|
244 |
-
if opt.classes == "imagenet":
|
245 |
-
given_classes = [i for i in range(1000)]
|
246 |
-
else:
|
247 |
-
cls_str = opt.classes
|
248 |
-
assert not cls_str.endswith(","), 'class string should not end with a ","'
|
249 |
-
given_classes = [int(c) for c in cls_str.split(",")]
|
250 |
-
|
251 |
-
logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
|
252 |
-
f"{global_step}")
|
253 |
-
|
254 |
-
print(f"Logging to {logdir}")
|
255 |
-
os.makedirs(logdir, exist_ok=True)
|
256 |
-
|
257 |
-
run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
|
258 |
-
given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
|
259 |
-
|
260 |
-
print("done.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/setup.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
from setuptools import setup, find_packages
|
2 |
-
|
3 |
-
setup(
|
4 |
-
name='taming-transformers',
|
5 |
-
version='0.0.1',
|
6 |
-
description='Taming Transformers for High-Resolution Image Synthesis',
|
7 |
-
packages=find_packages(),
|
8 |
-
install_requires=[
|
9 |
-
'torch',
|
10 |
-
'numpy',
|
11 |
-
'tqdm',
|
12 |
-
],
|
13 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/lr_scheduler.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
|
4 |
-
class LambdaWarmUpCosineScheduler:
|
5 |
-
"""
|
6 |
-
note: use with a base_lr of 1.0
|
7 |
-
"""
|
8 |
-
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
-
self.lr_warm_up_steps = warm_up_steps
|
10 |
-
self.lr_start = lr_start
|
11 |
-
self.lr_min = lr_min
|
12 |
-
self.lr_max = lr_max
|
13 |
-
self.lr_max_decay_steps = max_decay_steps
|
14 |
-
self.last_lr = 0.
|
15 |
-
self.verbosity_interval = verbosity_interval
|
16 |
-
|
17 |
-
def schedule(self, n):
|
18 |
-
if self.verbosity_interval > 0:
|
19 |
-
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
-
if n < self.lr_warm_up_steps:
|
21 |
-
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
-
self.last_lr = lr
|
23 |
-
return lr
|
24 |
-
else:
|
25 |
-
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
-
t = min(t, 1.0)
|
27 |
-
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
-
1 + np.cos(t * np.pi))
|
29 |
-
self.last_lr = lr
|
30 |
-
return lr
|
31 |
-
|
32 |
-
def __call__(self, n):
|
33 |
-
return self.schedule(n)
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/models/cond_transformer.py
DELETED
@@ -1,352 +0,0 @@
|
|
1 |
-
import os, math
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import pytorch_lightning as pl
|
5 |
-
|
6 |
-
from main import instantiate_from_config
|
7 |
-
from taming.modules.util import SOSProvider
|
8 |
-
|
9 |
-
|
10 |
-
def disabled_train(self, mode=True):
|
11 |
-
"""Overwrite model.train with this function to make sure train/eval mode
|
12 |
-
does not change anymore."""
|
13 |
-
return self
|
14 |
-
|
15 |
-
|
16 |
-
class Net2NetTransformer(pl.LightningModule):
|
17 |
-
def __init__(self,
|
18 |
-
transformer_config,
|
19 |
-
first_stage_config,
|
20 |
-
cond_stage_config,
|
21 |
-
permuter_config=None,
|
22 |
-
ckpt_path=None,
|
23 |
-
ignore_keys=[],
|
24 |
-
first_stage_key="image",
|
25 |
-
cond_stage_key="depth",
|
26 |
-
downsample_cond_size=-1,
|
27 |
-
pkeep=1.0,
|
28 |
-
sos_token=0,
|
29 |
-
unconditional=False,
|
30 |
-
):
|
31 |
-
super().__init__()
|
32 |
-
self.be_unconditional = unconditional
|
33 |
-
self.sos_token = sos_token
|
34 |
-
self.first_stage_key = first_stage_key
|
35 |
-
self.cond_stage_key = cond_stage_key
|
36 |
-
self.init_first_stage_from_ckpt(first_stage_config)
|
37 |
-
self.init_cond_stage_from_ckpt(cond_stage_config)
|
38 |
-
if permuter_config is None:
|
39 |
-
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
|
40 |
-
self.permuter = instantiate_from_config(config=permuter_config)
|
41 |
-
self.transformer = instantiate_from_config(config=transformer_config)
|
42 |
-
|
43 |
-
if ckpt_path is not None:
|
44 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
-
self.downsample_cond_size = downsample_cond_size
|
46 |
-
self.pkeep = pkeep
|
47 |
-
|
48 |
-
def init_from_ckpt(self, path, ignore_keys=list()):
|
49 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
50 |
-
for k in sd.keys():
|
51 |
-
for ik in ignore_keys:
|
52 |
-
if k.startswith(ik):
|
53 |
-
self.print("Deleting key {} from state_dict.".format(k))
|
54 |
-
del sd[k]
|
55 |
-
self.load_state_dict(sd, strict=False)
|
56 |
-
print(f"Restored from {path}")
|
57 |
-
|
58 |
-
def init_first_stage_from_ckpt(self, config):
|
59 |
-
model = instantiate_from_config(config)
|
60 |
-
model = model.eval()
|
61 |
-
model.train = disabled_train
|
62 |
-
self.first_stage_model = model
|
63 |
-
|
64 |
-
def init_cond_stage_from_ckpt(self, config):
|
65 |
-
if config == "__is_first_stage__":
|
66 |
-
print("Using first stage also as cond stage.")
|
67 |
-
self.cond_stage_model = self.first_stage_model
|
68 |
-
elif config == "__is_unconditional__" or self.be_unconditional:
|
69 |
-
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
|
70 |
-
f"Prepending {self.sos_token} as a sos token.")
|
71 |
-
self.be_unconditional = True
|
72 |
-
self.cond_stage_key = self.first_stage_key
|
73 |
-
self.cond_stage_model = SOSProvider(self.sos_token)
|
74 |
-
else:
|
75 |
-
model = instantiate_from_config(config)
|
76 |
-
model = model.eval()
|
77 |
-
model.train = disabled_train
|
78 |
-
self.cond_stage_model = model
|
79 |
-
|
80 |
-
def forward(self, x, c):
|
81 |
-
# one step to produce the logits
|
82 |
-
_, z_indices = self.encode_to_z(x)
|
83 |
-
_, c_indices = self.encode_to_c(c)
|
84 |
-
|
85 |
-
if self.training and self.pkeep < 1.0:
|
86 |
-
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
|
87 |
-
device=z_indices.device))
|
88 |
-
mask = mask.round().to(dtype=torch.int64)
|
89 |
-
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
|
90 |
-
a_indices = mask*z_indices+(1-mask)*r_indices
|
91 |
-
else:
|
92 |
-
a_indices = z_indices
|
93 |
-
|
94 |
-
cz_indices = torch.cat((c_indices, a_indices), dim=1)
|
95 |
-
|
96 |
-
# target includes all sequence elements (no need to handle first one
|
97 |
-
# differently because we are conditioning)
|
98 |
-
target = z_indices
|
99 |
-
# make the prediction
|
100 |
-
logits, _ = self.transformer(cz_indices[:, :-1])
|
101 |
-
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
|
102 |
-
logits = logits[:, c_indices.shape[1]-1:]
|
103 |
-
|
104 |
-
return logits, target
|
105 |
-
|
106 |
-
def top_k_logits(self, logits, k):
|
107 |
-
v, ix = torch.topk(logits, k)
|
108 |
-
out = logits.clone()
|
109 |
-
out[out < v[..., [-1]]] = -float('Inf')
|
110 |
-
return out
|
111 |
-
|
112 |
-
@torch.no_grad()
|
113 |
-
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
|
114 |
-
callback=lambda k: None):
|
115 |
-
x = torch.cat((c,x),dim=1)
|
116 |
-
block_size = self.transformer.get_block_size()
|
117 |
-
assert not self.transformer.training
|
118 |
-
if self.pkeep <= 0.0:
|
119 |
-
# one pass suffices since input is pure noise anyway
|
120 |
-
assert len(x.shape)==2
|
121 |
-
noise_shape = (x.shape[0], steps-1)
|
122 |
-
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
|
123 |
-
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
|
124 |
-
x = torch.cat((x,noise),dim=1)
|
125 |
-
logits, _ = self.transformer(x)
|
126 |
-
# take all logits for now and scale by temp
|
127 |
-
logits = logits / temperature
|
128 |
-
# optionally crop probabilities to only the top k options
|
129 |
-
if top_k is not None:
|
130 |
-
logits = self.top_k_logits(logits, top_k)
|
131 |
-
# apply softmax to convert to probabilities
|
132 |
-
probs = F.softmax(logits, dim=-1)
|
133 |
-
# sample from the distribution or take the most likely
|
134 |
-
if sample:
|
135 |
-
shape = probs.shape
|
136 |
-
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
137 |
-
ix = torch.multinomial(probs, num_samples=1)
|
138 |
-
probs = probs.reshape(shape[0],shape[1],shape[2])
|
139 |
-
ix = ix.reshape(shape[0],shape[1])
|
140 |
-
else:
|
141 |
-
_, ix = torch.topk(probs, k=1, dim=-1)
|
142 |
-
# cut off conditioning
|
143 |
-
x = ix[:, c.shape[1]-1:]
|
144 |
-
else:
|
145 |
-
for k in range(steps):
|
146 |
-
callback(k)
|
147 |
-
assert x.size(1) <= block_size # make sure model can see conditioning
|
148 |
-
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
149 |
-
logits, _ = self.transformer(x_cond)
|
150 |
-
# pluck the logits at the final step and scale by temperature
|
151 |
-
logits = logits[:, -1, :] / temperature
|
152 |
-
# optionally crop probabilities to only the top k options
|
153 |
-
if top_k is not None:
|
154 |
-
logits = self.top_k_logits(logits, top_k)
|
155 |
-
# apply softmax to convert to probabilities
|
156 |
-
probs = F.softmax(logits, dim=-1)
|
157 |
-
# sample from the distribution or take the most likely
|
158 |
-
if sample:
|
159 |
-
ix = torch.multinomial(probs, num_samples=1)
|
160 |
-
else:
|
161 |
-
_, ix = torch.topk(probs, k=1, dim=-1)
|
162 |
-
# append to the sequence and continue
|
163 |
-
x = torch.cat((x, ix), dim=1)
|
164 |
-
# cut off conditioning
|
165 |
-
x = x[:, c.shape[1]:]
|
166 |
-
return x
|
167 |
-
|
168 |
-
@torch.no_grad()
|
169 |
-
def encode_to_z(self, x):
|
170 |
-
quant_z, _, info = self.first_stage_model.encode(x)
|
171 |
-
indices = info[2].view(quant_z.shape[0], -1)
|
172 |
-
indices = self.permuter(indices)
|
173 |
-
return quant_z, indices
|
174 |
-
|
175 |
-
@torch.no_grad()
|
176 |
-
def encode_to_c(self, c):
|
177 |
-
if self.downsample_cond_size > -1:
|
178 |
-
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
|
179 |
-
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
|
180 |
-
if len(indices.shape) > 2:
|
181 |
-
indices = indices.view(c.shape[0], -1)
|
182 |
-
return quant_c, indices
|
183 |
-
|
184 |
-
@torch.no_grad()
|
185 |
-
def decode_to_img(self, index, zshape):
|
186 |
-
index = self.permuter(index, reverse=True)
|
187 |
-
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
|
188 |
-
quant_z = self.first_stage_model.quantize.get_codebook_entry(
|
189 |
-
index.reshape(-1), shape=bhwc)
|
190 |
-
x = self.first_stage_model.decode(quant_z)
|
191 |
-
return x
|
192 |
-
|
193 |
-
@torch.no_grad()
|
194 |
-
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
|
195 |
-
log = dict()
|
196 |
-
|
197 |
-
N = 4
|
198 |
-
if lr_interface:
|
199 |
-
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
|
200 |
-
else:
|
201 |
-
x, c = self.get_xc(batch, N)
|
202 |
-
x = x.to(device=self.device)
|
203 |
-
c = c.to(device=self.device)
|
204 |
-
|
205 |
-
quant_z, z_indices = self.encode_to_z(x)
|
206 |
-
quant_c, c_indices = self.encode_to_c(c)
|
207 |
-
|
208 |
-
# create a "half"" sample
|
209 |
-
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
|
210 |
-
index_sample = self.sample(z_start_indices, c_indices,
|
211 |
-
steps=z_indices.shape[1]-z_start_indices.shape[1],
|
212 |
-
temperature=temperature if temperature is not None else 1.0,
|
213 |
-
sample=True,
|
214 |
-
top_k=top_k if top_k is not None else 100,
|
215 |
-
callback=callback if callback is not None else lambda k: None)
|
216 |
-
x_sample = self.decode_to_img(index_sample, quant_z.shape)
|
217 |
-
|
218 |
-
# sample
|
219 |
-
z_start_indices = z_indices[:, :0]
|
220 |
-
index_sample = self.sample(z_start_indices, c_indices,
|
221 |
-
steps=z_indices.shape[1],
|
222 |
-
temperature=temperature if temperature is not None else 1.0,
|
223 |
-
sample=True,
|
224 |
-
top_k=top_k if top_k is not None else 100,
|
225 |
-
callback=callback if callback is not None else lambda k: None)
|
226 |
-
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
|
227 |
-
|
228 |
-
# det sample
|
229 |
-
z_start_indices = z_indices[:, :0]
|
230 |
-
index_sample = self.sample(z_start_indices, c_indices,
|
231 |
-
steps=z_indices.shape[1],
|
232 |
-
sample=False,
|
233 |
-
callback=callback if callback is not None else lambda k: None)
|
234 |
-
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
|
235 |
-
|
236 |
-
# reconstruction
|
237 |
-
x_rec = self.decode_to_img(z_indices, quant_z.shape)
|
238 |
-
|
239 |
-
log["inputs"] = x
|
240 |
-
log["reconstructions"] = x_rec
|
241 |
-
|
242 |
-
if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
|
243 |
-
figure_size = (x_rec.shape[2], x_rec.shape[3])
|
244 |
-
dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
|
245 |
-
label_for_category_no = dataset.get_textual_label_for_category_no
|
246 |
-
plotter = dataset.conditional_builders[self.cond_stage_key].plot
|
247 |
-
log["conditioning"] = torch.zeros_like(log["reconstructions"])
|
248 |
-
for i in range(quant_c.shape[0]):
|
249 |
-
log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
|
250 |
-
log["conditioning_rec"] = log["conditioning"]
|
251 |
-
elif self.cond_stage_key != "image":
|
252 |
-
cond_rec = self.cond_stage_model.decode(quant_c)
|
253 |
-
if self.cond_stage_key == "segmentation":
|
254 |
-
# get image from segmentation mask
|
255 |
-
num_classes = cond_rec.shape[1]
|
256 |
-
|
257 |
-
c = torch.argmax(c, dim=1, keepdim=True)
|
258 |
-
c = F.one_hot(c, num_classes=num_classes)
|
259 |
-
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
260 |
-
c = self.cond_stage_model.to_rgb(c)
|
261 |
-
|
262 |
-
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
|
263 |
-
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
|
264 |
-
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
|
265 |
-
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
|
266 |
-
log["conditioning_rec"] = cond_rec
|
267 |
-
log["conditioning"] = c
|
268 |
-
|
269 |
-
log["samples_half"] = x_sample
|
270 |
-
log["samples_nopix"] = x_sample_nopix
|
271 |
-
log["samples_det"] = x_sample_det
|
272 |
-
return log
|
273 |
-
|
274 |
-
def get_input(self, key, batch):
|
275 |
-
x = batch[key]
|
276 |
-
if len(x.shape) == 3:
|
277 |
-
x = x[..., None]
|
278 |
-
if len(x.shape) == 4:
|
279 |
-
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
280 |
-
if x.dtype == torch.double:
|
281 |
-
x = x.float()
|
282 |
-
return x
|
283 |
-
|
284 |
-
def get_xc(self, batch, N=None):
|
285 |
-
x = self.get_input(self.first_stage_key, batch)
|
286 |
-
c = self.get_input(self.cond_stage_key, batch)
|
287 |
-
if N is not None:
|
288 |
-
x = x[:N]
|
289 |
-
c = c[:N]
|
290 |
-
return x, c
|
291 |
-
|
292 |
-
def shared_step(self, batch, batch_idx):
|
293 |
-
x, c = self.get_xc(batch)
|
294 |
-
logits, target = self(x, c)
|
295 |
-
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
|
296 |
-
return loss
|
297 |
-
|
298 |
-
def training_step(self, batch, batch_idx):
|
299 |
-
loss = self.shared_step(batch, batch_idx)
|
300 |
-
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
301 |
-
return loss
|
302 |
-
|
303 |
-
def validation_step(self, batch, batch_idx):
|
304 |
-
loss = self.shared_step(batch, batch_idx)
|
305 |
-
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
306 |
-
return loss
|
307 |
-
|
308 |
-
def configure_optimizers(self):
|
309 |
-
"""
|
310 |
-
Following minGPT:
|
311 |
-
This long function is unfortunately doing something very simple and is being very defensive:
|
312 |
-
We are separating out all parameters of the model into two buckets: those that will experience
|
313 |
-
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
314 |
-
We are then returning the PyTorch optimizer object.
|
315 |
-
"""
|
316 |
-
# separate out all parameters to those that will and won't experience regularizing weight decay
|
317 |
-
decay = set()
|
318 |
-
no_decay = set()
|
319 |
-
whitelist_weight_modules = (torch.nn.Linear, )
|
320 |
-
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
321 |
-
for mn, m in self.transformer.named_modules():
|
322 |
-
for pn, p in m.named_parameters():
|
323 |
-
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
324 |
-
|
325 |
-
if pn.endswith('bias'):
|
326 |
-
# all biases will not be decayed
|
327 |
-
no_decay.add(fpn)
|
328 |
-
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
329 |
-
# weights of whitelist modules will be weight decayed
|
330 |
-
decay.add(fpn)
|
331 |
-
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
332 |
-
# weights of blacklist modules will NOT be weight decayed
|
333 |
-
no_decay.add(fpn)
|
334 |
-
|
335 |
-
# special case the position embedding parameter in the root GPT module as not decayed
|
336 |
-
no_decay.add('pos_emb')
|
337 |
-
|
338 |
-
# validate that we considered every parameter
|
339 |
-
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
|
340 |
-
inter_params = decay & no_decay
|
341 |
-
union_params = decay | no_decay
|
342 |
-
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
343 |
-
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
344 |
-
% (str(param_dict.keys() - union_params), )
|
345 |
-
|
346 |
-
# create the pytorch optimizer object
|
347 |
-
optim_groups = [
|
348 |
-
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
|
349 |
-
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
350 |
-
]
|
351 |
-
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
|
352 |
-
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/models/dummy_cond_stage.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
from torch import Tensor
|
2 |
-
|
3 |
-
|
4 |
-
class DummyCondStage:
|
5 |
-
def __init__(self, conditional_key):
|
6 |
-
self.conditional_key = conditional_key
|
7 |
-
self.train = None
|
8 |
-
|
9 |
-
def eval(self):
|
10 |
-
return self
|
11 |
-
|
12 |
-
@staticmethod
|
13 |
-
def encode(c: Tensor):
|
14 |
-
return c, None, (None, None, c)
|
15 |
-
|
16 |
-
@staticmethod
|
17 |
-
def decode(c: Tensor):
|
18 |
-
return c
|
19 |
-
|
20 |
-
@staticmethod
|
21 |
-
def to_rgb(c: Tensor):
|
22 |
-
return c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/models/vqgan.py
DELETED
@@ -1,404 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
-
import pytorch_lightning as pl
|
4 |
-
|
5 |
-
from main import instantiate_from_config
|
6 |
-
|
7 |
-
from taming.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
-
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
9 |
-
from taming.modules.vqvae.quantize import GumbelQuantize
|
10 |
-
from taming.modules.vqvae.quantize import EMAVectorQuantizer
|
11 |
-
|
12 |
-
class VQModel(pl.LightningModule):
|
13 |
-
def __init__(self,
|
14 |
-
ddconfig,
|
15 |
-
lossconfig,
|
16 |
-
n_embed,
|
17 |
-
embed_dim,
|
18 |
-
ckpt_path=None,
|
19 |
-
ignore_keys=[],
|
20 |
-
image_key="image",
|
21 |
-
colorize_nlabels=None,
|
22 |
-
monitor=None,
|
23 |
-
remap=None,
|
24 |
-
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
25 |
-
):
|
26 |
-
super().__init__()
|
27 |
-
self.image_key = image_key
|
28 |
-
self.encoder = Encoder(**ddconfig)
|
29 |
-
self.decoder = Decoder(**ddconfig)
|
30 |
-
self.loss = instantiate_from_config(lossconfig)
|
31 |
-
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
32 |
-
remap=remap, sane_index_shape=sane_index_shape)
|
33 |
-
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
34 |
-
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
35 |
-
if ckpt_path is not None:
|
36 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
37 |
-
self.image_key = image_key
|
38 |
-
if colorize_nlabels is not None:
|
39 |
-
assert type(colorize_nlabels)==int
|
40 |
-
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
41 |
-
if monitor is not None:
|
42 |
-
self.monitor = monitor
|
43 |
-
|
44 |
-
def init_from_ckpt(self, path, ignore_keys=list()):
|
45 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
46 |
-
keys = list(sd.keys())
|
47 |
-
for k in keys:
|
48 |
-
for ik in ignore_keys:
|
49 |
-
if k.startswith(ik):
|
50 |
-
print("Deleting key {} from state_dict.".format(k))
|
51 |
-
del sd[k]
|
52 |
-
self.load_state_dict(sd, strict=False)
|
53 |
-
print(f"Restored from {path}")
|
54 |
-
|
55 |
-
def encode(self, x):
|
56 |
-
h = self.encoder(x)
|
57 |
-
h = self.quant_conv(h)
|
58 |
-
quant, emb_loss, info = self.quantize(h)
|
59 |
-
return quant, emb_loss, info
|
60 |
-
|
61 |
-
def decode(self, quant):
|
62 |
-
quant = self.post_quant_conv(quant)
|
63 |
-
dec = self.decoder(quant)
|
64 |
-
return dec
|
65 |
-
|
66 |
-
def decode_code(self, code_b):
|
67 |
-
quant_b = self.quantize.embed_code(code_b)
|
68 |
-
dec = self.decode(quant_b)
|
69 |
-
return dec
|
70 |
-
|
71 |
-
def forward(self, input):
|
72 |
-
quant, diff, _ = self.encode(input)
|
73 |
-
dec = self.decode(quant)
|
74 |
-
return dec, diff
|
75 |
-
|
76 |
-
def get_input(self, batch, k):
|
77 |
-
x = batch[k]
|
78 |
-
if len(x.shape) == 3:
|
79 |
-
x = x[..., None]
|
80 |
-
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
81 |
-
return x.float()
|
82 |
-
|
83 |
-
def training_step(self, batch, batch_idx, optimizer_idx):
|
84 |
-
x = self.get_input(batch, self.image_key)
|
85 |
-
xrec, qloss = self(x)
|
86 |
-
|
87 |
-
if optimizer_idx == 0:
|
88 |
-
# autoencode
|
89 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
90 |
-
last_layer=self.get_last_layer(), split="train")
|
91 |
-
|
92 |
-
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
93 |
-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
94 |
-
return aeloss
|
95 |
-
|
96 |
-
if optimizer_idx == 1:
|
97 |
-
# discriminator
|
98 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
99 |
-
last_layer=self.get_last_layer(), split="train")
|
100 |
-
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
101 |
-
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
102 |
-
return discloss
|
103 |
-
|
104 |
-
def validation_step(self, batch, batch_idx):
|
105 |
-
x = self.get_input(batch, self.image_key)
|
106 |
-
xrec, qloss = self(x)
|
107 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
108 |
-
last_layer=self.get_last_layer(), split="val")
|
109 |
-
|
110 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
111 |
-
last_layer=self.get_last_layer(), split="val")
|
112 |
-
rec_loss = log_dict_ae["val/rec_loss"]
|
113 |
-
self.log("val/rec_loss", rec_loss,
|
114 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
115 |
-
self.log("val/aeloss", aeloss,
|
116 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
117 |
-
self.log_dict(log_dict_ae)
|
118 |
-
self.log_dict(log_dict_disc)
|
119 |
-
return self.log_dict
|
120 |
-
|
121 |
-
def configure_optimizers(self):
|
122 |
-
lr = self.learning_rate
|
123 |
-
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
124 |
-
list(self.decoder.parameters())+
|
125 |
-
list(self.quantize.parameters())+
|
126 |
-
list(self.quant_conv.parameters())+
|
127 |
-
list(self.post_quant_conv.parameters()),
|
128 |
-
lr=lr, betas=(0.5, 0.9))
|
129 |
-
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
130 |
-
lr=lr, betas=(0.5, 0.9))
|
131 |
-
return [opt_ae, opt_disc], []
|
132 |
-
|
133 |
-
def get_last_layer(self):
|
134 |
-
return self.decoder.conv_out.weight
|
135 |
-
|
136 |
-
def log_images(self, batch, **kwargs):
|
137 |
-
log = dict()
|
138 |
-
x = self.get_input(batch, self.image_key)
|
139 |
-
x = x.to(self.device)
|
140 |
-
xrec, _ = self(x)
|
141 |
-
if x.shape[1] > 3:
|
142 |
-
# colorize with random projection
|
143 |
-
assert xrec.shape[1] > 3
|
144 |
-
x = self.to_rgb(x)
|
145 |
-
xrec = self.to_rgb(xrec)
|
146 |
-
log["inputs"] = x
|
147 |
-
log["reconstructions"] = xrec
|
148 |
-
return log
|
149 |
-
|
150 |
-
def to_rgb(self, x):
|
151 |
-
assert self.image_key == "segmentation"
|
152 |
-
if not hasattr(self, "colorize"):
|
153 |
-
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
154 |
-
x = F.conv2d(x, weight=self.colorize)
|
155 |
-
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
156 |
-
return x
|
157 |
-
|
158 |
-
|
159 |
-
class VQSegmentationModel(VQModel):
|
160 |
-
def __init__(self, n_labels, *args, **kwargs):
|
161 |
-
super().__init__(*args, **kwargs)
|
162 |
-
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
|
163 |
-
|
164 |
-
def configure_optimizers(self):
|
165 |
-
lr = self.learning_rate
|
166 |
-
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
167 |
-
list(self.decoder.parameters())+
|
168 |
-
list(self.quantize.parameters())+
|
169 |
-
list(self.quant_conv.parameters())+
|
170 |
-
list(self.post_quant_conv.parameters()),
|
171 |
-
lr=lr, betas=(0.5, 0.9))
|
172 |
-
return opt_ae
|
173 |
-
|
174 |
-
def training_step(self, batch, batch_idx):
|
175 |
-
x = self.get_input(batch, self.image_key)
|
176 |
-
xrec, qloss = self(x)
|
177 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
178 |
-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
179 |
-
return aeloss
|
180 |
-
|
181 |
-
def validation_step(self, batch, batch_idx):
|
182 |
-
x = self.get_input(batch, self.image_key)
|
183 |
-
xrec, qloss = self(x)
|
184 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
185 |
-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
186 |
-
total_loss = log_dict_ae["val/total_loss"]
|
187 |
-
self.log("val/total_loss", total_loss,
|
188 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
189 |
-
return aeloss
|
190 |
-
|
191 |
-
@torch.no_grad()
|
192 |
-
def log_images(self, batch, **kwargs):
|
193 |
-
log = dict()
|
194 |
-
x = self.get_input(batch, self.image_key)
|
195 |
-
x = x.to(self.device)
|
196 |
-
xrec, _ = self(x)
|
197 |
-
if x.shape[1] > 3:
|
198 |
-
# colorize with random projection
|
199 |
-
assert xrec.shape[1] > 3
|
200 |
-
# convert logits to indices
|
201 |
-
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
202 |
-
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
203 |
-
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
204 |
-
x = self.to_rgb(x)
|
205 |
-
xrec = self.to_rgb(xrec)
|
206 |
-
log["inputs"] = x
|
207 |
-
log["reconstructions"] = xrec
|
208 |
-
return log
|
209 |
-
|
210 |
-
|
211 |
-
class VQNoDiscModel(VQModel):
|
212 |
-
def __init__(self,
|
213 |
-
ddconfig,
|
214 |
-
lossconfig,
|
215 |
-
n_embed,
|
216 |
-
embed_dim,
|
217 |
-
ckpt_path=None,
|
218 |
-
ignore_keys=[],
|
219 |
-
image_key="image",
|
220 |
-
colorize_nlabels=None
|
221 |
-
):
|
222 |
-
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
|
223 |
-
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
|
224 |
-
colorize_nlabels=colorize_nlabels)
|
225 |
-
|
226 |
-
def training_step(self, batch, batch_idx):
|
227 |
-
x = self.get_input(batch, self.image_key)
|
228 |
-
xrec, qloss = self(x)
|
229 |
-
# autoencode
|
230 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
|
231 |
-
output = pl.TrainResult(minimize=aeloss)
|
232 |
-
output.log("train/aeloss", aeloss,
|
233 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
234 |
-
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
235 |
-
return output
|
236 |
-
|
237 |
-
def validation_step(self, batch, batch_idx):
|
238 |
-
x = self.get_input(batch, self.image_key)
|
239 |
-
xrec, qloss = self(x)
|
240 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
|
241 |
-
rec_loss = log_dict_ae["val/rec_loss"]
|
242 |
-
output = pl.EvalResult(checkpoint_on=rec_loss)
|
243 |
-
output.log("val/rec_loss", rec_loss,
|
244 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
245 |
-
output.log("val/aeloss", aeloss,
|
246 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
247 |
-
output.log_dict(log_dict_ae)
|
248 |
-
|
249 |
-
return output
|
250 |
-
|
251 |
-
def configure_optimizers(self):
|
252 |
-
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
|
253 |
-
list(self.decoder.parameters())+
|
254 |
-
list(self.quantize.parameters())+
|
255 |
-
list(self.quant_conv.parameters())+
|
256 |
-
list(self.post_quant_conv.parameters()),
|
257 |
-
lr=self.learning_rate, betas=(0.5, 0.9))
|
258 |
-
return optimizer
|
259 |
-
|
260 |
-
|
261 |
-
class GumbelVQ(VQModel):
|
262 |
-
def __init__(self,
|
263 |
-
ddconfig,
|
264 |
-
lossconfig,
|
265 |
-
n_embed,
|
266 |
-
embed_dim,
|
267 |
-
temperature_scheduler_config,
|
268 |
-
ckpt_path=None,
|
269 |
-
ignore_keys=[],
|
270 |
-
image_key="image",
|
271 |
-
colorize_nlabels=None,
|
272 |
-
monitor=None,
|
273 |
-
kl_weight=1e-8,
|
274 |
-
remap=None,
|
275 |
-
):
|
276 |
-
|
277 |
-
z_channels = ddconfig["z_channels"]
|
278 |
-
super().__init__(ddconfig,
|
279 |
-
lossconfig,
|
280 |
-
n_embed,
|
281 |
-
embed_dim,
|
282 |
-
ckpt_path=None,
|
283 |
-
ignore_keys=ignore_keys,
|
284 |
-
image_key=image_key,
|
285 |
-
colorize_nlabels=colorize_nlabels,
|
286 |
-
monitor=monitor,
|
287 |
-
)
|
288 |
-
|
289 |
-
self.loss.n_classes = n_embed
|
290 |
-
self.vocab_size = n_embed
|
291 |
-
|
292 |
-
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
293 |
-
n_embed=n_embed,
|
294 |
-
kl_weight=kl_weight, temp_init=1.0,
|
295 |
-
remap=remap)
|
296 |
-
|
297 |
-
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
298 |
-
|
299 |
-
if ckpt_path is not None:
|
300 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
301 |
-
|
302 |
-
def temperature_scheduling(self):
|
303 |
-
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
304 |
-
|
305 |
-
def encode_to_prequant(self, x):
|
306 |
-
h = self.encoder(x)
|
307 |
-
h = self.quant_conv(h)
|
308 |
-
return h
|
309 |
-
|
310 |
-
def decode_code(self, code_b):
|
311 |
-
raise NotImplementedError
|
312 |
-
|
313 |
-
def training_step(self, batch, batch_idx, optimizer_idx):
|
314 |
-
self.temperature_scheduling()
|
315 |
-
x = self.get_input(batch, self.image_key)
|
316 |
-
xrec, qloss = self(x)
|
317 |
-
|
318 |
-
if optimizer_idx == 0:
|
319 |
-
# autoencode
|
320 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
321 |
-
last_layer=self.get_last_layer(), split="train")
|
322 |
-
|
323 |
-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
324 |
-
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
325 |
-
return aeloss
|
326 |
-
|
327 |
-
if optimizer_idx == 1:
|
328 |
-
# discriminator
|
329 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
330 |
-
last_layer=self.get_last_layer(), split="train")
|
331 |
-
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
332 |
-
return discloss
|
333 |
-
|
334 |
-
def validation_step(self, batch, batch_idx):
|
335 |
-
x = self.get_input(batch, self.image_key)
|
336 |
-
xrec, qloss = self(x, return_pred_indices=True)
|
337 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
338 |
-
last_layer=self.get_last_layer(), split="val")
|
339 |
-
|
340 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
341 |
-
last_layer=self.get_last_layer(), split="val")
|
342 |
-
rec_loss = log_dict_ae["val/rec_loss"]
|
343 |
-
self.log("val/rec_loss", rec_loss,
|
344 |
-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
345 |
-
self.log("val/aeloss", aeloss,
|
346 |
-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
347 |
-
self.log_dict(log_dict_ae)
|
348 |
-
self.log_dict(log_dict_disc)
|
349 |
-
return self.log_dict
|
350 |
-
|
351 |
-
def log_images(self, batch, **kwargs):
|
352 |
-
log = dict()
|
353 |
-
x = self.get_input(batch, self.image_key)
|
354 |
-
x = x.to(self.device)
|
355 |
-
# encode
|
356 |
-
h = self.encoder(x)
|
357 |
-
h = self.quant_conv(h)
|
358 |
-
quant, _, _ = self.quantize(h)
|
359 |
-
# decode
|
360 |
-
x_rec = self.decode(quant)
|
361 |
-
log["inputs"] = x
|
362 |
-
log["reconstructions"] = x_rec
|
363 |
-
return log
|
364 |
-
|
365 |
-
|
366 |
-
class EMAVQ(VQModel):
|
367 |
-
def __init__(self,
|
368 |
-
ddconfig,
|
369 |
-
lossconfig,
|
370 |
-
n_embed,
|
371 |
-
embed_dim,
|
372 |
-
ckpt_path=None,
|
373 |
-
ignore_keys=[],
|
374 |
-
image_key="image",
|
375 |
-
colorize_nlabels=None,
|
376 |
-
monitor=None,
|
377 |
-
remap=None,
|
378 |
-
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
379 |
-
):
|
380 |
-
super().__init__(ddconfig,
|
381 |
-
lossconfig,
|
382 |
-
n_embed,
|
383 |
-
embed_dim,
|
384 |
-
ckpt_path=None,
|
385 |
-
ignore_keys=ignore_keys,
|
386 |
-
image_key=image_key,
|
387 |
-
colorize_nlabels=colorize_nlabels,
|
388 |
-
monitor=monitor,
|
389 |
-
)
|
390 |
-
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
|
391 |
-
embedding_dim=embed_dim,
|
392 |
-
beta=0.25,
|
393 |
-
remap=remap)
|
394 |
-
def configure_optimizers(self):
|
395 |
-
lr = self.learning_rate
|
396 |
-
#Remove self.quantize from parameter list since it is updated via EMA
|
397 |
-
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
398 |
-
list(self.decoder.parameters())+
|
399 |
-
list(self.quant_conv.parameters())+
|
400 |
-
list(self.post_quant_conv.parameters()),
|
401 |
-
lr=lr, betas=(0.5, 0.9))
|
402 |
-
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
403 |
-
lr=lr, betas=(0.5, 0.9))
|
404 |
-
return [opt_ae, opt_disc], []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/diffusionmodules/model.py
DELETED
@@ -1,776 +0,0 @@
|
|
1 |
-
# pytorch_diffusion + derived encoder decoder
|
2 |
-
import math
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
def get_timestep_embedding(timesteps, embedding_dim):
|
9 |
-
"""
|
10 |
-
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
11 |
-
From Fairseq.
|
12 |
-
Build sinusoidal embeddings.
|
13 |
-
This matches the implementation in tensor2tensor, but differs slightly
|
14 |
-
from the description in Section 3.5 of "Attention Is All You Need".
|
15 |
-
"""
|
16 |
-
assert len(timesteps.shape) == 1
|
17 |
-
|
18 |
-
half_dim = embedding_dim // 2
|
19 |
-
emb = math.log(10000) / (half_dim - 1)
|
20 |
-
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
21 |
-
emb = emb.to(device=timesteps.device)
|
22 |
-
emb = timesteps.float()[:, None] * emb[None, :]
|
23 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
24 |
-
if embedding_dim % 2 == 1: # zero pad
|
25 |
-
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
26 |
-
return emb
|
27 |
-
|
28 |
-
|
29 |
-
def nonlinearity(x):
|
30 |
-
# swish
|
31 |
-
return x*torch.sigmoid(x)
|
32 |
-
|
33 |
-
|
34 |
-
def Normalize(in_channels):
|
35 |
-
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
36 |
-
|
37 |
-
|
38 |
-
class Upsample(nn.Module):
|
39 |
-
def __init__(self, in_channels, with_conv):
|
40 |
-
super().__init__()
|
41 |
-
self.with_conv = with_conv
|
42 |
-
if self.with_conv:
|
43 |
-
self.conv = torch.nn.Conv2d(in_channels,
|
44 |
-
in_channels,
|
45 |
-
kernel_size=3,
|
46 |
-
stride=1,
|
47 |
-
padding=1)
|
48 |
-
|
49 |
-
def forward(self, x):
|
50 |
-
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
51 |
-
if self.with_conv:
|
52 |
-
x = self.conv(x)
|
53 |
-
return x
|
54 |
-
|
55 |
-
|
56 |
-
class Downsample(nn.Module):
|
57 |
-
def __init__(self, in_channels, with_conv):
|
58 |
-
super().__init__()
|
59 |
-
self.with_conv = with_conv
|
60 |
-
if self.with_conv:
|
61 |
-
# no asymmetric padding in torch conv, must do it ourselves
|
62 |
-
self.conv = torch.nn.Conv2d(in_channels,
|
63 |
-
in_channels,
|
64 |
-
kernel_size=3,
|
65 |
-
stride=2,
|
66 |
-
padding=0)
|
67 |
-
|
68 |
-
def forward(self, x):
|
69 |
-
if self.with_conv:
|
70 |
-
pad = (0,1,0,1)
|
71 |
-
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
72 |
-
x = self.conv(x)
|
73 |
-
else:
|
74 |
-
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
75 |
-
return x
|
76 |
-
|
77 |
-
|
78 |
-
class ResnetBlock(nn.Module):
|
79 |
-
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
80 |
-
dropout, temb_channels=512):
|
81 |
-
super().__init__()
|
82 |
-
self.in_channels = in_channels
|
83 |
-
out_channels = in_channels if out_channels is None else out_channels
|
84 |
-
self.out_channels = out_channels
|
85 |
-
self.use_conv_shortcut = conv_shortcut
|
86 |
-
|
87 |
-
self.norm1 = Normalize(in_channels)
|
88 |
-
self.conv1 = torch.nn.Conv2d(in_channels,
|
89 |
-
out_channels,
|
90 |
-
kernel_size=3,
|
91 |
-
stride=1,
|
92 |
-
padding=1)
|
93 |
-
if temb_channels > 0:
|
94 |
-
self.temb_proj = torch.nn.Linear(temb_channels,
|
95 |
-
out_channels)
|
96 |
-
self.norm2 = Normalize(out_channels)
|
97 |
-
self.dropout = torch.nn.Dropout(dropout)
|
98 |
-
self.conv2 = torch.nn.Conv2d(out_channels,
|
99 |
-
out_channels,
|
100 |
-
kernel_size=3,
|
101 |
-
stride=1,
|
102 |
-
padding=1)
|
103 |
-
if self.in_channels != self.out_channels:
|
104 |
-
if self.use_conv_shortcut:
|
105 |
-
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
106 |
-
out_channels,
|
107 |
-
kernel_size=3,
|
108 |
-
stride=1,
|
109 |
-
padding=1)
|
110 |
-
else:
|
111 |
-
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
112 |
-
out_channels,
|
113 |
-
kernel_size=1,
|
114 |
-
stride=1,
|
115 |
-
padding=0)
|
116 |
-
|
117 |
-
def forward(self, x, temb):
|
118 |
-
h = x
|
119 |
-
h = self.norm1(h)
|
120 |
-
h = nonlinearity(h)
|
121 |
-
h = self.conv1(h)
|
122 |
-
|
123 |
-
if temb is not None:
|
124 |
-
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
125 |
-
|
126 |
-
h = self.norm2(h)
|
127 |
-
h = nonlinearity(h)
|
128 |
-
h = self.dropout(h)
|
129 |
-
h = self.conv2(h)
|
130 |
-
|
131 |
-
if self.in_channels != self.out_channels:
|
132 |
-
if self.use_conv_shortcut:
|
133 |
-
x = self.conv_shortcut(x)
|
134 |
-
else:
|
135 |
-
x = self.nin_shortcut(x)
|
136 |
-
|
137 |
-
return x+h
|
138 |
-
|
139 |
-
|
140 |
-
class AttnBlock(nn.Module):
|
141 |
-
def __init__(self, in_channels):
|
142 |
-
super().__init__()
|
143 |
-
self.in_channels = in_channels
|
144 |
-
|
145 |
-
self.norm = Normalize(in_channels)
|
146 |
-
self.q = torch.nn.Conv2d(in_channels,
|
147 |
-
in_channels,
|
148 |
-
kernel_size=1,
|
149 |
-
stride=1,
|
150 |
-
padding=0)
|
151 |
-
self.k = torch.nn.Conv2d(in_channels,
|
152 |
-
in_channels,
|
153 |
-
kernel_size=1,
|
154 |
-
stride=1,
|
155 |
-
padding=0)
|
156 |
-
self.v = torch.nn.Conv2d(in_channels,
|
157 |
-
in_channels,
|
158 |
-
kernel_size=1,
|
159 |
-
stride=1,
|
160 |
-
padding=0)
|
161 |
-
self.proj_out = torch.nn.Conv2d(in_channels,
|
162 |
-
in_channels,
|
163 |
-
kernel_size=1,
|
164 |
-
stride=1,
|
165 |
-
padding=0)
|
166 |
-
|
167 |
-
|
168 |
-
def forward(self, x):
|
169 |
-
h_ = x
|
170 |
-
h_ = self.norm(h_)
|
171 |
-
q = self.q(h_)
|
172 |
-
k = self.k(h_)
|
173 |
-
v = self.v(h_)
|
174 |
-
|
175 |
-
# compute attention
|
176 |
-
b,c,h,w = q.shape
|
177 |
-
q = q.reshape(b,c,h*w)
|
178 |
-
q = q.permute(0,2,1) # b,hw,c
|
179 |
-
k = k.reshape(b,c,h*w) # b,c,hw
|
180 |
-
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
181 |
-
w_ = w_ * (int(c)**(-0.5))
|
182 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
-
|
184 |
-
# attend to values
|
185 |
-
v = v.reshape(b,c,h*w)
|
186 |
-
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
187 |
-
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
188 |
-
h_ = h_.reshape(b,c,h,w)
|
189 |
-
|
190 |
-
h_ = self.proj_out(h_)
|
191 |
-
|
192 |
-
return x+h_
|
193 |
-
|
194 |
-
|
195 |
-
class Model(nn.Module):
|
196 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
197 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
198 |
-
resolution, use_timestep=True):
|
199 |
-
super().__init__()
|
200 |
-
self.ch = ch
|
201 |
-
self.temb_ch = self.ch*4
|
202 |
-
self.num_resolutions = len(ch_mult)
|
203 |
-
self.num_res_blocks = num_res_blocks
|
204 |
-
self.resolution = resolution
|
205 |
-
self.in_channels = in_channels
|
206 |
-
|
207 |
-
self.use_timestep = use_timestep
|
208 |
-
if self.use_timestep:
|
209 |
-
# timestep embedding
|
210 |
-
self.temb = nn.Module()
|
211 |
-
self.temb.dense = nn.ModuleList([
|
212 |
-
torch.nn.Linear(self.ch,
|
213 |
-
self.temb_ch),
|
214 |
-
torch.nn.Linear(self.temb_ch,
|
215 |
-
self.temb_ch),
|
216 |
-
])
|
217 |
-
|
218 |
-
# downsampling
|
219 |
-
self.conv_in = torch.nn.Conv2d(in_channels,
|
220 |
-
self.ch,
|
221 |
-
kernel_size=3,
|
222 |
-
stride=1,
|
223 |
-
padding=1)
|
224 |
-
|
225 |
-
curr_res = resolution
|
226 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
227 |
-
self.down = nn.ModuleList()
|
228 |
-
for i_level in range(self.num_resolutions):
|
229 |
-
block = nn.ModuleList()
|
230 |
-
attn = nn.ModuleList()
|
231 |
-
block_in = ch*in_ch_mult[i_level]
|
232 |
-
block_out = ch*ch_mult[i_level]
|
233 |
-
for i_block in range(self.num_res_blocks):
|
234 |
-
block.append(ResnetBlock(in_channels=block_in,
|
235 |
-
out_channels=block_out,
|
236 |
-
temb_channels=self.temb_ch,
|
237 |
-
dropout=dropout))
|
238 |
-
block_in = block_out
|
239 |
-
if curr_res in attn_resolutions:
|
240 |
-
attn.append(AttnBlock(block_in))
|
241 |
-
down = nn.Module()
|
242 |
-
down.block = block
|
243 |
-
down.attn = attn
|
244 |
-
if i_level != self.num_resolutions-1:
|
245 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
246 |
-
curr_res = curr_res // 2
|
247 |
-
self.down.append(down)
|
248 |
-
|
249 |
-
# middle
|
250 |
-
self.mid = nn.Module()
|
251 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
252 |
-
out_channels=block_in,
|
253 |
-
temb_channels=self.temb_ch,
|
254 |
-
dropout=dropout)
|
255 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
256 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
257 |
-
out_channels=block_in,
|
258 |
-
temb_channels=self.temb_ch,
|
259 |
-
dropout=dropout)
|
260 |
-
|
261 |
-
# upsampling
|
262 |
-
self.up = nn.ModuleList()
|
263 |
-
for i_level in reversed(range(self.num_resolutions)):
|
264 |
-
block = nn.ModuleList()
|
265 |
-
attn = nn.ModuleList()
|
266 |
-
block_out = ch*ch_mult[i_level]
|
267 |
-
skip_in = ch*ch_mult[i_level]
|
268 |
-
for i_block in range(self.num_res_blocks+1):
|
269 |
-
if i_block == self.num_res_blocks:
|
270 |
-
skip_in = ch*in_ch_mult[i_level]
|
271 |
-
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
272 |
-
out_channels=block_out,
|
273 |
-
temb_channels=self.temb_ch,
|
274 |
-
dropout=dropout))
|
275 |
-
block_in = block_out
|
276 |
-
if curr_res in attn_resolutions:
|
277 |
-
attn.append(AttnBlock(block_in))
|
278 |
-
up = nn.Module()
|
279 |
-
up.block = block
|
280 |
-
up.attn = attn
|
281 |
-
if i_level != 0:
|
282 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
283 |
-
curr_res = curr_res * 2
|
284 |
-
self.up.insert(0, up) # prepend to get consistent order
|
285 |
-
|
286 |
-
# end
|
287 |
-
self.norm_out = Normalize(block_in)
|
288 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
289 |
-
out_ch,
|
290 |
-
kernel_size=3,
|
291 |
-
stride=1,
|
292 |
-
padding=1)
|
293 |
-
|
294 |
-
|
295 |
-
def forward(self, x, t=None):
|
296 |
-
#assert x.shape[2] == x.shape[3] == self.resolution
|
297 |
-
|
298 |
-
if self.use_timestep:
|
299 |
-
# timestep embedding
|
300 |
-
assert t is not None
|
301 |
-
temb = get_timestep_embedding(t, self.ch)
|
302 |
-
temb = self.temb.dense[0](temb)
|
303 |
-
temb = nonlinearity(temb)
|
304 |
-
temb = self.temb.dense[1](temb)
|
305 |
-
else:
|
306 |
-
temb = None
|
307 |
-
|
308 |
-
# downsampling
|
309 |
-
hs = [self.conv_in(x)]
|
310 |
-
for i_level in range(self.num_resolutions):
|
311 |
-
for i_block in range(self.num_res_blocks):
|
312 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
313 |
-
if len(self.down[i_level].attn) > 0:
|
314 |
-
h = self.down[i_level].attn[i_block](h)
|
315 |
-
hs.append(h)
|
316 |
-
if i_level != self.num_resolutions-1:
|
317 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
318 |
-
|
319 |
-
# middle
|
320 |
-
h = hs[-1]
|
321 |
-
h = self.mid.block_1(h, temb)
|
322 |
-
h = self.mid.attn_1(h)
|
323 |
-
h = self.mid.block_2(h, temb)
|
324 |
-
|
325 |
-
# upsampling
|
326 |
-
for i_level in reversed(range(self.num_resolutions)):
|
327 |
-
for i_block in range(self.num_res_blocks+1):
|
328 |
-
h = self.up[i_level].block[i_block](
|
329 |
-
torch.cat([h, hs.pop()], dim=1), temb)
|
330 |
-
if len(self.up[i_level].attn) > 0:
|
331 |
-
h = self.up[i_level].attn[i_block](h)
|
332 |
-
if i_level != 0:
|
333 |
-
h = self.up[i_level].upsample(h)
|
334 |
-
|
335 |
-
# end
|
336 |
-
h = self.norm_out(h)
|
337 |
-
h = nonlinearity(h)
|
338 |
-
h = self.conv_out(h)
|
339 |
-
return h
|
340 |
-
|
341 |
-
|
342 |
-
class Encoder(nn.Module):
|
343 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
344 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
345 |
-
resolution, z_channels, double_z=True, **ignore_kwargs):
|
346 |
-
super().__init__()
|
347 |
-
self.ch = ch
|
348 |
-
self.temb_ch = 0
|
349 |
-
self.num_resolutions = len(ch_mult)
|
350 |
-
self.num_res_blocks = num_res_blocks
|
351 |
-
self.resolution = resolution
|
352 |
-
self.in_channels = in_channels
|
353 |
-
|
354 |
-
# downsampling
|
355 |
-
self.conv_in = torch.nn.Conv2d(in_channels,
|
356 |
-
self.ch,
|
357 |
-
kernel_size=3,
|
358 |
-
stride=1,
|
359 |
-
padding=1)
|
360 |
-
|
361 |
-
curr_res = resolution
|
362 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
363 |
-
self.down = nn.ModuleList()
|
364 |
-
for i_level in range(self.num_resolutions):
|
365 |
-
block = nn.ModuleList()
|
366 |
-
attn = nn.ModuleList()
|
367 |
-
block_in = ch*in_ch_mult[i_level]
|
368 |
-
block_out = ch*ch_mult[i_level]
|
369 |
-
for i_block in range(self.num_res_blocks):
|
370 |
-
block.append(ResnetBlock(in_channels=block_in,
|
371 |
-
out_channels=block_out,
|
372 |
-
temb_channels=self.temb_ch,
|
373 |
-
dropout=dropout))
|
374 |
-
block_in = block_out
|
375 |
-
if curr_res in attn_resolutions:
|
376 |
-
attn.append(AttnBlock(block_in))
|
377 |
-
down = nn.Module()
|
378 |
-
down.block = block
|
379 |
-
down.attn = attn
|
380 |
-
if i_level != self.num_resolutions-1:
|
381 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
382 |
-
curr_res = curr_res // 2
|
383 |
-
self.down.append(down)
|
384 |
-
|
385 |
-
# middle
|
386 |
-
self.mid = nn.Module()
|
387 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
388 |
-
out_channels=block_in,
|
389 |
-
temb_channels=self.temb_ch,
|
390 |
-
dropout=dropout)
|
391 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
392 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
393 |
-
out_channels=block_in,
|
394 |
-
temb_channels=self.temb_ch,
|
395 |
-
dropout=dropout)
|
396 |
-
|
397 |
-
# end
|
398 |
-
self.norm_out = Normalize(block_in)
|
399 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
400 |
-
2*z_channels if double_z else z_channels,
|
401 |
-
kernel_size=3,
|
402 |
-
stride=1,
|
403 |
-
padding=1)
|
404 |
-
|
405 |
-
|
406 |
-
def forward(self, x):
|
407 |
-
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
408 |
-
|
409 |
-
# timestep embedding
|
410 |
-
temb = None
|
411 |
-
|
412 |
-
# downsampling
|
413 |
-
hs = [self.conv_in(x)]
|
414 |
-
for i_level in range(self.num_resolutions):
|
415 |
-
for i_block in range(self.num_res_blocks):
|
416 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
417 |
-
if len(self.down[i_level].attn) > 0:
|
418 |
-
h = self.down[i_level].attn[i_block](h)
|
419 |
-
hs.append(h)
|
420 |
-
if i_level != self.num_resolutions-1:
|
421 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
422 |
-
|
423 |
-
# middle
|
424 |
-
h = hs[-1]
|
425 |
-
h = self.mid.block_1(h, temb)
|
426 |
-
h = self.mid.attn_1(h)
|
427 |
-
h = self.mid.block_2(h, temb)
|
428 |
-
|
429 |
-
# end
|
430 |
-
h = self.norm_out(h)
|
431 |
-
h = nonlinearity(h)
|
432 |
-
h = self.conv_out(h)
|
433 |
-
return h
|
434 |
-
|
435 |
-
|
436 |
-
class Decoder(nn.Module):
|
437 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
438 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
439 |
-
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
440 |
-
super().__init__()
|
441 |
-
self.ch = ch
|
442 |
-
self.temb_ch = 0
|
443 |
-
self.num_resolutions = len(ch_mult)
|
444 |
-
self.num_res_blocks = num_res_blocks
|
445 |
-
self.resolution = resolution
|
446 |
-
self.in_channels = in_channels
|
447 |
-
self.give_pre_end = give_pre_end
|
448 |
-
|
449 |
-
# compute in_ch_mult, block_in and curr_res at lowest res
|
450 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
451 |
-
block_in = ch*ch_mult[self.num_resolutions-1]
|
452 |
-
curr_res = resolution // 2**(self.num_resolutions-1)
|
453 |
-
self.z_shape = (1,z_channels,curr_res,curr_res)
|
454 |
-
print("Working with z of shape {} = {} dimensions.".format(
|
455 |
-
self.z_shape, np.prod(self.z_shape)))
|
456 |
-
|
457 |
-
# z to block_in
|
458 |
-
self.conv_in = torch.nn.Conv2d(z_channels,
|
459 |
-
block_in,
|
460 |
-
kernel_size=3,
|
461 |
-
stride=1,
|
462 |
-
padding=1)
|
463 |
-
|
464 |
-
# middle
|
465 |
-
self.mid = nn.Module()
|
466 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
467 |
-
out_channels=block_in,
|
468 |
-
temb_channels=self.temb_ch,
|
469 |
-
dropout=dropout)
|
470 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
471 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
472 |
-
out_channels=block_in,
|
473 |
-
temb_channels=self.temb_ch,
|
474 |
-
dropout=dropout)
|
475 |
-
|
476 |
-
# upsampling
|
477 |
-
self.up = nn.ModuleList()
|
478 |
-
for i_level in reversed(range(self.num_resolutions)):
|
479 |
-
block = nn.ModuleList()
|
480 |
-
attn = nn.ModuleList()
|
481 |
-
block_out = ch*ch_mult[i_level]
|
482 |
-
for i_block in range(self.num_res_blocks+1):
|
483 |
-
block.append(ResnetBlock(in_channels=block_in,
|
484 |
-
out_channels=block_out,
|
485 |
-
temb_channels=self.temb_ch,
|
486 |
-
dropout=dropout))
|
487 |
-
block_in = block_out
|
488 |
-
if curr_res in attn_resolutions:
|
489 |
-
attn.append(AttnBlock(block_in))
|
490 |
-
up = nn.Module()
|
491 |
-
up.block = block
|
492 |
-
up.attn = attn
|
493 |
-
if i_level != 0:
|
494 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
495 |
-
curr_res = curr_res * 2
|
496 |
-
self.up.insert(0, up) # prepend to get consistent order
|
497 |
-
|
498 |
-
# end
|
499 |
-
self.norm_out = Normalize(block_in)
|
500 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
501 |
-
out_ch,
|
502 |
-
kernel_size=3,
|
503 |
-
stride=1,
|
504 |
-
padding=1)
|
505 |
-
|
506 |
-
def forward(self, z):
|
507 |
-
#assert z.shape[1:] == self.z_shape[1:]
|
508 |
-
self.last_z_shape = z.shape
|
509 |
-
|
510 |
-
# timestep embedding
|
511 |
-
temb = None
|
512 |
-
|
513 |
-
# z to block_in
|
514 |
-
h = self.conv_in(z)
|
515 |
-
|
516 |
-
# middle
|
517 |
-
h = self.mid.block_1(h, temb)
|
518 |
-
h = self.mid.attn_1(h)
|
519 |
-
h = self.mid.block_2(h, temb)
|
520 |
-
|
521 |
-
# upsampling
|
522 |
-
for i_level in reversed(range(self.num_resolutions)):
|
523 |
-
for i_block in range(self.num_res_blocks+1):
|
524 |
-
h = self.up[i_level].block[i_block](h, temb)
|
525 |
-
if len(self.up[i_level].attn) > 0:
|
526 |
-
h = self.up[i_level].attn[i_block](h)
|
527 |
-
if i_level != 0:
|
528 |
-
h = self.up[i_level].upsample(h)
|
529 |
-
|
530 |
-
# end
|
531 |
-
if self.give_pre_end:
|
532 |
-
return h
|
533 |
-
|
534 |
-
h = self.norm_out(h)
|
535 |
-
h = nonlinearity(h)
|
536 |
-
h = self.conv_out(h)
|
537 |
-
return h
|
538 |
-
|
539 |
-
|
540 |
-
class VUNet(nn.Module):
|
541 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
542 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
543 |
-
in_channels, c_channels,
|
544 |
-
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
545 |
-
super().__init__()
|
546 |
-
self.ch = ch
|
547 |
-
self.temb_ch = self.ch*4
|
548 |
-
self.num_resolutions = len(ch_mult)
|
549 |
-
self.num_res_blocks = num_res_blocks
|
550 |
-
self.resolution = resolution
|
551 |
-
|
552 |
-
self.use_timestep = use_timestep
|
553 |
-
if self.use_timestep:
|
554 |
-
# timestep embedding
|
555 |
-
self.temb = nn.Module()
|
556 |
-
self.temb.dense = nn.ModuleList([
|
557 |
-
torch.nn.Linear(self.ch,
|
558 |
-
self.temb_ch),
|
559 |
-
torch.nn.Linear(self.temb_ch,
|
560 |
-
self.temb_ch),
|
561 |
-
])
|
562 |
-
|
563 |
-
# downsampling
|
564 |
-
self.conv_in = torch.nn.Conv2d(c_channels,
|
565 |
-
self.ch,
|
566 |
-
kernel_size=3,
|
567 |
-
stride=1,
|
568 |
-
padding=1)
|
569 |
-
|
570 |
-
curr_res = resolution
|
571 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
572 |
-
self.down = nn.ModuleList()
|
573 |
-
for i_level in range(self.num_resolutions):
|
574 |
-
block = nn.ModuleList()
|
575 |
-
attn = nn.ModuleList()
|
576 |
-
block_in = ch*in_ch_mult[i_level]
|
577 |
-
block_out = ch*ch_mult[i_level]
|
578 |
-
for i_block in range(self.num_res_blocks):
|
579 |
-
block.append(ResnetBlock(in_channels=block_in,
|
580 |
-
out_channels=block_out,
|
581 |
-
temb_channels=self.temb_ch,
|
582 |
-
dropout=dropout))
|
583 |
-
block_in = block_out
|
584 |
-
if curr_res in attn_resolutions:
|
585 |
-
attn.append(AttnBlock(block_in))
|
586 |
-
down = nn.Module()
|
587 |
-
down.block = block
|
588 |
-
down.attn = attn
|
589 |
-
if i_level != self.num_resolutions-1:
|
590 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
591 |
-
curr_res = curr_res // 2
|
592 |
-
self.down.append(down)
|
593 |
-
|
594 |
-
self.z_in = torch.nn.Conv2d(z_channels,
|
595 |
-
block_in,
|
596 |
-
kernel_size=1,
|
597 |
-
stride=1,
|
598 |
-
padding=0)
|
599 |
-
# middle
|
600 |
-
self.mid = nn.Module()
|
601 |
-
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
602 |
-
out_channels=block_in,
|
603 |
-
temb_channels=self.temb_ch,
|
604 |
-
dropout=dropout)
|
605 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
606 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
607 |
-
out_channels=block_in,
|
608 |
-
temb_channels=self.temb_ch,
|
609 |
-
dropout=dropout)
|
610 |
-
|
611 |
-
# upsampling
|
612 |
-
self.up = nn.ModuleList()
|
613 |
-
for i_level in reversed(range(self.num_resolutions)):
|
614 |
-
block = nn.ModuleList()
|
615 |
-
attn = nn.ModuleList()
|
616 |
-
block_out = ch*ch_mult[i_level]
|
617 |
-
skip_in = ch*ch_mult[i_level]
|
618 |
-
for i_block in range(self.num_res_blocks+1):
|
619 |
-
if i_block == self.num_res_blocks:
|
620 |
-
skip_in = ch*in_ch_mult[i_level]
|
621 |
-
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
622 |
-
out_channels=block_out,
|
623 |
-
temb_channels=self.temb_ch,
|
624 |
-
dropout=dropout))
|
625 |
-
block_in = block_out
|
626 |
-
if curr_res in attn_resolutions:
|
627 |
-
attn.append(AttnBlock(block_in))
|
628 |
-
up = nn.Module()
|
629 |
-
up.block = block
|
630 |
-
up.attn = attn
|
631 |
-
if i_level != 0:
|
632 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
633 |
-
curr_res = curr_res * 2
|
634 |
-
self.up.insert(0, up) # prepend to get consistent order
|
635 |
-
|
636 |
-
# end
|
637 |
-
self.norm_out = Normalize(block_in)
|
638 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
639 |
-
out_ch,
|
640 |
-
kernel_size=3,
|
641 |
-
stride=1,
|
642 |
-
padding=1)
|
643 |
-
|
644 |
-
|
645 |
-
def forward(self, x, z):
|
646 |
-
#assert x.shape[2] == x.shape[3] == self.resolution
|
647 |
-
|
648 |
-
if self.use_timestep:
|
649 |
-
# timestep embedding
|
650 |
-
assert t is not None
|
651 |
-
temb = get_timestep_embedding(t, self.ch)
|
652 |
-
temb = self.temb.dense[0](temb)
|
653 |
-
temb = nonlinearity(temb)
|
654 |
-
temb = self.temb.dense[1](temb)
|
655 |
-
else:
|
656 |
-
temb = None
|
657 |
-
|
658 |
-
# downsampling
|
659 |
-
hs = [self.conv_in(x)]
|
660 |
-
for i_level in range(self.num_resolutions):
|
661 |
-
for i_block in range(self.num_res_blocks):
|
662 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
663 |
-
if len(self.down[i_level].attn) > 0:
|
664 |
-
h = self.down[i_level].attn[i_block](h)
|
665 |
-
hs.append(h)
|
666 |
-
if i_level != self.num_resolutions-1:
|
667 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
668 |
-
|
669 |
-
# middle
|
670 |
-
h = hs[-1]
|
671 |
-
z = self.z_in(z)
|
672 |
-
h = torch.cat((h,z),dim=1)
|
673 |
-
h = self.mid.block_1(h, temb)
|
674 |
-
h = self.mid.attn_1(h)
|
675 |
-
h = self.mid.block_2(h, temb)
|
676 |
-
|
677 |
-
# upsampling
|
678 |
-
for i_level in reversed(range(self.num_resolutions)):
|
679 |
-
for i_block in range(self.num_res_blocks+1):
|
680 |
-
h = self.up[i_level].block[i_block](
|
681 |
-
torch.cat([h, hs.pop()], dim=1), temb)
|
682 |
-
if len(self.up[i_level].attn) > 0:
|
683 |
-
h = self.up[i_level].attn[i_block](h)
|
684 |
-
if i_level != 0:
|
685 |
-
h = self.up[i_level].upsample(h)
|
686 |
-
|
687 |
-
# end
|
688 |
-
h = self.norm_out(h)
|
689 |
-
h = nonlinearity(h)
|
690 |
-
h = self.conv_out(h)
|
691 |
-
return h
|
692 |
-
|
693 |
-
|
694 |
-
class SimpleDecoder(nn.Module):
|
695 |
-
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
696 |
-
super().__init__()
|
697 |
-
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
698 |
-
ResnetBlock(in_channels=in_channels,
|
699 |
-
out_channels=2 * in_channels,
|
700 |
-
temb_channels=0, dropout=0.0),
|
701 |
-
ResnetBlock(in_channels=2 * in_channels,
|
702 |
-
out_channels=4 * in_channels,
|
703 |
-
temb_channels=0, dropout=0.0),
|
704 |
-
ResnetBlock(in_channels=4 * in_channels,
|
705 |
-
out_channels=2 * in_channels,
|
706 |
-
temb_channels=0, dropout=0.0),
|
707 |
-
nn.Conv2d(2*in_channels, in_channels, 1),
|
708 |
-
Upsample(in_channels, with_conv=True)])
|
709 |
-
# end
|
710 |
-
self.norm_out = Normalize(in_channels)
|
711 |
-
self.conv_out = torch.nn.Conv2d(in_channels,
|
712 |
-
out_channels,
|
713 |
-
kernel_size=3,
|
714 |
-
stride=1,
|
715 |
-
padding=1)
|
716 |
-
|
717 |
-
def forward(self, x):
|
718 |
-
for i, layer in enumerate(self.model):
|
719 |
-
if i in [1,2,3]:
|
720 |
-
x = layer(x, None)
|
721 |
-
else:
|
722 |
-
x = layer(x)
|
723 |
-
|
724 |
-
h = self.norm_out(x)
|
725 |
-
h = nonlinearity(h)
|
726 |
-
x = self.conv_out(h)
|
727 |
-
return x
|
728 |
-
|
729 |
-
|
730 |
-
class UpsampleDecoder(nn.Module):
|
731 |
-
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
732 |
-
ch_mult=(2,2), dropout=0.0):
|
733 |
-
super().__init__()
|
734 |
-
# upsampling
|
735 |
-
self.temb_ch = 0
|
736 |
-
self.num_resolutions = len(ch_mult)
|
737 |
-
self.num_res_blocks = num_res_blocks
|
738 |
-
block_in = in_channels
|
739 |
-
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
740 |
-
self.res_blocks = nn.ModuleList()
|
741 |
-
self.upsample_blocks = nn.ModuleList()
|
742 |
-
for i_level in range(self.num_resolutions):
|
743 |
-
res_block = []
|
744 |
-
block_out = ch * ch_mult[i_level]
|
745 |
-
for i_block in range(self.num_res_blocks + 1):
|
746 |
-
res_block.append(ResnetBlock(in_channels=block_in,
|
747 |
-
out_channels=block_out,
|
748 |
-
temb_channels=self.temb_ch,
|
749 |
-
dropout=dropout))
|
750 |
-
block_in = block_out
|
751 |
-
self.res_blocks.append(nn.ModuleList(res_block))
|
752 |
-
if i_level != self.num_resolutions - 1:
|
753 |
-
self.upsample_blocks.append(Upsample(block_in, True))
|
754 |
-
curr_res = curr_res * 2
|
755 |
-
|
756 |
-
# end
|
757 |
-
self.norm_out = Normalize(block_in)
|
758 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
759 |
-
out_channels,
|
760 |
-
kernel_size=3,
|
761 |
-
stride=1,
|
762 |
-
padding=1)
|
763 |
-
|
764 |
-
def forward(self, x):
|
765 |
-
# upsampling
|
766 |
-
h = x
|
767 |
-
for k, i_level in enumerate(range(self.num_resolutions)):
|
768 |
-
for i_block in range(self.num_res_blocks + 1):
|
769 |
-
h = self.res_blocks[i_level][i_block](h, None)
|
770 |
-
if i_level != self.num_resolutions - 1:
|
771 |
-
h = self.upsample_blocks[k](h)
|
772 |
-
h = self.norm_out(h)
|
773 |
-
h = nonlinearity(h)
|
774 |
-
h = self.conv_out(h)
|
775 |
-
return h
|
776 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/discriminator/model.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
|
5 |
-
from taming.modules.util import ActNorm
|
6 |
-
|
7 |
-
|
8 |
-
def weights_init(m):
|
9 |
-
classname = m.__class__.__name__
|
10 |
-
if classname.find('Conv') != -1:
|
11 |
-
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
-
elif classname.find('BatchNorm') != -1:
|
13 |
-
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
-
nn.init.constant_(m.bias.data, 0)
|
15 |
-
|
16 |
-
|
17 |
-
class NLayerDiscriminator(nn.Module):
|
18 |
-
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
-
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
-
"""
|
21 |
-
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
-
"""Construct a PatchGAN discriminator
|
23 |
-
Parameters:
|
24 |
-
input_nc (int) -- the number of channels in input images
|
25 |
-
ndf (int) -- the number of filters in the last conv layer
|
26 |
-
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
-
norm_layer -- normalization layer
|
28 |
-
"""
|
29 |
-
super(NLayerDiscriminator, self).__init__()
|
30 |
-
if not use_actnorm:
|
31 |
-
norm_layer = nn.BatchNorm2d
|
32 |
-
else:
|
33 |
-
norm_layer = ActNorm
|
34 |
-
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
-
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
-
else:
|
37 |
-
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
-
|
39 |
-
kw = 4
|
40 |
-
padw = 1
|
41 |
-
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
-
nf_mult = 1
|
43 |
-
nf_mult_prev = 1
|
44 |
-
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
-
nf_mult_prev = nf_mult
|
46 |
-
nf_mult = min(2 ** n, 8)
|
47 |
-
sequence += [
|
48 |
-
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
-
norm_layer(ndf * nf_mult),
|
50 |
-
nn.LeakyReLU(0.2, True)
|
51 |
-
]
|
52 |
-
|
53 |
-
nf_mult_prev = nf_mult
|
54 |
-
nf_mult = min(2 ** n_layers, 8)
|
55 |
-
sequence += [
|
56 |
-
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
-
norm_layer(ndf * nf_mult),
|
58 |
-
nn.LeakyReLU(0.2, True)
|
59 |
-
]
|
60 |
-
|
61 |
-
sequence += [
|
62 |
-
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
-
self.main = nn.Sequential(*sequence)
|
64 |
-
|
65 |
-
def forward(self, input):
|
66 |
-
"""Standard forward."""
|
67 |
-
return self.main(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/losses/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from taming.modules.losses.vqperceptual import DummyLoss
|
2 |
-
|
|
|
|
|
|
taming-transformers/taming/modules/losses/lpips.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
from torchvision import models
|
6 |
-
from collections import namedtuple
|
7 |
-
|
8 |
-
from taming.util import get_ckpt_path
|
9 |
-
|
10 |
-
|
11 |
-
class LPIPS(nn.Module):
|
12 |
-
# Learned perceptual metric
|
13 |
-
def __init__(self, use_dropout=True):
|
14 |
-
super().__init__()
|
15 |
-
self.scaling_layer = ScalingLayer()
|
16 |
-
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
17 |
-
self.net = vgg16(pretrained=True, requires_grad=False)
|
18 |
-
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
19 |
-
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
20 |
-
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
21 |
-
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
22 |
-
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
23 |
-
self.load_from_pretrained()
|
24 |
-
for param in self.parameters():
|
25 |
-
param.requires_grad = False
|
26 |
-
|
27 |
-
def load_from_pretrained(self, name="vgg_lpips"):
|
28 |
-
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
|
29 |
-
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
30 |
-
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
31 |
-
|
32 |
-
@classmethod
|
33 |
-
def from_pretrained(cls, name="vgg_lpips"):
|
34 |
-
if name != "vgg_lpips":
|
35 |
-
raise NotImplementedError
|
36 |
-
model = cls()
|
37 |
-
ckpt = get_ckpt_path(name)
|
38 |
-
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
39 |
-
return model
|
40 |
-
|
41 |
-
def forward(self, input, target):
|
42 |
-
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
43 |
-
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
44 |
-
feats0, feats1, diffs = {}, {}, {}
|
45 |
-
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
46 |
-
for kk in range(len(self.chns)):
|
47 |
-
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
48 |
-
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
49 |
-
|
50 |
-
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
51 |
-
val = res[0]
|
52 |
-
for l in range(1, len(self.chns)):
|
53 |
-
val += res[l]
|
54 |
-
return val
|
55 |
-
|
56 |
-
|
57 |
-
class ScalingLayer(nn.Module):
|
58 |
-
def __init__(self):
|
59 |
-
super(ScalingLayer, self).__init__()
|
60 |
-
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
61 |
-
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
62 |
-
|
63 |
-
def forward(self, inp):
|
64 |
-
return (inp - self.shift) / self.scale
|
65 |
-
|
66 |
-
|
67 |
-
class NetLinLayer(nn.Module):
|
68 |
-
""" A single linear layer which does a 1x1 conv """
|
69 |
-
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
70 |
-
super(NetLinLayer, self).__init__()
|
71 |
-
layers = [nn.Dropout(), ] if (use_dropout) else []
|
72 |
-
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
73 |
-
self.model = nn.Sequential(*layers)
|
74 |
-
|
75 |
-
|
76 |
-
class vgg16(torch.nn.Module):
|
77 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
78 |
-
super(vgg16, self).__init__()
|
79 |
-
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
80 |
-
self.slice1 = torch.nn.Sequential()
|
81 |
-
self.slice2 = torch.nn.Sequential()
|
82 |
-
self.slice3 = torch.nn.Sequential()
|
83 |
-
self.slice4 = torch.nn.Sequential()
|
84 |
-
self.slice5 = torch.nn.Sequential()
|
85 |
-
self.N_slices = 5
|
86 |
-
for x in range(4):
|
87 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
88 |
-
for x in range(4, 9):
|
89 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
90 |
-
for x in range(9, 16):
|
91 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
92 |
-
for x in range(16, 23):
|
93 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
94 |
-
for x in range(23, 30):
|
95 |
-
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
96 |
-
if not requires_grad:
|
97 |
-
for param in self.parameters():
|
98 |
-
param.requires_grad = False
|
99 |
-
|
100 |
-
def forward(self, X):
|
101 |
-
h = self.slice1(X)
|
102 |
-
h_relu1_2 = h
|
103 |
-
h = self.slice2(h)
|
104 |
-
h_relu2_2 = h
|
105 |
-
h = self.slice3(h)
|
106 |
-
h_relu3_3 = h
|
107 |
-
h = self.slice4(h)
|
108 |
-
h_relu4_3 = h
|
109 |
-
h = self.slice5(h)
|
110 |
-
h_relu5_3 = h
|
111 |
-
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
112 |
-
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
113 |
-
return out
|
114 |
-
|
115 |
-
|
116 |
-
def normalize_tensor(x,eps=1e-10):
|
117 |
-
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
118 |
-
return x/(norm_factor+eps)
|
119 |
-
|
120 |
-
|
121 |
-
def spatial_average(x, keepdim=True):
|
122 |
-
return x.mean([2,3],keepdim=keepdim)
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/losses/segmentation.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import torch.nn.functional as F
|
3 |
-
|
4 |
-
|
5 |
-
class BCELoss(nn.Module):
|
6 |
-
def forward(self, prediction, target):
|
7 |
-
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
8 |
-
return loss, {}
|
9 |
-
|
10 |
-
|
11 |
-
class BCELossWithQuant(nn.Module):
|
12 |
-
def __init__(self, codebook_weight=1.):
|
13 |
-
super().__init__()
|
14 |
-
self.codebook_weight = codebook_weight
|
15 |
-
|
16 |
-
def forward(self, qloss, target, prediction, split):
|
17 |
-
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
18 |
-
loss = bce_loss + self.codebook_weight*qloss
|
19 |
-
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
20 |
-
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
21 |
-
"{}/quant_loss".format(split): qloss.detach().mean()
|
22 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/losses/vqperceptual.py
DELETED
@@ -1,136 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
from taming.modules.losses.lpips import LPIPS
|
6 |
-
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
-
|
8 |
-
|
9 |
-
class DummyLoss(nn.Module):
|
10 |
-
def __init__(self):
|
11 |
-
super().__init__()
|
12 |
-
|
13 |
-
|
14 |
-
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
15 |
-
if global_step < threshold:
|
16 |
-
weight = value
|
17 |
-
return weight
|
18 |
-
|
19 |
-
|
20 |
-
def hinge_d_loss(logits_real, logits_fake):
|
21 |
-
loss_real = torch.mean(F.relu(1. - logits_real))
|
22 |
-
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
23 |
-
d_loss = 0.5 * (loss_real + loss_fake)
|
24 |
-
return d_loss
|
25 |
-
|
26 |
-
|
27 |
-
def vanilla_d_loss(logits_real, logits_fake):
|
28 |
-
d_loss = 0.5 * (
|
29 |
-
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
30 |
-
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
31 |
-
return d_loss
|
32 |
-
|
33 |
-
|
34 |
-
class VQLPIPSWithDiscriminator(nn.Module):
|
35 |
-
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
36 |
-
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
37 |
-
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
38 |
-
disc_ndf=64, disc_loss="hinge"):
|
39 |
-
super().__init__()
|
40 |
-
assert disc_loss in ["hinge", "vanilla"]
|
41 |
-
self.codebook_weight = codebook_weight
|
42 |
-
self.pixel_weight = pixelloss_weight
|
43 |
-
self.perceptual_loss = LPIPS().eval()
|
44 |
-
self.perceptual_weight = perceptual_weight
|
45 |
-
|
46 |
-
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
47 |
-
n_layers=disc_num_layers,
|
48 |
-
use_actnorm=use_actnorm,
|
49 |
-
ndf=disc_ndf
|
50 |
-
).apply(weights_init)
|
51 |
-
self.discriminator_iter_start = disc_start
|
52 |
-
if disc_loss == "hinge":
|
53 |
-
self.disc_loss = hinge_d_loss
|
54 |
-
elif disc_loss == "vanilla":
|
55 |
-
self.disc_loss = vanilla_d_loss
|
56 |
-
else:
|
57 |
-
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
58 |
-
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
59 |
-
self.disc_factor = disc_factor
|
60 |
-
self.discriminator_weight = disc_weight
|
61 |
-
self.disc_conditional = disc_conditional
|
62 |
-
|
63 |
-
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
64 |
-
if last_layer is not None:
|
65 |
-
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
66 |
-
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
67 |
-
else:
|
68 |
-
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
69 |
-
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
70 |
-
|
71 |
-
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
72 |
-
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
73 |
-
d_weight = d_weight * self.discriminator_weight
|
74 |
-
return d_weight
|
75 |
-
|
76 |
-
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
77 |
-
global_step, last_layer=None, cond=None, split="train"):
|
78 |
-
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
79 |
-
if self.perceptual_weight > 0:
|
80 |
-
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
81 |
-
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
82 |
-
else:
|
83 |
-
p_loss = torch.tensor([0.0])
|
84 |
-
|
85 |
-
nll_loss = rec_loss
|
86 |
-
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
87 |
-
nll_loss = torch.mean(nll_loss)
|
88 |
-
|
89 |
-
# now the GAN part
|
90 |
-
if optimizer_idx == 0:
|
91 |
-
# generator update
|
92 |
-
if cond is None:
|
93 |
-
assert not self.disc_conditional
|
94 |
-
logits_fake = self.discriminator(reconstructions.contiguous())
|
95 |
-
else:
|
96 |
-
assert self.disc_conditional
|
97 |
-
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
98 |
-
g_loss = -torch.mean(logits_fake)
|
99 |
-
|
100 |
-
try:
|
101 |
-
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
102 |
-
except RuntimeError:
|
103 |
-
assert not self.training
|
104 |
-
d_weight = torch.tensor(0.0)
|
105 |
-
|
106 |
-
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
107 |
-
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
108 |
-
|
109 |
-
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
110 |
-
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
111 |
-
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
112 |
-
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
113 |
-
"{}/p_loss".format(split): p_loss.detach().mean(),
|
114 |
-
"{}/d_weight".format(split): d_weight.detach(),
|
115 |
-
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
116 |
-
"{}/g_loss".format(split): g_loss.detach().mean(),
|
117 |
-
}
|
118 |
-
return loss, log
|
119 |
-
|
120 |
-
if optimizer_idx == 1:
|
121 |
-
# second pass for discriminator update
|
122 |
-
if cond is None:
|
123 |
-
logits_real = self.discriminator(inputs.contiguous().detach())
|
124 |
-
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
125 |
-
else:
|
126 |
-
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
127 |
-
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
128 |
-
|
129 |
-
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
130 |
-
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
131 |
-
|
132 |
-
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
133 |
-
"{}/logits_real".format(split): logits_real.detach().mean(),
|
134 |
-
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
135 |
-
}
|
136 |
-
return d_loss, log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/misc/coord.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
class CoordStage(object):
|
4 |
-
def __init__(self, n_embed, down_factor):
|
5 |
-
self.n_embed = n_embed
|
6 |
-
self.down_factor = down_factor
|
7 |
-
|
8 |
-
def eval(self):
|
9 |
-
return self
|
10 |
-
|
11 |
-
def encode(self, c):
|
12 |
-
"""fake vqmodel interface"""
|
13 |
-
assert 0.0 <= c.min() and c.max() <= 1.0
|
14 |
-
b,ch,h,w = c.shape
|
15 |
-
assert ch == 1
|
16 |
-
|
17 |
-
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
18 |
-
mode="area")
|
19 |
-
c = c.clamp(0.0, 1.0)
|
20 |
-
c = self.n_embed*c
|
21 |
-
c_quant = c.round()
|
22 |
-
c_ind = c_quant.to(dtype=torch.long)
|
23 |
-
|
24 |
-
info = None, None, c_ind
|
25 |
-
return c_quant, None, info
|
26 |
-
|
27 |
-
def decode(self, c):
|
28 |
-
c = c/self.n_embed
|
29 |
-
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
30 |
-
mode="nearest")
|
31 |
-
return c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/transformer/mingpt.py
DELETED
@@ -1,415 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
taken from: https://github.com/karpathy/minGPT/
|
3 |
-
GPT model:
|
4 |
-
- the initial stem consists of a combination of token encoding and a positional encoding
|
5 |
-
- the meat of it is a uniform sequence of Transformer blocks
|
6 |
-
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
7 |
-
- all blocks feed into a central residual pathway similar to resnets
|
8 |
-
- the final decoder is a linear projection into a vanilla Softmax classifier
|
9 |
-
"""
|
10 |
-
|
11 |
-
import math
|
12 |
-
import logging
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
from torch.nn import functional as F
|
17 |
-
from transformers import top_k_top_p_filtering
|
18 |
-
|
19 |
-
logger = logging.getLogger(__name__)
|
20 |
-
|
21 |
-
|
22 |
-
class GPTConfig:
|
23 |
-
""" base GPT config, params common to all GPT versions """
|
24 |
-
embd_pdrop = 0.1
|
25 |
-
resid_pdrop = 0.1
|
26 |
-
attn_pdrop = 0.1
|
27 |
-
|
28 |
-
def __init__(self, vocab_size, block_size, **kwargs):
|
29 |
-
self.vocab_size = vocab_size
|
30 |
-
self.block_size = block_size
|
31 |
-
for k,v in kwargs.items():
|
32 |
-
setattr(self, k, v)
|
33 |
-
|
34 |
-
|
35 |
-
class GPT1Config(GPTConfig):
|
36 |
-
""" GPT-1 like network roughly 125M params """
|
37 |
-
n_layer = 12
|
38 |
-
n_head = 12
|
39 |
-
n_embd = 768
|
40 |
-
|
41 |
-
|
42 |
-
class CausalSelfAttention(nn.Module):
|
43 |
-
"""
|
44 |
-
A vanilla multi-head masked self-attention layer with a projection at the end.
|
45 |
-
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
46 |
-
explicit implementation here to show that there is nothing too scary here.
|
47 |
-
"""
|
48 |
-
|
49 |
-
def __init__(self, config):
|
50 |
-
super().__init__()
|
51 |
-
assert config.n_embd % config.n_head == 0
|
52 |
-
# key, query, value projections for all heads
|
53 |
-
self.key = nn.Linear(config.n_embd, config.n_embd)
|
54 |
-
self.query = nn.Linear(config.n_embd, config.n_embd)
|
55 |
-
self.value = nn.Linear(config.n_embd, config.n_embd)
|
56 |
-
# regularization
|
57 |
-
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
58 |
-
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
59 |
-
# output projection
|
60 |
-
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
61 |
-
# causal mask to ensure that attention is only applied to the left in the input sequence
|
62 |
-
mask = torch.tril(torch.ones(config.block_size,
|
63 |
-
config.block_size))
|
64 |
-
if hasattr(config, "n_unmasked"):
|
65 |
-
mask[:config.n_unmasked, :config.n_unmasked] = 1
|
66 |
-
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
|
67 |
-
self.n_head = config.n_head
|
68 |
-
|
69 |
-
def forward(self, x, layer_past=None):
|
70 |
-
B, T, C = x.size()
|
71 |
-
|
72 |
-
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
73 |
-
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
74 |
-
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
75 |
-
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
76 |
-
|
77 |
-
present = torch.stack((k, v))
|
78 |
-
if layer_past is not None:
|
79 |
-
past_key, past_value = layer_past
|
80 |
-
k = torch.cat((past_key, k), dim=-2)
|
81 |
-
v = torch.cat((past_value, v), dim=-2)
|
82 |
-
|
83 |
-
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
84 |
-
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
85 |
-
if layer_past is None:
|
86 |
-
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
87 |
-
|
88 |
-
att = F.softmax(att, dim=-1)
|
89 |
-
att = self.attn_drop(att)
|
90 |
-
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
91 |
-
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
92 |
-
|
93 |
-
# output projection
|
94 |
-
y = self.resid_drop(self.proj(y))
|
95 |
-
return y, present # TODO: check that this does not break anything
|
96 |
-
|
97 |
-
|
98 |
-
class Block(nn.Module):
|
99 |
-
""" an unassuming Transformer block """
|
100 |
-
def __init__(self, config):
|
101 |
-
super().__init__()
|
102 |
-
self.ln1 = nn.LayerNorm(config.n_embd)
|
103 |
-
self.ln2 = nn.LayerNorm(config.n_embd)
|
104 |
-
self.attn = CausalSelfAttention(config)
|
105 |
-
self.mlp = nn.Sequential(
|
106 |
-
nn.Linear(config.n_embd, 4 * config.n_embd),
|
107 |
-
nn.GELU(), # nice
|
108 |
-
nn.Linear(4 * config.n_embd, config.n_embd),
|
109 |
-
nn.Dropout(config.resid_pdrop),
|
110 |
-
)
|
111 |
-
|
112 |
-
def forward(self, x, layer_past=None, return_present=False):
|
113 |
-
# TODO: check that training still works
|
114 |
-
if return_present: assert not self.training
|
115 |
-
# layer past: tuple of length two with B, nh, T, hs
|
116 |
-
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
|
117 |
-
|
118 |
-
x = x + attn
|
119 |
-
x = x + self.mlp(self.ln2(x))
|
120 |
-
if layer_past is not None or return_present:
|
121 |
-
return x, present
|
122 |
-
return x
|
123 |
-
|
124 |
-
|
125 |
-
class GPT(nn.Module):
|
126 |
-
""" the full GPT language model, with a context size of block_size """
|
127 |
-
def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
|
128 |
-
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
129 |
-
super().__init__()
|
130 |
-
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
131 |
-
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
132 |
-
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
133 |
-
n_unmasked=n_unmasked)
|
134 |
-
# input embedding stem
|
135 |
-
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
136 |
-
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
137 |
-
self.drop = nn.Dropout(config.embd_pdrop)
|
138 |
-
# transformer
|
139 |
-
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
140 |
-
# decoder head
|
141 |
-
self.ln_f = nn.LayerNorm(config.n_embd)
|
142 |
-
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
143 |
-
self.block_size = config.block_size
|
144 |
-
self.apply(self._init_weights)
|
145 |
-
self.config = config
|
146 |
-
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
147 |
-
|
148 |
-
def get_block_size(self):
|
149 |
-
return self.block_size
|
150 |
-
|
151 |
-
def _init_weights(self, module):
|
152 |
-
if isinstance(module, (nn.Linear, nn.Embedding)):
|
153 |
-
module.weight.data.normal_(mean=0.0, std=0.02)
|
154 |
-
if isinstance(module, nn.Linear) and module.bias is not None:
|
155 |
-
module.bias.data.zero_()
|
156 |
-
elif isinstance(module, nn.LayerNorm):
|
157 |
-
module.bias.data.zero_()
|
158 |
-
module.weight.data.fill_(1.0)
|
159 |
-
|
160 |
-
def forward(self, idx, embeddings=None, targets=None):
|
161 |
-
# forward the GPT model
|
162 |
-
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
163 |
-
|
164 |
-
if embeddings is not None: # prepend explicit embeddings
|
165 |
-
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
166 |
-
|
167 |
-
t = token_embeddings.shape[1]
|
168 |
-
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
169 |
-
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
170 |
-
x = self.drop(token_embeddings + position_embeddings)
|
171 |
-
x = self.blocks(x)
|
172 |
-
x = self.ln_f(x)
|
173 |
-
logits = self.head(x)
|
174 |
-
|
175 |
-
# if we are given some desired targets also calculate the loss
|
176 |
-
loss = None
|
177 |
-
if targets is not None:
|
178 |
-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
179 |
-
|
180 |
-
return logits, loss
|
181 |
-
|
182 |
-
def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
|
183 |
-
# inference only
|
184 |
-
assert not self.training
|
185 |
-
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
186 |
-
if embeddings is not None: # prepend explicit embeddings
|
187 |
-
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
188 |
-
|
189 |
-
if past is not None:
|
190 |
-
assert past_length is not None
|
191 |
-
past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
|
192 |
-
past_shape = list(past.shape)
|
193 |
-
expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
|
194 |
-
assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
|
195 |
-
position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
|
196 |
-
else:
|
197 |
-
position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
|
198 |
-
|
199 |
-
x = self.drop(token_embeddings + position_embeddings)
|
200 |
-
presents = [] # accumulate over layers
|
201 |
-
for i, block in enumerate(self.blocks):
|
202 |
-
x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
|
203 |
-
presents.append(present)
|
204 |
-
|
205 |
-
x = self.ln_f(x)
|
206 |
-
logits = self.head(x)
|
207 |
-
# if we are given some desired targets also calculate the loss
|
208 |
-
loss = None
|
209 |
-
if targets is not None:
|
210 |
-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
211 |
-
|
212 |
-
return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
|
213 |
-
|
214 |
-
|
215 |
-
class DummyGPT(nn.Module):
|
216 |
-
# for debugging
|
217 |
-
def __init__(self, add_value=1):
|
218 |
-
super().__init__()
|
219 |
-
self.add_value = add_value
|
220 |
-
|
221 |
-
def forward(self, idx):
|
222 |
-
return idx + self.add_value, None
|
223 |
-
|
224 |
-
|
225 |
-
class CodeGPT(nn.Module):
|
226 |
-
"""Takes in semi-embeddings"""
|
227 |
-
def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
|
228 |
-
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
229 |
-
super().__init__()
|
230 |
-
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
231 |
-
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
232 |
-
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
233 |
-
n_unmasked=n_unmasked)
|
234 |
-
# input embedding stem
|
235 |
-
self.tok_emb = nn.Linear(in_channels, config.n_embd)
|
236 |
-
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
237 |
-
self.drop = nn.Dropout(config.embd_pdrop)
|
238 |
-
# transformer
|
239 |
-
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
240 |
-
# decoder head
|
241 |
-
self.ln_f = nn.LayerNorm(config.n_embd)
|
242 |
-
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
243 |
-
self.block_size = config.block_size
|
244 |
-
self.apply(self._init_weights)
|
245 |
-
self.config = config
|
246 |
-
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
247 |
-
|
248 |
-
def get_block_size(self):
|
249 |
-
return self.block_size
|
250 |
-
|
251 |
-
def _init_weights(self, module):
|
252 |
-
if isinstance(module, (nn.Linear, nn.Embedding)):
|
253 |
-
module.weight.data.normal_(mean=0.0, std=0.02)
|
254 |
-
if isinstance(module, nn.Linear) and module.bias is not None:
|
255 |
-
module.bias.data.zero_()
|
256 |
-
elif isinstance(module, nn.LayerNorm):
|
257 |
-
module.bias.data.zero_()
|
258 |
-
module.weight.data.fill_(1.0)
|
259 |
-
|
260 |
-
def forward(self, idx, embeddings=None, targets=None):
|
261 |
-
# forward the GPT model
|
262 |
-
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
263 |
-
|
264 |
-
if embeddings is not None: # prepend explicit embeddings
|
265 |
-
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
266 |
-
|
267 |
-
t = token_embeddings.shape[1]
|
268 |
-
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
269 |
-
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
270 |
-
x = self.drop(token_embeddings + position_embeddings)
|
271 |
-
x = self.blocks(x)
|
272 |
-
x = self.taming_cinln_f(x)
|
273 |
-
logits = self.head(x)
|
274 |
-
|
275 |
-
# if we are given some desired targets also calculate the loss
|
276 |
-
loss = None
|
277 |
-
if targets is not None:
|
278 |
-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
279 |
-
|
280 |
-
return logits, loss
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
#### sampling utils
|
285 |
-
|
286 |
-
def top_k_logits(logits, k):
|
287 |
-
v, ix = torch.topk(logits, k)
|
288 |
-
out = logits.clone()
|
289 |
-
out[out < v[:, [-1]]] = -float('Inf')
|
290 |
-
return out
|
291 |
-
|
292 |
-
@torch.no_grad()
|
293 |
-
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
294 |
-
"""
|
295 |
-
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
296 |
-
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
297 |
-
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
298 |
-
of block_size, unlike an RNN that has an infinite context window.
|
299 |
-
"""
|
300 |
-
block_size = model.get_block_size()
|
301 |
-
model.eval()
|
302 |
-
for k in range(steps):
|
303 |
-
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
304 |
-
logits, _ = model(x_cond)
|
305 |
-
# pluck the logits at the final step and scale by temperature
|
306 |
-
logits = logits[:, -1, :] / temperature
|
307 |
-
# optionally crop probabilities to only the top k options
|
308 |
-
if top_k is not None:
|
309 |
-
logits = top_k_logits(logits, top_k)
|
310 |
-
# apply softmax to convert to probabilities
|
311 |
-
probs = F.softmax(logits, dim=-1)
|
312 |
-
# sample from the distribution or take the most likely
|
313 |
-
if sample:
|
314 |
-
ix = torch.multinomial(probs, num_samples=1)
|
315 |
-
else:
|
316 |
-
_, ix = torch.topk(probs, k=1, dim=-1)
|
317 |
-
# append to the sequence and continue
|
318 |
-
x = torch.cat((x, ix), dim=1)
|
319 |
-
|
320 |
-
return x
|
321 |
-
|
322 |
-
|
323 |
-
@torch.no_grad()
|
324 |
-
def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
|
325 |
-
top_k=None, top_p=None, callback=None):
|
326 |
-
# x is conditioning
|
327 |
-
sample = x
|
328 |
-
cond_len = x.shape[1]
|
329 |
-
past = None
|
330 |
-
for n in range(steps):
|
331 |
-
if callback is not None:
|
332 |
-
callback(n)
|
333 |
-
logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
|
334 |
-
if past is None:
|
335 |
-
past = [present]
|
336 |
-
else:
|
337 |
-
past.append(present)
|
338 |
-
logits = logits[:, -1, :] / temperature
|
339 |
-
if top_k is not None:
|
340 |
-
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
341 |
-
|
342 |
-
probs = F.softmax(logits, dim=-1)
|
343 |
-
if not sample_logits:
|
344 |
-
_, x = torch.topk(probs, k=1, dim=-1)
|
345 |
-
else:
|
346 |
-
x = torch.multinomial(probs, num_samples=1)
|
347 |
-
# append to the sequence and continue
|
348 |
-
sample = torch.cat((sample, x), dim=1)
|
349 |
-
del past
|
350 |
-
sample = sample[:, cond_len:] # cut conditioning off
|
351 |
-
return sample
|
352 |
-
|
353 |
-
|
354 |
-
#### clustering utils
|
355 |
-
|
356 |
-
class KMeans(nn.Module):
|
357 |
-
def __init__(self, ncluster=512, nc=3, niter=10):
|
358 |
-
super().__init__()
|
359 |
-
self.ncluster = ncluster
|
360 |
-
self.nc = nc
|
361 |
-
self.niter = niter
|
362 |
-
self.shape = (3,32,32)
|
363 |
-
self.register_buffer("C", torch.zeros(self.ncluster,nc))
|
364 |
-
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
365 |
-
|
366 |
-
def is_initialized(self):
|
367 |
-
return self.initialized.item() == 1
|
368 |
-
|
369 |
-
@torch.no_grad()
|
370 |
-
def initialize(self, x):
|
371 |
-
N, D = x.shape
|
372 |
-
assert D == self.nc, D
|
373 |
-
c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
|
374 |
-
for i in range(self.niter):
|
375 |
-
# assign all pixels to the closest codebook element
|
376 |
-
a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
|
377 |
-
# move each codebook element to be the mean of the pixels that assigned to it
|
378 |
-
c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
|
379 |
-
# re-assign any poorly positioned codebook elements
|
380 |
-
nanix = torch.any(torch.isnan(c), dim=1)
|
381 |
-
ndead = nanix.sum().item()
|
382 |
-
print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
|
383 |
-
c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
|
384 |
-
|
385 |
-
self.C.copy_(c)
|
386 |
-
self.initialized.fill_(1)
|
387 |
-
|
388 |
-
|
389 |
-
def forward(self, x, reverse=False, shape=None):
|
390 |
-
if not reverse:
|
391 |
-
# flatten
|
392 |
-
bs,c,h,w = x.shape
|
393 |
-
assert c == self.nc
|
394 |
-
x = x.reshape(bs,c,h*w,1)
|
395 |
-
C = self.C.permute(1,0)
|
396 |
-
C = C.reshape(1,c,1,self.ncluster)
|
397 |
-
a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
|
398 |
-
return a
|
399 |
-
else:
|
400 |
-
# flatten
|
401 |
-
bs, HW = x.shape
|
402 |
-
"""
|
403 |
-
c = self.C.reshape( 1, self.nc, 1, self.ncluster)
|
404 |
-
c = c[bs*[0],:,:,:]
|
405 |
-
c = c[:,:,HW*[0],:]
|
406 |
-
x = x.reshape(bs, 1, HW, 1)
|
407 |
-
x = x[:,3*[0],:,:]
|
408 |
-
x = torch.gather(c, dim=3, index=x)
|
409 |
-
"""
|
410 |
-
x = self.C[x]
|
411 |
-
x = x.permute(0,2,1)
|
412 |
-
shape = shape if shape is not None else self.shape
|
413 |
-
x = x.reshape(bs, *shape)
|
414 |
-
|
415 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/transformer/permuter.py
DELETED
@@ -1,248 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
|
6 |
-
class AbstractPermuter(nn.Module):
|
7 |
-
def __init__(self, *args, **kwargs):
|
8 |
-
super().__init__()
|
9 |
-
def forward(self, x, reverse=False):
|
10 |
-
raise NotImplementedError
|
11 |
-
|
12 |
-
|
13 |
-
class Identity(AbstractPermuter):
|
14 |
-
def __init__(self):
|
15 |
-
super().__init__()
|
16 |
-
|
17 |
-
def forward(self, x, reverse=False):
|
18 |
-
return x
|
19 |
-
|
20 |
-
|
21 |
-
class Subsample(AbstractPermuter):
|
22 |
-
def __init__(self, H, W):
|
23 |
-
super().__init__()
|
24 |
-
C = 1
|
25 |
-
indices = np.arange(H*W).reshape(C,H,W)
|
26 |
-
while min(H, W) > 1:
|
27 |
-
indices = indices.reshape(C,H//2,2,W//2,2)
|
28 |
-
indices = indices.transpose(0,2,4,1,3)
|
29 |
-
indices = indices.reshape(C*4,H//2, W//2)
|
30 |
-
H = H//2
|
31 |
-
W = W//2
|
32 |
-
C = C*4
|
33 |
-
assert H == W == 1
|
34 |
-
idx = torch.tensor(indices.ravel())
|
35 |
-
self.register_buffer('forward_shuffle_idx',
|
36 |
-
nn.Parameter(idx, requires_grad=False))
|
37 |
-
self.register_buffer('backward_shuffle_idx',
|
38 |
-
nn.Parameter(torch.argsort(idx), requires_grad=False))
|
39 |
-
|
40 |
-
def forward(self, x, reverse=False):
|
41 |
-
if not reverse:
|
42 |
-
return x[:, self.forward_shuffle_idx]
|
43 |
-
else:
|
44 |
-
return x[:, self.backward_shuffle_idx]
|
45 |
-
|
46 |
-
|
47 |
-
def mortonify(i, j):
|
48 |
-
"""(i,j) index to linear morton code"""
|
49 |
-
i = np.uint64(i)
|
50 |
-
j = np.uint64(j)
|
51 |
-
|
52 |
-
z = np.uint(0)
|
53 |
-
|
54 |
-
for pos in range(32):
|
55 |
-
z = (z |
|
56 |
-
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
|
57 |
-
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
|
58 |
-
)
|
59 |
-
return z
|
60 |
-
|
61 |
-
|
62 |
-
class ZCurve(AbstractPermuter):
|
63 |
-
def __init__(self, H, W):
|
64 |
-
super().__init__()
|
65 |
-
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
|
66 |
-
idx = np.argsort(reverseidx)
|
67 |
-
idx = torch.tensor(idx)
|
68 |
-
reverseidx = torch.tensor(reverseidx)
|
69 |
-
self.register_buffer('forward_shuffle_idx',
|
70 |
-
idx)
|
71 |
-
self.register_buffer('backward_shuffle_idx',
|
72 |
-
reverseidx)
|
73 |
-
|
74 |
-
def forward(self, x, reverse=False):
|
75 |
-
if not reverse:
|
76 |
-
return x[:, self.forward_shuffle_idx]
|
77 |
-
else:
|
78 |
-
return x[:, self.backward_shuffle_idx]
|
79 |
-
|
80 |
-
|
81 |
-
class SpiralOut(AbstractPermuter):
|
82 |
-
def __init__(self, H, W):
|
83 |
-
super().__init__()
|
84 |
-
assert H == W
|
85 |
-
size = W
|
86 |
-
indices = np.arange(size*size).reshape(size,size)
|
87 |
-
|
88 |
-
i0 = size//2
|
89 |
-
j0 = size//2-1
|
90 |
-
|
91 |
-
i = i0
|
92 |
-
j = j0
|
93 |
-
|
94 |
-
idx = [indices[i0, j0]]
|
95 |
-
step_mult = 0
|
96 |
-
for c in range(1, size//2+1):
|
97 |
-
step_mult += 1
|
98 |
-
# steps left
|
99 |
-
for k in range(step_mult):
|
100 |
-
i = i - 1
|
101 |
-
j = j
|
102 |
-
idx.append(indices[i, j])
|
103 |
-
|
104 |
-
# step down
|
105 |
-
for k in range(step_mult):
|
106 |
-
i = i
|
107 |
-
j = j + 1
|
108 |
-
idx.append(indices[i, j])
|
109 |
-
|
110 |
-
step_mult += 1
|
111 |
-
if c < size//2:
|
112 |
-
# step right
|
113 |
-
for k in range(step_mult):
|
114 |
-
i = i + 1
|
115 |
-
j = j
|
116 |
-
idx.append(indices[i, j])
|
117 |
-
|
118 |
-
# step up
|
119 |
-
for k in range(step_mult):
|
120 |
-
i = i
|
121 |
-
j = j - 1
|
122 |
-
idx.append(indices[i, j])
|
123 |
-
else:
|
124 |
-
# end reached
|
125 |
-
for k in range(step_mult-1):
|
126 |
-
i = i + 1
|
127 |
-
idx.append(indices[i, j])
|
128 |
-
|
129 |
-
assert len(idx) == size*size
|
130 |
-
idx = torch.tensor(idx)
|
131 |
-
self.register_buffer('forward_shuffle_idx', idx)
|
132 |
-
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
133 |
-
|
134 |
-
def forward(self, x, reverse=False):
|
135 |
-
if not reverse:
|
136 |
-
return x[:, self.forward_shuffle_idx]
|
137 |
-
else:
|
138 |
-
return x[:, self.backward_shuffle_idx]
|
139 |
-
|
140 |
-
|
141 |
-
class SpiralIn(AbstractPermuter):
|
142 |
-
def __init__(self, H, W):
|
143 |
-
super().__init__()
|
144 |
-
assert H == W
|
145 |
-
size = W
|
146 |
-
indices = np.arange(size*size).reshape(size,size)
|
147 |
-
|
148 |
-
i0 = size//2
|
149 |
-
j0 = size//2-1
|
150 |
-
|
151 |
-
i = i0
|
152 |
-
j = j0
|
153 |
-
|
154 |
-
idx = [indices[i0, j0]]
|
155 |
-
step_mult = 0
|
156 |
-
for c in range(1, size//2+1):
|
157 |
-
step_mult += 1
|
158 |
-
# steps left
|
159 |
-
for k in range(step_mult):
|
160 |
-
i = i - 1
|
161 |
-
j = j
|
162 |
-
idx.append(indices[i, j])
|
163 |
-
|
164 |
-
# step down
|
165 |
-
for k in range(step_mult):
|
166 |
-
i = i
|
167 |
-
j = j + 1
|
168 |
-
idx.append(indices[i, j])
|
169 |
-
|
170 |
-
step_mult += 1
|
171 |
-
if c < size//2:
|
172 |
-
# step right
|
173 |
-
for k in range(step_mult):
|
174 |
-
i = i + 1
|
175 |
-
j = j
|
176 |
-
idx.append(indices[i, j])
|
177 |
-
|
178 |
-
# step up
|
179 |
-
for k in range(step_mult):
|
180 |
-
i = i
|
181 |
-
j = j - 1
|
182 |
-
idx.append(indices[i, j])
|
183 |
-
else:
|
184 |
-
# end reached
|
185 |
-
for k in range(step_mult-1):
|
186 |
-
i = i + 1
|
187 |
-
idx.append(indices[i, j])
|
188 |
-
|
189 |
-
assert len(idx) == size*size
|
190 |
-
idx = idx[::-1]
|
191 |
-
idx = torch.tensor(idx)
|
192 |
-
self.register_buffer('forward_shuffle_idx', idx)
|
193 |
-
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
194 |
-
|
195 |
-
def forward(self, x, reverse=False):
|
196 |
-
if not reverse:
|
197 |
-
return x[:, self.forward_shuffle_idx]
|
198 |
-
else:
|
199 |
-
return x[:, self.backward_shuffle_idx]
|
200 |
-
|
201 |
-
|
202 |
-
class Random(nn.Module):
|
203 |
-
def __init__(self, H, W):
|
204 |
-
super().__init__()
|
205 |
-
indices = np.random.RandomState(1).permutation(H*W)
|
206 |
-
idx = torch.tensor(indices.ravel())
|
207 |
-
self.register_buffer('forward_shuffle_idx', idx)
|
208 |
-
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
209 |
-
|
210 |
-
def forward(self, x, reverse=False):
|
211 |
-
if not reverse:
|
212 |
-
return x[:, self.forward_shuffle_idx]
|
213 |
-
else:
|
214 |
-
return x[:, self.backward_shuffle_idx]
|
215 |
-
|
216 |
-
|
217 |
-
class AlternateParsing(AbstractPermuter):
|
218 |
-
def __init__(self, H, W):
|
219 |
-
super().__init__()
|
220 |
-
indices = np.arange(W*H).reshape(H,W)
|
221 |
-
for i in range(1, H, 2):
|
222 |
-
indices[i, :] = indices[i, ::-1]
|
223 |
-
idx = indices.flatten()
|
224 |
-
assert len(idx) == H*W
|
225 |
-
idx = torch.tensor(idx)
|
226 |
-
self.register_buffer('forward_shuffle_idx', idx)
|
227 |
-
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
228 |
-
|
229 |
-
def forward(self, x, reverse=False):
|
230 |
-
if not reverse:
|
231 |
-
return x[:, self.forward_shuffle_idx]
|
232 |
-
else:
|
233 |
-
return x[:, self.backward_shuffle_idx]
|
234 |
-
|
235 |
-
|
236 |
-
if __name__ == "__main__":
|
237 |
-
p0 = AlternateParsing(16, 16)
|
238 |
-
print(p0.forward_shuffle_idx)
|
239 |
-
print(p0.backward_shuffle_idx)
|
240 |
-
|
241 |
-
x = torch.randint(0, 768, size=(11, 256))
|
242 |
-
y = p0(x)
|
243 |
-
xre = p0(y, reverse=True)
|
244 |
-
assert torch.equal(x, xre)
|
245 |
-
|
246 |
-
p1 = SpiralOut(2, 2)
|
247 |
-
print(p1.forward_shuffle_idx)
|
248 |
-
print(p1.backward_shuffle_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/util.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
|
5 |
-
def count_params(model):
|
6 |
-
total_params = sum(p.numel() for p in model.parameters())
|
7 |
-
return total_params
|
8 |
-
|
9 |
-
|
10 |
-
class ActNorm(nn.Module):
|
11 |
-
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
-
allow_reverse_init=False):
|
13 |
-
assert affine
|
14 |
-
super().__init__()
|
15 |
-
self.logdet = logdet
|
16 |
-
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
-
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
-
self.allow_reverse_init = allow_reverse_init
|
19 |
-
|
20 |
-
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
-
|
22 |
-
def initialize(self, input):
|
23 |
-
with torch.no_grad():
|
24 |
-
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
-
mean = (
|
26 |
-
flatten.mean(1)
|
27 |
-
.unsqueeze(1)
|
28 |
-
.unsqueeze(2)
|
29 |
-
.unsqueeze(3)
|
30 |
-
.permute(1, 0, 2, 3)
|
31 |
-
)
|
32 |
-
std = (
|
33 |
-
flatten.std(1)
|
34 |
-
.unsqueeze(1)
|
35 |
-
.unsqueeze(2)
|
36 |
-
.unsqueeze(3)
|
37 |
-
.permute(1, 0, 2, 3)
|
38 |
-
)
|
39 |
-
|
40 |
-
self.loc.data.copy_(-mean)
|
41 |
-
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
-
|
43 |
-
def forward(self, input, reverse=False):
|
44 |
-
if reverse:
|
45 |
-
return self.reverse(input)
|
46 |
-
if len(input.shape) == 2:
|
47 |
-
input = input[:,:,None,None]
|
48 |
-
squeeze = True
|
49 |
-
else:
|
50 |
-
squeeze = False
|
51 |
-
|
52 |
-
_, _, height, width = input.shape
|
53 |
-
|
54 |
-
if self.training and self.initialized.item() == 0:
|
55 |
-
self.initialize(input)
|
56 |
-
self.initialized.fill_(1)
|
57 |
-
|
58 |
-
h = self.scale * (input + self.loc)
|
59 |
-
|
60 |
-
if squeeze:
|
61 |
-
h = h.squeeze(-1).squeeze(-1)
|
62 |
-
|
63 |
-
if self.logdet:
|
64 |
-
log_abs = torch.log(torch.abs(self.scale))
|
65 |
-
logdet = height*width*torch.sum(log_abs)
|
66 |
-
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
-
return h, logdet
|
68 |
-
|
69 |
-
return h
|
70 |
-
|
71 |
-
def reverse(self, output):
|
72 |
-
if self.training and self.initialized.item() == 0:
|
73 |
-
if not self.allow_reverse_init:
|
74 |
-
raise RuntimeError(
|
75 |
-
"Initializing ActNorm in reverse direction is "
|
76 |
-
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
-
)
|
78 |
-
else:
|
79 |
-
self.initialize(output)
|
80 |
-
self.initialized.fill_(1)
|
81 |
-
|
82 |
-
if len(output.shape) == 2:
|
83 |
-
output = output[:,:,None,None]
|
84 |
-
squeeze = True
|
85 |
-
else:
|
86 |
-
squeeze = False
|
87 |
-
|
88 |
-
h = output / self.scale - self.loc
|
89 |
-
|
90 |
-
if squeeze:
|
91 |
-
h = h.squeeze(-1).squeeze(-1)
|
92 |
-
return h
|
93 |
-
|
94 |
-
|
95 |
-
class AbstractEncoder(nn.Module):
|
96 |
-
def __init__(self):
|
97 |
-
super().__init__()
|
98 |
-
|
99 |
-
def encode(self, *args, **kwargs):
|
100 |
-
raise NotImplementedError
|
101 |
-
|
102 |
-
|
103 |
-
class Labelator(AbstractEncoder):
|
104 |
-
"""Net2Net Interface for Class-Conditional Model"""
|
105 |
-
def __init__(self, n_classes, quantize_interface=True):
|
106 |
-
super().__init__()
|
107 |
-
self.n_classes = n_classes
|
108 |
-
self.quantize_interface = quantize_interface
|
109 |
-
|
110 |
-
def encode(self, c):
|
111 |
-
c = c[:,None]
|
112 |
-
if self.quantize_interface:
|
113 |
-
return c, None, [None, None, c.long()]
|
114 |
-
return c
|
115 |
-
|
116 |
-
|
117 |
-
class SOSProvider(AbstractEncoder):
|
118 |
-
# for unconditional training
|
119 |
-
def __init__(self, sos_token, quantize_interface=True):
|
120 |
-
super().__init__()
|
121 |
-
self.sos_token = sos_token
|
122 |
-
self.quantize_interface = quantize_interface
|
123 |
-
|
124 |
-
def encode(self, x):
|
125 |
-
# get batch size from data and replicate sos_token
|
126 |
-
c = torch.ones(x.shape[0], 1)*self.sos_token
|
127 |
-
c = c.long().to(x.device)
|
128 |
-
if self.quantize_interface:
|
129 |
-
return c, None, [None, None, c]
|
130 |
-
return c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/modules/vqvae/quantize.py
DELETED
@@ -1,445 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import numpy as np
|
5 |
-
from torch import einsum
|
6 |
-
from einops import rearrange
|
7 |
-
|
8 |
-
|
9 |
-
class VectorQuantizer(nn.Module):
|
10 |
-
"""
|
11 |
-
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
12 |
-
____________________________________________
|
13 |
-
Discretization bottleneck part of the VQ-VAE.
|
14 |
-
Inputs:
|
15 |
-
- n_e : number of embeddings
|
16 |
-
- e_dim : dimension of embedding
|
17 |
-
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
18 |
-
_____________________________________________
|
19 |
-
"""
|
20 |
-
|
21 |
-
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
22 |
-
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
23 |
-
# used wherever VectorQuantizer has been used before and is additionally
|
24 |
-
# more efficient.
|
25 |
-
def __init__(self, n_e, e_dim, beta):
|
26 |
-
super(VectorQuantizer, self).__init__()
|
27 |
-
self.n_e = n_e
|
28 |
-
self.e_dim = e_dim
|
29 |
-
self.beta = beta
|
30 |
-
|
31 |
-
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
32 |
-
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
33 |
-
|
34 |
-
def forward(self, z):
|
35 |
-
"""
|
36 |
-
Inputs the output of the encoder network z and maps it to a discrete
|
37 |
-
one-hot vector that is the index of the closest embedding vector e_j
|
38 |
-
z (continuous) -> z_q (discrete)
|
39 |
-
z.shape = (batch, channel, height, width)
|
40 |
-
quantization pipeline:
|
41 |
-
1. get encoder input (B,C,H,W)
|
42 |
-
2. flatten input to (B*H*W,C)
|
43 |
-
"""
|
44 |
-
# reshape z -> (batch, height, width, channel) and flatten
|
45 |
-
z = z.permute(0, 2, 3, 1).contiguous()
|
46 |
-
z_flattened = z.view(-1, self.e_dim)
|
47 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
48 |
-
|
49 |
-
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
50 |
-
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
51 |
-
torch.matmul(z_flattened, self.embedding.weight.t())
|
52 |
-
|
53 |
-
## could possible replace this here
|
54 |
-
# #\start...
|
55 |
-
# find closest encodings
|
56 |
-
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
57 |
-
|
58 |
-
min_encodings = torch.zeros(
|
59 |
-
min_encoding_indices.shape[0], self.n_e).to(z)
|
60 |
-
min_encodings.scatter_(1, min_encoding_indices, 1)
|
61 |
-
|
62 |
-
# dtype min encodings: torch.float32
|
63 |
-
# min_encodings shape: torch.Size([2048, 512])
|
64 |
-
# min_encoding_indices.shape: torch.Size([2048, 1])
|
65 |
-
|
66 |
-
# get quantized latent vectors
|
67 |
-
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
68 |
-
#.........\end
|
69 |
-
|
70 |
-
# with:
|
71 |
-
# .........\start
|
72 |
-
#min_encoding_indices = torch.argmin(d, dim=1)
|
73 |
-
#z_q = self.embedding(min_encoding_indices)
|
74 |
-
# ......\end......... (TODO)
|
75 |
-
|
76 |
-
# compute loss for embedding
|
77 |
-
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
78 |
-
torch.mean((z_q - z.detach()) ** 2)
|
79 |
-
|
80 |
-
# preserve gradients
|
81 |
-
z_q = z + (z_q - z).detach()
|
82 |
-
|
83 |
-
# perplexity
|
84 |
-
e_mean = torch.mean(min_encodings, dim=0)
|
85 |
-
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
86 |
-
|
87 |
-
# reshape back to match original input shape
|
88 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
89 |
-
|
90 |
-
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
91 |
-
|
92 |
-
def get_codebook_entry(self, indices, shape):
|
93 |
-
# shape specifying (batch, height, width, channel)
|
94 |
-
# TODO: check for more easy handling with nn.Embedding
|
95 |
-
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
96 |
-
min_encodings.scatter_(1, indices[:,None], 1)
|
97 |
-
|
98 |
-
# get quantized latent vectors
|
99 |
-
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
100 |
-
|
101 |
-
if shape is not None:
|
102 |
-
z_q = z_q.view(shape)
|
103 |
-
|
104 |
-
# reshape back to match original input shape
|
105 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
106 |
-
|
107 |
-
return z_q
|
108 |
-
|
109 |
-
|
110 |
-
class GumbelQuantize(nn.Module):
|
111 |
-
"""
|
112 |
-
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
113 |
-
Gumbel Softmax trick quantizer
|
114 |
-
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
115 |
-
https://arxiv.org/abs/1611.01144
|
116 |
-
"""
|
117 |
-
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
118 |
-
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
119 |
-
remap=None, unknown_index="random"):
|
120 |
-
super().__init__()
|
121 |
-
|
122 |
-
self.embedding_dim = embedding_dim
|
123 |
-
self.n_embed = n_embed
|
124 |
-
|
125 |
-
self.straight_through = straight_through
|
126 |
-
self.temperature = temp_init
|
127 |
-
self.kl_weight = kl_weight
|
128 |
-
|
129 |
-
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
130 |
-
self.embed = nn.Embedding(n_embed, embedding_dim)
|
131 |
-
|
132 |
-
self.use_vqinterface = use_vqinterface
|
133 |
-
|
134 |
-
self.remap = remap
|
135 |
-
if self.remap is not None:
|
136 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
137 |
-
self.re_embed = self.used.shape[0]
|
138 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
139 |
-
if self.unknown_index == "extra":
|
140 |
-
self.unknown_index = self.re_embed
|
141 |
-
self.re_embed = self.re_embed+1
|
142 |
-
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
143 |
-
f"Using {self.unknown_index} for unknown indices.")
|
144 |
-
else:
|
145 |
-
self.re_embed = n_embed
|
146 |
-
|
147 |
-
def remap_to_used(self, inds):
|
148 |
-
ishape = inds.shape
|
149 |
-
assert len(ishape)>1
|
150 |
-
inds = inds.reshape(ishape[0],-1)
|
151 |
-
used = self.used.to(inds)
|
152 |
-
match = (inds[:,:,None]==used[None,None,...]).long()
|
153 |
-
new = match.argmax(-1)
|
154 |
-
unknown = match.sum(2)<1
|
155 |
-
if self.unknown_index == "random":
|
156 |
-
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
157 |
-
else:
|
158 |
-
new[unknown] = self.unknown_index
|
159 |
-
return new.reshape(ishape)
|
160 |
-
|
161 |
-
def unmap_to_all(self, inds):
|
162 |
-
ishape = inds.shape
|
163 |
-
assert len(ishape)>1
|
164 |
-
inds = inds.reshape(ishape[0],-1)
|
165 |
-
used = self.used.to(inds)
|
166 |
-
if self.re_embed > self.used.shape[0]: # extra token
|
167 |
-
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
168 |
-
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
169 |
-
return back.reshape(ishape)
|
170 |
-
|
171 |
-
def forward(self, z, temp=None, return_logits=False):
|
172 |
-
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
173 |
-
hard = self.straight_through if self.training else True
|
174 |
-
temp = self.temperature if temp is None else temp
|
175 |
-
|
176 |
-
logits = self.proj(z)
|
177 |
-
if self.remap is not None:
|
178 |
-
# continue only with used logits
|
179 |
-
full_zeros = torch.zeros_like(logits)
|
180 |
-
logits = logits[:,self.used,...]
|
181 |
-
|
182 |
-
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
183 |
-
if self.remap is not None:
|
184 |
-
# go back to all entries but unused set to zero
|
185 |
-
full_zeros[:,self.used,...] = soft_one_hot
|
186 |
-
soft_one_hot = full_zeros
|
187 |
-
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
188 |
-
|
189 |
-
# + kl divergence to the prior loss
|
190 |
-
qy = F.softmax(logits, dim=1)
|
191 |
-
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
192 |
-
|
193 |
-
ind = soft_one_hot.argmax(dim=1)
|
194 |
-
if self.remap is not None:
|
195 |
-
ind = self.remap_to_used(ind)
|
196 |
-
if self.use_vqinterface:
|
197 |
-
if return_logits:
|
198 |
-
return z_q, diff, (None, None, ind), logits
|
199 |
-
return z_q, diff, (None, None, ind)
|
200 |
-
return z_q, diff, ind
|
201 |
-
|
202 |
-
def get_codebook_entry(self, indices, shape):
|
203 |
-
b, h, w, c = shape
|
204 |
-
assert b*h*w == indices.shape[0]
|
205 |
-
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
206 |
-
if self.remap is not None:
|
207 |
-
indices = self.unmap_to_all(indices)
|
208 |
-
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
209 |
-
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
210 |
-
return z_q
|
211 |
-
|
212 |
-
|
213 |
-
class VectorQuantizer2(nn.Module):
|
214 |
-
"""
|
215 |
-
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
216 |
-
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
217 |
-
"""
|
218 |
-
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
219 |
-
# backwards compatibility we use the buggy version by default, but you can
|
220 |
-
# specify legacy=False to fix it.
|
221 |
-
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
222 |
-
sane_index_shape=False, legacy=True):
|
223 |
-
super().__init__()
|
224 |
-
self.n_e = n_e
|
225 |
-
self.e_dim = e_dim
|
226 |
-
self.beta = beta
|
227 |
-
self.legacy = legacy
|
228 |
-
|
229 |
-
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
230 |
-
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
231 |
-
|
232 |
-
self.remap = remap
|
233 |
-
if self.remap is not None:
|
234 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
235 |
-
self.re_embed = self.used.shape[0]
|
236 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
237 |
-
if self.unknown_index == "extra":
|
238 |
-
self.unknown_index = self.re_embed
|
239 |
-
self.re_embed = self.re_embed+1
|
240 |
-
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
241 |
-
f"Using {self.unknown_index} for unknown indices.")
|
242 |
-
else:
|
243 |
-
self.re_embed = n_e
|
244 |
-
|
245 |
-
self.sane_index_shape = sane_index_shape
|
246 |
-
|
247 |
-
def remap_to_used(self, inds):
|
248 |
-
ishape = inds.shape
|
249 |
-
assert len(ishape)>1
|
250 |
-
inds = inds.reshape(ishape[0],-1)
|
251 |
-
used = self.used.to(inds)
|
252 |
-
match = (inds[:,:,None]==used[None,None,...]).long()
|
253 |
-
new = match.argmax(-1)
|
254 |
-
unknown = match.sum(2)<1
|
255 |
-
if self.unknown_index == "random":
|
256 |
-
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
257 |
-
else:
|
258 |
-
new[unknown] = self.unknown_index
|
259 |
-
return new.reshape(ishape)
|
260 |
-
|
261 |
-
def unmap_to_all(self, inds):
|
262 |
-
ishape = inds.shape
|
263 |
-
assert len(ishape)>1
|
264 |
-
inds = inds.reshape(ishape[0],-1)
|
265 |
-
used = self.used.to(inds)
|
266 |
-
if self.re_embed > self.used.shape[0]: # extra token
|
267 |
-
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
268 |
-
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
269 |
-
return back.reshape(ishape)
|
270 |
-
|
271 |
-
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
272 |
-
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
|
273 |
-
assert rescale_logits==False, "Only for interface compatible with Gumbel"
|
274 |
-
assert return_logits==False, "Only for interface compatible with Gumbel"
|
275 |
-
# reshape z -> (batch, height, width, channel) and flatten
|
276 |
-
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
277 |
-
z_flattened = z.view(-1, self.e_dim)
|
278 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
279 |
-
|
280 |
-
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
281 |
-
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
282 |
-
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
283 |
-
|
284 |
-
min_encoding_indices = torch.argmin(d, dim=1)
|
285 |
-
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
286 |
-
perplexity = None
|
287 |
-
min_encodings = None
|
288 |
-
|
289 |
-
# compute loss for embedding
|
290 |
-
if not self.legacy:
|
291 |
-
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
292 |
-
torch.mean((z_q - z.detach()) ** 2)
|
293 |
-
else:
|
294 |
-
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
295 |
-
torch.mean((z_q - z.detach()) ** 2)
|
296 |
-
|
297 |
-
# preserve gradients
|
298 |
-
z_q = z + (z_q - z).detach()
|
299 |
-
|
300 |
-
# reshape back to match original input shape
|
301 |
-
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
302 |
-
|
303 |
-
if self.remap is not None:
|
304 |
-
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
305 |
-
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
306 |
-
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
307 |
-
|
308 |
-
if self.sane_index_shape:
|
309 |
-
min_encoding_indices = min_encoding_indices.reshape(
|
310 |
-
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
311 |
-
|
312 |
-
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
313 |
-
|
314 |
-
def get_codebook_entry(self, indices, shape):
|
315 |
-
# shape specifying (batch, height, width, channel)
|
316 |
-
if self.remap is not None:
|
317 |
-
indices = indices.reshape(shape[0],-1) # add batch axis
|
318 |
-
indices = self.unmap_to_all(indices)
|
319 |
-
indices = indices.reshape(-1) # flatten again
|
320 |
-
|
321 |
-
# get quantized latent vectors
|
322 |
-
z_q = self.embedding(indices)
|
323 |
-
|
324 |
-
if shape is not None:
|
325 |
-
z_q = z_q.view(shape)
|
326 |
-
# reshape back to match original input shape
|
327 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
328 |
-
|
329 |
-
return z_q
|
330 |
-
|
331 |
-
class EmbeddingEMA(nn.Module):
|
332 |
-
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
333 |
-
super().__init__()
|
334 |
-
self.decay = decay
|
335 |
-
self.eps = eps
|
336 |
-
weight = torch.randn(num_tokens, codebook_dim)
|
337 |
-
self.weight = nn.Parameter(weight, requires_grad = False)
|
338 |
-
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
|
339 |
-
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
|
340 |
-
self.update = True
|
341 |
-
|
342 |
-
def forward(self, embed_id):
|
343 |
-
return F.embedding(embed_id, self.weight)
|
344 |
-
|
345 |
-
def cluster_size_ema_update(self, new_cluster_size):
|
346 |
-
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
347 |
-
|
348 |
-
def embed_avg_ema_update(self, new_embed_avg):
|
349 |
-
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
350 |
-
|
351 |
-
def weight_update(self, num_tokens):
|
352 |
-
n = self.cluster_size.sum()
|
353 |
-
smoothed_cluster_size = (
|
354 |
-
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
355 |
-
)
|
356 |
-
#normalize embedding average with smoothed cluster size
|
357 |
-
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
358 |
-
self.weight.data.copy_(embed_normalized)
|
359 |
-
|
360 |
-
|
361 |
-
class EMAVectorQuantizer(nn.Module):
|
362 |
-
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
363 |
-
remap=None, unknown_index="random"):
|
364 |
-
super().__init__()
|
365 |
-
self.codebook_dim = codebook_dim
|
366 |
-
self.num_tokens = num_tokens
|
367 |
-
self.beta = beta
|
368 |
-
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
369 |
-
|
370 |
-
self.remap = remap
|
371 |
-
if self.remap is not None:
|
372 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
373 |
-
self.re_embed = self.used.shape[0]
|
374 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
375 |
-
if self.unknown_index == "extra":
|
376 |
-
self.unknown_index = self.re_embed
|
377 |
-
self.re_embed = self.re_embed+1
|
378 |
-
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
379 |
-
f"Using {self.unknown_index} for unknown indices.")
|
380 |
-
else:
|
381 |
-
self.re_embed = n_embed
|
382 |
-
|
383 |
-
def remap_to_used(self, inds):
|
384 |
-
ishape = inds.shape
|
385 |
-
assert len(ishape)>1
|
386 |
-
inds = inds.reshape(ishape[0],-1)
|
387 |
-
used = self.used.to(inds)
|
388 |
-
match = (inds[:,:,None]==used[None,None,...]).long()
|
389 |
-
new = match.argmax(-1)
|
390 |
-
unknown = match.sum(2)<1
|
391 |
-
if self.unknown_index == "random":
|
392 |
-
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
393 |
-
else:
|
394 |
-
new[unknown] = self.unknown_index
|
395 |
-
return new.reshape(ishape)
|
396 |
-
|
397 |
-
def unmap_to_all(self, inds):
|
398 |
-
ishape = inds.shape
|
399 |
-
assert len(ishape)>1
|
400 |
-
inds = inds.reshape(ishape[0],-1)
|
401 |
-
used = self.used.to(inds)
|
402 |
-
if self.re_embed > self.used.shape[0]: # extra token
|
403 |
-
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
404 |
-
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
405 |
-
return back.reshape(ishape)
|
406 |
-
|
407 |
-
def forward(self, z):
|
408 |
-
# reshape z -> (batch, height, width, channel) and flatten
|
409 |
-
#z, 'b c h w -> b h w c'
|
410 |
-
z = rearrange(z, 'b c h w -> b h w c')
|
411 |
-
z_flattened = z.reshape(-1, self.codebook_dim)
|
412 |
-
|
413 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
414 |
-
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
415 |
-
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
416 |
-
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
417 |
-
|
418 |
-
|
419 |
-
encoding_indices = torch.argmin(d, dim=1)
|
420 |
-
|
421 |
-
z_q = self.embedding(encoding_indices).view(z.shape)
|
422 |
-
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
423 |
-
avg_probs = torch.mean(encodings, dim=0)
|
424 |
-
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
425 |
-
|
426 |
-
if self.training and self.embedding.update:
|
427 |
-
#EMA cluster size
|
428 |
-
encodings_sum = encodings.sum(0)
|
429 |
-
self.embedding.cluster_size_ema_update(encodings_sum)
|
430 |
-
#EMA embedding average
|
431 |
-
embed_sum = encodings.transpose(0,1) @ z_flattened
|
432 |
-
self.embedding.embed_avg_ema_update(embed_sum)
|
433 |
-
#normalize embed_avg and update weight
|
434 |
-
self.embedding.weight_update(self.num_tokens)
|
435 |
-
|
436 |
-
# compute loss for embedding
|
437 |
-
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
438 |
-
|
439 |
-
# preserve gradients
|
440 |
-
z_q = z + (z_q - z).detach()
|
441 |
-
|
442 |
-
# reshape back to match original input shape
|
443 |
-
#z_q, 'b h w c -> b c h w'
|
444 |
-
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
445 |
-
return z_q, loss, (perplexity, encodings, encoding_indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming/util.py
DELETED
@@ -1,157 +0,0 @@
|
|
1 |
-
import os, hashlib
|
2 |
-
import requests
|
3 |
-
from tqdm import tqdm
|
4 |
-
|
5 |
-
URL_MAP = {
|
6 |
-
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
7 |
-
}
|
8 |
-
|
9 |
-
CKPT_MAP = {
|
10 |
-
"vgg_lpips": "vgg.pth"
|
11 |
-
}
|
12 |
-
|
13 |
-
MD5_MAP = {
|
14 |
-
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
15 |
-
}
|
16 |
-
|
17 |
-
|
18 |
-
def download(url, local_path, chunk_size=1024):
|
19 |
-
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
20 |
-
with requests.get(url, stream=True) as r:
|
21 |
-
total_size = int(r.headers.get("content-length", 0))
|
22 |
-
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
23 |
-
with open(local_path, "wb") as f:
|
24 |
-
for data in r.iter_content(chunk_size=chunk_size):
|
25 |
-
if data:
|
26 |
-
f.write(data)
|
27 |
-
pbar.update(chunk_size)
|
28 |
-
|
29 |
-
|
30 |
-
def md5_hash(path):
|
31 |
-
with open(path, "rb") as f:
|
32 |
-
content = f.read()
|
33 |
-
return hashlib.md5(content).hexdigest()
|
34 |
-
|
35 |
-
|
36 |
-
def get_ckpt_path(name, root, check=False):
|
37 |
-
assert name in URL_MAP
|
38 |
-
path = os.path.join(root, CKPT_MAP[name])
|
39 |
-
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
40 |
-
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
41 |
-
download(URL_MAP[name], path)
|
42 |
-
md5 = md5_hash(path)
|
43 |
-
assert md5 == MD5_MAP[name], md5
|
44 |
-
return path
|
45 |
-
|
46 |
-
|
47 |
-
class KeyNotFoundError(Exception):
|
48 |
-
def __init__(self, cause, keys=None, visited=None):
|
49 |
-
self.cause = cause
|
50 |
-
self.keys = keys
|
51 |
-
self.visited = visited
|
52 |
-
messages = list()
|
53 |
-
if keys is not None:
|
54 |
-
messages.append("Key not found: {}".format(keys))
|
55 |
-
if visited is not None:
|
56 |
-
messages.append("Visited: {}".format(visited))
|
57 |
-
messages.append("Cause:\n{}".format(cause))
|
58 |
-
message = "\n".join(messages)
|
59 |
-
super().__init__(message)
|
60 |
-
|
61 |
-
|
62 |
-
def retrieve(
|
63 |
-
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
64 |
-
):
|
65 |
-
"""Given a nested list or dict return the desired value at key expanding
|
66 |
-
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
67 |
-
is done in-place.
|
68 |
-
|
69 |
-
Parameters
|
70 |
-
----------
|
71 |
-
list_or_dict : list or dict
|
72 |
-
Possibly nested list or dictionary.
|
73 |
-
key : str
|
74 |
-
key/to/value, path like string describing all keys necessary to
|
75 |
-
consider to get to the desired value. List indices can also be
|
76 |
-
passed here.
|
77 |
-
splitval : str
|
78 |
-
String that defines the delimiter between keys of the
|
79 |
-
different depth levels in `key`.
|
80 |
-
default : obj
|
81 |
-
Value returned if :attr:`key` is not found.
|
82 |
-
expand : bool
|
83 |
-
Whether to expand callable nodes on the path or not.
|
84 |
-
|
85 |
-
Returns
|
86 |
-
-------
|
87 |
-
The desired value or if :attr:`default` is not ``None`` and the
|
88 |
-
:attr:`key` is not found returns ``default``.
|
89 |
-
|
90 |
-
Raises
|
91 |
-
------
|
92 |
-
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
93 |
-
``None``.
|
94 |
-
"""
|
95 |
-
|
96 |
-
keys = key.split(splitval)
|
97 |
-
|
98 |
-
success = True
|
99 |
-
try:
|
100 |
-
visited = []
|
101 |
-
parent = None
|
102 |
-
last_key = None
|
103 |
-
for key in keys:
|
104 |
-
if callable(list_or_dict):
|
105 |
-
if not expand:
|
106 |
-
raise KeyNotFoundError(
|
107 |
-
ValueError(
|
108 |
-
"Trying to get past callable node with expand=False."
|
109 |
-
),
|
110 |
-
keys=keys,
|
111 |
-
visited=visited,
|
112 |
-
)
|
113 |
-
list_or_dict = list_or_dict()
|
114 |
-
parent[last_key] = list_or_dict
|
115 |
-
|
116 |
-
last_key = key
|
117 |
-
parent = list_or_dict
|
118 |
-
|
119 |
-
try:
|
120 |
-
if isinstance(list_or_dict, dict):
|
121 |
-
list_or_dict = list_or_dict[key]
|
122 |
-
else:
|
123 |
-
list_or_dict = list_or_dict[int(key)]
|
124 |
-
except (KeyError, IndexError, ValueError) as e:
|
125 |
-
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
126 |
-
|
127 |
-
visited += [key]
|
128 |
-
# final expansion of retrieved value
|
129 |
-
if expand and callable(list_or_dict):
|
130 |
-
list_or_dict = list_or_dict()
|
131 |
-
parent[last_key] = list_or_dict
|
132 |
-
except KeyNotFoundError as e:
|
133 |
-
if default is None:
|
134 |
-
raise e
|
135 |
-
else:
|
136 |
-
list_or_dict = default
|
137 |
-
success = False
|
138 |
-
|
139 |
-
if not pass_success:
|
140 |
-
return list_or_dict
|
141 |
-
else:
|
142 |
-
return list_or_dict, success
|
143 |
-
|
144 |
-
|
145 |
-
if __name__ == "__main__":
|
146 |
-
config = {"keya": "a",
|
147 |
-
"keyb": "b",
|
148 |
-
"keyc":
|
149 |
-
{"cc1": 1,
|
150 |
-
"cc2": 2,
|
151 |
-
}
|
152 |
-
}
|
153 |
-
from omegaconf import OmegaConf
|
154 |
-
config = OmegaConf.create(config)
|
155 |
-
print(config)
|
156 |
-
retrieve(config, "keya")
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming_transformers.egg-info/PKG-INFO
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
Metadata-Version: 2.1
|
2 |
-
Name: taming-transformers
|
3 |
-
Version: 0.0.1
|
4 |
-
Summary: Taming Transformers for High-Resolution Image Synthesis
|
5 |
-
Home-page: UNKNOWN
|
6 |
-
License: UNKNOWN
|
7 |
-
Platform: UNKNOWN
|
8 |
-
|
9 |
-
UNKNOWN
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming_transformers.egg-info/SOURCES.txt
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
README.md
|
2 |
-
setup.py
|
3 |
-
taming_transformers.egg-info/PKG-INFO
|
4 |
-
taming_transformers.egg-info/SOURCES.txt
|
5 |
-
taming_transformers.egg-info/dependency_links.txt
|
6 |
-
taming_transformers.egg-info/requires.txt
|
7 |
-
taming_transformers.egg-info/top_level.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
taming-transformers/taming_transformers.egg-info/dependency_links.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|
taming-transformers/taming_transformers.egg-info/requires.txt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
torch
|
2 |
-
numpy
|
3 |
-
tqdm
|
|
|
|
|
|
|
|
taming-transformers/taming_transformers.egg-info/top_level.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|