Spaces:
Running
on
Zero
Running
on
Zero
Initial commit: track binaries with LFS
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +177 -0
- DIV2k_mask.npy +3 -0
- LCFM/SD3.5-large_MSE_DIV2k.npy +3 -0
- README.md +70 -14
- app.py +774 -0
- assets/teaser3.svg +0 -0
- configs/inpainting.yaml +45 -0
- configs/inpainting_gradio.yaml +45 -0
- configs/motion_deblur.yaml +43 -0
- configs/x12.yaml +48 -0
- configs/x12_gradio.yaml +48 -0
- demo_images/demo_0_image.png +3 -0
- demo_images/demo_0_meta.json +7 -0
- demo_images/demo_1_image.png +3 -0
- demo_images/demo_1_mask.png +3 -0
- demo_images/demo_1_meta.json +7 -0
- demo_images/demo_2_image.png +3 -0
- demo_images/demo_2_mask.png +3 -0
- demo_images/demo_2_meta.json +7 -0
- demo_images/demo_3_image.png +3 -0
- demo_images/demo_3_mask.png +3 -0
- demo_images/demo_3_meta.json +7 -0
- examples/girl.png +3 -0
- examples/sunflowers.png +3 -0
- inference_scripts/run_image_inv.py +159 -0
- requirements.txt +18 -0
- scripts/compute_metrics.py +144 -0
- scripts/generate_caption.py +107 -0
- setup.py +14 -0
- src/flair/__init__.py +0 -0
- src/flair/degradations.py +198 -0
- src/flair/functions/__init__.py +0 -0
- src/flair/functions/ckpt_util.py +72 -0
- src/flair/functions/conjugate_gradient.py +66 -0
- src/flair/functions/degradation.py +211 -0
- src/flair/functions/jpeg.py +392 -0
- src/flair/functions/measurements.py +429 -0
- src/flair/functions/nonuniform/kernels/000001.npy +3 -0
- src/flair/functions/svd_ddnm.py +206 -0
- src/flair/functions/svd_operators.py +1308 -0
- src/flair/helper_functions.py +31 -0
- src/flair/pipelines/__init__.py +0 -0
- src/flair/pipelines/model_loader.py +97 -0
- src/flair/pipelines/sd3.py +111 -0
- src/flair/pipelines/utils.py +114 -0
- src/flair/scheduling.py +882 -0
- src/flair/utils.py +406 -0
- src/flair/utils/__init__.py +0 -0
- src/flair/utils/blur_util.py +48 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SD3_*
|
2 |
+
*.pdf
|
3 |
+
sd3_*
|
4 |
+
wandb/*
|
5 |
+
output/*
|
6 |
+
data/*
|
7 |
+
vscode/*
|
8 |
+
# Byte-compiled / optimized / DLL files
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
media/*
|
13 |
+
pose/*
|
14 |
+
wandb/*
|
15 |
+
thirdparty/*
|
16 |
+
# C extensions
|
17 |
+
*.so
|
18 |
+
|
19 |
+
# Distribution / packaging
|
20 |
+
.Python
|
21 |
+
build/
|
22 |
+
develop-eggs/
|
23 |
+
dist/
|
24 |
+
downloads/
|
25 |
+
eggs/
|
26 |
+
.eggs/
|
27 |
+
lib/
|
28 |
+
lib64/
|
29 |
+
parts/
|
30 |
+
sdist/
|
31 |
+
var/
|
32 |
+
wheels/
|
33 |
+
share/python-wheels/
|
34 |
+
*.egg-info/
|
35 |
+
.installed.cfg
|
36 |
+
*.egg
|
37 |
+
MANIFEST
|
38 |
+
|
39 |
+
# PyInstaller
|
40 |
+
# Usually these files are written by a python script from a template
|
41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
42 |
+
*.manifest
|
43 |
+
*.spec
|
44 |
+
|
45 |
+
# Installer logs
|
46 |
+
pip-log.txt
|
47 |
+
pip-delete-this-directory.txt
|
48 |
+
|
49 |
+
# Unit test / coverage reports
|
50 |
+
htmlcov/
|
51 |
+
.tox/
|
52 |
+
.nox/
|
53 |
+
.coverage
|
54 |
+
.coverage.*
|
55 |
+
.cache
|
56 |
+
nosetests.xml
|
57 |
+
coverage.xml
|
58 |
+
*.cover
|
59 |
+
*.py,cover
|
60 |
+
.hypothesis/
|
61 |
+
.pytest_cache/
|
62 |
+
cover/
|
63 |
+
|
64 |
+
# Translations
|
65 |
+
*.mo
|
66 |
+
*.pot
|
67 |
+
|
68 |
+
# Django stuff:
|
69 |
+
*.log
|
70 |
+
local_settings.py
|
71 |
+
db.sqlite3
|
72 |
+
db.sqlite3-journal
|
73 |
+
|
74 |
+
# Flask stuff:
|
75 |
+
instance/
|
76 |
+
.webassets-cache
|
77 |
+
|
78 |
+
# Scrapy stuff:
|
79 |
+
.scrapy
|
80 |
+
|
81 |
+
# Sphinx documentation
|
82 |
+
docs/_build/
|
83 |
+
|
84 |
+
# PyBuilder
|
85 |
+
.pybuilder/
|
86 |
+
target/
|
87 |
+
|
88 |
+
# Jupyter Notebook
|
89 |
+
.ipynb_checkpoints
|
90 |
+
|
91 |
+
# IPython
|
92 |
+
profile_default/
|
93 |
+
ipython_config.py
|
94 |
+
|
95 |
+
# pyenv
|
96 |
+
# For a library or package, you might want to ignore these files since the code is
|
97 |
+
# intended to run in multiple environments; otherwise, check them in:
|
98 |
+
# .python-version
|
99 |
+
|
100 |
+
# pipenv
|
101 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
102 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
103 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
104 |
+
# install all needed dependencies.
|
105 |
+
#Pipfile.lock
|
106 |
+
|
107 |
+
# poetry
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
109 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
110 |
+
# commonly ignored for libraries.
|
111 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
112 |
+
#poetry.lock
|
113 |
+
|
114 |
+
# pdm
|
115 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
116 |
+
#pdm.lock
|
117 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
118 |
+
# in version control.
|
119 |
+
# https://pdm.fming.dev/#use-with-ide
|
120 |
+
.pdm.toml
|
121 |
+
|
122 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
123 |
+
__pypackages__/
|
124 |
+
|
125 |
+
# Celery stuff
|
126 |
+
celerybeat-schedule
|
127 |
+
celerybeat.pid
|
128 |
+
|
129 |
+
# SageMath parsed files
|
130 |
+
*.sage.py
|
131 |
+
|
132 |
+
# Environments
|
133 |
+
.env
|
134 |
+
.venv
|
135 |
+
env/
|
136 |
+
venv/
|
137 |
+
ENV/
|
138 |
+
env.bak/
|
139 |
+
venv.bak/
|
140 |
+
|
141 |
+
# Spyder project settings
|
142 |
+
.spyderproject
|
143 |
+
.spyproject
|
144 |
+
|
145 |
+
# Rope project settings
|
146 |
+
.ropeproject
|
147 |
+
|
148 |
+
# mkdocs documentation
|
149 |
+
/site
|
150 |
+
|
151 |
+
# mypy
|
152 |
+
.mypy_cache/
|
153 |
+
.dmypy.json
|
154 |
+
dmypy.json
|
155 |
+
|
156 |
+
# Pyre type checker
|
157 |
+
.pyre/
|
158 |
+
|
159 |
+
# pytype static type analyzer
|
160 |
+
.pytype/
|
161 |
+
|
162 |
+
# Cython debug symbols
|
163 |
+
cython_debug/
|
164 |
+
|
165 |
+
# PyCharm
|
166 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
167 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
168 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
169 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
170 |
+
#.idea/
|
171 |
+
slurm*.out
|
172 |
+
prompt_cache/*
|
173 |
+
.vscode/
|
174 |
+
cache/
|
175 |
+
notebooks/
|
176 |
+
paper_vis_scripts/
|
177 |
+
sweeps/
|
DIV2k_mask.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:774e6caa5f557de4bd3bd7e0df1de2b770974f96d703a320941ab7fc3cf66ad1
|
3 |
+
size 589952
|
LCFM/SD3.5-large_MSE_DIV2k.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4ee67869328e73b1989ec52556c21f2aed2c6c6f9e90ef33e047b2514d22d31
|
3 |
+
size 528
|
README.md
CHANGED
@@ -1,14 +1,70 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
# FLAIR: Flow-Based Latent Alignment for Image Restoration
|
4 |
+
|
5 |
+
**Julius Erbach<sup>1</sup>, Dominik Narnhofer<sup>1</sup>, Andreas Dombos<sup>1</sup>, Jan Eric Lenssen<sup>1</sup>, Bernt Schiele<sup>2</sup>, Konrad Schindler<sup>1</sup>**
|
6 |
+
<br>
|
7 |
+
<sup>1</sup> Photogrammetry and Remote Sensing, ETH Zurich
|
8 |
+
<sup>2</sup> Max Planck Institute for Informatics, Saarbrücken
|
9 |
+
|
10 |
+
[](link)
|
11 |
+
[](inverseFLAIR.github.io)
|
12 |
+
[](link)
|
13 |
+
</div>
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
<img src="assets/teaser3.svg" alt="teaser" width=98%"/>
|
17 |
+
</p>
|
18 |
+
<p align="center">
|
19 |
+
<emph>FLAIR</emph> is a novel approach for solving inverse imaging problems using flow-based posterior sampling.
|
20 |
+
</p>
|
21 |
+
|
22 |
+
## Installation
|
23 |
+
|
24 |
+
1. Clone the repository:
|
25 |
+
```bash
|
26 |
+
git clone <your-repo-url>
|
27 |
+
cd <your-repo-name>
|
28 |
+
```
|
29 |
+
|
30 |
+
2. Create a virtual environment (recommended):
|
31 |
+
```bash
|
32 |
+
python3 -m venv venv
|
33 |
+
source venv/bin/activate
|
34 |
+
```
|
35 |
+
|
36 |
+
3. Install the required dependencies from `requirements.txt`:
|
37 |
+
```bash
|
38 |
+
pip install -r requirements.txt
|
39 |
+
pip install .
|
40 |
+
```
|
41 |
+
|
42 |
+
## Running Inference
|
43 |
+
|
44 |
+
To run inference, you can use one of the Python script run_image_inv.py with the according config file.
|
45 |
+
An example from the FFQH dataset
|
46 |
+
```bash
|
47 |
+
python inference_scripts/run_image_inv.py --config configs/inpainting.yaml --target_file examples/girl.png --result_folder output --prompt="a high quality photo of a face"
|
48 |
+
```
|
49 |
+
Or an example from the DIV2K dataset with captions provided by DAPE using the degraded input. The masks can be defined as rectanlge coordinates in the config file or provided as .npy file where true pixels are observed and false are masked out.
|
50 |
+
|
51 |
+
```bash
|
52 |
+
python inference_scripts/run_image_inv.py --config configs/inpainting.yaml --target_file examples/sunflowers.png --result_folder output --prompt="a high quality photo of bloom, blue, field, flower, sky, sunflower, sunflower field, yellow" --mask_file DIV2k_mask.npy
|
53 |
+
```
|
54 |
+
|
55 |
+
```bash
|
56 |
+
python inference_scripts/run_image_inv.py --config configs/x12.yaml --target_file examples/sunflowers.png --result_folder output --prompt="a high quality photo of bloom, blue, field, flower, sky, sunflower, sunflower field, yellow"
|
57 |
+
```
|
58 |
+
|
59 |
+
## Citation
|
60 |
+
|
61 |
+
If you find this work useful in your research, please cite our paper:
|
62 |
+
|
63 |
+
```bibtex
|
64 |
+
@article{er2025solving,
|
65 |
+
title={Solving Inverse Problems with FLAIR},
|
66 |
+
author={Erbach, Julius and Narnhofer, Dominik and Dombos, Andreas and Lenssen, Jan Eric and Schiele, Bernt and Schindler, Konrad},
|
67 |
+
journal={arXiv},
|
68 |
+
year={2025}
|
69 |
+
}
|
70 |
+
```
|
app.py
ADDED
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from gradio_imageslider import ImageSlider # Replaces gr.ImageCompare
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torchvision.transforms.functional as TF
|
8 |
+
import random
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import json # Added import
|
12 |
+
import copy
|
13 |
+
# Add project root to sys.path to allow direct import of var_post_samp
|
14 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
15 |
+
if project_root not in sys.path:
|
16 |
+
sys.path.insert(0, project_root)
|
17 |
+
|
18 |
+
from flair.pipelines import model_loader
|
19 |
+
from flair import var_post_samp, degradations
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
CONFIG_FILE_PATH = "./configs/inpainting_gradio.yaml"
|
24 |
+
DTYPE = torch.bfloat16
|
25 |
+
|
26 |
+
# Global variables to hold the model and config
|
27 |
+
MODEL = None
|
28 |
+
POSTERIOR_MODEL = None
|
29 |
+
BASE_CONFIG = None
|
30 |
+
DEVICES = None
|
31 |
+
PRIMARY_DEVICE = None
|
32 |
+
# project_root is already defined globally, will be used by save_configuration
|
33 |
+
|
34 |
+
SR_CONFIG_FILE_PATH = "./configs/x12_gradio.yaml"
|
35 |
+
|
36 |
+
# Function to save the current configuration for demo examples
|
37 |
+
def save_configuration(image_editor_data, image_input, prompt, seed_val, task, random_seed_bool, steps_val):
|
38 |
+
global project_root # Ensure access to the globally defined project_root
|
39 |
+
if task == "Super Resolution":
|
40 |
+
if image_input is None:
|
41 |
+
return gr.Markdown("""<p style='color:red;'>Error: No low-resolution image loaded.</p>""")
|
42 |
+
# For Super Resolution, we don't need a mask, just the image
|
43 |
+
input_image = image_input
|
44 |
+
mask_image = None
|
45 |
+
else: # Inpainting task
|
46 |
+
if image_editor_data is None or image_editor_data['background'] is None:
|
47 |
+
return gr.Markdown("""<p style='color:red;'>Error: No background image loaded.</p>""")
|
48 |
+
|
49 |
+
# Check if layers exist and the first layer (mask) is not None
|
50 |
+
if not image_editor_data['layers'] or image_editor_data['layers'][0] is None:
|
51 |
+
return gr.Markdown("""<p style='color:red;'>Error: No mask drawn. Please use the brush tool to draw a mask.</p>""")
|
52 |
+
|
53 |
+
input_image = image_editor_data['background']
|
54 |
+
mask_image = image_editor_data['layers'][0]
|
55 |
+
|
56 |
+
metadata = {
|
57 |
+
"prompt": prompt,
|
58 |
+
"seed_on_slider": int(seed_val),
|
59 |
+
"use_random_seed_checkbox": bool(random_seed_bool),
|
60 |
+
"num_steps": int(steps_val),
|
61 |
+
"task_type": task # Always inpainting for now
|
62 |
+
}
|
63 |
+
|
64 |
+
demo_images_dir = os.path.join(project_root, "demo_images")
|
65 |
+
try:
|
66 |
+
os.makedirs(demo_images_dir, exist_ok=True)
|
67 |
+
except Exception as e:
|
68 |
+
return gr.Markdown(f"""<p style='color:red;'>Error creating directory {demo_images_dir}: {str(e)}</p>""")
|
69 |
+
|
70 |
+
i = 0
|
71 |
+
while True:
|
72 |
+
base_filename = f"demo_{i}"
|
73 |
+
meta_check_path = os.path.join(demo_images_dir, f"{base_filename}_meta.json")
|
74 |
+
if not os.path.exists(meta_check_path):
|
75 |
+
break
|
76 |
+
i += 1
|
77 |
+
|
78 |
+
image_save_path = os.path.join(demo_images_dir, f"{base_filename}_image.png")
|
79 |
+
mask_save_path = os.path.join(demo_images_dir, f"{base_filename}_mask.png")
|
80 |
+
meta_save_path = os.path.join(demo_images_dir, f"{base_filename}_meta.json")
|
81 |
+
|
82 |
+
try:
|
83 |
+
input_image.save(image_save_path)
|
84 |
+
if mask_image is not None:
|
85 |
+
# Ensure mask is saved in a usable format, e.g., 'L' mode for grayscale, or 'RGBA' if it has transparency
|
86 |
+
if mask_image.mode != 'L' and mask_image.mode != '1': # If not already grayscale or binary
|
87 |
+
mask_image = mask_image.convert('RGBA') # Preserve transparency if drawn, or convert to L
|
88 |
+
mask_image.save(mask_save_path)
|
89 |
+
|
90 |
+
with open(meta_save_path, 'w') as f:
|
91 |
+
json.dump(metadata, f, indent=4)
|
92 |
+
return gr.Markdown(f"""<p style='color:green;'>Configuration saved as {base_filename} in demo_images folder.</p>""")
|
93 |
+
except Exception as e:
|
94 |
+
return gr.Markdown(f"""<p style='color:red;'>Error saving configuration: {str(e)}</p>""")
|
95 |
+
|
96 |
+
def embed_prompt(prompt, device):
|
97 |
+
print(f"Generating prompt embeddings for: {prompt}")
|
98 |
+
with torch.no_grad(): # Add torch.no_grad() here
|
99 |
+
POSTERIOR_MODEL.model.text_encoder.to(device).to(torch.bfloat16)
|
100 |
+
POSTERIOR_MODEL.model.text_encoder_2.to(device).to(torch.bfloat16)
|
101 |
+
POSTERIOR_MODEL.model.text_encoder_3.to(device).to(torch.bfloat16)
|
102 |
+
(
|
103 |
+
prompt_embeds,
|
104 |
+
negative_prompt_embeds,
|
105 |
+
pooled_prompt_embeds,
|
106 |
+
negative_pooled_prompt_embeds,
|
107 |
+
) = POSTERIOR_MODEL.model.encode_prompt(
|
108 |
+
prompt=prompt,
|
109 |
+
prompt_2=prompt,
|
110 |
+
prompt_3=prompt,
|
111 |
+
negative_prompt="",
|
112 |
+
negative_prompt_2="",
|
113 |
+
negative_prompt_3="",
|
114 |
+
do_classifier_free_guidance=POSTERIOR_MODEL.model.do_classifier_free_guidance,
|
115 |
+
prompt_embeds=None,
|
116 |
+
negative_prompt_embeds=None,
|
117 |
+
pooled_prompt_embeds=None,
|
118 |
+
negative_pooled_prompt_embeds=None,
|
119 |
+
device=device,
|
120 |
+
clip_skip=None,
|
121 |
+
num_images_per_prompt=1,
|
122 |
+
max_sequence_length=256,
|
123 |
+
lora_scale=None,
|
124 |
+
)
|
125 |
+
# POSTERIOR_MODEL.model.text_encoder.to("cpu").to(torch.bfloat16)
|
126 |
+
# POSTERIOR_MODEL.model.text_encoder_2.to("cpu").to(torch.bfloat16)
|
127 |
+
# POSTERIOR_MODEL.model.text_encoder_3.to("cpu").to(torch.bfloat16)
|
128 |
+
torch.cuda.empty_cache() # Clear GPU memory after embedding generation
|
129 |
+
return {
|
130 |
+
"prompt_embeds": prompt_embeds.to(device, dtype=DTYPE),
|
131 |
+
"negative_prompt_embeds": negative_prompt_embeds.to(device, dtype=DTYPE) if negative_prompt_embeds is not None else None,
|
132 |
+
"pooled_prompt_embeds": pooled_prompt_embeds.to(device, dtype=DTYPE),
|
133 |
+
"negative_pooled_prompt_embeds": negative_pooled_prompt_embeds.to(device, dtype=DTYPE) if negative_pooled_prompt_embeds is not None else None
|
134 |
+
}
|
135 |
+
|
136 |
+
def initialize_globals():
|
137 |
+
global MODEL, POSTERIOR_MODEL, BASE_CONFIG, DEVICES, PRIMARY_DEVICE
|
138 |
+
|
139 |
+
print("Global initialization started...")
|
140 |
+
# Setup device (run once)
|
141 |
+
if torch.cuda.is_available():
|
142 |
+
num_gpus = torch.cuda.device_count()
|
143 |
+
DEVICES = [f"cuda:{i}" for i in range(num_gpus)]
|
144 |
+
PRIMARY_DEVICE = DEVICES[0]
|
145 |
+
print(f"Initializing with devices: {DEVICES}, Primary: {PRIMARY_DEVICE}")
|
146 |
+
else:
|
147 |
+
DEVICES = ["cpu"]
|
148 |
+
PRIMARY_DEVICE = "cpu"
|
149 |
+
print("No CUDA devices found. Initializing with CPU.")
|
150 |
+
|
151 |
+
# Load base configuration (once)
|
152 |
+
with open(CONFIG_FILE_PATH, "r") as f:
|
153 |
+
BASE_CONFIG = yaml.safe_load(f)
|
154 |
+
|
155 |
+
# Prepare a temporary config for the initial model and posterior_model loading
|
156 |
+
init_config = BASE_CONFIG.copy()
|
157 |
+
|
158 |
+
# Ensure prompt/caption settings are valid for model_loader for initialization
|
159 |
+
# Forcing prompt mode for initial load.
|
160 |
+
init_config["prompt"] = [BASE_CONFIG.get("prompt", "Initialization prompt")]
|
161 |
+
init_config["caption_file"] = None
|
162 |
+
|
163 |
+
# Default values that might be needed by model_loader or utils called within
|
164 |
+
init_config.setdefault("target_file", "dummy_target.png")
|
165 |
+
init_config.setdefault("result_file", "dummy_results/")
|
166 |
+
init_config.setdefault("seed", random.randint(0, 2**32 - 1)) # Init with a random seed
|
167 |
+
|
168 |
+
print("Loading base model and variational posterior model once...")
|
169 |
+
# MODEL is the main diffusion model, loaded once.
|
170 |
+
# inp_kwargs_for_init are based on init_config, not directly used for subsequent inferences.
|
171 |
+
model_obj, _ = model_loader.load_model(init_config, device=DEVICES)
|
172 |
+
MODEL = model_obj
|
173 |
+
|
174 |
+
# Initialize VariationalPosterior once with the loaded MODEL and init_config.
|
175 |
+
# Its internal forward_operator will be based on init_config's degradation settings,
|
176 |
+
# but will be replaced in each inpaint_image call.
|
177 |
+
POSTERIOR_MODEL = var_post_samp.VariationalPosterior(MODEL, init_config)
|
178 |
+
print("Global initialization complete.")
|
179 |
+
|
180 |
+
|
181 |
+
def load_config_for_inference(prompt_text, seed=None):
|
182 |
+
# This function is now for creating a temporary config for each inference call,
|
183 |
+
# primarily to get up-to-date inp_kwargs via model_loader.
|
184 |
+
# It starts from BASE_CONFIG and applies current overrides.
|
185 |
+
if BASE_CONFIG is None:
|
186 |
+
raise RuntimeError("Base config not initialized. Call initialize_globals().")
|
187 |
+
|
188 |
+
current_config = BASE_CONFIG.copy()
|
189 |
+
|
190 |
+
current_config["prompt"] = [prompt_text] # Override with user's prompt
|
191 |
+
current_config["caption_file"] = None # Ensure we are in prompt mode
|
192 |
+
|
193 |
+
if seed is None:
|
194 |
+
seed = current_config.get("seed", random.randint(0, 2**32 - 1))
|
195 |
+
current_config["seed"] = seed
|
196 |
+
# Set global seeds for reproducibility for the current call
|
197 |
+
torch.manual_seed(seed)
|
198 |
+
np.random.seed(seed)
|
199 |
+
random.seed(seed)
|
200 |
+
print(f"Using seed for current inference: {seed}")
|
201 |
+
|
202 |
+
# Ensure other necessary fields are in 'current_config' if model_loader needs them
|
203 |
+
current_config.setdefault("target_file", "dummy_target.png")
|
204 |
+
current_config.setdefault("result_file", "dummy_results/")
|
205 |
+
|
206 |
+
return current_config
|
207 |
+
|
208 |
+
def preprocess_image(pil_image, resolution, is_mask=False):
|
209 |
+
img = pil_image.convert("RGB") if not is_mask else pil_image.convert("L")
|
210 |
+
|
211 |
+
# Calculate new dimensions to maintain aspect ratio, making shorter edge 'resolution'
|
212 |
+
original_width, original_height = img.size
|
213 |
+
if original_width < original_height:
|
214 |
+
new_short_edge = resolution
|
215 |
+
new_long_edge = int(resolution * (original_height / original_width))
|
216 |
+
new_width = new_short_edge
|
217 |
+
new_height = new_long_edge
|
218 |
+
else:
|
219 |
+
new_short_edge = resolution
|
220 |
+
new_long_edge = int(resolution * (original_width / original_height))
|
221 |
+
new_height = new_short_edge
|
222 |
+
new_width = new_long_edge
|
223 |
+
|
224 |
+
# TF.resize expects [height, width]
|
225 |
+
img = TF.resize(img, [new_height, new_width], interpolation=TF.InterpolationMode.LANCZOS)
|
226 |
+
|
227 |
+
# Center crop to the target square resolution
|
228 |
+
img = TF.center_crop(img, [resolution, resolution])
|
229 |
+
|
230 |
+
img_tensor = TF.to_tensor(img) # Scales to [0, 1]
|
231 |
+
if is_mask:
|
232 |
+
# Ensure mask is binary (0 or 1), 1 for region to inpaint
|
233 |
+
# The mask from ImageEditor is RGBA, convert to L first.
|
234 |
+
img = img.convert('L')
|
235 |
+
img_tensor = TF.to_tensor(img) # Recalculate tensor after convert
|
236 |
+
img_tensor = (img_tensor == 0.) # Threshold for mask (drawn parts are usually non-black)
|
237 |
+
img_tensor = img_tensor.repeat(3, 1, 1) # Repeat mask across 3 channels
|
238 |
+
else:
|
239 |
+
# Normalize image to [-1, 1]
|
240 |
+
img_tensor = img_tensor * 2 - 1
|
241 |
+
return img_tensor.unsqueeze(0) # Add batch dimension
|
242 |
+
|
243 |
+
def preprocess_lr_image(pil_image, resolution, device, dtype):
|
244 |
+
if pil_image is None:
|
245 |
+
raise ValueError("Input PIL image cannot be None.")
|
246 |
+
img = pil_image.convert("RGB")
|
247 |
+
|
248 |
+
# Center crop to the target square resolution (no resizing)
|
249 |
+
img = TF.center_crop(img, [resolution, resolution])
|
250 |
+
|
251 |
+
img_tensor = TF.to_tensor(img) # Scales to [0, 1]
|
252 |
+
# Normalize image to [-1, 1]
|
253 |
+
img_tensor = img_tensor * 2 - 1
|
254 |
+
return img_tensor.unsqueeze(0).to(device, dtype=dtype) # Add batch dimension and move to device
|
255 |
+
|
256 |
+
|
257 |
+
def postprocess_image(image_tensor):
|
258 |
+
# Remove batch dimension, move to CPU, convert to float
|
259 |
+
image_tensor = image_tensor.squeeze(0).cpu().float()
|
260 |
+
# Denormalize from [-1, 1] to [0, 1]
|
261 |
+
image_tensor = image_tensor * 0.5 + 0.5
|
262 |
+
# Clip values to [0, 1]
|
263 |
+
image_tensor = torch.clamp(image_tensor, 0, 1)
|
264 |
+
# Convert to PIL Image
|
265 |
+
pil_image = TF.to_pil_image(image_tensor)
|
266 |
+
return pil_image
|
267 |
+
|
268 |
+
def inpaint_image(image_editor_output, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps): # MODIFIED: seed_input changed to fixed_seed_value, use_random_seed
|
269 |
+
try:
|
270 |
+
if image_editor_output is None:
|
271 |
+
raise gr.Error("Please upload an image and draw a mask.")
|
272 |
+
|
273 |
+
input_pil = image_editor_output['background']
|
274 |
+
|
275 |
+
if not image_editor_output['layers'] or image_editor_output['layers'][0] is None:
|
276 |
+
raise gr.Error("Please draw a mask on the image using the brush tool.")
|
277 |
+
mask_pil = image_editor_output['layers'][0]
|
278 |
+
|
279 |
+
|
280 |
+
if input_pil is None:
|
281 |
+
raise gr.Error("Please upload an image.")
|
282 |
+
if mask_pil is None:
|
283 |
+
raise gr.Error("Please draw a mask on the image.")
|
284 |
+
|
285 |
+
current_seed = None
|
286 |
+
if use_random_seed:
|
287 |
+
current_seed = None # load_config_for_inference will generate a random seed
|
288 |
+
else:
|
289 |
+
try:
|
290 |
+
current_seed = int(fixed_seed_value)
|
291 |
+
except ValueError:
|
292 |
+
# This should ideally not happen with a slider, but good for robustness
|
293 |
+
raise gr.Error("Seed must be an integer.")
|
294 |
+
|
295 |
+
# Prepare config for current inference (gets prompt, seed)
|
296 |
+
current_config = load_config_for_inference(prompt_text, current_seed)
|
297 |
+
resolution = current_config["resolution"]
|
298 |
+
|
299 |
+
# MODIFIED: Set num_steps from slider into the current_config
|
300 |
+
# Assuming 'num_steps' is a key POSTERIOR_MODEL will use from its config.
|
301 |
+
# Common alternatives could be current_config['solver_kwargs']['n_steps'] = num_steps
|
302 |
+
current_config['n_steps'] = int(num_steps)
|
303 |
+
print(f"Using num_steps: {current_config['n_steps']}")
|
304 |
+
|
305 |
+
|
306 |
+
# Preprocess image and mask
|
307 |
+
guidance_img_tensor = preprocess_image(input_pil, resolution, is_mask=False).to(PRIMARY_DEVICE, dtype=DTYPE)
|
308 |
+
# Mask from ImageEditor is RGBA, preprocess_image will handle conversion to L and then binary
|
309 |
+
mask_tensor = preprocess_image(mask_pil, resolution, is_mask=True).to(PRIMARY_DEVICE, dtype=DTYPE)
|
310 |
+
|
311 |
+
# Get inp_kwargs for the CURRENT prompt and config.
|
312 |
+
print("Preparing inference inputs (e.g., prompt embeddings)...")
|
313 |
+
prompt_embeds = embed_prompt(prompt_text, device=PRIMARY_DEVICE) # Embed the prompt for the current inference
|
314 |
+
current_inp_kwargs = prompt_embeds
|
315 |
+
# MODIFIED: Use guidance_scale from slider
|
316 |
+
current_inp_kwargs['guidance'] = float(guidance_scale)
|
317 |
+
print(f"Using guidance_scale: {current_inp_kwargs['guidance']}")
|
318 |
+
|
319 |
+
# Update the global POSTERIOR_MODEL's config for this call.
|
320 |
+
# This ensures its methods use the latest settings (like num_steps) if they access self.config.
|
321 |
+
POSTERIOR_MODEL.config = current_config
|
322 |
+
POSTERIOR_MODEL.model._guidance_scale = guidance_scale
|
323 |
+
print("Applying forward operator (masking)...")
|
324 |
+
# Directly set the forward_operator on the global POSTERIOR_MODEL instance
|
325 |
+
# H and W are height and width of the guidance image tensor
|
326 |
+
POSTERIOR_MODEL.forward_operator = degradations.Inpainting(
|
327 |
+
mask=mask_tensor.bool()[0], # Inpainting often expects a boolean mask
|
328 |
+
H=guidance_img_tensor.shape[2],
|
329 |
+
W=guidance_img_tensor.shape[3],
|
330 |
+
noise_std=0,
|
331 |
+
)
|
332 |
+
y = POSTERIOR_MODEL.forward_operator(guidance_img_tensor)
|
333 |
+
|
334 |
+
print("Running inference...")
|
335 |
+
with torch.no_grad():
|
336 |
+
# Use the global POSTERIOR_MODEL instance
|
337 |
+
result_dict = POSTERIOR_MODEL.forward(y, current_inp_kwargs)
|
338 |
+
|
339 |
+
x_hat = result_dict["x_hat"]
|
340 |
+
|
341 |
+
print("Postprocessing result...")
|
342 |
+
output_pil = postprocess_image(x_hat)
|
343 |
+
|
344 |
+
# Convert mask tensor to PIL image for display
|
345 |
+
# Mask tensor is [0, 1], take one channel, convert to PIL
|
346 |
+
mask_display_tensor = mask_tensor.squeeze(0).cpu().float() # Remove batch, move to CPU
|
347 |
+
# If mask_tensor was (B, 3, H, W) and binary 0 or 1 (after repeat)
|
348 |
+
# We can take any channel, e.g., mask_display_tensor[0]
|
349 |
+
# Ensure it's (H, W) or (1, H, W) for to_pil_image
|
350 |
+
if mask_display_tensor.ndim == 3 and mask_display_tensor.shape[0] == 3: # (C, H, W)
|
351 |
+
mask_display_tensor = mask_display_tensor[0] # Take one channel (H, W)
|
352 |
+
|
353 |
+
# Ensure it's in the range [0, 1] and suitable for PIL conversion
|
354 |
+
# If it was 0. for masked and 1. for unmasked (or vice-versa depending on logic)
|
355 |
+
# TF.to_pil_image expects [0,1] for single channel float
|
356 |
+
mask_pil_display = TF.to_pil_image(mask_display_tensor)
|
357 |
+
|
358 |
+
return output_pil, [output_pil, output_pil], current_config["seed"] # MODIFIED: Removed mask_pil_display
|
359 |
+
except gr.Error as e: # Handle Gradio-specific errors first
|
360 |
+
raise
|
361 |
+
except Exception as e:
|
362 |
+
print(f"Error during inpainting: {e}")
|
363 |
+
import traceback # Ensure traceback is imported here if not globally
|
364 |
+
traceback.print_exc()
|
365 |
+
# Return a more user-friendly error message to Gradio
|
366 |
+
raise gr.Error(f"An error occurred: {str(e)}. Check console for details.")
|
367 |
+
|
368 |
+
def super_resolution_image(lr_image, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps, sr_scale_factor, downscale_input):
|
369 |
+
try:
|
370 |
+
if lr_image is None:
|
371 |
+
raise gr.Error("Please upload a low-resolution image.")
|
372 |
+
|
373 |
+
current_seed = None
|
374 |
+
if use_random_seed:
|
375 |
+
current_seed = random.randint(0, 2**32 - 1)
|
376 |
+
else:
|
377 |
+
try:
|
378 |
+
current_seed = int(fixed_seed_value)
|
379 |
+
except ValueError:
|
380 |
+
raise gr.Error("Seed must be an integer.")
|
381 |
+
|
382 |
+
# Load Super-Resolution specific configuration
|
383 |
+
if not os.path.exists(SR_CONFIG_FILE_PATH):
|
384 |
+
raise gr.Error(f"Super-resolution config file not found: {SR_CONFIG_FILE_PATH}")
|
385 |
+
with open(SR_CONFIG_FILE_PATH, "r") as f:
|
386 |
+
sr_base_config = yaml.safe_load(f)
|
387 |
+
|
388 |
+
current_sr_config = copy.deepcopy(sr_base_config) # Start with a copy of the base SR config
|
389 |
+
current_sr_config["prompt"] = [prompt_text]
|
390 |
+
current_sr_config["caption_file"] = None # Ensure prompt mode
|
391 |
+
current_sr_config["seed"] = current_seed
|
392 |
+
|
393 |
+
torch.manual_seed(current_seed)
|
394 |
+
np.random.seed(current_seed)
|
395 |
+
random.seed(current_seed)
|
396 |
+
print(f"Using seed for SR inference: {current_seed}")
|
397 |
+
|
398 |
+
current_sr_config['n_steps'] = int(num_steps)
|
399 |
+
current_sr_config["degradation"]["kwargs"]["scale"] = sr_scale_factor
|
400 |
+
current_sr_config["optimizer_dataterm"]["kwargs"]["lr"] = sr_base_config.get("optimizer_dataterm", {}).get("kwargs", {}).get("lr") * sr_scale_factor**2 / (sr_base_config.get("degradation", {}).get("kwargs", {}).get("scale")**2)
|
401 |
+
print(f"Using num_steps for SR: {current_sr_config['n_steps']}")
|
402 |
+
|
403 |
+
# Determine target HR resolution for the output
|
404 |
+
hr_resolution = current_sr_config.get("degradation", {}).get("kwargs", {}).get("img_size")
|
405 |
+
# Calculate target LR dimensions based on the chosen scale factor
|
406 |
+
target_lr_width = int(hr_resolution / sr_scale_factor)
|
407 |
+
target_lr_height = int(hr_resolution / sr_scale_factor)
|
408 |
+
print(f"Target LR dimensions for SR: {target_lr_width}x{target_lr_height} for scale x{sr_scale_factor}")
|
409 |
+
|
410 |
+
print("Preparing SR inference inputs (prompt embeddings)...")
|
411 |
+
prompt_embeds = embed_prompt(prompt_text, device=PRIMARY_DEVICE)
|
412 |
+
current_inp_kwargs = prompt_embeds
|
413 |
+
current_inp_kwargs['guidance'] = float(guidance_scale)
|
414 |
+
print(f"Using guidance_scale for SR: {current_inp_kwargs['guidance']}")
|
415 |
+
|
416 |
+
POSTERIOR_MODEL.config = current_sr_config
|
417 |
+
POSTERIOR_MODEL.model._guidance_scale = float(guidance_scale)
|
418 |
+
|
419 |
+
print("Applying SR forward operator...")
|
420 |
+
|
421 |
+
POSTERIOR_MODEL.forward_operator = degradations.SuperResGradio(
|
422 |
+
**current_sr_config["degradation"]["kwargs"]
|
423 |
+
)
|
424 |
+
|
425 |
+
if downscale_input:
|
426 |
+
y_tensor = preprocess_lr_image(lr_image, hr_resolution, PRIMARY_DEVICE, DTYPE)
|
427 |
+
# y_tensor = POSTERIOR_MODEL.forward_operator(y_tensor)
|
428 |
+
y_tensor = torch.nn.functional.interpolate(y_tensor, scale_factor=1/sr_scale_factor, mode='bilinear', align_corners=False, antialias=True)
|
429 |
+
# simulate 8bit input by quantizing to 8-bit
|
430 |
+
y_tensor = ((y_tensor * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) / 127.5 - 1.0).to(DTYPE)
|
431 |
+
else:
|
432 |
+
# check if the input image has the correct dimensions
|
433 |
+
if lr_image.size[0] != target_lr_width or lr_image.size[1] != target_lr_height:
|
434 |
+
raise gr.Error(f"Input image must be {target_lr_width}x{target_lr_height} pixels for the selected scale factor of {sr_scale_factor}.")
|
435 |
+
y_tensor = preprocess_lr_image(lr_image, target_lr_width, PRIMARY_DEVICE, DTYPE)
|
436 |
+
# add some noise to the input image
|
437 |
+
noise_std = current_sr_config.get("degradation", {}).get("kwargs", {}).get("noise_std", 0.0)
|
438 |
+
y_tensor += torch.randn_like(y_tensor) * noise_std
|
439 |
+
# save for debugging purposes
|
440 |
+
# first convert to PIL
|
441 |
+
pil_y = postprocess_image(y_tensor)# Remove batch dimension and convert to PIL
|
442 |
+
pil_y.save("debug_input_image.png") # Save the input image for debugging
|
443 |
+
|
444 |
+
|
445 |
+
print("Running SR inference...")
|
446 |
+
with torch.no_grad():
|
447 |
+
result_dict = POSTERIOR_MODEL.forward(y_tensor, current_inp_kwargs)
|
448 |
+
|
449 |
+
x_hat = result_dict["x_hat"]
|
450 |
+
|
451 |
+
print("Postprocessing SR result...")
|
452 |
+
output_pil = postprocess_image(x_hat)
|
453 |
+
|
454 |
+
# Upscale input image with nearest neighbor for comparison
|
455 |
+
upscaled_input = y_tensor.reshape(1,3,target_lr_height, target_lr_width)
|
456 |
+
upscaled_input = POSTERIOR_MODEL.forward_operator.nn(upscaled_input) # Use nearest neighbor upscaling
|
457 |
+
upscaled_input = postprocess_image(upscaled_input)
|
458 |
+
# save for debugging purposes
|
459 |
+
upscaled_input.save("debug_upscaled_input.png") # Save the upscaled input image for debugging
|
460 |
+
# upscaled_input = upscaled_input.resize((hr_resolution, hr_resolution), resample=Image.NEAREST)
|
461 |
+
return (upscaled_input, output_pil), current_sr_config["seed"]
|
462 |
+
|
463 |
+
except gr.Error as e:
|
464 |
+
raise
|
465 |
+
except Exception as e:
|
466 |
+
print(f"Error during super-resolution: {e}")
|
467 |
+
import traceback
|
468 |
+
traceback.print_exc()
|
469 |
+
raise gr.Error(f"An error occurred during super-resolution: {str(e)}. Check console for details.")
|
470 |
+
|
471 |
+
|
472 |
+
# Input for seed, allowing users to set it or leave it blank for random/config default
|
473 |
+
# Determine default num_steps from BASE_CONFIG if available
|
474 |
+
default_num_steps = 50 # Fallback default
|
475 |
+
if BASE_CONFIG is not None: # Check if BASE_CONFIG has been initialized
|
476 |
+
default_num_steps = BASE_CONFIG.get("num_steps", BASE_CONFIG.get("solver_kwargs", {}).get("num_steps", 50))
|
477 |
+
|
478 |
+
def superres_preview_preprocess(pil_image, resolution=768):
|
479 |
+
if pil_image is None:
|
480 |
+
return None
|
481 |
+
if pil_image.mode != "RGB":
|
482 |
+
pil_image = pil_image.convert("RGB")
|
483 |
+
# check if image is smaller than resolution
|
484 |
+
original_width, original_height = pil_image.size
|
485 |
+
if original_width < resolution or original_height < resolution:
|
486 |
+
return pil_image # No resizing needed, return original image
|
487 |
+
else:
|
488 |
+
pil_image = TF.center_crop(pil_image, [resolution, resolution])
|
489 |
+
return pil_image
|
490 |
+
|
491 |
+
|
492 |
+
# Dynamically load examples from demo_images directory
|
493 |
+
example_list_inp = []
|
494 |
+
example_list_sr = []
|
495 |
+
demo_images_dir = os.path.join(project_root, "static/demo_images")
|
496 |
+
|
497 |
+
if os.path.exists(demo_images_dir):
|
498 |
+
filenames = sorted(os.listdir(demo_images_dir))
|
499 |
+
processed_bases = set()
|
500 |
+
for filename in filenames:
|
501 |
+
if filename.startswith("demo_") and filename.endswith("_meta.json"):
|
502 |
+
base_name = filename[:-len("_meta.json")] # e.g., "demo_0"
|
503 |
+
if base_name in processed_bases:
|
504 |
+
continue
|
505 |
+
|
506 |
+
meta_path = os.path.join(demo_images_dir, filename)
|
507 |
+
image_filename = f"{base_name}_image.png"
|
508 |
+
image_path = os.path.join(demo_images_dir, image_filename)
|
509 |
+
mask_filename = f"{base_name}_mask.png"
|
510 |
+
mask_path = os.path.join(demo_images_dir, mask_filename)
|
511 |
+
|
512 |
+
if os.path.exists(image_path):
|
513 |
+
try:
|
514 |
+
with open(meta_path, 'r') as f:
|
515 |
+
metadata = json.load(f)
|
516 |
+
task = metadata.get("task_type")
|
517 |
+
prompt = metadata.get("prompt", "")
|
518 |
+
if task == "Super Resolution":
|
519 |
+
example_list_sr.append([image_path, prompt, task])
|
520 |
+
else:
|
521 |
+
image_editor_input = {
|
522 |
+
"background": image_path,
|
523 |
+
"layers": [mask_path],
|
524 |
+
"composite": None # Add this key to satisfy ImageEditor's as_example processing
|
525 |
+
}
|
526 |
+
example_list_inp.append([image_editor_input, prompt, task])
|
527 |
+
|
528 |
+
# Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
|
529 |
+
|
530 |
+
except json.JSONDecodeError:
|
531 |
+
print(f"Warning: Could not decode JSON from {meta_path}. Skipping example {base_name}.")
|
532 |
+
except Exception as e:
|
533 |
+
print(f"Warning: Error processing example {base_name}: {e}. Skipping.")
|
534 |
+
else:
|
535 |
+
missing_files = []
|
536 |
+
if not os.path.exists(image_path):
|
537 |
+
missing_files.append(image_filename)
|
538 |
+
if not os.path.exists(mask_path):
|
539 |
+
missing_files.append(mask_filename)
|
540 |
+
print(f"Warning: Missing files for example {base_name} ({', '.join(missing_files)}). Skipping.")
|
541 |
+
else:
|
542 |
+
print(f"Info: 'demo_images' directory not found at {demo_images_dir}. No dynamic examples will be loaded.")
|
543 |
+
|
544 |
+
|
545 |
+
if __name__ == "__main__":
|
546 |
+
if not os.path.exists(CONFIG_FILE_PATH):
|
547 |
+
print(f"ERROR: Configuration file not found at {CONFIG_FILE_PATH}")
|
548 |
+
sys.exit(1)
|
549 |
+
|
550 |
+
initialize_globals()
|
551 |
+
|
552 |
+
if MODEL is None or POSTERIOR_MODEL is None:
|
553 |
+
print("ERROR: Global model initialization failed.")
|
554 |
+
sys.exit(1)
|
555 |
+
|
556 |
+
# --- Define Gradio UI using gr.Blocks after globals are initialized ---
|
557 |
+
title_str = "Solving Inverse Problems with FLAIR: Inpainting Demo"
|
558 |
+
description_str = """
|
559 |
+
Select a task (Inpainting or Super Resolution) and upload an image.
|
560 |
+
For Inpainting, draw a mask on the image to specify the area to be filled. We observed that our model can event solve simple editing task, if provided with an appropriate prompt. For large masks the step size might need to be adjusted to e.g. 80.
|
561 |
+
For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
|
562 |
+
Use the slider to compare the low resolution input image with the super-resolved output.
|
563 |
+
|
564 |
+
"""
|
565 |
+
|
566 |
+
# Determine default values now that BASE_CONFIG is initialized
|
567 |
+
default_num_steps = BASE_CONFIG.get("num_steps", BASE_CONFIG.get("solver_kwargs", {}).get("num_steps", 50))
|
568 |
+
default_guidance_scale = BASE_CONFIG.get("guidance", 2.0)
|
569 |
+
|
570 |
+
with gr.Blocks() as iface:
|
571 |
+
gr.Markdown(f"## {title_str}")
|
572 |
+
gr.Markdown(description_str)
|
573 |
+
|
574 |
+
task_selector = gr.Dropdown(
|
575 |
+
choices=["Inpainting", "Super Resolution"],
|
576 |
+
value="Inpainting",
|
577 |
+
label="Task"
|
578 |
+
)
|
579 |
+
|
580 |
+
with gr.Row():
|
581 |
+
with gr.Column(scale=1): # Input column
|
582 |
+
# Inpainting Inputs
|
583 |
+
image_editor = gr.ImageEditor(
|
584 |
+
type="pil",
|
585 |
+
label="Upload Image & Draw Mask (for Inpainting)",
|
586 |
+
sources=["upload"],
|
587 |
+
height=512,
|
588 |
+
width=512,
|
589 |
+
visible=True
|
590 |
+
)
|
591 |
+
|
592 |
+
# Super Resolution Inputs
|
593 |
+
image_input = gr.Image(
|
594 |
+
type="pil",
|
595 |
+
label="Upload Low-Resolution Image (for Super Resolution)",
|
596 |
+
visible=False
|
597 |
+
)
|
598 |
+
|
599 |
+
sr_scale_slider = gr.Dropdown(
|
600 |
+
choices=[2, 4, 8, 12, 24],
|
601 |
+
value=12,
|
602 |
+
label="Upscaling Factor (Super Resolution)",
|
603 |
+
interactive=True,
|
604 |
+
visible=False # Initially hidden
|
605 |
+
)
|
606 |
+
downscale_input = gr.Checkbox(
|
607 |
+
label="Downscale the provided image.",
|
608 |
+
value=True,
|
609 |
+
interactive=True,
|
610 |
+
visible=False # Initially hidden
|
611 |
+
)
|
612 |
+
|
613 |
+
# Common Inputs
|
614 |
+
prompt_text = gr.Textbox(
|
615 |
+
label="Prompt",
|
616 |
+
placeholder="E.g., a beautiful landscape, a detailed portrait"
|
617 |
+
)
|
618 |
+
seed_slider = gr.Slider(
|
619 |
+
minimum=0,
|
620 |
+
maximum=2**32 -1, # Max for torch.manual_seed
|
621 |
+
step=1,
|
622 |
+
label="Seed (if not random)",
|
623 |
+
value=42,
|
624 |
+
interactive=True
|
625 |
+
)
|
626 |
+
|
627 |
+
use_random_seed_checkbox = gr.Checkbox(
|
628 |
+
label="Use Random Seed",
|
629 |
+
value=True,
|
630 |
+
interactive=True
|
631 |
+
)
|
632 |
+
guidance_scale_slider = gr.Slider(
|
633 |
+
minimum=1.0,
|
634 |
+
maximum=15.0,
|
635 |
+
step=0.5,
|
636 |
+
value=default_guidance_scale,
|
637 |
+
label="Guidance Scale"
|
638 |
+
)
|
639 |
+
num_steps_slider = gr.Slider(
|
640 |
+
minimum=28,
|
641 |
+
maximum=150,
|
642 |
+
step=1,
|
643 |
+
value=default_num_steps,
|
644 |
+
label="Number of Steps"
|
645 |
+
)
|
646 |
+
submit_button = gr.Button("Submit")
|
647 |
+
|
648 |
+
# # Add Save Configuration button and status text
|
649 |
+
# gr.Markdown("---") # Separator
|
650 |
+
# save_button = gr.Button("Save Current Configuration for Demo")
|
651 |
+
# save_status_text = gr.Markdown()
|
652 |
+
|
653 |
+
with gr.Column(scale=1): # Output column
|
654 |
+
output_image_display = gr.Image(type="pil", label="Result")
|
655 |
+
sr_compare_display = ImageSlider(label="Super-Resolution: Input vs Output", visible=False, position=0.5)
|
656 |
+
|
657 |
+
|
658 |
+
|
659 |
+
|
660 |
+
# --- Task routing and visibility logic ---
|
661 |
+
def update_visibility(task):
|
662 |
+
is_inpainting = task == "Inpainting"
|
663 |
+
is_super_resolution = task == "Super Resolution"
|
664 |
+
return {
|
665 |
+
image_editor: gr.update(visible=is_inpainting),
|
666 |
+
image_input: gr.update(visible=is_super_resolution),
|
667 |
+
sr_scale_slider: gr.update(visible=is_super_resolution),
|
668 |
+
downscale_input: gr.update(visible=is_super_resolution),
|
669 |
+
output_image_display: gr.update(visible=is_inpainting),
|
670 |
+
sr_compare_display: gr.update(visible=is_super_resolution, position=0.5),
|
671 |
+
downscale_input: gr.update(visible=is_super_resolution),
|
672 |
+
}
|
673 |
+
|
674 |
+
task_selector.change(
|
675 |
+
fn=update_visibility,
|
676 |
+
inputs=[task_selector],
|
677 |
+
outputs=[image_editor, image_input, sr_scale_slider, downscale_input, output_image_display, sr_compare_display]
|
678 |
+
)
|
679 |
+
|
680 |
+
|
681 |
+
# MODIFIED route_task to accept sr_scale_factor
|
682 |
+
def route_task(task, image_editor_data, lr_image_for_sr, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps, sr_scale_factor_value, downscale_input):
|
683 |
+
if task == "Inpainting":
|
684 |
+
return inpaint_image(image_editor_data, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps)
|
685 |
+
elif task == "Super Resolution":
|
686 |
+
result_images, seed_val = super_resolution_image(
|
687 |
+
lr_image_for_sr, prompt_text, fixed_seed_value, use_random_seed,
|
688 |
+
guidance_scale, num_steps, sr_scale_factor_value, downscale_input
|
689 |
+
)
|
690 |
+
return result_images[1], gr.update(value=result_images, position=0.5), seed_val
|
691 |
+
else:
|
692 |
+
raise gr.Error("Unsupported task.")
|
693 |
+
|
694 |
+
submit_button.click(
|
695 |
+
fn=route_task,
|
696 |
+
inputs=[
|
697 |
+
task_selector,
|
698 |
+
image_editor,
|
699 |
+
image_input,
|
700 |
+
prompt_text,
|
701 |
+
seed_slider,
|
702 |
+
use_random_seed_checkbox,
|
703 |
+
guidance_scale_slider,
|
704 |
+
num_steps_slider,
|
705 |
+
sr_scale_slider,
|
706 |
+
downscale_input,
|
707 |
+
],
|
708 |
+
outputs=[
|
709 |
+
output_image_display,
|
710 |
+
sr_compare_display,
|
711 |
+
seed_slider
|
712 |
+
]
|
713 |
+
)
|
714 |
+
|
715 |
+
# Wire up the save button
|
716 |
+
# save_button.click(
|
717 |
+
# fn=save_configuration,
|
718 |
+
# inputs=[
|
719 |
+
# image_editor,
|
720 |
+
# image_input,
|
721 |
+
# prompt_text,
|
722 |
+
# seed_slider,
|
723 |
+
# task_selector,
|
724 |
+
# use_random_seed_checkbox,
|
725 |
+
# num_steps_slider,
|
726 |
+
# ],
|
727 |
+
# outputs=[save_status_text]
|
728 |
+
# )
|
729 |
+
|
730 |
+
|
731 |
+
gr.Markdown("---") # Separator
|
732 |
+
gr.Markdown("### Click an example to load:")
|
733 |
+
def load_example(example_data, prompt, task):
|
734 |
+
image_editor_input = example_data[0]
|
735 |
+
prompt_value = example_data[1]
|
736 |
+
if task == "Inpainting":
|
737 |
+
image_editor.clear() # Clear current image and mask
|
738 |
+
if image_editor_input and image_editor_input.get("background"):
|
739 |
+
image_editor.upload_image(image_editor_input["background"])
|
740 |
+
if image_editor_input and image_editor_input.get("layers"):
|
741 |
+
for layer in image_editor_input["layers"]:
|
742 |
+
image_editor.upload_mask(layer)
|
743 |
+
elif task == "Super Resolution":
|
744 |
+
image_input.clear()
|
745 |
+
image_input.upload_image(image_editor_input)
|
746 |
+
|
747 |
+
# Set the prompt
|
748 |
+
prompt_text.value = prompt_value
|
749 |
+
# Optionally, set a random seed and guidance scale
|
750 |
+
seed_slider.value = random.randint(0, 2**32 - 1)
|
751 |
+
guidance_scale_slider.value = default_guidance_scale
|
752 |
+
# Set the task selector from the example
|
753 |
+
task_selector.set_value(task)
|
754 |
+
update_visibility(task) # Update visibility based on task
|
755 |
+
|
756 |
+
with gr.Row():
|
757 |
+
gr.Examples(
|
758 |
+
examples=example_list_sr,
|
759 |
+
inputs=[image_input, prompt_text, task_selector],
|
760 |
+
label="Super Resolution Examples",
|
761 |
+
fn=load_example,
|
762 |
+
)
|
763 |
+
with gr.Row():
|
764 |
+
gr.Examples(
|
765 |
+
examples=example_list_inp,
|
766 |
+
inputs=[image_editor, prompt_text, task_selector],
|
767 |
+
label="Inpainting Examples",
|
768 |
+
fn=load_example,
|
769 |
+
)
|
770 |
+
|
771 |
+
# --- End of Gradio UI definition ---
|
772 |
+
|
773 |
+
print("Launching Gradio demo...")
|
774 |
+
iface.launch()
|
assets/teaser3.svg
ADDED
|
configs/inpainting.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# === Model & Data Settings ===
|
2 |
+
model: "SD3"
|
3 |
+
resolution: 768
|
4 |
+
lambda_func: v
|
5 |
+
optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
|
6 |
+
regularizer_weight: 0.5
|
7 |
+
likelihood_weight_mode: reg_weight
|
8 |
+
likelihood_steps: 15
|
9 |
+
early_stopping: 1.e-4
|
10 |
+
epochs: 1
|
11 |
+
guidance: 2
|
12 |
+
quantize: False
|
13 |
+
use_tiny_ae: True
|
14 |
+
negative_prompt: ""
|
15 |
+
reg-shift: 0.0
|
16 |
+
|
17 |
+
# === Optimization Settings ===
|
18 |
+
optimizer:
|
19 |
+
name: SGD
|
20 |
+
kwargs:
|
21 |
+
lr: 1
|
22 |
+
optimizer_dataterm:
|
23 |
+
name: SGD
|
24 |
+
kwargs:
|
25 |
+
lr: 0.1
|
26 |
+
|
27 |
+
# === Sampling & Misc ===
|
28 |
+
t_sampling: descending
|
29 |
+
n_steps: 50
|
30 |
+
inv_alpha: 1-t
|
31 |
+
ts_min: 0.18
|
32 |
+
projection: False
|
33 |
+
seed: 42
|
34 |
+
|
35 |
+
# === Experiment-Specific Settings ===
|
36 |
+
prompt: A high quality photo of a girl with a pirate eye-patch.
|
37 |
+
degradation:
|
38 |
+
name: Inpainting
|
39 |
+
kwargs:
|
40 |
+
mask: [128, 640, 384, 640]
|
41 |
+
H: 768
|
42 |
+
W: 768
|
43 |
+
noise_std: 0.01
|
44 |
+
likelihood_weight: 1
|
45 |
+
|
configs/inpainting_gradio.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# === Model & Data Settings ===
|
2 |
+
model: "SD3"
|
3 |
+
resolution: 512
|
4 |
+
lambda_func: v
|
5 |
+
optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
|
6 |
+
regularizer_weight: 0.5
|
7 |
+
likelihood_weight_mode: reg_weight
|
8 |
+
likelihood_steps: 15
|
9 |
+
early_stopping: 1.e-4
|
10 |
+
epochs: 1
|
11 |
+
guidance: 2
|
12 |
+
quantize: False
|
13 |
+
use_tiny_ae: True
|
14 |
+
negative_prompt: ""
|
15 |
+
reg-shift: 0.0
|
16 |
+
|
17 |
+
# === Optimization Settings ===
|
18 |
+
optimizer:
|
19 |
+
name: SGD
|
20 |
+
kwargs:
|
21 |
+
lr: 1
|
22 |
+
optimizer_dataterm:
|
23 |
+
name: SGD
|
24 |
+
kwargs:
|
25 |
+
lr: 0.1
|
26 |
+
|
27 |
+
# === Sampling & Misc ===
|
28 |
+
t_sampling: descending
|
29 |
+
n_steps: 50
|
30 |
+
inv_alpha: 1-t
|
31 |
+
ts_min: 0.18
|
32 |
+
projection: False
|
33 |
+
seed: 42
|
34 |
+
|
35 |
+
# === Experiment-Specific Settings ===
|
36 |
+
prompt: A high quality photo of a girl with a pirate eye-patch.
|
37 |
+
degradation:
|
38 |
+
name: Inpainting
|
39 |
+
kwargs:
|
40 |
+
mask: [128, 640, 384, 640]
|
41 |
+
H: 512
|
42 |
+
W: 512
|
43 |
+
noise_std: 0.0
|
44 |
+
likelihood_weight: 1
|
45 |
+
|
configs/motion_deblur.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# === Model & Data Settings ===
|
2 |
+
model: "SD3"
|
3 |
+
resolution: 768
|
4 |
+
lambda_func: v
|
5 |
+
optimized_reg_weight: SD3_loss_v_norm_div2k_squared.npy
|
6 |
+
regularizer_weight: 0.5
|
7 |
+
likelihood_weight_mode: reg_weight
|
8 |
+
likelihood_steps: 15
|
9 |
+
early_stopping: 1.e-4
|
10 |
+
epochs: 1
|
11 |
+
guidance: 2
|
12 |
+
quantize: False
|
13 |
+
use_tiny_ae: True
|
14 |
+
negative_prompt: ""
|
15 |
+
reg-shift: 0.0
|
16 |
+
|
17 |
+
# === Optimization Settings ===
|
18 |
+
optimizer:
|
19 |
+
name: SGD
|
20 |
+
kwargs:
|
21 |
+
lr: 1
|
22 |
+
optimizer_dataterm:
|
23 |
+
name: SGD
|
24 |
+
kwargs:
|
25 |
+
lr: 0.1
|
26 |
+
|
27 |
+
# === Sampling & Misc ===
|
28 |
+
t_sampling: descending
|
29 |
+
n_steps: 50
|
30 |
+
inv_alpha: 1-t
|
31 |
+
ts_min: 0.18
|
32 |
+
projection: False
|
33 |
+
seed: 42
|
34 |
+
|
35 |
+
# === Experiment-Specific Settings ===
|
36 |
+
prompt: a high quality photo of a face
|
37 |
+
degradation:
|
38 |
+
name: MotionBlur
|
39 |
+
kwargs:
|
40 |
+
kernel_size: 61
|
41 |
+
noise_std: 0.01
|
42 |
+
img_size: 768
|
43 |
+
likelihood_weight: 1
|
configs/x12.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OmegaConf base config for neurips experiments
|
2 |
+
# Place common settings here. Each experiment config will override as needed.
|
3 |
+
|
4 |
+
# === Model & Data Settings ===
|
5 |
+
model: "SD3"
|
6 |
+
resolution: 768
|
7 |
+
lambda_func: v
|
8 |
+
optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
|
9 |
+
regularizer_weight: 0.5
|
10 |
+
likelihood_weight_mode: reg_weight
|
11 |
+
likelihood_steps: 15
|
12 |
+
early_stopping: 1.e-4
|
13 |
+
epochs: 1
|
14 |
+
guidance: 2
|
15 |
+
quantize: False
|
16 |
+
use_tiny_ae: True
|
17 |
+
negative_prompt: ""
|
18 |
+
reg-shift: 0.0
|
19 |
+
|
20 |
+
# === Optimization Settings ===
|
21 |
+
optimizer:
|
22 |
+
name: SGD
|
23 |
+
kwargs:
|
24 |
+
lr: 1
|
25 |
+
optimizer_dataterm:
|
26 |
+
name: SGD
|
27 |
+
kwargs:
|
28 |
+
lr: 12
|
29 |
+
|
30 |
+
# === Sampling & Misc ===
|
31 |
+
t_sampling: descending
|
32 |
+
n_steps: 50
|
33 |
+
inv_alpha: 1-t
|
34 |
+
quantize: False
|
35 |
+
ts_min: 0.18
|
36 |
+
projection: False
|
37 |
+
seed: 3
|
38 |
+
|
39 |
+
# === Experiment-Specific Settings ===
|
40 |
+
prompt: a high quality photo of
|
41 |
+
# Super-resolution x12 for FFHQ
|
42 |
+
degradation:
|
43 |
+
name: SuperRes
|
44 |
+
kwargs:
|
45 |
+
scale: 12
|
46 |
+
noise_std: 0.01
|
47 |
+
img_size: 768
|
48 |
+
likelihood_weight: 1
|
configs/x12_gradio.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OmegaConf base config for neurips experiments
|
2 |
+
# Place common settings here. Each experiment config will override as needed.
|
3 |
+
|
4 |
+
# === Model & Data Settings ===
|
5 |
+
model: "SD3"
|
6 |
+
resolution: 768
|
7 |
+
lambda_func: v
|
8 |
+
optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
|
9 |
+
regularizer_weight: 0.5
|
10 |
+
likelihood_weight_mode: reg_weight
|
11 |
+
likelihood_steps: 15
|
12 |
+
early_stopping: 1.e-4
|
13 |
+
epochs: 1
|
14 |
+
guidance: 2
|
15 |
+
quantize: False
|
16 |
+
use_tiny_ae: True
|
17 |
+
negative_prompt: ""
|
18 |
+
reg-shift: 0.0
|
19 |
+
|
20 |
+
# === Optimization Settings ===
|
21 |
+
optimizer:
|
22 |
+
name: SGD
|
23 |
+
kwargs:
|
24 |
+
lr: 1
|
25 |
+
optimizer_dataterm:
|
26 |
+
name: SGD
|
27 |
+
kwargs:
|
28 |
+
lr: 12
|
29 |
+
|
30 |
+
# === Sampling & Misc ===
|
31 |
+
t_sampling: descending
|
32 |
+
n_steps: 50
|
33 |
+
inv_alpha: 1-t
|
34 |
+
quantize: False
|
35 |
+
ts_min: 0.18
|
36 |
+
projection: False
|
37 |
+
seed: 3
|
38 |
+
|
39 |
+
# === Experiment-Specific Settings ===
|
40 |
+
prompt: a high quality photo of
|
41 |
+
# Super-resolution x12 for FFHQ
|
42 |
+
degradation:
|
43 |
+
name: SuperRes
|
44 |
+
kwargs:
|
45 |
+
scale: 12
|
46 |
+
noise_std: 0.0
|
47 |
+
img_size: 768
|
48 |
+
likelihood_weight: 1
|
demo_images/demo_0_image.png
ADDED
![]() |
Git LFS Details
|
demo_images/demo_0_meta.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "a high quality image of a face.",
|
3 |
+
"seed_on_slider": 2646030500,
|
4 |
+
"use_random_seed_checkbox": true,
|
5 |
+
"num_steps": 50,
|
6 |
+
"task_type": "Super Resolution"
|
7 |
+
}
|
demo_images/demo_1_image.png
ADDED
![]() |
Git LFS Details
|
demo_images/demo_1_mask.png
ADDED
![]() |
Git LFS Details
|
demo_images/demo_1_meta.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "a house",
|
3 |
+
"seed_on_slider": 434714511,
|
4 |
+
"use_random_seed_checkbox": false,
|
5 |
+
"num_steps": 50,
|
6 |
+
"task_type": "Inpainting"
|
7 |
+
}
|
demo_images/demo_2_image.png
ADDED
![]() |
Git LFS Details
|
demo_images/demo_2_mask.png
ADDED
![]() |
Git LFS Details
|
demo_images/demo_2_meta.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "a high quality image of a face.",
|
3 |
+
"seed_on_slider": 3211750901,
|
4 |
+
"use_random_seed_checkbox": false,
|
5 |
+
"num_steps": 50,
|
6 |
+
"task_type": "Inpainting"
|
7 |
+
}
|
demo_images/demo_3_image.png
ADDED
![]() |
Git LFS Details
|
demo_images/demo_3_mask.png
ADDED
![]() |
Git LFS Details
|
demo_images/demo_3_meta.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "a girl with bright blue eyes ",
|
3 |
+
"seed_on_slider": 4257757956,
|
4 |
+
"use_random_seed_checkbox": true,
|
5 |
+
"num_steps": 50,
|
6 |
+
"task_type": "Inpainting"
|
7 |
+
}
|
examples/girl.png
ADDED
![]() |
Git LFS Details
|
examples/sunflowers.png
ADDED
![]() |
Git LFS Details
|
inference_scripts/run_image_inv.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 <Julius Erbach ETH Zurich>
|
2 |
+
#
|
3 |
+
# This file is part of the var_post_samp project and is licensed under the MIT License.
|
4 |
+
# See the LICENSE file in the project root for more information.
|
5 |
+
"""
|
6 |
+
Usage:
|
7 |
+
python run_image_inv.py --config <config.yaml>
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import time
|
13 |
+
import csv
|
14 |
+
import yaml
|
15 |
+
import torch
|
16 |
+
import random
|
17 |
+
import click
|
18 |
+
import numpy as np
|
19 |
+
import tqdm
|
20 |
+
import datetime
|
21 |
+
import torchvision
|
22 |
+
|
23 |
+
from flair.helper_functions import parse_click_context
|
24 |
+
from flair.pipelines import model_loader
|
25 |
+
from flair.utils import data_utils
|
26 |
+
from flair import var_post_samp
|
27 |
+
|
28 |
+
|
29 |
+
dtype = torch.bfloat16
|
30 |
+
|
31 |
+
num_gpus = torch.cuda.device_count()
|
32 |
+
if num_gpus > 0:
|
33 |
+
devices = [f"cuda:{i}" for i in range(num_gpus)]
|
34 |
+
primary_device = devices[0]
|
35 |
+
print(f"Using devices: {devices}")
|
36 |
+
print(f"Primary device for operations: {primary_device}")
|
37 |
+
else:
|
38 |
+
print("No CUDA devices found. Using CPU.")
|
39 |
+
devices = ["cpu"]
|
40 |
+
primary_device = "cpu"
|
41 |
+
|
42 |
+
|
43 |
+
@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True})
|
44 |
+
@click.option("--config", "config_file_arg", type=click.Path(exists=True), help="Path to the config file")
|
45 |
+
@click.option("--target_file", type=click.Path(exists=True), help="Path to the target file or folder")
|
46 |
+
@click.option("--result_folder", type=click.Path(file_okay=False, dir_okay=True, writable=True, resolve_path=True), help="Path to the output folder. It will be created if it doesn't exist.")
|
47 |
+
@click.option("--mask_file", type=click.Path(exists=True), default=None, help="Path to the mask file npy. Optional used for image inpainting. True pixels are observed.")
|
48 |
+
@click.pass_context
|
49 |
+
def main(ctx, config_file_arg, target_file, result_folder, mask_file=None):
|
50 |
+
"""Main entry point for image inversion and sampling.
|
51 |
+
|
52 |
+
The user must provide either a caption_file (with per-image captions) OR a single prompt for all images in the config YAML file.
|
53 |
+
"""
|
54 |
+
with open(config_file_arg, "r") as f:
|
55 |
+
config = yaml.safe_load(f)
|
56 |
+
ctx = parse_click_context(ctx)
|
57 |
+
config.update(ctx)
|
58 |
+
# Read caption_file and prompt from config
|
59 |
+
caption_file = config.get("caption_file", None)
|
60 |
+
prompt = config.get("prompt", None)
|
61 |
+
|
62 |
+
# Enforce mutually exclusive caption_file or prompt
|
63 |
+
if (not caption_file and not prompt) or (caption_file and prompt):
|
64 |
+
raise ValueError("You must provide either 'caption_file' OR 'prompt' (not both) in the config file. See documentation.")
|
65 |
+
|
66 |
+
# wandb removed, so config_dict is just a copy
|
67 |
+
config_dict = dict(config)
|
68 |
+
torch.manual_seed(config["seed"])
|
69 |
+
np.random.seed(config["seed"])
|
70 |
+
random.seed(config["seed"])
|
71 |
+
|
72 |
+
# Use config values as-is (no to_absolute_path)
|
73 |
+
caption_file = caption_file if caption_file else None
|
74 |
+
|
75 |
+
guidance_img_iterator = data_utils.yield_images(
|
76 |
+
target_file, size=config["resolution"]
|
77 |
+
)
|
78 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
79 |
+
counter = 1
|
80 |
+
name = f'results_{config["model"]}_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}_{timestamp}'
|
81 |
+
candidate = os.path.join(name)
|
82 |
+
while os.path.exists(candidate):
|
83 |
+
candidate = os.path.join(f"{name}_{counter}")
|
84 |
+
counter += 1
|
85 |
+
output_folders = data_utils.generate_output_structure(
|
86 |
+
result_folder,
|
87 |
+
[
|
88 |
+
candidate,
|
89 |
+
f'input_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}',
|
90 |
+
f'target_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}',
|
91 |
+
],
|
92 |
+
)
|
93 |
+
config_out = os.path.join(os.path.split(output_folders[0])[0], "config.yaml")
|
94 |
+
with open(config_out, "w") as f:
|
95 |
+
yaml.safe_dump(config_dict, f)
|
96 |
+
|
97 |
+
source_files = list(data_utils.find_files(target_file, ext="png"))
|
98 |
+
num_images = len(source_files)
|
99 |
+
print(f"Found {num_images} images.")
|
100 |
+
|
101 |
+
# Load captions
|
102 |
+
if caption_file:
|
103 |
+
captions = data_utils.load_captions_from_file(caption_file, user_prompt="")
|
104 |
+
if not captions:
|
105 |
+
sys.exit("Error: No captions were loaded from the provided caption file.")
|
106 |
+
if len(captions) != num_images:
|
107 |
+
print("Warning: Number of captions does not match number of images.")
|
108 |
+
prompts_in_order = [captions.get(os.path.basename(f), "") for f in source_files]
|
109 |
+
else:
|
110 |
+
# Use the single prompt for all images
|
111 |
+
prompts_in_order = [prompt for _ in range(num_images)]
|
112 |
+
|
113 |
+
if any(p == "" for p in prompts_in_order):
|
114 |
+
print("Warning: Some images might not have corresponding captions or prompt is empty.")
|
115 |
+
config["prompt"] = prompts_in_order
|
116 |
+
|
117 |
+
model, inp_kwargs = model_loader.load_model(config, device=devices)
|
118 |
+
if mask_file and config["degradation"]["name"] == "Inpainting":
|
119 |
+
config["degradation"]["kwargs"]["mask"] = mask_file
|
120 |
+
posterior_model = var_post_samp.VariationalPosterior(model, config)
|
121 |
+
guidance_img_iterator = data_utils.yield_images(
|
122 |
+
target_file, size=config["resolution"]
|
123 |
+
)
|
124 |
+
for idx, guidance_img in tqdm.tqdm(enumerate(guidance_img_iterator), total=num_images):
|
125 |
+
guidance_img = guidance_img.to(dtype).cuda()
|
126 |
+
y = posterior_model.forward_operator(guidance_img)
|
127 |
+
tic = time.time()
|
128 |
+
with torch.no_grad():
|
129 |
+
result_dict = posterior_model.forward(y, inp_kwargs[idx])
|
130 |
+
x_hat = result_dict["x_hat"]
|
131 |
+
toc = time.time()
|
132 |
+
print(f"Runtime: {toc - tic}")
|
133 |
+
guidance_img = guidance_img.cuda()
|
134 |
+
result_file = output_folders[0].format(idx)
|
135 |
+
input_file = output_folders[1].format(idx)
|
136 |
+
ground_truth_file = output_folders[2].format(idx)
|
137 |
+
x_hat_pil = torchvision.transforms.ToPILImage()(
|
138 |
+
x_hat.float()[0].clip(-1, 1) * 0.5 + 0.5
|
139 |
+
)
|
140 |
+
x_hat_pil.save(result_file)
|
141 |
+
try:
|
142 |
+
if config["degradation"]["name"] == "SuperRes":
|
143 |
+
input_img = posterior_model.forward_operator.nn(y)
|
144 |
+
else:
|
145 |
+
input_img = posterior_model.forward_operator.pseudo_inv(y)
|
146 |
+
input_img_pil = torchvision.transforms.ToPILImage()(
|
147 |
+
input_img.float()[0].clip(-1, 1) * 0.5 + 0.5
|
148 |
+
)
|
149 |
+
input_img_pil.save(input_file)
|
150 |
+
except Exception:
|
151 |
+
print("Error in pseudo-inverse operation. Skipping input image save.")
|
152 |
+
guidance_img_pil = torchvision.transforms.ToPILImage()(
|
153 |
+
guidance_img.float()[0] * 0.5 + 0.5
|
154 |
+
)
|
155 |
+
guidance_img_pil.save(ground_truth_file)
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm
|
2 |
+
torch
|
3 |
+
click
|
4 |
+
pyyaml
|
5 |
+
numpy
|
6 |
+
torchvision
|
7 |
+
diffusers
|
8 |
+
Pillow
|
9 |
+
scipy
|
10 |
+
munch
|
11 |
+
transformers
|
12 |
+
torchmetrics
|
13 |
+
scikit-image
|
14 |
+
opencv-python
|
15 |
+
sentencepiece
|
16 |
+
protobuf
|
17 |
+
accelerate
|
18 |
+
gradio
|
scripts/compute_metrics.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchmetrics
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import argparse
|
5 |
+
from flair.utils import data_utils
|
6 |
+
import os
|
7 |
+
import tqdm
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchmetrics.image.kid import KernelInceptionDistance
|
10 |
+
|
11 |
+
|
12 |
+
MAX_BATCH_SIZE = None
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def main(args):
|
16 |
+
# Determine device
|
17 |
+
if args.device == "cuda" and torch.cuda.is_available():
|
18 |
+
device = torch.device("cuda")
|
19 |
+
else:
|
20 |
+
device = torch.device("cpu")
|
21 |
+
print(f"Using device: {device}")
|
22 |
+
|
23 |
+
# load images
|
24 |
+
gt_iterator = data_utils.yield_images(os.path.abspath(args.gt), size=args.resolution)
|
25 |
+
pred_iterator = data_utils.yield_images(os.path.abspath(args.pred), size=args.resolution)
|
26 |
+
fid_metric = torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device)
|
27 |
+
# kid_metric = KernelInceptionDistance(subset_size=args.kid_subset_size, normalize=True).to(device)
|
28 |
+
lpips_metric = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(
|
29 |
+
net_type="alex", normalize=False, reduction="mean"
|
30 |
+
).to(device)
|
31 |
+
if args.patch_size:
|
32 |
+
patch_fid_metric = torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device)
|
33 |
+
# patch_kid_metric = KernelInceptionDistance(subset_size=args.kid_subset_size, normalize=True).to(device)
|
34 |
+
psnr_list = []
|
35 |
+
lpips_list = []
|
36 |
+
ssim_list = []
|
37 |
+
# iterate over images
|
38 |
+
for gt, pred in tqdm.tqdm(zip(gt_iterator, pred_iterator)):
|
39 |
+
# Move tensors to the selected device
|
40 |
+
gt = gt.to(device)
|
41 |
+
pred = pred.to(device)
|
42 |
+
|
43 |
+
# resize gt to pred size
|
44 |
+
if gt.shape[-2:] != (args.resolution, args.resolution):
|
45 |
+
gt = F.interpolate(gt, size=args.resolution, mode="area")
|
46 |
+
if pred.shape[-2:] != (args.resolution, args.resolution):
|
47 |
+
pred = F.interpolate(pred, size=args.resolution, mode="area")
|
48 |
+
# to range [0,1]
|
49 |
+
gt_norm = gt * 0.5 + 0.5
|
50 |
+
pred_norm = pred * 0.5 + 0.5
|
51 |
+
# compute PSNR
|
52 |
+
psnr = torchmetrics.functional.image.peak_signal_noise_ratio(
|
53 |
+
pred_norm, gt_norm, data_range=1.0
|
54 |
+
)
|
55 |
+
psnr_list.append(psnr.cpu()) # Move result to CPU
|
56 |
+
# compute LPIPS
|
57 |
+
lpips_score = lpips_metric(pred.clip(-1,1), gt.clip(-1,1))
|
58 |
+
lpips_list.append(lpips_score.cpu()) # Move result to CPU
|
59 |
+
# compute SSIM
|
60 |
+
ssim = torchmetrics.functional.image.structural_similarity_index_measure(
|
61 |
+
pred_norm, gt_norm, data_range=1.0
|
62 |
+
)
|
63 |
+
ssim_list.append(ssim.cpu()) # Move result to CPU
|
64 |
+
print(f"PSNR: {psnr}, LPIPS: {lpips_score}, SSIM: {ssim}")
|
65 |
+
# compute FID
|
66 |
+
# Ensure inputs are on the correct device (already handled by moving gt/pred earlier)
|
67 |
+
fid_metric.update(gt_norm, real=False)
|
68 |
+
fid_metric.update(pred_norm, real=True)
|
69 |
+
# compute KID
|
70 |
+
# kid_metric.update(pred, real=False)
|
71 |
+
# kid_metric.update(gt, real=True)
|
72 |
+
# compute Patchwise FID/KID if patch_size is specified
|
73 |
+
if args.patch_size:
|
74 |
+
# Extract patches
|
75 |
+
patch_size = args.patch_size
|
76 |
+
gt_patches = F.unfold(gt_norm, kernel_size=patch_size, stride=patch_size)
|
77 |
+
pred_patches = F.unfold(pred_norm, kernel_size=patch_size, stride=patch_size)
|
78 |
+
# Reshape patches: (B, C*P*P, N_patches) -> (B*N_patches, C, P, P)
|
79 |
+
B, C, H, W = gt.shape
|
80 |
+
N_patches = gt_patches.shape[-1]
|
81 |
+
gt_patches = gt_patches.permute(0, 2, 1).reshape(B * N_patches, C, patch_size, patch_size)
|
82 |
+
pred_patches = pred_patches.permute(0, 2, 1).reshape(B * N_patches, C, patch_size, patch_size)
|
83 |
+
# Update patch FID metric (inputs are already on the correct device)
|
84 |
+
# Update patch KID metric
|
85 |
+
# process mini batches of patches
|
86 |
+
if MAX_BATCH_SIZE is None:
|
87 |
+
patch_fid_metric.update(pred_patches, real=False)
|
88 |
+
patch_fid_metric.update(gt_patches, real=True)
|
89 |
+
# patch_kid_metric.update(pred_patches, real=False)
|
90 |
+
# patch_kid_metric.update(gt_patches, real=True)
|
91 |
+
else:
|
92 |
+
for i in range(0, N_patches, MAX_BATCH_SIZE):
|
93 |
+
patch_fid_metric.update(pred_patches[i:i + MAX_BATCH_SIZE], real=False)
|
94 |
+
patch_fid_metric.update(gt_patches[i:i + MAX_BATCH_SIZE], real=True)
|
95 |
+
# patch_kid_metric.update(pred_patches[i:i + MAX_BATCH_SIZE], real=False)
|
96 |
+
# patch_kid_metric.update(gt_patches[i:i + MAX_BATCH_SIZE], real=True)
|
97 |
+
|
98 |
+
# compute FID
|
99 |
+
fid = fid_metric.compute()
|
100 |
+
# compute KID
|
101 |
+
# kid_mean, kid_std = kid_metric.compute()
|
102 |
+
if args.patch_size:
|
103 |
+
patch_fid = patch_fid_metric.compute()
|
104 |
+
# patch_kid_mean, patch_kid_std = patch_kid_metric.compute()
|
105 |
+
# compute average metrics (on CPU)
|
106 |
+
avg_psnr = torch.mean(torch.stack(psnr_list))
|
107 |
+
avg_lpips = torch.mean(torch.stack(lpips_list))
|
108 |
+
avg_ssim = torch.mean(torch.stack(ssim_list))
|
109 |
+
# compute standard deviation (on CPU)
|
110 |
+
std_psnr = torch.std(torch.stack(psnr_list))
|
111 |
+
std_lpips = torch.std(torch.stack(lpips_list))
|
112 |
+
std_ssim = torch.std(torch.stack(ssim_list))
|
113 |
+
print(f"PSNR: {avg_psnr} +/- {std_psnr}")
|
114 |
+
print(f"LPIPS: {avg_lpips} +/- {std_lpips}")
|
115 |
+
print(f"SSIM: {avg_ssim} +/- {std_ssim}")
|
116 |
+
print(f"FID: {fid}") # FID is computed on the selected device, print directly
|
117 |
+
# print(f"KID: {kid_mean} +/- {kid_std}") # KID is computed on the selected device, print directly
|
118 |
+
if args.patch_size:
|
119 |
+
print(f"Patch FID ({args.patch_size}x{args.patch_size}): {patch_fid}") # Patch FID is computed on the selected device, print directly
|
120 |
+
# print(f"Patch KID ({args.patch_size}x{args.patch_size}): {patch_kid_mean} +/- {patch_kid_std}") # Patch KID is computed on the selected device, print directly
|
121 |
+
# save to prediction folder
|
122 |
+
out_file = os.path.join(args.pred, "fid_metrics.txt")
|
123 |
+
with open(out_file, "w") as f:
|
124 |
+
f.write(f"PSNR: {avg_psnr.item()} +/- {std_psnr.item()}\n") # Use .item() for scalar tensors
|
125 |
+
f.write(f"LPIPS: {avg_lpips.item()} +/- {std_lpips.item()}\n")
|
126 |
+
f.write(f"SSIM: {avg_ssim.item()} +/- {std_ssim.item()}\n")
|
127 |
+
f.write(f"FID: {fid.item()}\n") # Use .item() for scalar tensors
|
128 |
+
# f.write(f"KID: {kid_mean.item()} +/- {kid_std.item()}\n") # Use .item() for scalar tensors
|
129 |
+
if args.patch_size:
|
130 |
+
f.write(f"Patch FID ({args.patch_size}x{args.patch_size}): {patch_fid.item()}\n") # Use .item() for scalar tensors
|
131 |
+
# f.write(f"Patch KID ({args.patch_size}x{args.patch_size}): {patch_kid_mean.item()} +/- {patch_kid_std.item()}\n") # Use .item() for scalar tensors
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
parser = argparse.ArgumentParser(description="Compute metrics")
|
136 |
+
parser.add_argument("--gt", type=str, help="Path to ground truth image")
|
137 |
+
parser.add_argument("--pred", type=str, help="Path to predicted image")
|
138 |
+
parser.add_argument("--resolution", type=int, default=768, help="resolution at which to evaluate")
|
139 |
+
parser.add_argument("--patch_size", type=int, default=None, help="Patch size for Patchwise FID/KID computation (e.g., 12). If None, skip.")
|
140 |
+
parser.add_argument("--kid_subset_size", type=int, default=1000, help="Subset size for KID computation.")
|
141 |
+
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to run computation on (cpu or cuda)")
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
main(args)
|
scripts/generate_caption.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
# Add the path to the thirdparty/SeeSR directory to the Python path
|
7 |
+
sys.path.append(os.path.abspath("./thirdparty/SeeSR"))
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
from ram.models.ram_lora import ram
|
12 |
+
from ram import inference_ram as inference
|
13 |
+
|
14 |
+
def load_ram_model(ram_model_path: str, dape_model_path: str):
|
15 |
+
"""
|
16 |
+
Load the RAM model with the given paths.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
ram_model_path (str): Path to the pretrained RAM model.
|
20 |
+
dape_model_path (str): Path to the pretrained DAPE model.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
torch.nn.Module: Loaded RAM model.
|
24 |
+
"""
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
|
27 |
+
# Load the RAM model
|
28 |
+
tag_model = ram(pretrained=ram_model_path, pretrained_condition=dape_model_path, image_size=384, vit="swin_l")
|
29 |
+
tag_model.eval()
|
30 |
+
return tag_model.to(device)
|
31 |
+
|
32 |
+
def generate_caption(image_path: str, tag_model) -> str:
|
33 |
+
"""
|
34 |
+
Generate a caption for a degraded image using the RAM model.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
image_path (str): Path to the degraded input image.
|
38 |
+
tag_model (torch.nn.Module): Preloaded RAM model.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
str: Generated caption for the image.
|
42 |
+
"""
|
43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
+
|
45 |
+
# Define image transformations
|
46 |
+
tensor_transforms = transforms.Compose([
|
47 |
+
transforms.ToTensor(),
|
48 |
+
])
|
49 |
+
ram_transforms = transforms.Compose([
|
50 |
+
transforms.Resize((384, 384)),
|
51 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
52 |
+
])
|
53 |
+
|
54 |
+
# Load and preprocess the image
|
55 |
+
image = Image.open(image_path).convert("RGB")
|
56 |
+
image_tensor = tensor_transforms(image).unsqueeze(0).to(device)
|
57 |
+
image_tensor = ram_transforms(image_tensor)
|
58 |
+
|
59 |
+
# Generate caption using the RAM model
|
60 |
+
caption = inference(image_tensor, tag_model)
|
61 |
+
|
62 |
+
return caption[0]
|
63 |
+
|
64 |
+
def process_images_in_directory(input_dir: str, output_file: str, tag_model):
|
65 |
+
"""
|
66 |
+
Process all images in a directory, generate captions using the RAM model,
|
67 |
+
and save the captions to a file.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
input_dir (str): Path to the directory containing input images.
|
71 |
+
output_file (str): Path to the file where captions will be saved.
|
72 |
+
tag_model (torch.nn.Module): Preloaded RAM model.
|
73 |
+
"""
|
74 |
+
# Open the output file for writing captions
|
75 |
+
with open(output_file, "w") as f:
|
76 |
+
# Iterate through all files in the input directory
|
77 |
+
for filename in os.listdir(input_dir):
|
78 |
+
# Construct the full path to the image file
|
79 |
+
image_path = os.path.join(input_dir, filename)
|
80 |
+
|
81 |
+
# Check if the file is an image
|
82 |
+
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
83 |
+
try:
|
84 |
+
# Generate a caption for the image
|
85 |
+
caption = generate_caption(image_path, tag_model)
|
86 |
+
print(f"Generated caption for {filename}: {caption}")
|
87 |
+
# Write the caption to the output file
|
88 |
+
f.write(f"{filename}: {caption}\n")
|
89 |
+
|
90 |
+
print(f"Processed {filename}: {caption}")
|
91 |
+
except Exception as e:
|
92 |
+
print(f"Error processing {filename}: {e}")
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
parser = argparse.ArgumentParser(description="Generate captions for images using RAM and DAPE models.")
|
96 |
+
parser.add_argument("--input_dir", type=str, default="data/val", help="Path to the directory containing input images.")
|
97 |
+
parser.add_argument("--output_file", type=str, default="data/val_captions.txt", help="Path to the file where captions will be saved.")
|
98 |
+
parser.add_argument("--ram_model", type=str, default="/home/erbachj/scratch2/projects/var_post_samp/thirdparty/SeeSR/preset/model/ram_swin_large_14m.pth", help="Path to the pretrained RAM model.")
|
99 |
+
parser.add_argument("--dape_model", type=str, default="/home/erbachj/scratch2/projects/var_post_samp/thirdparty/SeeSR/preset/model/DAPE.pth", help="Path to the pretrained DAPE model.")
|
100 |
+
|
101 |
+
args = parser.parse_args()
|
102 |
+
|
103 |
+
# Load the RAM model once
|
104 |
+
tag_model = load_ram_model(args.ram_model, args.dape_model)
|
105 |
+
|
106 |
+
# Process images in the directory
|
107 |
+
process_images_in_directory(args.input_dir, args.output_file, tag_model)
|
setup.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="flair",
|
5 |
+
version="0.1.0",
|
6 |
+
author="Julius Erbach",
|
7 |
+
author_email="[email protected]",
|
8 |
+
description="Solving Inverse Problems with FLAIR",
|
9 |
+
package_dir={"": "src"},
|
10 |
+
packages=find_packages(where="src"),
|
11 |
+
include_package_data=True,
|
12 |
+
python_requires=">=3.7",
|
13 |
+
keywords="pytorch variational posterior sampling deep learning",
|
14 |
+
)
|
src/flair/__init__.py
ADDED
File without changes
|
src/flair/degradations.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from munch import munchify
|
4 |
+
from scipy.ndimage import distance_transform_edt
|
5 |
+
from flair.functions.degradation import get_degradation
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
class BaseDegradation(torch.nn.Module):
|
9 |
+
def __init__(self, noise_std=0.0):
|
10 |
+
super().__init__()
|
11 |
+
self.noise_std = noise_std
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
x = x + self.noise_std * torch.randn_like(x)
|
15 |
+
return x
|
16 |
+
|
17 |
+
def pseudo_inv(self, y):
|
18 |
+
return y
|
19 |
+
|
20 |
+
|
21 |
+
def zero_filler(x, scale):
|
22 |
+
B, C, H, W = x.shape
|
23 |
+
scale = int(scale)
|
24 |
+
H_new, W_new = H * scale, W * scale
|
25 |
+
out = torch.zeros(B, C, H_new, W_new, dtype=x.dtype, device=x.device)
|
26 |
+
out[:, :, ::scale, ::scale] = x
|
27 |
+
return out
|
28 |
+
|
29 |
+
|
30 |
+
class SuperRes(BaseDegradation):
|
31 |
+
def __init__(self, scale, noise_std=0.0, img_size=256):
|
32 |
+
super().__init__(noise_std=noise_std)
|
33 |
+
self.scale = scale
|
34 |
+
deg_config = munchify({
|
35 |
+
'channels': 3,
|
36 |
+
'image_size': img_size,
|
37 |
+
'deg_scale': scale
|
38 |
+
})
|
39 |
+
self.img_size = img_size
|
40 |
+
self.deg = get_degradation("sr_bicubic", deg_config, device="cuda")
|
41 |
+
|
42 |
+
def forward(self, x, noise=True):
|
43 |
+
dtype = x.dtype
|
44 |
+
y = self.deg.A(x.float())
|
45 |
+
# add noise
|
46 |
+
if noise:
|
47 |
+
y = super().forward(y)
|
48 |
+
|
49 |
+
return y.to(dtype)
|
50 |
+
|
51 |
+
def pseudo_inv(self, y):
|
52 |
+
x = self.deg.At(y.float()).reshape(1,3,self.img_size, self.img_size)* self.scale**2
|
53 |
+
return x.to(y.dtype)
|
54 |
+
|
55 |
+
|
56 |
+
def nn(self, y):
|
57 |
+
x = torch.nn.functional.interpolate(
|
58 |
+
y.reshape(1,3,self.img_size//self.scale, self.img_size//self.scale), scale_factor=self.scale, mode="nearest"
|
59 |
+
)
|
60 |
+
return x.to(y.dtype)
|
61 |
+
|
62 |
+
class SuperResGradio(BaseDegradation):
|
63 |
+
def __init__(self, scale, noise_std=0.0, img_size=256):
|
64 |
+
super().__init__(noise_std=noise_std)
|
65 |
+
self.scale = scale
|
66 |
+
self.downscaler = lambda x: torch.nn.functional.interpolate(
|
67 |
+
x.float(), scale_factor=1/self.scale, mode="bilinear", align_corners=False, antialias=True
|
68 |
+
)
|
69 |
+
self.upscaler = lambda x: torch.nn.functional.interpolate(
|
70 |
+
x.float(), scale_factor=self.scale, mode="bilinear", align_corners=False, antialias=True
|
71 |
+
)
|
72 |
+
self.img_size = img_size
|
73 |
+
|
74 |
+
def forward(self, x, noise=True):
|
75 |
+
dtype = x.dtype
|
76 |
+
y = self.downscaler(x.float())
|
77 |
+
# add noise
|
78 |
+
if noise:
|
79 |
+
y = super().forward(y)
|
80 |
+
|
81 |
+
return y.to(dtype)
|
82 |
+
|
83 |
+
def pseudo_inv(self, y):
|
84 |
+
x = self.upscaler(y.float())
|
85 |
+
return x.to(y.dtype)
|
86 |
+
|
87 |
+
|
88 |
+
def nn(self, y):
|
89 |
+
x = torch.nn.functional.interpolate(
|
90 |
+
y.reshape(1,3,self.img_size//self.scale, self.img_size//self.scale), scale_factor=self.scale, mode="nearest-exact"
|
91 |
+
)
|
92 |
+
return x.to(y.dtype)
|
93 |
+
|
94 |
+
|
95 |
+
class Inpainting(BaseDegradation):
|
96 |
+
def __init__(self, mask, H, W, noise_std=0.0):
|
97 |
+
"""
|
98 |
+
mask: torch.Tensor, shape (H, W), dtype bool
|
99 |
+
function assumes 3 channels
|
100 |
+
"""
|
101 |
+
super().__init__(noise_std=noise_std)
|
102 |
+
if isinstance(mask, list):
|
103 |
+
# generate box from left, right, lower upper list
|
104 |
+
# observed region is True
|
105 |
+
mask_ = torch.ones(H, W, dtype=torch.bool)
|
106 |
+
mask_[slice(*mask[0:2]), slice(*mask[2:])] = False
|
107 |
+
# repeat for 3 channels
|
108 |
+
mask_ = mask_.repeat(3, 1, 1)
|
109 |
+
elif isinstance(mask, str):
|
110 |
+
# load mask file
|
111 |
+
mask_ = torch.tensor(np.load(mask), dtype=torch.bool)
|
112 |
+
mask_ = mask_.repeat(3, 1, 1)
|
113 |
+
elif isinstance(mask, torch.Tensor):
|
114 |
+
if mask.ndim == 2:
|
115 |
+
# assume mask is for one channel, repeat for 3 channels
|
116 |
+
mask_ = mask[None].repeat(3, 1, 1)
|
117 |
+
elif mask.ndim == 3 and mask.shape[0] == 1:
|
118 |
+
# assume mask is for one channel, repeat for 3 channels
|
119 |
+
mask_ = mask.repeat(3, 1, 1)
|
120 |
+
else:
|
121 |
+
mask_ = mask
|
122 |
+
else:
|
123 |
+
raise ValueError("Mask must be a list, string (file path), or torch.Tensor.")
|
124 |
+
self.mask = mask_
|
125 |
+
self.H, self.W = H, W
|
126 |
+
|
127 |
+
def forward(self, x, noise=True):
|
128 |
+
B = x.shape[0]
|
129 |
+
y = x[self.mask[None]].view(B, -1)
|
130 |
+
# add noise
|
131 |
+
if noise:
|
132 |
+
y = super().forward(y)
|
133 |
+
return y
|
134 |
+
|
135 |
+
def pseudo_inv(self, y):
|
136 |
+
x = torch.zeros(y.shape[0], 3 * self.H * self.W, dtype=y.dtype, device=y.device)
|
137 |
+
x[:, self.mask.view(-1)] = y
|
138 |
+
x = x.view(y.shape[0], 3, self.H, self.W)
|
139 |
+
# x = inpaint_nearest(x[0], self.mask[0])[None]
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
def inpaint_nearest(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
144 |
+
"""
|
145 |
+
Fill missing pixels in an image using the nearest observed pixel value.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
image: A tensor of shape [C, H, W] representing the image.
|
149 |
+
mask: A tensor of shape [H, W] with 1 for observed pixels and 0 for missing pixels.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
A tensor of shape [C, H, W] where missing pixels have been filled.
|
153 |
+
"""
|
154 |
+
# Move tensors to CPU and convert to numpy arrays.
|
155 |
+
image_np = image.cpu().float().numpy()
|
156 |
+
# Convert mask to boolean: True for observed, False for missing.
|
157 |
+
mask_np = mask.cpu().numpy().astype(bool)
|
158 |
+
|
159 |
+
# Compute the distance transform of the inverse mask (~mask_np).
|
160 |
+
# The function returns:
|
161 |
+
# - distances: distance to the nearest True pixel in mask_np
|
162 |
+
# - indices: the indices of that nearest True pixel for each pixel.
|
163 |
+
# indices has shape (2, H, W): first row is the row index, second row is the column index.
|
164 |
+
_, indices = distance_transform_edt(~mask_np, return_indices=True)
|
165 |
+
|
166 |
+
# Create a copy of the image to hold the filled values.
|
167 |
+
filled_image_np = np.empty_like(image_np)
|
168 |
+
|
169 |
+
# For each channel, replace every pixel with the value of the nearest observed pixel.
|
170 |
+
for c in range(image_np.shape[0]):
|
171 |
+
filled_image_np[c] = image_np[c, indices[0], indices[1]]
|
172 |
+
|
173 |
+
# Convert back to a torch tensor and send to the original device.
|
174 |
+
return torch.from_numpy(filled_image_np).to(image.device).to(image.dtype)
|
175 |
+
|
176 |
+
class MotionBlur(BaseDegradation):
|
177 |
+
def __init__(self, kernel_size=5, img_size=256, noise_std=0.0):
|
178 |
+
super().__init__(noise_std=noise_std)
|
179 |
+
deg_config = munchify({
|
180 |
+
'channels': 3,
|
181 |
+
'image_size': img_size,
|
182 |
+
'deg_scale': kernel_size
|
183 |
+
})
|
184 |
+
self.img_size = img_size
|
185 |
+
self.deg = get_degradation("deblur_motion", deg_config, device="cuda")
|
186 |
+
|
187 |
+
def forward(self, x, noise=True):
|
188 |
+
dtype = x.dtype
|
189 |
+
y = self.deg.A(x.float())
|
190 |
+
# add noise
|
191 |
+
if noise:
|
192 |
+
y = super().forward(y)
|
193 |
+
return y.to(dtype)
|
194 |
+
|
195 |
+
def pseudo_inv(self, y):
|
196 |
+
dtype = y.dtype
|
197 |
+
x = self.deg.At(y.float()).reshape(1,3,self.img_size, self.img_size)
|
198 |
+
return x.to(dtype)
|
src/flair/functions/__init__.py
ADDED
File without changes
|
src/flair/functions/ckpt_util.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, hashlib
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
URL_MAP = {
|
6 |
+
"cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
|
7 |
+
"ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
|
8 |
+
"lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
|
9 |
+
"ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
|
10 |
+
"lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
|
11 |
+
"ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
|
12 |
+
"lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
|
13 |
+
"ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
|
14 |
+
}
|
15 |
+
CKPT_MAP = {
|
16 |
+
"cifar10": "diffusion_cifar10_model/model-790000.ckpt",
|
17 |
+
"ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt",
|
18 |
+
"lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt",
|
19 |
+
"ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt",
|
20 |
+
"lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt",
|
21 |
+
"ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt",
|
22 |
+
"lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt",
|
23 |
+
"ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt",
|
24 |
+
}
|
25 |
+
MD5_MAP = {
|
26 |
+
"cifar10": "82ed3067fd1002f5cf4c339fb80c4669",
|
27 |
+
"ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3",
|
28 |
+
"lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c",
|
29 |
+
"ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f",
|
30 |
+
"lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b",
|
31 |
+
"ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558",
|
32 |
+
"lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3",
|
33 |
+
"ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f",
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
def download(url, local_path, chunk_size=1024):
|
38 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
39 |
+
with requests.get(url, stream=True) as r:
|
40 |
+
total_size = int(r.headers.get("content-length", 0))
|
41 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
42 |
+
with open(local_path, "wb") as f:
|
43 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
44 |
+
if data:
|
45 |
+
f.write(data)
|
46 |
+
pbar.update(chunk_size)
|
47 |
+
|
48 |
+
|
49 |
+
def md5_hash(path):
|
50 |
+
with open(path, "rb") as f:
|
51 |
+
content = f.read()
|
52 |
+
return hashlib.md5(content).hexdigest()
|
53 |
+
|
54 |
+
|
55 |
+
def get_ckpt_path(name, root=None, check=False, prefix='exp'):
|
56 |
+
if 'church_outdoor' in name:
|
57 |
+
name = name.replace('church_outdoor', 'church')
|
58 |
+
assert name in URL_MAP
|
59 |
+
# Modify the path when necessary
|
60 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.join(prefix, "logs/"))
|
61 |
+
root = (
|
62 |
+
root
|
63 |
+
if root is not None
|
64 |
+
else os.path.join(cachedir, "diffusion_models_converted")
|
65 |
+
)
|
66 |
+
path = os.path.join(root, CKPT_MAP[name])
|
67 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
68 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
69 |
+
download(URL_MAP[name], path)
|
70 |
+
md5 = md5_hash(path)
|
71 |
+
assert md5 == MD5_MAP[name], md5
|
72 |
+
return path
|
src/flair/functions/conjugate_gradient.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From wikipedia. MATLAB code,
|
2 |
+
# function x = conjgrad(A, b, x)
|
3 |
+
# r = b - A * x;
|
4 |
+
# p = r;
|
5 |
+
# rsold = r' * r;
|
6 |
+
#
|
7 |
+
# for i = 1:length(b)
|
8 |
+
# Ap = A * p;
|
9 |
+
# alpha = rsold / (p' * Ap);
|
10 |
+
# x = x + alpha * p;
|
11 |
+
# r = r - alpha * Ap;
|
12 |
+
# rsnew = r' * r;
|
13 |
+
# if sqrt(rsnew) < 1e-10
|
14 |
+
# break
|
15 |
+
# end
|
16 |
+
# p = r + (rsnew / rsold) * p;
|
17 |
+
# rsold = rsnew;
|
18 |
+
# end
|
19 |
+
# end
|
20 |
+
|
21 |
+
from typing import Callable, Optional
|
22 |
+
|
23 |
+
import torch
|
24 |
+
|
25 |
+
|
26 |
+
def CG(A: Callable,
|
27 |
+
b: torch.Tensor,
|
28 |
+
x: torch.Tensor,
|
29 |
+
m: Optional[int]=5,
|
30 |
+
eps: Optional[float]=1e-4,
|
31 |
+
damping: float=0.0,
|
32 |
+
use_mm: bool=False) -> torch.Tensor:
|
33 |
+
|
34 |
+
if use_mm:
|
35 |
+
mm_fn = lambda x, y: torch.mm(x.view(1, -1), y.view(1, -1).T)
|
36 |
+
else:
|
37 |
+
mm_fn = lambda x, y: (x * y).flatten().sum()
|
38 |
+
|
39 |
+
orig_shape = x.shape
|
40 |
+
x = x.view(x.shape[0], -1)
|
41 |
+
|
42 |
+
r = b - A(x)
|
43 |
+
p = r.clone()
|
44 |
+
|
45 |
+
rsold = mm_fn(r, r)
|
46 |
+
assert not (rsold != rsold).any(), print(f'NaN detected 1')
|
47 |
+
|
48 |
+
for i in range(m):
|
49 |
+
Ap = A(p) + damping * p
|
50 |
+
alpha = rsold / mm_fn(p, Ap)
|
51 |
+
assert not (alpha != alpha).any(), print(f'NaN detected 2')
|
52 |
+
|
53 |
+
x = x + alpha * p
|
54 |
+
r = r - alpha * Ap
|
55 |
+
|
56 |
+
rsnew = mm_fn(r, r)
|
57 |
+
assert not (rsnew != rsnew).any(), print('NaN detected 3')
|
58 |
+
|
59 |
+
if rsnew.sqrt().abs() < eps:
|
60 |
+
break
|
61 |
+
|
62 |
+
p = r + (rsnew / rsold) * p
|
63 |
+
rsold = rsnew.clone()
|
64 |
+
|
65 |
+
return x.reshape(orig_shape)
|
66 |
+
|
src/flair/functions/degradation.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from munch import Munch
|
4 |
+
|
5 |
+
from flair.functions import svd_operators as svd_op
|
6 |
+
from flair.functions import measurements
|
7 |
+
from flair.utils.inpaint_util import MaskGenerator
|
8 |
+
|
9 |
+
__DEGRADATION__ = {}
|
10 |
+
|
11 |
+
def register_degradation(name: str):
|
12 |
+
def wrapper(fn):
|
13 |
+
if __DEGRADATION__.get(name) is not None:
|
14 |
+
raise NameError(f'DEGRADATION {name} is already registered')
|
15 |
+
__DEGRADATION__[name]=fn
|
16 |
+
return fn
|
17 |
+
return wrapper
|
18 |
+
|
19 |
+
def get_degradation(name: str,
|
20 |
+
deg_config: Munch,
|
21 |
+
device:torch.device):
|
22 |
+
if __DEGRADATION__.get(name) is None:
|
23 |
+
raise NameError(f'DEGRADATION {name} does not exist.')
|
24 |
+
return __DEGRADATION__[name](deg_config, device)
|
25 |
+
|
26 |
+
@register_degradation(name='cs_walshhadamard')
|
27 |
+
def deg_cs_walshhadamard(deg_config, device):
|
28 |
+
compressed_size = round(1/deg_config.deg_scale)
|
29 |
+
A_funcs = svd_op.WalshHadamardCS(deg_config.channels,
|
30 |
+
deg_config.image_size,
|
31 |
+
compressed_size,
|
32 |
+
torch.randperm(deg_config.image_size**2),
|
33 |
+
device)
|
34 |
+
return A_funcs
|
35 |
+
|
36 |
+
@register_degradation(name='cs_blockbased')
|
37 |
+
def deg_cs_blockbased(deg_config, device):
|
38 |
+
cs_ratio = deg_config.deg_scale
|
39 |
+
A_funcs = svd_op.CS(deg_config.channels,
|
40 |
+
deg_config.image_size,
|
41 |
+
cs_ratio,
|
42 |
+
device)
|
43 |
+
return A_funcs
|
44 |
+
|
45 |
+
@register_degradation(name='inpainting')
|
46 |
+
def deg_inpainting(deg_config, device):
|
47 |
+
# TODO: generate mask rather than load
|
48 |
+
loaded = np.load("exp/inp_masks/mask_768_half.npy") # block
|
49 |
+
# loaded = np.load("lip_mask_4.npy")
|
50 |
+
mask = torch.from_numpy(loaded).to(device).reshape(-1)
|
51 |
+
missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3
|
52 |
+
missing_g = missing_r + 1
|
53 |
+
missing_b = missing_g + 1
|
54 |
+
missing = torch.cat([missing_r, missing_g, missing_b], dim=0)
|
55 |
+
A_funcs = svd_op.Inpainting(deg_config.channels,
|
56 |
+
deg_config.image_size,
|
57 |
+
missing,
|
58 |
+
device)
|
59 |
+
return A_funcs
|
60 |
+
|
61 |
+
@register_degradation(name='denoising')
|
62 |
+
def deg_denoise(deg_config, device):
|
63 |
+
A_funcs = svd_op.Denoising(deg_config.channels,
|
64 |
+
deg_config.image_size,
|
65 |
+
device)
|
66 |
+
return A_funcs
|
67 |
+
|
68 |
+
@register_degradation(name='colorization')
|
69 |
+
def deg_colorization(deg_config, device):
|
70 |
+
A_funcs = svd_op.Colorization(deg_config.image_size,
|
71 |
+
device)
|
72 |
+
return A_funcs
|
73 |
+
|
74 |
+
|
75 |
+
@register_degradation(name='sr_avgpool')
|
76 |
+
def deg_sr_avgpool(deg_config, device):
|
77 |
+
blur_by = int(deg_config.deg_scale)
|
78 |
+
A_funcs = svd_op.SuperResolution(deg_config.channels,
|
79 |
+
deg_config.image_size,
|
80 |
+
blur_by,
|
81 |
+
device)
|
82 |
+
return A_funcs
|
83 |
+
|
84 |
+
@register_degradation(name='sr_bicubic')
|
85 |
+
def deg_sr_bicubic(deg_config, device):
|
86 |
+
def bicubic_kernel(x, a=-0.5):
|
87 |
+
if abs(x) <= 1:
|
88 |
+
return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1
|
89 |
+
elif 1 < abs(x) and abs(x) < 2:
|
90 |
+
return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a
|
91 |
+
else:
|
92 |
+
return 0
|
93 |
+
|
94 |
+
factor = int(deg_config.deg_scale)
|
95 |
+
k = np.zeros((factor * 4))
|
96 |
+
for i in range(factor * 4):
|
97 |
+
x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5)
|
98 |
+
k[i] = bicubic_kernel(x)
|
99 |
+
k = k / np.sum(k)
|
100 |
+
kernel = torch.from_numpy(k).float().to(device)
|
101 |
+
A_funcs = svd_op.SRConv(kernel / kernel.sum(),
|
102 |
+
deg_config.channels,
|
103 |
+
deg_config.image_size,
|
104 |
+
device,
|
105 |
+
stride=factor)
|
106 |
+
return A_funcs
|
107 |
+
|
108 |
+
@register_degradation(name='deblur_uni')
|
109 |
+
def deg_deblur_uni(deg_config, device):
|
110 |
+
A_funcs = svd_op.Deblurring(torch.tensor([1/deg_config.deg_scale]*deg_config.deg_scale).to(device),
|
111 |
+
deg_config.channels,
|
112 |
+
deg_config.image_size,
|
113 |
+
device)
|
114 |
+
return A_funcs
|
115 |
+
|
116 |
+
@register_degradation(name='deblur_gauss')
|
117 |
+
def deg_deblur_gauss(deg_config, device):
|
118 |
+
sigma = 3.0
|
119 |
+
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
|
120 |
+
size = deg_config.deg_scale
|
121 |
+
ker = []
|
122 |
+
for k in range(-size//2, size//2):
|
123 |
+
ker.append(pdf(k))
|
124 |
+
kernel = torch.Tensor(ker).to(device)
|
125 |
+
A_funcs = svd_op.Deblurring(kernel / kernel.sum(),
|
126 |
+
deg_config.channels,
|
127 |
+
deg_config.image_size,
|
128 |
+
device)
|
129 |
+
return A_funcs
|
130 |
+
|
131 |
+
@register_degradation(name='deblur_aniso')
|
132 |
+
def deg_deblur_aniso(deg_config, device):
|
133 |
+
sigma = 20
|
134 |
+
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
|
135 |
+
kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device)
|
136 |
+
|
137 |
+
sigma = 1
|
138 |
+
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
|
139 |
+
kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device)
|
140 |
+
|
141 |
+
A_funcs = svd_op.Deblurring2D(kernel1 / kernel1.sum(),
|
142 |
+
kernel2 / kernel2.sum(),
|
143 |
+
deg_config.channels,
|
144 |
+
deg_config.image_size,
|
145 |
+
device)
|
146 |
+
return A_funcs
|
147 |
+
|
148 |
+
@register_degradation(name='deblur_motion')
|
149 |
+
def deg_deblur_motion(deg_config, device):
|
150 |
+
A_funcs = measurements.MotionBlurOperator(
|
151 |
+
kernel_size=deg_config.deg_scale,
|
152 |
+
intensity=0.5,
|
153 |
+
device=device
|
154 |
+
)
|
155 |
+
return A_funcs
|
156 |
+
|
157 |
+
@register_degradation(name='deblur_nonuniform')
|
158 |
+
def deg_deblur_motion(deg_config, device, kernels=None, masks=None):
|
159 |
+
A_funcs = measurements.NonuniformBlurOperator(
|
160 |
+
deg_config.image_size,
|
161 |
+
deg_config.deg_scale,
|
162 |
+
device,
|
163 |
+
kernels=kernels,
|
164 |
+
masks=masks,
|
165 |
+
)
|
166 |
+
return A_funcs
|
167 |
+
|
168 |
+
|
169 |
+
# ======= FOR arbitraty image size =======
|
170 |
+
@register_degradation(name='sr_avgpool_gen')
|
171 |
+
def deg_sr_avgpool_general(deg_config, device):
|
172 |
+
blur_by = int(deg_config.deg_scale)
|
173 |
+
A_funcs = svd_op.SuperResolutionGeneral(deg_config.channels,
|
174 |
+
deg_config.imgH,
|
175 |
+
deg_config.imgW,
|
176 |
+
blur_by,
|
177 |
+
device)
|
178 |
+
return A_funcs
|
179 |
+
|
180 |
+
@register_degradation(name='deblur_gauss_gen')
|
181 |
+
def deg_deblur_guass_general(deg_config, device):
|
182 |
+
A_funcs = measurements.GaussialBlurOperator(
|
183 |
+
kernel_size=deg_config.deg_scale,
|
184 |
+
intensity=3.0,
|
185 |
+
device=device
|
186 |
+
)
|
187 |
+
return A_funcs
|
188 |
+
|
189 |
+
|
190 |
+
from flair.functions.jpeg import jpeg_encode, jpeg_decode
|
191 |
+
|
192 |
+
class JPEGOperator():
|
193 |
+
def __init__(self, qf: int, device):
|
194 |
+
self.qf = qf
|
195 |
+
self.device = device
|
196 |
+
|
197 |
+
def A(self, img):
|
198 |
+
x_luma, x_chroma = jpeg_encode(img, self.qf)
|
199 |
+
return x_luma, x_chroma
|
200 |
+
|
201 |
+
def At(self, encoded):
|
202 |
+
return jpeg_decode(encoded, self.qf)
|
203 |
+
|
204 |
+
|
205 |
+
@register_degradation(name='jpeg')
|
206 |
+
def deg_jpeg(deg_config, device):
|
207 |
+
A_funcs = JPEGOperator(
|
208 |
+
qf = deg_config.deg_scale,
|
209 |
+
device=device
|
210 |
+
)
|
211 |
+
return A_funcs
|
src/flair/functions/jpeg.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# This file has been modified from ddrm-jpeg.
|
5 |
+
#
|
6 |
+
# Source:
|
7 |
+
# https://github.com/bahjat-kawar/ddrm-jpeg/blob/master/functions/jpeg_torch.py
|
8 |
+
#
|
9 |
+
# The license for the original version of this file can be
|
10 |
+
# found in this directory (LICENSE_DDRM_JPEG).
|
11 |
+
# The modifications to this file are subject to the same license.
|
12 |
+
# ---------------------------------------------------------------
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
def dct1(x):
|
19 |
+
"""
|
20 |
+
Discrete Cosine Transform, Type I
|
21 |
+
:param x: the input signal
|
22 |
+
:return: the DCT-I of the signal over the last dimension
|
23 |
+
"""
|
24 |
+
x_shape = x.shape
|
25 |
+
x = x.view(-1, x_shape[-1])
|
26 |
+
|
27 |
+
return torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1))[:, :, 0].view(*x_shape)
|
28 |
+
|
29 |
+
|
30 |
+
def idct1(X):
|
31 |
+
"""
|
32 |
+
The inverse of DCT-I, which is just a scaled DCT-I
|
33 |
+
Our definition if idct1 is such that idct1(dct1(x)) == x
|
34 |
+
:param X: the input signal
|
35 |
+
:return: the inverse DCT-I of the signal over the last dimension
|
36 |
+
"""
|
37 |
+
n = X.shape[-1]
|
38 |
+
return dct1(X) / (2 * (n - 1))
|
39 |
+
|
40 |
+
|
41 |
+
def dct(x, norm=None):
|
42 |
+
"""
|
43 |
+
Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
44 |
+
For the meaning of the parameter `norm`, see:
|
45 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
46 |
+
:param x: the input signal
|
47 |
+
:param norm: the normalization, None or 'ortho'
|
48 |
+
:return: the DCT-II of the signal over the last dimension
|
49 |
+
"""
|
50 |
+
x_shape = x.shape
|
51 |
+
N = x_shape[-1]
|
52 |
+
x = x.contiguous().view(-1, N)
|
53 |
+
|
54 |
+
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
55 |
+
|
56 |
+
Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
|
57 |
+
|
58 |
+
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
59 |
+
W_r = torch.cos(k)
|
60 |
+
W_i = torch.sin(k)
|
61 |
+
|
62 |
+
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
63 |
+
|
64 |
+
if norm == 'ortho':
|
65 |
+
V[:, 0] /= np.sqrt(N) * 2
|
66 |
+
V[:, 1:] /= np.sqrt(N / 2) * 2
|
67 |
+
|
68 |
+
V = 2 * V.view(*x_shape)
|
69 |
+
|
70 |
+
return V
|
71 |
+
|
72 |
+
|
73 |
+
def idct(X, norm=None):
|
74 |
+
"""
|
75 |
+
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
76 |
+
Our definition of idct is that idct(dct(x)) == x
|
77 |
+
For the meaning of the parameter `norm`, see:
|
78 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
79 |
+
:param X: the input signal
|
80 |
+
:param norm: the normalization, None or 'ortho'
|
81 |
+
:return: the inverse DCT-II of the signal over the last dimension
|
82 |
+
"""
|
83 |
+
|
84 |
+
x_shape = X.shape
|
85 |
+
N = x_shape[-1]
|
86 |
+
|
87 |
+
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
|
88 |
+
|
89 |
+
if norm == 'ortho':
|
90 |
+
X_v[:, 0] *= np.sqrt(N) * 2
|
91 |
+
X_v[:, 1:] *= np.sqrt(N / 2) * 2
|
92 |
+
|
93 |
+
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
|
94 |
+
W_r = torch.cos(k)
|
95 |
+
W_i = torch.sin(k)
|
96 |
+
|
97 |
+
V_t_r = X_v
|
98 |
+
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
|
99 |
+
|
100 |
+
V_r = V_t_r * W_r - V_t_i * W_i
|
101 |
+
V_i = V_t_r * W_i + V_t_i * W_r
|
102 |
+
|
103 |
+
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
|
104 |
+
|
105 |
+
v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
|
106 |
+
x = v.new_zeros(v.shape)
|
107 |
+
x[:, ::2] += v[:, :N - (N // 2)]
|
108 |
+
x[:, 1::2] += v.flip([1])[:, :N // 2]
|
109 |
+
|
110 |
+
return x.view(*x_shape)
|
111 |
+
|
112 |
+
|
113 |
+
def dct_2d(x, norm=None):
|
114 |
+
"""
|
115 |
+
2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
116 |
+
For the meaning of the parameter `norm`, see:
|
117 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
118 |
+
:param x: the input signal
|
119 |
+
:param norm: the normalization, None or 'ortho'
|
120 |
+
:return: the DCT-II of the signal over the last 2 dimensions
|
121 |
+
"""
|
122 |
+
X1 = dct(x, norm=norm)
|
123 |
+
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
124 |
+
return X2.transpose(-1, -2)
|
125 |
+
|
126 |
+
|
127 |
+
def idct_2d(X, norm=None):
|
128 |
+
"""
|
129 |
+
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
130 |
+
Our definition of idct is that idct_2d(dct_2d(x)) == x
|
131 |
+
For the meaning of the parameter `norm`, see:
|
132 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
133 |
+
:param X: the input signal
|
134 |
+
:param norm: the normalization, None or 'ortho'
|
135 |
+
:return: the DCT-II of the signal over the last 2 dimensions
|
136 |
+
"""
|
137 |
+
x1 = idct(X, norm=norm)
|
138 |
+
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
139 |
+
return x2.transpose(-1, -2)
|
140 |
+
|
141 |
+
|
142 |
+
def dct_3d(x, norm=None):
|
143 |
+
"""
|
144 |
+
3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
145 |
+
For the meaning of the parameter `norm`, see:
|
146 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
147 |
+
:param x: the input signal
|
148 |
+
:param norm: the normalization, None or 'ortho'
|
149 |
+
:return: the DCT-II of the signal over the last 3 dimensions
|
150 |
+
"""
|
151 |
+
X1 = dct(x, norm=norm)
|
152 |
+
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
153 |
+
X3 = dct(X2.transpose(-1, -3), norm=norm)
|
154 |
+
return X3.transpose(-1, -3).transpose(-1, -2)
|
155 |
+
|
156 |
+
|
157 |
+
def idct_3d(X, norm=None):
|
158 |
+
"""
|
159 |
+
The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
160 |
+
Our definition of idct is that idct_3d(dct_3d(x)) == x
|
161 |
+
For the meaning of the parameter `norm`, see:
|
162 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
163 |
+
:param X: the input signal
|
164 |
+
:param norm: the normalization, None or 'ortho'
|
165 |
+
:return: the DCT-II of the signal over the last 3 dimensions
|
166 |
+
"""
|
167 |
+
x1 = idct(X, norm=norm)
|
168 |
+
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
169 |
+
x3 = idct(x2.transpose(-1, -3), norm=norm)
|
170 |
+
return x3.transpose(-1, -3).transpose(-1, -2)
|
171 |
+
|
172 |
+
|
173 |
+
class LinearDCT(nn.Linear):
|
174 |
+
"""Implement any DCT as a linear layer; in practice this executes around
|
175 |
+
50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
|
176 |
+
increase memory usage.
|
177 |
+
:param in_features: size of expected input
|
178 |
+
:param type: which dct function in this file to use"""
|
179 |
+
def __init__(self, in_features, type, norm=None, bias=False):
|
180 |
+
self.type = type
|
181 |
+
self.N = in_features
|
182 |
+
self.norm = norm
|
183 |
+
super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
|
184 |
+
|
185 |
+
def reset_parameters(self):
|
186 |
+
# initialise using dct function
|
187 |
+
I = torch.eye(self.N)
|
188 |
+
if self.type == 'dct1':
|
189 |
+
self.weight.data = dct1(I).data.t()
|
190 |
+
elif self.type == 'idct1':
|
191 |
+
self.weight.data = idct1(I).data.t()
|
192 |
+
elif self.type == 'dct':
|
193 |
+
self.weight.data = dct(I, norm=self.norm).data.t()
|
194 |
+
elif self.type == 'idct':
|
195 |
+
self.weight.data = idct(I, norm=self.norm).data.t()
|
196 |
+
self.weight.requires_grad = False # don't learn this!
|
197 |
+
|
198 |
+
|
199 |
+
def apply_linear_2d(x, linear_layer):
|
200 |
+
"""Can be used with a LinearDCT layer to do a 2D DCT.
|
201 |
+
:param x: the input signal
|
202 |
+
:param linear_layer: any PyTorch Linear layer
|
203 |
+
:return: result of linear layer applied to last 2 dimensions
|
204 |
+
"""
|
205 |
+
X1 = linear_layer(x)
|
206 |
+
X2 = linear_layer(X1.transpose(-1, -2))
|
207 |
+
return X2.transpose(-1, -2)
|
208 |
+
|
209 |
+
|
210 |
+
def apply_linear_3d(x, linear_layer):
|
211 |
+
"""Can be used with a LinearDCT layer to do a 3D DCT.
|
212 |
+
:param x: the input signal
|
213 |
+
:param linear_layer: any PyTorch Linear layer
|
214 |
+
:return: result of linear layer applied to last 3 dimensions
|
215 |
+
"""
|
216 |
+
X1 = linear_layer(x)
|
217 |
+
X2 = linear_layer(X1.transpose(-1, -2))
|
218 |
+
X3 = linear_layer(X2.transpose(-1, -3))
|
219 |
+
return X3.transpose(-1, -3).transpose(-1, -2)
|
220 |
+
|
221 |
+
|
222 |
+
def torch_rgb2ycbcr(x):
|
223 |
+
# Assume x is a batch of size (N x C x H x W)
|
224 |
+
v = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).to(x.device)
|
225 |
+
ycbcr = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
|
226 |
+
ycbcr[:,1:] += 128
|
227 |
+
return ycbcr
|
228 |
+
|
229 |
+
|
230 |
+
def torch_ycbcr2rgb(x):
|
231 |
+
# Assume x is a batch of size (N x C x H x W)
|
232 |
+
v = torch.tensor([[ 1.00000000e+00, -3.68199903e-05, 1.40198758e+00],
|
233 |
+
[ 1.00000000e+00, -3.44113281e-01, -7.14103821e-01],
|
234 |
+
[ 1.00000000e+00, 1.77197812e+00, -1.34583413e-04]]).to(x.device)
|
235 |
+
x[:, 1:] -= 128
|
236 |
+
rgb = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
|
237 |
+
return rgb
|
238 |
+
|
239 |
+
def chroma_subsample(x):
|
240 |
+
return x[:, 0:1, :, :], x[:, 1:, ::2, ::2]
|
241 |
+
|
242 |
+
|
243 |
+
def general_quant_matrix(qf = 10):
|
244 |
+
q1 = torch.tensor([
|
245 |
+
16, 11, 10, 16, 24, 40, 51, 61,
|
246 |
+
12, 12, 14, 19, 26, 58, 60, 55,
|
247 |
+
14, 13, 16, 24, 40, 57, 69, 56,
|
248 |
+
14, 17, 22, 29, 51, 87, 80, 62,
|
249 |
+
18, 22, 37, 56, 68, 109, 103, 77,
|
250 |
+
24, 35, 55, 64, 81, 104, 113, 92,
|
251 |
+
49, 64, 78, 87, 103, 121, 120, 101,
|
252 |
+
72, 92, 95, 98, 112, 100, 103, 99
|
253 |
+
])
|
254 |
+
q2 = torch.tensor([
|
255 |
+
17, 18, 24, 47, 99, 99, 99, 99,
|
256 |
+
18, 21, 26, 66, 99, 99, 99, 99,
|
257 |
+
24, 26, 56, 99, 99, 99, 99, 99,
|
258 |
+
47, 66, 99, 99, 99, 99, 99, 99,
|
259 |
+
99, 99, 99, 99, 99, 99, 99, 99,
|
260 |
+
99, 99, 99, 99, 99, 99, 99, 99,
|
261 |
+
99, 99, 99, 99, 99, 99, 99, 99,
|
262 |
+
99, 99, 99, 99, 99, 99, 99, 99
|
263 |
+
])
|
264 |
+
s = (5000 / qf) if qf < 50 else (200 - 2 * qf)
|
265 |
+
q1 = torch.floor((s * q1 + 50) / 100)
|
266 |
+
q1[q1 <= 0] = 1
|
267 |
+
q1[q1 > 255] = 255
|
268 |
+
q2 = torch.floor((s * q2 + 50) / 100)
|
269 |
+
q2[q2 <= 0] = 1
|
270 |
+
q2[q2 > 255] = 255
|
271 |
+
return q1, q2
|
272 |
+
|
273 |
+
|
274 |
+
def quantization_matrix(qf):
|
275 |
+
return general_quant_matrix(qf)
|
276 |
+
# q1 = torch.tensor([[ 80, 55, 50, 80, 120, 200, 255, 255],
|
277 |
+
# [ 60, 60, 70, 95, 130, 255, 255, 255],
|
278 |
+
# [ 70, 65, 80, 120, 200, 255, 255, 255],
|
279 |
+
# [ 70, 85, 110, 145, 255, 255, 255, 255],
|
280 |
+
# [ 90, 110, 185, 255, 255, 255, 255, 255],
|
281 |
+
# [120, 175, 255, 255, 255, 255, 255, 255],
|
282 |
+
# [245, 255, 255, 255, 255, 255, 255, 255],
|
283 |
+
# [255, 255, 255, 255, 255, 255, 255, 255]])
|
284 |
+
# q2 = torch.tensor([[ 85, 90, 120, 235, 255, 255, 255, 255],
|
285 |
+
# [ 90, 105, 130, 255, 255, 255, 255, 255],
|
286 |
+
# [120, 130, 255, 255, 255, 255, 255, 255],
|
287 |
+
# [235, 255, 255, 255, 255, 255, 255, 255],
|
288 |
+
# [255, 255, 255, 255, 255, 255, 255, 255],
|
289 |
+
# [255, 255, 255, 255, 255, 255, 255, 255],
|
290 |
+
# [255, 255, 255, 255, 255, 255, 255, 255],
|
291 |
+
# [255, 255, 255, 255, 255, 255, 255, 255]])
|
292 |
+
# return q1, q2
|
293 |
+
|
294 |
+
def jpeg_encode(x, qf):
|
295 |
+
# Assume x is a batch of size (N x C x H x W)
|
296 |
+
# [-1, 1] to [0, 255]
|
297 |
+
x = (x + 1) / 2 * 255
|
298 |
+
n_batch, _, n_size, _ = x.shape
|
299 |
+
|
300 |
+
x = torch_rgb2ycbcr(x)
|
301 |
+
x_luma, x_chroma = chroma_subsample(x)
|
302 |
+
unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
|
303 |
+
x_luma = unfold(x_luma).transpose(2, 1)
|
304 |
+
x_chroma = unfold(x_chroma).transpose(2, 1)
|
305 |
+
|
306 |
+
x_luma = x_luma.reshape(-1, 8, 8) - 128
|
307 |
+
x_chroma = x_chroma.reshape(-1, 8, 8) - 128
|
308 |
+
|
309 |
+
dct_layer = LinearDCT(8, 'dct', norm='ortho')
|
310 |
+
dct_layer.to(x_luma.device)
|
311 |
+
x_luma = apply_linear_2d(x_luma, dct_layer)
|
312 |
+
x_chroma = apply_linear_2d(x_chroma, dct_layer)
|
313 |
+
|
314 |
+
x_luma = x_luma.view(-1, 1, 8, 8)
|
315 |
+
x_chroma = x_chroma.view(-1, 2, 8, 8)
|
316 |
+
|
317 |
+
q1, q2 = quantization_matrix(qf)
|
318 |
+
q1 = q1.to(x_luma.device)
|
319 |
+
q2 = q2.to(x_luma.device)
|
320 |
+
x_luma /= q1.view(1, 8, 8)
|
321 |
+
x_chroma /= q2.view(1, 8, 8)
|
322 |
+
|
323 |
+
x_luma = x_luma.round()
|
324 |
+
x_chroma = x_chroma.round()
|
325 |
+
|
326 |
+
x_luma = x_luma.reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
|
327 |
+
x_chroma = x_chroma.reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
|
328 |
+
|
329 |
+
fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
|
330 |
+
x_luma = fold(x_luma)
|
331 |
+
fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
|
332 |
+
x_chroma = fold(x_chroma)
|
333 |
+
|
334 |
+
return [x_luma, x_chroma]
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
def jpeg_decode(x, qf):
|
339 |
+
# Assume x[0] is a batch of size (N x 1 x H x W) (luma)
|
340 |
+
# Assume x[1:] is a batch of size (N x 2 x H/2 x W/2) (chroma)
|
341 |
+
x_luma, x_chroma = x
|
342 |
+
n_batch, _, n_size, _ = x_luma.shape
|
343 |
+
unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
|
344 |
+
x_luma = unfold(x_luma).transpose(2, 1)
|
345 |
+
x_luma = x_luma.reshape(-1, 1, 8, 8)
|
346 |
+
x_chroma = unfold(x_chroma).transpose(2, 1)
|
347 |
+
x_chroma = x_chroma.reshape(-1, 2, 8, 8)
|
348 |
+
|
349 |
+
q1, q2 = quantization_matrix(qf)
|
350 |
+
q1 = q1.to(x_luma.device)
|
351 |
+
q2 = q2.to(x_luma.device)
|
352 |
+
x_luma *= q1.view(1, 8, 8)
|
353 |
+
x_chroma *= q2.view(1, 8, 8)
|
354 |
+
|
355 |
+
x_luma = x_luma.reshape(-1, 8, 8)
|
356 |
+
x_chroma = x_chroma.reshape(-1, 8, 8)
|
357 |
+
|
358 |
+
dct_layer = LinearDCT(8, 'idct', norm='ortho')
|
359 |
+
dct_layer.to(x_luma.device)
|
360 |
+
x_luma = apply_linear_2d(x_luma, dct_layer)
|
361 |
+
x_chroma = apply_linear_2d(x_chroma, dct_layer)
|
362 |
+
|
363 |
+
x_luma = (x_luma + 128).reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
|
364 |
+
x_chroma = (x_chroma + 128).reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
|
365 |
+
|
366 |
+
fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
|
367 |
+
x_luma = fold(x_luma)
|
368 |
+
fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
|
369 |
+
x_chroma = fold(x_chroma)
|
370 |
+
|
371 |
+
x_chroma_repeated = torch.zeros(n_batch, 2, n_size, n_size, device = x_luma.device)
|
372 |
+
x_chroma_repeated[:, :, 0::2, 0::2] = x_chroma
|
373 |
+
x_chroma_repeated[:, :, 0::2, 1::2] = x_chroma
|
374 |
+
x_chroma_repeated[:, :, 1::2, 0::2] = x_chroma
|
375 |
+
x_chroma_repeated[:, :, 1::2, 1::2] = x_chroma
|
376 |
+
|
377 |
+
x = torch.cat([x_luma, x_chroma_repeated], dim=1)
|
378 |
+
|
379 |
+
x = torch_ycbcr2rgb(x)
|
380 |
+
|
381 |
+
# [0, 255] to [-1, 1]
|
382 |
+
x = x / 255 * 2 - 1
|
383 |
+
|
384 |
+
return x
|
385 |
+
|
386 |
+
|
387 |
+
def build_jpeg(qf):
|
388 |
+
# log.info(f"[Corrupt] JPEG restoration: {qf=} ...")
|
389 |
+
def jpeg(img):
|
390 |
+
encoded = jpeg_encode(img, qf)
|
391 |
+
return jpeg_decode(encoded, qf), encoded
|
392 |
+
return jpeg
|
src/flair/functions/measurements.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.'''
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torchvision import torch
|
8 |
+
|
9 |
+
from flair.utils.blur_util import Blurkernel
|
10 |
+
from flair.utils.img_util import fft2d
|
11 |
+
import numpy as np
|
12 |
+
from flair.utils.resizer import Resizer
|
13 |
+
from flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
|
14 |
+
from torch.fft import fft2, ifft2
|
15 |
+
|
16 |
+
|
17 |
+
from flair.motionblur.motionblur import Kernel
|
18 |
+
|
19 |
+
# =================
|
20 |
+
# Operation classes
|
21 |
+
# =================
|
22 |
+
|
23 |
+
__OPERATOR__ = {}
|
24 |
+
_GAMMA_FACTOR = 2.2
|
25 |
+
|
26 |
+
def register_operator(name: str):
|
27 |
+
def wrapper(cls):
|
28 |
+
if __OPERATOR__.get(name, None):
|
29 |
+
raise NameError(f"Name {name} is already registered!")
|
30 |
+
__OPERATOR__[name] = cls
|
31 |
+
return cls
|
32 |
+
return wrapper
|
33 |
+
|
34 |
+
|
35 |
+
def get_operator(name: str, **kwargs):
|
36 |
+
if __OPERATOR__.get(name, None) is None:
|
37 |
+
raise NameError(f"Name {name} is not defined.")
|
38 |
+
return __OPERATOR__[name](**kwargs)
|
39 |
+
|
40 |
+
|
41 |
+
class LinearOperator(ABC):
|
42 |
+
@abstractmethod
|
43 |
+
def forward(self, data, **kwargs):
|
44 |
+
# calculate A * X
|
45 |
+
pass
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def noisy_forward(self, data, **kwargs):
|
49 |
+
# calculate A * X + n
|
50 |
+
pass
|
51 |
+
|
52 |
+
@abstractmethod
|
53 |
+
def transpose(self, data, **kwargs):
|
54 |
+
# calculate A^T * X
|
55 |
+
pass
|
56 |
+
|
57 |
+
def ortho_project(self, data, **kwargs):
|
58 |
+
# calculate (I - A^T * A)X
|
59 |
+
return data - self.transpose(self.forward(data, **kwargs), **kwargs)
|
60 |
+
|
61 |
+
def project(self, data, measurement, **kwargs):
|
62 |
+
# calculate (I - A^T * A)Y - AX
|
63 |
+
return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)
|
64 |
+
|
65 |
+
|
66 |
+
@register_operator(name='noise')
|
67 |
+
class DenoiseOperator(LinearOperator):
|
68 |
+
def __init__(self, device):
|
69 |
+
self.device = device
|
70 |
+
|
71 |
+
def forward(self, data):
|
72 |
+
return data
|
73 |
+
|
74 |
+
def noisy_forward(self, data):
|
75 |
+
return data
|
76 |
+
|
77 |
+
def transpose(self, data):
|
78 |
+
return data
|
79 |
+
|
80 |
+
def ortho_project(self, data):
|
81 |
+
return data
|
82 |
+
|
83 |
+
def project(self, data):
|
84 |
+
return data
|
85 |
+
|
86 |
+
|
87 |
+
@register_operator(name='sr_bicubic')
|
88 |
+
class SuperResolutionOperator(LinearOperator):
|
89 |
+
def __init__(self,
|
90 |
+
in_shape,
|
91 |
+
scale_factor,
|
92 |
+
noise,
|
93 |
+
noise_scale,
|
94 |
+
device):
|
95 |
+
self.device = device
|
96 |
+
self.up_sample = partial(F.interpolate, scale_factor=scale_factor)
|
97 |
+
self.down_sample = Resizer(in_shape, 1/scale_factor).to(device)
|
98 |
+
self.noise = get_noise(name=noise, scale=noise_scale)
|
99 |
+
|
100 |
+
def A(self, data, **kwargs):
|
101 |
+
return self.forward(data, **kwargs)
|
102 |
+
|
103 |
+
def forward(self, data, **kwargs):
|
104 |
+
return self.down_sample(data)
|
105 |
+
|
106 |
+
def noisy_forward(self, data, **kwargs):
|
107 |
+
return self.noise.forward(self.down_sample(data))
|
108 |
+
|
109 |
+
def transpose(self, data, **kwargs):
|
110 |
+
return self.up_sample(data)
|
111 |
+
|
112 |
+
def project(self, data, measurement, **kwargs):
|
113 |
+
return data - self.transpose(self.forward(data)) + self.transpose(measurement)
|
114 |
+
|
115 |
+
@register_operator(name='deblur_motion')
|
116 |
+
class MotionBlurOperator(LinearOperator):
|
117 |
+
def __init__(self,
|
118 |
+
kernel_size,
|
119 |
+
intensity,
|
120 |
+
device):
|
121 |
+
self.device = device
|
122 |
+
self.kernel_size = kernel_size
|
123 |
+
self.conv = Blurkernel(blur_type='motion',
|
124 |
+
kernel_size=kernel_size,
|
125 |
+
std=intensity,
|
126 |
+
device=device).to(device) # should we keep this device term?
|
127 |
+
|
128 |
+
self.kernel_size =kernel_size
|
129 |
+
self.intensity = intensity
|
130 |
+
self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity)
|
131 |
+
kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32)
|
132 |
+
self.conv.update_weights(kernel)
|
133 |
+
|
134 |
+
def forward(self, data, **kwargs):
|
135 |
+
# A^T * A
|
136 |
+
return self.conv(data)
|
137 |
+
|
138 |
+
def noisy_forward(self, data, **kwargs):
|
139 |
+
pass
|
140 |
+
|
141 |
+
def transpose(self, data, **kwargs):
|
142 |
+
return data
|
143 |
+
|
144 |
+
def change_kernel(self):
|
145 |
+
self.kernel = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.intensity)
|
146 |
+
kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32)
|
147 |
+
self.conv.update_weights(kernel)
|
148 |
+
|
149 |
+
def get_kernel(self):
|
150 |
+
kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device)
|
151 |
+
return kernel.view(1, 1, self.kernel_size, self.kernel_size)
|
152 |
+
|
153 |
+
def A(self, data):
|
154 |
+
return self.forward(data)
|
155 |
+
|
156 |
+
def At(self, data):
|
157 |
+
return self.transpose(data)
|
158 |
+
|
159 |
+
@register_operator(name='deblur_gauss')
|
160 |
+
class GaussialBlurOperator(LinearOperator):
|
161 |
+
def __init__(self,
|
162 |
+
kernel_size,
|
163 |
+
intensity,
|
164 |
+
device):
|
165 |
+
self.device = device
|
166 |
+
self.kernel_size = kernel_size
|
167 |
+
self.conv = Blurkernel(blur_type='gaussian',
|
168 |
+
kernel_size=kernel_size,
|
169 |
+
std=intensity,
|
170 |
+
device=device).to(device)
|
171 |
+
self.kernel = self.conv.get_kernel()
|
172 |
+
self.conv.update_weights(self.kernel.type(torch.float32))
|
173 |
+
|
174 |
+
def forward(self, data, **kwargs):
|
175 |
+
return self.conv(data)
|
176 |
+
|
177 |
+
def noisy_forward(self, data, **kwargs):
|
178 |
+
pass
|
179 |
+
|
180 |
+
def transpose(self, data, **kwargs):
|
181 |
+
return data
|
182 |
+
|
183 |
+
def get_kernel(self):
|
184 |
+
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
|
185 |
+
|
186 |
+
def apply_kernel(self, data, kernel):
|
187 |
+
self.conv.update_weights(kernel.type(torch.float32))
|
188 |
+
return self.conv(data)
|
189 |
+
|
190 |
+
def A(self, data):
|
191 |
+
return self.forward(data)
|
192 |
+
|
193 |
+
def At(self, data):
|
194 |
+
return self.transpose(data)
|
195 |
+
|
196 |
+
@register_operator(name='inpainting')
|
197 |
+
class InpaintingOperator(LinearOperator):
|
198 |
+
'''This operator get pre-defined mask and return masked image.'''
|
199 |
+
def __init__(self,
|
200 |
+
noise,
|
201 |
+
noise_scale,
|
202 |
+
device):
|
203 |
+
self.device = device
|
204 |
+
self.noise = get_noise(name=noise, scale=noise_scale)
|
205 |
+
|
206 |
+
def forward(self, data, **kwargs):
|
207 |
+
try:
|
208 |
+
return data * kwargs.get('mask', None).to(self.device)
|
209 |
+
except:
|
210 |
+
raise ValueError("Require mask")
|
211 |
+
|
212 |
+
def noisy_forward(self, data, **kwargs):
|
213 |
+
return self.noise.forward(self.forward(data, **kwargs))
|
214 |
+
|
215 |
+
def transpose(self, data, **kwargs):
|
216 |
+
return data
|
217 |
+
|
218 |
+
def ortho_project(self, data, **kwargs):
|
219 |
+
return data - self.forward(data, **kwargs)
|
220 |
+
|
221 |
+
# Operator for BlindDPS.
|
222 |
+
@register_operator(name='blind_blur')
|
223 |
+
class BlindBlurOperator(LinearOperator):
|
224 |
+
def __init__(self, device, **kwargs) -> None:
|
225 |
+
self.device = device
|
226 |
+
|
227 |
+
def forward(self, data, kernel, **kwargs):
|
228 |
+
return self.apply_kernel(data, kernel)
|
229 |
+
|
230 |
+
def transpose(self, data, **kwargs):
|
231 |
+
return data
|
232 |
+
|
233 |
+
def apply_kernel(self, data, kernel):
|
234 |
+
#TODO: faster way to apply conv?:W
|
235 |
+
|
236 |
+
b_img = torch.zeros_like(data).to(self.device)
|
237 |
+
for i in range(3):
|
238 |
+
b_img[:, i, :, :] = F.conv2d(data[:, i:i+1, :, :], kernel, padding='same')
|
239 |
+
return b_img
|
240 |
+
|
241 |
+
|
242 |
+
class NonLinearOperator(ABC):
|
243 |
+
@abstractmethod
|
244 |
+
def forward(self, data, **kwargs):
|
245 |
+
pass
|
246 |
+
|
247 |
+
@abstractmethod
|
248 |
+
def noisy_forward(self, data, **kwargs):
|
249 |
+
pass
|
250 |
+
|
251 |
+
def project(self, data, measurement, **kwargs):
|
252 |
+
return data + measurement - self.forward(data)
|
253 |
+
|
254 |
+
@register_operator(name='phase_retrieval')
|
255 |
+
class PhaseRetrievalOperator(NonLinearOperator):
|
256 |
+
def __init__(self,
|
257 |
+
oversample,
|
258 |
+
noise,
|
259 |
+
noise_scale,
|
260 |
+
device):
|
261 |
+
self.pad = int((oversample / 8.0) * 256)
|
262 |
+
self.device = device
|
263 |
+
self.noise = get_noise(name=noise, scale=noise_scale)
|
264 |
+
|
265 |
+
def forward(self, data, **kwargs):
|
266 |
+
padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad))
|
267 |
+
amplitude = fft2d(padded).abs()
|
268 |
+
return amplitude
|
269 |
+
|
270 |
+
def noisy_forard(self, data, **kwargs):
|
271 |
+
return self.noise.forward(self.forward(data, **kwargs))
|
272 |
+
|
273 |
+
@register_operator(name='nonuniform_blur')
|
274 |
+
class NonuniformBlurOperator(LinearOperator):
|
275 |
+
def __init__(self, in_shape, kernel_size, device,
|
276 |
+
kernels=None, masks=None):
|
277 |
+
self.device = device
|
278 |
+
self.kernel_size = kernel_size
|
279 |
+
self.in_shape = in_shape
|
280 |
+
|
281 |
+
# TODO: generalize
|
282 |
+
if kernels is None and masks is None:
|
283 |
+
self.kernels = np.load('./functions/nonuniform/kernels/000001.npy')
|
284 |
+
self.masks = np.load('./functions/nonuniform/masks/000001.npy')
|
285 |
+
self.kernels = torch.tensor(self.kernels).to(device)
|
286 |
+
self.masks = torch.tensor(self.masks).to(device)
|
287 |
+
|
288 |
+
# approximate in image space
|
289 |
+
def forward_img(self, data):
|
290 |
+
K = self.kernel_size
|
291 |
+
data = F.pad(data, (K//2, K//2, K//2, K//2), mode="reflect")
|
292 |
+
kernels = self.kernels.transpose(0, 1)
|
293 |
+
data_rgb_batch = data.transpose(0, 1)
|
294 |
+
conv_rgb_batch = F.conv2d(data_rgb_batch, kernels)
|
295 |
+
y_rgb_batch = conv_rgb_batch * self.masks
|
296 |
+
y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True)
|
297 |
+
y = y_rgb_batch.transpose(0, 1)
|
298 |
+
return y
|
299 |
+
|
300 |
+
# NOTE: Only using this operator will make the problem nonlinear (gamma-correction)
|
301 |
+
def forward_nonlinear(self, data, flatten=False, noiseless=False):
|
302 |
+
# 1. Usual nonuniform blurring degradataion pipeline
|
303 |
+
kernels = self.kernels.transpose(0, 1)
|
304 |
+
FK, FKC = pre_calculate_FK(kernels)
|
305 |
+
y = ifft2(FK * fft2(data)).real
|
306 |
+
y = y.transpose(0, 1)
|
307 |
+
y_rgb_batch = self.masks * y
|
308 |
+
y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True)
|
309 |
+
y = y_rgb_batch.transpose(0, 1)
|
310 |
+
F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks)
|
311 |
+
self.pre_calculated = (FK, FKC, F2KM, FKFMy)
|
312 |
+
# 2. Gamma-correction
|
313 |
+
y = (y + 1) / 2
|
314 |
+
y = y ** (1 / _GAMMA_FACTOR)
|
315 |
+
y = (y - 0.5) / 0.5
|
316 |
+
return y
|
317 |
+
|
318 |
+
def noisy_forward(self, data, **kwargs):
|
319 |
+
return self.noise.forward(self.forward(data))
|
320 |
+
|
321 |
+
# exact in Fourier
|
322 |
+
def forward(self, data, flatten=False, noiseless=False):
|
323 |
+
# [1, 25, 33, 33] -> [25, 1, 33, 33]
|
324 |
+
kernels = self.kernels.transpose(0, 1)
|
325 |
+
# [25, 1, 512, 512]
|
326 |
+
FK, FKC = pre_calculate_FK(kernels)
|
327 |
+
# [25, 3, 512, 512]
|
328 |
+
y = ifft2(FK * fft2(data)).real
|
329 |
+
# [3, 25, 512, 512]
|
330 |
+
y = y.transpose(0, 1)
|
331 |
+
y_rgb_batch = self.masks * y
|
332 |
+
# [3, 1, 512, 512]
|
333 |
+
y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True)
|
334 |
+
# [1, 3, 512, 512]
|
335 |
+
y = y_rgb_batch.transpose(0, 1)
|
336 |
+
F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks)
|
337 |
+
self.pre_calculated = (FK, FKC, F2KM, FKFMy)
|
338 |
+
return y
|
339 |
+
|
340 |
+
def transpose(self, y, flatten=False):
|
341 |
+
kernels = self.kernels.transpose(0, 1)
|
342 |
+
FK, FKC = pre_calculate_FK(kernels)
|
343 |
+
# 1. braodcast and multiply
|
344 |
+
# [3, 1, 512, 512]
|
345 |
+
y_rgb_batch = y.transpose(0, 1)
|
346 |
+
# [3, 25, 512, 512]
|
347 |
+
y_rgb_batch = y_rgb_batch.repeat(1, 25, 1, 1)
|
348 |
+
y = self.masks * y_rgb_batch
|
349 |
+
# 2. transpose of convolution in Fourier
|
350 |
+
# [25, 3, 512, 512]
|
351 |
+
y = y.transpose(0, 1)
|
352 |
+
ATy_broadcast = ifft2(FKC * fft2(y)).real
|
353 |
+
# [1, 3, 512, 512]
|
354 |
+
ATy = torch.sum(ATy_broadcast, dim=0, keepdim=True)
|
355 |
+
return ATy
|
356 |
+
|
357 |
+
def A(self, data):
|
358 |
+
return self.forward(data)
|
359 |
+
|
360 |
+
def At(self, data):
|
361 |
+
return self.transpose(data)
|
362 |
+
|
363 |
+
# =============
|
364 |
+
# Noise classes
|
365 |
+
# =============
|
366 |
+
|
367 |
+
|
368 |
+
__NOISE__ = {}
|
369 |
+
|
370 |
+
def register_noise(name: str):
|
371 |
+
def wrapper(cls):
|
372 |
+
if __NOISE__.get(name, None):
|
373 |
+
raise NameError(f"Name {name} is already defined!")
|
374 |
+
__NOISE__[name] = cls
|
375 |
+
return cls
|
376 |
+
return wrapper
|
377 |
+
|
378 |
+
def get_noise(name: str, **kwargs):
|
379 |
+
if __NOISE__.get(name, None) is None:
|
380 |
+
raise NameError(f"Name {name} is not defined.")
|
381 |
+
noiser = __NOISE__[name](**kwargs)
|
382 |
+
noiser.__name__ = name
|
383 |
+
return noiser
|
384 |
+
|
385 |
+
class Noise(ABC):
|
386 |
+
def __call__(self, data):
|
387 |
+
return self.forward(data)
|
388 |
+
|
389 |
+
@abstractmethod
|
390 |
+
def forward(self, data):
|
391 |
+
pass
|
392 |
+
|
393 |
+
@register_noise(name='clean')
|
394 |
+
class Clean(Noise):
|
395 |
+
def __init__(self, **kwargs):
|
396 |
+
pass
|
397 |
+
|
398 |
+
def forward(self, data):
|
399 |
+
return data
|
400 |
+
|
401 |
+
@register_noise(name='gaussian')
|
402 |
+
class GaussianNoise(Noise):
|
403 |
+
def __init__(self, scale):
|
404 |
+
self.scale = scale
|
405 |
+
|
406 |
+
def forward(self, data):
|
407 |
+
return data + torch.randn_like(data, device=data.device) * self.scale
|
408 |
+
|
409 |
+
|
410 |
+
@register_noise(name='poisson')
|
411 |
+
class PoissonNoise(Noise):
|
412 |
+
def __init__(self, scale):
|
413 |
+
self.scale = scale
|
414 |
+
|
415 |
+
def forward(self, data):
|
416 |
+
'''
|
417 |
+
Follow skimage.util.random_noise.
|
418 |
+
'''
|
419 |
+
|
420 |
+
# version 3 (stack-overflow)
|
421 |
+
import numpy as np
|
422 |
+
data = (data + 1.0) / 2.0
|
423 |
+
data = data.clamp(0, 1)
|
424 |
+
device = data.device
|
425 |
+
data = data.detach().cpu()
|
426 |
+
data = torch.from_numpy(np.random.poisson(data * 255.0 * self.scale) / 255.0 / self.scale)
|
427 |
+
data = data * 2.0 - 1.0
|
428 |
+
data = data.clamp(-1, 1)
|
429 |
+
return data.to(device)
|
src/flair/functions/nonuniform/kernels/000001.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:538b2e852fdfd3966628fbf22b53ed75343c43084608aa12c05c7dbbd0db6728
|
3 |
+
size 109028
|
src/flair/functions/svd_ddnm.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import torchvision.utils as tvu
|
4 |
+
import torchvision
|
5 |
+
import os
|
6 |
+
|
7 |
+
class_num = 951
|
8 |
+
|
9 |
+
|
10 |
+
def compute_alpha(beta, t):
|
11 |
+
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
|
12 |
+
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
|
13 |
+
return a
|
14 |
+
|
15 |
+
def inverse_data_transform(x):
|
16 |
+
x = (x + 1.0) / 2.0
|
17 |
+
return torch.clamp(x, 0.0, 1.0)
|
18 |
+
|
19 |
+
def ddnm_diffusion(x, model, b, eta, A_funcs, y, cls_fn=None, classes=None, config=None):
|
20 |
+
with torch.no_grad():
|
21 |
+
|
22 |
+
# setup iteration variables
|
23 |
+
skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling
|
24 |
+
n = x.size(0)
|
25 |
+
x0_preds = []
|
26 |
+
xs = [x]
|
27 |
+
|
28 |
+
# generate time schedule
|
29 |
+
times = get_schedule_jump(config.time_travel.T_sampling,
|
30 |
+
config.time_travel.travel_length,
|
31 |
+
config.time_travel.travel_repeat,
|
32 |
+
)
|
33 |
+
time_pairs = list(zip(times[:-1], times[1:]))
|
34 |
+
|
35 |
+
# reverse diffusion sampling
|
36 |
+
for i, j in tqdm(time_pairs):
|
37 |
+
i, j = i*skip, j*skip
|
38 |
+
if j<0: j=-1
|
39 |
+
|
40 |
+
if j < i: # normal sampling
|
41 |
+
t = (torch.ones(n) * i).to(x.device)
|
42 |
+
next_t = (torch.ones(n) * j).to(x.device)
|
43 |
+
at = compute_alpha(b, t.long())
|
44 |
+
at_next = compute_alpha(b, next_t.long())
|
45 |
+
xt = xs[-1].to('cuda')
|
46 |
+
if cls_fn == None:
|
47 |
+
et = model(xt, t)
|
48 |
+
else:
|
49 |
+
classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num
|
50 |
+
et = model(xt, t, classes)
|
51 |
+
et = et[:, :3]
|
52 |
+
et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes)
|
53 |
+
|
54 |
+
if et.size(1) == 6:
|
55 |
+
et = et[:, :3]
|
56 |
+
|
57 |
+
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
|
58 |
+
|
59 |
+
x0_t_hat = x0_t - A_funcs.A_pinv(
|
60 |
+
A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
|
61 |
+
).reshape(*x0_t.size())
|
62 |
+
|
63 |
+
c1 = (1 - at_next).sqrt() * eta
|
64 |
+
c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)
|
65 |
+
xt_next = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et
|
66 |
+
|
67 |
+
x0_preds.append(x0_t.to('cpu'))
|
68 |
+
xs.append(xt_next.to('cpu'))
|
69 |
+
else: # time-travel back
|
70 |
+
next_t = (torch.ones(n) * j).to(x.device)
|
71 |
+
at_next = compute_alpha(b, next_t.long())
|
72 |
+
x0_t = x0_preds[-1].to('cuda')
|
73 |
+
|
74 |
+
xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt()
|
75 |
+
|
76 |
+
xs.append(xt_next.to('cpu'))
|
77 |
+
|
78 |
+
return [xs[-1]], [x0_preds[-1]]
|
79 |
+
|
80 |
+
def ddnm_plus_diffusion(x, model, b, eta, A_funcs, y, sigma_y, cls_fn=None, classes=None, config=None):
|
81 |
+
with torch.no_grad():
|
82 |
+
|
83 |
+
# setup iteration variables
|
84 |
+
skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling
|
85 |
+
n = x.size(0)
|
86 |
+
x0_preds = []
|
87 |
+
xs = [x]
|
88 |
+
|
89 |
+
# generate time schedule
|
90 |
+
times = get_schedule_jump(config.time_travel.T_sampling,
|
91 |
+
config.time_travel.travel_length,
|
92 |
+
config.time_travel.travel_repeat,
|
93 |
+
)
|
94 |
+
time_pairs = list(zip(times[:-1], times[1:]))
|
95 |
+
|
96 |
+
# reverse diffusion sampling
|
97 |
+
for i, j in tqdm(time_pairs):
|
98 |
+
i, j = i*skip, j*skip
|
99 |
+
if j<0: j=-1
|
100 |
+
|
101 |
+
if j < i: # normal sampling
|
102 |
+
t = (torch.ones(n) * i).to(x.device)
|
103 |
+
next_t = (torch.ones(n) * j).to(x.device)
|
104 |
+
at = compute_alpha(b, t.long())
|
105 |
+
at_next = compute_alpha(b, next_t.long())
|
106 |
+
xt = xs[-1].to('cuda')
|
107 |
+
if cls_fn == None:
|
108 |
+
et = model(xt, t)
|
109 |
+
else:
|
110 |
+
classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num
|
111 |
+
et = model(xt, t, classes)
|
112 |
+
et = et[:, :3]
|
113 |
+
et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes)
|
114 |
+
|
115 |
+
if et.size(1) == 6:
|
116 |
+
et = et[:, :3]
|
117 |
+
|
118 |
+
# Eq. 12
|
119 |
+
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
|
120 |
+
|
121 |
+
sigma_t = (1 - at_next).sqrt()[0, 0, 0, 0]
|
122 |
+
|
123 |
+
# Eq. 17
|
124 |
+
x0_t_hat = x0_t - A_funcs.Lambda(A_funcs.A_pinv(
|
125 |
+
A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
|
126 |
+
).reshape(x0_t.size(0), -1), at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta).reshape(*x0_t.size())
|
127 |
+
|
128 |
+
# Eq. 51
|
129 |
+
xt_next = at_next.sqrt() * x0_t_hat + A_funcs.Lambda_noise(
|
130 |
+
torch.randn_like(x0_t).reshape(x0_t.size(0), -1),
|
131 |
+
at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta, et.reshape(et.size(0), -1)).reshape(*x0_t.size())
|
132 |
+
|
133 |
+
x0_preds.append(x0_t.to('cpu'))
|
134 |
+
xs.append(xt_next.to('cpu'))
|
135 |
+
else: # time-travel back
|
136 |
+
next_t = (torch.ones(n) * j).to(x.device)
|
137 |
+
at_next = compute_alpha(b, next_t.long())
|
138 |
+
x0_t = x0_preds[-1].to('cuda')
|
139 |
+
|
140 |
+
xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt()
|
141 |
+
|
142 |
+
xs.append(xt_next.to('cpu'))
|
143 |
+
|
144 |
+
# #ablation
|
145 |
+
# if i%50==0:
|
146 |
+
# os.makedirs('/userhome/wyh/ddnm/debug/x0t', exist_ok=True)
|
147 |
+
# tvu.save_image(
|
148 |
+
# inverse_data_transform(x0_t[0]),
|
149 |
+
# os.path.join('/userhome/wyh/ddnm/debug/x0t', f"x0_t_{i}.png")
|
150 |
+
# )
|
151 |
+
|
152 |
+
# os.makedirs('/userhome/wyh/ddnm/debug/x0_t_hat', exist_ok=True)
|
153 |
+
# tvu.save_image(
|
154 |
+
# inverse_data_transform(x0_t_hat[0]),
|
155 |
+
# os.path.join('/userhome/wyh/ddnm/debug/x0_t_hat', f"x0_t_hat_{i}.png")
|
156 |
+
# )
|
157 |
+
|
158 |
+
# os.makedirs('/userhome/wyh/ddnm/debug/xt_next', exist_ok=True)
|
159 |
+
# tvu.save_image(
|
160 |
+
# inverse_data_transform(xt_next[0]),
|
161 |
+
# os.path.join('/userhome/wyh/ddnm/debug/xt_next', f"xt_next_{i}.png")
|
162 |
+
# )
|
163 |
+
|
164 |
+
return [xs[-1]], [x0_preds[-1]]
|
165 |
+
|
166 |
+
# form RePaint
|
167 |
+
def get_schedule_jump(T_sampling, travel_length, travel_repeat):
|
168 |
+
|
169 |
+
jumps = {}
|
170 |
+
for j in range(0, T_sampling - travel_length, travel_length):
|
171 |
+
jumps[j] = travel_repeat - 1
|
172 |
+
|
173 |
+
t = T_sampling
|
174 |
+
ts = []
|
175 |
+
|
176 |
+
while t >= 1:
|
177 |
+
t = t-1
|
178 |
+
ts.append(t)
|
179 |
+
|
180 |
+
if jumps.get(t, 0) > 0:
|
181 |
+
jumps[t] = jumps[t] - 1
|
182 |
+
for _ in range(travel_length):
|
183 |
+
t = t + 1
|
184 |
+
ts.append(t)
|
185 |
+
|
186 |
+
ts.append(-1)
|
187 |
+
|
188 |
+
_check_times(ts, -1, T_sampling)
|
189 |
+
|
190 |
+
return ts
|
191 |
+
|
192 |
+
def _check_times(times, t_0, T_sampling):
|
193 |
+
# Check end
|
194 |
+
assert times[0] > times[1], (times[0], times[1])
|
195 |
+
|
196 |
+
# Check beginning
|
197 |
+
assert times[-1] == -1, times[-1]
|
198 |
+
|
199 |
+
# Steplength = 1
|
200 |
+
for t_last, t_cur in zip(times[:-1], times[1:]):
|
201 |
+
assert abs(t_last - t_cur) == 1, (t_last, t_cur)
|
202 |
+
|
203 |
+
# Value range
|
204 |
+
for t in times:
|
205 |
+
assert t >= t_0, (t, t_0)
|
206 |
+
assert t <= T_sampling, (t, T_sampling)
|
src/flair/functions/svd_operators.py
ADDED
@@ -0,0 +1,1308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class A_functions:
|
5 |
+
"""
|
6 |
+
A class replacing the SVD of a matrix A, perhaps efficiently.
|
7 |
+
All input vectors are of shape (Batch, ...).
|
8 |
+
All output vectors are of shape (Batch, DataDimension).
|
9 |
+
"""
|
10 |
+
|
11 |
+
def V(self, vec):
|
12 |
+
"""
|
13 |
+
Multiplies the input vector by V
|
14 |
+
"""
|
15 |
+
raise NotImplementedError()
|
16 |
+
|
17 |
+
def Vt(self, vec):
|
18 |
+
"""
|
19 |
+
Multiplies the input vector by V transposed
|
20 |
+
"""
|
21 |
+
raise NotImplementedError()
|
22 |
+
|
23 |
+
def U(self, vec):
|
24 |
+
"""
|
25 |
+
Multiplies the input vector by U
|
26 |
+
"""
|
27 |
+
raise NotImplementedError()
|
28 |
+
|
29 |
+
def Ut(self, vec):
|
30 |
+
"""
|
31 |
+
Multiplies the input vector by U transposed
|
32 |
+
"""
|
33 |
+
raise NotImplementedError()
|
34 |
+
|
35 |
+
def singulars(self):
|
36 |
+
"""
|
37 |
+
Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U)
|
38 |
+
"""
|
39 |
+
raise NotImplementedError()
|
40 |
+
|
41 |
+
def add_zeros(self, vec):
|
42 |
+
"""
|
43 |
+
Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V)
|
44 |
+
"""
|
45 |
+
raise NotImplementedError()
|
46 |
+
|
47 |
+
def A(self, vec):
|
48 |
+
"""
|
49 |
+
Multiplies the input vector by A
|
50 |
+
"""
|
51 |
+
temp = self.Vt(vec)
|
52 |
+
singulars = self.singulars()
|
53 |
+
return self.U(singulars * temp[:, :singulars.shape[0]])
|
54 |
+
|
55 |
+
def At(self, vec):
|
56 |
+
"""
|
57 |
+
Multiplies the input vector by A transposed
|
58 |
+
"""
|
59 |
+
temp = self.Ut(vec)
|
60 |
+
singulars = self.singulars()
|
61 |
+
return self.V(self.add_zeros(singulars * temp[:, :singulars.shape[0]]))
|
62 |
+
|
63 |
+
def A_pinv(self, vec):
|
64 |
+
"""
|
65 |
+
Multiplies the input vector by the pseudo inverse of A
|
66 |
+
"""
|
67 |
+
temp = self.Ut(vec)
|
68 |
+
singulars = self.singulars()
|
69 |
+
|
70 |
+
factors = 1. / singulars
|
71 |
+
factors[singulars == 0] = 0.
|
72 |
+
|
73 |
+
# temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] / singulars
|
74 |
+
temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
|
75 |
+
return self.V(self.add_zeros(temp))
|
76 |
+
|
77 |
+
def A_pinv_eta(self, vec, eta):
|
78 |
+
"""
|
79 |
+
Multiplies the input vector by the pseudo inverse of A with factor eta
|
80 |
+
"""
|
81 |
+
temp = self.Ut(vec)
|
82 |
+
singulars = self.singulars()
|
83 |
+
factors = singulars / (singulars*singulars+eta)
|
84 |
+
# print(temp.size(), factors.size(), singulars.size())
|
85 |
+
temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
|
86 |
+
return self.V(self.add_zeros(temp))
|
87 |
+
|
88 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
89 |
+
raise NotImplementedError()
|
90 |
+
|
91 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
92 |
+
raise NotImplementedError()
|
93 |
+
|
94 |
+
|
95 |
+
# block-wise CS
|
96 |
+
class CS(A_functions):
|
97 |
+
def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4
|
98 |
+
self.img_dim = img_dim
|
99 |
+
self.channels = channels
|
100 |
+
self.y_dim = img_dim // 32
|
101 |
+
self.ratio = 32
|
102 |
+
A = torch.randn(32**2, 32**2).to(device)
|
103 |
+
_, _, self.V_small = torch.svd(A, some=False)
|
104 |
+
self.Vt_small = self.V_small.transpose(0, 1)
|
105 |
+
self.singulars_small = torch.ones(int(32 * 32 * ratio), device=device)
|
106 |
+
self.cs_size = self.singulars_small.size(0)
|
107 |
+
|
108 |
+
def V(self, vec):
|
109 |
+
#reorder the vector back into patches (because singulars are ordered descendingly)
|
110 |
+
|
111 |
+
temp = vec.clone().reshape(vec.shape[0], -1)
|
112 |
+
patches = torch.zeros(vec.size(0), self.channels * self.y_dim ** 2, self.ratio ** 2, device=vec.device)
|
113 |
+
patches[:, :, :self.cs_size] = temp[:, :self.channels * self.y_dim ** 2 * self.cs_size].contiguous().reshape(
|
114 |
+
vec.size(0), -1, self.cs_size)
|
115 |
+
patches[:, :, self.cs_size:] = temp[:, self.channels * self.y_dim ** 2 * self.cs_size:].contiguous().reshape(
|
116 |
+
vec.size(0), self.channels * self.y_dim ** 2, -1)
|
117 |
+
|
118 |
+
#multiply each patch by the small V
|
119 |
+
patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
120 |
+
#repatch the patches into an image
|
121 |
+
patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
|
122 |
+
recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
|
123 |
+
recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
|
124 |
+
return recon
|
125 |
+
|
126 |
+
def Vt(self, vec):
|
127 |
+
#extract flattened patches
|
128 |
+
patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
|
129 |
+
patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
130 |
+
patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
131 |
+
#multiply each by the small V transposed
|
132 |
+
patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
133 |
+
#reorder the vector to have the first entry first (because singulars are ordered descendingly)
|
134 |
+
recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
|
135 |
+
recon[:, :self.channels * self.y_dim ** 2 * self.cs_size] = patches[:, :, :, :self.cs_size].contiguous().reshape(
|
136 |
+
vec.shape[0], -1)
|
137 |
+
recon[:, self.channels * self.y_dim ** 2 * self.cs_size:] = patches[:, :, :, self.cs_size:].contiguous().reshape(
|
138 |
+
vec.shape[0], -1)
|
139 |
+
return recon
|
140 |
+
|
141 |
+
def U(self, vec):
|
142 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
143 |
+
|
144 |
+
def Ut(self, vec): #U is 1x1, so U^T = U
|
145 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
146 |
+
|
147 |
+
def singulars(self):
|
148 |
+
return self.singulars_small.repeat(self.channels * self.y_dim**2)
|
149 |
+
|
150 |
+
def add_zeros(self, vec):
|
151 |
+
reshaped = vec.clone().reshape(vec.shape[0], -1)
|
152 |
+
temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device)
|
153 |
+
temp[:, :reshaped.shape[1]] = reshaped
|
154 |
+
return temp
|
155 |
+
|
156 |
+
|
157 |
+
def color2gray(x):
|
158 |
+
x = x[:, 0:1, :, :] * 0.3333 + x[:, 1:2, :, :] * 0.3334 + x[:, 2:, :, :] * 0.3333
|
159 |
+
return x
|
160 |
+
|
161 |
+
|
162 |
+
def gray2color(x):
|
163 |
+
base = 0.3333 ** 2 + 0.3334 ** 2 + 0.3333 ** 2
|
164 |
+
return torch.stack((x * 0.3333 / base, x * 0.3334 / base, x * 0.3333 / base), 1)
|
165 |
+
|
166 |
+
|
167 |
+
#a memory inefficient implementation for any general degradation A
|
168 |
+
class GeneralA(A_functions):
|
169 |
+
def mat_by_vec(self, M, v):
|
170 |
+
vshape = v.shape[1]
|
171 |
+
if len(v.shape) > 2: vshape = vshape * v.shape[2]
|
172 |
+
if len(v.shape) > 3: vshape = vshape * v.shape[3]
|
173 |
+
return torch.matmul(M, v.view(v.shape[0], vshape,
|
174 |
+
1)).view(v.shape[0], M.shape[0])
|
175 |
+
|
176 |
+
def __init__(self, A):
|
177 |
+
self._U, self._singulars, self._V = torch.svd(A, some=False)
|
178 |
+
self._Vt = self._V.transpose(0, 1)
|
179 |
+
self._Ut = self._U.transpose(0, 1)
|
180 |
+
|
181 |
+
ZERO = 1e-3
|
182 |
+
self._singulars[self._singulars < ZERO] = 0
|
183 |
+
print(len([x.item() for x in self._singulars if x == 0]))
|
184 |
+
|
185 |
+
def V(self, vec):
|
186 |
+
return self.mat_by_vec(self._V, vec.clone())
|
187 |
+
|
188 |
+
def Vt(self, vec):
|
189 |
+
return self.mat_by_vec(self._Vt, vec.clone())
|
190 |
+
|
191 |
+
def U(self, vec):
|
192 |
+
return self.mat_by_vec(self._U, vec.clone())
|
193 |
+
|
194 |
+
def Ut(self, vec):
|
195 |
+
return self.mat_by_vec(self._Ut, vec.clone())
|
196 |
+
|
197 |
+
def singulars(self):
|
198 |
+
return self._singulars
|
199 |
+
|
200 |
+
def add_zeros(self, vec):
|
201 |
+
out = torch.zeros(vec.shape[0], self._V.shape[0], device=vec.device)
|
202 |
+
out[:, :self._U.shape[0]] = vec.clone().reshape(vec.shape[0], -1)
|
203 |
+
return out
|
204 |
+
|
205 |
+
#Walsh-Hadamard Compressive Sensing
|
206 |
+
class WalshHadamardCS(A_functions):
|
207 |
+
def fwht(self, vec): #the Fast Walsh Hadamard Transform is the same as its inverse
|
208 |
+
a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2)
|
209 |
+
h = 1
|
210 |
+
while h < self.img_dim**2:
|
211 |
+
a = a.reshape(vec.shape[0], self.channels, -1, h * 2)
|
212 |
+
b = a.clone()
|
213 |
+
a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h]
|
214 |
+
a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h]
|
215 |
+
h *= 2
|
216 |
+
a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim
|
217 |
+
return a
|
218 |
+
|
219 |
+
def __init__(self, channels, img_dim, ratio, perm, device):
|
220 |
+
self.channels = channels
|
221 |
+
self.img_dim = img_dim
|
222 |
+
self.ratio = ratio
|
223 |
+
self.perm = perm
|
224 |
+
self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device)
|
225 |
+
|
226 |
+
def V(self, vec):
|
227 |
+
temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device)
|
228 |
+
temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
229 |
+
return self.fwht(temp).reshape(vec.shape[0], -1)
|
230 |
+
|
231 |
+
def Vt(self, vec):
|
232 |
+
return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
233 |
+
|
234 |
+
def U(self, vec):
|
235 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
236 |
+
|
237 |
+
def Ut(self, vec):
|
238 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
239 |
+
|
240 |
+
def singulars(self):
|
241 |
+
return self._singulars
|
242 |
+
|
243 |
+
def add_zeros(self, vec):
|
244 |
+
out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
|
245 |
+
out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1)
|
246 |
+
return out
|
247 |
+
|
248 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
249 |
+
temp_vec = self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
250 |
+
|
251 |
+
singulars = self._singulars
|
252 |
+
lambda_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device)
|
253 |
+
temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
|
254 |
+
temp[:singulars.size(0)] = singulars
|
255 |
+
singulars = temp
|
256 |
+
inverse_singulars = 1. / singulars
|
257 |
+
inverse_singulars[singulars == 0] = 0.
|
258 |
+
|
259 |
+
if a != 0 and sigma_y != 0:
|
260 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
261 |
+
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
|
262 |
+
singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
|
263 |
+
|
264 |
+
lambda_t = lambda_t.reshape(1, -1)
|
265 |
+
temp_vec = temp_vec * lambda_t
|
266 |
+
|
267 |
+
temp_out = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
|
268 |
+
temp_out[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
269 |
+
return self.fwht(temp_out).reshape(vec.shape[0], -1)
|
270 |
+
|
271 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
272 |
+
temp_vec = vec.clone().reshape(
|
273 |
+
vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
274 |
+
temp_eps = epsilon.clone().reshape(
|
275 |
+
vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
276 |
+
|
277 |
+
d1_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * eta
|
278 |
+
d2_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
|
279 |
+
|
280 |
+
singulars = self._singulars
|
281 |
+
temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
|
282 |
+
temp[:singulars.size(0)] = singulars
|
283 |
+
singulars = temp
|
284 |
+
inverse_singulars = 1. / singulars
|
285 |
+
inverse_singulars[singulars == 0] = 0.
|
286 |
+
|
287 |
+
if a != 0 and sigma_y != 0:
|
288 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
289 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
290 |
+
d2_t = d2_t * (-change_index + 1.0)
|
291 |
+
|
292 |
+
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
|
293 |
+
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
|
294 |
+
change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
|
295 |
+
d2_t = d2_t * (-change_index + 1.0)
|
296 |
+
|
297 |
+
change_index = (singulars == 0) * 1.0
|
298 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
299 |
+
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
|
300 |
+
|
301 |
+
d1_t = d1_t.reshape(1, -1)
|
302 |
+
d2_t = d2_t.reshape(1, -1)
|
303 |
+
|
304 |
+
temp_vec = temp_vec * d1_t
|
305 |
+
temp_eps = temp_eps * d2_t
|
306 |
+
|
307 |
+
temp_out_vec = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
|
308 |
+
temp_out_vec[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
309 |
+
temp_out_vec = self.fwht(temp_out_vec).reshape(vec.shape[0], -1)
|
310 |
+
|
311 |
+
temp_out_eps = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
|
312 |
+
temp_out_eps[:, :, self.perm] = temp_eps.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
313 |
+
temp_out_eps = self.fwht(temp_out_eps).reshape(vec.shape[0], -1)
|
314 |
+
|
315 |
+
return temp_out_vec + temp_out_eps
|
316 |
+
|
317 |
+
|
318 |
+
#Inpainting
|
319 |
+
class Inpainting(A_functions):
|
320 |
+
def __init__(self, channels, img_dim, missing_indices, device):
|
321 |
+
self.channels = channels
|
322 |
+
self.img_dim = img_dim
|
323 |
+
self._singulars = torch.ones(channels * img_dim**2 - missing_indices.shape[0]).to(device)
|
324 |
+
self.missing_indices = missing_indices
|
325 |
+
self.kept_indices = torch.Tensor([i for i in range(channels * img_dim**2) if i not in missing_indices]).to(device).long()
|
326 |
+
|
327 |
+
def V(self, vec):
|
328 |
+
temp = vec.clone().reshape(vec.shape[0], -1)
|
329 |
+
out = torch.zeros_like(temp)
|
330 |
+
out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]]
|
331 |
+
out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:]
|
332 |
+
return out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
333 |
+
|
334 |
+
def Vt(self, vec):
|
335 |
+
temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
336 |
+
out = torch.zeros_like(temp)
|
337 |
+
out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices]
|
338 |
+
out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices]
|
339 |
+
return out
|
340 |
+
|
341 |
+
def U(self, vec):
|
342 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
343 |
+
|
344 |
+
def Ut(self, vec):
|
345 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
346 |
+
|
347 |
+
def singulars(self):
|
348 |
+
return self._singulars
|
349 |
+
|
350 |
+
def add_zeros(self, vec):
|
351 |
+
temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device)
|
352 |
+
reshaped = vec.clone().reshape(vec.shape[0], -1)
|
353 |
+
temp[:, :reshaped.shape[1]] = reshaped
|
354 |
+
return temp
|
355 |
+
|
356 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
357 |
+
|
358 |
+
temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
359 |
+
out = torch.zeros_like(temp)
|
360 |
+
out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices]
|
361 |
+
out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices]
|
362 |
+
|
363 |
+
singulars = self._singulars
|
364 |
+
lambda_t = torch.ones(temp.size(1), device=vec.device)
|
365 |
+
temp_singulars = torch.zeros(temp.size(1), device=vec.device)
|
366 |
+
temp_singulars[:singulars.size(0)] = singulars
|
367 |
+
singulars = temp_singulars
|
368 |
+
inverse_singulars = 1. / singulars
|
369 |
+
inverse_singulars[singulars == 0] = 0.
|
370 |
+
|
371 |
+
if a != 0 and sigma_y != 0:
|
372 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
373 |
+
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
|
374 |
+
singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
|
375 |
+
|
376 |
+
lambda_t = lambda_t.reshape(1, -1)
|
377 |
+
out = out * lambda_t
|
378 |
+
|
379 |
+
result = torch.zeros_like(temp)
|
380 |
+
result[:, self.kept_indices] = out[:, :self.kept_indices.shape[0]]
|
381 |
+
result[:, self.missing_indices] = out[:, self.kept_indices.shape[0]:]
|
382 |
+
return result.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
383 |
+
|
384 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
385 |
+
temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
386 |
+
out_vec = torch.zeros_like(temp_vec)
|
387 |
+
out_vec[:, :self.kept_indices.shape[0]] = temp_vec[:, self.kept_indices]
|
388 |
+
out_vec[:, self.kept_indices.shape[0]:] = temp_vec[:, self.missing_indices]
|
389 |
+
|
390 |
+
temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
391 |
+
out_eps = torch.zeros_like(temp_eps)
|
392 |
+
out_eps[:, :self.kept_indices.shape[0]] = temp_eps[:, self.kept_indices]
|
393 |
+
out_eps[:, self.kept_indices.shape[0]:] = temp_eps[:, self.missing_indices]
|
394 |
+
|
395 |
+
singulars = self._singulars
|
396 |
+
d1_t = torch.ones(temp_vec.size(1), device=vec.device) * sigma_t * eta
|
397 |
+
d2_t = torch.ones(temp_vec.size(1), device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
|
398 |
+
|
399 |
+
temp_singulars = torch.zeros(temp_vec.size(1), device=vec.device)
|
400 |
+
temp_singulars[:singulars.size(0)] = singulars
|
401 |
+
singulars = temp_singulars
|
402 |
+
inverse_singulars = 1. / singulars
|
403 |
+
inverse_singulars[singulars == 0] = 0.
|
404 |
+
|
405 |
+
if a != 0 and sigma_y != 0:
|
406 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
407 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
408 |
+
d2_t = d2_t * (-change_index + 1.0)
|
409 |
+
|
410 |
+
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
|
411 |
+
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
|
412 |
+
change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
|
413 |
+
d2_t = d2_t * (-change_index + 1.0)
|
414 |
+
|
415 |
+
change_index = (singulars == 0) * 1.0
|
416 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
417 |
+
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
|
418 |
+
|
419 |
+
d1_t = d1_t.reshape(1, -1)
|
420 |
+
d2_t = d2_t.reshape(1, -1)
|
421 |
+
out_vec = out_vec * d1_t
|
422 |
+
out_eps = out_eps * d2_t
|
423 |
+
|
424 |
+
result_vec = torch.zeros_like(temp_vec)
|
425 |
+
result_vec[:, self.kept_indices] = out_vec[:, :self.kept_indices.shape[0]]
|
426 |
+
result_vec[:, self.missing_indices] = out_vec[:, self.kept_indices.shape[0]:]
|
427 |
+
result_vec = result_vec.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
428 |
+
|
429 |
+
result_eps = torch.zeros_like(temp_eps)
|
430 |
+
result_eps[:, self.kept_indices] = out_eps[:, :self.kept_indices.shape[0]]
|
431 |
+
result_eps[:, self.missing_indices] = out_eps[:, self.kept_indices.shape[0]:]
|
432 |
+
result_eps = result_eps.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
|
433 |
+
|
434 |
+
return result_vec + result_eps
|
435 |
+
|
436 |
+
#Denoising
|
437 |
+
class Denoising(A_functions):
|
438 |
+
def __init__(self, channels, img_dim, device):
|
439 |
+
self._singulars = torch.ones(channels * img_dim**2, device=device)
|
440 |
+
|
441 |
+
def V(self, vec):
|
442 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
443 |
+
|
444 |
+
def Vt(self, vec):
|
445 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
446 |
+
|
447 |
+
def U(self, vec):
|
448 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
449 |
+
|
450 |
+
def Ut(self, vec):
|
451 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
452 |
+
|
453 |
+
def singulars(self):
|
454 |
+
return self._singulars
|
455 |
+
|
456 |
+
def add_zeros(self, vec):
|
457 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
458 |
+
|
459 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
460 |
+
if sigma_t < a * sigma_y:
|
461 |
+
factor = (sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y).item()
|
462 |
+
return vec * factor
|
463 |
+
else:
|
464 |
+
return vec
|
465 |
+
|
466 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
467 |
+
if sigma_t >= a * sigma_y:
|
468 |
+
factor = torch.sqrt(sigma_t ** 2 - a ** 2 * sigma_y ** 2).item()
|
469 |
+
return vec * factor
|
470 |
+
else:
|
471 |
+
return vec * sigma_t * eta
|
472 |
+
|
473 |
+
#Super Resolution
|
474 |
+
class SuperResolution(A_functions):
|
475 |
+
def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4
|
476 |
+
assert img_dim % ratio == 0
|
477 |
+
self.img_dim = img_dim
|
478 |
+
self.channels = channels
|
479 |
+
self.y_dim = img_dim // ratio
|
480 |
+
self.ratio = ratio
|
481 |
+
A = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device)
|
482 |
+
self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False)
|
483 |
+
self.Vt_small = self.V_small.transpose(0, 1)
|
484 |
+
|
485 |
+
def V(self, vec):
|
486 |
+
#reorder the vector back into patches (because singulars are ordered descendingly)
|
487 |
+
temp = vec.clone().reshape(vec.shape[0], -1)
|
488 |
+
patches = torch.zeros(vec.shape[0], self.channels, self.y_dim**2, self.ratio**2, device=vec.device)
|
489 |
+
patches[:, :, :, 0] = temp[:, :self.channels * self.y_dim**2].view(vec.shape[0], self.channels, -1)
|
490 |
+
for idx in range(self.ratio**2-1):
|
491 |
+
patches[:, :, :, idx+1] = temp[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1].view(vec.shape[0], self.channels, -1)
|
492 |
+
#multiply each patch by the small V
|
493 |
+
patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
494 |
+
#repatch the patches into an image
|
495 |
+
patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
|
496 |
+
recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
|
497 |
+
recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
|
498 |
+
return recon
|
499 |
+
|
500 |
+
def Vt(self, vec):
|
501 |
+
#extract flattened patches
|
502 |
+
patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
|
503 |
+
patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
504 |
+
unfold_shape = patches.shape
|
505 |
+
patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
506 |
+
#multiply each by the small V transposed
|
507 |
+
patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
508 |
+
#reorder the vector to have the first entry first (because singulars are ordered descendingly)
|
509 |
+
recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
|
510 |
+
recon[:, :self.channels * self.y_dim**2] = patches[:, :, :, 0].view(vec.shape[0], self.channels * self.y_dim**2)
|
511 |
+
for idx in range(self.ratio**2-1):
|
512 |
+
recon[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1] = patches[:, :, :, idx+1].view(vec.shape[0], self.channels * self.y_dim**2)
|
513 |
+
return recon
|
514 |
+
|
515 |
+
def U(self, vec):
|
516 |
+
return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
|
517 |
+
|
518 |
+
def Ut(self, vec): #U is 1x1, so U^T = U
|
519 |
+
return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
|
520 |
+
|
521 |
+
def singulars(self):
|
522 |
+
return self.singulars_small.repeat(self.channels * self.y_dim**2)
|
523 |
+
|
524 |
+
def add_zeros(self, vec):
|
525 |
+
reshaped = vec.clone().reshape(vec.shape[0], -1)
|
526 |
+
temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
|
527 |
+
temp[:, :reshaped.shape[1]] = reshaped
|
528 |
+
return temp
|
529 |
+
|
530 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
531 |
+
singulars = self.singulars_small
|
532 |
+
|
533 |
+
patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
|
534 |
+
patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
535 |
+
patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
|
536 |
+
|
537 |
+
patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
538 |
+
|
539 |
+
lambda_t = torch.ones(self.ratio ** 2, device=vec.device)
|
540 |
+
|
541 |
+
temp = torch.zeros(self.ratio ** 2, device=vec.device)
|
542 |
+
temp[:singulars.size(0)] = singulars
|
543 |
+
singulars = temp
|
544 |
+
inverse_singulars = 1. / singulars
|
545 |
+
inverse_singulars[singulars == 0] = 0.
|
546 |
+
|
547 |
+
if a != 0 and sigma_y != 0:
|
548 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
549 |
+
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
|
550 |
+
|
551 |
+
lambda_t = lambda_t.reshape(1, 1, 1, -1)
|
552 |
+
# print("lambda_t:", lambda_t)
|
553 |
+
# print("V:", self.V_small)
|
554 |
+
# print(lambda_t.size(), self.V_small.size())
|
555 |
+
# print("Sigma_t:", torch.matmul(torch.matmul(self.V_small, torch.diag(lambda_t.reshape(-1))), self.Vt_small))
|
556 |
+
patches = patches * lambda_t
|
557 |
+
|
558 |
+
|
559 |
+
patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1))
|
560 |
+
|
561 |
+
patches = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
|
562 |
+
patches = patches.permute(0, 1, 2, 4, 3, 5).contiguous()
|
563 |
+
patches = patches.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
|
564 |
+
|
565 |
+
return patches
|
566 |
+
|
567 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
568 |
+
singulars = self.singulars_small
|
569 |
+
|
570 |
+
patches_vec = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
|
571 |
+
patches_vec = patches_vec.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
572 |
+
patches_vec = patches_vec.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
|
573 |
+
|
574 |
+
patches_eps = epsilon.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
|
575 |
+
patches_eps = patches_eps.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
576 |
+
patches_eps = patches_eps.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
|
577 |
+
|
578 |
+
d1_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * eta
|
579 |
+
d2_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
|
580 |
+
|
581 |
+
temp = torch.zeros(self.ratio ** 2, device=vec.device)
|
582 |
+
temp[:singulars.size(0)] = singulars
|
583 |
+
singulars = temp
|
584 |
+
inverse_singulars = 1. / singulars
|
585 |
+
inverse_singulars[singulars == 0] = 0.
|
586 |
+
|
587 |
+
if a != 0 and sigma_y != 0:
|
588 |
+
|
589 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
590 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
591 |
+
d2_t = d2_t * (-change_index + 1.0)
|
592 |
+
|
593 |
+
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
|
594 |
+
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
|
595 |
+
d2_t = d2_t * (-change_index + 1.0)
|
596 |
+
|
597 |
+
change_index = (singulars == 0) * 1.0
|
598 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
599 |
+
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
|
600 |
+
|
601 |
+
d1_t = d1_t.reshape(1, 1, 1, -1)
|
602 |
+
d2_t = d2_t.reshape(1, 1, 1, -1)
|
603 |
+
patches_vec = patches_vec * d1_t
|
604 |
+
patches_eps = patches_eps * d2_t
|
605 |
+
|
606 |
+
patches_vec = torch.matmul(self.V_small, patches_vec.reshape(-1, self.ratio**2, 1))
|
607 |
+
|
608 |
+
patches_vec = patches_vec.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
|
609 |
+
patches_vec = patches_vec.permute(0, 1, 2, 4, 3, 5).contiguous()
|
610 |
+
patches_vec = patches_vec.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
|
611 |
+
|
612 |
+
patches_eps = torch.matmul(self.V_small, patches_eps.reshape(-1, self.ratio**2, 1))
|
613 |
+
|
614 |
+
patches_eps = patches_eps.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
|
615 |
+
patches_eps = patches_eps.permute(0, 1, 2, 4, 3, 5).contiguous()
|
616 |
+
patches_eps = patches_eps.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
|
617 |
+
|
618 |
+
return patches_vec + patches_eps
|
619 |
+
|
620 |
+
class SuperResolutionGeneral(SuperResolution):
|
621 |
+
def __init__(self, channels, imgH, imgW, ratio, device): #ratio = 2 or 4
|
622 |
+
assert imgH % ratio == 0 and imgW % ratio == 0
|
623 |
+
self.imgH = imgH
|
624 |
+
self.imgW = imgW
|
625 |
+
self.channels = channels
|
626 |
+
self.yH = imgH // ratio
|
627 |
+
self.yW = imgW // ratio
|
628 |
+
self.ratio = ratio
|
629 |
+
A = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device)
|
630 |
+
self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False)
|
631 |
+
self.Vt_small = self.V_small.transpose(0, 1)
|
632 |
+
|
633 |
+
def V(self, vec):
|
634 |
+
#reorder the vector back into patches (because singulars are ordered descendingly)
|
635 |
+
temp = vec.clone().reshape(vec.shape[0], -1)
|
636 |
+
patches = torch.zeros(vec.shape[0], self.channels, self.yH*self.yW, self.ratio**2, device=vec.device)
|
637 |
+
patches[:, :, :, 0] = temp[:, :self.channels * self.yH*self.yW].view(vec.shape[0], self.channels, -1)
|
638 |
+
for idx in range(self.ratio**2-1):
|
639 |
+
patches[:, :, :, idx+1] = temp[:, (self.channels*self.yH*self.yW+idx)::self.ratio**2-1].view(vec.shape[0], self.channels, -1)
|
640 |
+
#multiply each patch by the small V
|
641 |
+
patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
642 |
+
#repatch the patches into an image
|
643 |
+
patches_orig = patches.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
|
644 |
+
recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
|
645 |
+
recon = recon.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
|
646 |
+
return recon
|
647 |
+
|
648 |
+
def Vt(self, vec):
|
649 |
+
#extract flattened patches
|
650 |
+
patches = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
|
651 |
+
patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
652 |
+
unfold_shape = patches.shape
|
653 |
+
patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
654 |
+
#multiply each by the small V transposed
|
655 |
+
patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
656 |
+
#reorder the vector to have the first entry first (because singulars are ordered descendingly)
|
657 |
+
recon = torch.zeros(vec.shape[0], self.channels * self.imgH*self.imgW, device=vec.device)
|
658 |
+
recon[:, :self.channels * self.yH*self.yW] = patches[:, :, :, 0].view(vec.shape[0], self.channels * self.yH*self.yW)
|
659 |
+
for idx in range(self.ratio**2-1):
|
660 |
+
recon[:, (self.channels*self.yH*self.yW+idx)::self.ratio**2-1] = patches[:, :, :, idx+1].view(vec.shape[0], self.channels * self.yH*self.yW)
|
661 |
+
return recon
|
662 |
+
|
663 |
+
def U(self, vec):
|
664 |
+
return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
|
665 |
+
|
666 |
+
def Ut(self, vec): #U is 1x1, so U^T = U
|
667 |
+
return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
|
668 |
+
|
669 |
+
def singulars(self):
|
670 |
+
return self.singulars_small.repeat(self.channels * self.yH*self.yW)
|
671 |
+
|
672 |
+
def add_zeros(self, vec):
|
673 |
+
reshaped = vec.clone().reshape(vec.shape[0], -1)
|
674 |
+
temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
|
675 |
+
temp[:, :reshaped.shape[1]] = reshaped
|
676 |
+
return temp
|
677 |
+
|
678 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
679 |
+
singulars = self.singulars_small
|
680 |
+
|
681 |
+
patches = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
|
682 |
+
patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
683 |
+
patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
|
684 |
+
|
685 |
+
patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
|
686 |
+
|
687 |
+
lambda_t = torch.ones(self.ratio ** 2, device=vec.device)
|
688 |
+
|
689 |
+
temp = torch.zeros(self.ratio ** 2, device=vec.device)
|
690 |
+
temp[:singulars.size(0)] = singulars
|
691 |
+
singulars = temp
|
692 |
+
inverse_singulars = 1. / singulars
|
693 |
+
inverse_singulars[singulars == 0] = 0.
|
694 |
+
|
695 |
+
if a != 0 and sigma_y != 0:
|
696 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
697 |
+
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
|
698 |
+
|
699 |
+
lambda_t = lambda_t.reshape(1, 1, 1, -1)
|
700 |
+
# print("lambda_t:", lambda_t)
|
701 |
+
# print("V:", self.V_small)
|
702 |
+
# print(lambda_t.size(), self.V_small.size())
|
703 |
+
# print("Sigma_t:", torch.matmul(torch.matmul(self.V_small, torch.diag(lambda_t.reshape(-1))), self.Vt_small))
|
704 |
+
patches = patches * lambda_t
|
705 |
+
|
706 |
+
|
707 |
+
patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1))
|
708 |
+
|
709 |
+
patches = patches.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
|
710 |
+
patches = patches.permute(0, 1, 2, 4, 3, 5).contiguous()
|
711 |
+
patches = patches.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
|
712 |
+
|
713 |
+
return patches
|
714 |
+
|
715 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
716 |
+
singulars = self.singulars_small
|
717 |
+
|
718 |
+
patches_vec = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
|
719 |
+
patches_vec = patches_vec.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
720 |
+
patches_vec = patches_vec.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
|
721 |
+
|
722 |
+
patches_eps = epsilon.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
|
723 |
+
patches_eps = patches_eps.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
|
724 |
+
patches_eps = patches_eps.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
|
725 |
+
|
726 |
+
d1_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * eta
|
727 |
+
d2_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
|
728 |
+
|
729 |
+
temp = torch.zeros(self.ratio ** 2, device=vec.device)
|
730 |
+
temp[:singulars.size(0)] = singulars
|
731 |
+
singulars = temp
|
732 |
+
inverse_singulars = 1. / singulars
|
733 |
+
inverse_singulars[singulars == 0] = 0.
|
734 |
+
|
735 |
+
if a != 0 and sigma_y != 0:
|
736 |
+
|
737 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
738 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
739 |
+
d2_t = d2_t * (-change_index + 1.0)
|
740 |
+
|
741 |
+
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
|
742 |
+
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
|
743 |
+
d2_t = d2_t * (-change_index + 1.0)
|
744 |
+
|
745 |
+
change_index = (singulars == 0) * 1.0
|
746 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
747 |
+
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
|
748 |
+
|
749 |
+
d1_t = d1_t.reshape(1, 1, 1, -1)
|
750 |
+
d2_t = d2_t.reshape(1, 1, 1, -1)
|
751 |
+
patches_vec = patches_vec * d1_t
|
752 |
+
patches_eps = patches_eps * d2_t
|
753 |
+
|
754 |
+
patches_vec = torch.matmul(self.V_small, patches_vec.reshape(-1, self.ratio**2, 1))
|
755 |
+
|
756 |
+
patches_vec = patches_vec.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
|
757 |
+
patches_vec = patches_vec.permute(0, 1, 2, 4, 3, 5).contiguous()
|
758 |
+
patches_vec = patches_vec.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
|
759 |
+
|
760 |
+
patches_eps = torch.matmul(self.V_small, patches_eps.reshape(-1, self.ratio**2, 1))
|
761 |
+
|
762 |
+
patches_eps = patches_eps.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
|
763 |
+
patches_eps = patches_eps.permute(0, 1, 2, 4, 3, 5).contiguous()
|
764 |
+
patches_eps = patches_eps.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
|
765 |
+
|
766 |
+
return patches_vec + patches_eps
|
767 |
+
|
768 |
+
#Colorization
|
769 |
+
class Colorization(A_functions):
|
770 |
+
def __init__(self, img_dim, device):
|
771 |
+
self.channels = 3
|
772 |
+
self.img_dim = img_dim
|
773 |
+
#Do the SVD for the per-pixel matrix
|
774 |
+
A = torch.Tensor([[0.3333, 0.3334, 0.3333]]).to(device)
|
775 |
+
self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False)
|
776 |
+
self.Vt_small = self.V_small.transpose(0, 1)
|
777 |
+
|
778 |
+
def V(self, vec):
|
779 |
+
#get the needles
|
780 |
+
needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WA, C'
|
781 |
+
#multiply each needle by the small V
|
782 |
+
needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WA, C
|
783 |
+
#permute back to vector representation
|
784 |
+
recon = needles.permute(0, 2, 1) #shape: B, C, WA
|
785 |
+
return recon.reshape(vec.shape[0], -1)
|
786 |
+
|
787 |
+
def Vt(self, vec):
|
788 |
+
#get the needles
|
789 |
+
needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WA, C
|
790 |
+
#multiply each needle by the small V transposed
|
791 |
+
needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WA, C'
|
792 |
+
#reorder the vector so that the first entry of each needle is at the top
|
793 |
+
recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1)
|
794 |
+
return recon
|
795 |
+
|
796 |
+
def U(self, vec):
|
797 |
+
return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
|
798 |
+
|
799 |
+
def Ut(self, vec): #U is 1x1, so U^T = U
|
800 |
+
return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
|
801 |
+
|
802 |
+
def singulars(self):
|
803 |
+
return self.singulars_small.repeat(self.img_dim**2)
|
804 |
+
|
805 |
+
def add_zeros(self, vec):
|
806 |
+
reshaped = vec.clone().reshape(vec.shape[0], -1)
|
807 |
+
temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device)
|
808 |
+
temp[:, :self.img_dim**2] = reshaped
|
809 |
+
return temp
|
810 |
+
|
811 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
812 |
+
needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1)
|
813 |
+
|
814 |
+
needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels)
|
815 |
+
|
816 |
+
singulars = self.singulars_small
|
817 |
+
lambda_t = torch.ones(self.channels, device=vec.device)
|
818 |
+
temp = torch.zeros(self.channels, device=vec.device)
|
819 |
+
temp[:singulars.size(0)] = singulars
|
820 |
+
singulars = temp
|
821 |
+
inverse_singulars = 1. / singulars
|
822 |
+
inverse_singulars[singulars == 0] = 0.
|
823 |
+
|
824 |
+
if a != 0 and sigma_y != 0:
|
825 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
826 |
+
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
|
827 |
+
singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
|
828 |
+
|
829 |
+
lambda_t = lambda_t.reshape(1, 1, self.channels)
|
830 |
+
needles = needles * lambda_t
|
831 |
+
|
832 |
+
needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels)
|
833 |
+
|
834 |
+
recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1)
|
835 |
+
return recon
|
836 |
+
|
837 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
838 |
+
singulars = self.singulars_small
|
839 |
+
|
840 |
+
needles_vec = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1)
|
841 |
+
needles_epsilon = epsilon.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1)
|
842 |
+
|
843 |
+
d1_t = torch.ones(self.channels, device=vec.device) * sigma_t * eta
|
844 |
+
d2_t = torch.ones(self.channels, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
|
845 |
+
|
846 |
+
temp = torch.zeros(self.channels, device=vec.device)
|
847 |
+
temp[:singulars.size(0)] = singulars
|
848 |
+
singulars = temp
|
849 |
+
inverse_singulars = 1. / singulars
|
850 |
+
inverse_singulars[singulars == 0] = 0.
|
851 |
+
|
852 |
+
if a != 0 and sigma_y != 0:
|
853 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
854 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
855 |
+
d2_t = d2_t * (-change_index + 1.0)
|
856 |
+
|
857 |
+
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
|
858 |
+
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
|
859 |
+
change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
|
860 |
+
d2_t = d2_t * (-change_index + 1.0)
|
861 |
+
|
862 |
+
change_index = (singulars == 0) * 1.0
|
863 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
864 |
+
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
|
865 |
+
|
866 |
+
d1_t = d1_t.reshape(1, 1, self.channels)
|
867 |
+
d2_t = d2_t.reshape(1, 1, self.channels)
|
868 |
+
|
869 |
+
needles_vec = needles_vec * d1_t
|
870 |
+
needles_epsilon = needles_epsilon * d2_t
|
871 |
+
|
872 |
+
needles_vec = torch.matmul(self.V_small, needles_vec.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels)
|
873 |
+
recon_vec = needles_vec.permute(0, 2, 1).reshape(vec.shape[0], -1)
|
874 |
+
|
875 |
+
needles_epsilon = torch.matmul(self.V_small, needles_epsilon.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1,self.channels)
|
876 |
+
recon_epsilon = needles_epsilon.permute(0, 2, 1).reshape(vec.shape[0], -1)
|
877 |
+
|
878 |
+
return recon_vec + recon_epsilon
|
879 |
+
|
880 |
+
#Walsh-Aadamard Compressive Sensing
|
881 |
+
class WalshAadamardCS(A_functions):
|
882 |
+
def fwht(self, vec): #the Fast Walsh Aadamard Transform is the same as its inverse
|
883 |
+
a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2)
|
884 |
+
h = 1
|
885 |
+
while h < self.img_dim**2:
|
886 |
+
a = a.reshape(vec.shape[0], self.channels, -1, h * 2)
|
887 |
+
b = a.clone()
|
888 |
+
a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h]
|
889 |
+
a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h]
|
890 |
+
h *= 2
|
891 |
+
a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim
|
892 |
+
return a
|
893 |
+
|
894 |
+
def __init__(self, channels, img_dim, ratio, perm, device):
|
895 |
+
self.channels = channels
|
896 |
+
self.img_dim = img_dim
|
897 |
+
self.ratio = ratio
|
898 |
+
self.perm = perm
|
899 |
+
self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device)
|
900 |
+
|
901 |
+
def V(self, vec):
|
902 |
+
temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device)
|
903 |
+
temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
904 |
+
return self.fwht(temp).reshape(vec.shape[0], -1)
|
905 |
+
|
906 |
+
def Vt(self, vec):
|
907 |
+
return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
908 |
+
|
909 |
+
def U(self, vec):
|
910 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
911 |
+
|
912 |
+
def Ut(self, vec):
|
913 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
914 |
+
|
915 |
+
def singulars(self):
|
916 |
+
return self._singulars
|
917 |
+
|
918 |
+
def add_zeros(self, vec):
|
919 |
+
out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
|
920 |
+
out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1)
|
921 |
+
return out
|
922 |
+
|
923 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
924 |
+
temp_vec = self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
925 |
+
|
926 |
+
singulars = self._singulars
|
927 |
+
lambda_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device)
|
928 |
+
temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
|
929 |
+
temp[:singulars.size(0)] = singulars
|
930 |
+
singulars = temp
|
931 |
+
inverse_singulars = 1. / singulars
|
932 |
+
inverse_singulars[singulars == 0] = 0.
|
933 |
+
|
934 |
+
if a != 0 and sigma_y != 0:
|
935 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
936 |
+
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
|
937 |
+
singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
|
938 |
+
|
939 |
+
lambda_t = lambda_t.reshape(1, -1)
|
940 |
+
temp_vec = temp_vec * lambda_t
|
941 |
+
|
942 |
+
temp_out = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
|
943 |
+
temp_out[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
944 |
+
return self.fwht(temp_out).reshape(vec.shape[0], -1)
|
945 |
+
|
946 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
947 |
+
temp_vec = vec.clone().reshape(
|
948 |
+
vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
949 |
+
temp_eps = epsilon.clone().reshape(
|
950 |
+
vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
|
951 |
+
|
952 |
+
d1_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * eta
|
953 |
+
d2_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
|
954 |
+
|
955 |
+
singulars = self._singulars
|
956 |
+
temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
|
957 |
+
temp[:singulars.size(0)] = singulars
|
958 |
+
singulars = temp
|
959 |
+
inverse_singulars = 1. / singulars
|
960 |
+
inverse_singulars[singulars == 0] = 0.
|
961 |
+
|
962 |
+
if a != 0 and sigma_y != 0:
|
963 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
964 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
965 |
+
d2_t = d2_t * (-change_index + 1.0)
|
966 |
+
|
967 |
+
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
|
968 |
+
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
|
969 |
+
change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
|
970 |
+
d2_t = d2_t * (-change_index + 1.0)
|
971 |
+
|
972 |
+
change_index = (singulars == 0) * 1.0
|
973 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
974 |
+
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
|
975 |
+
|
976 |
+
d1_t = d1_t.reshape(1, -1)
|
977 |
+
d2_t = d2_t.reshape(1, -1)
|
978 |
+
|
979 |
+
temp_vec = temp_vec * d1_t
|
980 |
+
temp_eps = temp_eps * d2_t
|
981 |
+
|
982 |
+
temp_out_vec = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
|
983 |
+
temp_out_vec[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
984 |
+
temp_out_vec = self.fwht(temp_out_vec).reshape(vec.shape[0], -1)
|
985 |
+
|
986 |
+
temp_out_eps = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
|
987 |
+
temp_out_eps[:, :, self.perm] = temp_eps.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
|
988 |
+
temp_out_eps = self.fwht(temp_out_eps).reshape(vec.shape[0], -1)
|
989 |
+
|
990 |
+
return temp_out_vec + temp_out_eps
|
991 |
+
|
992 |
+
#Convolution-based super-resolution
|
993 |
+
class SRConv(A_functions):
|
994 |
+
def mat_by_img(self, M, v, dim):
|
995 |
+
return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim,
|
996 |
+
dim)).reshape(v.shape[0], self.channels, M.shape[0], dim)
|
997 |
+
|
998 |
+
def img_by_mat(self, v, M, dim):
|
999 |
+
return torch.matmul(v.reshape(v.shape[0] * self.channels, dim,
|
1000 |
+
dim), M).reshape(v.shape[0], self.channels, dim, M.shape[1])
|
1001 |
+
|
1002 |
+
def __init__(self, kernel, channels, img_dim, device, stride = 1):
|
1003 |
+
self.img_dim = img_dim
|
1004 |
+
self.channels = channels
|
1005 |
+
self.ratio = stride
|
1006 |
+
small_dim = img_dim // stride
|
1007 |
+
self.small_dim = small_dim
|
1008 |
+
#build 1D conv matrix
|
1009 |
+
A_small = torch.zeros(small_dim, img_dim, device=device)
|
1010 |
+
for i in range(stride//2, img_dim + stride//2, stride):
|
1011 |
+
for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
|
1012 |
+
j_effective = j
|
1013 |
+
#reflective padding
|
1014 |
+
if j_effective < 0: j_effective = -j_effective-1
|
1015 |
+
if j_effective >= img_dim: j_effective = (img_dim - 1) - (j_effective - img_dim)
|
1016 |
+
#matrix building
|
1017 |
+
A_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0]//2]
|
1018 |
+
#get the svd of the 1D conv
|
1019 |
+
self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False)
|
1020 |
+
ZERO = 3e-2
|
1021 |
+
self.singulars_small[self.singulars_small < ZERO] = 0
|
1022 |
+
#calculate the singular values of the big matrix
|
1023 |
+
self._singulars = torch.matmul(self.singulars_small.reshape(small_dim, 1), self.singulars_small.reshape(1, small_dim)).reshape(small_dim**2)
|
1024 |
+
#permutation for matching the singular values. See P_1 in Appendix D.5.
|
1025 |
+
self._perm = torch.Tensor([self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)] + \
|
1026 |
+
[self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim, self.img_dim)]).to(device).long()
|
1027 |
+
|
1028 |
+
def V(self, vec):
|
1029 |
+
#invert the permutation
|
1030 |
+
temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
|
1031 |
+
temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, :self._perm.shape[0], :]
|
1032 |
+
temp[:, self._perm.shape[0]:, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, self._perm.shape[0]:, :]
|
1033 |
+
temp = temp.permute(0, 2, 1)
|
1034 |
+
#multiply the image by V from the left and by V^T from the right
|
1035 |
+
out = self.mat_by_img(self.V_small, temp, self.img_dim)
|
1036 |
+
out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1)
|
1037 |
+
return out
|
1038 |
+
|
1039 |
+
def Vt(self, vec):
|
1040 |
+
#multiply the image by V^T from the left and by V from the right
|
1041 |
+
temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim)
|
1042 |
+
temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1)
|
1043 |
+
#permute the entries
|
1044 |
+
temp[:, :, :self._perm.shape[0]] = temp[:, :, self._perm]
|
1045 |
+
temp = temp.permute(0, 2, 1)
|
1046 |
+
return temp.reshape(vec.shape[0], -1)
|
1047 |
+
|
1048 |
+
def U(self, vec):
|
1049 |
+
#invert the permutation
|
1050 |
+
temp = torch.zeros(vec.shape[0], self.small_dim**2, self.channels, device=vec.device)
|
1051 |
+
temp[:, :self.small_dim**2, :] = vec.clone().reshape(vec.shape[0], self.small_dim**2, self.channels)
|
1052 |
+
temp = temp.permute(0, 2, 1)
|
1053 |
+
#multiply the image by U from the left and by U^T from the right
|
1054 |
+
out = self.mat_by_img(self.U_small, temp, self.small_dim)
|
1055 |
+
out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1)
|
1056 |
+
return out
|
1057 |
+
|
1058 |
+
def Ut(self, vec):
|
1059 |
+
#multiply the image by U^T from the left and by U from the right
|
1060 |
+
temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim)
|
1061 |
+
temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1)
|
1062 |
+
#permute the entries
|
1063 |
+
temp = temp.permute(0, 2, 1)
|
1064 |
+
return temp.reshape(vec.shape[0], -1)
|
1065 |
+
|
1066 |
+
def singulars(self):
|
1067 |
+
return self._singulars.repeat_interleave(3).reshape(-1)
|
1068 |
+
|
1069 |
+
def add_zeros(self, vec):
|
1070 |
+
reshaped = vec.clone().reshape(vec.shape[0], -1)
|
1071 |
+
temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
|
1072 |
+
temp[:, :reshaped.shape[1]] = reshaped
|
1073 |
+
return temp
|
1074 |
+
|
1075 |
+
#Deblurring
|
1076 |
+
class Deblurring(A_functions):
|
1077 |
+
def mat_by_img(self, M, v):
|
1078 |
+
return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
|
1079 |
+
self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)
|
1080 |
+
|
1081 |
+
def img_by_mat(self, v, M):
|
1082 |
+
return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
|
1083 |
+
self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])
|
1084 |
+
|
1085 |
+
def __init__(self, kernel, channels, img_dim, device, ZERO = 3e-2):
|
1086 |
+
self.img_dim = img_dim
|
1087 |
+
self.channels = channels
|
1088 |
+
#build 1D conv matrix
|
1089 |
+
A_small = torch.zeros(img_dim, img_dim, device=device)
|
1090 |
+
for i in range(img_dim):
|
1091 |
+
for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
|
1092 |
+
if j < 0 or j >= img_dim: continue
|
1093 |
+
A_small[i, j] = kernel[j - i + kernel.shape[0]//2]
|
1094 |
+
#get the svd of the 1D conv
|
1095 |
+
self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False)
|
1096 |
+
#ZERO = 3e-2
|
1097 |
+
self.singulars_small_orig = self.singulars_small.clone()
|
1098 |
+
self.singulars_small[self.singulars_small < ZERO] = 0
|
1099 |
+
#calculate the singular values of the big matrix
|
1100 |
+
self._singulars_orig = torch.matmul(self.singulars_small_orig.reshape(img_dim, 1), self.singulars_small_orig.reshape(1, img_dim)).reshape(img_dim**2)
|
1101 |
+
self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2)
|
1102 |
+
#sort the big matrix singulars and save the permutation
|
1103 |
+
self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True)
|
1104 |
+
self._singulars_orig = self._singulars_orig[self._perm]
|
1105 |
+
|
1106 |
+
def V(self, vec):
|
1107 |
+
#invert the permutation
|
1108 |
+
temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
|
1109 |
+
temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
|
1110 |
+
temp = temp.permute(0, 2, 1)
|
1111 |
+
#multiply the image by V from the left and by V^T from the right
|
1112 |
+
out = self.mat_by_img(self.V_small, temp)
|
1113 |
+
out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
|
1114 |
+
return out
|
1115 |
+
|
1116 |
+
def Vt(self, vec):
|
1117 |
+
#multiply the image by V^T from the left and by V from the right
|
1118 |
+
temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
|
1119 |
+
temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1)
|
1120 |
+
#permute the entries according to the singular values
|
1121 |
+
temp = temp[:, :, self._perm].permute(0, 2, 1)
|
1122 |
+
return temp.reshape(vec.shape[0], -1)
|
1123 |
+
|
1124 |
+
def U(self, vec):
|
1125 |
+
#invert the permutation
|
1126 |
+
temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
|
1127 |
+
temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
|
1128 |
+
temp = temp.permute(0, 2, 1)
|
1129 |
+
#multiply the image by U from the left and by U^T from the right
|
1130 |
+
out = self.mat_by_img(self.U_small, temp)
|
1131 |
+
out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1)
|
1132 |
+
return out
|
1133 |
+
|
1134 |
+
def Ut(self, vec):
|
1135 |
+
#multiply the image by U^T from the left and by U from the right
|
1136 |
+
temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone())
|
1137 |
+
temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1)
|
1138 |
+
#permute the entries according to the singular values
|
1139 |
+
temp = temp[:, :, self._perm].permute(0, 2, 1)
|
1140 |
+
return temp.reshape(vec.shape[0], -1)
|
1141 |
+
|
1142 |
+
def singulars(self):
|
1143 |
+
return self._singulars.repeat(1, 3).reshape(-1)
|
1144 |
+
|
1145 |
+
def add_zeros(self, vec):
|
1146 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
1147 |
+
|
1148 |
+
def A_pinv(self, vec):
|
1149 |
+
temp = self.Ut(vec)
|
1150 |
+
singulars = self._singulars.repeat(1, 3).reshape(-1)
|
1151 |
+
|
1152 |
+
factors = 1. / singulars
|
1153 |
+
factors[singulars == 0] = 0.
|
1154 |
+
|
1155 |
+
temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
|
1156 |
+
return self.V(self.add_zeros(temp))
|
1157 |
+
|
1158 |
+
def Lambda(self, vec, a, sigma_y, sigma_t, eta):
|
1159 |
+
temp_vec = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
|
1160 |
+
temp_vec = self.img_by_mat(temp_vec, self.V_small).reshape(vec.shape[0], self.channels, -1)
|
1161 |
+
temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)
|
1162 |
+
|
1163 |
+
singulars = self._singulars_orig
|
1164 |
+
lambda_t = torch.ones(self.img_dim ** 2, device=vec.device)
|
1165 |
+
temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
|
1166 |
+
temp_singulars[:singulars.size(0)] = singulars
|
1167 |
+
singulars = temp_singulars
|
1168 |
+
inverse_singulars = 1. / singulars
|
1169 |
+
inverse_singulars[singulars == 0] = 0.
|
1170 |
+
|
1171 |
+
if a != 0 and sigma_y != 0:
|
1172 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
1173 |
+
lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
|
1174 |
+
singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
|
1175 |
+
|
1176 |
+
lambda_t = lambda_t.reshape(1, -1, 1)
|
1177 |
+
temp_vec = temp_vec * lambda_t
|
1178 |
+
|
1179 |
+
temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
|
1180 |
+
temp[:, self._perm, :] = temp_vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels)
|
1181 |
+
temp = temp.permute(0, 2, 1)
|
1182 |
+
out = self.mat_by_img(self.V_small, temp)
|
1183 |
+
out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
|
1184 |
+
return out
|
1185 |
+
|
1186 |
+
def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
|
1187 |
+
temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1)
|
1188 |
+
temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)
|
1189 |
+
|
1190 |
+
temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1)
|
1191 |
+
temp_eps = temp_eps[:, :, self._perm].permute(0, 2, 1)
|
1192 |
+
|
1193 |
+
singulars = self._singulars_orig
|
1194 |
+
d1_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * eta
|
1195 |
+
d2_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
|
1196 |
+
|
1197 |
+
temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
|
1198 |
+
temp_singulars[:singulars.size(0)] = singulars
|
1199 |
+
singulars = temp_singulars
|
1200 |
+
inverse_singulars = 1. / singulars
|
1201 |
+
inverse_singulars[singulars == 0] = 0.
|
1202 |
+
|
1203 |
+
if a != 0 and sigma_y != 0:
|
1204 |
+
change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
|
1205 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
1206 |
+
d2_t = d2_t * (-change_index + 1.0)
|
1207 |
+
|
1208 |
+
change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
|
1209 |
+
d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
|
1210 |
+
change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
|
1211 |
+
d2_t = d2_t * (-change_index + 1.0)
|
1212 |
+
|
1213 |
+
change_index = (singulars == 0) * 1.0
|
1214 |
+
d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
|
1215 |
+
d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
|
1216 |
+
|
1217 |
+
d1_t = d1_t.reshape(1, -1, 1)
|
1218 |
+
d2_t = d2_t.reshape(1, -1, 1)
|
1219 |
+
|
1220 |
+
temp_vec = temp_vec * d1_t
|
1221 |
+
temp_eps = temp_eps * d2_t
|
1222 |
+
|
1223 |
+
temp_vec_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
|
1224 |
+
temp_vec_new[:, self._perm, :] = temp_vec
|
1225 |
+
out_vec = self.mat_by_img(self.V_small, temp_vec_new.permute(0, 2, 1))
|
1226 |
+
out_vec = self.img_by_mat(out_vec, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
|
1227 |
+
|
1228 |
+
temp_eps_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
|
1229 |
+
temp_eps_new[:, self._perm, :] = temp_eps
|
1230 |
+
out_eps = self.mat_by_img(self.V_small, temp_eps_new.permute(0, 2, 1))
|
1231 |
+
out_eps = self.img_by_mat(out_eps, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
|
1232 |
+
|
1233 |
+
return out_vec + out_eps
|
1234 |
+
|
1235 |
+
#Anisotropic Deblurring
|
1236 |
+
class Deblurring2D(A_functions):
|
1237 |
+
def mat_by_img(self, M, v):
|
1238 |
+
return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
|
1239 |
+
self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)
|
1240 |
+
|
1241 |
+
def img_by_mat(self, v, M):
|
1242 |
+
return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
|
1243 |
+
self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])
|
1244 |
+
|
1245 |
+
def __init__(self, kernel1, kernel2, channels, img_dim, device):
|
1246 |
+
self.img_dim = img_dim
|
1247 |
+
self.channels = channels
|
1248 |
+
A_small1 = torch.zeros(img_dim, img_dim, device=device)
|
1249 |
+
for i in range(img_dim):
|
1250 |
+
for j in range(i - kernel1.shape[0]//2, i + kernel1.shape[0]//2):
|
1251 |
+
if j < 0 or j >= img_dim: continue
|
1252 |
+
A_small1[i, j] = kernel1[j - i + kernel1.shape[0]//2]
|
1253 |
+
A_small2 = torch.zeros(img_dim, img_dim, device=device)
|
1254 |
+
for i in range(img_dim):
|
1255 |
+
for j in range(i - kernel2.shape[0]//2, i + kernel2.shape[0]//2):
|
1256 |
+
if j < 0 or j >= img_dim: continue
|
1257 |
+
A_small2[i, j] = kernel2[j - i + kernel2.shape[0]//2]
|
1258 |
+
self.U_small1, self.singulars_small1, self.V_small1 = torch.svd(A_small1, some=False)
|
1259 |
+
self.U_small2, self.singulars_small2, self.V_small2 = torch.svd(A_small2, some=False)
|
1260 |
+
ZERO = 3e-2
|
1261 |
+
self.singulars_small1[self.singulars_small1 < ZERO] = 0
|
1262 |
+
self.singulars_small2[self.singulars_small2 < ZERO] = 0
|
1263 |
+
#calculate the singular values of the big matrix
|
1264 |
+
self._singulars = torch.matmul(self.singulars_small1.reshape(img_dim, 1), self.singulars_small2.reshape(1, img_dim)).reshape(img_dim**2)
|
1265 |
+
#sort the big matrix singulars and save the permutation
|
1266 |
+
self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True)
|
1267 |
+
|
1268 |
+
def V(self, vec):
|
1269 |
+
#invert the permutation
|
1270 |
+
temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
|
1271 |
+
temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
|
1272 |
+
temp = temp.permute(0, 2, 1)
|
1273 |
+
#multiply the image by V from the left and by V^T from the right
|
1274 |
+
out = self.mat_by_img(self.V_small1, temp)
|
1275 |
+
out = self.img_by_mat(out, self.V_small2.transpose(0, 1)).reshape(vec.shape[0], -1)
|
1276 |
+
return out
|
1277 |
+
|
1278 |
+
def Vt(self, vec):
|
1279 |
+
#multiply the image by V^T from the left and by V from the right
|
1280 |
+
temp = self.mat_by_img(self.V_small1.transpose(0, 1), vec.clone())
|
1281 |
+
temp = self.img_by_mat(temp, self.V_small2).reshape(vec.shape[0], self.channels, -1)
|
1282 |
+
#permute the entries according to the singular values
|
1283 |
+
temp = temp[:, :, self._perm].permute(0, 2, 1)
|
1284 |
+
return temp.reshape(vec.shape[0], -1)
|
1285 |
+
|
1286 |
+
def U(self, vec):
|
1287 |
+
#invert the permutation
|
1288 |
+
temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
|
1289 |
+
temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
|
1290 |
+
temp = temp.permute(0, 2, 1)
|
1291 |
+
#multiply the image by U from the left and by U^T from the right
|
1292 |
+
out = self.mat_by_img(self.U_small1, temp)
|
1293 |
+
out = self.img_by_mat(out, self.U_small2.transpose(0, 1)).reshape(vec.shape[0], -1)
|
1294 |
+
return out
|
1295 |
+
|
1296 |
+
def Ut(self, vec):
|
1297 |
+
#multiply the image by U^T from the left and by U from the right
|
1298 |
+
temp = self.mat_by_img(self.U_small1.transpose(0, 1), vec.clone())
|
1299 |
+
temp = self.img_by_mat(temp, self.U_small2).reshape(vec.shape[0], self.channels, -1)
|
1300 |
+
#permute the entries according to the singular values
|
1301 |
+
temp = temp[:, :, self._perm].permute(0, 2, 1)
|
1302 |
+
return temp.reshape(vec.shape[0], -1)
|
1303 |
+
|
1304 |
+
def singulars(self):
|
1305 |
+
return self._singulars.repeat(1, 3).reshape(-1)
|
1306 |
+
|
1307 |
+
def add_zeros(self, vec):
|
1308 |
+
return vec.clone().reshape(vec.shape[0], -1)
|
src/flair/helper_functions.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import click
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
def parse_click_context(ctx):
|
7 |
+
"""Parse additional arguments passed via Click context."""
|
8 |
+
extra_args = {}
|
9 |
+
for arg in ctx.args:
|
10 |
+
if "=" in arg:
|
11 |
+
key, value = arg.split("=", 1)
|
12 |
+
if key.startswith("--"):
|
13 |
+
key = key[2:] # Remove leading "--"
|
14 |
+
extra_args[key] = yaml.safe_load(value)
|
15 |
+
return extra_args
|
16 |
+
|
17 |
+
def generate_captions_with_seesr(pseudo_inv_dir, output_caption_file):
|
18 |
+
"""Generate captions using the SEESR model."""
|
19 |
+
command = [
|
20 |
+
"conda",
|
21 |
+
"run",
|
22 |
+
"-n",
|
23 |
+
"seesr",
|
24 |
+
"python",
|
25 |
+
"/home/erbachj/scratch2/projects/var_post_samp/scripts/generate_caption.py",
|
26 |
+
"--input_dir",
|
27 |
+
pseudo_inv_dir,
|
28 |
+
"--output_file",
|
29 |
+
output_caption_file, # Corrected argument name
|
30 |
+
]
|
31 |
+
subprocess.run(command, check=True)
|
src/flair/pipelines/__init__.py
ADDED
File without changes
|
src/flair/pipelines/model_loader.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# functions to load models from config
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
+
import os
|
6 |
+
import math
|
7 |
+
from diffusers import (
|
8 |
+
BitsAndBytesConfig
|
9 |
+
)
|
10 |
+
from diffusers import AutoencoderTiny
|
11 |
+
|
12 |
+
|
13 |
+
from flair.pipelines import sd3
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def load_sd3(config, device):
|
19 |
+
if isinstance(device, list):
|
20 |
+
device = device[0]
|
21 |
+
if config["quantize"]:
|
22 |
+
nf4_config = BitsAndBytesConfig(
|
23 |
+
load_in_4bit=True,
|
24 |
+
bnb_4bit_quant_type="nf4",
|
25 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
26 |
+
)
|
27 |
+
else:
|
28 |
+
nf4_config = None
|
29 |
+
if config["model"] == "SD3.5-large":
|
30 |
+
pipe = sd3.SD3Wrapper.from_pretrained(
|
31 |
+
"stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16, quantization_config=nf4_config
|
32 |
+
)
|
33 |
+
elif config["model"] == "SD3.5-large-turbo":
|
34 |
+
pipe = sd3.SD3Wrapper.from_pretrained(
|
35 |
+
"stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, quantization_config=nf4_config,
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
pipe = sd3.SD3Wrapper.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, quantization_config=nf4_config)
|
39 |
+
# maybe use tiny autoencoder
|
40 |
+
if config["use_tiny_ae"]:
|
41 |
+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)
|
42 |
+
|
43 |
+
# encode prompts
|
44 |
+
inp_kwargs_list = []
|
45 |
+
prompts = config["prompt"]
|
46 |
+
pipe._guidance_scale = config["guidance"]
|
47 |
+
pipe._joint_attention_kwargs = {"ip_adapter_image_embeds": None}
|
48 |
+
for prompt in prompts:
|
49 |
+
print(f"Generating prompt embeddings for: {prompt}")
|
50 |
+
pipe.text_encoder.to(device).to(torch.bfloat16)
|
51 |
+
pipe.text_encoder_2.to(device).to(torch.bfloat16)
|
52 |
+
pipe.text_encoder_3.to(device).to(torch.bfloat16)
|
53 |
+
# encode
|
54 |
+
(
|
55 |
+
prompt_embeds,
|
56 |
+
negative_prompt_embeds,
|
57 |
+
pooled_prompt_embeds,
|
58 |
+
negative_pooled_prompt_embeds,
|
59 |
+
) = pipe.encode_prompt(
|
60 |
+
prompt=prompt,
|
61 |
+
prompt_2=prompt,
|
62 |
+
prompt_3=prompt,
|
63 |
+
negative_prompt=config["negative_prompt"],
|
64 |
+
negative_prompt_2=config["negative_prompt"],
|
65 |
+
negative_prompt_3=config["negative_prompt"],
|
66 |
+
do_classifier_free_guidance=pipe.do_classifier_free_guidance,
|
67 |
+
prompt_embeds=None,
|
68 |
+
negative_prompt_embeds=None,
|
69 |
+
pooled_prompt_embeds=None,
|
70 |
+
negative_pooled_prompt_embeds=None,
|
71 |
+
device=device,
|
72 |
+
clip_skip=None,
|
73 |
+
num_images_per_prompt=1,
|
74 |
+
max_sequence_length=256,
|
75 |
+
lora_scale=None,
|
76 |
+
)
|
77 |
+
|
78 |
+
if pipe.do_classifier_free_guidance:
|
79 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
80 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
81 |
+
inp_kwargs = {
|
82 |
+
"prompt_embeds": prompt_embeds,
|
83 |
+
"pooled_prompt_embeds": pooled_prompt_embeds,
|
84 |
+
"guidance": config["guidance"],
|
85 |
+
}
|
86 |
+
inp_kwargs_list.append(inp_kwargs)
|
87 |
+
pipe.vae.to(device).to(torch.bfloat16)
|
88 |
+
pipe.transformer.to(device).to(torch.bfloat16)
|
89 |
+
|
90 |
+
|
91 |
+
return pipe, inp_kwargs_list
|
92 |
+
|
93 |
+
def load_model(config, device=["cuda"]):
|
94 |
+
if re.match(r"SD3*", config["model"]):
|
95 |
+
return load_sd3(config, device)
|
96 |
+
else:
|
97 |
+
raise ValueError(f"Unknown model type {config['model']}")
|
src/flair/pipelines/sd3.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Dict, Any
|
3 |
+
from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
|
4 |
+
from flair.pipelines import utils
|
5 |
+
import tqdm
|
6 |
+
|
7 |
+
class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
|
8 |
+
def to(self, device, kwargs):
|
9 |
+
self.transformer.to(device)
|
10 |
+
self.vae.to(device)
|
11 |
+
return self
|
12 |
+
|
13 |
+
def get_timesteps(self, n_steps, device, ts_min=0):
|
14 |
+
# Create a linear schedule for timesteps
|
15 |
+
timesteps = torch.linspace(1, ts_min, n_steps+2, device=device, dtype=torch.float32)
|
16 |
+
return timesteps[1:-1] # Exclude the first and last timesteps
|
17 |
+
|
18 |
+
def single_step(
|
19 |
+
self,
|
20 |
+
img_latent: torch.Tensor,
|
21 |
+
t: torch.Tensor,
|
22 |
+
kwargs: Dict[str, Any],
|
23 |
+
is_noised_latent = False,
|
24 |
+
):
|
25 |
+
if "noise" in kwargs:
|
26 |
+
noise = kwargs["noise"].detach()
|
27 |
+
alpha = kwargs["inv_alpha"]
|
28 |
+
if alpha == "tsqrt":
|
29 |
+
alpha = t**0.5 # * 0.75
|
30 |
+
elif alpha == "t":
|
31 |
+
alpha = t
|
32 |
+
elif alpha == "sine":
|
33 |
+
alpha = torch.sin(t * 3.141592653589793/2)
|
34 |
+
elif alpha == "1-t":
|
35 |
+
alpha = 1 - t
|
36 |
+
elif alpha == "1-t*0.5":
|
37 |
+
alpha = (1 - t)*0.5
|
38 |
+
elif alpha == "1-t*0.9":
|
39 |
+
alpha = (1 - t)*0.9
|
40 |
+
elif alpha == "t**1/3":
|
41 |
+
alpha = t**(1/3)
|
42 |
+
elif alpha == "(1-t)**0.5":
|
43 |
+
alpha = (1-t)**0.5
|
44 |
+
elif alpha == "((1-t)*0.8)**0.5":
|
45 |
+
alpha = (1-t*0.8)**0.5
|
46 |
+
elif alpha == "(1-t)**2":
|
47 |
+
alpha = (1-t)**2
|
48 |
+
# alpha = t * kwargs["inv_alpha"]
|
49 |
+
noise = (alpha) * noise + (1-alpha**2)**0.5 * torch.randn_like(img_latent)
|
50 |
+
# noise = noise / noise.std()
|
51 |
+
# noise = noise / (1- 2*alpha*(1-alpha))**0.5
|
52 |
+
# noise = noise + alpha * torch.randn_like(img_latent)
|
53 |
+
else:
|
54 |
+
noise = torch.randn_like(img_latent)
|
55 |
+
if is_noised_latent:
|
56 |
+
noised_latent = img_latent
|
57 |
+
else:
|
58 |
+
noised_latent = t * noise + (1 - t) * img_latent
|
59 |
+
latent_model_input = torch.cat([noised_latent] * 2) if self.do_classifier_free_guidance else noised_latent
|
60 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
61 |
+
timestep = t.expand(latent_model_input.shape[0])
|
62 |
+
noise_pred = self.transformer(
|
63 |
+
hidden_states=latent_model_input.to(img_latent.dtype),
|
64 |
+
timestep=(timestep*1000).to(img_latent.dtype),
|
65 |
+
encoder_hidden_states=kwargs["prompt_embeds"].repeat(img_latent.shape[0], 1, 1),
|
66 |
+
pooled_projections=kwargs["pooled_prompt_embeds"].repeat(img_latent.shape[0], 1),
|
67 |
+
joint_attention_kwargs=None,
|
68 |
+
return_dict=False,
|
69 |
+
)[0]
|
70 |
+
if self.do_classifier_free_guidance:
|
71 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
72 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
73 |
+
|
74 |
+
eps = utils.v_to_eps(noise_pred, t, noised_latent)
|
75 |
+
return eps, noise, (1 - t), t, noise_pred
|
76 |
+
|
77 |
+
def encode(self, img):
|
78 |
+
# Encode the image into latent space
|
79 |
+
|
80 |
+
img_latent = self.vae.encode(img, return_dict=False)[0]
|
81 |
+
if hasattr(img_latent, "sample"):
|
82 |
+
img_latent = img_latent.sample()
|
83 |
+
img_latent = (img_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
84 |
+
return img_latent
|
85 |
+
|
86 |
+
def decode(self, img_latent):
|
87 |
+
# Decode the latent representation back to image space
|
88 |
+
img = self.vae.decode(img_latent / self.vae.config.scaling_factor + self.vae.config.shift_factor, return_dict=False)[0]
|
89 |
+
return img
|
90 |
+
|
91 |
+
def denoise(self, pseudo_inv, kwargs, inverse=False):
|
92 |
+
# get timesteps
|
93 |
+
timesteps = torch.linspace(1, 0, kwargs["n_steps"], device=pseudo_inv.device, dtype=pseudo_inv.dtype)
|
94 |
+
sigmas = timesteps
|
95 |
+
if inverse:
|
96 |
+
timesteps = timesteps.flip(0)
|
97 |
+
sigmas = sigmas.flip(0)
|
98 |
+
|
99 |
+
# make a single step
|
100 |
+
for i, t in tqdm.tqdm(enumerate(timesteps[:-1]), desc="Denoising", total=len(timesteps)-1):
|
101 |
+
eps, noise, _, t, v = self.single_step(
|
102 |
+
pseudo_inv,
|
103 |
+
t.to("cuda")*1000,
|
104 |
+
kwargs,
|
105 |
+
is_noised_latent=True,
|
106 |
+
)
|
107 |
+
# step
|
108 |
+
sigma_next = sigmas[i+1]
|
109 |
+
sigma_t = sigmas[i]
|
110 |
+
pseudo_inv = pseudo_inv + v * (sigma_next - sigma_t)
|
111 |
+
return pseudo_inv
|
src/flair/pipelines/utils.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def eps_from_v(z_0, z_t, sigma_t):
|
6 |
+
return (z_t - z_0) / sigma_t
|
7 |
+
|
8 |
+
def v_to_eps(v, t, x_t):
|
9 |
+
"""
|
10 |
+
function to compute the epsilon parametrization from the velocity field
|
11 |
+
with x_t = t * x_0 + (1 - t) * x_1 with x_0 ~ N(0,I)
|
12 |
+
"""
|
13 |
+
eps_t = (1-t)*v + x_t
|
14 |
+
return eps_t
|
15 |
+
|
16 |
+
|
17 |
+
def clip_gradients(gradients, clip_value):
|
18 |
+
grad_norm = gradients.norm(dim=2)
|
19 |
+
mask = grad_norm > clip_value
|
20 |
+
mask_exp = mask[:, :, None].expand_as(gradients)
|
21 |
+
gradients[mask_exp] = (
|
22 |
+
gradients[mask_exp]
|
23 |
+
/ grad_norm[:, :, None].expand_as(gradients)[mask_exp]
|
24 |
+
* clip_value
|
25 |
+
)
|
26 |
+
return gradients
|
27 |
+
|
28 |
+
|
29 |
+
class Adam:
|
30 |
+
|
31 |
+
def __init__(self, parameters, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
|
32 |
+
self.lr = lr
|
33 |
+
self.beta1 = beta1
|
34 |
+
self.beta2 = beta2
|
35 |
+
self.epsilon = epsilon
|
36 |
+
self.t = 0
|
37 |
+
self.m = torch.zeros_like(parameters)
|
38 |
+
self.v = torch.zeros_like(parameters)
|
39 |
+
|
40 |
+
def step(self, params, grad) -> torch.Tensor:
|
41 |
+
self.t += 1
|
42 |
+
|
43 |
+
self.m = self.beta1 * self.m + (1 - self.beta1) * grad
|
44 |
+
self.v = self.beta2 * self.v + (1 - self.beta2) * grad**2
|
45 |
+
|
46 |
+
m_hat = self.m / (1 - self.beta1**self.t)
|
47 |
+
v_hat = self.v / (1 - self.beta2**self.t)
|
48 |
+
|
49 |
+
# check if self.lr is callable
|
50 |
+
if callable(self.lr):
|
51 |
+
lr = self.lr(self.t - 1)
|
52 |
+
else:
|
53 |
+
lr = self.lr
|
54 |
+
update = lr * m_hat / (torch.sqrt(v_hat) + self.epsilon)
|
55 |
+
|
56 |
+
return params - update
|
57 |
+
|
58 |
+
|
59 |
+
def make_cosine_decay_schedule(
|
60 |
+
init_value: float,
|
61 |
+
total_steps: int,
|
62 |
+
alpha: float = 0.0,
|
63 |
+
exponent: float = 1.0,
|
64 |
+
warmup_steps=0,
|
65 |
+
):
|
66 |
+
def schedule(count):
|
67 |
+
if count < warmup_steps:
|
68 |
+
# linear up
|
69 |
+
return (init_value / warmup_steps) * count
|
70 |
+
else:
|
71 |
+
# half cosine down
|
72 |
+
decay_steps = total_steps - warmup_steps
|
73 |
+
count = min(count - warmup_steps, decay_steps)
|
74 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * count / decay_steps))
|
75 |
+
decayed = (1 - alpha) * cosine_decay**exponent + alpha
|
76 |
+
return init_value * decayed
|
77 |
+
|
78 |
+
return schedule
|
79 |
+
|
80 |
+
|
81 |
+
def make_linear_decay_schedule(
|
82 |
+
init_value: float, total_steps: int, final_value: float = 0, warmup_steps=0
|
83 |
+
):
|
84 |
+
def schedule(count):
|
85 |
+
if count < warmup_steps:
|
86 |
+
# linear up
|
87 |
+
return (init_value / warmup_steps) * count
|
88 |
+
else:
|
89 |
+
# linear down
|
90 |
+
decay_steps = total_steps - warmup_steps
|
91 |
+
count = min(count - warmup_steps, decay_steps)
|
92 |
+
return init_value - (init_value - final_value) * count / decay_steps
|
93 |
+
|
94 |
+
return schedule
|
95 |
+
|
96 |
+
|
97 |
+
def clip_norm_(tensor, max_norm):
|
98 |
+
norm = tensor.norm()
|
99 |
+
if norm > max_norm:
|
100 |
+
tensor.mul_(max_norm / norm)
|
101 |
+
|
102 |
+
|
103 |
+
def lr_warmup(step, warmup_steps):
|
104 |
+
return min(1.0, step / max(warmup_steps, 1))
|
105 |
+
|
106 |
+
|
107 |
+
def linear_decay_lambda(step, warmup_steps, decay_steps, total_steps):
|
108 |
+
if step < warmup_steps:
|
109 |
+
min(1.0, step / max(warmup_steps, 1))
|
110 |
+
else:
|
111 |
+
# linear down
|
112 |
+
# decay_steps = total_steps - warmup_steps
|
113 |
+
count = min(step - warmup_steps, decay_steps)
|
114 |
+
return 1 - count / decay_steps
|
src/flair/scheduling.py
ADDED
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.utils.torch_utils import randn_tensor
|
25 |
+
from diffusers.schedulers.scheduling_utils import (
|
26 |
+
KarrasDiffusionSchedulers,
|
27 |
+
SchedulerMixin,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
|
36 |
+
class EulerDiscreteSchedulerOutput(BaseOutput):
|
37 |
+
"""
|
38 |
+
Output class for the scheduler's `step` function output.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
42 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
43 |
+
denoising loop.
|
44 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
45 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
46 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
47 |
+
"""
|
48 |
+
|
49 |
+
prev_sample: torch.FloatTensor
|
50 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
51 |
+
|
52 |
+
|
53 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
54 |
+
def betas_for_alpha_bar(
|
55 |
+
num_diffusion_timesteps,
|
56 |
+
max_beta=0.999,
|
57 |
+
alpha_transform_type="cosine",
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
61 |
+
(1-beta) over time from t = [0,1].
|
62 |
+
|
63 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
64 |
+
to that part of the diffusion process.
|
65 |
+
|
66 |
+
|
67 |
+
Args:
|
68 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
69 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
70 |
+
prevent singularities.
|
71 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
72 |
+
Choose from `cosine` or `exp`
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
76 |
+
"""
|
77 |
+
if alpha_transform_type == "cosine":
|
78 |
+
|
79 |
+
def alpha_bar_fn(t):
|
80 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
81 |
+
|
82 |
+
elif alpha_transform_type == "exp":
|
83 |
+
|
84 |
+
def alpha_bar_fn(t):
|
85 |
+
return math.exp(t * -12.0)
|
86 |
+
|
87 |
+
else:
|
88 |
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
89 |
+
|
90 |
+
betas = []
|
91 |
+
for i in range(num_diffusion_timesteps):
|
92 |
+
t1 = i / num_diffusion_timesteps
|
93 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
94 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
95 |
+
return torch.tensor(betas, dtype=torch.float32)
|
96 |
+
|
97 |
+
|
98 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
99 |
+
def rescale_zero_terminal_snr(betas):
|
100 |
+
"""
|
101 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
102 |
+
|
103 |
+
|
104 |
+
Args:
|
105 |
+
betas (`torch.FloatTensor`):
|
106 |
+
the betas that the scheduler is being initialized with.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
110 |
+
"""
|
111 |
+
# Convert betas to alphas_bar_sqrt
|
112 |
+
alphas = 1.0 - betas
|
113 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
114 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
115 |
+
|
116 |
+
# Store old values.
|
117 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
118 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
119 |
+
|
120 |
+
# Shift so the last timestep is zero.
|
121 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
122 |
+
|
123 |
+
# Scale so the first timestep is back to the old value.
|
124 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
125 |
+
|
126 |
+
# Convert alphas_bar_sqrt to betas
|
127 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
128 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
129 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
130 |
+
betas = 1 - alphas
|
131 |
+
|
132 |
+
return betas
|
133 |
+
|
134 |
+
|
135 |
+
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
136 |
+
"""
|
137 |
+
Euler scheduler.
|
138 |
+
|
139 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
140 |
+
methods the library implements for all schedulers such as loading and saving.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
num_train_timesteps (`int`, defaults to 1000):
|
144 |
+
The number of diffusion steps to train the model.
|
145 |
+
beta_start (`float`, defaults to 0.0001):
|
146 |
+
The starting `beta` value of inference.
|
147 |
+
beta_end (`float`, defaults to 0.02):
|
148 |
+
The final `beta` value.
|
149 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
150 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
151 |
+
`linear` or `scaled_linear`.
|
152 |
+
trained_betas (`np.ndarray`, *optional*):
|
153 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
154 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
155 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
156 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
157 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
158 |
+
interpolation_type(`str`, defaults to `"linear"`, *optional*):
|
159 |
+
The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
|
160 |
+
`"linear"` or `"log_linear"`.
|
161 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
162 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
163 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
164 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
165 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
166 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
167 |
+
steps_offset (`int`, defaults to 0):
|
168 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
169 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
170 |
+
Diffusion.
|
171 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
172 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
173 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
174 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
175 |
+
"""
|
176 |
+
|
177 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
178 |
+
order = 1
|
179 |
+
|
180 |
+
@register_to_config
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
num_train_timesteps: int = 1000,
|
184 |
+
beta_start: float = 0.0001,
|
185 |
+
beta_end: float = 0.02,
|
186 |
+
beta_schedule: str = "linear",
|
187 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
188 |
+
prediction_type: str = "epsilon",
|
189 |
+
interpolation_type: str = "linear",
|
190 |
+
use_karras_sigmas: Optional[bool] = False,
|
191 |
+
sigma_min: Optional[float] = None,
|
192 |
+
sigma_max: Optional[float] = None,
|
193 |
+
timestep_spacing: str = "linspace",
|
194 |
+
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
195 |
+
steps_offset: int = 0,
|
196 |
+
rescale_betas_zero_snr: bool = False,
|
197 |
+
):
|
198 |
+
if trained_betas is not None:
|
199 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
200 |
+
elif beta_schedule == "linear":
|
201 |
+
self.betas = torch.linspace(
|
202 |
+
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
|
203 |
+
)
|
204 |
+
elif beta_schedule == "scaled_linear":
|
205 |
+
# this schedule is very specific to the latent diffusion model.
|
206 |
+
self.betas = (
|
207 |
+
torch.linspace(
|
208 |
+
beta_start**0.5,
|
209 |
+
beta_end**0.5,
|
210 |
+
num_train_timesteps,
|
211 |
+
dtype=torch.float32,
|
212 |
+
)
|
213 |
+
** 2
|
214 |
+
)
|
215 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
216 |
+
# Glide cosine schedule
|
217 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
218 |
+
else:
|
219 |
+
raise NotImplementedError(
|
220 |
+
f"{beta_schedule} does is not implemented for {self.__class__}"
|
221 |
+
)
|
222 |
+
|
223 |
+
if rescale_betas_zero_snr:
|
224 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
225 |
+
|
226 |
+
self.alphas = 1.0 - self.betas
|
227 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
228 |
+
|
229 |
+
if rescale_betas_zero_snr:
|
230 |
+
# Close to 0 without being 0 so first sigma is not inf
|
231 |
+
# FP16 smallest positive subnormal works well here
|
232 |
+
self.alphas_cumprod[-1] = 2**-24
|
233 |
+
|
234 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
235 |
+
timesteps = np.linspace(
|
236 |
+
0, num_train_timesteps - 1, num_train_timesteps, dtype=float
|
237 |
+
)[::-1].copy()
|
238 |
+
|
239 |
+
sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
|
240 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
241 |
+
|
242 |
+
# setable values
|
243 |
+
self.num_inference_steps = None
|
244 |
+
|
245 |
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
246 |
+
if timestep_type == "continuous" and prediction_type == "v_prediction":
|
247 |
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
|
248 |
+
else:
|
249 |
+
self.timesteps = timesteps
|
250 |
+
|
251 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
252 |
+
|
253 |
+
self.is_scale_input_called = False
|
254 |
+
self.use_karras_sigmas = use_karras_sigmas
|
255 |
+
|
256 |
+
self._step_index = None
|
257 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
258 |
+
self.sigmas_for_inference = False
|
259 |
+
|
260 |
+
@property
|
261 |
+
def init_noise_sigma(self):
|
262 |
+
# standard deviation of the initial noise distribution
|
263 |
+
max_sigma = (
|
264 |
+
max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
|
265 |
+
)
|
266 |
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
267 |
+
return max_sigma
|
268 |
+
|
269 |
+
return (max_sigma**2 + 1) ** 0.5
|
270 |
+
|
271 |
+
@property
|
272 |
+
def step_index(self):
|
273 |
+
"""
|
274 |
+
The index counter for current timestep. It will increae 1 after each scheduler step.
|
275 |
+
"""
|
276 |
+
return self._step_index
|
277 |
+
|
278 |
+
def scale_model_input(
|
279 |
+
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
280 |
+
) -> torch.FloatTensor:
|
281 |
+
"""
|
282 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
283 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
sample (`torch.FloatTensor`):
|
287 |
+
The input sample.
|
288 |
+
timestep (`int`, *optional*):
|
289 |
+
The current timestep in the diffusion chain.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
`torch.FloatTensor`:
|
293 |
+
A scaled input sample.
|
294 |
+
"""
|
295 |
+
if self.step_index is None:
|
296 |
+
self._init_step_index(timestep)
|
297 |
+
|
298 |
+
sigma = self.sigmas[self.step_index]
|
299 |
+
sample = sample / ((sigma**2 + 1) ** 0.5)
|
300 |
+
|
301 |
+
self.is_scale_input_called = True
|
302 |
+
return sample
|
303 |
+
|
304 |
+
def set_timesteps(
|
305 |
+
self,
|
306 |
+
num_inference_steps: int,
|
307 |
+
timesteps=None,
|
308 |
+
device: Union[str, torch.device] = None,
|
309 |
+
):
|
310 |
+
"""
|
311 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
312 |
+
|
313 |
+
Args:
|
314 |
+
num_inference_steps (`int`):
|
315 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
316 |
+
device (`str` or `torch.device`, *optional*):
|
317 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
318 |
+
"""
|
319 |
+
self.num_inference_steps = num_inference_steps
|
320 |
+
|
321 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
322 |
+
if self.config.timestep_spacing == "linspace":
|
323 |
+
timesteps = np.linspace(
|
324 |
+
0,
|
325 |
+
self.config.num_train_timesteps - 1,
|
326 |
+
num_inference_steps,
|
327 |
+
dtype=np.float32,
|
328 |
+
)[::-1].copy()
|
329 |
+
elif self.config.timestep_spacing == "leading":
|
330 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
331 |
+
# creates integer timesteps by multiplying by ratio
|
332 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
333 |
+
timesteps = (
|
334 |
+
(np.arange(0, num_inference_steps) * step_ratio)
|
335 |
+
.round()[::-1]
|
336 |
+
.copy()
|
337 |
+
.astype(np.float32)
|
338 |
+
)
|
339 |
+
timesteps += self.config.steps_offset
|
340 |
+
elif self.config.timestep_spacing == "trailing":
|
341 |
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
342 |
+
# creates integer timesteps by multiplying by ratio
|
343 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
344 |
+
timesteps = (
|
345 |
+
(np.arange(self.config.num_train_timesteps, 0, -step_ratio))
|
346 |
+
.round()
|
347 |
+
.copy()
|
348 |
+
.astype(np.float32)
|
349 |
+
)
|
350 |
+
timesteps -= 1
|
351 |
+
else:
|
352 |
+
raise ValueError(
|
353 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
354 |
+
)
|
355 |
+
|
356 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
357 |
+
log_sigmas = np.log(sigmas)
|
358 |
+
|
359 |
+
if self.config.interpolation_type == "linear":
|
360 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
361 |
+
elif self.config.interpolation_type == "log_linear":
|
362 |
+
sigmas = (
|
363 |
+
torch.linspace(
|
364 |
+
np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1
|
365 |
+
)
|
366 |
+
.exp()
|
367 |
+
.numpy()
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
raise ValueError(
|
371 |
+
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
|
372 |
+
" 'linear' or 'log_linear'"
|
373 |
+
)
|
374 |
+
|
375 |
+
if self.use_karras_sigmas:
|
376 |
+
sigmas = self._convert_to_karras(
|
377 |
+
in_sigmas=sigmas, num_inference_steps=self.num_inference_steps
|
378 |
+
)
|
379 |
+
timesteps = np.array(
|
380 |
+
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
|
381 |
+
)
|
382 |
+
|
383 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
384 |
+
|
385 |
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
386 |
+
if (
|
387 |
+
self.config.timestep_type == "continuous"
|
388 |
+
and self.config.prediction_type == "v_prediction"
|
389 |
+
):
|
390 |
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(
|
391 |
+
device=device
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(
|
395 |
+
device=device
|
396 |
+
)
|
397 |
+
|
398 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
399 |
+
self._step_index = None
|
400 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
401 |
+
self.sigmas_for_inference = True
|
402 |
+
|
403 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
404 |
+
# get log sigma
|
405 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
406 |
+
|
407 |
+
# get distribution
|
408 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
409 |
+
|
410 |
+
# get sigmas range
|
411 |
+
low_idx = (
|
412 |
+
np.cumsum((dists >= 0), axis=0)
|
413 |
+
.argmax(axis=0)
|
414 |
+
.clip(max=log_sigmas.shape[0] - 2)
|
415 |
+
)
|
416 |
+
high_idx = low_idx + 1
|
417 |
+
|
418 |
+
low = log_sigmas[low_idx]
|
419 |
+
high = log_sigmas[high_idx]
|
420 |
+
|
421 |
+
# interpolate sigmas
|
422 |
+
w = (low - log_sigma) / (low - high)
|
423 |
+
w = np.clip(w, 0, 1)
|
424 |
+
|
425 |
+
# transform interpolation to time range
|
426 |
+
t = (1 - w) * low_idx + w * high_idx
|
427 |
+
t = t.reshape(sigma.shape)
|
428 |
+
return t
|
429 |
+
|
430 |
+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
431 |
+
def _convert_to_karras(
|
432 |
+
self, in_sigmas: torch.FloatTensor, num_inference_steps
|
433 |
+
) -> torch.FloatTensor:
|
434 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
435 |
+
|
436 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
437 |
+
# TODO: Add this logic to the other schedulers
|
438 |
+
if hasattr(self.config, "sigma_min"):
|
439 |
+
sigma_min = self.config.sigma_min
|
440 |
+
else:
|
441 |
+
sigma_min = None
|
442 |
+
|
443 |
+
if hasattr(self.config, "sigma_max"):
|
444 |
+
sigma_max = self.config.sigma_max
|
445 |
+
else:
|
446 |
+
sigma_max = None
|
447 |
+
|
448 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
449 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
450 |
+
|
451 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
452 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
453 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
454 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
455 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
456 |
+
return sigmas
|
457 |
+
|
458 |
+
def derivative_sigma_t(self, step):
|
459 |
+
rho = 7.0
|
460 |
+
sigma_max = self.sigmas[0]
|
461 |
+
sigma_min = self.sigmas[-1]
|
462 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
463 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
464 |
+
derivative_sigma_t = (
|
465 |
+
-7
|
466 |
+
* (max_inv_rho + (step) / 1000 * (min_inv_rho - max_inv_rho)) ** 6
|
467 |
+
* (min_inv_rho - max_inv_rho)
|
468 |
+
)
|
469 |
+
return derivative_sigma_t
|
470 |
+
|
471 |
+
def _init_step_index(self, timestep):
|
472 |
+
if isinstance(timestep, torch.Tensor):
|
473 |
+
timestep = timestep.to(self.timesteps.device)
|
474 |
+
|
475 |
+
index_candidates = (self.timesteps == timestep).nonzero()
|
476 |
+
|
477 |
+
# The sigma index that is taken for the **very** first `step`
|
478 |
+
# is always the second index (or the last index if there is only 1)
|
479 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
480 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
481 |
+
if len(index_candidates) > 1:
|
482 |
+
step_index = index_candidates[1]
|
483 |
+
else:
|
484 |
+
step_index = index_candidates[0]
|
485 |
+
|
486 |
+
self._step_index = step_index.item()
|
487 |
+
|
488 |
+
def step(
|
489 |
+
self,
|
490 |
+
model_output: torch.FloatTensor,
|
491 |
+
timestep: Union[float, torch.FloatTensor],
|
492 |
+
sample: torch.FloatTensor,
|
493 |
+
s_churn: float = 0.0,
|
494 |
+
s_tmin: float = 0.0,
|
495 |
+
s_tmax: float = float("inf"),
|
496 |
+
s_noise: float = 1.0,
|
497 |
+
generator: Optional[torch.Generator] = None,
|
498 |
+
return_dict: bool = True,
|
499 |
+
likelihood_grad=None,
|
500 |
+
likelihood_weight=None,
|
501 |
+
regularizer_weight=None,
|
502 |
+
step_index: Optional[int] = None,
|
503 |
+
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
504 |
+
"""
|
505 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
506 |
+
process from the learned model outputs (most often the predicted noise).
|
507 |
+
|
508 |
+
Args:
|
509 |
+
model_output (`torch.FloatTensor`):
|
510 |
+
The direct output from learned diffusion model.
|
511 |
+
timestep (`float`):
|
512 |
+
The current discrete timestep in the diffusion chain.
|
513 |
+
sample (`torch.FloatTensor`):
|
514 |
+
A current instance of a sample created by the diffusion process.
|
515 |
+
s_churn (`float`):
|
516 |
+
s_tmin (`float`):
|
517 |
+
s_tmax (`float`):
|
518 |
+
s_noise (`float`, defaults to 1.0):
|
519 |
+
Scaling factor for noise added to the sample.
|
520 |
+
generator (`torch.Generator`, *optional*):
|
521 |
+
A random number generator.
|
522 |
+
return_dict (`bool`):
|
523 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
524 |
+
tuple.
|
525 |
+
|
526 |
+
Returns:
|
527 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
528 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
529 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
530 |
+
"""
|
531 |
+
|
532 |
+
if (
|
533 |
+
isinstance(timestep, int)
|
534 |
+
or isinstance(timestep, torch.IntTensor)
|
535 |
+
or isinstance(timestep, torch.LongTensor)
|
536 |
+
):
|
537 |
+
raise ValueError(
|
538 |
+
(
|
539 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
540 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
541 |
+
" one of the `scheduler.timesteps` as a timestep."
|
542 |
+
),
|
543 |
+
)
|
544 |
+
|
545 |
+
if not self.is_scale_input_called:
|
546 |
+
logger.warning(
|
547 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
548 |
+
"See `StableDiffusionPipeline` for a usage example."
|
549 |
+
)
|
550 |
+
|
551 |
+
# if self.step_index is None:
|
552 |
+
# self._init_step_index(timestep)
|
553 |
+
if step_index is None:
|
554 |
+
if self.step_index is None:
|
555 |
+
self._init_step_index(timestep)
|
556 |
+
step_index = torch.tensor(self.step_index).unsqueeze(0)
|
557 |
+
|
558 |
+
# Upcast to avoid precision issues when computing prev_sample
|
559 |
+
sample = sample.to(torch.float32)
|
560 |
+
if likelihood_grad is not None:
|
561 |
+
likelihood_grad = likelihood_grad.to(torch.float32)
|
562 |
+
|
563 |
+
sigma = self.sigmas[step_index][:, None, None, None].to(model_output.device)
|
564 |
+
|
565 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
566 |
+
condition = torch.logical_and(s_tmin <= sigma, sigma <= s_tmax)
|
567 |
+
gamma = torch.where(condition, gamma, torch.tensor(0))
|
568 |
+
|
569 |
+
noise = randn_tensor(
|
570 |
+
model_output.shape,
|
571 |
+
dtype=model_output.dtype,
|
572 |
+
device=model_output.device,
|
573 |
+
generator=generator,
|
574 |
+
)
|
575 |
+
|
576 |
+
eps = noise * s_noise
|
577 |
+
sigma_hat = sigma * (gamma + 1)
|
578 |
+
|
579 |
+
# if gamma > 0:
|
580 |
+
sample = torch.where(
|
581 |
+
gamma > 0, sample + eps * (sigma_hat**2 - sigma**2) ** 0.5, sample
|
582 |
+
)
|
583 |
+
# sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
584 |
+
|
585 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
586 |
+
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
587 |
+
# backwards compatibility
|
588 |
+
if (
|
589 |
+
self.config.prediction_type == "original_sample"
|
590 |
+
or self.config.prediction_type == "sample"
|
591 |
+
):
|
592 |
+
pred_original_sample = model_output
|
593 |
+
elif self.config.prediction_type == "epsilon":
|
594 |
+
pred_original_sample = sample - sigma_hat * model_output
|
595 |
+
elif self.config.prediction_type == "v_prediction":
|
596 |
+
# denoised = model_output * c_out + input * c_skip
|
597 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
|
598 |
+
sample / (sigma**2 + 1)
|
599 |
+
)
|
600 |
+
else:
|
601 |
+
raise ValueError(
|
602 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
603 |
+
)
|
604 |
+
|
605 |
+
# 2. Convert to an ODE derivative
|
606 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
607 |
+
if likelihood_grad is not None:
|
608 |
+
likelihood_grad = likelihood_grad * (
|
609 |
+
torch.norm(derivative) / torch.norm(likelihood_grad)
|
610 |
+
)
|
611 |
+
derivative = (
|
612 |
+
regularizer_weight * derivative + likelihood_weight * likelihood_grad
|
613 |
+
)
|
614 |
+
|
615 |
+
dt = (
|
616 |
+
self.sigmas[step_index + 1].to(sigma_hat.device)[:, None, None, None]
|
617 |
+
- sigma_hat
|
618 |
+
)
|
619 |
+
|
620 |
+
prev_sample = sample + derivative * dt
|
621 |
+
|
622 |
+
# Cast sample back to model compatible dtype
|
623 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
624 |
+
|
625 |
+
# upon completion increase step index by one
|
626 |
+
if self._step_index is not None:
|
627 |
+
self._step_index += 1
|
628 |
+
|
629 |
+
if not return_dict:
|
630 |
+
return (prev_sample,)
|
631 |
+
|
632 |
+
return EulerDiscreteSchedulerOutput(
|
633 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
634 |
+
)
|
635 |
+
|
636 |
+
def add_noise(
|
637 |
+
self,
|
638 |
+
original_samples: torch.FloatTensor,
|
639 |
+
noise: torch.FloatTensor,
|
640 |
+
timesteps: torch.FloatTensor,
|
641 |
+
) -> torch.FloatTensor:
|
642 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
643 |
+
sigmas = self.sigmas.to(
|
644 |
+
device=original_samples.device, dtype=original_samples.dtype
|
645 |
+
)
|
646 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
647 |
+
# mps does not support float64
|
648 |
+
schedule_timesteps = self.timesteps.to(
|
649 |
+
original_samples.device, dtype=torch.float32
|
650 |
+
)
|
651 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
652 |
+
else:
|
653 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
654 |
+
timesteps = timesteps.to(original_samples.device)
|
655 |
+
|
656 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
657 |
+
|
658 |
+
sigma = sigmas[step_indices].flatten()
|
659 |
+
while len(sigma.shape) < len(original_samples.shape):
|
660 |
+
sigma = sigma.unsqueeze(-1)
|
661 |
+
|
662 |
+
noisy_samples = original_samples + noise * sigma
|
663 |
+
return noisy_samples
|
664 |
+
|
665 |
+
def __len__(self):
|
666 |
+
return self.config.num_train_timesteps
|
667 |
+
|
668 |
+
|
669 |
+
class HeunDiscreteScheduler(EulerDiscreteScheduler):
|
670 |
+
|
671 |
+
def set_timesteps(
|
672 |
+
self,
|
673 |
+
num_inference_steps: int,
|
674 |
+
timesteps=None,
|
675 |
+
device: Union[str, torch.device] = None,
|
676 |
+
):
|
677 |
+
"""
|
678 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
679 |
+
|
680 |
+
Args:
|
681 |
+
num_inference_steps (`int`):
|
682 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
683 |
+
device (`str` or `torch.device`, *optional*):
|
684 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
685 |
+
"""
|
686 |
+
self.num_inference_steps = num_inference_steps
|
687 |
+
|
688 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
689 |
+
if self.config.timestep_spacing == "linspace":
|
690 |
+
timesteps = np.linspace(
|
691 |
+
0,
|
692 |
+
self.config.num_train_timesteps - 1,
|
693 |
+
num_inference_steps,
|
694 |
+
dtype=np.float32,
|
695 |
+
)[::-1].copy()
|
696 |
+
elif self.config.timestep_spacing == "leading":
|
697 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
698 |
+
# creates integer timesteps by multiplying by ratio
|
699 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
700 |
+
timesteps = (
|
701 |
+
(np.arange(0, num_inference_steps) * step_ratio)
|
702 |
+
.round()[::-1]
|
703 |
+
.copy()
|
704 |
+
.astype(np.float32)
|
705 |
+
)
|
706 |
+
timesteps += self.config.steps_offset
|
707 |
+
elif self.config.timestep_spacing == "trailing":
|
708 |
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
709 |
+
# creates integer timesteps by multiplying by ratio
|
710 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
711 |
+
timesteps = (
|
712 |
+
(np.arange(self.config.num_train_timesteps, 0, -step_ratio))
|
713 |
+
.round()
|
714 |
+
.copy()
|
715 |
+
.astype(np.float32)
|
716 |
+
)
|
717 |
+
timesteps -= 1
|
718 |
+
else:
|
719 |
+
raise ValueError(
|
720 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
721 |
+
)
|
722 |
+
|
723 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
724 |
+
log_sigmas = np.log(sigmas)
|
725 |
+
|
726 |
+
if self.config.interpolation_type == "linear":
|
727 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
728 |
+
elif self.config.interpolation_type == "log_linear":
|
729 |
+
sigmas = (
|
730 |
+
torch.linspace(
|
731 |
+
np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1
|
732 |
+
)
|
733 |
+
.exp()
|
734 |
+
.numpy()
|
735 |
+
)
|
736 |
+
else:
|
737 |
+
raise ValueError(
|
738 |
+
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
|
739 |
+
" 'linear' or 'log_linear'"
|
740 |
+
)
|
741 |
+
|
742 |
+
if self.use_karras_sigmas:
|
743 |
+
sigmas = self._convert_to_karras(
|
744 |
+
in_sigmas=sigmas, num_inference_steps=self.num_inference_steps
|
745 |
+
)
|
746 |
+
timesteps = np.array(
|
747 |
+
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
|
748 |
+
)
|
749 |
+
|
750 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
751 |
+
|
752 |
+
sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2)])
|
753 |
+
|
754 |
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
755 |
+
if (
|
756 |
+
self.config.timestep_type == "continuous"
|
757 |
+
and self.config.prediction_type == "v_prediction"
|
758 |
+
):
|
759 |
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(
|
760 |
+
device=device
|
761 |
+
)
|
762 |
+
else:
|
763 |
+
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(
|
764 |
+
device=device
|
765 |
+
)
|
766 |
+
|
767 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
768 |
+
self._step_index = None
|
769 |
+
self.sigmas = self.sigmas.to("cpu")
|
770 |
+
|
771 |
+
# empty dt and derivative
|
772 |
+
self.prev_derivative = None
|
773 |
+
self.dt = None
|
774 |
+
|
775 |
+
def step(
|
776 |
+
self,
|
777 |
+
model_output: Union[torch.Tensor, np.ndarray],
|
778 |
+
timestep: Union[float, torch.Tensor],
|
779 |
+
sample: Union[torch.Tensor, np.ndarray],
|
780 |
+
return_dict: bool = True,
|
781 |
+
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
782 |
+
"""
|
783 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
784 |
+
process from the learned model outputs (most often the predicted noise).
|
785 |
+
|
786 |
+
Args:
|
787 |
+
model_output (`torch.Tensor`):
|
788 |
+
The direct output from learned diffusion model.
|
789 |
+
timestep (`float`):
|
790 |
+
The current discrete timestep in the diffusion chain.
|
791 |
+
sample (`torch.Tensor`):
|
792 |
+
A current instance of a sample created by the diffusion process.
|
793 |
+
return_dict (`bool`):
|
794 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
795 |
+
|
796 |
+
Returns:
|
797 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
798 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
799 |
+
tuple is returned where the first element is the sample tensor.
|
800 |
+
"""
|
801 |
+
if self.step_index is None:
|
802 |
+
self._init_step_index(timestep)
|
803 |
+
|
804 |
+
if self.state_in_first_order:
|
805 |
+
sigma = self.sigmas[self.step_index]
|
806 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
807 |
+
else:
|
808 |
+
# 2nd order / Heun's method
|
809 |
+
sigma = self.sigmas[self.step_index - 1]
|
810 |
+
sigma_next = self.sigmas[self.step_index]
|
811 |
+
|
812 |
+
# currently only gamma=0 is supported. This usually works best anyways.
|
813 |
+
# We can support gamma in the future but then need to scale the timestep before
|
814 |
+
# passing it to the model which requires a change in API
|
815 |
+
gamma = 0
|
816 |
+
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
817 |
+
|
818 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
819 |
+
if self.config.prediction_type == "epsilon":
|
820 |
+
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
|
821 |
+
pred_original_sample = sample - sigma_input * model_output
|
822 |
+
elif self.config.prediction_type == "v_prediction":
|
823 |
+
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
|
824 |
+
pred_original_sample = model_output * (
|
825 |
+
-sigma_input / (sigma_input**2 + 1) ** 0.5
|
826 |
+
) + (sample / (sigma_input**2 + 1))
|
827 |
+
elif self.config.prediction_type == "sample":
|
828 |
+
pred_original_sample = model_output
|
829 |
+
else:
|
830 |
+
raise ValueError(
|
831 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
832 |
+
)
|
833 |
+
|
834 |
+
if self.config.clip_sample:
|
835 |
+
pred_original_sample = pred_original_sample.clamp(
|
836 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
837 |
+
)
|
838 |
+
|
839 |
+
if self.state_in_first_order:
|
840 |
+
# 2. Convert to an ODE derivative for 1st order
|
841 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
842 |
+
# 3. delta timestep
|
843 |
+
dt = sigma_next - sigma_hat
|
844 |
+
|
845 |
+
# store for 2nd order step
|
846 |
+
self.prev_derivative = derivative
|
847 |
+
self.dt = dt
|
848 |
+
self.sample = sample
|
849 |
+
else:
|
850 |
+
# 2. 2nd order / Heun's method
|
851 |
+
derivative = (sample - pred_original_sample) / sigma_next
|
852 |
+
derivative = (self.prev_derivative + derivative) / 2
|
853 |
+
|
854 |
+
# 3. take prev timestep & sample
|
855 |
+
dt = self.dt
|
856 |
+
sample = self.sample
|
857 |
+
|
858 |
+
# free dt and derivative
|
859 |
+
# Note, this puts the scheduler in "first order mode"
|
860 |
+
self.prev_derivative = None
|
861 |
+
self.dt = None
|
862 |
+
self.sample = None
|
863 |
+
|
864 |
+
prev_sample = sample + derivative * dt
|
865 |
+
|
866 |
+
# upon completion increase step index by one
|
867 |
+
self._step_index += 1
|
868 |
+
|
869 |
+
if not return_dict:
|
870 |
+
return (prev_sample,)
|
871 |
+
|
872 |
+
return EulerDiscreteSchedulerOutput(
|
873 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
874 |
+
)
|
875 |
+
|
876 |
+
@property
|
877 |
+
def state_in_first_order(self):
|
878 |
+
return self.dt is None
|
879 |
+
|
880 |
+
@property
|
881 |
+
def order(self):
|
882 |
+
return 2
|
src/flair/utils.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torchmetrics.functional.image import (
|
6 |
+
peak_signal_noise_ratio,
|
7 |
+
learned_perceptual_image_patch_similarity,
|
8 |
+
)
|
9 |
+
from PIL import Image
|
10 |
+
from skimage.color import rgb2lab, lab2rgb
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
|
16 |
+
RESAMPLE_MODE = Image.BICUBIC
|
17 |
+
|
18 |
+
|
19 |
+
def skip_iterator(iterator, skip):
|
20 |
+
for i, item in enumerate(iterator):
|
21 |
+
if i % skip == 0:
|
22 |
+
yield item
|
23 |
+
|
24 |
+
|
25 |
+
def generate_output_structure(output_dir, subfolders=[]):
|
26 |
+
"""
|
27 |
+
Generate a directory structure for the output. and return the paths to the subfolders. as template
|
28 |
+
"""
|
29 |
+
output_dir = Path(output_dir)
|
30 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
31 |
+
output_paths = []
|
32 |
+
for subfolder in subfolders:
|
33 |
+
output_paths.append(Path(os.path.join(output_dir, subfolder)))
|
34 |
+
(output_paths[-1]).mkdir(exist_ok=True, parents=True)
|
35 |
+
output_paths[-1] = os.path.join(output_paths[-1], "{}.png")
|
36 |
+
return output_paths
|
37 |
+
|
38 |
+
|
39 |
+
def find_files(path, ext="png"):
|
40 |
+
if os.path.isdir(path):
|
41 |
+
path = Path(path)
|
42 |
+
sorted_files = sorted(list(path.glob(f"*.{ext}")))
|
43 |
+
return sorted_files
|
44 |
+
else:
|
45 |
+
return [path]
|
46 |
+
|
47 |
+
|
48 |
+
def load_guidance_image(path, size=None):
|
49 |
+
"""
|
50 |
+
Load an image and convert it to a tensor.
|
51 |
+
Args: path to the image
|
52 |
+
returns: tensor of the image of shape (1, 3, H, W)
|
53 |
+
"""
|
54 |
+
img = Image.open(path)
|
55 |
+
img = img.convert("RGB")
|
56 |
+
tf = transforms.Compose([
|
57 |
+
transforms.Resize(size),
|
58 |
+
transforms.CenterCrop(size),
|
59 |
+
transforms.ToTensor()
|
60 |
+
])
|
61 |
+
img = tf(img) * 2 - 1
|
62 |
+
return img.unsqueeze(0)
|
63 |
+
|
64 |
+
|
65 |
+
def yield_images(path, ext="png", size=None):
|
66 |
+
files = find_files(path, ext)
|
67 |
+
for file in files:
|
68 |
+
yield load_guidance_image(file, size)
|
69 |
+
|
70 |
+
|
71 |
+
def yield_videos(paths, ext="png", H=None, W=None, n_frames=61):
|
72 |
+
for path in paths:
|
73 |
+
yield read_video(path, H, W, n_frames)
|
74 |
+
|
75 |
+
|
76 |
+
def read_video(path, H=None, W=None, n_frames=61) -> list[Image]:
|
77 |
+
path = Path(path)
|
78 |
+
frames = []
|
79 |
+
|
80 |
+
if Path(path).is_dir():
|
81 |
+
files = sorted(list(path.glob("*.png")))
|
82 |
+
for file in files[:n_frames]:
|
83 |
+
image = Image.open(file)
|
84 |
+
image.load()
|
85 |
+
if H is not None and W is not None:
|
86 |
+
image = image.resize((W, H), resample=Image.BICUBIC)
|
87 |
+
# to tensor
|
88 |
+
image = (
|
89 |
+
torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
|
90 |
+
/ 255.0
|
91 |
+
* 2
|
92 |
+
- 1
|
93 |
+
)
|
94 |
+
frames.append(image)
|
95 |
+
H, W = frames[0].size()[-2:]
|
96 |
+
frames = torch.stack(frames).unsqueeze(0)
|
97 |
+
return frames, (10, H, W)
|
98 |
+
|
99 |
+
capture = cv2.VideoCapture(str(path))
|
100 |
+
fps = capture.get(cv2.CAP_PROP_FPS)
|
101 |
+
|
102 |
+
while True:
|
103 |
+
success, frame = capture.read()
|
104 |
+
if not success:
|
105 |
+
break
|
106 |
+
|
107 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
108 |
+
frame = Image.fromarray(frame)
|
109 |
+
if H is not None and W is not None:
|
110 |
+
frame = frame.resize((W, H), resample=Image.BICUBIC)
|
111 |
+
# to tensor
|
112 |
+
frame = (
|
113 |
+
torch.tensor(np.array(frame), dtype=torch.float32).permute(2, 0, 1)
|
114 |
+
/ 255.0
|
115 |
+
* 2
|
116 |
+
- 1
|
117 |
+
)
|
118 |
+
frames.append(frame)
|
119 |
+
|
120 |
+
capture.release()
|
121 |
+
# to torch
|
122 |
+
frames = torch.stack(frames).unsqueeze(0)
|
123 |
+
return frames, (fps, W, H)
|
124 |
+
|
125 |
+
|
126 |
+
def resize_video(
|
127 |
+
video: list[Image], width, height, resample_mode=RESAMPLE_MODE
|
128 |
+
) -> list[Image]:
|
129 |
+
frames_lr = []
|
130 |
+
for frame in video:
|
131 |
+
frame_lr = frame.resize((width, height), resample_mode)
|
132 |
+
frames_lr.append(frame_lr)
|
133 |
+
return frames_lr
|
134 |
+
|
135 |
+
|
136 |
+
def export_to_video(
|
137 |
+
video_frames,
|
138 |
+
output_video_path=None,
|
139 |
+
fps=8,
|
140 |
+
put_numbers=False,
|
141 |
+
annotations=None,
|
142 |
+
fourcc="mp4v",
|
143 |
+
):
|
144 |
+
fourcc = cv2.VideoWriter_fourcc(*fourcc) # codec
|
145 |
+
writer = cv2.VideoWriter(output_video_path, fourcc, fps, video_frames[0].size)
|
146 |
+
for i, frame in enumerate(video_frames):
|
147 |
+
frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
|
148 |
+
|
149 |
+
if put_numbers:
|
150 |
+
text_position = (frame.shape[1] - 60, 30)
|
151 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
152 |
+
font_scale = 1
|
153 |
+
font_color = (255, 255, 255)
|
154 |
+
line_type = 2
|
155 |
+
cv2.putText(
|
156 |
+
frame,
|
157 |
+
f"{i + 1}",
|
158 |
+
text_position,
|
159 |
+
font,
|
160 |
+
font_scale,
|
161 |
+
font_color,
|
162 |
+
line_type,
|
163 |
+
)
|
164 |
+
|
165 |
+
if annotations:
|
166 |
+
annotation = annotations[i]
|
167 |
+
frame = draw_bodypose(
|
168 |
+
frame, annotation["candidates"], annotation["subsets"]
|
169 |
+
)
|
170 |
+
|
171 |
+
writer.write(frame)
|
172 |
+
writer.release()
|
173 |
+
|
174 |
+
|
175 |
+
def export_images(frames: list[Image], dir_name):
|
176 |
+
dir_name = Path(dir_name)
|
177 |
+
dir_name.mkdir(exist_ok=True, parents=True)
|
178 |
+
for i, frame in enumerate(frames):
|
179 |
+
frame.save(dir_name / f"{i:05d}.png")
|
180 |
+
|
181 |
+
|
182 |
+
def vid2tensor(images: list[Image]) -> torch.Tensor:
|
183 |
+
# PIL to numpy
|
184 |
+
if not isinstance(images, list):
|
185 |
+
raise ValueError()
|
186 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
187 |
+
images = np.stack(images, axis=0)
|
188 |
+
|
189 |
+
if images.ndim == 3:
|
190 |
+
# L mode, add luminance channel
|
191 |
+
images = np.expand_dims(images, -1)
|
192 |
+
|
193 |
+
# numpy to torch
|
194 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
195 |
+
return images
|
196 |
+
|
197 |
+
|
198 |
+
def compute_metrics(
|
199 |
+
source: list[Image],
|
200 |
+
output: list[Image],
|
201 |
+
output_lq: list[Image],
|
202 |
+
target: list[Image],
|
203 |
+
) -> dict:
|
204 |
+
psnr_ab = torch.tensor(
|
205 |
+
np.mean(compute_color_metrics(output, target)["psnr_ab"])
|
206 |
+
).float()
|
207 |
+
|
208 |
+
source = vid2tensor(source)
|
209 |
+
output = vid2tensor(output)
|
210 |
+
output_lq = vid2tensor(output_lq)
|
211 |
+
target = vid2tensor(target)
|
212 |
+
|
213 |
+
mse = ((output - target) ** 2).mean()
|
214 |
+
psnr = peak_signal_noise_ratio(output, target, data_range=1.0, dim=(1, 2, 3))
|
215 |
+
# lpips = learned_perceptual_image_patch_similarity(output, target)
|
216 |
+
|
217 |
+
mse_source = ((output_lq - source) ** 2).mean()
|
218 |
+
psnr_source = peak_signal_noise_ratio(
|
219 |
+
output_lq, source, data_range=1.0, dim=(1, 2, 3)
|
220 |
+
)
|
221 |
+
|
222 |
+
return {
|
223 |
+
"mse": mse.detach().cpu().item(),
|
224 |
+
"psnr": psnr.detach().cpu().item(),
|
225 |
+
"psnr_ab": psnr_ab.detach().cpu().item(),
|
226 |
+
"mse_source": mse_source.detach().cpu().item(),
|
227 |
+
"psnr_source": psnr_source.detach().cpu().item(),
|
228 |
+
}
|
229 |
+
|
230 |
+
|
231 |
+
def compute_psnr_ab(x, y_gt, pp_max=202.33542248):
|
232 |
+
"""Computes the PSNR of the ab color channels.
|
233 |
+
|
234 |
+
Note that the CIE-Lab space is asymmetric.
|
235 |
+
The maximum size for the 2 channels of the ab subspace is approximately 202.3354...
|
236 |
+
pp_max: Approximated maximum swing for the ab channels of the CIE-Lab color space
|
237 |
+
max_{x \in CIE-Lab} {x_a x_b} - min_{x \in CIE-Lab} {x_a x_b}
|
238 |
+
"""
|
239 |
+
assert (
|
240 |
+
len(x.shape) == 3
|
241 |
+
), f"Expecting data of the size HW2 but found {x.shape}; This should be a,b channels of CIE-Lab Space"
|
242 |
+
assert (
|
243 |
+
len(y_gt.shape) == 3
|
244 |
+
), f"Expecting data of the size HW2 but found {y_gt.shape}; This should be a,b channels of CIE-Lab Space"
|
245 |
+
assert (
|
246 |
+
x.shape == y_gt.shape
|
247 |
+
), f"Expecting data to have identical shape but found {y_gt.shape} != {x.shape}"
|
248 |
+
|
249 |
+
H, W, C = x.shape
|
250 |
+
assert (
|
251 |
+
C == 2
|
252 |
+
), f"This function assumes that both x & y are both the ab channels of the CIE-Lab Space"
|
253 |
+
|
254 |
+
MSE = np.sum((x - y_gt) ** 2) / (H * W * C) # C=2, two channels
|
255 |
+
MSE = np.clip(MSE, 1e-12, np.inf)
|
256 |
+
|
257 |
+
PSNR_ab = 10 * np.log10(pp_max**2) - 10 * np.log10(MSE)
|
258 |
+
|
259 |
+
return PSNR_ab
|
260 |
+
|
261 |
+
|
262 |
+
def compute_color_metrics(out: list[Image], target: list[Image]):
|
263 |
+
if len(out) != len(target):
|
264 |
+
raise ValueError("Videos do not have same length")
|
265 |
+
|
266 |
+
metrics = {"psnr_ab": []}
|
267 |
+
|
268 |
+
for out_frame, target_frame in zip(out, target):
|
269 |
+
out_frame, target_frame = np.asarray(out_frame), np.asarray(target_frame)
|
270 |
+
out_frame_lab, target_frame_lab = rgb2lab(out_frame), rgb2lab(target_frame)
|
271 |
+
|
272 |
+
psnr_ab = compute_psnr_ab(out_frame_lab[..., 1:3], target_frame_lab[..., 1:3])
|
273 |
+
metrics["psnr_ab"].append(psnr_ab.item())
|
274 |
+
|
275 |
+
return metrics
|
276 |
+
|
277 |
+
|
278 |
+
def to_device(sample, device):
|
279 |
+
result = {}
|
280 |
+
for key, val in sample.items():
|
281 |
+
if isinstance(val, torch.Tensor):
|
282 |
+
result[key] = val.to(device)
|
283 |
+
elif isinstance(val, list):
|
284 |
+
new_val = []
|
285 |
+
for e in val:
|
286 |
+
if isinstance(e, torch.Tensor):
|
287 |
+
new_val.append(e.to(device))
|
288 |
+
else:
|
289 |
+
new_val.append(val)
|
290 |
+
result[key] = new_val
|
291 |
+
else:
|
292 |
+
result[key] = val
|
293 |
+
return result
|
294 |
+
|
295 |
+
|
296 |
+
def seed_all(seed=42):
|
297 |
+
random.seed(seed)
|
298 |
+
np.random.seed(seed)
|
299 |
+
torch.manual_seed(seed)
|
300 |
+
|
301 |
+
|
302 |
+
def adapt_unet(unet, lora_rank=None, in_conv_mode="zeros"):
|
303 |
+
# adapt conv_in
|
304 |
+
kernel = unet.conv_in.weight.data
|
305 |
+
if in_conv_mode == "zeros":
|
306 |
+
new_kernel = torch.zeros(320, 4, 3, 3, dtype=kernel.dtype, device=kernel.device)
|
307 |
+
elif in_conv_mode == "reflect":
|
308 |
+
new_kernel = kernel[:, 4:].clone()
|
309 |
+
else:
|
310 |
+
raise NotImplementedError
|
311 |
+
unet.conv_in.weight.data = torch.cat([kernel, new_kernel], dim=1)
|
312 |
+
if in_conv_mode == "reflect":
|
313 |
+
unet.conv_in.weight.data *= 2.0 / 3.0
|
314 |
+
unet.conv_in.in_channels = 12
|
315 |
+
|
316 |
+
if lora_rank is not None:
|
317 |
+
from peft import LoraConfig
|
318 |
+
|
319 |
+
types = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding)
|
320 |
+
target_modules = [
|
321 |
+
(n, m) for n, m in unet.named_modules() if isinstance(m, types)
|
322 |
+
]
|
323 |
+
# identify parameters (not modules) that will not be lora'd
|
324 |
+
for _, m in target_modules:
|
325 |
+
m.requires_grad_(False)
|
326 |
+
not_adapted = [p for p in unet.parameters() if p.requires_grad]
|
327 |
+
|
328 |
+
unet_lora_config = LoraConfig(
|
329 |
+
r=lora_rank,
|
330 |
+
lora_alpha=lora_rank,
|
331 |
+
init_lora_weights="gaussian",
|
332 |
+
# target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
333 |
+
target_modules=[n for n, _ in target_modules],
|
334 |
+
)
|
335 |
+
# the following line sets all parameters except the loras to non-trainable
|
336 |
+
unet.add_adapter(unet_lora_config)
|
337 |
+
|
338 |
+
unet.conv_in.requires_grad_()
|
339 |
+
for p in not_adapted:
|
340 |
+
p.requires_grad_()
|
341 |
+
|
342 |
+
|
343 |
+
def repeat_infinite(iterable):
|
344 |
+
def repeated():
|
345 |
+
while True:
|
346 |
+
yield from iterable
|
347 |
+
|
348 |
+
return repeated
|
349 |
+
|
350 |
+
|
351 |
+
class CPUAdam:
|
352 |
+
def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
353 |
+
self.params = list(params)
|
354 |
+
self.lr = lr
|
355 |
+
self.betas = betas
|
356 |
+
self.eps = eps
|
357 |
+
self.weight_decay = weight_decay
|
358 |
+
# keep this in main memory to save VRAM
|
359 |
+
self.state = {
|
360 |
+
param: {
|
361 |
+
"step": 0,
|
362 |
+
"exp_avg": torch.zeros_like(param, device="cpu"),
|
363 |
+
"exp_avg_sq": torch.zeros_like(param, device="cpu"),
|
364 |
+
}
|
365 |
+
for param in self.params
|
366 |
+
}
|
367 |
+
|
368 |
+
def step(self):
|
369 |
+
for param in self.params:
|
370 |
+
if param.grad is None:
|
371 |
+
continue
|
372 |
+
|
373 |
+
grad = param.grad.data.cpu()
|
374 |
+
if self.weight_decay != 0:
|
375 |
+
grad.add_(param.data, alpha=self.weight_decay)
|
376 |
+
|
377 |
+
state = self.state[param]
|
378 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
379 |
+
beta1, beta2 = self.betas
|
380 |
+
|
381 |
+
state["step"] += 1
|
382 |
+
|
383 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
384 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
385 |
+
|
386 |
+
denom = exp_avg_sq.sqrt().add_(self.eps)
|
387 |
+
|
388 |
+
step_size = (
|
389 |
+
self.lr
|
390 |
+
* (1 - beta2 ** state["step"]) ** 0.5
|
391 |
+
/ (1 - beta1 ** state["step"])
|
392 |
+
)
|
393 |
+
|
394 |
+
# param.data.add_((-step_size * (exp_avg / denom)).cuda())
|
395 |
+
|
396 |
+
param.data.addcdiv_(exp_avg.cuda(), denom.cuda(), value=-step_size)
|
397 |
+
|
398 |
+
def zero_grad(self):
|
399 |
+
for param in self.params:
|
400 |
+
param.grad = None
|
401 |
+
|
402 |
+
def state_dict(self):
|
403 |
+
return self.state
|
404 |
+
|
405 |
+
def set_lr(self, lr):
|
406 |
+
self.lr = lr
|
src/flair/utils/__init__.py
ADDED
File without changes
|
src/flair/utils/blur_util.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
import scipy
|
5 |
+
|
6 |
+
from flair.utils.motionblur import Kernel as MotionKernel
|
7 |
+
|
8 |
+
class Blurkernel(nn.Module):
|
9 |
+
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
|
10 |
+
super().__init__()
|
11 |
+
self.blur_type = blur_type
|
12 |
+
self.kernel_size = kernel_size
|
13 |
+
self.std = std
|
14 |
+
self.device = device
|
15 |
+
self.seq = nn.Sequential(
|
16 |
+
nn.ReflectionPad2d(self.kernel_size//2),
|
17 |
+
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
|
18 |
+
)
|
19 |
+
|
20 |
+
self.weights_init()
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.seq(x)
|
24 |
+
|
25 |
+
def weights_init(self):
|
26 |
+
if self.blur_type == "gaussian":
|
27 |
+
n = np.zeros((self.kernel_size, self.kernel_size))
|
28 |
+
n[self.kernel_size // 2,self.kernel_size // 2] = 1
|
29 |
+
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
|
30 |
+
k = torch.from_numpy(k)
|
31 |
+
self.k = k
|
32 |
+
for name, f in self.named_parameters():
|
33 |
+
f.data.copy_(k)
|
34 |
+
elif self.blur_type == "motion":
|
35 |
+
k = MotionKernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
|
36 |
+
k = torch.from_numpy(k)
|
37 |
+
self.k = k
|
38 |
+
for name, f in self.named_parameters():
|
39 |
+
f.data.copy_(k)
|
40 |
+
|
41 |
+
def update_weights(self, k):
|
42 |
+
if not torch.is_tensor(k):
|
43 |
+
k = torch.from_numpy(k).to(self.device)
|
44 |
+
for name, f in self.named_parameters():
|
45 |
+
f.data.copy_(k)
|
46 |
+
|
47 |
+
def get_kernel(self):
|
48 |
+
return self.k
|